Samplers and Solvers#

The sampler is the main interface for generating new data from a trained diffusion model. Starting from pure noise \(\mathbf{x}_N\), the solver iteratively denoises the latent state through a sequence of time-steps until it reaches a clean sample \(\mathbf{x}_0\).

The central entry point is the sample() function. It takes a Denoiser, an initial noisy latent \(\mathbf{x}_N\), and a NoiseScheduler, and iterates the reverse process to produce samples.

Generic Sampling Process#

The sample() function supports any reverse process that can be written in the form:

\[\mathbf{x}_{n-1} = \text{Step}\bigl( D\!\bigl(\mathbf{x}_n, t_n;\, P(\mathbf{x}_n, t_n)\bigr);\; \mathbf{x}_n, t_n, t_{n-1}\bigr)\]

This equation is the foundation of the sampling process in the framework. Every component described in this page maps to one of the three terms:

  • \(P\) is the predictor — the Predictor that maps the noisy state and diffusion time to a prediction. This is where all model logic lives, including conditioning and guidance.

  • \(D\) is the denoiser — the Denoiser derived from \(P\) via the noise scheduler’s get_denoiser() factory.

  • \(\text{Step}\) is the solver’s numerical update rule.

This generic formulation encompasses standard ODE/SDE-based sampling, but also more advanced methods such as physics-informed posterior guidance (DPS), score-based data assimilation (SDA), and others. Any method that can express its update step through this denoiser/solver decomposition can be used with the sample() function.

Sampling Workflow#

A complete sampling workflow involves these steps:

  1. Load or reference a trained model satisfying the DiffusionModel interface (typically a backbone wrapped in a preconditioner).

  2. Build a Predictor (\(P\) in the sampling equation) by binding the conditioning via functools.partial, converting the three-argument DiffusionModel into a two-argument Predictor.

  3. Convert to a Denoiser (\(D\) in the equation). There are two paths:

    • Without guidance — pass the predictor directly to the noise scheduler’s get_denoiser() factory (as an x0_predictor or score_predictor).

    • With guidance — first instantiate one or more DPSGuidance objects, then combine them with the predictor using DPSScorePredictor to obtain a guided score-predictor. Finally, pass this guided score-predictor to get_denoiser().

  4. Initialize the noisy latent \(\mathbf{x}_N\) and the time-step schedule using the scheduler.

  5. Optionally configure a custom solver (\(\text{Step}\) in the equation) by instantiating a Solver (see Available Solvers), or simply pass a built-in string key (for example, "heun") to sample().

  6. Call sample() to run the reverse diffusion loop.

Example: Unconditional Image Generation#

This example shows the full workflow for an unconditional image model trained with the EDM formulation. It uses SongUNet as the backbone, wrapped with a thin adapter to match the DiffusionModel interface, and EDMPreconditioner for preconditioning.

The noise scheduler must generally be consistent between training and sampling—in particular, the same schedule family (for example, EDM, VP) should be used. Schedule parameters (for example, sigma_min, rho) can be adjusted at sampling time for experimentation, but the model was optimized for the training schedule, so large deviations may degrade sample quality.

import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample

# --- Backbone: wrap SongUNet to match the DiffusionModel interface ---
class UNetBackbone(Module):
    def __init__(self, img_resolution, channels, **kwargs):
        super().__init__()
        self.net = SongUNet(
            img_resolution=img_resolution,
            in_channels=channels,
            out_channels=channels,
            **kwargs,
        )
    def forward(self, x, t, condition=None):
        return self.net(x, noise_labels=t, class_labels=condition)

backbone = UNetBackbone(img_resolution=64, channels=3, model_channels=64,
                        channel_mult=[1, 2, 2], num_blocks=2)

# --- Preconditioner + training (sketch) ---
scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=80.0, rho=7)
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler) ...

# --- Sampling ---
precond.eval()

# Build predictor "P": bind condition=None for unconditional model
x0_predictor = partial(precond, condition=None)

# Convert to denoiser "D"
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)

# Initialize time-steps and noisy latent
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(8)
xN = scheduler.init_latents((3, 64, 64), tN)

# Run sampling loop with Heun solver ("Step")
samples = sample(denoiser, xN, scheduler, num_steps=num_steps, solver="heun")
# samples.shape: (8, 3, 64, 64)

It is also possible to adjust the schedule parameters at sampling time. For instance, one might increase sigma_max or change rho to explore the effect on sample quality:

sampling_scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=120.0, rho=5)
denoiser = sampling_scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = sampling_scheduler.timesteps(num_steps)[0].expand(8)
xN = sampling_scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, sampling_scheduler, num_steps=num_steps)

Example: Vector-Space Diffusion (Non-Image Data)#

The diffusion framework is not limited to image data. Any tensor-valued data can be used, including 1D vectors. Here the backbone uses the FullyConnected model from PhysicsNeMo, wrapped with a thin adapter to match the DiffusionModel interface.

import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.mlp import FullyConnected
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample

# Backbone: wrap FullyConnected to match the DiffusionModel interface
class FCBackbone(Module):
    def __init__(self, dim, hidden=256, num_layers=4):
        super().__init__()
        self.net = FullyConnected(
            in_features=dim, layer_size=hidden,
            out_features=dim, num_layers=num_layers,
        )
    def forward(self, x, t, condition=None):
        return self.net(x)

data_dim = 32
backbone = FCBackbone(dim=data_dim)
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=1.0)
# ... train with MSEDSMLoss(precond, scheduler) ...

# Sampling
precond.eval()
x0_predictor = partial(precond, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(16)
xN = scheduler.init_latents((data_dim,), tN)  # 1D latent: shape (16, 32)
samples = sample(denoiser, xN, scheduler, num_steps=num_steps)
# samples.shape: (16, 32)

Example: Conditional Sampling#

For conditional generation (for example, super-resolution), the model backbone processes both the noisy latent state and the conditioning input. A common pattern is to concatenate the conditioning image along the channel dimension inside a thin adapter. At sampling time, the conditioning is bound into the predictor (\(P\)) via functools.partial.

import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample

C_x, C_cond, res = 3, 3, 64   # Image channels, conditioning channels, resolution

# Backbone: SongUNet wrapped with an adapter that concatenates the
# conditioning image along the channel dimension
class ConditionalUNet(Module):
    def __init__(self):
        super().__init__()
        self.net = SongUNet(
            img_resolution=res,
            in_channels=C_x + C_cond,
            out_channels=C_x,
            model_channels=64,
            channel_mult=[1, 2, 2],
            num_blocks=2,
        )
    def forward(self, x, t, condition=None):
        x_cat = torch.cat([x, condition], dim=1)
        return self.net(x_cat, noise_labels=t, class_labels=None)

backbone = ConditionalUNet()
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler, condition=...) ...

# --- Sampling ---
precond.eval()

# Conditioning image (for example, low-resolution input for super-resolution)
low_res = torch.randn(4, C_cond, res, res)

# Bind condition into the predictor "P"
x0_predictor = partial(precond, condition=low_res)

# Convert to denoiser "D" and sample
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((C_x, res, res), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)
# samples.shape: (4, 3, 64, 64)

Example: Conditional Sampling with DPS Guidance#

