core.tensor_parallel.random#

Module Contents#

Classes#

CudaRNGStatesTracker

Tracker for the cuda RNG states.

CheckpointFunction

Checkpoint Function

CheckpointWithoutOutputFunction

Checkpoint Function Helper for CheckpointWithouOutput. Save context for recompute.

CheckpointWithoutOutput

Checkpoint a model or part of the model and release the output.

Functions#

_get_cuda_rng_state

Return the random number generator state of the specified GPU.

_set_cuda_rng_state

Sets the random number generator state of the current GPU.

get_expert_parallel_rng_tracker_name

Get the expert parallel rng tracker name

get_data_parallel_rng_tracker_name

Get the data parallel rng tracker name

initialize_rng_tracker

Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.

get_cuda_rng_tracker

Get cuda rng tracker.

get_all_rng_states

Returns all generator states used by the current CudaRNGStatesTracker.

model_parallel_cuda_manual_seed

Initialize model parallel cuda seed.

_get_all_rng_states

Get all the rng states.

_set_all_rng_states

Set all the rng states.

_fork_rng

Fork the rng state.

checkpoint

Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.

Data#

API#

core.tensor_parallel.random._MODEL_PARALLEL_RNG_TRACKER_NAME#

‘model-parallel-rng’

core.tensor_parallel.random._EXPERT_PARALLEL_RNG_TRACKER_NAME#

‘expert-parallel-rng’

core.tensor_parallel.random._DATA_PARALLEL_RNG_TRACKER_NAME#

‘data-parallel-rng’

core.tensor_parallel.random._get_cuda_rng_state(
device: Union[int, str, torch.device] = 'cuda',
clone: bool = False,
graph_safe: bool = False,
) torch.Tensor#

Return the random number generator state of the specified GPU.

Parameters:
  • device (int) – The gpu to retrieve the rng state

  • clone (bool) – Whether to also clone the retrieved RNG state

  • graph_safe (bool) – Get the rng state in a graph safe manner.

This function is adapted from torch.cuda.random.get_rng_state()

core.tensor_parallel.random._set_cuda_rng_state(
new_state: torch.Tensor,
device: int = -1,
graph_safe: bool = False,
)#

Sets the random number generator state of the current GPU.

Parameters:
  • new_state (torch.ByteTensor) – The desired state

  • device (int) – The gpu to retrieve the rng state

  • graph_safe (bool) – Set the rng state in a graph safe manner.

This function is adapted from PyTorch repo (torch.cuda.set_rng_state) with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases.

core.tensor_parallel.random.get_expert_parallel_rng_tracker_name()#

Get the expert parallel rng tracker name

core.tensor_parallel.random.get_data_parallel_rng_tracker_name()#

Get the data parallel rng tracker name

class core.tensor_parallel.random.CudaRNGStatesTracker(
use_cudagraphable_rng=False,
is_inference_rng_tracker=False,
)#

Tracker for the cuda RNG states.

Using the add method, a cuda rng state is initialized based on the input seed and is assigned to name. Later, by forking the rng state, we can perform operations and return to our starting cuda state.

Initialization

is_initialized()#

Checks if the internal RNG state has been set wirth set_states().

reset()#

Set to the initial state (no tracker).

get_states()#

Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.

set_states(states)#

Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.

add(name, seed)#

Track the rng state.

fork(name=_MODEL_PARALLEL_RNG_TRACKER_NAME)#

Fork the cuda rng state, perform operations, and exit with the original state.

core.tensor_parallel.random._CUDA_RNG_STATE_TRACKER#

None

core.tensor_parallel.random._CUDA_RNG_STATE_TRACKER_INITIALIZED#

False

core.tensor_parallel.random.initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
force_reset: bool = False,
)#

Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.

core.tensor_parallel.random.get_cuda_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
)#

Get cuda rng tracker.

core.tensor_parallel.random.get_all_rng_states()#

Returns all generator states used by the current CudaRNGStatesTracker.

core.tensor_parallel.random.model_parallel_cuda_manual_seed(
seed: int,
te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
tp_rank: Optional[int] = None,
ep_rank: Optional[int] = None,
etp_rank: Optional[int] = None,
force_reset_rng: bool = False,
)#

Initialize model parallel cuda seed.

This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Three set of RNG states are tracked: default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model parallel groups. This is used for example for dropout in the non-tensor-model-parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. expert-parallel-seed: This state is only used for the expert layer of MoE models. It is different among expert-tensor and expert-model parallel GPUs, and the same across expert-data parallel groups.

core.tensor_parallel.random._get_all_rng_states()#

Get all the rng states.

core.tensor_parallel.random._set_all_rng_states(
cpu_rng_state,
cuda_rng_state,
cuda_rng_state_tracker,
)#

Set all the rng states.

core.tensor_parallel.random._fork_rng()#

Fork the rng state.

class core.tensor_parallel.random.CheckpointFunction#

Bases: torch.autograd.Function

Checkpoint Function

This function is adapted from torch.utils.checkpoint with two main changes:

  1. torch.cuda.set_rng_state is replaced with _set_cuda_rng_state

  2. the states in the model parallel tracker are also properly tracked/set/reset.

static forward(ctx, run_function, distribute_saved_activations, *args)#

Forward pass.

static backward(ctx, *args)#

Backward pass.

core.tensor_parallel.random.checkpoint(function, distribute_saved_activations, *args)#

Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.

class core.tensor_parallel.random.CheckpointWithoutOutputFunction#

Bases: torch.autograd.Function

Checkpoint Function Helper for CheckpointWithouOutput. Save context for recompute.

static forward(ctx, run_function, checkpoint_without_output_obj, *args)#

Forward pass.

static backward(ctx, *args)#

Backward pass.

class core.tensor_parallel.random.CheckpointWithoutOutput(fp8=False)#

Bases: object

Checkpoint a model or part of the model and release the output.

For the normal ‘checkpoint` function, the outputs of it may be saved by the following modules for their backward computation. However, the output of the checkpointed function is re-generated at recomputation, so the output store is not technically needed. This method can manually discard the output in the forward pass and restore it by recomputation in the backward pass to reduce the memory usage.

Due to the reason above, to save memory with this method, the caller should make sure that the discarded output tensors are directly saved in the following modules for backward computation.

Initialization

checkpoint(run_function, *args)#

Checkpoint function.

_recompute(_)#

Used as a hook to recompute the output.

discard_output_and_register_recompute(hook_tensor)#

Release the output tensor storages and register the recompute function as a grad hook of the hook_tensor.

Note: the caller should make sure that the output tensors are no longer used in the forward pass and the gradient of the hook_tensor is computed before the recomputed tensors are used.