# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Collation utilities - Batch multiple (TensorDict, metadata) tuples.
Collators combine multiple (TensorDict, dict) tuples from Dataset into a single
batched output suitable for model consumption. By default, returns just the
batched TensorDict for PyTorch DataLoader compatibility. When collate_metadata=True,
returns a tuple of (TensorDict, list[dict]).
The default collator stacks TensorDicts along batch dimension using TensorDict.stack().
Metadata collation is optional and disabled by default.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Sequence
import torch
from tensordict import TensorDict
def _collate_metadata(metadata_list: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Collate metadata from multiple samples.
Simply returns the list of metadata dicts as-is. Each metadata dict
corresponds to one sample in the batch.
Parameters
----------
metadata_list : Sequence[dict[str, Any]]
Sequence of metadata dicts.
Returns
-------
list[dict[str, Any]]
List of metadata dicts.
"""
return list(metadata_list)
[docs]
class Collator(ABC):
"""
Abstract base class for collators.
Collators take a sequence of (TensorDict, dict) tuples and combine them
into a batched output. By default, returns just the batched TensorDict
for PyTorch DataLoader compatibility. When collate_metadata=True, returns
a tuple of (TensorDict, list[dict]).
Examples
--------
>>> class MyCollator(Collator):
... def __call__(
... self,
... samples: Sequence[tuple[TensorDict, dict]]
... ) -> TensorDict:
... # Custom batching logic
... ...
"""
@abstractmethod
def __call__(
self, samples: Sequence[tuple[TensorDict, dict[str, Any]]]
) -> TensorDict | tuple[TensorDict, list[dict[str, Any]]]:
"""
Collate a batch of samples.
Parameters
----------
samples : Sequence[tuple[TensorDict, dict[str, Any]]]
Sequence of (TensorDict, metadata dict) tuples to batch.
Returns
-------
TensorDict or tuple[TensorDict, list[dict[str, Any]]]
Batched TensorDict, or tuple of (batched TensorDict, list of metadata dicts)
if collate_metadata=True.
"""
raise NotImplementedError
[docs]
class DefaultCollator(Collator):
"""
Default collator that stacks TensorDicts along a new batch dimension.
Uses TensorDict.stack() to efficiently batch all tensors, creating
shape [batch_size, ...original_shape] for each field.
All samples must have:
- The same tensor keys
- Tensors with matching shapes (per key)
- Tensors on the same device
By default, returns just the batched TensorDict for PyTorch DataLoader
compatibility. Set collate_metadata=True to also return metadata.
Examples
--------
>>> data1 = TensorDict({"x": torch.randn(10, 3)}, device="cpu")
>>> data2 = TensorDict({"x": torch.randn(10, 3)}, device="cpu")
>>> samples = [
... (data1, {"file": "a.h5"}),
... (data2, {"file": "b.h5"}),
... ]
>>> collator = DefaultCollator()
>>> batched_data = collator(samples)
>>> batched_data["x"].shape
torch.Size([2, 10, 3])
With metadata collation enabled:
>>> collator = DefaultCollator(collate_metadata=True)
>>> batched_data, metadata_list = collator(samples)
>>> metadata_list
[{'file': 'a.h5'}, {'file': 'b.h5'}]
"""
def __init__(
self,
*,
stack_dim: int = 0,
keys: Optional[list[str]] = None,
collate_metadata: bool = False,
) -> None:
"""
Initialize the collator.
Parameters
----------
stack_dim : int, default=0
Dimension along which to stack tensors.
keys : list[str], optional
If provided, only collate these tensor keys. Others are ignored.
collate_metadata : bool, default=False
If True, collate metadata into list. Default is False for
compatibility with PyTorch DataLoader.
"""
self.stack_dim = stack_dim
self.keys = keys
self.collate_metadata = collate_metadata
def __call__(
self, samples: Sequence[tuple[TensorDict, dict[str, Any]]]
) -> TensorDict | tuple[TensorDict, list[dict[str, Any]]]:
"""
Collate samples by stacking TensorDicts.
Parameters
----------
samples : Sequence[tuple[TensorDict, dict[str, Any]]]
Sequence of (TensorDict, metadata) tuples to batch.
Returns
-------
TensorDict or tuple[TensorDict, list[dict[str, Any]]]
Batched TensorDict if collate_metadata=False (default),
or tuple of (batched TensorDict, list of metadata dicts)
if collate_metadata=True.
Raises
------
ValueError
If samples is empty or samples have mismatched keys/shapes.
"""
if not samples:
raise ValueError("Cannot collate empty sequence of samples")
# Separate data and metadata
data_list = [data for data, _ in samples]
# Use TensorDict.stack() for efficient batching
if self.keys is not None:
# Filter to only requested keys
data_list = [data.select(*self.keys) for data in data_list]
batched_data = torch.stack(data_list, dim=self.stack_dim)
# Collate metadata only if requested
if self.collate_metadata:
metadata_list = [meta for _, meta in samples]
return batched_data, _collate_metadata(metadata_list)
return batched_data
[docs]
class ConcatCollator(Collator):
"""
Collator that concatenates tensors along an existing dimension.
Unlike DefaultCollator which creates a new batch dimension, this
concatenates along an existing dimension. Useful for point clouds
or other variable-length data where you want to combine all points.
Optionally adds batch indices to track which points came from which sample.
By default, returns just the batched TensorDict for PyTorch DataLoader
compatibility. Set collate_metadata=True to also return metadata.
Examples
--------
>>> data1 = TensorDict({"points": torch.randn(100, 3)})
>>> data2 = TensorDict({"points": torch.randn(150, 3)})
>>> samples = [
... (data1, {"file": "a.h5"}),
... (data2, {"file": "b.h5"}),
... ]
>>> collator = ConcatCollator(dim=0, add_batch_idx=True)
>>> batched_data = collator(samples)
>>> batched_data["points"].shape
torch.Size([250, 3])
>>> batched_data["batch_idx"].shape
torch.Size([250])
With metadata collation enabled:
>>> collator = ConcatCollator(dim=0, add_batch_idx=True, collate_metadata=True)
>>> batched_data, metadata_list = collator(samples)
>>> metadata_list
[{'file': 'a.h5'}, {'file': 'b.h5'}]
"""
def __init__(
self,
*,
dim: int = 0,
add_batch_idx: bool = True,
batch_idx_key: str = "batch_idx",
keys: Optional[list[str]] = None,
collate_metadata: bool = False,
) -> None:
"""
Initialize the collator.
Parameters
----------
dim : int, default=0
Dimension along which to concatenate.
add_batch_idx : bool, default=True
If True, add a tensor of batch indices.
batch_idx_key : str, default="batch_idx"
Key for the batch index tensor.
keys : list[str], optional
If provided, only collate these tensor keys.
collate_metadata : bool, default=False
If True, collate metadata into lists. Default is False for
compatibility with PyTorch DataLoader.
"""
self.dim = dim
self.add_batch_idx = add_batch_idx
self.batch_idx_key = batch_idx_key
self.keys = keys
self.collate_metadata = collate_metadata
def __call__(
self, samples: Sequence[tuple[TensorDict, dict[str, Any]]]
) -> TensorDict | tuple[TensorDict, list[dict[str, Any]]]:
"""
Collate samples by concatenating tensors.
Parameters
----------
samples : Sequence[tuple[TensorDict, dict[str, Any]]]
Sequence of (TensorDict, metadata) tuples to batch.
Returns
-------
TensorDict or tuple[TensorDict, list[dict[str, Any]]]
Batched TensorDict if collate_metadata=False (default),
or tuple of (batched TensorDict, list of metadata dicts)
if collate_metadata=True.
Raises
------
ValueError
If samples is empty.
"""
if not samples:
raise ValueError("Cannot collate empty sequence of samples")
# Separate data
data_list = [data for data, _ in samples]
first_data = data_list[0]
keys = self.keys if self.keys else list(first_data.keys())
device = first_data.device
batched_tensors = {}
sizes = [] # Track sizes for batch indices
for key in keys:
tensors = []
for data in data_list:
if key not in data.keys():
raise ValueError(f"Data missing key '{key}'")
tensor = data[key]
tensors.append(tensor)
if key == keys[0]: # Track sizes from first key
sizes.append(tensor.shape[self.dim])
batched_tensors[key] = torch.cat(tensors, dim=self.dim)
# Add batch indices
if self.add_batch_idx:
batch_indices = []
for i, size in enumerate(sizes):
batch_indices.append(
torch.full((size,), i, dtype=torch.long, device=device)
)
batched_tensors[self.batch_idx_key] = torch.cat(batch_indices, dim=0)
# Create batched TensorDict
batched_data = TensorDict(batched_tensors, device=device)
# Collate metadata only if requested
if self.collate_metadata:
metadata_list = [meta for _, meta in samples]
return batched_data, _collate_metadata(metadata_list)
return batched_data
[docs]
class FunctionCollator(Collator):
"""
Collator that wraps a user-provided function.
Allows using any function as a collator without subclassing.
Examples
--------
>>> def my_collate(samples):
... # Custom logic
... data_list = [d for d, _ in samples]
... metadata_list = [m for _, m in samples]
... return torch.stack(data_list), metadata_list
>>> collator = FunctionCollator(my_collate)
"""
def __init__(
self,
fn: Callable[
[Sequence[tuple[TensorDict, dict[str, Any]]]],
tuple[TensorDict, list[dict[str, Any]]],
],
) -> None:
"""
Initialize with a collation function.
Parameters
----------
fn : Callable
Function that takes a sequence of (TensorDict, dict) tuples
and returns a (TensorDict, list[dict]) tuple.
"""
self.fn = fn
def __call__(
self, samples: Sequence[tuple[TensorDict, dict[str, Any]]]
) -> TensorDict | tuple[TensorDict, list[dict[str, Any]]]:
"""Apply the wrapped function."""
return self.fn(samples)
# Default collator instance
_default_collator = DefaultCollator()
def default_collate(
samples: Sequence[tuple[TensorDict, dict[str, Any]]],
) -> tuple[TensorDict, list[dict[str, Any]]]:
"""
Default collation function using stacking.
Convenience function that uses DefaultCollator.
Metadata is collated into a list of dicts.
Parameters
----------
samples : Sequence[tuple[TensorDict, dict[str, Any]]]
Sequence of (TensorDict, metadata) tuples to batch.
Returns
-------
tuple[TensorDict, list[dict[str, Any]]]
Tuple of (batched TensorDict, list of metadata dicts).
"""
return _default_collator(samples)
def concat_collate(
samples: Sequence[tuple[TensorDict, dict[str, Any]]],
dim: int = 0,
add_batch_idx: bool = True,
) -> tuple[TensorDict, list[dict[str, Any]]]:
"""
Collation function using concatenation.
Convenience function that uses ConcatCollator.
Metadata is collated into a list of dicts.
Parameters
----------
samples : Sequence[tuple[TensorDict, dict[str, Any]]]
Sequence of (TensorDict, metadata) tuples to batch.
dim : int, default=0
Dimension along which to concatenate.
add_batch_idx : bool, default=True
If True, add batch index tensor.
Returns
-------
tuple[TensorDict, list[dict[str, Any]]]
Tuple of (batched TensorDict, list of metadata dicts).
"""
collator = ConcatCollator(dim=dim, add_batch_idx=add_batch_idx)
return collator(samples)
def get_collator(
collate_fn: Collator
| Callable[
[Sequence[tuple[TensorDict, dict[str, Any]]]],
tuple[TensorDict, list[dict[str, Any]]],
]
| None = None,
*,
collate_metadata: bool = False,
) -> Collator:
"""
Get a Collator instance from various input types.
Parameters
----------
collate_fn : Collator or Callable, optional
Collator, callable, or None (uses default).
collate_metadata : bool, default=False
If True, collate metadata into list. Only used when collate_fn is None.
Default is False for compatibility with PyTorch DataLoader.
Returns
-------
Collator
Collator instance.
Raises
------
TypeError
If collate_fn is not a Collator, callable, or None.
"""
if collate_fn is None:
return DefaultCollator(collate_metadata=collate_metadata)
elif isinstance(collate_fn, Collator):
return collate_fn
elif callable(collate_fn):
return FunctionCollator(collate_fn)
else:
raise TypeError(
f"collate_fn must be Collator, callable, or None, "
f"got {type(collate_fn).__name__}"
)