nemo_automodel.components.datasets.vlm.collate_fns

View as Markdown

Module Contents

Functions

NameDescription
_build_labels_from_markersScan input_ids for assistant_markerstop_id and build labels.
_count_media_per_sampleCount images and videos per sample from conversation structure.
_decode_single_tokenDecode a single token id across tokenizer implementations.
_derive_turn_markersDerive the assistant-turn start marker and end-of-turn token id from the
_drop_overlong_samplesDrop conversations whose estimated token count exceeds max_length.
_ensure_rgbConvert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.
_estimate_media_tokensEstimate expanded media token count from image/video dimensions.
_expand_image_tokensExpand image placeholder tokens to the correct patch counts based on grid_thws.
_extract_assistant_text-
_extract_image_configExtract image processing config from processor for token estimation.
_extract_media_from_conversationsExtract image and video inputs from conversation content elements.
_find_pattern_indices-
_get_assistant_markerReturn the token-id sequence that introduces an assistant turn.
_get_stop_token_idReturn the token id of the turn-ending marker (<|im_end|>).
_inject_thinking_prefix_tokensInsert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.
build_labelsConstruct label and optional loss-mask tensors aligned to assistant responses.
build_labels_from_templateBuild training labels by scanning input_ids for chat-template role markers.
default_collate_fnDefault collate function for multimodal VLM datasets.
gemma4_inject_thinking_prefixInject Gemma4’s thinking-channel prefix after every assistant turn marker.
gemma4_prefix_collate_fnCollate function for Gemma4 models with thinking-channel prefix.
kimi_k25_vl_collate_fnCollate function for Kimi K2.5 VL processors with pre-expanded image tokens.
kimi_vl_collate_fnCollate function for KimiVL processors.
llava_onevision_collate_fnCollate function for LLaVA-OneVision-1.5 processors.
make_robust_collateWrap collate_fn so that on failure the entire batch is re-sampled.
neat_packed_vlm_collaterCollater for neat-packed VLM sequences.
nemotron_omni_collate_fnCollate for NemotronOmni (image / video / audio).
nemotron_parse_collate_fnCollate function for NVIDIA Nemotron-Parse models.
pad_collate_fnCollate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).
phi4_mm_collate_fnCollate function for Phi-4 MM model audio input
qwen2_5_collate_fnCollate function for Qwen2.5 VL model.
qwen3_omni_collate_fnCollate function for Qwen3 Omni processors.

Data

COLLATE_FNS

HAVE_QWEN_OMNI_UTILS

HAVE_QWEN_VL_UTILS

_DEFAULT_MERGE_KERNEL

_GEMMA4_MODEL_TURN

_GEMMA4_THINKING_PREFIX

_IMSTART_TEMPLATE_PROCESSORS

logger

API

nemo_automodel.components.datasets.vlm.collate_fns._build_labels_from_markers(
input_ids_batch: torch.Tensor,
assistant_marker: typing.List[int],
stop_id: int
) -> torch.Tensor

Scan input_ids for assistant_markerstop_id and build labels.

For each sequence in the batch, every token between the end of an assistant marker and the corresponding stop_id (inclusive) is copied into the labels tensor; all other positions are set to -100.

Parameters

input_ids_batch: Shape (B, L). assistant_marker: Token-id sequence that opens an assistant turn (e.g. [<|im_start|>, assistant_id, newline_id] for Qwen or [<start_of_turn>, model_id, newline_id] for Gemma4). stop_id: Single token id that closes a turn (e.g. <|im_end|> or <end_of_turn>).

nemo_automodel.components.datasets.vlm.collate_fns._count_media_per_sample(
conversations
)

Count images and videos per sample from conversation structure.

Returns two lists of length len(conversations) giving the number of image and video items in each conversation, respectively.

nemo_automodel.components.datasets.vlm.collate_fns._decode_single_token(
tokenizer,
token_id: int
) -> str

Decode a single token id across tokenizer implementations.

Some tokenizers accept an int token id, while others require a sequence of ids (e.g., List[int]). We try the common forms in order.

nemo_automodel.components.datasets.vlm.collate_fns._derive_turn_markers(
tokenizer
) -> typing.Tuple[typing.List[int], int]

Derive the assistant-turn start marker and end-of-turn token id from the tokenizer’s own chat template.

The function applies a minimal dummy conversation that contains a known sentinel string as the assistant reply, then locates the sentinel in the resulting token sequence. Everything between the end of the user turn and the start of the sentinel becomes the assistant marker; the first token after the sentinel becomes the end-of-turn id.

This approach is robust to BPE context-sensitivity and works for any model whose template wraps assistant turns with fixed token sequences — e.g. Gemma4’s <start_of_turn>model\n<end_of_turn>.

