FourCastNet

Introduction

This example reproduces FourCastNet 1 using Modulus. FourCastNet, short for Fourier ForeCasting Neural Network, is a global data-driven weather forecasting model that provides accurate short to medium range global predictions at 0.25◦ resolution. FourCastNet generates a week long forecast in less than 2 seconds, orders of magnitude faster than the ECMWF Integrated Forecasting System (IFS), a state-of-the-art Numerical Weather Prediction (NWP) model, with comparable or better accuracy. It is trained on a small subset of the ERA5 reanalysis dataset 2 from the ECMWF, which consists of hourly estimates of several atmospheric variables at a latitude and longitude resolution of \(0.25^{\circ}\). Given an inital condition from the ERA5 dataset as input, FourCastNet recursively applies an Adaptive Fourier Neural Operator (AFNO) network to predict their dynamics at later time steps. In the current iteration, FourCastNet forecasts 20 atmospheric variables. These variables, listed in the table below are sampled from the ERA5 dataset at a temporal resolution of 6 hours.

Table 4 FourCastNet modeled variables

Vertical Level

Variable

Surface

U10, V10, T2M, SP, MSLP

1000 hPa

U, V, Z

850 hPa

T, U, V, Z, RH

500 hPa

T, U, V, Z, RH

50 hPa

Z

Integrated

TCWV

In this tutorial, we will show you how to define, train and evaluate FourCastNet in Modulus. The topics covered here are:

  1. How to load the ERA5 dataset into Modulus

  2. How to define the FourCastNet architecture in Modulus

  3. How to train FourCastNet

  4. How to generate weather forecasts and quantitatively assess performance

Note

AFNOs are covered in detail in Adaptive Fourier Neural Operator and Darcy Flow with Adaptive Fourier Neural Operator and we recommend reading these chapters first. Please also refer to the ArXiv pre-print for more details on the original implementation 1.

Warning

The ERA5 dataset is very large (5 TB+) and we do not provide it as part of this tutorial. ERA5 data 2 was downloaded from the Copernicus Climate Change Service (C3S) Climate Data Store 3, 4.

Problem Description

The goal of FourCastNet is to forecast modeled variables on a short time scale of upto 10 days. FourCastNet is initialized using an initial condition from the ERA5 reanalysis dataset. The figure below shows an overview of how FourCastNet works:

Overview of FourCastNet

Fig. 98 FourCastNet overview. Figure reproduced with permission from 1.

To make a weather prediction, 20 different ERA5 variables each defined on a regular latitude/longitude grid of dimension \(720\times 1440\) spanning the entire globe at some starting time step \(t\) are given as inputs to the model (bottom left of figure). Then, an AFNO architecture (middle left) is used to predict these variables at a later time step \(t+\Delta t\) (the original paper uses a fixed time delta \(\Delta t\) of 6 hours). During inference, these predictions can be recursively fed back into the AFNO, which allows the model to predict multiple time steps ahead (bottom right). Furthermore, we can train the network by either using a single step prediction, or by unrolling the network over \(n\) steps and using a loss function which matches each predicted time step to training data (top right). Typically, single step prediction is used for initial training, and then two step prediction is used for fine tuning, as it is more expensive.

Note

The original paper employs an additional precipitation model (middle right), although we only implement the AFNO “backbone” model here.

Case Setup

To train FourCastNet, we use the ERA5 data over the years 1979 to 2015 (both included). When testing its performance, we use out of sample ERA5 data from 2018. Please see the original paper for a description of the 20 variables used and the preprocessing applied to the ERA5 dataset; they are specifically chosen to model important processes that influence low-level winds and precipitation. The data is stored using the following directory structure:

era5
├── train
│   ├── 1979.h5
│   ├── ...
│   ├── 2015.h5
├── test
│   ├── 2018.h5
└── stats
    ├── global_means.npy
    └── global_stds.py

where each HDF5 file contains all of the variables for each year, over 1460 time steps with 6 hour time deltas (i.e. it has a shape (1460, 20, 720, 1440)).

