NVIDIA Modulus Core (Latest Release)
Core (Latest Release)


Source code for modulus.utils.graphcast.loss

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import math
from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
from torch.autograd.function import once_differentiable

[docs]class CellAreaWeightedLossFunction(nn.Module): """Loss function with cell area weighting. Parameters ---------- area : torch.Tensor Cell area with shape [H, W]. """ def __init__(self, area): super().__init__() self.area = area
[docs] def forward(self, invar, outvar): """ Implicit forward function which computes the loss given a prediction and the corresponding targets. Parameters ---------- invar : torch.Tensor prediction of shape [T, C, H, W]. outvar : torch.Tensor target values of shape [T, C, H, W]. """ loss = (invar - outvar) ** 2 loss = loss.mean(dim=(0, 1)) loss = torch.mul(loss, self.area) loss = loss.mean() return loss
[docs]class CustomCellAreaWeightedLossAutogradFunction(torch.autograd.Function): """Autograd fuunction for custom loss with cell area weighting."""
[docs] @staticmethod def forward(ctx, invar: torch.Tensor, outvar: torch.Tensor, area: torch.Tensor): """Forward of custom loss function with cell area weighting.""" diff = invar - outvar # T x C x H x W loss = diff**2 loss = loss.mean(dim=(0, 1)) loss = torch.mul(loss, area) loss = loss.mean() loss_grad = diff * (2.0 / (math.prod(invar.shape))) loss_grad *= area.unsqueeze(0).unsqueeze(0) ctx.save_for_backward(loss_grad) return loss
[docs] @staticmethod @once_differentiable def backward(ctx, grad_loss: torch.Tensor): """Backward method of custom loss function with cell area weighting.""" # grad_loss should be 1, multiply nevertheless # to avoid issues with cases where this isn't the case (grad_invar,) = ctx.saved_tensors return grad_invar * grad_loss, None, None
[docs]class CustomCellAreaWeightedLossFunction(CellAreaWeightedLossFunction): """Custom loss function with cell area weighting. Parameters ---------- area : torch.Tensor Cell area with shape [H, W]. """ def __init__(self, area: torch.Tensor): super().__init__(area)
[docs] def forward(self, invar: torch.Tensor, outvar: torch.Tensor) -> torch.Tensor: """ Implicit forward function which computes the loss given a prediction and the corresponding targets. Parameters ---------- invar : torch.Tensor prediction of shape [T, C, H, W]. outvar : torch.Tensor target values of shape [T, C, H, W]. """ return CustomCellAreaWeightedLossAutogradFunction.apply( invar, outvar, self.area )
[docs]class GraphCastLossFunction(nn.Module): """Loss function as specified in GraphCast. Parameters ---------- area : torch.Tensor Cell area with shape [H, W]. """ def __init__(self, area, channels_list, dataset_metadata_path, time_diff_std_path): super().__init__() self.area = area self.channel_dict = self.get_channel_dict(dataset_metadata_path, channels_list) self.variable_weights = self.assign_variable_weights() self.time_diff_std = self.get_time_diff_std(time_diff_std_path, channels_list)
[docs] def forward(self, invar, outvar): """ Implicit forward function which computes the loss given a prediction and the corresponding targets. Parameters ---------- invar : torch.Tensor prediction of shape [T, C, H, W]. outvar : torch.Tensor target values of shape [T, C, H, W]. """ # outvar normalization loss = (invar - outvar) ** 2 # [T,C,H,W] # weighted by inverse variance loss = ( loss * 1.0 / torch.square(self.time_diff_std.view(1, -1, 1, 1).to(loss.device)) ) # weighted by variables variable_weights = self.variable_weights.view(1, -1, 1, 1).to(loss.device) loss = loss * variable_weights # [T,C,H,W] # weighted by area loss = loss.mean(dim=(0, 1)) loss = torch.mul(loss, self.area) loss = loss.mean() return loss
[docs] def get_time_diff_std(self, time_diff_std_path, channels_list): """Gets the time difference standard deviation""" if time_diff_std_path is not None: time_diff_np = np.load(time_diff_std_path) time_diff_np = time_diff_np[:, channels_list, ...] return torch.FloatTensor(time_diff_np) else: return torch.tensor([1.0], dtype=torch.float)
[docs] def get_channel_dict(self, dataset_metadata_path, channels_list): """Gets lists of surface and atmospheric channels""" with open(dataset_metadata_path, "r") as f: data_json = json.load(f) channel_list = [data_json["coords"]["channel"][c] for c in channels_list] # separate atmosphere and surface variables channel_dict = {"surface": [], "atmosphere": []} for each_channel in channel_list: if each_channel[-1].isdigit(): channel_dict["atmosphere"].append(each_channel) else: channel_dict["surface"].append(each_channel) return channel_dict
[docs] def parse_variable(self, variable_list): """Parse variable into its letter and numeric parts.""" for i, char in enumerate(variable_list): if char.isdigit(): return variable_list[:i], int(variable_list[i:])
[docs] def calculate_linear_weights(self, variables): """Calculate weights for each variable group.""" groups = defaultdict(list) # Group variables by their first letter for variable in variables: letter, number = self.parse_variable(variable) groups[letter].append((variable, number)) # Calculate weights for each group weights = {} for values in groups.values(): total = sum(number for _, number in values) for variable, number in values: weights[variable] = number / total return weights
[docs] def assign_surface_weights(self): """Assigns weights to surface variables""" surface_weights = {i: 0.1 for i in self.channel_dict["surface"]} if "t2m" in surface_weights: surface_weights["t2m"] = 1 return surface_weights
[docs] def assign_atmosphere_weights(self): """Assigns weights to atmospheric variables""" return self.calculate_linear_weights(self.channel_dict["atmosphere"])
[docs] def assign_variable_weights(self): """assigns per-variable per-pressure level weights""" surface_weights_dict = self.assign_surface_weights() atmosphere_weights_dict = self.assign_atmosphere_weights() surface_weights = list(surface_weights_dict.values()) atmosphere_weights = list(atmosphere_weights_dict.values()) variable_weights = torch.cat( (torch.FloatTensor(surface_weights), torch.FloatTensor(atmosphere_weights)) ) # [num_channel] return variable_weights
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jul 25, 2024.