> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.step3p7.mtp

Step3 Multi-Token Prediction blocks.

Step checkpoints store MTP depths after the main decoder layers as
`model.layers.&#123;num_hidden_layers + depth&#125;.*`.  Each depth has the same
decoder block structure plus fusion modules (`enorm`, `hnorm`, `eh_proj`)
and an MTP-local shared head under `transformer.shared_head`.

## Module Contents

### Classes

| Name                                                                                         | Description                         |
| -------------------------------------------------------------------------------------------- | ----------------------------------- |
| [`Step3p5MTPBlock`](#nemo_automodel-components-models-step3p7-mtp-Step3p5MTPBlock)           | One Step MTP prediction depth.      |
| [`Step3p5MTPModule`](#nemo_automodel-components-models-step3p7-mtp-Step3p5MTPModule)         | Stack of Step MTP depths.           |
| [`Step3p5MTPSharedHead`](#nemo_automodel-components-models-step3p7-mtp-Step3p5MTPSharedHead) | Per-depth Step MTP prediction head. |

### Functions

| Name                                                                                                 | Description                                                                   |
| ---------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------- |
| [`_ensure_indexed`](#nemo_automodel-components-models-step3p7-mtp-_ensure_indexed)                   | -                                                                             |
| [`_get_indexed_value`](#nemo_automodel-components-models-step3p7-mtp-_get_indexed_value)             | -                                                                             |
| [`_make_mtp_block_config`](#nemo_automodel-components-models-step3p7-mtp-_make_mtp_block_config)     | Return a shallow config copy patched for a dense sliding-attention MTP layer. |
| [`build_mtp_config_from_hf`](#nemo_automodel-components-models-step3p7-mtp-build_mtp_config_from_hf) | Build Step MTP runtime config from HF-style config fields.                    |
| [`build_step3p5_mtp`](#nemo_automodel-components-models-step3p7-mtp-build_step3p5_mtp)               | Construct Step MTP depths.                                                    |

### API

```python
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock(
    config: typing.Any,
    layer_idx: int,
    depth: int,
    moe_config: nemo_automodel.components.moe.config.MoEConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    dtype: torch.dtype
)
```

**Bases:** [Block](/nemo-automodel/nemo_automodel/components/models/step3p5/model#nemo_automodel-components-models-step3p5-model-Block)

One Step MTP prediction depth.

```python
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock.forward(
    hidden_states: torch.Tensor,
    embed_input: torch.Tensor,
    freqs_cis: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    padding_mask: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    attn_kwargs: typing.Any = {}
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPBlock.init_weights(
    buffer_device: torch.device
) -> None
```

```python
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule(
    config: typing.Any,
    mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig,
    dtype: torch.dtype
)
```

**Bases:** `Module`

Stack of Step MTP depths.

```python
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule.forward(
    hidden_states: torch.Tensor,
    freqs_cis: torch.Tensor,
    input_ids: torch.LongTensor | None = None,
    embed_fn = None,
    embed_inputs: tuple[torch.Tensor, ...] | list[torch.Tensor] | None = None,
    position_ids: torch.LongTensor | None = None,
    block_kwargs: typing.Any = {}
) -> list[torch.Tensor]
```

```python
class nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead(
    config: typing.Any,
    backend: nemo_automodel.components.models.common.BackendConfig,
    dtype: torch.dtype
)
```

**Bases:** `Module`

Per-depth Step MTP prediction head.

```python
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead.forward(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.models.step3p7.mtp.Step3p5MTPSharedHead.init_weights(
    buffer_device: torch.device
) -> None
```

```python
nemo_automodel.components.models.step3p7.mtp._ensure_indexed(
    values: typing.Any,
    index: int,
    value: typing.Any
) -> list[typing.Any]
```

```python
nemo_automodel.components.models.step3p7.mtp._get_indexed_value(
    values: typing.Any,
    index: int,
    default: typing.Any
) -> typing.Any
```

```python
nemo_automodel.components.models.step3p7.mtp._make_mtp_block_config(
    config: typing.Any,
    layer_idx: int,
    depth: int
) -> typing.Any
```

Return a shallow config copy patched for a dense sliding-attention MTP layer.

```python
nemo_automodel.components.models.step3p7.mtp.build_mtp_config_from_hf(
    config: typing.Any,
    loss_scaling_factor: float = 0.1
) -> nemo_automodel.components.models.common.mtp.MTPConfig
```

Build Step MTP runtime config from HF-style config fields.

```python
nemo_automodel.components.models.step3p7.mtp.build_step3p5_mtp(
    config: typing.Any,
    mtp_config: nemo_automodel.components.models.common.mtp.MTPConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig,
    dtype: torch.dtype
) -> nemo_automodel.components.models.step3p7.mtp.Step3p5MTPModule
```

Construct Step MTP depths.