# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterator, Sequence, TypeVar
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
_T = TypeVar("_T")
def _unwrap_module(model: torch.nn.Module, target_cls: type[_T]) -> _T:
"""Peel off DDP / torch.compile wrappers to reach the underlying module.
The unwrapping order handles arbitrary nesting of
``DistributedDataParallel`` (``model.module``) and ``torch.compile``
(``OptimizedModule._orig_mod``).
Parameters
----------
model : torch.nn.Module
Possibly-wrapped module to unwrap.
target_cls : type
The class to unwrap to.
Returns
-------
module : target_cls
The unwrapped module.
Raises
------
TypeError
If ``target_cls`` is not found before reaching a leaf module.
"""
m: Any = model
while not isinstance(m, target_cls):
if isinstance(m, torch._dynamo.eval_frame.OptimizedModule):
m = m._orig_mod
elif hasattr(m, "module"):
m = m.module
else:
raise TypeError(
f"Could not unwrap a {target_cls.__name__} from "
f"{type(model).__name__}. Found leaf type "
f"{type(m).__name__}."
)
return m
def apply_loss_weight(
w: Float[Tensor, " N"] | Float[Tensor, " N C"],
data_ndim: int,
) -> Tensor:
"""Reshape a loss weight tensor for broadcasting with per-element loss.
Handles both scalar weights of shape ``(N,)`` and per-channel weights of
shape ``(N, C)`` produced by schedulers with per-channel ``sigma_data``.
Trailing spatial dimensions are padded with size-1 so that the result
broadcasts correctly against loss tensors of shape ``(B, C, ...)``.
The output shape depends on the input:
- ``(N,)`` -> ``(N, 1, ..., 1)`` with ``data_ndim - 1`` trailing ones
- ``(N, C)`` -> ``(N, C, 1, ..., 1)`` with ``data_ndim - 2`` trailing ones
Parameters
----------
w : Tensor
Loss weight tensor of shape :math:`(N,)` or :math:`(N, C)`.
data_ndim : int
Number of dimensions in the data tensor (e.g. 4 for ``(B, C, H, W)``).
Returns
-------
Tensor
Reshaped weight tensor with trailing size-1 dimensions appended so
that ``w.ndim == data_ndim``.
"""
return w.reshape(*w.shape, *([1] * (data_ndim - w.ndim)))
[docs]
class StackedRandomGenerator:
"""
Wrapper for ``torch.Generator`` that allows specifying a different random
seed for each sample in a minibatch.
Parameters
----------
device : torch.device
Device to use for the random number generator.
seeds : Sequence[int]
Sequence (e.g. list or tuple) of random seeds for each sample in the
minibatch. Its length defines the batch size of generated samples.
"""
def __init__(self, device: torch.device, seeds: Sequence[int]):
super().__init__()
self.generators = [
torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds
]
[docs]
def randn(
self,
size: torch.Size | Sequence[int],
**kwargs: Any,
) -> torch.Tensor:
"""
Generate stacked samples from a standard normal distribution. Each sample is
generated using a different random seed.
Parameters
----------
size : Sequence[int] | torch.Size
Size of the output tensor. Accepts any sequence of integers or a
``torch.Size`` instance. First dimension must match the number of
random seeds.
**kwargs : Any
Additional arguments to pass to ``torch.randn``.
Returns
-------
torch.Tensor
Stacked samples from a standard normal distribution. Shape matches
``size``.
"""
if size[0] != len(self.generators):
raise ValueError(
f"Expected first dimension of size {len(self.generators)}, got {size[0]}"
)
return torch.stack(
[torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]
)
[docs]
def randt(
self,
nu: int,
size: torch.Size | Sequence[int],
**kwargs: Any,
) -> torch.Tensor:
"""
Generate stacked samples from a standard Student-t distribution with
``nu`` degrees of freedom. This is useful when sampling from heavy-tailed
diffusion models.
Parameters
----------
nu : int
Degrees of freedom for the Student-t distribution. Must be > 2.
size : Sequence[int] | torch.Size
Size of the output tensor. Accepts any sequence of integers or a
``torch.Size`` instance. First dimension must match the number of
random seeds.
**kwargs : Any
Additional arguments to pass to ``torch.randn``.
Returns
-------
torch.Tensor
Stacked samples from a standard Student-t distribution. Shape matches
``size``.
"""
# Size validation
if size[0] != len(self.generators):
raise ValueError(
f"Expected first dimension of size {len(self.generators)}, got {size[0]}"
)
# Validation for nu
if nu <= 2:
raise ValueError(f"Expected nu > 2, but got {nu}.")
# Generate samples from Student-t distribution
chi_dist = torch.distributions.Chi2(nu)
kappa = (
(chi_dist.sample((len(self.generators),)) / nu)
.view(-1, *([1] * len(size[1:])))
.to(self.generators[0].device)
)
eps = torch.stack(
[torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]
)
return eps / torch.sqrt(kappa)
[docs]
def randn_like(self, input: torch.Tensor) -> torch.Tensor:
"""
Generate stacked samples from a standard normal distribution with the same
shape and data type as the input tensor.
Parameters
----------
input : torch.Tensor
Input tensor to match the shape, data type, memory layout, and
device of.
Returns
-------
torch.Tensor
Stacked samples from a standard normal distribution. Shape matches
``input.shape``.
"""
return self.randn(
input.shape, dtype=input.dtype, layout=input.layout, device=input.device
)
[docs]
def randint(
self,
*args: Any,
size: torch.Size | Sequence[int],
**kwargs: Any,
) -> torch.Tensor:
"""
Generate stacked samples from a uniform distribution over the integers.
Parameters
----------
*args : Any
Required positional arguments to pass to ``torch.randint``.
size : Sequence[int] | torch.Size
Size of the output tensor. Accepts any sequence of integers or a
``torch.Size`` instance. First dimension must match the number of
random seeds.
**kwargs : Any
Additional keyword arguments to pass to ``torch.randint``.
Returns
-------
torch.Tensor
Stacked samples from a uniform distribution over the integers. Shape
matches ``size``.
"""
if size[0] != len(self.generators):
raise ValueError(
f"Expected first dimension of size {len(self.generators)}, got {size[0]}"
)
return torch.stack(
[
torch.randint(*args, size=size[1:], generator=gen, **kwargs)
for gen in self.generators
]
)
[docs]
class InfiniteSampler(torch.utils.data.Sampler[int]):
"""Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely.
This sampler yields indices indefinitely, optionally shuffling items as it goes.
It can also perform distributed sampling when `rank` and `num_replicas` are
specified.
Parameters
----------
dataset : torch.utils.data.Dataset
The dataset to sample from
rank : int, default=0
The rank of the current process within num_replicas processes
num_replicas : int, default=1
The number of processes participating in distributed sampling
shuffle : bool, default=True
Whether to shuffle the indices
seed : int, default=0
Random seed for reproducibility when shuffling
window_size : float, default=0.5
Fraction of dataset to use as window for shuffling. Must be between 0 and 1.
A larger window means more thorough shuffling but slower iteration.
start_idx : int, default=0
The initial index to use for the sampler. This is used for resuming training.
"""
def __init__(
self,
dataset: torch.utils.data.Dataset,
rank: int = 0,
num_replicas: int = 1,
shuffle: bool = True,
seed: int = 0,
window_size: float = 0.5,
start_idx: int = 0,
):
if not len(dataset) > 0:
raise ValueError("Dataset must contain at least one item")
if not num_replicas > 0:
raise ValueError("num_replicas must be positive")
if not 0 <= rank < num_replicas:
raise ValueError("rank must be non-negative and less than num_replicas")
if not 0 <= window_size <= 1:
raise ValueError("window_size must be between 0 and 1")
super().__init__()
self.dataset = dataset
self.rank = rank
self.num_replicas = num_replicas
self.shuffle = shuffle
self.seed = seed
self.window_size = window_size
self.start_idx = start_idx
def __iter__(self) -> Iterator[int]:
order = np.arange(len(self.dataset))
rnd = None
window = 0
if self.shuffle:
rnd = np.random.RandomState(self.seed)
rnd.shuffle(order)
window = int(np.rint(order.size * self.window_size))
idx = self.start_idx
while True:
i = idx % order.size
if idx % self.num_replicas == self.rank:
yield order[i]
if window >= 2:
j = (i - rnd.randint(window)) % order.size
order[i], order[j] = order[j], order[i]
idx += 1