nemo_automodel.components.models.step3p5.layers#
Module Contents#
Classes#
RMSNorm with (weight + 1) scaling used by Step3p5. |
|
Rotary embedding for Step3p5 with per-layer theta and partial rotary factor support. |
|
Step3p5 MLP with SwiGLU activation and optional clamping. |
|
Step3p5 attention with Q/K per-head RMSNorm, optional head-wise gate, and alternating attention patterns. |
API#
- class nemo_automodel.components.models.step3p5.layers.Step3p5RMSNorm(hidden_size: int, eps: float = 1e-05)#
Bases:
torch.nn.ModuleRMSNorm with (weight + 1) scaling used by Step3p5.
Unlike standard RMSNorm which uses
x_normed * weight, Step3p5 usesx_normed * (weight + 1). The weight is initialized to zeros, so initially the scaling factor is 1.Note: Cannot use TEβs fused RMSNorm because the (weight + 1) adjustment cannot be intercepted.
Initialization
- reset_parameters() None#
Reset parameters to initial state (zeros).
- forward(x: torch.Tensor) torch.Tensor#
- class nemo_automodel.components.models.step3p5.layers.Step3p5RotaryEmbedding(config: Any, layer_idx: int)#
Bases:
torch.nn.ModuleRotary embedding for Step3p5 with per-layer theta and partial rotary factor support.
Initialization
- _compute_inv_freq() torch.Tensor#
Compute inverse frequencies for rotary embeddings.
- forward(
- x: torch.Tensor,
- position_ids: torch.Tensor,
Compute cos and sin for rotary embeddings.
- Parameters:
x β Input tensor (used for dtype and device).
position_ids β Position indices [batch_size, seq_len].
- Returns:
Tuple of (cos, sin) tensors.
- class nemo_automodel.components.models.step3p5.layers.Step3p5MLP(
- config: Any,
- backend: nemo_automodel.components.models.common.BackendConfig,
- intermediate_size: int | None = None,
- swiglu_limit: float | None = None,
Bases:
torch.nn.ModuleStep3p5 MLP with SwiGLU activation and optional clamping.
Initialization
- forward(x: torch.Tensor) torch.Tensor#
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- class nemo_automodel.components.models.step3p5.layers.Step3p5Attention(
- config: Any,
- layer_idx: int,
- backend: nemo_automodel.components.models.common.BackendConfig,
Bases:
torch.nn.ModuleStep3p5 attention with Q/K per-head RMSNorm, optional head-wise gate, and alternating attention patterns.
Key features:
Q/K per-head normalization using Step3p5RMSNorm
Optional head-wise attention gate (g_proj + sigmoid)
Per-layer RoPE theta and partial_rotary_factors
Sliding window based on layer_types config
Initialization
- forward(
- x: torch.Tensor,
- *,
- freqs_cis: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- **attn_kwargs: Any,
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,