nemo_automodel.components.models.mimo_v2_flash.model

View as Markdown

Module Contents

Classes

NameDescription
MiMoV2FlashAttentionMiMo-V2-Flash attention with full and sliding-window variants.
MiMoV2FlashBlockDecoder block that alternates dense MLP and routed-MoE layers.
MiMoV2FlashForCausalLMCausal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters.
MiMoV2FlashModelBackbone model for Xiaomi MiMo-V2-Flash.
MiMoV2FlashRotaryEmbeddingRotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.
MiMoV2RMSNormRMSNorm used by MiMo-V2-Flash decoder blocks.

Functions

Data

ModelClass

API

class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
is_swa: bool,
layer_idx: int
)

Bases: Module

MiMo-V2-Flash attention with full and sliding-window variants.

attention_dropout
= float(config.attention_dropout or 0.0)
head_dim
= config.swa_head_dim
k_proj
num_attention_heads
= config.swa_num_attention_heads
num_key_value_groups
num_key_value_heads
= config.swa_num_key_value_heads
o_proj
q_proj
rope_dim
= self.rope_dim - self.rope_dim % 2
scaling
= self.head_dim ** -0.5
v_head_dim
= config.swa_v_head_dim
v_proj
v_scale
= getattr(config, 'attention_value_scale', None)
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention.forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock(
layer_idx: int,
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Decoder block that alternates dense MLP and routed-MoE layers.

attention_type
input_layernorm
mlp
= MoE(moe_config, backend)
post_attention_layernorm
self_attn
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock.forward(
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
padding_mask: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock.init_weights(
buffer_device: torch.device
) -> None
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)

Bases: HFCheckpointingMixin, Module, MoEFSDPSyncMixin

Causal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters.

_keep_in_fp32_modules_strict
backend
= backend or BackendConfig()
lm_head
model
state_dict_adapter
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.customize_pipeline_stage_modules(
module_names_per_stage: list[list[str]],
layers_prefix: str,
text_model: torch.nn.Module | None = None
) -> list[list[str]]

Keep the SWA rotary embedding on every PP stage.

nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.forward(
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: typing.Union[int, torch.Tensor] = 0,
output_hidden_states: typing.Optional[bool] = None,
kwargs: typing.Any = {}
) -> transformers.modeling_outputs.CausalLMOutputWithPast

Forward pass producing text logits.

Parameters:

input_ids
torch.Tensor | NoneDefaults to None

Input token IDs [B, S] (or THD-packed [T]/[1, T]).

inputs_embeds
torch.Tensor | NoneDefaults to None

Pre-computed input embeddings (optional).

position_ids
torch.Tensor | NoneDefaults to None

Optional position indices.

attention_mask
torch.Tensor | dict[str, torch.Tensor] | NoneDefaults to None

2D padding mask, 4D additive mask, or per-type dict.

padding_mask
torch.Tensor | NoneDefaults to None

Optional MoE padding mask.

logits_to_keep
Union[int, torch.Tensor]Defaults to 0

If 0, compute logits for all positions (training default); otherwise compute only the last logits_to_keep positions.

output_hidden_states
Optional[bool]Defaults to None

When set, the returned output carries the final hidden states (input to lm_head) in hidden_states.

**kwargs
AnyDefaults to {}

Additional arguments forwarded to the base model.

Returns: CausalLMOutputWithPast

class:~transformers.modeling_outputs.CausalLMOutputWithPast with

nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.from_config(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
kwargs = {}
)
classmethod
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.from_pretrained(
pretrained_model_name_or_path: str,
model_args = (),
kwargs = {}
)
classmethod
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.get_input_embeddings()
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.get_output_embeddings()
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16
) -> None
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.set_input_embeddings(
value
)
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM.set_output_embeddings(
new_embeddings
)
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None
)

Bases: Module

Backbone model for Xiaomi MiMo-V2-Flash.

embed_tokens
has_sliding_layers
layers
moe_config
= moe_config or MoEConfig(**moe_defaults)
norm
rotary_emb
swa_rotary_emb
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel._build_causal_mask_mapping(
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None,
position_ids: torch.Tensor,
cache_position: torch.Tensor
) -> dict[str, torch.Tensor]
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel.forward(
input_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
padding_mask: torch.Tensor | None = None,
cache_position: torch.Tensor | None = None,
kwargs: typing.Any = {}
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel.init_weights(
buffer_device: torch.device | None = None
) -> None
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashRotaryEmbedding(
rope_theta: float,
head_dim: int,
partial_rotary_factor: float = 1.0,
dtype: torch.dtype = torch.bfloat16
)

Bases: Module

Rotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.

attention_scaling
= 1.0
inv_freq
Tensor
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashRotaryEmbedding.forward(
x: torch.Tensor,
position_ids: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm(
hidden_size: int,
eps: float = 1e-06,
dtype: torch.dtype = torch.bfloat16
)

Bases: Module

RMSNorm used by MiMo-V2-Flash decoder blocks.

weight
= nn.Parameter(torch.ones(hidden_size, dtype=dtype))
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm.forward(
hidden_states: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm.reset_parameters() -> None
nemo_automodel.components.models.mimo_v2_flash.model._apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.mimo_v2_flash.model._convert_bool_4d_mask_to_additive(
attention_mask: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model._derive_padding_mask(
attention_mask: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model._eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0
) -> tuple[torch.Tensor, torch.Tensor]
nemo_automodel.components.models.mimo_v2_flash.model._ensure_additive_mask(
mask: torch.Tensor | None,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: torch.Tensor | None,
sliding_window: int | None
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model._fallback_additive_mask(
batch_size: int,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: torch.Tensor | None = None,
sliding_window: int | None = None
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model._repeat_kv(
hidden_states: torch.Tensor,
n_rep: int
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model._rotate_half(
x: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.models.mimo_v2_flash.model.ModelClass = MiMoV2FlashForCausalLM