nemo_automodel.components.models.llama.model#
Custom Llama model implementation for NeMo Automodel.
This module provides a self-contained Llama implementation with combined QKV and gate_up projections for improved efficiency. Following HuggingFaceâs implementation with optimizations.
Example (YAML):
model:
_target_: nemo_automodel.components.models.llama.build_llama_model
pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct
Module Contents#
Classes#
Multi-headed attention from âAttention Is All You Needâ paper with combined QKV projection. |
|
SwiGLU MLP with combined gate_up projection for efficiency. |
|
Single Llama decoder layer with RMSNorm, attention, and MLP. |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. |
|
Llama transformer model (embeddings + decoder layers + norm). |
|
Llama model with causal language modeling head. |
Functions#
Build a custom Llama model with combined projections for efficiency. |
Data#
API#
- nemo_automodel.components.models.llama.model.__all__#
[âbuild_llama_modelâ, âLlamaForCausalLMâ]
- class nemo_automodel.components.models.llama.model.LlamaAttention(config: transformers.LlamaConfig, layer_idx: int)#
Bases:
nemo_automodel.components.models.common.combined_projection.CombinedQKVAttentionMixin,torch.nn.ModuleMulti-headed attention from âAttention Is All You Needâ paper with combined QKV projection.
Initialization
- forward(
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
- class nemo_automodel.components.models.llama.model.LlamaMLP(config: transformers.LlamaConfig)#
Bases:
torch.nn.ModuleSwiGLU MLP with combined gate_up projection for efficiency.
Initialization
- forward(x: torch.Tensor) torch.Tensor#
- class nemo_automodel.components.models.llama.model.LlamaDecoderLayer(config: transformers.LlamaConfig, layer_idx: int)#
Bases:
transformers.modeling_layers.GradientCheckpointingLayerSingle Llama decoder layer with RMSNorm, attention, and MLP.
Inherits from GradientCheckpointingLayer for efficient activation checkpointing.
Initialization
- forward(
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- use_cache: Optional[bool] = False,
- cache_position: Optional[torch.LongTensor] = None,
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
- class nemo_automodel.components.models.llama.model.LlamaPreTrainedModel#
Bases:
transformers.modeling_utils.PreTrainedModelAn abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.
- config_class#
None
- base_model_prefix#
âmodelâ
- supports_gradient_checkpointing#
True
- _no_split_modules#
[âLlamaDecoderLayerâ]
- _skip_keys_device_placement#
[âpast_key_valuesâ]
- _supports_flash_attn#
True
- _supports_sdpa#
True
- _supports_flex_attn#
True
- _can_compile_fullgraph#
True
- _supports_attention_backend#
True
- _can_record_outputs#
None
- class nemo_automodel.components.models.llama.model.LlamaModel(config: transformers.LlamaConfig)#
Bases:
nemo_automodel.components.models.llama.model.LlamaPreTrainedModelLlama transformer model (embeddings + decoder layers + norm).
Initialization
- forward(
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
- class nemo_automodel.components.models.llama.model.LlamaForCausalLM(
- config: transformers.LlamaConfig,
- backend: Optional[nemo_automodel.components.moe.utils.BackendConfig] = None,
Bases:
nemo_automodel.components.models.llama.model.LlamaPreTrainedModelLlama model with causal language modeling head.
Initialization
- _tied_weights_keys#
[âlm_head.weightâ]
- _tp_plan#
None
- _pp_plan#
None
- save_pretrained_hf_format(save_directory: str, **kwargs)#
Save model in HuggingFace-compatible format by converting combined projections.
This method converts the custom modelâs combined projections (qkv_proj, gate_up_proj) back to HuggingFaceâs separate projections format before saving, making the checkpoint loadable with AutoModelForCausalLM.from_pretrained().
- Parameters:
save_directory â Directory where the model will be saved
**kwargs â Additional arguments passed to config.save_pretrained and save_file
- get_input_embeddings()#
- set_input_embeddings(value)#
- get_output_embeddings()#
- set_output_embeddings(new_embeddings)#
- set_decoder(decoder)#
- get_decoder()#
- forward(
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[transformers.cache_utils.Cache] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- cache_position: Optional[torch.LongTensor] = None,
- logits_to_keep: Union[int, torch.Tensor] = 0,
- **kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
Forward pass returning CausalLMOutputWithPast.
- Parameters:
input_ids â (batch_size, seq_len)
attention_mask â Optional attention mask
position_ids â Optional position indices
past_key_values â Optional cached key/values
inputs_embeds â Optional pre-computed embeddings
labels â Optional labels for computing loss
use_cache â Whether to use KV caching
cache_position â Position in cache
logits_to_keep â Number of final logits to compute (0=all, N=last N tokens)
- Returns:
CausalLMOutputWithPast with loss, logits, past_key_values
- nemo_automodel.components.models.llama.model.build_llama_model(
- pretrained_model_name_or_path: str,
- **kwargs: Any,
Build a custom Llama model with combined projections for efficiency.
This function loads the config from a HuggingFace model card and builds a custom Llama model with combined QKV and gate_up projections for improved efficiency.
- Parameters:
pretrained_model_name_or_path â HuggingFace model card name (e.g., âmeta-llama/Meta-Llama-3-70Bâ)
**kwargs â
Override config parameters. Common parameters include:
vocab_size: Vocabulary size
hidden_size: Hidden dimension size
num_hidden_layers: Number of transformer layers (useful for testing)
num_attention_heads: Number of attention heads
num_key_value_heads: Number of key/value heads for GQA
intermediate_size: MLP intermediate size
max_position_embeddings: Maximum sequence length
rms_norm_eps: RMSNorm epsilon
rope_theta: RoPE base frequency
attention_dropout: Attention dropout probability
pad_token_id: Padding token ID
attn_implementation: Attention backend (âeagerâ, âsdpaâ, âflash_attention_2â)
torch_dtype: Model dtype (default: bfloat16)
- Returns:
LlamaForCausalLM model instance with combined projections
.. rubric:: Example
Load with default settings (combined projections, bfloat16)#
model = build_llama_model(âmeta-llama/Meta-Llama-3-70Bâ)
Use SDPA for faster attention#
model = build_llama_model(âmeta-llama/Meta-Llama-3-70Bâ, attn_implementation=âsdpaâ)
Override for testing with fewer layers#
model = build_llama_model(âmeta-llama/Meta-Llama-3-70Bâ, num_hidden_layers=4)