# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Dict
import numpy as np
import torch
from torch import nn
from physicsnemo.nn.module.conv_layers import Conv2d
from physicsnemo.nn.module.group_norm import get_group_norm
from physicsnemo.nn.module.utils import get_earth_position_index, trunc_normal_
[docs]
class AttentionOp(torch.autograd.Function):
"""
Attention weight computation, i.e., softmax(Q^T * K).
Performs all computation using FP32, but uses the original datatype for
inputs/outputs/gradients to conserve memory.
"""
[docs]
@staticmethod
def forward(ctx, q, k):
w = (
torch.einsum(
"ncq,nck->nqk",
q.to(torch.float32),
(k / torch.sqrt(torch.tensor(k.shape[1]))).to(torch.float32),
)
.softmax(dim=2)
.to(q.dtype)
)
ctx.save_for_backward(q, k, w)
return w
[docs]
@staticmethod
def backward(ctx, dw):
q, k, w = ctx.saved_tensors
db = torch._softmax_backward_data(
grad_output=dw.to(torch.float32),
output=w.to(torch.float32),
dim=2,
input_dtype=torch.float32,
)
dq = torch.einsum("nck,nqk->ncq", k.to(torch.float32), db).to(
q.dtype
) / np.sqrt(k.shape[1])
dk = torch.einsum("ncq,nqk->nck", q.to(torch.float32), db).to(
k.dtype
) / np.sqrt(k.shape[1])
return dq, dk
[docs]
class UNetAttention(torch.nn.Module):
"""
Self-attention block used in U-Net-style architectures, such as DDPM++, NCSN++, and ADM.
Applies GroupNorm followed by multi-head self-attention and a projection layer.
Parameters
----------
out_channels : int
Number of channels :math:`C` in the input and output feature maps.
num_heads : int
Number of attention heads. Must be a positive integer.
eps : float, optional, default=1e-5
Epsilon value for numerical stability in GroupNorm.
init_zero : dict, optional, default={'init_weight': 0}
Initialization parameters with zero weights for certain layers.
init_attn : dict, optional, default=None
Initialization parameters specific to attention mechanism layers.
Defaults to 'init' if not provided.
init : dict, optional, default={}
Initialization parameters for convolutional and linear layers.
use_apex_gn : bool, optional, default=False
A boolean flag indicating whether we want to use Apex GroupNorm for NHWC layout.
Need to set this as False on cpu.
amp_mode : bool, optional, default=False
A boolean flag indicating whether mixed-precision (AMP) training is enabled.
fused_conv_bias: bool, optional, default=False
A boolean flag indicating whether bias will be passed as a parameter of conv2d.
Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(B, C, H, W)`, where :math:`B` is batch
size, :math:`C` is `out_channels`, and :math:`H, W` are spatial
dimensions.
Outputs
-------
torch.Tensor
Output tensor of the same shape as input: :math:`(B, C, H, W)`.
"""
def __init__(
self,
*,
out_channels: int,
num_heads: int,
eps: float = 1e-5,
init_zero: Dict[str, Any] = dict(init_weight=0),
init_attn: Any = None,
init: Dict[str, Any] = dict(),
use_apex_gn: bool = False,
amp_mode: bool = False,
fused_conv_bias: bool = False,
) -> None:
super().__init__()
# Parameters validation
if not isinstance(num_heads, int) or num_heads <= 0:
raise ValueError(
f"`num_heads` must be a positive integer, but got {num_heads}"
)
if out_channels % num_heads != 0:
raise ValueError(
f"`out_channels` must be divisible by `num_heads`, but got {out_channels} and {num_heads}"
)
self.num_heads = num_heads
self.norm = get_group_norm(
num_channels=out_channels,
eps=eps,
use_apex_gn=use_apex_gn,
amp_mode=amp_mode,
)
self.qkv = Conv2d(
in_channels=out_channels,
out_channels=out_channels * 3,
kernel=1,
fused_conv_bias=fused_conv_bias,
amp_mode=amp_mode,
**(init_attn if init_attn is not None else init),
)
self.proj = Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel=1,
fused_conv_bias=fused_conv_bias,
amp_mode=amp_mode,
**init_zero,
)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1: torch.Tensor = self.qkv(self.norm(x))
# # NOTE: V1.0.1 implementation
# q, k, v = x1.reshape(
# x.shape[0] * self.num_heads, x.shape[1] // self.num_heads, 3, -1
# ).unbind(2)
# w = AttentionOp.apply(q, k)
# attn = torch.einsum("nqk,nck->ncq", w, v)
qkv = (
x1.reshape(x.shape[0], self.num_heads, x.shape[1] // self.num_heads, 3, -1)
).permute(0, 1, 4, 3, 2)
(q, k, v) = (qkv[..., i, :] for i in range(3))
attn = torch.nn.functional.scaled_dot_product_attention(
q, k, v, scale=1 / math.sqrt(k.shape[-1])
)
attn = attn.transpose(-1, -2)
x: torch.Tensor = self.proj(attn.reshape(*x.shape)).add_(x)
return x
[docs]
class EarthAttention3D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D window attention with earth position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): [pressure levels, latitude, longitude]
window_size (tuple[int]): [pressure levels, latitude, longitude]
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wpl, Wlat, Wlon
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.type_of_windows = (input_resolution[0] // window_size[0]) * (
input_resolution[1] // window_size[1]
)
self.earth_position_bias_table = nn.Parameter(
torch.zeros(
(window_size[0] ** 2)
* (window_size[1] ** 2)
* (window_size[2] * 2 - 1),
self.type_of_windows,
num_heads,
)
) # Wpl**2 * Wlat**2 * Wlon*2-1, Npl//Wpl * Nlat//Wlat, nH
earth_position_index = get_earth_position_index(
window_size
) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon
self.register_buffer("earth_position_index", earth_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
self.softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: input features with shape of (B * num_lon, num_pl*num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)
"""
B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads)
.permute(3, 0, 4, 1, 2, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
earth_position_bias = self.earth_position_bias_table[
self.earth_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.window_size[0] * self.window_size[1] * self.window_size[2],
self.type_of_windows,
-1,
) # Wpl*Wlat*Wlon, Wpl*Wlat*Wlon, num_pl*num_lat, nH
earth_position_bias = earth_position_bias.permute(
3, 2, 0, 1
).contiguous() # nH, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon
attn = attn + earth_position_bias.unsqueeze(0)
if mask is not None:
nLon = mask.shape[0]
attn = attn.view(
B_ // nLon, nLon, self.num_heads, nW_, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, nW_, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
[docs]
class EarthAttention2D(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D window attention with earth position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): [latitude, longitude]
window_size (tuple[int]): [latitude, longitude]
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(
self,
dim,
input_resolution,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.dim = dim
self.window_size = window_size # Wlat, Wlon
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.type_of_windows = input_resolution[0] // window_size[0]
self.earth_position_bias_table = nn.Parameter(
torch.zeros(
(window_size[0] ** 2) * (window_size[1] * 2 - 1),
self.type_of_windows,
num_heads,
)
) # Wlat**2 * Wlon*2-1, Nlat//Wlat, nH
earth_position_index = get_earth_position_index(
window_size, ndim=2
) # Wlat*Wlon, Wlat*Wlon
self.register_buffer("earth_position_index", earth_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.earth_position_bias_table = trunc_normal_(
self.earth_position_bias_table, std=0.02
)
self.softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, x: torch.Tensor, mask=None):
"""
Args:
x: input features with shape of (B * num_lon, num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_lat, Wlat*Wlon, Wlat*Wlon)
"""
B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B_, nW_, N, 3, self.num_heads, C // self.num_heads)
.permute(3, 0, 4, 1, 2, 5)
)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = q @ k.transpose(-2, -1)
earth_position_bias = self.earth_position_bias_table[
self.earth_position_index.view(-1)
].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
self.type_of_windows,
-1,
) # Wlat*Wlon, Wlat*Wlon, num_lat, nH
earth_position_bias = earth_position_bias.permute(
3, 2, 0, 1
).contiguous() # nH, num_lat, Wlat*Wlon, Wlat*Wlon
attn = attn + earth_position_bias.unsqueeze(0)
if mask is not None:
nLon = mask.shape[0]
attn = attn.view(
B_ // nLon, nLon, self.num_heads, nW_, N, N
) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, nW_, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 2, 3, 1, 4).reshape(B_, nW_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x