# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code contains the DoMINO model architecture.
The DoMINO class contains an architecture to model both surface and
volume quantities together as well as separately (controlled using
the config.yaml file)
"""
import math
from collections import defaultdict
from typing import Callable, Literal, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from physicsnemo.models.unet import UNet
from physicsnemo.utils.neighbors import radius_search
from physicsnemo.utils.profiling import profile
[docs]
def get_activation(activation: Literal["relu", "gelu"]) -> Callable:
"""
Return a PyTorch activation function corresponding to the given name.
"""
if activation == "relu":
return F.relu
elif activation == "gelu":
return F.gelu
else:
raise ValueError(f"Activation function {activation} not found")
[docs]
def fourier_encode(coords, num_freqs):
"""Function to caluculate fourier features"""
# Create a range of frequencies
freqs = torch.exp(torch.linspace(0, math.pi, num_freqs, device=coords.device))
# Generate sine and cosine features
features = [torch.sin(coords * f) for f in freqs] + [
torch.cos(coords * f) for f in freqs
]
ret = torch.cat(features, dim=-1)
return ret
[docs]
def fourier_encode_vectorized(coords, freqs):
"""Vectorized Fourier feature encoding"""
D = coords.shape[-1]
F = freqs.shape[0]
freqs = freqs[None, None, :, None] # reshape to [*, F, 1] for broadcasting
coords = coords.unsqueeze(-2) # [*, 1, D]
scaled = (coords * freqs).reshape(*coords.shape[:-2], D * F) # [*, D, F]
features = torch.cat([torch.sin(scaled), torch.cos(scaled)], dim=-1) # [*, D, 2F]
return features.reshape(*coords.shape[:-2], D * 2 * F) # [*, D * 2F]
[docs]
def calculate_pos_encoding(nx, d=8):
"""Function to caluculate positional encoding"""
vec = []
for k in range(int(d / 2)):
vec.append(torch.sin(nx / 10000 ** (2 * (k) / d)))
vec.append(torch.cos(nx / 10000 ** (2 * (k) / d)))
return vec
[docs]
def scale_sdf(sdf: torch.Tensor) -> torch.Tensor:
"""
Scale a signed distance function (SDF) to emphasize surface regions.
This function applies a non-linear scaling to the SDF values that compresses
the range while preserving the sign, effectively giving more weight to points
near surfaces where abs(SDF) is small.
Args:
sdf: Tensor containing signed distance function values
Returns:
Tensor with scaled SDF values in range [-1, 1]
"""
return sdf / (0.4 + torch.abs(sdf))
[docs]
class BQWarp(nn.Module):
"""
Warp-based ball-query layer for finding neighboring points within a specified radius.
This layer uses an accelerated ball query implementation to efficiently find points
within a specified radius of query points.
"""
def __init__(
self,
grid_resolution=None,
radius: float = 0.25,
neighbors_in_radius: int = 10,
):
"""
Initialize the BQWarp layer.
Args:
grid_resolution: Resolution of the grid in each dimension [nx, ny, nz]
radius: Radius for ball query operation
neighbors_in_radius: Maximum number of neighbors to return within radius
"""
super().__init__()
if grid_resolution is None:
grid_resolution = [256, 96, 64]
self.radius = radius
self.neighbors_in_radius = neighbors_in_radius
self.grid_resolution = grid_resolution
[docs]
def forward(
self, x: torch.Tensor, p_grid: torch.Tensor, reverse_mapping: bool = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Performs ball query operation to find neighboring points and their features.
This method uses the Warp-accelerated ball query implementation to find points
within a specified radius. It can operate in two modes:
- Forward mapping: Find points from x that are near p_grid points (reverse_mapping=False)
- Reverse mapping: Find points from p_grid that are near x points (reverse_mapping=True)
Args:
x: Tensor of shape (batch_size, num_points, 3+features) containing point coordinates
and their features
p_grid: Tensor of shape (batch_size, grid_x, grid_y, grid_z, 3) containing grid point
coordinates
reverse_mapping: Boolean flag to control the direction of the mapping:
- True: Find p_grid points near x points
- False: Find x points near p_grid points
Returns:
tuple containing:
- mapping: Tensor containing indices of neighboring points
- outputs: Tensor containing coordinates of the neighboring points
"""
batch_size = x.shape[0]
nx, ny, nz = self.grid_resolution
p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3))
if reverse_mapping:
mapping, outputs = radius_search(
x[0],
p_grid[0],
self.radius,
self.neighbors_in_radius,
return_points=True,
)
mapping = mapping.unsqueeze(0)
outputs = outputs.unsqueeze(0)
else:
mapping, outputs = radius_search(
p_grid[0],
x[0],
self.radius,
self.neighbors_in_radius,
return_points=True,
)
mapping = mapping.unsqueeze(0)
outputs = outputs.unsqueeze(0)
return mapping, outputs
[docs]
class GeoConvOut(nn.Module):
"""
Geometry layer to project STL geometry data onto regular grids.
"""
def __init__(
self,
input_features: int,
model_parameters,
grid_resolution=None,
):
"""
Initialize the GeoConvOut layer.
Args:
input_features: Number of input feature dimensions
model_parameters: Configuration parameters for the model
grid_resolution: Resolution of the output grid [nx, ny, nz]
"""
super().__init__()
if grid_resolution is None:
grid_resolution = [256, 96, 64]
base_neurons = model_parameters.base_neurons
self.fourier_features = model_parameters.fourier_features
self.num_modes = model_parameters.num_modes
if self.fourier_features:
input_features_calculated = input_features * (1 + 2 * self.num_modes)
else:
input_features_calculated = input_features
self.fc1 = nn.Linear(input_features_calculated, base_neurons)
self.fc2 = nn.Linear(base_neurons, base_neurons // 2)
self.fc3 = nn.Linear(base_neurons // 2, model_parameters.base_neurons_in)
self.grid_resolution = grid_resolution
self.activation = get_activation(model_parameters.activation)
if self.fourier_features:
self.register_buffer(
"freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes))
)
[docs]
def forward(
self,
x: torch.Tensor,
grid: torch.Tensor,
radius: float = 0.025,
neighbors_in_radius: int = 10,
) -> torch.Tensor:
"""
Process and project geometric features onto a 3D grid.
Args:
x: Input tensor containing coordinates of the neighboring points
(batch_size, nx*ny*nz, 3, n_points)
grid: Input tensor represented as a grid of shape
(batch_size, nx, ny, nz, 3)
Returns:
Processed geometry features of shape (batch_size, base_neurons_in, nx, ny, nz)
"""
nx, ny, nz = (
self.grid_resolution[0],
self.grid_resolution[1],
self.grid_resolution[2],
)
grid = grid.reshape(1, nx * ny * nz, 3, 1)
x_transposed = torch.transpose(x, 2, 3)
dist_weights = 1.0 / (1e-6 + (x_transposed - grid) ** 2.0)
dist_weights = torch.transpose(dist_weights, 2, 3)
# x = torch.sum(x * dist_weights, 2) / torch.sum(dist_weights, 2)
# x = torch.sum(x, 2)
mask = abs(x - 0) > 1e-6
if self.fourier_features:
facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), axis=-1)
else:
facets = x
x = self.activation(self.fc1(facets))
x = self.activation(self.fc2(x))
x = F.tanh(self.fc3(x))
mask = mask[:, :, :, 0:1].expand(
mask.shape[0], mask.shape[1], mask.shape[2], x.shape[-1]
)
x = torch.sum(x * mask, 2)
x = rearrange(x, "b (x y z) c -> b c x y z", x=nx, y=ny, z=nz)
return x
[docs]
class GeoProcessor(nn.Module):
"""Geometry processing layer using CNNs"""
def __init__(self, input_filters: int, output_filters: int, model_parameters):
"""
Initialize the GeoProcessor network.
Args:
input_filters: Number of input channels
model_parameters: Configuration parameters for the model
"""
super().__init__()
base_filters = model_parameters.base_filters
self.conv1 = nn.Conv3d(
input_filters, base_filters, kernel_size=3, padding="same"
)
self.conv2 = nn.Conv3d(
base_filters, 2 * base_filters, kernel_size=3, padding="same"
)
self.conv3 = nn.Conv3d(
2 * base_filters, 4 * base_filters, kernel_size=3, padding="same"
)
self.conv3_1 = nn.Conv3d(
4 * base_filters, 4 * base_filters, kernel_size=3, padding="same"
)
self.conv4 = nn.Conv3d(
4 * base_filters, 2 * base_filters, kernel_size=3, padding="same"
)
self.conv5 = nn.Conv3d(
4 * base_filters, base_filters, kernel_size=3, padding="same"
)
self.conv6 = nn.Conv3d(
2 * base_filters, input_filters, kernel_size=3, padding="same"
)
self.conv7 = nn.Conv3d(
2 * input_filters, input_filters, kernel_size=3, padding="same"
)
self.conv8 = nn.Conv3d(
input_filters, output_filters, kernel_size=3, padding="same"
)
self.avg_pool = torch.nn.AvgPool3d((2, 2, 2))
self.max_pool = nn.MaxPool3d(2)
self.upsample = nn.Upsample(scale_factor=2, mode="nearest")
self.activation = get_activation(model_parameters.activation)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process geometry information through the 3D CNN network.
The network follows an encoder-decoder architecture with skip connections:
1. Downsampling path (encoder) with three levels of max pooling
2. Processing loop in the bottleneck
3. Upsampling path (decoder) with skip connections from the encoder
Args:
x: Input tensor containing grid-represented geometry of shape
(batch_size, input_filters, nx, ny, nz)
Returns:
Processed geometry features of shape (batch_size, 1, nx, ny, nz)
"""
# Encoder
x0 = x
x = self.conv1(x)
x = self.activation(x)
x = self.max_pool(x)
x1 = x
x = self.conv2(x)
x = self.activation(x)
x = self.max_pool(x)
x2 = x
x = self.conv3(x)
x = self.activation(x)
x = self.max_pool(x)
# Processor loop
x = self.activation(self.conv3_1(x))
# Decoder
x = self.conv4(x)
x = self.activation(x)
x = self.upsample(x)
x = torch.cat((x, x2), dim=1)
x = self.conv5(x)
x = self.activation(x)
x = self.upsample(x)
x = torch.cat((x, x1), dim=1)
x = self.conv6(x)
x = self.activation(x)
x = self.upsample(x)
x = torch.cat((x, x0), dim=1)
x = self.activation(self.conv7(x))
x = self.conv8(x)
return x
[docs]
class GeometryRep(nn.Module):
"""
Geometry representation module that processes STL geometry data.
This module constructs a multiscale representation of geometry by:
1. Computing multi-scale geometry encoding for local and global context
2. Processing signed distance field (SDF) data for surface information
The combined encoding enables the model to reason about both local and global
geometric properties.
"""
def __init__(
self,
input_features: int,
radii: Sequence[float],
neighbors_in_radius,
hops=1,
model_parameters=None,
):
"""
Initialize the GeometryRep module.
Args:
input_features: Number of input feature dimensions
model_parameters: Configuration parameters for the model
"""
super().__init__()
geometry_rep = model_parameters.geometry_rep
self.geo_encoding_type = model_parameters.geometry_encoding_type
self.cross_attention = geometry_rep.geo_processor.cross_attention
self.self_attention = geometry_rep.geo_processor.self_attention
self.activation_conv = get_activation(geometry_rep.geo_conv.activation)
self.activation_processor = geometry_rep.geo_processor.activation
self.bq_warp = nn.ModuleList()
self.geo_processors = nn.ModuleList()
for j in range(len(radii)):
self.bq_warp.append(
BQWarp(
grid_resolution=model_parameters.interp_res,
radius=radii[j],
neighbors_in_radius=neighbors_in_radius[j],
)
)
if geometry_rep.geo_processor.processor_type == "unet":
h = geometry_rep.geo_processor.base_filters
if self.self_attention:
normalization_in_unet = "layernorm"
else:
normalization_in_unet = None
self.geo_processors.append(
UNet(
in_channels=geometry_rep.geo_conv.base_neurons_in,
out_channels=geometry_rep.geo_conv.base_neurons_out,
model_depth=3,
feature_map_channels=[
h,
2 * h,
4 * h,
],
num_conv_blocks=1,
kernel_size=3,
stride=1,
conv_activation=self.activation_processor,
padding=1,
padding_mode="zeros",
pooling_type="MaxPool3d",
pool_size=2,
normalization=normalization_in_unet,
use_attn_gate=self.self_attention,
attn_decoder_feature_maps=[4 * h, 2 * h],
attn_feature_map_channels=[2 * h, h],
attn_intermediate_channels=4 * h,
gradient_checkpointing=True,
)
)
elif geometry_rep.geo_processor.processor_type == "conv":
self.geo_processors.append(
nn.Sequential(
GeoProcessor(
input_filters=geometry_rep.geo_conv.base_neurons_in,
output_filters=geometry_rep.geo_conv.base_neurons_out,
model_parameters=geometry_rep.geo_processor,
),
GeoProcessor(
input_filters=geometry_rep.geo_conv.base_neurons_in,
output_filters=geometry_rep.geo_conv.base_neurons_out,
model_parameters=geometry_rep.geo_processor,
),
)
)
else:
raise ValueError("Invalid prompt. Specify unet or conv ...")
self.geo_conv_out = nn.ModuleList()
self.geo_processor_out = nn.ModuleList()
for _ in range(len(radii)):
self.geo_conv_out.append(
GeoConvOut(
input_features=input_features,
model_parameters=geometry_rep.geo_conv,
grid_resolution=model_parameters.interp_res,
)
)
self.geo_processor_out.append(
nn.Conv3d(
geometry_rep.geo_conv.base_neurons_out,
1,
kernel_size=3,
padding="same",
)
)
if geometry_rep.geo_processor.processor_type == "unet":
h = geometry_rep.geo_processor.base_filters
if self.self_attention:
normalization_in_unet = "layernorm"
else:
normalization_in_unet = None
self.geo_processor_sdf = UNet(
in_channels=6,
out_channels=geometry_rep.geo_conv.base_neurons_out,
model_depth=3,
feature_map_channels=[
h,
2 * h,
4 * h,
],
num_conv_blocks=1,
kernel_size=3,
stride=1,
conv_activation=self.activation_processor,
padding=1,
padding_mode="zeros",
pooling_type="MaxPool3d",
pool_size=2,
normalization=normalization_in_unet,
use_attn_gate=self.self_attention,
attn_decoder_feature_maps=[4 * h, 2 * h],
attn_feature_map_channels=[2 * h, h],
attn_intermediate_channels=4 * h,
gradient_checkpointing=True,
)
elif geometry_rep.geo_processor.processor_type == "conv":
self.geo_processor_sdf = nn.Sequential(
GeoProcessor(
input_filters=6,
output_filters=geometry_rep.geo_conv.base_neurons_out,
model_parameters=geometry_rep.geo_processor,
),
GeoProcessor(
input_filters=geometry_rep.geo_conv.base_neurons_out,
output_filters=geometry_rep.geo_conv.base_neurons_out,
model_parameters=geometry_rep.geo_processor,
),
)
else:
raise ValueError("Invalid prompt. Specify unet or conv ...")
self.radii = radii
self.hops = hops
self.geo_processor_sdf_out = nn.Conv3d(
geometry_rep.geo_conv.base_neurons_out, 1, kernel_size=3, padding="same"
)
if self.cross_attention:
self.combined_unet = UNet(
in_channels=1 + len(radii),
out_channels=1 + len(radii),
model_depth=3,
feature_map_channels=[
h,
2 * h,
4 * h,
],
num_conv_blocks=1,
kernel_size=3,
stride=1,
conv_activation=self.activation_processor,
padding=1,
padding_mode="zeros",
pooling_type="MaxPool3d",
pool_size=2,
normalization="layernorm",
use_attn_gate=True,
attn_decoder_feature_maps=[4 * h, 2 * h],
attn_feature_map_channels=[2 * h, h],
attn_intermediate_channels=4 * h,
gradient_checkpointing=True,
)
[docs]
def forward(
self, x: torch.Tensor, p_grid: torch.Tensor, sdf: torch.Tensor
) -> torch.Tensor:
"""
Process geometry data to create a comprehensive representation.
This method combines short-range, long-range, and SDF-based geometry
encodings to create a rich representation of the geometry.
Args:
x: Input tensor containing geometric point data
p_grid: Grid points for sampling
sdf: Signed distance field tensor
Returns:
Comprehensive geometry encoding that concatenates short-range,
SDF-based, and long-range features
"""
if self.geo_encoding_type == "both" or self.geo_encoding_type == "stl":
# Calculate multi-scale geoemtry dependency
x_encoding = []
for j in range(len(self.radii)):
mapping, k_short = self.bq_warp[j](x, p_grid)
x_encoding_inter = self.geo_conv_out[j](k_short, p_grid)
# Propagate information in the geometry enclosed BBox
for _ in range(self.hops):
dx = self.geo_processors[j](x_encoding_inter) / self.hops
x_encoding_inter = x_encoding_inter + dx
x_encoding_inter = self.geo_processor_out[j](x_encoding_inter)
x_encoding.append(x_encoding_inter)
x_encoding = torch.cat(x_encoding, dim=1)
if self.geo_encoding_type == "both" or self.geo_encoding_type == "sdf":
# Expand SDF
sdf = torch.unsqueeze(sdf, 1)
# Scaled sdf to emphasize near surface
scaled_sdf = scale_sdf(sdf)
# Binary sdf
binary_sdf = torch.where(sdf >= 0, 0.0, 1.0)
# Gradients of SDF
sdf_x, sdf_y, sdf_z = torch.gradient(sdf, dim=[2, 3, 4])
# Process SDF and its computed features
sdf = torch.cat((sdf, scaled_sdf, binary_sdf, sdf_x, sdf_y, sdf_z), 1)
sdf_encoding = self.geo_processor_sdf(sdf)
sdf_encoding = self.geo_processor_sdf_out(sdf_encoding)
if self.geo_encoding_type == "both":
# Geometry encoding comprised of short-range, long-range and SDF features
encoding_g = torch.cat((x_encoding, sdf_encoding), 1)
elif self.geo_encoding_type == "sdf":
encoding_g = sdf_encoding
elif self.geo_encoding_type == "stl":
encoding_g = x_encoding
if self.cross_attention:
encoding_g = self.combined_unet(encoding_g)
return encoding_g
[docs]
class NNBasisFunctions(nn.Module):
"""Basis function layer for point clouds"""
def __init__(self, input_features: int, model_parameters=None):
super(NNBasisFunctions, self).__init__()
base_layer = model_parameters.base_layer
self.fourier_features = model_parameters.fourier_features
self.num_modes = model_parameters.num_modes
if self.fourier_features:
input_features_calculated = (
input_features + input_features * self.num_modes * 2
)
else:
input_features_calculated = input_features
self.fc1 = nn.Linear(input_features_calculated, base_layer)
self.fc2 = nn.Linear(base_layer, int(base_layer))
self.fc3 = nn.Linear(int(base_layer), int(base_layer))
self.activation = get_activation(model_parameters.activation)
if self.fourier_features:
self.register_buffer(
"freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes))
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Transform point features into a basis function representation.
Args:
x: Input tensor containing point features
Returns:
Tensor containing basis function coefficients
"""
if self.fourier_features:
facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), dim=-1)
else:
facets = x
facets = self.activation(self.fc1(facets))
facets = self.activation(self.fc2(facets))
facets = self.fc3(facets)
return facets
[docs]
class ParameterModel(nn.Module):
"""
Neural network module to encode simulation parameters.
This module encodes physical global parameters into a learned
latent representation that can be incorporated into the
model'sprediction process.
"""
def __init__(self, input_features: int, model_parameters=None):
"""
Initialize the parameter encoding network.
Args:
input_features: Number of input parameters to encode
model_parameters: Configuration parameters for the model
"""
super(ParameterModel, self).__init__()
self.fourier_features = model_parameters.fourier_features
self.num_modes = model_parameters.num_modes
if self.fourier_features:
input_features_calculated = (
input_features + input_features * self.num_modes * 2
)
self.register_buffer(
"freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes))
)
else:
input_features_calculated = input_features
base_layer = model_parameters.base_layer
self.fc1 = nn.Linear(input_features_calculated, base_layer)
self.fc2 = nn.Linear(base_layer, int(base_layer))
self.fc3 = nn.Linear(int(base_layer), int(base_layer))
self.activation = get_activation(model_parameters.activation)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Encode physical parameters into a latent representation.
Args:
x: Input tensor containing physical parameters (e.g., inlet velocity, air density)
Returns:
Tensor containing encoded parameter representation
"""
if self.fourier_features:
params = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), dim=-1)
else:
params = x
params = self.activation(self.fc1(params))
params = self.activation(self.fc2(params))
params = self.fc3(params)
return params
[docs]
class AggregationModel(nn.Module):
"""
Neural network module to aggregate local geometry encoding with basis functions.
This module combines basis function representations with geometry encodings
to predict the final output quantities. It serves as the final prediction layer
that integrates all available information sources.
"""
def __init__(
self,
input_features: int,
output_features: int,
model_parameters=None,
new_change: bool = True,
):
"""
Initialize the aggregation model.
Args:
input_features: Number of input feature dimensions
output_features: Number of output feature dimensions
model_parameters: Configuration parameters for the model
new_change: Flag to enable newer implementation (default: True)
"""
super(AggregationModel, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.new_change = new_change
base_layer = model_parameters.base_layer
self.fc1 = nn.Linear(self.input_features, base_layer)
self.fc2 = nn.Linear(base_layer, int(base_layer))
self.fc3 = nn.Linear(int(base_layer), int(base_layer))
self.fc4 = nn.Linear(int(base_layer), int(base_layer))
self.fc5 = nn.Linear(int(base_layer), self.output_features)
self.activation = get_activation(model_parameters.activation)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Process the combined input features to predict output quantities.
This method applies a series of fully connected layers to the input,
which typically contains a combination of basis functions, geometry
encodings, and potentially parameter encodings.
Args:
x: Input tensor containing combined features
Returns:
Tensor containing predicted output quantities
"""
out = self.activation(self.fc1(x))
out = self.activation(self.fc2(out))
out = self.activation(self.fc3(out))
out = self.activation(self.fc4(out))
out = self.fc5(out)
return out
[docs]
class LocalPointConv(nn.Module):
"""Layer for local geometry point kernel"""
def __init__(
self,
input_features,
base_layer,
output_features,
model_parameters=None,
):
super(LocalPointConv, self).__init__()
self.input_features = input_features
self.output_features = output_features
self.fc1 = nn.Linear(self.input_features, base_layer)
self.fc2 = nn.Linear(base_layer, self.output_features)
self.activation = get_activation(model_parameters.activation)
[docs]
def forward(self, x):
out = self.activation(self.fc1(x))
out = self.fc2(out)
return out
[docs]
class PositionEncoder(nn.Module):
"""Positional encoding of point clouds"""
def __init__(self, input_features: int, model_parameters=None):
super().__init__()
base_layer = model_parameters.base_neurons
self.fourier_features = model_parameters.fourier_features
self.num_modes = model_parameters.num_modes
if self.fourier_features:
input_features_calculated = (
input_features + input_features * self.num_modes * 2
)
else:
input_features_calculated = input_features
self.fc1 = nn.Linear(input_features_calculated, base_layer)
self.fc2 = nn.Linear(base_layer, int(base_layer))
self.fc3 = nn.Linear(int(base_layer), int(base_layer))
self.activation = get_activation(model_parameters.activation)
if self.fourier_features:
self.register_buffer(
"freqs", torch.exp(torch.linspace(0, math.pi, self.num_modes))
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Transform point features into a basis function representation.
Args:
x: Input tensor containing point features
Returns:
Tensor containing position encoder
"""
if self.fourier_features:
facets = torch.cat((x, fourier_encode_vectorized(x, self.freqs)), axis=-1)
else:
facets = x
facets = self.activation(self.fc1(facets))
facets = self.activation(self.fc2(facets))
facets = self.fc3(facets)
return facets
# @dataclass
# class MetaData(ModelMetaData):
# name: str = "DoMINO"
# # Optimization
# jit: bool = False
# cuda_graphs: bool = True
# amp: bool = True
# # Inference
# onnx_cpu: bool = True
# onnx_gpu: bool = True
# onnx_runtime: bool = True
# # Physics informed
# var_dim: int = 1
# func_torch: bool = False
# auto_grad: bool = False
[docs]
class DoMINO(nn.Module):
"""
DoMINO model architecture for predicting both surface and volume quantities.
The DoMINO (Deep Operational Modal Identification and Nonlinear Optimization) model
is designed to model both surface and volume physical quantities in aerodynamic
simulations. It can operate in three modes:
1. Surface-only: Predicting only surface quantities
2. Volume-only: Predicting only volume quantities
3. Combined: Predicting both surface and volume quantities
The model uses a combination of:
- Geometry representation modules
- Neural network basis functions
- Parameter encoding
- Local and global geometry processing
- Aggregation models for final prediction
Parameters
----------
input_features : int
Number of point input features
output_features_vol : int, optional
Number of output features in volume
output_features_surf : int, optional
Number of output features on surface
model_parameters
Model parameters controlled by config.yaml
Example
-------
>>> from physicsnemo.models.domino.model import DoMINO
>>> import torch, os
>>> from hydra import compose, initialize
>>> from omegaconf import OmegaConf
>>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
>>> cfg = OmegaConf.register_new_resolver("eval", eval)
>>> with initialize(version_base="1.3", config_path="examples/cfd/external_aerodynamics/domino/src/conf"):
... cfg = compose(config_name="config")
>>> cfg.model.model_type = "combined"
>>> model = DoMINO(
... input_features=3,
... output_features_vol=5,
... output_features_surf=4,
... model_parameters=cfg.model
... ).to(device)
Warp ...
>>> bsize = 1
>>> nx, ny, nz = cfg.model.interp_res
>>> num_neigh = 7
>>> global_features = 2
>>> pos_normals_closest_vol = torch.randn(bsize, 100, 3).to(device)
>>> pos_normals_com_vol = torch.randn(bsize, 100, 3).to(device)
>>> pos_normals_com_surface = torch.randn(bsize, 100, 3).to(device)
>>> geom_centers = torch.randn(bsize, 100, 3).to(device)
>>> grid = torch.randn(bsize, nx, ny, nz, 3).to(device)
>>> surf_grid = torch.randn(bsize, nx, ny, nz, 3).to(device)
>>> sdf_grid = torch.randn(bsize, nx, ny, nz).to(device)
>>> sdf_surf_grid = torch.randn(bsize, nx, ny, nz).to(device)
>>> sdf_nodes = torch.randn(bsize, 100, 1).to(device)
>>> surface_coordinates = torch.randn(bsize, 100, 3).to(device)
>>> surface_neighbors = torch.randn(bsize, 100, num_neigh, 3).to(device)
>>> surface_normals = torch.randn(bsize, 100, 3).to(device)
>>> surface_neighbors_normals = torch.randn(bsize, 100, num_neigh, 3).to(device)
>>> surface_sizes = torch.rand(bsize, 100).to(device) + 1e-6 # Note this needs to be > 0.0
>>> surface_neighbors_areas = torch.rand(bsize, 100, num_neigh).to(device) + 1e-6
>>> volume_coordinates = torch.randn(bsize, 100, 3).to(device)
>>> vol_grid_max_min = torch.randn(bsize, 2, 3).to(device)
>>> surf_grid_max_min = torch.randn(bsize, 2, 3).to(device)
>>> global_params_values = torch.randn(bsize, global_features, 1).to(device)
>>> global_params_reference = torch.randn(bsize, global_features, 1).to(device)
>>> input_dict = {
... "pos_volume_closest": pos_normals_closest_vol,
... "pos_volume_center_of_mass": pos_normals_com_vol,
... "pos_surface_center_of_mass": pos_normals_com_surface,
... "geometry_coordinates": geom_centers,
... "grid": grid,
... "surf_grid": surf_grid,
... "sdf_grid": sdf_grid,
... "sdf_surf_grid": sdf_surf_grid,
... "sdf_nodes": sdf_nodes,
... "surface_mesh_centers": surface_coordinates,
... "surface_mesh_neighbors": surface_neighbors,
... "surface_normals": surface_normals,
... "surface_neighbors_normals": surface_neighbors_normals,
... "surface_areas": surface_sizes,
... "surface_neighbors_areas": surface_neighbors_areas,
... "volume_mesh_centers": volume_coordinates,
... "volume_min_max": vol_grid_max_min,
... "surface_min_max": surf_grid_max_min,
... "global_params_reference": global_params_values,
... "global_params_values": global_params_reference,
... }
>>> output = model(input_dict)
>>> print(f"{output[0].shape}, {output[1].shape}")
torch.Size([1, 100, 5]), torch.Size([1, 100, 4])
"""
def __init__(
self,
input_features: int,
output_features_vol: int | None = None,
output_features_surf: int | None = None,
global_features: int = 2,
model_parameters=None,
):
"""
Initialize the DoMINO model.
Args:
input_features: Number of input feature dimensions for point data
output_features_vol: Number of output features for volume quantities (None for surface-only mode)
output_features_surf: Number of output features for surface quantities (None for volume-only mode)
model_parameters: Configuration parameters for the model
Raises:
ValueError: If both output_features_vol and output_features_surf are None
"""
super().__init__()
self.input_features = input_features
self.output_features_vol = output_features_vol
self.output_features_surf = output_features_surf
self.num_sample_points_surface = model_parameters.num_neighbors_surface
self.num_sample_points_volume = model_parameters.num_neighbors_volume
self.combined_vol_surf = model_parameters.combine_volume_surface
self.activation_processor = (
model_parameters.geometry_rep.geo_processor.activation
)
if self.combined_vol_surf:
h = 8
in_channels = (
2
+ len(model_parameters.geometry_rep.geo_conv.volume_radii)
+ len(model_parameters.geometry_rep.geo_conv.surface_radii)
)
out_channels_surf = 1 + len(
model_parameters.geometry_rep.geo_conv.surface_radii
)
out_channels_vol = 1 + len(
model_parameters.geometry_rep.geo_conv.volume_radii
)
self.combined_unet_surf = UNet(
in_channels=in_channels,
out_channels=out_channels_surf,
model_depth=3,
feature_map_channels=[
h,
2 * h,
4 * h,
],
num_conv_blocks=1,
kernel_size=3,
stride=1,
conv_activation=self.activation_processor,
padding=1,
padding_mode="zeros",
pooling_type="MaxPool3d",
pool_size=2,
normalization="layernorm",
use_attn_gate=True,
attn_decoder_feature_maps=[4 * h, 2 * h],
attn_feature_map_channels=[2 * h, h],
attn_intermediate_channels=4 * h,
gradient_checkpointing=True,
)
self.combined_unet_vol = UNet(
in_channels=in_channels,
out_channels=out_channels_vol,
model_depth=3,
feature_map_channels=[
h,
2 * h,
4 * h,
],
num_conv_blocks=1,
kernel_size=3,
stride=1,
conv_activation=self.activation_processor,
padding=1,
padding_mode="zeros",
pooling_type="MaxPool3d",
pool_size=2,
normalization="layernorm",
use_attn_gate=True,
attn_decoder_feature_maps=[4 * h, 2 * h],
attn_feature_map_channels=[2 * h, h],
attn_intermediate_channels=4 * h,
gradient_checkpointing=True,
)
self.global_features = global_features
if self.output_features_vol is None and self.output_features_surf is None:
raise ValueError(
"At least one of `output_features_vol` or `output_features_surf` must be specified"
)
if hasattr(model_parameters, "solution_calculation_mode"):
if model_parameters.solution_calculation_mode not in [
"one-loop",
"two-loop",
]:
raise ValueError(
f"Invalid solution_calculation_mode: {model_parameters.solution_calculation_mode}, select 'one-loop' or 'two-loop'."
)
self.solution_calculation_mode = model_parameters.solution_calculation_mode
else:
self.solution_calculation_mode = "two-loop"
self.num_variables_vol = output_features_vol
self.num_variables_surf = output_features_surf
self.grid_resolution = model_parameters.interp_res
self.use_surface_normals = model_parameters.use_surface_normals
self.use_surface_area = model_parameters.use_surface_area
self.encode_parameters = model_parameters.encode_parameters
self.geo_encoding_type = model_parameters.geometry_encoding_type
if hasattr(model_parameters, "num_volume_neighbors"):
self.num_volume_neighbors = model_parameters.num_volume_neighbors
else:
self.num_volume_neighbors = 50
if hasattr(model_parameters, "return_volume_neighbors"):
self.return_volume_neighbors = model_parameters.return_volume_neighbors
if (
self.return_volume_neighbors
and self.solution_calculation_mode == "one-loop"
):
print(
"'one-loop' solution_calculation mode not supported when return_volume_neighbors is set to true"
)
print("Overwriting the solution_calculation mode to 'two-loop'")
self.solution_calculation_mode = "two-loop"
if self.use_surface_normals:
if not self.use_surface_area:
input_features_surface = input_features + 3
else:
input_features_surface = input_features + 4
else:
input_features_surface = input_features
if self.encode_parameters:
# Defining the parameter model
base_layer_p = model_parameters.parameter_model.base_layer
self.parameter_model = ParameterModel(
input_features=self.global_features,
model_parameters=model_parameters.parameter_model,
)
else:
base_layer_p = 0
self.geo_rep_volume = GeometryRep(
input_features=input_features,
radii=model_parameters.geometry_rep.geo_conv.volume_radii,
neighbors_in_radius=model_parameters.geometry_rep.geo_conv.volume_neighbors_in_radius,
hops=model_parameters.geometry_rep.geo_conv.volume_hops,
model_parameters=model_parameters,
)
self.geo_rep_surface = GeometryRep(
input_features=input_features,
radii=model_parameters.geometry_rep.geo_conv.surface_radii,
neighbors_in_radius=model_parameters.geometry_rep.geo_conv.surface_neighbors_in_radius,
hops=model_parameters.geometry_rep.geo_conv.surface_hops,
model_parameters=model_parameters,
)
self.geo_rep_surface1 = GeometryRep(
input_features=input_features,
radii=model_parameters.geometry_rep.geo_conv.volume_radii,
neighbors_in_radius=model_parameters.geometry_rep.geo_conv.volume_neighbors_in_radius,
model_parameters=model_parameters,
)
# Basis functions for surface and volume
base_layer_nn = model_parameters.nn_basis_functions.base_layer
if self.output_features_surf is not None:
self.nn_basis_surf = nn.ModuleList()
for _ in range(
self.num_variables_surf
): # Have the same basis function for each variable
self.nn_basis_surf.append(
NNBasisFunctions(
input_features=input_features_surface,
model_parameters=model_parameters.nn_basis_functions,
)
)
if self.output_features_vol is not None:
self.nn_basis_vol = nn.ModuleList()
for _ in range(
self.num_variables_vol
): # Have the same basis function for each variable
self.nn_basis_vol.append(
NNBasisFunctions(
input_features=input_features,
model_parameters=model_parameters.nn_basis_functions,
)
)
# Positional encoding
position_encoder_base_neurons = model_parameters.position_encoder.base_neurons
self.activation = get_activation(model_parameters.activation)
self.use_sdf_in_basis_func = model_parameters.use_sdf_in_basis_func
if self.output_features_vol is not None:
if model_parameters.positional_encoding:
inp_pos_vol = 25 if model_parameters.use_sdf_in_basis_func else 12
else:
inp_pos_vol = 7 if model_parameters.use_sdf_in_basis_func else 3
self.fc_p_vol = PositionEncoder(
inp_pos_vol, model_parameters.position_encoder
)
if self.output_features_surf is not None:
if model_parameters.positional_encoding:
inp_pos_surf = 12
else:
inp_pos_surf = 3
self.fc_p_surf = PositionEncoder(
inp_pos_surf, model_parameters.position_encoder
)
# BQ for surface
self.surface_neighbors_in_radius = (
model_parameters.geometry_local.surface_neighbors_in_radius
)
self.surface_radius = model_parameters.geometry_local.surface_radii
self.surface_bq_warp = nn.ModuleList()
self.surface_local_point_conv = nn.ModuleList()
for ct in range(len(self.surface_radius)):
if self.geo_encoding_type == "both":
total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] * (
len(model_parameters.geometry_rep.geo_conv.surface_radii) + 1
)
elif self.geo_encoding_type == "stl":
total_neighbors_in_radius = self.surface_neighbors_in_radius[ct] * (
len(model_parameters.geometry_rep.geo_conv.surface_radii)
)
elif self.geo_encoding_type == "sdf":
total_neighbors_in_radius = self.surface_neighbors_in_radius[ct]
self.surface_bq_warp.append(
BQWarp(
grid_resolution=model_parameters.interp_res,
radius=self.surface_radius[ct],
neighbors_in_radius=self.surface_neighbors_in_radius[ct],
)
)
self.surface_local_point_conv.append(
LocalPointConv(
input_features=total_neighbors_in_radius,
base_layer=512,
output_features=self.surface_neighbors_in_radius[ct],
model_parameters=model_parameters.local_point_conv,
)
)
# BQ for volume
self.volume_neighbors_in_radius = (
model_parameters.geometry_local.volume_neighbors_in_radius
)
self.volume_radius = model_parameters.geometry_local.volume_radii
self.volume_bq_warp = nn.ModuleList()
self.volume_local_point_conv = nn.ModuleList()
for ct in range(len(self.volume_radius)):
if self.geo_encoding_type == "both":
total_neighbors_in_radius = self.volume_neighbors_in_radius[ct] * (
len(model_parameters.geometry_rep.geo_conv.volume_radii) + 1
)
elif self.geo_encoding_type == "stl":
total_neighbors_in_radius = self.volume_neighbors_in_radius[ct] * (
len(model_parameters.geometry_rep.geo_conv.volume_radii)
)
elif self.geo_encoding_type == "sdf":
total_neighbors_in_radius = self.volume_neighbors_in_radius[ct]
self.volume_bq_warp.append(
BQWarp(
grid_resolution=model_parameters.interp_res,
radius=self.volume_radius[ct],
neighbors_in_radius=self.volume_neighbors_in_radius[ct],
)
)
self.volume_local_point_conv.append(
LocalPointConv(
input_features=total_neighbors_in_radius,
base_layer=512,
output_features=self.volume_neighbors_in_radius[ct],
model_parameters=model_parameters.local_point_conv,
)
)
# Transmitting surface to volume
self.surf_to_vol_conv1 = nn.Conv3d(
len(model_parameters.geometry_rep.geo_conv.volume_radii) + 1,
16,
kernel_size=3,
padding="same",
)
self.surf_to_vol_conv2 = nn.Conv3d(
16,
len(model_parameters.geometry_rep.geo_conv.volume_radii) + 1,
kernel_size=3,
padding="same",
)
# Aggregation model
if self.output_features_surf is not None:
# Surface
base_layer_geo_surf = 0
for j in self.surface_neighbors_in_radius:
base_layer_geo_surf += j
self.agg_model_surf = nn.ModuleList()
for _ in range(self.num_variables_surf):
self.agg_model_surf.append(
AggregationModel(
input_features=position_encoder_base_neurons
+ base_layer_nn
+ base_layer_geo_surf
+ base_layer_p,
output_features=1,
model_parameters=model_parameters.aggregation_model,
)
)
if self.output_features_vol is not None:
# Volume
base_layer_geo_vol = 0
for j in self.volume_neighbors_in_radius:
base_layer_geo_vol += j
self.agg_model_vol = nn.ModuleList()
for _ in range(self.num_variables_vol):
self.agg_model_vol.append(
AggregationModel(
input_features=position_encoder_base_neurons
+ base_layer_nn
+ base_layer_geo_vol
+ base_layer_p,
output_features=1,
model_parameters=model_parameters.aggregation_model,
)
)
[docs]
def position_encoder(
self,
encoding_node: torch.Tensor,
eval_mode: Literal["surface", "volume"] = "volume",
) -> torch.Tensor:
"""
Compute positional encoding for input points.
Args:
encoding_node: Tensor containing node position information
eval_mode: Mode of evaluation, either "volume" or "surface"
Returns:
Tensor containing positional encoding features
"""
if eval_mode == "volume":
x = self.fc_p_vol(encoding_node)
elif eval_mode == "surface":
x = self.fc_p_surf(encoding_node)
else:
raise ValueError(
f"`eval_mode` must be 'surface' or 'volume', got {eval_mode=}"
)
return x
[docs]
def geo_encoding_local(
self, encoding_g, volume_mesh_centers, p_grid, mode="volume"
):
"""Function to calculate local geometry encoding from global encoding"""
if mode == "volume":
radius = self.volume_radius
bq_warp = self.volume_bq_warp
point_conv = self.volume_local_point_conv
elif mode == "surface":
radius = self.surface_radius
bq_warp = self.surface_bq_warp
point_conv = self.surface_local_point_conv
batch_size = volume_mesh_centers.shape[0]
nx, ny, nz = (
self.grid_resolution[0],
self.grid_resolution[1],
self.grid_resolution[2],
)
encoding_outer = []
for p in range(len(radius)):
p_grid = torch.reshape(p_grid, (batch_size, nx * ny * nz, 3))
mapping, outputs = bq_warp[p](
volume_mesh_centers, p_grid, reverse_mapping=False
)
mapping = mapping.type(torch.int64)
mask = mapping != 0
encoding_g_inner = []
for j in range(encoding_g.shape[1]):
geo_encoding = rearrange(
encoding_g[:, j], "b nx ny nz -> b 1 (nx ny nz)"
)
geo_encoding_sampled = torch.index_select(
geo_encoding, 2, mapping.flatten()
)
geo_encoding_sampled = torch.reshape(geo_encoding_sampled, mask.shape)
geo_encoding_sampled = geo_encoding_sampled * mask
encoding_g_inner.append(geo_encoding_sampled)
encoding_g_inner = torch.cat(encoding_g_inner, dim=2)
encoding_g_inner = point_conv[p](encoding_g_inner)
encoding_outer.append(encoding_g_inner)
encoding_g = torch.cat(encoding_outer, dim=-1)
return encoding_g
[docs]
def calculate_solution_with_neighbors(
self,
surface_mesh_centers,
encoding_g,
encoding_node,
surface_mesh_neighbors,
surface_normals,
surface_neighbors_normals,
surface_areas,
surface_neighbors_areas,
global_params_values,
global_params_reference,
num_sample_points=7,
):
"""Function to approximate solution given the neighborhood information"""
num_variables = self.num_variables_surf
nn_basis = self.nn_basis_surf
agg_model = self.agg_model_surf
if self.encode_parameters:
processed_parameters = []
for k in range(global_params_values.shape[1]):
param = torch.unsqueeze(global_params_values[:, k, :], 1)
ref = torch.unsqueeze(global_params_reference[:, k, :], 1)
param = param.expand(
param.shape[0],
surface_mesh_centers.shape[1],
param.shape[2],
)
param = param / ref
processed_parameters.append(param)
processed_parameters = torch.cat(processed_parameters, axis=-1)
param_encoding = self.parameter_model(processed_parameters)
if self.use_surface_normals:
if not self.use_surface_area:
surface_mesh_centers = torch.cat(
(surface_mesh_centers, surface_normals),
dim=-1,
)
if num_sample_points > 1:
surface_mesh_neighbors = torch.cat(
(
surface_mesh_neighbors,
surface_neighbors_normals,
),
dim=-1,
)
else:
surface_mesh_centers = torch.cat(
(
surface_mesh_centers,
surface_normals,
torch.log(surface_areas) / 10,
),
dim=-1,
)
if num_sample_points > 1:
surface_mesh_neighbors = torch.cat(
(
surface_mesh_neighbors,
surface_neighbors_normals,
torch.log(surface_neighbors_areas) / 10,
),
dim=-1,
)
if self.solution_calculation_mode == "one-loop":
encoding_list = [
encoding_node.unsqueeze(2).expand(-1, -1, num_sample_points, -1),
encoding_g.unsqueeze(2).expand(-1, -1, num_sample_points, -1),
]
for f in range(num_variables):
one_loop_centers_expanded = surface_mesh_centers.unsqueeze(2)
one_loop_noise = one_loop_centers_expanded - (
surface_mesh_neighbors + 1e-6
)
one_loop_noise = torch.norm(one_loop_noise, dim=-1, keepdim=True)
# Doing it this way prevents the intermediate one_loop_basis_f from being stored in memory for the rest of the function.
agg_output = agg_model[f](
torch.cat(
(
nn_basis[f](
torch.cat(
(
one_loop_centers_expanded,
surface_mesh_neighbors + 1e-6,
),
dim=2,
)
),
*encoding_list,
),
dim=-1,
)
)
one_loop_output_center, one_loop_output_neighbor = torch.split(
agg_output, [1, num_sample_points - 1], dim=2
)
one_loop_output_neighbor = one_loop_output_neighbor * (
1.0 / one_loop_noise
)
one_loop_output_center = one_loop_output_center.squeeze(2)
one_loop_output_neighbor = one_loop_output_neighbor.sum(2)
one_loop_dist_sum = torch.sum(1.0 / one_loop_noise, dim=2)
# Stop here
if num_sample_points > 1:
one_loop_output_res = (
0.5 * one_loop_output_center
+ 0.5 * one_loop_output_neighbor / one_loop_dist_sum
)
else:
one_loop_output_res = one_loop_output_center
if f == 0:
one_loop_output_all = one_loop_output_res
else:
one_loop_output_all = torch.cat(
(one_loop_output_all, one_loop_output_res), dim=-1
)
return one_loop_output_all
if self.solution_calculation_mode == "two-loop":
for f in range(num_variables):
for p in range(num_sample_points):
if p == 0:
volume_m_c = surface_mesh_centers
else:
volume_m_c = surface_mesh_neighbors[:, :, p - 1] + 1e-6
noise = surface_mesh_centers - volume_m_c
dist = torch.norm(noise, dim=-1, keepdim=True)
basis_f = nn_basis[f](volume_m_c)
output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1)
if self.encode_parameters:
output = torch.cat((output, param_encoding), dim=-1)
if p == 0:
output_center = agg_model[f](output)
else:
if p == 1:
output_neighbor = agg_model[f](output) * (1.0 / dist)
dist_sum = 1.0 / dist
else:
output_neighbor += agg_model[f](output) * (1.0 / dist)
dist_sum += 1.0 / dist
if num_sample_points > 1:
output_res = 0.5 * output_center + 0.5 * output_neighbor / dist_sum
else:
output_res = output_center
if f == 0:
output_all = output_res
else:
output_all = torch.cat((output_all, output_res), dim=-1)
return output_all
[docs]
def sample_sphere(self, center, r, num_points):
"""Uniformly sample points in a 3D sphere around the center.
This method generates random points within a sphere of radius r centered
at each point in the input tensor. The sampling is uniform in volume,
meaning points are more likely to be sampled in the outer regions of the sphere.
Args:
center: Tensor of shape (batch_size, num_points, 3) containing center coordinates
r: Radius of the sphere for sampling
num_points: Number of points to sample per center
Returns:
Tensor of shape (batch_size, num_points, num_samples, 3) containing
the sampled points around each center
"""
# Adjust the center points to the final shape:
unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1)
# Generate directions like the centers:
directions = torch.randn_like(unsqueezed_center)
directions = directions / torch.norm(directions, dim=-1, keepdim=True)
# Generate radii like the centers:
radii = r * torch.pow(torch.rand_like(unsqueezed_center), 1 / 3)
output = unsqueezed_center + directions * radii
return output
[docs]
def sample_sphere_shell(self, center, r_inner, r_outer, num_points):
"""Uniformly sample points in a 3D spherical shell around a center.
This method generates random points within a spherical shell (annulus)
between inner radius r_inner and outer radius r_outer centered at each
point in the input tensor. The sampling is uniform in volume within the shell.
Args:
center: Tensor of shape (batch_size, num_points, 3) containing center coordinates
r_inner: Inner radius of the spherical shell
r_outer: Outer radius of the spherical shell
num_points: Number of points to sample per center
Returns:
Tensor of shape (batch_size, num_points, num_samples, 3) containing
the sampled points within the spherical shell around each center
"""
# directions = torch.randn(
# size=(center.shape[0], center.shape[1], num_points, center.shape[2]),
# device=center.device,
# )
# directions = directions / torch.norm(directions, dim=-1, keepdim=True)
unsqueezed_center = center.unsqueeze(2).expand(-1, -1, num_points, -1)
# Generate directions like the centers:
directions = torch.randn_like(unsqueezed_center)
directions = directions / torch.norm(directions, dim=-1, keepdim=True)
radii = (
torch.rand_like(unsqueezed_center) * (r_outer**3 - r_inner**3) + r_inner**3
)
radii = torch.pow(radii, 1 / 3)
output = unsqueezed_center + directions * radii
return output
[docs]
def calculate_solution(
self,
volume_mesh_centers,
encoding_g,
encoding_node,
global_params_values,
global_params_reference,
eval_mode,
num_sample_points=20,
noise_intensity=50,
return_volume_neighbors=False,
):
"""Function to approximate solution sampling the neighborhood information"""
if eval_mode == "volume":
num_variables = self.num_variables_vol
nn_basis = self.nn_basis_vol
agg_model = self.agg_model_vol
elif eval_mode == "surface":
num_variables = self.num_variables_surf
nn_basis = self.nn_basis_surf
agg_model = self.agg_model_surf
if self.encode_parameters:
processed_parameters = []
for k in range(global_params_values.shape[1]):
param = torch.unsqueeze(global_params_values[:, k, :], 1)
ref = torch.unsqueeze(global_params_reference[:, k, :], 1)
param = param.expand(
param.shape[0],
volume_mesh_centers.shape[1],
param.shape[2],
)
param = param / ref
processed_parameters.append(param)
processed_parameters = torch.cat(processed_parameters, axis=-1)
param_encoding = self.parameter_model(processed_parameters)
if self.solution_calculation_mode == "one-loop":
# Stretch these out to num_sample_points
one_loop_encoding_node = encoding_node.unsqueeze(0).expand(
num_sample_points, -1, -1, -1
)
one_loop_encoding_g = encoding_g.unsqueeze(0).expand(
num_sample_points, -1, -1, -1
)
if self.encode_parameters:
one_loop_other_terms = (
one_loop_encoding_node,
one_loop_encoding_g,
param_encoding,
)
else:
one_loop_other_terms = (one_loop_encoding_node, one_loop_encoding_g)
for f in range(num_variables):
one_loop_volume_mesh_centers_expanded = volume_mesh_centers.unsqueeze(
0
).expand(num_sample_points, -1, -1, -1)
# Bulk_random_noise has shape (num_sample_points, batch_size, num_points, 3)
one_loop_bulk_random_noise = torch.rand_like(
one_loop_volume_mesh_centers_expanded
)
one_loop_bulk_random_noise = 2 * (one_loop_bulk_random_noise - 0.5)
one_loop_bulk_random_noise = (
one_loop_bulk_random_noise / noise_intensity
)
one_loop_bulk_dist = torch.norm(
one_loop_bulk_random_noise, dim=-1, keepdim=True
)
_, one_loop_bulk_dist = torch.split(
one_loop_bulk_dist, [1, num_sample_points - 1], dim=0
)
# Set the first sample point to 0.0:
one_loop_bulk_random_noise[0] = torch.zeros_like(
one_loop_bulk_random_noise[0]
)
# Add the noise to the expanded volume_mesh_centers:
one_loop_volume_m_c = volume_mesh_centers + one_loop_bulk_random_noise
# If this looks overly complicated - it is.
# But, this makes sure that the memory used to store the output of both nn_basis[f]
# as well as the output of torch.cat can be deallocated immediately.
# Apply the aggregation model and distance scaling:
one_loop_output = agg_model[f](
torch.cat(
(nn_basis[f](one_loop_volume_m_c), *one_loop_other_terms),
dim=-1,
)
)
# select off the first, unperturbed term:
one_loop_output_center, one_loop_output_neighbor = torch.split(
one_loop_output, [1, num_sample_points - 1], dim=0
)
# Scale the neighbor terms by the distance:
one_loop_output_neighbor = one_loop_output_neighbor / one_loop_bulk_dist
one_loop_dist_sum = torch.sum(1.0 / one_loop_bulk_dist, dim=0)
# Adjust shapes:
one_loop_output_center = one_loop_output_center.squeeze(1)
one_loop_output_neighbor = one_loop_output_neighbor.sum(0)
# Compare:
if num_sample_points > 1:
one_loop_output_res = (
0.5 * one_loop_output_center
+ 0.5 * one_loop_output_neighbor / one_loop_dist_sum
)
else:
one_loop_output_res = one_loop_output_center
if f == 0:
one_loop_output_all = one_loop_output_res
else:
one_loop_output_all = torch.cat(
(one_loop_output_all, one_loop_output_res), dim=-1
)
return one_loop_output_all
if self.solution_calculation_mode == "two-loop":
volume_m_c_perturbed = [volume_mesh_centers.unsqueeze(2)]
if return_volume_neighbors:
num_hop1 = num_sample_points
num_hop2 = (
num_sample_points // 2 if num_sample_points != 1 else 1
) # This is per 1 hop node
neighbors = defaultdict(list)
volume_m_c_hop1 = self.sample_sphere(
volume_mesh_centers, 1 / noise_intensity, num_hop1
)
# 1 hop neighbors
for i in range(num_hop1):
idx = len(volume_m_c_perturbed)
volume_m_c_perturbed.append(volume_m_c_hop1[:, :, i : i + 1, :])
neighbors[0].append(idx)
# 2 hop neighbors
for i in range(num_hop1):
parent_idx = (
i + 1
) # Skipping the first point, which is the original
parent_point = volume_m_c_perturbed[parent_idx]
children = self.sample_sphere_shell(
parent_point.squeeze(2),
1 / noise_intensity,
2 / noise_intensity,
num_hop2,
)
for c in range(num_hop2):
idx = len(volume_m_c_perturbed)
volume_m_c_perturbed.append(children[:, :, c : c + 1, :])
neighbors[parent_idx].append(idx)
volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2)
neighbors = dict(neighbors)
field_neighbors = {i: [] for i in range(num_variables)}
else:
volume_m_c_sample = self.sample_sphere(
volume_mesh_centers, 1 / noise_intensity, num_sample_points
)
for i in range(num_sample_points):
volume_m_c_perturbed.append(volume_m_c_sample[:, :, i : i + 1, :])
volume_m_c_perturbed = torch.cat(volume_m_c_perturbed, dim=2)
for f in range(num_variables):
for p in range(volume_m_c_perturbed.shape[2]):
volume_m_c = volume_m_c_perturbed[:, :, p, :]
if p != 0:
dist = torch.norm(
volume_m_c - volume_mesh_centers, dim=-1, keepdim=True
)
basis_f = nn_basis[f](volume_m_c)
output = torch.cat((basis_f, encoding_node, encoding_g), dim=-1)
if self.encode_parameters:
output = torch.cat((output, param_encoding), dim=-1)
if p == 0:
output_center = agg_model[f](output)
else:
if p == 1:
output_neighbor = agg_model[f](output) * (1.0 / dist)
dist_sum = 1.0 / dist
else:
output_neighbor += agg_model[f](output) * (1.0 / dist)
dist_sum += 1.0 / dist
if return_volume_neighbors:
field_neighbors[f].append(agg_model[f](output))
if return_volume_neighbors:
field_neighbors[f] = torch.stack(field_neighbors[f], dim=2)
if num_sample_points > 1:
output_res = (
0.5 * output_center + 0.5 * output_neighbor / dist_sum
) # This only applies to the main point, and not the preturbed points
else:
output_res = output_center
if f == 0:
output_all = output_res
else:
output_all = torch.cat((output_all, output_res), axis=-1)
if return_volume_neighbors:
field_neighbors = torch.cat(
[field_neighbors[i] for i in range(num_variables)], dim=3
)
return output_all, volume_m_c_perturbed, field_neighbors, neighbors
else:
return output_all
[docs]
@profile
def forward(self, data_dict, return_volume_neighbors=False):
# Loading STL inputs, bounding box grids, precomputed SDF and scaling factors
# STL nodes
geo_centers = data_dict["geometry_coordinates"]
# Bounding box grid
s_grid = data_dict["surf_grid"]
sdf_surf_grid = data_dict["sdf_surf_grid"]
# Scaling factors
surf_max = data_dict["surface_min_max"][:, 1]
surf_min = data_dict["surface_min_max"][:, 0]
# Parameters
global_params_values = data_dict["global_params_values"]
global_params_reference = data_dict["global_params_reference"]
if self.output_features_vol is not None:
# Represent geometry on computational grid
# Computational domain grid
p_grid = data_dict["grid"]
sdf_grid = data_dict["sdf_grid"]
# Scaling factors
vol_max = data_dict["volume_min_max"][:, 1]
vol_min = data_dict["volume_min_max"][:, 0]
# Normalize based on computational domain
geo_centers_vol = 2.0 * (geo_centers - vol_min) / (vol_max - vol_min) - 1
encoding_g_vol = self.geo_rep_volume(geo_centers_vol, p_grid, sdf_grid)
# SDF on volume mesh nodes
sdf_nodes = data_dict["sdf_nodes"]
# Positional encoding based on closest point on surface to a volume node
pos_volume_closest = data_dict["pos_volume_closest"]
# Positional encoding based on center of mass of geometry to volume node
pos_volume_center_of_mass = data_dict["pos_volume_center_of_mass"]
if self.use_sdf_in_basis_func:
encoding_node_vol = torch.cat(
(sdf_nodes, pos_volume_closest, pos_volume_center_of_mass), dim=-1
)
else:
encoding_node_vol = pos_volume_center_of_mass
# Calculate positional encoding on volume nodes
encoding_node_vol = self.position_encoder(
encoding_node_vol, eval_mode="volume"
)
if self.output_features_surf is not None:
# Represent geometry on bounding box
geo_centers_surf = (
2.0 * (geo_centers - surf_min) / (surf_max - surf_min) - 1
)
encoding_g_surf = self.geo_rep_surface(
geo_centers_surf, s_grid, sdf_surf_grid
)
# Positional encoding based on center of mass of geometry to surface node
pos_surface_center_of_mass = data_dict["pos_surface_center_of_mass"]
encoding_node_surf = pos_surface_center_of_mass
# Calculate positional encoding on surface centers
encoding_node_surf = self.position_encoder(
encoding_node_surf, eval_mode="surface"
)
if (
self.output_features_surf is not None
and self.output_features_vol is not None
and self.combined_vol_surf
):
encoding_g = torch.cat((encoding_g_vol, encoding_g_surf), axis=1)
encoding_g_surf = self.combined_unet_surf(encoding_g)
encoding_g_vol = self.combined_unet_vol(encoding_g)
if self.output_features_vol is not None:
# Calculate local geometry encoding for volume
# Sampled points on volume
volume_mesh_centers = data_dict["volume_mesh_centers"]
encoding_g_vol = self.geo_encoding_local(
0.5 * encoding_g_vol, volume_mesh_centers, p_grid, mode="volume"
)
# Approximate solution on volume node
output_vol = self.calculate_solution(
volume_mesh_centers,
encoding_g_vol,
encoding_node_vol,
global_params_values,
global_params_reference,
eval_mode="volume",
num_sample_points=self.num_sample_points_volume,
return_volume_neighbors=return_volume_neighbors,
)
else:
output_vol = None
if self.output_features_surf is not None:
# Sampled points on surface
surface_mesh_centers = data_dict["surface_mesh_centers"]
surface_normals = data_dict["surface_normals"]
surface_areas = data_dict["surface_areas"]
# Neighbors of sampled points on surface
surface_mesh_neighbors = data_dict["surface_mesh_neighbors"]
surface_neighbors_normals = data_dict["surface_neighbors_normals"]
surface_neighbors_areas = data_dict["surface_neighbors_areas"]
surface_areas = torch.unsqueeze(surface_areas, -1)
surface_neighbors_areas = torch.unsqueeze(surface_neighbors_areas, -1)
# Calculate local geometry encoding for surface
encoding_g_surf = self.geo_encoding_local(
0.5 * encoding_g_surf, surface_mesh_centers, s_grid, mode="surface"
)
# Approximate solution on surface cell center
output_surf = self.calculate_solution_with_neighbors(
surface_mesh_centers,
encoding_g_surf,
encoding_node_surf,
surface_mesh_neighbors,
surface_normals,
surface_neighbors_normals,
surface_areas,
surface_neighbors_areas,
global_params_values,
global_params_reference,
num_sample_points=self.num_sample_points_surface,
)
else:
output_surf = None
return output_vol, output_surf