Built-in Transforms#

PhysicsNeMo Transforms are the core data manipulation tool for datapipes. All transforms inherit from the transform base class.

class physicsnemo.datapipes.transforms.base.Transform[source]#

Bases: ABC

Abstract base class for all transforms.

Transforms operate on a TensorDict and return a modified TensorDict. They are designed to run on GPU tensors for maximum performance. Metadata is not passed to transforms (handled separately by Dataset/DataLoader).

Subclasses must implement:

  • __call__(data: TensorDict) -> TensorDict

Optionally override:

  • extra_repr() -> str: For custom repr output

  • state_dict() -> dict: For serialization

  • load_state_dict(state_dict: dict): For deserialization

Examples

>>> class MyTransform(Transform):
...     def __init__(self, scale: float):
...         super().__init__()
...         self.scale = scale
...
...     def __call__(self, data: TensorDict) -> TensorDict:
...         # Apply transformation to all tensors
...         return data.apply(lambda x: x * self.scale)
property device: device | None#

The device this transform operates on.

Returns:

The device, or None if not set.

Return type:

torch.device or None

extra_repr() str[source]#

Return extra information for repr.

Override this to add transform-specific info to the repr.

Returns:

Extra representation string.

Return type:

str

load_state_dict(state_dict: dict[str, Any]) None[source]#

Load state from a state dictionary.

Override this to restore transform state.

Parameters:

state_dict (dict[str, Any]) – State dictionary to load from.

set_epoch(epoch: int) None[source]#

Reseed the generator for a new epoch.

Reseeds self._generator with initial_seed() + epoch so each epoch produces a different but deterministic random sequence. No-op for deterministic transforms or when no generator has been assigned.

Parameters:

epoch (int) – Current epoch number.

set_generator(generator: Generator) None[source]#

Assign a torch.Generator for reproducible random sampling.

Only takes effect on stochastic transforms (those that declare self._generator). Deterministic transforms silently ignore the call.

Parameters:

generator (torch.Generator) – Generator to use for all subsequent random draws.

state_dict() dict[str, Any][source]#

Return a dictionary containing the transform’s state.

Override this for transforms with learnable or configurable state.

Returns:

State dictionary.

Return type:

dict[str, Any]

property stochastic: bool#

Whether this transform uses random sampling.

Returns True if the instance has a _generator attribute (set by stochastic subclasses such as SubsamplePoints). Deterministic transforms return False.

to(
device: device | str,
) Transform[source]#

Move any internal tensors, generators, and distributions to device.

torch.Generator objects cannot be moved in-place, so a new generator is created on device and seeded with initial_seed() from the original.

torch.distributions.Distribution objects are reconstructed with their parameter tensors moved to device, using arg_constraints to discover parameter names generically.

Override this method if your transform requires custom device handling.

Parameters:

device (torch.device or str) – Target device.

Returns:

Self for chaining.

Return type:

Transform

To implement a new transform, users are required to override the __call__ method as well as any initialization or configuration details.

The input to a transform is mutable by default, and so the order of transformations matters.

In general, transforms are transactional: take input in, manipulate it, return output, and almost never update state. Transforms should be device-agnostic, and use a compute-follows-data principle, operating on data on the device where it resides whenever possible.

By default, transforms accept and return tensordict objects: this is not, strictly, a requirement that must be enforced. If you implement custom transforms that return different data types, downstream transforms should expect that data type. One example of this, found in the minimal datapipe examples, is turning the tensordict objects into a PyTorch Geometric graph object. This type of manipulation is perfectly valid, but requires custom collation functions and prevents usage of tensordict-based transforms downstream.

One unique transformation is the Compose transformation, which takes a list of transformations and logically applies them in order, as one transformation. Use the Compose transformation similar to the torch.nn.Sequential container for stacking PyTorch Modules together.

class physicsnemo.datapipes.transforms.compose.Compose(
transforms: Sequence[Transform],
)[source]#

Bases: Transform

Compose multiple transforms into a sequential pipeline.

Applies transforms in order, passing the output of each as input to the next.

Parameters:

transforms (Sequence[Transform]) – Sequence of transforms to apply in order.

Examples

>>> from physicsnemo.datapipes.transforms import Normalize, SubsamplePoints
>>> from tensordict import TensorDict
>>> sample = TensorDict({
...     "pressure": torch.tensor([101325.0, 102325.0, 100325.0]),
... })
>>> normalize = Normalize(input_keys=["pressure"], method="mean_std", means={"pressure": 101325.0}, stds={"pressure": 1000.0})
>>> subsample = SubsamplePoints(input_keys=["pressure"], n_points=1000)
>>> pipeline = Compose([normalize, subsample])
>>> transformed = pipeline(sample)
>>> transformed["pressure"]
tensor([ 0.,  1., -1.])
append(
transform: Transform,
) None[source]#

Append a transform to the pipeline.

Parameters:

transform (Transform) – Transform to append.

Raises:

TypeError – If transform is not a Transform instance.

extra_repr() str[source]#

Return extra information for repr.

Returns:

Formatted string showing all transforms.

Return type:

str

set_epoch(epoch: int) None[source]#

Propagate epoch to every child transform.

Parameters:

epoch (int) – Current epoch number.

set_generator(generator: Generator) None[source]#

Fork generator and distribute one child per transform.

Parameters:

generator (torch.Generator) – Parent generator to fork from.

state_dict() dict[str, Any][source]#

Return state of all transforms.

Returns:

Dictionary containing transform states and types.

Return type:

dict[str, Any]

property stochastic: bool#

True if any child transform is stochastic.

to(
device: device | str,
) Compose[source]#

Move all transforms to the specified device.

