Source code for physicsnemo.diffusion.base

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Protocols and type hints for diffusion model interfaces."""

from typing import Any, Protocol, runtime_checkable

import torch
from jaxtyping import Float
from tensordict import TensorDict


[docs] @runtime_checkable class DiffusionModel(Protocol): r""" Protocol defining the common interface for diffusion models. A diffusion model is any neural network or function that transforms a noisy state ``x`` at diffusion time (or noise level) ``t`` into a prediction. This protocol defines the standard interface that all diffusion models must satisfy. Any model or function that implements this interface can be used with preconditioners, losses, samplers, and other diffusion utilities. The interface is **prediction-agnostic**: whether your model predicts clean data (:math:`\mathbf{x}_0`), noise (:math:`\epsilon`), score (:math:`\nabla \log p`), or velocity (:math:`\mathbf{v}`), the signature remains the same. The interface supports both conditional and unconditional diffusion models. The ``condition`` argument supports different conditioning scenarios: - **torch.Tensor**: Use when there is a single conditioning tensor (e.g., a class embedding or a single image). - **TensorDict**: Use when multiple conditioning tensors are needed, possibly with different shapes. The string keys can be used to provide semantic information about each conditioning tensor. - **None**: Use for unconditional generation or specific scenarios like classifier-free guidance where the model should ignore conditioning. Examples -------- >>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion import DiffusionModel >>> >>> class Model: ... def __call__(self, x, t, condition=None, **kwargs): ... return F.relu(x) ... >>> isinstance(Model(), DiffusionModel) True """ def __call__( self, x: Float[torch.Tensor, " B *dims"], t: Float[torch.Tensor, " B"], condition: Float[torch.Tensor, " B *cond_dims"] | TensorDict | None = None, **model_kwargs: Any, ) -> Float[torch.Tensor, " B *dims"]: r""" Forward pass of the diffusion model. Parameters ---------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions (e.g., channels and spatial dimensions). t : torch.Tensor Diffusion time or noise level tensor of shape :math:`(B,)`. condition : torch.Tensor, TensorDict, or None, optional, default=None Conditioning information for the model. If a Tensor or a TensorDict is passed, it should have batch size :math:`B` matching that of ``x``. Pass ``None`` for an unconditional model. **model_kwargs : Any Additional keyword arguments specific to the model implementation. Returns ------- torch.Tensor Model output with the same shape as ``x``. """ ...
[docs] @runtime_checkable class Predictor(Protocol): r""" Protocol defining a predictor interface for diffusion models. A predictor is any callable that takes a noisy state ``x`` and diffusion time ``t``, and returns a prediction about the clean data or the noise. Common types of predictors include x0-predictor (predicts the clean data :math:`\mathbf{x}_0`), score-predictor, noise-predictor (predicts the noise :math:`\boldsymbol{\epsilon}`), velocity-predictor etc. This protocol is **generic** and does not assume any specific type of prediction. A predictor can be a trained neural network, a guidance function (e.g., classifier-free guidance, DPS-style guidance), or any combination thereof. The exact meaning of the output depends on the predictor type and how it is used. Any callable that implements this interface can be used as a predictor in sampling utilities. This protocol is typically used during inference. For training, which often requires additional inputs like conditioning, use the more general :class:`DiffusionModel` protocol instead. A :class:`Predictor` can be obtained from a :class:`DiffusionModel` by partially applying the ``condition`` and any other keyword arguments using ``functools.partial``. **Relationship to Denoiser:** A :class:`Denoiser` is the update function used during sampling (e.g., the right-hand side of an ODE/SDE). It is obtained from a :class:`Predictor` via the :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` factory. A typical case is ODE/SDE-based sampling, where one solves: .. math:: \frac{d\mathbf{x}}{dt} = D(\mathbf{x}, t;\, P(\mathbf{x}, t)) where :math:`P` is the **predictor** and :math:`D` is the **denoiser** that wraps it. This equation captures the essence of how these two concepts are related in the framework. See Also -------- :class:`Denoiser` : The interface for sampling update functions. :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` : Factory to convert a predictor into a denoiser. Examples -------- **Example 1:** Convert a trained conditional model into a predictor using ``functools.partial``: >>> import torch >>> from functools import partial >>> from tensordict import TensorDict >>> from physicsnemo.diffusion import Predictor >>> >>> class MyModel: ... def __call__(self, x, t, condition=None): ... # x0-predictor: returns estimate of clean data ... # (here assumes conditional normal distribution N(x|y)) ... t_bc = t.view(-1, *([1] * (x.ndim - 1))) ... return x / (1 + t_bc**2) + condition["y"] ... >>> model = MyModel() >>> cond = TensorDict({"y": torch.randn(2, 4)}, batch_size=[2]) >>> x0_predictor = partial(model, condition=cond) >>> isinstance(x0_predictor, Predictor) True **Example 2:** Convert the x0-predictor above into a score-predictor (using a simple EDM-like schedule where :math:`\sigma(t) = t` and :math:`\alpha(t) = 1`): >>> def x0_to_score(x0, x_t, t): ... sigma_sq = t.view(-1, 1) ** 2 ... return (x0 - x_t) / sigma_sq >>> >>> def score_predictor(x, t): ... x0_pred = x0_predictor(x, t) ... return x0_to_score(x0_pred, x, t) >>> >>> isinstance(score_predictor, Predictor) True """ def __call__( self, x: Float[torch.Tensor, " B *dims"], t: Float[torch.Tensor, " B"], ) -> Float[torch.Tensor, " B *dims"]: r""" Forward pass of the predictor. Parameters ---------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions (e.g., channels and spatial dimensions). t : torch.Tensor Batched diffusion time tensor of shape :math:`(B,)`. Returns ------- torch.Tensor Prediction output with the same shape as ``x``. The exact meaning depends on the predictor type (x0, score, noise, velocity, etc.). """ ...
[docs] @runtime_checkable class Denoiser(Protocol): r""" Protocol defining a denoiser interface for diffusion model sampling. A denoiser is the **update function** used during sampling. It takes a noisy state ``x`` and diffusion time ``t``, and returns the update term consumed by a :class:`~physicsnemo.diffusion.samplers.solvers.Solver`. For continuous-time methods this is typically the right-hand side of the ODE/SDE, but the interface is generic and can support other sampling methods as well. This is the interface used by :class:`~physicsnemo.diffusion.samplers.solvers.Solver` classes and the :func:`~physicsnemo.diffusion.samplers.sample` function. Any callable that implements this interface can be used as a denoiser. **Important distinction from Predictor:** - A :class:`Predictor` is any callable that outputs a raw prediction (e.g., clean data :math:`\mathbf{x}_0`, score, guidance signal, etc.). - A :class:`Denoiser` is the update function derived from one or more predictors, used directly by the solver during sampling. **Typical workflow:** 1. Start with one or more :class:`Predictor` instances (e.g. trained model) 2. Optionally combine predictors (e.g., conditional + guidance scores) 3. Convert to a :class:`Denoiser` using :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` 4. Pass the denoiser to :func:`~physicsnemo.diffusion.samplers.sample` together with a :class:`~physicsnemo.diffusion.samplers.solvers.Solver` See Also -------- :class:`Predictor` : The interface for raw predictions. :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` : Factory to convert a predictor into a denoiser. :func:`~physicsnemo.diffusion.samplers.sample` : The sampling function that uses this denoiser interface. Examples -------- Manually creating a denoiser from an x0-predictor using a simple EDM-like schedule (:math:`\sigma(t)=t`, :math:`\alpha(t)=1`): >>> import torch >>> from physicsnemo.diffusion import Denoiser >>> >>> # Start from a predictor (x0-predictor) >>> def x0_predictor(x, t): ... t_bc = t.view(-1, *([1] * (x.ndim - 1))) ... return x / (1 + t_bc**2) >>> >>> # Build a denoiser (ODE RHS) from scratch: >>> # score = (x0 - x) / sigma^2, ODE RHS = -0.5 * g^2 * score >>> # For EDM: sigma = t, g^2 = 2*t, so RHS = (x0 - x) / t >>> def my_denoiser(x, t): ... x0 = x0_predictor(x, t) ... t_bc = t.view(-1, *([1] * (x.ndim - 1))) ... return (x0 - x) / t_bc ... >>> isinstance(my_denoiser, Denoiser) True """ def __call__( self, x: Float[torch.Tensor, " B *dims"], t: Float[torch.Tensor, " B"], ) -> Float[torch.Tensor, " B *dims"]: r""" Compute the denoising update at the given state and time. Parameters ---------- x : torch.Tensor Noisy latent state of shape :math:`(B, *)` where :math:`B` is the batch size and :math:`*` denotes any number of additional dimensions (e.g., channels and spatial dimensions). t : torch.Tensor Batched diffusion time tensor of shape :math:`(B,)`. All batch elements in the latent state ``x`` typically share the same diffusion time values, but ``t`` is still required to be a batched tensor. Returns ------- torch.Tensor Denoising update term with the same shape as ``x``. """ ...