FourCastNet

This example reproduces FourCastNet 1 using Modulus Sym. FourCastNet, short for Fourier ForeCasting Neural Network, is a global data-driven weather forecasting model that provides accurate short to medium range global predictions at 0.25◦ resolution. FourCastNet generates a week long forecast in less than 2 seconds, orders of magnitude faster than the ECMWF Integrated Forecasting System (IFS), a state-of-the-art Numerical Weather Prediction (NWP) model, with comparable or better accuracy. It is trained on a small subset of the ERA5 reanalysis dataset 2 from the ECMWF, which consists of hourly estimates of several atmospheric variables at a latitude and longitude resolution of \(0.25^{\circ}\). Given an initial condition from the ERA5 dataset as input, FourCastNet recursively applies an Adaptive Fourier Neural Operator (AFNO) network to predict their dynamics at later time steps. In the current iteration, FourCastNet forecasts 20 atmospheric variables. These variables, listed in the table below are sampled from the ERA5 dataset at a temporal resolution of 6 hours.

Table 4 FourCastNet modeled variables

Vertical Level

Variable

Surface U10, V10, T2M, SP, MSLP
1000 hPa U, V, Z
850 hPa T, U, V, Z, RH
500 hPa T, U, V, Z, RH
50 hPa Z
Integrated TCWV

In this tutorial, we will show you how to define, train and evaluate FourCastNet in Modulus Sym. The topics covered here are:

  1. How to load the ERA5 dataset into Modulus Sym

  2. How to define the FourCastNet architecture in Modulus Sym

  3. How to train FourCastNet

  4. How to generate weather forecasts and quantitatively assess performance

Note

AFNOs are covered in detail in Adaptive Fourier Neural Operator and Darcy Flow with Adaptive Fourier Neural Operator and we recommend reading these chapters first. Please also refer to the ArXiv pre-print for more details on the original implementation 1.

Warning

The ERA5 dataset is very large (5 TB+) and we do not provide it as part of this tutorial. ERA5 data 2 was downloaded from the Copernicus Climate Change Service (C3S) Climate Data Store 3, 4.

The goal of FourCastNet is to forecast modeled variables on a short time scale of upto 10 days. FourCastNet is initialized using an initial condition from the ERA5 reanalysis dataset. The figure below shows an overview of how FourCastNet works:

fourcastnet_overview.png

Fig. 98 FourCastNet overview. Figure reproduced with permission from 1.

To make a weather prediction, 20 different ERA5 variables each defined on a regular latitude/longitude grid of dimension \(720\times 1440\) spanning the entire globe at some starting time step \(t\) are given as inputs to the model (bottom left of figure). Then, an AFNO architecture (middle left) is used to predict these variables at a later time step \(t+\Delta t\) (the original paper uses a fixed time delta \(\Delta t\) of 6 hours). During inference, these predictions can be recursively fed back into the AFNO, which allows the model to predict multiple time steps ahead (bottom right). Furthermore, we can train the network by either using a single step prediction, or by unrolling the network over \(n\) steps and using a loss function which matches each predicted time step to training data (top right). Typically, single step prediction is used for initial training, and then two step prediction is used for fine tuning, as it is more expensive.

Note

The original paper employs an additional precipitation model (middle right), although we only implement the AFNO “backbone” model here.

To train FourCastNet, we use the ERA5 data over the years 1979 to 2015 (both included). When testing its performance, we use out of sample ERA5 data from 2018. Please see the original paper for a description of the 20 variables used and the preprocessing applied to the ERA5 dataset; they are specifically chosen to model important processes that influence low-level winds and precipitation. The data is stored using the following directory structure:

Copy
Copied!
            

era5 ├── train │ ├── 1979.h5 │ ├── ... │ ├── 2015.h5 ├── test │ ├── 2018.h5 └── stats ├── global_means.npy └── global_stds.py

where each HDF5 file contains all of the variables for each year, over 1460 time steps with 6 hour time deltas (i.e. it has a shape (1460, 20, 720, 1440)).

Note

All of the python scripts for this example are in examples/fourcastnet/.

Configuration

The configuration file for FourCastNet is similar to the configuration file used in the Darcy Flow with Adaptive Fourier Neural Operator example and is shown below.

Copy
Copied!
            

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. defaults: - modulus_default - arch: - afno - scheduler: cosine_annealing - optimizer: adam - loss: sum - _self_ arch: afno: patch_size: 8 embed_dim: 768 depth: 12 num_blocks: 8 optimizer: lr: 0.0005 scheduler: T_max: 80000 custom: n_channels: 20 tstep: 1 n_tsteps: 1 train_dataset: kind: "default" data_path: "/era5/ngc_era5_data/train" # Training dataset path here test_dataset: kind: "default" data_path: "/era5/ngc_era5_data/test" # Test dataset path here num_workers: grid: 8 validation: 8 tag: batch_size: grid: 2 validation: 2 training: amp: true rec_constraint_freq: 100000 # Dont bother recording constraint here rec_results_freq: 5000 save_network_freq: 5000 print_stats_freq: 100 summary_freq: 5000 max_steps: 70000 # 80 epochs * (55k samples / 64 batch size)