Note

All of the python scripts for this example are in examples/fourcastnet/.

Configuration

The configuration file for FourCastNet is similar to the configuration file used in the Darcy Flow with Adaptive Fourier Neural Operator example and is shown below.

defaults :
  - modulus_default
  - arch:
      - afno
  - scheduler: cosine_annealing
  - optimizer: adam
  - loss: sum
  - _self_

arch:
  afno:
    patch_size: 8
    embed_dim: 768
    depth: 12
    num_blocks: 8

optimizer:
  lr: 0.0005

scheduler:
  T_max: 80000

custom:
  n_channels: 20
  tstep: 1
  n_tsteps: 1
  training_data_path: "/era5/ngc_era5_data/train" # Training dataset path here
  test_data_path: "/era5/ngc_era5_data/test" # Test dataset path here
  num_workers:
    grid: 8
    validation: 8
  tag:

batch_size:
  grid: 2
  validation: 2

training:
  amp: true
  rec_constraint_freq: 100000 # Dont bother recording constraint here
  rec_results_freq : 5000
  save_network_freq: 5000
  print_stats_freq: 100
  summary_freq: 5000
  max_steps : 70000 # 80 epochs * (55k samples / 64 batch size)

In addition, we have added the custom.tstep and custom.n_tsteps parameters which define the time delta between the AFNO’s input and output time steps (in multiples of 6 hours, typically set to 1) and the number of time steps FourCastNet is unrolled over during training.

Loading Data

We load the ERA5 data into Modulus by defining a custom modulus.dataset.Dataset inside of fourcastnet/src/dataset.py:


class ERA5HDF5GridDataset(Dataset):
    """Lazy-loading ERA5 dataset.

    Parameters
    ----------
    data_dir : str
        Directory where ERA5 data is stored
    chans : List[int]
        Defines which ERA5 variables to load
    tstep : int
        Defines the size of the timestep between the input and output variables
    n_tsteps : int, optional
        Defines how many timesteps are included in the output variables
        Default is 1
    patch_size : int, optional
        If specified, crops input and output variables so image dimensions are
        divisible by patch_size
        Default is None
    n_samples_per_year : int, optional
        If specified, randomly selects n_samples_per_year samples from each year
        rather than all of the samples per year
        Default is None
    stats_dir : str, optional
        Directory to test data statistic numpy files that have the global mean and variance
    """

    def __init__(
        self,
        data_dir: str,
        chans: List[int],
        tstep: int = 1,
        n_tsteps: int = 1,
        patch_size: int = None,
        n_samples_per_year: int = None,
        stats_dir: str = None,
    ):
        self.data_dir = Path(to_absolute_path(data_dir))
        self.chans = chans
        self.nchans = len(self.chans)
        self.tstep = tstep
        self.n_tsteps = n_tsteps
        self.patch_size = patch_size
        self.n_samples_per_year = n_samples_per_year

        if stats_dir is None:
            self.stats_dir = self.data_dir.parent / "stats"

        # check root directory exists
        assert (
            self.data_dir.is_dir()
        ), f"Error, data directory {self.data_dir} does not exist"
        assert (
            self.stats_dir.is_dir()
        ), f"Error, stats directory {self.stats_dir} does not exist"

        # get all input data files
        self.data_paths = sorted(self.data_dir.glob("????.h5"))
        for data_path in self.data_paths:
            logging.info(f"ERA5 file found: {data_path}")
        self.n_years = len(self.data_paths)
        logging.info(f"Number of years: {self.n_years}")

        # get total number of examples and image shape from the first file,
        # assuming other files have exactly the same format.
        logging.info(f"Getting file stats from {self.data_paths[0]}")
        with h5py.File(self.data_paths[0], "r") as f:
            self.n_samples_per_year_all = f["fields"].shape[0]
            self.img_shape = f["fields"].shape[2:]
            logging.info(f"Number of channels available: {f['fields'].shape[1]}")

        # get example indices to use
        if self.n_samples_per_year is None:
            self.n_samples_per_year = self.n_samples_per_year_all
            self.samples = [
                np.arange(self.n_samples_per_year) for _ in range(self.n_years)
            ]
        else:
            if self.n_samples_per_year > self.n_samples_per_year_all:
                raise ValueError(
                    f"n_samples_per_year ({self.n_samples_per_year}) > number of samples available ({self.n_samples_per_year_all})!"
                )
            self.samples = [
                np.random.choice(
                    np.arange(self.n_samples_per_year_all),
                    self.n_samples_per_year,
                    replace=False,
                )
                for _ in range(self.n_years)
            ]
        logging.info(f"Number of samples/year: {self.n_samples_per_year}")

        # get total length
        self.length = self.n_years * self.n_samples_per_year

        # adjust image shape if patch_size defined
        if self.patch_size is not None:
            self.img_shape = [s - s % self.patch_size for s in self.img_shape]
        logging.info(f"Input image shape: {self.img_shape}")

        # load normalisation values
        # has shape [1, C, 1, 1]
        self.mu = np.load(self.stats_dir / "global_means.npy")[:, self.chans]
        # has shape [1, C, 1, 1]
        self.sd = np.load(self.stats_dir / "global_stds.npy")[:, self.chans]
        assert (
            self.mu.shape == self.sd.shape == (1, self.nchans, 1, 1)
        ), "Error, normalisation arrays have wrong shape"

    def worker_init_fn(self, iworker):
        super().worker_init_fn(iworker)

        # open all year files at once on worker thread
        self.data_files = [h5py.File(path, "r") for path in self.data_paths]

    @property
    def invar_keys(self):
        return ["x_t0"]

    @property
    def outvar_keys(self):
        return [f"x_t{(i+1)*self.tstep}" for i in range(self.n_tsteps)]

    def __getitem__(self, idx):

        # get local indices from global index
        year_idx = int(idx / self.n_samples_per_year)
        local_idx = int(idx % self.n_samples_per_year)
        in_idx = self.samples[year_idx][local_idx]

        # get output indices
        out_idxs = []
        for i in range(self.n_tsteps):
            out_idx = in_idx + (i + 1) * self.tstep
            # if at end of dataset, just learn identity instead
            if out_idx > (self.n_samples_per_year_all - 1):
                out_idx = in_idx
            out_idxs.append(out_idx)

        # get data
        xs = []
        for idx in [in_idx] + out_idxs:
            # get array
            # has shape [C, H, W]
            x = self.data_files[year_idx]["fields"][idx, self.chans]
            assert x.ndim == 3, f"Expected 3 dimensions, but got {x.shape}"

            # apply input / output normalisation (broadcasted operation)
            x = (x - self.mu[0]) / self.sd[0]

            # crop data if needed
            if self.patch_size is not None:
                x = x[..., : self.img_shape[0], : self.img_shape[1]]

            xs.append(x)

        # convert to tensor dicts
        invar = {"x_t0": xs[0]}
        outvar = {f"x_t{(i+1)*self.tstep}": x for i, x in enumerate(xs[1:])}
        invar = Dataset._to_tensor_dict(invar)
        outvar = Dataset._to_tensor_dict(outvar)
        # TODO: get rid to lambda weighting
        lambda_weighting = Dataset._to_tensor_dict(
            {k: np.ones_like(v) for k, v in outvar.items()}
        )
        # lambda_weighting = Dataset._to_tensor_dict(
        #     {k: np.array([1]) for k, v in outvar.items()}
        # )

        return invar, outvar, lambda_weighting

    def __len__(self):
        return self.length

Given an example index, the dataset’s __getitem__ method returns a single Modulus input variable, x_t0, which is a tensor of shape (20, 720, 1440) which contains the 20 ERA5 variables at a starting time step, and multiple output variables with the same shape, x_t1, x_t2, …, one for each predicted time step FourCastNet is unrolled over.

