nemo_automodel.components.models.llama.model#

Custom Llama model implementation for NeMo Automodel.

This module provides a self-contained Llama implementation with combined QKV and gate_up projections for improved efficiency. Following HuggingFace’s implementation with optimizations.

Example (YAML):

model:
  _target_: nemo_automodel.components.models.llama.build_llama_model
  pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct

Module Contents#

Classes#

LlamaAttention

Multi-headed attention from ‘Attention Is All You Need’ paper with combined QKV projection.

LlamaMLP

SwiGLU MLP with combined gate_up projection for efficiency.

LlamaDecoderLayer

Single Llama decoder layer with RMSNorm, attention, and MLP.

LlamaPreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

LlamaModel

Llama transformer model (embeddings + decoder layers + norm).

LlamaForCausalLM

Llama model with causal language modeling head.

Functions#

build_llama_model

Build a custom Llama model with combined projections for efficiency.

Data#

API#

nemo_automodel.components.models.llama.model.__all__#

[‘build_llama_model’, ‘LlamaForCausalLM’]

class nemo_automodel.components.models.llama.model.LlamaAttention(config: transformers.LlamaConfig, layer_idx: int)#

Bases: nemo_automodel.components.models.common.combined_projection.CombinedQKVAttentionMixin, torch.nn.Module

Multi-headed attention from ‘Attention Is All You Need’ paper with combined QKV projection.

Initialization

forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[transformers.cache_utils.Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) tuple[torch.Tensor, torch.Tensor]#
class nemo_automodel.components.models.llama.model.LlamaMLP(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

SwiGLU MLP with combined gate_up projection for efficiency.

Initialization

forward(x: torch.Tensor) torch.Tensor#
class nemo_automodel.components.models.llama.model.LlamaDecoderLayer(config: transformers.LlamaConfig, layer_idx: int)#

Bases: transformers.modeling_layers.GradientCheckpointingLayer

Single Llama decoder layer with RMSNorm, attention, and MLP.

Inherits from GradientCheckpointingLayer for efficient activation checkpointing.

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[transformers.cache_utils.Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) torch.Tensor#
class nemo_automodel.components.models.llama.model.LlamaPreTrainedModel#

Bases: transformers.modeling_utils.PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

config_class#

None

base_model_prefix#

‘model’

supports_gradient_checkpointing#

True

_no_split_modules#

[‘LlamaDecoderLayer’]

_skip_keys_device_placement#

[‘past_key_values’]

_supports_flash_attn#

True

_supports_sdpa#

True

_supports_flex_attn#

True

_can_compile_fullgraph#

True

_supports_attention_backend#

True

_can_record_outputs#

None

class nemo_automodel.components.models.llama.model.LlamaModel(config: transformers.LlamaConfig)#

Bases: nemo_automodel.components.models.llama.model.LlamaPreTrainedModel

Llama transformer model (embeddings + decoder layers + norm).

Initialization

forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) transformers.modeling_outputs.BaseModelOutputWithPast#
class nemo_automodel.components.models.llama.model.LlamaForCausalLM(
config: transformers.LlamaConfig,
backend: Optional[nemo_automodel.components.moe.utils.BackendConfig] = None,
)#

Bases: nemo_automodel.components.models.llama.model.LlamaPreTrainedModel

Llama model with causal language modeling head.

Initialization

_tied_weights_keys#

[‘lm_head.weight’]

_tp_plan#

None

_pp_plan#

None

save_pretrained_hf_format(save_directory: str, **kwargs)#

Save model in HuggingFace-compatible format by converting combined projections.

This method converts the custom model’s combined projections (qkv_proj, gate_up_proj) back to HuggingFace’s separate projections format before saving, making the checkpoint loadable with AutoModelForCausalLM.from_pretrained().

Parameters:
  • save_directory – Directory where the model will be saved

  • **kwargs – Additional arguments passed to config.save_pretrained and save_file

get_input_embeddings()#
set_input_embeddings(value)#
get_output_embeddings()#
set_output_embeddings(new_embeddings)#
set_decoder(decoder)#
get_decoder()#
forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) transformers.modeling_outputs.CausalLMOutputWithPast#

Forward pass returning CausalLMOutputWithPast.

Parameters:
  • input_ids – (batch_size, seq_len)

  • attention_mask – Optional attention mask

  • position_ids – Optional position indices

  • past_key_values – Optional cached key/values

  • inputs_embeds – Optional pre-computed embeddings

  • labels – Optional labels for computing loss

  • use_cache – Whether to use KV caching

  • cache_position – Position in cache

  • logits_to_keep – Number of final logits to compute (0=all, N=last N tokens)

Returns:

CausalLMOutputWithPast with loss, logits, past_key_values

nemo_automodel.components.models.llama.model.build_llama_model(
pretrained_model_name_or_path: str,
**kwargs: Any,
) torch.nn.Module#

Build a custom Llama model with combined projections for efficiency.

This function loads the config from a HuggingFace model card and builds a custom Llama model with combined QKV and gate_up projections for improved efficiency.

Parameters:
  • pretrained_model_name_or_path – HuggingFace model card name (e.g., “meta-llama/Meta-Llama-3-70B”)

  • **kwargs –

    Override config parameters. Common parameters include:

    • vocab_size: Vocabulary size

    • hidden_size: Hidden dimension size

    • num_hidden_layers: Number of transformer layers (useful for testing)

    • num_attention_heads: Number of attention heads

    • num_key_value_heads: Number of key/value heads for GQA

    • intermediate_size: MLP intermediate size

    • max_position_embeddings: Maximum sequence length

    • rms_norm_eps: RMSNorm epsilon

    • rope_theta: RoPE base frequency

    • attention_dropout: Attention dropout probability

    • pad_token_id: Padding token ID

    • attn_implementation: Attention backend (“eager”, “sdpa”, “flash_attention_2”)

    • torch_dtype: Model dtype (default: bfloat16)

Returns:

LlamaForCausalLM model instance with combined projections

.. rubric:: Example

Load with default settings (combined projections, bfloat16)#

model = build_llama_model(“meta-llama/Meta-Llama-3-70B”)

Use SDPA for faster attention#

model = build_llama_model(“meta-llama/Meta-Llama-3-70B”, attn_implementation=”sdpa”)

Override for testing with fewer layers#

model = build_llama_model(“meta-llama/Meta-Llama-3-70B”, num_hidden_layers=4)