Source code for physicsnemo.diffusion.preconditioners.preconditioners

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, Tuple

import torch
from jaxtyping import Float
from tensordict import TensorDict
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate

from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module


def _ensure_plain_tensor(t: torch.Tensor) -> torch.Tensor:
    """Unwrap fully-replicated DTensors to plain tensors.

    FSDP wraps registered buffers as DTensors with ``Replicate`` placement.
    Arithmetic in ``compute_coefficients`` propagates that wrapping to the
    resulting coefficient tensors.  These must be converted back to plain
    tensors before element-wise operations with non-DTensor inputs (plain
    tensors or ``ShardTensor``).

    Only DTensors whose *every* placement is ``Replicate`` are unwrapped;
    any tensor with a ``Shard`` placement (e.g. a ``ShardTensor``) is left
    untouched so its sharding metadata is preserved.
    """
    if isinstance(t, DTensor) and all(isinstance(p, Replicate) for p in t.placements):
        return t.to_local()
    return t


def _replicate_on_mesh(t: torch.Tensor, mesh) -> torch.Tensor:
    """Promote a plain tensor to a replicated ``DTensor`` on *mesh*.

    Already-distributed tensors are returned unchanged.  Used after
    :func:`_ensure_plain_tensor` to re-wrap coefficients for element-wise
    arithmetic with ``ShardTensor`` data in domain-parallel training.
    """
    if isinstance(t, DTensor):
        return t
    return DTensor.from_local(t, device_mesh=mesh, placements=[Replicate()])


