Source code for physicsnemo.datapipes.gnn.drivaernet_dataset

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path
from typing import Iterable

import pandas as pd
import torch
import yaml
from torch import Tensor
from torch.utils.data import Dataset

from physicsnemo.core.version_check import OptionalImport
from physicsnemo.datapipes.datapipe import Datapipe
from physicsnemo.datapipes.meta import DatapipeMetaData
from physicsnemo.nn.module.gnn_layers.utils import PyGData

# Lazy imports for optional dependencies
pyg = OptionalImport("torch_geometric")
pv = OptionalImport("pyvista")
vtk = OptionalImport("vtk")


[docs] @dataclass class MetaData(DatapipeMetaData): name: str = "DrivAerNet" # Optimization auto_device: bool = True cuda_graphs: bool = False # Parallel ddp_sharding: bool = True
[docs] class DrivAerNetDataset(Dataset, Datapipe): """ DrivAerNet dataset. Note: DrivAerNetDataset caches graphs in __getitem__ call which helps to avoid long initialization delay but increases first epoch time. Parameters ---------- data_dir: str The directory where the data is stored. split: str, optional The dataset split. Can be 'train', 'validation', or 'test', by default 'train'. num_samples: int, optional The number of samples to use, by default 10. coeff_filename: str, optional DrivAerNet coefficients file name, default is from the dataset location. invar_keys: Iterable[str], optional The input node features to consider. Default includes 'pos'. outvar_keys: Iterable[str], optional The output features to consider. Default includes 'p' and 'wallShearStress'. normalize_keys Iterable[str], optional The features to normalize. Default includes 'p' and 'wallShearStress'. cache_dir: str, optional Path to the cache directory to store graphs in PyG format for fast loading. Default is ./cache/. name: str, optional The name of the dataset, by default 'dataset'. force_reload: bool, optional If True, forces a reload of the cached data, by default False. """ def __init__( self, data_dir: str | Path, split: str = "train", num_samples: int = 10, coeff_filename: str = "AeroCoefficients_DrivAerNet_FilteredCorrected.csv", invar_keys: Iterable[str] = ("pos",), outvar_keys: Iterable[str] = ("p", "wallShearStress"), normalize_keys: Iterable[str] = ("p", "wallShearStress"), cache_dir: str | Path = "./cache/", name: str = "dataset", force_reload: bool = False, **kwargs, ) -> None: Datapipe.__init__(self, meta=MetaData()) self.name = name self.data_dir = Path(data_dir) if not self.data_dir.is_dir(): raise ValueError( f"Path {self.data_dir} does not exist or is not a directory." ) self.p_vtk_dir = self.data_dir / "SurfacePressureVTK" self.wss_vtk_dir = self.data_dir / "WallShearStressVTK" self.split = split.lower() if split not in (splits := ["train", "val", "test"]): raise ValueError(f"{split = } is not supported, must be one of {splits}.") self.force_reload = force_reload self.num_samples = num_samples self.input_keys = list(invar_keys) self.output_keys = list(outvar_keys) self.normalize_keys = list(normalize_keys) self.cache_dir = ( self._get_cache_dir(self.data_dir, Path(cache_dir)) if cache_dir is not None else None ) # Load split design ids used to select a corresponding data split. design_ids = pd.read_csv( self.data_dir / f"{split}_design_ids.txt", header=None, index_col=0 ) # Read coefficients file which contains Cd, Cl etc. coeffs = pd.read_csv(self.data_dir / coeff_filename, index_col="Design") coeffs = coeffs.join(design_ids, how="inner") # Read projected areas file which is in YAML-like format with entries that look like: # combined_DrivAer_F_D_WM_WW_1234.stl: 2.574603830871618 with open(self.data_dir / "projected_areas.txt", encoding="utf-8") as f: y = yaml.safe_load(f) proj_areas = pd.DataFrame.from_dict( {k.removeprefix("combined_").removesuffix(".stl"): v for k, v in y.items()}, orient="index", columns=["proj_area_x"], ) # TODO(akamenev): # DrivAerNet issue #1: there are 10 entries missing in # projected_areas.txt: # train: DrivAer_F_D_WM_WW_0132, 0797, 1118, 1421, 1556, 1891, 2353, 2459. # val: DrivAer_F_D_WM_WW_0603, 3199. # # DrivAerNet issue #2: there are 2 entries for which WSS vtk files are empty. # # Filter both of them out (can do it via join but this is more explicit). missing_ids = { "DrivAer_F_D_WM_WW_0132", "DrivAer_F_D_WM_WW_0603", "DrivAer_F_D_WM_WW_0797", "DrivAer_F_D_WM_WW_1118", "DrivAer_F_D_WM_WW_1421", "DrivAer_F_D_WM_WW_1556", "DrivAer_F_D_WM_WW_1891", "DrivAer_F_D_WM_WW_2353", "DrivAer_F_D_WM_WW_2459", "DrivAer_F_D_WM_WW_3199", } empty_wss = { "DrivAer_F_D_WM_WW_0978", "DrivAer_F_D_WM_WW_3641", } coeffs = coeffs.drop(missing_ids | empty_wss, errors="ignore") # Merge projected areas into the coeffs dataframe. coeffs = coeffs.join(proj_areas, how="inner") if self.num_samples > len(coeffs): raise ValueError( f"Number of available {self.split} dataset entries " f"({len(coeffs)}) is less than the number of samples " f"({self.num_samples})" ) coeffs.sort_index(inplace=True) self.coeffs = coeffs.iloc[: self.num_samples] # TODO(akamenev): these are estimates from small sample, need to compute from full data. self.nstats = { k: {"mean": v[0], "std": v[1]} for k, v in { "p": (-94.50448, 117.25317), "wallShearStress": ( torch.tensor([-0.56926626, 0.0027714, -0.07354721]), torch.tensor([0.82198745, 0.45956784, 0.7490267]), ), }.items() } self.estats = { "x": { "mean": torch.tensor([0, 0, 0, 0.01338306]), "std": torch.tensor([0.00512953, 0.00953013, 0.00923065, 0.00482016]), } } def __len__(self) -> int: return len(self.coeffs) def __getitem__(self, idx: int) -> PyGData: if not 0 <= idx < len(self): raise IndexError(f"Invalid {idx = }, must be in [0, {len(self)})") coeffs = self.coeffs.iloc[idx] gname = coeffs.name if self.cache_dir is None: # Caching is disabled - create the graph. graph = self._create_graph(gname) else: cached_graph_filename = self.cache_dir / (gname + ".pt") if not self.force_reload and cached_graph_filename.is_file(): graph = torch.load(cached_graph_filename, weights_only=False) else: graph = self._create_graph(gname) Path.mkdir(self.cache_dir, parents=True, exist_ok=True) torch.save(graph, cached_graph_filename) # Set graph inputs/outputs. graph.x = torch.cat([graph[k] for k in self.input_keys], dim=-1) graph.y = torch.cat([graph[k] for k in self.output_keys], dim=-1) return { "name": gname, "graph": graph, "c_d": torch.tensor(coeffs["Average Cd"], dtype=torch.float32), } @staticmethod def _get_cache_dir(data_dir, cache_dir): if not cache_dir.is_absolute(): cache_dir = data_dir / cache_dir return cache_dir.resolve() def _create_graph( self, name: str, to_bidirected: bool = True, ) -> PyGData: """Creates a PyG graph from DrivAerNet VTK data. Parameters ---------- name : str Name of the graph in DrivAerNet. to_bidirected : bool, optional Whether to make the graph bidirected. Default is True. Returns ------- PyGData The PyG graph. """ def extract_edges(mesh: "pv.PolyData") -> list[tuple[int, int]]: # Extract connectivity information from the mesh. # Traversal API is faster comparing to iterating over mesh.cell. polys = mesh.GetPolys() if polys is None: raise ValueError("Failed to get polygons from the mesh.") polys.InitTraversal() edge_list = [] for _ in range(polys.GetNumberOfCells()): id_list = vtk.vtkIdList() polys.GetNextCell(id_list) num_ids = id_list.GetNumberOfIds() for j in range(num_ids - 1): edge_list.append( # noqa: PERF401 (id_list.GetId(j), id_list.GetId(j + 1)) ) # Add the final edge between the last and the first vertices. edge_list.append((id_list.GetId(num_ids - 1), id_list.GetId(0))) return edge_list def permute_mesh(p_vtk_path: Path, wss_vtk_path: Path) -> Tensor: # The issue with DrivAerNet dataset is pressure and WSS meshes # are stored in different files. Even though each file contains # the same mesh coordinates, the nodes are permuted (order does not match) # which makes it impossible to do simple point_data assignment. # This method permutes WSS mesh by using vtkProbeFilter. p_reader = vtk.vtkPolyDataReader() p_reader.SetFileName(p_vtk_path) p_reader.Update() p_out = p_reader.GetOutput() wss_reader = vtk.vtkPolyDataReader() wss_reader.SetFileName(wss_vtk_path) wss_reader.Update() wss_out = wss_reader.GetOutput() probe = vtk.vtkProbeFilter() # p mesh is the input for which corresponding values from # wss mesh are retrieved. probe.SetInputData(p_out) probe.SetSourceData(wss_out) probe.Update() probe_out = probe.GetOutput() wss_arr = probe_out.GetPointData().GetArray("wallShearStress") num_points = p_out.GetNumberOfPoints() wss = torch.empty((num_points, 3), dtype=torch.float32) for i in range(num_points): x, y, z = wss_arr.GetTuple3(i) wss[i, 0] = x wss[i, 1] = y wss[i, 2] = z return wss # Load the pressure mesh even if p is not selected. # The p and wss meshes contain the same mesh nodes, # so use nodes from p for simplicity. p_vtk_path = self.p_vtk_dir / (name + ".vtk") p_mesh = pv.read(p_vtk_path) edge_list = extract_edges(p_mesh) # Create PyG graph using the connectivity information edges = torch.tensor(edge_list).t() if to_bidirected: edges = pyg.utils.to_undirected(edges) graph = pyg.data.Data(edge_index=edges) # Assign node features using the vertex data graph.pos = torch.tensor(p_mesh.points, dtype=torch.float32) if (k := "p") in self.output_keys: graph[k] = torch.tensor(p_mesh.point_data[k], dtype=torch.float32) if (k := "wallShearStress") in self.output_keys: wss_vtk_path = self.wss_vtk_dir / (name + ".vtk") graph[k] = permute_mesh(p_vtk_path, wss_vtk_path) # Normalize nodes. for k in self.input_keys + self.output_keys: if k not in self.normalize_keys: continue v = (graph[k] - self.nstats[k]["mean"]) / self.nstats[k]["std"] graph[k] = v.unsqueeze(-1) if v.ndim == 1 else v # Add edge features which contain relative edge nodes displacement and # displacement norm. Stored as `x` in the graph edge data. u, v = graph.edge_index pos = graph.pos disp = pos[u] - pos[v] disp_norm = torch.linalg.norm(disp, dim=-1, keepdim=True) graph.edge_attr = torch.cat((disp, disp_norm), dim=-1) # Normalize edges. graph.edge_attr = (graph.edge_attr - self.estats["x"]["mean"]) / self.estats[ "x" ]["std"] return graph
[docs] @torch.no_grad def denormalize( self, pred: Tensor, gt: Tensor, device: torch.device ) -> tuple[Tensor, Tensor]: """Denormalizes the inputs using previously collected statistics.""" def denorm(x: Tensor, name: str): stats = self.nstats[name] mean = torch.as_tensor(stats["mean"]).to(device) std = torch.as_tensor(stats["std"]).to(device) return x * std + mean pred_d = [] gt_d = [] pred_d.append(denorm(pred[:, :1], "p")) gt_d.append(denorm(gt[:, :1], "p")) if (k := "wallShearStress") in self.output_keys: pred_d.append(denorm(pred[:, 1:4], k)) gt_d.append(denorm(gt[:, 1:4], k)) return torch.cat(pred_d, dim=-1), torch.cat(gt_d, dim=-1)