bridge.data.vlm_datasets.collate#

Collation utilities for building VLM training batches from conversation examples.

Module Contents#

Functions#

_gather_assistant_text_segments

Extract assistant text segments from the structured conversation example.

create_multiturn_loss_mask_by_search

Tokenizer-agnostic masking via substring search of assistant texts.

phi4_mm_collate_fn

Collate function for Phi-4 MM model audio input

qwen2_5_collate_fn

Collate function for Qwen2.5 VL model.

nemotron_nano_v2_vl_collate_fn

Collate function for Nemotron Nano V2 VL model.

ministral3_collate_fn

Collate function for Ministral 3 VL model.

glm4v_collate_fn

Collate function for GLM-4.5V model.

default_collate_fn

Default collate function for VLM models.

qwen2_audio_collate_fn

Collate function for Qwen2-Audio model.

_expand_image_tokens

Expand image placeholder tokens to the correct count based on grid_thws.

kimi_k25_vl_collate_fn

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

Data#

API#

bridge.data.vlm_datasets.collate.MISSING_QWEN_VL_UTILS_MSG#

‘qwen_vl_utils is required for Qwen2.5 VL processing. Please pip install qwen-vl-utils or provide c…’

bridge.data.vlm_datasets.collate._gather_assistant_text_segments(example: dict) list[str]#

Extract assistant text segments from the structured conversation example.

The example schema is expected to be {“conversation”: [{“role”: …, “content”: […]} …]} where content is a list of items like {“type”: “text”|”image”|…, “text”: “…”}. Returns a list of concatenated text strings, one per assistant turn.

Tokenizer-agnostic masking via substring search of assistant texts.

  • Tokenize full conversation with processor already done -> input_ids

  • Extract assistant text strings from the structured example

  • For each assistant text, tokenize without special tokens and search sequentially

  • On success, unmask that span; otherwise leave masked

bridge.data.vlm_datasets.collate.phi4_mm_collate_fn(examples, processor)#

Collate function for Phi-4 MM model audio input

bridge.data.vlm_datasets.collate.qwen2_5_collate_fn(
examples: list,
processor,
) dict[str, torch.Tensor]#

Collate function for Qwen2.5 VL model.

bridge.data.vlm_datasets.collate.nemotron_nano_v2_vl_collate_fn(
examples: list,
processor,
start_of_response_token=None,
) dict[str, torch.Tensor]#

Collate function for Nemotron Nano V2 VL model.

bridge.data.vlm_datasets.collate.ministral3_collate_fn(
examples: list,
processor,
) dict[str, torch.Tensor]#

Collate function for Ministral 3 VL model.

bridge.data.vlm_datasets.collate.glm4v_collate_fn(examples: list, processor) dict[str, torch.Tensor]#

Collate function for GLM-4.5V model.

GLM-4.5V requires mm_token_type_ids to distinguish image (1) and video (2) tokens from text (0) when computing 3D MRoPE positions. The processor returns this field by default (return_mm_token_type_ids=True in Glm4vProcessor defaults). We wrap all visual tensors — including mm_token_type_ids — in

Class:

GenericVisualInputs so they flow through vlm_step.py to the model.

bridge.data.vlm_datasets.collate.default_collate_fn(
examples: list,
processor,
) dict[str, torch.Tensor]#

Default collate function for VLM models.

bridge.data.vlm_datasets.collate.qwen2_audio_collate_fn(
examples: list,
processor,
) dict[str, torch.Tensor]#

Collate function for Qwen2-Audio model.

Uses HF-compatible label construction:

  • Backward search for assistant text spans (matching HF Trainer convention)

  • No skipped_tokens masking on labels (model learns to predict EOS/im_end)

  • Loss mask derived directly from active label positions

bridge.data.vlm_datasets.collate._expand_image_tokens(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
grid_thws: torch.Tensor,
media_token_id: int,
merge_kernel_size: tuple[int, int] = (2, 2),
) tuple[torch.Tensor, torch.Tensor]#

Expand image placeholder tokens to the correct count based on grid_thws.

For PP, this ensures the sequence length is fixed BEFORE the model forward pass, eliminating dynamic sequence expansion inside the model.

Parameters:
  • input_ids – (seq_len,) tensor with one placeholder per image

  • attention_mask – (seq_len,) tensor

  • grid_thws – (num_images, 3) tensor with [t, h, w] for each image

  • media_token_id – Token ID of the image placeholder

  • merge_kernel_size – Vision tower’s patch merge kernel, default (2, 2)

Returns:

Input IDs with placeholder expanded to N tokens expanded_attention_mask: Attention mask expanded accordingly

Return type:

expanded_input_ids

bridge.data.vlm_datasets.collate.kimi_k25_vl_collate_fn(
examples: list[dict[str, Any]],
processor,
max_length: int | None = None,
) dict[str, torch.Tensor]#

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

For pipeline parallelism, this function:

  1. Processes each sample to get input_ids with 1 placeholder per image

  2. Pre-expands each placeholder to N tokens (N = t*(h//2)*(w//2) from grid_thws)

  3. Pads all sequences to fixed max_length This ensures the model forward pass doesn’t change sequence length dynamically.

bridge.data.vlm_datasets.collate.COLLATE_FNS#

None