Source code for nemo_automodel.components.datasets.vlm.collate_fns

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock

import torch

from nemo_automodel.components.datasets.vlm.utils import extract_skipped_token_ids
from nemo_automodel.shared.import_utils import MISSING_QWEN_VL_UTILS_MSG

try:
    from qwen_vl_utils import process_vision_info

    HAVE_QWEN_VL_UTILS = True
except ImportError:
    HAVE_QWEN_VL_UTILS = False
    process_vision_info = MagicMock()


[docs] def create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token=None): r""" Create loss mask by finding start of turn token positions, similar to squad.py approach. Args: input_ids: List or tensor of token IDs for a single example processor: Processor/tokenizer to convert token string to ID start_of_response_token: String token that marks the start of turns (e.g., "<start_of_turn>model\n") Returns: loss_mask: List of 0/1 flags where 0 = masked (prompt), 1 = unmasked (response) """ tokenizer = getattr(processor, "tokenizer", processor) input_ids = input_ids.tolist() if start_of_response_token is None: return [1] * len(input_ids) if isinstance(start_of_response_token, str): start_of_response_token_id = tokenizer(start_of_response_token, add_special_tokens=False)["input_ids"] start_of_turn_token_id = start_of_response_token_id[0] if isinstance(start_of_response_token, str) and input_ids.count(start_of_turn_token_id) >= 2: first_start_of_turn_token_id = input_ids.index(start_of_turn_token_id) response_start = input_ids.index(start_of_turn_token_id, first_start_of_turn_token_id + 1) else: response_start = 0 pad_token_id = getattr(tokenizer, "pad_token_id", 0) if pad_token_id is None: pad_token_id = 0 loss_mask = [0] * response_start + [1] * (len(input_ids) - response_start) for i, token_id in enumerate(input_ids): if token_id == pad_token_id: loss_mask[i] = 0 return loss_mask
[docs] def phi4_mm_collate_fn(examples, processor): """Collate function for Phi-4 MM model audio input""" # Extract conversations and audio data conversations = [example["conversation"] for example in examples] audios = [example["audio"] for example in examples] texts = [processor.apply_chat_template(conversation, tokenize=False) for conversation in conversations] audio_inputs = [(audio["array"], audio["sampling_rate"]) if isinstance(audio, dict) else audio for audio in audios] batch = processor( text=texts, audios=audio_inputs, return_tensors="pt", padding=True, truncation=True, max_length=1024 ) labels = batch["input_ids"].clone()[:, 1:] labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1) loss_masks = [] for i, conversation in enumerate(conversations): input_ids = batch["input_ids"][i].tolist() assistant_content = conversation[1]["content"] assistant_tokens = processor.tokenizer(assistant_content, add_special_tokens=False)["input_ids"] loss_mask = [0] * len(input_ids) for start_idx in range(len(input_ids) - len(assistant_tokens) + 1): if input_ids[start_idx : start_idx + len(assistant_tokens)] == assistant_tokens: for j in range(len(assistant_tokens)): loss_mask[start_idx + j] = 1 break loss_masks.append(loss_mask) max_len = max(len(mask) for mask in loss_masks) padded_loss_masks = [mask + [0] * (max_len - len(mask)) for mask in loss_masks] batch["loss_mask"] = torch.tensor(padded_loss_masks, dtype=torch.float) labels[batch["loss_mask"] == 0] = -100 batch["labels"] = labels # Remove specified batch features if present for key in ["input_image_embeds", "image_sizes", "image_attention_mask"]: if key in batch: del batch[key] return batch
[docs] def qwen2_5_collate_fn( examples: list, processor, start_of_response_token="<|im_start|>assistant\n" ) -> dict[str, torch.Tensor]: """Collate function for Qwen2.5 VL model.""" if not HAVE_QWEN_VL_UTILS: raise ImportError(MISSING_QWEN_VL_UTILS_MSG) skipped_tokens = extract_skipped_token_ids(processor) texts = [processor.apply_chat_template(example["conversation"], tokenize=False) for example in examples] image_inputs = [process_vision_info(example["conversation"])[0] for example in examples] batch = processor( text=texts, images=image_inputs, padding=True, return_tensors="pt", ) labels = batch["input_ids"].clone()[:, 1:] labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1) labels[torch.isin(labels, skipped_tokens)] = -100 batch["labels"] = labels loss_masks = [ create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token) for input_ids in batch["input_ids"] ] batch["loss_mask"] = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device) return batch
[docs] def default_collate_fn(examples: list, processor, start_of_response_token=None) -> dict[str, torch.Tensor]: """Default collate function for VLM models.""" if not HAVE_QWEN_VL_UTILS: raise ImportError(MISSING_QWEN_VL_UTILS_MSG) skipped_tokens = extract_skipped_token_ids(processor) batch = processor.apply_chat_template( [example["conversation"] for example in examples], tokenize=True, add_generation_prompt=False, return_tensors="pt", return_dict=True, ) if "position_ids" not in batch: batch_size, seq_len = batch["input_ids"].shape batch["position_ids"] = ( torch.arange(seq_len, device=batch["input_ids"].device).unsqueeze(0).expand(batch_size, -1) ) batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16) labels = batch["input_ids"].clone()[:, 1:] labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1) labels[torch.isin(labels, skipped_tokens)] = -100 batch["labels"] = labels loss_masks = [ create_loss_mask_with_start_of_response_token(input_ids, processor, start_of_response_token) for input_ids in batch["input_ids"] ] batch["loss_mask"] = torch.tensor(loss_masks, dtype=torch.float, device=batch["input_ids"].device) return batch
# Mapping of processor types to their collate functions COLLATE_FNS = { "Qwen2_5_VLProcessor": qwen2_5_collate_fn, "default": default_collate_fn, }