Source code for physicsnemo.datapipes.dataloader

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
DataLoader - Batched iteration over datasets with prefetching.

The DataLoader orchestrates efficient batch loading by leveraging
the Dataset's prefetching capabilities with CUDA streams.
By default, returns batched TensorDict for PyTorch DataLoader compatibility.
When collate_metadata=True, returns (TensorDict, list[dict]) tuples.
"""

from __future__ import annotations

from typing import Any, Callable, Iterator, Optional, Sequence

import torch
from tensordict import TensorDict
from torch.utils.data import RandomSampler, Sampler, SequentialSampler

from physicsnemo.datapipes._rng import fork_generator
from physicsnemo.datapipes.collate import Collator, get_collator
from physicsnemo.datapipes.protocols import DatasetBase
from physicsnemo.datapipes.registry import register


[docs] @register() class DataLoader: """ Batched iteration over a Dataset with stream-based prefetching. Unlike PyTorch's DataLoader which uses CPU multiprocessing, this DataLoader uses CUDA streams to overlap data loading, preprocessing, and collation. This is more efficient for SciML workloads where: - Datasets are huge - Batches are small - Preprocessing benefits from GPU acceleration Features: - Stream-based parallelism (one stream per sample in flight) - Toggleable prefetching for debugging - Compatible with PyTorch samplers (DistributedSampler, etc.) - Familiar torch DataLoader interface Examples -------- >>> from physicsnemo.datapipes import DataLoader, Dataset, HDF5Reader, Normalize >>> >>> dataset = Dataset( # doctest: +SKIP ... HDF5Reader("data.h5", fields=["input", "target"]), ... transforms=Normalize(["input"], method="mean_std", means={"input": 0.0}, stds={"input": 1.0}), ... device="cuda", # Automatic GPU transfer ... ) >>> loader = DataLoader(dataset, batch_size=16, shuffle=True) # doctest: +SKIP >>> >>> for batch in loader: # doctest: +SKIP ... output = model(batch["input"]) With DistributedSampler: >>> from torch.utils.data.distributed import DistributedSampler >>> sampler = DistributedSampler(dataset) # doctest: +SKIP >>> loader = DataLoader(dataset, batch_size=16, sampler=sampler) # doctest: +SKIP """ def __init__( self, dataset: DatasetBase, *, batch_size: int = 1, shuffle: bool = False, sampler: Optional[Sampler] = None, drop_last: bool = False, collate_fn: Optional[ Collator | Callable[ [Sequence[tuple[TensorDict, dict[str, Any]]]], tuple[TensorDict, list[dict[str, Any]]], ] ] = None, collate_metadata: bool = False, prefetch_factor: int = 2, num_streams: int = 4, use_streams: bool = True, seed: int | None = None, ) -> None: """ Initialize the DataLoader. Parameters ---------- dataset : DatasetBase Dataset to load from. Any subclass of :class:`DatasetBase` (e.g. :class:`Dataset`, :class:`MeshDataset`). batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False If True, shuffle indices each epoch. Ignored if sampler provided. sampler : Sampler, optional Custom sampler for index generation. If provided, shuffle is ignored. drop_last : bool, default=False If True, drop the last incomplete batch. collate_fn : Collator or Callable, optional Function to collate samples into batches. Defaults to stacking. collate_metadata : bool, default=False If True, collate metadata into a list of dicts. Set to False for compatibility with PyTorch DataLoader. Only used when collate_fn is None (uses default collator). prefetch_factor : int, default=2 Number of batches to prefetch ahead. Set to 0 to disable prefetching. num_streams : int, default=4 Number of CUDA streams for prefetching. use_streams : bool, default=True If True, use CUDA streams for overlap. Set False for debugging or CPU-only operation. seed : int, optional Master seed for all pipeline randomness. When set, the DataLoader derives independent generators for the sampler, reader, and every stochastic transform, making the full pipeline reproducible. Use :meth:`set_epoch` to vary the random sequence across epochs while staying deterministic. Raises ------ ValueError If batch_size < 1. """ if batch_size < 1: raise ValueError(f"batch_size must be >= 1, got {batch_size}") self.dataset = dataset self.batch_size = batch_size self.shuffle = shuffle self.drop_last = drop_last self.prefetch_factor = prefetch_factor self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() self._seed = seed # Build master generator and fork for sampler + dataset sampler_generator: torch.Generator | None = None if seed is not None: master = torch.Generator() master.manual_seed(seed) # Fork: child 0 → sampler, child 1 → dataset forks = fork_generator(master, 2) sampler_generator = forks[0] if hasattr(dataset, "set_generator"): dataset.set_generator(forks[1]) # Handle sampler if sampler is not None: self.sampler = sampler # For DistributedSampler, propagate seed if available if seed is not None and hasattr(sampler, "seed"): # DistributedSampler exposes seed as a constructor arg # but it's read-only; users should pass seed at construction. pass elif shuffle: self.sampler = RandomSampler(dataset, generator=sampler_generator) else: self.sampler = SequentialSampler(dataset) # Handle collation self.collate_fn = get_collator(collate_fn, collate_metadata=collate_metadata) # Create CUDA streams for prefetching self._streams: list[torch.cuda.Stream] = [] if self.use_streams: for _ in range(num_streams): self._streams.append(torch.cuda.Stream()) def __len__(self) -> int: """ Return the number of batches. Returns ------- int Number of batches in the dataloader. """ n_samples = ( len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset) ) if self.drop_last: return n_samples // self.batch_size return (n_samples + self.batch_size - 1) // self.batch_size def _generate_batches(self) -> Iterator[list[int]]: """ Generate batches of indices. Yields ------ list[int] List of sample indices for each batch. """ batch = [] for idx in self.sampler: batch.append(idx) if len(batch) == self.batch_size: yield batch batch = [] if batch and not self.drop_last: yield batch def __iter__( self, ) -> Iterator[TensorDict | tuple[TensorDict, list[dict[str, Any]]]]: """ Iterate over batches. Uses stream-based prefetching when enabled to overlap IO, GPU transfers, and computation. Yields ------ TensorDict or tuple[TensorDict, list[dict[str, Any]]] Batched TensorDict if collate_metadata=False (default), or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. """ if self.prefetch_factor > 0 and self.use_streams: yield from self._iter_prefetch() else: yield from self._iter_simple() def _iter_simple( self, ) -> Iterator[TensorDict | tuple[TensorDict, list[dict[str, Any]]]]: """ Simple synchronous iteration without prefetching. Yields ------ TensorDict or tuple[TensorDict, list[dict[str, Any]]] Collated batch. """ for batch_indices in self._generate_batches(): samples = [self.dataset[idx] for idx in batch_indices] yield self.collate_fn(samples) def _iter_prefetch( self, ) -> Iterator[TensorDict | tuple[TensorDict, list[dict[str, Any]]]]: """ Iteration with stream-based prefetching. Strategy: 1. Prefetch `prefetch_factor` batches worth of samples 2. As we yield batches, prefetch more to keep the pipeline full 3. Each sample in a batch uses a different stream for overlap Yields ------ TensorDict or tuple[TensorDict, list[dict[str, Any]]] Collated batch. """ # Collect all batches upfront for prefetch planning all_batches = list(self._generate_batches()) if not all_batches: return num_prefetch_batches = min(self.prefetch_factor, len(all_batches)) stream_idx = 0 # Start initial prefetch prefetched_up_to = 0 for batch_idx in range(num_prefetch_batches): for sample_idx in all_batches[batch_idx]: stream = self._streams[stream_idx % self.num_streams] self.dataset.prefetch(sample_idx, stream=stream) stream_idx += 1 prefetched_up_to = batch_idx + 1 # Yield batches and prefetch more for batch_idx, batch_indices in enumerate(all_batches): # Collect samples (uses prefetched if available) samples = [self.dataset[idx] for idx in batch_indices] batch = self.collate_fn(samples) # Prefetch next batch if available next_prefetch_idx = prefetched_up_to if next_prefetch_idx < len(all_batches): for sample_idx in all_batches[next_prefetch_idx]: stream = self._streams[stream_idx % self.num_streams] self.dataset.prefetch(sample_idx, stream=stream) stream_idx += 1 prefetched_up_to += 1 yield batch # Clean up any remaining prefetch state self.dataset.cancel_prefetch()
[docs] def set_epoch(self, epoch: int) -> None: """ Set the epoch for the sampler and the full data pipeline. Propagates the epoch to the sampler (for :class:`~torch.utils.data.distributed.DistributedSampler`), the dataset, reader, and every stochastic transform so all RNG streams are reseeded deterministically. Parameters ---------- epoch : int Current epoch number. """ if hasattr(self.sampler, "set_epoch"): self.sampler.set_epoch(epoch) if hasattr(self.dataset, "set_epoch"): self.dataset.set_epoch(epoch)
[docs] def enable_prefetch(self) -> None: """ Enable stream-based prefetching. Raises ------ RuntimeError If CUDA is not available. """ if not torch.cuda.is_available(): raise RuntimeError( "CUDA is not available, cannot enable stream prefetching" ) if not self._streams: for _ in range(self.num_streams): self._streams.append(torch.cuda.Stream()) self.use_streams = True
[docs] def disable_prefetch(self) -> None: """Disable prefetching (useful for debugging).""" self.use_streams = False self.dataset.cancel_prefetch()
def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the DataLoader. """ return ( f"DataLoader(\n" f" dataset={self.dataset},\n" f" batch_size={self.batch_size},\n" f" shuffle={self.shuffle},\n" f" drop_last={self.drop_last},\n" f" prefetch_factor={self.prefetch_factor},\n" f" num_streams={self.num_streams},\n" f" use_streams={self.use_streams}\n" f")" )