# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from abc import ABC, abstractmethod
from typing import Any, Tuple
import torch
from tensordict import TensorDict
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
[docs]
class BaseAffinePreconditioner(Module, ABC):
r"""
Abstract base class for diffusion model preconditioners using an affine
transformation.
This class provides a standardized interface for implementing
preconditioners that use affine transformations of the model
input and output.
The preconditioner wraps a neural network model :math:`F` and applies
a preconditioning formula to transform the network output to produce
the preconditioned output :math:`D(\mathbf{x}, t)` according to:
.. math::
D(\mathbf{x}, t) = c_{\text{skip}}(t) \mathbf{x} +
c_{\text{out}}(t) F(c_{\text{in}}(t) \mathbf{x}, c_{\text{noise}}(t))
where:
- :math:`c_{\text{in}}(t)`: Input scaling coefficient
- :math:`c_{\text{noise}}(t)`: Noise conditioning value
- :math:`c_{\text{out}}(t)`: Output scaling coefficient
- :math:`c_{\text{skip}}(t)`: Skip connection scaling coefficient
and where :math:`\mathbf{x}` is the latent state and :math:`t` is the
diffusion time.
The wrapped model :math:`F` must be an instance of
:class:`~physicsnemo.core.Module` that satisfies the
:class:`~physicsnemo.diffusion.DiffusionModel` interface, with the
following signature:
.. code-block:: python
model(
x: torch.Tensor, # Shape: (B, *)
t: torch.Tensor, # Shape: (B,)
condition: torch.Tensor | TensorDict | None = None,
**model_kwargs: Any,
) -> torch.Tensor # Shape: (B, *)
The preconditioner is agnostic to the prediction target of the wrapped
model :math:`F`. The same preconditioning formula is applied regardless of
whether the model is an :math:`\mathbf{x}_0`-predictor, an
:math:`\epsilon`-predictor, a score predictor, or a
:math:`\mathbf{v}`-predictor.
.. note::
The preconditioner itself also satisfies the
:class:`~physicsnemo.diffusion.DiffusionModel` interface, meaning it
does not change the signature of the wrapped model :math:`F`, and it
can be used anywhere a diffusion model is expected.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model :math:`F` to wrap with the
signature described above.
meta : ModelMetaData, optional
Meta data class for storing info regarding model, by default None.
Subclasses can pass their own metadata.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
To implement a new preconditioner, a subclass of
:class:`BaseAffinePreconditioner` must be defined, and some methods
have to be implemented:
- Subclasses must implement the :meth:`compute_coefficients` method to
define the specific preconditioning scheme.
- A :meth:`sigma` method can optionally be implemented.
If a subclass implements the :meth:`sigma` method, the diffusion time
:math:`t` is first transformed to a noise level :math:`\sigma(t)`
before being passed to :meth:`compute_coefficients`. This allows
implementing preconditioners for different time-to-noise-level
mappings while keeping the same preconditioning interface, in
particular for
preconditioning schemes based on noise level (that is
:math:`c_{\text{in}}(\sigma)`,
:math:`c_{\text{noise}}(\sigma)`, :math:`c_{\text{out}}(\sigma)`,
:math:`c_{\text{skip}}(\sigma)` instead of :math:`c_{\text{in}}(t)`,
:math:`c_{\text{noise}}(t)`, :math:`c_{\text{out}}(t)`,
:math:`c_{\text{skip}}(t)`).
- The ``forward`` method of the preconditioner *should not* be
overriden.
The argument ``t`` of the preconditioner forward method is always
assumed to be the diffusion time. For preconditioning schemes based
on noise level, the noise level :math:`\sigma(t)` is computed
internally using the :meth:`sigma` method.
Examples
--------
The following example shows how to implement a classical EDM
preconditioner. For EDM, there is no need to implement the :meth:`sigma`
method since :math:`\sigma(t) = t` (noise level and diffusion time are the
same).
We first define a simple model to wrap:
>>> import torch
>>> from tensordict import TensorDict
>>> from physicsnemo.nn import Module
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.channels = channels
... self.net = torch.nn.Conv2d(channels, channels, 1)
...
... def forward(self, x, t, condition=None):
... return self.net(x)
Now we define the EDM preconditioner:
>>> from physicsnemo.diffusion.preconditioners import (
... BaseAffinePreconditioner,
... )
>>> class SimpleEDMPreconditioner(BaseAffinePreconditioner):
... def __init__(self, model, sigma_data: float = 0.5):
... super().__init__(model)
... self.sigma_data = sigma_data
...
... def compute_coefficients(self, t: torch.Tensor):
... # For EDM sigma(t) = t, so the argument passed to
... # compute_coefficients is already sigma(t)
... sigma_data = self.sigma_data
... c_skip = sigma_data**2 / (t**2 + sigma_data**2)
... c_out = t * sigma_data / (t**2 + sigma_data**2).sqrt()
... c_in = 1 / (sigma_data**2 + t**2).sqrt()
... c_noise = t.log() / 4
... return c_in, c_noise, c_out, c_skip
...
>>> model = SimpleModel(channels=3)
>>> precond = SimpleEDMPreconditioner(model, sigma_data=0.5)
>>> x = torch.randn(2, 3, 16, 16)
>>> t = torch.rand(2)
>>> condition = TensorDict({}, batch_size=[2])
>>> out = precond(x, t, condition)
>>> out.shape
torch.Size([2, 3, 16, 16])
The following example shows how to override the :meth:`sigma` method to
implement a Variance Exploding (VE) preconditioner where
:math:`\sigma(t) = \sqrt{t}`.
>>> class VEPreconditioner(BaseAffinePreconditioner):
... def __init__(self, model):
... super().__init__(model)
...
... def sigma(self, t: torch.Tensor) -> torch.Tensor:
... # Override sigma for VE time-to-noise-level mapping
... return t.sqrt()
...
... def compute_coefficients(self, sigma: torch.Tensor):
... # Here the argument passed to compute_coefficients is
... # sigma(t) = sqrt(t) due to override of the sigma method
... # due to override of the sigma method
... c_skip = torch.ones_like(sigma)
... c_out = sigma
... c_in = torch.ones_like(sigma)
... c_noise = (0.5 * sigma).log()
... return c_in, c_noise, c_out, c_skip
...
>>> precond_ve = VEPreconditioner(model)
>>> out_ve = precond_ve(x, t, condition)
>>> out_ve.shape
torch.Size([2, 3, 16, 16])
**Wrapping existing models to satisfy the DiffusionModel interface**
Some models in PhysicsNeMo have signatures that differ from the
:class:`~physicsnemo.diffusion.DiffusionModel` interface. Below are
examples showing how to write thin wrappers to make them compatible
with preconditioners, including image-based conditioning via channel
concatenation.
**Example: Wrapping SongUNet**
The :class:`~physicsnemo.models.diffusion_unets.SongUNet` model has
the signature ``forward(x, noise_labels, class_labels, augment_labels)``.
We wrap it to match ``forward(x, t, condition)``, where ``condition``
contains both class labels (1D vector) and an image to concatenate
channel-wise:
>>> from physicsnemo.models.diffusion_unets import SongUNet
>>> from physicsnemo.diffusion import DiffusionModel
>>> from tensordict import TensorDict
>>> class SongUNetWrapper(Module):
... def __init__(self, img_channels, cond_channels, label_dim, **kwargs):
... super().__init__()
... # in_channels = img_channels + cond_channels for concatenation
... self.net = SongUNet(
... in_channels=img_channels + cond_channels,
... out_channels=img_channels,
... label_dim=label_dim,
... **kwargs,
... )
...
... def forward(self, x, t, condition):
... # Concatenate image condition "y" channel-wise to input
... y = condition["y"] # shape: (B, C_cond, H, W)
... x_cat = torch.cat([x, y], dim=1)
... # Extract 1D vector condition for class_labels
... class_labels = condition["class_labels"] # shape: (B, label_dim)
... return self.net(x_cat, noise_labels=t, class_labels=class_labels)
...
>>> wrapped = SongUNetWrapper(
... img_channels=2, cond_channels=1, label_dim=4, img_resolution=8
... )
>>> isinstance(wrapped, DiffusionModel)
True
>>> x = torch.rand(1, 2, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
... "y": torch.rand(1, 1, 8, 8), # image condition
... "class_labels": torch.rand(1, 4), # 1D vector condition
... }, batch_size=[1])
>>> out = wrapped(x, t, condition)
>>> out.shape
torch.Size([1, 2, 8, 8])
**Example: Wrapping DiT**
The :class:`~physicsnemo.models.dit.DiT` model has
the signature ``forward(x, t, condition, ...)``. We wrap it to support
both image conditioning (via channel concatenation) and vector
conditioning:
>>> from physicsnemo.models.dit import DiT
>>> class DiTWrapper(Module):
... def __init__(self, img_channels, cond_channels, cond_dim, **kwargs):
... super().__init__()
... # in_channels = img_channels + cond_channels for concatenation
... self.net = DiT(
... in_channels=img_channels + cond_channels,
... out_channels=img_channels,
... condition_dim=cond_dim,
... **kwargs,
... )
...
... def forward(self, x, t, condition):
... # Concatenate image condition "y" channel-wise to input
... y = condition["y"] # shape: (B, C_cond, H, W)
... x_cat = torch.cat([x, y], dim=1)
... # Extract 1D vector condition
... vec = condition["vec"] # shape: (B, cond_dim)
... return self.net(x_cat, t, condition=vec)
...
>>> wrapped_dit = DiTWrapper(
... img_channels=2, cond_channels=1, cond_dim=4,
... input_size=8, patch_size=4, attention_backend="timm",
... )
>>> isinstance(wrapped_dit, DiffusionModel)
True
>>> x = torch.rand(1, 2, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
... "y": torch.rand(1, 1, 8, 8), # image condition
... "vec": torch.rand(1, 4), # 1D vector condition
... }, batch_size=[1])
>>> out = wrapped_dit(x, t, condition)
>>> out.shape
torch.Size([1, 2, 8, 8])
**Example: Using ConcatConditionWrapper with a preconditioner**
The pattern in the previous example, where (spatially-varying) conditioning
is concatenated to the noised latent state (and possibly vector conditioning
is also passed as a separate argument) is common across several diffusion
use-cases.
Thus for convenience, we provide the wrapper class :class:`~physicsnemo.diffusion.utils.ConcatConditionWrapper`
to save you the trouble of writing your own wrapper for this common pattern:
>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.utils import ConcatConditionWrapper
>>> base_model = SongUNet(img_resolution=8, in_channels=4, out_channels=3, label_dim=4)
>>> wrapped_model = ConcatConditionWrapper(base_model)
>>> precond = EDMPreconditioner(wrapped_model, sigma_data=0.5)
>>> x = torch.rand(1, 3, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
... "cond_concat": torch.rand(1, 1, 8, 8), # image condition
... "cond_vec": torch.rand(1, 4), # vector condition
... }, batch_size=[1])
>>> out = precond(x, t, condition)
>>> out.shape
torch.Size([1, 3, 8, 8])
The same wrapper can be used with :class:`~physicsnemo.models.dit.DiT`
backbones as well, with ``cond_vec`` passed to the model's ``condition`` argument.
"""
def __init__(
self,
model: Module,
meta: ModelMetaData | None = None,
) -> None:
super().__init__()
self.meta = meta
self.model = model
[docs]
@abstractmethod
def compute_coefficients(
self, t: torch.Tensor, /
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute the preconditioning coefficients for a given diffusion time
:math:`t` or noise level :math:`\sigma`.
This abstract method must be implemented by subclasses to define
the specific preconditioning scheme.
Parameters
----------
t : torch.Tensor
Diffusion time (or noise level if :meth:`sigma` is
implemented) tensor of shape :math:`(B, 1, ..., 1)` where
:math:`B` is the batch size and the trailing singleton
dimensions match the spatial dimensions of the latent state
``x`` for broadcasting.
Returns
-------
c_in : torch.Tensor
Input scaling coefficient of shape :math:`(B, 1, ..., 1)`.
c_noise : torch.Tensor
Noise conditioning value of shape :math:`(B, 1, ..., 1)`.
c_out : torch.Tensor
Output scaling coefficient of shape :math:`(B, 1, ..., 1)`.
c_skip : torch.Tensor
Skip connection scaling coefficient of shape
:math:`(B, 1, ..., 1)`.
"""
...
[docs]
def sigma(self, t: torch.Tensor) -> torch.Tensor:
r"""
Map diffusion time :math:`t` to noise level :math:`\sigma(t)`.
By default, this is the identity function :math:`\sigma(t) = t`.
Subclasses can override this to implement preconditioners for different
time-to-noise-level mappings.
When overridden, the output of this method is passed to
:meth:`compute_coefficients` instead of the raw time ``t``.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)` where
:math:`B` is the batch size.
Returns
-------
torch.Tensor
Noise level :math:`\sigma(t)` of shape :math:`(B,)`.
"""
return t
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
condition: torch.Tensor | TensorDict | None = None,
**model_kwargs: Any,
) -> torch.Tensor:
if not torch.compiler.is_compiling():
B = x.shape[0]
if t.shape != (B,):
raise ValueError(
f"Expected t to have shape ({B},) matching batch size of "
f"x, but got {t.shape}."
)
if isinstance(condition, TensorDict):
if condition.batch_size and condition.batch_size[0] != B:
raise ValueError(
f"Condition TensorDict has batch size {condition.batch_size[0]} "
f"but expected {B} to match x."
)
elif isinstance(condition, torch.Tensor):
if condition.shape[0] != B:
raise ValueError(
f"Condition tensor has batch size {condition.shape[0]} "
f"but expected {B} to match x."
)
# Map time step to noise level via sigma method
sigma_t = self.sigma(t).reshape(-1, *([1] * (x.ndim - 1)))
# Compute preconditioning coefficients
c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma_t)
# Forward through the underlying model
if condition is not None:
F_x = self.model(
c_in * x,
c_noise.flatten(),
condition=condition,
**model_kwargs,
)
else:
F_x = self.model(
c_in * x,
c_noise.flatten(),
**model_kwargs,
)
D_x = c_skip * x + c_out * F_x
return D_x
[docs]
class VPPreconditioner(BaseAffinePreconditioner):
r"""
Variance Preserving (VP) preconditioner.
Implements the preconditioning scheme from the VP formulation of
score-based generative models.
The time-to-noise-level mapping is:
.. math::
\sigma(t) = \sqrt{\exp\left(\frac{\beta_d}{2} t^2
+ \beta_{\min} t\right) - 1}
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= 1 \\
c_{\text{out}} &= -\sigma \\
c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\
c_{\text{noise}} &= (M - 1) \cdot \sigma^{-1}(\sigma)
With these coefficients, the preconditioned model output is expected to be
an :math:`\mathbf{x}_0`-prediction (clean data estimate).
This preconditioner is not directly compatible for score-prediction
training or others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.VPNoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
beta_d : float, optional
Extent of the noise level schedule, by default 19.9.
beta_min : float, optional
Initial slope of the noise level schedule, by default 0.1.
M : int, optional
Number of discretization steps in the DDPM formulation,
by default 1000.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Score-Based Generative Modeling through Stochastic
Differential Equations <https://arxiv.org/abs/2011.13456>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = VPPreconditioner(model, beta_d=19.9, beta_min=0.1, M=1000)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(
self,
model: Module,
beta_d: float = 19.9,
beta_min: float = 0.1,
M: int = 1000,
) -> None:
super().__init__(model)
self.register_buffer("beta_d", torch.tensor(beta_d))
self.register_buffer("beta_min", torch.tensor(beta_min))
self.register_buffer("M", torch.tensor(M))
[docs]
def sigma(self, t: torch.Tensor) -> torch.Tensor:
r"""
Compute :math:`\sigma(t)` for the VP formulation.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
Returns
-------
torch.Tensor
Noise level :math:`\sigma(t)` of shape :math:`(B,)`.
"""
exponent = 0.5 * self.beta_d * (t**2) + self.beta_min * t
return (exponent.exp() - 1).sqrt()
[docs]
def compute_coefficients(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute VP preconditioning coefficients.
Parameters
----------
sigma : torch.Tensor
Noise level tensor of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
c_skip = torch.ones_like(sigma)
c_out = -sigma
c_in = 1 / (sigma**2 + 1).sqrt()
# Compute t = sigma_inv(sigma)
t = (
(self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt()
- self.beta_min
) / self.beta_d
c_noise = (self.M - 1) * t
return c_in, c_noise, c_out, c_skip
[docs]
class VEPreconditioner(BaseAffinePreconditioner):
r"""
Variance Exploding (VE) preconditioner.
Implements the preconditioning scheme from the VE formulation of
score-based generative models.
For VE, the time-to-noise-level mapping is the identity:
:math:`\sigma(t) = t`.
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= 1 \\
c_{\text{out}} &= \sigma \\
c_{\text{in}} &= 1 \\
c_{\text{noise}} &= \log(0.5 \cdot \sigma)
With these coefficients, the preconditioned model output is expected to be an
:math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner
is not directly compatible for score-prediction training or others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.VENoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Score-Based Generative Modeling through Stochastic
Differential Equations <https://arxiv.org/abs/2011.13456>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = VEPreconditioner(model)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(self, model: Module) -> None:
super().__init__(model)
[docs]
def compute_coefficients(
self, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute VE preconditioning coefficients.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
c_skip = torch.ones_like(t)
c_out = t
c_in = torch.ones_like(t)
c_noise = (0.5 * t).log()
return c_in, c_noise, c_out, c_skip
[docs]
class IDDPMPreconditioner(BaseAffinePreconditioner):
r"""
Improved DDPM (iDDPM) preconditioner.
Implements the preconditioning scheme from the improved DDPM
formulation.
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= 1 \\
c_{\text{out}} &= -\sigma \\
c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\
c_{\text{noise}} &= M - 1 - \text{argmin}|\sigma - u_j|
where :math:`u_j, j = 0, ..., M` are the precomputed noise levels from
the iDDPM discretization.
With these coefficients, the preconditioned model output is expected to be
an :math:`\mathbf{x}_0`-prediction (clean data estimate). This
preconditioner is not directly compatible for score-prediction training or
others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.IDDPMNoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
C_1 : float, optional
Timestep adjustment at low noise levels, by default 0.001.
C_2 : float, optional
Timestep adjustment at high noise levels, by default 0.008.
M : int, optional
Number of discretization steps in the DDPM formulation,
by default 1000.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Improved Denoising Diffusion Probabilistic Models
<https://arxiv.org/abs/2102.09672>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = IDDPMPreconditioner(model, C_1=0.001, C_2=0.008, M=1000)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(
self,
model: Module,
C_1: float = 0.001,
C_2: float = 0.008,
M: int = 1000,
) -> None:
super().__init__(model)
self.register_buffer("C_1", torch.tensor(C_1))
self.register_buffer("C_2", torch.tensor(C_2))
self.register_buffer("M", torch.tensor(M))
# Precompute the noise level schedule u_j, j = 0, ..., M
u = torch.zeros(M + 1)
for j in range(M, 0, -1):
angle_j = 0.5 * math.pi * j / M / (C_2 + 1)
angle_jm1 = 0.5 * math.pi * (j - 1) / M / (C_2 + 1)
alpha_bar_j = math.sin(angle_j) ** 2
alpha_bar_jm1 = math.sin(angle_jm1) ** 2
alpha_ratio = alpha_bar_jm1 / alpha_bar_j
u[j - 1] = ((u[j] ** 2 + 1) / max(alpha_ratio, C_1) - 1).sqrt()
self.register_buffer("u", u)
[docs]
def compute_coefficients(
self, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute iDDPM preconditioning coefficients.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
c_skip = torch.ones_like(t)
c_out = -t
c_in = 1 / (t**2 + 1).sqrt()
# Round sigma to nearest index in precomputed schedule u
u: torch.Tensor = self.u # type: ignore[assignment]
t_flat = t.reshape(1, -1, 1)
u_reshaped = u.reshape(1, -1, 1)
idx = torch.cdist(t_flat, u_reshaped).argmin(2).reshape(t.shape)
c_noise = self.M - 1 - idx
return c_in, c_noise, c_out, c_skip
[docs]
class EDMPreconditioner(BaseAffinePreconditioner):
r"""
EDM preconditioner.
Implements the improved preconditioning scheme proposed in the EDM
paper.
For EDM, the time-to-noise-level mapping is the identity:
:math:`\sigma(t) = t`.
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= \frac{\sigma_{\text{data}}^2}
{\sigma^2 + \sigma_{\text{data}}^2} \\
c_{\text{out}} &= \frac{\sigma \cdot \sigma_{\text{data}}}
{\sqrt{\sigma^2 + \sigma_{\text{data}}^2}} \\
c_{\text{in}} &= \frac{1}
{\sqrt{\sigma_{\text{data}}^2 + \sigma^2}} \\
c_{\text{noise}} &= \frac{\log(\sigma)}{4}
With these coefficients, the preconditioned model output is expected to be an
:math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner
is not directly compatible for score-prediction training or others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.EDMNoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
sigma_data : float, optional
Expected standard deviation of the training data, by default 0.5.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Elucidating the Design Space of Diffusion-Based
Generative Models <https://arxiv.org/abs/2206.00364>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = EDMPreconditioner(model, sigma_data=0.5)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(
self,
model: Module,
sigma_data: float = 0.5,
) -> None:
super().__init__(model)
self.register_buffer("sigma_data", torch.tensor(sigma_data))
[docs]
def compute_coefficients(
self, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute EDM preconditioning coefficients.
Parameters
----------
t : torch.Tensor
Diffusion time (or noise level, since they are identical for EDM)
of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
sd = self.sigma_data
c_skip = sd**2 / (t**2 + sd**2)
c_out = t * sd / (t**2 + sd**2).sqrt()
c_in = 1 / (sd**2 + t**2).sqrt()
c_noise = t.log() / 4
return c_in, c_noise, c_out, c_skip