.. note:: apply_chat_template may return a :class:~transformers.BatchEncoding (a UserDict subclass, not a plain :class:dict), so isinstance(result, dict) is False. We access result["input_ids"] directly, which works for both BatchEncoding and plain dict / list.

Returns

tuple[list[int], int] (assistant_marker, end_of_turn_id)

Raises

ValueError If the sentinel cannot be located in the template output or if the resulting marker is empty.

nemo_automodel.components.datasets.vlm.collate_fns._drop_overlong_samples(
conversations,
processor,
max_length
)

Drop conversations whose estimated token count exceeds max_length.

Returns (filtered_conversations, kept_indices) where kept_indices are the original positions that survived filtering. Raises ValueError when every sample in the batch is dropped (caught by robust_collate which re-samples).

nemo_automodel.components.datasets.vlm.collate_fns._ensure_rgb(
conversations
)

Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.

nemo_automodel.components.datasets.vlm.collate_fns._estimate_media_tokens(
conversation,
processor
)

Estimate expanded media token count from image/video dimensions.

Returns total extra tokens beyond the single-placeholder-per-media count that tokenization produces. Only images with known dimensions (PIL Image objects or loadable paths) are estimated; unknown media items contribute 0 extra tokens (the placeholder is still counted in the base tokenization).

nemo_automodel.components.datasets.vlm.collate_fns._expand_image_tokens(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
grid_thws: torch.Tensor,
media_token_id: int,
merge_kernel_size: typing.Tuple[int, int] = _DEFAULT_MERGE_KERNEL
) -> typing.Tuple[torch.Tensor, torch.Tensor]

Expand image placeholder tokens to the correct patch counts based on grid_thws.

For PP, this ensures the sequence length is fixed BEFORE the model forward pass, eliminating dynamic sequence expansion inside the model.

Supports both single-image and multi-image samples. Each placeholder token is expanded to (h // merge_h) * (w // merge_w) tokens using the corresponding entry in grid_thws. The number of placeholders must match grid_thws.shape[0].

Parameters:

input_ids
torch.Tensor

(seq_len,) tensor containing one media_token_id per image.

attention_mask
torch.Tensor

(seq_len,) tensor aligned with input_ids.

grid_thws
torch.Tensor

(N, 3) tensor with [t, h, w] for each of the N images.

media_token_id
int

Token ID used as the image placeholder.

merge_kernel_size
Tuple[int, int]Defaults to _DEFAULT_MERGE_KERNEL

Vision tower’s patch merge kernel, default (2, 2).

Returns: torch.Tensor

Input IDs with each placeholder expanded to its patch count.

Raises:

  • ValueError: When the number of placeholders does not match grid_thws.shape[0].
nemo_automodel.components.datasets.vlm.collate_fns._extract_assistant_text(
message: typing.Dict[str, typing.Any]
) -> str
nemo_automodel.components.datasets.vlm.collate_fns._extract_image_config(
processor
)

Extract image processing config from processor for token estimation.

nemo_automodel.components.datasets.vlm.collate_fns._extract_media_from_conversations(
conversations
)

Extract image and video inputs from conversation content elements.

Images are returned as-is (PIL Image or path string) for the image processor. Videos are returned as path strings so the video processor can read and sample them using its own fps / max_frames configuration.

Returns:

(images list | None, videos list | None)

nemo_automodel.components.datasets.vlm.collate_fns._find_pattern_indices(
template,
pattern,
search_start_index = 0,
allow_first_token_mismatch = False
)
nemo_automodel.components.datasets.vlm.collate_fns._get_assistant_marker(
tokenizer
) -> typing.Optional[typing.List[int]]

Return the token-id sequence that introduces an assistant turn.

For Qwen-family models the marker is [<|im_start|>, assistant, \n]. Returns None when the tokenizer does not use this convention.

nemo_automodel.components.datasets.vlm.collate_fns._get_stop_token_id(
tokenizer
) -> typing.Optional[int]

Return the token id of the turn-ending marker (<|im_end|>).

nemo_automodel.components.datasets.vlm.collate_fns._inject_thinking_prefix_tokens(
batch: typing.Dict[str, torch.Tensor],
tokenizer
) -> typing.Dict[str, torch.Tensor]

Insert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.

Modifies input_ids, attention_mask, and mm_token_type_ids (if present). Additionally, any other 2-D integer tensor whose second dimension matches input_ids is extended with zeros so that sequence lengths stay consistent (this ismore of future-proofing)

nemo_automodel.components.datasets.vlm.collate_fns.build_labels(
input_ids_batch: torch.Tensor,
conversations: typing.Sequence[typing.Sequence[typing.Dict[str, typing.Any]]],
processor
) -> torch.Tensor

Construct label and optional loss-mask tensors aligned to assistant responses.

nemo_automodel.components.datasets.vlm.collate_fns.build_labels_from_template(
input_ids_batch: torch.Tensor,
conversations: typing.Sequence[typing.Sequence[typing.Dict[str, typing.Any]]],
processor
) -> torch.Tensor

Build training labels by scanning input_ids for chat-template role markers.

Instead of re-tokenizing assistant text and searching for it (fragile due to BPE context sensitivity), this function locates the structural markers that the chat template inserts around each assistant turn and sets labels only for the content region.

Two strategies are attempted in order:

  1. Fast path (_IMSTART_TEMPLATE_PROCESSORS): for Qwen-family models whose tokenizers expose <|im_start|> / <|im_end|> via :func:convert_tokens_to_ids, the marker ids are resolved directly without applying any dummy conversation.

  2. General path (_derive_turn_markers): for all other processors (e.g. Gemma4), the assistant-turn markers are derived automatically by applying a minimal dummy conversation that contains a sentinel string. This handles models whose tokenizers do not reliably expose special-token ids via convert_tokens_to_ids or encode.

If both strategies fail, the function falls back to the legacy :func:build_labels (BPE pattern-matching), which logs a warning because it is sensitive to tokenisation context and may produce num_label_tokens=0 / nan loss on some samples.

nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
max_length: typing.Optional[int] = None,
drop_overlong: bool = False,
_post_tokenize_hook = None
) -> typing.Dict[str, torch.Tensor]