In side the training script, fourcastnet/era5_FCN.py, the ERA5 datasets are initialized using the following:

    # load training/ test data
    channels = list(range(cfg.custom.n_channels))
    train_dataset = ERA5HDF5GridDataset(
        cfg.custom.training_data_path,
        chans=channels,
        tstep=cfg.custom.tstep,
        n_tsteps=cfg.custom.n_tsteps,
        patch_size=cfg.arch.afno.patch_size,
    )
    test_dataset = ERA5HDF5GridDataset(
        cfg.custom.test_data_path,
        chans=channels,
        tstep=cfg.custom.tstep,
        n_tsteps=cfg.custom.n_tsteps,
        patch_size=cfg.arch.afno.patch_size,
        n_samples_per_year=20,
    )

FourCastNet Model

Next, we need to define FourCastNet as a custom Modulus architecture. This model is found inside fourcastnet/src/fourcastnet.py which is a wrapper class of Modulus’ AFNO model. FourCastNet has two training phases: the first is single step prediction and the second is two step predictions. This small wrapper allows AFNO to be executed for any n_tsteps of time steps using autoregressive forward passes.

class FourcastNetArch(Arch):
    "Defines the FourcastNet architecture"

    def __init__(
        self,
        input_keys: List[Key],
        output_keys: List[Key],
        img_shape: Tuple[int, int],
        detach_keys: List[Key] = [],
        patch_size: int = 16,
        embed_dim: int = 256,
        depth: int = 4,
        num_blocks: int = 4,
    ) -> None:
        """Fourcastnet model. This is a simple wrapper for Modulus' AFNO model.
        The only difference is that FourcastNet needs multi-step training. This class
        allows the model to auto-regressively predict multiple timesteps

        Parameters (Same as AFNO)
        ----------
        input_keys : List[Key]
            Input key list. The key dimension size should equal the variables channel dim.
        output_keys : List[Key]
            Output key list. The key dimension size should equal the variables channel dim.
        img_shape : Tuple[int, int]
            Input image dimensions (height, width)
        detach_keys : List[Key], optional
            List of keys to detach gradients, by default []
        patch_size : int, optional
            Size of image patchs, by default 16
        embed_dim : int, optional
            Embedded channel size, by default 256
        depth : int, optional
            Number of AFNO layers, by default 4
        num_blocks : int, optional
            Number of blocks in the frequency weight matrices, by default 4
        """
        super().__init__(
            input_keys=input_keys,
            output_keys=output_keys,
            detach_keys=detach_keys,
        )

        # get number of timesteps steps to unroll
        assert (
            len(self.input_keys) == 1
        ), "Error, FourcastNet only accepts one input variable (x_t0)"
        self.n_tsteps = len(self.output_keys)
        logging.info(f"Unrolling FourcastNet over {self.n_tsteps} timesteps")

        # get number of input/output channels
        in_channels = self.input_keys[0].size
        out_channels = self.output_keys[0].size

        # intialise AFNO kernel
        self._impl = AFNONet(
            in_channels=in_channels,
            out_channels=out_channels,
            patch_size=(patch_size, patch_size),
            img_size=img_shape,
            embed_dim=embed_dim,
            depth=depth,
            num_blocks=num_blocks,
        )

    def forward(self, in_vars: Dict[str, Tensor]) -> Dict[str, Tensor]:

        # prepare input tensor
        x = self.prepare_input(
            input_variables=in_vars,
            mask=self.input_key_dict.keys(),
            detach_dict=self.detach_key_dict,
            dim=1,
            input_scales=self.input_scales,
        )

        # unroll model over multiple timesteps
        ys = []
        for t in range(self.n_tsteps):
            x = self._impl(x)
            ys.append(x)
        y = torch.cat(ys, dim=1)

        # prepare output dict
        return self.prepare_output(
            output_tensor=y,
            output_var=self.output_key_dict,
            dim=1,
            output_scales=self.output_scales,
        )

