Source code for nemo_automodel.training.rng
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
import torch
[docs]
def init_all_rng(seed: int, ranked: bool = False):
"""Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.
Args:
seed (int): Base seed value.
ranked (bool): Adjust seed by process rank if True.
"""
assert isinstance(seed, int) and seed > 0, "Seed must be a positive integer"
assert isinstance(ranked, bool), "Ranked must be a boolean"
if ranked:
# Example: use PyTorch's distributed rank if available
try:
import torch.distributed as dist
if dist.is_initialized():
seed += dist.get_rank()
except ImportError:
pass
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
[docs]
class StatefulRNG:
"""Context manager for reproducible RNG states across random, NumPy, and PyTorch.
"""
def __init__(self, seed: int, ranked: bool = False):
"""Initialize and optionally rank-adjust RNGs with a given seed.
Args:
seed (int): Base seed for RNGs.
ranked (bool): Adjust seed based on process rank.
"""
self._init_state = self.state_dict()
self._saved_state = None
self.seed = seed
self.ranked = ranked
[docs]
def __del__(self):
self.load_state_dict(self._init_state)
[docs]
def state_dict(self):
"""Get current RNG states.
Returns:
dict: RNG states for random, NumPy, and PyTorch.
"""
return {
"random_rng_state": random.getstate(),
"np_rng_state": np.random.get_state(),
"torch_rng_state": torch.get_rng_state(),
"cuda_rng_state": torch.cuda.get_rng_state_all(),
}
[docs]
def load_state_dict(self, state): # pragma: no cover
"""Restore RNG states from a saved state.
Args:
state (dict): RNG states as returned by state_dict().
"""
random.setstate(state["random_rng_state"])
np.random.set_state(state["np_rng_state"])
torch.set_rng_state(state["torch_rng_state"])
torch.cuda.set_rng_state_all(state["cuda_rng_state"])
[docs]
def __enter__(self):
"""Save current RNG states.
"""
assert self._saved_state is None
self._saved_state = self.state_dict()
return self
[docs]
def __exit__(self, exc_type, exc_value, traceback):
"""Restore RNG states on context exit.
"""
self.load_state_dict(self._saved_state)
self._saved_state = None