[docs] class BaseAffinePreconditioner(Module, ABC): r""" Abstract base class for diffusion model preconditioners using an affine transformation. This class provides a standardized interface for implementing preconditioners that use affine transformations of the model input and output. The preconditioner wraps a neural network model :math:`F` and applies a preconditioning formula to transform the network output to produce the preconditioned output :math:`D(\mathbf{x}, t)` according to: .. math:: D(\mathbf{x}, t) = c_{\text{skip}}(t) \mathbf{x} + c_{\text{out}}(t) F(c_{\text{in}}(t) \mathbf{x}, c_{\text{noise}}(t)) where: - :math:`c_{\text{in}}(t)`: Input scaling coefficient - :math:`c_{\text{noise}}(t)`: Noise conditioning value - :math:`c_{\text{out}}(t)`: Output scaling coefficient - :math:`c_{\text{skip}}(t)`: Skip connection scaling coefficient and where :math:`\mathbf{x}` is the latent state and :math:`t` is the diffusion time. The wrapped model :math:`F` must be an instance of :class:`~physicsnemo.core.Module` that satisfies the :class:`~physicsnemo.diffusion.DiffusionModel` interface, with the following signature: .. code-block:: python model( x: torch.Tensor, # Shape: (B, *) t: torch.Tensor, # Shape: (B,) condition: torch.Tensor | TensorDict | None = None, **model_kwargs: Any, ) -> torch.Tensor # Shape: (B, *) The preconditioner is agnostic to the prediction target of the wrapped model :math:`F`. The same preconditioning formula is applied regardless of whether the model is an :math:`\mathbf{x}_0`-predictor, an :math:`\epsilon`-predictor, a score predictor, or a :math:`\mathbf{v}`-predictor. .. note:: The preconditioner itself also satisfies the :class:`~physicsnemo.diffusion.DiffusionModel` interface, meaning it does not change the signature of the wrapped model :math:`F`, and it can be used anywhere a diffusion model is expected. Parameters ---------- model : physicsnemo.Module The underlying neural network model :math:`F` to wrap with the signature described above. meta : ModelMetaData, optional Meta data class for storing info regarding model, by default None. Subclasses can pass their own metadata. Forward ------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions. t : torch.Tensor Diffusion time tensor of shape :math:`(B,)`. condition : torch.Tensor, TensorDict, or None, optional, default=None Single Tensor or a TensorDict containing conditioning tensors with batch size :math:`B` matching that of ``x``. Pass ``None`` for an unconditional model. **model_kwargs : Any Additional keyword arguments passed to the underlying model. Outputs ------- torch.Tensor Preconditioned model output with the same shape as the original model output. Note ---- To implement a new preconditioner, a subclass of :class:`BaseAffinePreconditioner` must be defined, and some methods have to be implemented: - Subclasses must implement the :meth:`compute_coefficients` method to define the specific preconditioning scheme. - A :meth:`sigma` method can optionally be implemented. If a subclass implements the :meth:`sigma` method, the diffusion time :math:`t` is first transformed to a noise level :math:`\sigma(t)` before being passed to :meth:`compute_coefficients`. This allows implementing preconditioners for different time-to-noise-level mappings while keeping the same preconditioning interface, in particular for preconditioning schemes based on noise level (that is :math:`c_{\text{in}}(\sigma)`, :math:`c_{\text{noise}}(\sigma)`, :math:`c_{\text{out}}(\sigma)`, :math:`c_{\text{skip}}(\sigma)` instead of :math:`c_{\text{in}}(t)`, :math:`c_{\text{noise}}(t)`, :math:`c_{\text{out}}(t)`, :math:`c_{\text{skip}}(t)`). - The ``forward`` method of the preconditioner *should not* be overriden. The argument ``t`` of the preconditioner forward method is always assumed to be the diffusion time. For preconditioning schemes based on noise level, the noise level :math:`\sigma(t)` is computed internally using the :meth:`sigma` method. Examples -------- The following example shows how to implement a classical EDM preconditioner. For EDM, there is no need to implement the :meth:`sigma` method since :math:`\sigma(t) = t` (noise level and diffusion time are the same). We first define a simple model to wrap: >>> import torch >>> from tensordict import TensorDict >>> from physicsnemo.nn import Module >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.channels = channels ... self.net = torch.nn.Conv2d(channels, channels, 1) ... ... def forward(self, x, t, condition=None): ... return self.net(x) Now we define the EDM preconditioner: >>> from physicsnemo.diffusion.preconditioners import ( ... BaseAffinePreconditioner, ... ) >>> class SimpleEDMPreconditioner(BaseAffinePreconditioner): ... def __init__(self, model, sigma_data: float = 0.5): ... super().__init__(model) ... self.sigma_data = sigma_data ... ... def compute_coefficients(self, t: torch.Tensor): ... # For EDM sigma(t) = t, so the argument passed to ... # compute_coefficients is already sigma(t) ... sigma_data = self.sigma_data ... c_skip = sigma_data**2 / (t**2 + sigma_data**2) ... c_out = t * sigma_data / (t**2 + sigma_data**2).sqrt() ... c_in = 1 / (sigma_data**2 + t**2).sqrt() ... c_noise = t.log() / 4 ... return c_in, c_noise, c_out, c_skip ... >>> model = SimpleModel(channels=3) >>> precond = SimpleEDMPreconditioner(model, sigma_data=0.5) >>> x = torch.randn(2, 3, 16, 16) >>> t = torch.rand(2) >>> condition = TensorDict({}, batch_size=[2]) >>> out = precond(x, t, condition) >>> out.shape torch.Size([2, 3, 16, 16]) The following example shows how to override the :meth:`sigma` method to implement a Variance Exploding (VE) preconditioner where :math:`\sigma(t) = \sqrt{t}`. >>> class VEPreconditioner(BaseAffinePreconditioner): ... def __init__(self, model): ... super().__init__(model) ... ... def sigma(self, t: torch.Tensor) -> torch.Tensor: ... # Override sigma for VE time-to-noise-level mapping ... return t.sqrt() ... ... def compute_coefficients(self, sigma: torch.Tensor): ... # Here the argument passed to compute_coefficients is ... # sigma(t) = sqrt(t) due to override of the sigma method ... # due to override of the sigma method ... c_skip = torch.ones_like(sigma) ... c_out = sigma ... c_in = torch.ones_like(sigma) ... c_noise = (0.5 * sigma).log() ... return c_in, c_noise, c_out, c_skip ... >>> precond_ve = VEPreconditioner(model) >>> out_ve = precond_ve(x, t, condition) >>> out_ve.shape torch.Size([2, 3, 16, 16]) **Wrapping existing models to satisfy the DiffusionModel interface** Some models in PhysicsNeMo have signatures that differ from the :class:`~physicsnemo.diffusion.DiffusionModel` interface. Below are examples showing how to write thin wrappers to make them compatible with preconditioners, including image-based conditioning via channel concatenation. **Example: Wrapping SongUNet** The :class:`~physicsnemo.models.diffusion_unets.SongUNet` model has the signature ``forward(x, noise_labels, class_labels, augment_labels)``. We wrap it to match ``forward(x, t, condition)``, where ``condition`` contains both class labels (1D vector) and an image to concatenate channel-wise: >>> from physicsnemo.models.diffusion_unets import SongUNet >>> from physicsnemo.diffusion import DiffusionModel >>> from tensordict import TensorDict >>> class SongUNetWrapper(Module): ... def __init__(self, img_channels, cond_channels, label_dim, **kwargs): ... super().__init__() ... # in_channels = img_channels + cond_channels for concatenation ... self.net = SongUNet( ... in_channels=img_channels + cond_channels, ... out_channels=img_channels, ... label_dim=label_dim, ... **kwargs, ... ) ... ... def forward(self, x, t, condition): ... # Concatenate image condition "y" channel-wise to input ... y = condition["y"] # shape: (B, C_cond, H, W) ... x_cat = torch.cat([x, y], dim=1) ... # Extract 1D vector condition for class_labels ... class_labels = condition["class_labels"] # shape: (B, label_dim) ... return self.net(x_cat, noise_labels=t, class_labels=class_labels) ... >>> wrapped = SongUNetWrapper( ... img_channels=2, cond_channels=1, label_dim=4, img_resolution=8 ... ) >>> isinstance(wrapped, DiffusionModel) True >>> x = torch.rand(1, 2, 8, 8) >>> t = torch.rand(1) >>> condition = TensorDict({ ... "y": torch.rand(1, 1, 8, 8), # image condition ... "class_labels": torch.rand(1, 4), # 1D vector condition ... }, batch_size=[1]) >>> out = wrapped(x, t, condition) >>> out.shape torch.Size([1, 2, 8, 8]) **Example: Wrapping DiT** The :class:`~physicsnemo.models.dit.DiT` model has the signature ``forward(x, t, condition, ...)``. We wrap it to support both image conditioning (via channel concatenation) and vector conditioning: >>> from physicsnemo.models.dit import DiT >>> class DiTWrapper(Module): ... def __init__(self, img_channels, cond_channels, cond_dim, **kwargs): ... super().__init__() ... # in_channels = img_channels + cond_channels for concatenation ... self.net = DiT( ... in_channels=img_channels + cond_channels, ... out_channels=img_channels, ... condition_dim=cond_dim, ... **kwargs, ... ) ... ... def forward(self, x, t, condition): ... # Concatenate image condition "y" channel-wise to input ... y = condition["y"] # shape: (B, C_cond, H, W) ... x_cat = torch.cat([x, y], dim=1) ... # Extract 1D vector condition ... vec = condition["vec"] # shape: (B, cond_dim) ... return self.net(x_cat, t, condition=vec) ... >>> wrapped_dit = DiTWrapper( ... img_channels=2, cond_channels=1, cond_dim=4, ... input_size=8, patch_size=4, attention_backend="timm", ... ) >>> isinstance(wrapped_dit, DiffusionModel) True >>> x = torch.rand(1, 2, 8, 8) >>> t = torch.rand(1) >>> condition = TensorDict({ ... "y": torch.rand(1, 1, 8, 8), # image condition ... "vec": torch.rand(1, 4), # 1D vector condition ... }, batch_size=[1]) >>> out = wrapped_dit(x, t, condition) >>> out.shape torch.Size([1, 2, 8, 8]) **Example: Using ConcatConditionWrapper with a preconditioner** The pattern in the previous example, where (spatially-varying) conditioning is concatenated to the noised latent state (and possibly vector conditioning is also passed as a separate argument) is common across several diffusion use-cases. Thus for convenience, we provide the wrapper class :class:`~physicsnemo.diffusion.utils.ConcatConditionWrapper` to save you the trouble of writing your own wrapper for this common pattern: >>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner >>> from physicsnemo.diffusion.utils import ConcatConditionWrapper >>> base_model = SongUNet(img_resolution=8, in_channels=4, out_channels=3, label_dim=4) >>> wrapped_model = ConcatConditionWrapper(base_model) >>> precond = EDMPreconditioner(wrapped_model, sigma_data=0.5) >>> x = torch.rand(1, 3, 8, 8) >>> t = torch.rand(1) >>> condition = TensorDict({ ... "cond_concat": torch.rand(1, 1, 8, 8), # image condition ... "cond_vec": torch.rand(1, 4), # vector condition ... }, batch_size=[1]) >>> out = precond(x, t, condition) >>> out.shape torch.Size([1, 3, 8, 8]) The same wrapper can be used with :class:`~physicsnemo.models.dit.DiT` backbones as well, with ``cond_vec`` passed to the model's ``condition`` argument. """ def __init__( self, model: Module, meta: ModelMetaData | None = None, ) -> None: super().__init__() self.meta = meta self.model = model
[docs] @abstractmethod def compute_coefficients( self, t: torch.Tensor, / ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Compute the preconditioning coefficients for a given diffusion time :math:`t` or noise level :math:`\sigma`. This abstract method must be implemented by subclasses to define the specific preconditioning scheme. Parameters ---------- t : torch.Tensor Diffusion time (or noise level if :meth:`sigma` is implemented) tensor of shape :math:`(B, 1, ..., 1)` where :math:`B` is the batch size and the trailing singleton dimensions match the spatial dimensions of the latent state ``x`` for broadcasting. Returns ------- c_in : torch.Tensor Input scaling coefficient of shape :math:`(B, 1, ..., 1)`. c_noise : torch.Tensor Noise conditioning value of shape :math:`(B, 1, ..., 1)`. c_out : torch.Tensor Output scaling coefficient of shape :math:`(B, 1, ..., 1)`. c_skip : torch.Tensor Skip connection scaling coefficient of shape :math:`(B, 1, ..., 1)`. """ ...
[docs] def sigma(self, t: torch.Tensor) -> torch.Tensor: r""" Map diffusion time :math:`t` to noise level :math:`\sigma(t)`. By default, this is the identity function :math:`\sigma(t) = t`. Subclasses can override this to implement preconditioners for different time-to-noise-level mappings. When overridden, the output of this method is passed to :meth:`compute_coefficients` instead of the raw time ``t``. Parameters ---------- t : torch.Tensor Diffusion time tensor of shape :math:`(B,)` where :math:`B` is the batch size. Returns ------- torch.Tensor Noise level :math:`\sigma(t)` of shape :math:`(B,)`. """ return t
def _validate_input(self, x: torch.Tensor) -> None: """Subclass hook for additional input validation. Called inside the ``torch.compiler.is_compiling()`` guard in :meth:`forward`, so it is skipped during ``torch.compile`` tracing. The default implementation is a no-op. """ def forward( self, x: torch.Tensor, t: torch.Tensor, condition: torch.Tensor | TensorDict | None = None, **model_kwargs: Any, ) -> torch.Tensor: if not torch.compiler.is_compiling(): B = x.shape[0] if t.shape != (B,): raise ValueError( f"Expected t to have shape ({B},) matching batch size of " f"x, but got {t.shape}." ) if isinstance(condition, TensorDict): if condition.batch_size and condition.batch_size[0] != B: raise ValueError( f"Condition TensorDict has batch size {condition.batch_size[0]} " f"but expected {B} to match x." ) elif isinstance(condition, torch.Tensor): if condition.shape[0] != B: raise ValueError( f"Condition tensor has batch size {condition.shape[0]} " f"but expected {B} to match x." ) self._validate_input(x) # Map time step to noise level via sigma method expected_shape = (-1,) + (1,) * (x.ndim - 1) sigma_t = self.sigma(t).reshape(expected_shape) # Compute preconditioning coefficients # Unwrap to plain tensors so that element-wise arithmetic between # coefficients and sampled sigma (may be Replicated in domain-parallel # scenarios) is type-compatible. sigma_t = _ensure_plain_tensor(sigma_t) c_in, c_noise, c_out, c_skip = self.compute_coefficients(sigma_t) # FSDP may convert model buffers (e.g. sigma_data) to DTensors, which # propagates through compute_coefficients. First unwrap to plain # tensors, then re-promote to replicated DTensors on x's mesh when # x is a ShardTensor so that element-wise arithmetic between # coefficients and data is type-compatible. c_in, c_noise, c_out, c_skip = ( _ensure_plain_tensor(c) for c in (c_in, c_noise, c_out, c_skip) ) x_mesh = getattr(x, "device_mesh", None) if x_mesh is not None: c_in, c_noise, c_out, c_skip = ( _replicate_on_mesh(c, x_mesh) for c in (c_in, c_noise, c_out, c_skip) ) # Forward through the underlying model if condition is not None: F_x = self.model( c_in * x, c_noise.flatten(), condition=condition, **model_kwargs, ) else: F_x = self.model( c_in * x, c_noise.flatten(), **model_kwargs, ) D_x = c_skip * x + c_out * F_x return D_x
[docs] class VPPreconditioner(BaseAffinePreconditioner): r""" Variance Preserving (VP) preconditioner. Implements the preconditioning scheme from the VP formulation of score-based generative models. The time-to-noise-level mapping is: .. math:: \sigma(t) = \sqrt{\exp\left(\frac{\beta_d}{2} t^2 + \beta_{\min} t\right) - 1} The preconditioning coefficients are: .. math:: c_{\text{skip}} &= 1 \\ c_{\text{out}} &= -\sigma \\ c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ c_{\text{noise}} &= (M - 1) \cdot \sigma^{-1}(\sigma) With these coefficients, the preconditioned model output is expected to be an :math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others. For training, it is usually paired with :class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss` (``prediction_type="x0"``) and :class:`~physicsnemo.diffusion.noise_schedulers.VPNoiseScheduler`. Parameters ---------- model : physicsnemo.Module The underlying neural network model to wrap with signature described in :class:`BaseAffinePreconditioner`. beta_d : float, optional Extent of the noise level schedule, by default 19.9. beta_min : float, optional Initial slope of the noise level schedule, by default 0.1. M : int, optional Number of discretization steps in the DDPM formulation, by default 1000. Forward ------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions. t : torch.Tensor Diffusion time tensor of shape :math:`(B,)`. condition : torch.Tensor, TensorDict, or None, optional, default=None Single Tensor or a TensorDict containing conditioning tensors with batch size :math:`B` matching that of ``x``. Pass ``None`` for an unconditional model. **model_kwargs : Any Additional keyword arguments passed to the underlying model. Outputs ------- torch.Tensor Preconditioned model output with the same shape as the original model output. Note ---- Reference: `Score-Based Generative Modeling through Stochastic Differential Equations <https://arxiv.org/abs/2011.13456>`_ Examples -------- >>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = VPPreconditioner(model, beta_d=19.9, beta_min=0.1, M=1000) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16]) """ def __init__( self, model: Module, beta_d: float = 19.9, beta_min: float = 0.1, M: int = 1000, ) -> None: super().__init__(model) self.register_buffer("beta_d", torch.tensor(beta_d)) self.register_buffer("beta_min", torch.tensor(beta_min)) self.register_buffer("M", torch.tensor(M))
[docs] def sigma(self, t: torch.Tensor) -> torch.Tensor: r""" Compute :math:`\sigma(t)` for the VP formulation. Parameters ---------- t : torch.Tensor Diffusion time tensor of shape :math:`(B,)`. Returns ------- torch.Tensor Noise level :math:`\sigma(t)` of shape :math:`(B,)`. """ exponent = 0.5 * self.beta_d * (t**2) + self.beta_min * t return (exponent.exp() - 1).sqrt()
[docs] def compute_coefficients( self, sigma: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Compute VP preconditioning coefficients. Parameters ---------- sigma : torch.Tensor Noise level tensor of shape :math:`(B, 1, ..., 1)`. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] Preconditioning coefficients (:math:`c_{\text{in}}`, :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. """ c_skip = torch.ones_like(sigma) c_out = -sigma c_in = 1 / (sigma**2 + 1).sqrt() # Compute t = sigma_inv(sigma) t = ( (self.beta_min**2 + 2 * self.beta_d * (1 + sigma**2).log()).sqrt() - self.beta_min ) / self.beta_d c_noise = (self.M - 1) * t return c_in, c_noise, c_out, c_skip
[docs] class VEPreconditioner(BaseAffinePreconditioner): r""" Variance Exploding (VE) preconditioner. Implements the preconditioning scheme from the VE formulation of score-based generative models. For VE, the time-to-noise-level mapping is the identity: :math:`\sigma(t) = t`. The preconditioning coefficients are: .. math:: c_{\text{skip}} &= 1 \\ c_{\text{out}} &= \sigma \\ c_{\text{in}} &= 1 \\ c_{\text{noise}} &= \log(0.5 \cdot \sigma) With these coefficients, the preconditioned model output is expected to be an :math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others. For training, it is usually paired with :class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss` (``prediction_type="x0"``) and :class:`~physicsnemo.diffusion.noise_schedulers.VENoiseScheduler`. Parameters ---------- model : physicsnemo.Module The underlying neural network model to wrap with signature described in :class:`BaseAffinePreconditioner`. Forward ------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions. t : torch.Tensor Diffusion time tensor of shape :math:`(B,)`. condition : torch.Tensor, TensorDict, or None, optional, default=None Single Tensor or a TensorDict containing conditioning tensors with batch size :math:`B` matching that of ``x``. Pass ``None`` for an unconditional model. **model_kwargs : Any Additional keyword arguments passed to the underlying model. Outputs ------- torch.Tensor Preconditioned model output with the same shape as the original model output. Note ---- Reference: `Score-Based Generative Modeling through Stochastic Differential Equations <https://arxiv.org/abs/2011.13456>`_ Examples -------- >>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = VEPreconditioner(model) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16]) """ def __init__(self, model: Module) -> None: super().__init__(model)
[docs] def compute_coefficients( self, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Compute VE preconditioning coefficients. Parameters ---------- t : torch.Tensor Diffusion time tensor of shape :math:`(B, 1, ..., 1)`. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] Preconditioning coefficients (:math:`c_{\text{in}}`, :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. """ c_skip = torch.ones_like(t) c_out = t c_in = torch.ones_like(t) c_noise = (0.5 * t).log() return c_in, c_noise, c_out, c_skip
[docs] class IDDPMPreconditioner(BaseAffinePreconditioner): r""" Improved DDPM (iDDPM) preconditioner. Implements the preconditioning scheme from the improved DDPM formulation. The preconditioning coefficients are: .. math:: c_{\text{skip}} &= 1 \\ c_{\text{out}} &= -\sigma \\ c_{\text{in}} &= \frac{1}{\sqrt{\sigma^2 + 1}} \\ c_{\text{noise}} &= M - 1 - \text{argmin}|\sigma - u_j| where :math:`u_j, j = 0, ..., M` are the precomputed noise levels from the iDDPM discretization. With these coefficients, the preconditioned model output is expected to be an :math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others. For training, it is usually paired with :class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss` (``prediction_type="x0"``) and :class:`~physicsnemo.diffusion.noise_schedulers.IDDPMNoiseScheduler`. Parameters ---------- model : physicsnemo.Module The underlying neural network model to wrap with signature described in :class:`BaseAffinePreconditioner`. C_1 : float, optional Timestep adjustment at low noise levels, by default 0.001. C_2 : float, optional Timestep adjustment at high noise levels, by default 0.008. M : int, optional Number of discretization steps in the DDPM formulation, by default 1000. Forward ------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions. t : torch.Tensor Diffusion time tensor of shape :math:`(B,)`. condition : torch.Tensor, TensorDict, or None, optional, default=None Single Tensor or a TensorDict containing conditioning tensors with batch size :math:`B` matching that of ``x``. Pass ``None`` for an unconditional model. **model_kwargs : Any Additional keyword arguments passed to the underlying model. Outputs ------- torch.Tensor Preconditioned model output with the same shape as the original model output. Note ---- Reference: `Improved Denoising Diffusion Probabilistic Models <https://arxiv.org/abs/2102.09672>`_ Examples -------- >>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = IDDPMPreconditioner(model, C_1=0.001, C_2=0.008, M=1000) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16]) """ def __init__( self, model: Module, C_1: float = 0.001, C_2: float = 0.008, M: int = 1000, ) -> None: super().__init__(model) self.register_buffer("C_1", torch.tensor(C_1)) self.register_buffer("C_2", torch.tensor(C_2)) self.register_buffer("M", torch.tensor(M)) # Precompute the noise level schedule u_j, j = 0, ..., M u = torch.zeros(M + 1) for j in range(M, 0, -1): angle_j = 0.5 * math.pi * j / M / (C_2 + 1) angle_jm1 = 0.5 * math.pi * (j - 1) / M / (C_2 + 1) alpha_bar_j = math.sin(angle_j) ** 2 alpha_bar_jm1 = math.sin(angle_jm1) ** 2 alpha_ratio = alpha_bar_jm1 / alpha_bar_j u[j - 1] = ((u[j] ** 2 + 1) / max(alpha_ratio, C_1) - 1).sqrt() self.register_buffer("u", u)
[docs] def compute_coefficients( self, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Compute iDDPM preconditioning coefficients. Parameters ---------- t : torch.Tensor Diffusion time tensor of shape :math:`(B, 1, ..., 1)`. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] Preconditioning coefficients (:math:`c_{\text{in}}`, :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, :math:`c_{\text{skip}}`) of shape :math:`(B, 1, ..., 1)`. """ c_skip = torch.ones_like(t) c_out = -t c_in = 1 / (t**2 + 1).sqrt() # Round sigma to nearest index in precomputed schedule u u: torch.Tensor = self.u # type: ignore[assignment] t_flat = t.reshape(1, -1, 1) u_reshaped = u.reshape(1, -1, 1) idx = torch.cdist(t_flat, u_reshaped).argmin(2).reshape(t.shape) c_noise = self.M - 1 - idx return c_in, c_noise, c_out, c_skip
[docs] class EDMPreconditioner(BaseAffinePreconditioner): r""" EDM preconditioner. Implements the improved preconditioning scheme proposed in the EDM paper. For EDM, the time-to-noise-level mapping is the identity: :math:`\sigma(t) = t`. The preconditioning coefficients are: .. math:: c_{\text{skip}} &= \frac{\sigma_{\text{data}}^2} {\sigma^2 + \sigma_{\text{data}}^2} \\ c_{\text{out}} &= \frac{\sigma \cdot \sigma_{\text{data}}} {\sqrt{\sigma^2 + \sigma_{\text{data}}^2}} \\ c_{\text{in}} &= \frac{1} {\sqrt{\sigma_{\text{data}}^2 + \sigma^2}} \\ c_{\text{noise}} &= \frac{\log(\sigma)}{4} With these coefficients, the preconditioned model output is expected to be an :math:`\mathbf{x}_0`-prediction (clean data estimate). This preconditioner is not directly compatible for score-prediction training or others. For training, it is usually paired with :class:`~physicsnemo.diffusion.metrics.losses.MSEDSMLoss` (``prediction_type="x0"``) and :class:`~physicsnemo.diffusion.noise_schedulers.EDMNoiseScheduler`. Parameters ---------- model : physicsnemo.Module The underlying neural network model to wrap with signature described in :class:`BaseAffinePreconditioner`. sigma_data : float or Sequence[float] or torch.Tensor, optional Expected standard deviation of the training data, by default 0.5. When a scalar ``float`` is given, the same value is applied to all channels. When a ``Sequence[float]`` or 1-D ``Tensor`` of length :math:`C` is given, each output channel receives its own preconditioning and :math:`c_{\text{skip}}`, :math:`c_{\text{out}}`, :math:`c_{\text{in}}` become per-channel while :math:`c_{\text{noise}}` remains scalar. The per-channel form should be paired with an :class:`~physicsnemo.diffusion.noise_schedulers.EDMNoiseScheduler` constructed with the same per-channel ``sigma_data``. Forward ------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions. t : torch.Tensor Diffusion time tensor of shape :math:`(B,)`. condition : torch.Tensor, TensorDict, or None, optional, default=None Single Tensor or a TensorDict containing conditioning tensors with batch size :math:`B` matching that of ``x``. Pass ``None`` for an unconditional model. **model_kwargs : Any Additional keyword arguments passed to the underlying model. Outputs ------- torch.Tensor Preconditioned model output with the same shape as the original model output. Note ---- Reference: `Elucidating the Design Space of Diffusion-Based Generative Models <https://arxiv.org/abs/2206.00364>`_ Examples -------- Scalar ``sigma_data`` (default): >>> import torch >>> from physicsnemo.core import Module >>> # Define a simple model satisfying the diffusion model interface >>> class SimpleModel(Module): ... def __init__(self, channels: int): ... super().__init__() ... self.net = torch.nn.Conv2d(channels, channels, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> model = SimpleModel(channels=3) >>> precond = EDMPreconditioner(model, sigma_data=0.5) >>> x = torch.randn(2, 3, 16, 16) # batch of 2 images >>> t = torch.rand(2) # diffusion time for each sample >>> out = precond(x, t, condition=None) >>> out.shape torch.Size([2, 3, 16, 16]) Per-channel ``sigma_data`` for heterogeneous channels (e.g. weather variables with different scales): >>> precond_ch = EDMPreconditioner(model, sigma_data=[0.3, 0.5, 0.7]) >>> out_ch = precond_ch(x, t, condition=None) >>> out_ch.shape torch.Size([2, 3, 16, 16]) """ def __init__( self, model: Module, sigma_data: float | Sequence[float] | Float[torch.Tensor, " C"] = 0.5, ) -> None: super().__init__(model) if isinstance(sigma_data, torch.Tensor): sd = sigma_data.detach().to(dtype=torch.float32).reshape(-1) elif isinstance(sigma_data, Sequence) and not isinstance(sigma_data, str): sd = torch.as_tensor(list(sigma_data), dtype=torch.float32) else: sd = torch.as_tensor(float(sigma_data)) if sd.ndim > 0 and sd.numel() == 1: sd = sd.squeeze() self.register_buffer("sigma_data", sd) # Store the number of channels for per-channel sigma_data so that # validation in forward() can check against x without reading # sigma_data.ndim at runtime (which could cause a graph break). self._sigma_data_channels: int = sd.numel() if sd.ndim > 0 else 0 def _validate_input(self, x: torch.Tensor) -> None: C = self._sigma_data_channels if C > 0 and x.ndim >= 2 and x.shape[1] != C: raise ValueError( f"EDMPreconditioner has per-channel sigma_data with " f"{C} channels, but input x has {x.shape[1]} channels " f"(shape {tuple(x.shape)})." ) def _reshape_sigma_data(self, t: torch.Tensor) -> torch.Tensor: """Reshape ``sigma_data`` for broadcasting against *t*. For scalar ``sigma_data`` this is a no-op. For per-channel ``sigma_data`` of shape ``(C,)`` the result is ``(1, C, 1, ..., 1)`` matching the spatial dims of *t*. """ sd = self.sigma_data if self._sigma_data_channels > 0: sd = sd.view(1, -1, *([1] * (t.ndim - 2))) return sd
[docs] def compute_coefficients( self, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Compute EDM preconditioning coefficients. Parameters ---------- t : torch.Tensor Diffusion time (or noise level, since they are identical for EDM) of shape :math:`(B, 1, ..., 1)`. Returns ------- Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] Preconditioning coefficients (:math:`c_{\text{in}}`, :math:`c_{\text{noise}}`, :math:`c_{\text{out}}`, :math:`c_{\text{skip}}`). When ``sigma_data`` is scalar all coefficients have shape :math:`(B, 1, ..., 1)`. When ``sigma_data`` is per-channel, :math:`c_{\text{in}}`, :math:`c_{\text{out}}`, and :math:`c_{\text{skip}}` have shape :math:`(B, C, 1, ..., 1)` while :math:`c_{\text{noise}}` remains :math:`(B, 1, ..., 1)`. """ sd = self._reshape_sigma_data(t) c_skip = sd**2 / (t**2 + sd**2) c_out = t * sd / (t**2 + sd**2).sqrt() c_in = 1 / (sd**2 + t**2).sqrt() c_noise = t.log() / 4 return c_in, c_noise, c_out, c_skip