DPS (Diffusion Posterior Sampling) guidance steers the sampling toward satisfying observation constraints. Guidance modifies the predictor \(P\) in the sampling equation: the guidance objects are combined with the x0-predictor into a guided score-predictor via DPSScorePredictor, and then converted to a denoiser \(D\) via the noise scheduler.

import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.guidance import DPSScorePredictor, DataConsistencyDPSGuidance
from physicsnemo.diffusion.samplers import sample

scheduler = EDMNoiseScheduler()

# Build predictor "P" from trained conditional model
x0_predictor = partial(trained_model, condition=condition)

# Build guidance objects
mask = torch.zeros(4, 3, 64, 64, dtype=torch.bool)
mask[:, :, ::8, ::8] = True  # Observe every 8th pixel
y_obs = torch.randn(4, 3, 64, 64)
guidance = DataConsistencyDPSGuidance(mask=mask, y=y_obs, std_y=0.1)

# Combine predictor + guidance into a guided score-predictor
guided_score_predictor = DPSScorePredictor(
    x0_predictor=x0_predictor,
    x0_to_score_fn=scheduler.x0_to_score,
    guidances=guidance,
)

# Convert to denoiser "D" via the noise scheduler
denoiser = scheduler.get_denoiser(score_predictor=guided_score_predictor)

tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)

Example: Custom Solver and Custom Time-Steps#

This example shows how to define a solver from scratch by implementing the Solver protocol. Any object with a step(x, t_cur, t_next) method can serve as \(\text{Step}\) in the sampling equation. Here we implement a simple implicit trapezoidal rule (second-order Runge-Kutta), and pair it with custom time-steps and trajectory snapshots.

import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.samplers import sample

# Custom solver: implicit trapezoidal rule (second-order)
class TrapezoidalSolver:
    def __init__(self, denoiser, num_inner_iters=3):
        self.denoiser = denoiser
        self.num_inner_iters = num_inner_iters

    def step(self, x, t_cur, t_next):
        t_cur_bc = t_cur.reshape(-1, *([1] * (x.ndim - 1)))
        t_next_bc = t_next.reshape(-1, *([1] * (x.ndim - 1)))
        h = t_next_bc - t_cur_bc
        d_cur = self.denoiser(x, t_cur)
        # Predictor: Euler step to get initial guess
        x_next = x + h * d_cur
        # Corrector: fixed-point iterations for the implicit trapezoidal rule
        for _ in range(self.num_inner_iters):
            d_next = self.denoiser(x_next, t_next)
            x_next = x + 0.5 * h * (d_cur + d_next)
        return x_next

scheduler = EDMNoiseScheduler()
x0_predictor = partial(trained_model, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)

# Custom time-steps
custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 2.0, 1.0, 0.5, 0.1, 0.0])
tN = custom_t[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)

solver = TrapezoidalSolver(denoiser, num_inner_iters=3)
trajectory = sample(
    denoiser, xN, scheduler,
    num_steps=0,            # Ignored when time_steps is provided
    time_steps=custom_t,
    solver=solver,
    time_eval=[0, 4, 7],    # Collect snapshots at steps 0, 4, 7
)
# trajectory is a list of 3 tensors, each of shape (4, 3, 64, 64)

Available Solvers#

Solvers implement the \(\text{Step}\) operator in the sampling equation. At each iteration, the solver receives the denoiser output \(D(\mathbf{x}_n, t_n)\) and advances the latent state from time \(t_n\) to \(t_{n-1}\).

There are two ways to use solvers:

Built-in solvers can be selected by passing a string key to sample():

  • "euler"EulerSolver. First-order. Fast (one denoiser evaluation per step) but lower quality.

  • "heun"HeunSolver. Second-order. Higher quality but twice as expensive per step.

  • "edm_stochastic_euler"EDMStochasticEulerSolver. First-order with configurable stochastic noise injection.

  • "edm_stochastic_heun"EDMStochasticHeunSolver. Second-order with configurable stochastic noise injection.

Custom solvers can be defined by implementing the Solver protocol: any object with a step(x, t_cur, t_next) method. Pass the instance directly to sample() for full control over the integration method.

Guidance#

Guidance techniques modify the predictor \(P\) in the sampling equation to steer the generated samples toward desired properties, such as consistency with observed data, satisfaction of physical constraints, or other task-specific objectives.

In the framework, guidance operates at the Predictor level: you compose or modify predictors before converting them into a Denoiser \(D\). Concretely, guidance objects implement the DPSGuidance protocol and are combined with an x0-predictor into a guided score-predictor via DPSScorePredictor. The resulting score-predictor is then passed to the noise scheduler’s get_denoiser() factory to obtain a denoiser \(D\) for sampling. The sampler and solver (\(\text{Step}\)) are unchanged—they only see the final denoiser.

The framework provides two ready-to-use guidance implementations:

Custom guidances can be defined by implementing the DPSGuidance protocol—any callable with the signature (x, t, x_0) -> guidance_term. Multiple guidances can be combined by passing a list to DPSScorePredictor.

API Reference#

Sample Entry Point#

sample#

physicsnemo.diffusion.samplers.sample(
denoiser: Denoiser,
xN: Float[Tensor, 'B *dims'],
noise_scheduler: NoiseScheduler,
num_steps: int,
solver: Literal['euler', 'heun', 'edm_stochastic_euler', 'edm_stochastic_heun'] | Solver = 'heun',
time_steps: Float[Tensor, 'N_plus_1'] | None = None,
solver_options: Dict[str, Any] | None = None,
time_eval: list[int] | None = None,
) Float[Tensor, 'B *dims'] | List[Float[Tensor, 'B *dims']][source]#

Generate batched samples from a diffusion model.

This interface is quite generic and can be used to generate samples from any reverse diffusion process of the form:

\[\mathbf{x}_{n-1} = G (\mathbf{x}_{i \geq n}, t_{i \geq n-1})\]

This covers both ODE/SDE-based sampling (e.g. VP, VE, EDM) and discrete Markov chain-based sampling (e.g. DDPM). The exact expression of the operator \(G\) depends on the combination of:

  • The solver, which determines the numerical method to update the latent state \(\mathbf{x}_n\) at each time-step.

  • The denoiser, which can be the right hand side for ODE/SDE-based sampling, the denoised latent state for discrete Markov chain-based sampling, etc.

Typically, the update applied is roughly:

\[\mathbf{x}_{n-1} = \text{Step}(D(\mathbf{x}_n, t_n); \mathbf{x}_n, t_n, t_{n-1})\]

where \(D\) is the denoiser and \(\text{Step}\) is the update rule of the solver, implemented by the step() method. Variants are possible by passing more complex solvers and denoisers.

The solver can be specified as a string key (with optional solver_options), or as a pre-configured object implementing the Solver interface (in which case solver_options must be None). The solver must implement a step method with the following signature:

def step(
    self,
    x: Tensor,      # shape: (B, *dims)
    t_cur: Tensor,   # shape: (B,)
    t_next: Tensor,  # shape: (B,)
) -> Tensor: ...  # updated x, shape: (B, *dims)

Any object that implements the Solver interface can be used as a solver.

The denoiser must implement the Denoiser interface, with the following signature:

def denoiser(
    x: Tensor,  # Noisy latent state, shape (B, *dims)
    t: Tensor,  # Diffusion time, shape (B,)
) -> Tensor: # ODE/SDE RHS, same shape (B, *dims) as x

