bridge.data.mimo.dataset#

Dataset wrapper for MIMO multi-encoder models.

Module Contents#

Classes#

MimoDataset

Dataset for MIMO models with per-modality preprocessing.

API#

class bridge.data.mimo.dataset.MimoDataset(
examples: Any,
processors: Dict[str, Any],
tokenizer: Any,
seq_length: int,
special_token_ids: Dict[str, int],
encoder_seq_lengths: Dict[str, int],
modality_columns: Dict[str, str],
text_column: str = 'text',
max_samples: Optional[int] = None,
preprocess_fn: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
)#

Bases: torch.utils.data.Dataset

Dataset for MIMO models with per-modality preprocessing.

Wraps a data source (HuggingFace dataset or list of examples) and applies per-modality processors to convert raw inputs (images, audio, etc.) into preprocessed tensors (pixel_values, input_features) that encoders consume during the forward pass.

Parameters:
  • examples – Data source - either a HuggingFace Dataset or a list of dicts.

  • processors – Dict mapping modality name to HF processor, e.g., {“vision”: AutoProcessor.from_pretrained(“openai/clip-vit-large-patch14”)}.

  • tokenizer – Tokenizer for text processing.

  • seq_length – Total sequence length for the model (encoder placeholders + text tokens). Must be greater than sum(encoder_seq_lengths.values()) to leave room for text. Text is truncated to fit: max_text_tokens = seq_length - total_encoder_tokens.

  • special_token_ids – Per-encoder placeholder token IDs, e.g., {“vision”: 32000}.

  • encoder_seq_lengths – Per-encoder output sequence lengths, e.g., {“vision”: 577}. Determines how many placeholder tokens to insert for each modality. For CLIP ViT-L/14 with 224x224 images, this would be 577 (576 patches + 1 CLS).

  • modality_columns – Dict mapping modality name to column name in dataset, e.g., {“vision”: “image”, “audio”: “audio_path”}.

  • text_column – Column name for text/conversation data. Default: “text”.

  • max_samples – Optional limit on dataset size for debugging.

  • preprocess_fn – Optional function to preprocess each example before modality processing.

.. rubric:: Example

from datasets import load_dataset from transformers import AutoProcessor, AutoTokenizer

Using HuggingFace Dataset

hf_ds = load_dataset(“liuhaotian/LLaVA-Instruct-150K”, split=”train”) processor = AutoProcessor.from_pretrained(“openai/clip-vit-large-patch14”) tokenizer = AutoTokenizer.from_pretrained(“meta-llama/Llama-2-7b-hf”)

dataset = MimoDataset( … examples=hf_ds, … processors={“vision”: processor}, … tokenizer=tokenizer, … seq_length=2048, … special_token_ids={“vision”: 32000}, … encoder_seq_lengths={“vision”: 577}, # CLIP ViT-L/14 output tokens … modality_columns={“vision”: “image”}, … )

Or using a simple list of dicts for testing/prototyping

examples = [ … {“text”: “Describe this image.”, “image”: “img1.jpg”}, … {“text”: “What do you see?”, “image”: “img2.jpg”}, … ] dataset = MimoDataset( … examples=examples, … processors={“vision”: processor}, … tokenizer=tokenizer, … seq_length=2048, … special_token_ids={“vision”: 32000}, … encoder_seq_lengths={“vision”: 577}, … modality_columns={“vision”: “image”}, … )

Initialization

__len__() int#
__getitem__(idx: int) Dict[str, Any]#

Get a single example with preprocessed modality inputs.

Returns:

  • input_ids: Tokenized text with placeholder tokens

  • labels: Same as input_ids (for causal LM training)

  • attention_mask: Attention mask

  • position_ids: Position indices

  • modality_inputs: Dict[str, Any] with preprocessed inputs per modality e.g., {“vision”: {“pixel_values”: tensor, …}}

Return type:

Dict containing

_tokenize_with_placeholders(
text: str,
modality_inputs: Dict[str, Dict[str, Any]],
) torch.Tensor#

Tokenize text and insert placeholder tokens for each modality.

For each modality present, inserts N placeholder tokens at the beginning of the sequence, where N = encoder_seq_lengths[modality_name]. This matches the number of embeddings the encoder will produce, enabling 1:1 replacement during the model forward pass.

Parameters:
  • text – Raw text to tokenize.

  • modality_inputs – Dict of preprocessed modality inputs.

Returns:

Token IDs tensor with placeholder tokens inserted.