# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Tuple
import torch
from jaxtyping import Float
from tensordict import TensorDict
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
def _ensure_plain_tensor(t: torch.Tensor) -> torch.Tensor:
"""Unwrap fully-replicated DTensors to plain tensors.
FSDP wraps registered buffers as DTensors with ``Replicate`` placement.
Arithmetic in ``compute_coefficients`` propagates that wrapping to the
resulting coefficient tensors. These must be converted back to plain
tensors before element-wise operations with non-DTensor inputs (plain
tensors or ``ShardTensor``).
Only DTensors whose *every* placement is ``Replicate`` are unwrapped;
any tensor with a ``Shard`` placement (e.g. a ``ShardTensor``) is left
untouched so its sharding metadata is preserved.
"""
if isinstance(t, DTensor) and all(isinstance(p, Replicate) for p in t.placements):
return t.to_local()
return t
def _replicate_on_mesh(t: torch.Tensor, mesh) -> torch.Tensor:
"""Promote a plain tensor to a replicated ``DTensor`` on *mesh*.
Already-distributed tensors are returned unchanged. Used after
:func:`_ensure_plain_tensor` to re-wrap coefficients for element-wise
arithmetic with ``ShardTensor`` data in domain-parallel training.
"""
if isinstance(t, DTensor):
return t
return DTensor.from_local(t, device_mesh=mesh, placements=[Replicate()])
[docs]
class BaseAffinePreconditioner(Module, ABC):
r"""
Abstract base class for diffusion model preconditioners using an affine
transformation.
This class provides a standardized interface for implementing
preconditioners that use affine transformations of the model
input and output.
The preconditioner wraps a neural network model :math:`F` and applies
a preconditioning formula to transform the network output to produce
the preconditioned output :math:`D(\mathbf{x}, t)` according to:
.. math::
D(\mathbf{x}, t) = c_{\text{skip}}(t) \mathbf{x} +
c_{\text{out}}(t) F(c_{\text{in}}(t) \mathbf{x}, c_{\text{noise}}(t))
where:
- :math:`c_{\text{in}}(t)`: Input scaling coefficient
- :math:`c_{\text{noise}}(t)`: Noise conditioning value
- :math:`c_{\text{out}}(t)`: Output scaling coefficient
- :math:`c_{\text{skip}}(t)`: Skip connection scaling coefficient
and where :math:`\mathbf{x}` is the latent state and :math:`t` is the
diffusion time.
The wrapped model :math:`F` must be an instance of
:class:`~physicsnemo.core.Module` that satisfies the
:class:`~physicsnemo.diffusion.DiffusionModel` interface, with the
following signature:
.. code-block:: python
model(
x: torch.Tensor, # Shape: (B, *)
t: torch.Tensor, # Shape: (B,)
condition: torch.Tensor | TensorDict | None = None,
**model_kwargs: Any,
) -> torch.Tensor # Shape: (B, *)
The preconditioner is agnostic to the prediction target of the wrapped
model :math:`F`. The same preconditioning formula is applied regardless of
whether the model is an :math:`\mathbf{x}_0`-predictor, an
:math:`\epsilon`-predictor, a score predictor, or a
:math:`\mathbf{v}`-predictor.
.. note::
The preconditioner itself also satisfies the
:class:`~physicsnemo.diffusion.DiffusionModel` interface, meaning it
does not change the signature of the wrapped model :math:`F`, and it
can be used anywhere a diffusion model is expected.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model :math:`F` to wrap with the
signature described above.
meta : ModelMetaData, optional
Meta data class for storing info regarding model, by default None.
Subclasses can pass their own metadata.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
To implement a new preconditioner, a subclass of
:class:`BaseAffinePreconditioner` must be defined, and some methods
have to be implemented:
- Subclasses must implement the :meth:`compute_coefficients` method to
define the specific preconditioning scheme.
- A :meth:`sigma` method can optionally be implemented.
If a subclass implements the :meth:`sigma` method, the diffusion time
:math:`t` is first transformed to a noise level :math:`\sigma(t)`
before being passed to :meth:`compute_coefficients`. This allows
implementing preconditioners for different time-to-noise-level
mappings while keeping the same preconditioning interface, in
particular for
preconditioning schemes based on noise level (that is
:math:`c_{\text{in}}(\sigma)`,
:math:`c_{\text{noise}}(\sigma)`, :math:`c_{\text{out}}(\sigma)`,
:math:`c_{\text{skip}}(\sigma)` instead of :math:`c_{\text{in}}(t)`,
:math:`c_{\text{noise}}(t)`, :math:`c_{\text{out}}(t)`,
:math:`c_{\text{skip}}(t)`).
- The ``forward`` method of the preconditioner *should not* be
overriden.
The argument ``t`` of the preconditioner forward method is always
assumed to be the diffusion time. For preconditioning schemes based
on noise level, the noise level :math:`\sigma(t)` is computed
internally using the :meth:`sigma` method.
Examples
--------
The following example shows how to implement a classical EDM
preconditioner. For EDM, there is no need to implement the :meth:`sigma`
method since :math:`\sigma(t) = t` (noise level and diffusion time are the
same).
We first define a simple model to wrap:
>>> import torch
>>> from tensordict import TensorDict
>>> from physicsnemo.nn import Module
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.channels = channels
... self.net = torch.nn.Conv2d(channels, channels, 1)
...
... def forward(self, x, t, condition=None):
... return self.net(x)
Now we define the EDM preconditioner:
>>> from physicsnemo.diffusion.preconditioners import (
... BaseAffinePreconditioner,
... )
>>> class SimpleEDMPreconditioner(BaseAffinePreconditioner):
... def __init__(self, model, sigma_data: float = 0.5):
... super().__init__(model)
... self.sigma_data = sigma_data
...
... def compute_coefficients(self, t: torch.Tensor):
... # For EDM sigma(t) = t, so the argument passed to
... # compute_coefficients is already sigma(t)
... sigma_data = self.sigma_data
... c_skip = sigma_data**2 / (t**2 + sigma_data**2)
... c_out = t * sigma_data / (t**2 + sigma_data**2).sqrt()
... c_in = 1 / (sigma_data**2 + t**2).sqrt()
... c_noise = t.log() / 4
... return c_in, c_noise, c_out, c_skip
...
>>> model = SimpleModel(channels=3)
>>> precond = SimpleEDMPreconditioner(model, sigma_data=0.5)
>>> x = torch.randn(2, 3, 16, 16)
>>> t = torch.rand(2)
>>> condition = TensorDict({}, batch_size=[2])
>>> out = precond(x, t, condition)
>>> out.shape
torch.Size([2, 3, 16, 16])
The following example shows how to override the :meth:`sigma` method to
implement a Variance Exploding (VE) preconditioner where
:math:`\sigma(t) = \sqrt{t}`.
>>> class VEPreconditioner(BaseAffinePreconditioner):
... def __init__(self, model):
... super().__init__(model)
...
... def sigma(self, t: torch.Tensor) -> torch.Tensor:
... # Override sigma for VE time-to-noise-level mapping
... return t.sqrt()
...
... def compute_coefficients(self, sigma: torch.Tensor):
... # Here the argument passed to compute_coefficients is
... # sigma(t) = sqrt(t) due to override of the sigma method
... # due to override of the sigma method
... c_skip = torch.ones_like(sigma)
... c_out = sigma
... c_in = torch.ones_like(sigma)
... c_noise = (0.5 * sigma).log()
... return c_in, c_noise, c_out, c_skip
...
>>> precond_ve = VEPreconditioner(model)
>>> out_ve = precond_ve(x, t, condition)
>>> out_ve.shape
torch.Size([2, 3, 16, 16])
**Wrapping existing models to satisfy the DiffusionModel interface**
Some models in PhysicsNeMo have signatures that differ from the
:class:`~physicsnemo.diffusion.DiffusionModel` interface. Below are
examples showing how to write thin wrappers to make them compatible
with preconditioners, including image-based conditioning via channel
concatenation.
**Example: Wrapping SongUNet**
The :class:`~physicsnemo.models.diffusion_unets.SongUNet` model has
the signature ``forward(x, noise_labels, class_labels, augment_labels)``.
We wrap it to match ``forward(x, t, condition)``, where ``condition``
contains both class labels (1D vector) and an image to concatenate
channel-wise:
>>> from physicsnemo.models.diffusion_unets import SongUNet
>>> from physicsnemo.diffusion import DiffusionModel
>>> from tensordict import TensorDict
>>> class SongUNetWrapper(Module):
... def __init__(self, img_channels, cond_channels, label_dim, **kwargs):
... super().__init__()
... # in_channels = img_channels + cond_channels for concatenation
... self.net = SongUNet(
... in_channels=img_channels + cond_channels,
... out_channels=img_channels,
... label_dim=label_dim,
... **kwargs,
... )
...
... def forward(self, x, t, condition):
... # Concatenate image condition "y" channel-wise to input
... y = condition["y"] # shape: (B, C_cond, H, W)
... x_cat = torch.cat([x, y], dim=1)
... # Extract 1D vector condition for class_labels
... class_labels = condition["class_labels"] # shape: (B, label_dim)
... return self.net(x_cat, noise_labels=t, class_labels=class_labels)
...
>>> wrapped = SongUNetWrapper(
... img_channels=2, cond_channels=1, label_dim=4, img_resolution=8
... )
>>> isinstance(wrapped, DiffusionModel)
True
>>> x = torch.rand(1, 2, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
... "y": torch.rand(1, 1, 8, 8), # image condition
... "class_labels": torch.rand(1, 4), # 1D vector condition
... }, batch_size=[1])
>>> out = wrapped(x, t, condition)
>>> out.shape
torch.Size([1, 2, 8, 8])
**Example: Wrapping DiT**
The :class:`~physicsnemo.models.dit.DiT` model has
the signature ``forward(x, t, condition, ...)``. We wrap it to support
both image conditioning (via channel concatenation) and vector
conditioning:
>>> from physicsnemo.models.dit import DiT
>>> class DiTWrapper(Module):
... def __init__(self, img_channels, cond_channels, cond_dim, **kwargs):
... super().__init__()
... # in_channels = img_channels + cond_channels for concatenation
... self.net = DiT(
... in_channels=img_channels + cond_channels,
... out_channels=img_channels,
... condition_dim=cond_dim,
... **kwargs,
... )
...
... def forward(self, x, t, condition):
... # Concatenate image condition "y" channel-wise to input
... y = condition["y"] # shape: (B, C_cond, H, W)
... x_cat = torch.cat([x, y], dim=1)
... # Extract 1D vector condition
... vec = condition["vec"] # shape: (B, cond_dim)
... return self.net(x_cat, t, condition=vec)
...
>>> wrapped_dit = DiTWrapper(
... img_channels=2, cond_channels=1, cond_dim=4,
... input_size=8, patch_size=4, attention_backend="timm",
... )
>>> isinstance(wrapped_dit, DiffusionModel)
True
>>> x = torch.rand(1, 2, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
... "y": torch.rand(1, 1, 8, 8), # image condition
... "vec": torch.rand(1, 4), # 1D vector condition
... }, batch_size=[1])
>>> out = wrapped_dit(x, t, condition)
>>> out.shape
torch.Size([1, 2, 8, 8])
**Example: Using ConcatConditionWrapper with a preconditioner**
The pattern in the previous example, where (spatially-varying) conditioning
is concatenated to the noised latent state (and possibly vector conditioning
is also passed as a separate argument) is common across several diffusion
use-cases.
Thus for convenience, we provide the wrapper class :class:`~physicsnemo.diffusion.utils.ConcatConditionWrapper`
to save you the trouble of writing your own wrapper for this common pattern:
>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.utils import ConcatConditionWrapper
>>> base_model = SongUNet(img_resolution=8, in_channels=4, out_channels=3, label_dim=4)
>>> wrapped_model = ConcatConditionWrapper(base_model)
>>> precond = EDMPreconditioner(wrapped_model, sigma_data=0.5)
>>> x = torch.rand(1, 3, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
... "cond_concat": torch.rand(1, 1, 8, 8), # image condition
... "cond_vec": torch.rand(1, 4), # vector condition
... }, batch_size=[1])
>>> out = precond(x, t, condition)
>>> out.shape
torch.Size([1, 3, 8, 8])
The same wrapper can be used with :class:`~physicsnemo.models.dit.DiT`
backbones as well, with ``cond_vec`` passed to the model's ``condition`` argument.
"""
def __init__(
self,
model: Module,
meta: ModelMetaData | None = None,
) -> None:
super().__init__()
self.meta = meta
self.model = model
[docs]
@abstractmethod
def compute_coefficients(
self, t: torch.Tensor, /
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute the preconditioning coefficients for a given diffusion time
:math:`t` or noise level :math:`\sigma`.
This abstract method must be implemented by subclasses to define
the specific preconditioning scheme.
Parameters
----------
t : torch.Tensor
Diffusion time (or noise level if :meth:`sigma` is
implemented) tensor of shape :math:`(B, 1, ..., 1)` where
:math:`B` is the batch size and the trailing singleton
dimensions match the spatial dimensions of the latent state
``x`` for broadcasting.
Returns
-------
c_in : torch.Tensor
Input scaling coefficient of shape :math:`(B, 1, ..., 1)`.
c_noise : torch.Tensor
Noise conditioning value of shape :math:`(B, 1, ..., 1)`.
c_out : torch.Tensor
Output scaling coefficient of shape :math:`(B, 1, ..., 1)`.
c_skip : torch.Tensor
Skip connection scaling coefficient of shape
:math:`(B, 1, ..., 1)`.
"""
...
[docs]
def sigma(self, t: torch.Tensor) -> torch.Tensor:
r"""
Map diffusion time :math:`t` to noise level :math:`\sigma(t)`.
By default, this is the identity function :math:`\sigma(t) = t`.
Subclasses can override this to implement preconditioners for different
time-to-noise-level mappings.
When overridden, the output of this method is passed to
:meth:`compute_coefficients` instead of the raw time ``t``.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)` where
:math:`B` is the batch size.
Returns
-------
torch.Tensor
Noise level :math:`\sigma(t)` of shape :math:`(B,)`.
"""
return t
def _validate_input(self, x: torch.Tensor) -> None:
"""Subclass hook for additional input validation.
Called inside the ``torch.compiler.is_compiling()`` guard in
:meth:`forward`, so it is skipped during ``torch.compile``
tracing. The default implementation is a no-op.
"""
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
condition: torch.Tensor | TensorDict | None = None,
**model_kwargs: Any,
) -> torch.Tensor:
if not torch.compiler.is_compiling():
B = x.shape[0]
if t.shape != (B,):
raise ValueError(
f"Expected t to have shape ({B},) matching batch size of "
f"x, but got {t.shape}."
)
if isinstance(condition, TensorDict):
if condition.batch_size and condition.batch_size[0] != B:
raise ValueError(
f"Condition TensorDict has batch size {condition.batch_size[0]} "
f"but expected {B} to match x."
)
elif isinstance(condition, torch.Tensor):
if condition.shape[0] != B:
raise ValueError(
f"Condition tensor has batch size {condition.shape[0]} "
f"but expected {B} to match x."
)
self._validate_input(x)
# Map time step to noise level via sigma method
expected_shape = (-1,) + (1,) * (x.ndim - 1)
sigma_t = self.sigma(t).reshape(expected_shape)
# Compute preconditioning coefficients
# Unwrap to plain tensors so that element-wise arithmetic between
# coefficients and sampled sigma (may be Replicated in domain-parallel
# scenarios) is type-compatible.
sigma_t = _ensure_plain_tensor(sigma_t)
c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma_t)
# FSDP may convert model buffers (e.g. sigma_data) to DTensors, which
# propagates through compute_coefficients. First unwrap to plain
# tensors, then re-promote to replicated DTensors on x's mesh when
# x is a ShardTensor so that element-wise arithmetic between
# coefficients and data is type-compatible.
c_in, c_noise, c_out, c_skip = (
_ensure_plain_tensor(c) for c in (c_in, c_noise, c_out, c_skip)
)
x_mesh = getattr(x, "device_mesh", None)
if x_mesh is not None:
c_in, c_noise, c_out, c_skip = (
_replicate_on_mesh(c, x_mesh) for c in (c_in, c_noise, c_out, c_skip)
)
# Forward through the underlying model
if condition is not None:
F_x = self.model(
c_in * x,
c_noise.flatten(),
condition=condition,
**model_kwargs,
)
else:
F_x = self.model(
c_in * x,
c_noise.flatten(),
**model_kwargs,
)
D_x = c_skip * x + c_out * F_x
return D_x
[docs]
class VPPreconditioner(BaseAffinePreconditioner):
r"""
Variance Preserving (VP) preconditioner.
Implements the preconditioning scheme from the VP formulation of
score-based generative models.
The time-to-noise-level mapping is:
.. math::
\sigma(t) = \sqrt{\exp\left(\frac{\beta_d}{2} t^2
+ \beta_{\min} t\right) - 1}
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= 1 \\
c_{\text{out}} &= -\sigma \\
c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\
c_{\text{noise}} &= (M - 1) \cdot \sigma^{-1}(\sigma)
With these coefficients, the preconditioned model output is expected to be
an :math:`\mathbf{x}_0`-prediction (clean data estimate).
This preconditioner is not directly compatible for score-prediction
training or others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.VPNoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
beta_d : float, optional
Extent of the noise level schedule, by default 19.9.
beta_min : float, optional
Initial slope of the noise level schedule, by default 0.1.
M : int, optional
Number of discretization steps in the DDPM formulation,
by default 1000.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Score-Based Generative Modeling through Stochastic
Differential Equations <https://arxiv.org/abs/2011.13456>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = VPPreconditioner(model, beta_d=19.9, beta_min=0.1, M=1000)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(
self,
model: Module,
beta_d: float = 19.9,
beta_min: float = 0.1,
M: int = 1000,
) -> None:
super().__init__(model)
self.register_buffer("beta_d", torch.tensor(beta_d))
self.register_buffer("beta_min", torch.tensor(beta_min))
self.register_buffer("M", torch.tensor(M))
[docs]
def sigma(self, t: torch.Tensor) -> torch.Tensor:
r"""
Compute :math:`\sigma(t)` for the VP formulation.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
Returns
-------
torch.Tensor
Noise level :math:`\sigma(t)` of shape :math:`(B,)`.
"""
exponent = 0.5 * self.beta_d * (t**2) + self.beta_min * t
return (exponent.exp() - 1).sqrt()
[docs]
def compute_coefficients(
self, sigma: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute VP preconditioning coefficients.
Parameters
----------
sigma : torch.Tensor
Noise level tensor of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
c_skip = torch.ones_like(sigma)
c_out = -sigma
c_in = 1 / (sigma**2 + 1).sqrt()
# Compute t = sigma_inv(sigma)
t = (
(self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt()
- self.beta_min
) / self.beta_d
c_noise = (self.M - 1) * t
return c_in, c_noise, c_out, c_skip
[docs]
class VEPreconditioner(BaseAffinePreconditioner):
r"""
Variance Exploding (VE) preconditioner.
Implements the preconditioning scheme from the VE formulation of
score-based generative models.
For VE, the time-to-noise-level mapping is the identity:
:math:`\sigma(t) = t`.
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= 1 \\
c_{\text{out}} &= \sigma \\
c_{\text{in}} &= 1 \\
c_{\text{noise}} &= \log(0.5 \cdot \sigma)
With these coefficients, the preconditioned model output is expected to be an
:math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner
is not directly compatible for score-prediction training or others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.VENoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Score-Based Generative Modeling through Stochastic
Differential Equations <https://arxiv.org/abs/2011.13456>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = VEPreconditioner(model)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(self, model: Module) -> None:
super().__init__(model)
[docs]
def compute_coefficients(
self, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute VE preconditioning coefficients.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
c_skip = torch.ones_like(t)
c_out = t
c_in = torch.ones_like(t)
c_noise = (0.5 * t).log()
return c_in, c_noise, c_out, c_skip
[docs]
class IDDPMPreconditioner(BaseAffinePreconditioner):
r"""
Improved DDPM (iDDPM) preconditioner.
Implements the preconditioning scheme from the improved DDPM
formulation.
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= 1 \\
c_{\text{out}} &= -\sigma \\
c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\
c_{\text{noise}} &= M - 1 - \text{argmin}|\sigma - u_j|
where :math:`u_j, j = 0, ..., M` are the precomputed noise levels from
the iDDPM discretization.
With these coefficients, the preconditioned model output is expected to be
an :math:`\mathbf{x}_0`-prediction (clean data estimate). This
preconditioner is not directly compatible for score-prediction training or
others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.IDDPMNoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
C_1 : float, optional
Timestep adjustment at low noise levels, by default 0.001.
C_2 : float, optional
Timestep adjustment at high noise levels, by default 0.008.
M : int, optional
Number of discretization steps in the DDPM formulation,
by default 1000.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Improved Denoising Diffusion Probabilistic Models
<https://arxiv.org/abs/2102.09672>`_
Examples
--------
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = IDDPMPreconditioner(model, C_1=0.001, C_2=0.008, M=1000)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(
self,
model: Module,
C_1: float = 0.001,
C_2: float = 0.008,
M: int = 1000,
) -> None:
super().__init__(model)
self.register_buffer("C_1", torch.tensor(C_1))
self.register_buffer("C_2", torch.tensor(C_2))
self.register_buffer("M", torch.tensor(M))
# Precompute the noise level schedule u_j, j = 0, ..., M
u = torch.zeros(M + 1)
for j in range(M, 0, -1):
angle_j = 0.5 * math.pi * j / M / (C_2 + 1)
angle_jm1 = 0.5 * math.pi * (j - 1) / M / (C_2 + 1)
alpha_bar_j = math.sin(angle_j) ** 2
alpha_bar_jm1 = math.sin(angle_jm1) ** 2
alpha_ratio = alpha_bar_jm1 / alpha_bar_j
u[j - 1] = ((u[j] ** 2 + 1) / max(alpha_ratio, C_1) - 1).sqrt()
self.register_buffer("u", u)
[docs]
def compute_coefficients(
self, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute iDDPM preconditioning coefficients.
Parameters
----------
t : torch.Tensor
Diffusion time tensor of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`.
"""
c_skip = torch.ones_like(t)
c_out = -t
c_in = 1 / (t**2 + 1).sqrt()
# Round sigma to nearest index in precomputed schedule u
u: torch.Tensor = self.u # type: ignore[assignment]
t_flat = t.reshape(1, -1, 1)
u_reshaped = u.reshape(1, -1, 1)
idx = torch.cdist(t_flat, u_reshaped).argmin(2).reshape(t.shape)
c_noise = self.M - 1 - idx
return c_in, c_noise, c_out, c_skip
[docs]
class EDMPreconditioner(BaseAffinePreconditioner):
r"""
EDM preconditioner.
Implements the improved preconditioning scheme proposed in the EDM
paper.
For EDM, the time-to-noise-level mapping is the identity:
:math:`\sigma(t) = t`.
The preconditioning coefficients are:
.. math::
c_{\text{skip}} &= \frac{\sigma_{\text{data}}^2}
{\sigma^2 + \sigma_{\text{data}}^2} \\
c_{\text{out}} &= \frac{\sigma \cdot \sigma_{\text{data}}}
{\sqrt{\sigma^2 + \sigma_{\text{data}}^2}} \\
c_{\text{in}} &= \frac{1}
{\sqrt{\sigma_{\text{data}}^2 + \sigma^2}} \\
c_{\text{noise}} &= \frac{\log(\sigma)}{4}
With these coefficients, the preconditioned model output is expected to be an
:math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner
is not directly compatible for score-prediction training or others.
For training, it is usually paired with
:class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss`
(``prediction_type="x0"``) and
:class:`~physicsnemo.diffusion.noise_schedulers.EDMNoiseScheduler`.
Parameters
----------
model : physicsnemo.Module
The underlying neural network model to wrap with signature described in
:class:`BaseAffinePreconditioner`.
sigma_data : float or Sequence[float] or torch.Tensor, optional
Expected standard deviation of the training data, by default 0.5.
When a scalar ``float`` is given, the same value is applied to all
channels. When a ``Sequence[float]`` or 1-D ``Tensor`` of length
:math:`C` is given, each output channel receives its own
preconditioning and :math:`c_{\text{skip}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{in}}` become per-channel while
:math:`c_{\text{noise}}` remains scalar. The per-channel form should
be paired with an
:class:`~physicsnemo.diffusion.noise_schedulers.EDMNoiseScheduler`
constructed with the same per-channel ``sigma_data``.
Forward
-------
x : torch.Tensor
Noisy latent state of shape :math:`(B, *)` where :math:`B` is the
batch size and :math:`*` denotes any number of additional dimensions.
t : torch.Tensor
Diffusion time tensor of shape :math:`(B,)`.
condition : torch.Tensor, TensorDict, or None, optional, default=None
Single Tensor or a TensorDict containing conditioning tensors with
batch size :math:`B` matching that of ``x``. Pass ``None`` for an
unconditional model.
**model_kwargs : Any
Additional keyword arguments passed to the underlying model.
Outputs
-------
torch.Tensor
Preconditioned model output with the same shape as the original model
output.
Note
----
Reference: `Elucidating the Design Space of Diffusion-Based
Generative Models <https://arxiv.org/abs/2206.00364>`_
Examples
--------
Scalar ``sigma_data`` (default):
>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
... def __init__(self, channels: int):
... super().__init__()
... self.net = torch.nn.Conv2d(channels, channels, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = EDMPreconditioner(model, sigma_data=0.5)
>>> x = torch.randn(2, 3, 16, 16) # batch of 2 images
>>> t = torch.rand(2) # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
Per-channel ``sigma_data`` for heterogeneous channels (e.g. weather
variables with different scales):
>>> precond_ch = EDMPreconditioner(model, sigma_data=[0.3, 0.5, 0.7])
>>> out_ch = precond_ch(x, t, condition=None)
>>> out_ch.shape
torch.Size([2, 3, 16, 16])
"""
def __init__(
self,
model: Module,
sigma_data: float | Sequence[float] | Float[torch.Tensor, " C"] = 0.5,
) -> None:
super().__init__(model)
if isinstance(sigma_data, torch.Tensor):
sd = sigma_data.detach().to(dtype=torch.float32).reshape(-1)
elif isinstance(sigma_data, Sequence) and not isinstance(sigma_data, str):
sd = torch.as_tensor(list(sigma_data), dtype=torch.float32)
else:
sd = torch.as_tensor(float(sigma_data))
if sd.ndim > 0 and sd.numel() == 1:
sd = sd.squeeze()
self.register_buffer("sigma_data", sd)
# Store the number of channels for per-channel sigma_data so that
# validation in forward() can check against x without reading
# sigma_data.ndim at runtime (which could cause a graph break).
self._sigma_data_channels: int = sd.numel() if sd.ndim > 0 else 0
def _validate_input(self, x: torch.Tensor) -> None:
C = self._sigma_data_channels
if C > 0 and x.ndim >= 2 and x.shape[1] != C:
raise ValueError(
f"EDMPreconditioner has per-channel sigma_data with "
f"{C} channels, but input x has {x.shape[1]} channels "
f"(shape {tuple(x.shape)})."
)
def _reshape_sigma_data(self, t: torch.Tensor) -> torch.Tensor:
"""Reshape ``sigma_data`` for broadcasting against *t*.
For scalar ``sigma_data`` this is a no-op. For per-channel
``sigma_data`` of shape ``(C,)`` the result is
``(1, C, 1, ..., 1)`` matching the spatial dims of *t*.
"""
sd = self.sigma_data
if self._sigma_data_channels > 0:
sd = sd.view(1, -1, *([1] * (t.ndim - 2)))
return sd
[docs]
def compute_coefficients(
self, t: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
r"""
Compute EDM preconditioning coefficients.
Parameters
----------
t : torch.Tensor
Diffusion time (or noise level, since they are identical for EDM)
of shape :math:`(B, 1, ..., 1)`.
Returns
-------
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Preconditioning coefficients (:math:`c_{\text{in}}`,
:math:`c_{\text{noise}}`, :math:`c_{\text{out}}`,
:math:`c_{\text{skip}}`). When ``sigma_data`` is scalar all
coefficients have shape :math:`(B, 1, ..., 1)`. When
``sigma_data`` is per-channel, :math:`c_{\text{in}}`,
:math:`c_{\text{out}}`, and :math:`c_{\text{skip}}` have shape
:math:`(B, C, 1, ..., 1)` while :math:`c_{\text{noise}}`
remains :math:`(B, 1, ..., 1)`.
"""
sd = self._reshape_sigma_data(t)
c_skip = sd**2 / (t**2 + sd**2)
c_out = t * sd / (t**2 + sd**2).sqrt()
c_in = 1 / (sd**2 + t**2).sqrt()
c_noise = t.log() / 4
return c_in, c_noise, c_out, c_skip