Any object that implements the Denoiser interface can be used as a denoiser. A denoiser is typically obtained from a Predictor using the noise scheduler’s get_denoiser() factory.

Time-steps are generated by the noise_scheduler using its timesteps() method with the provided num_steps. To use custom time-steps, pass a 1D tensor to time_steps which will override the schedule’s time-steps.

Parameters:
  • denoiser (Denoiser) – A callable that takes (x, t) and returns the denoising update term with the same shape as the latent state xN. See Denoiser for the expected interface. Typically obtained via the get_denoiser() factory, which converts a Predictor (e.g., score-predictor, x0-predictor) into a denoiser.

  • xN (Tensor) – Initial noisy latent state \(\mathbf{x}_N\) of shape \((B, *)\) where \(B\) is the batch size. All batch elements share the same diffusion time values. The dtype and device of xN determine the dtype and device of the generated samples and any internally created tensors. Can usually be obtained by using init_latents() from a noise scheduler (typically from the same noise scheduler instance as the noise_scheduler argument, but can be different if desired).

  • noise_scheduler (NoiseScheduler) – The noise scheduler instance used for generating time-steps. The schedule’s timesteps() method is called with num_steps to produce the diffusion time values, unless time_steps is provided to override them.

  • num_steps (int) – Number of sampling steps. Passed to the noise scheduler’s timesteps() method. Ignored when time_steps is provided.

  • solver (str | Solver, default="heun") –

    The numerical solver to use. Supports three levels of customizability:

    Basic: Pass a string key to use a built-in solver with default settings.

    Moderately advanced: Pass a string key plus solver_options to override default solver parameters.

    Advanced: Pass a custom Solver instance implementing the Solver interface. In this case, solver_options must be empty.

    Available string keys:

    • "euler": First-order Euler method. Fast but lower quality. See EulerSolver.

    • "heun": Second-order Heun method. Higher quality but requires two denoiser evaluations per step. See HeunSolver.

    • "edm_stochastic_euler": First-order stochastic sampler from the EDM paper with configurable noise injection. See EDMStochasticEulerSolver.

    • "edm_stochastic_heun": Second-order stochastic sampler from the EDM paper with configurable noise injection. See EDMStochasticHeunSolver.

  • time_steps (Tensor | None, default=None) – Optional 1D tensor of shape \((N + 1,)\) containing explicit diffusion time values \(t_N, t_{N-1}, ..., t_0\) in decreasing order. If provided, overrides the time-steps from noise_scheduler and num_steps is ignored. To produce a fully denoised latent state \(\mathbf{x}_0\), the last element must be \(t_0 = 0\).

  • solver_options (Dict[str, Any], default={}) – Additional options passed to the solver constructor. Only used when solver is a string; must be empty when solver is a Solver instance. See individual solver classes for available options.

  • time_eval (List[int] | None, default=None) – Indices of time-steps at which to return intermediate samples. If provided, returns a list of tensors. If None, returns only the final denoised latent state \(\mathbf{x}_0\).

Returns:

If time_eval is None, returns the final denoised latent state \(\mathbf{x}_0\) of shape \((B, *)\). Otherwise, returns a list of tensors \(\mathbf{x}_t\) of shape \((B, *)\) containing latent states at time-step indices specified in time_eval.

Return type:

Tensor | List[Tensor]

See also

solvers

Available ODE/SDE solvers.

noise_schedulers

Available noise schedules.

Examples

Example 1: Minimal usage. Just provide a denoiser, initial noise, a scheduler, and the number of steps.

>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> # Toy denoiser (in practice, this would be a trained neural network)
>>> denoiser = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2)  # Toy denoiser
>>> scheduler = EDMNoiseScheduler()
>>> xN = torch.randn(2, 3, 8, 8) * 80  # Initial noise scaled by sigma_max
>>> x0 = sample(denoiser, xN, scheduler, num_steps=10)
>>> x0.shape
torch.Size([2, 3, 8, 8])

Example 2: Standard pattern using scheduler methods. Use init_latents to generate initial noise and get_denoiser to convert a predictor to a denoiser for sampling.

>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>> t_steps = scheduler.timesteps(10)
>>> tN = t_steps[0].expand(2)  # Initial time for batch of 2
>>>
>>> # Use scheduler to generate initial latents at time tN
>>> xN = scheduler.init_latents((3, 8, 8), tN)
>>>
>>> # Convert x0-predictor to denoiser (score conversion is automatic)
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2)  # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>>
>>> x0 = sample(denoiser, xN, scheduler, num_steps=10)
>>> x0.shape
torch.Size([2, 3, 8, 8])

Example 3: Custom time-steps and solver. Same as Example 2, but using explicit time-steps and the faster (but lower quality) Euler solver.

>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>>
>>> # Custom time-steps (fewer steps for faster sampling)
>>> custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 0.0])
>>> tN = custom_t[0].expand(2)
>>> xN = scheduler.init_latents((3, 8, 8), tN)
>>>
>>> # Same denoiser setup as Example 2
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2)  # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>>
>>> # Use custom time-steps and Euler solver (num_steps ignored)
>>> x0 = sample(denoiser, xN, scheduler, num_steps=0, time_steps=custom_t,
...             solver="euler")
>>> x0.shape
torch.Size([2, 3, 8, 8])

Example 4: Bare-bone custom scheduler. Define a scheduler from scratch implementing the NoiseScheduler protocol, without importing any built-in scheduler class.

>>> import torch
>>> from physicsnemo.diffusion.samplers import sample
>>>
>>> # Define a minimal EDM-like scheduler from scratch
>>> class MinimalScheduler:
...     def timesteps(self, num_steps, *, device=None, dtype=None):
...         return torch.linspace(1.0, 0.0, num_steps + 1,
...                               device=device, dtype=dtype)
...     def sample_time(self, N, *, device=None, dtype=None):
...         return torch.rand(N, device=device, dtype=dtype)
...     def add_noise(self, x0, time):
...         return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0)
...     def init_latents(self, spatial_shape, tN, *, device=None,
...                      dtype=None):
...         return tN.view(-1, 1, 1, 1) * torch.randn(
...             tN.shape[0], *spatial_shape, device=device, dtype=dtype)
...     def get_denoiser(self, *, x0_predictor=None, **kwargs):
...         # EDM-like: sigma=t, alpha=1, g^2=2t
...         # score = (x0 - x) / t^2, ODE RHS = (x0 - x) / t
...         def _denoiser(x, t):
...             x0 = x0_predictor(x, t)
...             t_bc = t.view(-1, *([1] * (x.ndim - 1)))
...             return (x0 - x) / t_bc
...         return _denoiser
>>>
>>> scheduler = MinimalScheduler()
>>> tN = torch.tensor([1.0, 1.0])
>>> xN = scheduler.init_latents((3, 8, 8), tN)
>>>
>>> # x0-predictor -> denoiser via the scheduler factory
>>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2)  # Toy x0-predictor
>>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
>>> x0 = sample(denoiser, xN, scheduler, num_steps=10, solver="euler")
>>> x0.shape
torch.Size([2, 3, 8, 8])

Solvers#

Solver#

class physicsnemo.diffusion.samplers.solvers.Solver(*args, **kwargs)[source]#

Protocol defining the interface for diffusion solvers.

