Samplers and Solvers#
The sampler is the main interface for generating new data from a trained diffusion model. Starting from pure noise \(\mathbf{x}_N\), the solver iteratively denoises the latent state through a sequence of time-steps until it reaches a clean sample \(\mathbf{x}_0\).
The central entry point is the sample() function. It takes a
Denoiser, an initial noisy latent
\(\mathbf{x}_N\), and a
NoiseScheduler, and iterates
the reverse process to produce samples.
Generic Sampling Process#
The sample() function supports any reverse process that can be written
in the form:
This equation is the foundation of the sampling process in the framework. Every component described in this page maps to one of the three terms:
\(P\) is the predictor — the
Predictorthat maps the noisy state and diffusion time to a prediction. This is where all model logic lives, including conditioning and guidance.\(D\) is the denoiser — the
Denoiserderived from \(P\) via the noise scheduler’sget_denoiser()factory.\(\text{Step}\) is the solver’s numerical update rule.
This generic formulation encompasses standard ODE/SDE-based sampling, but
also more advanced methods such as physics-informed posterior guidance
(DPS), score-based data assimilation
(SDA), and others. Any
method that can express its update step through this denoiser/solver
decomposition can be used with the sample() function.
Sampling Workflow#
A complete sampling workflow involves these steps:
Load or reference a trained model satisfying the
DiffusionModelinterface (typically a backbone wrapped in a preconditioner).Build a Predictor (\(P\) in the sampling equation) by binding the conditioning via
functools.partial, converting the three-argumentDiffusionModelinto a two-argumentPredictor.Convert to a Denoiser (\(D\) in the equation). There are two paths:
Without guidance — pass the predictor directly to the noise scheduler’s
get_denoiser()factory (as anx0_predictororscore_predictor).With guidance — first instantiate one or more
DPSGuidanceobjects, then combine them with the predictor usingDPSScorePredictorto obtain a guided score-predictor. Finally, pass this guided score-predictor toget_denoiser().
Initialize the noisy latent \(\mathbf{x}_N\) and the time-step schedule using the scheduler.
Optionally configure a custom solver (\(\text{Step}\) in the equation) by instantiating a
Solver(see Available Solvers), or simply pass a built-in string key (for example,"heun") tosample().Call
sample()to run the reverse diffusion loop.
Example: Unconditional Image Generation#
This example shows the full workflow for an unconditional image model trained
with the EDM formulation. It uses
SongUNet as the backbone,
wrapped with a thin adapter to match the
DiffusionModel interface, and
EDMPreconditioner for
preconditioning.
The noise scheduler must generally be consistent between training and
sampling—in particular, the same schedule family (for example, EDM, VP)
should be used. Schedule parameters (for example, sigma_min, rho) can
be adjusted at sampling time for experimentation, but the model was
optimized for the training schedule, so large deviations may degrade
sample quality.
import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample
# --- Backbone: wrap SongUNet to match the DiffusionModel interface ---
class UNetBackbone(Module):
def __init__(self, img_resolution, channels, **kwargs):
super().__init__()
self.net = SongUNet(
img_resolution=img_resolution,
in_channels=channels,
out_channels=channels,
**kwargs,
)
def forward(self, x, t, condition=None):
return self.net(x, noise_labels=t, class_labels=condition)
backbone = UNetBackbone(img_resolution=64, channels=3, model_channels=64,
channel_mult=[1, 2, 2], num_blocks=2)
# --- Preconditioner + training (sketch) ---
scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=80.0, rho=7)
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler) ...
# --- Sampling ---
precond.eval()
# Build predictor "P": bind condition=None for unconditional model
x0_predictor = partial(precond, condition=None)
# Convert to denoiser "D"
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
# Initialize time-steps and noisy latent
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(8)
xN = scheduler.init_latents((3, 64, 64), tN)
# Run sampling loop with Heun solver ("Step")
samples = sample(denoiser, xN, scheduler, num_steps=num_steps, solver="heun")
# samples.shape: (8, 3, 64, 64)
It is also possible to adjust the schedule parameters at sampling time. For
instance, one might increase sigma_max or change rho to explore the
effect on sample quality:
sampling_scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=120.0, rho=5)
denoiser = sampling_scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = sampling_scheduler.timesteps(num_steps)[0].expand(8)
xN = sampling_scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, sampling_scheduler, num_steps=num_steps)
Example: Vector-Space Diffusion (Non-Image Data)#
The diffusion framework is not limited to image data. Any tensor-valued data
can be used, including 1D vectors. Here the backbone uses the
FullyConnected model from PhysicsNeMo,
wrapped with a thin adapter to match the
DiffusionModel interface.
import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.mlp import FullyConnected
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample
# Backbone: wrap FullyConnected to match the DiffusionModel interface
class FCBackbone(Module):
def __init__(self, dim, hidden=256, num_layers=4):
super().__init__()
self.net = FullyConnected(
in_features=dim, layer_size=hidden,
out_features=dim, num_layers=num_layers,
)
def forward(self, x, t, condition=None):
return self.net(x)
data_dim = 32
backbone = FCBackbone(dim=data_dim)
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=1.0)
# ... train with MSEDSMLoss(precond, scheduler) ...
# Sampling
precond.eval()
x0_predictor = partial(precond, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(16)
xN = scheduler.init_latents((data_dim,), tN) # 1D latent: shape (16, 32)
samples = sample(denoiser, xN, scheduler, num_steps=num_steps)
# samples.shape: (16, 32)
Example: Conditional Sampling#
For conditional generation (for example, super-resolution), the model backbone
processes both the noisy latent state and the conditioning input. A common
pattern is to concatenate the conditioning image along the channel dimension
inside a thin adapter. At sampling time, the conditioning is bound into the
predictor (\(P\)) via functools.partial.
import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample
C_x, C_cond, res = 3, 3, 64 # Image channels, conditioning channels, resolution
# Backbone: SongUNet wrapped with an adapter that concatenates the
# conditioning image along the channel dimension
class ConditionalUNet(Module):
def __init__(self):
super().__init__()
self.net = SongUNet(
img_resolution=res,
in_channels=C_x + C_cond,
out_channels=C_x,
model_channels=64,
channel_mult=[1, 2, 2],
num_blocks=2,
)
def forward(self, x, t, condition=None):
x_cat = torch.cat([x, condition], dim=1)
return self.net(x_cat, noise_labels=t, class_labels=None)
backbone = ConditionalUNet()
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler, condition=...) ...
# --- Sampling ---
precond.eval()
# Conditioning image (for example, low-resolution input for super-resolution)
low_res = torch.randn(4, C_cond, res, res)
# Bind condition into the predictor "P"
x0_predictor = partial(precond, condition=low_res)
# Convert to denoiser "D" and sample
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((C_x, res, res), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)
# samples.shape: (4, 3, 64, 64)
Example: Conditional Sampling with DPS Guidance#
DPS (Diffusion Posterior Sampling) guidance steers the sampling toward
satisfying observation constraints. Guidance modifies the predictor
\(P\) in the sampling equation:
the guidance objects are combined with the x0-predictor into a guided
score-predictor via DPSScorePredictor,
and then converted to a denoiser \(D\) via the noise scheduler.
import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.guidance import DPSScorePredictor, DataConsistencyDPSGuidance
from physicsnemo.diffusion.samplers import sample
scheduler = EDMNoiseScheduler()
# Build predictor "P" from trained conditional model
x0_predictor = partial(trained_model, condition=condition)
# Build guidance objects
mask = torch.zeros(4, 3, 64, 64, dtype=torch.bool)
mask[:, :, ::8, ::8] = True # Observe every 8th pixel
y_obs = torch.randn(4, 3, 64, 64)
guidance = DataConsistencyDPSGuidance(mask=mask, y=y_obs, std_y=0.1)
# Combine predictor + guidance into a guided score-predictor
guided_score_predictor = DPSScorePredictor(
x0_predictor=x0_predictor,
x0_to_score_fn=scheduler.x0_to_score,
guidances=guidance,
)
# Convert to denoiser "D" via the noise scheduler
denoiser = scheduler.get_denoiser(score_predictor=guided_score_predictor)
tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)
Example: Custom Solver and Custom Time-Steps#
This example shows how to define a solver from scratch by implementing the
Solver protocol. Any object
with a step(x, t_cur, t_next) method can serve as \(\text{Step}\) in
the sampling equation. Here we implement
a simple implicit trapezoidal rule (second-order Runge-Kutta), and pair it with
custom time-steps and trajectory snapshots.
import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.samplers import sample
# Custom solver: implicit trapezoidal rule (second-order)
class TrapezoidalSolver:
def __init__(self, denoiser, num_inner_iters=3):
self.denoiser = denoiser
self.num_inner_iters = num_inner_iters
def step(self, x, t_cur, t_next):
t_cur_bc = t_cur.reshape(-1, *([1] * (x.ndim - 1)))
t_next_bc = t_next.reshape(-1, *([1] * (x.ndim - 1)))
h = t_next_bc - t_cur_bc
d_cur = self.denoiser(x, t_cur)
# Predictor: Euler step to get initial guess
x_next = x + h * d_cur
# Corrector: fixed-point iterations for the implicit trapezoidal rule
for _ in range(self.num_inner_iters):
d_next = self.denoiser(x_next, t_next)
x_next = x + 0.5 * h * (d_cur + d_next)
return x_next
scheduler = EDMNoiseScheduler()
x0_predictor = partial(trained_model, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
# Custom time-steps
custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 2.0, 1.0, 0.5, 0.1, 0.0])
tN = custom_t[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)
solver = TrapezoidalSolver(denoiser, num_inner_iters=3)
trajectory = sample(
denoiser, xN, scheduler,
num_steps=0, # Ignored when time_steps is provided
time_steps=custom_t,
solver=solver,
time_eval=[0, 4, 7], # Collect snapshots at steps 0, 4, 7
)
# trajectory is a list of 3 tensors, each of shape (4, 3, 64, 64)
Available Solvers#
Solvers implement the \(\text{Step}\) operator in the sampling equation. At each iteration, the solver receives the denoiser output \(D(\mathbf{x}_n, t_n)\) and advances the latent state from time \(t_n\) to \(t_{n-1}\).
There are two ways to use solvers:
Built-in solvers can be selected by passing a string key to sample():
"euler"—EulerSolver. First-order. Fast (one denoiser evaluation per step) but lower quality."heun"—HeunSolver. Second-order. Higher quality but twice as expensive per step."edm_stochastic_euler"—EDMStochasticEulerSolver. First-order with configurable stochastic noise injection."edm_stochastic_heun"—EDMStochasticHeunSolver. Second-order with configurable stochastic noise injection.
Custom solvers can be defined by implementing the
Solver protocol: any object
with a step(x, t_cur, t_next) method. Pass the instance directly to
sample() for full control over the integration method.
Guidance#
Guidance techniques modify the predictor \(P\) in the sampling equation to steer the generated samples toward desired properties, such as consistency with observed data, satisfaction of physical constraints, or other task-specific objectives.
In the framework, guidance operates at the
Predictor level: you compose or modify
predictors before converting them into a
Denoiser \(D\). Concretely, guidance
objects implement the DPSGuidance
protocol and are combined with an x0-predictor into a guided score-predictor
via DPSScorePredictor. The resulting
score-predictor is then passed to the noise scheduler’s
get_denoiser()
factory to obtain a denoiser \(D\) for sampling. The sampler and solver
(\(\text{Step}\)) are unchanged—they only see the final denoiser.
The framework provides two ready-to-use guidance implementations:
DataConsistencyDPSGuidance— For masked observations (inpainting, sparse probes, data assimilation).ModelConsistencyDPSGuidance— For generic (potentially nonlinear) observation operators.
Custom guidances can be defined by implementing the
DPSGuidance protocol—any callable
with the signature (x, t, x_0) -> guidance_term. Multiple guidances can
be combined by passing a list to
DPSScorePredictor.
API Reference#
Sample Entry Point#
sample#
Solvers#
Solver#
EulerSolver#
HeunSolver#
EDMStochasticEulerSolver#
EDMStochasticHeunSolver#
Guidance#
DPSGuidance#
- class physicsnemo.diffusion.guidance.DPSGuidance(*args, **kwargs)[source]#
Protocol defining the interface for Diffusion Posterior Sampling (DPS) guidance.
A DPS guidance is a callable that computes a guidance term to steer the diffusion sampling process toward satisfying some observation constraint. It returns a quantity analogous to a likelihood score, which is typically added to the unconditional score during sampling.
The typical form is:
\[\rho(t) \nabla_{\mathbf{x}} \ell(A(\hat{\mathbf{x}}_0) - \mathbf{y})\]where \(\rho(t)\) is a time-dependent guidance strength, \(A\) is a (potentially nonlinear) observation operator, \(\mathbf{y}\) is the observed data, and \(\ell\) is a scalar loss function. However, variants are possible as long as the guidance produces a quantity similar to a score (e.g., a likelihood score).
This is the minimal interface for guidance, and any object that implements this interface can be used with
DPSScorePredictorto build a guided score-predictor, which implements thePredictorinterface.See also
DPSScorePredictorCombines an x0-predictor with one or more guidances.
Examples
Example 1: Minimal guidance for inpainting. Given a binary mask and observed pixels, guide the diffusion to match observations:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSGuidance >>> >>> class InpaintingGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask # Binary mask: 1 = observed, 0 = missing ... self.y_obs = y_obs # Observed pixel values ... self.gamma = gamma ... ... def __call__(self, x, t, x_0): ... # Compute residual at observed locations ... residual = self.mask * (x_0 - self.y_obs) ... # Gradient of L2 loss w.r.t. x_0 is just the residual ... # (simplified: assumes identity observation operator) ... return -self.gamma * residual ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = InpaintingGuidance(mask, y_obs) >>> isinstance(guidance, DPSGuidance) True
Example 2: Building a guided score predictor from scratch. A common pattern is to combine an x0-predictor with a guidance to create a score predictor that can be used for sampling. This shows the complete workflow:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSGuidance >>> >>> # Define a guidance that pushes toward observed values >>> class MyGuidance: ... def __init__(self, y_obs, gamma=0.1): ... self.y_obs = y_obs ... self.gamma = gamma ... ... def __call__(self, x, t, x_0): ... return -self.gamma * (x_0 - self.y_obs) ... >>> # Toy x0-predictor (in practice, a trained neural network) >>> x0_predictor = lambda x, t: x * 0.9 >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = MyGuidance(y_obs, gamma=0.5) >>> >>> # Build a guided score predictor that combines x0-predictor + guidance >>> def guided_score_predictor(x, t): ... x_0 = x0_predictor(x, t) ... guidance_term = guidance(x, t, x_0) ... # Convert x0 to score (for EDM: score = (x_0 - x) / t^2) ... expected_shape = (-1,) + (1,) * (x.ndim - 1) ... t_bc = t.reshape(expected_shape) ... score = (x_0 - x) / (t_bc ** 2) ... return score + guidance_term ... >>> # guided_score_predictor is now a Predictor (score predictor); pass it >>> # to scheduler.get_denoiser(score_predictor=...) to obtain a Denoiser >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> output = guided_score_predictor(x, t) >>> output.shape torch.Size([1, 3, 8, 8])
Note:
DPSScorePredictorprovides a convenient way to apply one or more guidances to an x0-predictor without manually implementing the above pattern.
DPSScorePredictor#
- class physicsnemo.diffusion.guidance.DPSScorePredictor(
- x0_predictor: Predictor,
- x0_to_score_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *dims']],
- guidances: DPSGuidance | Sequence[DPSGuidance],
Bases:
PredictorScore predictor that combines an x0-predictor with DPS-style guidance.
This class transforms a
Predictor(specifically, an x0-predictor) into a scorePredictorby applying one or more DPS guidances. The resulting score predictor can be passed toget_denoiser()to obtain aDenoiserfor sampling.The output is the sum of the unconditional score (derived from the x0-prediction) and all guidance terms:
\[\nabla_{\mathbf{x}} \log p(\mathbf{x}) + \sum_i g_i(\mathbf{x}, t, \hat{\mathbf{x}}_0)\]where \(g_i\) are the guidance terms implementing the
DPSGuidanceinterface.Each guidance must implement the
DPSGuidanceprotocol, which is a callable with the following signature:def guidance( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) x_0: Tensor, # shape: (B, *dims) ) -> Tensor: ... # guidance term, shape: (B, *dims)
Important
When using multiple guidances that internally call
torch.autograd.grad(e.g.,ModelConsistencyDPSGuidanceorDataConsistencyDPSGuidance), each guidance except the last must be constructed withretain_graph=True. Otherwise the computational graph is destroyed after the first guidance computes its gradient and subsequent guidances will fail. With a single guidance this is not needed.- Parameters:
x0_predictor (Predictor) – A
Predictorthat takes(x, t)and returns an estimate of the clean data \(\hat{\mathbf{x}}_0\). Typically obtained from a trainedDiffusionModelviafunctools.partial.x0_to_score_fn (Callable[[Tensor, Tensor, Tensor], Tensor]) – Callback to convert x0-prediction to score. Signature:
x0_to_score_fn(x_0, x, t) -> score. Typically obtained from a noise scheduler, e.g.,x0_to_score().guidances (DPSGuidance | Sequence[DPSGuidance]) – One or more guidance objects implementing the
DPSGuidanceinterface.
See also
DPSGuidanceProtocol for guidance implementations.
PredictorProtocol satisfied by this class.
get_denoiser()Converts the score-predictor to a denoiser for sampling.
Examples
Example 1: Basic usage with a single guidance for inpainting:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSScorePredictor, DPSGuidance >>> >>> # Toy x0-predictor (in practice, this is a trained neural network) >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> # Simple x0_to_score function (for EDM: score = (x_0 - x) / t^2) >>> def x0_to_score_fn(x_0, x, t): ... expected_shape = (-1,) + (1,) * (x.ndim - 1) ... t_bc = t.reshape(expected_shape) ... return (x_0 - x) / (t_bc ** 2) ... >>> # Simple inpainting guidance >>> class InpaintGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask ... self.y_obs = y_obs ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * self.mask * (x_0 - self.y_obs) ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = InpaintGuidance(mask, y_obs) >>> >>> # Create DPS score predictor >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> >>> # Use in sampling >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> output = dps_score_pred(x, t) >>> output.shape torch.Size([1, 3, 8, 8])
Example 2: Multiple guidances for multi-constraint problems:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSScorePredictor >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> # Use scheduler to get x0_to_score_fn >>> scheduler = EDMNoiseScheduler() >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> # Guidance 1: match observed values at specific locations >>> class ObservationGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask ... self.y_obs = y_obs ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * self.mask * (x_0 - self.y_obs) ... >>> # Guidance 2: regularization toward zero mean >>> class ZeroMeanGuidance: ... def __init__(self, gamma=0.1): ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * x_0.mean() * torch.ones_like(x_0) ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance1 = ObservationGuidance(mask, y_obs) >>> guidance2 = ZeroMeanGuidance() >>> >>> # Combine multiple guidances >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=[guidance1, guidance2], ... ) >>> >>> x = torch.randn(2, 3, 8, 8) >>> t = torch.tensor([1.0, 1.0]) >>> output = dps_score_pred(x, t) >>> output.shape torch.Size([2, 3, 8, 8])
Example 3: Multiple autograd-based guidances require
retain_graph=Trueon all but the last:>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DPSScorePredictor, ... DataConsistencyDPSGuidance, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> mask1 = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask1[:, :, 2, 3] = True >>> mask2 = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask2[:, :, 5, 6] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> # First guidance retains the graph for the second one >>> g1 = DataConsistencyDPSGuidance( ... mask=mask1, y=y_obs, std_y=0.1, retain_graph=True, ... ) >>> # Last guidance does not need retain_graph >>> g2 = DataConsistencyDPSGuidance( ... mask=mask2, y=y_obs, std_y=0.1, ... ) >>> >>> dps = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=[g1, g2], ... ) >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> dps(x, t).shape torch.Size([1, 3, 8, 8])
ModelConsistencyDPSGuidance#
- class physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance(
- observation_operator: Callable[[Float[Tensor, 'B *dims']], Float[Tensor, 'B *obs_dims']],
- y: Float[Tensor, 'B *obs_dims'],
- std_y: float,
- norm: int | Callable[[Float[Tensor, 'B *obs_dims'], Float[Tensor, 'B *obs_dims']], Float[Tensor, 'B']] = 2,
- gamma: float = 0.0,
- sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- retain_graph: bool = False,
- create_graph: bool = False,
Bases:
DPSGuidanceDPS guidance for generic observation models with Gaussian noise.
Implements the
DPSGuidanceinterface for generic (possibly nonlinear) observation operators.Computes the likelihood score assuming Gaussian measurement noise with standard deviation
std_y. The guidance term is:\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| A\left(\hat{\mathbf{x}}_0\right) - \mathbf{y} \|^2\]where \(A\) is the observation operator and the scaling incorporates a Score-Based Data Assimilation (SDA) correction through the parameter \(\Gamma\) that accounts for the covariance of the \(\hat{\mathbf{x}}_0(\mathbf{x}_t, t)\) estimate at different diffusion times. The L2 norm can be replaced by other Lp norms or custom loss functions via the
normparameter.The
observation_operatormust be a differentiable callable with the following signature:def observation_operator( x_0: Tensor, # shape: (B, *dims) ) -> Tensor: ... # predicted observations, shape: (B, *obs_dims)
When
normis a callable, it must have the following signature:def norm( y_pred: Tensor, # shape: (B, *obs_dims) y_true: Tensor, # shape: (B, *obs_dims) ) -> Tensor: ... # scalar loss per batch element, shape: (B,)
- Parameters:
observation_operator (Callable[[Tensor], Tensor]) – Observation operator mapping clean state to observations. Must be differentiable (supports
torch.autograd).y (Tensor) – Observed data of shape \((B, *obs\_dims)\) matching the output of
A.std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).
norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An
intvalue (default2) uses the corresponding Lp norm. A callable receives(y_pred, y_true)and returns a scalar loss per batch element of shape \((B,)\).gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When
gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to0for classical DPS without SDA scaling.sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when
gamma > 0. Typically obtained from a noise scheduler, e.g.,sigma()for a linear-Gaussian noise schedule.alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. Typically obtained from a noise scheduler, e.g.,
alpha()for a linear-Gaussian noise schedule.retain_graph (bool, default=False) – If
True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a singleDPSScorePredictor— all guidances except the last must set this toTrue.create_graph (bool, default=False) – If
True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).
Note
References:
See also
DataConsistencyDPSGuidanceSimplified guidance for masked observations.
DPSScorePredictorCombines an x0-predictor with one or more guidances.
Examples
Example 1: Super-resolution with a nonlinear blur + downsampling operator:
>>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import ( ... ModelConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> >>> # Observation operator: Gaussian blur + 2x downsampling >>> def blur_downsample(x): ... # Apply 3x3 Gaussian-like blur ... kernel = torch.ones(1, 1, 3, 3, device=x.device) / 9 ... kernel = kernel.expand(x.shape[1], 1, 3, 3) ... blurred = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) ... # Downsample 2x ... return F.avg_pool2d(blurred, kernel_size=2, stride=2) ... >>> # Low-resolution observations (4x4 from 8x8 high-res) >>> y_obs = torch.randn(1, 3, 4, 4) >>> >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=blur_downsample, ... y=y_obs, ... std_y=0.1, ... ) >>> >>> # Use in DPS sampling >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Toy x0 estimate >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Combine with DPSScorePredictor for complete sampling workflow >>> x0_predictor = lambda x, t: x * 0.9 >>> def x0_to_score_fn(x_0, x, t): ... expected_shape = (-1,) + (1,) * (x.ndim - 1) ... t_bc = t.reshape(expected_shape) ... return (x_0 - x) / (t_bc ** 2) ... >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 2: With SDA scaling using noise scheduler methods:
>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... ModelConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Linear observation operator (select first channel) >>> A = lambda x: x[:, :1] >>> y_obs = torch.randn(1, 1, 8, 8) >>> >>> # Enable SDA scaling with gamma > 0, providing sigma and alpha functions >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=A, ... y=y_obs, ... std_y=0.075, ... gamma=0.05, # Enable SDA scaling ... sigma_fn=scheduler.sigma, ... alpha_fn=scheduler.alpha, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor and scheduler's x0_to_score >>> x0_predictor = lambda x, t: x * 0.9 >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 3: With a custom loss function (Huber loss):
>>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import ModelConsistencyDPSGuidance >>> >>> # Wrap torch's Huber loss to return per-batch scalars >>> def huber_loss(y_pred, y_true): ... per_elem = F.huber_loss(y_pred, y_true, reduction="none") ... return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1) ... >>> A = lambda x: x[:, :1] # Select first channel >>> y_obs = torch.randn(1, 1, 8, 8) >>> >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=A, ... y=y_obs, ... std_y=0.1, ... norm=huber_loss, # Custom loss function ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8])
DataConsistencyDPSGuidance#
- class physicsnemo.diffusion.guidance.DataConsistencyDPSGuidance(
- mask: Bool[Tensor, 'B *dims'],
- y: Float[Tensor, 'B *dims'],
- std_y: float,
- norm: int | Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims']], Float[Tensor, 'B']] = 2,
- gamma: float = 0.0,
- sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- retain_graph: bool = False,
- create_graph: bool = False,
Bases:
DPSGuidanceDPS guidance for masked observations with Gaussian noise.
Implements the
DPSGuidanceinterface for masked observation operators, a simplified version ofModelConsistencyDPSGuidance. This is typical for data assimilation tasks like inpainting, outpainting, or sparse observations, where measurements are available at specific locations.Computes the likelihood score assuming Gaussian measurement noise with standard deviation
std_y. The guidance term is:\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| \mathbf{M} \odot (\hat{\mathbf{x}}_0 - \mathbf{y}) \|^2\]where \(\mathbf{M}\) is a binary mask (1 = observed, 0 = missing), \(\odot\) denotes element-wise multiplication, and the scaling incorporates an SDA correction through the parameter \(\Gamma\). The L2 norm can be replaced by other Lp norms or custom loss functions via the
normparameter.When
normis a callable, it must have the following signature:def norm( y_pred: Tensor, # shape: (B, *obs_dims) y_true: Tensor, # shape: (B, *obs_dims) ) -> Tensor: ... # scalar loss per batch element, shape: (B,)
- Parameters:
mask (Tensor) – Boolean mask of shape \((B, *)\) matching the state shape.
Truefor observed locations,Falsefor missing.y (Tensor) – Observed data of shape \((B, *)\) matching the state shape. Values at unobserved locations (where
mask=0) are ignored.std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).
norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An
intvalue (default2) uses the corresponding Lp norm. A callable receives(mask * x_0, mask * y)and returns a scalar loss per batch element of shape \((B,)\).gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When
gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to0for classical DPS without SDA scaling.sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when
gamma > 0. Typically obtained from a noise scheduler. For example, usesigma()for a linear-Gaussian noise schedule.alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. For example, use
alpha()for a linear-Gaussian noise schedule.retain_graph (bool, default=False) – If
True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a singleDPSScorePredictor— all guidances except the last must set this toTrue.create_graph (bool, default=False) – If
True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).
Note
References:
See also
ModelConsistencyDPSGuidanceGuidance for general observation operators.
DPSScorePredictorCombines an x0-predictor with one or more guidances.
Examples
Example 1: Sparse observations at probe locations:
>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DataConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> >>> # Boolean mask: only observe a few probe locations >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True # Probe at (2, 3) >>> mask[:, :, 5, 6] = True # Probe at (5, 6) >>> mask[:, :, 1, 7] = True # Probe at (1, 7) >>> y_obs = torch.randn(1, 3, 8, 8) # Observed values >>> >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.1, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Toy x0 estimate (must be computed from x) >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor for complete sampling workflow >>> x0_predictor = lambda x, t: x * 0.9 >>> def x0_to_score_fn(x_0, x, t): ... expected_shape = (-1,) + (1,) * (x.ndim - 1) ... t_bc = t.reshape(expected_shape) ... return (x_0 - x) / (t_bc ** 2) ... >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 2: With SDA scaling and L1 norm using noise scheduler:
>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DataConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Same sparse probe locations as Example 1 >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True >>> mask[:, :, 5, 6] = True >>> mask[:, :, 1, 7] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> # Enable SDA scaling and use L1 norm for robustness >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.075, ... norm=1, # L1 norm ... gamma=1.0, # Enable SDA scaling ... sigma_fn=scheduler.sigma, ... alpha_fn=scheduler.alpha, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Must be computed from x >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor and scheduler's x0_to_score >>> x0_predictor = lambda x, t: x * 0.9 >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 3: With a custom loss function (Huber loss):
>>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import DataConsistencyDPSGuidance >>> >>> # Wrap torch's Huber loss to return per-batch scalars >>> def huber_loss(y_pred, y_true): ... per_elem = F.huber_loss(y_pred, y_true, reduction="none") ... return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1) ... >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True >>> mask[:, :, 5, 6] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.1, ... norm=huber_loss, # Custom loss function ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8])