Preconditioners#
Preconditioning is an optional but very common technique that improves the stability and convergence of diffusion model training. The core idea is that the raw inputs and outputs of a neural network span very different scales depending on the noise level \(\sigma(t)\). A preconditioner wraps the backbone with an affine rescaling so that the effective input and output have unit variance across all noise levels, making the learning problem uniformly well-conditioned.
Three Approaches#
Depending on how much customization you need, there are three ways to use preconditioning in the framework.
1. Ready-to-use preconditioners. The framework ships preconditioners that pair with each built-in noise scheduler. These work out of the box with no additional implementation:
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
precond = EDMPreconditioner(backbone_model, sigma_data=0.5)
2. Subclass the abstract base class.
For a custom affine preconditioning scheme, subclass
BaseAffinePreconditioner and implement
compute_coefficients(). Optionally override
sigma() if \(\sigma(t) \neq t\). The
forward method should not be overridden.
from physicsnemo.diffusion.preconditioners import BaseAffinePreconditioner
class MyPreconditioner(BaseAffinePreconditioner):
def compute_coefficients(self, sigma):
c_skip = 1 / (sigma**2 + 1)
c_out = sigma / (sigma**2 + 1).sqrt()
c_in = 1 / (sigma**2 + 1).sqrt()
c_noise = sigma.log() / 4
return c_in, c_noise, c_out, c_skip
3. Implement preconditioning directly in a Module.
If the affine formula does not fit your use case, you can implement
preconditioning directly in a Module
that satisfies the DiffusionModel protocol.
This gives complete freedom over the preconditioning logic.
from physicsnemo.core import Module
class MyPreconditionedModel(Module):
def __init__(self, backbone):
super().__init__()
self.backbone = backbone
def forward(self, x, t, condition=None):
# Custom preconditioning logic
x_scaled = x / (1 + t.view(-1, 1, 1, 1)**2).sqrt()
out = self.backbone(x_scaled, t, condition)
return x + t.view(-1, 1, 1, 1) * out
How Preconditioners Fit in the Pipeline#
A preconditioner itself satisfies the
DiffusionModel interface, so it can be used
anywhere a plain model is expected—in training losses, in the denoiser
factory, and in sampling:
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.metrics.losses import MSEDSMLoss
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone_model, sigma_data=0.5)
loss_fn = MSEDSMLoss(precond, scheduler)
# Training: the loss sees `precond` as the model
loss = loss_fn(x0, condition=condition)
# Sampling: the preconditioner is used as the predictor
from functools import partial
x0_predictor = partial(precond, condition=condition)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
Important
All built-in preconditioners are designed so that the preconditioned output
is an \(\mathbf{x}_0\)-prediction (clean data estimate). They are
intended for use with
MSEDSMLoss with
prediction_type="x0" (the default).
See the BaseAffinePreconditioner docstring for additional examples,
including how to write thin wrappers to adapt backbones with non-standard
signatures (for example, SongUNet, DiT) to the
DiffusionModel interface.
API Reference#
BaseAffinePreconditioner#
- class physicsnemo.diffusion.preconditioners.BaseAffinePreconditioner(*args, **kwargs)[source]#
Bases:
Module,ABCAbstract base class for diffusion model preconditioners using an affine transformation.
This class provides a standardized interface for implementing preconditioners that use affine transformations of the model input and output.
The preconditioner wraps a neural network model \(F\) and applies a preconditioning formula to transform the network output to produce the preconditioned output \(D(\mathbf{x}, t)\) according to:
\[D(\mathbf{x}, t) = c_{\text{skip}}(t) \mathbf{x} + c_{\text{out}}(t) F(c_{\text{in}}(t) \mathbf{x}, c_{\text{noise}}(t))\]where:
\(c_{\text{in}}(t)\): Input scaling coefficient
\(c_{\text{noise}}(t)\): Noise conditioning value
\(c_{\text{out}}(t)\): Output scaling coefficient
\(c_{\text{skip}}(t)\): Skip connection scaling coefficient
and where \(\mathbf{x}\) is the latent state and \(t\) is the diffusion time.
The wrapped model \(F\) must be an instance of
Modulethat satisfies theDiffusionModelinterface, with the following signature:model( x: torch.Tensor, # Shape: (B, *) t: torch.Tensor, # Shape: (B,) condition: torch.Tensor | TensorDict | None = None, **model_kwargs: Any, ) -> torch.Tensor # Shape: (B, *)
The preconditioner is agnostic to the prediction target of the wrapped model \(F\). The same preconditioning formula is applied regardless of whether the model is an \(\mathbf{x}_0\)-predictor, an \(\epsilon\)-predictor, a score predictor, or a \(\mathbf{v}\)-predictor.
Note
The preconditioner itself also satisfies the
DiffusionModelinterface, meaning it does not change the signature of the wrapped model \(F\), and it can be used anywhere a diffusion model is expected.- Parameters:
model (physicsnemo.Module) – The underlying neural network model \(F\) to wrap with the signature described above.
meta (ModelMetaData, optional) – Meta data class for storing info regarding model, by default None. Subclasses can pass their own metadata.
- Forward:
x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.
t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).
condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of
x. PassNonefor an unconditional model.**model_kwargs (Any) – Additional keyword arguments passed to the underlying model.
- Outputs:
torch.Tensor – Preconditioned model output with the same shape as the original model output.
Note
To implement a new preconditioner, a subclass of
BaseAffinePreconditionermust be defined, and some methods have to be implemented:Subclasses must implement the
compute_coefficients()method to define the specific preconditioning scheme.A
sigma()method can optionally be implemented. If a subclass implements thesigma()method, the diffusion time \(t\) is first transformed to a noise level \(\sigma(t)\) before being passed tocompute_coefficients(). This allows implementing preconditioners for different time-to-noise-level mappings while keeping the same preconditioning interface, in particular for preconditioning schemes based on noise level (that is \(c_{\text{in}}(\sigma)\), \(c_{\text{noise}}(\sigma)\), \(c_{\text{out}}(\sigma)\), \(c_{\text{skip}}(\sigma)\) instead of \(c_{\text{in}}(t)\), \(c_{\text{noise}}(t)\), \(c_{\text{out}}(t)\), \(c_{\text{skip}}(t)\)).The
forwardmethod of the preconditioner should not be overriden.
The argument
tof the preconditioner forward method is always assumed to be the diffusion time. For preconditioning schemes based on noise level, the noise level \(\sigma(t)\) is computed internally using thesigma()method.Examples
The following example shows how to implement a classical EDM preconditioner. For EDM, there is no need to implement the
sigma()method since \(\sigma(t) = t\) (noise level and diffusion time are the same).We first define a simple model to wrap:
>>> import torch >>> from tensordict import TensorDict >>> from physicsnemo.nn import Module >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.channels = channels ... self.net = torch.nn.Conv2d(channels, channels, 1) ... ... def forward(self, x, t, condition=None): ... return self.net(x)
Now we define the EDM preconditioner:
>>> from physicsnemo.diffusion.preconditioners import ( ... BaseAffinePreconditioner, ... ) >>> class SimpleEDMPreconditioner(BaseAffinePreconditioner): ... def __init__(self, model, sigma_data: float = 0.5): ... super().__init__(model) ... self.sigma_data = sigma_data ... ... def compute_coefficients(self, t: torch.Tensor): ... # For EDM sigma(t) = t, so the argument passed to ... # compute_coefficients is already sigma(t) ... sigma_data = self.sigma_data ... c_skip = sigma_data**2 / (t**2 + sigma_data**2) ... c_out = t * sigma_data / (t**2 + sigma_data**2).sqrt() ... c_in = 1 / (sigma_data**2 + t**2).sqrt() ... c_noise = t.log() / 4 ... return c_in, c_noise, c_out, c_skip ... >>> model = SimpleModel(channels=3) >>> precond = SimpleEDMPreconditioner(model, sigma_data=0.5) >>> x = torch.randn(2, 3, 16, 16) >>> t = torch.rand(2) >>> condition = TensorDict({}, batch_size=[2]) >>> out = precond(x, t, condition) >>> out.shape torch.Size([2, 3, 16, 16])
The following example shows how to override the
sigma()method to implement a Variance Exploding (VE) preconditioner where \(\sigma(t) = \sqrt{t}\).>>> class VEPreconditioner(BaseAffinePreconditioner): ... def __init__(self, model): ... super().__init__(model) ... ... def sigma(self, t: torch.Tensor) -> torch.Tensor: ... # Override sigma for VE time-to-noise-level mapping ... return t.sqrt() ... ... def compute_coefficients(self, sigma: torch.Tensor): ... # Here the argument passed to compute_coefficients is ... # sigma(t) = sqrt(t) due to override of the sigma method ... # due to override of the sigma method ... c_skip = torch.ones_like(sigma) ... c_out = sigma ... c_in = torch.ones_like(sigma) ... c_noise = (0.5 * sigma).log() ... return c_in, c_noise, c_out, c_skip ... >>> precond_ve = VEPreconditioner(model) >>> out_ve = precond_ve(x, t, condition) >>> out_ve.shape torch.Size([2, 3, 16, 16])
Wrapping existing models to satisfy the DiffusionModel interface
Some models in PhysicsNeMo have signatures that differ from the
DiffusionModelinterface. Below are examples showing how to write thin wrappers to make them compatible with preconditioners, including image-based conditioning via channel concatenation.Example: Wrapping SongUNet
The
SongUNetmodel has the signatureforward(x, noise_labels, class_labels, augment_labels). We wrap it to matchforward(x, t, condition), whereconditioncontains both class labels (1D vector) and an image to concatenate channel-wise:>>> from physicsnemo.models.diffusion_unets import SongUNet >>> from physicsnemo.diffusion import DiffusionModel >>> from tensordict import TensorDict >>> class SongUNetWrapper(Module): ... def __init__(self, img_channels, cond_channels, label_dim, **kwargs): ... super().__init__() ... # in_channels = img_channels + cond_channels for concatenation ... self.net = SongUNet( ... in_channels=img_channels + cond_channels, ... out_channels=img_channels, ... label_dim=label_dim, ... **kwargs, ... ) ... ... def forward(self, x, t, condition): ... # Concatenate image condition "y" channel-wise to input ... y = condition["y"] # shape: (B, C_cond, H, W) ... x_cat = torch.cat([x, y], dim=1) ... # Extract 1D vector condition for class_labels ... class_labels = condition["class_labels"] # shape: (B, label_dim) ... return self.net(x_cat, noise_labels=t, class_labels=class_labels) ... >>> wrapped = SongUNetWrapper( ... img_channels=2, cond_channels=1, label_dim=4, img_resolution=8 ... ) >>> isinstance(wrapped, DiffusionModel) True >>> x = torch.rand(1, 2, 8, 8) >>> t = torch.rand(1) >>> condition = TensorDict({ ... "y": torch.rand(1, 1, 8, 8), # image condition ... "class_labels": torch.rand(1, 4), # 1D vector condition ... }, batch_size=[1]) >>> out = wrapped(x, t, condition) >>> out.shape torch.Size([1, 2, 8, 8])
Example: Wrapping DiT
The
DiTmodel has the signatureforward(x, t, condition, ...). We wrap it to support both image conditioning (via channel concatenation) and vector conditioning:>>> from physicsnemo.models.dit import DiT >>> class DiTWrapper(Module): ... def __init__(self, img_channels, cond_channels, cond_dim, **kwargs): ... super().__init__() ... # in_channels = img_channels + cond_channels for concatenation ... self.net = DiT( ... in_channels=img_channels + cond_channels, ... out_channels=img_channels, ... condition_dim=cond_dim, ... **kwargs, ... ) ... ... def forward(self, x, t, condition): ... # Concatenate image condition "y" channel-wise to input ... y = condition["y"] # shape: (B, C_cond, H, W) ... x_cat = torch.cat([x, y], dim=1) ... # Extract 1D vector condition ... vec = condition["vec"] # shape: (B, cond_dim) ... return self.net(x_cat, t, condition=vec) ... >>> wrapped_dit = DiTWrapper( ... img_channels=2, cond_channels=1, cond_dim=4, ... input_size=8, patch_size=4, attention_backend="timm", ... ) >>> isinstance(wrapped_dit, DiffusionModel) True >>> x = torch.rand(1, 2, 8, 8) >>> t = torch.rand(1) >>> condition = TensorDict({ ... "y": torch.rand(1, 1, 8, 8), # image condition ... "vec": torch.rand(1, 4), # 1D vector condition ... }, batch_size=[1]) >>> out = wrapped_dit(x, t, condition) >>> out.shape torch.Size([1, 2, 8, 8])
Example: Using ConcatConditionWrapper with a preconditioner
The pattern in the previous example, where (spatially-varying) conditioning is concatenated to the noised latent state (and possibly vector conditioning is also passed as a separate argument) is common across several diffusion use-cases. Thus for convenience, we provide the wrapper class
ConcatConditionWrapperto save you the trouble of writing your own wrapper for this common pattern:>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner >>> from physicsnemo.diffusion.utils import ConcatConditionWrapper >>> base_model = SongUNet(img_resolution=8, in_channels=4, out_channels=3, label_dim=4) >>> wrapped_model = ConcatConditionWrapper(base_model) >>> precond = EDMPreconditioner(wrapped_model, sigma_data=0.5) >>> x = torch.rand(1, 3, 8, 8) >>> t = torch.rand(1) >>> condition = TensorDict({ ... "cond_concat": torch.rand(1, 1, 8, 8), # image condition ... "cond_vec": torch.rand(1, 4), # vector condition ... }, batch_size=[1]) >>> out = precond(x, t, condition) >>> out.shape torch.Size([1, 3, 8, 8])
The same wrapper can be used with
DiTbackbones as well, withcond_vecpassed to the model’sconditionargument.- abstractmethod compute_coefficients(
- t: Tensor,
- /,
Compute the preconditioning coefficients for a given diffusion time \(t\) or noise level \(\sigma\).
This abstract method must be implemented by subclasses to define the specific preconditioning scheme.
- Parameters:
t (torch.Tensor) – Diffusion time (or noise level if
sigma()is implemented) tensor of shape \((B, 1, ..., 1)\) where \(B\) is the batch size and the trailing singleton dimensions match the spatial dimensions of the latent statexfor broadcasting.- Returns:
c_in (torch.Tensor) – Input scaling coefficient of shape \((B, 1, ..., 1)\).
c_noise (torch.Tensor) – Noise conditioning value of shape \((B, 1, ..., 1)\).
c_out (torch.Tensor) – Output scaling coefficient of shape \((B, 1, ..., 1)\).
c_skip (torch.Tensor) – Skip connection scaling coefficient of shape \((B, 1, ..., 1)\).
- sigma(t: Tensor) Tensor[source]#
Map diffusion time \(t\) to noise level \(\sigma(t)\).
By default, this is the identity function \(\sigma(t) = t\). Subclasses can override this to implement preconditioners for different time-to-noise-level mappings.
When overridden, the output of this method is passed to
compute_coefficients()instead of the raw timet.- Parameters:
t (torch.Tensor) – Diffusion time tensor of shape \((B,)\) where \(B\) is the batch size.
- Returns:
Noise level \(\sigma(t)\) of shape \((B,)\).
- Return type:
torch.Tensor
EDMPreconditioner#
- class physicsnemo.diffusion.preconditioners.EDMPreconditioner(*args, **kwargs)[source]#
Bases:
BaseAffinePreconditionerEDM preconditioner.
Implements the improved preconditioning scheme proposed in the EDM paper.
For EDM, the time-to-noise-level mapping is the identity: \(\sigma(t) = t\).
The preconditioning coefficients are:
\[\begin{split}c_{\text{skip}} &= \frac{\sigma_{\text{data}}^2} {\sigma^2 + \sigma_{\text{data}}^2} \\ c_{\text{out}} &= \frac{\sigma \cdot \sigma_{\text{data}}} {\sqrt{\sigma^2 + \sigma_{\text{data}}^2}} \\ c_{\text{in}} &= \frac{1} {\sqrt{\sigma_{\text{data}}^2 + \sigma^2}} \\ c_{\text{noise}} &= \frac{\log(\sigma)}{4}\end{split}\]With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.
For training, it is usually paired with
MSEDSMLoss(prediction_type="x0") andEDMNoiseScheduler.- Parameters:
model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in
BaseAffinePreconditioner.sigma_data (float, optional) – Expected standard deviation of the training data, by default 0.5.
- Forward:
x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.
t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).
condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of
x. PassNonefor an unconditional model.**model_kwargs (Any) – Additional keyword arguments passed to the underlying model.
- Outputs:
torch.Tensor – Preconditioned model output with the same shape as the original model output.
Examples
>>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = EDMPreconditioner(model, sigma_data=0.5) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16])
- compute_coefficients(
- t: Tensor,
Compute EDM preconditioning coefficients.
- Parameters:
t (torch.Tensor) – Diffusion time (or noise level, since they are identical for EDM) of shape \((B, 1, ..., 1)\).
- Returns:
Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
VEPreconditioner#
- class physicsnemo.diffusion.preconditioners.VEPreconditioner(*args, **kwargs)[source]#
Bases:
BaseAffinePreconditionerVariance Exploding (VE) preconditioner.
Implements the preconditioning scheme from the VE formulation of score-based generative models.
For VE, the time-to-noise-level mapping is the identity: \(\sigma(t) = t\).
The preconditioning coefficients are:
\[\begin{split}c_{\text{skip}} &= 1 \\ c_{\text{out}} &= \sigma \\ c_{\text{in}} &= 1 \\ c_{\text{noise}} &= \log(0.5 \cdot \sigma)\end{split}\]With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.
For training, it is usually paired with
MSEDSMLoss(prediction_type="x0") andVENoiseScheduler.- Parameters:
model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in
BaseAffinePreconditioner.- Forward:
x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.
t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).
condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of
x. PassNonefor an unconditional model.**model_kwargs (Any) – Additional keyword arguments passed to the underlying model.
- Outputs:
torch.Tensor – Preconditioned model output with the same shape as the original model output.
Examples
>>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = VEPreconditioner(model) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16])
- compute_coefficients(
- t: Tensor,
Compute VE preconditioning coefficients.
- Parameters:
t (torch.Tensor) – Diffusion time tensor of shape \((B, 1, ..., 1)\).
- Returns:
Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
VPPreconditioner#
- class physicsnemo.diffusion.preconditioners.VPPreconditioner(*args, **kwargs)[source]#
Bases:
BaseAffinePreconditionerVariance Preserving (VP) preconditioner.
Implements the preconditioning scheme from the VP formulation of score-based generative models.
The time-to-noise-level mapping is:
\[\sigma(t) = \sqrt{\exp\left(\frac{\beta_d}{2} t^2 + \beta_{\min} t\right) - 1}\]The preconditioning coefficients are:
\[\begin{split}c_{\text{skip}} &= 1 \\ c_{\text{out}} &= -\sigma \\ c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ c_{\text{noise}} &= (M - 1) \cdot \sigma^{-1}(\sigma)\end{split}\]With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.
For training, it is usually paired with
MSEDSMLoss(prediction_type="x0") andVPNoiseScheduler.- Parameters:
model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in
BaseAffinePreconditioner.beta_d (float, optional) – Extent of the noise level schedule, by default 19.9.
beta_min (float, optional) – Initial slope of the noise level schedule, by default 0.1.
M (int, optional) – Number of discretization steps in the DDPM formulation, by default 1000.
- Forward:
x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.
t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).
condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of
x. PassNonefor an unconditional model.**model_kwargs (Any) – Additional keyword arguments passed to the underlying model.
- Outputs:
torch.Tensor – Preconditioned model output with the same shape as the original model output.
Examples
>>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = VPPreconditioner(model, beta_d=19.9, beta_min=0.1, M=1000) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16])
- compute_coefficients(
- sigma: Tensor,
Compute VP preconditioning coefficients.
- Parameters:
sigma (torch.Tensor) – Noise level tensor of shape \((B, 1, ..., 1)\).
- Returns:
Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
IDDPMPreconditioner#
- class physicsnemo.diffusion.preconditioners.IDDPMPreconditioner(*args, **kwargs)[source]#
Bases:
BaseAffinePreconditionerImproved DDPM (iDDPM) preconditioner.
Implements the preconditioning scheme from the improved DDPM formulation.
The preconditioning coefficients are:
\[\begin{split}c_{\text{skip}} &= 1 \\ c_{\text{out}} &= -\sigma \\ c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ c_{\text{noise}} &= M - 1 - \text{argmin}|\sigma - u_j|\end{split}\]where \(u_j, j = 0, ..., M\) are the precomputed noise levels from the iDDPM discretization.
With these coefficients, the preconditioned model output is expected to be an \(\mathbf{x}_0\)-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others.
For training, it is usually paired with
MSEDSMLoss(prediction_type="x0") andIDDPMNoiseScheduler.- Parameters:
model (physicsnemo.Module) – The underlying neural network model to wrap with signature described in
BaseAffinePreconditioner.C_1 (float, optional) – Timestep adjustment at low noise levels, by default 0.001.
C_2 (float, optional) – Timestep adjustment at high noise levels, by default 0.008.
M (int, optional) – Number of discretization steps in the DDPM formulation, by default 1000.
- Forward:
x (torch.Tensor) – Noisy latent state of shape \((B, *)\) where \(B\) is the batch size and \(*\) denotes any number of additional dimensions.
t (torch.Tensor) – Diffusion time tensor of shape \((B,)\).
condition (torch.Tensor, TensorDict, or None, optional, default=None) – Single Tensor or a TensorDict containing conditioning tensors with batch size \(B\) matching that of
x. PassNonefor an unconditional model.**model_kwargs (Any) – Additional keyword arguments passed to the underlying model.
- Outputs:
torch.Tensor – Preconditioned model output with the same shape as the original model output.
Note
Reference: Improved Denoising Diffusion Probabilistic Models
Examples
>>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = IDDPMPreconditioner(model, C_1=0.001, C_2=0.008, M=1000) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16])
- compute_coefficients(
- t: Tensor,
Compute iDDPM preconditioning coefficients.
- Parameters:
t (torch.Tensor) – Diffusion time tensor of shape \((B, 1, ..., 1)\).
- Returns:
Preconditioning coefficients (\(c_{\text{in}}\), \(c_{\text{noise}}\), \(c_{\text{out}}\), \(c_{\text{skip}}\)) of shape \((B, 1, ..., 1)\).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]