# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometric transformations for simplicial meshes.
This module implements linear and affine transformations with intelligent
cache handling. By default, all caches are invalidated; transformations
explicitly opt-in to preserve/transform specific cache fields.
Cached fields handled:
- areas: point_data and cell_data
- normals: point_data and cell_data
- centroids: cell_data only
"""
from collections.abc import Sequence
from typing import TYPE_CHECKING, Literal
import torch
import torch.nn.functional as F
from jaxtyping import Float
from tensordict import TensorDict
if TYPE_CHECKING:
from physicsnemo.mesh.mesh import Mesh
### User Data Transformation ###
def _transform_tensordict(
data: TensorDict,
matrix: Float[torch.Tensor, "new_n_spatial_dims n_spatial_dims"],
n_spatial_dims: int,
field_type: str,
mask: TensorDict | None = None,
) -> TensorDict:
"""Transform vector/tensor fields in a TensorDict.
When ``mask`` is ``None``, all fields with compatible shapes are
transformed. When ``mask`` is a ``TensorDict`` of scalar bool leaves,
only fields whose corresponding mask value is ``True`` are transformed;
fields absent from the mask are left unchanged.
Parameters
----------
data : TensorDict
TensorDict with cache already stripped.
matrix : Float[torch.Tensor, "new_n_spatial_dims n_spatial_dims"]
Transformation matrix, shape :math:`(S', S)`.
n_spatial_dims : int
Expected spatial dimensionality.
field_type : str
Description for error messages (e.g., ``"point_data"``, ``"global_data"``).
mask : TensorDict or None, optional
Parallel TensorDict with scalar ``bool`` tensor leaves. When
provided, only keys whose mask is ``True`` are transformed.
Keys absent from the mask default to ``False`` (not transformed).
Returns
-------
TensorDict
TensorDict with transformed fields (modified in place).
"""
batch_size = data.batch_size
has_batch_dim = len(batch_size) > 0
def transform_field(key: str, value: torch.Tensor) -> torch.Tensor:
"""Transform a single vector or tensor field."""
shape = value.shape[len(batch_size) :]
### Scalars are invariant under linear transformations
if len(shape) == 0:
return value
### Validate spatial dimension compatibility
if shape[0] != n_spatial_dims:
raise ValueError(
f"Cannot transform {field_type} field {key!r} with shape {value.shape}. "
f"First spatial dimension must be {n_spatial_dims}, but got {shape[0]}. "
f"Use a dict to select specific fields, e.g. "
f'transform_{field_type}={{"field_name": True}}.'
)
### Vector field: v' = v @ M^T
if len(shape) == 1:
return value @ matrix.T
### Rank-2 tensor field: T' = M @ T @ M^T (e.g., stress tensors)
if shape == (n_spatial_dims, n_spatial_dims):
if has_batch_dim:
return torch.einsum("ij,bjk,lk->bil", matrix, value, matrix)
else:
return torch.einsum("ij,jk,lk->il", matrix, value, matrix)
### Higher-rank tensor field: apply transformation to each spatial index
if all(s == n_spatial_dims for s in shape):
result = value
# Index chars for einsum (skip 'b' for batch and 'z' for contraction)
chars = "acdefghijklmnopqrstuvwxy"
batch_prefix = "b" if has_batch_dim else ""
for dim_idx in range(len(shape)):
input_indices = "".join(
chars[i].upper()
if i < dim_idx
else "z"
if i == dim_idx
else chars[i]
for i in range(len(shape))
)
output_indices = "".join(
chars[i].upper() if i <= dim_idx else chars[i]
for i in range(len(shape))
)
einsum_str = f"{chars[dim_idx].upper()}z,{batch_prefix}{input_indices}->{batch_prefix}{output_indices}"
result = torch.einsum(einsum_str, matrix, result)
return result
raise ValueError(
f"Cannot transform {field_type} field {key!r} with shape {value.shape}. "
f"Expected all spatial dimensions to be {n_spatial_dims}, but got {shape}"
)
if mask is None:
transformed = data.named_apply(transform_field, batch_size=batch_size)
else:
def selective_transform(
key: str, value: torch.Tensor, should_transform: torch.Tensor
) -> torch.Tensor:
if not should_transform.item():
return value
return transform_field(key, value)
transformed = data.named_apply(
selective_transform,
mask,
default=torch.tensor(False),
batch_size=batch_size,
)
data.update(transformed)
return data
### Rotation Matrix Construction ###
def _build_rotation_matrix(
angle: float | Float[torch.Tensor, ""],
axis: Float[torch.Tensor, " n_spatial_dims"] | None,
device: torch.device,
) -> Float[torch.Tensor, "n_spatial_dims n_spatial_dims"]:
"""Build rotation matrix for 2D or 3D.
Parameters
----------
angle : float or Float[torch.Tensor, ""]
Rotation angle in radians.
axis : Float[torch.Tensor, " n_spatial_dims"] or None
Rotation axis vector. None for 2D, shape :math:`(3,)` for 3D.
device : device
Target device for the output matrix.
Returns
-------
Float[torch.Tensor, "n_spatial_dims n_spatial_dims"]
Rotation matrix: :math:`(2, 2)` if axis is None,
:math:`(3, 3)` if axis has shape :math:`(3,)`.
"""
angle = torch.as_tensor(angle, device=device)
c, s = torch.cos(angle), torch.sin(angle)
if axis is None:
### 2D rotation matrix: [[c, -s], [s, c]]
return torch.stack([torch.stack([c, -s]), torch.stack([s, c])])
### 3D rotation using Rodrigues' formula: R = cI + s[u]_× + (1-c)(u⊗u)
axis = torch.as_tensor(axis, device=device, dtype=angle.dtype)
if axis.shape != (3,):
raise NotImplementedError(
f"Rotation only supported for 2D (axis=None) or 3D (axis shape (3,)). "
f"Got axis with shape {axis.shape}."
)
if axis.norm() < 1e-10:
raise ValueError(f"Axis vector has near-zero length: {axis.norm()=}")
u = F.normalize(axis, dim=0, eps=0.0)
ux, uy, uz = u
zero = torch.zeros((), device=device, dtype=u.dtype)
# Skew-symmetric cross-product matrix [u]_×
u_cross = torch.stack(
[
torch.stack([zero, -uz, uy]),
torch.stack([uz, zero, -ux]),
torch.stack([-uy, ux, zero]),
]
)
identity = torch.eye(3, device=device, dtype=u.dtype)
return c * identity + s * u_cross + (1 - c) * u.outer(u)
### Axis Resolution ###
def _resolve_rotation_axis(
axis: Float[torch.Tensor, " n_spatial_dims"]
| Sequence[float]
| Literal["x", "y", "z"]
| None,
n_spatial_dims: int,
device: torch.device,
) -> Float[torch.Tensor, " n_spatial_dims"] | None:
"""Normalize an axis specification into a tensor or None.
Parameters
----------
axis : Float[torch.Tensor, " n_spatial_dims"] or Sequence[float] or {"x", "y", "z"} or None
Rotation axis. ``None`` for 2D, tensor/sequence/string for 3D.
n_spatial_dims : int
Number of spatial dimensions (used for validation).
device : torch.device
Target device for the output tensor.
Returns
-------
Float[torch.Tensor, " n_spatial_dims"] or None
Normalized axis tensor with shape :math:`(3,)` and dtype
``float32``, or ``None`` for 2D rotation.
"""
if isinstance(axis, str):
axis_map = {"x": 0, "y": 1, "z": 2}
if axis not in axis_map:
raise ValueError(f"axis must be 'x', 'y', or 'z', got {axis!r}")
idx = axis_map[axis]
if idx >= n_spatial_dims:
raise ValueError(
f"axis={axis!r} is invalid for mesh with "
f"n_spatial_dims={n_spatial_dims}"
)
resolved = torch.zeros(n_spatial_dims, device=device)
resolved[idx] = 1.0
return resolved
if axis is not None:
axis = torch.as_tensor(axis, device=device, dtype=torch.float32)
expected_dims = 2 if axis is None else 3
if n_spatial_dims != expected_dims:
raise ValueError(
f"axis={'None' if axis is None else 'provided'} implies "
f"{expected_dims}D rotation, but mesh has "
f"n_spatial_dims={n_spatial_dims}"
)
return axis
### Matrix Construction Helpers ###
[docs]
def rotation_matrix(
angle: float,
axis: Float[torch.Tensor, " n_spatial_dims"]
| Sequence[float]
| Literal["x", "y", "z"]
| None,
n_spatial_dims: int,
device: torch.device,
dtype: torch.dtype,
) -> Float[torch.Tensor, "n_spatial_dims n_spatial_dims"]:
r"""Build a rotation matrix from angle and axis.
Parameters
----------
angle : float
Rotation angle in radians (counterclockwise, right-hand rule).
axis : Float[torch.Tensor, " n_spatial_dims"] or Sequence[float] or {"x", "y", "z"} or None
Rotation axis. ``None`` for 2D, tensor/sequence/string for 3D.
n_spatial_dims : int
Number of spatial dimensions.
device : torch.device
Target device for the output matrix.
dtype : torch.dtype
Target dtype for the output matrix.
Returns
-------
Float[torch.Tensor, "n_spatial_dims n_spatial_dims"]
Rotation matrix, shape :math:`(S, S)`.
"""
resolved = _resolve_rotation_axis(axis, n_spatial_dims, device)
return _build_rotation_matrix(angle=angle, axis=resolved, device=device).to(
dtype=dtype
)
[docs]
def scale_matrix(
factor: float | Float[torch.Tensor, " n_spatial_dims"] | Sequence[float],
n_spatial_dims: int,
device: torch.device,
dtype: torch.dtype,
) -> Float[torch.Tensor, "n_spatial_dims n_spatial_dims"]:
r"""Build a diagonal scale matrix from a factor specification.
Parameters
----------
factor : float or Float[torch.Tensor, " n_spatial_dims"] or Sequence[float]
Scale factor(s). Scalar for uniform, vector for non-uniform.
n_spatial_dims : int
Number of spatial dimensions.
device : torch.device
Target device for the output matrix.
dtype : torch.dtype
Target dtype for the output matrix.
Returns
-------
Float[torch.Tensor, "n_spatial_dims n_spatial_dims"]
Diagonal scale matrix, shape :math:`(S, S)`.
Raises
------
ValueError
If ``factor`` is a vector whose length does not match
``n_spatial_dims``.
"""
factor_t = torch.as_tensor(factor, device=device, dtype=dtype)
if factor_t.ndim == 0:
factor_t = factor_t.expand(n_spatial_dims)
elif not torch.compiler.is_compiling() and factor_t.shape[-1] != n_spatial_dims:
raise ValueError(
f"factor must be scalar or shape ({n_spatial_dims},), got {factor_t.shape}"
)
return torch.diag(factor_t)
### Transform Mask Normalization ###
def _normalize_transform_mask(
spec: bool | dict | TensorDict,
) -> TensorDict | None:
"""Convert a transform spec to a TensorDict mask, or None.
Parameters
----------
spec : bool or dict or TensorDict
Transform specification. ``True`` returns ``None`` (transform all).
A ``dict`` of bools is recursively converted to a ``TensorDict``
with scalar ``bool`` tensor leaves. A ``TensorDict`` is used
directly.
Returns
-------
TensorDict or None
Mask TensorDict with scalar bool leaves, or ``None`` to
transform all fields.
"""
if spec is True:
return None
if isinstance(spec, TensorDict):
return spec
if isinstance(spec, dict):
return TensorDict(
{
k: torch.tensor(v)
if isinstance(v, bool)
else _normalize_transform_mask(v)
for k, v in spec.items()
},
batch_size=[],
)
raise TypeError(f"Expected bool, dict, or TensorDict, got {type(spec)!r}")
def _maybe_transform_data(
data: TensorDict,
spec: bool | TensorDict,
matrix: torch.Tensor,
n_spatial_dims: int,
label: str,
) -> TensorDict:
"""Clone and transform a data TensorDict if spec is not False."""
if spec is False:
return data
cloned = data.clone()
_transform_tensordict(
cloned,
matrix,
n_spatial_dims,
label,
mask=_normalize_transform_mask(spec),
)
return cloned
### Public API ###
[docs]
def translate(
mesh: "Mesh",
offset: Float[torch.Tensor, " n_spatial_dims"] | Sequence[float],
) -> "Mesh":
"""Apply a translation to the mesh.
Translation only affects point positions and centroids. Vector/tensor fields
are unchanged by translation (they represent directions, not positions).
Parameters
----------
mesh : Mesh
Input mesh to translate.
offset : Float[torch.Tensor, " n_spatial_dims"] or Sequence[float]
Translation vector, shape :math:`(S,)`.
Returns
-------
Mesh
New Mesh with translated geometry.
Notes
-----
Cache Handling:
- areas: Unchanged
- centroids: Translated
- normals: Unchanged
"""
offset = torch.as_tensor(offset, device=mesh.points.device, dtype=mesh.points.dtype)
if not torch.compiler.is_compiling():
if offset.shape[-1] != mesh.n_spatial_dims:
raise ValueError(
f"offset must have shape ({mesh.n_spatial_dims},), got {offset.shape}"
)
new_points = mesh.points + offset
device = mesh.points.device
new_cache = TensorDict(
{
"cell": TensorDict({}, batch_size=[mesh.n_cells]),
"point": TensorDict({}, batch_size=[mesh.n_points]),
"topology": mesh._cache.get("topology", TensorDict({})),
},
device=device,
)
### Areas and normals are unchanged by translation
for category in ("cell", "point"):
for key in ("areas", "normals"):
if (v := mesh._cache.get((category, key), None)) is not None:
new_cache[category, key] = v
### Centroids are translated
if (v := mesh._cache.get(("cell", "centroids"), None)) is not None:
new_cache["cell", "centroids"] = v + offset
from physicsnemo.mesh.mesh import Mesh
return Mesh(
points=new_points,
cells=mesh.cells,
point_data=mesh.point_data,
cell_data=mesh.cell_data,
global_data=mesh.global_data,
_cache=new_cache,
)
[docs]
def rotate(
mesh: "Mesh",
angle: float,
axis: Float[torch.Tensor, " n_spatial_dims"]
| Sequence[float]
| Literal["x", "y", "z"]
| None = None,
center: Float[torch.Tensor, " n_spatial_dims"] | Sequence[float] | None = None,
transform_point_data: bool | TensorDict = False,
transform_cell_data: bool | TensorDict = False,
transform_global_data: bool | TensorDict = False,
) -> "Mesh":
"""Rotate the mesh about an axis by a specified angle.
Parameters
----------
mesh : Mesh
Input mesh to rotate.
angle : float
Rotation angle in radians (counterclockwise, right-hand rule).
axis : Float[torch.Tensor, " n_spatial_dims"] or Sequence[float] or {"x", "y", "z"} or None
Rotation axis vector. ``None`` for 2D, shape :math:`(3,)` for 3D.
String literals ``"x"``, ``"y"``, ``"z"`` are converted to unit
vectors ``(1,0,0)``, ``(0,1,0)``, ``(0,0,1)`` respectively.
center : Float[torch.Tensor, " n_spatial_dims"] or Sequence[float] or None
Center point for rotation. If ``None``, rotates about the origin.
transform_point_data : bool or TensorDict
Controls transformation of ``point_data`` fields. See
:func:`transform` for full semantics.
transform_cell_data : bool or TensorDict
Same semantics as ``transform_point_data``, for ``cell_data``.
transform_global_data : bool or TensorDict
Same semantics as ``transform_point_data``, for ``global_data``.
Returns
-------
Mesh
New Mesh with rotated geometry.
Notes
-----
Cache Handling:
- areas: Unchanged (rotation preserves volumes)
- centroids: Rotated
- normals: Rotated
"""
R = rotation_matrix(
angle=angle,
axis=axis,
n_spatial_dims=mesh.n_spatial_dims,
device=mesh.points.device,
dtype=mesh.points.dtype,
)
### Handle center by translate-rotate-translate
if center is not None:
center = torch.as_tensor(
center, device=mesh.points.device, dtype=mesh.points.dtype
)
return translate(
rotate(
translate(mesh, -center),
angle,
axis,
center=None,
transform_point_data=transform_point_data,
transform_cell_data=transform_cell_data,
transform_global_data=transform_global_data,
),
center,
)
return transform(
mesh,
matrix=R,
transform_point_data=transform_point_data,
transform_cell_data=transform_cell_data,
transform_global_data=transform_global_data,
assume_invertible=True,
)
[docs]
def scale(
mesh: "Mesh",
factor: float | Float[torch.Tensor, " n_spatial_dims"] | Sequence[float],
center: Float[torch.Tensor, " n_spatial_dims"] | Sequence[float] | None = None,
transform_point_data: bool | TensorDict = False,
transform_cell_data: bool | TensorDict = False,
transform_global_data: bool | TensorDict = False,
assume_invertible: bool | None = None,
) -> "Mesh":
"""Scale the mesh by specified factor(s).
Parameters
----------
mesh : Mesh
Input mesh to scale.
factor : float or Float[torch.Tensor, " n_spatial_dims"] or Sequence[float]
Scale factor(s). Scalar for uniform, vector for non-uniform.
center : Float[torch.Tensor, " n_spatial_dims"] or Sequence[float] or None
Center point for scaling. If ``None``, scales about the origin.
transform_point_data : bool or TensorDict
Controls transformation of ``point_data`` fields. See
:func:`transform` for full semantics.
transform_cell_data : bool or TensorDict
Same semantics as ``transform_point_data``, for ``cell_data``.
transform_global_data : bool or TensorDict
Same semantics as ``transform_point_data``, for ``global_data``.
assume_invertible : bool or None
Controls cache propagation:
- True: Assume all factors are non-zero, propagate caches (compile-safe)
- False: Assume some factor is zero, skip cache propagation (compile-safe)
- None: Check determinant at runtime (may cause graph breaks under torch.compile)
Returns
-------
Mesh
New Mesh with scaled geometry.
Notes
-----
Cache Handling:
- areas: Scaled correctly. For non-isotropic transforms of codimension-1
embedded manifolds, per-element scaling is computed using normals.
- centroids: Scaled
- normals: Transformed by inverse-transpose (direction adjusted, magnitude normalized)
"""
M = scale_matrix(
factor=factor,
n_spatial_dims=mesh.n_spatial_dims,
device=mesh.points.device,
dtype=mesh.points.dtype,
)
### Handle center by translate-scale-translate
if center is not None:
center = torch.as_tensor(
center, device=mesh.points.device, dtype=mesh.points.dtype
)
return translate(
scale(
translate(mesh, -center),
factor,
center=None,
transform_point_data=transform_point_data,
transform_cell_data=transform_cell_data,
transform_global_data=transform_global_data,
assume_invertible=assume_invertible,
),
center,
)
return transform(
mesh,
matrix=M,
transform_point_data=transform_point_data,
transform_cell_data=transform_cell_data,
transform_global_data=transform_global_data,
assume_invertible=assume_invertible,
)