nemo_automodel.components.models.nemotron_v3.layers#

Module Contents#

Classes#

NemotronV3Attention

Multi-headed attention for NemotronV3 (Nano-v3).

NemotronV3MambaRMSNormGated

Gated RMSNorm for Mamba layers.

NemotronV3Mamba2Mixer

Mamba2 mixer for NemotronV3 (training-only, uses CUDA kernels).

NemotronV3Block

NemotronV3 decoder block (training-only, simplified).

API#

class nemo_automodel.components.models.nemotron_v3.layers.NemotronV3Attention(config)#

Bases: torch.nn.Module

Multi-headed attention for NemotronV3 (Nano-v3).

This is a standard GQA attention module following the NemotronH architecture. Uses PyTorch’s scaled_dot_product_attention (SDPA) for the attention computation. Note: RoPE is not applied in this module, matching the HF NemotronHAttention implementation.

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_values=None,
layer_idx: int | None = None,
) torch.Tensor#
init_weights(
num_hidden_layers: int,
rescale_prenorm_residual: bool = True,
buffer_device: torch.device | None = None,
) None#

Initialize attention weights following NemotronV3 spec.

class nemo_automodel.components.models.nemotron_v3.layers.NemotronV3MambaRMSNormGated(
hidden_size: int,
group_size: int,
eps: float = 1e-05,
)#

Bases: torch.nn.Module

Gated RMSNorm for Mamba layers.

Uses the fused triton kernel from mamba_ssm for efficiency.

Initialization

forward(
hidden_states: torch.Tensor,
gate: torch.Tensor | None = None,
) torch.Tensor#
class nemo_automodel.components.models.nemotron_v3.layers.NemotronV3Mamba2Mixer(config, layer_idx: int)#

Bases: torch.nn.Module

Mamba2 mixer for NemotronV3 (training-only, uses CUDA kernels).

This implementation uses the fused mamba_split_conv1d_scan_combined kernel for maximum training efficiency. Does not support inference caching.

Requires mamba_ssm and causal_conv1d packages.

Initialization

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
) torch.Tensor#

Forward pass with three code paths.

Path A (training): past_key_values is None β†’ fully-fused kernel. Path B (prefill): past_key_values present, seq_len > 1 β†’ unfused scan + cache init. Path C (decode): past_key_values present, seq_len == 1, has_previous_state β†’ single-step update.

Parameters:
  • hidden_states – Input tensor of shape (batch, seq_len, hidden_size)

  • attention_mask – Optional attention mask (applied to padding)

  • past_key_values – Optional NemotronHybridCache instance.

  • cache_position – Token positions for cache updates.

Returns:

Output tensor of shape (batch, seq_len, hidden_size)

init_weights(
num_hidden_layers: int,
rescale_prenorm_residual: bool = True,
buffer_device: torch.device | None = None,
) None#

Initialize Mamba2Mixer weights following NemotronV3 spec.

class nemo_automodel.components.models.nemotron_v3.layers.NemotronV3Block(config, layer_idx: int, moe_config=None, backend=None)#

Bases: torch.nn.Module

NemotronV3 decoder block (training-only, simplified).

Pre-norm architecture: norm β†’ mixer β†’ residual add Supports hybrid layer types: Mamba, Attention, MLP, MoE

Initialization

Initialize NemotronV3Block.

Parameters:
  • config – Model configuration with layers_block_type attribute

  • layer_idx – Index of this layer in the model

  • moe_config – MoE configuration (required for MoE layers)

  • backend – Backend configuration (optional)

property mlp#

Return mixer for MoE blocks for compatibility with parallelizer.

forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
past_key_values=None,
cache_position: torch.LongTensor | None = None,
) torch.Tensor#

Forward pass through the block.

Parameters:
  • hidden_states – Input tensor of shape (batch, seq_len, hidden_size)

  • attention_mask –

    Mask tensor - type depends on layer:

    • For attention: 4D causal mask [batch, 1, seq_len, seq_len]

    • For mamba: 2D padding mask [batch, seq_len]

    • For mlp/moe: None

  • past_key_values – Optional NemotronHybridCache for KV/SSM caching.

  • cache_position – Token position indices for cache updates.

Returns:

Output tensor of shape (batch, seq_len, hidden_size)

init_weights(buffer_device: torch.device | None = None) None#

Initialize block weights following NemotronV3 spec.

Parameters:

buffer_device – Device for buffer initialization (used by MLP/MoE)