nemo_automodel.components.training.rng#

Module Contents#

Classes#

RNGState

StatefulRNG

RNG manager for reproducible RNG states across random, NumPy, and PyTorch.

ScopedRNG

Context manager for reproducible RNG states across random, NumPy, and PyTorch.

Functions#

init_all_rng

Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.

_get_rng_state

Get current RNG states.

_restore_rng_state

Restore RNG states from a saved state.

API#

nemo_automodel.components.training.rng.init_all_rng(seed: int, ranked: bool = False)#

Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.

Parameters:
  • seed (int) – Base seed value.

  • ranked (bool) – Adjust seed by process rank if True.

class nemo_automodel.components.training.rng.RNGState#
random_rng_state: tuple#

None

np_rng_state: tuple#

None

torch_rng_state: torch.Tensor#

None

cuda_rng_state: torch.Tensor#

None

nemo_automodel.components.training.rng._get_rng_state()#

Get current RNG states.

Returns:

RNG states for random, NumPy, and PyTorch.

Return type:

dict

nemo_automodel.components.training.rng._restore_rng_state(state)#

Restore RNG states from a saved state.

Parameters:

state (dict) – RNG states as returned by state_dict().

class nemo_automodel.components.training.rng.StatefulRNG(seed: int, ranked: bool = False)#

RNG manager for reproducible RNG states across random, NumPy, and PyTorch.

Initialization

Initialize and optionally rank-adjust RNGs with a given seed.

Parameters:
  • seed (int) – Base seed for RNGs.

  • ranked (bool) – Adjust seed based on process rank.

state_dict()#

Get current RNG states.

Returns:

RNG states for random, NumPy, and PyTorch.

Return type:

dict

load_state_dict(state)#

Restore RNG states from a saved state.

Parameters:

state (dict) – RNG states as returned by state_dict().

class nemo_automodel.components.training.rng.ScopedRNG(seed: int = 95050, ranked: bool = False)#

Context manager for reproducible RNG states across random, NumPy, and PyTorch.

Initialization

Initialize and optionally rank-adjust RNGs with a given seed.

Parameters:
  • seed (int) – Base seed for RNGs.

  • ranked (bool) – Adjust seed based on process rank.

__enter__()#

Save current RNG states.

__exit__(exc_type, exc_value, traceback)#

Restore RNG states on context exit.