bridge.data.mimo.collate#

Collate functions for MIMO datasets.

Module Contents#

Functions#

mimo_collate_fn

Collate function for MIMO datasets.

API#

bridge.data.mimo.collate.mimo_collate_fn(
batch: List[Dict[str, Any]],
modality_names: List[str],
) Dict[str, Any]#

Collate function for MIMO datasets.

Stacks batch items and organizes modality inputs into a structure suitable for MIMO model forward pass.

Parameters:
  • batch –

    List of examples from MimoDataset, each containing:

    • input_ids: Token IDs with placeholder tokens

    • labels: Labels for causal LM training

    • attention_mask: Attention mask

    • position_ids: Position indices

    • modality_inputs: Dict[str, Dict[str, Any]] with preprocessed inputs

  • modality_names – List of modality names to collate.

Returns:

  • input_ids: (batch, seq) stacked token IDs

  • labels: (batch, seq) stacked labels

  • attention_mask: (batch, seq) attention mask

  • position_ids: (batch, seq) position indices

  • modality_inputs: Dict[str, Dict[str, Tensor]] with batched modality tensors Each modality’s tensors are stacked along batch dimension.

Return type:

Dict containing

.. rubric:: Example

batch = [ … { … “input_ids”: torch.tensor([32000, 1, 2, 3]), … “labels”: torch.tensor([32000, 1, 2, 3]), … “attention_mask”: torch.ones(4), … “position_ids”: torch.arange(4), … “modality_inputs”: { … “vision”: {“pixel_values”: torch.randn(3, 224, 224)}, … }, … }, … # … more examples … ] collated = mimo_collate_fn(batch, modality_names=[“vision”])