nemo_automodel.components.datasets.vlm.collate_fns#

Module Contents#

Functions#

make_robust_collate

Wrap collate_fn so that on failure the entire batch is re-sampled.

_find_pattern_indices

_extract_assistant_text

_decode_single_token

Decode a single token id across tokenizer implementations.

build_labels

Construct label and optional loss-mask tensors aligned to assistant responses.

_get_assistant_marker

Return the token-id sequence that introduces an assistant turn.

_get_stop_token_id

Return the token id of the turn-ending marker (<|im_end|>).

_derive_turn_markers

Derive the assistant-turn start marker and end-of-turn token id from the tokenizer’s own chat template.

_build_labels_from_markers

Scan input_ids for assistant_marker … stop_id and build labels.

build_labels_from_template

Build training labels by scanning input_ids for chat-template role markers.

phi4_mm_collate_fn

Collate function for Phi-4 MM model audio input

_extract_media_from_conversations

Extract image and video inputs from conversation content elements.

_count_media_per_sample

Count images and videos per sample from conversation structure.

qwen2_5_collate_fn

Collate function for Qwen2.5 VL model.

qwen3_omni_collate_fn

Collate function for Qwen3 Omni processors.

_extract_audios_from_conversation

Walk a Qwen3-Omni-style conversation and collect audio payloads in order.

_validate_and_coerce_audio_payload

Coerce an audio payload to a 1-D float32 np.ndarray or raise.

_conversation_ends_with_assistant_text

Return True iff the last turn is an assistant turn with non-empty text content.

qwen3_omni_asr_collate_fn

Collate Qwen3-Omni ASR conversations into model inputs without qwen_omni_utils.

qwen2_5_omni_asr_collate_fn

Collate Qwen2.5-Omni ASR conversations.

kimi_vl_collate_fn

Collate function for KimiVL processors.

_expand_image_tokens

Expand image placeholder tokens to the correct patch counts based on grid_thws.

kimi_k25_vl_collate_fn

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

nemotron_parse_collate_fn

Collate function for NVIDIA Nemotron-Parse models.

_ensure_rgb

Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.

_extract_image_config

Extract image processing config from processor for token estimation.

_estimate_media_tokens

Estimate expanded media token count from image/video dimensions.

_drop_overlong_samples

Drop conversations whose estimated token count exceeds max_length.

default_collate_fn

Default collate function for multimodal VLM datasets.

pad_collate_fn

Collate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).

neat_packed_vlm_collater

Collater for neat-packed VLM sequences.

nemotron_omni_collate_fn

Collate for NemotronOmni (image / video / audio).

_inject_thinking_prefix_tokens

Insert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.

gemma4_inject_thinking_prefix

Inject Gemma4’s thinking-channel prefix after every assistant turn marker.

gemma4_prefix_collate_fn

Collate function for Gemma4 models with thinking-channel prefix.

llava_onevision_collate_fn

Collate function for LLaVA-OneVision-1.5 processors.

Data#

API#

nemo_automodel.components.datasets.vlm.collate_fns.logger#

‘getLogger(…)’

nemo_automodel.components.datasets.vlm.collate_fns._DEFAULT_MERGE_KERNEL: Tuple[int, int]#

(2, 2)

nemo_automodel.components.datasets.vlm.collate_fns.make_robust_collate(dataset, collate_fn, max_retries=10)#

Wrap collate_fn so that on failure the entire batch is re-sampled.

Parameters:
  • dataset – The dataset to re-sample from on failure.

  • collate_fn – The collate function to wrap.

  • max_retries – Maximum number of retry attempts.

nemo_automodel.components.datasets.vlm.collate_fns._find_pattern_indices(
template,
pattern,
search_start_index=0,
allow_first_token_mismatch=False,
)#
nemo_automodel.components.datasets.vlm.collate_fns._extract_assistant_text(message: Dict[str, Any]) str#
nemo_automodel.components.datasets.vlm.collate_fns._decode_single_token(tokenizer, token_id: int) str#

Decode a single token id across tokenizer implementations.

Some tokenizers accept an int token id, while others require a sequence of ids (e.g., List[int]). We try the common forms in order.

nemo_automodel.components.datasets.vlm.collate_fns.build_labels(
input_ids_batch: torch.Tensor,
conversations: Sequence[Sequence[Dict[str, Any]]],
processor,
) torch.Tensor#

