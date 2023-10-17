# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Import libraries import torch import logging import numpy as np from torch import nn from typing import Dict , List , Optional , Callable , Union # Import from Modulus from modulus.sym.eq.derivatives import gradient from modulus.sym.hydra import to_absolute_path , add_hydra_run_path logger = logging . getLogger ( __name__ ) [docs] class Aggregator ( nn . Module ): """ Base class for loss aggregators """ def __init__ ( self , params , num_losses , weights ): super () . __init__ () self . params : List [ torch . Tensor ] = list ( params ) self . num_losses : int = num_losses self . weights : Optional [ Dict [ str , float ]] = weights self . device : torch . device self . device = list ( set ( p . device for p in self . params ))[ 0 ] self . init_loss : torch . Tensor = torch . tensor ( 0.0 , device = self . device ) def weigh_losses_initialize ( weights : Optional [ Dict [ str , float ]] ) -> Callable [ [ Dict [ str , torch . Tensor ], Optional [ Dict [ str , float ]]], Dict [ str , torch . Tensor ], ]: if weights is None : def weigh_losses ( losses : Dict [ str , torch . Tensor ], weights : None ) -> Dict [ str , torch . Tensor ]: return losses else : def weigh_losses ( losses : Dict [ str , torch . Tensor ], weights : Dict [ str , float ] ) -> Dict [ str , torch . Tensor ]: for key in losses . keys (): if key not in weights . keys (): weights . update ({ key : 1.0 }) losses = { key : weights [ key ] * losses [ key ] for key in losses . keys ()} return losses return weigh_losses self . weigh_losses = weigh_losses_initialize ( self . weights ) [docs] class Sum ( Aggregator ): """ Loss aggregation by summation """ def __init__ ( self , params , num_losses , weights = None ): super () . __init__ ( params , num_losses , weights ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Aggregates the losses by summation Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # Initialize loss loss : torch . Tensor = torch . zeros_like ( self . init_loss ) # Add losses for key in losses . keys (): loss += losses [ key ] return loss [docs] class GradNorm ( Aggregator ): """ GradNorm for loss aggregation Reference: "Chen, Z., Badrinarayanan, V., Lee, C.Y. and Rabinovich, A., 2018, July. Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks. In International Conference on Machine Learning (pp. 794-803). PMLR." """ def __init__ ( self , params , num_losses , alpha = 1.0 , weights = None ): super () . __init__ ( params , num_losses , weights ) self . alpha : float = alpha self . lmbda : torch . nn . Parameter = nn . Parameter ( torch . zeros ( num_losses , device = self . device ) ) self . register_buffer ( "init_losses" , torch . zeros ( self . num_losses , device = self . device ) ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Weights and aggregates the losses using the gradNorm algorithm Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # get initial losses if step == 0 : for i , key in enumerate ( losses . keys ()): self . init_losses [ i ] = losses [ key ] . clone () . detach () with torch . no_grad (): normalizer : torch . Tensor = self . num_losses / ( torch . exp ( self . lmbda ) . sum ()) for i in range ( self . num_losses ): self . lmbda [ i ] = self . lmbda [ i ] . clone () + torch . log ( normalizer . detach () ) # c*exp(x) = exp(log(c)+x) lmbda_exp : torch . Tensor = torch . exp ( self . lmbda ) # compute relative losses, inverse rate, and grad coefficient losses_stacked : torch . Tensor = torch . stack ( list ( losses . values ())) with torch . no_grad (): relative_losses : torch . Tensor = torch . div ( losses_stacked , self . init_losses ) inverse_rate : torch . Tensor = relative_losses / ( relative_losses . mean ()) gradnorm_coef : torch . Tensor = torch . pow ( inverse_rate , self . alpha ) # compute gradient norm and average gradient norm grads_norm : torch . Tensor = torch . zeros_like ( self . init_losses ) shared_params : torch . Tensor = self . params [ - 2 ] # TODO generalize this for i , key in enumerate ( losses . keys ()): grads : torch . Tensor = gradient ( losses [ key ], [ shared_params ])[ 0 ] grads_norm [ i ] = torch . norm ( lmbda_exp [ i ] * grads . detach (), p = 2 ) avg_grad : torch . Tensor = grads_norm . detach () . mean () # compute gradnorm & model losses loss_gradnorm : torch . Tensor = torch . abs ( grads_norm - avg_grad * gradnorm_coef ) . sum () loss_model : torch . Tensor = ( lmbda_exp . detach () * losses_stacked ) . sum () loss : torch . Tensor = loss_gradnorm + loss_model return loss [docs] class ResNorm ( Aggregator ): """ Residual normalization for loss aggregation Contributors: T. Nandi, D. Van Essendelft, M. A. Nabian """ def __init__ ( self , params , num_losses , alpha = 1.0 , weights = None ): super () . __init__ ( params , num_losses , weights ) self . alpha : float = alpha self . lmbda : torch . nn . Parameter = nn . Parameter ( torch . zeros ( num_losses , device = self . device ) ) self . register_buffer ( "init_losses" , torch . zeros ( self . num_losses , device = self . device ) ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Weights and aggregates the losses using the ResNorm algorithm Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # get initial losses if step == 0 : for i , key in enumerate ( losses . keys ()): self . init_losses [ i ] = losses [ key ] . clone () . detach () with torch . no_grad (): normalizer : torch . Tensor = self . num_losses / ( torch . exp ( self . lmbda ) . sum ()) for i in range ( self . num_losses ): self . lmbda [ i ] = self . lmbda [ i ] . clone () + torch . log ( normalizer . detach () ) # c*exp(x) = exp(log(c)+x) lmbda_exp : torch . Tensor = torch . exp ( self . lmbda ) # compute relative losses, inverse rate, and grad coefficient losses_stacked : torch . Tensor = torch . stack ( list ( losses . values ())) with torch . no_grad (): relative_losses : torch . Tensor = torch . div ( losses_stacked , self . init_losses ) inverse_rate : torch . Tensor = relative_losses / ( relative_losses . mean ()) resnorm_coef : torch . Tensor = torch . pow ( inverse_rate , self . alpha ) # compute residual norm and average residual norm residuals : torch . Tensor = torch . zeros_like ( self . init_losses ) for i , key in enumerate ( losses . keys ()): residuals [ i ] = lmbda_exp [ i ] * losses [ key ] . detach () avg_residuals : torch . Tensor = losses_stacked . detach () . mean () # compute ResNorm & model losses loss_resnorm : torch . Tensor = torch . abs ( residuals - avg_residuals * resnorm_coef ) . sum () loss_model : torch . Tensor = ( lmbda_exp . detach () * losses_stacked ) . sum () loss : torch . Tensor = loss_resnorm + loss_model return loss [docs] class HomoscedasticUncertainty ( Aggregator ): """ Homoscedastic task uncertainty for loss aggregation Reference: "Reference: Kendall, A., Gal, Y. and Cipolla, R., 2018. Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 7482-7491)." """ def __init__ ( self , params , num_losses , weights = None ): super () . __init__ ( params , num_losses , weights ) self . log_var : torch . nn . Parameter = nn . Parameter ( torch . zeros ( self . num_losses , device = self . device ) ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Weights and aggregates the losses using homoscedastic task uncertainty Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # Initialize loss loss : torch . Tensor = torch . zeros_like ( self . init_loss ) # Compute precision precision : torch . Tensor = torch . exp ( - self . log_var ) # Aggregate losses for i , key in enumerate ( losses . keys ()): loss += precision [ i ] * losses [ key ] loss += self . log_var . sum () loss /= 2.0 return loss [docs] class LRAnnealing ( Aggregator ): """ Learning rate annealing for loss aggregation References: "Wang, S., Teng, Y. and Perdikaris, P., 2020. Understanding and mitigating gradient pathologies in physics-informed neural networks. arXiv preprint arXiv:2001.04536.", and "Jin, X., Cai, S., Li, H. and Karniadakis, G.E., 2021. NSFnets (Navier-Stokes flow nets): Physics-informed neural networks for the incompressible Navier-Stokes equations. Journal of Computational Physics, 426, p.109951." """ def __init__ ( self , params , num_losses , update_freq = 1 , alpha = 0.01 , ref_key = None , eps = 1e-8 , weights = None , ): super () . __init__ ( params , num_losses , weights ) self . update_freq : int = update_freq self . alpha : float = alpha self . ref_key : Union [ str , None ] = ref_key self . eps : float = eps self . register_buffer ( "lmbda_ema" , torch . ones ( self . num_losses , device = self . device ) ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Weights and aggregates the losses using the learning rate annealing algorithm Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # Initialize loss loss : torch . Tensor = torch . zeros_like ( self . init_loss ) # Determine reference loss if self . ref_key is None : ref_idx = 0 else : for i , key in enumerate ( losses . keys ()): if self . ref_key in key : ref_idx = i break # Update loss weights and aggregate losses if step % self . update_freq == 0 : grads_mean : List [ torch . Tensor ] = [] # Compute the mean of each loss gradients for key in losses . keys (): grads : List [ torch . Tensor ] = gradient ( losses [ key ], self . params ) grads_flattened : List [ torch . Tensor ] = [] for i in range ( len ( grads )): if grads [ i ] is not None : grads_flattened . append ( torch . abs ( torch . flatten ( grads [ i ]))) grads_mean . append (( torch . mean ( torch . cat ( grads_flattened )))) # Compute the exponential moving average of weights and aggregate losses for i , key in enumerate ( losses . keys ()): with torch . no_grad (): self . lmbda_ema [ i ] *= 1.0 - self . alpha self . lmbda_ema [ i ] += ( self . alpha * grads_mean [ ref_idx ] / ( grads_mean [ i ] + self . eps ) ) loss += self . lmbda_ema [ i ] . clone () * losses [ key ] # Aggregate losses without update to loss weights else : for i , key in enumerate ( losses . keys ()): loss += self . lmbda_ema [ i ] * losses [ key ] return loss [docs] class SoftAdapt ( Aggregator ): """ SoftAdapt for loss aggregation Reference: "Heydari, A.A., Thompson, C.A. and Mehmood, A., 2019. Softadapt: Techniques for adaptive loss weighting of neural networks with multi-part loss functions. arXiv preprint arXiv: 1912.12355." """ def __init__ ( self , params , num_losses , eps = 1e-8 , weights = None ): super () . __init__ ( params , num_losses , weights ) self . eps : float = eps self . register_buffer ( "prev_losses" , torch . zeros ( self . num_losses , device = self . device ) ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Weights and aggregates the losses using the original variant of the softadapt algorithm Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # Initialize loss loss : torch . Tensor = torch . zeros_like ( self . init_loss ) # Aggregate losses by summation at step 0 if step == 0 : for i , key in enumerate ( losses . keys ()): loss += losses [ key ] self . prev_losses [ i ] = losses [ key ] . clone () . detach () # Aggregate losses using SoftAdapt for step > 0 else : lmbda : torch . Tensor = torch . ones_like ( self . prev_losses ) lmbda_sum : torch . Tensor = torch . zeros_like ( self . init_loss ) losses_stacked : torch . Tensor = torch . stack ( list ( losses . values ())) normalizer : torch . Tensor = ( losses_stacked / self . prev_losses ) . max () for i , key in enumerate ( losses . keys ()): with torch . no_grad (): lmbda [ i ] = torch . exp ( losses [ key ] / ( self . prev_losses [ i ] + self . eps ) - normalizer ) lmbda_sum += lmbda [ i ] loss += lmbda [ i ] . clone () * losses [ key ] self . prev_losses [ i ] = losses [ key ] . clone () . detach () loss *= self . num_losses / ( lmbda_sum + self . eps ) return loss [docs] class Relobralo ( Aggregator ): """ Relative loss balancing with random lookback Reference: "Bischof, R. and Kraus, M., 2021. Multi-Objective Loss Balancing for Physics-Informed Deep Learning. arXiv preprint arXiv:2110.09813." """ def __init__ ( self , params , num_losses , alpha = 0.95 , beta = 0.99 , tau = 1.0 , eps = 1e-8 , weights = None ): super () . __init__ ( params , num_losses , weights ) self . alpha : float = alpha self . beta : float = beta self . tau : float = tau self . eps : float = eps self . register_buffer ( "init_losses" , torch . zeros ( self . num_losses , device = self . device ) ) self . register_buffer ( "prev_losses" , torch . zeros ( self . num_losses , device = self . device ) ) self . register_buffer ( "lmbda_ema" , torch . ones ( self . num_losses , device = self . device ) ) [docs] def forward ( self , losses : Dict [ str , torch . Tensor ], step : int ) -> torch . Tensor : """ Weights and aggregates the losses using the ReLoBRaLo algorithm Parameters ---------- losses : Dict[str, torch.Tensor] A dictionary of losses. step : int Optimizer step. Returns ------- loss : torch.Tensor Aggregated loss. """ # weigh losses losses = self . weigh_losses ( losses , self . weights ) # Initialize loss loss : torch . Tensor = torch . zeros_like ( self . init_loss ) # Aggregate losses by summation at step 0 if step == 0 : for i , key in enumerate ( losses . keys ()): loss += losses [ key ] self . init_losses [ i ] = losses [ key ] . clone () . detach () self . prev_losses [ i ] = losses [ key ] . clone () . detach () # Aggregate losses using ReLoBRaLo for step > 0 else : losses_stacked : torch . Tensor = torch . stack ( list ( losses . values ())) normalizer_prev : torch . Tensor = ( losses_stacked / ( self . tau * self . prev_losses ) ) . max () normalizer_init : torch . Tensor = ( losses_stacked / ( self . tau * self . init_losses ) ) . max () rho : torch . Tensor = torch . bernoulli ( torch . tensor ( self . beta )) with torch . no_grad (): lmbda_prev : torch . Tensor = torch . exp ( losses_stacked / ( self . tau * self . prev_losses + self . eps ) - normalizer_prev ) lmbda_init : torch . Tensor = torch . exp ( losses_stacked / ( self . tau * self . init_losses + self . eps ) - normalizer_init ) lmbda_prev *= self . num_losses / ( lmbda_prev . sum () + self . eps ) lmbda_init *= self . num_losses / ( lmbda_init . sum () + self . eps ) # Compute the exponential moving average of weights and aggregate losses for i , key in enumerate ( losses . keys ()): with torch . no_grad (): self . lmbda_ema [ i ] = self . alpha * ( rho * self . lmbda_ema [ i ] . clone () + ( 1.0 - rho ) * lmbda_init [ i ] ) self . lmbda_ema [ i ] += ( 1.0 - self . alpha ) * lmbda_prev [ i ] loss += self . lmbda_ema [ i ] . clone () * losses [ key ] self . prev_losses [ i ] = losses [ key ] . clone () . detach () return loss [docs] class NTK ( nn . Module ): def __init__ ( self , run_per_step : int = 1000 , save_name : Union [ str , None ] = None ): super ( NTK , self ) . __init__ () self . run_per_step = run_per_step self . if_csv_head = True self . save_name = ( to_absolute_path ( add_hydra_run_path ( save_name )) if save_name else None ) if self . save_name : logger . warning ( "Cuda graphs does not work when saving NTK values to file! Set `cuda_graphs` to false." ) def group_ntk ( self , model , losses ): # The item in this losses should scalar loss values after MSE, etc. ntk_value = dict () for key , loss in losses . items (): grad = torch . autograd . grad ( torch . sqrt ( torch . abs ( loss )), model . parameters (), retain_graph = True , allow_unused = True , ) ntk_value [ key ] = torch . sqrt ( torch . sum ( torch . stack ( [ torch . sum ( t . detach () ** 2 ) for t in grad if t is not None ], dim = 0 , ) ) ) return ntk_value def save_ntk ( self , ntk_dict , step ): import pandas as pd # TODO: Remove output_dict = {} for key , value in ntk_dict . items (): output_dict [ key ] = value . cpu () . numpy () df = pd . DataFrame ( output_dict , index = [ step ]) df . to_csv ( self . save_name + ".csv" , mode = "a" , header = self . if_csv_head ) self . if_csv_head = False [docs] def forward ( self , constraints , ntk_weights , step ): losses = dict () dict_constraint_losses = dict () ntk_sum = 0 # Execute constraint forward passes for key , constraint in constraints . items (): # TODO: Test streaming here torch . cuda . nvtx . range_push ( f "Running Constraint { key } " ) constraint . forward () torch . cuda . nvtx . range_pop () for key , constraint in constraints . items (): # compute losses constraint_losses = constraint . loss ( step ) if ( step % self . run_per_step == 0 ) and ( step > 0 ): ntk_dict = self . group_ntk ( constraint . model , constraint_losses ) else : ntk_dict = None if ntk_dict is not None : ntk_weights [ key ] = ntk_dict if ntk_weights . get ( key ) is not None : ntk_sum += torch . sum ( torch . stack ( list ( ntk_weights [ key ] . values ()), dim = 0 ) ) dict_constraint_losses [ key ] = constraint_losses if step == 0 : # May not work on restarts ntk_sum = 1.0 if self . save_name and ( step % self . run_per_step == 0 ) and ( step > 0 ): self . save_ntk ( { d_key + "_" + k : v for d_key , d in ntk_weights . items () for k , v in d . items () }, step , ) for key , constraint_losses in dict_constraint_losses . items (): # add together losses of like kind for loss_key , value in constraint_losses . items (): if ( ntk_weights . get ( key ) is None or ntk_weights [ key ] . get ( loss_key ) is None ): ntk_weight = ntk_sum / 1.0 else : ntk_weight = ntk_sum / ntk_weights [ key ][ loss_key ] if loss_key not in list ( losses . keys ()): losses [ loss_key ] = ntk_weight * value else : losses [ loss_key ] += ntk_weight * value return losses , ntk_weights