Metrics and Losses#

This module provides two categories of tools:

  • training losses for learning diffusion models

  • evaluation metrics for measuring the quality of generated samples

Training Losses#

The standard training objective for diffusion models is denoising score matching (DSM). The model is trained to recover clean data from a noisy version, with the noise scheduler handling time sampling, noise injection, and loss weighting.

MSEDSMLoss implements the MSE-based DSM loss and supports both x0-predictor and score-predictor training. WeightedMSEDSMLoss extends it with an element-wise weight tensor for masking specific spatial regions or channels (for example, land versus ocean in weather applications).

from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
from physicsnemo.diffusion.metrics.losses import MSEDSMLoss

scheduler = EDMNoiseScheduler()

# x0-predictor training (default)
loss_fn = MSEDSMLoss(model, scheduler)

# Score-predictor training
loss_fn_score = MSEDSMLoss(
    model, scheduler,
    prediction_type="score",
    score_to_x0_fn=scheduler.score_to_x0,
)

Evaluation Metrics#

The framework provides evaluation metrics for assessing the quality of generated samples. The Fréchet Inception Distance (FID) is available using calculate_fid_from_inception_stats(), which computes the FID from precomputed Inception-v3 statistics.

API Reference#

MSEDSMLoss#

class physicsnemo.diffusion.metrics.losses.MSEDSMLoss(
model: DiffusionModel,
noise_scheduler: NoiseScheduler,
prediction_type: Literal['x0', 'score'] = 'x0',
score_to_x0_fn: Callable[[Tensor, Tensor, Tensor], Tensor] | None = None,
reduction: Literal['none', 'mean', 'sum'] = 'mean',
)[source]#

Mean-squared-error denoising score matching loss for training diffusion models.

Implements the denoising score matching objective. Given clean data \(\mathbf{x}_0\), the loss is:

\[\mathcal{L} = \mathbb{E}_{t, \boldsymbol{\epsilon}} \left[ w(t) \left\| \hat{\mathbf{x}}_0(\mathbf{x}_t, t) - \mathbf{x}_0 \right\|^2 \right]\]

All training functionality is centered around a noise scheduler that must implement the NoiseScheduler protocol. At each training step the noise scheduler provides:

  • Time sampling via sample_time(): draws random diffusion times \(t\).

  • Noise injection via add_noise(): produces the noisy state \(\mathbf{x}_t\) from clean data \(\mathbf{x}_0\).

  • Loss weighting via loss_weight(): returns the per-sample weight \(w(t)\).

The model can be trained to either directly predict the clean data \(\hat{\mathbf{x}}_0\) (prediction_type="x0", default) or to predict the score, which is then converted to an \(\hat{\mathbf{x}}_0\) estimate via a user-provided score_to_x0_fn callback (prediction_type="score").

The model argument must satisfy the DiffusionModel interface:

model(
    x: torch.Tensor,       # Noisy state, shape: (B, *)
    t: torch.Tensor,       # Diffusion time, shape: (B,)
    condition: torch.Tensor | TensorDict | None = None, # Conditioning information, shape: (B, *cond_dims)
    **model_kwargs: Any,
) -> torch.Tensor          # Model prediction, shape: (B, *)

When prediction_type="score", you must also provide a score_to_x0_fn callback when instantiating the loss, with the following signature:

score_to_x0_fn(
    score: torch.Tensor,   # Predicted score, shape: (B, *)
    x_t: torch.Tensor,     # Noisy state, shape: (B, *)
    t: torch.Tensor,       # Diffusion time, shape: (B,)
) -> torch.Tensor          # Clean data estimate, shape: (B, *)

For LinearGaussianNoiseScheduler subclasses, the score_to_x0() method provides a ready-made score_to_x0_fn.

Parameters:
  • model (DiffusionModel) – Diffusion model to train. Can be a plain neural network, or a model wrapped with a preconditioner (e.g., EDMPreconditioner). The output is interpreted according to prediction_type: as a clean-data estimate when "x0", or as a score when "score". Must satisfy the DiffusionModel protocol.

  • noise_scheduler (NoiseScheduler) – Noise scheduler implementing the NoiseScheduler protocol, providing the methods: sample_time(), add_noise(), and loss_weight().

  • prediction_type (Literal["x0", "score"], default="x0") – Type of prediction the model outputs. Use "x0" when the model directly predicts clean data (the most common case with standard preconditioners). Use "score" when the model predicts the score, in which case score_to_x0_fn must be provided.

  • score_to_x0_fn (Callable[[Tensor, Tensor, Tensor], Tensor], optional) – Callback to convert a score prediction to an \(\hat{\mathbf{x}}_0\) estimate. Required when prediction_type="score". See above for the expected signature.

  • reduction (Literal["none", "mean", "sum"], default="mean") – Reduction to apply to the output: "none" returns the per-element loss, "mean" returns the mean over all elements, "sum" returns the sum over all elements.