The FourCastNet model is initialized in the training script, fourcastnet/era5_FCN.py:

    # define input/output keys
    input_keys = [Key(k, size=train_dataset.nchans) for k in train_dataset.invar_keys]
    output_keys = [Key(k, size=train_dataset.nchans) for k in train_dataset.outvar_keys]

    # make list of nodes to unroll graph on
    model = FourcastNetArch(
        input_keys=input_keys,
        output_keys=output_keys,
        img_shape=test_dataset.img_shape,
        patch_size=cfg.arch.afno.patch_size,
        embed_dim=cfg.arch.afno.embed_dim,
        depth=cfg.arch.afno.depth,
        num_blocks=cfg.arch.afno.num_blocks,
    )
    nodes = [model.make_node(name="FCN")]

Adding Constraints

With the custom dataset for loading the ERA5 data and the FourCastNet model created, the next step is setting up the Modulus training domain. The main training script is fourcastnet/era5_FCN.py and constraints the standard steps needed for training a model in Modulus. A standard data-driven grid constraint is created:

    # add constraints to domain
    supervised = SupervisedGridConstraint(
        nodes=nodes,
        dataset=train_dataset,
        batch_size=cfg.batch_size.grid,
        loss=LpLoss(),
        num_workers=cfg.custom.num_workers.grid,
    )
    domain.add_constraint(supervised, "supervised")

A validator is also added to the training script:

    # add validator
    val = GridValidator(
        nodes,
        dataset=test_dataset,
        batch_size=cfg.batch_size.validation,
        plotter=GridValidatorPlotter(n_examples=5),
        num_workers=cfg.custom.num_workers.validation,
    )
    domain.add_validator(val, "test")

Training the Model

The training can now be simply started by executing the python script.

python fcn_era5.py

Results and Post-processing

With the trained model fourcastnet/inferencer.py is used to calculate the latitude weighted Root Mean Squared Error (RMSE) and the latitude weighted Anomaly Correlation Coefficient (ACC) values. The inferencer script uses runs the trained model on multiple initial conditions provided in the test dataset. Below the ACC and RSME values of the model trained in Modulus is compared to the results of the original work with excellent comparison to the original work 1. Additionally, a 24 hour forecast is also illustrated comparing the integrated vertical column of atmospheric water vapor predicted by Modulus and the target ERA5 dataset.

Note

See the original ArXiv paper or src/metrics.py for more details on how these metrics are calculated. Multiple dataset statistics are needed to properly calculate the metrics of interest.

Modulus FourCastNet ACC

Fig. 99 Comparison of the anomaly correlation coefficient (ACC) of the predicted 10 meter u component wind speed (u10) and geopotential height (z500) using the original FourCastNet model (Original) and the version trained in Modulus.

Modulus FourCastNet RSME

Fig. 100 Comparison of the predictive root mean square error (RSME) of each variable between the original FourCastNet model (Original) and the version trained in Modulus.

Modulus FourCastNet TCWV

Fig. 101 24 hour prediction of the integrated vertical column of atmospheric water vapor predicted by Modulus compared to the ground truth ERA5 dataset from ECMWF.

References

1(1,2,3,4)

Pathak, Jaideep, et al. “FourCastNet : A global data-driven high-resolution weather model using adaptive Fourier neural operators” arXiv preprint arXiv:2202.11214 (2022).

2(1,2)

Hersbach, Hans, et al. “The ERA5 global reanalysis” Quarterly Journal of the Royal Meteorological Society (2020).

3

Hersbach, Hans, et al. “ERA5 hourly data on pressure levels from 1959 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS).” , 10.24381/cds.bd0915c6 (2018)

4

Hersbach, Hans et al., “ERA5 hourly data on single levels from 1959 to present. Copernicus Climate Change Service (C3S) Climate Data Store (CDS).” , 10.24381/cds.adbb2d47 (2018)

5

Guibas, John, et al. “Adaptive fourier neural operators: Efficient token mixers for transformers” International Conference on Learning Representations, 2022.