Samplers and Solvers#
The sampler is the main interface for generating new data from a trained diffusion model. Starting from pure noise \(\mathbf{x}_N\), the solver iteratively denoises the latent state through a sequence of time-steps until it reaches a clean sample \(\mathbf{x}_0\).
The central entry point is the sample() function. It takes a
Denoiser, an initial noisy latent
\(\mathbf{x}_N\), and a
NoiseScheduler, and iterates
the reverse process to produce samples.
Generic Sampling Process#
The sample() function supports any reverse process that can be written
in the form:
This equation is the foundation of the sampling process in the framework. Every component described in this page maps to one of the three terms:
\(P\) is the predictor — the
Predictorthat maps the noisy state and diffusion time to a prediction. This is where all model logic lives, including conditioning and guidance.\(D\) is the denoiser — the
Denoiserderived from \(P\) via the noise scheduler’sget_denoiser()factory.\(\text{Step}\) is the solver’s numerical update rule.
This generic formulation encompasses standard ODE/SDE-based sampling, but
also more advanced methods such as physics-informed posterior guidance
(DPS), score-based data assimilation
(SDA), and others. Any
method that can express its update step through this denoiser/solver
decomposition can be used with the sample() function.
Sampling Workflow#
A complete sampling workflow involves these steps:
Load or reference a trained model satisfying the
DiffusionModelinterface (typically a backbone wrapped in a preconditioner).Build a Predictor (\(P\) in the sampling equation) by binding the conditioning via
functools.partial, converting the three-argumentDiffusionModelinto a two-argumentPredictor.Convert to a Denoiser (\(D\) in the equation). There are two paths:
Without guidance — pass the predictor directly to the noise scheduler’s
get_denoiser()factory (as anx0_predictororscore_predictor).With guidance — first instantiate one or more
DPSGuidanceobjects, then combine them with the predictor usingDPSScorePredictorto obtain a guided score-predictor. Finally, pass this guided score-predictor toget_denoiser().
Initialize the noisy latent \(\mathbf{x}_N\) and the time-step schedule using the scheduler.
Optionally configure a custom solver (\(\text{Step}\) in the equation) by instantiating a
Solver(see Available Solvers), or simply pass a built-in string key (for example,"heun") tosample().Call
sample()to run the reverse diffusion loop.
Example: Unconditional Image Generation#
This example shows the full workflow for an unconditional image model trained
with the EDM formulation. It uses
SongUNet as the backbone,
wrapped with a thin adapter to match the
DiffusionModel interface, and
EDMPreconditioner for
preconditioning.
The noise scheduler must generally be consistent between training and
sampling—in particular, the same schedule family (for example, EDM, VP)
should be used. Schedule parameters (for example, sigma_min, rho) can
be adjusted at sampling time for experimentation, but the model was
optimized for the training schedule, so large deviations may degrade
sample quality.
import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample
# --- Backbone: wrap SongUNet to match the DiffusionModel interface ---
class UNetBackbone(Module):
def __init__(self, img_resolution, channels, **kwargs):
super().__init__()
self.net = SongUNet(
img_resolution=img_resolution,
in_channels=channels,
out_channels=channels,
**kwargs,
)
def forward(self, x, t, condition=None):
return self.net(x, noise_labels=t, class_labels=condition)
backbone = UNetBackbone(img_resolution=64, channels=3, model_channels=64,
channel_mult=[1, 2, 2], num_blocks=2)
# --- Preconditioner + training (sketch) ---
scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=80.0, rho=7)
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler) ...
# --- Sampling ---
precond.eval()
# Build predictor "P": bind condition=None for unconditional model
x0_predictor = partial(precond, condition=None)
# Convert to denoiser "D"
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
# Initialize time-steps and noisy latent
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(8)
xN = scheduler.init_latents((3, 64, 64), tN)
# Run sampling loop with Heun solver ("Step")
samples = sample(denoiser, xN, scheduler, num_steps=num_steps, solver="heun")
# samples.shape: (8, 3, 64, 64)
It is also possible to adjust the schedule parameters at sampling time. For
instance, one might increase sigma_max or change rho to explore the
effect on sample quality:
sampling_scheduler = EDMNoiseScheduler(sigma_min=0.002, sigma_max=120.0, rho=5)
denoiser = sampling_scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = sampling_scheduler.timesteps(num_steps)[0].expand(8)
xN = sampling_scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, sampling_scheduler, num_steps=num_steps)
Example: Vector-Space Diffusion (Non-Image Data)#
The diffusion framework is not limited to image data. Any tensor-valued data
can be used, including 1D vectors. Here the backbone uses the
FullyConnected model from PhysicsNeMo,
wrapped with a thin adapter to match the
DiffusionModel interface.
import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.mlp import FullyConnected
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample
# Backbone: wrap FullyConnected to match the DiffusionModel interface
class FCBackbone(Module):
def __init__(self, dim, hidden=256, num_layers=4):
super().__init__()
self.net = FullyConnected(
in_features=dim, layer_size=hidden,
out_features=dim, num_layers=num_layers,
)
def forward(self, x, t, condition=None):
return self.net(x)
data_dim = 32
backbone = FCBackbone(dim=data_dim)
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=1.0)
# ... train with MSEDSMLoss(precond, scheduler) ...
# Sampling
precond.eval()
x0_predictor = partial(precond, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
num_steps = 50
tN = scheduler.timesteps(num_steps)[0].expand(16)
xN = scheduler.init_latents((data_dim,), tN) # 1D latent: shape (16, 32)
samples = sample(denoiser, xN, scheduler, num_steps=num_steps)
# samples.shape: (16, 32)
Example: Conditional Sampling#
For conditional generation (for example, super-resolution), the model backbone
processes both the noisy latent state and the conditioning input. A common
pattern is to concatenate the conditioning image along the channel dimension
inside a thin adapter. At sampling time, the conditioning is bound into the
predictor (\(P\)) via functools.partial.
import torch
from functools import partial
from physicsnemo.core import Module
from physicsnemo.models.diffusion_unets import SongUNet
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.samplers import sample
C_x, C_cond, res = 3, 3, 64 # Image channels, conditioning channels, resolution
# Backbone: SongUNet wrapped with an adapter that concatenates the
# conditioning image along the channel dimension
class ConditionalUNet(Module):
def __init__(self):
super().__init__()
self.net = SongUNet(
img_resolution=res,
in_channels=C_x + C_cond,
out_channels=C_x,
model_channels=64,
channel_mult=[1, 2, 2],
num_blocks=2,
)
def forward(self, x, t, condition=None):
x_cat = torch.cat([x, condition], dim=1)
return self.net(x_cat, noise_labels=t, class_labels=None)
backbone = ConditionalUNet()
scheduler = EDMNoiseScheduler()
precond = EDMPreconditioner(backbone, sigma_data=0.5)
# ... train with MSEDSMLoss(precond, scheduler, condition=...) ...
# --- Sampling ---
precond.eval()
# Conditioning image (for example, low-resolution input for super-resolution)
low_res = torch.randn(4, C_cond, res, res)
# Bind condition into the predictor "P"
x0_predictor = partial(precond, condition=low_res)
# Convert to denoiser "D" and sample
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((C_x, res, res), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)
# samples.shape: (4, 3, 64, 64)
Example: Conditional Sampling with DPS Guidance#
DPS (Diffusion Posterior Sampling) guidance steers the sampling toward
satisfying observation constraints. Guidance modifies the predictor
\(P\) in the sampling equation:
the guidance objects are combined with the x0-predictor into a guided
score-predictor via DPSScorePredictor,
and then converted to a denoiser \(D\) via the noise scheduler.
import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.guidance import DPSScorePredictor, DataConsistencyDPSGuidance
from physicsnemo.diffusion.samplers import sample
scheduler = EDMNoiseScheduler()
# Build predictor "P" from trained conditional model
x0_predictor = partial(trained_model, condition=condition)
# Build guidance objects
mask = torch.zeros(4, 3, 64, 64, dtype=torch.bool)
mask[:, :, ::8, ::8] = True # Observe every 8th pixel
y_obs = torch.randn(4, 3, 64, 64)
guidance = DataConsistencyDPSGuidance(mask=mask, y=y_obs, std_y=0.1)
# Combine predictor + guidance into a guided score-predictor
guided_score_predictor = DPSScorePredictor(
x0_predictor=x0_predictor,
x0_to_score_fn=scheduler.x0_to_score,
guidances=guidance,
)
# Convert to denoiser "D" via the noise scheduler
denoiser = scheduler.get_denoiser(score_predictor=guided_score_predictor)
tN = scheduler.timesteps(50)[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)
samples = sample(denoiser, xN, scheduler, num_steps=50)
Example: Custom Solver and Custom Time-Steps#
This example shows how to define a solver from scratch by implementing the
Solver protocol. Any object
with a step(x, t_cur, t_next) method can serve as \(\text{Step}\) in
the sampling equation. Here we implement
a simple implicit trapezoidal rule (second-order Runge-Kutta), and pair it with
custom time-steps and trajectory snapshots.
import torch
from functools import partial
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.samplers import sample
# Custom solver: implicit trapezoidal rule (second-order)
class TrapezoidalSolver:
def __init__(self, denoiser, num_inner_iters=3):
self.denoiser = denoiser
self.num_inner_iters = num_inner_iters
def step(self, x, t_cur, t_next):
t_cur_bc = t_cur.reshape(-1, *([1] * (x.ndim - 1)))
t_next_bc = t_next.reshape(-1, *([1] * (x.ndim - 1)))
h = t_next_bc - t_cur_bc
d_cur = self.denoiser(x, t_cur)
# Predictor: Euler step to get initial guess
x_next = x + h * d_cur
# Corrector: fixed-point iterations for the implicit trapezoidal rule
for _ in range(self.num_inner_iters):
d_next = self.denoiser(x_next, t_next)
x_next = x + 0.5 * h * (d_cur + d_next)
return x_next
scheduler = EDMNoiseScheduler()
x0_predictor = partial(trained_model, condition=None)
denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor)
# Custom time-steps
custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 2.0, 1.0, 0.5, 0.1, 0.0])
tN = custom_t[0].expand(4)
xN = scheduler.init_latents((3, 64, 64), tN)
solver = TrapezoidalSolver(denoiser, num_inner_iters=3)
trajectory = sample(
denoiser, xN, scheduler,
num_steps=0, # Ignored when time_steps is provided
time_steps=custom_t,
solver=solver,
time_eval=[0, 4, 7], # Collect snapshots at steps 0, 4, 7
)
# trajectory is a list of 3 tensors, each of shape (4, 3, 64, 64)
Available Solvers#
Solvers implement the \(\text{Step}\) operator in the sampling equation. At each iteration, the solver receives the denoiser output \(D(\mathbf{x}_n, t_n)\) and advances the latent state from time \(t_n\) to \(t_{n-1}\).
There are two ways to use solvers:
Built-in solvers can be selected by passing a string key to sample():
"euler"—EulerSolver. First-order. Fast (one denoiser evaluation per step) but lower quality."heun"—HeunSolver. Second-order. Higher quality but twice as expensive per step."edm_stochastic_euler"—EDMStochasticEulerSolver. First-order with configurable stochastic noise injection."edm_stochastic_heun"—EDMStochasticHeunSolver. Second-order with configurable stochastic noise injection.
Custom solvers can be defined by implementing the
Solver protocol: any object
with a step(x, t_cur, t_next) method. Pass the instance directly to
sample() for full control over the integration method.
Guidance#
Guidance techniques modify the predictor \(P\) in the sampling equation to steer the generated samples toward desired properties, such as consistency with observed data, satisfaction of physical constraints, or other task-specific objectives.
In the framework, guidance operates at the
Predictor level: you compose or modify
predictors before converting them into a
Denoiser \(D\). Concretely, guidance
objects implement the DPSGuidance
protocol and are combined with an x0-predictor into a guided score-predictor
via DPSScorePredictor. The resulting
score-predictor is then passed to the noise scheduler’s
get_denoiser()
factory to obtain a denoiser \(D\) for sampling. The sampler and solver
(\(\text{Step}\)) are unchanged—they only see the final denoiser.
The framework provides two ready-to-use guidance implementations:
DataConsistencyDPSGuidance— For masked observations (inpainting, sparse probes, data assimilation).ModelConsistencyDPSGuidance— For generic (potentially nonlinear) observation operators.
Custom guidances can be defined by implementing the
DPSGuidance protocol—any callable
with the signature (x, t, x_0) -> guidance_term. Multiple guidances can
be combined by passing a list to
DPSScorePredictor.
API Reference#
Sample Entry Point#
sample#
- physicsnemo.diffusion.samplers.sample(
- denoiser: Denoiser,
- xN: Float[Tensor, 'B *dims'],
- noise_scheduler: NoiseScheduler,
- num_steps: int,
- solver: Literal['euler', 'heun', 'edm_stochastic_euler', 'edm_stochastic_heun'] | Solver = 'heun',
- time_steps: Float[Tensor, 'N_plus_1'] | None = None,
- solver_options: Dict[str, Any] | None = None,
- time_eval: list[int] | None = None,
Generate batched samples from a diffusion model.
This interface is quite generic and can be used to generate samples from any reverse diffusion process of the form:
\[\mathbf{x}_{n-1} = G (\mathbf{x}_{i \geq n}, t_{i \geq n-1})\]This covers both ODE/SDE-based sampling (e.g. VP, VE, EDM) and discrete Markov chain-based sampling (e.g. DDPM). The exact expression of the operator \(G\) depends on the combination of:
The
solver, which determines the numerical method to update the latent state \(\mathbf{x}_n\) at each time-step.The
denoiser, which can be the right hand side for ODE/SDE-based sampling, the denoised latent state for discrete Markov chain-based sampling, etc.
Typically, the update applied is roughly:
\[\mathbf{x}_{n-1} = \text{Step}(D(\mathbf{x}_n, t_n); \mathbf{x}_n, t_n, t_{n-1})\]where \(D\) is the
denoiserand \(\text{Step}\) is the update rule of the solver, implemented by thestep()method. Variants are possible by passing more complex solvers and denoisers.The
solvercan be specified as a string key (with optionalsolver_options), or as a pre-configured object implementing theSolverinterface (in which casesolver_optionsmust beNone). The solver must implement astepmethod with the following signature:def step( self, x: Tensor, # shape: (B, *dims) t_cur: Tensor, # shape: (B,) t_next: Tensor, # shape: (B,) ) -> Tensor: ... # updated x, shape: (B, *dims)
Any object that implements the
Solverinterface can be used as a solver.The
denoisermust implement theDenoiserinterface, with the following signature:def denoiser( x: Tensor, # Noisy latent state, shape (B, *dims) t: Tensor, # Diffusion time, shape (B,) ) -> Tensor: # ODE/SDE RHS, same shape (B, *dims) as x
Any object that implements the
Denoiserinterface can be used as a denoiser. A denoiser is typically obtained from aPredictorusing the noise scheduler’sget_denoiser()factory.Time-steps are generated by the
noise_schedulerusing itstimesteps()method with the providednum_steps. To use custom time-steps, pass a 1D tensor totime_stepswhich will override the schedule’s time-steps.- Parameters:
denoiser (Denoiser) – A callable that takes
(x, t)and returns the denoising update term with the same shape as the latent statexN. SeeDenoiserfor the expected interface. Typically obtained via theget_denoiser()factory, which converts aPredictor(e.g., score-predictor, x0-predictor) into a denoiser.xN (Tensor) – Initial noisy latent state \(\mathbf{x}_N\) of shape \((B, *)\) where \(B\) is the batch size. All batch elements share the same diffusion time values. The
dtypeanddeviceofxNdetermine thedtypeanddeviceof the generated samples and any internally created tensors. Can usually be obtained by usinginit_latents()from a noise scheduler (typically from the same noise scheduler instance as thenoise_schedulerargument, but can be different if desired).noise_scheduler (NoiseScheduler) – The noise scheduler instance used for generating time-steps. The schedule’s
timesteps()method is called withnum_stepsto produce the diffusion time values, unlesstime_stepsis provided to override them.num_steps (int) – Number of sampling steps. Passed to the noise scheduler’s
timesteps()method. Ignored whentime_stepsis provided.solver (str | Solver, default="heun") –
The numerical solver to use. Supports three levels of customizability:
Basic: Pass a string key to use a built-in solver with default settings.
Moderately advanced: Pass a string key plus
solver_optionsto override default solver parameters.Advanced: Pass a custom
Solverinstance implementing theSolverinterface. In this case,solver_optionsmust be empty.Available string keys:
"euler": First-order Euler method. Fast but lower quality. SeeEulerSolver."heun": Second-order Heun method. Higher quality but requires two denoiser evaluations per step. SeeHeunSolver."edm_stochastic_euler": First-order stochastic sampler from the EDM paper with configurable noise injection. SeeEDMStochasticEulerSolver."edm_stochastic_heun": Second-order stochastic sampler from the EDM paper with configurable noise injection. SeeEDMStochasticHeunSolver.
time_steps (Tensor | None, default=None) – Optional 1D tensor of shape \((N + 1,)\) containing explicit diffusion time values \(t_N, t_{N-1}, ..., t_0\) in decreasing order. If provided, overrides the time-steps from
noise_schedulerandnum_stepsis ignored. To produce a fully denoised latent state \(\mathbf{x}_0\), the last element must be \(t_0 = 0\).solver_options (Dict[str, Any], default={}) – Additional options passed to the solver constructor. Only used when
solveris a string; must be empty whensolveris aSolverinstance. See individual solver classes for available options.time_eval (List[int] | None, default=None) – Indices of time-steps at which to return intermediate samples. If provided, returns a list of tensors. If
None, returns only the final denoised latent state \(\mathbf{x}_0\).
- Returns:
If
time_evalisNone, returns the final denoised latent state \(\mathbf{x}_0\) of shape \((B, *)\). Otherwise, returns a list of tensors \(\mathbf{x}_t\) of shape \((B, *)\) containing latent states at time-step indices specified intime_eval.- Return type:
Tensor | List[Tensor]
See also
solversAvailable ODE/SDE solvers.
noise_schedulersAvailable noise schedules.
Examples
Example 1: Minimal usage. Just provide a denoiser, initial noise, a scheduler, and the number of steps.
>>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> # Toy denoiser (in practice, this would be a trained neural network) >>> denoiser = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy denoiser >>> scheduler = EDMNoiseScheduler() >>> xN = torch.randn(2, 3, 8, 8) * 80 # Initial noise scaled by sigma_max >>> x0 = sample(denoiser, xN, scheduler, num_steps=10) >>> x0.shape torch.Size([2, 3, 8, 8])
Example 2: Standard pattern using scheduler methods. Use
init_latentsto generate initial noise andget_denoiserto convert a predictor to a denoiser for sampling.>>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> t_steps = scheduler.timesteps(10) >>> tN = t_steps[0].expand(2) # Initial time for batch of 2 >>> >>> # Use scheduler to generate initial latents at time tN >>> xN = scheduler.init_latents((3, 8, 8), tN) >>> >>> # Convert x0-predictor to denoiser (score conversion is automatic) >>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor >>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor) >>> >>> x0 = sample(denoiser, xN, scheduler, num_steps=10) >>> x0.shape torch.Size([2, 3, 8, 8])
Example 3: Custom time-steps and solver. Same as Example 2, but using explicit time-steps and the faster (but lower quality) Euler solver.
>>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Custom time-steps (fewer steps for faster sampling) >>> custom_t = torch.tensor([80.0, 40.0, 20.0, 10.0, 5.0, 0.0]) >>> tN = custom_t[0].expand(2) >>> xN = scheduler.init_latents((3, 8, 8), tN) >>> >>> # Same denoiser setup as Example 2 >>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor >>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor) >>> >>> # Use custom time-steps and Euler solver (num_steps ignored) >>> x0 = sample(denoiser, xN, scheduler, num_steps=0, time_steps=custom_t, ... solver="euler") >>> x0.shape torch.Size([2, 3, 8, 8])
Example 4: Bare-bone custom scheduler. Define a scheduler from scratch implementing the
NoiseSchedulerprotocol, without importing any built-in scheduler class.>>> import torch >>> from physicsnemo.diffusion.samplers import sample >>> >>> # Define a minimal EDM-like scheduler from scratch >>> class MinimalScheduler: ... def timesteps(self, num_steps, *, device=None, dtype=None): ... return torch.linspace(1.0, 0.0, num_steps + 1, ... device=device, dtype=dtype) ... def sample_time(self, N, *, device=None, dtype=None): ... return torch.rand(N, device=device, dtype=dtype) ... def add_noise(self, x0, time): ... return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0) ... def init_latents(self, spatial_shape, tN, *, device=None, ... dtype=None): ... return tN.view(-1, 1, 1, 1) * torch.randn( ... tN.shape[0], *spatial_shape, device=device, dtype=dtype) ... def get_denoiser(self, *, x0_predictor=None, **kwargs): ... # EDM-like: sigma=t, alpha=1, g^2=2t ... # score = (x0 - x) / t^2, ODE RHS = (x0 - x) / t ... def _denoiser(x, t): ... x0 = x0_predictor(x, t) ... t_bc = t.view(-1, *([1] * (x.ndim - 1))) ... return (x0 - x) / t_bc ... return _denoiser >>> >>> scheduler = MinimalScheduler() >>> tN = torch.tensor([1.0, 1.0]) >>> xN = scheduler.init_latents((3, 8, 8), tN) >>> >>> # x0-predictor -> denoiser via the scheduler factory >>> x0_predictor = lambda x, t: x / (1 + t.view(-1, *([1] * (x.ndim - 1)))**2) # Toy x0-predictor >>> denoiser = scheduler.get_denoiser(x0_predictor=x0_predictor) >>> x0 = sample(denoiser, xN, scheduler, num_steps=10, solver="euler") >>> x0.shape torch.Size([2, 3, 8, 8])
Solvers#
Solver#
- class physicsnemo.diffusion.samplers.solvers.Solver(*args, **kwargs)[source]#
Protocol defining the interface for diffusion solvers.
A solver implements a numerical method to integrate the diffusion process from a noisy state to a less noisy (or clean) state. Each call to
step()advances the state from timet_cur(\(t_n\)) tot_next(\(t_{n-1}\)).This is the minimal interface required for sampling from a diffusion model, and any object that implements this interface can be used as a solver in sampling utilities.
The update rule applied by the sampler is roughly:
\[\mathbf{x}_{n-1} = \text{Step}(F(\mathbf{x}_n, t_n); \mathbf{x}_n, t_n, t_{n-1})\]where \(F\) is the denoiser (e.g. the right hand side in the case of ODE/SDE-based sampling, the denoised latent state in the case of discrete Markov chain-based sampling, etc.) and \(\text{Step}\) is the update rule of the solver, implemented by the
step()method.See also
sample()The sampling function that uses solvers to generate samples.
Examples
>>> import torch >>> from physicsnemo.diffusion.samplers.solvers import Solver >>> >>> class SimpleEuler: ... def __init__(self, denoiser): ... self.denoiser = denoiser ... def step(self, x, t_cur, t_next): ... d = (x - self.denoiser(x, t_cur)) / t_cur ... return x + (t_next - t_cur) * d ... >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1)**2) # Toy denoiser >>> solver = SimpleEuler(denoiser) >>> isinstance(solver, Solver) True
- step(
- x: Float[Tensor, 'B *dims'],
- t_cur: Float[Tensor, 'B'],
- t_next: Float[Tensor, 'B'],
Perform one integration step from
t_curtot_next.- Parameters:
x (Tensor) – Current noisy latent state \(\mathbf{x}_{n}\) of shape \((B, *)\) where \(B\) is the batch size.
t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).
t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).
- Returns:
Updated latent state \(\mathbf{x}_{n-1}\) at time
t_next, same shape asx.- Return type:
Tensor
EulerSolver#
- class physicsnemo.diffusion.samplers.solvers.EulerSolver(denoiser: Denoiser)[source]#
Bases:
SolverFirst-order Euler solver for diffusion ODEs.
This is a fast solver with one denoiser evaluation per step, but typically produces lower quality samples compared to higher-order methods.
- Parameters:
denoiser (Denoiser) – A callable implementing the
Denoiserinterface. Here it is expected to return the right hand side of the ODE. Typically obtained viaget_denoiser(), but any callable with the correct signature can be used.
Examples
>>> import torch >>> from physicsnemo.diffusion.samplers.solvers import EulerSolver >>> >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = EulerSolver(denoiser) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) >>> isinstance(solver, Solver) True
- step(
- x: Float[Tensor, 'B *dims'],
- t_cur: Float[Tensor, 'B'],
- t_next: Float[Tensor, 'B'],
Perform one Euler integration step.
- Parameters:
x (Tensor) – Current noisy latent state \(\mathbf{x}_{n}\) of shape \((B, *)\) where \(B\) is the batch size.
t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).
t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).
- Returns:
Updated latent state \(\mathbf{x}_{n-1}\) at time
t_next, same shape asx.- Return type:
Tensor
HeunSolver#
- class physicsnemo.diffusion.samplers.solvers.HeunSolver(
- denoiser: Denoiser,
- alpha: float = 1.0,
Bases:
SolverSecond-order Heun solver for diffusion ODEs.
This method requires two denoiser evaluations per step but usually produces higher quality samples than
EulerSolver.- Parameters:
denoiser (Denoiser) – A callable implementing the
Denoiserinterface. Here it is expected to return the right hand side of the ODE. Typically obtained viaget_denoiser(), but any callable with the correct signature can be used.alpha (float, optional) – Interpolation parameter for the corrector step, must be in (0, 1].
alpha=1gives the standard Heun method (trapezoidal rule),alpha=0.5gives the midpoint method. By default 1.
Examples
>>> import torch >>> from physicsnemo.diffusion.samplers.solvers import HeunSolver >>> >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = HeunSolver(denoiser) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8])
- step(
- x: Float[Tensor, 'B *dims'],
- t_cur: Float[Tensor, 'B'],
- t_next: Float[Tensor, 'B'],
Perform one Heun integration step.
- Parameters:
x (Tensor) – Current noisy latent state \(\mathbf{x}_n\) of shape \((B, *)\) where \(B\) is the batch size.
t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).
t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).
- Returns:
Updated latent state \(\mathbf{x}_{n-1}\) at time
t_next, same shape asx.- Return type:
Tensor
EDMStochasticEulerSolver#
- class physicsnemo.diffusion.samplers.solvers.EDMStochasticEulerSolver(
- denoiser: Denoiser,
- S_churn: float = 0,
- S_min: float = 0,
- S_max: float = inf,
- S_noise: float = 1,
- num_steps: int = 18,
- sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- sigma_inv_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- diffusion_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *_']] | None = None,
Bases:
SolverFirst-order stochastic Euler sampler from the EDM paper.
Implements stochastic sampling with configurable noise injection controlled by the “churn” parameters.
Important
This is not a true SDE solver. It performs ad-hoc noise injection (“churn”) at each step to improve sample diversity, but the underlying integration is still an ODE step. Therefore, the denoiser should return the right-hand side of the ODE, not the SDE.
By default, noise injection is performed directly in time-step space. For linear-Gaussian noise schedules where diffusion time and noise level are not equal (e.g., VP schedule), provide
sigma_fnandsigma_inv_fnto apply churn in noise-level space rather than time-step space. Optionally providediffusion_fnto control the time-dependent magnitude of the injected noise.def sigma_fn( t: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # noise level, same shape as t def sigma_inv_fn( sigma: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # diffusion time, same shape as sigma def diffusion_fn( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) ) -> Tensor: ... # g^2(x, t), broadcastable to shape of x
- Parameters:
denoiser (Denoiser) – A callable implementing the
Denoiserinterface. Should return the right-hand side of the ODE (not the SDE, since the stochastic noise injection is handled internally by this solver). Typically obtained viaget_denoiser()withdenoising_type="ode".S_churn (float, optional) – Controls the amount of noise added at each step. Higher values add more stochasticity. By default 0 (deterministic), in which case this solver is equivalent to the deterministic
EulerSolver.S_min (float, optional) – Minimum diffusion time (or noise level if
sigma_fnandsigma_inv_fnare provided) for applying churn. By default 0.S_max (float, optional) – Maximum diffusion time (or noise level if
sigma_fnandsigma_inv_fnare provided) for applying churn. By defaultfloat("inf").S_noise (float, optional) – Noise scaling factor. Large values add more noise to the latent state. By default 1.
num_steps (int, optional) – Total number of sampling steps, used to scale churn. By default 18.
sigma_fn (Callable[[Tensor], Tensor] | None, optional) – Maps time to noise level \(\sigma(t)\). Useful for linear-Gaussian schedules where \(\sigma(t) \neq t\). Typically
sigma(). If provided,sigma_inv_fnmust also be provided. By defaultNone(identity mapping).sigma_inv_fn (Callable[[Tensor], Tensor] | None, optional) – Maps noise level back to time. Typically
sigma_inv(). If provided,sigma_fnmust also be provided. By defaultNone(identity mapping).diffusion_fn (Callable[[Tensor, Tensor], Tensor] | None, optional) – Controls the time-dependent magnitude of the injected noise, in addition of the
S_noisescaling factor. Typically the squared diffusion coefficient \(g^2(\mathbf{x}, t)\) from the reverse SDE, obtained fromdiffusion(). By defaultNone(\(g^2 = 2t\)), which corresponds to an EDM-like noise schedule.
Examples
Basic usage with default parameters (noise injection in time-step space):
>>> import torch >>> from physicsnemo.diffusion.samplers.solvers import ( ... EDMStochasticEulerSolver, ... ) >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = EDMStochasticEulerSolver(denoiser, S_churn=40, num_steps=18) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8])
Using noise scheduler methods for linear-Gaussian schedules where \(\sigma(t) \neq t\) (e.g., VP schedule). The callbacks map between time and noise level, allowing the churn to be applied in noise-level space before converting back to time-step space:
>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler >>> scheduler = VPNoiseScheduler() >>> num_steps = 10 >>> solver = EDMStochasticEulerSolver( ... denoiser, ... S_churn=40, ... num_steps=num_steps, ... sigma_fn=scheduler.sigma, ... sigma_inv_fn=scheduler.sigma_inv, ... diffusion_fn=scheduler.diffusion, ... ) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8])
- step(
- x: Float[Tensor, 'B *dims'],
- t_cur: Float[Tensor, 'B'],
- t_next: Float[Tensor, 'B'],
Perform one stochastic Euler sampling step.
- Parameters:
x (Tensor) – Current noisy latent state \(\mathbf{x}_n\) of shape \((B, *)\) where \(B\) is the batch size.
t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).
t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).
- Returns:
Updated latent state \(\mathbf{x}_{n-1}\) at time
t_next, same shape asx.- Return type:
Tensor
EDMStochasticHeunSolver#
- class physicsnemo.diffusion.samplers.solvers.EDMStochasticHeunSolver(
- denoiser: Denoiser,
- alpha: float = 1.0,
- S_churn: float = 0,
- S_min: float = 0,
- S_max: float = inf,
- S_noise: float = 1,
- num_steps: int = 18,
- sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- sigma_inv_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- diffusion_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *_']] | None = None,
Bases:
SolverSecond-order stochastic Heun sampler from the EDM paper.
Implements stochastic sampling with configurable noise injection controlled by the “churn” parameters, using a second-order Heun correction step.
Important
This is not a true SDE solver. It performs ad-hoc noise injection (“churn”) at each step to improve sample diversity, but the underlying integration is still an ODE step. Therefore, the denoiser should return the right-hand side of the ODE, not the SDE.
By default, noise injection is performed directly in time-step space. For linear-Gaussian noise schedules where diffusion time and noise level are not equal (e.g., VP schedule), provide
sigma_fnandsigma_inv_fnto apply churn in noise-level space rather than time-step space. Optionally providediffusion_fnto control the time-dependent magnitude of the injected noise.def sigma_fn( t: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # noise level, same shape as t def sigma_inv_fn( sigma: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # diffusion time, same shape as sigma def diffusion_fn( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) ) -> Tensor: ... # g^2(x, t), broadcastable to shape of x
- Parameters:
denoiser (Denoiser) – A callable implementing the
Denoiserinterface. Should return the right-hand side of the ODE (not the SDE, since the stochastic noise injection is handled internally by this solver). Typically obtained viaget_denoiser()withdenoising_type="ode".alpha (float, optional) – Interpolation parameter for the corrector step, must be in (0, 1].
alpha=1gives the standard Heun method (trapezoidal rule),alpha=0.5gives the midpoint method. By default 1.S_churn (float, optional) – Controls the amount of noise added at each step. Higher values add more stochasticity. By default 0 (deterministic), in which case this solver is equivalent to the deterministic
HeunSolver.S_min (float, optional) – Minimum diffusion time (or noise level if
sigma_fnandsigma_inv_fnare provided) for applying churn. By default 0.S_max (float, optional) – Maximum diffusion time (or noise level if
sigma_fnandsigma_inv_fnare provided) for applying churn. By defaultfloat("inf").S_noise (float, optional) – Noise scaling factor. Large values add more noise to the latent state. By default 1.
num_steps (int, optional) – Total number of sampling steps, used to scale churn. By default 18.
sigma_fn (Callable[[Tensor], Tensor] | None, optional) – Maps time to noise level \(\sigma(t)\). Useful for linear-Gaussian schedules where \(\sigma(t) \neq t\). Typically
sigma(). If provided,sigma_inv_fnmust also be provided. By defaultNone(identity mapping).sigma_inv_fn (Callable[[Tensor], Tensor] | None, optional) – Maps noise level back to time. Typically
sigma_inv(). If provided,sigma_fnmust also be provided. By defaultNone(identity mapping).diffusion_fn (Callable[[Tensor, Tensor], Tensor] | None, optional) – Controls the time-dependent magnitude of the injected noise, in addition of the
S_noisescaling factor. Typically the squared diffusion coefficient \(g^2(\mathbf{x}, t)\) from the reverse SDE, obtained fromdiffusion(). By defaultNone(\(g^2 = 2t\)), which corresponds to an EDM-like noise schedule.
Examples
Basic usage with default parameters (noise injection in time-step space):
>>> import torch >>> from physicsnemo.diffusion.samplers.solvers import ( ... EDMStochasticHeunSolver, ... ) >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = EDMStochasticHeunSolver(denoiser, S_churn=40, num_steps=18) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8])
Using noise scheduler methods for linear-Gaussian schedules where \(\sigma(t) \neq t\) (e.g., VP schedule). The callbacks map between time and noise level, allowing the churn to be applied in noise-level space before converting back to time-step space:
>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler >>> scheduler = VPNoiseScheduler() >>> num_steps = 10 >>> solver = EDMStochasticHeunSolver( ... denoiser, ... S_churn=40, ... num_steps=num_steps, ... sigma_fn=scheduler.sigma, ... sigma_inv_fn=scheduler.sigma_inv, ... diffusion_fn=scheduler.diffusion, ... ) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8])
- step(
- x: Float[Tensor, 'B *dims'],
- t_cur: Float[Tensor, 'B'],
- t_next: Float[Tensor, 'B'],
Perform one stochastic Heun sampling step.
- Parameters:
x (Tensor) – Current noisy latent state \(\mathbf{x}_n\) of shape \((B, *)\) where \(B\) is the batch size.
t_cur (Tensor) – Current diffusion time \(t_n\) of shape \((B,)\).
t_next (Tensor) – Target diffusion time \(t_{n-1}\) of shape \((B,)\).
- Returns:
Updated latent state \(\mathbf{x}_{n-1}\) at time
t_next, same shape asx.- Return type:
Tensor
Guidance#
DPSGuidance#
- class physicsnemo.diffusion.guidance.DPSGuidance(*args, **kwargs)[source]#
Protocol defining the interface for Diffusion Posterior Sampling (DPS) guidance.
A DPS guidance is a callable that computes a guidance term to steer the diffusion sampling process toward satisfying some observation constraint. It returns a quantity analogous to a likelihood score, which is typically added to the unconditional score during sampling.
The typical form is:
\[\rho(t) \nabla_{\mathbf{x}} \ell(A(\hat{\mathbf{x}}_0) - \mathbf{y})\]where \(\rho(t)\) is a time-dependent guidance strength, \(A\) is a (potentially nonlinear) observation operator, \(\mathbf{y}\) is the observed data, and \(\ell\) is a scalar loss function. However, variants are possible as long as the guidance produces a quantity similar to a score (e.g., a likelihood score).
This is the minimal interface for guidance, and any object that implements this interface can be used with
DPSScorePredictorto build a guided score-predictor, which implements thePredictorinterface.See also
DPSScorePredictorCombines an x0-predictor with one or more guidances.
Examples
Example 1: Minimal guidance for inpainting. Given a binary mask and observed pixels, guide the diffusion to match observations:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSGuidance >>> >>> class InpaintingGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask # Binary mask: 1 = observed, 0 = missing ... self.y_obs = y_obs # Observed pixel values ... self.gamma = gamma ... ... def __call__(self, x, t, x_0): ... # Compute residual at observed locations ... residual = self.mask * (x_0 - self.y_obs) ... # Gradient of L2 loss w.r.t. x_0 is just the residual ... # (simplified: assumes identity observation operator) ... return -self.gamma * residual ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = InpaintingGuidance(mask, y_obs) >>> isinstance(guidance, DPSGuidance) True
Example 2: Building a guided score predictor from scratch. A common pattern is to combine an x0-predictor with a guidance to create a score predictor that can be used for sampling. This shows the complete workflow:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSGuidance >>> >>> # Define a guidance that pushes toward observed values >>> class MyGuidance: ... def __init__(self, y_obs, gamma=0.1): ... self.y_obs = y_obs ... self.gamma = gamma ... ... def __call__(self, x, t, x_0): ... return -self.gamma * (x_0 - self.y_obs) ... >>> # Toy x0-predictor (in practice, a trained neural network) >>> x0_predictor = lambda x, t: x * 0.9 >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = MyGuidance(y_obs, gamma=0.5) >>> >>> # Build a guided score predictor that combines x0-predictor + guidance >>> def guided_score_predictor(x, t): ... x_0 = x0_predictor(x, t) ... guidance_term = guidance(x, t, x_0) ... # Convert x0 to score (for EDM: score = (x_0 - x) / t^2) ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... score = (x_0 - x) / (t_bc ** 2) ... return score + guidance_term ... >>> # guided_score_predictor is now a Predictor (score predictor); pass it >>> # to scheduler.get_denoiser(score_predictor=...) to obtain a Denoiser >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> output = guided_score_predictor(x, t) >>> output.shape torch.Size([1, 3, 8, 8])
Note:
DPSScorePredictorprovides a convenient way to apply one or more guidances to an x0-predictor without manually implementing the above pattern.
DPSScorePredictor#
- class physicsnemo.diffusion.guidance.DPSScorePredictor(
- x0_predictor: Predictor,
- x0_to_score_fn: Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims'], Float[Tensor, 'B']], Float[Tensor, 'B *dims']],
- guidances: DPSGuidance | Sequence[DPSGuidance],
Bases:
objectScore predictor that combines an x0-predictor with DPS-style guidance.
This class transforms a
Predictor(specifically, an x0-predictor) into a scorePredictorby applying one or more DPS guidances. The resulting score predictor can be passed toget_denoiser()to obtain aDenoiserfor sampling.The output is the sum of the unconditional score (derived from the x0-prediction) and all guidance terms:
\[\nabla_{\mathbf{x}} \log p(\mathbf{x}) + \sum_i g_i(\mathbf{x}, t, \hat{\mathbf{x}}_0)\]where \(g_i\) are the guidance terms implementing the
DPSGuidanceinterface.Each guidance must implement the
DPSGuidanceprotocol, which is a callable with the following signature:def guidance( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) x_0: Tensor, # shape: (B, *dims) ) -> Tensor: ... # guidance term, shape: (B, *dims)
Important
When using multiple guidances that internally call
torch.autograd.grad(e.g.,ModelConsistencyDPSGuidanceorDataConsistencyDPSGuidance), each guidance except the last must be constructed withretain_graph=True. Otherwise the computational graph is destroyed after the first guidance computes its gradient and subsequent guidances will fail. With a single guidance this is not needed.- Parameters:
x0_predictor (Predictor) – A
Predictorthat takes(x, t)and returns an estimate of the clean data \(\hat{\mathbf{x}}_0\). Typically obtained from a trainedDiffusionModelviafunctools.partial.x0_to_score_fn (Callable[[Tensor, Tensor, Tensor], Tensor]) – Callback to convert x0-prediction to score. Signature:
x0_to_score_fn(x_0, x, t) -> score. Typically obtained from a noise scheduler, e.g.,x0_to_score().guidances (DPSGuidance | Sequence[DPSGuidance]) – One or more guidance objects implementing the
DPSGuidanceinterface.
See also
DPSGuidanceProtocol for guidance implementations.
PredictorProtocol satisfied by this class.
get_denoiser()Converts the score-predictor to a denoiser for sampling.
Examples
Example 1: Basic usage with a single guidance for inpainting:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSScorePredictor, DPSGuidance >>> >>> # Toy x0-predictor (in practice, this is a trained neural network) >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> # Simple x0_to_score function (for EDM: score = (x_0 - x) / t^2) >>> def x0_to_score_fn(x_0, x, t): ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... return (x_0 - x) / (t_bc ** 2) ... >>> # Simple inpainting guidance >>> class InpaintGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask ... self.y_obs = y_obs ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * self.mask * (x_0 - self.y_obs) ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = InpaintGuidance(mask, y_obs) >>> >>> # Create DPS score predictor >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> >>> # Use in sampling >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> output = dps_score_pred(x, t) >>> output.shape torch.Size([1, 3, 8, 8])
Example 2: Multiple guidances for multi-constraint problems:
>>> import torch >>> from physicsnemo.diffusion.guidance import DPSScorePredictor >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> # Use scheduler to get x0_to_score_fn >>> scheduler = EDMNoiseScheduler() >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> # Guidance 1: match observed values at specific locations >>> class ObservationGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask ... self.y_obs = y_obs ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * self.mask * (x_0 - self.y_obs) ... >>> # Guidance 2: regularization toward zero mean >>> class ZeroMeanGuidance: ... def __init__(self, gamma=0.1): ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * x_0.mean() * torch.ones_like(x_0) ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance1 = ObservationGuidance(mask, y_obs) >>> guidance2 = ZeroMeanGuidance() >>> >>> # Combine multiple guidances >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=[guidance1, guidance2], ... ) >>> >>> x = torch.randn(2, 3, 8, 8) >>> t = torch.tensor([1.0, 1.0]) >>> output = dps_score_pred(x, t) >>> output.shape torch.Size([2, 3, 8, 8])
Example 3: Multiple autograd-based guidances require
retain_graph=Trueon all but the last:>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DPSScorePredictor, ... DataConsistencyDPSGuidance, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> mask1 = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask1[:, :, 2, 3] = True >>> mask2 = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask2[:, :, 5, 6] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> # First guidance retains the graph for the second one >>> g1 = DataConsistencyDPSGuidance( ... mask=mask1, y=y_obs, std_y=0.1, retain_graph=True, ... ) >>> # Last guidance does not need retain_graph >>> g2 = DataConsistencyDPSGuidance( ... mask=mask2, y=y_obs, std_y=0.1, ... ) >>> >>> dps = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=[g1, g2], ... ) >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> dps(x, t).shape torch.Size([1, 3, 8, 8])
ModelConsistencyDPSGuidance#
- class physicsnemo.diffusion.guidance.ModelConsistencyDPSGuidance(
- observation_operator: Callable[[Float[Tensor, 'B *dims']], Float[Tensor, 'B *obs_dims']],
- y: Float[Tensor, 'B *obs_dims'],
- std_y: float,
- norm: int | Callable[[Float[Tensor, 'B *obs_dims'], Float[Tensor, 'B *obs_dims']], Float[Tensor, 'B']] = 2,
- gamma: float = 0.0,
- sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- retain_graph: bool = False,
- create_graph: bool = False,
Bases:
DPSGuidanceDPS guidance for generic observation models with Gaussian noise.
Implements the
DPSGuidanceinterface for generic (possibly nonlinear) observation operators.Computes the likelihood score assuming Gaussian measurement noise with standard deviation
std_y. The guidance term is:\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| A\left(\hat{\mathbf{x}}_0\right) - \mathbf{y} \|^2\]where \(A\) is the observation operator and the scaling incorporates a Score-Based Data Assimilation (SDA) correction through the parameter \(\Gamma\) that accounts for the covariance of the \(\hat{\mathbf{x}}_0(\mathbf{x}_t, t)\) estimate at different diffusion times. The L2 norm can be replaced by other Lp norms or custom loss functions via the
normparameter.The
observation_operatormust be a differentiable callable with the following signature:def observation_operator( x_0: Tensor, # shape: (B, *dims) ) -> Tensor: ... # predicted observations, shape: (B, *obs_dims)
When
normis a callable, it must have the following signature:def norm( y_pred: Tensor, # shape: (B, *obs_dims) y_true: Tensor, # shape: (B, *obs_dims) ) -> Tensor: ... # scalar loss per batch element, shape: (B,)
- Parameters:
observation_operator (Callable[[Tensor], Tensor]) – Observation operator mapping clean state to observations. Must be differentiable (supports
torch.autograd).y (Tensor) – Observed data of shape \((B, *obs\_dims)\) matching the output of
A.std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).
norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An
intvalue (default2) uses the corresponding Lp norm. A callable receives(y_pred, y_true)and returns a scalar loss per batch element of shape \((B,)\).gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When
gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to0for classical DPS without SDA scaling.sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when
gamma > 0. Typically obtained from a noise scheduler, e.g.,sigma()for a linear-Gaussian noise schedule.alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. Typically obtained from a noise scheduler, e.g.,
alpha()for a linear-Gaussian noise schedule.retain_graph (bool, default=False) – If
True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a singleDPSScorePredictor— all guidances except the last must set this toTrue.create_graph (bool, default=False) – If
True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).
Note
References:
See also
DataConsistencyDPSGuidanceSimplified guidance for masked observations.
DPSScorePredictorCombines an x0-predictor with one or more guidances.
Examples
Example 1: Super-resolution with a nonlinear blur + downsampling operator:
>>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import ( ... ModelConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> >>> # Observation operator: Gaussian blur + 2x downsampling >>> def blur_downsample(x): ... # Apply 3x3 Gaussian-like blur ... kernel = torch.ones(1, 1, 3, 3, device=x.device) / 9 ... kernel = kernel.expand(x.shape[1], 1, 3, 3) ... blurred = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) ... # Downsample 2x ... return F.avg_pool2d(blurred, kernel_size=2, stride=2) ... >>> # Low-resolution observations (4x4 from 8x8 high-res) >>> y_obs = torch.randn(1, 3, 4, 4) >>> >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=blur_downsample, ... y=y_obs, ... std_y=0.1, ... ) >>> >>> # Use in DPS sampling >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Toy x0 estimate >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Combine with DPSScorePredictor for complete sampling workflow >>> x0_predictor = lambda x, t: x * 0.9 >>> def x0_to_score_fn(x_0, x, t): ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... return (x_0 - x) / (t_bc ** 2) ... >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 2: With SDA scaling using noise scheduler methods:
>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... ModelConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Linear observation operator (select first channel) >>> A = lambda x: x[:, :1] >>> y_obs = torch.randn(1, 1, 8, 8) >>> >>> # Enable SDA scaling with gamma > 0, providing sigma and alpha functions >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=A, ... y=y_obs, ... std_y=0.075, ... gamma=0.05, # Enable SDA scaling ... sigma_fn=scheduler.sigma, ... alpha_fn=scheduler.alpha, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor and scheduler's x0_to_score >>> x0_predictor = lambda x, t: x * 0.9 >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 3: With a custom loss function (Huber loss):
>>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import ModelConsistencyDPSGuidance >>> >>> # Wrap torch's Huber loss to return per-batch scalars >>> def huber_loss(y_pred, y_true): ... per_elem = F.huber_loss(y_pred, y_true, reduction="none") ... return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1) ... >>> A = lambda x: x[:, :1] # Select first channel >>> y_obs = torch.randn(1, 1, 8, 8) >>> >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=A, ... y=y_obs, ... std_y=0.1, ... norm=huber_loss, # Custom loss function ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8])
DataConsistencyDPSGuidance#
- class physicsnemo.diffusion.guidance.DataConsistencyDPSGuidance(
- mask: Bool[Tensor, 'B *dims'],
- y: Float[Tensor, 'B *dims'],
- std_y: float,
- norm: int | Callable[[Float[Tensor, 'B *dims'], Float[Tensor, 'B *dims']], Float[Tensor, 'B']] = 2,
- gamma: float = 0.0,
- sigma_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- alpha_fn: Callable[[Float[Tensor, '*shape']], Float[Tensor, '*shape']] | None = None,
- retain_graph: bool = False,
- create_graph: bool = False,
Bases:
DPSGuidanceDPS guidance for masked observations with Gaussian noise.
Implements the
DPSGuidanceinterface for masked observation operators, a simplified version ofModelConsistencyDPSGuidance. This is typical for data assimilation tasks like inpainting, outpainting, or sparse observations, where measurements are available at specific locations.Computes the likelihood score assuming Gaussian measurement noise with standard deviation
std_y. The guidance term is:\[\nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| \mathbf{M} \odot (\hat{\mathbf{x}}_0 - \mathbf{y}) \|^2\]where \(\mathbf{M}\) is a binary mask (1 = observed, 0 = missing), \(\odot\) denotes element-wise multiplication, and the scaling incorporates an SDA correction through the parameter \(\Gamma\). The L2 norm can be replaced by other Lp norms or custom loss functions via the
normparameter.When
normis a callable, it must have the following signature:def norm( y_pred: Tensor, # shape: (B, *obs_dims) y_true: Tensor, # shape: (B, *obs_dims) ) -> Tensor: ... # scalar loss per batch element, shape: (B,)
- Parameters:
mask (Tensor) – Boolean mask of shape \((B, *)\) matching the state shape.
Truefor observed locations,Falsefor missing.y (Tensor) – Observed data of shape \((B, *)\) matching the state shape. Values at unobserved locations (where
mask=0) are ignored.std_y (float) – Standard deviation of the measurement noise \(\sigma_y\).
norm (int | Callable[[Tensor, Tensor], Tensor], default=2) – Loss function used to compute the residual. An
intvalue (default2) uses the corresponding Lp norm. A callable receives(mask * x_0, mask * y)and returns a scalar loss per batch element of shape \((B,)\).gamma (float, default=0.0) – SDA covariance scaling factor \(\Gamma\). When
gamma > 0, applies SDA correction that accounts for the covariance of the \(\hat{\mathbf{x}}_0\) estimate at different noise levels. Set to0for classical DPS without SDA scaling.sigma_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to noise level \(\sigma(t)\). Required when
gamma > 0. Typically obtained from a noise scheduler. For example, usesigma()for a linear-Gaussian noise schedule.alpha_fn (Callable[[Tensor], Tensor] | None, default=None) – Function mapping diffusion time to signal coefficient \(\alpha(t)\). Optional; defaults to \(\alpha(t) = 1\) if not provided. For example, use
alpha()for a linear-Gaussian noise schedule.retain_graph (bool, default=False) – If
True, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a singleDPSScorePredictor— all guidances except the last must set this toTrue.create_graph (bool, default=False) – If
True, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process).
Note
References:
See also
ModelConsistencyDPSGuidanceGuidance for general observation operators.
DPSScorePredictorCombines an x0-predictor with one or more guidances.
Examples
Example 1: Sparse observations at probe locations:
>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DataConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> >>> # Boolean mask: only observe a few probe locations >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True # Probe at (2, 3) >>> mask[:, :, 5, 6] = True # Probe at (5, 6) >>> mask[:, :, 1, 7] = True # Probe at (1, 7) >>> y_obs = torch.randn(1, 3, 8, 8) # Observed values >>> >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.1, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Toy x0 estimate (must be computed from x) >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor for complete sampling workflow >>> x0_predictor = lambda x, t: x * 0.9 >>> def x0_to_score_fn(x_0, x, t): ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... return (x_0 - x) / (t_bc ** 2) ... >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 2: With SDA scaling and L1 norm using noise scheduler:
>>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DataConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Same sparse probe locations as Example 1 >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True >>> mask[:, :, 5, 6] = True >>> mask[:, :, 1, 7] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> # Enable SDA scaling and use L1 norm for robustness >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.075, ... norm=1, # L1 norm ... gamma=1.0, # Enable SDA scaling ... sigma_fn=scheduler.sigma, ... alpha_fn=scheduler.alpha, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Must be computed from x >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor and scheduler's x0_to_score >>> x0_predictor = lambda x, t: x * 0.9 >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8])
Example 3: With a custom loss function (Huber loss):
>>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import DataConsistencyDPSGuidance >>> >>> # Wrap torch's Huber loss to return per-batch scalars >>> def huber_loss(y_pred, y_true): ... per_elem = F.huber_loss(y_pred, y_true, reduction="none") ... return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1) ... >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True >>> mask[:, :, 5, 6] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.1, ... norm=huber_loss, # Custom loss function ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8])