A solver implements a numerical method to integrate the diffusion process from a noisy state to a less noisy (or clean) state. Each call to step() advances the state from time t_cur (\(t_n\)) to t_next (\(t_{n-1}\)).

This is the minimal interface required for sampling from a diffusion model, and any object that implements this interface can be used as a solver in sampling utilities.

The update rule applied by the sampler is roughly:

\[\mathbf{x}_{n-1} = \text{Step}(F(\mathbf{x}_n, t_n); \mathbf{x}_n, t_n, t_{n-1})\]

where \(F\) is the denoiser (e.g. the right hand side in the case of ODE/SDE-based sampling, the denoised latent state in the case of discrete Markov chain-based sampling, etc.) and \(\text{Step}\) is the update rule of the solver, implemented by the step() method.

See also

sample()

The sampling function that uses solvers to generate samples.

Examples

>>> import torch
>>> from physicsnemo.diffusion.samplers.solvers import Solver
>>>
>>> class SimpleEuler:
...     def __init__(self, denoiser):
...         self.denoiser = denoiser
...     def step(self, x, t_cur, t_next):
...         d = (x - self.denoiser(x, t_cur)) / t_cur
...         return x + (t_next - t_cur) * d
...
>>> denoiser = lambda x, t: x / (1 + t.view(-1, 1)**2)  # Toy denoiser
>>> solver = SimpleEuler(denoiser)
>>> isinstance(solver, Solver)
True
step(
x: Float[Tensor, 'B *dims'],
t_cur: Float[Tensor, 'B'],
t_next: Float[Tensor, 'B'],
) Float[Tensor, 'B *dims'][source]#

Perform one integration step from t_cur to t_next.

Parameters:
  • x (Tensor) – Current noisy latent state \(\mathbf{x}_{n}\) of shape \((B, *)\) where \(B\) is the batch size.

  • t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).

  • t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).

Returns:

Updated latent state \(\mathbf{x}_{n-1}\) at time t_next, same shape as x.

Return type:

Tensor

EulerSolver#

class physicsnemo.diffusion.samplers.solvers.EulerSolver(denoiser: Denoiser)[source]#

Bases: Solver

First-order Euler solver for diffusion ODEs.

This is a fast solver with one denoiser evaluation per step, but typically produces lower quality samples compared to higher-order methods.

Parameters:

denoiser (Denoiser) – A callable implementing the Denoiser interface. Here it is expected to return the right hand side of the ODE. Typically obtained via get_denoiser(), but any callable with the correct signature can be used.

Examples

>>> import torch
>>> from physicsnemo.diffusion.samplers.solvers import EulerSolver
>>>
>>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2)  # Toy denoiser
>>> solver = EulerSolver(denoiser)
>>> x_t = torch.randn(1, 3, 8, 8)
>>> t_cur = torch.tensor([1.0])
>>> t_next = torch.tensor([0.5])
>>> x_tm1 = solver.step(x_t, t_cur, t_next)
>>> x_tm1.shape
torch.Size([1, 3, 8, 8])
>>> isinstance(solver, Solver)
True
step(
x: Float[Tensor, 'B *dims'],
t_cur: Float[Tensor, 'B'],
t_next: Float[Tensor, 'B'],
) Float[Tensor, 'B *dims'][source]#

Perform one Euler integration step.

Parameters:
  • x (Tensor) – Current noisy latent state \(\mathbf{x}_{n}\) of shape \((B, *)\) where \(B\) is the batch size.

  • t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).

  • t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).

Returns:

Updated latent state \(\mathbf{x}_{n-1}\) at time t_next, same shape as x.

Return type:

Tensor

HeunSolver#

class physicsnemo.diffusion.samplers.solvers.HeunSolver(
denoiser: Denoiser,
alpha: float = 1.0,
)[source]#

Bases: Solver

Second-order Heun solver for diffusion ODEs.

This method requires two denoiser evaluations per step but usually produces higher quality samples than EulerSolver.

Parameters:
  • denoiser (Denoiser) – A callable implementing the Denoiser interface. Here it is expected to return the right hand side of the ODE. Typically obtained via get_denoiser(), but any callable with the correct signature can be used.

  • alpha (float, optional) – Interpolation parameter for the corrector step, must be in (0, 1]. alpha=1 gives the standard Heun method (trapezoidal rule), alpha=0.5 gives the midpoint method. By default 1.

Examples

>>> import torch
>>> from physicsnemo.diffusion.samplers.solvers import HeunSolver
>>>
>>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2)  # Toy denoiser
>>> solver = HeunSolver(denoiser)
>>> x_t = torch.randn(1, 3, 8, 8)
>>> t_cur = torch.tensor([1.0])
>>> t_next = torch.tensor([0.5])
>>> x_tm1 = solver.step(x_t, t_cur, t_next)
>>> x_tm1.shape
torch.Size([1, 3, 8, 8])
step(
x: Float[Tensor, 'B *dims'],
t_cur: Float[Tensor, 'B'],
t_next: Float[Tensor, 'B'],
) Float[Tensor, 'B *dims'][source]#

Perform one Heun integration step.

Parameters:
  • x (Tensor) – Current noisy latent state \(\mathbf{x}_n\) of shape \((B, *)\) where \(B\) is the batch size.

  • t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).

  • t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).

Returns:

Updated latent state \(\mathbf{x}_{n-1}\) at time t_next, same shape as x.

Return type:

Tensor

EDMStochasticEulerSolver#

class physicsnemo.diffusion.samplers.solvers.EDMStochasticEulerSolver(
denoiser: Denoiser,
S_churn: float = 0,
S_min: float = 0,
S_max: float = inf,
S_noise: float = 1,
num_steps: int = 18,
sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
sigma_inv_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
diffusion_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *_']] | None = None,
)[source]#

Bases: Solver

First-order stochastic Euler sampler from the EDM paper.

Implements stochastic sampling with configurable noise injection controlled by the “churn” parameters.

Important

This is not a true SDE solver. It performs ad-hoc noise injection (“churn”) at each step to improve sample diversity, but the underlying integration is still an ODE step. Therefore, the denoiser should return the right-hand side of the ODE, not the SDE.

By default, noise injection is performed directly in time-step space. For linear-Gaussian noise schedules where diffusion time and noise level are not equal (e.g., VP schedule), provide sigma_fn and sigma_inv_fn to apply churn in noise-level space rather than time-step space. Optionally provide diffusion_fn to control the time-dependent magnitude of the injected noise.

def sigma_fn(
    t: Tensor,  # shape: (B,) or broadcastable
) -> Tensor: ...  # noise level, same shape as t

def sigma_inv_fn(
    sigma: Tensor,  # shape: (B,) or broadcastable
) -> Tensor: ...  # diffusion time, same shape as sigma

