core.models.mimo.model.base#

Module Contents#

Classes#

MimoModel

Multimodal In/Out Model supporting arbitrary combinations of modalities.

Data#

API#

core.models.mimo.model.base.logger#

‘getLogger(…)’

class core.models.mimo.model.base.MimoModel(
mimo_config: megatron.core.models.mimo.config.MimoModelConfig,
cp_group=None,
tp_group=None,
)#

Bases: megatron.core.transformer.MegatronModule

Multimodal In/Out Model supporting arbitrary combinations of modalities.

.. warning:: EXPERIMENTAL: This class is experimental, still under active development, and the API is subject to change without notice. Use at your own risk.

.. note:: This implementation is in development and may undergo API changes.

This model processes multiple modalities (e.g., vision, audio) alongside text, combining their embeddings before passing them through a language model.

Parameters:

mimo_config (MimoModelConfig) – Configuration for the model, including language model and modality submodules

Initialization

Initialize the multimodal model.

.. rubric:: Example

# Create a model with default configuration
model = MimoModel(mimo_config)
sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Build sharded state dict, bypassing parallel_state global fallbacks.

Iterates modality_submodules manually (ModuleDict lacks sharded_state_dict) and injects dp_cp_group from each module’s pg_collection.

align_embeddings_by_token_positions(
modality_embeddings: Dict[str, torch.Tensor],
input_ids: torch.Tensor,
special_token_ids: Dict[str, int],
) torch.Tensor#

Align embeddings from different modalities based on special token positions in input_ids.

Parameters:
  • modality_embeddings – Dictionary mapping modality names to their embeddings. For all modalities: tensor of shape (N, H). Shape: (num_tokens_for_modality, hidden_dim)

  • input_ids – Input token IDs. Shape: (B, S) or (S,) Contains special tokens that mark where each modality’s embeddings should go. The number of special tokens for each modality should exactly match the number of embeddings for that modality.

  • special_token_ids – Dictionary mapping modality names to their special token IDs

Returns:

(S, B, H)

Return type:

Combined embeddings tensor. Shape

_initialize_submodules() None#

Initialize modality submodules from the ModuleSpec configurations.

When role is set, only initializes submodules this rank participates in. Stage info is passed to from_spec() to conditionally skip projection.

_initialize_language_model() None#

Initialize the language model.

When role is set, only initializes if this rank participates in language module.

set_input_tensor(input_tensor)#

Set input tensor for pipeline parallelism.

This method is required by Megatron’s pipeline parallel mechanism. It passes the output tensor from the previous stage as input to this stage.

Parameters:

input_tensor

Either:

  • Dict[str, Tensor]: Maps module names to their input tensors (for multi-module PP)

  • Tensor or List[Tensor]: Single tensor for language model (backward compat)

Returns:

None

get_text_embeddings(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
special_token_ids: Dict[str, int],
) torch.Tensor#

Get embeddings for text tokens in the input.

Parameters:
  • input_ids – Input token IDs. Shape: (B, S) Contains text tokens and potentially special tokens for other modalities.

  • position_ids – Position IDs corresponding to input tokens, used for positional encoding. Shape: (B, S)

  • special_token_ids – Dictionary mapping modality names to their special token IDs. Used to identify non-text tokens in the input_ids.

Returns:

Embeddings for text tokens. Shape: (N, H), where N is the number of text tokens.

Return type:

torch.Tensor

forward(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None,
packing_kwargs: Optional[dict] = None,
)#

Forward pass through the multimodal model.

Parameters:
  • input_ids – Input token IDs. Shape: (B, S)

  • position_ids – Position IDs. Shape: (B, S)

  • attention_mask – Attention mask. Shape: (B, S)

  • loss_mask – Loss mask. Shape: (B, S)

  • labels – Labels for training. Shape: (B, S)

  • modality_inputs – Dictionary mapping modality names to encoder inputs. For example: { “images”: { “clip_encoder”: {“pixel_values”: clip_images}, “vit_encoder”: {“images”: vit_images} }, “audio”: { “whisper_encoder”: {“input_features”: whisper_features} } }

  • packing_kwargs – Optional dictionary of kwargs to construct PackedSeqParams if packed_seq_params is not provided. For example: { “cu_seqlens_q”: cu_seqlens, “cu_seqlens_kv”: cu_seqlens, “cu_seqlens_q_padded”: cu_seqlens_padded, “cu_seqlens_kv_padded”: cu_seqlens_padded, “max_seqlen_q”: torch.tensor( max(seqlens_padded), dtype=torch.int32 ), “max_seqlen_kv”: torch.tensor( max(seqlens_padded), dtype=torch.int32 ), }

