core.models.mimo.model.base#
Module Contents#
Classes#
Multimodal In/Out Model supporting arbitrary combinations of modalities. |
Data#
API#
- core.models.mimo.model.base.logger#
‘getLogger(…)’
- class core.models.mimo.model.base.MimoModel(
- mimo_config: megatron.core.models.mimo.config.MimoModelConfig,
Bases:
megatron.core.transformer.MegatronModuleMultimodal In/Out Model supporting arbitrary combinations of modalities.
.. warning:: EXPERIMENTAL: This class is experimental, still under active development, and the API is subject to change without notice. Use at your own risk.
.. note:: This implementation is in development and may undergo API changes.
This model processes multiple modalities (e.g., vision, audio) alongside text, combining their embeddings before passing them through a language model.
- Parameters:
mimo_config (MimoModelConfig) – Configuration for the model, including language model and modality submodules
Initialization
Initialize the multimodal model.
.. rubric:: Example
# Create a model with default configuration model = MimoModel(mimo_config)
- align_embeddings_by_token_positions(
- modality_embeddings: Dict[str, torch.Tensor],
- input_ids: torch.Tensor,
- special_token_ids: Dict[str, int],
Align embeddings from different modalities based on special token positions in input_ids.
- Parameters:
modality_embeddings – Dictionary mapping modality names to their embeddings. For all modalities: tensor of shape [num_tokens_for_modality, hidden_dim]
input_ids – Input token IDs of shape [batch_size, seq_len] containing special tokens that mark where each modality’s embeddings should go. The number of special tokens for each modality should exactly match the number of embeddings for that modality.
special_token_ids – Dictionary mapping modality names to their special token IDs
- Returns:
Combined embeddings tensor of shape [seq_len, batch_size, hidden_dim]
- _initialize_submodules() None#
Initialize modality submodules from the ModuleSpec configurations.
Only modalities present in the config will be instantiated. For each modality in the config, builds the corresponding submodule using from_spec.
- _initialize_language_model() None#
Initialize the language model.
- set_input_tensor(input_tensor)#
Set input tensor for pipeline parallelism.
This method is required by Megatron’s pipeline parallel mechanism. It passes the output tensor from the previous stage as input to this stage.
- Parameters:
input_tensor – Tensor or list of tensors passed between pipeline stages
- Returns:
None
- get_text_embeddings(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- special_token_ids: Dict[str, int],
Get embeddings for text tokens in the input.
- Parameters:
input_ids – Input token IDs of shape [batch_size, seq_len] containing text tokens and potentially special tokens for other modalities.
position_ids – Position IDs corresponding to input tokens, used for positional encoding. Shape [batch_size, seq_len].
special_token_ids – Dictionary mapping modality names to their special token IDs. Used to identify non-text tokens in the input_ids.
- Returns:
Embeddings for text tokens, shape [num_text_tokens, hidden_dim].
- Return type:
torch.Tensor
- forward(
- input_ids: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- loss_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None,
Forward pass through the multimodal model.
- Parameters:
input_ids – Input token IDs [batch_size, seq_length]
position_ids – Position IDs [batch_size, seq_length]
attention_mask – Attention mask [batch_size, seq_length]
loss_mask – Loss mask [batch_size, seq_length]
labels – Labels for training
modality_inputs – Dictionary mapping modality names to encoder inputs. For example: { “images”: { “clip_encoder”: {“pixel_values”: clip_images}, “vit_encoder”: {“images”: vit_images} }, “audio”: { “whisper_encoder”: {“input_features”: whisper_features} } }
- Returns:
Tuple containing model outputs and loss mask
- Return type:
tuple