Source code for physicsnemo.nn.module.transformer_layers

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import Sequence

import torch
from timm.layers import to_2tuple
from timm.models.swin_transformer import SwinTransformerStage
from torch import nn

from physicsnemo.nn.module.utils import (
    PatchEmbed2D,
    PatchRecovery2D,
    crop2d,
    crop3d,
    get_pad2d,
    get_pad3d,
    get_shift_window_mask,
    window_partition,
    window_reverse,
)

from .attention_layers import EarthAttention2D, EarthAttention3D
from .drop import DropPath
from .mlp_layers import Mlp
from .resample_layers import DownSample2D, UpSample2D


[docs] class Transformer3DBlock(nn.Module): """ Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 3D Transformer Block Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (tuple[int]): Window size [pressure levels, latitude, longitude]. shift_size (tuple[int]): Shift size for SW-MSA [pressure levels, latitude, longitude]. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, input_resolution, num_heads, window_size=None, shift_size=None, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() window_size = (2, 6, 12) if window_size is None else window_size shift_size = (1, 3, 6) if shift_size is None else shift_size self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.norm1 = norm_layer(dim) padding = get_pad3d(input_resolution, window_size) self.pad = nn.ZeroPad3d(padding) pad_resolution = list(input_resolution) pad_resolution[0] += padding[-1] + padding[-2] pad_resolution[1] += padding[2] + padding[3] pad_resolution[2] += padding[0] + padding[1] self.attn = EarthAttention3D( dim=dim, input_resolution=pad_resolution, window_size=window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) shift_pl, shift_lat, shift_lon = self.shift_size self.roll = shift_pl and shift_lon and shift_lat if self.roll: attn_mask = get_shift_window_mask(pad_resolution, window_size, shift_size) else: attn_mask = None self.register_buffer("attn_mask", attn_mask)
[docs] def forward(self, x: torch.Tensor): Pl, Lat, Lon = self.input_resolution B, L, C = x.shape shortcut = x x = self.norm1(x) x = x.view(B, Pl, Lat, Lon, C) # start pad x = self.pad(x.permute(0, 4, 1, 2, 3)).permute(0, 2, 3, 4, 1) _, Pl_pad, Lat_pad, Lon_pad, _ = x.shape shift_pl, shift_lat, shift_lon = self.shift_size if self.roll: shifted_x = torch.roll( x, shifts=(-shift_pl, -shift_lat, -shift_lon), dims=(1, 2, 3) ) x_windows = window_partition(shifted_x, self.window_size) # B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C else: shifted_x = x x_windows = window_partition(shifted_x, self.window_size) # B*num_lon, num_pl*num_lat, win_pl, win_lat, win_lon, C win_pl, win_lat, win_lon = self.window_size x_windows = x_windows.view( x_windows.shape[0], x_windows.shape[1], win_pl * win_lat * win_lon, C ) # B*num_lon, num_pl*num_lat, win_pl*win_lat*win_lon, C attn_windows = self.attn( x_windows, mask=self.attn_mask ) # B*num_lon, num_pl*num_lat, win_pl*win_lat*win_lon, C attn_windows = attn_windows.view( attn_windows.shape[0], attn_windows.shape[1], win_pl, win_lat, win_lon, C ) if self.roll: shifted_x = window_reverse( attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad ) # B * Pl * Lat * Lon * C x = torch.roll( shifted_x, shifts=(shift_pl, shift_lat, shift_lon), dims=(1, 2, 3) ) else: shifted_x = window_reverse( attn_windows, self.window_size, Pl=Pl_pad, Lat=Lat_pad, Lon=Lon_pad ) x = shifted_x # crop, end pad x = crop3d(x.permute(0, 4, 1, 2, 3), self.input_resolution).permute( 0, 2, 3, 4, 1 ) x = x.reshape(B, Pl * Lat * Lon, C) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs] class Transformer2DBlock(nn.Module): """ Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn 2D Transformer Block Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resulotion. num_heads (int): Number of attention heads. window_size (tuple[int]): Window size [latitude, longitude]. shift_size (tuple[int]): Shift size for SW-MSA [latitude, longitude]. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 act_layer (nn.Module, optional): Activation layer. Default: nn.GELU norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, input_resolution, num_heads, window_size=None, shift_size=None, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, ): super().__init__() window_size = (6, 12) if window_size is None else window_size shift_size = (3, 6) if shift_size is None else shift_size self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.window_size = window_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio self.norm1 = norm_layer(dim) padding = get_pad2d(input_resolution, window_size) self.pad = nn.ZeroPad2d(padding) pad_resolution = list(input_resolution) pad_resolution[0] += padding[2] + padding[3] pad_resolution[1] += padding[0] + padding[1] self.attn = EarthAttention2D( dim=dim, input_resolution=pad_resolution, window_size=window_size, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) shift_lat, shift_lon = self.shift_size self.roll = shift_lon and shift_lat if self.roll: attn_mask = get_shift_window_mask( pad_resolution, window_size, shift_size, ndim=2 ) else: attn_mask = None self.register_buffer("attn_mask", attn_mask)
[docs] def forward(self, x: torch.Tensor): Lat, Lon = self.input_resolution B, L, C = x.shape shortcut = x x = self.norm1(x) x = x.view(B, Lat, Lon, C) # start pad x = self.pad(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) _, Lat_pad, Lon_pad, _ = x.shape shift_lat, shift_lon = self.shift_size if self.roll: shifted_x = torch.roll(x, shifts=(-shift_lat, -shift_lon), dims=(1, 2)) x_windows = window_partition(shifted_x, self.window_size, ndim=2) # B*num_lon, num_lat, win_lat, win_lon, C else: shifted_x = x x_windows = window_partition(shifted_x, self.window_size, ndim=2) # B*num_lon, num_lat, win_lat, win_lon, C win_lat, win_lon = self.window_size x_windows = x_windows.view( x_windows.shape[0], x_windows.shape[1], win_lat * win_lon, C ) # B*num_lon, num_lat, win_lat*win_lon, C attn_windows = self.attn( x_windows, mask=self.attn_mask ) # B*num_lon, num_lat, win_lat*win_lon, C attn_windows = attn_windows.view( attn_windows.shape[0], attn_windows.shape[1], win_lat, win_lon, C ) if self.roll: shifted_x = window_reverse( attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2 ) # B * Lat * Lon * C x = torch.roll(shifted_x, shifts=(shift_lat, shift_lon), dims=(1, 2)) else: shifted_x = window_reverse( attn_windows, self.window_size, Lat=Lat_pad, Lon=Lon_pad, ndim=2 ) x = shifted_x # crop, end pad x = crop2d(x.permute(0, 3, 1, 2), self.input_resolution).permute(0, 2, 3, 1) x = x.reshape(B, Lat * Lon, C) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x
[docs] class FuserLayer(nn.Module): """Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn A basic 3D Transformer layer for one stage Args: dim (int): Number of input channels. input_resolution (tuple[int]): Input resolution. depth (int): Number of blocks. num_heads (int): Number of attention heads. window_size (tuple[int]): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, dim, input_resolution, depth, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, ): super().__init__() self.dim = dim self.input_resolution = input_resolution self.depth = depth self.blocks = nn.ModuleList( [ Transformer3DBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=(0, 0, 0) if i % 2 == 0 else None, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, Sequence) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] )
[docs] def forward(self, x): for blk in self.blocks: x = blk(x) return x
[docs] class EncoderLayer(nn.Module): """A 2D Transformer Encoder Module for one stage Args: img_size (tuple[int]): image size(Lat, Lon). patch_size (tuple[int]): Patch token size of Patch Embedding. in_chans (int): number of input channels of Patch Embedding. dim (int): Number of input channels of transformer. input_resolution (tuple[int]): Input resolution for transformer before downsampling. middle_resolution (tuple[int]): Input resolution for transformer after downsampling. depth (int): Number of blocks for transformer before downsampling. depth_middle (int): Number of blocks for transformer after downsampling. num_heads (int): Number of attention heads. window_size (tuple[int]): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, img_size, patch_size, in_chans, dim, input_resolution, middle_resolution, depth, depth_middle, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, ): super().__init__() self.in_chans = in_chans self.dim = dim self.input_resolution = input_resolution self.depth = depth self.depth_middle = depth_middle if isinstance(drop_path, Sequence): drop_path_middle = drop_path[depth:] drop_path = drop_path[:depth] else: drop_path_middle = drop_path if isinstance(num_heads, Sequence): num_heads_middle = num_heads[1] num_heads = num_heads[0] else: num_heads_middle = num_heads self.patchembed2d = PatchEmbed2D( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=dim, ) self.blocks = nn.ModuleList( [ Transformer2DBlock( dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, shift_size=(0, 0) if i % 2 == 0 else None, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, Sequence) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) self.downsample = DownSample2D( in_dim=dim, input_resolution=input_resolution, output_resolution=middle_resolution, ) self.blocks_middle = nn.ModuleList( [ Transformer2DBlock( dim=dim * 2, input_resolution=middle_resolution, num_heads=num_heads_middle, window_size=window_size, shift_size=(0, 0) if i % 2 == 0 else None, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path_middle[i] if isinstance(drop_path_middle, Sequence) else drop_path_middle, norm_layer=norm_layer, ) for i in range(depth_middle) ] )
[docs] def forward(self, x): x = self.patchembed2d(x) B, C, Lat, Lon = x.shape x = x.reshape(B, C, -1).transpose(1, 2) for blk in self.blocks: x = blk(x) skip = x.reshape(B, Lat, Lon, C) x = self.downsample(x) for blk in self.blocks_middle: x = blk(x) return x, skip
[docs] class DecoderLayer(nn.Module): """A 2D Transformer Decoder Module for one stage Args: img_size (tuple[int]): image size(Lat, Lon). patch_size (tuple[int]): Patch token size of Patch Recovery. out_chans (int): number of output channels of Patch Recovery. dim (int): Number of input channels of transformer. output_resolution (tuple[int]): Input resolution for transformer after upsampling. middle_resolution (tuple[int]): Input resolution for transformer before upsampling. depth (int): Number of blocks for transformer after upsampling. depth_middle (int): Number of blocks for transformer before upsampling. num_heads (int): Number of attention heads. window_size (tuple[int]): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ def __init__( self, img_size, patch_size, out_chans, dim, output_resolution, middle_resolution, depth, depth_middle, num_heads, window_size, mlp_ratio=4.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=nn.LayerNorm, ): super().__init__() self.out_chans = out_chans self.dim = dim self.output_resolution = output_resolution self.depth = depth self.depth_middle = depth_middle if isinstance(drop_path, Sequence): drop_path_middle = drop_path[depth:] drop_path = drop_path[:depth] else: drop_path_middle = drop_path if isinstance(num_heads, Sequence): num_heads_middle = num_heads[1] num_heads = num_heads[0] else: num_heads_middle = num_heads self.blocks_middle = nn.ModuleList( [ Transformer2DBlock( dim=dim * 2, input_resolution=middle_resolution, num_heads=num_heads_middle, window_size=window_size, shift_size=(0, 0) if i % 2 == 0 else None, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path_middle[i] if isinstance(drop_path_middle, Sequence) else drop_path_middle, norm_layer=norm_layer, ) for i in range(depth_middle) ] ) self.upsample = UpSample2D( in_dim=dim * 2, out_dim=dim, input_resolution=middle_resolution, output_resolution=output_resolution, ) self.blocks = nn.ModuleList( [ Transformer2DBlock( dim=dim, input_resolution=output_resolution, num_heads=num_heads, window_size=window_size, shift_size=(0, 0) if i % 2 == 0 else None, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, Sequence) else drop_path, norm_layer=norm_layer, ) for i in range(depth) ] ) self.patchrecovery2d = PatchRecovery2D(img_size, patch_size, 2 * dim, out_chans)
[docs] def forward(self, x, skip): B, Lat, Lon, C = skip.shape for blk in self.blocks_middle: x = blk(x) x = self.upsample(x) for blk in self.blocks: x = blk(x) output = torch.concat([x, skip.reshape(B, -1, C)], dim=-1) output = output.transpose(1, 2).reshape(B, -1, Lat, Lon) output = self.patchrecovery2d(output) return output
[docs] class SwinTransformer(nn.Module): """Swin Transformer Args: embed_dim (int): Patch embedding dimension. input_resolution (tuple[int]): Lat, Lon. num_heads (int): Number of attention heads in different layers. window_size (int | tuple[int]): Window size. depth (int): Number of blocks. """ def __init__(self, embed_dim, input_resolution, num_heads, window_size, depth): super().__init__() window_size = to_2tuple(window_size) padding = get_pad2d(input_resolution, to_2tuple(window_size)) padding_left, padding_right, padding_top, padding_bottom = padding self.padding = padding self.pad = nn.ZeroPad2d(padding) input_resolution = list(input_resolution) input_resolution[0] = input_resolution[0] + padding_top + padding_bottom input_resolution[1] = input_resolution[1] + padding_left + padding_right self.layer = SwinTransformerStage( dim=embed_dim, out_dim=embed_dim, input_resolution=input_resolution, depth=depth, downsample=None, num_heads=num_heads, window_size=window_size, )
[docs] def forward(self, x): B, C, Lat, Lon = x.shape padding_left, padding_right, padding_top, padding_bottom = self.padding # pad x = self.pad(x) _, _, pad_lat, pad_lon = x.shape x = x.permute(0, 2, 3, 1) # B Lat Lon C x = self.layer(x) x = x.permute(0, 3, 1, 2) # crop x = x[ :, :, padding_top : pad_lat - padding_bottom, padding_left : pad_lon - padding_right, ] return x