nemo_automodel.components.datasets.vlm.collate_fns
nemo_automodel.components.datasets.vlm.collate_fns
Module Contents
Functions
Data
API
Scan input_ids for assistant_marker … stop_id and build labels.
For each sequence in the batch, every token between the end of an
assistant marker and the corresponding stop_id (inclusive) is copied
into the labels tensor; all other positions are set to -100.
Parameters
input_ids_batch:
Shape (B, L).
assistant_marker:
Token-id sequence that opens an assistant turn (e.g.
[<|im_start|>, assistant_id, newline_id] for Qwen or
[<start_of_turn>, model_id, newline_id] for Gemma4).
stop_id:
Single token id that closes a turn (e.g. <|im_end|> or
<end_of_turn>).
Count images and videos per sample from conversation structure.
Returns two lists of length len(conversations) giving the number of
image and video items in each conversation, respectively.
Decode a single token id across tokenizer implementations.
Some tokenizers accept an int token id, while others require a sequence of
ids (e.g., List[int]). We try the common forms in order.
Derive the assistant-turn start marker and end-of-turn token id from the tokenizer’s own chat template.
The function applies a minimal dummy conversation that contains a known sentinel string as the assistant reply, then locates the sentinel in the resulting token sequence. Everything between the end of the user turn and the start of the sentinel becomes the assistant marker; the first token after the sentinel becomes the end-of-turn id.
This approach is robust to BPE context-sensitivity and works for any model
whose template wraps assistant turns with fixed token sequences — e.g.
Gemma4’s <start_of_turn>model\n … <end_of_turn>.
.. note::
apply_chat_template may return a :class:~transformers.BatchEncoding
(a UserDict subclass, not a plain :class:dict), so
isinstance(result, dict) is False. We access result["input_ids"]
directly, which works for both BatchEncoding and plain dict / list.
Returns
tuple[list[int], int]
(assistant_marker, end_of_turn_id)
Raises
ValueError If the sentinel cannot be located in the template output or if the resulting marker is empty.
Drop conversations whose estimated token count exceeds max_length.
Returns (filtered_conversations, kept_indices) where kept_indices
are the original positions that survived filtering. Raises ValueError
when every sample in the batch is dropped (caught by robust_collate
which re-samples).
Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.
Estimate expanded media token count from image/video dimensions.
Returns total extra tokens beyond the single-placeholder-per-media count that tokenization produces. Only images with known dimensions (PIL Image objects or loadable paths) are estimated; unknown media items contribute 0 extra tokens (the placeholder is still counted in the base tokenization).
Expand image placeholder tokens to the correct patch counts based on grid_thws.
For PP, this ensures the sequence length is fixed BEFORE the model forward pass, eliminating dynamic sequence expansion inside the model.
Supports both single-image and multi-image samples. Each placeholder token is
expanded to (h // merge_h) * (w // merge_w) tokens using the corresponding
entry in grid_thws. The number of placeholders must match grid_thws.shape[0].
Parameters:
(seq_len,) tensor containing one media_token_id per image.
(seq_len,) tensor aligned with input_ids.
(N, 3) tensor with [t, h, w] for each of the N images.
Token ID used as the image placeholder.
Vision tower’s patch merge kernel, default (2, 2).
Returns: torch.Tensor
Input IDs with each placeholder expanded to its patch count.
Raises:
ValueError: When the number of placeholders does not matchgrid_thws.shape[0].
Extract image processing config from processor for token estimation.
Extract image and video inputs from conversation content elements.
Images are returned as-is (PIL Image or path string) for the image processor.
Videos are returned as path strings so the video processor can read and sample
them using its own fps / max_frames configuration.
Returns:
(images list | None, videos list | None)
Return the token-id sequence that introduces an assistant turn.
For Qwen-family models the marker is [<|im_start|>, assistant, \n].
Returns None when the tokenizer does not use this convention.
Return the token id of the turn-ending marker (<|im_end|>).
Insert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.
Modifies input_ids, attention_mask, and mm_token_type_ids
(if present). Additionally, any other 2-D integer tensor whose second
dimension matches input_ids is extended with zeros so that sequence
lengths stay consistent (this ismore of future-proofing)
Construct label and optional loss-mask tensors aligned to assistant responses.
Build training labels by scanning input_ids for chat-template role markers.
Instead of re-tokenizing assistant text and searching for it (fragile due to BPE context sensitivity), this function locates the structural markers that the chat template inserts around each assistant turn and sets labels only for the content region.
Two strategies are attempted in order:
-
Fast path (
_IMSTART_TEMPLATE_PROCESSORS): for Qwen-family models whose tokenizers expose<|im_start|>/<|im_end|>via :func:convert_tokens_to_ids, the marker ids are resolved directly without applying any dummy conversation. -
General path (
_derive_turn_markers): for all other processors (e.g. Gemma4), the assistant-turn markers are derived automatically by applying a minimal dummy conversation that contains a sentinel string. This handles models whose tokenizers do not reliably expose special-token ids viaconvert_tokens_to_idsorencode.
If both strategies fail, the function falls back to the legacy
:func:build_labels (BPE pattern-matching), which logs a warning because
it is sensitive to tokenisation context and may produce num_label_tokens=0
/ nan loss on some samples.
Default collate function for multimodal VLM datasets.
Parameters:
Optional callable (batch, processor) -> batch
invoked right after apply_chat_template and before
build_labels. Used by model-specific collate wrappers
(e.g. Gemma4 thinking-channel injection) to transform the
tokenized batch and the prefix tokens without duplicating the rest of the pipeline.
Inject Gemma4’s thinking-channel prefix after every assistant turn marker.
Gemma4 31B / 26B-A4B MoE instruction-tuned models always emit a thinking-
channel prefix before the actual response. When this prefix is absent from
training sequences the model predicts <|channel> but the label says
answer text, inflating initial loss to ~9. Injecting the prefix (masked
as -100 in labels) lets the model see its expected pattern and brings
initial loss down to ~3.
Safe no-op for non-Gemma4 tokenizers.
Collate function for Gemma4 models with thinking-channel prefix.
Wraps default_collate_fn and injects <|channel>thought\n<channel|>
after every <|turn>model\n marker before labels are built. The injected
tokens are automatically masked to -100 by build_labels_from_template
(which only unmasks tokens inside assistant turns), so the model sees its
expected thinking prefix without being penalised for it.
Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.
For pipeline parallelism, this function:
- Processes each sample to get input_ids with 1 placeholder per image
- Pre-expands the placeholder to N tokens (N = (h//2)*(w//2) from grid_thws)
- Pads all sequences to fixed max_length This ensures the model forward pass doesn’t change sequence length dynamically.
Collate function for KimiVL processors.
Collate function for LLaVA-OneVision-1.5 processors.
Handles image and video inputs using the Qwen-style chat template and qwen_vl_utils for vision info processing.
Wrap collate_fn so that on failure the entire batch is re-sampled.
Parameters:
The dataset to re-sample from on failure.
The collate function to wrap.
Maximum number of retry attempts.
Collater for neat-packed VLM sequences.
Packs arrive with variable lengths (no pre-padding). This collater:
- Pads all text tensors to a common length.
- Converts the indexed
attention_maskto the appropriate format:flash_attention_2: keeps the indexed[B, S]mask (values 1, 2, … for documents, 0 for padding). The monkey-patched_get_unpad_dataconverts this tocu_seqlensforflash_attn_varlen_func.sdpa/eager: converts to a 4D block-causal bool mask.
- Concatenates media tensors across the batch dimension.
No autoregressive shift — it was already applied during packing.
Parameters:
List of packed sample dicts from PackedDatasetWrapper.
Token ID for padding input_ids (default 0).
If set, pad every batch to this fixed length.
If None (default), pad to the longest pack in the batch.
A fixed length avoids recompilation with torch.compile
and ensures uniform tensor shapes across steps.
Attention backend ("flash_attention_2",
"sdpa", or "eager").
Returns: dict
Dict with batched tensors ready for model forward.
Collate for NemotronOmni (image / video / audio).
Defers <image>/<video>/<audio> placeholder expansion to the processor;
the collate only gathers media into processor kwargs, pads, builds labels, and
stacks tensors.
Collate function for NVIDIA Nemotron-Parse models.
The Nemotron-Parse processor does not expose a chat template, so we build the prompt + answer string manually, mask the prompt tokens, and keep the image preprocessing handled by the processor.
Collate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).
Each example is expected to carry at least input_ids, attention_mask,
and labels as 1-D tensors, plus optional media tensors (pixel_values,
image_grid_thw, pixel_values_videos, video_grid_thw).
Fake image injection and vision-token masking are handled per-sample in
:class:PreTokenizedDatasetWrapper.__getitem__, so this function only
pads, stacks, and concatenates.
The function:
- Pads all sequence tensors to the same length (either max_length or the longest sequence in the batch).
- Concatenates media tensors across the batch.
- Applies the standard autoregressive shift (
labels = labels[:, 1:], inputs truncated by one token).
Collate function for Phi-4 MM model audio input
Collate function for Qwen2.5 VL model.
Collate function for Qwen3 Omni processors.