Samplers and Solvers#

The sampler is the main interface for generating new data from a trained diffusion model. Starting from pure noise \(\mathbf{x}_N\), the solver iteratively denoises the latent state through a sequence of time-steps until it reaches a clean sample \(\mathbf{x}_0\).

The central entry point is the sample() function. It takes a Denoiser, an initial noisy latent \(\mathbf{x}_N\), and a NoiseScheduler, and iterates the reverse process to produce samples.

Generic Sampling Process#

The sample() function supports any reverse process that can be written in the form:

\[\mathbf{x}_{n-1} = \text{Step}\bigl( D\!\bigl(\mathbf{x}_n, t_n;\, P(\mathbf{x}_n, t_n)\bigr);\; \mathbf{x}_n, t_n, t_{n-1}\bigr)\]

This equation is the foundation of the sampling process in the framework. Every component described in this page maps to one of the three terms:

  • \(P\) is the predictor — the Predictor that maps the noisy state and diffusion time to a prediction. This is where all model logic lives, including conditioning and guidance.

  • \(D\) is the denoiser — the Denoiser derived from \(P\) via the noise scheduler’s get_denoiser() factory.

  • \(\text{Step}\) is the solver’s numerical update rule.

This generic formulation encompasses standard ODE/SDE-based sampling, but also more advanced methods such as physics-informed posterior guidance (DPS), score-based data assimilation (SDA), and others. Any method that can express its update step through this denoiser/solver decomposition can be used with the sample() function.

Sampling Workflow#

A complete sampling workflow involves these steps:

  1. Load or reference a trained model satisfying the DiffusionModel interface (typically a backbone wrapped in a preconditioner).

  2. Build a Predictor (\(P\) in the sampling equation) by binding the conditioning via functools.partial, converting the three-argument DiffusionModel into a two-argument Predictor.

  3. Convert to a Denoiser (\(D\) in the equation). There are two paths:

    • Without guidance — pass the predictor directly to the noise scheduler’s get_denoiser() factory (as an x0_predictor or score_predictor).

    • With guidance — first instantiate one or more DPSGuidance objects, then combine them with the predictor using DPSScorePredictor to obtain a guided score-predictor. Finally, pass this guided score-predictor to get_denoiser().

  4. Initialize the noisy latent \(\mathbf{x}_N\) and the time-step schedule using the scheduler.

  5. Optionally configure a custom solver (\(\text{Step}\) in the equation) by instantiating a Solver (see Available Solvers), or simply pass a built-in string key (for example, "heun") to sample().

  6. Call sample() to run the reverse diffusion loop.

Example: Unconditional Image Generation#

This example shows the full workflow for an unconditional image model trained with the EDM formulation. It uses SongUNet as the backbone, wrapped with a thin adapter to match the DiffusionModel interface, and EDMPreconditioner for preconditioning.

The noise scheduler must generally be consistent between training and sampling—in particular, the same schedule family (for example, EDM, VP) should be used. Schedule parameters (for example, sigma_min, rho) can be adjusted at sampling time for experimentation, but the model was optimized for the training schedule, so large deviations may degrade sample quality.

import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample

# --- Backbone: wrap SongUNet to match the DiffusionModel interface ---
class UNetBackbone(Module):
    def __init__(self, img_resolution, channels, **kwargs):
        super().__init__()
        self.net = SongUNet(
            img_resolution=img_resolution,
            in_channels=channels,
            out_channels=channels,
            **kwargs,
        )
    def forward(self, x, t, condition=None):
        return self.net(x, noise_labels=t, class_labels=condition)

backbone = UNetBackbone(img_resolution=64, channels=3, model_channels=64,
                        channel_mult=[1, 2, 2], num_blocks=2)

# --- Preconditioner + training (sketch) ---
scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=80.0, rho=7)
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler) ...

# --- Sampling ---
precond.eval()

# Build predictor "P": bind condition=None for unconditional model
x0_predictor = partial(precond, condition=None)

# Convert to denoiser "D"
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)

# Initialize time-steps and noisy latent
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(8)
xN = scheduler.init_latents((3, 64, 64), tN)

# Run sampling loop with Heun solver ("Step")
samples = sample(denoiser, xN, scheduler, num_steps=num_steps, solver="heun")
# samples.shape: (8, 3, 64, 64)

