Source code for physicsnemo.diffusion.utils.utils

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Iterator, Sequence

import numpy as np
import torch


[docs] class StackedRandomGenerator: """ Wrapper for ``torch.Generator`` that allows specifying a different random seed for each sample in a minibatch. Parameters ---------- device : torch.device Device to use for the random number generator. seeds : Sequence[int] Sequence (e.g. list or tuple) of random seeds for each sample in the minibatch. Its length defines the batch size of generated samples. """ def __init__(self, device: torch.device, seeds: Sequence[int]): super().__init__() self.generators = [ torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds ]
[docs] def randn( self, size: torch.Size | Sequence[int], **kwargs: Any, ) -> torch.Tensor: """ Generate stacked samples from a standard normal distribution. Each sample is generated using a different random seed. Parameters ---------- size : Sequence[int] | torch.Size Size of the output tensor. Accepts any sequence of integers or a ``torch.Size`` instance. First dimension must match the number of random seeds. **kwargs : Any Additional arguments to pass to ``torch.randn``. Returns ------- torch.Tensor Stacked samples from a standard normal distribution. Shape matches ``size``. """ if size[0] != len(self.generators): raise ValueError( f"Expected first dimension of size {len(self.generators)}, got {size[0]}" ) return torch.stack( [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] )
[docs] def randt( self, nu: int, size: torch.Size | Sequence[int], **kwargs: Any, ) -> torch.Tensor: """ Generate stacked samples from a standard Student-t distribution with ``nu`` degrees of freedom. This is useful when sampling from heavy-tailed diffusion models. Parameters ---------- nu : int Degrees of freedom for the Student-t distribution. Must be > 2. size : Sequence[int] | torch.Size Size of the output tensor. Accepts any sequence of integers or a ``torch.Size`` instance. First dimension must match the number of random seeds. **kwargs : Any Additional arguments to pass to ``torch.randn``. Returns ------- torch.Tensor Stacked samples from a standard Student-t distribution. Shape matches ``size``. """ # Size validation if size[0] != len(self.generators): raise ValueError( f"Expected first dimension of size {len(self.generators)}, got {size[0]}" ) # Validation for nu if nu <= 2: raise ValueError(f"Expected nu > 2, but got {nu}.") # Generate samples from Student-t distribution chi_dist = torch.distributions.Chi2(nu) kappa = ( (chi_dist.sample((len(self.generators),)) / nu) .view(-1, *([1] * len(size[1:]))) .to(self.generators[0].device) ) eps = torch.stack( [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators] ) return eps / torch.sqrt(kappa)
[docs] def randn_like(self, input: torch.Tensor) -> torch.Tensor: """ Generate stacked samples from a standard normal distribution with the same shape and data type as the input tensor. Parameters ---------- input : torch.Tensor Input tensor to match the shape, data type, memory layout, and device of. Returns ------- torch.Tensor Stacked samples from a standard normal distribution. Shape matches ``input.shape``. """ return self.randn( input.shape, dtype=input.dtype, layout=input.layout, device=input.device )
[docs] def randint( self, *args: Any, size: torch.Size | Sequence[int], **kwargs: Any, ) -> torch.Tensor: """ Generate stacked samples from a uniform distribution over the integers. Parameters ---------- *args : Any Required positional arguments to pass to ``torch.randint``. size : Sequence[int] | torch.Size Size of the output tensor. Accepts any sequence of integers or a ``torch.Size`` instance. First dimension must match the number of random seeds. **kwargs : Any Additional keyword arguments to pass to ``torch.randint``. Returns ------- torch.Tensor Stacked samples from a uniform distribution over the integers. Shape matches ``size``. """ if size[0] != len(self.generators): raise ValueError( f"Expected first dimension of size {len(self.generators)}, got {size[0]}" ) return torch.stack( [ torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators ] )
[docs] class InfiniteSampler(torch.utils.data.Sampler[int]): """Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely. This sampler yields indices indefinitely, optionally shuffling items as it goes. It can also perform distributed sampling when `rank` and `num_replicas` are specified. Parameters ---------- dataset : torch.utils.data.Dataset The dataset to sample from rank : int, default=0 The rank of the current process within num_replicas processes num_replicas : int, default=1 The number of processes participating in distributed sampling shuffle : bool, default=True Whether to shuffle the indices seed : int, default=0 Random seed for reproducibility when shuffling window_size : float, default=0.5 Fraction of dataset to use as window for shuffling. Must be between 0 and 1. A larger window means more thorough shuffling but slower iteration. start_idx : int, default=0 The initial index to use for the sampler. This is used for resuming training. """ def __init__( self, dataset: torch.utils.data.Dataset, rank: int = 0, num_replicas: int = 1, shuffle: bool = True, seed: int = 0, window_size: float = 0.5, start_idx: int = 0, ): if not len(dataset) > 0: raise ValueError("Dataset must contain at least one item") if not num_replicas > 0: raise ValueError("num_replicas must be positive") if not 0 <= rank < num_replicas: raise ValueError("rank must be non-negative and less than num_replicas") if not 0 <= window_size <= 1: raise ValueError("window_size must be between 0 and 1") super().__init__() self.dataset = dataset self.rank = rank self.num_replicas = num_replicas self.shuffle = shuffle self.seed = seed self.window_size = window_size self.start_idx = start_idx def __iter__(self) -> Iterator[int]: order = np.arange(len(self.dataset)) rnd = None window = 0 if self.shuffle: rnd = np.random.RandomState(self.seed) rnd.shuffle(order) window = int(np.rint(order.size * self.window_size)) idx = self.start_idx while True: i = idx % order.size if idx % self.num_replicas == self.rank: yield order[i] if window >= 2: j = (i - rnd.randint(window)) % order.size order[i], order[j] = order[j], order[i] idx += 1