Built-in Transforms#

PhysicsNeMo Transforms are the core data manipulation tool for datapipes. All transforms inherit from the transform base class.

class physicsnemo.datapipes.transforms.base.Transform[source]#

Bases: ABC

Abstract base class for all transforms.

Transforms operate on a TensorDict and return a modified TensorDict. They are designed to run on GPU tensors for maximum performance. Metadata is not passed to transforms (handled separately by Dataset/DataLoader).

Subclasses must implement:

  • __call__(data: TensorDict) -> TensorDict

Optionally override:

  • extra_repr() -> str: For custom repr output

  • state_dict() -> dict: For serialization

  • load_state_dict(state_dict: dict): For deserialization

Examples

>>> class MyTransform(Transform):
...     def __init__(self, scale: float):
...         super().__init__()
...         self.scale = scale
...
...     def __call__(self, data: TensorDict) -> TensorDict:
...         # Apply transformation to all tensors
...         return data.apply(lambda x: x * self.scale)
property device: device | None#

The device this transform operates on.

Returns:

The device, or None if not set.

Return type:

torch.device or None

extra_repr() str[source]#

Return extra information for repr.

Override this to add transform-specific info to the repr.

Returns:

Extra representation string.

Return type:

str

load_state_dict(state_dict: dict[str, Any]) None[source]#

Load state from a state dictionary.

Override this to restore transform state.

Parameters:

state_dict (dict[str, Any]) – State dictionary to load from.

state_dict() dict[str, Any][source]#

Return a dictionary containing the transform’s state.

Override this for transforms with learnable or configurable state.

Returns:

State dictionary.

Return type:

dict[str, Any]

to(
device: device | str,
) Transform[source]#

Move any internal tensors to the specified device.

This default implementation automatically moves any tensor attributes found in self.__dict__ to the specified device. Override this method if your transform requires custom device handling.

Parameters:

device (torch.device or str) – Target device.

Returns:

Self for chaining.

Return type:

Transform

To implement a new transform, users are required to override the __call__ method as well as any initialization or configuration details.

The input to a transform is mutable by default, and so the order of transformations matters.

In general, transforms are transactional: take input in, manipulate it, return output, and almost never update state. Transforms should be device-agnostic, and use a compute-follows-data principle, operating on data on the device where it resides whenever possible.

By default, transforms accept and return tensordict objects: this is not, strictly, a requirement that must be enforced. If you implement custom transforms that return different data types, downstream transforms should expect that data type. One example of this, found in the minimal datapipe examples, is turning the tensordict objects into a PyTorch Geometric graph object. This type of manipulation is perfectly valid, but requires custom collation functions and prevents usage of tensordict-based transforms downstream.

One unique transformation is the Compose transformation, which takes a list of transformations and logically applies them in order, as one transformation. Use the Compose transformation similar to the torch.nn.Sequential container for stacking PyTorch Modules together.

class physicsnemo.datapipes.transforms.compose.Compose(
transforms: Sequence[Transform],
)[source]#

Bases: Transform

Compose multiple transforms into a sequential pipeline.

Applies transforms in order, passing the output of each as input to the next.

Parameters:

transforms (Sequence[Transform]) – Sequence of transforms to apply in order.

Examples

>>> from physicsnemo.datapipes.transforms import Normalize, SubsamplePoints
>>> from tensordict import TensorDict
>>> sample = TensorDict({
...     "pressure": torch.tensor([101325.0, 102325.0, 100325.0]),
... })
>>> normalize = Normalize(input_keys=["pressure"], method="mean_std", means={"pressure": 101325.0}, stds={"pressure": 1000.0})
>>> subsample = SubsamplePoints(input_keys=["pressure"], n_points=1000)
>>> pipeline = Compose([normalize, subsample])
>>> transformed = pipeline(sample)
>>> transformed["pressure"]
tensor([ 0.,  1., -1.])
append(
transform: Transform,
) None[source]#

Append a transform to the pipeline.

Parameters:

transform (Transform) – Transform to append.

Raises:

TypeError – If transform is not a Transform instance.

extra_repr() str[source]#

Return extra information for repr.

Returns:

Formatted string showing all transforms.

Return type:

str

state_dict() dict[str, Any][source]#

Return state of all transforms.

Returns:

Dictionary containing transform states and types.

Return type:

dict[str, Any]

