nemo_automodel.components.datasets.vlm.collate_fns#

Module Contents#

Functions#

make_robust_collate

Wrap collate_fn so that on failure the entire batch is re-sampled.

_find_pattern_indices

_extract_assistant_text

_decode_single_token

Decode a single token id across tokenizer implementations.

build_labels

Construct label and optional loss-mask tensors aligned to assistant responses.

_get_assistant_marker

Return the token-id sequence that introduces an assistant turn.

_get_stop_token_id

Return the token id of the turn-ending marker (<|im_end|>).

build_labels_from_template

Build training labels by scanning input_ids for chat-template role markers.

phi4_mm_collate_fn

Collate function for Phi-4 MM model audio input

_extract_media_from_conversations

Extract image and video inputs from conversation content elements.

_count_media_per_sample

Count images and videos per sample from conversation structure.

qwen2_5_collate_fn

Collate function for Qwen2.5 VL model.

qwen3_omni_collate_fn

Collate function for Qwen3 Omni processors.

kimi_vl_collate_fn

Collate function for KimiVL processors.

_expand_image_tokens

Expand single image placeholder tokens to the correct number based on grid_thws.

kimi_k25_vl_collate_fn

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

nemotron_parse_collate_fn

Collate function for NVIDIA Nemotron-Parse models.

_ensure_rgb

Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.

_extract_image_config

Extract image processing config from processor for token estimation.

_estimate_media_tokens

Estimate expanded media token count from image/video dimensions.

_drop_overlong_samples

Drop conversations whose estimated token count exceeds max_length.

default_collate_fn

Default collate function for multimodal VLM datasets.

pad_collate_fn

Collate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).

neat_packed_vlm_collater

Collater for neat-packed VLM sequences.

_inject_thinking_prefix_tokens

Insert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.

gemma4_prefix_collate_fn

Collate function for Gemma4 models with thinking-channel prefix.

Data#

API#

nemo_automodel.components.datasets.vlm.collate_fns.logger#

‘getLogger(…)’

nemo_automodel.components.datasets.vlm.collate_fns.make_robust_collate(dataset, collate_fn, max_retries=10)#

Wrap collate_fn so that on failure the entire batch is re-sampled.

Parameters:
  • dataset – The dataset to re-sample from on failure.

  • collate_fn – The collate function to wrap.

  • max_retries – Maximum number of retry attempts.

nemo_automodel.components.datasets.vlm.collate_fns._find_pattern_indices(
template,
pattern,
search_start_index=0,
allow_first_token_mismatch=False,
)#
nemo_automodel.components.datasets.vlm.collate_fns._extract_assistant_text(message: Dict[str, Any]) str#
nemo_automodel.components.datasets.vlm.collate_fns._decode_single_token(tokenizer, token_id: int) str#

Decode a single token id across tokenizer implementations.

Some tokenizers accept an int token id, while others require a sequence of ids (e.g., List[int]). We try the common forms in order.

nemo_automodel.components.datasets.vlm.collate_fns.build_labels(
input_ids_batch: torch.Tensor,
conversations: Sequence[Sequence[Dict[str, Any]]],
processor,
) torch.Tensor#

Construct label and optional loss-mask tensors aligned to assistant responses.

nemo_automodel.components.datasets.vlm.collate_fns._get_assistant_marker(tokenizer) Optional[List[int]]#

Return the token-id sequence that introduces an assistant turn.

For Qwen-family models the marker is [<|im_start|>, assistant, \n]. Returns None when the tokenizer does not use this convention.

nemo_automodel.components.datasets.vlm.collate_fns._get_stop_token_id(tokenizer) Optional[int]#

Return the token id of the turn-ending marker (<|im_end|>).

nemo_automodel.components.datasets.vlm.collate_fns._IMSTART_TEMPLATE_PROCESSORS#

‘frozenset(…)’

nemo_automodel.components.datasets.vlm.collate_fns.build_labels_from_template(
input_ids_batch: torch.Tensor,
conversations: Sequence[Sequence[Dict[str, Any]]],
processor,
) torch.Tensor#

Build training labels by scanning input_ids for chat-template role markers.

Instead of re-tokenizing assistant text and searching for it (fragile due to BPE context sensitivity), this function locates the structural markers that the chat template inserts around each assistant turn:

``<|im_start|>assistant\n`` … content … ``<|im_end|>``

Labels are set to the actual token ids for the content region (including <|im_end|>); everything else is -100.

Falls back to the old :func:build_labels for processor types that do not use the <|im_start|>/<|im_end|> convention (e.g. Kimi, Phi4, Nemotron-Parse).

nemo_automodel.components.datasets.vlm.collate_fns.phi4_mm_collate_fn(examples, processor)#

Collate function for Phi-4 MM model audio input

nemo_automodel.components.datasets.vlm.collate_fns._extract_media_from_conversations(conversations)#

Extract image and video inputs from conversation content elements.

Images are returned as-is (PIL Image or path string) for the image processor. Videos are returned as path strings so the video processor can read and sample them using its own fps / max_frames configuration.

Returns:

(images list | None, videos list | None)

Return type:

tuple

nemo_automodel.components.datasets.vlm.collate_fns._count_media_per_sample(conversations)#

Count images and videos per sample from conversation structure.

Returns two lists of length len(conversations) giving the number of image and video items in each conversation, respectively.

nemo_automodel.components.datasets.vlm.collate_fns.qwen2_5_collate_fn(
examples: list,
processor,
) dict[str, torch.Tensor]#

Collate function for Qwen2.5 VL model.

nemo_automodel.components.datasets.vlm.collate_fns.qwen3_omni_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
use_audio_in_video: bool = False,
) Dict[str, torch.Tensor]#

Collate function for Qwen3 Omni processors.

nemo_automodel.components.datasets.vlm.collate_fns.kimi_vl_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for KimiVL processors.

nemo_automodel.components.datasets.vlm.collate_fns._expand_image_tokens(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
grid_thws: torch.Tensor,
media_token_id: int,
merge_kernel_size: Tuple[int, int] = (2, 2),
) Tuple[torch.Tensor, torch.Tensor]#

Expand single image placeholder tokens to the correct number based on grid_thws.

For PP, this ensures the sequence length is fixed BEFORE the model forward pass, eliminating dynamic sequence expansion inside the model.

Assumes 1 image per sample (1 placeholder per sequence).

Parameters:
  • input_ids – (seq_len,) tensor with 1 media_token_id placeholder

  • attention_mask – (seq_len,) tensor

  • grid_thws – (1, 3) tensor with [t, h, w] for the single image

  • media_token_id – Token ID of the image placeholder

  • merge_kernel_size – Vision tower’s patch merge kernel, default (2, 2)

Returns:

Input IDs with placeholder expanded to N tokens expanded_attention_mask: Attention mask expanded accordingly

Return type:

expanded_input_ids

nemo_automodel.components.datasets.vlm.collate_fns.kimi_k25_vl_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

For pipeline parallelism, this function:

  1. Processes each sample to get input_ids with 1 placeholder per image

  2. Pre-expands the placeholder to N tokens (N = (h//2)*(w//2) from grid_thws)

  3. Pads all sequences to fixed max_length This ensures the model forward pass doesn’t change sequence length dynamically.

nemo_automodel.components.datasets.vlm.collate_fns.nemotron_parse_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
task_prompt: str = '</s><s><predict_bbox><predict_classes><output_markdown>',
) Dict[str, torch.Tensor]#

Collate function for NVIDIA Nemotron-Parse models.

The Nemotron-Parse processor does not expose a chat template, so we build the prompt + answer string manually, mask the prompt tokens, and keep the image preprocessing handled by the processor.

nemo_automodel.components.datasets.vlm.collate_fns._ensure_rgb(conversations)#

Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.

nemo_automodel.components.datasets.vlm.collate_fns._extract_image_config(processor)#

Extract image processing config from processor for token estimation.

nemo_automodel.components.datasets.vlm.collate_fns._estimate_media_tokens(conversation, processor)#

Estimate expanded media token count from image/video dimensions.

Returns total extra tokens beyond the single-placeholder-per-media count that tokenization produces. Only images with known dimensions (PIL Image objects or loadable paths) are estimated; unknown media items contribute 0 extra tokens (the placeholder is still counted in the base tokenization).

nemo_automodel.components.datasets.vlm.collate_fns._drop_overlong_samples(conversations, processor, max_length)#

Drop conversations whose estimated token count exceeds max_length.

Returns (filtered_conversations, kept_indices) where kept_indices are the original positions that survived filtering. Raises ValueError when every sample in the batch is dropped (caught by robust_collate which re-samples).

nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
_post_tokenize_hook=None,
) Dict[str, torch.Tensor]#

