nemo_automodel.components.models.step3p5.model

View as Markdown

Module Contents

Classes

NameDescription
BlockStep3p5 transformer block with attention, MLP/MoE, and shared experts.
Step3p5ForCausalLMStep3p5 model for causal language modeling.
Step3p5ModelStep3p5 transformer model.

Functions

NameDescription
_keep_step_router_bias_fp32Keep Step router correction bias in fp32 after module-wide dtype casts.
parse_moe_layers_enumParse moe_layers_enum to get set of MoE layer indices.

Data

ModelClass

API

class nemo_automodel.components.models.step3p5.model.Block(
layer_idx: int,
config: typing.Any,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Step3p5 transformer block with attention, MLP/MoE, and shared experts.

attention_type
input_layernorm
is_moe_layer
= layer_idx in moe_layers
moe
= MoE(layer_moe_config, backend)
post_attention_layernorm
self_attn
= Step3p5Attention(config, layer_idx, backend)
share_expert
nemo_automodel.components.models.step3p5.model.Block.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.step3p5.model.Block.init_weights(
buffer_device: torch.device
) -> None
class nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM(
config: typing.Any,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

Step3p5 model for causal language modeling.

_keep_in_fp32_modules
= ['rotary_emb']
backend
= backend or BackendConfig()
lm_head
model
state_dict_adapter
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
attn_kwargs: typing.Any = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast

Forward pass returning :class:~transformers.modeling_outputs.CausalLMOutputWithPast.

Supports both BSHD format (input_ids shape [B, S]) and THD format (input_ids shape [1, T]); when attn_kwargs["qkv_format"] == "thd", inputs are squeezed to THD before the base-model forward and logits (and the final hidden_states) are unsqueezed back to a leading-batch dimension on exit.

Parameters:

input_ids
torch.Tensor

Input token IDs.

position_ids
torch.Tensor | NoneDefaults to None

Optional position indices.

attention_mask
torch.Tensor | NoneDefaults to None

Optional 2D padding mask.

padding_mask
torch.Tensor | NoneDefaults to None

Optional padding mask used by the THD squeeze helper.

logits_to_keep
Union[int, torch.Tensor]Defaults to 0

If 0 (default), compute logits for all positions; if > 0 (or a tensor), only compute logits for the last logits_to_keep positions (avoids materialising the full logit matrix during generation / fused CE).

output_hidden_states
Optional[bool]Defaults to None

Whether to carry the final hidden states on the output.

**attn_kwargs
AnyDefaults to {}

Additional arguments forwarded to the base model.

Returns: CausalLMOutputWithPast

class:~transformers.modeling_outputs.CausalLMOutputWithPast with logits and,

nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.from_config(
config: typing.Any,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.get_input_embeddings()
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.get_output_embeddings()
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.step3p5.model.Step3p5ForCausalLM.set_output_embeddings(
new_embeddings
)
class nemo_automodel.components.models.step3p5.model.Step3p5Model(
config: typing.Any,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None
)

Bases: Module

Step3p5 transformer model.

embed_tokens
has_sliding_layers
= 'sliding_attention' in layer_types
head_dim
layers
= torch.nn.ModuleDict()
max_seq_len
= config.max_position_embeddings
moe_config
= moe_config or MoEConfig(**moe_defaults)
norm
rotary_emb
nemo_automodel.components.models.step3p5.model.Step3p5Model._apply(
fn
)
nemo_automodel.components.models.step3p5.model.Step3p5Model.forward(
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
padding_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.step3p5.model.Step3p5Model.init_weights(
buffer_device: torch.device | None = None
) -> None
nemo_automodel.components.models.step3p5.model._keep_step_router_bias_fp32(
module: torch.nn.Module
) -> None

Keep Step router correction bias in fp32 after module-wide dtype casts.

nemo_automodel.components.models.step3p5.model.parse_moe_layers_enum(
moe_layers_enum: str | int | tuple | list | None,
num_hidden_layers: int
) -> set[int]

Parse moe_layers_enum to get set of MoE layer indices.

Parameters:

moe_layers_enum
str | int | tuple | list | None

Tuple/list of layer indices, integer, comma-separated string, or None. HF Step-3.5-Flash uses tuple format like (3, 4, 5, …, 44).

num_hidden_layers
int

Total number of hidden layers.

Returns: set[int]

Set of layer indices that should be MoE layers.

nemo_automodel.components.models.step3p5.model.ModelClass = Step3p5ForCausalLM