> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.models.mimo_v2_flash.model

## Module Contents

### Classes

| Name                                                                                                             | Description                                                             |
| ---------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
| [`MiMoV2FlashAttention`](#nemo_automodel-components-models-mimo_v2_flash-model-MiMoV2FlashAttention)             | MiMo-V2-Flash attention with full and sliding-window variants.          |
| [`MiMoV2FlashBlock`](#nemo_automodel-components-models-mimo_v2_flash-model-MiMoV2FlashBlock)                     | Decoder block that alternates dense MLP and routed-MoE layers.          |
| [`MiMoV2FlashForCausalLM`](#nemo_automodel-components-models-mimo_v2_flash-model-MiMoV2FlashForCausalLM)         | Causal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters. |
| [`MiMoV2FlashModel`](#nemo_automodel-components-models-mimo_v2_flash-model-MiMoV2FlashModel)                     | Backbone model for Xiaomi MiMo-V2-Flash.                                |
| [`MiMoV2FlashRotaryEmbedding`](#nemo_automodel-components-models-mimo_v2_flash-model-MiMoV2FlashRotaryEmbedding) | Rotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.   |
| [`MiMoV2RMSNorm`](#nemo_automodel-components-models-mimo_v2_flash-model-MiMoV2RMSNorm)                           | RMSNorm used by MiMo-V2-Flash decoder blocks.                           |

### Functions

| Name                                                                                                                           | Description |
| ------------------------------------------------------------------------------------------------------------------------------ | ----------- |
| [`_apply_rotary_pos_emb`](#nemo_automodel-components-models-mimo_v2_flash-model-_apply_rotary_pos_emb)                         | -           |
| [`_convert_bool_4d_mask_to_additive`](#nemo_automodel-components-models-mimo_v2_flash-model-_convert_bool_4d_mask_to_additive) | -           |
| [`_derive_padding_mask`](#nemo_automodel-components-models-mimo_v2_flash-model-_derive_padding_mask)                           | -           |
| [`_eager_attention_forward`](#nemo_automodel-components-models-mimo_v2_flash-model-_eager_attention_forward)                   | -           |
| [`_ensure_additive_mask`](#nemo_automodel-components-models-mimo_v2_flash-model-_ensure_additive_mask)                         | -           |
| [`_fallback_additive_mask`](#nemo_automodel-components-models-mimo_v2_flash-model-_fallback_additive_mask)                     | -           |
| [`_repeat_kv`](#nemo_automodel-components-models-mimo_v2_flash-model-_repeat_kv)                                               | -           |
| [`_rotate_half`](#nemo_automodel-components-models-mimo_v2_flash-model-_rotate_half)                                           | -           |

### Data

[`ModelClass`](#nemo_automodel-components-models-mimo_v2_flash-model-ModelClass)

### API

```python
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention(
    config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    is_swa: bool,
    layer_idx: int
)
```

**Bases:** `Module`

MiMo-V2-Flash attention with full and sliding-window variants.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention.forward(
    hidden_states: torch.Tensor,
    position_embeddings: tuple[torch.Tensor, torch.Tensor],
    attention_mask: torch.Tensor | None = None,
    kwargs: typing.Any = {}
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention.init_weights(
    buffer_device: torch.device,
    init_std: float = 0.02
) -> None
```

```python
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock(
    layer_idx: int,
    config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig,
    backend: nemo_automodel.components.models.common.BackendConfig
)
```

**Bases:** `Module`

Decoder block that alternates dense MLP and routed-MoE layers.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock.forward(
    hidden_states: torch.Tensor,
    attention_mask: torch.Tensor | None = None,
    position_embeddings: tuple[torch.Tensor, torch.Tensor],
    padding_mask: torch.Tensor | None = None,
    kwargs: typing.Any = {}
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock.init_weights(
    buffer_device: torch.device
) -> None
```

```python
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM(
    config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
    backend: nemo_automodel.components.models.common.BackendConfig | None = None,
    kwargs = {}
)
```

**Bases:** [HFCheckpointingMixin](/nemo-automodel/nemo_automodel/components/models/common/hf_checkpointing_mixin#nemo_automodel-components-models-common-hf_checkpointing_mixin-HFCheckpointingMixin), `Module`, [MoEFSDPSyncMixin](/nemo-automodel/nemo_automodel/components/moe/fsdp_mixin#nemo_automodel-components-moe-fsdp_mixin-MoEFSDPSyncMixin)

Causal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.customize_pipeline_stage_modules(
    module_names_per_stage: list[list[str]],
    layers_prefix: str,
    text_model: torch.nn.Module | None = None
) -> list[list[str]]
```

Keep the SWA rotary embedding on every PP stage.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.forward(
    input_ids: torch.Tensor | None = None,
    inputs_embeds: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
    padding_mask: torch.Tensor | None = None,
    logits_to_keep: typing.Union[int, torch.Tensor] = 0,
    output_hidden_states: typing.Optional[bool] = None,
    kwargs: typing.Any = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast
```

Forward pass producing text logits.

**Parameters:**

Input token IDs `[B, S]` (or THD-packed `[T]`/`[1, T]`).

Pre-computed input embeddings (optional).

Optional position indices.

2D padding mask, 4D additive mask, or per-type dict.

Optional MoE padding mask.

If 0, compute logits for all positions (training default);
otherwise compute only the last `logits_to_keep` positions.

When set, the returned output carries the final
hidden states (input to `lm_head`) in `hidden_states`.

Additional arguments forwarded to the base model.

**Returns:** `CausalLMOutputWithPast`

class:`~transformers.modeling_outputs.CausalLMOutputWithPast` with

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.from_config(
    config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
    backend: nemo_automodel.components.models.common.BackendConfig | None = None,
    kwargs = {}
)
```

classmethod

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.from_pretrained(
    pretrained_model_name_or_path: str,
    model_args = (),
    kwargs = {}
)
```

classmethod

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.get_input_embeddings()
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.get_output_embeddings()
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.initialize_weights(
    buffer_device: torch.device | None = None,
    dtype: torch.dtype = torch.bfloat16
) -> None
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.set_input_embeddings(
    value
)
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.set_output_embeddings(
    new_embeddings
)
```

```python
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel(
    config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
    backend: nemo_automodel.components.models.common.BackendConfig,
    moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
    moe_overrides: dict | None = None
)
```

**Bases:** `Module`

Backbone model for Xiaomi MiMo-V2-Flash.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel._build_causal_mask_mapping(
    inputs_embeds: torch.Tensor,
    attention_mask: torch.Tensor | dict[str, torch.Tensor] | None,
    position_ids: torch.Tensor,
    cache_position: torch.Tensor
) -> dict[str, torch.Tensor]
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel.forward(
    input_ids: torch.Tensor | None = None,
    inputs_embeds: torch.Tensor | None = None,
    position_ids: torch.Tensor | None = None,
    attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
    padding_mask: torch.Tensor | None = None,
    cache_position: torch.Tensor | None = None,
    kwargs: typing.Any = {}
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel.init_weights(
    buffer_device: torch.device | None = None
) -> None
```

```python
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashRotaryEmbedding(
    rope_theta: float,
    head_dim: int,
    partial_rotary_factor: float = 1.0,
    dtype: torch.dtype = torch.bfloat16
)
```

**Bases:** `Module`

Rotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashRotaryEmbedding.forward(
    x: torch.Tensor,
    position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm(
    hidden_size: int,
    eps: float = 1e-06,
    dtype: torch.dtype = torch.bfloat16
)
```

**Bases:** `Module`

RMSNorm used by MiMo-V2-Flash decoder blocks.

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm.forward(
    hidden_states: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm.reset_parameters() -> None
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._apply_rotary_pos_emb(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._convert_bool_4d_mask_to_additive(
    attention_mask: torch.Tensor,
    dtype: torch.dtype
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._derive_padding_mask(
    attention_mask: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._eager_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: torch.Tensor | None,
    scaling: float,
    dropout: float = 0.0
) -> tuple[torch.Tensor, torch.Tensor]
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._ensure_additive_mask(
    mask: torch.Tensor | None,
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
    device: torch.device,
    attention_mask: torch.Tensor | None,
    sliding_window: int | None
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._fallback_additive_mask(
    batch_size: int,
    seq_len: int,
    dtype: torch.dtype,
    device: torch.device,
    attention_mask: torch.Tensor | None = None,
    sliding_window: int | None = None
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._repeat_kv(
    hidden_states: torch.Tensor,
    n_rep: int
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model._rotate_half(
    x: torch.Tensor
) -> torch.Tensor
```

```python
nemo_automodel.components.models.mimo_v2_flash.model.ModelClass = MiMoV2FlashForCausalLM
```