Preconditioners#

Preconditioning is an optional but very common technique that improves the stability and convergence of diffusion model training. The core idea is that the raw inputs and outputs of a neural network span very different scales depending on the noise level \(\sigma(t)\). A preconditioner wraps the backbone with an affine rescaling so that the effective input and output have unit variance across all noise levels, making the learning problem uniformly well-conditioned.

Three Approaches#

Depending on how much customization you need, there are three ways to use preconditioning in the framework.

1. Ready-to-use preconditioners. The framework ships preconditioners that pair with each built-in noise scheduler. These work out of the box with no additional implementation:

from physicsnemo.diffusion.preconditioners import EDMPreconditioner

precond = EDMPreconditioner(backbone_model, sigma_data=0.5)

2. Subclass the abstract base class. For a custom affine preconditioning scheme, subclass BaseAffinePreconditioner and implement compute_coefficients(). Optionally override sigma() if \(\sigma(t) \neq t\). The forward method should not be overridden.

from physicsnemo.diffusion.preconditioners import BaseAffinePreconditioner

class MyPreconditioner(BaseAffinePreconditioner):
    def compute_coefficients(self, sigma):
        c_skip = 1 / (sigma**2 + 1)
        c_out = sigma / (sigma**2 + 1).sqrt()
        c_in = 1 / (sigma**2 + 1).sqrt()
        c_noise = sigma.log() / 4
        return c_in, c_noise, c_out, c_skip

3. Implement preconditioning directly in a Module. If the affine formula does not fit your use case, you can implement preconditioning directly in a Module that satisfies the DiffusionModel protocol. This gives complete freedom over the preconditioning logic.

from physicsnemo.core import Module

class MyPreconditionedModel(Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone

    def forward(self, x, t, condition=None):
        # Custom preconditioning logic
        x_scaled = x / (1 + t.view(-1, 1, 1, 1)**2).sqrt()
        out = self.backbone(x_scaled, t, condition)
        return x + t.view(-1, 1, 1, 1) * out

How Preconditioners Fit in the Pipeline#

A preconditioner itself satisfies the DiffusionModel interface, so it can be used anywhere a plain model is expected—in training losses, in the denoiser factory, and in sampling:

from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.metrics.losses import MSEDSMLoss

scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone_model, sigma_data=0.5)
loss_fn = MSEDSMLoss(precond, scheduler)

# Training: the loss sees `precond` as the model
loss = loss_fn(x0, condition=condition)

# Sampling: the preconditioner is used as the predictor
from functools import partial
x0_predictor = partial(precond, condition=condition)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)

Important

All built-in preconditioners are designed so that the preconditioned output is an \(\mathbf{x}_0\)-prediction (clean data estimate). They are intended for use with MSEDSMLoss with prediction_type="x0" (the default).

See the BaseAffinePreconditioner docstring for additional examples, including how to write thin wrappers to adapt backbones with non-standard signatures (for example, SongUNet, DiT) to the DiffusionModel interface.

API Reference#

BaseAffinePreconditioner#

class physicsnemo.diffusion.preconditioners.BaseAffinePreconditioner(*args, **kwargs)[source]#

Bases: Module, ABC

Abstract base class for diffusion model preconditioners using an affine transformation.

This class provides a standardized interface for implementing preconditioners that use affine transformations of the model input and output.

The preconditioner wraps a neural network model \(F\) and applies a preconditioning formula to transform the network output to produce the preconditioned output \(D(\mathbf{x}, t)\) according to:

\[D(\mathbf{x}, t) = c_{\text{skip}}(t) \mathbf{x} + c_{\text{out}}(t) F(c_{\text{in}}(t) \mathbf{x}, c_{\text{noise}}(t))\]

where:

  • \(c_{\text{in}}(t)\): Input scaling coefficient

  • \(c_{\text{noise}}(t)\): Noise conditioning value

  • \(c_{\text{out}}(t)\): Output scaling coefficient

  • \(c_{\text{skip}}(t)\): Skip connection scaling coefficient

and where \(\mathbf{x}\) is the latent state and \(t\) is the diffusion time.