It is also possible to adjust the schedule parameters at sampling time. For instance, one might increase sigma_max or change rho to explore the effect on sample quality:

sampling_scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=120.0, rho=5)
denoiser = sampling_scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = sampling_scheduler.timesteps(num_steps)[0].expand(8)
xN = sampling_scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, sampling_scheduler, num_steps=num_steps)

Example: Vector-Space Diffusion (Non-Image Data)#

The diffusion framework is not limited to image data. Any tensor-valued data can be used, including 1D vectors. Here the backbone uses the FullyConnected model from PhysicsNeMo, wrapped with a thin adapter to match the DiffusionModel interface.

import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.mlp import FullyConnected
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample

# Backbone: wrap FullyConnected to match the DiffusionModel interface
class FCBackbone(Module):
    def __init__(self, dim, hidden=256, num_layers=4):
        super().__init__()
        self.net = FullyConnected(
            in_features=dim, layer_size=hidden,
            out_features=dim, num_layers=num_layers,
        )
    def forward(self, x, t, condition=None):
        return self.net(x)

data_dim = 32
backbone = FCBackbone(dim=data_dim)
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=1.0)
# ... train with MSEDSMLoss(precond, scheduler) ...

# Sampling
precond.eval()
x0_predictor = partial(precond, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(16)
xN = scheduler.init_latents((data_dim,), tN)  # 1D latent: shape (16, 32)
samples = sample(denoiser, xN, scheduler, num_steps=num_steps)
# samples.shape: (16, 32)

Example: Conditional Sampling#

For conditional generation (for example, super-resolution), the model backbone processes both the noisy latent state and the conditioning input. A common pattern is to concatenate the conditioning image along the channel dimension inside a thin adapter. At sampling time, the conditioning is bound into the predictor (\(P\)) via functools.partial.

import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample

C_x, C_cond, res = 3, 3, 64   # Image channels, conditioning channels, resolution

# Backbone: SongUNet wrapped with an adapter that concatenates the
# conditioning image along the channel dimension
class ConditionalUNet(Module):
    def __init__(self):
        super().__init__()
        self.net = SongUNet(
            img_resolution=res,
            in_channels=C_x + C_cond,
            out_channels=C_x,
            model_channels=64,
            channel_mult=[1, 2, 2],
            num_blocks=2,
        )
    def forward(self, x, t, condition=None):
        x_cat = torch.cat([x, condition], dim=1)
        return self.net(x_cat, noise_labels=t, class_labels=None)

backbone = ConditionalUNet()
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler, condition=...) ...

# --- Sampling ---
precond.eval()

# Conditioning image (for example, low-resolution input for super-resolution)
low_res = torch.randn(4, C_cond, res, res)

# Bind condition into the predictor "P"
x0_predictor = partial(precond, condition=low_res)

# Convert to denoiser "D" and sample
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((C_x, res, res), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)
# samples.shape: (4, 3, 64, 64)

Example: Conditional Sampling with DPS Guidance#

DPS (Diffusion Posterior Sampling) guidance steers the sampling toward satisfying observation constraints. Guidance modifies the predictor \(P\) in the sampling equation: the guidance objects are combined with the x0-predictor into a guided score-predictor via DPSScorePredictor, and then converted to a denoiser \(D\) via the noise scheduler.

import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.guidance import DPSScorePredictor, DataConsistencyDPSGuidance
from physicsnemo.diffusion.samplers import sample

scheduler = EDMNoiseScheduler()

# Build predictor "P" from trained conditional model
x0_predictor = partial(trained_model, condition=condition)

# Build guidance objects
mask = torch.zeros(4, 3, 64, 64, dtype=torch.bool)
mask[:, :, ::8, ::8] = True  # Observe every 8th pixel
y_obs = torch.randn(4, 3, 64, 64)
guidance = DataConsistencyDPSGuidance(mask=mask, y=y_obs, std_y=0.1)

# Combine predictor + guidance into a guided score-predictor
guided_score_predictor = DPSScorePredictor(
    x0_predictor=x0_predictor,
    x0_to_score_fn=scheduler.x0_to_score,
    guidances=guidance,
)

# Convert to denoiser "D" via the noise scheduler
denoiser = scheduler.get_denoiser(score_predictor=guided_score_predictor)

tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)

