Source code for physicsnemo.nn.module.gumbel_softmax

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from jaxtyping import Float


[docs] def gumbel_softmax( logits: Float[torch.Tensor, "... num_categories"], tau: torch.Tensor | float = 1.0, ) -> Float[torch.Tensor, "... num_categories"]: r""" Implementation of Gumbel Softmax from Transolver++. Applies a differentiable approximation to sampling from a categorical distribution using the Gumbel-Softmax trick. Original code: https://github.com/thuml/Transolver_plus/blob/main/models/Transolver_plus.py#L69 Parameters ---------- logits : torch.Tensor Input logits tensor of shape :math:`(*, K)` where :math:`K` is the number of categories. tau : torch.Tensor | float, optional, default=1.0 Temperature parameter. Lower values make the distribution more concentrated. Returns ------- torch.Tensor Gumbel-Softmax output of the same shape as ``logits``. """ # Sample Gumbel noise u = torch.rand_like(logits) gumbel_noise = -torch.log(-torch.log(u + 1e-8) + 1e-8) # Add noise and apply temperature-scaled softmax y = logits + gumbel_noise y = y / tau y = torch.nn.functional.softmax(y, dim=-1) return y
[docs] class GumbelSoftmax(nn.Module): r"""Gumbel-Softmax module for differentiable categorical sampling. This module wraps the :func:`gumbel_softmax` function as an ``nn.Module``, allowing it to be used as a layer in neural network architectures. The Gumbel-Softmax trick provides a differentiable approximation to sampling from a categorical distribution, enabling end-to-end training of models with discrete latent variables. Parameters ---------- tau : float, optional, default=1.0 Initial temperature parameter. Lower values make the distribution more concentrated (closer to one-hot). Can be modified after initialization. learnable : bool, optional, default=False If ``True``, the temperature parameter is registered as a learnable ``nn.Parameter``. If ``False``, it is a fixed buffer. Examples -------- >>> import torch >>> gs = GumbelSoftmax(tau=0.5) >>> logits = torch.randn(2, 10) # batch_size=2, num_categories=10 >>> probs = gs(logits) >>> probs.shape torch.Size([2, 10]) >>> torch.allclose(probs.sum(dim=-1), torch.ones(2)) # Each row sums to 1 True >>> # With learnable temperature >>> gs_learnable = GumbelSoftmax(tau=1.0, learnable=True) >>> gs_learnable.tau.requires_grad True See Also -------- :func:`gumbel_softmax` : Functional implementation of Gumbel-Softmax. """ def __init__(self, tau: float = 1.0, learnable: bool = False): super().__init__() if learnable: self.tau = nn.Parameter(torch.tensor(tau)) else: self.register_buffer("tau", torch.tensor(tau))
[docs] def forward( self, logits: Float[torch.Tensor, "... num_categories"] ) -> Float[torch.Tensor, "... num_categories"]: r"""Apply Gumbel-Softmax to input logits. Parameters ---------- logits : torch.Tensor Input logits tensor of shape :math:`(*, K)` where :math:`K` is the number of categories. Returns ------- torch.Tensor Gumbel-Softmax output of the same shape as ``logits``. """ return gumbel_softmax(logits, tau=self.tau)