nemo_automodel.components.models.nemotron_v3.layers
nemo_automodel.components.models.nemotron_v3.layers
Module Contents
Classes
API
Bases: Module
GQA attention for NemotronV3 (no RoPE), compatible with TE/SDPA backends.
Bases: Module
NemotronV3 decoder block (training-only, simplified).
Pre-norm architecture: norm → mixer → residual add Supports hybrid layer types: Mamba, Attention, MLP, MoE
Map block_type to MoE parallelizer’s layer_type convention.
Return mixer for MoE blocks for compatibility with parallelizer.
Alias for mixer, for compatibility with MoE parallelizer.
Forward pass through the block.
Parameters:
Input tensor of shape (batch, seq_len, hidden_size)
Mask tensor - type depends on layer:
- For attention: 4D causal mask [batch, 1, seq_len, seq_len]
- For mamba: 2D padding mask [batch, seq_len]
- For mlp/moe: None
Optional NemotronHybridCache for KV/SSM caching.
Token position indices for cache updates.
Additional keyword arguments forwarded to attention layers only (e.g. cu_seqlens, cp_size, cp_rank for Context Parallelism).
Returns: torch.Tensor
Output tensor of shape (batch, seq_len, hidden_size)
Initialize block weights following NemotronV3 spec.
Parameters:
Device for buffer initialization (used by MLP/MoE)
Bases: Module
Mamba2 mixer for NemotronV3 (training-only, uses CUDA kernels).
This implementation uses the fused mamba_split_conv1d_scan_combined kernel for maximum training efficiency. Does not support inference caching.
Requires mamba_ssm and causal_conv1d packages.
Forward pass with three code paths.
Path A (training): past_key_values is None → fully-fused kernel. Path B (prefill): past_key_values present, seq_len > 1 → unfused scan + cache init. Path C (decode): past_key_values present, seq_len == 1, has_previous_state → single-step update.
Parameters:
Input tensor of shape (batch, seq_len, hidden_size)
Optional attention mask (applied to padding)
Optional NemotronHybridCache instance.
Token positions for cache updates.
Returns: torch.Tensor
Output tensor of shape (batch, seq_len, hidden_size)
Initialize Mamba2Mixer weights following NemotronV3 spec.
Bases: Module
Gated RMSNorm for Mamba layers.
Uses the fused triton kernel from mamba_ssm for efficiency.