Example: Custom Solver and Custom Time-Steps#

This example shows how to define a solver from scratch by implementing the Solver protocol. Any object with a step(x, t_cur, t_next) method can serve as \(\text{Step}\) in the sampling equation. Here we implement a simple implicit trapezoidal rule (second-order Runge-Kutta), and pair it with custom time-steps and trajectory snapshots.

import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.samplers import sample

# Custom solver: implicit trapezoidal rule (second-order)
class TrapezoidalSolver:
    def __init__(self, denoiser, num_inner_iters=3):
        self.denoiser = denoiser
        self.num_inner_iters = num_inner_iters

    def step(self, x, t_cur, t_next):
        t_cur_bc = t_cur.reshape(-1, *([1] * (x.ndim - 1)))
        t_next_bc = t_next.reshape(-1, *([1] * (x.ndim - 1)))
        h = t_next_bc - t_cur_bc
        d_cur = self.denoiser(x, t_cur)
        # Predictor: Euler step to get initial guess
        x_next = x + h * d_cur
        # Corrector: fixed-point iterations for the implicit trapezoidal rule
        for _ in range(self.num_inner_iters):
            d_next = self.denoiser(x_next, t_next)
            x_next = x + 0.5 * h * (d_cur + d_next)
        return x_next

scheduler = EDMNoiseScheduler()
x0_predictor = partial(trained_model, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)

# Custom time-steps
custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 2.0, 1.0, 0.5, 0.1, 0.0])
tN = custom_t[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)

solver = TrapezoidalSolver(denoiser, num_inner_iters=3)
trajectory = sample(
    denoiser, xN, scheduler,
    num_steps=0,            # Ignored when time_steps is provided
    time_steps=custom_t,
    solver=solver,
    time_eval=[0, 4, 7],    # Collect snapshots at steps 0, 4, 7
)
# trajectory is a list of 3 tensors, each of shape (4, 3, 64, 64)

Available Solvers#

Solvers implement the \(\text{Step}\) operator in the sampling equation. At each iteration, the solver receives the denoiser output \(D(\mathbf{x}_n, t_n)\) and advances the latent state from time \(t_n\) to \(t_{n-1}\).

There are two ways to use solvers:

Built-in solvers can be selected by passing a string key to sample():

  • "euler"EulerSolver. First-order. Fast (one denoiser evaluation per step) but lower quality.

  • "heun"HeunSolver. Second-order. Higher quality but twice as expensive per step.

  • "edm_stochastic_euler"EDMStochasticEulerSolver. First-order with configurable stochastic noise injection.

  • "edm_stochastic_heun"EDMStochasticHeunSolver. Second-order with configurable stochastic noise injection.

Custom solvers can be defined by implementing the Solver protocol: any object with a step(x, t_cur, t_next) method. Pass the instance directly to sample() for full control over the integration method.

Guidance#

Guidance techniques modify the predictor \(P\) in the sampling equation to steer the generated samples toward desired properties, such as consistency with observed data, satisfaction of physical constraints, or other task-specific objectives.

In the framework, guidance operates at the Predictor level: you compose or modify predictors before converting them into a Denoiser \(D\). Concretely, guidance objects implement the DPSGuidance protocol and are combined with an x0-predictor into a guided score-predictor via DPSScorePredictor. The resulting score-predictor is then passed to the noise scheduler’s get_denoiser() factory to obtain a denoiser \(D\) for sampling. The sampler and solver (\(\text{Step}\)) are unchanged—they only see the final denoiser.

The framework provides two ready-to-use guidance implementations:

Custom guidances can be defined by implementing the DPSGuidance protocol—any callable with the signature (x, t, x_0) -> guidance_term. Multiple guidances can be combined by passing a list to DPSScorePredictor.

API Reference#

Sample Entry Point#

sample#

Solvers#

Solver#

EulerSolver#

HeunSolver#

EDMStochasticEulerSolver#

EDMStochasticHeunSolver#

Guidance#

DPSGuidance#

class physicsnemo.diffusion.guidance.DPSGuidance(*args, **kwargs)[source]#

Protocol defining the interface for Diffusion Posterior Sampling (DPS) guidance.

A DPS guidance is a callable that computes a guidance term to steer the diffusion sampling process toward satisfying some observation constraint. It returns a quantity analogous to a likelihood score, which is typically added to the unconditional score during sampling.

