Source code for nemo_automodel.datasets.vlm.collate_fns
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from nemo_automodel.datasets.vlm.utils import extract_skipped_token_ids
from nemo_automodel.shared.import_utils import MISSING_QWEN_VL_UTILS_MSG
from unittest.mock import MagicMock
try:
from qwen_vl_utils import process_vision_info
HAVE_QWEN_VL_UTILS = True
except ImportError:
HAVE_QWEN_VL_UTILS = False
process_vision_info = MagicMock()
[docs]
def qwen2_5_collate_fn(examples: list, processor) -> dict[str, torch.Tensor]:
"""Collate function for Qwen2.5 VL model."""
if not HAVE_QWEN_VL_UTILS:
raise ImportError(MISSING_QWEN_VL_UTILS_MSG)
skipped_tokens = extract_skipped_token_ids(processor)
texts = [processor.apply_chat_template(example["conversation"], tokenize=False) for example in examples]
image_inputs = [process_vision_info(example["conversation"])[0] for example in examples]
batch = processor(
text=texts,
images=image_inputs,
padding=True,
return_tensors="pt",
)
labels = batch["input_ids"].clone()[:, 1:]
labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1)
labels[torch.isin(labels, skipped_tokens)] = -100
batch["labels"] = labels
return batch
[docs]
def default_collate_fn(examples: list, processor) -> dict[str, torch.Tensor]:
"""Default collate function for VLM models."""
if not HAVE_QWEN_VL_UTILS:
raise ImportError(MISSING_QWEN_VL_UTILS_MSG)
skipped_tokens = extract_skipped_token_ids(processor)
batch = processor.apply_chat_template(
[example["conversation"] for example in examples],
tokenize=True,
add_generation_prompt=False,
return_tensors="pt",
return_dict=True,
)
batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
labels = batch["input_ids"].clone()[:, 1:]
labels = torch.cat([labels, -100 * torch.ones_like(labels[:, :1])], dim=1)
labels[torch.isin(labels, skipped_tokens)] = -100
batch["labels"] = labels
return batch
# Mapping of processor types to their collate functions
COLLATE_FNS = {
"Qwen2_5_VLProcessor": qwen2_5_collate_fn,
"default": default_collate_fn,
}