Source code for physicsnemo.diffusion.guidance.dps_guidance

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""DPS (Diffusion Posterior Sampling) guidance for diffusion models."""

from typing import Callable, Protocol, Sequence, runtime_checkable

import torch
from jaxtyping import Bool, Float
from torch import Tensor

from physicsnemo.diffusion.base import Predictor


[docs] @runtime_checkable class DPSGuidance(Protocol): r""" Protocol defining the interface for Diffusion Posterior Sampling (DPS) guidance. A DPS guidance is a callable that computes a guidance term to steer the diffusion sampling process toward satisfying some observation constraint. It returns a quantity analogous to a likelihood score, which is typically added to the unconditional score during sampling. The typical form is: .. math:: \rho(t) \nabla_{\mathbf{x}} \ell(A(\hat{\mathbf{x}}_0) - \mathbf{y}) where :math:`\rho(t)` is a time-dependent guidance strength, :math:`A` is a (potentially nonlinear) observation operator, :math:`\mathbf{y}` is the observed data, and :math:`\ell` is a scalar loss function. However, variants are possible as long as the guidance produces a quantity similar to a score (e.g., a likelihood score). This is the minimal interface for guidance, and any object that implements this interface can be used with :class:`DPSScorePredictor` to build a guided score-predictor, which implements the :class:`~physicsnemo.diffusion.Predictor` interface. See Also -------- :class:`DPSScorePredictor` : Combines an x0-predictor with one or more guidances. Examples -------- **Example 1:** Minimal guidance for inpainting. Given a binary mask and observed pixels, guide the diffusion to match observations: >>> import torch >>> from physicsnemo.diffusion.guidance import DPSGuidance >>> >>> class InpaintingGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask # Binary mask: 1 = observed, 0 = missing ... self.y_obs = y_obs # Observed pixel values ... self.gamma = gamma ... ... def __call__(self, x, t, x_0): ... # Compute residual at observed locations ... residual = self.mask * (x_0 - self.y_obs) ... # Gradient of L2 loss w.r.t. x_0 is just the residual ... # (simplified: assumes identity observation operator) ... return -self.gamma * residual ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = InpaintingGuidance(mask, y_obs) >>> isinstance(guidance, DPSGuidance) True **Example 2:** Building a guided score predictor from scratch. A common pattern is to combine an x0-predictor with a guidance to create a score predictor that can be used for sampling. This shows the complete workflow: >>> import torch >>> from physicsnemo.diffusion.guidance import DPSGuidance >>> >>> # Define a guidance that pushes toward observed values >>> class MyGuidance: ... def __init__(self, y_obs, gamma=0.1): ... self.y_obs = y_obs ... self.gamma = gamma ... ... def __call__(self, x, t, x_0): ... return -self.gamma * (x_0 - self.y_obs) ... >>> # Toy x0-predictor (in practice, a trained neural network) >>> x0_predictor = lambda x, t: x * 0.9 >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = MyGuidance(y_obs, gamma=0.5) >>> >>> # Build a guided score predictor that combines x0-predictor + guidance >>> def guided_score_predictor(x, t): ... x_0 = x0_predictor(x, t) ... guidance_term = guidance(x, t, x_0) ... # Convert x0 to score (for EDM: score = (x_0 - x) / t^2) ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... score = (x_0 - x) / (t_bc ** 2) ... return score + guidance_term ... >>> # guided_score_predictor is now a Predictor (score predictor); pass it >>> # to scheduler.get_denoiser(score_predictor=...) to obtain a Denoiser >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> output = guided_score_predictor(x, t) >>> output.shape torch.Size([1, 3, 8, 8]) Note: :class:`DPSScorePredictor` provides a convenient way to apply one or more guidances to an x0-predictor without manually implementing the above pattern. """ def __call__( self, x: Float[Tensor, " B *dims"], t: Float[Tensor, " B"], x_0: Float[Tensor, " B *dims"], ) -> Float[Tensor, " B *dims"]: r""" Compute the guidance term. Parameters ---------- x : Tensor Noisy latent state at diffusion time ``t``, of shape :math:`(B, *)`. Typically used to compute gradients when the guidance requires backpropagation through the diffusion process, in which case it needs to have ``requires_grad=True``. t : Tensor Batched diffusion time of shape :math:`(B,)`. x_0 : Tensor Estimate of the clean latent state, of shape :math:`(B, *)`. Typically produced by an x0-predictor or clean data predictor. Returns ------- Tensor Guidance term of the same shape as ``x``. This is analogous to a likelihood score and is typically added to the unconditional score to guide the sampling process. """ ...
[docs] class DPSScorePredictor: r""" Score predictor that combines an x0-predictor with DPS-style guidance. This class transforms a :class:`~physicsnemo.diffusion.Predictor` (specifically, an **x0-predictor**) into a score :class:`~physicsnemo.diffusion.Predictor` by applying one or more DPS guidances. The resulting score predictor can be passed to :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` to obtain a :class:`~physicsnemo.diffusion.Denoiser` for sampling. The output is the sum of the unconditional score (derived from the x0-prediction) and all guidance terms: .. math:: \nabla_{\mathbf{x}} \log p(\mathbf{x}) + \sum_i g_i(\mathbf{x}, t, \hat{\mathbf{x}}_0) where :math:`g_i` are the guidance terms implementing the :class:`DPSGuidance` interface. Each guidance must implement the :class:`DPSGuidance` protocol, which is a callable with the following signature: .. code-block:: python def guidance( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) x_0: Tensor, # shape: (B, *dims) ) -> Tensor: ... # guidance term, shape: (B, *dims) .. important:: When using **multiple guidances** that internally call ``torch.autograd.grad`` (e.g., :class:`ModelConsistencyDPSGuidance` or :class:`DataConsistencyDPSGuidance`), each guidance except the last must be constructed with ``retain_graph=True``. Otherwise the computational graph is destroyed after the first guidance computes its gradient and subsequent guidances will fail. With a **single guidance** this is not needed. Parameters ---------- x0_predictor : Predictor A :class:`~physicsnemo.diffusion.Predictor` that takes ``(x, t)`` and returns an estimate of the clean data :math:`\hat{\mathbf{x}}_0`. Typically obtained from a trained :class:`~physicsnemo.diffusion.DiffusionModel` via ``functools.partial``. x0_to_score_fn : Callable[[Tensor, Tensor, Tensor], Tensor] Callback to convert x0-prediction to score. Signature: ``x0_to_score_fn(x_0, x, t) -> score``. Typically obtained from a noise scheduler, e.g., :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.x0_to_score`. guidances : DPSGuidance | Sequence[DPSGuidance] One or more guidance objects implementing the :class:`DPSGuidance` interface. See Also -------- :class:`DPSGuidance` : Protocol for guidance implementations. :class:`~physicsnemo.diffusion.Predictor` : Protocol satisfied by this class. :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` : Converts the score-predictor to a denoiser for sampling. Examples -------- **Example 1:** Basic usage with a single guidance for inpainting: >>> import torch >>> from physicsnemo.diffusion.guidance import DPSScorePredictor, DPSGuidance >>> >>> # Toy x0-predictor (in practice, this is a trained neural network) >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> # Simple x0_to_score function (for EDM: score = (x_0 - x) / t^2) >>> def x0_to_score_fn(x_0, x, t): ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... return (x_0 - x) / (t_bc ** 2) ... >>> # Simple inpainting guidance >>> class InpaintGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask ... self.y_obs = y_obs ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * self.mask * (x_0 - self.y_obs) ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance = InpaintGuidance(mask, y_obs) >>> >>> # Create DPS score predictor >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> >>> # Use in sampling >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> output = dps_score_pred(x, t) >>> output.shape torch.Size([1, 3, 8, 8]) **Example 2:** Multiple guidances for multi-constraint problems: >>> import torch >>> from physicsnemo.diffusion.guidance import DPSScorePredictor >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> # Use scheduler to get x0_to_score_fn >>> scheduler = EDMNoiseScheduler() >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> # Guidance 1: match observed values at specific locations >>> class ObservationGuidance: ... def __init__(self, mask, y_obs, gamma=1.0): ... self.mask = mask ... self.y_obs = y_obs ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * self.mask * (x_0 - self.y_obs) ... >>> # Guidance 2: regularization toward zero mean >>> class ZeroMeanGuidance: ... def __init__(self, gamma=0.1): ... self.gamma = gamma ... def __call__(self, x, t, x_0): ... return -self.gamma * x_0.mean() * torch.ones_like(x_0) ... >>> mask = torch.ones(1, 3, 8, 8) >>> y_obs = torch.randn(1, 3, 8, 8) >>> guidance1 = ObservationGuidance(mask, y_obs) >>> guidance2 = ZeroMeanGuidance() >>> >>> # Combine multiple guidances >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=[guidance1, guidance2], ... ) >>> >>> x = torch.randn(2, 3, 8, 8) >>> t = torch.tensor([1.0, 1.0]) >>> output = dps_score_pred(x, t) >>> output.shape torch.Size([2, 3, 8, 8]) **Example 3:** Multiple autograd-based guidances require ``retain_graph=True`` on all but the last: >>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DPSScorePredictor, ... DataConsistencyDPSGuidance, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> x0_predictor = lambda x, t: x * 0.9 >>> >>> mask1 = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask1[:, :, 2, 3] = True >>> mask2 = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask2[:, :, 5, 6] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> # First guidance retains the graph for the second one >>> g1 = DataConsistencyDPSGuidance( ... mask=mask1, y=y_obs, std_y=0.1, retain_graph=True, ... ) >>> # Last guidance does not need retain_graph >>> g2 = DataConsistencyDPSGuidance( ... mask=mask2, y=y_obs, std_y=0.1, ... ) >>> >>> dps = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=[g1, g2], ... ) >>> x = torch.randn(1, 3, 8, 8) >>> t = torch.tensor([1.0]) >>> dps(x, t).shape torch.Size([1, 3, 8, 8]) """ def __init__( self, x0_predictor: Predictor, x0_to_score_fn: Callable[ [Float[Tensor, " B *dims"], Float[Tensor, " B *dims"], Float[Tensor, " B"]], Float[Tensor, " B *dims"], ], guidances: DPSGuidance | Sequence[DPSGuidance], ) -> None: self.x0_predictor = x0_predictor self.x0_to_score_fn = x0_to_score_fn # Normalize guidances to a list if isinstance(guidances, Sequence) and not isinstance(guidances, str): self.guidances = list(guidances) else: self.guidances = [guidances] def __call__( self, x: Float[Tensor, " B *dims"], t: Float[Tensor, " B"], ) -> Float[Tensor, " B *dims"]: r""" Compute the guided score for sampling. Parameters ---------- x : Tensor Noisy latent state at diffusion time ``t``, of shape :math:`(B, *)`. t : Tensor Batched diffusion time of shape :math:`(B,)`. Returns ------- Tensor Guided score of same shape :math:`(B, *)` as ``x``. Computed as the sum of the unconditional score and all guidance terms. """ if not torch.compiler.is_compiling() and torch.is_inference_mode_enabled(): raise RuntimeError( "DPSScorePredictor requires autograd but torch inference mode " "is enabled. Wrap the calling code with " "'with torch.inference_mode(False):' or 'with torch.no_grad():' " "instead." ) x = x.detach().requires_grad_(True) with torch.enable_grad(): x_0 = self.x0_predictor(x, t) guidance_sum = torch.zeros_like(x) for guidance in self.guidances: guidance_sum += guidance(x, t, x_0) score = self.x0_to_score_fn(x_0, x, t) return score + guidance_sum
[docs] class ModelConsistencyDPSGuidance(DPSGuidance): r""" DPS guidance for generic observation models with Gaussian noise. Implements the :class:`DPSGuidance` interface for generic (possibly nonlinear) observation operators. Computes the likelihood score assuming Gaussian measurement noise with standard deviation ``std_y``. The guidance term is: .. math:: \nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| A\left(\hat{\mathbf{x}}_0\right) - \mathbf{y} \|^2 where :math:`A` is the observation operator and the scaling incorporates a Score-Based Data Assimilation (SDA) correction through the parameter :math:`\Gamma` that accounts for the covariance of the :math:`\hat{\mathbf{x}}_0(\mathbf{x}_t, t)` estimate at different diffusion times. The L2 norm can be replaced by other Lp norms or custom loss functions via the ``norm`` parameter. The ``observation_operator`` must be a differentiable callable with the following signature: .. code-block:: python def observation_operator( x_0: Tensor, # shape: (B, *dims) ) -> Tensor: ... # predicted observations, shape: (B, *obs_dims) When ``norm`` is a callable, it must have the following signature: .. code-block:: python def norm( y_pred: Tensor, # shape: (B, *obs_dims) y_true: Tensor, # shape: (B, *obs_dims) ) -> Tensor: ... # scalar loss per batch element, shape: (B,) Parameters ---------- observation_operator : Callable[[Tensor], Tensor] Observation operator mapping clean state to observations. Must be differentiable (supports ``torch.autograd``). y : Tensor Observed data of shape :math:`(B, *obs\_dims)` matching the output of ``A``. std_y : float Standard deviation of the measurement noise :math:`\sigma_y`. norm : int | Callable[[Tensor, Tensor], Tensor], default=2 Loss function used to compute the residual. An ``int`` value (default ``2``) uses the corresponding Lp norm. A callable receives ``(y_pred, y_true)`` and returns a scalar loss per batch element of shape :math:`(B,)`. gamma : float, default=0.0 SDA covariance scaling factor :math:`\Gamma`. When ``gamma > 0``, applies SDA correction that accounts for the covariance of the :math:`\hat{\mathbf{x}}_0` estimate at different noise levels. Set to ``0`` for classical DPS without SDA scaling. sigma_fn : Callable[[Tensor], Tensor] | None, default=None Function mapping diffusion time to noise level :math:`\sigma(t)`. Required when ``gamma > 0``. Typically obtained from a noise scheduler, e.g., :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma` for a linear-Gaussian noise schedule. alpha_fn : Callable[[Tensor], Tensor] | None, default=None Function mapping diffusion time to signal coefficient :math:`\alpha(t)`. Optional; defaults to :math:`\alpha(t) = 1` if not provided. Typically obtained from a noise scheduler, e.g., :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.alpha` for a linear-Gaussian noise schedule. retain_graph : bool, default=False If ``True``, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a single :class:`DPSScorePredictor` — all guidances except the last must set this to ``True``. create_graph : bool, default=False If ``True``, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process). Note ---- References: - DPS: `Diffusion Posterior Sampling for General Noisy Inverse Problems <https://arxiv.org/abs/2209.14687>`_ - SDA: `Score-based Data Assimilation <https://arxiv.org/abs/2306.10574>`_ See Also -------- :class:`DataConsistencyDPSGuidance` : Simplified guidance for masked observations. :class:`DPSScorePredictor` : Combines an x0-predictor with one or more guidances. Examples -------- **Example 1:** Super-resolution with a nonlinear blur + downsampling operator: >>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import ( ... ModelConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> >>> # Observation operator: Gaussian blur + 2x downsampling >>> def blur_downsample(x): ... # Apply 3x3 Gaussian-like blur ... kernel = torch.ones(1, 1, 3, 3, device=x.device) / 9 ... kernel = kernel.expand(x.shape[1], 1, 3, 3) ... blurred = F.conv2d(x, kernel, padding=1, groups=x.shape[1]) ... # Downsample 2x ... return F.avg_pool2d(blurred, kernel_size=2, stride=2) ... >>> # Low-resolution observations (4x4 from 8x8 high-res) >>> y_obs = torch.randn(1, 3, 4, 4) >>> >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=blur_downsample, ... y=y_obs, ... std_y=0.1, ... ) >>> >>> # Use in DPS sampling >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Toy x0 estimate >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Combine with DPSScorePredictor for complete sampling workflow >>> x0_predictor = lambda x, t: x * 0.9 >>> def x0_to_score_fn(x_0, x, t): ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... return (x_0 - x) / (t_bc ** 2) ... >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8]) **Example 2:** With SDA scaling using noise scheduler methods: >>> import torch >>> from physicsnemo.diffusion.guidance import ( ... ModelConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Linear observation operator (select first channel) >>> A = lambda x: x[:, :1] >>> y_obs = torch.randn(1, 1, 8, 8) >>> >>> # Enable SDA scaling with gamma > 0, providing sigma and alpha functions >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=A, ... y=y_obs, ... std_y=0.075, ... gamma=0.05, # Enable SDA scaling ... sigma_fn=scheduler.sigma, ... alpha_fn=scheduler.alpha, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor and scheduler's x0_to_score >>> x0_predictor = lambda x, t: x * 0.9 >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8]) **Example 3:** With a custom loss function (Huber loss): >>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import ModelConsistencyDPSGuidance >>> >>> # Wrap torch's Huber loss to return per-batch scalars >>> def huber_loss(y_pred, y_true): ... per_elem = F.huber_loss(y_pred, y_true, reduction="none") ... return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1) ... >>> A = lambda x: x[:, :1] # Select first channel >>> y_obs = torch.randn(1, 1, 8, 8) >>> >>> guidance = ModelConsistencyDPSGuidance( ... observation_operator=A, ... y=y_obs, ... std_y=0.1, ... norm=huber_loss, # Custom loss function ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) """ def __init__( self, observation_operator: Callable[ [Float[Tensor, " B *dims"]], Float[Tensor, " B *obs_dims"] ], y: Float[Tensor, " B *obs_dims"], std_y: float, norm: int | Callable[ [Float[Tensor, " B *obs_dims"], Float[Tensor, " B *obs_dims"]], Float[Tensor, " B"], ] = 2, gamma: float = 0.0, sigma_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, alpha_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, retain_graph: bool = False, create_graph: bool = False, ) -> None: if gamma > 0 and sigma_fn is None: raise ValueError("sigma_fn must be provided when gamma > 0") self.observation_operator = observation_operator self.y = y self.std_y = std_y self.norm = norm self.gamma = gamma self.sigma_fn = ( sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t) ) self.alpha_fn = ( alpha_fn if alpha_fn is not None else lambda t: torch.ones_like(t) ) self.retain_graph = retain_graph self.create_graph = create_graph def __call__( self, x: Float[Tensor, " B *dims"], t: Float[Tensor, " B"], x_0: Float[Tensor, " B *dims"], ) -> Float[Tensor, " B *dims"]: r""" Compute the likelihood score guidance term. Parameters ---------- x : Tensor Noisy latent state :math:`\mathbf{x}_t`, of shape :math:`(B, *)`. Must have ``requires_grad=True`` and be part of a computational graph connecting to ``x_0``. Its ``dtype`` and ``device`` determine those of all internal computations. t : Tensor Batched diffusion time of shape :math:`(B,)`. x_0 : Tensor Estimate of the clean latent state :math:`\hat{\mathbf{x}}_0 (\mathbf{x}_t, t)`, with same shape as ``x``. Must be computed from ``x`` via an x0-predictor to allow gradient backpropagation. Returns ------- Tensor Likelihood score guidance term of same shape as ``x``. """ if not torch.compiler.is_compiling() and torch.is_inference_mode_enabled(): raise RuntimeError( "ModelConsistencyDPSGuidance requires autograd but torch " "inference mode is enabled. Wrap the calling code with " "'with torch.inference_mode(False):' or 'with torch.no_grad():' " "instead." ) y = self.y.to(dtype=x.dtype, device=x.device) with torch.enable_grad(): # Compute predicted observations y_pred = self.observation_operator(x_0) # Compute loss if callable(self.norm): loss = self.norm(y_pred, y) else: residual = (y_pred - y).reshape(y_pred.shape[0], -1) loss = residual.abs().pow(self.norm).sum(dim=1) # Compute gradient of loss w.r.t. x (backprop through x_0) grad_x = torch.autograd.grad( outputs=loss.sum(), inputs=x, retain_graph=self.retain_graph, create_graph=self.create_graph, )[0] # Compute scaling factor t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) sigma_t = self.sigma_fn(t_bc) alpha_t = self.alpha_fn(t_bc) variance = self.std_y**2 + self.gamma * (sigma_t**2) / (alpha_t**2) return -grad_x / (2 * variance)
[docs] class DataConsistencyDPSGuidance(DPSGuidance): r""" DPS guidance for masked observations with Gaussian noise. Implements the :class:`DPSGuidance` interface for masked observation operators, a simplified version of :class:`ModelConsistencyDPSGuidance`. This is typical for data assimilation tasks like inpainting, outpainting, or sparse observations, where measurements are available at specific locations. Computes the likelihood score assuming Gaussian measurement noise with standard deviation ``std_y``. The guidance term is: .. math:: \nabla_{\mathbf{x}} \log p(\mathbf{y} | \mathbf{x}_t) = -\frac{1}{2 \left( \sigma_y^2 + \Gamma \frac{\sigma(t)^2}{\alpha(t)^2} \right)} \nabla_{\mathbf{x}} \| \mathbf{M} \odot (\hat{\mathbf{x}}_0 - \mathbf{y}) \|^2 where :math:`\mathbf{M}` is a binary mask (1 = observed, 0 = missing), :math:`\odot` denotes element-wise multiplication, and the scaling incorporates an SDA correction through the parameter :math:`\Gamma`. The L2 norm can be replaced by other Lp norms or custom loss functions via the ``norm`` parameter. When ``norm`` is a callable, it must have the following signature: .. code-block:: python def norm( y_pred: Tensor, # shape: (B, *obs_dims) y_true: Tensor, # shape: (B, *obs_dims) ) -> Tensor: ... # scalar loss per batch element, shape: (B,) Parameters ---------- mask : Tensor Boolean mask of shape :math:`(B, *)` matching the state shape. ``True`` for observed locations, ``False`` for missing. y : Tensor Observed data of shape :math:`(B, *)` matching the state shape. Values at unobserved locations (where ``mask=0``) are ignored. std_y : float Standard deviation of the measurement noise :math:`\sigma_y`. norm : int | Callable[[Tensor, Tensor], Tensor], default=2 Loss function used to compute the residual. An ``int`` value (default ``2``) uses the corresponding Lp norm. A callable receives ``(mask * x_0, mask * y)`` and returns a scalar loss per batch element of shape :math:`(B,)`. gamma : float, default=0.0 SDA covariance scaling factor :math:`\Gamma`. When ``gamma > 0``, applies SDA correction that accounts for the covariance of the :math:`\hat{\mathbf{x}}_0` estimate at different noise levels. Set to ``0`` for classical DPS without SDA scaling. sigma_fn : Callable[[Tensor], Tensor] | None, default=None Function mapping diffusion time to noise level :math:`\sigma(t)`. Required when ``gamma > 0``. Typically obtained from a noise scheduler. For example, use :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma` for a linear-Gaussian noise schedule. alpha_fn : Callable[[Tensor], Tensor] | None, default=None Function mapping diffusion time to signal coefficient :math:`\alpha(t)`. Optional; defaults to :math:`\alpha(t) = 1` if not provided. For example, use :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.alpha` for a linear-Gaussian noise schedule. retain_graph : bool, default=False If ``True``, the computational graph is retained after computing gradients. Required when combining multiple autograd-based guidances in a single :class:`DPSScorePredictor` — all guidances except the last must set this to ``True``. create_graph : bool, default=False If ``True``, a graph of the derivative is constructed, allowing higher-order derivatives (e.g., differentiating through the entire sampling process). Note ---- References: - DPS: `Diffusion Posterior Sampling for General Noisy Inverse Problems <https://arxiv.org/abs/2209.14687>`_ - SDA: `Score-based Data Assimilation <https://arxiv.org/abs/2306.10574>`_ See Also -------- :class:`ModelConsistencyDPSGuidance` : Guidance for general observation operators. :class:`DPSScorePredictor` : Combines an x0-predictor with one or more guidances. Examples -------- **Example 1:** Sparse observations at probe locations: >>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DataConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> >>> # Boolean mask: only observe a few probe locations >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True # Probe at (2, 3) >>> mask[:, :, 5, 6] = True # Probe at (5, 6) >>> mask[:, :, 1, 7] = True # Probe at (1, 7) >>> y_obs = torch.randn(1, 3, 8, 8) # Observed values >>> >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.1, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Toy x0 estimate (must be computed from x) >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor for complete sampling workflow >>> x0_predictor = lambda x, t: x * 0.9 >>> def x0_to_score_fn(x_0, x, t): ... t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) ... return (x_0 - x) / (t_bc ** 2) ... >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=x0_to_score_fn, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8]) **Example 2:** With SDA scaling and L1 norm using noise scheduler: >>> import torch >>> from physicsnemo.diffusion.guidance import ( ... DataConsistencyDPSGuidance, ... DPSScorePredictor, ... ) >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> >>> scheduler = EDMNoiseScheduler() >>> >>> # Same sparse probe locations as Example 1 >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True >>> mask[:, :, 5, 6] = True >>> mask[:, :, 1, 7] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> # Enable SDA scaling and use L1 norm for robustness >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.075, ... norm=1, # L1 norm ... gamma=1.0, # Enable SDA scaling ... sigma_fn=scheduler.sigma, ... alpha_fn=scheduler.alpha, ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 # Must be computed from x >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) >>> >>> # Use with DPSScorePredictor and scheduler's x0_to_score >>> x0_predictor = lambda x, t: x * 0.9 >>> dps_score_pred = DPSScorePredictor( ... x0_predictor=x0_predictor, ... x0_to_score_fn=scheduler.x0_to_score, ... guidances=guidance, ... ) >>> score = dps_score_pred(x, t) >>> score.shape torch.Size([1, 3, 8, 8]) **Example 3:** With a custom loss function (Huber loss): >>> import torch >>> import torch.nn.functional as F >>> from physicsnemo.diffusion.guidance import DataConsistencyDPSGuidance >>> >>> # Wrap torch's Huber loss to return per-batch scalars >>> def huber_loss(y_pred, y_true): ... per_elem = F.huber_loss(y_pred, y_true, reduction="none") ... return per_elem.reshape(y_pred.shape[0], -1).sum(dim=1) ... >>> mask = torch.zeros(1, 3, 8, 8, dtype=torch.bool) >>> mask[:, :, 2, 3] = True >>> mask[:, :, 5, 6] = True >>> y_obs = torch.randn(1, 3, 8, 8) >>> >>> guidance = DataConsistencyDPSGuidance( ... mask=mask, ... y=y_obs, ... std_y=0.1, ... norm=huber_loss, # Custom loss function ... ) >>> >>> x = torch.randn(1, 3, 8, 8, requires_grad=True) >>> t = torch.tensor([1.0]) >>> x_0 = x * 0.9 >>> output = guidance(x, t, x_0) >>> output.shape torch.Size([1, 3, 8, 8]) """ def __init__( self, mask: Bool[Tensor, " B *dims"], y: Float[Tensor, " B *dims"], std_y: float, norm: int | Callable[ [Float[Tensor, "B *dims"], Float[Tensor, "B *dims"]], # noqa: F821 Float[Tensor, " B"], ] = 2, gamma: float = 0.0, sigma_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, alpha_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, retain_graph: bool = False, create_graph: bool = False, ) -> None: if gamma > 0 and sigma_fn is None: raise ValueError("sigma_fn must be provided when gamma > 0") self.mask = mask.float() self.y = y self.std_y = std_y self.norm = norm self.gamma = gamma self.sigma_fn = ( sigma_fn if sigma_fn is not None else lambda t: torch.zeros_like(t) ) self.alpha_fn = ( alpha_fn if alpha_fn is not None else lambda t: torch.ones_like(t) ) self.retain_graph = retain_graph self.create_graph = create_graph def __call__( self, x: Float[Tensor, " B *dims"], t: Float[Tensor, " B"], x_0: Float[Tensor, " B *dims"], ) -> Float[Tensor, " B *dims"]: r""" Compute the likelihood score guidance term. Parameters ---------- x : Tensor Noisy latent state :math:`\mathbf{x}_t`, of shape :math:`(B, *)`. Must have ``requires_grad=True`` and be part of a computational graph connecting to ``x_0``. Its ``dtype`` and ``device`` determine those of all internal computations. t : Tensor Batched diffusion time of shape :math:`(B,)`. x_0 : Tensor Estimate of the clean latent state :math:`\hat{\mathbf{x}}_0 (\mathbf{x}_t, t)`, with same shape as ``x``. Must be computed from ``x`` via an x0-predictor to allow gradient backpropagation. Returns ------- Tensor Likelihood score guidance term of same shape as ``x``. """ if not torch.compiler.is_compiling() and torch.is_inference_mode_enabled(): raise RuntimeError( "DataConsistencyDPSGuidance requires autograd but torch " "inference mode is enabled. Wrap the calling code with " "'with torch.inference_mode(False):' or 'with torch.no_grad():' " "instead." ) mask = self.mask.to(dtype=x.dtype, device=x.device) y = self.y.to(dtype=x.dtype, device=x.device) with torch.enable_grad(): # Compute masked predicted and observed values y_pred = mask * x_0 y_true = mask * y # Compute loss if callable(self.norm): loss = self.norm(y_pred, y_true) else: residual = (y_pred - y_true).reshape(x_0.shape[0], -1) loss = residual.abs().pow(self.norm).sum(dim=1) # Compute gradient of loss w.r.t. x (backprop through x_0) grad_x = torch.autograd.grad( outputs=loss.sum(), inputs=x, retain_graph=self.retain_graph, create_graph=self.create_graph, )[0] # Compute scaling factor t_bc = t.reshape(-1, *([1] * (x.ndim - 1))) sigma_t = self.sigma_fn(t_bc) alpha_t = self.alpha_fn(t_bc) variance = self.std_y**2 + self.gamma * (sigma_t**2) / (alpha_t**2) return -grad_x / (2 * variance)