# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Import libraries
import torch
import logging
import numpy as np
from torch import nn
from typing import Dict, List, Optional, Callable, Union
# Import from Modulus
from modulus.sym.eq.derivatives import gradient
from modulus.sym.hydra import to_absolute_path, add_hydra_run_path
logger = logging.getLogger(__name__)
[docs]class Aggregator(nn.Module):
"""
Base class for loss aggregators
"""
def __init__(self, params, num_losses, weights):
super().__init__()
self.params: List[torch.Tensor] = list(params)
self.num_losses: int = num_losses
self.weights: Optional[Dict[str, float]] = weights
self.device: torch.device
self.device = list(set(p.device for p in self.params))[0]
self.init_loss: torch.Tensor = torch.tensor(0.0, device=self.device)
def weigh_losses_initialize(
weights: Optional[Dict[str, float]]
) -> Callable[
[Dict[str, torch.Tensor], Optional[Dict[str, float]]],
Dict[str, torch.Tensor],
]:
if weights is None:
def weigh_losses(
losses: Dict[str, torch.Tensor], weights: None
) -> Dict[str, torch.Tensor]:
return losses
else:
def weigh_losses(
losses: Dict[str, torch.Tensor], weights: Dict[str, float]
) -> Dict[str, torch.Tensor]:
for key in losses.keys():
if key not in weights.keys():
weights.update({key: 1.0})
losses = {key: weights[key] * losses[key] for key in losses.keys()}
return losses
return weigh_losses
self.weigh_losses = weigh_losses_initialize(self.weights)
[docs]class Sum(Aggregator):
"""
Loss aggregation by summation
"""
def __init__(self, params, num_losses, weights=None):
super().__init__(params, num_losses, weights)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Aggregates the losses by summation
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# Initialize loss
loss: torch.Tensor = torch.zeros_like(self.init_loss)
# Add losses
for key in losses.keys():
loss += losses[key]
return loss
[docs]class GradNorm(Aggregator):
"""
GradNorm for loss aggregation
Reference: "Chen, Z., Badrinarayanan, V., Lee, C.Y. and Rabinovich, A., 2018, July.
Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks.
In International Conference on Machine Learning (pp. 794-803). PMLR."
"""
def __init__(self, params, num_losses, alpha=1.0, weights=None):
super().__init__(params, num_losses, weights)
self.alpha: float = alpha
self.lmbda: torch.nn.Parameter = nn.Parameter(
torch.zeros(num_losses, device=self.device)
)
self.register_buffer(
"init_losses", torch.zeros(self.num_losses, device=self.device)
)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Weights and aggregates the losses using the gradNorm algorithm
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# get initial losses
if step == 0:
for i, key in enumerate(losses.keys()):
self.init_losses[i] = losses[key].clone().detach()
with torch.no_grad():
normalizer: torch.Tensor = self.num_losses / (torch.exp(self.lmbda).sum())
for i in range(self.num_losses):
self.lmbda[i] = self.lmbda[i].clone() + torch.log(
normalizer.detach()
) # c*exp(x) = exp(log(c)+x)
lmbda_exp: torch.Tensor = torch.exp(self.lmbda)
# compute relative losses, inverse rate, and grad coefficient
losses_stacked: torch.Tensor = torch.stack(list(losses.values()))
with torch.no_grad():
relative_losses: torch.Tensor = torch.div(losses_stacked, self.init_losses)
inverse_rate: torch.Tensor = relative_losses / (relative_losses.mean())
gradnorm_coef: torch.Tensor = torch.pow(inverse_rate, self.alpha)
# compute gradient norm and average gradient norm
grads_norm: torch.Tensor = torch.zeros_like(self.init_losses)
shared_params: torch.Tensor = self.params[-2] # TODO generalize this
for i, key in enumerate(losses.keys()):
grads: torch.Tensor = gradient(losses[key], [shared_params])[0]
grads_norm[i] = torch.norm(lmbda_exp[i] * grads.detach(), p=2)
avg_grad: torch.Tensor = grads_norm.detach().mean()
# compute gradnorm & model losses
loss_gradnorm: torch.Tensor = torch.abs(
grads_norm - avg_grad * gradnorm_coef
).sum()
loss_model: torch.Tensor = (lmbda_exp.detach() * losses_stacked).sum()
loss: torch.Tensor = loss_gradnorm + loss_model
return loss
[docs]class ResNorm(Aggregator):
"""
Residual normalization for loss aggregation
Contributors: T. Nandi, D. Van Essendelft, M. A. Nabian
"""
def __init__(self, params, num_losses, alpha=1.0, weights=None):
super().__init__(params, num_losses, weights)
self.alpha: float = alpha
self.lmbda: torch.nn.Parameter = nn.Parameter(
torch.zeros(num_losses, device=self.device)
)
self.register_buffer(
"init_losses", torch.zeros(self.num_losses, device=self.device)
)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Weights and aggregates the losses using the ResNorm algorithm
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# get initial losses
if step == 0:
for i, key in enumerate(losses.keys()):
self.init_losses[i] = losses[key].clone().detach()
with torch.no_grad():
normalizer: torch.Tensor = self.num_losses / (torch.exp(self.lmbda).sum())
for i in range(self.num_losses):
self.lmbda[i] = self.lmbda[i].clone() + torch.log(
normalizer.detach()
) # c*exp(x) = exp(log(c)+x)
lmbda_exp: torch.Tensor = torch.exp(self.lmbda)
# compute relative losses, inverse rate, and grad coefficient
losses_stacked: torch.Tensor = torch.stack(list(losses.values()))
with torch.no_grad():
relative_losses: torch.Tensor = torch.div(losses_stacked, self.init_losses)
inverse_rate: torch.Tensor = relative_losses / (relative_losses.mean())
resnorm_coef: torch.Tensor = torch.pow(inverse_rate, self.alpha)
# compute residual norm and average residual norm
residuals: torch.Tensor = torch.zeros_like(self.init_losses)
for i, key in enumerate(losses.keys()):
residuals[i] = lmbda_exp[i] * losses[key].detach()
avg_residuals: torch.Tensor = losses_stacked.detach().mean()
# compute ResNorm & model losses
loss_resnorm: torch.Tensor = torch.abs(
residuals - avg_residuals * resnorm_coef
).sum()
loss_model: torch.Tensor = (lmbda_exp.detach() * losses_stacked).sum()
loss: torch.Tensor = loss_resnorm + loss_model
return loss
[docs]class HomoscedasticUncertainty(Aggregator):
"""
Homoscedastic task uncertainty for loss aggregation
Reference: "Reference: Kendall, A., Gal, Y. and Cipolla, R., 2018.
Multi-task learning using uncertainty to weigh losses for scene geometry and semantics.
In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 7482-7491)."
"""
def __init__(self, params, num_losses, weights=None):
super().__init__(params, num_losses, weights)
self.log_var: torch.nn.Parameter = nn.Parameter(
torch.zeros(self.num_losses, device=self.device)
)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Weights and aggregates the losses using homoscedastic task uncertainty
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# Initialize loss
loss: torch.Tensor = torch.zeros_like(self.init_loss)
# Compute precision
precision: torch.Tensor = torch.exp(-self.log_var)
# Aggregate losses
for i, key in enumerate(losses.keys()):
loss += precision[i] * losses[key]
loss += self.log_var.sum()
loss /= 2.0
return loss
[docs]class LRAnnealing(Aggregator):
"""
Learning rate annealing for loss aggregation
References: "Wang, S., Teng, Y. and Perdikaris, P., 2020.
Understanding and mitigating gradient pathologies in physics-informed
neural networks. arXiv preprint arXiv:2001.04536.", and
"Jin, X., Cai, S., Li, H. and Karniadakis, G.E., 2021.
NSFnets (Navier-Stokes flow nets): Physics-informed neural networks for the
incompressible Navier-Stokes equations. Journal of Computational Physics, 426, p.109951."
"""
def __init__(
self,
params,
num_losses,
update_freq=1,
alpha=0.01,
ref_key=None,
eps=1e-8,
weights=None,
):
super().__init__(params, num_losses, weights)
self.update_freq: int = update_freq
self.alpha: float = alpha
self.ref_key: Union[str, None] = ref_key
self.eps: float = eps
self.register_buffer(
"lmbda_ema", torch.ones(self.num_losses, device=self.device)
)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Weights and aggregates the losses using the learning rate annealing algorithm
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# Initialize loss
loss: torch.Tensor = torch.zeros_like(self.init_loss)
# Determine reference loss
if self.ref_key is None:
ref_idx = 0
else:
for i, key in enumerate(losses.keys()):
if self.ref_key in key:
ref_idx = i
break
# Update loss weights and aggregate losses
if step % self.update_freq == 0:
grads_mean: List[torch.Tensor] = []
# Compute the mean of each loss gradients
for key in losses.keys():
grads: List[torch.Tensor] = gradient(losses[key], self.params)
grads_flattened: List[torch.Tensor] = []
for i in range(len(grads)):
if grads[i] is not None:
grads_flattened.append(torch.abs(torch.flatten(grads[i])))
grads_mean.append((torch.mean(torch.cat(grads_flattened))))
# Compute the exponential moving average of weights and aggregate losses
for i, key in enumerate(losses.keys()):
with torch.no_grad():
self.lmbda_ema[i] *= 1.0 - self.alpha
self.lmbda_ema[i] += (
self.alpha * grads_mean[ref_idx] / (grads_mean[i] + self.eps)
)
loss += self.lmbda_ema[i].clone() * losses[key]
# Aggregate losses without update to loss weights
else:
for i, key in enumerate(losses.keys()):
loss += self.lmbda_ema[i] * losses[key]
return loss
[docs]class SoftAdapt(Aggregator):
"""
SoftAdapt for loss aggregation
Reference: "Heydari, A.A., Thompson, C.A. and Mehmood, A., 2019.
Softadapt: Techniques for adaptive loss weighting of neural networks with multi-part loss functions.
arXiv preprint arXiv: 1912.12355."
"""
def __init__(self, params, num_losses, eps=1e-8, weights=None):
super().__init__(params, num_losses, weights)
self.eps: float = eps
self.register_buffer(
"prev_losses", torch.zeros(self.num_losses, device=self.device)
)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Weights and aggregates the losses using the original variant of the softadapt algorithm
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# Initialize loss
loss: torch.Tensor = torch.zeros_like(self.init_loss)
# Aggregate losses by summation at step 0
if step == 0:
for i, key in enumerate(losses.keys()):
loss += losses[key]
self.prev_losses[i] = losses[key].clone().detach()
# Aggregate losses using SoftAdapt for step > 0
else:
lmbda: torch.Tensor = torch.ones_like(self.prev_losses)
lmbda_sum: torch.Tensor = torch.zeros_like(self.init_loss)
losses_stacked: torch.Tensor = torch.stack(list(losses.values()))
normalizer: torch.Tensor = (losses_stacked / self.prev_losses).max()
for i, key in enumerate(losses.keys()):
with torch.no_grad():
lmbda[i] = torch.exp(
losses[key] / (self.prev_losses[i] + self.eps) - normalizer
)
lmbda_sum += lmbda[i]
loss += lmbda[i].clone() * losses[key]
self.prev_losses[i] = losses[key].clone().detach()
loss *= self.num_losses / (lmbda_sum + self.eps)
return loss
[docs]class Relobralo(Aggregator):
"""
Relative loss balancing with random lookback
Reference: "Bischof, R. and Kraus, M., 2021.
Multi-Objective Loss Balancing for Physics-Informed Deep Learning.
arXiv preprint arXiv:2110.09813."
"""
def __init__(
self, params, num_losses, alpha=0.95, beta=0.99, tau=1.0, eps=1e-8, weights=None
):
super().__init__(params, num_losses, weights)
self.alpha: float = alpha
self.beta: float = beta
self.tau: float = tau
self.eps: float = eps
self.register_buffer(
"init_losses", torch.zeros(self.num_losses, device=self.device)
)
self.register_buffer(
"prev_losses", torch.zeros(self.num_losses, device=self.device)
)
self.register_buffer(
"lmbda_ema", torch.ones(self.num_losses, device=self.device)
)
[docs] def forward(self, losses: Dict[str, torch.Tensor], step: int) -> torch.Tensor:
"""
Weights and aggregates the losses using the ReLoBRaLo algorithm
Parameters
----------
losses : Dict[str, torch.Tensor]
A dictionary of losses.
step : int
Optimizer step.
Returns
-------
loss : torch.Tensor
Aggregated loss.
"""
# weigh losses
losses = self.weigh_losses(losses, self.weights)
# Initialize loss
loss: torch.Tensor = torch.zeros_like(self.init_loss)
# Aggregate losses by summation at step 0
if step == 0:
for i, key in enumerate(losses.keys()):
loss += losses[key]
self.init_losses[i] = losses[key].clone().detach()
self.prev_losses[i] = losses[key].clone().detach()
# Aggregate losses using ReLoBRaLo for step > 0
else:
losses_stacked: torch.Tensor = torch.stack(list(losses.values()))
normalizer_prev: torch.Tensor = (
losses_stacked / (self.tau * self.prev_losses)
).max()
normalizer_init: torch.Tensor = (
losses_stacked / (self.tau * self.init_losses)
).max()
rho: torch.Tensor = torch.bernoulli(torch.tensor(self.beta))
with torch.no_grad():
lmbda_prev: torch.Tensor = torch.exp(
losses_stacked / (self.tau * self.prev_losses + self.eps)
- normalizer_prev
)
lmbda_init: torch.Tensor = torch.exp(
losses_stacked / (self.tau * self.init_losses + self.eps)
- normalizer_init
)
lmbda_prev *= self.num_losses / (lmbda_prev.sum() + self.eps)
lmbda_init *= self.num_losses / (lmbda_init.sum() + self.eps)
# Compute the exponential moving average of weights and aggregate losses
for i, key in enumerate(losses.keys()):
with torch.no_grad():
self.lmbda_ema[i] = self.alpha * (
rho * self.lmbda_ema[i].clone() + (1.0 - rho) * lmbda_init[i]
)
self.lmbda_ema[i] += (1.0 - self.alpha) * lmbda_prev[i]
loss += self.lmbda_ema[i].clone() * losses[key]
self.prev_losses[i] = losses[key].clone().detach()
return loss
[docs]class NTK(nn.Module):
def __init__(self, run_per_step: int = 1000, save_name: Union[str, None] = None):
super(NTK, self).__init__()
self.run_per_step = run_per_step
self.if_csv_head = True
self.save_name = (
to_absolute_path(add_hydra_run_path(save_name)) if save_name else None
)
if self.save_name:
logger.warning(
"Cuda graphs does not work when saving NTK values to file! Set `cuda_graphs` to false."
)
def group_ntk(self, model, losses):
# The item in this losses should scalar loss values after MSE, etc.
ntk_value = dict()
for key, loss in losses.items():
grad = torch.autograd.grad(
torch.sqrt(torch.abs(loss)),
model.parameters(),
retain_graph=True,
allow_unused=True,
)
ntk_value[key] = torch.sqrt(
torch.sum(
torch.stack(
[torch.sum(t.detach() ** 2) for t in grad if t is not None],
dim=0,
)
)
)
return ntk_value
def save_ntk(self, ntk_dict, step):
import pandas as pd # TODO: Remove
output_dict = {}
for key, value in ntk_dict.items():
output_dict[key] = value.cpu().numpy()
df = pd.DataFrame(output_dict, index=[step])
df.to_csv(self.save_name + ".csv", mode="a", header=self.if_csv_head)
self.if_csv_head = False
[docs] def forward(self, constraints, ntk_weights, step):
losses = dict()
dict_constraint_losses = dict()
ntk_sum = 0
# Execute constraint forward passes
for key, constraint in constraints.items():
# TODO: Test streaming here
torch.cuda.nvtx.range_push(f"Running Constraint {key}")
constraint.forward()
torch.cuda.nvtx.range_pop()
for key, constraint in constraints.items():
# compute losses
constraint_losses = constraint.loss(step)
if (step % self.run_per_step == 0) and (step > 0):
ntk_dict = self.group_ntk(constraint.model, constraint_losses)
else:
ntk_dict = None
if ntk_dict is not None:
ntk_weights[key] = ntk_dict
if ntk_weights.get(key) is not None:
ntk_sum += torch.sum(
torch.stack(list(ntk_weights[key].values()), dim=0)
)
dict_constraint_losses[key] = constraint_losses
if step == 0: # May not work on restarts
ntk_sum = 1.0
if self.save_name and (step % self.run_per_step == 0) and (step > 0):
self.save_ntk(
{
d_key + "_" + k: v
for d_key, d in ntk_weights.items()
for k, v in d.items()
},
step,
)
for key, constraint_losses in dict_constraint_losses.items():
# add together losses of like kind
for loss_key, value in constraint_losses.items():
if (
ntk_weights.get(key) is None
or ntk_weights[key].get(loss_key) is None
):
ntk_weight = ntk_sum / 1.0
else:
ntk_weight = ntk_sum / ntk_weights[key][loss_key]
if loss_key not in list(losses.keys()):
losses[loss_key] = ntk_weight * value
else:
losses[loss_key] += ntk_weight * value
return losses, ntk_weights