to(
device: device | str,
) Compose[source]#

Move all transforms to the specified device.

Parameters:

device (torch.device or str) – Target device.

Returns:

Self for chaining.

Return type:

Compose

Below are a collection of commonly used transforms in PhysicsNeMo’s datapipes. Since the input and output are tensordict objects, a common configuration pattern for a transformation is to specify a list of input and output keys to operate on, though this is not always the case.

Most transforms do not have an internal state; the ones that do, however, will automatically move tensors to and from the GPU with the to() syntax as expected.

Normalization#

class physicsnemo.datapipes.transforms.normalize.Normalize(
input_keys: list[str],
means: dict[str, float | Tensor] | float | Tensor | None = None,
stds: dict[str, float | Tensor] | float | Tensor | None = None,
*,
method: Literal['mean_std', 'min_max'] | None = None,
mins: dict[str, float | Tensor] | float | Tensor | None = None,
maxs: dict[str, float | Tensor] | float | Tensor | None = None,
stats_file: str | Path | None = None,
eps: float = 1e-08,
)[source]#

Bases: Transform

Normalize specified fields using mean-std or min-max scaling.

Supports two normalization methods: - mean_std: Applies (x - mean) / std for each specified field - min_max: Applies (x - center) / half_range, normalizing to [-1, 1]

where center = (max + min) / 2 and half_range = (max - min) / 2

Parameters can be provided directly or loaded from a .npz file.

Examples

Mean-std scaling:

>>> import torch
>>> from tensordict import TensorDict
>>> sample = TensorDict({
...     "pressure": torch.tensor([101325.0, 102325.0, 100325.0]),
...     "velocity": torch.tensor([10.0, -10.0, 0.0]),
... })
>>> norm = Normalize(
...     input_keys=["pressure", "velocity"],
...     method="mean_std",
...     means={"pressure": 101325.0, "velocity": 0.0},
...     stds={"pressure": 1000.0, "velocity": 10.0},
... )
>>> normalized = norm(sample)
>>> normalized["pressure"]
tensor([ 0.,  1., -1.])
>>> normalized["velocity"]
tensor([ 1., -1.,  0.])

Min-max scaling (normalizes to [-1, 1]):

>>> sample = TensorDict({
...     "pressure": torch.tensor([100000.0, 105000.0, 110000.0]),
... })
>>> norm = Normalize(
...     input_keys=["pressure"],
...     method="min_max",
...     mins={"pressure": 100000.0},
...     maxs={"pressure": 110000.0},
... )
>>> normalized = norm(sample)
>>> normalized["pressure"]
tensor([-1.,  0.,  1.])
extra_repr() str[source]#

Return extra information for repr.

Override this to add transform-specific info to the repr.

Returns:

Extra representation string.

Return type:

str

inverse(
data: TensorDict,
) TensorDict[source]#

Apply inverse normalization (denormalize).

For mean_std method: x * std + mean For min_max method: x * half_range + center

Parameters:

data (TensorDict) – Normalized TensorDict.

Returns:

Denormalized TensorDict.

Return type:

TensorDict

load_state_dict(state_dict: dict[str, Any]) None[source]#

Load normalization parameters.

state_dict() dict[str, Any][source]#

Return normalization parameters.

to(
device: device | str,
) Normalize[source]#

Move normalization parameters to the specified device.

Subsampling#

class physicsnemo.datapipes.transforms.subsample.SubsamplePoints(
input_keys: list[str],
n_points: int,
*,
algorithm: Literal['poisson_fixed', 'uniform'] = 'poisson_fixed',
weights_key: str | None = None,
)[source]#

Bases: Transform

Subsample points from large point clouds or meshes.

This transform applies coordinated subsampling to multiple tensor fields, ensuring that the same points are selected across all specified keys. Useful for downsampling large volumetric data or point clouds while maintaining correspondence between coordinates and field values.

Supports two sampling algorithms:

  • "poisson_fixed": Near-uniform sampling for very large datasets (> 2^24 points)

  • "uniform": Standard uniform sampling

Optionally supports weighted sampling (e.g., area-weighted for surface meshes) by providing a weights_key.