The typical form is:

\[\rho(t) \nabla_{\mathbf{x}} \ell(A(\hat{\mathbf{x}}_0) - \mathbf{y})\]

where \(\rho(t)\) is a time-dependent guidance strength, \(A\) is a (potentially nonlinear) observation operator, \(\mathbf{y}\) is the observed data, and \(\ell\) is a scalar loss function. However, variants are possible as long as the guidance produces a quantity similar to a score (e.g., a likelihood score).

This is the minimal interface for guidance, and any object that implements this interface can be used with DPSScorePredictor to build a guided score-predictor, which implements the Predictor interface.

See also

DPSScorePredictor

Combines an x0-predictor with one or more guidances.

Examples

Example 1: Minimal guidance for inpainting. Given a binary mask and observed pixels, guide the diffusion to match observations:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSGuidance
>>>
>>> class InpaintingGuidance:
...     def __init__(self, mask, y_obs, gamma=1.0):
...         self.mask = mask  # Binary mask: 1 = observed, 0 = missing
...         self.y_obs = y_obs  # Observed pixel values
...         self.gamma = gamma
...
...     def __call__(self, x, t, x_0):
...         # Compute residual at observed locations
...         residual = self.mask * (x_0 - self.y_obs)
...         # Gradient of L2 loss w.r.t. x_0 is just the residual
...         # (simplified: assumes identity observation operator)
...         return -self.gamma * residual
...
>>> mask = torch.ones(1, 3, 8, 8)
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance = InpaintingGuidance(mask, y_obs)
>>> isinstance(guidance, DPSGuidance)
True

Example 2: Building a guided score predictor from scratch. A common pattern is to combine an x0-predictor with a guidance to create a score predictor that can be used for sampling. This shows the complete workflow:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSGuidance
>>>
>>> # Define a guidance that pushes toward observed values
>>> class MyGuidance:
...     def __init__(self, y_obs, gamma=0.1):
...         self.y_obs = y_obs
...         self.gamma = gamma
...
...     def __call__(self, x, t, x_0):
...         return -self.gamma * (x_0 - self.y_obs)
...
>>> # Toy x0-predictor (in practice, a trained neural network)
>>> x0_predictor = lambda x, t: x * 0.9
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance = MyGuidance(y_obs, gamma=0.5)
>>>
>>> # Build a guided score predictor that combines x0-predictor + guidance
>>> def guided_score_predictor(x, t):
...     x_0 = x0_predictor(x, t)
...     guidance_term = guidance(x, t, x_0)
...     # Convert x0 to score (for EDM: score = (x_0 - x) / t^2)
...     expected_shape = (-1,) + (1,) * (x.ndim - 1)
...     t_bc = t.reshape(expected_shape)
...     score = (x_0 - x) / (t_bc ** 2)
...     return score + guidance_term
...
>>> # guided_score_predictor is now a Predictor (score predictor); pass it
>>> # to scheduler.get_denoiser(score_predictor=...) to obtain a Denoiser
>>> x = torch.randn(1, 3, 8, 8)
>>> t = torch.tensor([1.0])
>>> output = guided_score_predictor(x, t)
>>> output.shape
torch.Size([1, 3, 8, 8])

Note: DPSScorePredictor provides a convenient way to apply one or more guidances to an x0-predictor without manually implementing the above pattern.

DPSScorePredictor#

class physicsnemo.diffusion.guidance.DPSScorePredictor(
x0_predictor: Predictor,
x0_to_score_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *dims']],
guidances: DPSGuidance | Sequence[DPSGuidance],
)[source]#

Bases: Predictor

Score predictor that combines an x0-predictor with DPS-style guidance.

This class transforms a Predictor (specifically, an x0-predictor) into a score Predictor by applying one or more DPS guidances. The resulting score predictor can be passed to get_denoiser() to obtain a Denoiser for sampling.

The output is the sum of the unconditional score (derived from the x0-prediction) and all guidance terms:

\[\nabla_{\mathbf{x}} \log p(\mathbf{x}) + \sum_i g_i(\mathbf{x}, t, \hat{\mathbf{x}}_0)\]

where \(g_i\) are the guidance terms implementing the DPSGuidance interface.

Each guidance must implement the DPSGuidance protocol, which is a callable with the following signature:

def guidance(
    x: Tensor,    # shape: (B, *dims)
    t: Tensor,    # shape: (B,)
    x_0: Tensor,  # shape: (B, *dims)
) -> Tensor: ...  # guidance term, shape: (B, *dims)

Important

When using multiple guidances that internally call torch.autograd.grad (e.g., ModelConsistencyDPSGuidance or DataConsistencyDPSGuidance), each guidance except the last must be constructed with retain_graph=True. Otherwise the computational graph is destroyed after the first guidance computes its gradient and subsequent guidances will fail. With a single guidance this is not needed.

Parameters:
  • x0_predictor (Predictor) – A Predictor that takes (x, t) and returns an estimate of the clean data \(\hat{\mathbf{x}}_0\). Typically obtained from a trained DiffusionModel via functools.partial.

  • x0_to_score_fn (Callable[[Tensor, Tensor, Tensor], Tensor]) – Callback to convert x0-prediction to score. Signature: x0_to_score_fn(x_0, x, t) -> score. Typically obtained from a noise scheduler, e.g., x0_to_score().

  • guidances (DPSGuidance | Sequence[DPSGuidance]) – One or more guidance objects implementing the DPSGuidance interface.

See also

DPSGuidance

Protocol for guidance implementations.

Predictor

Protocol satisfied by this class.

get_denoiser()

Converts the score-predictor to a denoiser for sampling.

Examples

Example 1: Basic usage with a single guidance for inpainting:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSScorePredictor, DPSGuidance
>>>
>>> # Toy x0-predictor (in practice, this is a trained neural network)
>>> x0_predictor = lambda x, t: x * 0.9
>>>
>>> # Simple x0_to_score function (for EDM: score = (x_0 - x) / t^2)
>>> def x0_to_score_fn(x_0, x, t):
...     expected_shape = (-1,) + (1,) * (x.ndim - 1)
...     t_bc = t.reshape(expected_shape)
...     return (x_0 - x) / (t_bc ** 2)
...
>>> # Simple inpainting guidance
>>> class InpaintGuidance:
...     def __init__(self, mask, y_obs, gamma=1.0):
...         self.mask = mask
...         self.y_obs = y_obs
...         self.gamma = gamma
...     def __call__(self, x, t, x_0):
...         return -self.gamma * self.mask * (x_0 - self.y_obs)
...
>>> mask = torch.ones(1, 3, 8, 8)
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance = InpaintGuidance(mask, y_obs)
>>>
>>> # Create DPS score predictor
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=x0_to_score_fn,
...     guidances=guidance,
... )
>>>
>>> # Use in sampling
>>> x = torch.randn(1, 3, 8, 8)
>>> t = torch.tensor([1.0])
>>> output = dps_score_pred(x, t)
>>> output.shape
torch.Size([1, 3, 8, 8])

Example 2: Multiple guidances for multi-constraint problems:

>>> import torch
>>> from physicsnemo.diffusion.guidance import DPSScorePredictor
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> # Use scheduler to get x0_to_score_fn
>>> scheduler = EDMNoiseScheduler()
>>> x0_predictor = lambda x, t: x * 0.9
>>>
>>> # Guidance 1: match observed values at specific locations
>>> class ObservationGuidance:
...     def __init__(self, mask, y_obs, gamma=1.0):
...         self.mask = mask
...         self.y_obs = y_obs
...         self.gamma = gamma
...     def __call__(self, x, t, x_0):
...         return -self.gamma * self.mask * (x_0 - self.y_obs)
...
>>> # Guidance 2: regularization toward zero mean
>>> class ZeroMeanGuidance:
...     def __init__(self, gamma=0.1):
...         self.gamma = gamma
...     def __call__(self, x, t, x_0):
...         return -self.gamma * x_0.mean() * torch.ones_like(x_0)
...
>>> mask = torch.ones(1, 3, 8, 8)
>>> y_obs = torch.randn(1, 3, 8, 8)
>>> guidance1 = ObservationGuidance(mask, y_obs)
>>> guidance2 = ZeroMeanGuidance()
>>>
>>> # Combine multiple guidances
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=[guidance1, guidance2],
... )
>>>
>>> x = torch.randn(2, 3, 8, 8)
>>> t = torch.tensor([1.0, 1.0])
>>> output = dps_score_pred(x, t)
>>> output.shape
torch.Size([2, 3, 8, 8])

