bridge.data.megatron_mimo.collate#
Collate functions for MegatronMIMO datasets.
Module Contents#
Functions#
Collate function for MegatronMIMO datasets. |
API#
- bridge.data.megatron_mimo.collate.megatron_mimo_collate_fn(
- batch: List[Dict[str, Any]],
- modality_names: List[str],
Collate function for MegatronMIMO datasets.
Stacks batch items and organizes modality inputs into a structure suitable for MegatronMIMO model forward pass.
- Parameters:
batch –
List of examples from MegatronMIMODataset, each containing:
input_ids: Token IDs with placeholder tokens
labels: Labels for causal LM training
loss_mask: Per-token loss mask
attention_mask: Attention mask
position_ids: Position indices
modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
modality_names – List of modality names to collate.
- Returns:
input_ids: (batch, seq) stacked token IDs
labels: (batch, seq) stacked labels
loss_mask: (batch, seq) stacked per-token loss mask
attention_mask: (batch, seq) attention mask
position_ids: (batch, seq) position indices
modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors Each modality’s tensors are stacked along batch dimension.
- Return type:
Dict containing
.. rubric:: Example
batch = [ … { … “input_ids”: torch.tensor([32000, 1, 2, 3]), … “labels”: torch.tensor([32000, 1, 2, 3]), … “attention_mask”: torch.ones(4), … “position_ids”: torch.arange(4), … “modality_inputs”: { … “vision”: {“pixel_values”: torch.randn(3, 224, 224)}, … }, … }, … # … more examples … ] collated = megatron_mimo_collate_fn(batch, modality_names=[“vision”])