nemo_automodel.components.models.step3p5.layers#

Module Contents#

Classes#

Step3p5RMSNorm

RMSNorm with (weight + 1) scaling used by Step3p5.

Step3p5RotaryEmbedding

Rotary embedding for Step3p5 with per-layer theta and partial rotary factor support.

Step3p5MLP

Step3p5 MLP with SwiGLU activation and optional clamping.

Step3p5Attention

Step3p5 attention with Q/K per-head RMSNorm, optional head-wise gate, and alternating attention patterns.

API#

class nemo_automodel.components.models.step3p5.layers.Step3p5RMSNorm(hidden_size: int, eps: float = 1e-05)#

Bases: torch.nn.Module

RMSNorm with (weight + 1) scaling used by Step3p5.

Unlike standard RMSNorm which uses x_normed * weight, Step3p5 uses x_normed * (weight + 1). The weight is initialized to zeros, so initially the scaling factor is 1.

Note: Cannot use TE’s fused RMSNorm because the (weight + 1) adjustment cannot be intercepted.

Initialization

reset_parameters() None#

Reset parameters to initial state (zeros).

forward(x: torch.Tensor) torch.Tensor#
class nemo_automodel.components.models.step3p5.layers.Step3p5RotaryEmbedding(config: Any, layer_idx: int)#

Bases: torch.nn.Module

Rotary embedding for Step3p5 with per-layer theta and partial rotary factor support.

Initialization

_compute_inv_freq() torch.Tensor#

Compute inverse frequencies for rotary embeddings.

forward(
x: torch.Tensor,
position_ids: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#

Compute cos and sin for rotary embeddings.

Parameters:
  • x – Input tensor (used for dtype and device).

  • position_ids – Position indices [batch_size, seq_len].

Returns:

Tuple of (cos, sin) tensors.

class nemo_automodel.components.models.step3p5.layers.Step3p5MLP(
config: Any,
backend: nemo_automodel.components.models.common.BackendConfig,
intermediate_size: int | None = None,
swiglu_limit: float | None = None,
)#

Bases: torch.nn.Module

Step3p5 MLP with SwiGLU activation and optional clamping.

Initialization

forward(x: torch.Tensor) torch.Tensor#
init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.models.step3p5.layers.Step3p5Attention(
config: Any,
layer_idx: int,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Step3p5 attention with Q/K per-head RMSNorm, optional head-wise gate, and alternating attention patterns.

Key features:

  • Q/K per-head normalization using Step3p5RMSNorm

  • Optional head-wise attention gate (g_proj + sigmoid)

  • Per-layer RoPE theta and partial_rotary_factors

  • Sliding window based on layer_types config

Initialization

forward(
x: torch.Tensor,
*,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
**attn_kwargs: Any,
) torch.Tensor#
init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#