deeplearning/modulus/modulus-sym/_modules/modulus/sym/hydra/training.html

Source code for modulus.sym.hydra.training

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Supported modulus training paradigms
"""

import torch

from dataclasses import dataclass
from hydra.core.config_store import ConfigStore
from omegaconf import MISSING, II
from typing import Any

from .loss import NTKConf


[docs]@dataclass class TrainingConf: max_steps: int = MISSING grad_agg_freq: int = MISSING rec_results_freq: int = MISSING rec_validation_freq: int = MISSING rec_inference_freq: int = MISSING rec_monitor_freq: int = MISSING rec_constraint_freq: int = MISSING save_network_freq: int = MISSING print_stats_freq: int = MISSING summary_freq: int = MISSING amp: bool = MISSING amp_dtype: str = MISSING
[docs]@dataclass class DefaultTraining(TrainingConf): max_steps: int = 10000 grad_agg_freq: int = 1 rec_results_freq: int = 1000 rec_validation_freq: int = II("training.rec_results_freq") rec_inference_freq: int = II("training.rec_results_freq") rec_monitor_freq: int = II("training.rec_results_freq") rec_constraint_freq: int = II("training.rec_results_freq") save_network_freq: int = 1000 print_stats_freq: int = 100 summary_freq: int = 1000 amp: bool = False amp_dtype: str = "float16" ntk: NTKConf = NTKConf()
[docs]@dataclass class VariationalTraining(DefaultTraining): test_function: str = MISSING use_quadratures: bool = False
[docs]@dataclass class StopCriterionConf: metric: Any = MISSING min_delta: Any = MISSING patience: int = MISSING mode: str = MISSING freq: int = MISSING strict: bool = MISSING
[docs]@dataclass class DefaultStopCriterion(StopCriterionConf): metric: Any = None min_delta: Any = None patience: int = 50000 mode: str = "min" freq: int = 1000 strict: bool = False
[docs]def register_training_configs() -> None: cs = ConfigStore.instance() cs.store( group="training", name="default_training", node=DefaultTraining, ) cs.store( group="training", name="variational_training", node=VariationalTraining, ) cs.store( group="stop_criterion", name="default_stop_criterion", node=DefaultStopCriterion, )
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.