nemo_automodel.components.models.llama.model

View as Markdown

Custom Llama model implementation for NeMo Automodel.

This module provides a self-contained Llama implementation following HuggingFace’s implementation. Uses separate q_proj/k_proj/v_proj and gate_proj/up_proj (HF-style).

Example (YAML):

model:
_target_: nemo_automodel.NeMoAutoModelForCausalLM.from_pretrained
pretrained_model_name_or_path: meta-llama/Llama-3.3-70B-Instruct

Module Contents

Classes

NameDescription
LlamaAttentionMulti-headed attention from ‘Attention Is All You Need’ paper.
LlamaDecoderLayerSingle Llama decoder layer with RMSNorm, attention, and MLP.
LlamaForCausalLMLlama model with causal language modeling head.
LlamaMLPSwiGLU MLP with separate gate_proj and up_proj — identical to HuggingFace default.
LlamaModelLlama transformer model (embeddings + decoder layers + norm).
LlamaPreTrainedModelAn abstract class to handle weights initialization and a simple interface for downloading and loading pretrained

Data

ModelClass

check_model_inputs

API

class nemo_automodel.components.models.llama.model.LlamaAttention(
config: transformers.LlamaConfig,
layer_idx: int,
backend: typing.Optional['BackendConfig'] = None
)

Bases: Module

Multi-headed attention from ‘Attention Is All You Need’ paper.

Uses separate q_proj / k_proj / v_proj — identical to the default HuggingFace Llama implementation.