In addition, we have added the custom.tstep and custom.n_tsteps parameters which define the time delta between the AFNO’s input and output time steps (in multiples of 6 hours, typically set to 1) and the number of time steps FourCastNet is unrolled over during training.

Loading Data

Modulus Sym FourCastNet currently has two options for loading the data:

  1. DALI-based dataloader which uses NVIDIA Data Loading Library (DALI) for accelerated data loading and processing.

  2. Standard PyTorch dataloader.

DALI dataloader is the default option, but can be changed by setting custom.train_dataset.kind option to pytorch.

Both dataloaders use a shared implementation which supports ERA5 data format and is defined in fourcastnet/src/dataset.py:

Copy
Copied!
            

class ERA5HDF5GridBaseDataset: """Lazy-loading ERA5 dataset. Provides common implementation that is used in map- or iterable-style datasets. Parameters ---------- data_dir : str Directory where ERA5 data is stored chans : List[int] Defines which ERA5 variables to load tstep : int Defines the size of the timestep between the input and output variables n_tsteps : int, optional Defines how many timesteps are included in the output variables Default is 1 patch_size : int, optional If specified, crops input and output variables so image dimensions are divisible by patch_size Default is None n_samples_per_year : int, optional If specified, randomly selects n_samples_per_year samples from each year rather than all of the samples per year Default is None stats_dir : str, optional Directory to test data statistic numpy files that have the global mean and variance """ def __init__( self, data_dir: str, chans: List[int], tstep: int = 1, n_tsteps: int = 1, patch_size: int = None, n_samples_per_year: int = None, stats_dir: str = None, **kwargs, ): self.data_dir = Path(to_absolute_path(data_dir)) self.chans = chans self.nchans = len(self.chans) self.tstep = tstep self.n_tsteps = n_tsteps self.patch_size = patch_size self.n_samples_per_year = n_samples_per_year if stats_dir is None: self.stats_dir = self.data_dir.parent / "stats" # check root directory exists assert ( self.data_dir.is_dir() ), f"Error, data directory{self.data_dir}does not exist" assert ( self.stats_dir.is_dir() ), f"Error, stats directory{self.stats_dir}does not exist" # get all input data files self.data_paths = sorted(self.data_dir.glob("????.h5")) for data_path in self.data_paths: logging.info(f"ERA5 file found:{data_path}") self.n_years = len(self.data_paths) logging.info(f"Number of years:{self.n_years}") # get total number of examples and image shape from the first file, # assuming other files have exactly the same format. logging.info(f"Getting file stats from{self.data_paths[0]}") with h5py.File(self.data_paths[0], "r") as f: self.n_samples_per_year_all = f["fields"].shape[0] self.img_shape = f["fields"].shape[2:] logging.info(f"Number of channels available:{f['fields'].shape[1]}") # get example indices to use if self.n_samples_per_year is None: self.n_samples_per_year = self.n_samples_per_year_all self.samples = [ np.arange(self.n_samples_per_year) for _ in range(self.n_years) ] else: if self.n_samples_per_year > self.n_samples_per_year_all: raise ValueError( f"n_samples_per_year ({self.n_samples_per_year}) > number of samples available ({self.n_samples_per_year_all})!" ) self.samples = [ np.random.choice( np.arange(self.n_samples_per_year_all), self.n_samples_per_year, replace=False, ) for _ in range(self.n_years) ] logging.info(f"Number of samples/year:{self.n_samples_per_year}") # get total length self.length = self.n_years * self.n_samples_per_year # adjust image shape if patch_size defined if self.patch_size is not None: self.img_shape = [s - s % self.patch_size for s in self.img_shape] logging.info(f"Input image shape:{self.img_shape}") # load normalisation values # has shape [1, C, 1, 1] self.mu = np.load(self.stats_dir / "global_means.npy")[:, self.chans] # has shape [1, C, 1, 1] self.sd = np.load(self.stats_dir / "global_stds.npy")[:, self.chans] assert ( self.mu.shape == self.sd.shape == (1, self.nchans, 1, 1) ), "Error, normalisation arrays have wrong shape" @property def invar_keys(self): return ["x_t0"] @property def outvar_keys(self): return [f"x_t{(i+1)*self.tstep}" for i in range(self.n_tsteps)]

