Source code for physicsnemo.diffusion.samplers.solvers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ODE/SDE solvers for diffusion model sampling."""

import math
from typing import Callable, Protocol, runtime_checkable

import torch
from jaxtyping import Float
from torch import Tensor

from physicsnemo.diffusion.base import Denoiser


[docs] @runtime_checkable class Solver(Protocol): r""" Protocol defining the interface for diffusion solvers. A solver implements a numerical method to integrate the diffusion process from a noisy state to a less noisy (or clean) state. Each call to :meth:`step` advances the state from time ``t_cur`` (:math:`t_n`) to ``t_next`` (:math:`t_{n-1}`). This is the minimal interface required for sampling from a diffusion model, and any object that implements this interface can be used as a solver in sampling utilities. The update rule applied by the sampler is roughly: .. math:: \mathbf{x}_{n-1} = \text{Step}(F(\mathbf{x}_n, t_n); \mathbf{x}_n, t_n, t_{n-1}) where :math:`F` is the denoiser (e.g. the right hand side in the case of ODE/SDE-based sampling, the denoised latent state in the case of discrete Markov chain-based sampling, etc.) and :math:`\text{Step}` is the update rule of the solver, implemented by the :meth:`step` method. See Also -------- :func:`~physicsnemo.diffusion.samplers.sample` : The sampling function that uses solvers to generate samples. Examples -------- >>> import torch >>> from physicsnemo.diffusion.samplers.solvers import Solver >>> >>> class SimpleEuler: ... def __init__(self, denoiser): ... self.denoiser = denoiser ... def step(self, x, t_cur, t_next): ... d = (x - self.denoiser(x, t_cur)) / t_cur ... return x + (t_next - t_cur) * d ... >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1)**2) # Toy denoiser >>> solver = SimpleEuler(denoiser) >>> isinstance(solver, Solver) True """
[docs] def step( self, x: Float[Tensor, " B *dims"], t_cur: Float[Tensor, " B"], t_next: Float[Tensor, " B"], ) -> Float[Tensor, " B *dims"]: r""" Perform one integration step from ``t_cur`` to ``t_next``. Parameters ---------- x : Tensor Current noisy latent state :math:`\mathbf{x}_{n}` of shape :math:`(B, *)` where :math:`B` is the batch size. t_cur : Tensor Current diffusion time :math:`t_n` of shape :math:`(B,)`. t_next : Tensor Target diffusion time :math:`t_{n-1}` of shape :math:`(B,)`. Returns ------- Tensor Updated latent state :math:`\mathbf{x}_{n-1}` at time ``t_next``, same shape as ``x``. """ ...
[docs] class EulerSolver(Solver): r""" First-order Euler solver for diffusion ODEs. This is a fast solver with one denoiser evaluation per step, but typically produces lower quality samples compared to higher-order methods. Parameters ---------- denoiser : Denoiser A callable implementing the :class:`~physicsnemo.diffusion.Denoiser` interface. Here it is expected to return the right hand side of the ODE. Typically obtained via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser`, but any callable with the correct signature can be used. Examples -------- >>> import torch >>> from physicsnemo.diffusion.samplers.solvers import EulerSolver >>> >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = EulerSolver(denoiser) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) >>> isinstance(solver, Solver) True """ def __init__(self, denoiser: Denoiser) -> None: self.denoiser = denoiser
[docs] def step( self, x: Float[Tensor, " B *dims"], t_cur: Float[Tensor, " B"], t_next: Float[Tensor, " B"], ) -> Float[Tensor, " B *dims"]: r""" Perform one Euler integration step. Parameters ---------- x : Tensor Current noisy latent state :math:`\mathbf{x}_{n}` of shape :math:`(B, *)` where :math:`B` is the batch size. t_cur : Tensor Current diffusion time :math:`t_n` of shape :math:`(B,)`. t_next : Tensor Target diffusion time :math:`t_{n-1}` of shape :math:`(B,)`. Returns ------- Tensor Updated latent state :math:`\mathbf{x}_{n-1}` at time ``t_next``, same shape as ``x``. """ # Ensure contiguous strides so successive denoiser calls (across # sampling steps) present the same stride layout to torch.compile, # avoiding spurious recompilations / silently divergent traces. t_cur = t_cur.contiguous() t_next = t_next.contiguous() # Reshape t for broadcasting: (B,) -> (B, 1, ..., 1) expected_shape = (-1,) + (1,) * (x.ndim - 1) t_cur_bc = t_cur.reshape(expected_shape) t_next_bc = t_next.reshape(expected_shape) # RHS evaluation and step update d_cur = self.denoiser(x, t_cur) x_next = x + (t_next_bc - t_cur_bc) * d_cur return x_next
[docs] class HeunSolver(Solver): r""" Second-order Heun solver for diffusion ODEs. This method requires two denoiser evaluations per step but usually produces higher quality samples than :class:`EulerSolver`. Parameters ---------- denoiser : Denoiser A callable implementing the :class:`~physicsnemo.diffusion.Denoiser` interface. Here it is expected to return the right hand side of the ODE. Typically obtained via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser`, but any callable with the correct signature can be used. alpha : float, optional Interpolation parameter for the corrector step, must be in (0, 1]. ``alpha=1`` gives the standard Heun method (trapezoidal rule), ``alpha=0.5`` gives the midpoint method. By default 1. Examples -------- >>> import torch >>> from physicsnemo.diffusion.samplers.solvers import HeunSolver >>> >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = HeunSolver(denoiser) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) """ def __init__( self, denoiser: Denoiser, alpha: float = 1.0, ) -> None: self.denoiser = denoiser if not 0 < alpha <= 1: raise ValueError(f"alpha must be in (0, 1], got {alpha}") self.alpha = alpha
[docs] def step( self, x: Float[Tensor, " B *dims"], t_cur: Float[Tensor, " B"], t_next: Float[Tensor, " B"], ) -> Float[Tensor, " B *dims"]: r""" Perform one Heun integration step. Parameters ---------- x : Tensor Current noisy latent state :math:`\mathbf{x}_n` of shape :math:`(B, *)` where :math:`B` is the batch size. t_cur : Tensor Current diffusion time :math:`t_n` of shape :math:`(B,)`. t_next : Tensor Target diffusion time :math:`t_{n-1}` of shape :math:`(B,)`. Returns ------- Tensor Updated latent state :math:`\mathbf{x}_{n-1}` at time ``t_next``, same shape as ``x``. """ # Ensure contiguous strides so that both denoiser calls (at t_cur # and at the intermediate t_prime) present the same stride layout # to torch.compile, avoiding spurious recompilations. t_cur = t_cur.contiguous() t_next = t_next.contiguous() # Reshape t for broadcasting: (B,) -> (B, 1, ..., 1) expected_shape = (-1,) + (1,) * (x.ndim - 1) t_cur_bc = t_cur.reshape(expected_shape) t_next_bc = t_next.reshape(expected_shape) h_bc = t_next_bc - t_cur_bc # First RHS evaluation d_cur = self.denoiser(x, t_cur) # Predictor step to intermediate point t_prime_bc = t_cur_bc + self.alpha * h_bc x_prime = x + self.alpha * h_bc * d_cur # Mask for elements where t_next != 0 (need 2nd order correction) # Shape: (B, 1, ..., 1) for broadcasting mask_bc = (t_next_bc != 0).float() # Second RHS evaluation (compute everywhere, masked later) # Avoid division by zero in denoiser by using t_cur where t_prime is 0 t_prime = t_prime_bc.reshape(x.shape[0]) t_prime_safe = torch.where(t_prime == 0, t_cur, t_prime) d_prime = self.denoiser(x_prime, t_prime_safe) # Apply 2nd order correction only where t_next != 0 # Where t_next == 0, use first-order Euler step w_cur = 1 - 1 / (2 * self.alpha) w_prime = 1 / (2 * self.alpha) x_euler = x + h_bc * d_cur x_heun = x + h_bc * (w_cur * d_cur + w_prime * d_prime) x_next = mask_bc * x_heun + (1 - mask_bc) * x_euler return x_next
[docs] class EDMStochasticEulerSolver(Solver): r""" First-order stochastic Euler sampler from the EDM paper. Implements stochastic sampling with configurable noise injection controlled by the "churn" parameters. .. important:: This is **not** a true SDE solver. It performs ad-hoc noise injection ("churn") at each step to improve sample diversity, but the underlying integration is still an ODE step. Therefore, the denoiser should return the right-hand side of the **ODE**, not the SDE. By default, noise injection is performed directly in time-step space. For linear-Gaussian noise schedules where diffusion time and noise level are not equal (e.g., VP schedule), provide ``sigma_fn`` and ``sigma_inv_fn`` to apply churn in noise-level space rather than time-step space. Optionally provide ``diffusion_fn`` to control the time-dependent magnitude of the injected noise. .. code-block:: python def sigma_fn( t: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # noise level, same shape as t def sigma_inv_fn( sigma: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # diffusion time, same shape as sigma def diffusion_fn( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) ) -> Tensor: ... # g^2(x, t), broadcastable to shape of x Parameters ---------- denoiser : Denoiser A callable implementing the :class:`~physicsnemo.diffusion.Denoiser` interface. Should return the right-hand side of the **ODE** (not the SDE, since the stochastic noise injection is handled internally by this solver). Typically obtained via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` with ``denoising_type="ode"``. S_churn : float, optional Controls the amount of noise added at each step. Higher values add more stochasticity. By default 0 (deterministic), in which case this solver is equivalent to the deterministic :class:`EulerSolver`. S_min : float, optional Minimum diffusion time (or noise level if ``sigma_fn`` and ``sigma_inv_fn`` are provided) for applying churn. By default 0. S_max : float, optional Maximum diffusion time (or noise level if ``sigma_fn`` and ``sigma_inv_fn`` are provided) for applying churn. By default ``float("inf")``. S_noise : float, optional Noise scaling factor. Large values add more noise to the latent state. By default 1. num_steps : int, optional Total number of sampling steps, used to scale churn. By default 18. sigma_fn : Callable[[Tensor], Tensor] | None, optional Maps time to noise level :math:`\sigma(t)`. Useful for linear-Gaussian schedules where :math:`\sigma(t) \neq t`. Typically :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma`. If provided, ``sigma_inv_fn`` must also be provided. By default ``None`` (identity mapping). sigma_inv_fn : Callable[[Tensor], Tensor] | None, optional Maps noise level back to time. Typically :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma_inv`. If provided, ``sigma_fn`` must also be provided. By default ``None`` (identity mapping). diffusion_fn : Callable[[Tensor, Tensor], Tensor] | None, optional Controls the time-dependent magnitude of the injected noise, in addition of the ``S_noise`` scaling factor. Typically the squared diffusion coefficient :math:`g^2(\mathbf{x}, t)` from the reverse SDE, obtained from :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.diffusion`. By default ``None`` (:math:`g^2 = 2t`), which corresponds to an EDM-like noise schedule. Note ---- Reference: `Elucidating the Design Space of Diffusion-Based Generative Models <https://arxiv.org/abs/2206.00364>`_ Examples -------- Basic usage with default parameters (noise injection in time-step space): >>> import torch >>> from physicsnemo.diffusion.samplers.solvers import ( ... EDMStochasticEulerSolver, ... ) >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = EDMStochasticEulerSolver(denoiser, S_churn=40, num_steps=18) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) Using noise scheduler methods for linear-Gaussian schedules where :math:`\sigma(t) \neq t` (e.g., VP schedule). The callbacks map between time and noise level, allowing the churn to be applied in noise-level space before converting back to time-step space: >>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler >>> scheduler = VPNoiseScheduler() >>> num_steps = 10 >>> solver = EDMStochasticEulerSolver( ... denoiser, ... S_churn=40, ... num_steps=num_steps, ... sigma_fn=scheduler.sigma, ... sigma_inv_fn=scheduler.sigma_inv, ... diffusion_fn=scheduler.diffusion, ... ) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) """ def __init__( self, denoiser: Denoiser, S_churn: float = 0, S_min: float = 0, S_max: float = float("inf"), S_noise: float = 1, num_steps: int = 18, sigma_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, sigma_inv_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, diffusion_fn: Callable[ [Float[Tensor, " B *dims"], Float[Tensor, " B"]], Float[Tensor, " B *_"] ] | None = None, ) -> None: self.denoiser = denoiser self.S_churn = S_churn self.S_min = S_min self.S_max = S_max self.S_noise = S_noise self.num_steps = num_steps # Validate sigma_fn and sigma_inv_fn if (sigma_fn is None) != (sigma_inv_fn is None): raise ValueError( "sigma_fn and sigma_inv_fn must both be provided or both None." ) if sigma_fn is None and sigma_inv_fn is None: self.sigma_fn = lambda t: t self.sigma_inv_fn = lambda sigma: sigma self._use_noise_level_space = False else: self.sigma_fn = sigma_fn self.sigma_inv_fn = sigma_inv_fn self._use_noise_level_space = True if diffusion_fn is None: self.diffusion_fn = lambda x, t: 2 * t.reshape(-1, *([1] * (x.ndim - 1))) else: self.diffusion_fn = diffusion_fn
[docs] def step( self, x: Float[Tensor, " B *dims"], t_cur: Float[Tensor, " B"], t_next: Float[Tensor, " B"], ) -> Float[Tensor, " B *dims"]: r""" Perform one stochastic Euler sampling step. Parameters ---------- x : Tensor Current noisy latent state :math:`\mathbf{x}_n` of shape :math:`(B, *)` where :math:`B` is the batch size. t_cur : Tensor Current diffusion time :math:`t_n` of shape :math:`(B,)`. t_next : Tensor Target diffusion time :math:`t_{n-1}` of shape :math:`(B,)`. Returns ------- Tensor Updated latent state :math:`\mathbf{x}_{n-1}` at time ``t_next``, same shape as ``x``. """ # Ensure contiguous strides so successive denoiser calls (across # sampling steps) present the same stride layout to torch.compile, # avoiding spurious recompilations / silently divergent traces. t_cur = t_cur.contiguous() t_next = t_next.contiguous() # Reshape t for broadcasting: (B,) -> (B, 1, ..., 1) expected_shape = (-1,) + (1,) * (x.ndim - 1) t_cur_bc = t_cur.reshape(expected_shape) t_next_bc = t_next.reshape(expected_shape) gamma_base = min(self.S_churn / self.num_steps, math.sqrt(2) - 1) # Compute perturbed time t_hat with increased noise # NOTE: sigma_fn and sigma_inv_fn are identity if not provided (stays # in time-step space). diffusion_fn defaults to g^2 = 2t (EDM-like # noise schedule). sigma_cur_bc = self.sigma_fn(t_cur_bc) # Mask: apply churn only where S_min <= sigma <= S_max churn_mask = (sigma_cur_bc >= self.S_min) & (sigma_cur_bc <= self.S_max) gamma_bc = torch.where(churn_mask, gamma_base, 0.0) sigma_hat_bc = sigma_cur_bc + gamma_bc * sigma_cur_bc t_hat_bc = self.sigma_inv_fn(sigma_hat_bc) # Noise scale: sqrt(sigma_hat^2 - sigma_cur^2) * S_noise * g(x,t) / sqrt(2*t) g_sq_bc = self.diffusion_fn(x, t_cur) safe_t_cur_bc = torch.where(t_cur_bc == 0, torch.ones_like(t_cur_bc), t_cur_bc) noise_scale_bc = ( (t_hat_bc**2 - t_cur_bc**2).clamp(min=0).sqrt() * self.S_noise * (g_sq_bc / (2 * safe_t_cur_bc)).sqrt() ) noise_scale_bc = torch.where( t_cur_bc == 0, torch.zeros_like(noise_scale_bc), noise_scale_bc ) # Perturb latent with noise x_hat = x + noise_scale_bc * torch.randn_like(x) # Euler step from t_hat to t_next t_hat = t_hat_bc.reshape(x.shape[0]) d_cur = self.denoiser(x_hat, t_hat) x_next = x_hat + (t_next_bc - t_hat_bc) * d_cur return x_next
[docs] class EDMStochasticHeunSolver(Solver): r""" Second-order stochastic Heun sampler from the EDM paper. Implements stochastic sampling with configurable noise injection controlled by the "churn" parameters, using a second-order Heun correction step. .. important:: This is **not** a true SDE solver. It performs ad-hoc noise injection ("churn") at each step to improve sample diversity, but the underlying integration is still an ODE step. Therefore, the denoiser should return the right-hand side of the **ODE**, not the SDE. By default, noise injection is performed directly in time-step space. For linear-Gaussian noise schedules where diffusion time and noise level are not equal (e.g., VP schedule), provide ``sigma_fn`` and ``sigma_inv_fn`` to apply churn in noise-level space rather than time-step space. Optionally provide ``diffusion_fn`` to control the time-dependent magnitude of the injected noise. .. code-block:: python def sigma_fn( t: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # noise level, same shape as t def sigma_inv_fn( sigma: Tensor, # shape: (B,) or broadcastable ) -> Tensor: ... # diffusion time, same shape as sigma def diffusion_fn( x: Tensor, # shape: (B, *dims) t: Tensor, # shape: (B,) ) -> Tensor: ... # g^2(x, t), broadcastable to shape of x Parameters ---------- denoiser : Denoiser A callable implementing the :class:`~physicsnemo.diffusion.Denoiser` interface. Should return the right-hand side of the **ODE** (not the SDE, since the stochastic noise injection is handled internally by this solver). Typically obtained via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.get_denoiser` with ``denoising_type="ode"``. alpha : float, optional Interpolation parameter for the corrector step, must be in (0, 1]. ``alpha=1`` gives the standard Heun method (trapezoidal rule), ``alpha=0.5`` gives the midpoint method. By default 1. S_churn : float, optional Controls the amount of noise added at each step. Higher values add more stochasticity. By default 0 (deterministic), in which case this solver is equivalent to the deterministic :class:`HeunSolver`. S_min : float, optional Minimum diffusion time (or noise level if ``sigma_fn`` and ``sigma_inv_fn`` are provided) for applying churn. By default 0. S_max : float, optional Maximum diffusion time (or noise level if ``sigma_fn`` and ``sigma_inv_fn`` are provided) for applying churn. By default ``float("inf")``. S_noise : float, optional Noise scaling factor. Large values add more noise to the latent state. By default 1. num_steps : int, optional Total number of sampling steps, used to scale churn. By default 18. sigma_fn : Callable[[Tensor], Tensor] | None, optional Maps time to noise level :math:`\sigma(t)`. Useful for linear-Gaussian schedules where :math:`\sigma(t) \neq t`. Typically :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma`. If provided, ``sigma_inv_fn`` must also be provided. By default ``None`` (identity mapping). sigma_inv_fn : Callable[[Tensor], Tensor] | None, optional Maps noise level back to time. Typically :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.sigma_inv`. If provided, ``sigma_fn`` must also be provided. By default ``None`` (identity mapping). diffusion_fn : Callable[[Tensor, Tensor], Tensor] | None, optional Controls the time-dependent magnitude of the injected noise, in addition of the ``S_noise`` scaling factor. Typically the squared diffusion coefficient :math:`g^2(\mathbf{x}, t)` from the reverse SDE, obtained from :meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.diffusion`. By default ``None`` (:math:`g^2 = 2t`), which corresponds to an EDM-like noise schedule. Note ---- Reference: `Elucidating the Design Space of Diffusion-Based Generative Models <https://arxiv.org/abs/2206.00364>`_ Examples -------- Basic usage with default parameters (noise injection in time-step space): >>> import torch >>> from physicsnemo.diffusion.samplers.solvers import ( ... EDMStochasticHeunSolver, ... ) >>> denoiser = lambda x, t: x / (1 + t.view(-1, 1, 1, 1)**2) # Toy denoiser >>> solver = EDMStochasticHeunSolver(denoiser, S_churn=40, num_steps=18) >>> x_t = torch.randn(1, 3, 8, 8) >>> t_cur = torch.tensor([1.0]) >>> t_next = torch.tensor([0.5]) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) Using noise scheduler methods for linear-Gaussian schedules where :math:`\sigma(t) \neq t` (e.g., VP schedule). The callbacks map between time and noise level, allowing the churn to be applied in noise-level space before converting back to time-step space: >>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler >>> scheduler = VPNoiseScheduler() >>> num_steps = 10 >>> solver = EDMStochasticHeunSolver( ... denoiser, ... S_churn=40, ... num_steps=num_steps, ... sigma_fn=scheduler.sigma, ... sigma_inv_fn=scheduler.sigma_inv, ... diffusion_fn=scheduler.diffusion, ... ) >>> x_tm1 = solver.step(x_t, t_cur, t_next) >>> x_tm1.shape torch.Size([1, 3, 8, 8]) """ def __init__( self, denoiser: Denoiser, alpha: float = 1.0, S_churn: float = 0, S_min: float = 0, S_max: float = float("inf"), S_noise: float = 1, num_steps: int = 18, sigma_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, sigma_inv_fn: Callable[[Float[Tensor, " *shape"]], Float[Tensor, " *shape"]] | None = None, diffusion_fn: Callable[ [Float[Tensor, " B *dims"], Float[Tensor, " B"]], Float[Tensor, " B *_"] ] | None = None, ) -> None: self.denoiser = denoiser if not 0 < alpha <= 1: raise ValueError(f"alpha must be in (0, 1], got {alpha}") self.alpha = alpha self.S_churn = S_churn self.S_min = S_min self.S_max = S_max self.S_noise = S_noise self.num_steps = num_steps # Validate sigma_fn and sigma_inv_fn if (sigma_fn is None) != (sigma_inv_fn is None): raise ValueError( "sigma_fn and sigma_inv_fn must both be provided or both None." ) if sigma_fn is None and sigma_inv_fn is None: self.sigma_fn = lambda t: t self.sigma_inv_fn = lambda sigma: sigma self._use_noise_level_space = False else: self.sigma_fn = sigma_fn self.sigma_inv_fn = sigma_inv_fn self._use_noise_level_space = True if diffusion_fn is None: self.diffusion_fn = lambda x, t: 2 * t.reshape(-1, *([1] * (x.ndim - 1))) else: self.diffusion_fn = diffusion_fn
[docs] def step( self, x: Float[Tensor, " B *dims"], t_cur: Float[Tensor, " B"], t_next: Float[Tensor, " B"], ) -> Float[Tensor, " B *dims"]: r""" Perform one stochastic Heun sampling step. Parameters ---------- x : Tensor Current noisy latent state :math:`\mathbf{x}_n` of shape :math:`(B, *)` where :math:`B` is the batch size. t_cur : Tensor Current diffusion time :math:`t_n` of shape :math:`(B,)`. t_next : Tensor Target diffusion time :math:`t_{n-1}` of shape :math:`(B,)`. Returns ------- Tensor Updated latent state :math:`\mathbf{x}_{n-1}` at time ``t_next``, same shape as ``x``. """ # Ensure contiguous strides so successive denoiser calls (across # sampling steps) present the same stride layout to torch.compile, # avoiding spurious recompilations / silently divergent traces. t_cur = t_cur.contiguous() t_next = t_next.contiguous() # Reshape t for broadcasting: (B,) -> (B, 1, ..., 1) expected_shape = (-1,) + (1,) * (x.ndim - 1) t_cur_bc = t_cur.reshape(expected_shape) t_next_bc = t_next.reshape(expected_shape) gamma_base = min(self.S_churn / self.num_steps, math.sqrt(2) - 1) # Compute perturbed time t_hat with increased noise # NOTE: sigma_fn and sigma_inv_fn are identity if not provided (stays # in time-step space). diffusion_fn defaults to g^2 = 2t (EDM-like # noise schedule). sigma_cur_bc = self.sigma_fn(t_cur_bc) # Mask: apply churn only where S_min <= sigma <= S_max churn_mask = (sigma_cur_bc >= self.S_min) & (sigma_cur_bc <= self.S_max) gamma_bc = torch.where(churn_mask, gamma_base, 0.0) sigma_hat_bc = sigma_cur_bc + gamma_bc * sigma_cur_bc t_hat_bc = self.sigma_inv_fn(sigma_hat_bc) # Noise scale: sqrt(sigma_hat^2 - sigma_cur^2) * S_noise * g(x,t) / sqrt(2*t) g_sq_bc = self.diffusion_fn(x, t_cur) safe_t_cur_bc = torch.where(t_cur_bc == 0, torch.ones_like(t_cur_bc), t_cur_bc) noise_scale_bc = ( (sigma_hat_bc**2 - sigma_cur_bc**2).clamp(min=0).sqrt() * self.S_noise * (g_sq_bc / (2 * safe_t_cur_bc)).sqrt() ) noise_scale_bc = torch.where( t_cur_bc == 0, torch.zeros_like(noise_scale_bc), noise_scale_bc ) # Perturb latent with noise x_hat = x + noise_scale_bc * torch.randn_like(x) # Euler step from t_hat to intermediate point (predictor) t_hat = t_hat_bc.reshape(x.shape[0]) h_bc = t_next_bc - t_hat_bc d_cur = self.denoiser(x_hat, t_hat) t_prime_bc = t_hat_bc + self.alpha * h_bc x_prime = x_hat + self.alpha * h_bc * d_cur # Mask for elements where t_next != 0 (need 2nd order correction) mask_bc = (t_next_bc != 0).float() # Second RHS evaluation (compute everywhere, masked later) t_prime = t_prime_bc.reshape(x.shape[0]) # Avoid issues by using t_hat where t_prime would be 0 t_prime_safe = torch.where(t_prime == 0, t_hat, t_prime) d_prime = self.denoiser(x_prime, t_prime_safe) # Apply 2nd order correction only where t_next != 0 w_cur = 1 - 1 / (2 * self.alpha) w_prime = 1 / (2 * self.alpha) x_euler = x_hat + h_bc * d_cur x_heun = x_hat + h_bc * (w_cur * d_cur + w_prime * d_prime) x_next = mask_bc * x_heun + (1 - mask_bc) * x_euler return x_next