def diffusion_fn(
    x: Tensor,  # shape: (B, *dims)
    t: Tensor,  # shape: (B,)
) -> Tensor: ...  # g^2(x, t), broadcastable to shape of x
Parameters:
  • denoiser (Denoiser) – A callable implementing the Denoiser interface. Should return the right-hand side of the ODE (not the SDE, since the stochastic noise injection is handled internally by this solver). Typically obtained via get_denoiser() with denoising_type="ode".

  • S_churn (float, optional) – Controls the amount of noise added at each step. Higher values add more stochasticity. By default 0 (deterministic), in which case this solver is equivalent to the deterministic EulerSolver.

  • S_min (float, optional) – Minimum diffusion time (or noise level if sigma_fn and sigma_inv_fn are provided) for applying churn. By default 0.

  • S_max (float, optional) – Maximum diffusion time (or noise level if sigma_fn and sigma_inv_fn are provided) for applying churn. By default float("inf").

  • S_noise (float, optional) – Noise scaling factor. Large values add more noise to the latent state. By default 1.

  • num_steps (int, optional) – Total number of sampling steps, used to scale churn. By default 18.

  • sigma_fn (Callable[[Tensor], Tensor] | None, optional) – Maps time to noise level \(\sigma(t)\). Useful for linear-Gaussian schedules where \(\sigma(t) \neq t\). Typically sigma(). If provided, sigma_inv_fn must also be provided. By default None (identity mapping).

  • sigma_inv_fn (Callable[[Tensor], Tensor] | None, optional) – Maps noise level back to time. Typically sigma_inv(). If provided, sigma_fn must also be provided. By default None (identity mapping).

  • diffusion_fn (Callable[[Tensor, Tensor], Tensor] | None, optional) – Controls the time-dependent magnitude of the injected noise, in addition of the S_noise scaling factor. Typically the squared diffusion coefficient \(g^2(\mathbf{x}, t)\) from the reverse SDE, obtained from diffusion(). By default None (\(g^2 = 2t\)), which corresponds to an EDM-like noise schedule.

Examples

Basic usage with default parameters (noise injection in time-step space):

>>> import torch
>>> from physicsnemo.diffusion.samplers.solvers import (
...     EDMStochasticEulerSolver,
... )
>>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2)  # Toy denoiser
>>> solver = EDMStochasticEulerSolver(denoiser, S_churn=40, num_steps=18)
>>> x_t = torch.randn(1, 3, 8, 8)
>>> t_cur = torch.tensor([1.0])
>>> t_next = torch.tensor([0.5])
>>> x_tm1 = solver.step(x_t, t_cur, t_next)
>>> x_tm1.shape
torch.Size([1, 3, 8, 8])

Using noise scheduler methods for linear-Gaussian schedules where \(\sigma(t) \neq t\) (e.g., VP schedule). The callbacks map between time and noise level, allowing the churn to be applied in noise-level space before converting back to time-step space:

>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler
>>> scheduler = VPNoiseScheduler()
>>> num_steps = 10
>>> solver = EDMStochasticEulerSolver(
...     denoiser,
...     S_churn=40,
...     num_steps=num_steps,
...     sigma_fn=scheduler.sigma,
...     sigma_inv_fn=scheduler.sigma_inv,
...     diffusion_fn=scheduler.diffusion,
... )
>>> x_tm1 = solver.step(x_t, t_cur, t_next)
>>> x_tm1.shape
torch.Size([1, 3, 8, 8])
step(
x: Float[Tensor, 'B *dims'],
t_cur: Float[Tensor, 'B'],
t_next: Float[Tensor, 'B'],
) Float[Tensor, 'B *dims'][source]#

Perform one stochastic Euler sampling step.

Parameters:
  • x (Tensor) – Current noisy latent state \(\mathbf{x}_n\) of shape \((B, *)\) where \(B\) is the batch size.

  • t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).

  • t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).

Returns:

Updated latent state \(\mathbf{x}_{n-1}\) at time t_next, same shape as x.

Return type:

Tensor

EDMStochasticHeunSolver#

class physicsnemo.diffusion.samplers.solvers.EDMStochasticHeunSolver(
denoiser: Denoiser,
alpha: float = 1.0,
S_churn: float = 0,
S_min: float = 0,
S_max: float = inf,
S_noise: float = 1,
num_steps: int = 18,
sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
sigma_inv_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
diffusion_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *_']] | None = None,
)[source]#

Bases: Solver

Second-order stochastic Heun sampler from the EDM paper.

Implements stochastic sampling with configurable noise injection controlled by the “churn” parameters, using a second-order Heun correction step.

Important

This is not a true SDE solver. It performs ad-hoc noise injection (“churn”) at each step to improve sample diversity, but the underlying integration is still an ODE step. Therefore, the denoiser should return the right-hand side of the ODE, not the SDE.

By default, noise injection is performed directly in time-step space. For linear-Gaussian noise schedules where diffusion time and noise level are not equal (e.g., VP schedule), provide sigma_fn and sigma_inv_fn to apply churn in noise-level space rather than time-step space. Optionally provide diffusion_fn to control the time-dependent magnitude of the injected noise.

def sigma_fn(
    t: Tensor,  # shape: (B,) or broadcastable
) -> Tensor: ...  # noise level, same shape as t

def sigma_inv_fn(
    sigma: Tensor,  # shape: (B,) or broadcastable
) -> Tensor: ...  # diffusion time, same shape as sigma

def diffusion_fn(
    x: Tensor,  # shape: (B, *dims)
    t: Tensor,  # shape: (B,)
) -> Tensor: ...  # g^2(x, t), broadcastable to shape of x
Parameters:
  • denoiser (Denoiser) – A callable implementing the Denoiser interface. Should return the right-hand side of the ODE (not the SDE, since the stochastic noise injection is handled internally by this solver). Typically obtained via get_denoiser() with denoising_type="ode".

  • alpha (float, optional) – Interpolation parameter for the corrector step, must be in (0, 1]. alpha=1 gives the standard Heun method (trapezoidal rule), alpha=0.5 gives the midpoint method. By default 1.

  • S_churn (float, optional) – Controls the amount of noise added at each step. Higher values add more stochasticity. By default 0 (deterministic), in which case this solver is equivalent to the deterministic HeunSolver.

  • S_min (float, optional) – Minimum diffusion time (or noise level if sigma_fn and sigma_inv_fn are provided) for applying churn. By default 0.

  • S_max (float, optional) – Maximum diffusion time (or noise level if sigma_fn and sigma_inv_fn are provided) for applying churn. By default float("inf").

  • S_noise (float, optional) – Noise scaling factor. Large values add more noise to the latent state. By default 1.

  • num_steps (int, optional) – Total number of sampling steps, used to scale churn. By default 18.

  • sigma_fn (Callable[[Tensor], Tensor] | None, optional) – Maps time to noise level \(\sigma(t)\). Useful for linear-Gaussian schedules where \(\sigma(t) \neq t\). Typically sigma(). If provided, sigma_inv_fn must also be provided. By default None (identity mapping).

  • sigma_inv_fn (Callable[[Tensor], Tensor] | None, optional) – Maps noise level back to time. Typically sigma_inv(). If provided, sigma_fn must also be provided. By default None (identity mapping).

  • diffusion_fn (Callable[[Tensor, Tensor], Tensor] | None, optional) – Controls the time-dependent magnitude of the injected noise, in addition of the S_noise scaling factor. Typically the squared diffusion coefficient \(g^2(\mathbf{x}, t)\) from the reverse SDE, obtained from diffusion(). By default None (\(g^2 = 2t\)), which corresponds to an EDM-like noise schedule.

Examples

Basic usage with default parameters (noise injection in time-step space):