Returns:

(output, loss_mask) where output semantics depend on role: - Encoder-only ranks: Dict[str, Tensor] of encoder outputs - Language module ranks: language model output (logits or loss) - No role (all modules colocated): language model output

Return type:

tuple

_forward_encoders(
input_ids: Optional[torch.Tensor],
modality_inputs: Optional[Dict[str, Dict[str, Any]]],
input_tensors: Optional[Dict[str, torch.Tensor]],
) Dict[str, torch.Tensor]#

Forward pass for encoder modules on this rank.

Parameters:
  • modality_inputs – Raw inputs for each modality (images, audio, etc.)

  • input_tensors – Hidden states from previous pipeline stages

Returns:

Dict mapping encoder names to their output tensors

_attach_modality_split_sizes(
output: torch.Tensor,
input_ids: Optional[torch.Tensor],
encoder_name: str,
) None#

Annotate flat modality outputs with per-sample split sizes for bridge fan-out.

Only attaches when per-sample token counts are non-uniform. Uniform counts give equal splits, which the bridge’s torch.tensor_split fallback already produces, so the metadata would be a no-op.

TODO(mimo): non-uniform per-sample counts in fan-in (encoder DP > LM DP) are not supported. Multiple encoder ranks contribute slices to a single LM peer, and the receiver-side torch.cat path in BridgeCommunicator has no metadata channel today, so per-sample boundaries are lost on the LM rank. Lift this by routing per-sample sizes through the bridge alongside the activations and adding a sample-aligned concat path.

_has_encoder_tokens(
input_ids: Optional[torch.Tensor],
encoder_name: str,
) bool#

Return whether the batch contains tokens for an encoder module.

_empty_encoder_output(encoder_name: str) torch.Tensor#

Return the bridge payload for text-only non-colocated batches.

_build_packed_seq_params(
packing_kwargs: Optional[dict],
) Optional[megatron.core.packed_seq_params.PackedSeqParams]#

Build THD PackedSeqParams from packing_kwargs (None if not packing).

_shard_language_inputs(
embeddings: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
loss_mask: Optional[torch.Tensor],
packed_seq_params: Optional[megatron.core.packed_seq_params.PackedSeqParams] = None,
) Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[megatron.core.packed_seq_params.PackedSeqParams]]#

Apply CP/SP sharding via the partition adapter, or pass through if inactive.

embeddings are sequence-first (S, B, H) (None on non-first PP stages) and come back in (S/(cp*tp), B, H); labels/loss_mask are (B, S).

_forward_language_module(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
loss_mask: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
input_tensors: Optional[Dict[str, torch.Tensor]],
packing_kwargs: Optional[dict] = None,
) Tuple[Any, Optional[torch.Tensor]]#

Forward pass for language module on this rank.

Parameters:
  • input_ids – Token IDs

  • position_ids – Position IDs

  • attention_mask – Attention mask. Must be None under context parallelism (CP-local hidden states cannot line up with a dense mask); mask via a causal attn_mask_type or packed_seq_params instead.

  • loss_mask – Loss mask for per-token loss normalization

  • labels – Labels for loss computation

  • input_tensors – Hidden states or embeddings from previous stage

  • packing_kwargs – Optional kwargs to construct packed (THD) sequence params.

Returns:

Tuple of (language model output, possibly CP-sharded loss mask). The output is hidden states, logits, or loss depending on the stage.

_build_colocated_communicators()#
destroy() None#

Release process groups owned by this MimoModel.

_apply_colocated_comms(modality_embeddings)#

Transform encoder embeddings from encoder TP/DP to LLM TP/DP layout.

_forward_all_modules(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
loss_mask: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
modality_inputs: Optional[Dict[str, Dict[str, Any]]],
packing_kwargs: Optional[dict] = None,
)#

Forward pass when all modules are on all ranks (no multi-module PP).

This is the original behavior, preserved for backward compatibility.