nemo_rl.models.megatron.pipeline_parallel#

Pipeline parallel utilities for Megatron models.

Module Contents#

Functions#

broadcast_obj_from_pp_rank

Broadcast an object across pipeline parallel ranks.

broadcast_loss_metrics_from_last_stage

Broadcast loss metrics from the last pipeline stage to all stages.

broadcast_tensors_from_last_stage

Broadcast multiple tensors from the last pipeline stage to all stages.

API#

nemo_rl.models.megatron.pipeline_parallel.broadcast_obj_from_pp_rank(obj: Any) Any#

Broadcast an object across pipeline parallel ranks.

This utility function handles broadcasting an object from the rank that owns it to all other pipeline parallel ranks. If only one rank has the object (non-None), it will be broadcast to all other ranks.

Parameters:

obj – The object to broadcast. Can be None on ranks that don’t own it.

Returns:

The object on all ranks (either the original or the broadcast copy).

Raises:

ValueError – If the object doesn’t exist on any pipeline parallel rank.

nemo_rl.models.megatron.pipeline_parallel.broadcast_loss_metrics_from_last_stage(
loss_metrics: Optional[list] = None,
) list#

Broadcast loss metrics from the last pipeline stage to all stages.

This utility handles the common pattern where loss computation happens on the last pipeline stage and needs to be broadcast to all other stages.

Parameters:

loss_metrics – List of loss metrics if on last stage, None otherwise

Returns:

List of loss metrics on all ranks

nemo_rl.models.megatron.pipeline_parallel.broadcast_tensors_from_last_stage(
tensors: dict[str, Optional[torch.Tensor]],
) dict[str, torch.Tensor]#

Broadcast multiple tensors from the last pipeline stage to all stages.

Parameters:
  • tensors – Dictionary mapping tensor names to tensors (None on non-last stages)

  • pp_group – Pipeline parallel group (auto-detected if None)

Returns:

Dictionary of broadcasted tensors on all ranks