Source code for physicsnemo.diffusion.samplers.samplers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Diffusion model sampling interface."""

from typing import Any, Dict, List, Literal

import torch.distributed as dist
from jaxtyping import Float
from torch import Tensor
from torch.distributed.tensor.placement_types import Replicate

from physicsnemo.diffusion.base import Denoiser
from physicsnemo.diffusion.noise_schedulers import NoiseScheduler
from physicsnemo.domain_parallel.shard_tensor import scatter_tensor

from .solvers import (
    EDMStochasticEulerSolver,
    EDMStochasticHeunSolver,
    EulerSolver,
    HeunSolver,
    Solver,
)

SOLVERS: Dict[str, type[Solver]] = {
    "euler": EulerSolver,
    "heun": HeunSolver,
    "edm_stochastic_euler": EDMStochasticEulerSolver,
    "edm_stochastic_heun": EDMStochasticHeunSolver,
}


def _maybe_replicate_timesteps(
    t_steps: Float[Tensor, " N_plus_1"],
    xN: Float[Tensor, " B *dims"],
) -> Float[Tensor, " N_plus_1"]:
    """Replicate ``t_steps`` on the device mesh of ``xN`` when needed.

    If ``xN`` lives on a device mesh (e.g. a ``ShardTensor`` used for domain
    parallelism) but ``t_steps`` does not, this function wraps ``t_steps`` as a
    replicated distributed tensor on the same mesh.  This ensures that solver
    arithmetic between latents and time-step scalars is type-compatible.

    When ``xN`` is a plain tensor, or ``t_steps`` is already on a mesh, this is
    a no-op.
    """
    xN_mesh = getattr(xN, "device_mesh", None)
    if xN_mesh is None or hasattr(t_steps, "device_mesh"):
        return t_steps

    source_rank = dist.get_global_rank(xN_mesh.get_group(), 0)
    return scatter_tensor(
        t_steps,
        source_rank,
        xN_mesh,
        placements=(Replicate(),),
        global_shape=t_steps.shape,
        dtype=t_steps.dtype,
    )


[docs] def sample( denoiser: Denoiser, xN: Float[Tensor, " B *dims"], noise_scheduler: NoiseScheduler, num_steps: int, solver: Literal["euler", "heun", "edm_stochastic_euler", "edm_stochastic_heun"] | Solver = "heun", time_steps: Float[Tensor, " N_plus_1"] | None = None, solver_options: Dict[str, Any] | None = None, time_eval: list[int] | None = None, ) -> Float[Tensor, " B *dims"] | List[Float[Tensor, " B *dims"]]: r""" Generate batched samples from a diffusion model. This interface is quite generic and can be used to generate samples from any reverse diffusion process of the form: .. math:: \mathbf{x}_{n-1} = G (\mathbf{x}_{i \geq n}, t_{i \geq n-1}) This covers both ODE/SDE-based sampling (e.g. VP, VE, EDM) and discrete Markov chain-based sampling (e.g. DDPM). The exact expression of the operator :math:`G` depends on the combination of: - The ``solver``, which determines the numerical method to update the latent state :math:`\mathbf{x}_n` at each time-step. - The ``denoiser``, which can be the right hand side for ODE/SDE-based sampling, the denoised latent state for discrete Markov chain-based sampling, etc. Typically, the update applied is roughly: .. math:: \mathbf{x}_{n-1} = \text{Step}(D(\mathbf{x}_n, t_n); \mathbf{x}_n, t_n, t_{n-1}) where :math:`D` is the ``denoiser`` and :math:`\text{Step}` is the update rule of the solver, implemented by the :meth:`~physicsnemo.diffusion.samplers.solvers.Solver.step` method. Variants are possible by passing more complex solvers and denoisers. The ``solver`` can be specified as a string key (with optional ``solver_options``), or as a pre-configured object implementing the :class:`~physicsnemo.diffusion.samplers.solvers.Solver` interface (in which case ``solver_options`` must be ``None``). The solver must implement a ``step`` method with the following signature: .. code-block:: python def step( self, x: Tensor, # shape: (B, *dims) t_cur: Tensor, # shape: (B,) t_next: Tensor, # shape: (B,) ) -> Tensor: ... # updated x, shape: (B, *dims) Any object that implements the :class:`~physicsnemo.diffusion.samplers.solvers.Solver` interface can be used as a solver. The ``denoiser`` must implement the :class:`~physicsnemo.diffusion.Denoiser` interface, with the following signature: .. code-block:: python def denoiser( x: Tensor, # Noisy latent state, shape (B, *dims) t: Tensor, # Diffusion time, shape (B,) ) -> Tensor: # ODE/SDE RHS, same shape (B, *dims) as x Any object that implements the :class:`~physicsnemo.diffusion.Denoiser` interface can be used as a denoiser. A denoiser is typically obtained from a :class:`~physicsnemo.diffusion.Predictor` using the noise scheduler's :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` factory. Time-steps are generated by the ``noise_scheduler`` using its :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.timesteps` method with the provided ``num_steps``. To use custom time-steps, pass a 1D tensor to ``time_steps`` which will override the schedule's time-steps. Parameters ---------- denoiser : Denoiser A callable that takes ``(x, t)`` and returns the denoising update term with the same shape as the latent state ``xN``. See :class:`~physicsnemo.diffusion.Denoiser` for the expected interface. Typically obtained via the :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` factory, which converts a :class:`~physicsnemo.diffusion.Predictor` (e.g., score-predictor, x0-predictor) into a denoiser. xN : Tensor Initial noisy latent state :math:`\mathbf{x}_N` of shape :math:`(B, *)` where :math:`B` is the batch size. All batch elements share the same diffusion time values. The ``dtype`` and ``device`` of ``xN`` determine the ``dtype`` and ``device`` of the generated samples and any internally created tensors. Can usually be obtained by using :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.init_latents` from a noise scheduler (typically obtained from the same noise scheduler instance passed as the ``noise_scheduler`` argument, but can be different if desired). noise_scheduler : NoiseScheduler The noise scheduler instance used for generating time-steps. The schedule's :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.timesteps` method is called with ``num_steps`` to produce the diffusion time values, unless ``time_steps`` is provided to override them. num_steps : int Number of sampling steps. Passed to the noise scheduler's :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.timesteps` method. Ignored when ``time_steps`` is provided. solver : str | Solver, default="heun" The numerical solver to use. Supports three levels of customizability: **Basic**: Pass a string key to use a built-in solver with default settings. **Moderately advanced**: Pass a string key plus ``solver_options`` to override default solver parameters. **Advanced**: Pass a custom :class:`Solver` instance implementing the :class:`~physicsnemo.diffusion.samplers.solvers.Solver` interface. In this case, ``solver_options`` must be empty. Available string keys: * ``"euler"``: First-order Euler method. Fast but lower quality. See :class:`~physicsnemo.diffusion.samplers.solvers.EulerSolver`. * ``"heun"``: Second-order Heun method. Higher quality but requires two denoiser evaluations per step. See :class:`~physicsnemo.diffusion.samplers.solvers.HeunSolver`. * ``"edm_stochastic_euler"``: First-order stochastic sampler from the EDM paper with configurable noise injection. See :class:`~physicsnemo.diffusion.samplers.solvers.EDMStochasticEulerSolver`. * ``"edm_stochastic_heun"``: Second-order stochastic sampler from the EDM paper with configurable noise injection. See :class:`~physicsnemo.diffusion.samplers.solvers.EDMStochasticHeunSolver`. time_steps : Tensor | None, default=None Optional 1D tensor of shape :math:`(N + 1,)` containing explicit diffusion time values :math:`t_N, t_{N-1}, ..., t_0` in decreasing order. If provided, overrides the time-steps from ``noise_scheduler`` and ``num_steps`` is ignored. To produce a fully denoised latent state :math:`\mathbf{x}_0`, the last element must be :math:`t_0 = 0`. solver_options : Dict[str, Any], default={} Additional options passed to the solver constructor. Only used when ``solver`` is a string; must be empty when ``solver`` is a :class:`Solver` instance. See individual solver classes for available options. time_eval : List[int] | None, default=None Indices of time-steps at which to return intermediate samples. Must contain values in ``range(0, num_steps)`` (or ``range(0, len(time_steps) - 1)`` when ``time_steps`` is provided). If provided, returns a list of tensors. If ``None``, returns only the final denoised latent state :math:`\mathbf{x}_0`. Returns ------- Tensor | List[Tensor] If ``time_eval`` is ``None``, returns the final denoised latent state :math:`\mathbf{x}_0` of shape :math:`(B, *)`. Otherwise, returns a list of tensors :math:`\mathbf{x}_t` of shape :math:`(B, *)` containing latent states at time-step indices specified in ``time_eval``. See Also -------- :mod:`~physicsnemo.diffusion.samplers.solvers` : Available ODE/SDE solvers. :mod:`~physicsnemo.diffusion.noise_schedulers` : Available noise schedules. Examples -------- **Example 1:** Minimal usage. Just provide a denoiser, initial noise, a scheduler, and the number of steps. >>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> # Toy denoiser (in practice, this would be a trained neural network) >>> denoiser = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy denoiser >>> scheduler = EDMNoiseScheduler() >>> xN = torch.randn(2, 3, 8, 8) * 80 # Initial noise scaled by sigma_max >>> x0 = sample(denoiser, xN, scheduler, num_steps=10) >>> x0.shape torch.Size([2, 3, 8, 8]) **Example 2:** Standard pattern using scheduler methods. Use ``init_latents`` to generate initial noise and ``get_denoiser`` to convert a predictor to a denoiser for sampling. >>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> t_steps = scheduler.timesteps(10) >>> tN = t_steps[0].expand(2) # Initial time for batch of 2 >>> >>> # Use scheduler to generate initial latents at time tN >>> xN = scheduler.init_latents((3, 8, 8), tN) >>> >>> # Convert x0-predictor to denoiser (score conversion is automatic) >>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor >>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor) >>> >>> x0 = sample(denoiser, xN, scheduler, num_steps=10) >>> x0.shape torch.Size([2, 3, 8, 8]) **Example 3:** Custom time-steps and solver. Same as Example 2, but using explicit time-steps and the faster (but lower quality) Euler solver. >>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Custom time-steps (fewer steps for faster sampling) >>> custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 0.0]) >>> tN = custom_t[0].expand(2) >>> xN = scheduler.init_latents((3, 8, 8), tN) >>> >>> # Same denoiser setup as Example 2 >>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor >>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor) >>> >>> # Use custom time-steps and Euler solver (num_steps ignored) >>> x0 = sample(denoiser, xN, scheduler, num_steps=0, time_steps=custom_t, ... solver="euler") >>> x0.shape torch.Size([2, 3, 8, 8]) **Example 4:** Bare-bone custom scheduler. Define a scheduler from scratch implementing the :class:`NoiseScheduler` protocol, without importing any built-in scheduler class. >>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> >>> # Define a minimal EDM-like scheduler from scratch >>> class MinimalScheduler: ... def timesteps(self, num_steps, *, device=None, dtype=None): ... return torch.linspace(1.0, 0.0, num_steps + 1, ... device=device, dtype=dtype) ... def sample_time(self, N, *, device=None, dtype=None): ... return torch.rand(N, device=device, dtype=dtype) ... def add_noise(self, x0, time): ... return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0) ... def init_latents(self, spatial_shape, tN, *, device=None, ... dtype=None): ... return tN.view(-1, 1, 1, 1) * torch.randn( ... tN.shape[0], *spatial_shape, device=device, dtype=dtype) ... def get_denoiser(self, *, x0_predictor=None, **kwargs): ... # EDM-like: sigma=t, alpha=1, g^2=2t ... # score = (x0 - x) / t^2, ODE RHS = (x0 - x) / t ... def _denoiser(x, t): ... x0 = x0_predictor(x, t) ... t_bc = t.view(-1, *([1] * (x.ndim - 1))) ... return (x0 - x) / t_bc ... return _denoiser >>> >>> scheduler = MinimalScheduler() >>> tN = torch.tensor([1.0, 1.0]) >>> xN = scheduler.init_latents((3, 8, 8), tN) >>> >>> # x0-predictor -> denoiser via the scheduler factory >>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor >>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor) >>> x0 = sample(denoiser, xN, scheduler, num_steps=10, solver="euler") >>> x0.shape torch.Size([2, 3, 8, 8]) """ if solver_options is None: solver_options = {} # Validate and instantiate solver if isinstance(solver, str): if solver not in SOLVERS: available = ", ".join(f'"{k}"' for k in SOLVERS.keys()) raise ValueError( f"Unknown solver '{solver}'. Available solvers: {available}." ) solver_cls = SOLVERS[solver] solver_ = solver_cls(denoiser, **solver_options) else: # Assume solver is a Solver-like object with a step method if solver_options: raise ValueError( "solver_options must be None when solver is a Solver instance." ) solver_ = solver # Generate time-steps from noise_scheduler or use provided ones if time_steps is not None: t_steps = time_steps.to(device=xN.device, dtype=xN.dtype) else: t_steps = noise_scheduler.timesteps(num_steps, device=xN.device, dtype=xN.dtype) # When xN is a distributed tensor (e.g. ShardTensor for domain # parallelism) but t_steps is a plain tensor, replicate t_steps on the # same mesh so that solver arithmetic between latents and timesteps is # type-compatible. t_steps = _maybe_replicate_timesteps(t_steps, xN) # Main sampling loop samples: List[Tensor] = [] x = xN n_steps = len(t_steps) - 1 # Last element is 0 (final time) if time_eval is not None: out_of_range = [i for i in time_eval if i < 0 or i >= n_steps] if out_of_range: raise ValueError( f"time_eval contains out-of-range indices {out_of_range}; " f"valid indices are in range(0, {n_steps})." ) for i in range(n_steps): t_cur = t_steps[i] t_next = t_steps[i + 1] # Expand t to batch dimension: scalar -> (B,) batch_size = x.shape[0] t_cur_batch = t_cur.expand(batch_size) t_next_batch = t_next.expand(batch_size) # Perform one solver step x = solver_.step(x, t_cur_batch, t_next_batch) # Collect sample if requested if time_eval is not None and i in time_eval: samples.append(x.clone()) # Return based on time_eval if time_eval is not None: return samples return x