Raises:
  • ValueError – If prediction_type is not "x0" or "score".

  • ValueError – If prediction_type="score" and score_to_x0_fn is None.

Examples

Example 1: Standard unconditional x0-predictor training with EDM schedule and preconditioner:

>>> import torch
>>> from physicsnemo.core import Module
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.metrics.losses import MSEDSMLoss
>>>
>>> class UnconditionalModel(Module):
...     def __init__(self):
...         super().__init__()
...         self.net = torch.nn.Conv2d(3, 3, 1)
...     def forward(self, x, t, condition=None):
...         return self.net(x)
>>>
>>> model = UnconditionalModel()
>>> scheduler = EDMNoiseScheduler()
>>> precond = EDMPreconditioner(model)
>>> loss_fn = MSEDSMLoss(precond, scheduler)
>>> x0 = torch.randn(4, 3, 8, 8)
>>> loss = loss_fn(x0)
>>> loss.shape
torch.Size([])

Example 2: Conditional training with VP schedule. The model receives multiple conditioning tensors (an image and a vector) through a TensorDict:

>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler
>>> from physicsnemo.diffusion.preconditioners import VPPreconditioner
>>> from tensordict import TensorDict
>>>
>>> class ConditionalModel(Module):
...     def __init__(self):
...         super().__init__()
...         self.img_net = torch.nn.Conv2d(6, 3, 1)
...         self.vec_net = torch.nn.Linear(4, 3 * 8 * 8)
...     def forward(self, x, t, condition=None):
...         y_img = condition["image"]
...         y_vec = self.vec_net(condition["vector"]).view_as(x)
...         return self.img_net(torch.cat([x, y_img], dim=1)) + y_vec
>>>
>>> cond_model = ConditionalModel()
>>> scheduler_vp = VPNoiseScheduler()
>>> precond_vp = VPPreconditioner(cond_model)
>>> loss_fn = MSEDSMLoss(precond_vp, scheduler_vp)
>>> condition = TensorDict({
...     "image": torch.randn(4, 3, 8, 8),
...     "vector": torch.randn(4, 4),
... }, batch_size=[4])
>>> loss = loss_fn(x0, condition=condition)
>>> loss.shape
torch.Size([])

Example 3: Training a score-predictor. The model outputs score predictions, and score_to_x0_fn converts them to clean data estimates for the loss computation. For linear-Gaussian noise schedulers, the method score_to_x0() provides a ready-made conversion:

>>> scheduler = EDMNoiseScheduler()
>>> loss_fn = MSEDSMLoss(
...     model=model,
...     noise_scheduler=scheduler,
...     prediction_type="score",
...     score_to_x0_fn=scheduler.score_to_x0,
... )
>>> loss = loss_fn(x0)
>>> loss.shape
torch.Size([])

Example 4: Bare-bones approach without any built-in scheduler or preconditioner. This shows how to plug custom components into MSEDSMLoss by implementing the NoiseScheduler and DiffusionModel protocols from scratch:

>>> import math
>>>
>>> # Custom noise scheduler (EDM-like, sigma(t)=t, alpha(t)=1)
>>> class MyScheduler:
...     def sample_time(self, N, *, device=None, dtype=None):
...         return (0.002 * (80 / 0.002) ** torch.rand(N, device=device, dtype=dtype))
...     def add_noise(self, x0, time):
...         return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0)
...     def loss_weight(self, t):
...         return (t**2 + 0.5**2) / (t * 0.5) ** 2
...     def score_to_x0(self, score, x_t, t):
...         return x_t + t.view(-1, 1, 1, 1)**2 * score
...     def timesteps(self, n, *, device=None, dtype=None):
...         return torch.zeros(1)
...     def init_latents(self, s, tN, *, device=None, dtype=None):
...         return torch.zeros(1)
...     def get_denoiser(self, **kw):
...         return lambda x, t: x
>>>
>>> # Custom model with single-tensor conditioning
>>> class ConditionalModel:
...     def __init__(self):
...         self.w = torch.randn(3, 6, 1, 1) * 0.01
...     def __call__(self, x, t, condition=None, **kw):
...         return torch.nn.functional.conv2d(
...             torch.cat([x, condition], dim=1), self.w)
>>>
>>> my_scheduler = MyScheduler()
>>> cond_model = ConditionalModel()
>>> loss_fn = MSEDSMLoss(cond_model, my_scheduler)
>>> x0 = torch.randn(2, 3, 8, 8)
>>> cond = torch.randn(2, 3, 8, 8)  # single-tensor conditioning
>>> loss = loss_fn(x0, condition=cond)
>>> loss.shape
torch.Size([])
>>>
>>> # Also works with score prediction + custom conversion
>>> loss_fn_score = MSEDSMLoss(
...     cond_model, my_scheduler,
...     prediction_type="score",
...     score_to_x0_fn=my_scheduler.score_to_x0,
... )
>>> loss = loss_fn_score(x0, condition=cond)
>>> loss.shape
torch.Size([])