Parameters:
  • input_keys (list[str]) – List of tensor keys to subsample. All must have the same first dimension size.

  • n_points (int) – Number of points to sample.

  • algorithm ({"poisson_fixed", "uniform"}, default="poisson_fixed") – Sampling algorithm to use.

  • weights_key (str, optional) – Optional key for sampling weights (e.g., "surface_areas" for area-weighted surface sampling). When provided, samples are drawn according to the weights distribution.

Examples

Uniform sampling:

>>> transform = SubsamplePoints(
...     input_keys=["volume_mesh_centers", "volume_fields"],
...     n_points=10000,
...     algorithm="poisson_fixed"
... )
>>> sample = TensorDict({
...     "volume_mesh_centers": torch.randn(100000, 3),
...     "volume_fields": torch.randn(100000, 5)
... })
>>> result = transform(sample)
>>> print(result["volume_mesh_centers"].shape)
torch.Size([10000, 3])

Weighted sampling:

>>> transform = SubsamplePoints(
...     input_keys=["surface_mesh_centers", "surface_fields", "surface_normals"],
...     n_points=5000,
...     algorithm="uniform",
...     weights_key="surface_areas"
... )
>>> sample = TensorDict({
...     "surface_mesh_centers": torch.randn(20000, 3),
...     "surface_fields": torch.randn(20000, 2),
...     "surface_normals": torch.randn(20000, 3),
...     "surface_areas": torch.rand(20000)
... })
>>> result = transform(sample)
>>> print(result["surface_mesh_centers"].shape)
torch.Size([5000, 3])

Notes

All specified keys must have the same size in their first dimension. The same indices are applied to all keys to maintain correspondence.

Geometric#

class physicsnemo.datapipes.transforms.geometric.ComputeSDF(
input_keys: list[str],
output_key: str,
mesh_coords_key: str,
mesh_faces_key: str,
*,
use_winding_number: bool = True,
closest_points_key: str | None = None,
)[source]#

Bases: Transform

Compute signed distance field from a mesh.

Computes the signed distance from query points to the nearest point on a triangular mesh surface. Optionally returns the closest points on the mesh surface for each query point.

Parameters:
  • input_keys (list[str]) – List of keys containing query points to compute SDF for. Each tensor should have shape \((N, 3)\).

  • output_key (str) – Key to store the computed SDF values.

  • mesh_coords_key (str) – Key for mesh vertex coordinates, shape \((M, 3)\).

  • mesh_faces_key (str) – Key for mesh face indices (flattened), shape \((F*3,)\).

  • use_winding_number (bool, default=True) – If True, use winding number for sign determination.

  • closest_points_key (str, optional) – Optional key to store closest points on mesh.

Examples

>>> transform = ComputeSDF(
...     input_keys=["volume_mesh_centers"],
...     output_key="sdf_nodes",
...     mesh_coords_key="stl_coordinates",
...     mesh_faces_key="stl_faces",
...     closest_points_key="closest_points"
... )
>>> sample = TensorDict({
...     "volume_mesh_centers": torch.randn(10000, 3),
...     "stl_coordinates": torch.randn(5000, 3),
...     "stl_faces": torch.randint(0, 5000, (10000,))
... })
>>> result = transform(sample)
>>> print(result["sdf_nodes"].shape)
torch.Size([10000, 1])
class physicsnemo.datapipes.transforms.geometric.ComputeNormals(
positions_key: str,
closest_points_key: str,
center_of_mass_key: str,
output_key: str,
*,
handle_zero_distance: bool = True,
)[source]#

Bases: Transform

Compute normal vectors from closest points.

Computes normalized direction vectors from query points to their closest points on a surface. Handles zero-distance edge cases by falling back to center of mass direction.

Parameters:
  • positions_key (str) – Key for position tensor, shape \((N, 3)\).

  • closest_points_key (str) – Key for closest points tensor, shape \((N, 3)\).

  • center_of_mass_key (str) – Key for center of mass, shape \((1, 3)\) or \((3,)\).

  • output_key (str) – Key to store computed normals.

  • handle_zero_distance (bool, default=True) – If True, use center_of_mass fallback for zero distances.

Examples

>>> transform = ComputeNormals(
...     positions_key="volume_mesh_centers",
...     closest_points_key="closest_points",
...     center_of_mass_key="center_of_mass",
...     output_key="volume_normals"
... )
class physicsnemo.datapipes.transforms.geometric.Translate(
input_keys: list[str],
center_key_or_value: str | Tensor,
*,
subtract: bool = False,
)[source]#

