bridge.data.mimo.dataset#
Dataset wrapper for MIMO multi-encoder models.
Module Contents#
Classes#
Dataset for MIMO models with per-modality preprocessing. |
API#
- class bridge.data.mimo.dataset.MimoDataset(
- examples: Any,
- processors: Dict[str, Any],
- tokenizer: Any,
- seq_length: int,
- special_token_ids: Dict[str, int],
- encoder_seq_lengths: Dict[str, int],
- modality_columns: Dict[str, str],
- text_column: str = 'text',
- max_samples: Optional[int] = None,
- preprocess_fn: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
Bases:
torch.utils.data.DatasetDataset for MIMO models with per-modality preprocessing.
Wraps a data source (HuggingFace dataset or list of examples) and applies per-modality processors to convert raw inputs (images, audio, etc.) into preprocessed tensors (pixel_values, input_features) that encoders consume during the forward pass.
- Parameters:
examples – Data source - either a HuggingFace Dataset or a list of dicts.
processors – Dict mapping modality name to HF processor, e.g., {“vision”: AutoProcessor.from_pretrained(“openai/clip-vit-large-patch14”)}.
tokenizer – Tokenizer for text processing.
seq_length – Total sequence length for the model (encoder placeholders + text tokens). Must be greater than sum(encoder_seq_lengths.values()) to leave room for text. Text is truncated to fit: max_text_tokens = seq_length - total_encoder_tokens.
special_token_ids – Per-encoder placeholder token IDs, e.g., {“vision”: 32000}.
encoder_seq_lengths – Per-encoder output sequence lengths, e.g., {“vision”: 577}. Determines how many placeholder tokens to insert for each modality. For CLIP ViT-L/14 with 224x224 images, this would be 577 (576 patches + 1 CLS).
modality_columns – Dict mapping modality name to column name in dataset, e.g., {“vision”: “image”, “audio”: “audio_path”}.
text_column – Column name for text/conversation data. Default: “text”.
max_samples – Optional limit on dataset size for debugging.
preprocess_fn – Optional function to preprocess each example before modality processing.
.. rubric:: Example
from datasets import load_dataset from transformers import AutoProcessor, AutoTokenizer
Using HuggingFace Dataset
hf_ds = load_dataset(“liuhaotian/LLaVA-Instruct-150K”, split=”train”) processor = AutoProcessor.from_pretrained(“openai/clip-vit-large-patch14”) tokenizer = AutoTokenizer.from_pretrained(“meta-llama/Llama-2-7b-hf”)
dataset = MimoDataset( … examples=hf_ds, … processors={“vision”: processor}, … tokenizer=tokenizer, … seq_length=2048, … special_token_ids={“vision”: 32000}, … encoder_seq_lengths={“vision”: 577}, # CLIP ViT-L/14 output tokens … modality_columns={“vision”: “image”}, … )
Or using a simple list of dicts for testing/prototyping
examples = [ … {“text”: “Describe this image.”, “image”: “img1.jpg”}, … {“text”: “What do you see?”, “image”: “img2.jpg”}, … ] dataset = MimoDataset( … examples=examples, … processors={“vision”: processor}, … tokenizer=tokenizer, … seq_length=2048, … special_token_ids={“vision”: 32000}, … encoder_seq_lengths={“vision”: 577}, … modality_columns={“vision”: “image”}, … )
Initialization
- __len__() int#
- __getitem__(idx: int) Dict[str, Any]#
Get a single example with preprocessed modality inputs.
- Returns:
input_ids: Tokenized text with placeholder tokens
labels: Same as input_ids (for causal LM training)
attention_mask: Attention mask
position_ids: Position indices
modality_inputs: Dict[str, Any] with preprocessed inputs per modality e.g., {“vision”: {“pixel_values”: tensor, …}}
- Return type:
Dict containing
- _tokenize_with_placeholders(
- text: str,
- modality_inputs: Dict[str, Dict[str, Any]],
Tokenize text and insert placeholder tokens for each modality.
For each modality present, inserts N placeholder tokens at the beginning of the sequence, where N = encoder_seq_lengths[modality_name]. This matches the number of embeddings the encoder will produce, enabling 1:1 replacement during the model forward pass.
- Parameters:
text – Raw text to tokenize.
modality_inputs – Dict of preprocessed modality inputs.
- Returns:
Token IDs tensor with placeholder tokens inserted.