core.models.mimo.model.base#

Module Contents#

Classes#

MimoModel

Multimodal In/Out Model supporting arbitrary combinations of modalities.

Data#

API#

core.models.mimo.model.base.logger#

‘getLogger(…)’

class core.models.mimo.model.base.MimoModel(
mimo_config: megatron.core.models.mimo.config.MimoModelConfig,
)#

Bases: megatron.core.transformer.MegatronModule

Multimodal In/Out Model supporting arbitrary combinations of modalities.

.. warning:: EXPERIMENTAL: This class is experimental, still under active development, and the API is subject to change without notice. Use at your own risk.

.. note:: This implementation is in development and may undergo API changes.

This model processes multiple modalities (e.g., vision, audio) alongside text, combining their embeddings before passing them through a language model.

Parameters:

mimo_config (MimoModelConfig) – Configuration for the model, including language model and modality submodules

Initialization

Initialize the multimodal model.

.. rubric:: Example

# Create a model with default configuration
model = MimoModel(mimo_config)
align_embeddings_by_token_positions(
modality_embeddings: Dict[str, torch.Tensor],
input_ids: torch.Tensor,
special_token_ids: Dict[str, int],
) torch.Tensor#

Align embeddings from different modalities based on special token positions in input_ids.

Parameters:
  • modality_embeddings – Dictionary mapping modality names to their embeddings. For all modalities: tensor of shape [num_tokens_for_modality, hidden_dim]

  • input_ids – Input token IDs of shape [batch_size, seq_len] containing special tokens that mark where each modality’s embeddings should go. The number of special tokens for each modality should exactly match the number of embeddings for that modality.

  • special_token_ids – Dictionary mapping modality names to their special token IDs

Returns:

Combined embeddings tensor of shape [seq_len, batch_size, hidden_dim]

_initialize_submodules() None#

Initialize modality submodules from the ModuleSpec configurations.

Only modalities present in the config will be instantiated. For each modality in the config, builds the corresponding submodule using from_spec.

_initialize_language_model() None#

Initialize the language model.

set_input_tensor(input_tensor)#

Set input tensor for pipeline parallelism.

This method is required by Megatron’s pipeline parallel mechanism. It passes the output tensor from the previous stage as input to this stage.

Parameters:

input_tensor – Tensor or list of tensors passed between pipeline stages

Returns:

None

get_text_embeddings(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
special_token_ids: Dict[str, int],
) torch.Tensor#

Get embeddings for text tokens in the input.

Parameters:
  • input_ids – Input token IDs of shape [batch_size, seq_len] containing text tokens and potentially special tokens for other modalities.

  • position_ids – Position IDs corresponding to input tokens, used for positional encoding. Shape [batch_size, seq_len].

  • special_token_ids – Dictionary mapping modality names to their special token IDs. Used to identify non-text tokens in the input_ids.

Returns:

Embeddings for text tokens, shape [num_text_tokens, hidden_dim].

Return type:

torch.Tensor

forward(
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None,
)#

Forward pass through the multimodal model.

Parameters:
  • input_ids – Input token IDs [batch_size, seq_length]

  • position_ids – Position IDs [batch_size, seq_length]

  • attention_mask – Attention mask [batch_size, seq_length]

  • loss_mask – Loss mask [batch_size, seq_length]

  • labels – Labels for training

  • modality_inputs – Dictionary mapping modality names to encoder inputs. For example: { “images”: { “clip_encoder”: {“pixel_values”: clip_images}, “vit_encoder”: {“images”: vit_images} }, “audio”: { “whisper_encoder”: {“input_features”: whisper_features} } }

Returns:

Tuple containing model outputs and loss mask

Return type:

tuple