Bases: Transform

Apply a translation by adding or subtracting a center point.

By default, this will ADD the translation. But you can also use the subtract mode: this is particularly useful when composed with CenterOfMass: you can compute the CoM and apply a translation as a CoM subtraction to center the data at the origin.

Parameters:
  • input_keys (list[str]) – List of position tensor keys to translate.

  • center_key_or_value (str or torch.Tensor) – Either a key name (str) for a tensor in the sample, or a fixed tensor value to add/subtract.

  • subtract (bool, default=False) – If False (default), ADD the translation (data + center). If True, SUBTRACT the translation (data - center). Use subtract=True when centering data around a reference point like center of mass.

Examples

Add mode (default) - shift points by a fixed offset:

>>> transform = Translate(
...     input_keys=["positions"],
...     center_key_or_value=torch.tensor([1.0, 2.0, 3.0])
... )
>>> # result["positions"] = original + [1, 2, 3]

Subtract mode - center points by subtracting center of mass:

>>> transform = Translate(
...     input_keys=["volume_mesh_centers", "surface_mesh_centers"],
...     center_key_or_value="center_of_mass",
...     subtract=True
... )
>>> # result["positions"] = original - center_of_mass
class physicsnemo.datapipes.transforms.geometric.Scale(
input_keys: list[str],
scale: Tensor,
*,
divide: bool = False,
)[source]#

Bases: Transform

Apply a scale factor by multiplying or dividing by a reference scale.

By default, this will MULTIPLY by the scale factor. But you can also use the divide mode: this is particularly useful for normalizing data to make the representation scale invariant (e.g., dividing by a characteristic length scale).

Parameters:
  • input_keys (list[str]) – List of position tensor keys to scale.

  • scale (torch.Tensor) – Scale factor tensor, shape \((1, D)\) or \((D,)\).

  • divide (bool, default=False) – If False (default), MULTIPLY by the scale (data * scale). If True, DIVIDE by the scale (data / scale). Use divide=True when normalizing data by a reference scale.

Examples

Multiply mode (default) - scale up positions by 2x:

>>> transform = Scale(
...     input_keys=["positions"],
...     scale=torch.tensor([2.0, 2.0, 2.0])
... )
>>> # result["positions"] = original * [2, 2, 2]

Divide mode - normalize by a reference scale:

>>> transform = Scale(
...     input_keys=["volume_mesh_centers", "geometry_coordinates"],
...     scale=torch.tensor([1.0, 1.0, 1.0]),
...     divide=True
... )
>>> # result["positions"] = original / scale

Spatial#

class physicsnemo.datapipes.transforms.spatial.BoundingBoxFilter(
input_keys: list[str],
bbox_min: Tensor,
bbox_max: Tensor,
*,
dependent_keys: list[str] | None = None,
)[source]#

Bases: Transform

Filter points outside a spatial bounding box.

Removes points that fall outside specified min/max bounds and applies the same filtering to dependent arrays to maintain correspondence. This is useful for focusing on specific regions of interest or removing outliers from simulation data.

Parameters:
  • input_keys (list[str]) – List of coordinate tensor keys to filter.

  • bbox_min (torch.Tensor) – Minimum corner of bounding box, shape \((3,)\).

  • bbox_max (torch.Tensor) – Maximum corner of bounding box, shape \((3,)\).

  • dependent_keys (list[str], optional) – Optional list of keys to filter using the same mask. These maintain correspondence with the filtered coordinates.

Examples

>>> transform = BoundingBoxFilter(
...     input_keys=["volume_mesh_centers"],
...     bbox_min=torch.tensor([-1.0, -1.0, -1.0]),
...     bbox_max=torch.tensor([1.0, 1.0, 1.0]),
...     dependent_keys=["volume_fields", "sdf_nodes"]
... )
>>> sample = TensorDict({
...     "volume_mesh_centers": torch.randn(10000, 3) * 2,  # Some outside bbox
...     "volume_fields": torch.randn(10000, 4)
... })
>>> result = transform(sample)
>>> # Only points within bbox remain
class physicsnemo.datapipes.transforms.spatial.CreateGrid(
output_key: str,
resolution: tuple[int, int, int],
bbox_min: Tensor,
bbox_max: Tensor,
)[source]#

Bases: Transform

Create a regular 3D spatial grid.