Construct label and optional loss-mask tensors aligned to assistant responses.

nemo_automodel.components.datasets.vlm.collate_fns._get_assistant_marker(tokenizer) Optional[List[int]]#

Return the token-id sequence that introduces an assistant turn.

For Qwen-family models the marker is [<|im_start|>, assistant, \n]. Returns None when the tokenizer does not use this convention.

nemo_automodel.components.datasets.vlm.collate_fns._get_stop_token_id(tokenizer) Optional[int]#

Return the token id of the turn-ending marker (<|im_end|>).

nemo_automodel.components.datasets.vlm.collate_fns._IMSTART_TEMPLATE_PROCESSORS#

‘frozenset(…)’

nemo_automodel.components.datasets.vlm.collate_fns._derive_turn_markers(tokenizer) Tuple[List[int], int]#

Derive the assistant-turn start marker and end-of-turn token id from the tokenizer’s own chat template.

The function applies a minimal dummy conversation that contains a known sentinel string as the assistant reply, then locates the sentinel in the resulting token sequence. Everything between the end of the user turn and the start of the sentinel becomes the assistant marker; the first token after the sentinel becomes the end-of-turn id.

This approach is robust to BPE context-sensitivity and works for any model whose template wraps assistant turns with fixed token sequences — e.g. Gemma4’s <start_of_turn>model\n … <end_of_turn>.

.. note:: apply_chat_template may return a :class:~transformers.BatchEncoding (a UserDict subclass, not a plain :class:dict), so isinstance(result, dict) is False. We access result["input_ids"] directly, which works for both BatchEncoding and plain dict / list.

Returns#

tuple[list[int], int] (assistant_marker, end_of_turn_id)

Raises#

ValueError If the sentinel cannot be located in the template output or if the resulting marker is empty.

nemo_automodel.components.datasets.vlm.collate_fns._build_labels_from_markers(
input_ids_batch: torch.Tensor,
assistant_marker: List[int],
stop_id: int,
) torch.Tensor#

Scan input_ids for assistant_marker … stop_id and build labels.

For each sequence in the batch, every token between the end of an assistant marker and the corresponding stop_id (inclusive) is copied into the labels tensor; all other positions are set to -100.

Parameters#

input_ids_batch: Shape (B, L). assistant_marker: Token-id sequence that opens an assistant turn (e.g. [<|im_start|>, assistant_id, newline_id] for Qwen or [<start_of_turn>, model_id, newline_id] for Gemma4). stop_id: Single token id that closes a turn (e.g. <|im_end|> or <end_of_turn>).

nemo_automodel.components.datasets.vlm.collate_fns.build_labels_from_template(
input_ids_batch: torch.Tensor,
conversations: Sequence[Sequence[Dict[str, Any]]],
processor,
) torch.Tensor#

Build training labels by scanning input_ids for chat-template role markers.

Instead of re-tokenizing assistant text and searching for it (fragile due to BPE context sensitivity), this function locates the structural markers that the chat template inserts around each assistant turn and sets labels only for the content region.

Two strategies are attempted in order:

  1. Fast path (_IMSTART_TEMPLATE_PROCESSORS): for Qwen-family models whose tokenizers expose <|im_start|> / <|im_end|> via

    func:

    convert_tokens_to_ids, the marker ids are resolved directly without applying any dummy conversation.

  2. General path (_derive_turn_markers): for all other processors (e.g. Gemma4), the assistant-turn markers are derived automatically by applying a minimal dummy conversation that contains a sentinel string. This handles models whose tokenizers do not reliably expose special-token ids via convert_tokens_to_ids or encode.

If both strategies fail, the function falls back to the legacy

Func:

build_labels (BPE pattern-matching), which logs a warning because it is sensitive to tokenisation context and may produce num_label_tokens=0 / nan loss on some samples.

nemo_automodel.components.datasets.vlm.collate_fns.phi4_mm_collate_fn(examples, processor)#

Collate function for Phi-4 MM model audio input

nemo_automodel.components.datasets.vlm.collate_fns._extract_media_from_conversations(conversations)#

Extract image and video inputs from conversation content elements.

Images are returned as-is (PIL Image or path string) for the image processor. Videos are returned as path strings so the video processor can read and sample them using its own fps / max_frames configuration.

Returns:

(images list | None, videos list | None)

Return type:

tuple

