nemo_automodel.components.training.rng#

Module Contents#

Classes#

RNGState

Snapshot of Python, NumPy, Torch, and CUDA RNG states.

StatefulRNG

RNG manager for reproducible RNG states across random, NumPy, and PyTorch.

ScopedRNG

Context manager for reproducible RNG states across random, NumPy, and PyTorch.

Functions#

init_all_rng

Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.

_get_rng_state

Get current RNG states.

_restore_rng_state

Restore RNG states from a saved state.

API#

nemo_automodel.components.training.rng.init_all_rng(seed: int, ranked: bool = False)#

Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.

Parameters:
  • seed (int) – Base seed value.

  • ranked (bool) – Adjust seed by process rank if True.

class nemo_automodel.components.training.rng.RNGState#

Snapshot of Python, NumPy, Torch, and CUDA RNG states.

random_rng_state: tuple#

None

np_rng_state: tuple#

None

torch_rng_state: torch.Tensor#

None

cuda_rng_state: torch.Tensor#

None

nemo_automodel.components.training.rng._get_rng_state()#

Get current RNG states.

Returns:

RNG states for random, NumPy, and PyTorch.

Return type:

dict

nemo_automodel.components.training.rng._restore_rng_state(state)#

Restore RNG states from a saved state.

Parameters:

state (dict) – RNG states as returned by state_dict().

class nemo_automodel.components.training.rng.StatefulRNG(seed: int, ranked: bool = False)#

RNG manager for reproducible RNG states across random, NumPy, and PyTorch.

Initialization

Initialize and optionally rank-adjust RNGs with a given seed.

Parameters:
  • seed (int) – Base seed for RNGs.

  • ranked (bool) – Adjust seed based on process rank.

state_dict()#

Get current RNG states.

Returns:

RNG states for random, NumPy, and PyTorch.

Return type:

dict

load_state_dict(state)#

Restore RNG states from a saved state.

Parameters:

state (dict) – RNG states as returned by state_dict().

class nemo_automodel.components.training.rng.ScopedRNG(seed: int = 95050, ranked: bool = False)#

Context manager for reproducible RNG states across random, NumPy, and PyTorch.

Initialization

Initialize and optionally rank-adjust RNGs with a given seed.

Parameters:
  • seed (int) – Base seed for RNGs.

  • ranked (bool) – Adjust seed based on process rank.

__enter__()#

Save current RNG states.

__exit__(exc_type, exc_value, traceback)#

Restore RNG states on context exit.