nemo_automodel.components.models.llama.model#
Custom Llama model implementation for NeMo Automodel.
This module provides a self-contained Llama implementation with combined QKV and gate_up projections for improved efficiency. Following HuggingFace’s implementation with optimizations.
Example (YAML):
model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
Module Contents#
Classes#
Multi-headed attention from ‘Attention Is All You Need’ paper with combined QKV projection. |
|
SwiGLU MLP with combined gate_up projection for efficiency. |
|
Single Llama decoder layer with RMSNorm, attention, and MLP. |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. |
|
Llama transformer model (embeddings + decoder layers + norm). |
|
Llama model with causal language modeling head. |
Data#
API#
- nemo_automodel.components.models.llama.model.check_model_inputs#
‘get_check_model_inputs_decorator(…)’
- class nemo_automodel.components.models.llama.model.LlamaAttention(config: transformers.LlamaConfig, layer_idx: int)#
Bases:
nemo_automodel.components.models.common.combined_projection.CombinedQKVAttentionMixin,torch.nn.ModuleMulti-headed attention from ‘Attention Is All You Need’ paper with combined QKV projection.
Initialization
- forward(
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
- class nemo_automodel.components.models.llama.model.LlamaMLP(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleSwiGLU MLP with combined gate_up projection for efficiency.
Initialization
- forward(x: torch.Tensor) torch.Tensor#
- class nemo_automodel.components.models.llama.model.LlamaDecoderLayer(config: transformers.LlamaConfig, layer_idx: int)#
Bases:
transformers.modeling_layers.GradientCheckpointingLayerSingle Llama decoder layer with RMSNorm, attention, and MLP.
Inherits from GradientCheckpointingLayer for efficient activation checkpointing.
Initialization
- forward(
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
- class nemo_automodel.components.models.llama.model.LlamaPreTrainedModel#
Bases:
transformers.modeling_utils.PreTrainedModelAn abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
- config_class#
None
- base_model_prefix#
‘model’
- supports_gradient_checkpointing#
True
- _no_split_modules#
[‘LlamaDecoderLayer’]
- _skip_keys_device_placement#
[‘past_key_values’]
- _supports_flash_attn#
True
- _supports_sdpa#
True
- _supports_flex_attn#
True
- _can_compile_fullgraph#
True
- _supports_attention_backend#
True
- _can_record_outputs#
None
- class nemo_automodel.components.models.llama.model.LlamaModel(config: transformers.LlamaConfig)#
Bases:
nemo_automodel.components.models.llama.model.LlamaPreTrainedModelLlama transformer model (embeddings + decoder layers + norm).
Initialization
- forward(
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
- class nemo_automodel.components.models.llama.model.LlamaForCausalLM(config: transformers.LlamaConfig)#
Bases:
nemo_automodel.components.models.llama.model.LlamaPreTrainedModelLlama model with causal language modeling head.
Initialization
- _tied_weights_keys#
[‘lm_head.weight’]
- _tp_plan#
None
- _pp_plan#
None
- save_pretrained_hf_format(save_directory: str, **kwargs)#
Save model in HuggingFace-compatible format by converting combined projections.
This method converts the custom model’s combined projections (qkv_proj, gate_up_proj) back to HuggingFace’s separate projections format before saving, making the checkpoint loadable with AutoModelForCausalLM.from_pretrained().
- Parameters:
save_directory – Directory where the model will be saved
**kwargs – Additional arguments passed to config.save_pretrained and save_file
- get_input_embeddings()#
- set_input_embeddings(value)#
- get_output_embeddings()#
- set_output_embeddings(new_embeddings)#
- set_decoder(decoder)#
- get_decoder()#
- forward(
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
Forward pass returning CausalLMOutputWithPast.
- Parameters:
input_ids – (batch_size, seq_len)
attention_mask – Optional attention mask
position_ids – Optional position indices
past_key_values – Optional cached key/values
inputs_embeds – Optional pre-computed embeddings
labels – Optional labels for computing loss
use_cache – Whether to use KV caching
cache_position – Position in cache
logits_to_keep – Number of final logits to compute (0=all, N=last N tokens)
- Returns:
CausalLMOutputWithPast with loss, logits, past_key_values
- nemo_automodel.components.models.llama.model.ModelClass#
None