nemo_automodel.components.training.rng

View as Markdown

Module Contents

Classes

NameDescription
RNGStateSnapshot of Python, NumPy, Torch, and CUDA RNG states.
ScopedRNGContext manager for reproducible RNG states across random, NumPy, and PyTorch.
StatefulRNGRNG manager for reproducible RNG states across random, NumPy, and PyTorch.

Functions

NameDescription
_get_rng_stateGet current RNG states.
_restore_rng_stateRestore RNG states from a saved state.
init_all_rngInitialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.

API

class nemo_automodel.components.training.rng.RNGState(
random_rng_state: tuple,
np_rng_state: tuple,
torch_rng_state: torch.Tensor,
cuda_rng_state: torch.Tensor
)
Dataclass

Snapshot of Python, NumPy, Torch, and CUDA RNG states.

cuda_rng_state
Tensor
np_rng_state
tuple
random_rng_state
tuple
torch_rng_state
Tensor
class nemo_automodel.components.training.rng.ScopedRNG(
seed: int = 95050,
ranked: bool = False
)

Context manager for reproducible RNG states across random, NumPy, and PyTorch.

nemo_automodel.components.training.rng.ScopedRNG.__enter__()

Save current RNG states.

nemo_automodel.components.training.rng.ScopedRNG.__exit__(
exc_type,
exc_value,
traceback
)

Restore RNG states on context exit.

class nemo_automodel.components.training.rng.StatefulRNG(
seed: int,
ranked: bool = False
)

RNG manager for reproducible RNG states across random, NumPy, and PyTorch.

nemo_automodel.components.training.rng.StatefulRNG.load_state_dict(
state
)

Restore RNG states from a saved state.

Parameters:

state
dict

RNG states as returned by state_dict().

nemo_automodel.components.training.rng.StatefulRNG.state_dict()

Get current RNG states.

Returns:

RNG states for random, NumPy, and PyTorch.

nemo_automodel.components.training.rng._get_rng_state()

Get current RNG states.

Returns:

RNG states for random, NumPy, and PyTorch.

nemo_automodel.components.training.rng._restore_rng_state(
state
)

Restore RNG states from a saved state.

Parameters:

state
dict

RNG states as returned by state_dict().

nemo_automodel.components.training.rng.init_all_rng(
seed: int,
ranked: bool = False
)

Initialize RNGs for Python, NumPy, and PyTorch (incl. CUDA) with a seed.

Parameters:

seed
int

Base seed value.

ranked
boolDefaults to False

Adjust seed by process rank if True.