# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Denoising score matching losses for diffusion model training."""
from __future__ import annotations
from typing import Any, Callable, Literal
import torch
from jaxtyping import Float
from tensordict import TensorDict
from torch import Tensor
from physicsnemo.diffusion.base import DiffusionModel, PredictorType
from physicsnemo.diffusion.noise_schedulers import NoiseScheduler
from physicsnemo.diffusion.utils.utils import apply_loss_weight
def _check_domain_parallel_scheduler(
x0: torch.Tensor, scheduler: NoiseScheduler
) -> None:
"""Raise if *x0* is domain-sharded but *scheduler* is not domain-parallel."""
mesh = getattr(x0, "device_mesh", None)
if mesh is None:
return
from physicsnemo.diffusion.noise_schedulers.domain_parallel import (
DomainParallelNoiseScheduler,
)
if not isinstance(scheduler, DomainParallelNoiseScheduler):
raise ValueError(
"x0 is a ShardTensor (domain-parallel) but the noise scheduler "
"is not a DomainParallelNoiseScheduler. Wrap your scheduler with "
"DomainParallelNoiseScheduler before passing it to the loss. "
"See physicsnemo.diffusion.noise_schedulers.DomainParallelNoiseScheduler."
)
def _check_weight_mesh(weight: torch.Tensor, x0: torch.Tensor) -> None:
"""Raise if *x0* is a DTensor but *weight* is not on the same mesh."""
mesh = getattr(x0, "device_mesh", None)
if mesh is None:
return
weight_mesh = getattr(weight, "device_mesh", None)
if weight_mesh is None:
raise ValueError(
"x0 is a DTensor (domain-parallel) but weight is a plain tensor. "
"weight must be a DTensor on the same device mesh as x0. "
"Shard or replicate weight to match x0 before passing it to the loss."
)
def _maybe_promote_to_mesh(t: torch.Tensor, ref: torch.Tensor) -> torch.Tensor:
"""Promote *t* to a replicated DTensor on *ref*'s device mesh if needed.
When ``ref`` is a ``ShardTensor`` (or any ``DTensor`` with a device mesh),
plain-tensor operands must be promoted to replicated ``DTensor``s on the
same mesh before element-wise arithmetic, otherwise DTensor dispatch
raises a mixed-type error.
"""
mesh = getattr(ref, "device_mesh", None)
if mesh is None:
return t
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Replicate
if isinstance(t, DTensor):
return t
return DTensor.from_local(t, device_mesh=mesh, placements=[Replicate()])
[docs]
class MSEDSMLoss:
r"""
Mean-squared-error denoising score matching loss for training diffusion
models.
Implements the denoising score matching objective. Given clean data
:math:`\mathbf{x}_0`, the loss is:
.. math::
\mathcal{L} = \mathbb{E}_{t, \boldsymbol{\epsilon}}
\left[ w(t) \left\| \hat{\mathbf{x}}_0(\mathbf{x}_t, t)
- \mathbf{x}_0 \right\|^2 \right]
All training functionality is centered around a **noise scheduler** that
must implement the
:class:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler` protocol.
At each training step the noise scheduler provides:
- **Time sampling** via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.sample_time`: draws
random diffusion times :math:`t`.
- **Noise injection** via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.add_noise`: produces
the noisy state :math:`\mathbf{x}_t` from clean data
:math:`\mathbf{x}_0`.
- **Loss weighting** via :meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.loss_weight`: returns
the per-sample weight :math:`w(t)`.
The model can be trained to either directly predict the clean data
:math:`\hat{\mathbf{x}}_0` (``prediction_type="x0"``, default) or to
predict the score, which is then converted to an
:math:`\hat{\mathbf{x}}_0` estimate via a user-provided
``score_to_x0_fn`` callback (``prediction_type="score"``).
.. warning::
For domain-parallel training where ``x0`` is a ``ShardTensor``,
the scheduler **must** be wrapped with
:class:`~physicsnemo.diffusion.noise_schedulers.DomainParallelNoiseScheduler`
so that sampled diffusion times are broadcast across spatial
shards. Passing a plain scheduler with sharded data will raise a
``ValueError`` at runtime.
The ``model`` argument must satisfy the
:class:`~physicsnemo.diffusion.DiffusionModel` interface:
.. code-block:: python
model(
x: torch.Tensor, # Noisy state, shape: (B, *)
t: torch.Tensor, # Diffusion time, shape: (B,)
condition: torch.Tensor | TensorDict | None = None, # Conditioning information, shape: (B, *cond_dims)
**model_kwargs: Any,
) -> torch.Tensor # Model prediction, shape: (B, *)
When ``prediction_type="score"``, you must also provide a
``score_to_x0_fn`` callback when instantiating the loss, with the
following signature:
.. code-block:: python
score_to_x0_fn(
score: torch.Tensor, # Predicted score, shape: (B, *)
x_t: torch.Tensor, # Noisy state, shape: (B, *)
t: torch.Tensor, # Diffusion time, shape: (B,)
) -> torch.Tensor # Clean data estimate, shape: (B, *)
For
:class:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler`
subclasses, the
:meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.score_to_x0`
method provides a ready-made ``score_to_x0_fn``.
Parameters
----------
model : DiffusionModel
Diffusion model to train. Can be a plain neural network, or a
model wrapped with a preconditioner (e.g.,
:class:`~physicsnemo.diffusion.preconditioners.EDMPreconditioner`).
The output is interpreted according to ``prediction_type``: as a
clean-data estimate when ``"x0"``, or as a score when ``"score"``.
Must satisfy the
:class:`~physicsnemo.diffusion.DiffusionModel` protocol.
noise_scheduler : NoiseScheduler
Noise scheduler implementing the
:class:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler`
protocol, providing the methods:
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.sample_time`,
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.add_noise`, and
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.loss_weight`.
prediction_type : PredictorType, default="x0"
Type of prediction the model outputs. Use ``"x0"`` when the model
directly predicts clean data (the most common case with standard
preconditioners). Use ``"score"`` when the model predicts the score,
in which case ``score_to_x0_fn`` must be provided. Use ``"epsilon"``
when the model predicts the noise, in which case ``epsilon_to_x0_fn``
must be provided.
score_to_x0_fn : Callable[[Tensor, Tensor, Tensor], Tensor], optional
Callback to convert a score prediction to an
:math:`\hat{\mathbf{x}}_0` estimate. Required when
``prediction_type="score"``. See above for the expected signature.
epsilon_to_x0_fn : Callable[[Tensor, Tensor, Tensor], Tensor], optional
Callback to convert an epsilon (noise) prediction to an
:math:`\hat{\mathbf{x}}_0` estimate. Required when
``prediction_type="epsilon"``. See above for the expected signature.
reduction : Literal["none", "mean", "sum"], default="mean"
Reduction to apply to the output: ``"none"`` returns the
per-element loss, ``"mean"`` returns the mean over all elements,
``"sum"`` returns the sum over all elements.
Raises
------
ValueError
If ``prediction_type`` is not ``"x0"``, ``"score"``, or ``"epsilon"``.
ValueError
If ``prediction_type="score"`` and ``score_to_x0_fn`` is ``None``.
ValueError
If ``prediction_type="epsilon"`` and ``epsilon_to_x0_fn`` is ``None``.
Examples
--------
**Example 1:** Standard unconditional x0-predictor training with EDM
schedule and preconditioner:
>>> import torch
>>> from physicsnemo.core import Module
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.metrics.losses import MSEDSMLoss
>>>
>>> class UnconditionalModel(Module):
... def __init__(self):
... super().__init__()
... self.net = torch.nn.Conv2d(3, 3, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>>
>>> model = UnconditionalModel()
>>> scheduler = EDMNoiseScheduler()
>>> precond = EDMPreconditioner(model)
>>> loss_fn = MSEDSMLoss(precond, scheduler)
>>> x0 = torch.randn(4, 3, 8, 8)
>>> loss = loss_fn(x0)
>>> loss.shape
torch.Size([])
**Example 2:** Conditional training with VP schedule. The model receives
multiple conditioning tensors (an image and a vector) through a
``TensorDict``:
>>> from physicsnemo.diffusion.noise_schedulers import VPNoiseScheduler
>>> from physicsnemo.diffusion.preconditioners import VPPreconditioner
>>> from tensordict import TensorDict
>>>
>>> class ConditionalModel(Module):
... def __init__(self):
... super().__init__()
... self.img_net = torch.nn.Conv2d(6, 3, 1)
... self.vec_net = torch.nn.Linear(4, 3 * 8 * 8)
... def forward(self, x, t, condition=None):
... y_img = condition["image"]
... y_vec = self.vec_net(condition["vector"]).view_as(x)
... return self.img_net(torch.cat([x, y_img], dim=1)) + y_vec
>>>
>>> cond_model = ConditionalModel()
>>> scheduler_vp = VPNoiseScheduler()
>>> precond_vp = VPPreconditioner(cond_model)
>>> loss_fn = MSEDSMLoss(precond_vp, scheduler_vp)
>>> condition = TensorDict({
... "image": torch.randn(4, 3, 8, 8),
... "vector": torch.randn(4, 4),
... }, batch_size=[4])
>>> loss = loss_fn(x0, condition=condition)
>>> loss.shape
torch.Size([])
**Example 3:** Training a score-predictor. The model outputs score
predictions, and ``score_to_x0_fn`` converts them to clean data estimates
for the loss computation. For linear-Gaussian noise schedulers, the method
:meth:`~physicsnemo.diffusion.noise_schedulers.LinearGaussianNoiseScheduler.score_to_x0`
provides a ready-made conversion:
>>> scheduler = EDMNoiseScheduler()
>>> loss_fn = MSEDSMLoss(
... model=model,
... noise_scheduler=scheduler,
... prediction_type="score",
... score_to_x0_fn=scheduler.score_to_x0,
... )
>>> loss = loss_fn(x0)
>>> loss.shape
torch.Size([])
**Example 4:** Bare-bones approach without any built-in scheduler or
preconditioner. This shows how to plug custom components into
:class:`MSEDSMLoss` by implementing the
:class:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler` and
:class:`~physicsnemo.diffusion.DiffusionModel` protocols from scratch.
It also demonstrates passing externally sampled diffusion times via the
``t`` argument for per-sigma-bin loss tracking:
>>> import math
>>>
>>> # Custom noise scheduler (EDM-like, sigma(t)=t, alpha(t)=1)
>>> class MyScheduler:
... def sample_time(self, N, *, device=None, dtype=None):
... return (0.002 * (80 / 0.002) ** torch.rand(N, device=device, dtype=dtype))
... def add_noise(self, x0, time):
... return x0 + time.view(-1, 1, 1, 1) * torch.randn_like(x0)
... def loss_weight(self, t):
... return (t**2 + 0.5**2) / (t * 0.5) ** 2
... def score_to_x0(self, score, x_t, t):
... return x_t + t.view(-1, 1, 1, 1)**2 * score
... def timesteps(self, n, *, device=None, dtype=None):
... return torch.zeros(1)
... def init_latents(self, s, tN, *, device=None, dtype=None):
... return torch.zeros(1)
... def get_denoiser(self, **kw):
... return lambda x, t: x
>>>
>>> # Custom model with single-tensor conditioning
>>> class ConditionalModel:
... def __init__(self):
... self.w = torch.randn(3, 6, 1, 1) * 0.01
... def __call__(self, x, t, condition=None, **kw):
... return torch.nn.functional.conv2d(
... torch.cat([x, condition], dim=1), self.w)
>>>
>>> my_scheduler = MyScheduler()
>>> cond_model = ConditionalModel()
>>> loss_fn = MSEDSMLoss(cond_model, my_scheduler)
>>> x0 = torch.randn(2, 3, 8, 8)
>>> cond = torch.randn(2, 3, 8, 8) # single-tensor conditioning
>>>
>>> # Sample times externally for per-sigma-bin loss tracking
>>> t = my_scheduler.sample_time(x0.shape[0], device=x0.device, dtype=x0.dtype)
>>> loss = loss_fn(x0, condition=cond, t=t)
>>> loss.shape
torch.Size([])
>>> t.shape # t is available for diagnostics after the loss call
torch.Size([2])
>>>
>>> # Also works with score prediction + custom conversion
>>> loss_fn_score = MSEDSMLoss(
... cond_model, my_scheduler,
... prediction_type="score",
... score_to_x0_fn=my_scheduler.score_to_x0,
... )
>>> loss = loss_fn_score(x0, condition=cond)
>>> loss.shape
torch.Size([])
"""
def __init__(
self,
model: DiffusionModel,
noise_scheduler: NoiseScheduler,
prediction_type: PredictorType = "x0",
score_to_x0_fn: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
]
| None = None,
epsilon_to_x0_fn: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
]
| None = None,
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self.noise_scheduler = noise_scheduler
match prediction_type:
case "x0":
self._to_x0 = lambda prediction, x_t, t: prediction
case "score":
if score_to_x0_fn is None:
raise ValueError(
"score_to_x0_fn must be provided when prediction_type='score'."
)
self._to_x0 = score_to_x0_fn
case "epsilon":
if epsilon_to_x0_fn is None:
raise ValueError(
"epsilon_to_x0_fn must be provided when prediction_type='epsilon'."
)
self._to_x0 = epsilon_to_x0_fn
case _:
raise ValueError(
f"prediction_type must be 'x0', 'score', or 'epsilon', "
f"got '{prediction_type}'."
)
# Define the reduction callbacks
_reductions = {
"none": lambda x: x,
"mean": lambda x: x.mean(),
"sum": lambda x: x.sum(),
}
if reduction not in _reductions:
raise ValueError(
f"reduction must be 'none', 'mean', or 'sum', got '{reduction}'."
)
self._reduce = _reductions[reduction]
def __call__(
self,
x0: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"] | None = None,
condition: Float[Tensor, " B *cond_dims"] | TensorDict | None = None,
**model_kwargs: Any,
) -> Float[Tensor, " B *dims"] | Float[Tensor, ""]:
r"""
Compute the denoising score matching loss.
Parameters
----------
x0 : Tensor
Clean data of shape :math:`(B, *)` where :math:`B` is the batch
size and :math:`*` denotes any number of additional dimensions.
t : Tensor or None, optional, default=None
Pre-sampled diffusion time values of shape :math:`(B,)`. When
``None`` (the default), times are sampled internally via
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.sample_time`.
Passing explicit times is useful when the caller needs access
to the sampled values for diagnostics (e.g., per-sigma-bin
loss tracking).
condition : Tensor, TensorDict, or None, optional, default=None
Conditioning information passed to the model. See
:class:`~physicsnemo.diffusion.DiffusionModel` for details.
**model_kwargs : Any
Additional keyword arguments forwarded to the model
Returns
-------
Tensor
If ``reduction="none"``, the per-element weighted loss with same
shape :math:`(B, *)` as ``x0``. If ``reduction="mean"``, or
``reduction="sum"``, a scalar tensor.
"""
if not torch.compiler.is_compiling():
_check_domain_parallel_scheduler(x0, self.noise_scheduler)
B = x0.shape[0]
if t is None:
t = self.noise_scheduler.sample_time(B, device=x0.device, dtype=x0.dtype)
x_t = self.noise_scheduler.add_noise(x0, t)
prediction = self.model(x_t, t, condition=condition, **model_kwargs)
x0_pred = self._to_x0(prediction, x_t, t)
loss = (x0_pred - x0) ** 2
w = self.noise_scheduler.loss_weight(t)
w = apply_loss_weight(w, x0.ndim)
w = _maybe_promote_to_mesh(w, loss)
loss = w * loss
return self._reduce(loss)
[docs]
class WeightedMSEDSMLoss:
r"""
Weighted mean-squared-error denoising score matching loss.
Identical to :class:`MSEDSMLoss` but accepts an additional ``weight``
argument that multiplies the per-element squared error.
.. math::
\mathcal{L} = \mathbb{E}_{t, \boldsymbol{\epsilon}}
\left[ w(t) \left\| \mathbf{m} \odot
\left(\hat{\mathbf{x}}_0(\mathbf{x}_t, t)
- \mathbf{x}_0\right) \right\|^2 \right]
where :math:`\mathbf{m}` is the element-wise weight (e.g., a binary
mask). A common use case is masking out certain spatial regions or
channels of the state.
.. note::
The ``weight`` argument is **not** related to the time-dependent loss
weight :math:`w(t)` provided by the noise scheduler.
.. warning::
For domain-parallel training where ``x0`` is a ``DTensor``
(e.g., a :class:`~physicsnemo.domain_parallel.ShardTensor`), ``weight`` must also be a ``DTensor``
on the same device mesh. The loss function does **not**
automatically promote ``weight``; callers are responsible for
sharding or replicating it to match ``x0``. Passing a plain
tensor ``weight`` with a sharded ``x0`` will raise a
``ValueError`` at runtime.
For more details on prediction types, expected signatures, and
additional examples, see :class:`MSEDSMLoss`.
Parameters
----------
model : DiffusionModel
Diffusion model to train. Must satisfy the
:class:`~physicsnemo.diffusion.DiffusionModel` protocol.
noise_scheduler : NoiseScheduler
Noise scheduler implementing the
:class:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler`
protocol.
prediction_type : PredictorType, default="x0"
Type of prediction the model outputs. See :class:`MSEDSMLoss`.
score_to_x0_fn : callable, optional
Callback to convert a score prediction to an
:math:`\hat{\mathbf{x}}_0` estimate. Required when
``prediction_type="score"``.
epsilon_to_x0_fn : callable, optional
Callback to convert an epsilon (noise) prediction to an
:math:`\hat{\mathbf{x}}_0` estimate. Required when
``prediction_type="epsilon"``.
reduction : {"none", "mean", "sum"}, default="mean"
Reduction to apply to the output: ``"none"`` returns the
per-element loss, ``"mean"`` the mean, ``"sum"`` the sum.
Examples
--------
Apply a spatial mask so the loss is computed only over unmasked regions:
>>> import torch
>>> from physicsnemo.core import Module
>>> from physicsnemo.diffusion.noise_schedulers import EDMNoiseScheduler
>>> from physicsnemo.diffusion.preconditioners import EDMPreconditioner
>>> from physicsnemo.diffusion.metrics.losses import WeightedMSEDSMLoss
>>>
>>> class UnconditionalModel(Module):
... def __init__(self):
... super().__init__()
... self.net = torch.nn.Conv2d(3, 3, 1)
... def forward(self, x, t, condition=None):
... return self.net(x)
>>>
>>> model = UnconditionalModel()
>>> scheduler = EDMNoiseScheduler()
>>> precond = EDMPreconditioner(model)
>>> loss_fn = WeightedMSEDSMLoss(precond, scheduler)
>>>
>>> x0 = torch.randn(4, 3, 8, 8)
>>> # Binary mask: zero out the left half of the spatial domain
>>> mask = torch.ones(4, 3, 8, 8)
>>> mask[:, :, :, :4] = 0.0
>>> loss = loss_fn(x0, weight=mask)
>>> loss.shape
torch.Size([])
"""
def __init__(
self,
model: DiffusionModel,
noise_scheduler: NoiseScheduler,
prediction_type: PredictorType = "x0",
score_to_x0_fn: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
]
| None = None,
epsilon_to_x0_fn: Callable[
[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor
]
| None = None,
reduction: Literal["none", "mean", "sum"] = "mean",
) -> None:
self.model = model
self.noise_scheduler = noise_scheduler
match prediction_type:
case "x0":
self._to_x0 = lambda prediction, x_t, t: prediction
case "score":
if score_to_x0_fn is None:
raise ValueError(
"score_to_x0_fn must be provided when prediction_type='score'."
)
self._to_x0 = score_to_x0_fn
case "epsilon":
if epsilon_to_x0_fn is None:
raise ValueError(
"epsilon_to_x0_fn must be provided when prediction_type='epsilon'."
)
self._to_x0 = epsilon_to_x0_fn
case _:
raise ValueError(
f"prediction_type must be 'x0', 'score', or 'epsilon', "
f"got '{prediction_type}'."
)
# Define the reduction callbacks
_reductions = {
"none": lambda x: x,
"mean": lambda x: x.mean(),
"sum": lambda x: x.sum(),
}
if reduction not in _reductions:
raise ValueError(
f"reduction must be 'none', 'mean', or 'sum', got '{reduction}'."
)
self._reduce = _reductions[reduction]
def __call__(
self,
x0: Float[Tensor, " B *dims"],
weight: Float[Tensor, " B *dims"],
t: Float[Tensor, " B"] | None = None,
condition: Float[Tensor, " B *cond_dims"] | TensorDict | None = None,
**model_kwargs: Any,
) -> Float[Tensor, " B *dims"] | Float[Tensor, ""]:
r"""
Compute the weighted denoising score matching loss.
Parameters
----------
x0 : Tensor
Clean data of shape :math:`(B, *)` where :math:`B` is the batch
size and :math:`*` denotes any number of additional dimensions.
weight : Tensor
Per-element weight of shape :math:`(B, *)`, same shape as
``x0``. For binary masking, use 0 for masked elements and 1
for active elements. When ``x0`` is a ``DTensor``
(domain-parallel), ``weight`` must also be a ``DTensor`` on
the same device mesh.
t : Tensor or None, optional, default=None
Pre-sampled diffusion time values of shape :math:`(B,)`. When
``None`` (the default), times are sampled internally via
:meth:`~physicsnemo.diffusion.noise_schedulers.NoiseScheduler.sample_time`.
Passing explicit times is useful when the caller needs access
to the sampled values for diagnostics (e.g., per-sigma-bin
loss tracking).
condition : Tensor, TensorDict, or None, optional, default=None
Conditioning information passed to the model. See
:class:`~physicsnemo.diffusion.DiffusionModel` for details.
**model_kwargs : Any
Additional keyword arguments forwarded to the model
Returns
-------
Tensor
If ``reduction="none"``, the per-element weighted loss with same
shape :math:`(B, *)` as ``x0``. If ``reduction="mean"``, or
``reduction="sum"``, a scalar tensor.
"""
if not torch.compiler.is_compiling():
# Validation checks for domain-parallel training
_check_domain_parallel_scheduler(x0, self.noise_scheduler)
_check_weight_mesh(weight, x0)
B = x0.shape[0]
if t is None:
t = self.noise_scheduler.sample_time(B, device=x0.device, dtype=x0.dtype)
x_t = self.noise_scheduler.add_noise(x0, t)
prediction = self.model(x_t, t, condition=condition, **model_kwargs)
x0_pred = self._to_x0(prediction, x_t, t)
loss = weight * (x0_pred - x0) ** 2
w = self.noise_scheduler.loss_weight(t)
w = apply_loss_weight(w, x0.ndim)
w = _maybe_promote_to_mesh(w, loss)
loss = w * loss
return self._reduce(loss)