Default collate function for multimodal VLM datasets.

Parameters:

_post_tokenize_hook – Optional callable (batch, processor) -> batch invoked right after apply_chat_template and before build_labels. Used by model-specific collate wrappers (e.g. Gemma4 thinking-channel injection) to transform the tokenized batch and the prefix tokens without duplicating the rest of the pipeline.

nemo_automodel.components.datasets.vlm.collate_fns.pad_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).

Each example is expected to carry at least input_ids, attention_mask, and labels as 1-D tensors, plus optional media tensors (pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw).

Fake image injection and vision-token masking are handled per-sample in

Class:

PreTokenizedDatasetWrapper.__getitem__, so this function only pads, stacks, and concatenates.

The function:

  1. Pads all sequence tensors to the same length (either max_length or the longest sequence in the batch).

  2. Concatenates media tensors across the batch.

  3. Applies the standard autoregressive shift (labels = labels[:, 1:], inputs truncated by one token).

nemo_automodel.components.datasets.vlm.collate_fns.neat_packed_vlm_collater(
batch: list[dict],
padding_idx: int = 0,
max_length: int | None = None,
attn_implementation: str = 'sdpa',
) dict#

Collater for neat-packed VLM sequences.

Packs arrive with variable lengths (no pre-padding). This collater:

  1. Pads all text tensors to a common length.

  2. Converts the indexed attention_mask to the appropriate format:

    • flash_attention_2: keeps the indexed [B, S] mask (values 1, 2, … for documents, 0 for padding). The monkey-patched _get_unpad_data converts this to cu_seqlens for flash_attn_varlen_func.

    • sdpa / eager: converts to a 4D block-causal bool mask.

  3. Concatenates media tensors across the batch dimension.

No autoregressive shift — it was already applied during packing.

Parameters:
  • batch – List of packed sample dicts from PackedDatasetWrapper.

  • padding_idx – Token ID for padding input_ids (default 0).

  • max_length – If set, pad every batch to this fixed length. If None (default), pad to the longest pack in the batch. A fixed length avoids recompilation with torch.compile and ensures uniform tensor shapes across steps.

  • attn_implementation – Attention backend ("flash_attention_2", "sdpa", or "eager").

Returns:

Dict with batched tensors ready for model forward.

nemo_automodel.components.datasets.vlm.collate_fns._GEMMA4_MODEL_TURN#

‘<|turn>model\n’

nemo_automodel.components.datasets.vlm.collate_fns._GEMMA4_THINKING_PREFIX = <Multiline-String>#
nemo_automodel.components.datasets.vlm.collate_fns._inject_thinking_prefix_tokens(
batch: Dict[str, torch.Tensor],
tokenizer,
) Dict[str, torch.Tensor]#

Insert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.

Gemma4 31B / 26B-A4B MoE instruction-tuned models always emit a thinking- channel prefix before the actual response. When this prefix is absent from training sequences the model predicts <|channel> but the label says answer text, inflating initial loss to ~9. Injecting the prefix (masked as -100 in labels) lets the model see its expected pattern and brings initial loss down to ~3.

Modifies input_ids, attention_mask, and mm_token_type_ids (if present). Additionally, any other 2-D integer tensor whose second dimension matches input_ids is extended with zeros so that sequence lengths stay consistent (this ismore of future-proofing)

nemo_automodel.components.datasets.vlm.collate_fns.gemma4_prefix_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for Gemma4 models with thinking-channel prefix.

Wraps default_collate_fn and injects <|channel>thought\n<channel|> after every <|turn>model\n marker before labels are built. The injected tokens are automatically masked to -100 by build_labels (which only unmasks tokens matching the assistant answer text), so the model sees its expected thinking prefix without being penalised for it.

nemo_automodel.components.datasets.vlm.collate_fns.COLLATE_FNS#

None