Source code for physicsnemo.nn.module.embedding_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Literal

import numpy as np
import torch
from jaxtyping import Float
from torch import Tensor

from physicsnemo.core.module import Module
from physicsnemo.nn.module.utils.utils import _validate_amp


[docs] class FourierEmbedding(torch.nn.Module): """ Generates Fourier embeddings for timesteps, primarily used in the NCSN++ architecture. This class generates embeddings by first multiplying input tensor `x` and internally stored random frequencies, and then concatenating the cosine and sine of the resultant. Parameters: ----------- num_channels : int The number of channels in the embedding. The final embedding size will be 2 * num_channels because of concatenation of cosine and sine results. scale : int, optional A scale factor applied to the random frequencies, controlling their range and thereby the frequency of oscillations in the embedding space. By default 16. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. """ def __init__(self, num_channels: int, scale: int = 16, amp_mode: bool = False): super().__init__() self.register_buffer("freqs", torch.randn(num_channels // 2) * scale) self.amp_mode = amp_mode
[docs] def forward(self, x): freqs = self.freqs _validate_amp(self.amp_mode) if not self.amp_mode: if x.dtype != self.freqs.dtype: freqs = self.freqs.to(x.dtype) x = x.ger((2 * np.pi * freqs)) x = torch.cat([x.cos(), x.sin()], dim=1) return x
[docs] class PositionalEmbedding(torch.nn.Module): """ A module for generating positional embeddings based on timesteps. This embedding technique is employed in the DDPM++ and ADM architectures. Parameters: ----------- num_channels : int Number of channels for the embedding. max_positions : int, optional Maximum number of positions for the embeddings, by default 10000. endpoint : bool, optional If True, the embedding considers the endpoint. By default False. amp_mode : bool, optional A boolean flag indicating whether mixed-precision (AMP) training is enabled. Defaults to False. learnable : bool, optional A boolean flag indicating whether learnable positional embedding is enabled. Defaults to False. freq_embed_dim: int, optional The dimension of the frequency embedding. Defaults to None, in which case it will be set to num_channels. mlp_hidden_dim: int, optional The dimension of the hidden layer in the MLP. Defaults to None, in which case it will be set to 2 * num_channels. Only applicable if learnable is True; if learnable is False, this parameter is ignored. embed_fn: Literal["cos_sin", "np_sin_cos"], optional The function to use for embedding into sin/cos features (allows for swapping the order of sin/cos). Defaults to 'cos_sin'. Options: - 'cos_sin': Uses torch to compute frequency embeddings and returns in order (cos, sin) - 'np_sin_cos': Uses numpy to compute frequency embeddings and returns in order (sin, cos) """ def __init__( self, num_channels: int, max_positions: int = 10000, endpoint: bool = False, amp_mode: bool = False, learnable: bool = False, freq_embed_dim: int | None = None, mlp_hidden_dim: int | None = None, embed_fn: Literal["cos_sin", "np_sin_cos"] = "cos_sin", ): super().__init__() self.num_channels = num_channels self.max_positions = max_positions self.endpoint = endpoint self.amp_mode = amp_mode self.learnable = learnable self.embed_fn = embed_fn if freq_embed_dim is None: freq_embed_dim = num_channels self.freq_embed_dim = freq_embed_dim if learnable: if mlp_hidden_dim is None: mlp_hidden_dim = 2 * num_channels self.mlp = torch.nn.Sequential( torch.nn.Linear(freq_embed_dim, mlp_hidden_dim, bias=True), torch.nn.SiLU(), torch.nn.Linear(mlp_hidden_dim, num_channels, bias=True), ) freqs = torch.arange(start=0, end=self.freq_embed_dim // 2, dtype=torch.float32) freqs = freqs / (self.freq_embed_dim // 2 - (1 if self.endpoint else 0)) freqs = (1 / self.max_positions) ** freqs self.register_buffer("freqs", freqs, persistent=False)
[docs] def forward(self, x): x = torch.outer(x, self.freqs) if self.embed_fn == "cos_sin": x = torch.cat([x.cos(), x.sin()], dim=1) elif self.embed_fn == "np_sin_cos": x = torch.cat([x.sin(), x.cos()], dim=1) if self.learnable: x = self.mlp(x) return x
[docs] class SinusoidalTimestepEmbedding(Module): r"""Sinusoidal embedding for timesteps (e.g. for modulation / diffusion). For input :math:`x` (timestep) and :math:`D =` ``num_channels`` (even), the frequencies are :math:`\omega_k = k\pi` for :math:`k = 1, \ldots, D/2`, and the output is the concatenation of cosines and sines: .. math:: \mathrm{embed}(x) = \big[ \cos(x \omega_1), \ldots, \cos(x \omega_{D/2}), \sin(x \omega_1), \ldots, \sin(x \omega_{D/2}) \big] \in \mathbb{R}^D. This is a simpler scheme than :class:`PositionalEmbedding` (which uses geometric frequencies and optional learnable MLP) which works well for timestep conditioning in ModAFNO model. Parameters ---------- num_channels : int Number of output channels. Must be even. Forward ------- x : torch.Tensor Input tensor, shape ``B ...`` (e.g. :math:`(B,)` or :math:`(B, 1)`), containing timesteps. Outputs ------- torch.Tensor Output tensor, shape ``B D`` with :math:`D =` ``num_channels``. Examples -------- >>> import torch >>> from physicsnemo.nn.module.embedding_layers import SinusoidalTimestepEmbedding >>> emb = SinusoidalTimestepEmbedding(num_channels=64) >>> t = torch.tensor([0.0, 0.5, 1.0]) >>> out = emb(t) >>> out.shape torch.Size([3, 64]) See Also -------- :class:`PositionalEmbedding` : DDPM/ADM-style positional embedding (geometric frequencies, optional MLP). """ def __init__(self, num_channels: int): super().__init__() self.num_channels = num_channels freqs = torch.pi * torch.arange( start=1, end=self.num_channels // 2 + 1, dtype=torch.float32 ) self.register_buffer("freqs", freqs)
[docs] def forward(self, x: Float[Tensor, "B ..."]) -> Float[Tensor, "B D"]: r"""Forward pass computing sinusoidal embeddings.""" x = x.view(-1).outer(self.freqs.to(x.dtype)) x = torch.cat([x.cos(), x.sin()], dim=1) return x
[docs] class OneHotEmbedding(Module): r"""Soft one-hot embedding for normalized timesteps in :math:`[0, 1]`. The :math:`i`-th channel is :math:`\max(0, 1 - |t(D-1) - i|)`, giving a soft one-hot vector of dimension :math:`D`. Used for timestep conditioning in ModAFNO when ``method="onehot"``. Parameters ---------- num_channels : int Number of channels (embedding dimension). Forward ------- t : torch.Tensor Input tensor, shape ``B ...``, with normalized timesteps in ``[0, 1]``. Outputs ------- torch.Tensor Output tensor, shape ``B D``. Examples -------- >>> import torch >>> from physicsnemo.nn.module.embedding_layers import OneHotEmbedding >>> emb = OneHotEmbedding(num_channels=64) >>> t = torch.tensor([[0.0], [0.5], [1.0]]) >>> out = emb(t) >>> out.shape torch.Size([3, 64]) """ def __init__(self, num_channels: int): super().__init__() self.num_channels = num_channels ind = torch.arange(num_channels, dtype=torch.float32) self.register_buffer("indices", ind.view(1, -1))
[docs] def forward(self, t: Float[Tensor, "B ..."]) -> Float[Tensor, "B D"]: r"""Forward pass computing soft one-hot embeddings.""" ind = (t.view(-1, 1) * (self.num_channels - 1)).to(self.indices.dtype) return torch.clamp(1 - torch.abs(ind - self.indices), min=0)