>>> import torch
>>> from physicsnemo.diffusion.samplers.solvers import (
...     EDMStochasticHeunSolver,
... )
>>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2)  # Toy denoiser
>>> solver = EDMStochasticHeunSolver(denoiser, S_churn=40, num_steps=18)
>>> x_t = torch.randn(1, 3, 8, 8)
>>> t_cur = torch.tensor([1.0])
>>> t_next = torch.tensor([0.5])
>>> x_tm1 = solver.step(x_t, t_cur, t_next)
>>> x_tm1.shape
torch.Size([1, 3, 8, 8])

Using noise scheduler methods for linear-Gaussian schedules where \(\sigma(t) \neq t\) (e.g., VP schedule). The callbacks map between time and noise level, allowing the churn to be applied in noise-level space before converting back to time-step space:

>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler
>>> scheduler = VPNoiseScheduler()
>>> num_steps = 10
>>> solver = EDMStochasticHeunSolver(
...     denoiser,
...     S_churn=40,
...     num_steps=num_steps,
...     sigma_fn=scheduler.sigma,
...     sigma_inv_fn=scheduler.sigma_inv,
...     diffusion_fn=scheduler.diffusion,
... )
>>> x_tm1 = solver.step(x_t, t_cur, t_next)
>>> x_tm1.shape
torch.Size([1, 3, 8, 8])
step(
x: Float[Tensor, 'B *dims'],
t_cur: Float[Tensor, 'B'],
t_next: Float[Tensor, 'B'],
) Float[Tensor, 'B *dims'][source]#

Perform one stochastic Heun sampling step.

Parameters:
  • x (Tensor) – Current noisy latent state \(\mathbf{x}_n\) of shape \((B, *)\) where \(B\) is the batch size.

  • t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).

  • t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).

Returns:

Updated latent state \(\mathbf{x}_{n-1}\) at time t_next, same shape as x.

Return type:

Tensor

Guidance#

DPSGuidance#

class physicsnemo.diffusion.guidance.DPSGuidance(*args, **kwargs)[source]#

Protocol defining the interface for Diffusion Posterior Sampling (DPS) guidance.

A DPS guidance is a callable that computes a guidance term to steer the diffusion sampling process toward satisfying some observation constraint. It returns a quantity analogous to a likelihood score, which is typically added to the unconditional score during sampling.

The typical form is:

\[\rho(t) \nabla_{\mathbf{x}} \ell(A(\hat{\mathbf{x}}_0) - \mathbf{y})\]

where \(\rho(t)\) is a time-dependent guidance strength, \(A\) is a (potentially nonlinear) observation operator, \(\mathbf{y}\) is the observed data, and \(\ell\) is a scalar loss function. However, variants are possible as long as the guidance produces a quantity similar to a score (e.g., a likelihood score).

This is the minimal interface for guidance, and any object that implements this interface can be used with DPSScorePredictor to build a guided score-predictor, which implements the Predictor interface.

See also

DPSScorePredictor

Combines an x0-predictor with one or more guidances.

Examples

Example 1: Minimal guidance for inpainting. Given a binary mask and observed pixels, guide the diffusion to match observations:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSGuidance
>>>
>>> class InpaintingGuidance:
...     def __init__(self, mask, y_obs, gamma=1.0):
...         self.mask = mask  # Binary mask: 1 = observed, 0 = missing
...         self.y_obs = y_obs  # Observed pixel values
...         self.gamma = gamma
...
...     def __call__(self, x, t, x_0):
...         # Compute residual at observed locations
...         residual = self.mask * (x_0 - self.y_obs)
...         # Gradient of L2 loss w.r.t. x_0 is just the residual
...         # (simplified: assumes identity observation operator)
...         return -self.gamma * residual
...
>>> mask = torch.ones(1, 3, 8, 8)
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance = InpaintingGuidance(mask, y_obs)
>>> isinstance(guidance, DPSGuidance)
True

Example 2: Building a guided score predictor from scratch. A common pattern is to combine an x0-predictor with a guidance to create a score predictor that can be used for sampling. This shows the complete workflow:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSGuidance
>>>
>>> # Define a guidance that pushes toward observed values
>>> class MyGuidance:
...     def __init__(self, y_obs, gamma=0.1):
...         self.y_obs = y_obs
...         self.gamma = gamma
...
...     def __call__(self, x, t, x_0):
...         return -self.gamma * (x_0 - self.y_obs)
...
>>> # Toy x0-predictor (in practice, a trained neural network)
>>> x0_predictor = lambda x, t: x * 0.9
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance = MyGuidance(y_obs, gamma=0.5)
>>>
>>> # Build a guided score predictor that combines x0-predictor + guidance
>>> def guided_score_predictor(x, t):
...     x_0 = x0_predictor(x, t)
...     guidance_term = guidance(x, t, x_0)
...     # Convert x0 to score (for EDM: score = (x_0 - x) / t^2)
...     t_bc = t.reshape(-1, *([1] * (x.ndim - 1)))
...     score = (x_0 - x) / (t_bc ** 2)
...     return score + guidance_term
...
>>> # guided_score_predictor is now a Predictor (score predictor); pass it
>>> # to scheduler.get_denoiser(score_predictor=...) to obtain a Denoiser
>>> x = torch.randn(1, 3, 8, 8)
>>> t = torch.tensor([1.0])
>>> output = guided_score_predictor(x, t)
>>> output.shape
torch.Size([1, 3, 8, 8])

Note: DPSScorePredictor provides a convenient way to apply one or more guidances to an x0-predictor without manually implementing the above pattern.

DPSScorePredictor#

class physicsnemo.diffusion.guidance.DPSScorePredictor(
x0_predictor: Predictor,
x0_to_score_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *dims']],
guidances: DPSGuidance | Sequence[DPSGuidance],
)[source]#

Bases: object

Score predictor that combines an x0-predictor with DPS-style guidance.

This class transforms a Predictor (specifically, an x0-predictor) into a score Predictor by applying one or more DPS guidances. The resulting score predictor can be passed to get_denoiser() to obtain a Denoiser for sampling.

The output is the sum of the unconditional score (derived from the x0-prediction) and all guidance terms:

\[\nabla_{\mathbf{x}} \log p(\mathbf{x}) + \sum_i g_i(\mathbf{x}, t, \hat{\mathbf{x}}_0)\]

where \(g_i\) are the guidance terms implementing the DPSGuidance interface.

Each guidance must implement the DPSGuidance protocol, which is a callable with the following signature:

def guidance(
    x: Tensor,    # shape: (B, *dims)
    t: Tensor,    # shape: (B,)
    x_0: Tensor,  # shape: (B, *dims)
) -> Tensor: ...  # guidance term, shape: (B, *dims)

Important

When using multiple guidances that internally call torch.autograd.grad (e.g., ModelConsistencyDPSGuidance or DataConsistencyDPSGuidance), each guidance except the last must be constructed with retain_graph=True. Otherwise the computational graph is destroyed after the first guidance computes its gradient and subsequent guidances will fail. With a single guidance this is not needed.

Parameters:
  • x0_predictor (Predictor) – A Predictor that takes (x, t) and returns an estimate of the clean data \(\hat{\mathbf{x}}_0\). Typically obtained from a trained DiffusionModel via functools.partial.

  • x0_to_score_fn (Callable[[Tensor, Tensor, Tensor], Tensor]) – Callback to convert x0-prediction to score. Signature: x0_to_score_fn(x_0, x, t) -> score. Typically obtained from a noise scheduler, e.g., x0_to_score().

  • guidances (DPSGuidance | Sequence[DPSGuidance]) – One or more guidance objects implementing the DPSGuidance interface.