Example 3: Multiple autograd-based guidances require retain_graph=True on all but the last:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     DPSScorePredictor,
...     DataConsistencyDPSGuidance,
... )
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>> x0_predictor = lambda x, t: x * 0.9
>>>
>>> mask1 = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask1[:, :, 2, 3] = True
>>> mask2 = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask2[:, :, 5, 6] = True
>>> y_obs = torch.randn(1, 3, 8, 8)
>>>
>>> # First guidance retains the graph for the second one
>>> g1 = DataConsistencyDPSGuidance(
...     mask=mask1, y=y_obs, std_y=0.1, retain_graph=True,
... )
>>> # Last guidance does not need retain_graph
>>> g2 = DataConsistencyDPSGuidance(
...     mask=mask2, y=y_obs, std_y=0.1,
... )
>>>
>>> dps = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=[g1, g2],
... )
>>> x = torch.randn(1, 3, 8, 8)
>>> t = torch.tensor([1.0])
>>> dps(x, t).shape
torch.Size([1, 3, 8, 8])

ModelConsistencyDPSGuidance#

class physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance(
observation_operator: Callable[[Float[Tensor, 'B *dims']], Float[Tensor, 'B *obs_dims']],
y: Float[Tensor, 'B *obs_dims'],
std_y: float,
norm: int | Callable[[Float[Tensor, 'B *obs_dims'], Float[Tensor, 'B *obs_dims']], Float[Tensor, 'B']] = 2,
gamma: float = 0.0,
sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
retain_graph: bool = False,
create_graph: bool = False,
)[source]#

Bases: DPSGuidance

DPS guidance for generic observation models with Gaussian noise.

Implements the DPSGuidance interface for generic (possibly nonlinear) observation operators.

Computes the likelihood score assuming Gaussian measurement noise with standard deviation std_y. The guidance term is:

\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| A\left(\hat{\mathbf{x}}_0\right) - \mathbf{y} \|^2\]

where \(A\) is the observation operator and the scaling incorporates a Score-Based Data Assimilation (SDA) correction through the parameter \(\Gamma\) that accounts for the covariance of the \(\hat{\mathbf{x}}_0(\mathbf{x}_t, t)\) estimate at different diffusion times. The L2 norm can be replaced by other Lp norms or custom loss functions via the norm parameter.

The observation_operator must be a differentiable callable with the following signature:

def observation_operator(
    x_0: Tensor,  # shape: (B, *dims)
) -> Tensor: ...  # predicted observations, shape: (B, *obs_dims)

When norm is a callable, it must have the following signature:

def norm(
    y_pred: Tensor,  # shape: (B, *obs_dims)
    y_true: Tensor,  # shape: (B, *obs_dims)
) -> Tensor: ...    # scalar loss per batch element, shape: (B,)
Parameters:
  • observation_operator (Callable[[Tensor], Tensor]) – Observation operator mapping clean state to observations. Must be differentiable (supports torch.autograd).

  • y (Tensor) – Observed data of shape \((B, *obs\_dims)\) matching the output of A.

  • std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).

  • norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An int value (default 2) uses the corresponding Lp norm. A callable receives (y_pred, y_true) and returns a scalar loss per batch element of shape \((B,)\).

  • gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to 0 for classical DPS without SDA scaling.

  • sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when gamma > 0. Typically obtained from a noise scheduler, e.g., sigma() for a linear-Gaussian noise schedule.

  • alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. Typically obtained from a noise scheduler, e.g., alpha() for a linear-Gaussian noise schedule.

  • retain_graph (bool, default=False) – If True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a single DPSScorePredictor — all guidances except the last must set this to True.

  • create_graph (bool, default=False) – If True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).

See also

DataConsistencyDPSGuidance

Simplified guidance for masked observations.

DPSScorePredictor

Combines an x0-predictor with one or more guidances.

Examples

Example 1: Super-resolution with a nonlinear blur + downsampling operator:

>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion.guidance import (
...     ModelConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>>
>>> # Observation operator: Gaussian blur + 2x downsampling
>>> def blur_downsample(x):
...     # Apply 3x3 Gaussian-like blur
...     kernel = torch.ones(1, 1, 3, 3, device=x.device) / 9
...     kernel = kernel.expand(x.shape[1], 1, 3, 3)
...     blurred = F.conv2d(x, kernel, padding=1, groups=x.shape[1])
...     # Downsample 2x
...     return F.avg_pool2d(blurred, kernel_size=2, stride=2)
...
>>> # Low-resolution observations (4x4 from 8x8 high-res)
>>> y_obs = torch.randn(1, 3, 4, 4)
>>>
>>> guidance = ModelConsistencyDPSGuidance(
...     observation_operator=blur_downsample,
...     y=y_obs,
...     std_y=0.1,
... )
>>>
>>> # Use in DPS sampling
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9  # Toy x0 estimate
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Combine with DPSScorePredictor for complete sampling workflow
>>> x0_predictor = lambda x, t: x * 0.9
>>> def x0_to_score_fn(x_0, x, t):
...     expected_shape = (-1,) + (1,) * (x.ndim - 1)
...     t_bc = t.reshape(expected_shape)
...     return (x_0 - x) / (t_bc ** 2)
...
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=x0_to_score_fn,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 2: With SDA scaling using noise scheduler methods:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     ModelConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>>
>>> # Linear observation operator (select first channel)
>>> A = lambda x: x[:, :1]
>>> y_obs = torch.randn(1, 1, 8, 8)
>>>
>>> # Enable SDA scaling with gamma > 0, providing sigma and alpha functions
>>> guidance = ModelConsistencyDPSGuidance(
...     observation_operator=A,
...     y=y_obs,
...     std_y=0.075,
...     gamma=0.05,  # Enable SDA scaling
...     sigma_fn=scheduler.sigma,
...     alpha_fn=scheduler.alpha,
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Use with DPSScorePredictor and scheduler's x0_to_score
>>> x0_predictor = lambda x, t: x * 0.9
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 3: With a custom loss function (Huber loss):

>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion.guidance import ModelConsistencyDPSGuidance
>>>
>>> # Wrap torch's Huber loss to return per-batch scalars
>>> def huber_loss(y_pred, y_true):
...     per_elem = F.huber_loss(y_pred, y_true, reduction="none")
...     return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1)
...
>>> A = lambda x: x[:, :1]  # Select first channel
>>> y_obs = torch.randn(1, 1, 8, 8)
>>>
>>> guidance = ModelConsistencyDPSGuidance(
...     observation_operator=A,
...     y=y_obs,
...     std_y=0.1,
...     norm=huber_loss,  # Custom loss function
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])

DataConsistencyDPSGuidance#

class physicsnemo.diffusion.guidance.DataConsistencyDPSGuidance(
mask: Bool[Tensor, 'B *dims'],
y: Float[Tensor, 'B *dims'],
std_y: float,
norm: int | Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims']], Float[Tensor, 'B']] = 2,
gamma: float = 0.0,
sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
retain_graph: bool = False,
create_graph: bool = False,
)[source]#

Bases: DPSGuidance

DPS guidance for masked observations with Gaussian noise.

Implements the DPSGuidance interface for masked observation operators, a simplified version of ModelConsistencyDPSGuidance. This is typical for data assimilation tasks like inpainting, outpainting, or sparse observations, where measurements are available at specific locations.

Computes the likelihood score assuming Gaussian measurement noise with standard deviation std_y. The guidance term is:

\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| \mathbf{M} \odot (\hat{\mathbf{x}}_0 - \mathbf{y}) \|^2\]

where \(\mathbf{M}\) is a binary mask (1 = observed, 0 = missing), \(\odot\) denotes element-wise multiplication, and the scaling incorporates an SDA correction through the parameter \(\Gamma\). The L2 norm can be replaced by other Lp norms or custom loss functions via the norm parameter.

When norm is a callable, it must have the following signature:

