nemo_automodel.components.models.step3p5.layers

View as Markdown

Module Contents

Classes

NameDescription
Step3p5AttentionStep3p5 attention with Q/K per-head RMSNorm, optional head-wise gate, and alternating attention patterns.
Step3p5MLPStep3p5 MLP with SwiGLU activation and optional clamping.
Step3p5RMSNormRMSNorm with (weight + 1) scaling used by Step3p5.
Step3p5RotaryEmbeddingRotary embedding for Step3p5 with per-layer theta and partial rotary factor support.

API

class nemo_automodel.components.models.step3p5.layers.Step3p5Attention(
config: typing.Any,
layer_idx: int,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Step3p5 attention with Q/K per-head RMSNorm, optional head-wise gate, and alternating attention patterns.

Key features:

  • Q/K per-head normalization using Step3p5RMSNorm
  • Optional head-wise attention gate (g_proj + sigmoid)
  • Per-layer RoPE theta and partial_rotary_factors
  • Sliding window based on layer_types config
g_proj
head_dim
k_norm
k_proj
num_heads
num_kv_groups
= self.num_heads // self.num_kv_heads
num_kv_heads
o_proj
q_norm
q_proj
rotary_emb
= Step3p5RotaryEmbedding(config, layer_idx)
sliding_window
= config.sliding_window
use_head_wise_attn_gate
= getattr(config, 'use_head_wise_attn_gate', False)
use_rope
= use_rope_layers[layer_idx]
v_proj
nemo_automodel.components.models.step3p5.layers.Step3p5Attention.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.step3p5.layers.Step3p5Attention.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.models.step3p5.layers.Step3p5MLP(
config: typing.Any,
backend: nemo_automodel.components.models.common.BackendConfig,
intermediate_size: int | None = None,
swiglu_limit: float | None = None
)

Bases: Module

Step3p5 MLP with SwiGLU activation and optional clamping.

down_proj
gate_proj
hidden_size
= config.hidden_size
intermediate_size
= intermediate_size or config.intermediate_size
up_proj
nemo_automodel.components.models.step3p5.layers.Step3p5MLP.forward(
x: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.step3p5.layers.Step3p5MLP.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.models.step3p5.layers.Step3p5RMSNorm(
hidden_size: int,
eps: float = 1e-05
)

Bases: Module

RMSNorm with (weight + 1) scaling used by Step3p5.

Unlike standard RMSNorm which uses x_normed * weight, Step3p5 uses x_normed * (weight + 1). The weight is initialized to zeros, so initially the scaling factor is 1.

Note: Cannot use TE’s fused RMSNorm because the (weight + 1) adjustment cannot be intercepted.

weight
= nn.Parameter(torch.zeros(hidden_size))
nemo_automodel.components.models.step3p5.layers.Step3p5RMSNorm.forward(
x: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.step3p5.layers.Step3p5RMSNorm.reset_parameters() -> None

Reset parameters to initial state (zeros).

class nemo_automodel.components.models.step3p5.layers.Step3p5RotaryEmbedding(
config: typing.Any,
layer_idx: int
)

Bases: Module

Rotary embedding for Step3p5 with per-layer theta and partial rotary factor support.

base
= rope_theta[layer_idx]
head_dim
max_position_embeddings
= config.max_position_embeddings
partial_rotary_factor
= partial_rotary_factors[layer_idx]
rotary_dim
= int(self.head_dim * self.partial_rotary_factor)
nemo_automodel.components.models.step3p5.layers.Step3p5RotaryEmbedding._apply(
fn
)
nemo_automodel.components.models.step3p5.layers.Step3p5RotaryEmbedding._compute_inv_freq(
device: torch.device | None = None
) -> torch.Tensor

Compute inverse frequencies for rotary embeddings.

nemo_automodel.components.models.step3p5.layers.Step3p5RotaryEmbedding.forward(
x: torch.Tensor,
position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Compute cos and sin for rotary embeddings.

Parameters:

x
torch.Tensor

Input tensor (used for dtype and device).

position_ids
torch.Tensor

Position indices [batch_size, seq_len].

Returns: tuple[torch.Tensor, torch.Tensor]

Tuple of (cos, sin) tensors.