core.tensor_parallel.random#
Module Contents#
Classes#
Tracker for the cuda RNG states. |
|
Checkpoint Function |
|
Checkpoint Function Helper for CheckpointWithouOutput. Save context for recompute. |
|
Checkpoint a model or part of the model and release the output. |
Functions#
Return the random number generator state of the specified GPU. |
|
Sets the random number generator state of the current GPU. |
|
Get the expert parallel rng tracker name |
|
Get the data parallel rng tracker name |
|
Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8. |
|
Get cuda rng tracker. |
|
Returns all generator states used by the current |
|
Initialize model parallel cuda seed. |
|
Get all the rng states. |
|
Set all the rng states. |
|
Fork the rng state. |
|
Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint. |
Data#
API#
- core.tensor_parallel.random._MODEL_PARALLEL_RNG_TRACKER_NAME#
‘model-parallel-rng’
- core.tensor_parallel.random._EXPERT_PARALLEL_RNG_TRACKER_NAME#
‘expert-parallel-rng’
- core.tensor_parallel.random._DATA_PARALLEL_RNG_TRACKER_NAME#
‘data-parallel-rng’
- core.tensor_parallel.random._get_cuda_rng_state(
- device: Union[int, str, torch.device] = 'cuda',
- clone: bool = False,
- graph_safe: bool = False,
Return the random number generator state of the specified GPU.
- Parameters:
device (int) – The gpu to retrieve the rng state
clone (bool) – Whether to also clone the retrieved RNG state
graph_safe (bool) – Get the rng state in a graph safe manner.
This function is adapted from torch.cuda.random.get_rng_state()
- core.tensor_parallel.random._set_cuda_rng_state(
- new_state: torch.Tensor,
- device: int = -1,
- graph_safe: bool = False,
Sets the random number generator state of the current GPU.
- Parameters:
new_state (torch.ByteTensor) – The desired state
device (int) – The gpu to retrieve the rng state
graph_safe (bool) – Set the rng state in a graph safe manner.
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases.
- core.tensor_parallel.random.get_expert_parallel_rng_tracker_name()#
Get the expert parallel rng tracker name
- core.tensor_parallel.random.get_data_parallel_rng_tracker_name()#
Get the data parallel rng tracker name
- class core.tensor_parallel.random.CudaRNGStatesTracker(
- use_cudagraphable_rng=False,
- is_inference_rng_tracker=False,
Tracker for the cuda RNG states.
Using the
addmethod, a cuda rng state is initialized based on the inputseedand is assigned toname. Later, by forking the rng state, we can perform operations and return to our starting cuda state.Initialization
- is_initialized()#
Checks if the internal RNG state has been set wirth set_states().
- reset()#
Set to the initial state (no tracker).
- get_states()#
Get rng states. Copy the dictionary so we have direct pointers to the states, not just a pointer to the dictionary.
- set_states(states)#
Set the rng states. For efficiency purposes, we do not check the size of seed for compatibility.
- add(name, seed)#
Track the rng state.
- fork(name=_MODEL_PARALLEL_RNG_TRACKER_NAME)#
Fork the cuda rng state, perform operations, and exit with the original state.
- core.tensor_parallel.random._CUDA_RNG_STATE_TRACKER#
None
- core.tensor_parallel.random._CUDA_RNG_STATE_TRACKER_INITIALIZED#
False
- core.tensor_parallel.random.initialize_rng_tracker(
- use_te_rng_tracker: bool = False,
- inference_rng_tracker: bool = False,
- use_cudagraphable_rng: bool = False,
- force_reset: bool = False,
Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.
- core.tensor_parallel.random.get_cuda_rng_tracker(
- use_te_rng_tracker: bool = False,
- inference_rng_tracker: bool = False,
- use_cudagraphable_rng: bool = False,
Get cuda rng tracker.
- core.tensor_parallel.random.get_all_rng_states()#
Returns all generator states used by the current
CudaRNGStatesTracker.
- core.tensor_parallel.random.model_parallel_cuda_manual_seed(
- seed: int,
- te_rng_tracker: bool = False,
- inference_rng_tracker: bool = False,
- use_cudagraphable_rng: bool = False,
- tp_rank: Optional[int] = None,
- ep_rank: Optional[int] = None,
- etp_rank: Optional[int] = None,
- force_reset_rng: bool = False,
Initialize model parallel cuda seed.
This function should be called after the model parallel is initialized. Also, no torch.cuda.manual_seed should be called after this function. Basically, this is replacement for that function. Three set of RNG states are tracked: default state: This is for data parallelism and is the same among a set of model parallel GPUs but different across different model parallel groups. This is used for example for dropout in the non-tensor-model-parallel regions. tensor-model-parallel state: This state is different among a set of model parallel GPUs, but the same across data parallel groups. This is used for example for dropout in model parallel regions. expert-parallel-seed: This state is only used for the expert layer of MoE models. It is different among expert-tensor and expert-model parallel GPUs, and the same across expert-data parallel groups.
- core.tensor_parallel.random._get_all_rng_states()#
Get all the rng states.
- core.tensor_parallel.random._set_all_rng_states(
- cpu_rng_state,
- cuda_rng_state,
- cuda_rng_state_tracker,
Set all the rng states.
- core.tensor_parallel.random._fork_rng()#
Fork the rng state.
- class core.tensor_parallel.random.CheckpointFunction#
Bases:
torch.autograd.FunctionCheckpoint Function
This function is adapted from torch.utils.checkpoint with two main changes:
torch.cuda.set_rng_state is replaced with
_set_cuda_rng_statethe states in the model parallel tracker are also properly tracked/set/reset.
- static forward(ctx, run_function, distribute_saved_activations, *args)#
Forward pass.
- static backward(ctx, *args)#
Backward pass.
- core.tensor_parallel.random.checkpoint(function, distribute_saved_activations, *args)#
Checkpoint a model or part of the model. This has been directly copied from torch.utils.checkpoint.
- class core.tensor_parallel.random.CheckpointWithoutOutputFunction#
Bases:
torch.autograd.FunctionCheckpoint Function Helper for CheckpointWithouOutput. Save context for recompute.
- static forward(ctx, run_function, checkpoint_without_output_obj, *args)#
Forward pass.
- static backward(ctx, *args)#
Backward pass.
- class core.tensor_parallel.random.CheckpointWithoutOutput(fp8=False)#
Bases:
objectCheckpoint a model or part of the model and release the output.
For the normal ‘checkpoint` function, the outputs of it may be saved by the following modules for their backward computation. However, the output of the checkpointed function is re-generated at recomputation, so the output store is not technically needed. This method can manually discard the output in the forward pass and restore it by recomputation in the backward pass to reduce the memory usage.
Due to the reason above, to save memory with this method, the caller should make sure that the discarded output tensors are directly saved in the following modules for backward computation.
Initialization
- checkpoint(run_function, *args)#
Checkpoint function.
- _recompute(_)#
Used as a hook to recompute the output.
- discard_output_and_register_recompute(hook_tensor)#
Release the output tensor storages and register the recompute function as a grad hook of the hook_tensor.
Note: the caller should make sure that the output tensors are no longer used in the forward pass and the gradient of the hook_tensor is computed before the recomputed tensors are used.