attention_dropout
= config.attention_dropout
head_dim
k_proj
num_key_value_groups
o_proj
q_proj
rope_fusion
= getattr(backend, 'rope_fusion', False)
scaling
= self.head_dim ** -0.5
v_proj
nemo_automodel.components.models.llama.model.LlamaAttention.forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: typing.Optional[torch.Tensor],
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> tuple[torch.Tensor, torch.Tensor]
class nemo_automodel.components.models.llama.model.LlamaDecoderLayer(
config: transformers.LlamaConfig,
layer_idx: int,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: GradientCheckpointingLayer

Single Llama decoder layer with RMSNorm, attention, and MLP.

Inherits from GradientCheckpointingLayer for efficient activation checkpointing.

hidden_size
= config.hidden_size
input_layernorm
mlp
= LlamaMLP(config=config)
post_attention_layernorm
self_attn
nemo_automodel.components.models.llama.model.LlamaDecoderLayer.forward(
hidden_states: torch.Tensor,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
use_cache: typing.Optional[bool] = False,
cache_position: typing.Optional[torch.LongTensor] = None,
position_embeddings: typing.Optional[tuple[torch.Tensor, torch.Tensor]] = None,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> torch.Tensor
class nemo_automodel.components.models.llama.model.LlamaForCausalLM(
config: transformers.LlamaConfig,
backend: typing.Optional[nemo_automodel.components.models.common.BackendConfig] = None
)

Bases: HFCheckpointingMixin, LlamaPreTrainedModel

Llama model with causal language modeling head.

_pp_plan
= {'lm_head': (['hidden_states'], ['logits'])}
_tied_weights_keys
= {'lm_head.weight': 'model.embed_tokens.weight'}
_tp_plan
= {'lm_head': 'colwise_rep'}
backend
= backend or BackendConfig()
lm_head
model
= LlamaModel(config=config, backend=(self.backend))
state_dict_adapter
= LlamaStateDictAdapter(config=(self.config))
vocab_size
= config.vocab_size
nemo_automodel.components.models.llama.model.LlamaForCausalLM.forward(
input_ids: typing.Optional[torch.LongTensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: typing.Optional[torch.FloatTensor] = None,
labels: typing.Optional[torch.LongTensor] = None,
use_cache: typing.Optional[bool] = None,
output_attentions: typing.Optional[bool] = None,
output_hidden_states: typing.Optional[bool] = None,
return_dict: typing.Optional[bool] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast

Forward pass returning CausalLMOutputWithPast.

Parameters:

input_ids
Optional[torch.LongTensor]Defaults to None

(batch_size, seq_len)

attention_mask
Optional[torch.Tensor]Defaults to None

Optional attention mask

position_ids
Optional[torch.LongTensor]Defaults to None

Optional position indices

past_key_values
Optional[Cache]Defaults to None

Optional cached key/values

inputs_embeds
Optional[torch.FloatTensor]Defaults to None

Optional pre-computed embeddings

labels
Optional[torch.LongTensor]Defaults to None

Optional labels for computing loss

use_cache
Optional[bool]Defaults to None

Whether to use KV caching

cache_position
Optional[torch.LongTensor]Defaults to None

Position in cache

logits_to_keep
Union[int, torch.Tensor]Defaults to 0

Number of final logits to compute (0=all, N=last N tokens)

Returns: CausalLMOutputWithPast

CausalLMOutputWithPast with loss, logits, past_key_values

nemo_automodel.components.models.llama.model.LlamaForCausalLM.from_config(
config: transformers.LlamaConfig,
backend: typing.Optional[nemo_automodel.components.models.common.BackendConfig] = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.llama.model.LlamaForCausalLM.get_decoder()
nemo_automodel.components.models.llama.model.LlamaForCausalLM.get_input_embeddings()
nemo_automodel.components.models.llama.model.LlamaForCausalLM.get_output_embeddings()
nemo_automodel.components.models.llama.model.LlamaForCausalLM.set_decoder(
decoder
)
nemo_automodel.components.models.llama.model.LlamaForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.llama.model.LlamaForCausalLM.set_output_embeddings(
new_embeddings
)
nemo_automodel.components.models.llama.model.LlamaForCausalLM.tie_weights(
_args: object = (),
_kwargs: object = {}
) -> None
class nemo_automodel.components.models.llama.model.LlamaMLP(
config: transformers.LlamaConfig
)

Bases: Module

SwiGLU MLP with separate gate_proj and up_proj — identical to HuggingFace default.

act_fn
= ACT2FN[config.hidden_act]
down_proj
gate_proj
hidden_size
= config.hidden_size
intermediate_size
= config.intermediate_size
up_proj
nemo_automodel.components.models.llama.model.LlamaMLP.forward(
x: torch.Tensor
) -> torch.Tensor
class nemo_automodel.components.models.llama.model.LlamaModel(
config: transformers.LlamaConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: LlamaPreTrainedModel

Llama transformer model (embeddings + decoder layers + norm).

embed_tokens
layers
norm
padding_idx
= config.pad_token_id
rotary_emb
vocab_size
= config.vocab_size
nemo_automodel.components.models.llama.model.LlamaModel.forward(
input_ids: typing.Optional[torch.LongTensor] = None,
attention_mask: typing.Optional[torch.Tensor] = None,
position_ids: typing.Optional[torch.LongTensor] = None,
past_key_values: typing.Optional[transformers.cache_utils.Cache] = None,
inputs_embeds: typing.Optional[torch.FloatTensor] = None,
use_cache: typing.Optional[bool] = None,
output_attentions: typing.Optional[bool] = None,
output_hidden_states: typing.Optional[bool] = None,
return_dict: typing.Optional[bool] = None,
cache_position: typing.Optional[torch.LongTensor] = None,
kwargs: transformers.processing_utils.Unpack[transformers.utils.TransformersKwargs] = {}
) -> transformers.modeling_outputs.BaseModelOutputWithPast
class nemo_automodel.components.models.llama.model.LlamaPreTrainedModel()

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

_can_record_outputs
_no_split_modules
= ['LlamaDecoderLayer']
_skip_keys_device_placement
= ['past_key_values']
base_model_prefix
= 'model'
nemo_automodel.components.models.llama.model.ModelClass = LlamaForCausalLM
nemo_automodel.components.models.llama.model.check_model_inputs = get_check_model_inputs_decorator()