nemo_automodel.components.datasets.vlm.collate_fns._count_media_per_sample(conversations)#

Count images and videos per sample from conversation structure.

Returns two lists of length len(conversations) giving the number of image and video items in each conversation, respectively.

nemo_automodel.components.datasets.vlm.collate_fns.qwen2_5_collate_fn(
examples: list,
processor,
) dict[str, torch.Tensor]#

Collate function for Qwen2.5 VL model.

nemo_automodel.components.datasets.vlm.collate_fns.qwen3_omni_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
use_audio_in_video: bool = False,
) Dict[str, torch.Tensor]#

Collate function for Qwen3 Omni processors.

nemo_automodel.components.datasets.vlm.collate_fns._extract_audios_from_conversation(
conversation: Sequence[Dict[str, Any]],
) List[Any]#

Walk a Qwen3-Omni-style conversation and collect audio payloads in order.

The returned list contains the raw audio objects (typically 1-D np.ndarray waveforms) attached to {"type": "audio", "audio": ...} items in any message’s content list. Used by :func:qwen3_omni_asr_collate_fn to feed the processor’s audio= kwarg without going through qwen_omni_utils.

nemo_automodel.components.datasets.vlm.collate_fns._validate_and_coerce_audio_payload(
payload: Any,
sample_index: int,
) numpy.ndarray#

Coerce an audio payload to a 1-D float32 np.ndarray or raise.

The single rule:

  • Convert any numeric np.ndarray / torch.Tensor to np.float32.

  • The result must be exactly 1-D after conversion (mono waveform).

  • Anything else raises ValueError naming the sample index, observed shape, and observed dtype so the caller can pinpoint the bad sample.

Parameters:
  • payload – Audio object pulled from a conversation content item.

  • sample_index – Index of the offending sample within the batch (for error messages).

Returns:

A 1-D np.float32 np.ndarray.

Raises:

ValueError – When the payload is not a numeric array or is not 1-D.

nemo_automodel.components.datasets.vlm.collate_fns._conversation_ends_with_assistant_text(
conversation: Sequence[Dict[str, Any]],
) bool#

Return True iff the last turn is an assistant turn with non-empty text content.

nemo_automodel.components.datasets.vlm.collate_fns.qwen3_omni_asr_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
) Dict[str, torch.Tensor]#

Collate Qwen3-Omni ASR conversations into model inputs without qwen_omni_utils.

Unlike :func:qwen3_omni_collate_fn, this collate is intended for environments that lack qwen_omni_utils and torchcodec. It assumes audio waveforms are already attached to the conversation as 1-D np.ndarray items of the form {"type": "audio", "audio": waveform} (see

Func:

nemo_automodel.components.datasets.vlm.datasets.make_hf_audio_asr_dataset) and passes them directly to the processor’s audio= kwarg, which routes to the bundled WhisperFeatureExtractor.

Label masking is delegated to :func:build_labels_from_template, which uses the marker-based fast path that already supports Qwen3OmniMoeProcessor via _IMSTART_TEMPLATE_PROCESSORS. The collate produces pre-shifted labels (labels[:, 1:]) and slices same-shape tensors to [:, :-1] so the downstream loss (MaskedCrossEntropy/FusedLinearCrossEntropy) consumes them without a second internal shift.

Parameters:
  • examples – Iterable of dicts each containing a conversation key, where the last turn MUST be an assistant turn with non-empty text.

  • processor – A Qwen3OmniMoeProcessor instance (or compatible mock).

Returns:

Dict with input_ids, attention_mask, input_features, feature_attention_mask, and labels plus any other tensors the processor returns, all aligned along the batch dimension.

Raises:

ValueError – If any conversation lacks a non-empty assistant turn at the end (the marker-based labeler would otherwise produce all--100 labels and a NaN loss).

nemo_automodel.components.datasets.vlm.collate_fns.qwen2_5_omni_asr_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
) Dict[str, torch.Tensor]#

Collate Qwen2.5-Omni ASR conversations.

Thin alias over :func:qwen3_omni_asr_collate_fn: the body is processor- agnostic (it only depends on the processor exposing apply_chat_template and the audio= kwarg, both of which Qwen2_5OmniProcessor provides), so the entire Qwen3-Omni-ASR path works unchanged here. We expose a separate symbol so YAML configs can pick the right collate via _target_ without users having to know about the Qwen3-Omni name.