Given an example index, the dataset’s __getitem__ method returns a single Modulus Sym input variable, x_t0, which is a tensor of shape (20, 720, 1440) which contains the 20 ERA5 variables at a starting time step, and multiple output variables with the same shape, x_t1, x_t2, …, one for each predicted time step FourCastNet is unrolled over:

Copy
Copied!
            

class ERA5HDF5GridDataset(ERA5HDF5GridBaseDataset, Dataset): """Map-style ERA5 dataset.""" def __getitem__(self, idx): # get local indices from global index year_idx = int(idx / self.n_samples_per_year) local_idx = int(idx % self.n_samples_per_year) in_idx = self.samples[year_idx][local_idx] # get output indices out_idxs = [] for i in range(self.n_tsteps): out_idx = in_idx + (i + 1) * self.tstep # if at end of dataset, just learn identity instead if out_idx > (self.n_samples_per_year_all - 1): out_idx = in_idx out_idxs.append(out_idx) # get data xs = [] for idx in [in_idx] + out_idxs: # get array # has shape [C, H, W] x = self.data_files[year_idx]["fields"][idx, self.chans] assert x.ndim == 3, f"Expected 3 dimensions, but got{x.shape}" # apply input / output normalisation (broadcasted operation) x = (x - self.mu[0]) / self.sd[0] # crop data if needed if self.patch_size is not None: x = x[..., : self.img_shape[0], : self.img_shape[1]] xs.append(x) # convert to tensor dicts assert len(self.invar_keys) == 1 invar = {self.invar_keys[0]: xs[0]} assert len(self.outvar_keys) == len(xs) - 1 outvar = {self.outvar_keys[i]: x for i, x in enumerate(xs[1:])} invar = Dataset._to_tensor_dict(invar) outvar = Dataset._to_tensor_dict(outvar) lambda_weighting = Dataset._to_tensor_dict( {k: np.ones_like(v) for k, v in outvar.items()} ) return invar, outvar, lambda_weighting

Inside the training script, fourcastnet/era5_FCN.py, the ERA5 datasets are initialized using the following:

Copy
Copied!
            

train_dataset = _create_dataset( cfg.custom.train_dataset.kind, data_dir=cfg.custom.train_dataset.data_path, chans=channels, tstep=cfg.custom.tstep, n_tsteps=cfg.custom.n_tsteps, patch_size=cfg.arch.afno.patch_size, batch_size=cfg.batch_size.grid, num_workers=cfg.custom.num_workers.grid, shuffle=True, ) test_dataset = _create_dataset( cfg.custom.test_dataset.kind, data_dir=cfg.custom.test_dataset.data_path, chans=channels, tstep=cfg.custom.tstep, n_tsteps=cfg.custom.n_tsteps, patch_size=cfg.arch.afno.patch_size, n_samples_per_year=20, batch_size=cfg.batch_size.validation, num_workers=cfg.custom.num_workers.validation, )

FourCastNet Model

Next, we need to define FourCastNet as a custom Modulus Sym architecture. This model is found inside fourcastnet/src/fourcastnet.py which is a wrapper class of Modulus Sym’ AFNO model. FourCastNet has two training phases: the first is single step prediction and the second is two step predictions. This small wrapper allows AFNO to be executed for any n_tsteps of time steps using autoregressive forward passes.

Copy
Copied!
            

class FourcastNetArch(Arch): "Defines the FourcastNet architecture" def __init__( self, input_keys: List[Key], output_keys: List[Key], img_shape: Tuple[int, int], detach_keys: List[Key] = [], patch_size: int = 16, embed_dim: int = 256, depth: int = 4, num_blocks: int = 4, ) -> None: """Fourcastnet model. This is a simple wrapper for Modulus' AFNO model. The only difference is that FourcastNet needs multi-step training. This class allows the model to auto-regressively predict multiple timesteps Parameters (Same as AFNO) ---------- input_keys : List[Key] Input key list. The key dimension size should equal the variables channel dim. output_keys : List[Key] Output key list. The key dimension size should equal the variables channel dim. img_shape : Tuple[int, int] Input image dimensions (height, width) detach_keys : List[Key], optional List of keys to detach gradients, by default [] patch_size : int, optional Size of image patchs, by default 16 embed_dim : int, optional Embedded channel size, by default 256 depth : int, optional Number of AFNO layers, by default 4 num_blocks : int, optional Number of blocks in the frequency weight matrices, by default 4 """ super().__init__( input_keys=input_keys, output_keys=output_keys, detach_keys=detach_keys, ) # get number of timesteps steps to unroll assert ( len(self.input_keys) == 1 ), "Error, FourcastNet only accepts one input variable (x_t0)" self.n_tsteps = len(self.output_keys) logging.info(f"Unrolling FourcastNet over{self.n_tsteps}timesteps") # get number of input/output channels in_channels = self.input_keys[0].size out_channels = self.output_keys[0].size # intialise AFNO kernel self._impl = AFNONet( in_channels=in_channels, out_channels=out_channels, patch_size=(patch_size, patch_size), img_size=img_shape, embed_dim=embed_dim, depth=depth, num_blocks=num_blocks, ) def forward(self, in_vars: Dict[str, Tensor]) -> Dict[str, Tensor]: # prepare input tensor x = self.prepare_input( input_variables=in_vars, mask=self.input_key_dict.keys(), detach_dict=self.detach_key_dict, dim=1, input_scales=self.input_scales, ) # unroll model over multiple timesteps ys = [] for t in range(self.n_tsteps): x = self._impl(x) ys.append(x) y = torch.cat(ys, dim=1) # prepare output dict return self.prepare_output( output_tensor=y, output_var=self.output_key_dict, dim=1, output_scales=self.output_scales, )