Parameters:

device (torch.device or str) – Target device.

Returns:

Self for chaining.

Return type:

Compose

Below are a collection of commonly used transforms in PhysicsNeMo’s datapipes. Since the input and output are tensordict objects, a common configuration pattern for a transformation is to specify a list of input and output keys to operate on, though this is not always the case.

Most transforms do not have an internal state; the ones that do, however, will automatically move tensors to and from the GPU with the to() syntax as expected.

Normalization#

Subsampling#

Geometric#

Spatial#

Field Processing#

class physicsnemo.datapipes.transforms.field_slice.FieldSlice(
slicing: dict[str, dict[int | str, list[int] | dict[str, int]]],
)[source]#

Bases: Transform

Select specific indices or slices from tensor dimensions.

This transform allows selecting subsets of data along any dimension of specified fields. It supports two modes:

  1. Index selection: Provide a list of indices to select

  2. Slice selection: Provide start/stop/step as a dict

Parameters:

slicing (dict[str, dict[int | str, SliceSpec]]) –

Dictionary mapping field names to dimension slicing specs. Format:

{
    "field_name": {
        dim: indices_or_slice,
        ...
    },
    ...
}

Where:

  • dim is the dimension index (int, or str for Hydra like “-1”)

  • indices_or_slice is either:
    • A list of indices: [0, 2, 5]

    • A slice dict: {"start": 0, "stop": 5, "step": 1}

Examples

Index selection - select features 0, 2, 5 from last dimension:

>>> transform = FieldSlice({
...     "features": {-1: [0, 2, 5]},
... })
>>> # Input shape: (N, 10) -> Output shape: (N, 3)

Slice selection - select first 5 features:

>>> transform = FieldSlice({
...     "features": {-1: {"start": 0, "stop": 5}},
... })
>>> # Input shape: (N, 10) -> Output shape: (N, 5)

Multiple dimensions:

>>> transform = FieldSlice({
...     "grid": {
...         0: [0, 1, 2],      # First 3 indices of dim 0
...         -1: {"stop": 4},   # First 4 of last dim (slice)
...     },
... })

Hydra configuration example:

_target_: physicsnemo.datapipes.transforms.FieldSlice
slicing:
  features:
    "-1": [0, 2, 5]
  velocity:
    "-1":
      stop: 2
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

class physicsnemo.datapipes.transforms.field_processing.BroadcastGlobalFeatures(
input_keys: list[str],
n_points_key: str,
output_key: str,
)[source]#

Bases: Transform

Broadcast global scalar/vector features to all spatial points.

Replicates global parameters (e.g., density, velocity) to match the number of spatial points, enabling concatenation with local features.

Parameters:
  • input_keys (list[str]) – List of global feature keys to broadcast.

  • n_points_key (str) – Key of a tensor whose first dimension gives the number of points to broadcast to.

  • output_key (str) – Key to store the broadcasted features.

Examples

>>> transform = BroadcastGlobalFeatures(
...     input_keys=["air_density", "stream_velocity"],
...     n_points_key="embeddings",
...     output_key="fx"
... )
>>> data = TensorDict({
...     "air_density": torch.tensor(1.225),
...     "stream_velocity": torch.tensor(30.0),
...     "embeddings": torch.randn(10000, 7)
... })
>>> result = transform(data)
>>> print(result["fx"].shape)
torch.Size([10000, 2])

Feature Building#

class physicsnemo.datapipes.transforms.concat_fields.ConcatFields(
input_keys: list[str],
output_key: str,
*,
dim: int = -1,
skip_missing: bool = False,
)[source]#

Bases: Transform

Concatenate multiple tensor fields along a specified dimension.

Combines specified fields into a single output tensor by concatenating along the feature dimension. Useful for building embeddings from multiple components like positions, normals, and signed distance fields.

All input tensors must have the same shape except for the concatenation dimension.

Parameters:
  • input_keys (list[str]) – List of tensor keys to concatenate, in order.

  • output_key (str) – Key to store the concatenated result.

  • dim (int, default=-1) – Dimension along which to concatenate.

  • skip_missing (bool, default=False) – If True, skip keys that are not present in the data instead of raising an error. Useful for optional fields.

Examples

>>> transform = ConcatFields(
...     input_keys=["positions", "sdf", "normals"],
...     output_key="embeddings"
... )
>>> data = TensorDict({
...     "positions": torch.randn(10000, 3),
...     "sdf": torch.randn(10000, 1),
...     "normals": torch.randn(10000, 3)
... })
>>> result = transform(data)
>>> print(result["embeddings"].shape)
torch.Size([10000, 7])
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

class physicsnemo.datapipes.transforms.concat_fields.NormalizeVectors(
input_keys: list[str],
*,
dim: int = -1,
eps: float = 1e-06,
)[source]#

Bases: Transform

Normalize vectors to unit length.

Divides vectors by their L2 norm along the specified dimension. Handles zero-length vectors by adding a small epsilon to prevent division by zero.

Parameters:
  • input_keys (list[str]) – List of tensor keys to normalize.

  • dim (int, default=-1) – Dimension along which to compute norm.

  • eps (float, default=1e-6) – Small value to prevent division by zero.

Examples

>>> transform = NormalizeVectors(input_keys=["normals"])
>>> data = TensorDict({"normals": torch.randn(10000, 3)})
>>> result = transform(data)
>>> # Normals are now unit length
>>> norms = torch.norm(result["normals"], dim=-1)
>>> print(torch.allclose(norms, torch.ones_like(norms), atol=1e-5))
True
extra_repr() str[source]#

Return extra information for repr.

Returns:

String with transform parameters.

Return type:

str

Utility#