Source code for physicsnemo.datapipes.multi_dataset

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
MultiDataset - Compose multiple Dataset instances behind a single dataset-like interface.

MultiDataset presents a single index space (concatenation of all constituent datasets)
and delegates __getitem__, prefetch, and close to the appropriate sub-dataset.
Each sub-dataset can have its own Reader and transforms. Optional output strictness
validates that all sub-datasets produce the same TensorDict keys (outputs) so default
collation works.
"""

from __future__ import annotations

from typing import Any, Iterator, Optional, Sequence

import torch
from tensordict import TensorDict

from physicsnemo.datapipes._rng import fork_generator
from physicsnemo.datapipes.protocols import DatasetBase
from physicsnemo.datapipes.registry import register

# Metadata key added by MultiDataset to identify which sub-dataset produced the sample.
DATASET_INDEX_METADATA_KEY = "dataset_index"


def _validate_strict_outputs(datasets: Sequence[DatasetBase]) -> list[str]:
    """
    Check that all non-empty datasets produce the same TensorDict keys; return them.

    Loads one sample from each non-empty dataset and compares output keys (after
    transforms). This validates output schema, not reader field_names.

    Parameters
    ----------
    datasets : Sequence[DatasetBase]
        Datasets to validate.

    Returns
    -------
    list[str]
        Common output keys (sorted, from first non-empty dataset).

    Raises
    ------
    ValueError
        If any non-empty dataset has different output keys.
    """
    if not datasets:
        return []
    ref_keys: Optional[list[str]] = None
    ref_index: Optional[int] = None
    for i, ds in enumerate(datasets):
        if len(ds) == 0:
            continue
        data, _ = ds[0]
        keys = sorted(data.keys())
        if ref_keys is None:
            ref_keys = keys
            ref_index = i
        elif keys != ref_keys:
            raise ValueError(
                "output_strict=True requires identical output keys (TensorDict keys) "
                f"across datasets: dataset {ref_index} has {ref_keys}, dataset {i} has {keys}"
            )
    if ref_keys is not None:
        return list(ref_keys)
    first = datasets[0]
    return list(first.field_names) if hasattr(first, "field_names") else []


[docs] @register() class MultiDataset: r""" A dataset that composes multiple :class:`DatasetBase` instances behind one index space. Accepts both :class:`Dataset` (TensorDict pipelines) and :class:`MeshDataset` (Mesh pipelines) as sub-datasets. Global indices are mapped to (dataset_index, local_index) by concatenation: indices 0..len0-1 come from the first dataset, len0..len0+len1-1 from the second, and so on. Each constituent can have its own Reader and transforms. Metadata is enriched with ``dataset_index`` so batches can identify the source. Parameters ---------- *datasets : DatasetBase One or more Dataset or MeshDataset instances passed as positional arguments (Reader + transforms each). Order defines index mapping: first dataset occupies 0..len(ds0)-1, etc. output_strict : bool, default=True If True, require all datasets to produce the same TensorDict keys (output keys after transforms) so :class:`DefaultCollator` can stack batches. If False, no check is done; use a custom collator when keys or shapes differ. Note that `output_strict=True` will load the first instance of all datasets upon construction. Think of it as a debugging parameter: if you are sure that your datasets are working properly, and want to defer loading, you can safely disable this. Raises ------ ValueError If no datasets are provided or if ``output_strict=True`` and output keys differ. Notes ----- MultiDataset implements the same interface as :class:`Dataset` (``__len__``, ``__getitem__``, ``prefetch``, ``cancel_prefetch``, ``close``, ``field_names``) and can be passed to :class:`DataLoader` in place of a single dataset. Prefetch and close are delegated to the sub-dataset that owns the index. When ``output_strict=True``, validation checks that each dataset's *output* TensorDict (after transforms) has the same keys, not the reader's field_names. When ``output_strict=False``, :attr:`field_names` returns the first dataset's field names; with heterogeneous datasets, prefer a custom collator and use metadata ``dataset_index`` to group or pad by source. Shuffling and sampling --------------------- The DataLoader sees one linear index space of size :math:`\\sum_k \\text{len}(\\text{datasets}[k])`. With ``shuffle=True``, the default :class:`RandomSampler` shuffles these global indices, so each batch is a random mix of samples from all sub-datasets. There is no per-dataset balancing: if one dataset is much larger, its samples will appear more often. For balanced or stratified sampling, use a custom :class:`torch.utils.data.Sampler` (e.g. weighted or one sample per dataset per batch) and pass it to the DataLoader. Metadata -------- Every sample returned by :meth:`__getitem__` has its metadata dict extended with the key :const:`DATASET_INDEX_METADATA_KEY` (``"dataset_index"``), the integer index of the sub-dataset that produced the sample (0 for the first dataset, 1 for the second, etc.). Sub-dataset–specific metadata (e.g. file path, index within that dataset) is unchanged. When using the DataLoader with ``collate_metadata=True``, each batch yields a list of metadata dicts aligned with the batch dimension; each dict includes ``dataset_index`` so you can filter, weight, or aggregate by source in the training loop. Examples -------- >>> from physicsnemo.datapipes import Dataset, MultiDataset, HDF5Reader, Normalize >>> ds_a = Dataset(HDF5Reader("a.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP >>> ds_b = Dataset(HDF5Reader("b.h5", fields=["x", "y"]), transforms=None) # doctest: +SKIP >>> multi = MultiDataset(ds_a, ds_b, output_strict=True) # doctest: +SKIP >>> len(multi) == len(ds_a) + len(ds_b) # doctest: +SKIP True >>> data, meta = multi[0] # from ds_a # doctest: +SKIP >>> meta["dataset_index"] # 0 # doctest: +SKIP """ def __init__( self, *datasets: DatasetBase, output_strict: bool = True, ) -> None: if len(datasets) < 1: raise ValueError( f"MultiDataset requires at least one dataset, got {len(datasets)}" ) for i, ds in enumerate(datasets): if not isinstance(ds, DatasetBase): raise TypeError( f"datasets[{i}] must be a Dataset or MeshDataset instance, " f"got {type(ds).__name__}" ) self._datasets = list(datasets) self._output_strict = output_strict # Cumulative lengths: cumul[k] = sum(len(datasets[j]) for j in range(k)) # So index i is in dataset k when cumul[k] <= i < cumul[k+1], local = i - cumul[k] cumul = [0] for ds in self._datasets: cumul.append(cumul[-1] + len(ds)) self._cumul = cumul if output_strict: self._field_names = _validate_strict_outputs(self._datasets) else: first = self._datasets[0] if hasattr(first, "field_names"): self._field_names = list(first.field_names) else: self._field_names = [] def _index_to_dataset_and_local(self, index: int) -> tuple[int, int]: """ Map global index to (dataset_index, local_index). Parameters ---------- index : int Global index in [0, len(self)). Returns ------- tuple[int, int] (dataset_index, local_index). Raises ------ IndexError If index is out of range. """ n = len(self) if index < 0: index = n + index if index < 0 or index >= n: raise IndexError( f"Index {index} out of range for MultiDataset with {n} samples" ) # Find k such that cumul[k] <= index < cumul[k+1] for k in range(len(self._cumul) - 1): if self._cumul[k] <= index < self._cumul[k + 1]: return k, index - self._cumul[k] # Fallback (should not be reached) return len(self._datasets) - 1, index - self._cumul[-2] def _index_to_dataset_and_local_optional( self, index: int ) -> Optional[tuple[int, int]]: """ Map global index to (dataset_index, local_index), or None if out of range. Used by cancel_prefetch to match Dataset behavior (no-op for invalid index). """ n = len(self) if index < 0: index = n + index if index < 0 or index >= n: return None for k in range(len(self._cumul) - 1): if self._cumul[k] <= index < self._cumul[k + 1]: return k, index - self._cumul[k] return len(self._datasets) - 1, index - self._cumul[-2] def __len__(self) -> int: """Return the total number of samples (sum of all sub-dataset lengths).""" return self._cumul[-1] # ------------------------------------------------------------------ # RNG management # ------------------------------------------------------------------
[docs] def set_generator(self, generator: torch.Generator) -> None: """Fork *generator* and distribute one child per sub-dataset. Parameters ---------- generator : torch.Generator Parent generator (typically forked from the DataLoader's master generator). """ children = fork_generator(generator, len(self._datasets)) for child, ds in zip(children, self._datasets): if hasattr(ds, "set_generator"): ds.set_generator(child)
[docs] def set_epoch(self, epoch: int) -> None: """Propagate epoch to every sub-dataset. Parameters ---------- epoch : int Current epoch number. """ for ds in self._datasets: if hasattr(ds, "set_epoch"): ds.set_epoch(epoch)
def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: """ Return the transformed sample and metadata for the given global index. Metadata is enriched with ``dataset_index`` (key :const:`DATASET_INDEX_METADATA_KEY`). Parameters ---------- index : int Global sample index. Supports negative indexing. Returns ------- tuple[TensorDict, dict[str, Any]] (TensorDict, metadata dict) from the owning sub-dataset. """ ds_id, local_i = self._index_to_dataset_and_local(index) data, metadata = self._datasets[ds_id][local_i] metadata = dict(metadata) metadata[DATASET_INDEX_METADATA_KEY] = ds_id return data, metadata
[docs] def prefetch( self, index: int, stream: Optional[Any] = None, ) -> None: """ Start prefetching the sample at the given global index. Delegates to the sub-dataset that owns that index. Parameters ---------- index : int Global sample index to prefetch. stream : object, optional Optional CUDA stream for the sub-dataset prefetch. """ ds_id, local_i = self._index_to_dataset_and_local(index) self._datasets[ds_id].prefetch(local_i, stream=stream)
[docs] def cancel_prefetch(self, index: Optional[int] = None) -> None: """ Cancel prefetch for the given index or all sub-datasets. When index is provided, only cancels if it is in range; out-of-range indices are ignored to match :class:`Dataset` behavior. Parameters ---------- index : int, optional Global index to cancel, or None to cancel all. """ if index is None: for ds in self._datasets: ds.cancel_prefetch(None) else: mapped = self._index_to_dataset_and_local_optional(index) if mapped is not None: ds_id, local_i = mapped self._datasets[ds_id].cancel_prefetch(local_i)
def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: """Iterate over all samples in global index order.""" for i in range(len(self)): yield self[i] @property def field_names(self) -> list[str]: """ Field names in samples. With ``output_strict=True``, returns the common output keys (TensorDict keys after transforms). With ``output_strict=False``, returns the first dataset's field names. """ return list(self._field_names)
[docs] def close(self) -> None: """Close all sub-datasets and release resources.""" for ds in self._datasets: ds.close()
def __enter__(self) -> "MultiDataset": """Context manager entry.""" return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: """Context manager exit.""" self.close() def __repr__(self) -> str: parts = [f" ({i}): {ds}" for i, ds in enumerate(self._datasets)] return ( f"MultiDataset(\n output_strict={self._output_strict},\n datasets=[\n" + ",\n".join(parts) + "\n ]\n)" )