core.pipeline_parallel.utils#

Module Contents#

Classes#

NoopScheduleNode

A placeholder node in the computation graph that simply passes through inputs and outputs.

ScheduleNode

Base node for fine-grained scheduling.

AbstractSchedulePlan

To use combined 1f1b, model must implement build_schedule_plan while take the same signature as model forward but return an instance of AbstractSchedulePlan

Functions#

is_pp_first_stage

Return True if in the first pipeline model-parallel stage, False otherwise.

is_pp_last_stage

Return True if in the last pipeline-model-parallel stage, False otherwise.

is_vp_first_stage

Return True if in the first virtual pipeline model-parallel stage, False otherwise.

is_vp_last_stage

Return True if in the last virtual pipeline model-parallel stage, False otherwise.

get_pp_first_rank

Return the global rank of the first rank in the pipeline parallel group.

get_pp_last_rank

Return the global rank of the last rank in the pipeline parallel group.

get_pp_next_rank

Return the global rank of the next rank in the pipeline parallel group, or None if last stage.

get_pp_prev_rank

Return the global rank of the previous rank in the pipeline parallel group, or None if first stage.

make_viewless

Make_viewless util func

stream_acquire_context

Stream acquire context

set_streams

Set the streams for communication and computation

get_comp_stream

Get the stream for computation

get_comm_stream

Get the stream for communication

Data#

API#

core.pipeline_parallel.utils.is_pp_first_stage(pp_group: torch.distributed.ProcessGroup)#

Return True if in the first pipeline model-parallel stage, False otherwise.

core.pipeline_parallel.utils.is_pp_last_stage(pp_group: torch.distributed.ProcessGroup)#

Return True if in the last pipeline-model-parallel stage, False otherwise.

core.pipeline_parallel.utils.is_vp_first_stage(vp_stage: int, vp_size: int | None)#

Return True if in the first virtual pipeline model-parallel stage, False otherwise.

core.pipeline_parallel.utils.is_vp_last_stage(vp_stage: int, vp_size: int | None)#

Return True if in the last virtual pipeline model-parallel stage, False otherwise.

core.pipeline_parallel.utils.get_pp_first_rank(pp_group: torch.distributed.ProcessGroup)#

Return the global rank of the first rank in the pipeline parallel group.

core.pipeline_parallel.utils.get_pp_last_rank(pp_group: torch.distributed.ProcessGroup)#

Return the global rank of the last rank in the pipeline parallel group.

core.pipeline_parallel.utils.get_pp_next_rank(pp_group: torch.distributed.ProcessGroup)#

Return the global rank of the next rank in the pipeline parallel group, or None if last stage.

core.pipeline_parallel.utils.get_pp_prev_rank(pp_group: torch.distributed.ProcessGroup)#

Return the global rank of the previous rank in the pipeline parallel group, or None if first stage.

core.pipeline_parallel.utils.make_viewless(e)#

Make_viewless util func

core.pipeline_parallel.utils.stream_acquire_context(stream, event)#

Stream acquire context

class core.pipeline_parallel.utils.NoopScheduleNode#

A placeholder node in the computation graph that simply passes through inputs and outputs.

This class is used as a no-op node in the scheduling system when a real computation node is not needed but the interface must be maintained (e.g., dense layer doesn’t need moe_dispatch and moe_combine). It simply returns its inputs unchanged in both forward and backward passes.

forward(inputs)#

Passes through inputs unchanged in the forward pass.

backward(outgrads)#

Passes through gradients unchanged in the backward pass.

class core.pipeline_parallel.utils.ScheduleNode(
forward_func: Callable,
stream: torch.cuda.Stream,
event: torch.cuda.Event,
backward_func: Optional[Callable] = None,
free_input: bool = False,
name: str = 'schedule_node',
)#

Base node for fine-grained scheduling.

This class represents a computational node in the pipeline schedule. It handles the execution of forward and backward operations on a stream.

Initialization

Initialize a schedule node.

Parameters:
  • forward_func (callable) – Function to execute during the forward pass.

  • stream (torch.cuda.Stream) –

    The CUDA stream for this node’s computation. This can be either a ‘compute’ stream or a ‘communicate’ stream.

    • ’compute’ stream: Used for computational nodes like attention and experts.

    • ’communicate’ stream: Used for nodes that handle token communication, such as token dispatch and combine operations in MoE layers.

  • event (torch.cuda.Event) – The CUDA event used for synchronization. Each microbatch within a model chunk shares the same event, which is used to manage dependencies between nodes operating on different streams.

  • backward_func (callable, optional) – Function for the backward pass.

  • free_input (bool) – Flag to indicate if the input should be freed after the forward pass.

  • name (str) – Name of the node for debugging purposes.

default_backward_func(outputs, output_grad)#

Default backward function

forward(inputs=())#

Schedule node forward

_forward(*inputs)#
get_output()#

Get the forward output

backward(output_grad)#

Schedule node backward

_backward(*output_grad)#
get_grad()#

Get the grad of inputs

_release_state()#

Clear the state of the node

class core.pipeline_parallel.utils.AbstractSchedulePlan#

Bases: abc.ABC

To use combined 1f1b, model must implement build_schedule_plan while take the same signature as model forward but return an instance of AbstractSchedulePlan

abstractmethod static run(
f_schedule_plan,
b_schedule_plan,
grad=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
)#

run() is the protocol between our schedule logic and model, which is used to schedule the forward and backward schedule plans for the models.

core.pipeline_parallel.utils._COMP_STREAM#

None

core.pipeline_parallel.utils._COMM_STREAM#

None

core.pipeline_parallel.utils.set_streams(comp_stream=None, comm_stream=None)#

Set the streams for communication and computation

core.pipeline_parallel.utils.get_comp_stream()#

Get the stream for computation

core.pipeline_parallel.utils.get_comm_stream()#

Get the stream for communication