The wrapped model \(F\) must be an instance of Module that satisfies the DiffusionModel interface, with the following signature:

model(
    x: torch.Tensor,  # Shape: (B, *)
    t: torch.Tensor,  # Shape: (B,)
    condition: torch.Tensor | TensorDict | None = None,
    **model_kwargs: Any,
) -> torch.Tensor  # Shape: (B, *)

The preconditioner is agnostic to the prediction target of the wrapped model \(F\). The same preconditioning formula is applied regardless of whether the model is an \(\mathbf{x}_0\)-predictor, an \(\epsilon\)-predictor, a score predictor, or a \(\mathbf{v}\)-predictor.

Note

The preconditioner itself also satisfies the DiffusionModel interface, meaning it does not change the signature of the wrapped model \(F\), and it can be used anywhere a diffusion model is expected.

Parameters:
  • model (physicsnemo.Module) – The underlying neural network model \(F\) to wrap with the signature described above.

  • meta (ModelMetaData, optional) – Meta data class for storing info regarding model, by default None. Subclasses can pass their own metadata.

Forward:
  • x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.

  • t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).

  • condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of x. Pass None for an unconditional model.

  • **model_kwargs (Any) – Additional keyword arguments passed to the underlying model.

Outputs:

torch.Tensor – Preconditioned model output with the same shape as the original model output.

Note

To implement a new preconditioner, a subclass of BaseAffinePreconditioner must be defined, and some methods have to be implemented:

  • Subclasses must implement the compute_coefficients() method to define the specific preconditioning scheme.

  • A sigma() method can optionally be implemented. If a subclass implements the sigma() method, the diffusion time \(t\) is first transformed to a noise level \(\sigma(t)\) before being passed to compute_coefficients(). This allows implementing preconditioners for different time-to-noise-level mappings while keeping the same preconditioning interface, in particular for preconditioning schemes based on noise level (that is \(c_{\text{in}}(\sigma)\), \(c_{\text{noise}}(\sigma)\), \(c_{\text{out}}(\sigma)\), \(c_{\text{skip}}(\sigma)\) instead of \(c_{\text{in}}(t)\), \(c_{\text{noise}}(t)\), \(c_{\text{out}}(t)\), \(c_{\text{skip}}(t)\)).

  • The forward method of the preconditioner should not be overriden.

The argument t of the preconditioner forward method is always assumed to be the diffusion time. For preconditioning schemes based on noise level, the noise level \(\sigma(t)\) is computed internally using the sigma() method.

Examples

The following example shows how to implement a classical EDM preconditioner. For EDM, there is no need to implement the sigma() method since \(\sigma(t) = t\) (noise level and diffusion time are the same).

We first define a simple model to wrap:

>>> import torch
>>> from tensordict import TensorDict
>>> from physicsnemo.nn import Module
>>> class SimpleModel(Module):
...     def __init__(self, channels: int):
...         super().__init__()
...         self.channels = channels
...         self.net = torch.nn.Conv2d(channels, channels, 1)
...
...     def forward(self, x, t, condition=None):
...         return self.net(x)

Now we define the EDM preconditioner:

>>> from physicsnemo.diffusion.preconditioners import (
...     BaseAffinePreconditioner,
... )
>>> class SimpleEDMPreconditioner(BaseAffinePreconditioner):
...     def __init__(self, model, sigma_data: float = 0.5):
...         super().__init__(model)
...         self.sigma_data = sigma_data
...
...     def compute_coefficients(self, t: torch.Tensor):
...         # For EDM sigma(t) = t, so the argument passed to
...         # compute_coefficients is already sigma(t)
...         sigma_data = self.sigma_data
...         c_skip = sigma_data**2 / (t**2 + sigma_data**2)
...         c_out = t * sigma_data / (t**2 + sigma_data**2).sqrt()
...         c_in = 1 / (sigma_data**2 + t**2).sqrt()
...         c_noise = t.log() / 4
...         return c_in, c_noise, c_out, c_skip
...
>>> model = SimpleModel(channels=3)
>>> precond = SimpleEDMPreconditioner(model, sigma_data=0.5)
>>> x = torch.randn(2, 3, 16, 16)
>>> t = torch.rand(2)
>>> condition = TensorDict({}, batch_size=[2])
>>> out = precond(x, t, condition)
>>> out.shape
torch.Size([2, 3, 16, 16])

