nemo_automodel.components.training.rng#
Module Contents#
Classes#
RNG manager for reproducible RNG states across random, NumPy, and PyTorch. |
|
Context manager for reproducible RNG states across random, NumPy, and PyTorch. |
Functions#
Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed. |
|
Get current RNG states. |
|
Restore RNG states from a saved state. |
API#
- nemo_automodel.components.training.rng.init_all_rng(seed: int, ranked: bool = False)#
Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.
- Parameters:
seed (int) – Base seed value.
ranked (bool) – Adjust seed by process rank if True.
- class nemo_automodel.components.training.rng.RNGState#
- random_rng_state: tuple#
None
- np_rng_state: tuple#
None
- torch_rng_state: torch.Tensor#
None
- cuda_rng_state: torch.Tensor#
None
- nemo_automodel.components.training.rng._get_rng_state()#
Get current RNG states.
- Returns:
RNG states for random, NumPy, and PyTorch.
- Return type:
dict
- nemo_automodel.components.training.rng._restore_rng_state(state)#
Restore RNG states from a saved state.
- Parameters:
state (dict) – RNG states as returned by state_dict().
- class nemo_automodel.components.training.rng.StatefulRNG(seed: int, ranked: bool = False)#
RNG manager for reproducible RNG states across random, NumPy, and PyTorch.
Initialization
Initialize and optionally rank-adjust RNGs with a given seed.
- Parameters:
seed (int) – Base seed for RNGs.
ranked (bool) – Adjust seed based on process rank.
- state_dict()#
Get current RNG states.
- Returns:
RNG states for random, NumPy, and PyTorch.
- Return type:
dict
- load_state_dict(state)#
Restore RNG states from a saved state.
- Parameters:
state (dict) – RNG states as returned by state_dict().
- class nemo_automodel.components.training.rng.ScopedRNG(seed: int = 95050, ranked: bool = False)#
Context manager for reproducible RNG states across random, NumPy, and PyTorch.
Initialization
Initialize and optionally rank-adjust RNGs with a given seed.
- Parameters:
seed (int) – Base seed for RNGs.
ranked (bool) – Adjust seed based on process rank.
- __enter__()#
Save current RNG states.
- __exit__(exc_type, exc_value, traceback)#
Restore RNG states on context exit.