# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility transforms for key management and tensor generation.
Provides transforms for renaming keys, removing (purging) keys from TensorDicts,
and creating constant-filled tensors.
"""
from __future__ import annotations
from typing import Optional
import torch
from tensordict import TensorDict
from physicsnemo.datapipes.registry import register
from physicsnemo.datapipes.transforms.base import Transform
[docs]
@register()
class Rename(Transform):
r"""
Rename keys in a TensorDict.
Replaces existing key names with new names according to a mapping.
The tensor data is preserved, only the keys are changed.
Nested tensordicts can use this too. The keys are flattened with a '.'
separator: a["b"]["d"] will map to a["b.d"] for renaming. If you want to
replace d, you'd provide {"a.d" : "a.c"} in the mapping file.
Parameters
----------
mapping : dict[str, str]
Dictionary mapping old key names to new key names.
Keys are the original names, values are the new names.
strict : bool, default=True
If True, raise an error if a key in the mapping is not found
in the data. If False, silently skip missing keys.
Examples
--------
>>> transform = Rename(mapping={"old_name": "new_name", "x": "positions"})
>>> data = TensorDict({
... "old_name": torch.randn(100, 3),
... "x": torch.randn(100, 3),
... "other": torch.randn(100, 1)
... })
>>> result = transform(data)
>>> print(sorted(result.keys()))
['new_name', 'other', 'positions']
"""
def __init__(
self,
mapping: dict[str, str],
*,
strict: bool = True,
) -> None:
"""
Initialize the rename transform.
Parameters
----------
mapping : dict[str, str]
Dictionary mapping old key names to new key names.
Keys are the original names, values are the new names.
strict : bool, default=True
If True, raise an error if a key in the mapping is not found
in the data. If False, silently skip missing keys.
"""
super().__init__()
self.mapping = mapping
self.strict = strict
def __call__(self, data: TensorDict) -> TensorDict:
"""
Rename keys according to the mapping.
Parameters
----------
data : TensorDict
Input TensorDict with keys to rename.
Returns
-------
TensorDict
TensorDict with renamed keys.
Raises
------
KeyError
If strict=True and a key in the mapping is not found.
ValueError
If a new key name already exists in the data.
"""
# Flatten keys to handle nested TensorDicts
data_f = data.flatten_keys(separator=".")
data_keys = set(str(k) for k in data_f.keys())
# Check for missing keys if strict mode
if self.strict:
missing_keys = set(self.mapping.keys()) - data_keys
if missing_keys:
raise KeyError(
f"Keys not found in data: {missing_keys}. "
f"Available keys: {list(data_f.keys())}"
)
# Check for conflicts with new names
keys_to_rename = set(self.mapping.keys()) & data_keys
new_names = {self.mapping[k] for k in keys_to_rename}
keys_not_renamed = data_keys - keys_to_rename
conflicts = new_names & keys_not_renamed
if conflicts:
raise ValueError(f"New key names conflict with existing keys: {conflicts}")
# Rename keys in-place on flattened data
for old_key, new_key in self.mapping.items():
if old_key in data_f.keys():
data_f.rename_key_(old_key, new_key)
return data_f.unflatten_keys(separator=".")
[docs]
@register()
class Purge(Transform):
r"""
Remove keys and their associated tensors from a TensorDict.
Supports two mutually exclusive modes:
- drop_only: Specify keys to remove (keep everything else)
- keep_only: Specify keys to keep (remove everything else)
Only one mode can be active at a time. By default, drop_only=None means
no keys are dropped (identity transform).
Parameters
----------
keep_only : list[str], optional
List of keys to keep. All other keys will be removed.
Cannot be used together with drop_only.
drop_only : list[str], optional
List of keys to remove. All other keys will be kept.
Cannot be used together with keep_only. Default is None
(drop nothing).
strict : bool, default=True
If True, raise an error if a specified key is not found
in the data. If False, silently skip missing keys.
Examples
--------
Drop mode - remove specific keys:
>>> transform = Purge(drop_only=["temp", "debug_info"])
>>> data = TensorDict({
... "positions": torch.randn(100, 3),
... "temp": torch.randn(100, 1),
... "debug_info": torch.randn(100, 10)
... })
>>> result = transform(data)
>>> print(list(result.keys()))
['positions']
Keep mode - keep only specific keys:
>>> transform = Purge(keep_only=["positions", "velocities"])
>>> data = TensorDict({
... "positions": torch.randn(100, 3),
... "velocities": torch.randn(100, 3),
... "temp": torch.randn(100, 1)
... })
>>> result = transform(data)
>>> print(list(result.keys()))
['positions', 'velocities']
Raises
------
ValueError
If both keep_only and drop_only are specified.
"""
def __init__(
self,
*,
keep_only: Optional[list[str]] = None,
drop_only: Optional[list[str]] = None,
strict: bool = True,
) -> None:
"""
Initialize the purge transform.
Parameters
----------
keep_only : list[str], optional
List of keys to keep. All other keys will be removed.
Cannot be used together with drop_only.
drop_only : list[str], optional
List of keys to remove. All other keys will be kept.
Cannot be used together with keep_only. Default is None
(drop nothing).
strict : bool, default=True
If True, raise an error if a specified key is not found
in the data. If False, silently skip missing keys.
Raises
------
ValueError
If both keep_only and drop_only are specified.
"""
super().__init__()
if keep_only is not None and drop_only is not None:
raise ValueError(
"Cannot specify both 'keep_only' and 'drop_only'. "
"Use only one option at a time."
)
self.keep_only = keep_only
self.drop_only = drop_only
self.strict = strict
@staticmethod
def _to_nested_key(key: str) -> str | tuple[str, ...]:
"""
Convert a dot-separated key string to a tuple for nested access.
Parameters
----------
key : str
Key string, possibly with dots for nested access.
Returns
-------
str or tuple[str, ...]
Original string if no dots, otherwise tuple of parts.
"""
if "." in key:
return tuple(key.split("."))
return key
def __call__(self, data: TensorDict) -> TensorDict:
"""
Remove or keep specified keys from the TensorDict.
Parameters
----------
data : TensorDict
Input TensorDict.
Returns
-------
TensorDict
TensorDict with keys removed according to the configuration.
Raises
------
KeyError
If strict=True and a specified key is not found.
"""
# Get all keys including nested ones (as tuples for nested, strings for flat)
available_keys = set(data.keys(include_nested=True, leaves_only=True))
if self.keep_only is not None:
# Keep only mode: use TensorDict.select
# Convert dot-separated strings to tuples for nested key access
keys_to_keep = [self._to_nested_key(k) for k in self.keep_only]
if self.strict:
missing_keys = set(keys_to_keep) - available_keys
if missing_keys:
raise KeyError(
f"Keys specified in 'keep_only' not found in data: {missing_keys}. "
f"Available keys: {list(available_keys)}"
)
# Use select which handles nested keys properly
# strict=False to allow missing keys in non-strict mode
return data.select(*keys_to_keep, strict=self.strict)
elif self.drop_only is not None:
# Drop only mode: use TensorDict.exclude
# Convert dot-separated strings to tuples for nested key access
keys_to_drop = [self._to_nested_key(k) for k in self.drop_only]
if self.strict:
missing_keys = set(keys_to_drop) - available_keys
if missing_keys:
raise KeyError(
f"Keys specified in 'drop_only' not found in data: {missing_keys}. "
f"Available keys: {list(available_keys)}"
)
# Use exclude which handles nested keys properly
return data.exclude(*keys_to_drop)
else:
# Default: drop nothing, keep everything
return data.clone()
[docs]
@register()
class ConstantField(Transform):
r"""
Create a tensor filled with a constant value.
Creates a tensor where the first dimension matches a reference tensor
and the last dimension is configurable. The tensor is filled with the
specified constant value. Useful for creating placeholder tensors like
zero SDF values for surface points, or indicator fields.
Parameters
----------
reference_key : str
Key for the tensor to use as shape reference.
The first dimension of this tensor determines
the number of rows in the output.
output_key : str
Key to store the constant tensor.
fill_value : float, default=0.0
The constant value to fill the tensor with.
output_dim : int, default=1
Feature dimension for output tensor. Creates tensor with
shape (N, output_dim) where N is the first dimension of
the reference tensor.
Examples
--------
Create zeros (default):
>>> transform = ConstantField(
... reference_key="positions",
... output_key="sdf",
... output_dim=1
... )
>>> data = TensorDict({"positions": torch.randn(10000, 3)})
>>> result = transform(data)
>>> print(result["sdf"].shape)
torch.Size([10000, 1])
>>> print(result["sdf"][0, 0].item())
0.0
Create ones:
>>> transform = ConstantField(
... reference_key="positions",
... output_key="mask",
... fill_value=1.0,
... output_dim=1
... )
Create custom constant:
>>> transform = ConstantField(
... reference_key="positions",
... output_key="temperature",
... fill_value=293.15, # Room temperature in Kelvin
... output_dim=1
... )
"""
def __init__(
self,
reference_key: str,
output_key: str,
*,
fill_value: float = 0.0,
output_dim: int = 1,
) -> None:
"""
Initialize the constant field creation transform.
Parameters
----------
reference_key : str
Key for the tensor to use as shape reference.
The first dimension of this tensor determines
the number of rows in the output.
output_key : str
Key to store the constant tensor.
fill_value : float, default=0.0
The constant value to fill the tensor with.
output_dim : int, default=1
Feature dimension for output tensor. Creates tensor with
shape (N, output_dim) where N is the first dimension of
the reference tensor.
"""
super().__init__()
self.reference_key = reference_key
self.output_key = output_key
self.fill_value = fill_value
self.output_dim = output_dim
def __call__(self, data: TensorDict) -> TensorDict:
"""
Create constant-filled tensor matching reference shape.
Parameters
----------
data : TensorDict
Input TensorDict containing the reference tensor.
Returns
-------
TensorDict
TensorDict with the constant tensor added.
Raises
------
KeyError
If the reference key is not found in the data.
"""
if self.reference_key not in data.keys():
raise KeyError(
f"Reference key '{self.reference_key}' not found in data. "
f"Available keys: {list(data.keys())}"
)
reference = data[self.reference_key]
n_points = reference.shape[0]
constant_tensor = torch.full(
(n_points, self.output_dim),
self.fill_value,
dtype=reference.dtype,
device=reference.device,
)
return data.update({self.output_key: constant_tensor})