The following example shows how to override the sigma() method to implement a Variance Exploding (VE) preconditioner where \(\sigma(t) = \sqrt{t}\).

>>> class VEPreconditioner(BaseAffinePreconditioner):
...     def __init__(self, model):
...         super().__init__(model)
...
...     def sigma(self, t: torch.Tensor) -> torch.Tensor:
...         # Override sigma for VE time-to-noise-level mapping
...         return t.sqrt()
...
...     def compute_coefficients(self, sigma: torch.Tensor):
...         # Here the argument passed to compute_coefficients is
...         # sigma(t) = sqrt(t) due to override of the sigma method
...         # due to override of the sigma method
...         c_skip = torch.ones_like(sigma)
...         c_out = sigma
...         c_in = torch.ones_like(sigma)
...         c_noise = (0.5 * sigma).log()
...         return c_in, c_noise, c_out, c_skip
...
>>> precond_ve = VEPreconditioner(model)
>>> out_ve = precond_ve(x, t, condition)
>>> out_ve.shape
torch.Size([2, 3, 16, 16])

Wrapping existing models to satisfy the DiffusionModel interface

Some models in PhysicsNeMo have signatures that differ from the DiffusionModel interface. Below are examples showing how to write thin wrappers to make them compatible with preconditioners, including image-based conditioning via channel concatenation.

Example: Wrapping SongUNet

The SongUNet model has the signature forward(x, noise_labels, class_labels, augment_labels). We wrap it to match forward(x, t, condition), where condition contains both class labels (1D vector) and an image to concatenate channel-wise:

>>> from physicsnemo.models.diffusion_unets import SongUNet
>>> from physicsnemo.diffusion import DiffusionModel
>>> from tensordict import TensorDict
>>> class SongUNetWrapper(Module):
...     def __init__(self, img_channels, cond_channels, label_dim, **kwargs):
...         super().__init__()
...         # in_channels = img_channels + cond_channels for concatenation
...         self.net = SongUNet(
...             in_channels=img_channels + cond_channels,
...             out_channels=img_channels,
...             label_dim=label_dim,
...             **kwargs,
...         )
...
...     def forward(self, x, t, condition):
...         # Concatenate image condition "y" channel-wise to input
...         y = condition["y"]  # shape: (B, C_cond, H, W)
...         x_cat = torch.cat([x, y], dim=1)
...         # Extract 1D vector condition for class_labels
...         class_labels = condition["class_labels"]  # shape: (B, label_dim)
...         return self.net(x_cat, noise_labels=t, class_labels=class_labels)
...
>>> wrapped = SongUNetWrapper(
...     img_channels=2, cond_channels=1, label_dim=4, img_resolution=8
... )
>>> isinstance(wrapped, DiffusionModel)
True
>>> x = torch.rand(1, 2, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
...     "y": torch.rand(1, 1, 8, 8),           # image condition
...     "class_labels": torch.rand(1, 4),      # 1D vector condition
... }, batch_size=[1])
>>> out = wrapped(x, t, condition)
>>> out.shape
torch.Size([1, 2, 8, 8])

Example: Wrapping DiT

The DiT model has the signature forward(x, t, condition, ...). We wrap it to support both image conditioning (via channel concatenation) and vector conditioning:

