Source code for physicsnemo.datapipes.readers.zarr

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
ZarrReader - Read data from Zarr arrays.

Supports reading from a directory of Zarr groups, one sample per group.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Optional

import numpy as np
import torch

from physicsnemo.core.version_check import OptionalImport
from physicsnemo.datapipes.readers.base import Reader
from physicsnemo.datapipes.registry import register

zarr = OptionalImport("zarr")


[docs] @register() class ZarrReader(Reader): """ Read samples from Zarr groups. Zarr is a chunked, compressed array format ideal for large scientific datasets. Each Zarr group in the directory represents one sample. Supports loading both arrays and attributes from Zarr groups. Examples -------- Basic usage: >>> # Directory with sample_0.zarr, sample_1.zarr, ... >>> # Each contains arrays like "positions", "features", etc. >>> reader = ZarrReader("data_dir/", group_pattern="sample_*.zarr") # doctest: +SKIP >>> data, metadata = reader[0] # Returns (TensorDict, dict) tuple # doctest: +SKIP Load only specific fields: >>> reader = ZarrReader("data_dir/", fields=["positions", "velocity"]) # doctest: +SKIP >>> data, metadata = reader[0] # doctest: +SKIP Load attributes from Zarr groups: >>> # If the Zarr group has attributes like "timestep" or "scale_factor", >>> # you can request them as fields: >>> reader = ZarrReader("data_dir/", fields=["positions", "timestep", "scale_factor"]) # doctest: +SKIP >>> data, metadata = reader[0] # data["timestep"] contains the attribute value # doctest: +SKIP With coordinated subsampling for large arrays: >>> reader = ZarrReader( # doctest: +SKIP ... "data_dir/", ... coordinated_subsampling={ ... "n_points": 50000, ... "target_keys": ["volume_coords", "volume_fields"], ... } ... ) >>> data, metadata = reader[0] # doctest: +SKIP """ def __init__( self, path: str | Path, *, fields: Optional[list[str]] = None, default_values: Optional[dict[str, torch.Tensor]] = None, group_pattern: str = "*.zarr", pin_memory: bool = False, include_index_in_metadata: bool = True, coordinated_subsampling: Optional[dict[str, Any]] = None, cache_stores: bool = True, ) -> None: """ Initialize the Zarr reader. Parameters ---------- path : str or Path Path to directory containing Zarr groups. fields : list[str], optional List of array or attribute names to load. If None, loads all available arrays from each group. When a field name matches an attribute key (and not an array), the attribute value will be converted to a tensor. Note: string attributes are not supported. default_values : dict[str, torch.Tensor], optional Dictionary mapping field names to default tensors. If a field in ``fields`` is not found in the file but has an entry here, the default tensor is used instead of raising an error. Useful for optional fields. group_pattern : str, default="*.zarr" Glob pattern for finding Zarr groups. pin_memory : bool, default=False If True, place tensors in pinned memory for faster GPU transfer. include_index_in_metadata : bool, default=True If True, include sample index in metadata. coordinated_subsampling : dict[str, Any], optional Optional dict to configure coordinated subsampling. If provided, must contain ``n_points`` (int) and ``target_keys`` (list of str). cache_stores : bool, default=True If True, cache opened zarr stores to avoid repeated opening and prevent executor shutdown errors. Set to False if memory is a concern with many groups. Raises ------ ImportError If zarr is not installed. FileNotFoundError If path doesn't exist. ValueError If no Zarr groups found in directory. """ if not zarr.available: zarr._get_module() # Raises RuntimeError with install hint super().__init__( pin_memory=pin_memory, include_index_in_metadata=include_index_in_metadata, coordinated_subsampling=coordinated_subsampling, ) self.path = Path(path).expanduser().resolve() self._user_fields = fields self.default_values = default_values or {} self.group_pattern = group_pattern self._cache_stores = cache_stores self._cached_stores: dict[Path, Any] = {} # Cache for opened zarr stores if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") if not self.path.is_dir(): raise ValueError( f"Path must be a directory containing Zarr groups: {self.path}" ) # Detect mode: single group or directory of groups self._single_group_mode = self._is_zarr_group(self.path) if self._single_group_mode: # Single Zarr group - samples indexed along first dimension self._groups = [self.path] root = zarr.open(self.path, mode="r") if isinstance(root, zarr.Array): raise ValueError( f"Expected Zarr group with named arrays, got single array at " f"{self.path}. Path should be a Zarr group containing named arrays." ) self._available_fields = list(root.array_keys()) # Get length from first array's first dimension if not self._available_fields: raise ValueError(f"Zarr group {self.path} contains no arrays") first_array = root[self._available_fields[0]] self._length = first_array.shape[0] else: # Directory containing multiple Zarr groups self._groups = sorted( [p for p in self.path.glob(group_pattern) if self._is_zarr_group(p)] ) if not self._groups: raise ValueError( f"No Zarr groups matching '{group_pattern}' found in {self.path}" ) self._length = len(self._groups) # Discover available fields from first group root = zarr.open(self._groups[0], mode="r") if isinstance(root, zarr.Array): raise ValueError( f"Expected Zarr group with named arrays, got single array at " f"{self._groups[0]}. Each sample should be a Zarr group containing " f"named arrays." ) self._available_fields = list(root.array_keys()) @property def fields(self) -> list[str]: """Fields that will be loaded (user-specified or all available).""" if self._user_fields is not None: return self._user_fields return self._available_fields def _open_zarr_store(self, path: Path) -> Any: """ Open a zarr store, using cache if enabled. This prevents the "cannot schedule new futures after shutdown" error by reusing opened stores instead of repeatedly calling zarr.open(). Parameters ---------- path : Path Path to the zarr group. Returns ------- Any Opened zarr group. """ if self._cache_stores: if path not in self._cached_stores: self._cached_stores[path] = zarr.open(path, mode="r") return self._cached_stores[path] else: return zarr.open(path, mode="r") def _is_zarr_group(self, path: Path) -> bool: """ Check if a path is a Zarr group. A Zarr group is identified by the presence of a zarr.json file (v3) or .zgroup file (v2). """ return (path / "zarr.json").exists() or (path / ".zgroup").exists() def _select_random_sections_from_slice( self, slice_start: int, slice_stop: int, n_points: int, ) -> slice: """ Select a random contiguous slice from a range. Parameters ---------- slice_start : int Start index of the available range. slice_stop : int Stop index of the available range (exclusive). n_points : int Number of points to sample. Returns ------- slice A slice object representing the random contiguous section. Raises ------ ValueError If the range is smaller than n_points. """ total_points = slice_stop - slice_start if total_points < n_points: raise ValueError( f"Slice size {total_points} is less than the number of points " f"{n_points} requested for subsampling" ) start = np.random.randint(slice_start, slice_stop - n_points + 1) return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group.""" if self._single_group_mode: # Single group: index into first dimension of each array group_path = self._groups[0] root = self._open_zarr_store(group_path) else: # Directory mode: each group is one sample group_path = self._groups[index] root = self._open_zarr_store(group_path) data = {} fields_to_load = self.fields # Discover available arrays and attributes for this sample at runtime available_arrays = set(root.array_keys()) available_attrs = set(root.attrs.keys()) if hasattr(root, "attrs") else set() available = available_arrays | available_attrs # Check for missing required fields (check both arrays and attributes) required_fields = set(fields_to_load) - set(self.default_values.keys()) missing_fields = required_fields - available if missing_fields: raise KeyError( f"Required fields {missing_fields} not found in {group_path}. " f"Available arrays: {list(available_arrays)}, " f"Available attributes: {list(available_attrs)}" ) # Determine subsample slice if coordinated subsampling is enabled subsample_slice = None target_keys_set = set() if self._coordinated_subsampling_config is not None: n_points = self._coordinated_subsampling_config["n_points"] target_keys_set = set(self._coordinated_subsampling_config["target_keys"]) # Find slice from first available target key for field in target_keys_set: if field in root: if self._single_group_mode: # In single group mode, subsample along dimensions after the first array_shape = root[field].shape[1] else: array_shape = root[field].shape[0] subsample_slice = self._select_random_sections_from_slice( 0, array_shape, n_points ) break # Load each field for field in fields_to_load: if field in root: if self._single_group_mode: # Single group mode: index into first dimension if subsample_slice is not None and field in target_keys_set: # Apply subsampling on dimensions after the first data[field] = torch.from_numpy( root[field][index, subsample_slice] ) else: data[field] = torch.from_numpy(root[field][index]) else: # Directory mode: load entire array or subsample if subsample_slice is not None and field in target_keys_set: data[field] = torch.from_numpy(root[field][subsample_slice]) else: data[field] = torch.from_numpy(root[field][:]) elif field in available_attrs: # Load from attributes (discovered at runtime for this sample) attr_value = root.attrs[field] data[field] = self._convert_attr_to_tensor(attr_value, field) elif field in self.default_values: data[field] = self.default_values[field].clone() return data def _convert_attr_to_tensor(self, value: Any, field_name: str) -> torch.Tensor: """ Convert an attribute value to a torch.Tensor. Parameters ---------- value : Any The attribute value to convert. field_name : str Name of the field (for error messages). Returns ------- torch.Tensor A tensor containing the attribute value. Raises ------ TypeError If the attribute value cannot be converted to a tensor. """ try: match value: case np.ndarray(): return torch.from_numpy(value) case list() | tuple(): return torch.tensor(value) case int() | float() | bool(): return torch.tensor(value) case str(): raise TypeError( f"Cannot convert string attribute '{field_name}' to tensor. " f"String attributes are not supported." ) case _: # Try to convert via numpy return torch.from_numpy(np.asarray(value)) except (TypeError, ValueError) as e: raise TypeError( f"Cannot convert attribute '{field_name}' of type {type(value).__name__} " f"to tensor: {e}" ) from e def __len__(self) -> int: """Return number of samples.""" return self._length def _get_field_names(self) -> list[str]: """Return field names that will be loaded.""" return self.fields def _get_sample_metadata(self, index: int) -> dict[str, Any]: """Return metadata for a sample including source info.""" if self._single_group_mode: return { "source_file": str(self._groups[0]), "source_filename": self._groups[0].name, "sample_index": index, } else: return { "source_file": str(self._groups[index]), "source_filename": self._groups[index].name, } @property def _supports_coordinated_subsampling(self) -> bool: """Zarr reader supports coordinated subsampling.""" return True
[docs] def close(self) -> None: """Close resources and cached zarr stores.""" # Clear cached stores to allow garbage collection # This helps prevent executor shutdown issues self._cached_stores.clear() super().close()
def __repr__(self) -> str: subsample_info = "" if self._coordinated_subsampling_config is not None: cfg = self._coordinated_subsampling_config subsample_info = f", subsampling={cfg['n_points']} points" return ( f"ZarrReader(" f"path={self.path}, " f"len={len(self)}, " f"fields={self.fields}, " f"cache_stores={self._cache_stores}" f"{subsample_info})" )