bridge.data.mimo.collate#
Collate functions for MIMO datasets.
Module Contents#
Functions#
Collate function for MIMO datasets. |
API#
- bridge.data.mimo.collate.mimo_collate_fn(
- batch: List[Dict[str, Any]],
- modality_names: List[str],
Collate function for MIMO datasets.
Stacks batch items and organizes modality inputs into a structure suitable for MIMO model forward pass.
- Parameters:
batch –
List of examples from MimoDataset, each containing:
input_ids: Token IDs with placeholder tokens
labels: Labels for causal LM training
attention_mask: Attention mask
position_ids: Position indices
modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs
modality_names – List of modality names to collate.
- Returns:
input_ids: (batch, seq) stacked token IDs
labels: (batch, seq) stacked labels
attention_mask: (batch, seq) attention mask
position_ids: (batch, seq) position indices
modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors Each modality’s tensors are stacked along batch dimension.
- Return type:
Dict containing
.. rubric:: Example
batch = [ … { … “input_ids”: torch.tensor([32000, 1, 2, 3]), … “labels”: torch.tensor([32000, 1, 2, 3]), … “attention_mask”: torch.ones(4), … “position_ids”: torch.arange(4), … “modality_inputs”: { … “vision”: {“pixel_values”: torch.randn(3, 224, 224)}, … }, … }, … # … more examples … ] collated = mimo_collate_fn(batch, modality_names=[“vision”])