Default collate function for multimodal VLM datasets.

Parameters:

_post_tokenize_hook
Defaults to None

Optional callable (batch, processor) -> batch invoked right after apply_chat_template and before build_labels. Used by model-specific collate wrappers (e.g. Gemma4 thinking-channel injection) to transform the tokenized batch and the prefix tokens without duplicating the rest of the pipeline.

nemo_automodel.components.datasets.vlm.collate_fns.gemma4_inject_thinking_prefix(
batch: typing.Dict[str, torch.Tensor],
processor
) -> typing.Dict[str, torch.Tensor]

Inject Gemma4’s thinking-channel prefix after every assistant turn marker.

Gemma4 31B / 26B-A4B MoE instruction-tuned models always emit a thinking- channel prefix before the actual response. When this prefix is absent from training sequences the model predicts <|channel> but the label says answer text, inflating initial loss to ~9. Injecting the prefix (masked as -100 in labels) lets the model see its expected pattern and brings initial loss down to ~3.

Safe no-op for non-Gemma4 tokenizers.

nemo_automodel.components.datasets.vlm.collate_fns.gemma4_prefix_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
max_length: typing.Optional[int] = None
) -> typing.Dict[str, torch.Tensor]

Collate function for Gemma4 models with thinking-channel prefix.

Wraps default_collate_fn and injects <|channel>thought\n<channel|> after every <|turn>model\n marker before labels are built. The injected tokens are automatically masked to -100 by build_labels_from_template (which only unmasks tokens inside assistant turns), so the model sees its expected thinking prefix without being penalised for it.

nemo_automodel.components.datasets.vlm.collate_fns.kimi_k25_vl_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
max_length: typing.Optional[int] = None,
drop_overlong: bool = False
) -> typing.Dict[str, torch.Tensor]

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

For pipeline parallelism, this function:

  1. Processes each sample to get input_ids with 1 placeholder per image
  2. Pre-expands the placeholder to N tokens (N = (h//2)*(w//2) from grid_thws)
  3. Pads all sequences to fixed max_length This ensures the model forward pass doesn’t change sequence length dynamically.
nemo_automodel.components.datasets.vlm.collate_fns.kimi_vl_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
max_length: typing.Optional[int] = None
) -> typing.Dict[str, torch.Tensor]

Collate function for KimiVL processors.

nemo_automodel.components.datasets.vlm.collate_fns.llava_onevision_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor
) -> typing.Dict[str, torch.Tensor]

Collate function for LLaVA-OneVision-1.5 processors.

Handles image and video inputs using the Qwen-style chat template and qwen_vl_utils for vision info processing.

nemo_automodel.components.datasets.vlm.collate_fns.make_robust_collate(
dataset,
collate_fn,
max_retries = 10
)

Wrap collate_fn so that on failure the entire batch is re-sampled.

Parameters:

dataset

The dataset to re-sample from on failure.

collate_fn

The collate function to wrap.

max_retries
Defaults to 10

Maximum number of retry attempts.

nemo_automodel.components.datasets.vlm.collate_fns.neat_packed_vlm_collater(
batch: list[dict],
padding_idx: int = 0,
max_length: int | None = None,
attn_implementation: str = 'sdpa'
) -> dict

