core.pipeline_parallel.utils#
Module Contents#
Classes#
A placeholder node in the computation graph that simply passes through inputs and outputs. |
|
Base node for fine-grained scheduling. |
|
To use combined 1f1b, model must implement build_schedule_plan while take the same signature as model forward but return an instance of AbstractSchedulePlan |
Functions#
Return True if in the first pipeline model-parallel stage, False otherwise. |
|
Return True if in the last pipeline-model-parallel stage, False otherwise. |
|
Return True if in the first virtual pipeline model-parallel stage, False otherwise. |
|
Return True if in the last virtual pipeline model-parallel stage, False otherwise. |
|
Return the global rank of the first rank in the pipeline parallel group. |
|
Return the global rank of the last rank in the pipeline parallel group. |
|
Return the global rank of the next rank in the pipeline parallel group, or None if last stage. |
|
Return the global rank of the previous rank in the pipeline parallel group, or None if first stage. |
|
Make_viewless util func |
|
Stream acquire context |
|
Set the streams for communication and computation |
|
Get the stream for computation |
|
Get the stream for communication |
Data#
API#
- core.pipeline_parallel.utils.is_pp_first_stage(pp_group: torch.distributed.ProcessGroup)#
Return True if in the first pipeline model-parallel stage, False otherwise.
- core.pipeline_parallel.utils.is_pp_last_stage(pp_group: torch.distributed.ProcessGroup)#
Return True if in the last pipeline-model-parallel stage, False otherwise.
- core.pipeline_parallel.utils.is_vp_first_stage(vp_stage: int, vp_size: int | None)#
Return True if in the first virtual pipeline model-parallel stage, False otherwise.
- core.pipeline_parallel.utils.is_vp_last_stage(vp_stage: int, vp_size: int | None)#
Return True if in the last virtual pipeline model-parallel stage, False otherwise.
- core.pipeline_parallel.utils.get_pp_first_rank(pp_group: torch.distributed.ProcessGroup)#
Return the global rank of the first rank in the pipeline parallel group.
- core.pipeline_parallel.utils.get_pp_last_rank(pp_group: torch.distributed.ProcessGroup)#
Return the global rank of the last rank in the pipeline parallel group.
- core.pipeline_parallel.utils.get_pp_next_rank(pp_group: torch.distributed.ProcessGroup)#
Return the global rank of the next rank in the pipeline parallel group, or None if last stage.
- core.pipeline_parallel.utils.get_pp_prev_rank(pp_group: torch.distributed.ProcessGroup)#
Return the global rank of the previous rank in the pipeline parallel group, or None if first stage.
- core.pipeline_parallel.utils.make_viewless(e)#
Make_viewless util func
- core.pipeline_parallel.utils.stream_acquire_context(stream, event)#
Stream acquire context
- class core.pipeline_parallel.utils.NoopScheduleNode#
A placeholder node in the computation graph that simply passes through inputs and outputs.
This class is used as a no-op node in the scheduling system when a real computation node is not needed but the interface must be maintained (e.g., dense layer doesn’t need moe_dispatch and moe_combine). It simply returns its inputs unchanged in both forward and backward passes.
- forward(inputs)#
Passes through inputs unchanged in the forward pass.
- backward(outgrads)#
Passes through gradients unchanged in the backward pass.
- class core.pipeline_parallel.utils.ScheduleNode(
- forward_func: Callable,
- stream: torch.cuda.Stream,
- event: torch.cuda.Event,
- backward_func: Optional[Callable] = None,
- free_input: bool = False,
- name: str = 'schedule_node',
Base node for fine-grained scheduling.
This class represents a computational node in the pipeline schedule. It handles the execution of forward and backward operations on a stream.
Initialization
Initialize a schedule node.
- Parameters:
forward_func (callable) – Function to execute during the forward pass.
stream (torch.cuda.Stream) –
The CUDA stream for this node’s computation. This can be either a ‘compute’ stream or a ‘communicate’ stream.
’compute’ stream: Used for computational nodes like attention and experts.
’communicate’ stream: Used for nodes that handle token communication, such as token dispatch and combine operations in MoE layers.
event (torch.cuda.Event) – The CUDA event used for synchronization. Each microbatch within a model chunk shares the same event, which is used to manage dependencies between nodes operating on different streams.
backward_func (callable, optional) – Function for the backward pass.
free_input (bool) – Flag to indicate if the input should be freed after the forward pass.
name (str) – Name of the node for debugging purposes.
- default_backward_func(outputs, output_grad)#
Default backward function
- forward(inputs=())#
Schedule node forward
- _forward(*inputs)#
- get_output()#
Get the forward output
- backward(output_grad)#
Schedule node backward
- _backward(*output_grad)#
- get_grad()#
Get the grad of inputs
- _release_state()#
Clear the state of the node
- class core.pipeline_parallel.utils.AbstractSchedulePlan#
Bases:
abc.ABCTo use combined 1f1b, model must implement build_schedule_plan while take the same signature as model forward but return an instance of AbstractSchedulePlan
- abstractmethod static run(
- f_schedule_plan,
- b_schedule_plan,
- grad=None,
- pre_forward=None,
- pre_backward=None,
- post_forward=None,
- post_backward=None,
run() is the protocol between our schedule logic and model, which is used to schedule the forward and backward schedule plans for the models.
- core.pipeline_parallel.utils._COMP_STREAM#
None
- core.pipeline_parallel.utils._COMM_STREAM#
None
- core.pipeline_parallel.utils.set_streams(comp_stream=None, comm_stream=None)#
Set the streams for communication and computation
- core.pipeline_parallel.utils.get_comp_stream()#
Get the stream for computation
- core.pipeline_parallel.utils.get_comm_stream()#
Get the stream for communication