nemo_automodel.components.models.nemotron_v3.model
nemo_automodel.components.models.nemotron_v3.model
Module Contents
Classes
Data
API
Bases: CausalLMOutputWithPast
CausalLMOutputWithPast plus declared MTP fields.
The MTP per-depth hidden states and scaling factor must be regular
dataclass fields (rather than dynamically-set attributes) so they survive
output-restructuring layers like FSDP2’s mixed-precision output cast,
which rebuild ModelOutput instances from declared fields only.
Bases: HFCheckpointingMixin, GenerationMixin, Module, MoEFSDPSyncMixin
NemotronV3 model with language modeling head.
Supports .generate() from transformers.generation.GenerationMixin with O(1)
per-step KV caching for attention layers and recurrent state caching for Mamba2 layers.
Return the device of the first model parameter (required by GenerationMixin).
Return the dtype of the first model parameter (used by cache construction).
Build the per-depth rolled-token embeddings on the first PP stage.
The first PP stage owns embed_tokens and is the only rank that can
produce the future-token embeddings consumed by the MTP head on the
final stage. The tuple flows alongside hidden_states through every
intermediate stage as additional positional outputs (see forward).
Parameters:
Token ids [B, S] (int).
Returns: torch.Tensor
Tuple of length self.mtp_config.num_layers containing
True when this module instance has been trimmed to a PP stage subset.
Detection mirrors DeepseekV4ForCausalLM._is_pipeline_parallel_stage:
any of (a) lm_head is None, (b) inner embed_tokens is None,
(c) model.layers count diverges from config.num_hidden_layers
is sufficient — the PP splitter nulls these attributes when trimming.
The checks use hasattr to distinguish “splitter nulled the
attribute” (attribute present, value is None) from “caller replaced
self.model with a stub that doesn’t declare the attribute”
(attribute absent). Tests that swap in stub inner modules should not
be misclassified as PP stages.
Build a 4D SDPA-compatible causal mask.
Prefill (query_len == kv_len): standard lower-triangular causal mask. Decode (query_len == 1): all-zeros row allowing attention to all cached positions.
Pin the MTP head to the last PP stage’s FQN list.
Called by split_model_into_stages (functional.py:494-502) after the
default per-stage FQN auto-generation. The auto-generator includes
embed_tokens on the first stage and norm/lm_head on the
last stage but doesn’t know about model.mtp; this hook appends it.
Forward pass with optional loss computation.
Supports both BSHD format (input_ids shape [B, S]) and THD format
(input_ids shape [T] after squeeze_input_for_thd). When
kwargs["qkv_format"] == "thd" AND the attention backend is TE,
inputs are squeezed to THD before the base-model forward and logits
are unsqueezed back to [1, T, V] on exit. SDPA / flex stay in BSHD.
Pipeline-parallel awareness: when run as a PP stage, input_ids is
the upstream stage’s hidden-state tensor on non-first stages, and
*mtp_embed_inputs carries num_nextn_predict_layers future-token
embeddings produced by the first stage. See the Returns section below
for the per-stage tuple contract. The single-rank (no-PP) path returns
:class:NemotronHCausalLMOutputWithPast unchanged.
Parameters:
Input token IDs. BSHD: [B, S]; THD: [1, T]
(squeezed internally). On non-first PP stages this slot
instead carries the upstream stage’s hidden-state tensor.
Pre-computed future-token embeddings produced by the first PP stage and forwarded between stages as positional args. Empty on the single-rank (no-PP) path.
2D padding mask [B, S].
Dict with precomputed 4D causal masks
(key "full_attention" is consumed).
Pre-computed input embeddings (optional).
Token IDs for loss computation [B, S] (optional;
under PP, loss is computed by PipelineCausalLMLoss).
Optional NemotronHybridCache for incremental decoding.
Whether to return past_key_values for subsequent steps.
Token position indices for cache updates.
Position IDs (forwarded into MTP sublayer kwargs).
Padding mask [B, S] used by the THD squeeze helper
and as the MoE / mamba 2D mask source.
If > 0, only compute logits for the last
logits_to_keep token positions.
Whether to return hidden states.
Accepted for API compatibility (always returns a
NemotronHCausalLMOutputWithPast off-PP).
Additional arguments forwarded to the base model
(e.g. qkv_format, cu_seqlens, cu_seqlens_padded,
max_seqlen, seq_idx, cp_rank, cp_size,
_packed_seq_ids).
Returns: CausalLMOutputWithPast
Off-PP: :class:NemotronHCausalLMOutputWithPast with logits,
Create model from config.
Parameters:
NemotronH config
Backend configuration
Additional arguments
Returns:
NemotronHForCausalLM instance
Load pretrained model.
Parameters:
Path or name of pretrained model
Additional positional arguments
Additional keyword arguments
Returns:
NemotronHForCausalLM instance
Return analytical (inputs_meta, outputs_meta) for a PP stage.
Inter-stage tensors are plain [B, S, H] (no HC stream). With MTP
enabled, every transfer carries 1 + D tensors so the variadic
forward signature is exercised on every microbatch.
Unwrap any checkpoint-wrapped blocks (inverse of gradient_checkpointing_enable).
Enable activation checkpointing on each transformer (and MTP) block.
Wraps every decoder block (and MTP block, when present) with a
non-reentrant checkpoint wrapper so that block activations are recomputed
during the backward pass instead of being stored. This is the single-GPU
entry point: FSDP2Manager.parallelize calls it when world_size == 1
(the expert-parallel path performs the equivalent wrapping inside the MoE
parallelizer’s apply_ac). Without it, the hybrid Mamba2/Attention MoE
keeps every block’s activations live, which is what pushes single-GPU LoRA
SFT over a single 80GB device. Idempotent.
Parameters:
Accepted for HF API compatibility; currently unused (NO_REENTRANT wrapping is always used).
Initialize model weights.
PP-aware: skips lm_head and mtp initialization when those have
been trimmed to None on a non-owning stage. self.model itself
also internally guards embed_tokens and norm.
Parameters:
Device to use for buffer initialization
Target dtype for model weights
Prepare model inputs for each generation step.
On the first call (prefill), creates a :class:NemotronHybridCache and
forwards the full prompt. On subsequent calls (decode), only the newly
generated token is forwarded.
Parameters:
Accumulated token ids [batch_size, current_seq_len].
Padding mask [batch_size, current_seq_len].
Pre-computed embeddings for the first step (optional).
NemotronHybridCache from the previous step (None on first call).
Token position indices.
Whether to use caching (default True).
Remaining model kwargs.
Returns: dict
Dict of keyword arguments to pass to :meth:forward.
Bases: Module
NemotronV3 base model (without LM head).
This is a hybrid architecture with Mamba2, Attention, MLP, and MoE layers.
Forward pass through the model. Supports BSHD [B, S, H] and THD [T, H].
Pipeline-parallel awareness: when self.embed_tokens is None (non-first
PP stage), input_ids is interpreted as the upstream hidden-state
tensor and routed through the inputs_embeds branch. When
self.norm is None (non-last PP stage), the final norm is skipped.
Initialize model weights according to NemotronV3 spec.
After PP splitting, embed_tokens may be None on non-first
stages and norm may be None on non-last stages; guard each.
Parameters:
Device to use for buffer initialization