Source code for physicsnemo.datapipes.dataset

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Dataset - Combines a Reader with a transform pipeline.

The Dataset is the primary interface for accessing preprocessed data.
It wraps a Reader and applies transforms to produce ready-to-use TensorDicts.
Supports prefetching with CUDA streams for overlapped IO and computation,
and automatic device transfer when device parameter is specified.
"""

from __future__ import annotations

from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from typing import Any, Iterator, Optional, Sequence

import torch
from tensordict import TensorDict

from physicsnemo.datapipes.readers.base import Reader
from physicsnemo.datapipes.registry import register
from physicsnemo.datapipes.transforms.base import Transform
from physicsnemo.datapipes.transforms.compose import Compose
from physicsnemo.distributed import DistributedManager


@dataclass
class _PrefetchResult:
    """Result of a prefetch operation."""

    index: int
    data: Optional[TensorDict] = None
    metadata: Optional[dict[str, Any]] = None
    error: Optional[Exception] = None
    event: Optional[torch.cuda.Event] = None  # For stream sync


[docs] @register() class Dataset: """ A dataset combining a Reader with a transform pipeline. The Dataset provides a torch-like interface for accessing data: - Indexing: dataset[i] returns transformed sample i - Iteration: for sample in dataset - Length: len(dataset) - Prefetching: dataset.prefetch(i, stream) for async loading The pipeline is: Reader → Transforms → Sample Prefetching Model ----------------- The dataset supports prefetching samples using a thread pool. When a CUDA stream is provided, GPU operations (device transfer, GPU transforms) happen on that stream, allowing overlap with other computation. >>> # Start prefetching >>> dataset.prefetch(0, stream=stream0) # doctest: +SKIP >>> dataset.prefetch(1, stream=stream1) # doctest: +SKIP >>> >>> # Retrieve results (waits if not ready) >>> sample_0 = dataset[0] # Uses prefetched result # doctest: +SKIP Examples -------- >>> from physicsnemo.datapipes import Dataset, HDF5Reader, Normalize >>> >>> reader = HDF5Reader("data.h5", fields=["pressure", "velocity"]) # doctest: +SKIP >>> transforms = Normalize( # doctest: +SKIP ... ["pressure"], ... method="mean_std", ... means={"pressure": 0.0}, # doctest: +SKIP ... stds={"pressure": 1.0}, # doctest: +SKIP ... ) >>> >>> dataset = Dataset(reader, transforms=transforms, device="cuda") # doctest: +SKIP >>> sample, metadata = dataset[0] # doctest: +SKIP """ def __init__( self, reader: Reader, *, transforms: Optional[Transform | Sequence[Transform]] = None, device: Optional[str | torch.device] = None, num_workers: int = 2, ) -> None: """ Initialize the dataset. Parameters ---------- reader : Reader Data reader providing raw samples. transforms : Transform or Sequence[Transform], optional Transform or sequence of transforms to apply. If a sequence, they are composed in order. device : str or torch.device, optional Target device for automatic transfer (e.g., "cuda", "cuda:0"). If None, no automatic transfer is performed (data stays on CPU). When specified, data is transferred to this device before transforms. If device is "auto", will select the device with distributed manager. Auto device falls back to CPU. num_workers : int, default=2 Number of worker threads for prefetching. Raises ------ TypeError If reader is not a Reader instance. """ if not isinstance(reader, Reader): raise TypeError( f"reader must be a Reader instance, got {type(reader).__name__}" ) self.reader = reader self.num_workers = num_workers if device == "auto": if torch.cuda.is_available(): if DistributedManager.is_initialized(): device = DistributedManager().device else: device = "cuda:0" else: device = "cpu" # Now, instantiate the device if not already done: match device: case torch.device(): self.target_device = device case str(): self.target_device = torch.device(device) case None: self.target_device = None # Handle transforms if transforms is None: self.transforms: Optional[Transform] = None elif isinstance(transforms, Transform): self.transforms = transforms elif isinstance(transforms, Sequence): if len(transforms) == 0: self.transforms = None elif len(transforms) == 1: self.transforms = transforms[0] else: self.transforms = Compose(transforms) else: raise TypeError( f"transforms must be Transform, Sequence[Transform], or None, " f"got {type(transforms).__name__}" ) # Share device with transforms so their internal state is on the right device if self.target_device is not None and self.transforms is not None: self.transforms.to(self.target_device) # Prefetch state - using thread-safe dict for results # Key: index, Value: Future[_PrefetchResult] self._prefetch_futures: dict[int, Future[_PrefetchResult]] = {} self._executor: Optional[ThreadPoolExecutor] = None def _ensure_executor(self) -> ThreadPoolExecutor: """ Lazily create the thread pool executor. Returns ------- ThreadPoolExecutor The thread pool executor for prefetching. """ if self._executor is None: self._executor = ThreadPoolExecutor( max_workers=self.num_workers, thread_name_prefix="datapipe_prefetch", ) return self._executor def _load_and_transform( self, index: int, stream: Optional[torch.cuda.Stream] = None, ) -> _PrefetchResult: """ Load a sample and apply transforms. Called by worker threads. Parameters ---------- index : int Sample index. stream : torch.cuda.Stream, optional Optional CUDA stream for GPU operations. Returns ------- _PrefetchResult PrefetchResult with data, metadata, or error. """ result = _PrefetchResult(index=index) try: # Load from reader (CPU, potentially slow IO) data, metadata = self.reader[index] # Auto-transfer to target device if specified if self.target_device is not None: if stream is not None: with torch.cuda.stream(stream): data = data.to(self.target_device, non_blocking=True) else: data = data.to(self.target_device, non_blocking=True) # Apply transforms (data is now on target device if specified) if self.transforms is not None: if stream is not None: with torch.cuda.stream(stream): data = self.transforms(data) # Record event for synchronization result.event = torch.cuda.Event() result.event.record(stream) else: data = self.transforms(data) result.data = data result.metadata = metadata except Exception as e: result.error = e return result
[docs] def prefetch( self, index: int, stream: Optional[torch.cuda.Stream] = None, ) -> None: """ Start prefetching a sample asynchronously. The sample will be loaded in a background thread. If a CUDA stream is provided, GPU operations happen on that stream. Call __getitem__ to retrieve the result (it will wait if needed). Parameters ---------- index : int Sample index to prefetch. stream : torch.cuda.Stream, optional Optional CUDA stream for GPU operations. """ # Don't prefetch if already in flight if index in self._prefetch_futures: return executor = self._ensure_executor() future = executor.submit(self._load_and_transform, index, stream) self._prefetch_futures[index] = future
[docs] def prefetch_batch( self, indices: Sequence[int], streams: Optional[Sequence[torch.cuda.Stream]] = None, ) -> None: """ Start prefetching multiple samples. Parameters ---------- indices : Sequence[int] Sample indices to prefetch. streams : Sequence[torch.cuda.Stream], optional Optional CUDA streams, one per index. If shorter than indices, streams are cycled. If None, no streams used. """ for i, idx in enumerate(indices): stream = None if streams: stream = streams[i % len(streams)] self.prefetch(idx, stream=stream)
def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: """ Get a transformed sample by index. If the index was prefetched, returns the prefetched result (waiting for completion if necessary). Otherwise loads synchronously. Parameters ---------- index : int Sample index. Returns ------- tuple[TensorDict, dict[str, Any]] Tuple of (TensorDict with transformed data, metadata dict). Raises ------ IndexError If index is out of range. Exception If prefetch failed, re-raises the error. """ # Check if prefetched future = self._prefetch_futures.pop(index, None) if future is not None: # Wait for prefetch to complete result = future.result() if result.error is not None: raise result.error # Sync stream if needed if result.event is not None: result.event.synchronize() return result.data, result.metadata # Not prefetched, load synchronously data, metadata = self.reader[index] # Auto-transfer to target device if specified if self.target_device is not None: data = data.to(self.target_device, non_blocking=True) # Apply transforms if self.transforms is not None: data = self.transforms(data) return data, metadata
[docs] def cancel_prefetch(self, index: Optional[int] = None) -> None: """ Cancel prefetch requests. Note: Already-running tasks will complete, but results are discarded. Parameters ---------- index : int, optional Specific index to cancel. If None, cancels all. """ if index is None: # Cancel all - just clear the dict, let futures complete self._prefetch_futures.clear() else: self._prefetch_futures.pop(index, None)
def __len__(self) -> int: """ Return the number of samples in the dataset. Returns ------- int Number of samples. """ return len(self.reader) def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: """ Iterate over all samples. Note: This does NOT automatically prefetch. For prefetched iteration, use the DataLoader which manages prefetching strategy. Yields ------ tuple[TensorDict, dict[str, Any]] Tuple of (transformed data, metadata) for each sample. """ for i in range(len(self)): yield self[i] @property def field_names(self) -> list[str]: """ List of field names in samples (from reader). Returns ------- list[str] Field names available in samples. """ return self.reader.field_names @property def prefetch_count(self) -> int: """ Number of items currently being prefetched. Returns ------- int Count of in-flight prefetch operations. """ return len(self._prefetch_futures)
[docs] def close(self) -> None: """ Close the dataset and stop prefetching. Waits for any in-flight prefetch tasks to complete before shutdown. This prevents "cannot schedule new futures after shutdown" errors from libraries like zarr that use async I/O internally. """ # Wait for any in-flight prefetch tasks to complete before shutdown. # This prevents "cannot schedule new futures after shutdown" errors # from libraries like zarr that use async I/O internally. for future in self._prefetch_futures.values(): try: future.result(timeout=30.0) # Wait up to 30s per task except Exception: # noqa: BLE001, S110 pass # Ignore errors during shutdown self._prefetch_futures.clear() if self._executor is not None: self._executor.shutdown(wait=True) self._executor = None self.reader.close()
def __enter__(self) -> "Dataset": """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: """Context manager exit.""" self.close() def __repr__(self) -> str: """ Return string representation. Returns ------- str String representation of the Dataset. """ transform_str = repr(self.transforms) if self.transforms else "None" return f"Dataset(\n reader={self.reader},\n transforms={transform_str}\n)"