def norm(
    y_pred: Tensor,  # shape: (B, *obs_dims)
    y_true: Tensor,  # shape: (B, *obs_dims)
) -> Tensor: ...    # scalar loss per batch element, shape: (B,)
Parameters:
  • mask (Tensor) – Boolean mask of shape \((B, *)\) matching the state shape. True for observed locations, False for missing.

  • y (Tensor) – Observed data of shape \((B, *)\) matching the state shape. Values at unobserved locations (where mask=0) are ignored.

  • std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).

  • norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An int value (default 2) uses the corresponding Lp norm. A callable receives (mask * x_0, mask * y) and returns a scalar loss per batch element of shape \((B,)\).

  • gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to 0 for classical DPS without SDA scaling.

  • sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when gamma > 0. Typically obtained from a noise scheduler. For example, use sigma() for a linear-Gaussian noise schedule.

  • alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. For example, use alpha() for a linear-Gaussian noise schedule.

  • retain_graph (bool, default=False) – If True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a single DPSScorePredictor — all guidances except the last must set this to True.

  • create_graph (bool, default=False) – If True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).

See also

ModelConsistencyDPSGuidance

Guidance for general observation operators.

DPSScorePredictor

Combines an x0-predictor with one or more guidances.

Examples

Example 1: Sparse observations at probe locations:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     DataConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>>
>>> # Boolean mask: only observe a few probe locations
>>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask[:, :, 2, 3] = True  # Probe at (2, 3)
>>> mask[:, :, 5, 6] = True  # Probe at (5, 6)
>>> mask[:, :, 1, 7] = True  # Probe at (1, 7)
>>> y_obs = torch.randn(1, 3, 8, 8)  # Observed values
>>>
>>> guidance = DataConsistencyDPSGuidance(
...     mask=mask,
...     y=y_obs,
...     std_y=0.1,
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9  # Toy x0 estimate (must be computed from x)
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Use with DPSScorePredictor for complete sampling workflow
>>> x0_predictor = lambda x, t: x * 0.9
>>> def x0_to_score_fn(x_0, x, t):
...     expected_shape = (-1,) + (1,) * (x.ndim - 1)
...     t_bc = t.reshape(expected_shape)
...     return (x_0 - x) / (t_bc ** 2)
...
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=x0_to_score_fn,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 2: With SDA scaling and L1 norm using noise scheduler:

>>> import torch
>>> from physicsnemo.diffusion.guidance import (
...     DataConsistencyDPSGuidance,
...     DPSScorePredictor,
... )
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>>
>>> scheduler = EDMNoiseScheduler()
>>>
>>> # Same sparse probe locations as Example 1
>>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask[:, :, 2, 3] = True
>>> mask[:, :, 5, 6] = True
>>> mask[:, :, 1, 7] = True
>>> y_obs = torch.randn(1, 3, 8, 8)
>>>
>>> # Enable SDA scaling and use L1 norm for robustness
>>> guidance = DataConsistencyDPSGuidance(
...     mask=mask,
...     y=y_obs,
...     std_y=0.075,
...     norm=1,  # L1 norm
...     gamma=1.0,  # Enable SDA scaling
...     sigma_fn=scheduler.sigma,
...     alpha_fn=scheduler.alpha,
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9  # Must be computed from x
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])
>>>
>>> # Use with DPSScorePredictor and scheduler's x0_to_score
>>> x0_predictor = lambda x, t: x * 0.9
>>> dps_score_pred = DPSScorePredictor(
...     x0_predictor=x0_predictor,
...     x0_to_score_fn=scheduler.x0_to_score,
...     guidances=guidance,
... )
>>> score = dps_score_pred(x, t)
>>> score.shape
torch.Size([1, 3, 8, 8])

Example 3: With a custom loss function (Huber loss):

>>> import torch
>>> import torch.nn.functional as F
>>> from physicsnemo.diffusion.guidance import DataConsistencyDPSGuidance
>>>
>>> # Wrap torch's Huber loss to return per-batch scalars
>>> def huber_loss(y_pred, y_true):
...     per_elem = F.huber_loss(y_pred, y_true, reduction="none")
...     return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1)
...
>>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool)
>>> mask[:, :, 2, 3] = True
>>> mask[:, :, 5, 6] = True
>>> y_obs = torch.randn(1, 3, 8, 8)
>>>
>>> guidance = DataConsistencyDPSGuidance(
...     mask=mask,
...     y=y_obs,
...     std_y=0.1,
...     norm=huber_loss,  # Custom loss function
... )
>>>
>>> x = torch.randn(1, 3, 8, 8, requires_grad=True)
>>> t = torch.tensor([1.0])
>>> x_0 = x * 0.9
>>> output = guidance(x, t, x_0)
>>> output.shape
torch.Size([1, 3, 8, 8])