# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Diffusion model sampling interface."""
from typing import Any, Dict, List, Literal
import torch.distributed as dist
from jaxtyping import Float
from torch import Tensor
from torch.distributed.tensor.placement_types import Replicate
from physicsnemo.diffusion.base import Denoiser
from physicsnemo.diffusion.noise_schedulers import NoiseScheduler
from physicsnemo.domain_parallel.shard_tensor import scatter_tensor
from .solvers import (
EDMStochasticEulerSolver,
EDMStochasticHeunSolver,
EulerSolver,
HeunSolver,
Solver,
)
SOLVERS: Dict[str, type[Solver]] = {
"euler": EulerSolver,
"heun": HeunSolver,
"edm_stochastic_euler": EDMStochasticEulerSolver,
"edm_stochastic_heun": EDMStochasticHeunSolver,
}
def _maybe_replicate_timesteps(
t_steps: Float[Tensor, " N_plus_1"],
xN: Float[Tensor, " B *dims"],
) -> Float[Tensor, " N_plus_1"]:
"""Replicate ``t_steps`` on the device mesh of ``xN`` when needed.
If ``xN`` lives on a device mesh (e.g. a ``ShardTensor`` used for domain
parallelism) but ``t_steps`` does not, this function wraps ``t_steps`` as a
replicated distributed tensor on the same mesh. This ensures that solver
arithmetic between latents and time-step scalars is type-compatible.
When ``xN`` is a plain tensor, or ``t_steps`` is already on a mesh, this is
a no-op.
"""
xN_mesh = getattr(xN, "device_mesh", None)
if xN_mesh is None or hasattr(t_steps, "device_mesh"):
return t_steps
source_rank = dist.get_global_rank(xN_mesh.get_group(), 0)
return scatter_tensor(
t_steps,
source_rank,
xN_mesh,
placements=(Replicate(),),
global_shape=t_steps.shape,
dtype=t_steps.dtype,
)
[docs]
def sample(
denoiser: Denoiser,
xN: Float[Tensor, " B *dims"],
noise_scheduler: NoiseScheduler,
num_steps: int,
solver: Literal["euler", "heun", "edm_stochastic_euler", "edm_stochastic_heun"]
| Solver = "heun",
time_steps: Float[Tensor, " N_plus_1"] | None = None,
solver_options: Dict[str, Any] | None = None,
time_eval: list[int] | None = None,
) -> Float[Tensor, " B *dims"] | List[Float[Tensor, " B *dims"]]:
r"""
Generate batched samples from a diffusion model.
This interface is quite generic and can be used to generate samples from
any reverse diffusion process of the form:
.. math::
\mathbf{x}_{n-1} = G (\mathbf{x}_{i \geq n}, t_{i \geq n-1})
This covers both ODE/SDE-based sampling (e.g. VP, VE, EDM) and discrete
Markov chain-based sampling (e.g. DDPM). The exact expression of the
operator :math:`G` depends on the combination of:
- The ``solver``, which determines the numerical method to update
the latent state :math:`\mathbf{x}_n` at each time-step.
- The ``denoiser``, which can be the right hand side for ODE/SDE-based
sampling, the denoised latent state for discrete Markov chain-based
sampling, etc.
Typically, the update applied is roughly:
.. math::
\mathbf{x}_{n-1} = \text{Step}(D(\mathbf{x}_n, t_n);
\mathbf{x}_n, t_n, t_{n-1})
where :math:`D` is the ``denoiser`` and :math:`\text{Step}` is the
update rule of the solver, implemented by the
:meth:`~physicsnemo.diffusion.samplers.solvers.Solver.step` method.
Variants are possible by passing more complex solvers and denoisers.
The ``solver`` can be specified as a string key (with optional
``solver_options``), or as a pre-configured object implementing the
:class:`~physicsnemo.diffusion.samplers.solvers.Solver` interface (in
which case ``solver_options`` must be ``None``). The solver must implement
a ``step`` method with the following signature:
.. code-block:: python
def step(
self,
x: Tensor, # shape: (B, *dims)
t_cur: Tensor, # shape: (B,)
t_next: Tensor, # shape: (B,)
) -> Tensor: ... # updated x, shape: (B, *dims)
Any object that implements the
:class:`~physicsnemo.diffusion.samplers.solvers.Solver` interface can be
used as a solver.
The ``denoiser`` must implement the
:class:`~physicsnemo.diffusion.Denoiser` interface, with the following
signature:
.. code-block:: python
def denoiser(
x: Tensor, # Noisy latent state, shape (B, *dims)
t: Tensor, # Diffusion time, shape (B,)
) -> Tensor: # ODE/SDE RHS, same shape (B, *dims) as x
Any object that implements the :class:`~physicsnemo.diffusion.Denoiser`
interface can be used as a denoiser. A denoiser is typically obtained from
a :class:`~physicsnemo.diffusion.Predictor` using the noise scheduler's
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser`
factory.
Time-steps are generated by the ``noise_scheduler`` using its
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.timesteps`
method with the provided ``num_steps``. To use custom time-steps, pass a
1D tensor to ``time_steps`` which will override the schedule's time-steps.
Parameters
----------
denoiser : Denoiser
A callable that takes ``(x, t)`` and returns the denoising update
term with the same shape as the latent state ``xN``. See
:class:`~physicsnemo.diffusion.Denoiser` for the expected interface.
Typically obtained via the
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser`
factory, which converts a :class:`~physicsnemo.diffusion.Predictor`
(e.g., score-predictor, x0-predictor) into a denoiser.
xN : Tensor
Initial noisy latent state :math:`\mathbf{x}_N` of shape :math:`(B, *)`
where :math:`B` is the batch size. All batch elements share the same
diffusion time values. The ``dtype`` and ``device`` of ``xN`` determine
the ``dtype`` and ``device`` of the generated samples and any
internally created tensors. Can usually be obtained by using
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.init_latents`
from a noise scheduler (typically obtained from the same noise scheduler
instance passed as the ``noise_scheduler`` argument, but can be
different if desired).
noise_scheduler : NoiseScheduler
The noise scheduler instance used for generating time-steps. The
schedule's
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.timesteps`
method is called with ``num_steps`` to produce the diffusion time
values, unless ``time_steps`` is provided to override them.
num_steps : int
Number of sampling steps. Passed to the noise scheduler's
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.timesteps`
method. Ignored when ``time_steps`` is provided.
solver : str | Solver, default="heun"
The numerical solver to use. Supports three levels of customizability:
**Basic**: Pass a string key to use a built-in solver
with default settings.
**Moderately advanced**: Pass a string key plus
``solver_options`` to override default solver parameters.
**Advanced**: Pass a custom :class:`Solver` instance
implementing the
:class:`~physicsnemo.diffusion.samplers.solvers.Solver` interface.
In this case, ``solver_options`` must be empty.
Available string keys:
* ``"euler"``: First-order Euler method. Fast but lower quality.
See :class:`~physicsnemo.diffusion.samplers.solvers.EulerSolver`.
* ``"heun"``: Second-order Heun method. Higher quality but requires
two denoiser evaluations per step.
See :class:`~physicsnemo.diffusion.samplers.solvers.HeunSolver`.
* ``"edm_stochastic_euler"``: First-order stochastic sampler from
the EDM paper with configurable noise injection. See
:class:`~physicsnemo.diffusion.samplers.solvers.EDMStochasticEulerSolver`.
* ``"edm_stochastic_heun"``: Second-order stochastic sampler from
the EDM paper with configurable noise injection. See
:class:`~physicsnemo.diffusion.samplers.solvers.EDMStochasticHeunSolver`.
time_steps : Tensor | None, default=None
Optional 1D tensor of shape :math:`(N + 1,)` containing explicit
diffusion time values :math:`t_N, t_{N-1}, ..., t_0` in decreasing
order. If provided, overrides the time-steps from ``noise_scheduler``
and ``num_steps`` is ignored. To produce a fully denoised latent state
:math:`\mathbf{x}_0`, the last element must be :math:`t_0 = 0`.
solver_options : Dict[str, Any], default={}
Additional options passed to the solver constructor. Only used when
``solver`` is a string; must be empty when ``solver`` is a
:class:`Solver` instance. See individual solver classes for available
options.
time_eval : List[int] | None, default=None
Indices of time-steps at which to return intermediate samples. Must
contain values in ``range(0, num_steps)`` (or ``range(0,
len(time_steps) - 1)`` when ``time_steps`` is provided). If provided,
returns a list of tensors. If ``None``, returns only the final
denoised latent state :math:`\mathbf{x}_0`.
Returns
-------
Tensor | List[Tensor]
If ``time_eval`` is ``None``, returns the final denoised latent state
:math:`\mathbf{x}_0` of shape :math:`(B, *)`. Otherwise, returns a list
of tensors :math:`\mathbf{x}_t` of shape :math:`(B, *)` containing
latent states at time-step indices specified in ``time_eval``.
See Also
--------
:mod:`~physicsnemo.diffusion.samplers.solvers` : Available ODE/SDE solvers.
:mod:`~physicsnemo.diffusion.noise_schedulers` : Available noise schedules.
Examples
--------
**Example 1:** Minimal usage. Just provide a denoiser, initial noise, a
scheduler, and the number of steps.
>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> # Toy denoiser (in practice, this would be a trained neural network)
>>> denoiser = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy denoiser
>>> scheduler = EDMNoiseScheduler()
>>> xN = torch.randn(2, 3, 8, 8) * 80 # Initial noise scaled by sigma_max
>>> x0 = sample(denoiser, xN, scheduler, num_steps=10)
>>> x0.shape
torch.Size([2, 3, 8, 8])
**Example 2:** Standard pattern using scheduler methods. Use
``init_latents`` to generate initial noise and ``get_denoiser`` to convert
a predictor to a denoiser for sampling.
>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(2) # Initial time for batch of 2
>>>
>>> # Use scheduler to generate initial latents at time tN
>>> xN = scheduler.init_latents((3, 8, 8), tN)
>>>
>>> # Convert x0-predictor to denoiser (score conversion is automatic)
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>>
>>> x0 = sample(denoiser, xN, scheduler, num_steps=10)
>>> x0.shape
torch.Size([2, 3, 8, 8])
**Example 3:** Custom time-steps and solver. Same as Example 2, but using
explicit time-steps and the faster (but lower quality) Euler solver.
>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>>
>>> # Custom time-steps (fewer steps for faster sampling)
>>> custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 0.0])
>>> tN = custom_t[0].expand(2)
>>> xN = scheduler.init_latents((3, 8, 8), tN)
>>>
>>> # Same denoiser setup as Example 2
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>>
>>> # Use custom time-steps and Euler solver (num_steps ignored)
>>> x0 = sample(denoiser, xN, scheduler, num_steps=0, time_steps=custom_t,
... solver="euler")
>>> x0.shape
torch.Size([2, 3, 8, 8])
**Example 4:** Bare-bone custom scheduler. Define a scheduler from scratch
implementing the :class:`NoiseScheduler` protocol, without importing any
built-in scheduler class.
>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>>
>>> # Define a minimal EDM-like scheduler from scratch
>>> class MinimalScheduler:
... def timesteps(self, num_steps, *, device=None, dtype=None):
... return torch.linspace(1.0, 0.0, num_steps + 1,
... device=device, dtype=dtype)
... def sample_time(self, N, *, device=None, dtype=None):
... return torch.rand(N, device=device, dtype=dtype)
... def add_noise(self, x0, time):
... return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0)
... def init_latents(self, spatial_shape, tN, *, device=None,
... dtype=None):
... return tN.view(-1, 1, 1, 1) * torch.randn(
... tN.shape[0], *spatial_shape, device=device, dtype=dtype)
... def get_denoiser(self, *, x0_predictor=None, **kwargs):
... # EDM-like: sigma=t, alpha=1, g^2=2t
... # score = (x0 - x) / t^2, ODE RHS = (x0 - x) / t
... def _denoiser(x, t):
... x0 = x0_predictor(x, t)
... t_bc = t.view(-1, *([1] * (x.ndim - 1)))
... return (x0 - x) / t_bc
... return _denoiser
>>>
>>> scheduler = MinimalScheduler()
>>> tN = torch.tensor([1.0, 1.0])
>>> xN = scheduler.init_latents((3, 8, 8), tN)
>>>
>>> # x0-predictor -> denoiser via the scheduler factory
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>> x0 = sample(denoiser, xN, scheduler, num_steps=10, solver="euler")
>>> x0.shape
torch.Size([2, 3, 8, 8])
"""
if solver_options is None:
solver_options = {}
# Validate and instantiate solver
if isinstance(solver, str):
if solver not in SOLVERS:
available = ", ".join(f'"{k}"' for k in SOLVERS.keys())
raise ValueError(
f"Unknown solver '{solver}'. Available solvers: {available}."
)
solver_cls = SOLVERS[solver]
solver_ = solver_cls(denoiser, **solver_options)
else:
# Assume solver is a Solver-like object with a step method
if solver_options:
raise ValueError(
"solver_options must be None when solver is a Solver instance."
)
solver_ = solver
# Generate time-steps from noise_scheduler or use provided ones
if time_steps is not None:
t_steps = time_steps.to(device=xN.device, dtype=xN.dtype)
else:
t_steps = noise_scheduler.timesteps(num_steps, device=xN.device, dtype=xN.dtype)
# When xN is a distributed tensor (e.g. ShardTensor for domain
# parallelism) but t_steps is a plain tensor, replicate t_steps on the
# same mesh so that solver arithmetic between latents and timesteps is
# type-compatible.
t_steps = _maybe_replicate_timesteps(t_steps, xN)
# Main sampling loop
samples: List[Tensor] = []
x = xN
n_steps = len(t_steps) - 1 # Last element is 0 (final time)
if time_eval is not None:
out_of_range = [i for i in time_eval if i < 0 or i >= n_steps]
if out_of_range:
raise ValueError(
f"time_eval contains out-of-range indices {out_of_range}; "
f"valid indices are in range(0, {n_steps})."
)
for i in range(n_steps):
t_cur = t_steps[i]
t_next = t_steps[i + 1]
# Expand t to batch dimension: scalar -> (B,)
batch_size = x.shape[0]
t_cur_batch = t_cur.expand(batch_size)
t_next_batch = t_next.expand(batch_size)
# Perform one solver step
x = solver_.step(x, t_cur_batch, t_next_batch)
# Collect sample if requested
if time_eval is not None and i in time_eval:
samples.append(x.clone())
# Return based on time_eval
if time_eval is not None:
return samples
return x