Generates a uniform grid spanning a bounding box, used for latent space representations, interpolation grids, or structured spatial queries.

Parameters:
  • output_key (str) – Key to store the generated grid.

  • resolution (tuple[int, int, int]) – Grid resolution as (nx, ny, nz).

  • bbox_min (torch.Tensor) – Minimum corner of bounding box, shape \((3,)\).

  • bbox_max (torch.Tensor) – Maximum corner of bounding box, shape \((3,)\).

Examples

>>> transform = CreateGrid(
...     output_key="grid",
...     resolution=(64, 64, 64),
...     bbox_min=torch.tensor([-1.0, -1.0, -1.0]),
...     bbox_max=torch.tensor([1.0, 1.0, 1.0])
... )
>>> sample = TensorDict({})
>>> result = transform(sample)
>>> print(result["grid"].shape)
torch.Size([262144, 3])
class physicsnemo.datapipes.transforms.spatial.KNearestNeighbors(
points_key: str,
queries_key: str,
k: int,
*,
output_prefix: str = 'neighbors',
extract_keys: list[str] | None = None,
drop_first_neighbor: bool = False,
)[source]#

Bases: Transform

Compute k-nearest neighbors in a point cloud.

Finds the k nearest neighbors for each query point and extracts corresponding coordinates and other attributes. Useful for local feature aggregation in mesh networks and spatial interpolation.

Parameters:
  • points_key (str) – Key for reference points to search, shape \((N, 3)\).

  • queries_key (str) – Key for query points, shape \((M, 3)\).

  • k (int) – Number of nearest neighbors to find.

  • output_prefix (str, default="neighbors") – Prefix for output keys.

  • extract_keys (list[str], optional) – Optional list of keys to extract for neighbors (e.g., ["normals", "areas"]). If None, only extracts coordinates.

Examples

>>> transform = KNearestNeighbors(
...     points_key="surface_mesh_centers",
...     queries_key="surface_mesh_centers_subsampled",
...     k=11,
...     output_prefix="surface_neighbors",
...     extract_keys=["surface_normals", "surface_areas"]
... )
>>> sample = TensorDict({
...     "surface_mesh_centers": torch.randn(10000, 3),
...     "surface_mesh_centers_subsampled": torch.randn(1000, 3),
...     "surface_normals": torch.randn(10000, 3),
...     "surface_areas": torch.rand(10000)
... })
>>> result = transform(sample)
>>> # Creates: surface_neighbors_coords, surface_neighbors_normals, etc.
class physicsnemo.datapipes.transforms.spatial.CenterOfMass(
coords_key: str,
output_key: str,
*,
areas_key: str | None = None,
)[source]#

Bases: Transform

Compute weighted center of mass for a point cloud.

Calculates the center of mass using area or mass weights, typically applied to mesh data where each point represents a cell with a specific area.

Parameters:
  • coords_key (str) – Key for coordinates, shape \((N, 3)\).

  • areas_key (str) – Key for area weights, shape \((N,)\).

  • output_key (str) – Key to store the computed center of mass, shape \((1, 3)\).

Examples

>>> transform = CenterOfMass(
...     coords_key="stl_centers",
...     areas_key="stl_areas",
...     output_key="center_of_mass"
... )
>>> sample = TensorDict({
...     "stl_centers": torch.randn(5000, 3),
...     "stl_areas": torch.rand(5000)
... })
>>> result = transform(sample)
>>> print(result["center_of_mass"].shape)
torch.Size([3])

Field Processing#

class physicsnemo.datapipes.transforms.field_slice.FieldSlice(
slicing: dict[str, dict[int | str, list[int] | dict[str, int]]],
)[source]#

Bases: Transform

Select specific indices or slices from tensor dimensions.

This transform allows selecting subsets of data along any dimension of specified fields. It supports two modes:

  1. Index selection: Provide a list of indices to select

  2. Slice selection: Provide start/stop/step as a dict

Parameters:

slicing (dict[str, dict[int | str, SliceSpec]]) –

Dictionary mapping field names to dimension slicing specs. Format:

{
    "field_name": {
        dim: indices_or_slice,
        ...
    },
    ...
}

Where:

  • dim is the dimension index (int, or str for Hydra like “-1”)

  • indices_or_slice is either:
    • A list of indices: [0, 2, 5]

    • A slice dict: {"start": 0, "stop": 5, "step": 1}