WeightedMSEDSMLoss#

class physicsnemo.diffusion.metrics.losses.WeightedMSEDSMLoss(
model: DiffusionModel,
noise_scheduler: NoiseScheduler,
prediction_type: Literal['x0', 'score'] = 'x0',
score_to_x0_fn: Callable[[Tensor, Tensor, Tensor], Tensor] | None = None,
reduction: Literal['none', 'mean', 'sum'] = 'mean',
)[source]#

Weighted mean-squared-error denoising score matching loss.

Identical to MSEDSMLoss but accepts an additional weight argument that multiplies the per-element squared error.

\[\mathcal{L} = \mathbb{E}_{t, \boldsymbol{\epsilon}} \left[ w(t) \left\| \mathbf{m} \odot \left(\hat{\mathbf{x}}_0(\mathbf{x}_t, t) - \mathbf{x}_0\right) \right\|^2 \right]\]

where \(\mathbf{m}\) is the element-wise weight (e.g., a binary mask). A common use case is masking out certain spatial regions or channels of the state.

Note

The weight argument is not related to the time-dependent loss weight \(w(t)\) provided by the noise scheduler.

For more details on prediction types, expected signatures, and additional examples, see MSEDSMLoss.

Parameters:
  • model (DiffusionModel) – Diffusion model to train. Must satisfy the DiffusionModel protocol.

  • noise_scheduler (NoiseScheduler) – Noise scheduler implementing the NoiseScheduler protocol.

  • prediction_type (Literal["x0", "score"], default="x0") – Type of prediction the model outputs. See MSEDSMLoss.

  • score_to_x0_fn (callable, optional) – Callback to convert a score prediction to an \(\hat{\mathbf{x}}_0\) estimate. Required when prediction_type="score".

  • reduction ({"none", "mean", "sum"}, default="mean") – Reduction to apply to the output: "none" returns the per-element loss, "mean" the mean, "sum" the sum.

Examples

Apply a spatial mask so the loss is computed only over unmasked regions:

>>> import torch
>>> from physicsnemo.core import Module
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.metrics.losses import WeightedMSEDSMLoss
>>>
>>> class UnconditionalModel(Module):
...     def __init__(self):
...         super().__init__()
...         self.net = torch.nn.Conv2d(3, 3, 1)
...     def forward(self, x, t, condition=None):
...         return self.net(x)
>>>
>>> model = UnconditionalModel()
>>> scheduler = EDMNoiseScheduler()
>>> precond = EDMPreconditioner(model)
>>> loss_fn = WeightedMSEDSMLoss(precond, scheduler)
>>>
>>> x0 = torch.randn(4, 3, 8, 8)
>>> # Binary mask: zero out the left half of the spatial domain
>>> mask = torch.ones(4, 3, 8, 8)
>>> mask[:, :, :, :4] = 0.0
>>> loss = loss_fn(x0, weight=mask)
>>> loss.shape
torch.Size([])

calculate_fid_from_inception_stats#

physicsnemo.diffusion.metrics.fid.calculate_fid_from_inception_stats(
mu: Tensor,
sigma: Tensor,
mu_ref: Tensor,
sigma_ref: Tensor,
) Tensor[source]#

Calculate the Fréchet Inception Distance (FID) between two sets of Inception statistics.

The Fréchet Inception Distance is a measure of the similarity between two datasets based on their Inception features (mu and sigma). It is commonly used to evaluate the quality of generated images in generative models.

Parameters:
  • mu (torch.Tensor:) – Mean of Inception statistics for the generated dataset.

  • sigma (torch.Tensor:) – Covariance matrix of Inception statistics for the generated dataset.

  • mu_ref (torch.Tensor) – Mean of Inception statistics for the reference dataset.

  • sigma_ref (torch.Tensor) – Covariance matrix of Inception statistics for the reference dataset.

Returns:

The Fréchet Inception Distance (FID) between the two datasets.

Return type:

float