core.models.mimo.model.base#

Module Contents#

Classes#

MimoModel

Multimodal In/Out Model supporting arbitrary combinations of modalities.

Data#

API#

core.models.mimo.model.base.logger#

‘getLogger(…)’

class core.models.mimo.model.base.MimoModel(
mimo_config: megatron.core.models.mimo.config.MimoModelConfig,
cp_group=None,
tp_group=None,
)#

Bases: megatron.core.transformer.MegatronModule

Multimodal In/Out Model supporting arbitrary combinations of modalities.

.. warning:: EXPERIMENTAL: This class is experimental, still under active development, and the API is subject to change without notice. Use at your own risk.

.. note:: This implementation is in development and may undergo API changes.

This model processes multiple modalities (e.g., vision, audio) alongside text, combining their embeddings before passing them through a language model.

Parameters:

mimo_config (MimoModelConfig) – Configuration for the model, including language model and modality submodules

Initialization

Initialize the multimodal model.

.. rubric:: Example

# Create a model with default configuration
model = MimoModel(mimo_config)
sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Build sharded state dict, bypassing parallel_state global fallbacks.

Iterates modality_submodules manually (ModuleDict lacks sharded_state_dict) and injects dp_cp_group from each module’s pg_collection.

align_embeddings_by_token_positions(
modality_embeddings: Dict[str, torch.Tensor],
input_ids: torch.Tensor,
special_token_ids: Dict[str, int],
) torch.Tensor#

Align embeddings from different modalities based on special token positions in input_ids.

Parameters:
  • modality_embeddings – Dictionary mapping modality names to their embeddings. For all modalities: tensor of shape (N, H). Shape: (num_tokens_for_modality, hidden_dim)

  • input_ids – Input token IDs. Shape: (B, S) or (S,) Contains special tokens that mark where each modality’s embeddings should go. The number of special tokens for each modality should exactly match the number of embeddings for that modality.

  • special_token_ids – Dictionary mapping modality names to their special token IDs

Returns:

(S, B, H)

Return type:

Combined embeddings tensor. Shape

_initialize_submodules() None#

Initialize modality submodules from the ModuleSpec configurations.

When role is set, only initializes submodules this rank participates in. Stage info is passed to from_spec() to conditionally skip projection.

_initialize_language_model() None#

Initialize the language model.

When role is set, only initializes if this rank participates in language module.

set_input_tensor(input_tensor)#

Set input tensor for pipeline parallelism.

This method is required by Megatron’s pipeline parallel mechanism. It passes the output tensor from the previous stage as input to this stage.

Parameters:

input_tensor

Either:

  • Dict[str, Tensor]: Maps module names to their input tensors (for multi-module PP)

  • Tensor or List[Tensor]: Single tensor for language model (backward compat)

Returns:

None

get_text_embeddings(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
special_token_ids: Dict[str, int],
) torch.Tensor#

Get embeddings for text tokens in the input.

Parameters:
  • input_ids – Input token IDs. Shape: (B, S) Contains text tokens and potentially special tokens for other modalities.

  • position_ids – Position IDs corresponding to input tokens, used for positional encoding. Shape: (B, S)

  • special_token_ids – Dictionary mapping modality names to their special token IDs. Used to identify non-text tokens in the input_ids.

Returns:

Embeddings for text tokens. Shape: (N, H), where N is the number of text tokens.

Return type:

torch.Tensor

forward(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None,
packing_kwargs: Optional[dict] = None,
)#

Forward pass through the multimodal model.

Parameters:
  • input_ids – Input token IDs. Shape: (B, S)

  • position_ids – Position IDs. Shape: (B, S)

  • attention_mask – Attention mask. Shape: (B, S)

  • loss_mask – Loss mask. Shape: (B, S)

  • labels – Labels for training. Shape: (B, S)

  • modality_inputs – Dictionary mapping modality names to encoder inputs. For example: { “images”: { “clip_encoder”: {“pixel_values”: clip_images}, “vit_encoder”: {“images”: vit_images} }, “audio”: { “whisper_encoder”: {“input_features”: whisper_features} } }

  • packing_kwargs – Optional dictionary of kwargs to construct PackedSeqParams if packed_seq_params is not provided. For example: { “cu_seqlens_q”: cu_seqlens, “cu_seqlens_kv”: cu_seqlens, “cu_seqlens_q_padded”: cu_seqlens_padded, “cu_seqlens_kv_padded”: cu_seqlens_padded, “max_seqlen_q”: torch.tensor( max(seqlens_padded), dtype=torch.int32 ), “max_seqlen_kv”: torch.tensor( max(seqlens_padded), dtype=torch.int32 ), }

Returns:

(output, loss_mask) where output semantics depend on role: - Encoder-only ranks: Dict[str, Tensor] of encoder outputs - Language module ranks: language model output (logits or loss) - No role (all modules colocated): language model output

Return type:

tuple

_forward_encoders(
modality_inputs: Optional[Dict[str, Dict[str, Any]]],
input_tensors: Optional[Dict[str, torch.Tensor]],
) Dict[str, torch.Tensor]#

Forward pass for encoder modules on this rank.

Parameters:
  • modality_inputs – Raw inputs for each modality (images, audio, etc.)

  • input_tensors – Hidden states from previous pipeline stages

Returns:

Dict mapping encoder names to their output tensors

_forward_language_module(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
input_tensors: Optional[Dict[str, torch.Tensor]],
) torch.Tensor#

Forward pass for language module on this rank.

Parameters:
  • input_ids – Token IDs

  • position_ids – Position IDs

  • attention_mask – Attention mask

  • labels – Labels for loss computation

  • input_tensors – Hidden states or embeddings from previous stage

Returns:

Language model output (hidden states, logits, or loss depending on stage)

_forward_all_modules(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
loss_mask: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
modality_inputs: Optional[Dict[str, Dict[str, Any]]],
packing_kwargs: Optional[dict] = None,
)#

Forward pass when all modules are on all ranks (no multi-module PP).

This is the original behavior, preserved for backward compatibility.