The FourCastNet model is initialized in the training script, fourcastnet/era5_FCN.py:

Copy
Copied!
            

# make list of nodes to unroll graph on model = FourcastNetArch( input_keys=input_keys, output_keys=output_keys, img_shape=test_dataset.img_shape, patch_size=cfg.arch.afno.patch_size, embed_dim=cfg.arch.afno.embed_dim, depth=cfg.arch.afno.depth, num_blocks=cfg.arch.afno.num_blocks, ) nodes = [model.make_node(name="FCN")]

Adding Constraints

With the custom dataset for loading the ERA5 data and the FourCastNet model created, the next step is setting up the Modulus Sym training domain. The main training script is fourcastnet/era5_FCN.py and constraints the standard steps needed for training a model in Modulus Sym. A standard data-driven grid constraint is created:

Copy
Copied!
            

# make domain domain = Domain() # add constraints to domain supervised = SupervisedGridConstraint( nodes=nodes, dataset=train_dataset, batch_size=cfg.batch_size.grid, loss=LpLoss(), num_workers=cfg.custom.num_workers.grid, ) domain.add_constraint(supervised, "supervised")

A validator is also added to the training script:

Copy
Copied!
            

# add validator val = GridValidator( nodes, dataset=test_dataset, batch_size=cfg.batch_size.validation, plotter=GridValidatorPlotter(n_examples=5), num_workers=cfg.custom.num_workers.validation, ) domain.add_validator(val, "test")

The training can now be simply started by executing the python script.

Copy
Copied!
            

python fcn_era5.py

Results and Post-processing

With the trained model fourcastnet/inferencer.py is used to calculate the latitude weighted Root Mean Squared Error (RMSE) and the latitude weighted Anomaly Correlation Coefficient (ACC) values. The inferencer script uses runs the trained model on multiple initial conditions provided in the test dataset. Below the ACC and RMSE values of the model trained in Modulus Sym is compared to the results of the original work with excellent comparison to the original work 1. Additionally, a 24 hour forecast is also illustrated comparing the integrated vertical column of atmospheric water vapor predicted by Modulus Sym and the target ERA5 dataset.

Note

See the original ArXiv paper or src/metrics.py for more details on how these metrics are calculated. Multiple dataset statistics are needed to properly calculate the metrics of interest.

fourcastnet_acc.png

Fig. 99 Comparison of the anomaly correlation coefficient (ACC) of the predicted 10 meter u component wind speed (u10) and geopotential height (z500) using the original FourCastNet model (Original) and the version trained in Modulus Sym.

fourcastnet_rmse.png

Fig. 100 Comparison of the predictive root mean square error (RMSE) of each variable between the original FourCastNet model (Original) and the version trained in Modulus Sym.

fourcastnet_tcwv.png

Fig. 101 24 hour prediction of the integrated vertical column of atmospheric water vapor predicted by Modulus Sym compared to the ground truth ERA5 dataset from ECMWF.

References

[1](1,2,3,4)

Pathak, Jaideep, et al. “FourCastNet : A global data-driven high-resolution weather model using adaptive Fourier neural operators” arXiv preprint arXiv:2202.11214 (2022).

[2](1,2)

Hersbach, Hans, et al. “The ERA5 global reanalysis” Quarterly Journal of the Royal Meteorological Society (2020).

[3]

Hersbach, Hans, et al. “ERA5 hourly data on pressure levels from 1959 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS).” , 10.24381/cds.bd0915c6 (2018)

[4]

Hersbach, Hans et al., “ERA5 hourly data on single levels from 1959 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS).” , 10.24381/cds.adbb2d47 (2018)

[5]

Guibas, John, et al. “Adaptive fourier neural operators: Efficient token mixers for transformers” International Conference on Learning Representations, 2022.

Previous Deep Operator Network
Next Interface Problem by Variational Method
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.