nemo_automodel.components.models.mimo_v2_flash.model#

Module Contents#

Classes#

MiMoV2FlashRotaryEmbedding

Rotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.

MiMoV2RMSNorm

RMSNorm used by MiMo-V2-Flash decoder blocks.

MiMoV2FlashAttention

MiMo-V2-Flash attention with full and sliding-window variants.

MiMoV2FlashBlock

Decoder block that alternates dense MLP and routed-MoE layers.

MiMoV2FlashModel

Backbone model for Xiaomi MiMo-V2-Flash.

MiMoV2FlashForCausalLM

Causal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters.

Functions#

Data#

API#

nemo_automodel.components.models.mimo_v2_flash.model._rotate_half(x: torch.Tensor) torch.Tensor#
nemo_automodel.components.models.mimo_v2_flash.model._apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#
nemo_automodel.components.models.mimo_v2_flash.model._repeat_kv(hidden_states: torch.Tensor, n_rep: int) torch.Tensor#
nemo_automodel.components.models.mimo_v2_flash.model._convert_bool_4d_mask_to_additive(
attention_mask: torch.Tensor,
dtype: torch.dtype,
) torch.Tensor#
nemo_automodel.components.models.mimo_v2_flash.model._derive_padding_mask(attention_mask: torch.Tensor) torch.Tensor#
nemo_automodel.components.models.mimo_v2_flash.model._fallback_additive_mask(
batch_size: int,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: torch.Tensor | None = None,
sliding_window: int | None = None,
) torch.Tensor#
nemo_automodel.components.models.mimo_v2_flash.model._ensure_additive_mask(
mask: torch.Tensor | None,
*,
batch_size: int,
seq_len: int,
dtype: torch.dtype,
device: torch.device,
attention_mask: torch.Tensor | None,
sliding_window: int | None,
) torch.Tensor#
nemo_automodel.components.models.mimo_v2_flash.model._eager_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
) tuple[torch.Tensor, torch.Tensor]#
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashRotaryEmbedding(
*,
rope_theta: float,
head_dim: int,
partial_rotary_factor: float = 1.0,
dtype: torch.dtype = torch.bfloat16,
)#

Bases: torch.nn.Module

Rotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.

Initialization

inv_freq: torch.Tensor#

None

forward(
x: torch.Tensor,
position_ids: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor]#
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm(
hidden_size: int,
eps: float = 1e-06,
dtype: torch.dtype = torch.bfloat16,
)#

Bases: torch.nn.Module

RMSNorm used by MiMo-V2-Flash decoder blocks.

Initialization

reset_parameters() None#
forward(hidden_states: torch.Tensor) torch.Tensor#
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
is_swa: bool,
layer_idx: int,
)#

Bases: torch.nn.Module

MiMo-V2-Flash attention with full and sliding-window variants.

Initialization

forward(
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: torch.Tensor | None = None,
**kwargs: Any,
) tuple[torch.Tensor, torch.Tensor]#
init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock(
layer_idx: int,
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Decoder block that alternates dense MLP and routed-MoE layers.

Initialization

forward(
hidden_states: torch.Tensor,
*,
attention_mask: torch.Tensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
padding_mask: torch.Tensor | None = None,
**kwargs: Any,
) torch.Tensor#
init_weights(buffer_device: torch.device) None#
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
*,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
moe_overrides: dict | None = None,
)#

Bases: torch.nn.Module

Backbone model for Xiaomi MiMo-V2-Flash.

Initialization

_build_causal_mask_mapping(
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None,
position_ids: torch.Tensor,
cache_position: torch.Tensor,
) dict[str, torch.Tensor]#
forward(
input_ids: torch.Tensor | None = None,
*,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
padding_mask: torch.Tensor | None = None,
cache_position: torch.Tensor | None = None,
**kwargs: Any,
) torch.Tensor#
init_weights(buffer_device: torch.device | None = None) None#
class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#

Bases: nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin, torch.nn.Module, nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixin

Causal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters.

Initialization

_keep_in_fp32_modules_strict#

[‘mlp.gate.e_score_correction_bias’, ‘attention_sink_bias’]

_pp_keep_self_forward#

True

_skip_init_weights_on_load#

True

classmethod from_config(
config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
backend: nemo_automodel.components.models.common.BackendConfig | None = None,
**kwargs,
)#
classmethod from_pretrained(
pretrained_model_name_or_path: str,
*model_args,
**kwargs,
)#
get_input_embeddings()#
set_input_embeddings(value)#
get_output_embeddings()#
set_output_embeddings(new_embeddings)#
forward(
input_ids: torch.Tensor | None = None,
*,
inputs_embeds: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
padding_mask: torch.Tensor | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Any,
) torch.Tensor#
customize_pipeline_stage_modules(
module_names_per_stage: list[list[str]],
*,
layers_prefix: str,
text_model: torch.nn.Module | None = None,
) list[list[str]]#

Keep the SWA rotary embedding on every PP stage.

initialize_weights(
buffer_device: torch.device | None = None,
dtype: torch.dtype = torch.bfloat16,
) None#
nemo_automodel.components.models.mimo_v2_flash.model.ModelClass#

None