Built-in Transforms#
PhysicsNeMo Transforms are the core data manipulation tool for datapipes. All
transforms inherit from the transform base class.
- class physicsnemo.datapipes.transforms.base.Transform[source]#
Bases:
ABCAbstract base class for all transforms.
Transforms operate on a TensorDict and return a modified TensorDict. They are designed to run on GPU tensors for maximum performance. Metadata is not passed to transforms (handled separately by Dataset/DataLoader).
Subclasses must implement:
__call__(data: TensorDict) -> TensorDict
Optionally override:
extra_repr() -> str: For custom repr outputstate_dict() -> dict: For serializationload_state_dict(state_dict: dict): For deserialization
Examples
>>> class MyTransform(Transform): ... def __init__(self, scale: float): ... super().__init__() ... self.scale = scale ... ... def __call__(self, data: TensorDict) -> TensorDict: ... # Apply transformation to all tensors ... return data.apply(lambda x: x * self.scale)
- property device: device | None#
The device this transform operates on.
- Returns:
The device, or None if not set.
- Return type:
torch.device or None
- extra_repr() str[source]#
Return extra information for repr.
Override this to add transform-specific info to the repr.
- Returns:
Extra representation string.
- Return type:
str
- load_state_dict(state_dict: dict[str, Any]) None[source]#
Load state from a state dictionary.
Override this to restore transform state.
- Parameters:
state_dict (dict[str, Any]) – State dictionary to load from.
- state_dict() dict[str, Any][source]#
Return a dictionary containing the transform’s state.
Override this for transforms with learnable or configurable state.
- Returns:
State dictionary.
- Return type:
dict[str, Any]
- to(
- device: device | str,
Move any internal tensors to the specified device.
This default implementation automatically moves any tensor attributes found in self.__dict__ to the specified device. Override this method if your transform requires custom device handling.
- Parameters:
device (torch.device or str) – Target device.
- Returns:
Self for chaining.
- Return type:
To implement a new transform, users are required to override the __call__
method as well as any initialization or configuration details.
The input to a transform is mutable by default, and so the order of transformations matters.
In general, transforms are transactional: take input in, manipulate it, return output, and almost never update state. Transforms should be device-agnostic, and use a compute-follows-data principle, operating on data on the device where it resides whenever possible.
By default, transforms accept and return tensordict objects: this is not,
strictly, a requirement that must be enforced.
If you implement custom transforms that return different data types, downstream
transforms should expect that data type. One example of this, found in the
minimal datapipe examples, is turning the tensordict objects into a PyTorch
Geometric graph object. This type of manipulation is perfectly valid, but
requires custom collation functions and prevents usage of tensordict-based
transforms downstream.
One unique transformation is the Compose transformation, which takes
a list of transformations and logically applies them in order, as one
transformation. Use the Compose transformation similar to the torch.nn.Sequential container for stacking PyTorch Modules together.
- class physicsnemo.datapipes.transforms.compose.Compose(
- transforms: Sequence[Transform],
Bases:
TransformCompose multiple transforms into a sequential pipeline.
Applies transforms in order, passing the output of each as input to the next.
- Parameters:
transforms (Sequence[Transform]) – Sequence of transforms to apply in order.
Examples
>>> from physicsnemo.datapipes.transforms import Normalize, SubsamplePoints >>> from tensordict import TensorDict >>> sample = TensorDict({ ... "pressure": torch.tensor([101325.0, 102325.0, 100325.0]), ... }) >>> normalize = Normalize(input_keys=["pressure"], method="mean_std", means={"pressure": 101325.0}, stds={"pressure": 1000.0}) >>> subsample = SubsamplePoints(input_keys=["pressure"], n_points=1000) >>> pipeline = Compose([normalize, subsample]) >>> transformed = pipeline(sample) >>> transformed["pressure"] tensor([ 0., 1., -1.])
- append(
- transform: Transform,
Append a transform to the pipeline.
- Parameters:
transform (Transform) – Transform to append.
- Raises:
TypeError – If transform is not a Transform instance.
- extra_repr() str[source]#
Return extra information for repr.
- Returns:
Formatted string showing all transforms.
- Return type:
str
Below are a collection of commonly used transforms in PhysicsNeMo’s datapipes.
Since the input and output are tensordict objects, a common configuration pattern
for a transformation is to specify a list of input and output keys to operate on,
though this is not always the case.
Most transforms do not have an internal state; the ones that do, however, will
automatically move tensors to and from the GPU with the to() syntax as expected.
Normalization#
- class physicsnemo.datapipes.transforms.normalize.Normalize(
- input_keys: list[str],
- means: dict[str, float | Tensor] | float | Tensor | None = None,
- stds: dict[str, float | Tensor] | float | Tensor | None = None,
- *,
- method: Literal['mean_std', 'min_max'] | None = None,
- mins: dict[str, float | Tensor] | float | Tensor | None = None,
- maxs: dict[str, float | Tensor] | float | Tensor | None = None,
- stats_file: str | Path | None = None,
- eps: float = 1e-08,
Bases:
TransformNormalize specified fields using mean-std or min-max scaling.
Supports two normalization methods: -
mean_std: Applies (x - mean) / std for each specified field -min_max: Applies (x - center) / half_range, normalizing to [-1, 1]where center = (max + min) / 2 and half_range = (max - min) / 2
Parameters can be provided directly or loaded from a
.npzfile.Examples
Mean-std scaling:
>>> import torch >>> from tensordict import TensorDict >>> sample = TensorDict({ ... "pressure": torch.tensor([101325.0, 102325.0, 100325.0]), ... "velocity": torch.tensor([10.0, -10.0, 0.0]), ... }) >>> norm = Normalize( ... input_keys=["pressure", "velocity"], ... method="mean_std", ... means={"pressure": 101325.0, "velocity": 0.0}, ... stds={"pressure": 1000.0, "velocity": 10.0}, ... ) >>> normalized = norm(sample) >>> normalized["pressure"] tensor([ 0., 1., -1.]) >>> normalized["velocity"] tensor([ 1., -1., 0.])
Min-max scaling (normalizes to [-1, 1]):
>>> sample = TensorDict({ ... "pressure": torch.tensor([100000.0, 105000.0, 110000.0]), ... }) >>> norm = Normalize( ... input_keys=["pressure"], ... method="min_max", ... mins={"pressure": 100000.0}, ... maxs={"pressure": 110000.0}, ... ) >>> normalized = norm(sample) >>> normalized["pressure"] tensor([-1., 0., 1.])
- extra_repr() str[source]#
Return extra information for repr.
Override this to add transform-specific info to the repr.
- Returns:
Extra representation string.
- Return type:
str
Subsampling#
- class physicsnemo.datapipes.transforms.subsample.SubsamplePoints(
- input_keys: list[str],
- n_points: int,
- *,
- algorithm: Literal['poisson_fixed', 'uniform'] = 'poisson_fixed',
- weights_key: str | None = None,
Bases:
TransformSubsample points from large point clouds or meshes.
This transform applies coordinated subsampling to multiple tensor fields, ensuring that the same points are selected across all specified keys. Useful for downsampling large volumetric data or point clouds while maintaining correspondence between coordinates and field values.
Supports two sampling algorithms:
"poisson_fixed": Near-uniform sampling for very large datasets (> 2^24 points)"uniform": Standard uniform sampling
Optionally supports weighted sampling (e.g., area-weighted for surface meshes) by providing a
weights_key.- Parameters:
input_keys (list[str]) – List of tensor keys to subsample. All must have the same first dimension size.
n_points (int) – Number of points to sample.
algorithm ({"poisson_fixed", "uniform"}, default="poisson_fixed") – Sampling algorithm to use.
weights_key (str, optional) – Optional key for sampling weights (e.g.,
"surface_areas"for area-weighted surface sampling). When provided, samples are drawn according to the weights distribution.
Examples
Uniform sampling:
>>> transform = SubsamplePoints( ... input_keys=["volume_mesh_centers", "volume_fields"], ... n_points=10000, ... algorithm="poisson_fixed" ... ) >>> sample = TensorDict({ ... "volume_mesh_centers": torch.randn(100000, 3), ... "volume_fields": torch.randn(100000, 5) ... }) >>> result = transform(sample) >>> print(result["volume_mesh_centers"].shape) torch.Size([10000, 3])
Weighted sampling:
>>> transform = SubsamplePoints( ... input_keys=["surface_mesh_centers", "surface_fields", "surface_normals"], ... n_points=5000, ... algorithm="uniform", ... weights_key="surface_areas" ... ) >>> sample = TensorDict({ ... "surface_mesh_centers": torch.randn(20000, 3), ... "surface_fields": torch.randn(20000, 2), ... "surface_normals": torch.randn(20000, 3), ... "surface_areas": torch.rand(20000) ... }) >>> result = transform(sample) >>> print(result["surface_mesh_centers"].shape) torch.Size([5000, 3])
Notes
All specified keys must have the same size in their first dimension. The same indices are applied to all keys to maintain correspondence.
Geometric#
- class physicsnemo.datapipes.transforms.geometric.ComputeSDF(
- input_keys: list[str],
- output_key: str,
- mesh_coords_key: str,
- mesh_faces_key: str,
- *,
- use_winding_number: bool = True,
- closest_points_key: str | None = None,
Bases:
TransformCompute signed distance field from a mesh.
Computes the signed distance from query points to the nearest point on a triangular mesh surface. Optionally returns the closest points on the mesh surface for each query point.
- Parameters:
input_keys (list[str]) – List of keys containing query points to compute SDF for. Each tensor should have shape \((N, 3)\).
output_key (str) – Key to store the computed SDF values.
mesh_coords_key (str) – Key for mesh vertex coordinates, shape \((M, 3)\).
mesh_faces_key (str) – Key for mesh face indices (flattened), shape \((F*3,)\).
use_winding_number (bool, default=True) – If True, use winding number for sign determination.
closest_points_key (str, optional) – Optional key to store closest points on mesh.
Examples
>>> transform = ComputeSDF( ... input_keys=["volume_mesh_centers"], ... output_key="sdf_nodes", ... mesh_coords_key="stl_coordinates", ... mesh_faces_key="stl_faces", ... closest_points_key="closest_points" ... ) >>> sample = TensorDict({ ... "volume_mesh_centers": torch.randn(10000, 3), ... "stl_coordinates": torch.randn(5000, 3), ... "stl_faces": torch.randint(0, 5000, (10000,)) ... }) >>> result = transform(sample) >>> print(result["sdf_nodes"].shape) torch.Size([10000, 1])
- class physicsnemo.datapipes.transforms.geometric.ComputeNormals(
- positions_key: str,
- closest_points_key: str,
- center_of_mass_key: str,
- output_key: str,
- *,
- handle_zero_distance: bool = True,
Bases:
TransformCompute normal vectors from closest points.
Computes normalized direction vectors from query points to their closest points on a surface. Handles zero-distance edge cases by falling back to center of mass direction.
- Parameters:
positions_key (str) – Key for position tensor, shape \((N, 3)\).
closest_points_key (str) – Key for closest points tensor, shape \((N, 3)\).
center_of_mass_key (str) – Key for center of mass, shape \((1, 3)\) or \((3,)\).
output_key (str) – Key to store computed normals.
handle_zero_distance (bool, default=True) – If True, use center_of_mass fallback for zero distances.
Examples
>>> transform = ComputeNormals( ... positions_key="volume_mesh_centers", ... closest_points_key="closest_points", ... center_of_mass_key="center_of_mass", ... output_key="volume_normals" ... )
- class physicsnemo.datapipes.transforms.geometric.Translate(
- input_keys: list[str],
- center_key_or_value: str | Tensor,
- *,
- subtract: bool = False,
Bases:
TransformApply a translation by adding or subtracting a center point.
By default, this will ADD the translation. But you can also use the subtract mode: this is particularly useful when composed with CenterOfMass: you can compute the CoM and apply a translation as a CoM subtraction to center the data at the origin.
- Parameters:
input_keys (list[str]) – List of position tensor keys to translate.
center_key_or_value (str or torch.Tensor) – Either a key name (str) for a tensor in the sample, or a fixed tensor value to add/subtract.
subtract (bool, default=False) – If False (default), ADD the translation (data + center). If True, SUBTRACT the translation (data - center). Use subtract=True when centering data around a reference point like center of mass.
Examples
Add mode (default) - shift points by a fixed offset:
>>> transform = Translate( ... input_keys=["positions"], ... center_key_or_value=torch.tensor([1.0, 2.0, 3.0]) ... ) >>> # result["positions"] = original + [1, 2, 3]
Subtract mode - center points by subtracting center of mass:
>>> transform = Translate( ... input_keys=["volume_mesh_centers", "surface_mesh_centers"], ... center_key_or_value="center_of_mass", ... subtract=True ... ) >>> # result["positions"] = original - center_of_mass
- class physicsnemo.datapipes.transforms.geometric.Scale(
- input_keys: list[str],
- scale: Tensor,
- *,
- divide: bool = False,
Bases:
TransformApply a scale factor by multiplying or dividing by a reference scale.
By default, this will MULTIPLY by the scale factor. But you can also use the divide mode: this is particularly useful for normalizing data to make the representation scale invariant (e.g., dividing by a characteristic length scale).
- Parameters:
input_keys (list[str]) – List of position tensor keys to scale.
scale (torch.Tensor) – Scale factor tensor, shape \((1, D)\) or \((D,)\).
divide (bool, default=False) – If False (default), MULTIPLY by the scale (data * scale). If True, DIVIDE by the scale (data / scale). Use divide=True when normalizing data by a reference scale.
Examples
Multiply mode (default) - scale up positions by 2x:
>>> transform = Scale( ... input_keys=["positions"], ... scale=torch.tensor([2.0, 2.0, 2.0]) ... ) >>> # result["positions"] = original * [2, 2, 2]
Divide mode - normalize by a reference scale:
>>> transform = Scale( ... input_keys=["volume_mesh_centers", "geometry_coordinates"], ... scale=torch.tensor([1.0, 1.0, 1.0]), ... divide=True ... ) >>> # result["positions"] = original / scale
Spatial#
- class physicsnemo.datapipes.transforms.spatial.BoundingBoxFilter(
- input_keys: list[str],
- bbox_min: Tensor,
- bbox_max: Tensor,
- *,
- dependent_keys: list[str] | None = None,
Bases:
TransformFilter points outside a spatial bounding box.
Removes points that fall outside specified min/max bounds and applies the same filtering to dependent arrays to maintain correspondence. This is useful for focusing on specific regions of interest or removing outliers from simulation data.
- Parameters:
input_keys (list[str]) – List of coordinate tensor keys to filter.
bbox_min (torch.Tensor) – Minimum corner of bounding box, shape \((3,)\).
bbox_max (torch.Tensor) – Maximum corner of bounding box, shape \((3,)\).
dependent_keys (list[str], optional) – Optional list of keys to filter using the same mask. These maintain correspondence with the filtered coordinates.
Examples
>>> transform = BoundingBoxFilter( ... input_keys=["volume_mesh_centers"], ... bbox_min=torch.tensor([-1.0, -1.0, -1.0]), ... bbox_max=torch.tensor([1.0, 1.0, 1.0]), ... dependent_keys=["volume_fields", "sdf_nodes"] ... ) >>> sample = TensorDict({ ... "volume_mesh_centers": torch.randn(10000, 3) * 2, # Some outside bbox ... "volume_fields": torch.randn(10000, 4) ... }) >>> result = transform(sample) >>> # Only points within bbox remain
- class physicsnemo.datapipes.transforms.spatial.CreateGrid(
- output_key: str,
- resolution: tuple[int, int, int],
- bbox_min: Tensor,
- bbox_max: Tensor,
Bases:
TransformCreate a regular 3D spatial grid.
Generates a uniform grid spanning a bounding box, used for latent space representations, interpolation grids, or structured spatial queries.
- Parameters:
output_key (str) – Key to store the generated grid.
resolution (tuple[int, int, int]) – Grid resolution as (nx, ny, nz).
bbox_min (torch.Tensor) – Minimum corner of bounding box, shape \((3,)\).
bbox_max (torch.Tensor) – Maximum corner of bounding box, shape \((3,)\).
Examples
>>> transform = CreateGrid( ... output_key="grid", ... resolution=(64, 64, 64), ... bbox_min=torch.tensor([-1.0, -1.0, -1.0]), ... bbox_max=torch.tensor([1.0, 1.0, 1.0]) ... ) >>> sample = TensorDict({}) >>> result = transform(sample) >>> print(result["grid"].shape) torch.Size([262144, 3])
- class physicsnemo.datapipes.transforms.spatial.KNearestNeighbors(
- points_key: str,
- queries_key: str,
- k: int,
- *,
- output_prefix: str = 'neighbors',
- extract_keys: list[str] | None = None,
- drop_first_neighbor: bool = False,
Bases:
TransformCompute k-nearest neighbors in a point cloud.
Finds the k nearest neighbors for each query point and extracts corresponding coordinates and other attributes. Useful for local feature aggregation in mesh networks and spatial interpolation.
- Parameters:
points_key (str) – Key for reference points to search, shape \((N, 3)\).
queries_key (str) – Key for query points, shape \((M, 3)\).
k (int) – Number of nearest neighbors to find.
output_prefix (str, default="neighbors") – Prefix for output keys.
extract_keys (list[str], optional) – Optional list of keys to extract for neighbors (e.g.,
["normals", "areas"]). If None, only extracts coordinates.
Examples
>>> transform = KNearestNeighbors( ... points_key="surface_mesh_centers", ... queries_key="surface_mesh_centers_subsampled", ... k=11, ... output_prefix="surface_neighbors", ... extract_keys=["surface_normals", "surface_areas"] ... ) >>> sample = TensorDict({ ... "surface_mesh_centers": torch.randn(10000, 3), ... "surface_mesh_centers_subsampled": torch.randn(1000, 3), ... "surface_normals": torch.randn(10000, 3), ... "surface_areas": torch.rand(10000) ... }) >>> result = transform(sample) >>> # Creates: surface_neighbors_coords, surface_neighbors_normals, etc.
- class physicsnemo.datapipes.transforms.spatial.CenterOfMass(
- coords_key: str,
- output_key: str,
- *,
- areas_key: str | None = None,
Bases:
TransformCompute weighted center of mass for a point cloud.
Calculates the center of mass using area or mass weights, typically applied to mesh data where each point represents a cell with a specific area.
- Parameters:
coords_key (str) – Key for coordinates, shape \((N, 3)\).
areas_key (str) – Key for area weights, shape \((N,)\).
output_key (str) – Key to store the computed center of mass, shape \((1, 3)\).
Examples
>>> transform = CenterOfMass( ... coords_key="stl_centers", ... areas_key="stl_areas", ... output_key="center_of_mass" ... ) >>> sample = TensorDict({ ... "stl_centers": torch.randn(5000, 3), ... "stl_areas": torch.rand(5000) ... }) >>> result = transform(sample) >>> print(result["center_of_mass"].shape) torch.Size([3])
Field Processing#
- class physicsnemo.datapipes.transforms.field_slice.FieldSlice(
- slicing: dict[str, dict[int | str, list[int] | dict[str, int]]],
Bases:
TransformSelect specific indices or slices from tensor dimensions.
This transform allows selecting subsets of data along any dimension of specified fields. It supports two modes:
Index selection: Provide a list of indices to select
Slice selection: Provide start/stop/step as a dict
- Parameters:
slicing (dict[str, dict[int | str, SliceSpec]]) –
Dictionary mapping field names to dimension slicing specs. Format:
{ "field_name": { dim: indices_or_slice, ... }, ... }
Where:
dimis the dimension index (int, or str for Hydra like “-1”)indices_or_sliceis either:A list of indices:
[0, 2, 5]A slice dict:
{"start": 0, "stop": 5, "step": 1}
Examples
Index selection - select features 0, 2, 5 from last dimension:
>>> transform = FieldSlice({ ... "features": {-1: [0, 2, 5]}, ... }) >>> # Input shape: (N, 10) -> Output shape: (N, 3)
Slice selection - select first 5 features:
>>> transform = FieldSlice({ ... "features": {-1: {"start": 0, "stop": 5}}, ... }) >>> # Input shape: (N, 10) -> Output shape: (N, 5)
Multiple dimensions:
>>> transform = FieldSlice({ ... "grid": { ... 0: [0, 1, 2], # First 3 indices of dim 0 ... -1: {"stop": 4}, # First 4 of last dim (slice) ... }, ... })
Hydra configuration example:
_target_: physicsnemo.datapipes.transforms.FieldSlice slicing: features: "-1": [0, 2, 5] velocity: "-1": stop: 2
- class physicsnemo.datapipes.transforms.field_processing.BroadcastGlobalFeatures(
- input_keys: list[str],
- n_points_key: str,
- output_key: str,
Bases:
TransformBroadcast global scalar/vector features to all spatial points.
Replicates global parameters (e.g., density, velocity) to match the number of spatial points, enabling concatenation with local features.
- Parameters:
input_keys (list[str]) – List of global feature keys to broadcast.
n_points_key (str) – Key of a tensor whose first dimension gives the number of points to broadcast to.
output_key (str) – Key to store the broadcasted features.
Examples
>>> transform = BroadcastGlobalFeatures( ... input_keys=["air_density", "stream_velocity"], ... n_points_key="embeddings", ... output_key="fx" ... ) >>> data = TensorDict({ ... "air_density": torch.tensor(1.225), ... "stream_velocity": torch.tensor(30.0), ... "embeddings": torch.randn(10000, 7) ... }) >>> result = transform(data) >>> print(result["fx"].shape) torch.Size([10000, 2])
Feature Building#
- class physicsnemo.datapipes.transforms.concat_fields.ConcatFields(
- input_keys: list[str],
- output_key: str,
- *,
- dim: int = -1,
- skip_missing: bool = False,
Bases:
TransformConcatenate multiple tensor fields along a specified dimension.
Combines specified fields into a single output tensor by concatenating along the feature dimension. Useful for building embeddings from multiple components like positions, normals, and signed distance fields.
All input tensors must have the same shape except for the concatenation dimension.
- Parameters:
input_keys (list[str]) – List of tensor keys to concatenate, in order.
output_key (str) – Key to store the concatenated result.
dim (int, default=-1) – Dimension along which to concatenate.
skip_missing (bool, default=False) – If True, skip keys that are not present in the data instead of raising an error. Useful for optional fields.
Examples
>>> transform = ConcatFields( ... input_keys=["positions", "sdf", "normals"], ... output_key="embeddings" ... ) >>> data = TensorDict({ ... "positions": torch.randn(10000, 3), ... "sdf": torch.randn(10000, 1), ... "normals": torch.randn(10000, 3) ... }) >>> result = transform(data) >>> print(result["embeddings"].shape) torch.Size([10000, 7])
- class physicsnemo.datapipes.transforms.concat_fields.NormalizeVectors(
- input_keys: list[str],
- *,
- dim: int = -1,
- eps: float = 1e-06,
Bases:
TransformNormalize vectors to unit length.
Divides vectors by their L2 norm along the specified dimension. Handles zero-length vectors by adding a small epsilon to prevent division by zero.
- Parameters:
input_keys (list[str]) – List of tensor keys to normalize.
dim (int, default=-1) – Dimension along which to compute norm.
eps (float, default=1e-6) – Small value to prevent division by zero.
Examples
>>> transform = NormalizeVectors(input_keys=["normals"]) >>> data = TensorDict({"normals": torch.randn(10000, 3)}) >>> result = transform(data) >>> # Normals are now unit length >>> norms = torch.norm(result["normals"], dim=-1) >>> print(torch.allclose(norms, torch.ones_like(norms), atol=1e-5)) True
Utility#
- class physicsnemo.datapipes.transforms.utility.Rename(mapping: dict[str, str], *, strict: bool = True)[source]#
Bases:
TransformRename keys in a TensorDict.
Replaces existing key names with new names according to a mapping. The tensor data is preserved, only the keys are changed.
Nested tensordicts can use this too. The keys are flattened with a ‘.’ separator: a[“b”][“d”] will map to a[“b.d”] for renaming. If you want to replace d, you’d provide {“a.d” : “a.c”} in the mapping file.
- Parameters:
mapping (dict[str, str]) – Dictionary mapping old key names to new key names. Keys are the original names, values are the new names.
strict (bool, default=True) – If True, raise an error if a key in the mapping is not found in the data. If False, silently skip missing keys.
Examples
>>> transform = Rename(mapping={"old_name": "new_name", "x": "positions"}) >>> data = TensorDict({ ... "old_name": torch.randn(100, 3), ... "x": torch.randn(100, 3), ... "other": torch.randn(100, 1) ... }) >>> result = transform(data) >>> print(sorted(result.keys())) ['new_name', 'other', 'positions']
- class physicsnemo.datapipes.transforms.utility.Purge(
- *,
- keep_only: list[str] | None = None,
- drop_only: list[str] | None = None,
- strict: bool = True,
Bases:
TransformRemove keys and their associated tensors from a TensorDict.
Supports two mutually exclusive modes:
drop_only: Specify keys to remove (keep everything else)
keep_only: Specify keys to keep (remove everything else)
Only one mode can be active at a time. By default, drop_only=None means no keys are dropped (identity transform).
- Parameters:
keep_only (list[str], optional) – List of keys to keep. All other keys will be removed. Cannot be used together with drop_only.
drop_only (list[str], optional) – List of keys to remove. All other keys will be kept. Cannot be used together with keep_only. Default is None (drop nothing).
strict (bool, default=True) – If True, raise an error if a specified key is not found in the data. If False, silently skip missing keys.
Examples
Drop mode - remove specific keys:
>>> transform = Purge(drop_only=["temp", "debug_info"]) >>> data = TensorDict({ ... "positions": torch.randn(100, 3), ... "temp": torch.randn(100, 1), ... "debug_info": torch.randn(100, 10) ... }) >>> result = transform(data) >>> print(list(result.keys())) ['positions']
Keep mode - keep only specific keys:
>>> transform = Purge(keep_only=["positions", "velocities"]) >>> data = TensorDict({ ... "positions": torch.randn(100, 3), ... "velocities": torch.randn(100, 3), ... "temp": torch.randn(100, 1) ... }) >>> result = transform(data) >>> print(list(result.keys())) ['positions', 'velocities']
- Raises:
ValueError – If both keep_only and drop_only are specified.
- class physicsnemo.datapipes.transforms.utility.ConstantField(
- reference_key: str,
- output_key: str,
- *,
- fill_value: float = 0.0,
- output_dim: int = 1,
Bases:
TransformCreate a tensor filled with a constant value.
Creates a tensor where the first dimension matches a reference tensor and the last dimension is configurable. The tensor is filled with the specified constant value. Useful for creating placeholder tensors like zero SDF values for surface points, or indicator fields.
- Parameters:
reference_key (str) – Key for the tensor to use as shape reference. The first dimension of this tensor determines the number of rows in the output.
output_key (str) – Key to store the constant tensor.
fill_value (float, default=0.0) – The constant value to fill the tensor with.
output_dim (int, default=1) – Feature dimension for output tensor. Creates tensor with shape (N, output_dim) where N is the first dimension of the reference tensor.
Examples
Create zeros (default):
>>> transform = ConstantField( ... reference_key="positions", ... output_key="sdf", ... output_dim=1 ... ) >>> data = TensorDict({"positions": torch.randn(10000, 3)}) >>> result = transform(data) >>> print(result["sdf"].shape) torch.Size([10000, 1]) >>> print(result["sdf"][0, 0].item()) 0.0
Create ones:
>>> transform = ConstantField( ... reference_key="positions", ... output_key="mask", ... fill_value=1.0, ... output_dim=1 ... )
Create custom constant:
>>> transform = ConstantField( ... reference_key="positions", ... output_key="temperature", ... fill_value=293.15, # Room temperature in Kelvin ... output_dim=1 ... )