# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modulus.utils.sfno.distributed import comm
from modulus.utils.sfno.distributed.mappings import (
reduce_from_parallel_region,
copy_to_parallel_region,
)
[docs]class Preprocessor2D(nn.Module):
"""
Preprocessing methods to flatten image history, add static features, and
convert the data format from NCHW to NHWC.
"""
def __init__(self, params): # pragma: no cover
super(Preprocessor2D, self).__init__()
self.n_history = params.n_history
self.transform_to_nhwc = params.enable_nhwc
self.history_normalization_mode = params.history_normalization_mode
if self.history_normalization_mode == "exponential":
self.history_normalization_decay = params.history_normalization_decay
# inverse ordering, since first element is oldest
history_normalization_weights = torch.exp(
(-self.history_normalization_decay)
* torch.arange(
start=self.n_history, end=-1, step=-1, dtype=torch.float32
)
)
history_normalization_weights = history_normalization_weights / torch.sum(
history_normalization_weights
)
history_normalization_weights = torch.reshape(
history_normalization_weights, (1, -1, 1, 1, 1)
)
elif self.history_normalization_mode == "mean":
history_normalization_weights = torch.Tensor(
1.0 / float(self.n_history + 1), dtype=torch.float32
)
history_normalization_weights = torch.reshape(
history_normalization_weights, (1, -1, 1, 1, 1)
)
else:
history_normalization_weights = torch.ones(
self.n_history + 1, dtype=torch.float32
)
self.register_buffer(
"history_normalization_weights",
history_normalization_weights,
persistent=False,
)
self.history_mean = None
self.history_std = None
self.history_diff_mean = None
self.history_diff_var = None
self.history_eps = 1e-6
self.img_shape = [params.img_shape_x, params.img_shape_y]
# unpredicted input channels:
self.unpredicted_inp_train = None
self.unpredicted_tar_train = None
self.unpredicted_inp_eval = None
self.unpredicted_tar_eval = None
# process static features
static_features = None
# needed for sharding
start_x = params.img_local_offset_x
end_x = min(start_x + params.img_local_shape_x, params.img_shape_x)
pad_x = params.img_local_shape_x - (end_x - start_x)
start_y = params.img_local_offset_y
end_y = min(start_y + params.img_local_shape_y, params.img_shape_y)
pad_y = params.img_local_shape_y - (end_y - start_y)
# set up grid
if params.add_grid:
with torch.no_grad():
tx = torch.linspace(0, 1, params.img_shape_x + 1, dtype=torch.float32)[
0:-1
]
ty = torch.linspace(0, 1, params.img_shape_y + 1, dtype=torch.float32)[
0:-1
]
x_grid, y_grid = torch.meshgrid(tx, ty, indexing="ij")
x_grid, y_grid = x_grid.unsqueeze(0).unsqueeze(0), y_grid.unsqueeze(
0
).unsqueeze(0)
grid = torch.cat([x_grid, y_grid], dim=1)
# shard spatially:
grid = grid[:, :, start_x:end_x, start_y:end_y]
# pad if needed
grid = F.pad(grid, [0, pad_y, 0, pad_x])
# transform if requested
if params.gridtype == "sinusoidal":
num_freq = 1
if hasattr(params, "grid_num_frequencies"):
num_freq = int(params.grid_num_frequencies)
singrid = None
for freq in range(1, num_freq + 1):
if singrid is None:
singrid = torch.sin(grid)
else:
singrid = torch.cat(
[singrid, torch.sin(freq * grid)], dim=1
)
static_features = singrid
else:
static_features = grid
if params.add_orography:
from utils.conditioning_inputs import get_orography
oro = torch.tensor(
get_orography(params.orography_path), dtype=torch.float32
)
oro = torch.reshape(oro, (1, 1, oro.shape[0], oro.shape[1]))
# shard
oro = oro[:, :, start_x:end_x, start_y:end_y]
# pad if needed
oro = F.pad(oro, [0, pad_y, 0, pad_x])
if static_features is None:
static_features = oro
else:
static_features = torch.cat([static_features, oro], dim=1)
if params.add_landmask:
from utils.conditioning_inputs import get_land_mask
lsm = torch.tensor(get_land_mask(params.landmask_path), dtype=torch.long)
# one hot encode and move channels to front:
lsm = torch.permute(torch.nn.functional.one_hot(lsm), (2, 0, 1)).to(
torch.float32
)
lsm = torch.reshape(lsm, (1, lsm.shape[0], lsm.shape[1], lsm.shape[2]))
# shard
lsm = lsm[:, :, start_x:end_x, start_y:end_y]
# pad if needed
lsm = F.pad(lsm, [0, pad_y, 0, pad_x])
if static_features is None:
static_features = lsm
else:
static_features = torch.cat([static_features, lsm], dim=1)
self.do_add_static_features = False
if static_features is not None:
self.do_add_static_features = True
self.register_buffer("static_features", static_features, persistent=False)
[docs] def flatten_history(self, x): # pragma: no cover
"""Flatten input so that history is included as part of channels"""
if x.dim() == 5:
b_, t_, c_, h_, w_ = x.shape
x = torch.reshape(x, (b_, t_ * c_, h_, w_))
return x
[docs] def expand_history(self, x, nhist): # pragma: no cover
"""Expand history from flattened data"""
if x.dim() == 4:
b_, ct_, h_, w_ = x.shape
x = torch.reshape(x, (b_, nhist, ct_ // nhist, h_, w_))
return x
[docs] def add_static_features(self, x): # pragma: no cover
"""Adds static features to the input"""
if self.do_add_static_features:
# we need to replicate the grid for each batch:
static = torch.tile(self.static_features, dims=(x.shape[0], 1, 1, 1))
x = torch.cat([x, static], dim=1)
return x
[docs] def remove_static_features(self, x): # pragma: no cover
"""
Removes static features from the input
only remove if something was added in the first place
"""
if self.do_add_static_features:
nfeat = self.static_features.shape[1]
x = x[:, : x.shape[1] - nfeat, :, :]
return x
[docs] def append_history(self, x1, x2, step): # pragma: no cover
"""
Appends history to the main input.
Without history, just returns the second tensor (x2).
"""
# take care of unpredicted features first
# this is necessary in order to copy the targets unpredicted features
# (such as zenith angle) into the inputs unpredicted features,
# such that they can be forward in the next autoregressive step
# update the unpredicted input
if self.training:
if (self.unpredicted_tar_train is not None) and (
step < self.unpredicted_tar_train.shape[1]
):
utar = self.unpredicted_tar_train[:, step : (step + 1), :, :, :]
if self.n_history == 0:
self.unpredicted_inp_train.copy_(utar)
else:
self.unpredicted_inp_train.copy_(
torch.cat(
[self.unpredicted_inp_train[:, 1:, :, :, :], utar], dim=1
)
)
else:
if (self.unpredicted_tar_eval is not None) and (
step < self.unpredicted_tar_eval.shape[1]
):
utar = self.unpredicted_tar_eval[:, step : (step + 1), :, :, :]
if self.n_history == 0:
self.unpredicted_inp_eval.copy_(utar)
else:
self.unpredicted_inp_eval.copy_(
torch.cat(
[self.unpredicted_inp_eval[:, 1:, :, :, :], utar], dim=1
)
)
# without history, just return the second tensor
if self.n_history > 0:
# this is more complicated
x1 = self.expand_history(x1, nhist=self.n_history + 1)
x2 = self.expand_history(x2, nhist=1)
# append
res = torch.cat([x1[:, 1:, :, :, :], x2], dim=1)
# flatten again
res = self.flatten_history(res)
else:
res = x2
return res
[docs] def append_channels(self, x, xc): # pragma: no cover
"""Appends channels"""
xdim = x.dim()
x = self.expand_history(x, self.n_history + 1)
xc = self.expand_history(xc, self.n_history + 1)
# concatenate
xo = torch.cat([x, xc], dim=2)
# flatten if requested
if xdim == 4:
xo = self.flatten_history(xo)
return xo
[docs] def history_compute_stats(self, x): # pragma: no cover
"""Compute stats from history timesteps"""
if self.history_normalization_mode == "none":
self.history_mean = torch.zeros(
(1, 1, 1, 1), dtype=torch.float32, device=x.device
)
self.history_std = torch.ones(
(1, 1, 1, 1), dtype=torch.float32, device=x.device
)
elif self.history_normalization_mode == "timediff":
# reshaping
xdim = x.dim()
if xdim == 4:
b_, c_, h_, w_ = x.shape
xr = torch.reshape(
x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)
)
else:
xshape = x.shape
xr = x
# time difference mean:
self.history_diff_mean = torch.mean(
torch.sum(xr[:, 1:, ...] - xr[:, 0:-1, ...], dim=(4, 5)), dim=(1, 2)
)
# reduce across gpus
if comm.get_size("spatial") > 1:
self.history_diff_mean = reduce_from_parallel_region(
self.history_diff_mean, "spatial"
)
self.history_diff_mean = self.history_diff_mean / float(
self.img_shape[0] * self.img_shape[1]
)
# time difference std
self.history_diff_var = torch.mean(
torch.sum(
torch.square(
(xr[:, 1:, ...] - xr[:, 0:-1, ...]) - self.history_diff_mean
),
dim=(4, 5),
),
dim=(1, 2),
)
# reduce across gpus
if comm.get_size("spatial") > 1:
self.history_diff_var = reduce_from_parallel_region(
self.history_diff_var, "spatial"
)
self.history_diff_var = self.history_diff_var / float(
self.img_shape[0] * self.img_shape[1]
)
# time difference stds
self.history_diff_mean = copy_to_parallel_region(
self.history_diff_mean, "spatial"
)
self.history_diff_var = copy_to_parallel_region(
self.history_diff_var, "spatial"
)
else:
xdim = x.dim()
if xdim == 4:
b_, c_, h_, w_ = x.shape
xr = torch.reshape(
x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)
)
else:
xshape = x.shape
xr = x
# mean
# compute weighted mean over dim 1, but sum over dim=3,4
self.history_mean = torch.sum(
xr * self.history_normalization_weights, dim=(1, 3, 4), keepdim=True
)
# reduce across gpus
if comm.get_size("spatial") > 1:
self.history_mean = reduce_from_parallel_region(
self.history_mean, "spatial"
)
self.history_mean = self.history_mean / float(
self.img_shape[0] * self.img_shape[1]
)
# compute std
self.history_std = torch.sum(
torch.square(xr - self.history_mean)
* self.history_normalization_weights,
dim=(1, 3, 4),
keepdim=True,
)
# reduce across gpus
if comm.get_size("spatial") > 1:
self.history_std = reduce_from_parallel_region(
self.history_std, "spatial"
)
self.history_std = torch.sqrt(
self.history_std / float(self.img_shape[0] * self.img_shape[1])
)
# squeeze
self.history_mean = torch.squeeze(self.history_mean, dim=1)
self.history_std = torch.squeeze(self.history_std, dim=1)
# copy to parallel region
self.history_mean = copy_to_parallel_region(self.history_mean, "spatial")
self.history_std = copy_to_parallel_region(self.history_std, "spatial")
return
[docs] def history_normalize(self, x, target=False): # pragma: no cover
"""Normalize history"""
if self.history_normalization_mode in ["none", "timediff"]:
return x
xdim = x.dim()
if xdim == 4:
b_, c_, h_, w_ = x.shape
xr = torch.reshape(
x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)
)
else:
xshape = x.shape
xr = x
x = self.flatten_history(x)
# normalize
if target:
# strip off the unpredicted channels
xn = (x - self.history_mean[:, : x.shape[1], :, :]) / self.history_std[
:, : x.shape[1], :, :
]
else:
# tile to include history
hm = torch.tile(self.history_mean, (1, self.n_history + 1, 1, 1))
hs = torch.tile(self.history_std, (1, self.n_history + 1, 1, 1))
xn = (x - hm) / hs
if xdim == 5:
xn = torch.reshape(xn, xshape)
return xn
[docs] def history_denormalize(self, xn, target=False): # pragma: no cover
"""Denormalize history"""
if self.history_normalization_mode in ["none", "timediff"]:
return xn
assert self.history_mean is not None
assert self.history_std is not None
xndim = xn.dim()
if xndim == 5:
xnshape = xn.shape
xn = self.flatten_history(xn)
# de-normalize
if target:
# strip off the unpredicted channels
x = (
xn * self.history_std[:, : xn.shape[1], :, :]
+ self.history_mean[:, : xn.shape[1], :, :]
)
else:
# tile to include history
hm = torch.tile(self.history_mean, (1, self.n_history + 1, 1, 1))
hs = torch.tile(self.history_std, (1, self.n_history + 1, 1, 1))
x = xn * hs + hm
if xndim == 5:
x = torch.reshape(x, xnshape)
return x
[docs] def cache_unpredicted_features(
self, x, y=None, xz=None, yz=None
): # pragma: no cover
"""Caches features not predicted by the model (such as zenith angle)"""
if self.training:
if (self.unpredicted_inp_train is not None) and (xz is not None):
self.unpredicted_inp_train.copy_(xz)
else:
self.unpredicted_inp_train = xz
if (self.unpredicted_tar_train is not None) and (yz is not None):
self.unpredicted_tar_train.copy_(yz)
else:
self.unpredicted_tar_train = yz
else:
if (self.unpredicted_inp_eval is not None) and (xz is not None):
self.unpredicted_inp_eval.copy_(xz)
else:
self.unpredicted_inp_eval = xz
if (self.unpredicted_tar_eval is not None) and (yz is not None):
self.unpredicted_tar_eval.copy_(yz)
else:
self.unpredicted_tar_eval = yz
return x, y
[docs] def append_unpredicted_features(self, inp): # pragma: no cover
"""Appends features not predicted by the model (such as zenith angle) from the input"""
if self.training:
if self.unpredicted_inp_train is not None:
inp = self.append_channels(inp, self.unpredicted_inp_train)
else:
if self.unpredicted_inp_eval is not None:
inp = self.append_channels(inp, self.unpredicted_inp_eval)
return inp
[docs] def remove_unpredicted_features(self, inp): # pragma: no cover
"""Removes features not predicted by the model (such as zenith angle) from the input"""
if self.training:
if self.unpredicted_inp_train is not None:
inpf = self.expand_history(inp, nhist=self.n_history + 1)
inpc = inpf[
:, :, : inpf.shape[2] - self.unpredicted_inp_train.shape[2], :, :
]
inp = self.flatten_history(inpc)
else:
if self.unpredicted_inp_eval is not None:
inpf = self.expand_history(inp, nhist=self.n_history + 1)
inpc = inpf[
:, :, : inpf.shape[2] - self.unpredicted_inp_eval.shape[2], :, :
]
inp = self.flatten_history(inpc)
return inp
[docs]def get_preprocessor(params): # pragma: no cover
"""Returns the preprocessor module"""
return Preprocessor2D(params)