Source code for physicsnemo.nn.module.transformer_layers
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
import torch
from timm.layers import to_2tuple
from timm.models.swin_transformer import SwinTransformerStage
from torch import nn
from physicsnemo.nn.module.utils import (
PatchEmbed2D,
PatchRecovery2D,
crop2d,
crop3d,
get_pad2d,
get_pad3d,
get_shift_window_mask,
window_partition,
window_reverse,
)
from .attention_layers import EarthAttention2D, EarthAttention3D
from .drop import DropPath
from .mlp_layers import Mlp
from .resample_layers import DownSample2D, UpSample2D
[docs]
class Transformer3DBlock(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
3D Transformer Block
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size [pressure levels, latitude, longitude].
shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude].
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=None,
shift_size=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
window_size = (2, 6, 12) if window_size is None else window_size
shift_size = (1, 3, 6) if shift_size is None else shift_size
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim)
padding = get_pad3d(input_resolution, window_size)
self.pad = nn.ZeroPad3d(padding)
pad_resolution = list(input_resolution)
pad_resolution[0] += padding[-1] + padding[-2]
pad_resolution[1] += padding[2] + padding[3]
pad_resolution[2] += padding[0] + padding[1]
self.attn = EarthAttention3D(
dim=dim,
input_resolution=pad_resolution,
window_size=window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
shift_pl, shift_lat, shift_lon = self.shift_size
self.roll = shift_pl and shift_lon and shift_lat
if self.roll:
attn_mask = get_shift_window_mask(pad_resolution, window_size, shift_size)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
[docs]
def forward(self, x: torch.Tensor):
Pl, Lat, Lon = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, Pl, Lat, Lon, C)
# start pad
x = self.pad(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1)
_, Pl_pad, Lat_pad, Lon_pad, _ = x.shape
shift_pl, shift_lat, shift_lon = self.shift_size
if self.roll:
shifted_x = torch.roll(
x, shifts=(-shift_pl, -shift_lat, -shift_lon), dims=(1, 2, 3)
)
x_windows = window_partition(shifted_x, self.window_size)
# B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size)
# B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C
win_pl, win_lat, win_lon = self.window_size
x_windows = x_windows.view(
x_windows.shape[0], x_windows.shape[1], win_pl * win_lat * win_lon, C
)
# B*num_lon, num_pl*num_lat, win_pl*win_lat*win_lon, C
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # B*num_lon, num_pl*num_lat, win_pl*win_lat*win_lon, C
attn_windows = attn_windows.view(
attn_windows.shape[0], attn_windows.shape[1], win_pl, win_lat, win_lon, C
)
if self.roll:
shifted_x = window_reverse(
attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad
)
# B * Pl * Lat * Lon * C
x = torch.roll(
shifted_x, shifts=(shift_pl, shift_lat, shift_lon), dims=(1, 2, 3)
)
else:
shifted_x = window_reverse(
attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad
)
x = shifted_x
# crop, end pad
x = crop3d(x.permute(0, 4, 1, 2, 3), self.input_resolution).permute(
0, 2, 3, 4, 1
)
x = x.reshape(B, Pl * Lat * Lon, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
[docs]
class Transformer2DBlock(nn.Module):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
2D Transformer Block
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size [latitude, longitude].
shift_size (tuple[int]): Shift size for SW-MSA [latitude, longitude].
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
num_heads,
window_size=None,
shift_size=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
window_size = (6, 12) if window_size is None else window_size
shift_size = (3, 6) if shift_size is None else shift_size
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim)
padding = get_pad2d(input_resolution, window_size)
self.pad = nn.ZeroPad2d(padding)
pad_resolution = list(input_resolution)
pad_resolution[0] += padding[2] + padding[3]
pad_resolution[1] += padding[0] + padding[1]
self.attn = EarthAttention2D(
dim=dim,
input_resolution=pad_resolution,
window_size=window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
shift_lat, shift_lon = self.shift_size
self.roll = shift_lon and shift_lat
if self.roll:
attn_mask = get_shift_window_mask(
pad_resolution, window_size, shift_size, ndim=2
)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
[docs]
def forward(self, x: torch.Tensor):
Lat, Lon = self.input_resolution
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, Lat, Lon, C)
# start pad
x = self.pad(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
_, Lat_pad, Lon_pad, _ = x.shape
shift_lat, shift_lon = self.shift_size
if self.roll:
shifted_x = torch.roll(x, shifts=(-shift_lat, -shift_lon), dims=(1, 2))
x_windows = window_partition(shifted_x, self.window_size, ndim=2)
# B*num_lon, num_lat, win_lat, win_lon, C
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.window_size, ndim=2)
# B*num_lon, num_lat, win_lat, win_lon, C
win_lat, win_lon = self.window_size
x_windows = x_windows.view(
x_windows.shape[0], x_windows.shape[1], win_lat * win_lon, C
)
# B*num_lon, num_lat, win_lat*win_lon, C
attn_windows = self.attn(
x_windows, mask=self.attn_mask
) # B*num_lon, num_lat, win_lat*win_lon, C
attn_windows = attn_windows.view(
attn_windows.shape[0], attn_windows.shape[1], win_lat, win_lon, C
)
if self.roll:
shifted_x = window_reverse(
attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2
)
# B * Lat * Lon * C
x = torch.roll(shifted_x, shifts=(shift_lat, shift_lon), dims=(1, 2))
else:
shifted_x = window_reverse(
attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2
)
x = shifted_x
# crop, end pad
x = crop2d(x.permute(0, 3, 1, 2), self.input_resolution).permute(0, 2, 3, 1)
x = x.reshape(B, Lat * Lon, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
[docs]
class FuserLayer(nn.Module):
"""Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
A basic 3D Transformer layer for one stage
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.blocks = nn.ModuleList(
[
Transformer3DBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=(0, 0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, Sequence)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
[docs]
class EncoderLayer(nn.Module):
"""A 2D Transformer Encoder Module for one stage
Args:
img_size (tuple[int]): image size(Lat, Lon).
patch_size (tuple[int]): Patch token size of Patch Embedding.
in_chans (int): number of input channels of Patch Embedding.
dim (int): Number of input channels of transformer.
input_resolution (tuple[int]): Input resolution for transformer before downsampling.
middle_resolution (tuple[int]): Input resolution for transformer after downsampling.
depth (int): Number of blocks for transformer before downsampling.
depth_middle (int): Number of blocks for transformer after downsampling.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
img_size,
patch_size,
in_chans,
dim,
input_resolution,
middle_resolution,
depth,
depth_middle,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.in_chans = in_chans
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.depth_middle = depth_middle
if isinstance(drop_path, Sequence):
drop_path_middle = drop_path[depth:]
drop_path = drop_path[:depth]
else:
drop_path_middle = drop_path
if isinstance(num_heads, Sequence):
num_heads_middle = num_heads[1]
num_heads = num_heads[0]
else:
num_heads_middle = num_heads
self.patchembed2d = PatchEmbed2D(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=dim,
)
self.blocks = nn.ModuleList(
[
Transformer2DBlock(
dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, Sequence)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.downsample = DownSample2D(
in_dim=dim,
input_resolution=input_resolution,
output_resolution=middle_resolution,
)
self.blocks_middle = nn.ModuleList(
[
Transformer2DBlock(
dim=dim * 2,
input_resolution=middle_resolution,
num_heads=num_heads_middle,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path_middle[i]
if isinstance(drop_path_middle, Sequence)
else drop_path_middle,
norm_layer=norm_layer,
)
for i in range(depth_middle)
]
)
[docs]
def forward(self, x):
x = self.patchembed2d(x)
B, C, Lat, Lon = x.shape
x = x.reshape(B, C, -1).transpose(1, 2)
for blk in self.blocks:
x = blk(x)
skip = x.reshape(B, Lat, Lon, C)
x = self.downsample(x)
for blk in self.blocks_middle:
x = blk(x)
return x, skip
[docs]
class DecoderLayer(nn.Module):
"""A 2D Transformer Decoder Module for one stage
Args:
img_size (tuple[int]): image size(Lat, Lon).
patch_size (tuple[int]): Patch token size of Patch Recovery.
out_chans (int): number of output channels of Patch Recovery.
dim (int): Number of input channels of transformer.
output_resolution (tuple[int]): Input resolution for transformer after upsampling.
middle_resolution (tuple[int]): Input resolution for transformer before upsampling.
depth (int): Number of blocks for transformer after upsampling.
depth_middle (int): Number of blocks for transformer before upsampling.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
img_size,
patch_size,
out_chans,
dim,
output_resolution,
middle_resolution,
depth,
depth_middle,
num_heads,
window_size,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.out_chans = out_chans
self.dim = dim
self.output_resolution = output_resolution
self.depth = depth
self.depth_middle = depth_middle
if isinstance(drop_path, Sequence):
drop_path_middle = drop_path[depth:]
drop_path = drop_path[:depth]
else:
drop_path_middle = drop_path
if isinstance(num_heads, Sequence):
num_heads_middle = num_heads[1]
num_heads = num_heads[0]
else:
num_heads_middle = num_heads
self.blocks_middle = nn.ModuleList(
[
Transformer2DBlock(
dim=dim * 2,
input_resolution=middle_resolution,
num_heads=num_heads_middle,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path_middle[i]
if isinstance(drop_path_middle, Sequence)
else drop_path_middle,
norm_layer=norm_layer,
)
for i in range(depth_middle)
]
)
self.upsample = UpSample2D(
in_dim=dim * 2,
out_dim=dim,
input_resolution=middle_resolution,
output_resolution=output_resolution,
)
self.blocks = nn.ModuleList(
[
Transformer2DBlock(
dim=dim,
input_resolution=output_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=(0, 0) if i % 2 == 0 else None,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, Sequence)
else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)
]
)
self.patchrecovery2d = PatchRecovery2D(img_size, patch_size, 2 * dim, out_chans)
[docs]
def forward(self, x, skip):
B, Lat, Lon, C = skip.shape
for blk in self.blocks_middle:
x = blk(x)
x = self.upsample(x)
for blk in self.blocks:
x = blk(x)
output = torch.concat([x, skip.reshape(B, -1, C)], dim=-1)
output = output.transpose(1, 2).reshape(B, -1, Lat, Lon)
output = self.patchrecovery2d(output)
return output
[docs]
class SwinTransformer(nn.Module):
"""Swin Transformer
Args:
embed_dim (int): Patch embedding dimension.
input_resolution (tuple[int]): Lat, Lon.
num_heads (int): Number of attention heads in different layers.
window_size (int | tuple[int]): Window size.
depth (int): Number of blocks.
"""
def __init__(self, embed_dim, input_resolution, num_heads, window_size, depth):
super().__init__()
window_size = to_2tuple(window_size)
padding = get_pad2d(input_resolution, to_2tuple(window_size))
padding_left, padding_right, padding_top, padding_bottom = padding
self.padding = padding
self.pad = nn.ZeroPad2d(padding)
input_resolution = list(input_resolution)
input_resolution[0] = input_resolution[0] + padding_top + padding_bottom
input_resolution[1] = input_resolution[1] + padding_left + padding_right
self.layer = SwinTransformerStage(
dim=embed_dim,
out_dim=embed_dim,
input_resolution=input_resolution,
depth=depth,
downsample=None,
num_heads=num_heads,
window_size=window_size,
)
[docs]
def forward(self, x):
B, C, Lat, Lon = x.shape
padding_left, padding_right, padding_top, padding_bottom = self.padding
# pad
x = self.pad(x)
_, _, pad_lat, pad_lon = x.shape
x = x.permute(0, 2, 3, 1) # B Lat Lon C
x = self.layer(x)
x = x.permute(0, 3, 1, 2)
# crop
x = x[
:,
:,
padding_top : pad_lat - padding_bottom,
padding_left : pad_lon - padding_right,
]
return x