nemo_automodel.components.datasets.vlm.collate_fns.kimi_vl_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for KimiVL processors.

nemo_automodel.components.datasets.vlm.collate_fns._expand_image_tokens(
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
grid_thws: torch.Tensor,
media_token_id: int,
merge_kernel_size: Tuple[int, int] = _DEFAULT_MERGE_KERNEL,
) Tuple[torch.Tensor, torch.Tensor]#

Expand image placeholder tokens to the correct patch counts based on grid_thws.

For PP, this ensures the sequence length is fixed BEFORE the model forward pass, eliminating dynamic sequence expansion inside the model.

Supports both single-image and multi-image samples. Each placeholder token is expanded to (h // merge_h) * (w // merge_w) tokens using the corresponding entry in grid_thws. The number of placeholders must match grid_thws.shape[0].

Parameters:
  • input_ids – (seq_len,) tensor containing one media_token_id per image.

  • attention_mask – (seq_len,) tensor aligned with input_ids.

  • grid_thws – (N, 3) tensor with [t, h, w] for each of the N images.

  • media_token_id – Token ID used as the image placeholder.

  • merge_kernel_size – Vision tower’s patch merge kernel, default (2, 2).

Returns:

Input IDs with each placeholder expanded to its patch count. expanded_attention_mask: Attention mask expanded accordingly.

Return type:

expanded_input_ids

Raises:

ValueError – When the number of placeholders does not match grid_thws.shape[0].

nemo_automodel.components.datasets.vlm.collate_fns.kimi_k25_vl_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
drop_overlong: bool = False,
) Dict[str, torch.Tensor]#

Collate function for Kimi K2.5 VL processors with pre-expanded image tokens.

For pipeline parallelism, this function:

  1. Processes each sample to get input_ids with 1 placeholder per image

  2. Pre-expands the placeholder to N tokens (N = (h//2)*(w//2) from grid_thws)

  3. Pads all sequences to fixed max_length This ensures the model forward pass doesn’t change sequence length dynamically.

nemo_automodel.components.datasets.vlm.collate_fns.nemotron_parse_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
task_prompt: str = '</s><s><predict_bbox><predict_classes><output_markdown>',
) Dict[str, torch.Tensor]#

Collate function for NVIDIA Nemotron-Parse models.

The Nemotron-Parse processor does not expose a chat template, so we build the prompt + answer string manually, mask the prompt tokens, and keep the image preprocessing handled by the processor.

nemo_automodel.components.datasets.vlm.collate_fns._ensure_rgb(conversations)#

Convert any PIL images in conversations to RGB to handle RGBA/grayscale inputs.

nemo_automodel.components.datasets.vlm.collate_fns._extract_image_config(processor)#

Extract image processing config from processor for token estimation.

nemo_automodel.components.datasets.vlm.collate_fns._estimate_media_tokens(conversation, processor)#

Estimate expanded media token count from image/video dimensions.

Returns total extra tokens beyond the single-placeholder-per-media count that tokenization produces. Only images with known dimensions (PIL Image objects or loadable paths) are estimated; unknown media items contribute 0 extra tokens (the placeholder is still counted in the base tokenization).

nemo_automodel.components.datasets.vlm.collate_fns._drop_overlong_samples(conversations, processor, max_length)#

Drop conversations whose estimated token count exceeds max_length.

Returns (filtered_conversations, kept_indices) where kept_indices are the original positions that survived filtering. Raises ValueError when every sample in the batch is dropped (caught by robust_collate which re-samples).

nemo_automodel.components.datasets.vlm.collate_fns.default_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
drop_overlong: bool = False,
_post_tokenize_hook=None,
) Dict[str, torch.Tensor]#

Default collate function for multimodal VLM datasets.

Parameters:

_post_tokenize_hook – Optional callable (batch, processor) -> batch invoked right after apply_chat_template and before build_labels. Used by model-specific collate wrappers (e.g. Gemma4 thinking-channel injection) to transform the tokenized batch and the prefix tokens without duplicating the rest of the pipeline.

nemo_automodel.components.datasets.vlm.collate_fns.pad_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for pre-tokenized samples (from :class:PreTokenizedDatasetWrapper).

Each example is expected to carry at least input_ids, attention_mask, and labels as 1-D tensors, plus optional media tensors (pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw).

Fake image injection and vision-token masking are handled per-sample in

Class:

PreTokenizedDatasetWrapper.__getitem__, so this function only pads, stacks, and concatenates.

The function:

  1. Pads all sequence tensors to the same length (either max_length or the longest sequence in the batch).

  2. Concatenates media tensors across the batch.

  3. Applies the standard autoregressive shift (labels = labels[:, 1:], inputs truncated by one token).

nemo_automodel.components.datasets.vlm.collate_fns.neat_packed_vlm_collater(
batch: list[dict],
padding_idx: int = 0,
max_length: int | None = None,
attn_implementation: str = 'sdpa',
) dict#

Collater for neat-packed VLM sequences.

Packs arrive with variable lengths (no pre-padding). This collater:

  1. Pads all text tensors to a common length.

  2. Converts the indexed attention_mask to the appropriate format:

    • flash_attention_2: keeps the indexed [B, S] mask (values 1, 2, … for documents, 0 for padding). The monkey-patched _get_unpad_data converts this to cu_seqlens for flash_attn_varlen_func.

    • sdpa / eager: converts to a 4D block-causal bool mask.

  3. Concatenates media tensors across the batch dimension.

No autoregressive shift — it was already applied during packing.

Parameters:
  • batch – List of packed sample dicts from PackedDatasetWrapper.

  • padding_idx – Token ID for padding input_ids (default 0).

  • max_length – If set, pad every batch to this fixed length. If None (default), pad to the longest pack in the batch. A fixed length avoids recompilation with torch.compile and ensures uniform tensor shapes across steps.

  • attn_implementation – Attention backend ("flash_attention_2", "sdpa", or "eager").

Returns:

Dict with batched tensors ready for model forward.

nemo_automodel.components.datasets.vlm.collate_fns.nemotron_omni_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
max_video_frames: int = 8,
) Dict[str, torch.Tensor]#

Collate for NemotronOmni (image / video / audio).

Defers <image>/<video>/<audio> placeholder expansion to the processor; the collate only gathers media into processor kwargs, pads, builds labels, and stacks tensors.

nemo_automodel.components.datasets.vlm.collate_fns._GEMMA4_MODEL_TURN#

‘<|turn>model\n’

nemo_automodel.components.datasets.vlm.collate_fns._GEMMA4_THINKING_PREFIX = <Multiline-String>#
nemo_automodel.components.datasets.vlm.collate_fns._inject_thinking_prefix_tokens(
batch: Dict[str, torch.Tensor],
tokenizer,
) Dict[str, torch.Tensor]#

Insert <|channel>thought\n<channel|> tokens after every <|turn>model\n marker.

Modifies input_ids, attention_mask, and mm_token_type_ids (if present). Additionally, any other 2-D integer tensor whose second dimension matches input_ids is extended with zeros so that sequence lengths stay consistent (this ismore of future-proofing)

nemo_automodel.components.datasets.vlm.collate_fns.gemma4_inject_thinking_prefix(
batch: Dict[str, torch.Tensor],
processor,
) Dict[str, torch.Tensor]#

Inject Gemma4’s thinking-channel prefix after every assistant turn marker.

Gemma4 31B / 26B-A4B MoE instruction-tuned models always emit a thinking- channel prefix before the actual response. When this prefix is absent from training sequences the model predicts <|channel> but the label says answer text, inflating initial loss to ~9. Injecting the prefix (masked as -100 in labels) lets the model see its expected pattern and brings initial loss down to ~3.

Safe no-op for non-Gemma4 tokenizers.

nemo_automodel.components.datasets.vlm.collate_fns.gemma4_prefix_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
max_length: Optional[int] = None,
) Dict[str, torch.Tensor]#

Collate function for Gemma4 models with thinking-channel prefix.

Wraps default_collate_fn and injects <|channel>thought\n<channel|> after every <|turn>model\n marker before labels are built. The injected tokens are automatically masked to -100 by build_labels_from_template (which only unmasks tokens inside assistant turns), so the model sees its expected thinking prefix without being penalised for it.

nemo_automodel.components.datasets.vlm.collate_fns.llava_onevision_collate_fn(
examples: Sequence[Dict[str, Any]],
processor,
) Dict[str, torch.Tensor]#

Collate function for LLaVA-OneVision-1.5 processors.

Handles image and video inputs using the Qwen-style chat template and qwen_vl_utils for vision info processing.

nemo_automodel.components.datasets.vlm.collate_fns.COLLATE_FNS#

None