Collater for neat-packed VLM sequences.

Packs arrive with variable lengths (no pre-padding). This collater:

  1. Pads all text tensors to a common length.
  2. Converts the indexed attention_mask to the appropriate format:
    • flash_attention_2: keeps the indexed [B, S] mask (values 1, 2, … for documents, 0 for padding). The monkey-patched _get_unpad_data converts this to cu_seqlens for flash_attn_varlen_func.
    • sdpa / eager: converts to a 4D block-causal bool mask.
  3. Concatenates media tensors across the batch dimension.

No autoregressive shift — it was already applied during packing.

Parameters:

batch
list[dict]

List of packed sample dicts from PackedDatasetWrapper.

padding_idx
intDefaults to 0

Token ID for padding input_ids (default 0).

max_length
int | NoneDefaults to None

If set, pad every batch to this fixed length. If None (default), pad to the longest pack in the batch. A fixed length avoids recompilation with torch.compile and ensures uniform tensor shapes across steps.

attn_implementation
strDefaults to 'sdpa'

Attention backend ("flash_attention_2", "sdpa", or "eager").

Returns: dict

Dict with batched tensors ready for model forward.

nemo_automodel.components.datasets.vlm.collate_fns.nemotron_omni_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
max_length: typing.Optional[int] = None,
max_video_frames: int = 8
) -> typing.Dict[str, torch.Tensor]

Collate for NemotronOmni (image / video / audio).

Defers <image>/<video>/<audio> placeholder expansion to the processor; the collate only gathers media into processor kwargs, pads, builds labels, and stacks tensors.

nemo_automodel.components.datasets.vlm.collate_fns.nemotron_parse_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
task_prompt: str = '</s><s><predict_bbox><pred...
) -> typing.Dict[str, torch.Tensor]

Collate function for NVIDIA Nemotron-Parse models.

The Nemotron-Parse processor does not expose a chat template, so we build the prompt + answer string manually, mask the prompt tokens, and keep the image preprocessing handled by the processor.

nemo_automodel.components.datasets.vlm.collate_fns.pad_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
max_length: typing.Optional[int] = None
) -> typing.Dict[str, torch.Tensor]

Collate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).

Each example is expected to carry at least input_ids, attention_mask, and labels as 1-D tensors, plus optional media tensors (pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw).

Fake image injection and vision-token masking are handled per-sample in :class:PreTokenizedDatasetWrapper.__getitem__, so this function only pads, stacks, and concatenates.

The function:

  1. Pads all sequence tensors to the same length (either max_length or the longest sequence in the batch).
  2. Concatenates media tensors across the batch.
  3. Applies the standard autoregressive shift (labels = labels[:, 1:], inputs truncated by one token).
nemo_automodel.components.datasets.vlm.collate_fns.phi4_mm_collate_fn(
examples,
processor
)

Collate function for Phi-4 MM model audio input

nemo_automodel.components.datasets.vlm.collate_fns.qwen2_5_collate_fn(
examples: list,
processor
) -> dict[str, torch.Tensor]

Collate function for Qwen2.5 VL model.

nemo_automodel.components.datasets.vlm.collate_fns.qwen3_omni_collate_fn(
examples: typing.Sequence[typing.Dict[str, typing.Any]],
processor,
use_audio_in_video: bool = False
) -> typing.Dict[str, torch.Tensor]

Collate function for Qwen3 Omni processors.

nemo_automodel.components.datasets.vlm.collate_fns.COLLATE_FNS = {'Qwen2_5_VLProcessor': qwen2_5_collate_fn, 'Qwen3OmniMoeProcessor': qwen3_omni_...
nemo_automodel.components.datasets.vlm.collate_fns.HAVE_QWEN_OMNI_UTILS = True
nemo_automodel.components.datasets.vlm.collate_fns.HAVE_QWEN_VL_UTILS = True
nemo_automodel.components.datasets.vlm.collate_fns._DEFAULT_MERGE_KERNEL: Tuple[int, int] = (2, 2)
nemo_automodel.components.datasets.vlm.collate_fns._GEMMA4_MODEL_TURN = '<|turn>model\n'
nemo_automodel.components.datasets.vlm.collate_fns._GEMMA4_THINKING_PREFIX = '<|channel>thought\n<channel|>'
nemo_automodel.components.datasets.vlm.collate_fns._IMSTART_TEMPLATE_PROCESSORS = frozenset({'Qwen2VLProcessor', 'Qwen2_5_VLProcessor', 'Qwen2_5OmniProcessor', 'Q...
nemo_automodel.components.datasets.vlm.collate_fns.logger = logging.getLogger(__name__)