>>> from physicsnemo.models.dit import DiT
>>> class DiTWrapper(Module):
...     def __init__(self, img_channels, cond_channels, cond_dim, **kwargs):
...         super().__init__()
...         # in_channels = img_channels + cond_channels for concatenation
...         self.net = DiT(
...             in_channels=img_channels + cond_channels,
...             out_channels=img_channels,
...             condition_dim=cond_dim,
...             **kwargs,
...         )
...
...     def forward(self, x, t, condition):
...         # Concatenate image condition "y" channel-wise to input
...         y = condition["y"]  # shape: (B, C_cond, H, W)
...         x_cat = torch.cat([x, y], dim=1)
...         # Extract 1D vector condition
...         vec = condition["vec"]  # shape: (B, cond_dim)
...         return self.net(x_cat, t, condition=vec)
...
>>> wrapped_dit = DiTWrapper(
...     img_channels=2, cond_channels=1, cond_dim=4,
...     input_size=8, patch_size=4, attention_backend="timm",
... )
>>> isinstance(wrapped_dit, DiffusionModel)
True
>>> x = torch.rand(1, 2, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
...     "y": torch.rand(1, 1, 8, 8),  # image condition
...     "vec": torch.rand(1, 4),       # 1D vector condition
... }, batch_size=[1])
>>> out = wrapped_dit(x, t, condition)
>>> out.shape
torch.Size([1, 2, 8, 8])

Example: Using ConcatConditionWrapper with a preconditioner

The pattern in the previous example, where (spatially-varying) conditioning is concatenated to the noised latent state (and possibly vector conditioning is also passed as a separate argument) is common across several diffusion use-cases. Thus for convenience, we provide the wrapper class ConcatConditionWrapper to save you the trouble of writing your own wrapper for this common pattern:

>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.utils import ConcatConditionWrapper
>>> base_model = SongUNet(img_resolution=8, in_channels=4, out_channels=3, label_dim=4)
>>> wrapped_model = ConcatConditionWrapper(base_model)
>>> precond = EDMPreconditioner(wrapped_model, sigma_data=0.5)
>>> x = torch.rand(1, 3, 8, 8)
>>> t = torch.rand(1)
>>> condition = TensorDict({
...     "cond_concat": torch.rand(1, 1, 8, 8),  # image condition
...     "cond_vec": torch.rand(1, 4),           # vector condition
... }, batch_size=[1])
>>> out = precond(x, t, condition)
>>> out.shape
torch.Size([1, 3, 8, 8])

The same wrapper can be used with DiT backbones as well, with cond_vec passed to the model’s condition argument.

abstractmethod compute_coefficients(
t: Tensor,
/,
) Tuple[Tensor, Tensor, Tensor, Tensor][source]#

Compute the preconditioning coefficients for a given diffusion time \(t\) or noise level \(\sigma\).

This abstract method must be implemented by subclasses to define the specific preconditioning scheme.

Parameters:

t (torch.Tensor) – Diffusion time (or noise level if sigma() is implemented) tensor of shape \((B, 1, ..., 1)\) where \(B\) is the batch size and the trailing singleton dimensions match the spatial dimensions of the latent state x for broadcasting.

Returns:

  • c_in (torch.Tensor) – Input scaling coefficient of shape \((B, 1, ..., 1)\).

  • c_noise (torch.Tensor) – Noise conditioning value of shape \((B, 1, ..., 1)\).

  • c_out (torch.Tensor) – Output scaling coefficient of shape \((B, 1, ..., 1)\).

  • c_skip (torch.Tensor) – Skip connection scaling coefficient of shape \((B, 1, ..., 1)\).

sigma(t: Tensor) Tensor[source]#

Map diffusion time \(t\) to noise level \(\sigma(t)\).

By default, this is the identity function \(\sigma(t) = t\). Subclasses can override this to implement preconditioners for different time-to-noise-level mappings.

When overridden, the output of this method is passed to compute_coefficients() instead of the raw time t.

Parameters:

t (torch.Tensor) – Diffusion time tensor of shape \((B,)\) where \(B\) is the batch size.

Returns:

Noise level \(\sigma(t)\) of shape \((B,)\).

Return type:

torch.Tensor

EDMPreconditioner#

class physicsnemo.diffusion.preconditioners.EDMPreconditioner(*args, **kwargs)[source]#

Bases: BaseAffinePreconditioner

EDM preconditioner.

Implements the improved preconditioning scheme proposed in the EDM paper.

For EDM, the time-to-noise-level mapping is the identity: \(\sigma(t) = t\).

The preconditioning coefficients are:

\[\begin{split}c_{\text{skip}} &= \frac{\sigma_{\text{data}}^2} {\sigma^2 + \sigma_{\text{data}}^2} \\ c_{\text{out}} &= \frac{\sigma \cdot \sigma_{\text{data}}} {\sqrt{\sigma^2 + \sigma_{\text{data}}^2}} \\ c_{\text{in}} &= \frac{1} {\sqrt{\sigma_{\text{data}}^2 + \sigma^2}} \\ c_{\text{noise}} &= \frac{\log(\sigma)}{4}\end{split}\]

With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.

For training, it is usually paired with MSEDSMLoss (prediction_type="x0") and EDMNoiseScheduler.

Parameters:
  • model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in BaseAffinePreconditioner.

  • sigma_data (float, optional) – Expected standard deviation of the training data, by default 0.5.

Forward:
  • x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.

  • t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).

  • condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of x. Pass None for an unconditional model.

  • **model_kwargs (Any) – Additional keyword arguments passed to the underlying model.

Outputs:

torch.Tensor – Preconditioned model output with the same shape as the original model output.

Examples

>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
...     def __init__(self, channels: int):
...         super().__init__()
...         self.net = torch.nn.Conv2d(channels, channels, 1)
...     def forward(self, x, t, condition=None):
...         return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = EDMPreconditioner(model, sigma_data=0.5)
>>> x = torch.randn(2, 3, 16, 16)  # batch of 2 images
>>> t = torch.rand(2)              # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
compute_coefficients(
t: Tensor,
) Tuple[Tensor, Tensor, Tensor, Tensor][source]#

Compute EDM preconditioning coefficients.

Parameters:

t (torch.Tensor) – Diffusion time (or noise level, since they are identical for EDM) of shape \((B, 1, ..., 1)\).

Returns:

Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

VEPreconditioner#

class physicsnemo.diffusion.preconditioners.VEPreconditioner(*args, **kwargs)[source]#

Bases: BaseAffinePreconditioner

Variance Exploding (VE) preconditioner.

Implements the preconditioning scheme from the VE formulation of score-based generative models.

For VE, the time-to-noise-level mapping is the identity: \(\sigma(t) = t\).

The preconditioning coefficients are:

\[\begin{split}c_{\text{skip}} &= 1 \\ c_{\text{out}} &= \sigma \\ c_{\text{in}} &= 1 \\ c_{\text{noise}} &= \log(0.5 \cdot \sigma)\end{split}\]

With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.

For training, it is usually paired with MSEDSMLoss (prediction_type="x0") and VENoiseScheduler.

Parameters:

model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in BaseAffinePreconditioner.

Forward:
  • x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.

  • t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).

  • condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of x. Pass None for an unconditional model.

  • **model_kwargs (Any) – Additional keyword arguments passed to the underlying model.

Outputs:

torch.Tensor – Preconditioned model output with the same shape as the original model output.

Examples

>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
...     def __init__(self, channels: int):
...         super().__init__()
...         self.net = torch.nn.Conv2d(channels, channels, 1)
...     def forward(self, x, t, condition=None):
...         return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = VEPreconditioner(model)
>>> x = torch.randn(2, 3, 16, 16)  # batch of 2 images
>>> t = torch.rand(2)              # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
compute_coefficients(
t: Tensor,
) Tuple[Tensor, Tensor, Tensor, Tensor][source]#

Compute VE preconditioning coefficients.

Parameters:

t (torch.Tensor) – Diffusion time tensor of shape \((B, 1, ..., 1)\).

Returns:

Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

VPPreconditioner#

class physicsnemo.diffusion.preconditioners.VPPreconditioner(*args, **kwargs)[source]#

Bases: BaseAffinePreconditioner

Variance Preserving (VP) preconditioner.

Implements the preconditioning scheme from the VP formulation of score-based generative models.

The time-to-noise-level mapping is:

\[\sigma(t) = \sqrt{\exp\left(\frac{\beta_d}{2} t^2 + \beta_{\min} t\right) - 1}\]

The preconditioning coefficients are:

\[\begin{split}c_{\text{skip}} &= 1 \\ c_{\text{out}} &= -\sigma \\ c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ c_{\text{noise}} &= (M - 1) \cdot \sigma^{-1}(\sigma)\end{split}\]

