# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Noise schedulers for diffusion models."""
import math
from abc import ABC, abstractmethod
from typing import Any, Literal, Protocol, Tuple, runtime_checkable
import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor
from physicsnemo.diffusion.base import Denoiser, Predictor
[docs]
@runtime_checkable
class NoiseScheduler(Protocol):
r"""
Protocol defining the minimal interface for noise schedulers.
A noise scheduler defines methods for training (adding noise, sampling
diffusion time) and for sampling (generating diffusion time-steps,
initializing latent state, obtaining a denoiser). This interface is generic
and does not assume any specific form of noise schedule.
Any object that implements this interface can be used with the diffusion
training and sampling utilities.
**Training methods:**
- :meth:`sample_time`: Sample diffusion time values for training
- :meth:`add_noise`: Add noise to clean data at given diffusion time
- :meth:`loss_weight`: Compute per-sample loss weight for training
**Sampling methods:**
- :meth:`timesteps`: Generate discrete time-steps for sampling
- :meth:`init_latents`: Initialize noisy latent state :math:`\mathbf{x}_N`
- :meth:`get_denoiser`: Convert a predictor (e.g. model that predicts
clean, data, score, etc.) to a sampling-compatible denoiser
See Also
--------
:class:`LinearGaussianNoiseScheduler` : base abstract class for
linear-Gaussian schedules. Implements the NoiseScheduler protocol.
:func:`~physicsnemo.diffusion.samplers.sample` : sampling function for
generating data samples from a diffusion model.
Examples
--------
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import NoiseScheduler
>>>
>>> class MyScheduler:
... def sample_time(self, N, device=None, dtype=None):
... return torch.rand(N, device=device, dtype=dtype)
... def add_noise(self, x0, time):
... return x0 + time.view(-1, 1) * torch.randn_like(x0)
... def timesteps(self, num_steps, device=None, dtype=None):
... return torch.linspace(1, 0, num_steps + 1, device=device)
... def init_latents(self, spatial_shape, tN, device=None, dtype=None):
... return torch.randn(tN.shape[0], *spatial_shape, device=device)
... def get_denoiser(self, x0_predictor=None, score_predictor=None, **kwargs):
... def denoiser(x, t):
... if x0_predictor is not None:
... return (x - x0_predictor(x, t)) / (t.view(-1, 1))
... elif score_predictor is not None:
... return -score_predictor(x, t) * t.view(-1, 1)
... return denoiser
... def loss_weight(self, t):
... return 1 / t**2
...
>>> scheduler = MyScheduler()
>>> isinstance(scheduler, NoiseScheduler)
True
"""
[docs]
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion time values for training.
Used in training to sample random diffusion times, typically in the
denoising score matching loss.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
...
[docs]
def add_noise(
self,
x0: Float[Tensor, " B *dims"],
time: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Add noise to clean data at the given diffusion times.
Used in training to create noisy samples from clean data.
Parameters
----------
x0 : Tensor
Clean latent state of shape :math:`(B, *)`.
time : Tensor
Diffusion time values of shape :math:`(B,)`.
Returns
-------
Tensor
Noisy latent state of shape :math:`(B, *)`.
"""
...
[docs]
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate discrete time-steps for sampling.
Used in sampling to produce the sequence of diffusion times.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Time-steps tensor of shape :math:`(N + 1,)` in decreasing order,
with the last element being 0.
"""
...
[docs]
def init_latents(
self,
spatial_shape: Tuple[int, ...],
tN: Float[Tensor, " B"],
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " B *spatial_shape"]:
r"""
Initialize the noisy latent state :math:`\mathbf{x}_N` for sampling.
Used in sampling to generate the initial condition at diffusion time
``tN``.
Parameters
----------
spatial_shape : Tuple[int, ...]
Spatial shape of the latent state, e.g., ``(C, H, W)``.
tN : Tensor
Initial diffusion time of shape :math:`(B,)`. Determines the noise
level for the initial latent state.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Initial noisy latent state of shape :math:`(B, *spatial\_shape)`.
"""
...
[docs]
def get_denoiser(
self,
**kwargs: Any,
) -> Denoiser:
r"""
Factory that converts a predictor into a denoiser for sampling.
Used in sampling to transform a :class:`Predictor` (e.g., x0-predictor,
score-predictor) into a :class:`Denoiser` that returns the
update term compatible with the solver. The exact transformation
depends on the noise scheduler implementation.
Parameters
----------
**kwargs : Any
Implementation-specific keyword arguments. Concrete
implementations typically accept keyword-only predictor arguments
(e.g., ``score_predictor``, ``x0_predictor``). See concrete classes
docstrings for details (e.g.
:meth:`LinearGaussianNoiseScheduler.get_denoiser`).
Returns
-------
Denoiser
A callable that implements the
:class:`~physicsnemo.diffusion.Denoiser` interface, for use
with solvers and the
:func:`~physicsnemo.diffusion.samplers.sample` function.
"""
...
[docs]
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"] | Float[Tensor, " N C"]:
r"""
Compute loss weight for denoising score matching training.
Used in training to weight the per-sample loss in
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight with leading dimension :math:`N`. Shape is
:math:`(N,)` for scalar ``sigma_data``, or :math:`(N, C)`
when the scheduler uses per-channel ``sigma_data`` (see
:class:`EDMNoiseScheduler`).
"""
...
[docs]
class LinearGaussianNoiseScheduler(ABC, NoiseScheduler):
r"""
Abstract base class for linear-Gaussian noise schedules.
It implements the :class:`NoiseScheduler` interface and it can be
subclassed to define custom linear-Gaussian noise schedules of the form:
.. math::
\mathbf{x}(t) = \alpha(t) \mathbf{x}_0
+ \sigma(t) \boldsymbol{\epsilon}
where :math:`\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})` is
standard Gaussian noise, :math:`\alpha(t)` is the signal coefficient, and
:math:`\sigma(t)` is the noise level.
**Training:**
The :meth:`add_noise` method implements the forward diffusion process using
the formula above. The :meth:`sample_time` method samples diffusion times.
**Sampling:**
For ODE-based sampling, the reverse process follows the probability flow
ODE:
.. math::
\frac{d\mathbf{x}}{dt} = f(\mathbf{x}, t)
- \frac{1}{2} g^2(\mathbf{x}, t) \nabla_{\mathbf{x}} \log p(\mathbf{x})
For SDE-based sampling:
.. math::
d\mathbf{x} = \left[ f(\mathbf{x}, t)
- g^2(\mathbf{x}, t) \nabla_{\mathbf{x}} \log p(\mathbf{x}) \right] dt
+ g(\mathbf{x}, t) d\mathbf{W}
The :meth:`get_denoiser` factory converts a predictor (either a
score-predictor or an x0-predictor) into the appropriate ODE/SDE
right-hand side.
**Abstract methods (must be implemented by subclasses):**
- :meth:`sigma`: Map time to noise level :math:`\sigma(t)`
- :meth:`sigma_inv`: Map noise level back to time
- :meth:`sigma_dot`: Time derivative :math:`\dot{\sigma}(t)`
- :meth:`alpha`: Compute the signal coefficient :math:`\alpha(t)`
- :meth:`alpha_dot`: Time derivative :math:`\dot{\alpha}(t)`
- :meth:`timesteps`: Generate discrete time-steps for sampling
- :meth:`sample_time`: Sample diffusion times for training
- :meth:`loss_weight`: Compute loss weight for training
**Concrete methods (have default implementations, but can be overridden for
custom behavior):**
- :meth:`drift`: Drift term :math:`f(\mathbf{x}, t)` for ODE/SDE
- :meth:`diffusion`: Squared diffusion term :math:`g^2(\mathbf{x}, t)`
- :meth:`x0_to_score`: Convert x0-prediction to score
- :meth:`score_to_x0`: Convert score to x0-prediction
- :meth:`add_noise`: Add noise to clean data (training)
- :meth:`init_latents`: Initialize latent state (sampling)
- :meth:`get_denoiser`: Get ODE/SDE RHS (sampling)
Examples
--------
**Example 1:** A minimal EDM-like noise schedule. Only the abstract methods
need to be implemented since defaults work for EDM:
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import (
... LinearGaussianNoiseScheduler,
... )
>>>
>>> class SimpleEDMScheduler(LinearGaussianNoiseScheduler):
... def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
... self.sigma_min = sigma_min
... self.sigma_max = sigma_max
... self.rho = rho
...
... def sigma(self, t): return t
... def sigma_inv(self, sigma): return sigma
... def sigma_dot(self, t): return torch.ones_like(t)
... def alpha(self, t): return torch.ones_like(t)
... def alpha_dot(self, t): return torch.zeros_like(t)
...
... def timesteps(self, num_steps, *, device=None, dtype=None):
... i = torch.arange(num_steps, device=device, dtype=dtype)
... smax_rho = self.sigma_max**(1/self.rho)
... smin_rho = self.sigma_min**(1/self.rho)
... frac = i/(num_steps-1)
... t = (smax_rho + frac * (smin_rho - smax_rho))**self.rho
... return torch.cat([t, torch.zeros(1, device=device)])
...
... def sample_time(self, N, *, device=None, dtype=None):
... u = torch.rand(N, device=device, dtype=dtype)
... return self.sigma_min * (self.sigma_max/self.sigma_min)**u
... def loss_weight(self, t):
... return 1 / t**2
...
>>> scheduler = SimpleEDMScheduler()
>>> t_steps = scheduler.timesteps(10)
>>> t_steps.shape
torch.Size([11])
**Example 2:** Customizing behavior by overriding concrete methods. This
shows how to override the drift term for a custom diffusion process:
>>> class CustomDriftScheduler(SimpleEDMScheduler):
... def drift(self, x, t):
... # Custom drift: f(x, t) = -0.5 * x (Ornstein-Uhlenbeck style)
... return -0.5 * x
...
>>> custom = CustomDriftScheduler()
>>>
>>> # The custom drift is used internally by get_denoiser
>>> score_pred = lambda x, t: -x / (1 + t.view(-1, 1)**2) # Toy score predictor
>>> denoiser = custom.get_denoiser(score_predictor=score_pred)
>>> x = torch.randn(2, 4)
>>> t = torch.tensor([1.0, 1.0])
>>> out = denoiser(x, t) # Uses custom drift in ODE RHS computation
>>> out.shape
torch.Size([2, 4])
"""
[docs]
@abstractmethod
def sigma(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""
Map diffusion time to noise level :math:`\sigma(t)`.
Used in both training and sampling.
Parameters
----------
t : Tensor
Diffusion time tensor of any shape.
Returns
-------
Tensor
Noise coefficient :math:`\sigma(t)` with same shape as ``t``.
"""
...
[docs]
@abstractmethod
def sigma_inv(
self,
sigma: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""
Map noise level back to diffusion time.
Used in both training and sampling.
Parameters
----------
sigma : Tensor
Noise level tensor of any shape.
Returns
-------
Tensor
Diffusion time with same shape as ``sigma``.
"""
...
[docs]
@abstractmethod
def sigma_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""
Compute time derivative of noise level :math:`\dot{\sigma}(t)`.
Used in sampling.
Parameters
----------
t : Tensor
Diffusion time tensor of any shape.
Returns
-------
Tensor
Time derivative :math:`\dot{\sigma}(t)` with same shape as ``t``.
"""
...
[docs]
@abstractmethod
def alpha(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""
Compute the signal coefficient :math:`\alpha(t)`.
Used in both training and sampling.
Parameters
----------
t : Tensor
Diffusion time tensor of any shape.
Returns
-------
Tensor
Signal coefficient :math:`\alpha(t)` with same shape as ``t``.
"""
...
[docs]
@abstractmethod
def alpha_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""
Compute time derivative of signal coefficient :math:`\dot{\alpha}(t)`.
Used in sampling.
Parameters
----------
t : Tensor
Diffusion time tensor of any shape.
Returns
-------
Tensor
Time derivative :math:`\dot{\alpha}(t)` with same shape as ``t``.
"""
...
[docs]
@abstractmethod
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate discrete time-steps for sampling.
Used in sampling to produce the sequence of diffusion times. Returns
a tensor of shape :math:`(N + 1,)` in decreasing order, with the last
element being 0.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Time-steps tensor of shape :math:`(N + 1,)`.
"""
...
[docs]
@abstractmethod
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion time values for training.
Used in training to sample random diffusion times for the denoising
score matching loss.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
...
[docs]
@abstractmethod
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"] | Float[Tensor, " N C"]:
r"""
Compute loss weight for denoising score matching training.
Used in training to weight the per-sample loss in
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`. The loss
weight is designed for training an x0-predictor (clean data
predictor).
For training a score-predictor, additionally provide a
``score_to_x0_fn`` callback to
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight with leading dimension :math:`N`. Shape is
:math:`(N,)` for scalar ``sigma_data``, or :math:`(N, C)`
when the scheduler uses per-channel ``sigma_data`` (see
:class:`EDMNoiseScheduler`).
"""
...
[docs]
def drift(
self,
x: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Compute drift term :math:`f(\mathbf{x}, t)` for ODE/SDE sampling.
Used by :meth:`get_denoiser` to build the ODE/SDE right-hand side.
By default: :math:`f(\mathbf{x}, t) = \frac{\dot{\alpha}(t)}{\alpha(t)}
\mathbf{x}`.
This method can be overridden to implement different drift terms.
Parameters
----------
x : Tensor
Latent state of shape :math:`(B, *)`.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Drift term with same shape as ``x``.
"""
expected_shape = (-1,) + (1,) * (x.ndim - 1)
t_bc = t.reshape(expected_shape)
alpha_t_bc = self.alpha(t_bc)
alpha_dot_t_bc = self.alpha_dot(t_bc)
return (alpha_dot_t_bc / alpha_t_bc) * x
[docs]
def diffusion(
self,
x: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *_"]:
r"""
Compute squared diffusion term :math:`g^2(\mathbf{x}, t)`.
Used by :meth:`get_denoiser` to build the ODE/SDE right-hand side.
By default: :math:`g^2 = 2 \dot{\sigma} \sigma - 2 \frac{\dot{\alpha}}
{\alpha} \sigma^2`.
This method can be overridden to implement different diffusion terms.
Parameters
----------
x : Tensor
Latent state of shape :math:`(B, *)`.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Squared diffusion term, broadcastable to shape of ``x``.
"""
expected_shape = (-1,) + (1,) * (x.ndim - 1)
t_bc = t.reshape(expected_shape)
sigma_t_bc = self.sigma(t_bc)
sigma_dot_t_bc = self.sigma_dot(t_bc)
alpha_t_bc = self.alpha(t_bc)
alpha_dot_t_bc = self.alpha_dot(t_bc)
g_sq_bc = (
2 * sigma_dot_t_bc * sigma_t_bc
- 2 * (alpha_dot_t_bc / alpha_t_bc) * sigma_t_bc**2
)
return g_sq_bc
[docs]
def x0_to_score(
self,
x0: Float[Tensor, " B *dims"],
x_t: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Convert x0-predictor output to score.
This conversion is done automatically by :meth:`get_denoiser` when
``x0_predictor`` is provided, but can also be called manually.
The score is: :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)
= \frac{\alpha(t) \hat{\mathbf{x}}_0 - \mathbf{x}_t}{\sigma^2(t)}`.
This is a helper method that usually does not need to be overridden in
subclasses.
Parameters
----------
x0 : Tensor
Predicted clean data :math:`\hat{\mathbf{x}}_0` of shape
:math:`(B, *)`.
x_t : Tensor
Current noisy state :math:`\mathbf{x}_t` of shape :math:`(B, *)`.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Score with same shape as ``x0``.
Examples
--------
>>> scheduler = EDMNoiseScheduler()
>>> # If you have an x0-predictor, wrap it for manual conversion
>>> # (done automatically by get_denoiser):
>>> def x0_predictor(x, t):
... t_bc = t.view(-1, *([1] * (x.ndim - 1)))
... return x / (1 + t_bc**2)
>>> def score_predictor(x, t):
... x0_pred = x0_predictor(x, t)
... return scheduler.x0_to_score(x0_pred, x, t)
>>> # Or simply: scheduler.get_denoiser(x0_predictor=x0_predictor)
"""
expected_shape = (-1,) + (1,) * (x0.ndim - 1)
t_bc = t.reshape(expected_shape)
alpha_t_bc = self.alpha(t_bc)
sigma_t_bc = self.sigma(t_bc)
return (alpha_t_bc * x0 - x_t) / (sigma_t_bc**2)
[docs]
def score_to_x0(
self,
score: Float[Tensor, " B *dims"],
x_t: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Convert score to x0-prediction.
This is the inverse of :meth:`x0_to_score`. Given a score
prediction :math:`s(\mathbf{x}_t, t)` and
the noisy state :math:`\mathbf{x}_t`, recover the corresponding
:math:`\hat{\mathbf{x}}_0` estimate:
.. math::
\hat{\mathbf{x}}_0 = \frac{\mathbf{x}_t
+ \sigma^2(t) \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)}
{\alpha(t)}
A common use case is with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss` to train a
score-predictor instead of an x0-predictor: pass this method as the
``score_to_x0_fn`` argument with ``prediction_type="score"``.
This is a helper method that usually does not need to be overridden
in subclasses.
Parameters
----------
score : Tensor
Predicted score :math:`s(\mathbf{x}_t, t)` of shape :math:`(B, *)`.
x_t : Tensor
Current noisy state :math:`\mathbf{x}_t` with same shape as
``score``.
t : Tensor
Diffusion time with shape :math:`(B,)`.
Returns
-------
Tensor
Estimated clean data :math:`\hat{\mathbf{x}}_0` with same shape
as ``score``.
Examples
--------
>>> scheduler = EDMNoiseScheduler()
>>> # If you have a score-predictor, convert to x0 for DSM loss:
>>> def score_predictor(x, t):
... return -x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2)
>>> x_t = torch.randn(2, 4)
>>> t = torch.tensor([1.0, 1.0])
>>> score = score_predictor(x_t, t)
>>> x0_est = scheduler.score_to_x0(score, x_t, t)
>>> x0_est.shape
torch.Size([2, 4])
"""
expected_shape = (-1,) + (1,) * (score.ndim - 1)
t_bc = t.reshape(expected_shape)
alpha_t_bc = self.alpha(t_bc)
sigma_t_bc = self.sigma(t_bc)
return (x_t + sigma_t_bc**2 * score) / alpha_t_bc
[docs]
def epsilon_to_score(
self,
epsilon: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Convert epsilon (noise) prediction to score.
For the linear-Gaussian forward process
:math:`\mathbf{x}_t = \alpha(t)\mathbf{x}_0
+ \sigma(t)\boldsymbol{\epsilon}`, the score is related to epsilon by:
.. math::
\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)
= -\frac{\boldsymbol{\epsilon}}{\sigma(t)}
Parameters
----------
epsilon : Tensor
Predicted noise :math:`\hat{\boldsymbol{\epsilon}}` of shape
:math:`(B, *)`.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Score with same shape as ``epsilon``.
Examples
--------
>>> scheduler = EDMNoiseScheduler()
>>> eps = torch.randn(2, 4)
>>> t = torch.tensor([1.0, 1.0])
>>> score = scheduler.epsilon_to_score(eps, t)
>>> score.shape
torch.Size([2, 4])
"""
expected_shape = (-1,) + (1,) * (epsilon.ndim - 1)
t_bc = t.reshape(expected_shape)
sigma_t_bc = self.sigma(t_bc)
return -epsilon / sigma_t_bc
[docs]
def score_to_epsilon(
self,
score: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Convert score to epsilon (noise) prediction.
Inverse of :meth:`epsilon_to_score`:
.. math::
\boldsymbol{\epsilon}
= -\sigma(t) \nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)
Parameters
----------
score : Tensor
Score :math:`\nabla_{\mathbf{x}_t} \log p(\mathbf{x}_t)` of
shape :math:`(B, *)`.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Epsilon with same shape as ``score``.
Examples
--------
>>> scheduler = EDMNoiseScheduler()
>>> score = torch.randn(2, 4)
>>> t = torch.tensor([1.0, 1.0])
>>> eps = scheduler.score_to_epsilon(score, t)
>>> eps.shape
torch.Size([2, 4])
"""
expected_shape = (-1,) + (1,) * (score.ndim - 1)
t_bc = t.reshape(expected_shape)
sigma_t_bc = self.sigma(t_bc)
return -sigma_t_bc * score
[docs]
def epsilon_to_x0(
self,
epsilon: Float[Tensor, " B *dims"],
x_t: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Convert epsilon (noise) prediction to x0-prediction.
Given :math:`\mathbf{x}_t = \alpha(t)\mathbf{x}_0
+ \sigma(t)\boldsymbol{\epsilon}`:
.. math::
\hat{\mathbf{x}}_0 = \frac{\mathbf{x}_t
- \sigma(t)\hat{\boldsymbol{\epsilon}}}{\alpha(t)}
Parameters
----------
epsilon : Tensor
Predicted noise :math:`\hat{\boldsymbol{\epsilon}}` of shape
:math:`(B, *)`.
x_t : Tensor
Current noisy state :math:`\mathbf{x}_t` with same shape as
``epsilon``.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Estimated clean data :math:`\hat{\mathbf{x}}_0` with same shape
as ``epsilon``.
Examples
--------
>>> scheduler = EDMNoiseScheduler()
>>> eps = torch.randn(2, 4)
>>> x_t = torch.randn(2, 4)
>>> t = torch.tensor([1.0, 1.0])
>>> x0_est = scheduler.epsilon_to_x0(eps, x_t, t)
>>> x0_est.shape
torch.Size([2, 4])
"""
expected_shape = (-1,) + (1,) * (epsilon.ndim - 1)
t_bc = t.reshape(expected_shape)
alpha_t_bc = self.alpha(t_bc)
sigma_t_bc = self.sigma(t_bc)
return (x_t - sigma_t_bc * epsilon) / alpha_t_bc
[docs]
def x0_to_epsilon(
self,
x0: Float[Tensor, " B *dims"],
x_t: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Convert x0-prediction to epsilon (noise) prediction.
Inverse of :meth:`epsilon_to_x0`:
.. math::
\hat{\boldsymbol{\epsilon}} = \frac{\mathbf{x}_t
- \alpha(t)\hat{\mathbf{x}}_0}{\sigma(t)}
Parameters
----------
x0 : Tensor
Predicted clean data :math:`\hat{\mathbf{x}}_0` of shape
:math:`(B, *)`.
x_t : Tensor
Current noisy state :math:`\mathbf{x}_t` with same shape as
``x0``.
t : Tensor
Diffusion time of shape :math:`(B,)`.
Returns
-------
Tensor
Epsilon with same shape as ``x0``.
Examples
--------
>>> scheduler = EDMNoiseScheduler()
>>> x0 = torch.randn(2, 4)
>>> x_t = torch.randn(2, 4)
>>> t = torch.tensor([1.0, 1.0])
>>> eps = scheduler.x0_to_epsilon(x0, x_t, t)
>>> eps.shape
torch.Size([2, 4])
"""
expected_shape = (-1,) + (1,) * (x0.ndim - 1)
t_bc = t.reshape(expected_shape)
alpha_t_bc = self.alpha(t_bc)
sigma_t_bc = self.sigma(t_bc)
return (x_t - alpha_t_bc * x0) / sigma_t_bc
[docs]
def get_denoiser(
self,
*,
score_predictor: Predictor | None = None,
x0_predictor: Predictor | None = None,
epsilon_predictor: Predictor | None = None,
denoising_type: Literal["ode", "sde"] = "ode",
**kwargs: Any,
) -> Denoiser:
r"""
Factory that converts a predictor to a denoiser for sampling.
Accepts exactly one of **score-predictor**, **x0-predictor**, or
**epsilon-predictor**. The returned denoiser computes the right-hand
side of the reverse ODE or SDE.
For ODE (``denoising_type="ode"``):
.. math::
\frac{d\mathbf{x}}{dt} = f(\mathbf{x}, t) - \frac{1}{2} g^2(t)
s(\mathbf{x}, t)
For SDE (``denoising_type="sde"``):
.. math::
d\mathbf{x} = \left[ f(\mathbf{x}, t) - g^2(t) s(\mathbf{x}, t)
\right] dt + g(t) d\mathbf{W}
where :math:`s(\mathbf{x}, t)` is the score. When an x0-predictor is
provided, the score is computed internally via :meth:`x0_to_score`.
When an epsilon-predictor is provided, the score is computed internally
via :meth:`epsilon_to_score`. When a score-predictor is provided, it
is used directly. *Note:* As usually done in SDE integration, the
stochastic term :math:`g(t) d\mathbf{W}` is handled by the solver,
not returned by the denoiser itself.
Parameters
----------
score_predictor : Predictor, optional
A score-predictor that takes ``(x_t, t)`` and returns a score
(e.g. :math:`\nabla_{\mathbf{x}} \log p(\mathbf{x}_t)`). Can be
unconditional, conditional, guidance-augmented, etc. Mutually
exclusive with ``x0_predictor`` and ``epsilon_predictor``.
x0_predictor : Predictor, optional
An x0-predictor that takes ``(x_t, t)`` and returns an estimate
of clean data :math:`\hat{\mathbf{x}}_0`. The score is computed
internally via :meth:`x0_to_score`. Mutually exclusive with
``score_predictor`` and ``epsilon_predictor``.
epsilon_predictor : Predictor, optional
An epsilon-predictor that takes ``(x_t, t)`` and returns an
estimate of the noise
:math:`\hat{\boldsymbol{\epsilon}}`. The score is computed
internally via :meth:`epsilon_to_score`. Mutually exclusive with
``score_predictor`` and ``x0_predictor``.
denoising_type : {"ode", "sde"}, default="ode"
Type of reverse process. Use ``"ode"`` for deterministic sampling,
``"sde"`` for stochastic sampling.
**kwargs : Any
Ignored.
Returns
-------
Denoiser
A denoiser computing the RHS of the reverse ODE/SDE. Implements
the :class:`~physicsnemo.diffusion.Denoiser` interface.
Raises
------
ValueError
If not exactly one of ``score_predictor``, ``x0_predictor``, or
``epsilon_predictor`` is provided.
Examples
--------
Generate ODE RHS from a score-predictor:
>>> import torch
>>> scheduler = EDMNoiseScheduler()
>>> score_pred = lambda x, t: -x / t.view(-1, 1, 1, 1)**2 # Toy score-predictor
>>> denoiser = scheduler.get_denoiser(
... score_predictor=score_pred, denoising_type="ode")
>>> x = torch.randn(2, 3, 8, 8)
>>> t = torch.tensor([1.0, 1.0])
>>> dx_dt = denoiser(x, t) # Returns ODE RHS for sampling
>>> dx_dt.shape
torch.Size([2, 3, 8, 8])
Generate ODE RHS from an x0-predictor (score conversion is done internally):
>>> x0_pred = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(
... x0_predictor=x0_pred, denoising_type="ode")
>>> dx_dt = denoiser(x, t) # Returns ODE RHS for sampling
>>> dx_dt.shape
torch.Size([2, 3, 8, 8])
Generate ODE RHS from an epsilon-predictor:
>>> eps_pred = lambda x, t: x * 0.1 # Toy epsilon-predictor
>>> denoiser = scheduler.get_denoiser(
... epsilon_predictor=eps_pred, denoising_type="ode")
>>> dx_dt = denoiser(x, t) # Returns ODE RHS for sampling
>>> dx_dt.shape
torch.Size([2, 3, 8, 8])
"""
# Validate: exactly one predictor must be provided
provided = sum(
p is not None for p in (score_predictor, x0_predictor, epsilon_predictor)
)
if provided != 1:
raise ValueError(
"Exactly one of 'score_predictor', 'x0_predictor', or "
"'epsilon_predictor' must be provided."
)
# Capture methods as local variables to avoid referencing self
drift = self.drift
diffusion = self.diffusion
# Build the score function
if x0_predictor is not None:
x0_to_score = self.x0_to_score
def _score(
x: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
x0 = x0_predictor(x, t)
return x0_to_score(x0, x, t)
score_fn = _score
elif epsilon_predictor is not None:
eps_to_score = self.epsilon_to_score
def _score_from_eps(
x: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
eps = epsilon_predictor(x, t)
return eps_to_score(eps, t)
score_fn = _score_from_eps
else:
score_fn = score_predictor
if denoising_type == "ode":
def ode_denoiser(
x: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
score = score_fn(x, t)
f = drift(x, t)
g_sq_bc = diffusion(x, t)
dx_dt = f - 0.5 * g_sq_bc * score
return dx_dt
return ode_denoiser
elif denoising_type == "sde":
def sde_denoiser(
x: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
score = score_fn(x, t)
f = drift(x, t)
g_sq_bc = diffusion(x, t)
# Deterministic part of the SDE drift
# Note: stochastic term g(t)*dW is handled by the solver
dx_dt = f - g_sq_bc * score
return dx_dt
return sde_denoiser
else:
raise ValueError(
f"denoising_type must be 'ode' or 'sde', got '{denoising_type}'"
)
[docs]
def add_noise(
self,
x0: Float[Tensor, " B *dims"],
time: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Add noise to clean data at the given diffusion times.
Used in training to create noisy samples from clean data. Implements:
.. math::
\mathbf{x}(t) = \alpha(t) \mathbf{x}_0
+ \sigma(t) \boldsymbol{\epsilon}
Usually does not need to be overridden in subclasses: overriding the
:meth:`alpha` and :meth:`sigma` methods is sufficient for most use
cases.
Parameters
----------
x0 : Tensor
Clean latent state of shape :math:`(B, *)`.
time : Tensor
Diffusion time values of shape :math:`(B,)`.
Returns
-------
Tensor
Noisy latent state of shape :math:`(B, *)`.
"""
expected_shape = (-1,) + (1,) * (x0.ndim - 1)
t_bc = time.reshape(expected_shape)
alpha_t_bc = self.alpha(t_bc)
sigma_t_bc = self.sigma(t_bc)
noise = torch.randn_like(x0)
return alpha_t_bc * x0 + sigma_t_bc * noise
[docs]
def init_latents(
self,
spatial_shape: Tuple[int, ...],
tN: Float[Tensor, " B"],
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " B *spatial_shape"]:
r"""
Initialize the noisy latent state :math:`\mathbf{x}_N` for sampling.
Generates:
.. math::
\mathbf{x}_N = \sigma(t_N) \cdot \boldsymbol{\epsilon}
where :math:`\boldsymbol{\epsilon} \sim \mathcal{N}(0, \mathbf{I})`.
Parameters
----------
spatial_shape : Tuple[int, ...]
Spatial shape of the latent state, e.g., ``(C, H, W)``.
tN : Tensor
Initial diffusion time of shape :math:`(B,)`.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Initial noisy latent of shape :math:`(B, *spatial\_shape)`.
"""
B = tN.shape[0]
noise = torch.randn(B, *spatial_shape, device=device, dtype=dtype)
expected_shape = (-1,) + (1,) * len(spatial_shape)
tN_bc = tN.reshape(expected_shape)
sigma_tN_bc = self.sigma(tN_bc)
return sigma_tN_bc * noise
# =============================================================================
# Concrete noise schedule implementations
# =============================================================================
[docs]
class EDMNoiseScheduler(LinearGaussianNoiseScheduler):
r"""
EDM noise scheduler.
The EDM formulation uses :math:`\alpha(t) = 1` (no signal attenuation)
and :math:`\sigma(t) = t` (identity mapping between time and noise level).
**Sampling time-steps** are computed with polynomial spacing:
.. math::
t_i = \left(\sigma_{\max}^{1/\rho} + \frac{i}{N-1}
\left(\sigma_{\min}^{1/\rho} - \sigma_{\max}^{1/\rho}\right)
\right)^{\rho}
**Training times** are sampled from a log-normal distribution with
mean :math:`P_{\text{mean}}` and standard deviation :math:`P_{\text{std}}`.
Parameters
----------
sigma_min : float, optional
Minimum noise level for sampling time-steps, by default 0.002.
sigma_max : float, optional
Maximum noise level for sampling time-steps, by default 80.
rho : float, optional
Exponent controlling time-step spacing. Larger values concentrate more
steps at lower noise levels (better for fine details). By default 7.
sigma_data : float or Tensor, optional
Expected standard deviation of the training data, by default 0.5.
Used by :meth:`loss_weight` to compute the per-sample loss weight.
When a scalar ``float`` is given, it is stored as a 0-D tensor and
the same value is applied to all channels. When a 1-D ``Tensor``
of shape :math:`(C,)` is given, each channel receives its own
weight and :meth:`loss_weight` returns shape :math:`(N, C)` instead
of :math:`(N,)`.
P_mean : float, optional
Mean of the log-normal distribution used to sample training times,
by default -1.2.
P_std : float, optional
Standard deviation of the log-normal distribution used to sample
training times, by default 1.2.
Note
----
Reference: `Elucidating the Design Space of Diffusion-Based
Generative Models <https://arxiv.org/abs/2206.00364>`_
Examples
--------
Basic training and sampling workflow using the EDM noise scheduler:
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=80.0, rho=7)
>>>
>>> # Training: sample times and add noise
>>> x0 = torch.randn(4, 3, 8, 8) # Clean data
>>> t = scheduler.sample_time(4) # Sample diffusion times
>>> x_t = scheduler.add_noise(x0, t) # Create noisy samples
>>> x_t.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Sampling: generate timesteps and initial latents
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(4) # Initial time for batch of 4
>>> xN = scheduler.init_latents((3, 8, 8), tN) # Initial noise
>>> xN.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Convert x0-predictor to denoiser for sampling
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>> denoiser(xN, tN).shape # ODE RHS for sampling
torch.Size([4, 3, 8, 8])
Per-channel ``sigma_data`` for heterogeneous channels:
>>> sigma_per_ch = torch.tensor([0.3, 0.5, 0.7])
>>> scheduler_ch = EDMNoiseScheduler(sigma_data=sigma_per_ch)
>>> t = scheduler_ch.sample_time(4)
>>> w = scheduler_ch.loss_weight(t)
>>> w.shape
torch.Size([4, 3])
"""
def __init__(
self,
sigma_min: float = 0.002,
sigma_max: float = 80.0,
rho: float = 7.0,
sigma_data: float | Float[Tensor, " C"] = 0.5,
P_mean: float = -1.2,
P_std: float = 1.2,
) -> None:
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
self.sigma_data: Tensor = (
sigma_data
if isinstance(sigma_data, Tensor)
else torch.as_tensor(sigma_data, dtype=torch.float32)
)
self._per_channel: bool = self.sigma_data.ndim > 0
self.P_mean = P_mean
self.P_std = P_std
[docs]
def sigma(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Identity mapping: :math:`\sigma(t) = t`."""
return t
[docs]
def sigma_inv(
self,
sigma: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Identity mapping: :math:`t = \sigma`."""
return sigma
[docs]
def sigma_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant derivative: :math:`\dot{\sigma}(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant signal coefficient: :math:`\alpha(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Zero derivative: :math:`\dot{\alpha}(t) = 0`."""
return torch.zeros_like(t)
[docs]
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate EDM time-steps with polynomial spacing.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
torch.Tensor
Time-steps tensor of shape :math:`(N + 1,)` where :math:`N` is
``num_steps``.
"""
step_indices = torch.arange(num_steps, dtype=dtype, device=device)
smax_inv_rho = self.sigma_max ** (1 / self.rho)
smin_inv_rho = self.sigma_min ** (1 / self.rho)
frac = step_indices / (num_steps - 1)
interp = smax_inv_rho + frac * (smin_inv_rho - smax_inv_rho)
t_steps = interp**self.rho
zero = torch.zeros(1, dtype=dtype, device=device)
return torch.cat([t_steps, zero])
[docs]
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion times from a log-normal distribution:
:math:`\ln(\sigma) \sim \mathcal{N}(P_{\text{mean}},
P_{\text{std}}^2)`.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
rnd_normal = torch.randn(N, device=device, dtype=dtype)
return (rnd_normal * self.P_std + self.P_mean).exp()
[docs]
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"] | Float[Tensor, " N C"]:
r"""
Compute EDM loss weight.
.. math::
w(t) = \frac{\sigma(t)^2 + \sigma_{\text{data}}^2}
{\left(\sigma(t) \cdot \sigma_{\text{data}}\right)^2}
.. important::
This loss weight is designed for training an x0-predictor
(clean data predictor) wrapped with
:class:`~physicsnemo.diffusion.preconditioners.EDMPreconditioner`.
It is not suitable for training a score-predictor, or a model
without a pre-conditioner.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight of shape :math:`(N,)` when ``sigma_data`` is a
scalar, or :math:`(N, C)` when ``sigma_data`` is per-channel.
"""
sigma = self.sigma(t)
sd = self.sigma_data.to(device=sigma.device, dtype=sigma.dtype)
if self._per_channel:
# Per-channel: sigma (N,) → (N, 1); sd (C,) → (1, C)
sigma = sigma.unsqueeze(-1)
sd = sd.unsqueeze(0)
return (sigma**2 + sd**2) / (sigma * sd) ** 2
[docs]
class VENoiseScheduler(LinearGaussianNoiseScheduler):
r"""
Variance Exploding (VE) noise scheduler.
Implements the VE formulation with :math:`\sigma(t) = \sqrt{t}` and
:math:`\alpha(t) = 1` (no signal attenuation).
**Sampling time-steps** use geometric spacing in :math:`\sigma^2` space:
.. math::
\sigma_i^2 = \sigma_{\max}^2 \cdot
\left(\frac{\sigma_{\min}^2}{\sigma_{\max}^2}\right)^{i/(N-1)}
**Training times** are sampled log-uniformly between ``sigma_min`` and
``sigma_max``, then mapped to time via :math:`t = \sigma^2`.
Parameters
----------
sigma_min : float, optional
Minimum noise level, by default 0.02.
sigma_max : float, optional
Maximum noise level, by default 100.
Note
----
Reference: `Score-Based Generative Modeling through Stochastic
Differential Equations <https://arxiv.org/abs/2011.13456>`_
Examples
--------
Basic training and sampling workflow using the VE noise scheduler:
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import VENoiseScheduler
>>>
>>> scheduler = VENoiseScheduler(sigma_min=0.02, sigma_max=100.0)
>>>
>>> # Training: sample times and add noise
>>> x0 = torch.randn(4, 3, 8, 8) # Clean data
>>> t = scheduler.sample_time(4) # Sample diffusion times
>>> x_t = scheduler.add_noise(x0, t) # Create noisy samples
>>> x_t.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Sampling: generate timesteps and initial latents
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(4) # Initial time for batch of 4
>>> xN = scheduler.init_latents((3, 8, 8), tN) # Initial noise
>>> xN.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Convert x0-predictor to denoiser for sampling
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>> denoiser(xN, tN).shape # ODE RHS for sampling
torch.Size([4, 3, 8, 8])
"""
def __init__(
self,
sigma_min: float = 0.02,
sigma_max: float = 100.0,
) -> None:
self.sigma_min = sigma_min
self.sigma_max = sigma_max
[docs]
def sigma(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""VE noise coefficient: :math:`\sigma(t) = \sqrt{t}`."""
return t.sqrt()
[docs]
def sigma_inv(
self,
sigma: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Inverse VE mapping: :math:`t = \sigma^2`."""
return sigma**2
[docs]
def sigma_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Time derivative: :math:`\dot{\sigma}(t) = 1/(2\sqrt{t})`."""
return 0.5 / t.sqrt()
[docs]
def alpha(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant signal coefficient: :math:`\alpha(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Zero derivative: :math:`\dot{\alpha}(t) = 0`."""
return torch.zeros_like(t)
[docs]
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate VE time-steps with geometric spacing in :math:`\sigma^2`.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
torch.Tensor
Time-steps tensor of shape :math:`(N + 1,)`.
"""
step_indices = torch.arange(num_steps, dtype=dtype, device=device)
ratio = self.sigma_min**2 / self.sigma_max**2
exponent = step_indices / (num_steps - 1)
t_steps = (self.sigma_max**2) * (ratio**exponent)
zero = torch.zeros(1, dtype=dtype, device=device)
return torch.cat([t_steps, zero])
[docs]
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion times log-uniformly in sigma space, mapped to time.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
u = torch.rand(N, device=device, dtype=dtype)
log_ratio = math.log(self.sigma_max / self.sigma_min)
sigma = self.sigma_min * torch.exp(u * log_ratio)
return self.sigma_inv(sigma)
[docs]
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"]:
r"""
Compute VE loss weight: :math:`w(t) = 1 / \sigma(t)^2`.
.. important::
This loss weight is designed for training an x0-predictor
(clean data predictor) wrapped with
:class:`~physicsnemo.diffusion.preconditioners.VEPreconditioner`.
It is not suitable for training a score-predictor, or a model
without a pre-conditioner.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight of shape :math:`(N,)`.
"""
return 1 / self.sigma(t) ** 2
[docs]
class IDDPMNoiseScheduler(LinearGaussianNoiseScheduler):
r"""
Improved DDPM (iDDPM) noise scheduler.
Uses identity mappings :math:`\sigma(t) = t` and :math:`\alpha(t) = 1`.
The key feature is a precomputed noise level schedule derived from a
cosine schedule, providing improved sample quality in comparison to
original DDPM.
**Sampling time-steps** are selected from a precomputed schedule of
:math:`M` discrete noise levels, subsampled to ``num_steps``.
**Training times** are sampled uniformly from the precomputed schedule.
Parameters
----------
sigma_min : float, optional
Minimum noise level for filtering, by default 0.002.
sigma_max : float, optional
Maximum noise level for filtering, by default 81.
C_1 : float, optional
Clipping threshold for alpha ratio, by default 0.001.
C_2 : float, optional
Cosine schedule parameter, by default 0.008.
M : int, optional
Number of precomputed discretization steps, by default 1000.
Note
----
Reference: `Improved Denoising Diffusion Probabilistic Models
<https://arxiv.org/abs/2102.09672>`_
Examples
--------
Basic training and sampling workflow using the iDDPM noise scheduler:
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import IDDPMNoiseScheduler
>>>
>>> scheduler = IDDPMNoiseScheduler(C_1=0.001, C_2=0.008, M=1000)
>>>
>>> # Training: sample times and add noise
>>> x0 = torch.randn(4, 3, 8, 8) # Clean data
>>> t = scheduler.sample_time(4) # Sample diffusion times
>>> x_t = scheduler.add_noise(x0, t) # Create noisy samples
>>> x_t.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Sampling: generate timesteps and initial latents
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(4) # Initial time for batch of 4
>>> xN = scheduler.init_latents((3, 8, 8), tN) # Initial noise
>>> xN.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Convert x0-predictor to denoiser for sampling
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>> denoiser(xN, tN).shape # ODE RHS for sampling
torch.Size([4, 3, 8, 8])
"""
def __init__(
self,
sigma_min: float = 0.002,
sigma_max: float = 81.0,
C_1: float = 0.001,
C_2: float = 0.008,
M: int = 1000,
) -> None:
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.C_1 = C_1
self.C_2 = C_2
self.M = M
# Precompute the noise level schedule u_j, j = 0, ..., M
self._u = self._compute_u_schedule()
def _compute_u_schedule(self) -> Tensor:
"""Precompute the iDDPM noise level schedule."""
u = torch.zeros(self.M + 1)
for j in range(self.M, 0, -1):
angle_j = 0.5 * math.pi * j / self.M / (self.C_2 + 1)
angle_jm1 = 0.5 * math.pi * (j - 1) / self.M / (self.C_2 + 1)
alpha_bar_j = math.sin(angle_j) ** 2
alpha_bar_jm1 = math.sin(angle_jm1) ** 2
alpha_ratio = alpha_bar_jm1 / alpha_bar_j
val = (u[j] ** 2 + 1) / max(alpha_ratio, self.C_1) - 1
u[j - 1] = val.sqrt()
return u
[docs]
def sigma(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""For iDDPM, :math:`\sigma(t) = t` (identity mapping)."""
return t
[docs]
def sigma_inv(
self,
sigma: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""For iDDPM, :math:`t = \sigma` (identity mapping)."""
return sigma
[docs]
def sigma_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant derivative: :math:`\dot{\sigma}(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant signal coefficient: :math:`\alpha(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Zero derivative: :math:`\dot{\alpha}(t) = 0`."""
return torch.zeros_like(t)
[docs]
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate iDDPM time-steps from precomputed schedule.
Subsamples ``num_steps`` values from the precomputed schedule of
:math:`M` noise levels.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
torch.Tensor
Time-steps tensor of shape :math:`(N + 1,)`.
"""
u = self._u.to(device=device, dtype=dtype)
# Filter to valid sigma range
in_range = torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)
u_filtered = u[in_range]
step_indices = torch.arange(num_steps, dtype=dtype, device=device)
scale = (len(u_filtered) - 1) / (num_steps - 1)
indices = (scale * step_indices).round().to(torch.int64)
sigma_steps = u_filtered[indices]
zero = torch.zeros(1, dtype=dtype, device=device)
return torch.cat([sigma_steps, zero])
[docs]
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion times uniformly from precomputed schedule.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
u = self._u.to(device=device, dtype=dtype)
in_range = torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)
u_filtered = u[in_range]
# Sample random indices
indices = torch.randint(0, len(u_filtered), (N,), device=device)
return u_filtered[indices]
[docs]
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"]:
r"""
Compute iDDPM loss weight: :math:`w(t) = 1 / \sigma(t)^2`.
.. important::
This loss weight is designed for training an x0-predictor
(clean data predictor) wrapped with
:class:`~physicsnemo.diffusion.preconditioners.IDDPMPreconditioner`.
It is not suitable for training a score-predictor, or a model
without a pre-conditioner.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight of shape :math:`(N,)`.
"""
return 1 / self.sigma(t) ** 2
[docs]
class VPNoiseScheduler(LinearGaussianNoiseScheduler):
r"""
Variance Preserving (VP) noise scheduler.
Implements the VP formulation where the total variance is preserved:
:math:`\alpha(t)^2 + \sigma(t)^2 = 1`. This is based on a linear beta
schedule: :math:`\beta(t) = \beta_{\min} + t \cdot \beta_d`.
The noise and signal coefficients are:
.. math::
\alpha(t) = \exp\left(-\frac{1}{2}
\left(\frac{\beta_d}{2} t^2 + \beta_{\min} t\right)\right)
.. math::
\sigma(t) = \sqrt{1 - \alpha(t)^2}
= \sqrt{1 - \exp\left(-\frac{\beta_d}{2} t^2
- \beta_{\min} t\right)}
**Sampling time-steps** are linearly spaced from ``t_max`` (usually 1) to
``epsilon_s`` (small positive value to avoid singularities).
**Training times** are sampled uniformly between ``epsilon_s`` and
``t_max``.
Parameters
----------
beta_min : float, optional
Minimum beta value for the linear schedule, by default 0.1.
beta_d : float, optional
Beta slope (delta) for the linear schedule, by default 19.1.
epsilon_s : float, optional
Small positive value for minimum time, by default 1e-3.
t_max : float, optional
Maximum diffusion time, by default 1.0.
Note
----
Reference: `Score-Based Generative Modeling through Stochastic
Differential Equations <https://arxiv.org/abs/2011.13456>`_
Examples
--------
Basic training and sampling workflow using the VP noise scheduler:
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler
>>>
>>> scheduler = VPNoiseScheduler(beta_min=0.1, beta_d=19.1)
>>>
>>> # Training: sample times and add noise
>>> x0 = torch.randn(4, 3, 8, 8) # Clean data
>>> t = scheduler.sample_time(4) # Sample diffusion times
>>> x_t = scheduler.add_noise(x0, t) # Create noisy samples
>>> x_t.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Sampling: generate timesteps and initial latents
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(4) # Initial time for batch of 4
>>> xN = scheduler.init_latents((3, 8, 8), tN) # Initial noise
>>> xN.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Convert x0-predictor to denoiser for sampling
>>> x0_predictor = lambda x, t: x * 0.9 # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>> denoiser(xN, tN).shape # ODE RHS for sampling
torch.Size([4, 3, 8, 8])
"""
def __init__(
self,
beta_min: float = 0.1,
beta_d: float = 19.1,
epsilon_s: float = 1e-3,
t_max: float = 1.0,
) -> None:
self.beta_min = beta_min
self.beta_d = beta_d
self.epsilon_s = epsilon_s
self.t_max = t_max
[docs]
def alpha(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Signal coefficient: :math:`\alpha(t) = \exp(-a(t)/2)`."""
return torch.exp(-0.5 * (0.5 * self.beta_d * t**2 + self.beta_min * t))
[docs]
def alpha_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Derivative: :math:`\dot{\alpha}(t) = -\frac{\beta(t)}{2} \alpha(t)`."""
beta_t = self.beta_min + self.beta_d * t
return -0.5 * beta_t * self.alpha(t)
[docs]
def sigma(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Noise level: :math:`\sigma(t) = \sqrt{1 - \alpha(t)^2}`."""
alpha_sq = self.alpha(t) ** 2
return torch.sqrt(1 - alpha_sq)
[docs]
def sigma_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Derivative: :math:`\dot{\sigma}(t) = -\alpha(t) \dot{\alpha}(t) / \sigma(t)`.""" # noqa: E501
alpha_t = self.alpha(t)
sigma_t = self.sigma(t)
alpha_dot_t = self.alpha_dot(t)
# d/dt sqrt(1 - alpha^2) = -alpha * alpha_dot / sqrt(1 - alpha^2)
return -alpha_t * alpha_dot_t / sigma_t
[docs]
def sigma_inv(
self,
sigma: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""
Inverse mapping from sigma to time.
Solves: :math:`\sigma^2 = 1 - \exp(-a(t))` for :math:`t`.
"""
# sigma^2 = 1 - exp(-a) => a = -log(1 - sigma^2)
# a = beta_d/2 * t^2 + beta_min * t
# Quadratic: beta_d * t^2 + 2*beta_min * t + 2*log(1-sigma^2) = 0
log_term = torch.log(1 - sigma**2 + 1e-8) # small eps for stability
discriminant = self.beta_min**2 - 2 * self.beta_d * log_term
return (-self.beta_min + torch.sqrt(discriminant.clamp(min=0))) / self.beta_d
[docs]
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate VP time-steps with linear spacing.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
torch.Tensor
Time-steps tensor of shape :math:`(N + 1,)`.
"""
# Linear spacing from t_max to epsilon_s
step_indices = torch.arange(num_steps, dtype=dtype, device=device)
frac = step_indices / (num_steps - 1)
t_steps = self.t_max + frac * (self.epsilon_s - self.t_max)
zero = torch.zeros(1, dtype=dtype, device=device)
return torch.cat([t_steps, zero])
[docs]
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion times uniformly in :math:`[\epsilon_s, t_{max}]`.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
u = torch.rand(N, device=device, dtype=dtype)
return self.epsilon_s + u * (self.t_max - self.epsilon_s)
[docs]
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"]:
r"""
Compute VP loss weight: :math:`w(t) = \alpha(t)^2 / \sigma(t)^2`.
.. important::
This loss weight is designed for training an x0-predictor
(clean data predictor) wrapped with
:class:`~physicsnemo.diffusion.preconditioners.VPPreconditioner`.
It is not suitable for training a score-predictor, or a model
without a pre-conditioner.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight of shape :math:`(N,)`.
"""
return (self.alpha(t) / self.sigma(t)) ** 2
[docs]
class StudentTEDMNoiseScheduler(LinearGaussianNoiseScheduler):
r"""
Student-t EDM noise scheduler for heavy-tailed diffusion models.
This scheduler is a variant of :class:`EDMNoiseScheduler` that uses
Student-t noise instead of Gaussian noise. It is useful for modeling
heavy-tailed distributions and can improve sample quality for certain
data types.
.. important::
Despite inheriting from :class:`LinearGaussianNoiseScheduler`, this
scheduler is **not truly Gaussian**. It uses the same linear structure
(identity mappings :math:`\sigma(t) = t` and :math:`\alpha(t) = 1`) but
replaces Gaussian noise with Student-t noise. The "Linear" part of
:class:`LinearGaussianNoiseScheduler` still applies, but the "Gaussian"
part does not.
This scheduler uses a non-gaussian forward process:
.. math::
\mathbf{x}(t) = \mathbf{x}_0 + \sigma(t) \mathbf{n}, \quad
\mathbf{n} \sim \text{Student-}t(\nu)
The marginal distribution :math:`p(\mathbf{x}_t | \mathbf{x}_0)` is
therefore a scaled Student-t distribution, not Gaussian.
**Comparison with EDMNoiseScheduler:**
This scheduler shares the same time-to-noise mappings as
:class:`EDMNoiseScheduler`.
The only differences are in :meth:`add_noise` and :meth:`init_latents`,
which use Student-t noise instead of Gaussian noise.
Parameters
----------
sigma_min : float, optional
Minimum noise level for sampling time-steps, by default 0.002.
sigma_max : float, optional
Maximum noise level for sampling time-steps, by default 80.
rho : float, optional
Exponent controlling time-step spacing. Larger values concentrate more
steps at lower noise levels (better for fine details). By default 7.
nu : int, optional
Degrees of freedom for Student-t distribution. Must be > 2.
As ``nu`` increases, the distribution approaches Gaussian. Lower values
produce heavier tails. By default 10.
sigma_data : float or Tensor, optional
Expected standard deviation of the training data, by default 0.5.
Used by :meth:`loss_weight` to compute the per-sample loss weight.
When a 1-D ``Tensor`` of shape :math:`(C,)` is given, each channel
receives its own weight and :meth:`loss_weight` returns shape
:math:`(N, C)` instead of :math:`(N,)`.
P_mean : float, optional
Mean of the log-normal distribution used to sample training times,
by default -1.2.
P_std : float, optional
Standard deviation of the log-normal distribution used to sample
training times, by default 1.2.
Note
----
Reference: `Heavy-Tailed Diffusion Models
<https://arxiv.org/abs/2410.14171>`_
Examples
--------
Basic training and sampling workflow with Student-t noise:
>>> import torch
>>> from physicsnemo.diffusion.noise_schedulers import (
... StudentTEDMNoiseScheduler,
... )
>>>
>>> scheduler = StudentTEDMNoiseScheduler(nu=10)
>>>
>>> # Training: sample times and add Student-t noise
>>> x0 = torch.randn(4, 3, 8, 8) # Clean data
>>> t = scheduler.sample_time(4) # Sample diffusion times
>>> x_t = scheduler.add_noise(x0, t) # Adds Student-t noise
>>> x_t.shape
torch.Size([4, 3, 8, 8])
>>>
>>> # Sampling: generate timesteps and Student-t initial latents
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(4)
>>> xN = scheduler.init_latents((3, 8, 8), tN) # Student-t latents
>>> xN.shape
torch.Size([4, 3, 8, 8])
"""
def __init__(
self,
sigma_min: float = 0.002,
sigma_max: float = 80.0,
rho: float = 7.0,
nu: int = 10,
sigma_data: float | Float[Tensor, " C"] = 0.5,
P_mean: float = -1.2,
P_std: float = 1.2,
) -> None:
if nu <= 2:
raise ValueError(f"nu must be > 2, got {nu}")
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
self.nu = nu
self.sigma_data: Tensor = (
sigma_data
if isinstance(sigma_data, Tensor)
else torch.as_tensor(sigma_data, dtype=torch.float32)
)
self._per_channel: bool = self.sigma_data.ndim > 0
self.P_mean = P_mean
self.P_std = P_std
[docs]
def sigma(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Identity mapping: :math:`\sigma(t) = t`."""
return t
[docs]
def sigma_inv(
self,
sigma: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Identity mapping: :math:`t = \sigma`."""
return sigma
[docs]
def sigma_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant derivative: :math:`\dot{\sigma}(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Constant signal coefficient: :math:`\alpha(t) = 1`."""
return torch.ones_like(t)
[docs]
def alpha_dot(
self,
t: Float[Tensor, " *shape"],
) -> Float[Tensor, " *shape"]:
r"""Zero derivative: :math:`\dot{\alpha}(t) = 0`."""
return torch.zeros_like(t)
[docs]
def timesteps(
self,
num_steps: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N+1"]:
r"""
Generate EDM time-steps with polynomial spacing.
Parameters
----------
num_steps : int
Number of sampling steps.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
torch.Tensor
Time-steps tensor of shape :math:`(N + 1,)` where :math:`N` is
``num_steps``.
"""
step_indices = torch.arange(num_steps, dtype=dtype, device=device)
smax_inv_rho = self.sigma_max ** (1 / self.rho)
smin_inv_rho = self.sigma_min ** (1 / self.rho)
frac = step_indices / (num_steps - 1)
interp = smax_inv_rho + frac * (smin_inv_rho - smax_inv_rho)
t_steps = interp**self.rho
zero = torch.zeros(1, dtype=dtype, device=device)
return torch.cat([t_steps, zero])
[docs]
def sample_time(
self,
N: int,
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " N"]:
r"""
Sample N diffusion times from a log-normal distribution:
:math:`\ln(\sigma) \sim \mathcal{N}(P_{\text{mean}},
P_{\text{std}}^2)`.
Parameters
----------
N : int
Number of time values to sample.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Sampled diffusion times of shape :math:`(N,)`.
"""
rnd_normal = torch.randn(N, device=device, dtype=dtype)
return (rnd_normal * self.P_std + self.P_mean).exp()
[docs]
def loss_weight(
self,
t: Float[Tensor, " N"],
) -> Float[Tensor, " N"] | Float[Tensor, " N C"]:
r"""
Compute Student-t EDM loss weight: :math:`w(t) = \frac{\tilde{\sigma}(t)^2 + \sigma_{\text{data}}^2}
{\left(\tilde{\sigma}(t) \cdot \sigma_{\text{data}}\right)^2}`
where :math:`\tilde{\sigma}(t) = \sigma(t) \cdot \sqrt{\frac{\nu}{\nu -
2}}` is the scaled noise level.
.. important::
This loss weight is designed for training an x0-predictor
(clean data predictor) wrapped with
:class:`~physicsnemo.diffusion.preconditioners.EDMPreconditioner`.
It is not suitable for training a score-predictor, or a model
without a pre-conditioner.
Parameters
----------
t : Tensor
Diffusion time values of shape :math:`(N,)`.
Returns
-------
Tensor
Loss weight of shape :math:`(N,)` when ``sigma_data`` is a
scalar, or :math:`(N, C)` when ``sigma_data`` is per-channel.
"""
sigma = self.sigma(t) * np.sqrt(self.nu / (self.nu - 2))
sd = self.sigma_data.to(device=sigma.device, dtype=sigma.dtype)
if self._per_channel:
# Per-channel: sigma (N,) → (N, 1); sd (C,) → (1, C)
sigma = sigma.unsqueeze(-1)
sd = sd.unsqueeze(0)
return (sigma**2 + sd**2) / (sigma * sd) ** 2
def _sample_student_t(
self,
*shape: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Tensor:
r"""
Sample from standard Student-t distribution.
Student-t samples are generated as: :math:`X / \sqrt{V / \nu}` where
:math:`X \sim \mathcal{N}(0, 1)` and :math:`V \sim \chi^2(\nu)`.
Parameters
----------
*shape : int
Shape of the output tensor.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Student-t samples of the specified shape.
"""
normal = torch.randn(*shape, device=device, dtype=dtype)
nu = torch.tensor(self.nu, device=device, dtype=dtype)
chi2_dist = torch.distributions.Chi2(df=nu)
chi2_samples = chi2_dist.sample((shape[0], *([1] * (len(shape) - 1))))
kappa = chi2_samples / nu
return normal / torch.sqrt(kappa)
[docs]
def add_noise(
self,
x0: Float[Tensor, " B *dims"],
time: Float[Tensor, " B"],
) -> Float[Tensor, " B *dims"]:
r"""
Add Student-t noise to clean data at the given diffusion times.
Unlike the Gaussian case in :class:`LinearGaussianNoiseScheduler`,
this method uses Student-t noise:
.. math::
\mathbf{x}(t) = \mathbf{x}_0 + \sigma(t) \mathbf{n}, \quad
\mathbf{n} \sim \text{Student-}t(\nu)
Parameters
----------
x0 : Tensor
Clean latent state of shape :math:`(B, *)`.
time : Tensor
Diffusion time values of shape :math:`(B,)`.
Returns
-------
Tensor
Noisy latent state of shape :math:`(B, *)`.
"""
expected_shape = (-1,) + (1,) * (x0.ndim - 1)
t_bc = time.reshape(expected_shape)
sigma_t_bc = self.sigma(t_bc)
noise = self._sample_student_t(*x0.shape, device=x0.device, dtype=x0.dtype)
return x0 + sigma_t_bc * noise
[docs]
def init_latents(
self,
spatial_shape: Tuple[int, ...],
tN: Float[Tensor, " B"],
*,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Float[Tensor, " B *spatial_shape"]:
r"""
Initialize noisy latent state with Student-t noise.
Unlike the Gaussian case in :class:`LinearGaussianNoiseScheduler`,
this method uses Student-t noise:
.. math::
\mathbf{x}_N = \sigma(t_N) \cdot \mathbf{n}, \quad
\mathbf{n} \sim \text{Student-}t(\nu)
Parameters
----------
spatial_shape : Tuple[int, ...]
Spatial shape of the latent state, e.g., ``(C, H, W)``.
tN : Tensor
Initial diffusion time of shape :math:`(B,)`.
device : torch.device, optional
Device to place the tensor on.
dtype : torch.dtype, optional
Data type of the tensor.
Returns
-------
Tensor
Initial noisy latent of shape :math:`(B, *spatial\_shape)`.
"""
B = tN.shape[0]
noise = self._sample_student_t(B, *spatial_shape, device=device, dtype=dtype)
expected_shape = (-1,) + (1,) * len(spatial_shape)
tN_bc = tN.reshape(expected_shape)
sigma_tN_bc = self.sigma(tN_bc)
return sigma_tN_bc * noise