Source code for physicsnemo.nn.module.attention_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Any, Dict

import numpy as np
import torch
from torch import nn

from physicsnemo.nn.module.conv_layers import Conv2d
from physicsnemo.nn.module.group_norm import get_group_norm
from physicsnemo.nn.module.utils import get_earth_position_index, trunc_normal_


[docs] class AttentionOp(torch.autograd.Function): """ Attention weight computation, i.e., softmax(Q^T * K). Performs all computation using FP32, but uses the original datatype for inputs/outputs/gradients to conserve memory. """
[docs] @staticmethod def forward(ctx, q, k): w = ( torch.einsum( "ncq,nck->nqk", q.to(torch.float32), (k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32), ) .softmax(dim=2) .to(q.dtype) ) ctx.save_for_backward(q, k, w) return w
[docs] @staticmethod def backward(ctx, dw): q, k, w = ctx.saved_tensors db = torch._softmax_backward_data( grad_output=dw.to(torch.float32), output=w.to(torch.float32), dim=2, input_dtype=torch.float32, ) dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to( q.dtype ) / np.sqrt(k.shape[1]) dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to( k.dtype ) / np.sqrt(k.shape[1]) return dq, dk
[docs] class UNetAttention(torch.nn.Module): """ Self-attention block used in U-Net-style architectures, such as DDPM++, NCSN++, and ADM. Applies GroupNorm followed by multi-head self-attention and a projection layer. Parameters ---------- out_channels : int Number of channels :math:`C` in the input and output feature maps. num_heads : int Number of attention heads. Must be a positive integer. eps : float, optional, default=1e-5 Epsilon value for numerical stability in GroupNorm. init_zero : dict, optional, default={'init_weight': 0} Initialization parameters with zero weights for certain layers. init_attn : dict, optional, default=None Initialization parameters specific to attention mechanism layers. Defaults to 'init' if not provided. init : dict, optional, default={} Initialization parameters for convolutional and linear layers. use_apex_gn : bool, optional, default=False A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout. Need to set this as False on cpu. amp_mode : bool, optional, default=False A boolean flag indicating whether mixed-precision (AMP) training is enabled. fused_conv_bias: bool, optional, default=False A boolean flag indicating whether bias will be passed as a parameter of conv2d. Forward ------- x : torch.Tensor Input tensor of shape :math:`(B, C, H, W)`, where :math:`B` is batch size, :math:`C` is `out_channels`, and :math:`H, W` are spatial dimensions. Outputs ------- torch.Tensor Output tensor of the same shape as input: :math:`(B, C, H, W)`. """ def __init__( self, *, out_channels: int, num_heads: int, eps: float = 1e-5, init_zero: Dict[str, Any] = dict(init_weight=0), init_attn: Any = None, init: Dict[str, Any] = dict(), use_apex_gn: bool = False, amp_mode: bool = False, fused_conv_bias: bool = False, ) -> None: super().__init__() # Parameters validation if not isinstance(num_heads, int) or num_heads <= 0: raise ValueError( f"`num_heads` must be a positive integer, but got {num_heads}" ) if out_channels % num_heads != 0: raise ValueError( f"`out_channels` must be divisible by `num_heads`, but got {out_channels} and {num_heads}" ) self.num_heads = num_heads self.norm = get_group_norm( num_channels=out_channels, eps=eps, use_apex_gn=use_apex_gn, amp_mode=amp_mode, ) self.qkv = Conv2d( in_channels=out_channels, out_channels=out_channels * 3, kernel=1, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **(init_attn if init_attn is not None else init), ) self.proj = Conv2d( in_channels=out_channels, out_channels=out_channels, kernel=1, fused_conv_bias=fused_conv_bias, amp_mode=amp_mode, **init_zero, )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: x1: torch.Tensor = self.qkv(self.norm(x)) # # NOTE: V1.0.1 implementation # q, k, v = x1.reshape( # x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1 # ).unbind(2) # w = AttentionOp.apply(q, k) # attn = torch.einsum("nqk,nck->ncq", w, v) qkv = ( x1.reshape(x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1) ).permute(0, 1, 4, 3, 2) (q, k, v) = (qkv[..., i, :] for i in range(3)) attn = torch.nn.functional.scaled_dot_product_attention( q, k, v, scale=1 / math.sqrt(k.shape[-1]) ) attn = attn.transpose(-1, -2) x: torch.Tensor = self.proj(attn.reshape(*x.shape)).add_(x) return x
[docs] class EarthAttention3D(nn.Module): """ Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 3D window attention with earth position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. input_resolution (tuple[int]): [pressure levels, latitude, longitude] window_size (tuple[int]): [pressure levels, latitude, longitude] num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, input_resolution, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wpl, Wlat, Wlon self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.type_of_windows = (input_resolution[0] // window_size[0]) * ( input_resolution[1] // window_size[1] ) self.earth_position_bias_table = nn.Parameter( torch.zeros( (window_size[0] ** 2) * (window_size[1] ** 2) * (window_size[2] * 2 - 1), self.type_of_windows, num_heads, ) ) # Wpl**2 * Wlat**2 * Wlon*2-1, Npl//Wpl * Nlat//Wlat, nH earth_position_index = get_earth_position_index( window_size ) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon self.register_buffer("earth_position_index", earth_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.earth_position_bias_table = trunc_normal_( self.earth_position_bias_table, std=0.02 ) self.softmax = nn.Softmax(dim=-1)
[docs] def forward(self, x: torch.Tensor, mask=None): """ Args: x: input features with shape of (B * num_lon, num_pl*num_lat, N, C) mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon) """ B_, nW_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads) .permute(3, 0, 4, 1, 2, 5) ) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) earth_position_bias = self.earth_position_bias_table[ self.earth_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1] * self.window_size[2], self.window_size[0] * self.window_size[1] * self.window_size[2], self.type_of_windows, -1, ) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon, num_pl*num_lat, nH earth_position_bias = earth_position_bias.permute( 3, 2, 0, 1 ).contiguous() # nH, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon attn = attn + earth_position_bias.unsqueeze(0) if mask is not None: nLon = mask.shape[0] attn = attn.view( B_ // nLon, nLon, self.num_heads, nW_, N, N ) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, nW_, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C) x = self.proj(x) x = self.proj_drop(x) return x
[docs] class EarthAttention2D(nn.Module): """ Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 2D window attention with earth position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. input_resolution (tuple[int]): [latitude, longitude] window_size (tuple[int]): [latitude, longitude] num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ def __init__( self, dim, input_resolution, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.dim = dim self.window_size = window_size # Wlat, Wlon self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.type_of_windows = input_resolution[0] // window_size[0] self.earth_position_bias_table = nn.Parameter( torch.zeros( (window_size[0] ** 2) * (window_size[1] * 2 - 1), self.type_of_windows, num_heads, ) ) # Wlat**2 * Wlon*2-1, Nlat//Wlat, nH earth_position_index = get_earth_position_index( window_size, ndim=2 ) # Wlat*Wlon, Wlat*Wlon self.register_buffer("earth_position_index", earth_position_index) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.earth_position_bias_table = trunc_normal_( self.earth_position_bias_table, std=0.02 ) self.softmax = nn.Softmax(dim=-1)
[docs] def forward(self, x: torch.Tensor, mask=None): """ Args: x: input features with shape of (B * num_lon, num_lat, N, C) mask: (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon) """ B_, nW_, N, C = x.shape qkv = ( self.qkv(x) .reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads) .permute(3, 0, 4, 1, 2, 5) ) q, k, v = qkv[0], qkv[1], qkv[2] q = q * self.scale attn = q @ k.transpose(-2, -1) earth_position_bias = self.earth_position_bias_table[ self.earth_position_index.view(-1) ].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], self.type_of_windows, -1, ) # Wlat*Wlon, Wlat*Wlon, num_lat, nH earth_position_bias = earth_position_bias.permute( 3, 2, 0, 1 ).contiguous() # nH, num_lat, Wlat*Wlon, Wlat*Wlon attn = attn + earth_position_bias.unsqueeze(0) if mask is not None: nLon = mask.shape[0] attn = attn.view( B_ // nLon, nLon, self.num_heads, nW_, N, N ) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, nW_, N, N) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C) x = self.proj(x) x = self.proj_drop(x) return x