nemo_automodel.components.models.step3p7.mtp#
Step3 Multi-Token Prediction blocks.
Step checkpoints store MTP depths after the main decoder layers as
model.layers.{num_hidden_layers + depth}.*. Each depth has the same
decoder block structure plus fusion modules (enorm, hnorm, eh_proj)
and an MTP-local shared head under transformer.shared_head.
Module Contents#
Classes#
Per-depth Step MTP prediction head. |
|
One Step MTP prediction depth. |
|
Stack of Step MTP depths. |
Functions#
Return a shallow config copy patched for a dense sliding-attention MTP layer. |
|
Build Step MTP runtime config from HF-style config fields. |
|
Construct Step MTP depths. |
API#
- nemo_automodel.components.models.step3p7.mtp._get_indexed_value(
- values: Any,
- index: int,
- default: Any,
- nemo_automodel.components.models.step3p7.mtp._ensure_indexed(
- values: Any,
- index: int,
- value: Any,
- nemo_automodel.components.models.step3p7.mtp._make_mtp_block_config(
- config: Any,
- layer_idx: int,
- depth: int,
Return a shallow config copy patched for a dense sliding-attention MTP layer.
- config: Any,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype,
Bases:
torch.nn.ModulePer-depth Step MTP prediction head.
Initialization
- class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock(
- config: Any,
- layer_idx: int,
- depth: int,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- dtype: torch.dtype,
Bases:
nemo_automodel.components.models.step3p5.model.BlockOne Step MTP prediction depth.
Initialization
- forward(
- hidden_states: torch.Tensor,
- *,
- embed_input: torch.Tensor,
- freqs_cis: torch.Tensor,
- attention_mask: torch.Tensor | None = None,
- padding_mask: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- **attn_kwargs: Any,
- init_weights(buffer_device: torch.device) None#
- class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule(
- config: Any,
- mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- dtype: torch.dtype,
Bases:
torch.nn.ModuleStack of Step MTP depths.
Initialization
- property num_depths: int#
- forward(
- hidden_states: torch.Tensor,
- *,
- freqs_cis: torch.Tensor,
- input_ids: torch.LongTensor | None = None,
- embed_fn=None,
- embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
- position_ids: torch.LongTensor | None = None,
- **block_kwargs: Any,
- nemo_automodel.components.models.step3p7.mtp.build_mtp_config_from_hf(
- config: Any,
- *,
- loss_scaling_factor: float = 0.1,
Build Step MTP runtime config from HF-style config fields.
- nemo_automodel.components.models.step3p7.mtp.build_step3p5_mtp(
- config: Any,
- mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- dtype: torch.dtype,
Construct Step MTP depths.