nemo_automodel.components.models.llama.model#

Custom Llama model implementation for NeMo Automodel.

This module provides a self-contained Llama implementation with combined QKV and gate_up projections for improved efficiency. Following HuggingFace’s implementation with optimizations.

Example (YAML):

model:
  _target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
  pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct

Module Contents#

Classes#

LlamaAttention

Multi-headed attention from ‘Attention Is All You Need’ paper with combined QKV projection.

LlamaMLP

SwiGLU MLP with combined gate_up projection for efficiency.

LlamaDecoderLayer

Single Llama decoder layer with RMSNorm, attention, and MLP.

LlamaPreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

LlamaModel

Llama transformer model (embeddings + decoder layers + norm).

LlamaForCausalLM

Llama model with causal language modeling head.

Data#

API#

nemo_automodel.components.models.llama.model.check_model_inputs#

‘get_check_model_inputs_decorator(…)’

class nemo_automodel.components.models.llama.model.LlamaAttention(config: transformers.LlamaConfig, layer_idx: int)#

Bases: nemo_automodel.components.models.common.combined_projection.CombinedQKVAttentionMixin, torch.nn.Module

Multi-headed attention from ‘Attention Is All You Need’ paper with combined QKV projection.

Initialization

forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[transformers.cache_utils.Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) tuple[torch.Tensor, torch.Tensor]#
class nemo_automodel.components.models.llama.model.LlamaMLP(config: transformers.LlamaConfig)#

Bases: torch.nn.Module

SwiGLU MLP with combined gate_up projection for efficiency.

Initialization

forward(x: torch.Tensor) torch.Tensor#
class nemo_automodel.components.models.llama.model.LlamaDecoderLayer(config: transformers.LlamaConfig, layer_idx: int)#

Bases: transformers.modeling_layers.GradientCheckpointingLayer

Single Llama decoder layer with RMSNorm, attention, and MLP.

Inherits from GradientCheckpointingLayer for efficient activation checkpointing.

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[transformers.cache_utils.Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) torch.Tensor#
class nemo_automodel.components.models.llama.model.LlamaPreTrainedModel#

Bases: transformers.modeling_utils.PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

config_class#

None

base_model_prefix#

‘model’

supports_gradient_checkpointing#

True

_no_split_modules#

[‘LlamaDecoderLayer’]

_skip_keys_device_placement#

[‘past_key_values’]

_supports_flash_attn#

True

_supports_sdpa#

True

_supports_flex_attn#

True

_can_compile_fullgraph#

True

_supports_attention_backend#

True

_can_record_outputs#

None

class nemo_automodel.components.models.llama.model.LlamaModel(config: transformers.LlamaConfig)#

Bases: nemo_automodel.components.models.llama.model.LlamaPreTrainedModel

Llama transformer model (embeddings + decoder layers + norm).

Initialization

forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) transformers.modeling_outputs.BaseModelOutputWithPast#
class nemo_automodel.components.models.llama.model.LlamaForCausalLM(config: transformers.LlamaConfig)#

Bases: nemo_automodel.components.models.llama.model.LlamaPreTrainedModel

Llama model with causal language modeling head.

Initialization

_tied_weights_keys#

[‘lm_head.weight’]

_tp_plan#

None

_pp_plan#

None

save_pretrained_hf_format(save_directory: str, **kwargs)#

Save model in HuggingFace-compatible format by converting combined projections.

This method converts the custom model’s combined projections (qkv_proj, gate_up_proj) back to HuggingFace’s separate projections format before saving, making the checkpoint loadable with AutoModelForCausalLM.from_pretrained().

Parameters:
  • save_directory – Directory where the model will be saved

  • **kwargs – Additional arguments passed to config.save_pretrained and save_file

get_input_embeddings()#
set_input_embeddings(value)#
get_output_embeddings()#
set_output_embeddings(new_embeddings)#
set_decoder(decoder)#
get_decoder()#
forward(
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs],
) transformers.modeling_outputs.CausalLMOutputWithPast#

Forward pass returning CausalLMOutputWithPast.

Parameters:
  • input_ids – (batch_size, seq_len)

  • attention_mask – Optional attention mask

  • position_ids – Optional position indices

  • past_key_values – Optional cached key/values

  • inputs_embeds – Optional pre-computed embeddings

  • labels – Optional labels for computing loss

  • use_cache – Whether to use KV caching

  • cache_position – Position in cache

  • logits_to_keep – Number of final logits to compute (0=all, N=last N tokens)

Returns:

CausalLMOutputWithPast with loss, logits, past_key_values

nemo_automodel.components.models.llama.model.ModelClass#

None