nemo_automodel.components.speculative.eagle.draft_llama_v12#
Minimal Llama-based draft model for EAGLE-1 / EAGLE-2 training.
Module Contents#
Classes#
Standard Llama-style self attention for the EAGLE-1/2 draft. |
|
Standard SwiGLU MLP used by the EAGLE-1/2 draft. |
|
Single decoder layer for the minimal EAGLE-1/2 draft model. |
|
Minimal Llama draft that predicts next-step hidden states. |
Functions#
Build a standard causal + padding mask for eager attention. |
API#
- nemo_automodel.components.speculative.eagle.draft_llama_v12._build_causal_mask(
- attention_mask: torch.Tensor,
- dtype: torch.dtype,
Build a standard causal + padding mask for eager attention.
- class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaAttention(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleStandard Llama-style self attention for the EAGLE-1/2 draft.
Initialization
- _repeat_kv(tensor: torch.Tensor) torch.Tensor#
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: torch.Tensor,
- class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaMLP(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleStandard SwiGLU MLP used by the EAGLE-1/2 draft.
Initialization
- forward(hidden_states: torch.Tensor) torch.Tensor#
- class nemo_automodel.components.speculative.eagle.draft_llama_v12.EagleLlamaDecoderLayer(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleSingle decoder layer for the minimal EAGLE-1/2 draft model.
Initialization
- forward(
- hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,
- position_ids: torch.Tensor,
- class nemo_automodel.components.speculative.eagle.draft_llama_v12.LlamaEagleDraftModel(config: transformers.LlamaConfig)#
Bases:
transformers.PreTrainedModelMinimal Llama draft that predicts next-step hidden states.
Initialization
- config_class#
None
- main_input_name#
‘input_ids’
- copy_embeddings_from_target(
- target_embeddings: torch.nn.Embedding,
Copy the target model token embeddings into the draft embeddings.
- freeze_embeddings() None#
Freeze draft token embeddings.
- forward(
- input_ids: torch.Tensor,
- target_hidden_states: torch.Tensor,
- attention_mask: torch.Tensor,