# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Reader base class - Abstract interface for data sources.
Readers are simple, transactional data loaders. They load data from sources
and return TensorDict instances with CPU tensors plus separate metadata dicts.
Device transfers and threading are handled elsewhere (Dataset and DataLoader).
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Iterator
import torch
from tensordict import TensorDict
logger = logging.getLogger(__name__)
[docs]
class Reader(ABC):
"""
Abstract base class for data readers.
Readers are intentionally simple and transactional:
- Load data from a source (file, database, etc.)
- Return (TensorDict, metadata_dict) tuples with CPU tensors
- No threading, no prefetching, no device transfers
This design makes custom readers easy to implement. Users only need to:
1. Implement ``_load_sample(index)`` to load raw data
2. Implement ``__len__()`` to return dataset size
Device transfers are handled automatically by Dataset (if device parameter set).
Threading/prefetching is handled by the DataLoader.
Examples
--------
Custom reader implementation:
>>> class MyReader(Reader): # doctest: +SKIP
... def __init__(self, path: str, **kwargs):
... super().__init__(**kwargs)
... self.data = load_my_data(path)
...
... def _load_sample(self, index: int) -> dict[str, torch.Tensor]:
... return {"x": torch.from_numpy(self.data[index])}
...
... def __len__(self) -> int:
... return len(self.data)
Subclasses must implement:
- ``_load_sample(index: int) -> dict[str, torch.Tensor]``
- ``__len__() -> int``
Optionally override:
- ``_get_field_names() -> list[str]``
- ``_get_sample_metadata(index: int) -> dict[str, Any]``
- ``close()``
"""
def __init__(
self,
*,
pin_memory: bool = False,
include_index_in_metadata: bool = True,
coordinated_subsampling: dict[str, Any] | None = None,
) -> None:
"""
Initialize the reader.
Parameters
----------
pin_memory : bool, default=False
If True, place tensors in pinned (page-locked) memory.
This enables faster async CPU→GPU transfers later.
Only use if you plan to move data to GPU.
include_index_in_metadata : bool, default=True
If True, include sample index in metadata.
coordinated_subsampling : dict[str, Any], optional
Optional dict to configure coordinated subsampling at construction
time. If provided, must contain:
- ``n_points``: Number of points to read from each target tensor
- ``target_keys``: List of tensor keys to apply subsampling to
This allows configuration via Hydra. Readers that don't support
coordinated subsampling will ignore this parameter.
"""
self.pin_memory = pin_memory
self.include_index_in_metadata = include_index_in_metadata
self._coordinated_subsampling_config = coordinated_subsampling
@abstractmethod
def _load_sample(self, index: int) -> dict[str, torch.Tensor]:
"""
Load raw data for a single sample.
This is the main method to implement. Load data from your source
and return it as a dictionary of CPU tensors.
Parameters
----------
index : int
Sample index (0 to len-1).
Returns
-------
dict[str, torch.Tensor]
Dictionary mapping field names to CPU tensors.
Raises
------
IndexError
If index is out of range.
"""
raise NotImplementedError
@abstractmethod
def __len__(self) -> int:
"""
Return the number of samples in the dataset.
Returns
-------
int
Number of samples.
"""
raise NotImplementedError
def _get_field_names(self) -> list[str]:
"""
Return the list of field names in samples.
Override this to provide field names without loading a sample.
Default implementation loads sample 0 and extracts keys.
Returns
-------
list[str]
List of field names available in samples.
"""
if len(self) == 0:
return []
data = self._load_sample(0)
return list(data.keys())
def _get_sample_metadata(self, index: int) -> dict[str, Any]:
"""
Return metadata for a sample.
Override this to provide source-specific metadata (filenames, etc.).
Default implementation returns empty dict (index added separately).
Parameters
----------
index : int
Sample index.
Returns
-------
dict[str, Any]
Dictionary of metadata (not tensors).
"""
return {}
@property
def _supports_coordinated_subsampling(self) -> bool:
"""
Return True if this reader supports coordinated subsampling.
Override this property in subclasses that implement coordinated
subsampling.
Returns
-------
bool
True if coordinated subsampling is supported.
"""
return False
@property
def field_names(self) -> list[str]:
"""
List of field names available in samples.
Returns
-------
list[str]
Field names.
"""
return self._get_field_names()
def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]:
"""
Load and return a single sample.
Parameters
----------
index : int
Sample index. Supports negative indexing.
Returns
-------
tuple[TensorDict, dict[str, Any]]
Tuple of (TensorDict with CPU tensors, metadata dict).
Raises
------
IndexError
If index is out of range.
"""
# Handle negative indexing
if index < 0:
index = len(self) + index
if index < 0 or index >= len(self):
raise IndexError(
f"Index {index} out of range for reader with {len(self)} samples"
)
# Load data
data_dict = self._load_sample(index)
# Build metadata
metadata = self._get_sample_metadata(index)
if self.include_index_in_metadata:
metadata["index"] = index
# Pin memory if requested
if self.pin_memory:
data_dict = {k: v.pin_memory() for k, v in data_dict.items()}
# Create TensorDict
data = TensorDict(data_dict, device=torch.device("cpu"))
return data, metadata
def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]:
"""
Iterate over all samples.
Yields
------
tuple[TensorDict, dict[str, Any]]
Tuple of (TensorDict with CPU tensors, metadata dict) for each sample.
Raises
------
RuntimeError
If a sample fails to load, wrapping the original exception with
context about which sample failed.
"""
for i in range(len(self)):
try:
yield self[i]
except Exception as e:
error_msg = f"Sample {i} failed with exception: {type(e).__name__}: {e}"
logger.error(error_msg)
raise RuntimeError(error_msg) from e
[docs]
def set_generator(self, generator: torch.Generator) -> None:
"""Assign a ``torch.Generator`` for reproducible random sampling.
Override in subclasses that use randomness (e.g. subsampling).
The default implementation is a no-op.
Parameters
----------
generator : torch.Generator
Generator to use for random draws.
"""
[docs]
def set_epoch(self, epoch: int) -> None:
"""Reseed the reader's RNG for a new epoch.
Override in subclasses that use randomness.
The default implementation is a no-op.
Parameters
----------
epoch : int
Current epoch number.
"""
[docs]
def close(self) -> None:
"""
Clean up resources (file handles, connections, etc.).
Override this in subclasses that hold open resources.
"""
pass
def __enter__(self) -> "Reader":
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""Context manager exit."""
self.close()
def __repr__(self) -> str:
"""
Return string representation.
Returns
-------
str
String representation of the reader.
"""
return (
f"{self.__class__.__name__}(len={len(self)}, pin_memory={self.pin_memory})"
)