Examples

Index selection - select features 0, 2, 5 from last dimension:

>>> transform = FieldSlice({
...     "features": {-1: [0, 2, 5]},
... })
>>> # Input shape: (N, 10) -> Output shape: (N, 3)

Slice selection - select first 5 features:

>>> transform = FieldSlice({
...     "features": {-1: {"start": 0, "stop": 5}},
... })
>>> # Input shape: (N, 10) -> Output shape: (N, 5)

Multiple dimensions:

>>> transform = FieldSlice({
...     "grid": {
...         0: [0, 1, 2],      # First 3 indices of dim 0
...         -1: {"stop": 4},   # First 4 of last dim (slice)
...     },
... })

Hydra configuration example:

_target_: physicsnemo.datapipes.transforms.FieldSlice
slicing:
  features:
    "-1": [0, 2, 5]
  velocity:
    "-1":
      stop: 2
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

class physicsnemo.datapipes.transforms.field_processing.BroadcastGlobalFeatures(
input_keys: list[str],
n_points_key: str,
output_key: str,
)[source]#

Bases: Transform

Broadcast global scalar/vector features to all spatial points.

Replicates global parameters (e.g., density, velocity) to match the number of spatial points, enabling concatenation with local features.

Parameters:
  • input_keys (list[str]) – List of global feature keys to broadcast.

  • n_points_key (str) – Key of a tensor whose first dimension gives the number of points to broadcast to.

  • output_key (str) – Key to store the broadcasted features.

Examples

>>> transform = BroadcastGlobalFeatures(
...     input_keys=["air_density", "stream_velocity"],
...     n_points_key="embeddings",
...     output_key="fx"
... )
>>> data = TensorDict({
...     "air_density": torch.tensor(1.225),
...     "stream_velocity": torch.tensor(30.0),
...     "embeddings": torch.randn(10000, 7)
... })
>>> result = transform(data)
>>> print(result["fx"].shape)
torch.Size([10000, 2])

Feature Building#

class physicsnemo.datapipes.transforms.concat_fields.ConcatFields(
input_keys: list[str],
output_key: str,
*,
dim: int = -1,
skip_missing: bool = False,
)[source]#

Bases: Transform

Concatenate multiple tensor fields along a specified dimension.

Combines specified fields into a single output tensor by concatenating along the feature dimension. Useful for building embeddings from multiple components like positions, normals, and signed distance fields.

All input tensors must have the same shape except for the concatenation dimension.

Parameters:
  • input_keys (list[str]) – List of tensor keys to concatenate, in order.

  • output_key (str) – Key to store the concatenated result.

  • dim (int, default=-1) – Dimension along which to concatenate.

  • skip_missing (bool, default=False) – If True, skip keys that are not present in the data instead of raising an error. Useful for optional fields.

Examples

>>> transform = ConcatFields(
...     input_keys=["positions", "sdf", "normals"],
...     output_key="embeddings"
... )
>>> data = TensorDict({
...     "positions": torch.randn(10000, 3),
...     "sdf": torch.randn(10000, 1),
...     "normals": torch.randn(10000, 3)
... })
>>> result = transform(data)
>>> print(result["embeddings"].shape)
torch.Size([10000, 7])
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

class physicsnemo.datapipes.transforms.concat_fields.NormalizeVectors(
input_keys: list[str],
*,
dim: int = -1,
eps: float = 1e-06,
)[source]#

Bases: Transform

Normalize vectors to unit length.

Divides vectors by their L2 norm along the specified dimension. Handles zero-length vectors by adding a small epsilon to prevent division by zero.

Parameters:
  • input_keys (list[str]) – List of tensor keys to normalize.

  • dim (int, default=-1) – Dimension along which to compute norm.

  • eps (float, default=1e-6) – Small value to prevent division by zero.

Examples

>>> transform = NormalizeVectors(input_keys=["normals"])
>>> data = TensorDict({"normals": torch.randn(10000, 3)})
>>> result = transform(data)
>>> # Normals are now unit length
>>> norms = torch.norm(result["normals"], dim=-1)
>>> print(torch.allclose(norms, torch.ones_like(norms), atol=1e-5))
True
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

Utility#

class physicsnemo.datapipes.transforms.utility.Rename(mapping: dict[str, str], *, strict: bool = True)[source]#

