Source code for physicsnemo.datapipes.readers.vtk

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
VTKReader - Read data from VTK format files (.stl, .vtp, .vtu).

Supports reading mesh data from directories containing VTK files using PyVista.
"""

from __future__ import annotations

import importlib
from pathlib import Path
from typing import Any, Optional

import numpy as np
import torch

from physicsnemo.core.version_check import check_version_spec
from physicsnemo.datapipes.readers.base import Reader
from physicsnemo.datapipes.registry import register

# Check if pyvista is available
PYVISTA_AVAILABLE = check_version_spec("pyvista", hard_fail=False)

if PYVISTA_AVAILABLE:
    pv = importlib.import_module("pyvista")


[docs] @register() class VTKReader(Reader): r""" Read samples from VTK format files (.stl, .vtp, .vtu). This reader loads mesh data from directories where each subdirectory contains VTK files representing one sample. Supports STL (surface meshes), VTP (PolyData), and VTU (UnstructuredGrid) formats. Requires PyVista to be installed. If PyVista is not available, attempting to instantiate this reader will raise an ImportError with installation instructions. Examples -------- >>> # Directory structure: >>> # data/ >>> # sample_0/ >>> # geometry.stl >>> # surface.vtp >>> # sample_1/ >>> # geometry.stl >>> # surface.vtp >>> # ... >>> >>> reader = VTKReader( # doctest: +SKIP ... "data/", ... keys_to_read=["stl_coordinates", "stl_faces", "surface_normals"], ... ) >>> data, metadata = reader[0] # Returns (TensorDict, dict) tuple # doctest: +SKIP >>> print(data["stl_coordinates"].shape) # (N, 3) # doctest: +SKIP Available Keys: From .stl files: - ``stl_coordinates``: Vertex coordinates, shape :math:`(N, 3)` - ``stl_faces``: Face indices (flattened), shape :math:`(M*3,)` - ``stl_centers``: Face centers, shape :math:`(M, 3)` - ``surface_normals``: Face normals, shape :math:`(M, 3)` From .vtp files: - ``surface_mesh_centers``: Cell centers - ``surface_normals``: Cell normals - ``surface_mesh_sizes``: Cell areas - Additional fields from the VTP file Note: VTK files are typically small enough to fit in memory, so coordinated subsampling is not supported. Use transforms for downsampling if needed. """ def __init__( self, path: str | Path, *, keys_to_read: Optional[list[str]] = None, exclude_patterns: Optional[list[str]] = None, pin_memory: bool = False, include_index_in_metadata: bool = True, ) -> None: """ Initialize the VTK reader. Parameters ---------- path : str or Path Path to directory containing subdirectories with VTK files. keys_to_read : list[str], optional List of keys to extract from VTK files. If None, extracts all available data. exclude_patterns : list[str], optional List of filename patterns to exclude (e.g., ["single_solid"]). pin_memory : bool, default=False If True, place tensors in pinned memory for faster GPU transfer. include_index_in_metadata : bool, default=True If True, include sample index in metadata. Raises ------ ImportError If PyVista is not installed. FileNotFoundError If path doesn't exist. ValueError If no valid VTK directories found. """ if not PYVISTA_AVAILABLE: raise ImportError( "PyVista is required for VTKReader but is not installed.\n" "Install it with: pip install pyvista\n" "See https://docs.pyvista.org/getting-started/installation.html " "for more information." ) super().__init__( pin_memory=pin_memory, include_index_in_metadata=include_index_in_metadata, ) self.path = Path(path) self.keys_to_read = keys_to_read self.exclude_patterns = exclude_patterns or ["single_solid"] if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") if not self.path.is_dir(): raise ValueError(f"Path must be a directory: {self.path}") # Find all subdirectories containing VTK files self._directories = [] for subdir in self.path.iterdir(): if subdir.is_dir() and self._is_vtk_directory(subdir): self._directories.append(subdir) self._directories = sorted(self._directories) if not self._directories: raise ValueError( f"No directories containing VTK files found in {self.path}" ) self._length = len(self._directories) # Supported file keys mapped to file extensions self._stl_keys = { "stl_coordinates", "stl_centers", "stl_faces", "stl_areas", "surface_normals", } self._vtp_keys = { "surface_mesh_centers", "surface_normals", "surface_mesh_sizes", "CpMeanTrim", "pMeanTrim", "wallShearStressMeanTrim", } self._vtu_keys = { "volume_mesh_centers", "volume_fields", } def _is_vtk_directory(self, directory: Path) -> bool: """Check if a directory contains VTK files.""" vtk_extensions = {".stl", ".vtp", ".vtu", ".vtk"} for file in directory.iterdir(): if file.suffix in vtk_extensions: return True return False def _get_file_by_extension(self, directory: Path, extension: str) -> Optional[Path]: """Get the first file with the given extension, excluding patterns.""" for file in directory.iterdir(): if file.suffix == extension: # Check if any exclude pattern is in the filename if not any(pattern in file.name for pattern in self.exclude_patterns): return file return None def _read_stl_data(self, stl_path: Path) -> dict[str, torch.Tensor]: """Read data from an STL file.""" mesh = pv.read(stl_path) data = {} # Extract faces (reshape from flat array to triangles) faces = mesh.faces.reshape(-1, 4) faces = faces[:, 1:] # Remove the first column (always 3 for triangles) data["stl_faces"] = torch.from_numpy(faces.flatten()) # Extract coordinates data["stl_coordinates"] = torch.from_numpy(mesh.points) # Extract normals data["surface_normals"] = torch.from_numpy(mesh.cell_normals) # Compute face centers (for stl_centers) # Each face has 3 vertices, compute the mean vertices = mesh.points face_indices = faces face_centers = vertices[face_indices].mean(axis=1) data["stl_centers"] = torch.from_numpy(face_centers) # Compute face areas (for stl_areas) # Area of triangle = 0.5 * ||cross(v1-v0, v2-v0)|| v0 = vertices[face_indices[:, 0]] v1 = vertices[face_indices[:, 1]] v2 = vertices[face_indices[:, 2]] cross_prod = np.cross(v1 - v0, v2 - v0) areas = 0.5 * np.linalg.norm(cross_prod, axis=1) data["stl_areas"] = torch.from_numpy(areas) return data def _read_vtp_data(self, vtp_path: Path) -> dict[str, torch.Tensor]: """Read data from a VTP file.""" # VTP reading is not yet implemented in the original cae_dataset.py # Placeholder for future implementation raise NotImplementedError( "VTP file reading is not yet implemented. " "This will be added in a future update." ) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a VTK directory.""" directory = self._directories[index] result = {} # Determine which file types to read based on requested keys need_stl = self.keys_to_read is None or any( key in self._stl_keys for key in self.keys_to_read ) need_vtp = self.keys_to_read is not None and any( key in self._vtp_keys for key in self.keys_to_read ) need_vtu = self.keys_to_read is not None and any( key in self._vtu_keys for key in self.keys_to_read ) # Read STL data if needed if need_stl: stl_path = self._get_file_by_extension(directory, ".stl") if stl_path: stl_data = self._read_stl_data(stl_path) result.update(stl_data) # Read VTP data if needed if need_vtp: vtp_path = self._get_file_by_extension(directory, ".vtp") if vtp_path: vtp_data = self._read_vtp_data(vtp_path) result.update(vtp_data) # Read VTU data if needed if need_vtu: raise NotImplementedError("VTU file reading is not yet implemented.") # Filter to requested keys if specified if self.keys_to_read is not None: result = {k: v for k, v in result.items() if k in self.keys_to_read} return result def __len__(self) -> int: """Return number of samples.""" return self._length def _get_field_names(self) -> list[str]: """Return field names.""" if self.keys_to_read is not None: return self.keys_to_read # Load first sample to discover available keys if len(self) == 0: return [] sample = self._load_sample(0) return list(sample.keys()) def _get_sample_metadata(self, index: int) -> dict[str, Any]: """Return metadata for a sample including source directory info.""" return { "source_file": str(self._directories[index]), "source_filename": self._directories[index].name, } @property def _supports_coordinated_subsampling(self) -> bool: """VTK files don't support coordinated subsampling.""" return False def __repr__(self) -> str: return f"VTKReader(path={self.path}, len={len(self)}, keys={self.keys_to_read})"