nemo_automodel.components.models.mimo_v2_flash.model#
Module Contents#
Classes#
Rotary embedding module matching MiMo-V2-Flash partial-RoPE behavior. |
|
RMSNorm used by MiMo-V2-Flash decoder blocks. |
|
MiMo-V2-Flash attention with full and sliding-window variants. |
|
Decoder block that alternates dense MLP and routed-MoE layers. |
|
Backbone model for Xiaomi MiMo-V2-Flash. |
|
Causal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters. |
Functions#
Data#
API#
- nemo_automodel.components.models.mimo_v2_flash.model._rotate_half(x: torch.Tensor) torch.Tensor#
- nemo_automodel.components.models.mimo_v2_flash.model._apply_rotary_pos_emb(
- q: torch.Tensor,
- k: torch.Tensor,
- cos: torch.Tensor,
- sin: torch.Tensor,
- nemo_automodel.components.models.mimo_v2_flash.model._repeat_kv(hidden_states: torch.Tensor, n_rep: int) torch.Tensor#
- nemo_automodel.components.models.mimo_v2_flash.model._convert_bool_4d_mask_to_additive(
- attention_mask: torch.Tensor,
- dtype: torch.dtype,
- nemo_automodel.components.models.mimo_v2_flash.model._derive_padding_mask(attention_mask: torch.Tensor) torch.Tensor#
- nemo_automodel.components.models.mimo_v2_flash.model._fallback_additive_mask(
- batch_size: int,
- seq_len: int,
- dtype: torch.dtype,
- device: torch.device,
- attention_mask: torch.Tensor | None = None,
- sliding_window: int | None = None,
- nemo_automodel.components.models.mimo_v2_flash.model._ensure_additive_mask(
- mask: torch.Tensor | None,
- *,
- batch_size: int,
- seq_len: int,
- dtype: torch.dtype,
- device: torch.device,
- attention_mask: torch.Tensor | None,
- sliding_window: int | None,
- nemo_automodel.components.models.mimo_v2_flash.model._eager_attention_forward(
- module: torch.nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float,
- dropout: float = 0.0,
- class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashRotaryEmbedding(
- *,
- rope_theta: float,
- head_dim: int,
- partial_rotary_factor: float = 1.0,
- dtype: torch.dtype = torch.bfloat16,
Bases:
torch.nn.ModuleRotary embedding module matching MiMo-V2-Flash partial-RoPE behavior.
Initialization
- inv_freq: torch.Tensor#
None
- forward(
- x: torch.Tensor,
- position_ids: torch.Tensor,
- class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2RMSNorm(
- hidden_size: int,
- eps: float = 1e-06,
- dtype: torch.dtype = torch.bfloat16,
Bases:
torch.nn.ModuleRMSNorm used by MiMo-V2-Flash decoder blocks.
Initialization
- reset_parameters() None#
- forward(hidden_states: torch.Tensor) torch.Tensor#
- class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashAttention(
- config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- is_swa: bool,
- layer_idx: int,
Bases:
torch.nn.ModuleMiMo-V2-Flash attention with full and sliding-window variants.
Initialization
- forward(
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- attention_mask: torch.Tensor | None = None,
- **kwargs: Any,
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashBlock(
- layer_idx: int,
- config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
Bases:
torch.nn.ModuleDecoder block that alternates dense MLP and routed-MoE layers.
Initialization
- forward(
- hidden_states: torch.Tensor,
- *,
- attention_mask: torch.Tensor | None = None,
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
- padding_mask: torch.Tensor | None = None,
- **kwargs: Any,
- init_weights(buffer_device: torch.device) None#
- class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashModel(
- config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
- *,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
- moe_overrides: dict | None = None,
Bases:
torch.nn.ModuleBackbone model for Xiaomi MiMo-V2-Flash.
Initialization
- _build_causal_mask_mapping(
- inputs_embeds: torch.Tensor,
- attention_mask: torch.Tensor | dict[str, torch.Tensor] | None,
- position_ids: torch.Tensor,
- cache_position: torch.Tensor,
- forward(
- input_ids: torch.Tensor | None = None,
- *,
- inputs_embeds: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
- padding_mask: torch.Tensor | None = None,
- cache_position: torch.Tensor | None = None,
- **kwargs: Any,
- init_weights(buffer_device: torch.device | None = None) None#
- class nemo_automodel.components.models.mimo_v2_flash.model.MiMoV2FlashForCausalLM(
- config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
Bases:
nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin,torch.nn.Module,nemo_automodel.components.moe.fsdp_mixin.MoEFSDPSyncMixinCausal LM wrapper for MiMo-V2-Flash with Automodel checkpoint adapters.
Initialization
- _keep_in_fp32_modules_strict#
[‘mlp.gate.e_score_correction_bias’, ‘attention_sink_bias’]
- _pp_keep_self_forward#
True
- _skip_init_weights_on_load#
True
- classmethod from_config(
- config: nemo_automodel.components.models.mimo_v2_flash.config.MiMoV2FlashConfig,
- moe_config: nemo_automodel.components.moe.config.MoEConfig | None = None,
- backend: nemo_automodel.components.models.common.BackendConfig | None = None,
- **kwargs,
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- *model_args,
- **kwargs,
- get_input_embeddings()#
- set_input_embeddings(value)#
- get_output_embeddings()#
- set_output_embeddings(new_embeddings)#
- forward(
- input_ids: torch.Tensor | None = None,
- *,
- inputs_embeds: torch.Tensor | None = None,
- position_ids: torch.Tensor | None = None,
- attention_mask: torch.Tensor | dict[str, torch.Tensor] | None = None,
- padding_mask: torch.Tensor | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs: Any,
- customize_pipeline_stage_modules(
- module_names_per_stage: list[list[str]],
- *,
- layers_prefix: str,
- text_model: torch.nn.Module | None = None,
Keep the SWA rotary embedding on every PP stage.
- initialize_weights(
- buffer_device: torch.device | None = None,
- dtype: torch.dtype = torch.bfloat16,
- nemo_automodel.components.models.mimo_v2_flash.model.ModelClass#
None