Bases: Transform

Rename keys in a TensorDict.

Replaces existing key names with new names according to a mapping. The tensor data is preserved, only the keys are changed.

Nested tensordicts can use this too. The keys are flattened with a ‘.’ separator: a[“b”][“d”] will map to a[“b.d”] for renaming. If you want to replace d, you’d provide {“a.d” : “a.c”} in the mapping file.

Parameters:
  • mapping (dict[str, str]) – Dictionary mapping old key names to new key names. Keys are the original names, values are the new names.

  • strict (bool, default=True) – If True, raise an error if a key in the mapping is not found in the data. If False, silently skip missing keys.

Examples

>>> transform = Rename(mapping={"old_name": "new_name", "x": "positions"})
>>> data = TensorDict({
...     "old_name": torch.randn(100, 3),
...     "x": torch.randn(100, 3),
...     "other": torch.randn(100, 1)
... })
>>> result = transform(data)
>>> print(sorted(result.keys()))
['new_name', 'other', 'positions']
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

class physicsnemo.datapipes.transforms.utility.Purge(
*,
keep_only: list[str] | None = None,
drop_only: list[str] | None = None,
strict: bool = True,
)[source]#

Bases: Transform

Remove keys and their associated tensors from a TensorDict.

Supports two mutually exclusive modes:

  • drop_only: Specify keys to remove (keep everything else)

  • keep_only: Specify keys to keep (remove everything else)

Only one mode can be active at a time. By default, drop_only=None means no keys are dropped (identity transform).

Parameters:
  • keep_only (list[str], optional) – List of keys to keep. All other keys will be removed. Cannot be used together with drop_only.

  • drop_only (list[str], optional) – List of keys to remove. All other keys will be kept. Cannot be used together with keep_only. Default is None (drop nothing).

  • strict (bool, default=True) – If True, raise an error if a specified key is not found in the data. If False, silently skip missing keys.

Examples

Drop mode - remove specific keys:

>>> transform = Purge(drop_only=["temp", "debug_info"])
>>> data = TensorDict({
...     "positions": torch.randn(100, 3),
...     "temp": torch.randn(100, 1),
...     "debug_info": torch.randn(100, 10)
... })
>>> result = transform(data)
>>> print(list(result.keys()))
['positions']

Keep mode - keep only specific keys:

>>> transform = Purge(keep_only=["positions", "velocities"])
>>> data = TensorDict({
...     "positions": torch.randn(100, 3),
...     "velocities": torch.randn(100, 3),
...     "temp": torch.randn(100, 1)
... })
>>> result = transform(data)
>>> print(list(result.keys()))
['positions', 'velocities']
Raises:

ValueError – If both keep_only and drop_only are specified.

extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

class physicsnemo.datapipes.transforms.utility.ConstantField(
reference_key: str,
output_key: str,
*,
fill_value: float = 0.0,
output_dim: int = 1,
)[source]#

Bases: Transform

Create a tensor filled with a constant value.

Creates a tensor where the first dimension matches a reference tensor and the last dimension is configurable. The tensor is filled with the specified constant value. Useful for creating placeholder tensors like zero SDF values for surface points, or indicator fields.

Parameters:
  • reference_key (str) – Key for the tensor to use as shape reference. The first dimension of this tensor determines the number of rows in the output.

  • output_key (str) – Key to store the constant tensor.

  • fill_value (float, default=0.0) – The constant value to fill the tensor with.

  • output_dim (int, default=1) – Feature dimension for output tensor. Creates tensor with shape (N, output_dim) where N is the first dimension of the reference tensor.

Examples

Create zeros (default):

>>> transform = ConstantField(
...     reference_key="positions",
...     output_key="sdf",
...     output_dim=1
... )
>>> data = TensorDict({"positions": torch.randn(10000, 3)})
>>> result = transform(data)
>>> print(result["sdf"].shape)
torch.Size([10000, 1])
>>> print(result["sdf"][0, 0].item())
0.0

Create ones:

>>> transform = ConstantField(
...     reference_key="positions",
...     output_key="mask",
...     fill_value=1.0,
...     output_dim=1
... )

Create custom constant:

>>> transform = ConstantField(
...     reference_key="positions",
...     output_key="temperature",
...     fill_value=293.15,  # Room temperature in Kelvin
...     output_dim=1
... )
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str