See also

DPSGuidance

Protocol for guidance implementations.

Predictor

Protocol satisfied by this class.

get_denoiser()

Converts the score-predictor to a denoiser for sampling.

Examples

Example 1: Basic usage with a single guidance for inpainting:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSScorePredictor, DPSGuidance
>>>
>>> # Toy x0-predictor (in practice, this is a trained neural network)
>>> x0_predictor = lambda x, t: x * 0.9
>>>
>>> # Simple x0_to_score function (for EDM: score = (x_0 - x) / t^2)
>>> def x0_to_score_fn(x_0, x, t):
...     t_bc = t.reshape(-1, *([1] * (x.ndim - 1)))
...     return (x_0 - x) / (t_bc ** 2)
...
>>> # Simple inpainting guidance
>>> class InpaintGuidance:
...     def __init__(self, mask, y_obs, gamma=1.0):
...         self.mask = mask
...         self.y_obs = y_obs
...         self.gamma = gamma
...     def __call__(self, x, t, x_0):
...         return -self.gamma * self.mask * (x_0 - self.y_obs)
...
>>> mask = torch.ones(1, 3, 8, 8)
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance = InpaintGuidance(mask, y_obs)
>>>
>>> # Create DPS score predictor
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=x0_to_score_fn,
...     guidances=guidance,
... )
>>>
>>> # Use in sampling
>>> x = torch.randn(1, 3, 8, 8)
>>> t = torch.tensor([1.0])
>>> output = dps_score_pred(x, t)
>>> output.shape
torch.Size([1, 3, 8, 8])

Example 2: Multiple guidances for multi-constraint problems:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSScorePredictor
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> # Use scheduler to get x0_to_score_fn
>>> scheduler = EDMNoiseScheduler()
>>> x0_predictor = lambda x, t: x * 0.9
>>>
>>> # Guidance 1: match observed values at specific locations
>>> class ObservationGuidance:
...     def __init__(self, mask, y_obs, gamma=1.0):
...         self.mask = mask
...         self.y_obs = y_obs
...         self.gamma = gamma
...     def __call__(self, x, t, x_0):
...         return -self.gamma * self.mask * (x_0 - self.y_obs)
...
>>> # Guidance 2: regularization toward zero mean
>>> class ZeroMeanGuidance:
...     def __init__(self, gamma=0.1):
...         self.gamma = gamma
...     def __call__(self, x, t, x_0):
...         return -self.gamma * x_0.mean() * torch.ones_like(x_0)
...
>>> mask = torch.ones(1, 3, 8, 8)
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance1 = ObservationGuidance(mask, y_obs)
>>> guidance2 = ZeroMeanGuidance()
>>>
>>> # Combine multiple guidances
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=[guidance1, guidance2],
... )
>>>
>>> x = torch.randn(2, 3, 8, 8)
>>> t = torch.tensor([1.0, 1.0])
>>> output = dps_score_pred(x, t)
>>> output.shape
torch.Size([2, 3, 8, 8])

Example 3: Multiple autograd-based guidances require retain_graph=True on all but the last:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     DPSScorePredictor,
...     DataConsistencyDPSGuidance,
... )
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>> x0_predictor = lambda x, t: x * 0.9
>>>
>>> mask1 = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask1[:, :, 2, 3] = True
>>> mask2 = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask2[:, :, 5, 6] = True
>>> y_obs = torch.randn(1, 3, 8, 8)
>>>
>>> # First guidance retains the graph for the second one
>>> g1 = DataConsistencyDPSGuidance(
...     mask=mask1, y=y_obs, std_y=0.1, retain_graph=True,
... )
>>> # Last guidance does not need retain_graph
>>> g2 = DataConsistencyDPSGuidance(
...     mask=mask2, y=y_obs, std_y=0.1,
... )
>>>
>>> dps = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=[g1, g2],
... )
>>> x = torch.randn(1, 3, 8, 8)
>>> t = torch.tensor([1.0])
>>> dps(x, t).shape
torch.Size([1, 3, 8, 8])

ModelConsistencyDPSGuidance#

class physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance(
observation_operator: Callable[[Float[Tensor, 'B *dims']], Float[Tensor, 'B *obs_dims']],
y: Float[Tensor, 'B *obs_dims'],
std_y: float,
norm: int | Callable[[Float[Tensor, 'B *obs_dims'], Float[Tensor, 'B *obs_dims']], Float[Tensor, 'B']] = 2,
gamma: float = 0.0,
sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
retain_graph: bool = False,
create_graph: bool = False,
)[source]#

Bases: DPSGuidance

DPS guidance for generic observation models with Gaussian noise.

Implements the DPSGuidance interface for generic (possibly nonlinear) observation operators.

Computes the likelihood score assuming Gaussian measurement noise with standard deviation std_y. The guidance term is:

\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| A\left(\hat{\mathbf{x}}_0\right) - \mathbf{y} \|^2\]

where \(A\) is the observation operator and the scaling incorporates a Score-Based Data Assimilation (SDA) correction through the parameter \(\Gamma\) that accounts for the covariance of the \(\hat{\mathbf{x}}_0(\mathbf{x}_t, t)\) estimate at different diffusion times. The L2 norm can be replaced by other Lp norms or custom loss functions via the norm parameter.

The observation_operator must be a differentiable callable with the following signature:

def observation_operator(
    x_0: Tensor,  # shape: (B, *dims)
) -> Tensor: ...  # predicted observations, shape: (B, *obs_dims)

When norm is a callable, it must have the following signature:

def norm(
    y_pred: Tensor,  # shape: (B, *obs_dims)
    y_true: Tensor,  # shape: (B, *obs_dims)
) -> Tensor: ...    # scalar loss per batch element, shape: (B,)
Parameters:
  • observation_operator (Callable[[Tensor], Tensor]) – Observation operator mapping clean state to observations. Must be differentiable (supports torch.autograd).

  • y (Tensor) – Observed data of shape \((B, *obs\_dims)\) matching the output of A.

  • std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).

  • norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An int value (default 2) uses the corresponding Lp norm. A callable receives (y_pred, y_true) and returns a scalar loss per batch element of shape \((B,)\).

  • gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to 0 for classical DPS without SDA scaling.

  • sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when gamma > 0. Typically obtained from a noise scheduler, e.g., sigma() for a linear-Gaussian noise schedule.

  • alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. Typically obtained from a noise scheduler, e.g., alpha() for a linear-Gaussian noise schedule.

  • retain_graph (bool, default=False) – If True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a single DPSScorePredictor — all guidances except the last must set this to True.

  • create_graph (bool, default=False) – If True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).

See also

DataConsistencyDPSGuidance

Simplified guidance for masked observations.

DPSScorePredictor

Combines an x0-predictor with one or more guidances.

Examples

Example 1: Super-resolution with a nonlinear blur + downsampling operator:

>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion.guidance import (
...     ModelConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>>
>>> # Observation operator: Gaussian blur + 2x downsampling
>>> def blur_downsample(x):
...     # Apply 3x3 Gaussian-like blur
...     kernel = torch.ones(1, 1, 3, 3, device=x.device) / 9
...     kernel = kernel.expand(x.shape[1], 1, 3, 3)
...     blurred = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
...     # Downsample 2x
...     return F.avg_pool2d(blurred, kernel_size=2, stride=2)
...
>>> # Low-resolution observations (4x4 from 8x8 high-res)
>>> y_obs = torch.randn(1, 3, 4, 4)
>>>
>>> guidance = ModelConsistencyDPSGuidance(
...     observation_operator=blur_downsample,
...     y=y_obs,
...     std_y=0.1,
... )
>>>
>>> # Use in DPS sampling
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9  # Toy x0 estimate
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Combine with DPSScorePredictor for complete sampling workflow
>>> x0_predictor = lambda x, t: x * 0.9
>>> def x0_to_score_fn(x_0, x, t):
...     t_bc = t.reshape(-1, *([1] * (x.ndim - 1)))
...     return (x_0 - x) / (t_bc ** 2)
...
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=x0_to_score_fn,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 2: With SDA scaling using noise scheduler methods:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     ModelConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>>
>>> # Linear observation operator (select first channel)
>>> A = lambda x: x[:, :1]
>>> y_obs = torch.randn(1, 1, 8, 8)
>>>
>>> # Enable SDA scaling with gamma > 0, providing sigma and alpha functions
>>> guidance = ModelConsistencyDPSGuidance(
...     observation_operator=A,
...     y=y_obs,
...     std_y=0.075,
...     gamma=0.05,  # Enable SDA scaling
...     sigma_fn=scheduler.sigma,
...     alpha_fn=scheduler.alpha,
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Use with DPSScorePredictor and scheduler's x0_to_score
>>> x0_predictor = lambda x, t: x * 0.9
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 3: With a custom loss function (Huber loss):

>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion.guidance import ModelConsistencyDPSGuidance
>>>
>>> # Wrap torch's Huber loss to return per-batch scalars
>>> def huber_loss(y_pred, y_true):
...     per_elem = F.huber_loss(y_pred, y_true, reduction="none")
...     return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1)
...
>>> A = lambda x: x[:, :1]  # Select first channel
>>> y_obs = torch.randn(1, 1, 8, 8)
>>>
>>> guidance = ModelConsistencyDPSGuidance(
...     observation_operator=A,
...     y=y_obs,
...     std_y=0.1,
...     norm=huber_loss,  # Custom loss function
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])

DataConsistencyDPSGuidance#

class physicsnemo.diffusion.guidance.DataConsistencyDPSGuidance(
mask: Bool[Tensor, 'B *dims'],
y: Float[Tensor, 'B *dims'],
std_y: float,
norm: int | Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims']], Float[Tensor, 'B']] = 2,
gamma: float = 0.0,
sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
retain_graph: bool = False,
create_graph: bool = False,
)[source]#

Bases: DPSGuidance

DPS guidance for masked observations with Gaussian noise.

Implements the DPSGuidance interface for masked observation operators, a simplified version of ModelConsistencyDPSGuidance. This is typical for data assimilation tasks like inpainting, outpainting, or sparse observations, where measurements are available at specific locations.

Computes the likelihood score assuming Gaussian measurement noise with standard deviation std_y. The guidance term is:

\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| \mathbf{M} \odot (\hat{\mathbf{x}}_0 - \mathbf{y}) \|^2\]

where \(\mathbf{M}\) is a binary mask (1 = observed, 0 = missing), \(\odot\) denotes element-wise multiplication, and the scaling incorporates an SDA correction through the parameter \(\Gamma\). The L2 norm can be replaced by other Lp norms or custom loss functions via the norm parameter.

When norm is a callable, it must have the following signature:

def norm(
    y_pred: Tensor,  # shape: (B, *obs_dims)
    y_true: Tensor,  # shape: (B, *obs_dims)
) -> Tensor: ...    # scalar loss per batch element, shape: (B,)
Parameters:
  • mask (Tensor) – Boolean mask of shape \((B, *)\) matching the state shape. True for observed locations, False for missing.

  • y (Tensor) – Observed data of shape \((B, *)\) matching the state shape. Values at unobserved locations (where mask=0) are ignored.

  • std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).

  • norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An int value (default 2) uses the corresponding Lp norm. A callable receives (mask * x_0, mask * y) and returns a scalar loss per batch element of shape \((B,)\).

  • gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to 0 for classical DPS without SDA scaling.

  • sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when gamma > 0. Typically obtained from a noise scheduler. For example, use sigma() for a linear-Gaussian noise schedule.

  • alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. For example, use alpha() for a linear-Gaussian noise schedule.

  • retain_graph (bool, default=False) – If True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a single DPSScorePredictor — all guidances except the last must set this to True.

  • create_graph (bool, default=False) – If True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).

See also

ModelConsistencyDPSGuidance

Guidance for general observation operators.

DPSScorePredictor

Combines an x0-predictor with one or more guidances.

Examples

Example 1: Sparse observations at probe locations:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     DataConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>>
>>> # Boolean mask: only observe a few probe locations
>>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask[:, :, 2, 3] = True  # Probe at (2, 3)
>>> mask[:, :, 5, 6] = True  # Probe at (5, 6)
>>> mask[:, :, 1, 7] = True  # Probe at (1, 7)
>>> y_obs = torch.randn(1, 3, 8, 8)  # Observed values
>>>
>>> guidance = DataConsistencyDPSGuidance(
...     mask=mask,
...     y=y_obs,
...     std_y=0.1,
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9  # Toy x0 estimate (must be computed from x)
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Use with DPSScorePredictor for complete sampling workflow
>>> x0_predictor = lambda x, t: x * 0.9
>>> def x0_to_score_fn(x_0, x, t):
...     t_bc = t.reshape(-1, *([1] * (x.ndim - 1)))
...     return (x_0 - x) / (t_bc ** 2)
...
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=x0_to_score_fn,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 2: With SDA scaling and L1 norm using noise scheduler:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     DataConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>>
>>> # Same sparse probe locations as Example 1
>>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask[:, :, 2, 3] = True
>>> mask[:, :, 5, 6] = True
>>> mask[:, :, 1, 7] = True
>>> y_obs = torch.randn(1, 3, 8, 8)
>>>
>>> # Enable SDA scaling and use L1 norm for robustness
>>> guidance = DataConsistencyDPSGuidance(
...     mask=mask,
...     y=y_obs,
...     std_y=0.075,
...     norm=1,  # L1 norm
...     gamma=1.0,  # Enable SDA scaling
...     sigma_fn=scheduler.sigma,
...     alpha_fn=scheduler.alpha,
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9  # Must be computed from x
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Use with DPSScorePredictor and scheduler's x0_to_score
>>> x0_predictor = lambda x, t: x * 0.9
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 3: With a custom loss function (Huber loss):

>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion.guidance import DataConsistencyDPSGuidance
>>>
>>> # Wrap torch's Huber loss to return per-batch scalars
>>> def huber_loss(y_pred, y_true):
...     per_elem = F.huber_loss(y_pred, y_true, reduction="none")
...     return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1)
...
>>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask[:, :, 2, 3] = True
>>> mask[:, :, 5, 6] = True
>>> y_obs = torch.randn(1, 3, 8, 8)
>>>
>>> guidance = DataConsistencyDPSGuidance(
...     mask=mask,
...     y=y_obs,
...     std_y=0.1,
...     norm=huber_loss,  # Custom loss function
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])