With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.

For training, it is usually paired with MSEDSMLoss (prediction_type="x0") and VPNoiseScheduler.

Parameters:
  • model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in BaseAffinePreconditioner.

  • beta_d (float, optional) – Extent of the noise level schedule, by default 19.9.

  • beta_min (float, optional) – Initial slope of the noise level schedule, by default 0.1.

  • M (int, optional) – Number of discretization steps in the DDPM formulation, by default 1000.

Forward:
  • x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.

  • t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).

  • condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of x. Pass None for an unconditional model.

  • **model_kwargs (Any) – Additional keyword arguments passed to the underlying model.

Outputs:

torch.Tensor – Preconditioned model output with the same shape as the original model output.

Examples

>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
...     def __init__(self, channels: int):
...         super().__init__()
...         self.net = torch.nn.Conv2d(channels, channels, 1)
...     def forward(self, x, t, condition=None):
...         return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = VPPreconditioner(model, beta_d=19.9, beta_min=0.1, M=1000)
>>> x = torch.randn(2, 3, 16, 16)  # batch of 2 images
>>> t = torch.rand(2)              # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
compute_coefficients(
sigma: Tensor,
) Tuple[Tensor, Tensor, Tensor, Tensor][source]#

Compute VP preconditioning coefficients.

Parameters:

sigma (torch.Tensor) – Noise level tensor of shape \((B, 1, ..., 1)\).

Returns:

Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]

sigma(t: Tensor) Tensor[source]#

Compute \(\sigma(t)\) for the VP formulation.

Parameters:

t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).

Returns:

Noise level \(\sigma(t)\) of shape \((B,)\).

Return type:

torch.Tensor

IDDPMPreconditioner#

class physicsnemo.diffusion.preconditioners.IDDPMPreconditioner(*args, **kwargs)[source]#

Bases: BaseAffinePreconditioner

Improved DDPM (iDDPM) preconditioner.

Implements the preconditioning scheme from the improved DDPM formulation.

The preconditioning coefficients are:

\[\begin{split}c_{\text{skip}} &= 1 \\ c_{\text{out}} &= -\sigma \\ c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ c_{\text{noise}} &= M - 1 - \text{argmin}|\sigma - u_j|\end{split}\]

where \(u_j, j = 0, ..., M\) are the precomputed noise levels from the iDDPM discretization.

With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.

For training, it is usually paired with MSEDSMLoss (prediction_type="x0") and IDDPMNoiseScheduler.

Parameters:
  • model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in BaseAffinePreconditioner.

  • C_1 (float, optional) – Timestep adjustment at low noise levels, by default 0.001.

  • C_2 (float, optional) – Timestep adjustment at high noise levels, by default 0.008.

  • M (int, optional) – Number of discretization steps in the DDPM formulation, by default 1000.

Forward:
  • x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.

  • t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).

  • condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of x. Pass None for an unconditional model.

  • **model_kwargs (Any) – Additional keyword arguments passed to the underlying model.

Outputs:

torch.Tensor – Preconditioned model output with the same shape as the original model output.

Examples

>>> import torch
>>> from physicsnemo.core import Module
>>> # Define a simple model satisfying the diffusion model interface
>>> class SimpleModel(Module):
...     def __init__(self, channels: int):
...         super().__init__()
...         self.net = torch.nn.Conv2d(channels, channels, 1)
...     def forward(self, x, t, condition=None):
...         return self.net(x)
>>> model = SimpleModel(channels=3)
>>> precond = IDDPMPreconditioner(model, C_1=0.001, C_2=0.008, M=1000)
>>> x = torch.randn(2, 3, 16, 16)  # batch of 2 images
>>> t = torch.rand(2)              # diffusion time for each sample
>>> out = precond(x, t, condition=None)
>>> out.shape
torch.Size([2, 3, 16, 16])
compute_coefficients(
t: Tensor,
) Tuple[Tensor, Tensor, Tensor, Tensor][source]#

Compute iDDPM preconditioning coefficients.

Parameters:

t (torch.Tensor) – Diffusion time tensor of shape \((B, 1, ..., 1)\).

Returns:

Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]