Metrics and Losses#
This module provides two categories of tools:
training losses for learning diffusion models
evaluation metrics for measuring the quality of generated samples
Training Losses#
The standard training objective for diffusion models is denoising score matching (DSM). The model is trained to recover clean data from a noisy version, with the noise scheduler handling time sampling, noise injection, and loss weighting.
MSEDSMLoss implements the
MSE-based DSM loss and supports both x0-predictor and score-predictor
training. WeightedMSEDSMLoss
extends it with an element-wise weight tensor for masking specific spatial
regions or channels (for example, land versus ocean in weather applications).
from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.metrics.losses import MSEDSMLoss
scheduler = EDMNoiseScheduler()
# x0-predictor training (default)
loss_fn = MSEDSMLoss(model, scheduler)
# Score-predictor training
loss_fn_score = MSEDSMLoss(
model, scheduler,
prediction_type="score",
score_to_x0_fn=scheduler.score_to_x0,
)
Evaluation Metrics#
The framework provides evaluation metrics for assessing the quality of
generated samples. The Fréchet Inception Distance (FID) is
available using
calculate_fid_from_inception_stats(),
which computes the FID from precomputed Inception-v3 statistics.
API Reference#
MSEDSMLoss#
- class physicsnemo.diffusion.metrics.losses.MSEDSMLoss(
- model: DiffusionModel,
- noise_scheduler: NoiseScheduler,
- prediction_type: Literal['x0', 'score'] = 'x0',
- score_to_x0_fn: Callable[[Tensor, Tensor, Tensor], Tensor] | None = None,
- reduction: Literal['none', 'mean', 'sum'] = 'mean',
Mean-squared-error denoising score matching loss for training diffusion models.
Implements the denoising score matching objective. Given clean data \(\mathbf{x}_0\), the loss is:
\[\mathcal{L} = \mathbb{E}_{t, \boldsymbol{\epsilon}} \left[ w(t) \left\| \hat{\mathbf{x}}_0(\mathbf{x}_t, t) - \mathbf{x}_0 \right\|^2 \right]\]All training functionality is centered around a noise scheduler that must implement the
NoiseSchedulerprotocol. At each training step the noise scheduler provides:Time sampling via
sample_time(): draws random diffusion times \(t\).Noise injection via
add_noise(): produces the noisy state \(\mathbf{x}_t\) from clean data \(\mathbf{x}_0\).Loss weighting via
loss_weight(): returns the per-sample weight \(w(t)\).
The model can be trained to either directly predict the clean data \(\hat{\mathbf{x}}_0\) (
prediction_type="x0", default) or to predict the score, which is then converted to an \(\hat{\mathbf{x}}_0\) estimate via a user-providedscore_to_x0_fncallback (prediction_type="score").The
modelargument must satisfy theDiffusionModelinterface:model( x: torch.Tensor, # Noisy state, shape: (B, *) t: torch.Tensor, # Diffusion time, shape: (B,) condition: torch.Tensor | TensorDict | None = None, # Conditioning information, shape: (B, *cond_dims) **model_kwargs: Any, ) -> torch.Tensor # Model prediction, shape: (B, *)
When
prediction_type="score", you must also provide ascore_to_x0_fncallback when instantiating the loss, with the following signature:score_to_x0_fn( score: torch.Tensor, # Predicted score, shape: (B, *) x_t: torch.Tensor, # Noisy state, shape: (B, *) t: torch.Tensor, # Diffusion time, shape: (B,) ) -> torch.Tensor # Clean data estimate, shape: (B, *)
For
LinearGaussianNoiseSchedulersubclasses, thescore_to_x0()method provides a ready-madescore_to_x0_fn.- Parameters:
model (DiffusionModel) – Diffusion model to train. Can be a plain neural network, or a model wrapped with a preconditioner (e.g.,
EDMPreconditioner). The output is interpreted according toprediction_type: as a clean-data estimate when"x0", or as a score when"score". Must satisfy theDiffusionModelprotocol.noise_scheduler (NoiseScheduler) – Noise scheduler implementing the
NoiseSchedulerprotocol, providing the methods:sample_time(),add_noise(), andloss_weight().prediction_type (Literal["x0", "score"], default="x0") – Type of prediction the model outputs. Use
"x0"when the model directly predicts clean data (the most common case with standard preconditioners). Use"score"when the model predicts the score, in which casescore_to_x0_fnmust be provided.score_to_x0_fn (Callable[[Tensor, Tensor, Tensor], Tensor], optional) – Callback to convert a score prediction to an \(\hat{\mathbf{x}}_0\) estimate. Required when
prediction_type="score". See above for the expected signature.reduction (Literal["none", "mean", "sum"], default="mean") – Reduction to apply to the output:
"none"returns the per-element loss,"mean"returns the mean over all elements,"sum"returns the sum over all elements.
- Raises:
ValueError – If
prediction_typeis not"x0"or"score".ValueError – If
prediction_type="score"andscore_to_x0_fnisNone.
Examples
Example 1: Standard unconditional x0-predictor training with EDM schedule and preconditioner:
>>> import torch >>> from physicsnemo.core import Module >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner >>> from physicsnemo.diffusion.metrics.losses import MSEDSMLoss >>> >>> class UnconditionalModel(Module): ... def __init__(self): ... super().__init__() ... self.net = torch.nn.Conv2d(3, 3, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> >>> model = UnconditionalModel() >>> scheduler = EDMNoiseScheduler() >>> precond = EDMPreconditioner(model) >>> loss_fn = MSEDSMLoss(precond, scheduler) >>> x0 = torch.randn(4, 3, 8, 8) >>> loss = loss_fn(x0) >>> loss.shape torch.Size([])
Example 2: Conditional training with VP schedule. The model receives multiple conditioning tensors (an image and a vector) through a
TensorDict:>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler >>> from physicsnemo.diffusion.preconditioners import VPPreconditioner >>> from tensordict import TensorDict >>> >>> class ConditionalModel(Module): ... def __init__(self): ... super().__init__() ... self.img_net = torch.nn.Conv2d(6, 3, 1) ... self.vec_net = torch.nn.Linear(4, 3 * 8 * 8) ... def forward(self, x, t, condition=None): ... y_img = condition["image"] ... y_vec = self.vec_net(condition["vector"]).view_as(x) ... return self.img_net(torch.cat([x, y_img], dim=1)) + y_vec >>> >>> cond_model = ConditionalModel() >>> scheduler_vp = VPNoiseScheduler() >>> precond_vp = VPPreconditioner(cond_model) >>> loss_fn = MSEDSMLoss(precond_vp, scheduler_vp) >>> condition = TensorDict({ ... "image": torch.randn(4, 3, 8, 8), ... "vector": torch.randn(4, 4), ... }, batch_size=[4]) >>> loss = loss_fn(x0, condition=condition) >>> loss.shape torch.Size([])
Example 3: Training a score-predictor. The model outputs score predictions, and
score_to_x0_fnconverts them to clean data estimates for the loss computation. For linear-Gaussian noise schedulers, the methodscore_to_x0()provides a ready-made conversion:>>> scheduler = EDMNoiseScheduler() >>> loss_fn = MSEDSMLoss( ... model=model, ... noise_scheduler=scheduler, ... prediction_type="score", ... score_to_x0_fn=scheduler.score_to_x0, ... ) >>> loss = loss_fn(x0) >>> loss.shape torch.Size([])
Example 4: Bare-bones approach without any built-in scheduler or preconditioner. This shows how to plug custom components into
MSEDSMLossby implementing theNoiseSchedulerandDiffusionModelprotocols from scratch:>>> import math >>> >>> # Custom noise scheduler (EDM-like, sigma(t)=t, alpha(t)=1) >>> class MyScheduler: ... def sample_time(self, N, *, device=None, dtype=None): ... return (0.002 * (80 / 0.002) ** torch.rand(N, device=device, dtype=dtype)) ... def add_noise(self, x0, time): ... return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0) ... def loss_weight(self, t): ... return (t**2 + 0.5**2) / (t * 0.5) ** 2 ... def score_to_x0(self, score, x_t, t): ... return x_t + t.view(-1, 1, 1, 1)**2 * score ... def timesteps(self, n, *, device=None, dtype=None): ... return torch.zeros(1) ... def init_latents(self, s, tN, *, device=None, dtype=None): ... return torch.zeros(1) ... def get_denoiser(self, **kw): ... return lambda x, t: x >>> >>> # Custom model with single-tensor conditioning >>> class ConditionalModel: ... def __init__(self): ... self.w = torch.randn(3, 6, 1, 1) * 0.01 ... def __call__(self, x, t, condition=None, **kw): ... return torch.nn.functional.conv2d( ... torch.cat([x, condition], dim=1), self.w) >>> >>> my_scheduler = MyScheduler() >>> cond_model = ConditionalModel() >>> loss_fn = MSEDSMLoss(cond_model, my_scheduler) >>> x0 = torch.randn(2, 3, 8, 8) >>> cond = torch.randn(2, 3, 8, 8) # single-tensor conditioning >>> loss = loss_fn(x0, condition=cond) >>> loss.shape torch.Size([]) >>> >>> # Also works with score prediction + custom conversion >>> loss_fn_score = MSEDSMLoss( ... cond_model, my_scheduler, ... prediction_type="score", ... score_to_x0_fn=my_scheduler.score_to_x0, ... ) >>> loss = loss_fn_score(x0, condition=cond) >>> loss.shape torch.Size([])
WeightedMSEDSMLoss#
- class physicsnemo.diffusion.metrics.losses.WeightedMSEDSMLoss(
- model: DiffusionModel,
- noise_scheduler: NoiseScheduler,
- prediction_type: Literal['x0', 'score'] = 'x0',
- score_to_x0_fn: Callable[[Tensor, Tensor, Tensor], Tensor] | None = None,
- reduction: Literal['none', 'mean', 'sum'] = 'mean',
Weighted mean-squared-error denoising score matching loss.
Identical to
MSEDSMLossbut accepts an additionalweightargument that multiplies the per-element squared error.\[\mathcal{L} = \mathbb{E}_{t, \boldsymbol{\epsilon}} \left[ w(t) \left\| \mathbf{m} \odot \left(\hat{\mathbf{x}}_0(\mathbf{x}_t, t) - \mathbf{x}_0\right) \right\|^2 \right]\]where \(\mathbf{m}\) is the element-wise weight (e.g., a binary mask). A common use case is masking out certain spatial regions or channels of the state.
Note
The
weightargument is not related to the time-dependent loss weight \(w(t)\) provided by the noise scheduler.For more details on prediction types, expected signatures, and additional examples, see
MSEDSMLoss.- Parameters:
model (DiffusionModel) – Diffusion model to train. Must satisfy the
DiffusionModelprotocol.noise_scheduler (NoiseScheduler) – Noise scheduler implementing the
NoiseSchedulerprotocol.prediction_type (Literal["x0", "score"], default="x0") – Type of prediction the model outputs. See
MSEDSMLoss.score_to_x0_fn (callable, optional) – Callback to convert a score prediction to an \(\hat{\mathbf{x}}_0\) estimate. Required when
prediction_type="score".reduction ({"none", "mean", "sum"}, default="mean") – Reduction to apply to the output:
"none"returns the per-element loss,"mean"the mean,"sum"the sum.
Examples
Apply a spatial mask so the loss is computed only over unmasked regions:
>>> import torch >>> from physicsnemo.core import Module >>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler >>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner >>> from physicsnemo.diffusion.metrics.losses import WeightedMSEDSMLoss >>> >>> class UnconditionalModel(Module): ... def __init__(self): ... super().__init__() ... self.net = torch.nn.Conv2d(3, 3, 1) ... def forward(self, x, t, condition=None): ... return self.net(x) >>> >>> model = UnconditionalModel() >>> scheduler = EDMNoiseScheduler() >>> precond = EDMPreconditioner(model) >>> loss_fn = WeightedMSEDSMLoss(precond, scheduler) >>> >>> x0 = torch.randn(4, 3, 8, 8) >>> # Binary mask: zero out the left half of the spatial domain >>> mask = torch.ones(4, 3, 8, 8) >>> mask[:, :, :, :4] = 0.0 >>> loss = loss_fn(x0, weight=mask) >>> loss.shape torch.Size([])
calculate_fid_from_inception_stats#
- physicsnemo.diffusion.metrics.fid.calculate_fid_from_inception_stats(
- mu: Tensor,
- sigma: Tensor,
- mu_ref: Tensor,
- sigma_ref: Tensor,
Calculate the Fréchet Inception Distance (FID) between two sets of Inception statistics.
The Fréchet Inception Distance is a measure of the similarity between two datasets based on their Inception features (mu and sigma). It is commonly used to evaluate the quality of generated images in generative models.
- Parameters:
mu (torch.Tensor:) – Mean of Inception statistics for the generated dataset.
sigma (torch.Tensor:) – Covariance matrix of Inception statistics for the generated dataset.
mu_ref (torch.Tensor) – Mean of Inception statistics for the reference dataset.
sigma_ref (torch.Tensor) – Covariance matrix of Inception statistics for the reference dataset.
- Returns:
The Fréchet Inception Distance (FID) between the two datasets.
- Return type:
float