Source code for modulus.domain.validator.validator

import torch

[docs]class Validator: """ Validator base class """ def forward_grad(self, invar): pred_outvar = self.model(invar) return pred_outvar def forward_nograd(self, invar): with torch.no_grad(): pred_outvar = self.model(invar) return pred_outvar def save_results(self, name, results_dir, writer, save_filetypes, step): raise NotImplementedError("Subclass of Validator needs to implement this") @staticmethod def _l2_relative_error(true_var, pred_var): # TODO replace with metric classes new_var = {} for key in true_var.keys(): new_var["l2_relative_error_" + str(key)] = torch.sqrt( torch.mean(torch.square(true_var[key] - pred_var[key])) / torch.var(true_var[key]) ) return new_var
© Copyright 2021-2022, NVIDIA. Last updated on Apr 26, 2023.