core.models.mimo.model.base#
Module Contents#
Classes#
Multimodal In/Out Model supporting arbitrary combinations of modalities. |
Data#
API#
- core.models.mimo.model.base.logger#
‘getLogger(…)’
- class core.models.mimo.model.base.MimoModel(
- mimo_config: megatron.core.models.mimo.config.MimoModelConfig,
- cp_group=None,
- tp_group=None,
Bases:
megatron.core.transformer.MegatronModuleMultimodal In/Out Model supporting arbitrary combinations of modalities.
.. warning:: EXPERIMENTAL: This class is experimental, still under active development, and the API is subject to change without notice. Use at your own risk.
.. note:: This implementation is in development and may undergo API changes.
This model processes multiple modalities (e.g., vision, audio) alongside text, combining their embeddings before passing them through a language model.
- Parameters:
mimo_config (MimoModelConfig) – Configuration for the model, including language model and modality submodules
Initialization
Initialize the multimodal model.
.. rubric:: Example
# Create a model with default configuration model = MimoModel(mimo_config)
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Build sharded state dict, bypassing parallel_state global fallbacks.
Iterates modality_submodules manually (ModuleDict lacks sharded_state_dict) and injects dp_cp_group from each module’s pg_collection.
- align_embeddings_by_token_positions(
- modality_embeddings: Dict[str, torch.Tensor],
- input_ids: torch.Tensor,
- special_token_ids: Dict[str, int],
Align embeddings from different modalities based on special token positions in input_ids.
- Parameters:
modality_embeddings – Dictionary mapping modality names to their embeddings. For all modalities: tensor of shape (N, H). Shape: (num_tokens_for_modality, hidden_dim)
input_ids – Input token IDs. Shape: (B, S) or (S,) Contains special tokens that mark where each modality’s embeddings should go. The number of special tokens for each modality should exactly match the number of embeddings for that modality.
special_token_ids – Dictionary mapping modality names to their special token IDs
- Returns:
(S, B, H)
- Return type:
Combined embeddings tensor. Shape
- _initialize_submodules() None#
Initialize modality submodules from the ModuleSpec configurations.
When role is set, only initializes submodules this rank participates in. Stage info is passed to from_spec() to conditionally skip projection.
- _initialize_language_model() None#
Initialize the language model.
When role is set, only initializes if this rank participates in language module.
- set_input_tensor(input_tensor)#
Set input tensor for pipeline parallelism.
This method is required by Megatron’s pipeline parallel mechanism. It passes the output tensor from the previous stage as input to this stage.
- Parameters:
input_tensor –
Either:
Dict[str, Tensor]: Maps module names to their input tensors (for multi-module PP)
Tensor or List[Tensor]: Single tensor for language model (backward compat)
- Returns:
None
- get_text_embeddings(
- input_ids: torch.Tensor,
- position_ids: torch.Tensor,
- special_token_ids: Dict[str, int],
Get embeddings for text tokens in the input.
- Parameters:
input_ids – Input token IDs. Shape: (B, S) Contains text tokens and potentially special tokens for other modalities.
position_ids – Position IDs corresponding to input tokens, used for positional encoding. Shape: (B, S)
special_token_ids – Dictionary mapping modality names to their special token IDs. Used to identify non-text tokens in the input_ids.
- Returns:
Embeddings for text tokens. Shape: (N, H), where N is the number of text tokens.
- Return type:
torch.Tensor
- forward(
- input_ids: torch.Tensor,
- position_ids: Optional[torch.Tensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- loss_mask: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- modality_inputs: Optional[Dict[str, Dict[str, Any]]] = None,
- packing_kwargs: Optional[dict] = None,
Forward pass through the multimodal model.
- Parameters:
input_ids – Input token IDs. Shape: (B, S)
position_ids – Position IDs. Shape: (B, S)
attention_mask – Attention mask. Shape: (B, S)
loss_mask – Loss mask. Shape: (B, S)
labels – Labels for training. Shape: (B, S)
modality_inputs – Dictionary mapping modality names to encoder inputs. For example: { “images”: { “clip_encoder”: {“pixel_values”: clip_images}, “vit_encoder”: {“images”: vit_images} }, “audio”: { “whisper_encoder”: {“input_features”: whisper_features} } }
packing_kwargs – Optional dictionary of kwargs to construct PackedSeqParams if packed_seq_params is not provided. For example: { “cu_seqlens_q”: cu_seqlens, “cu_seqlens_kv”: cu_seqlens, “cu_seqlens_q_padded”: cu_seqlens_padded, “cu_seqlens_kv_padded”: cu_seqlens_padded, “max_seqlen_q”: torch.tensor( max(seqlens_padded), dtype=torch.int32 ), “max_seqlen_kv”: torch.tensor( max(seqlens_padded), dtype=torch.int32 ), }
- Returns:
(output, loss_mask) where output semantics depend on role: - Encoder-only ranks: Dict[str, Tensor] of encoder outputs - Language module ranks: language model output (logits or loss) - No role (all modules colocated): language model output
- Return type:
tuple
- _forward_encoders(
- modality_inputs: Optional[Dict[str, Dict[str, Any]]],
- input_tensors: Optional[Dict[str, torch.Tensor]],
Forward pass for encoder modules on this rank.
- Parameters:
modality_inputs – Raw inputs for each modality (images, audio, etc.)
input_tensors – Hidden states from previous pipeline stages
- Returns:
Dict mapping encoder names to their output tensors
- _forward_language_module(
- input_ids: torch.Tensor,
- position_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- labels: Optional[torch.Tensor],
- input_tensors: Optional[Dict[str, torch.Tensor]],
Forward pass for language module on this rank.
- Parameters:
input_ids – Token IDs
position_ids – Position IDs
attention_mask – Attention mask
labels – Labels for loss computation
input_tensors – Hidden states or embeddings from previous stage
- Returns:
Language model output (hidden states, logits, or loss depending on stage)
- _forward_all_modules(
- input_ids: torch.Tensor,
- position_ids: Optional[torch.Tensor],
- attention_mask: Optional[torch.Tensor],
- loss_mask: Optional[torch.Tensor],
- labels: Optional[torch.Tensor],
- modality_inputs: Optional[Dict[str, Dict[str, Any]]],
- packing_kwargs: Optional[dict] = None,
Forward pass when all modules are on all ranks (no multi-module PP).
This is the original behavior, preserved for backward compatibility.