nemo_rl.models.megatron.pipeline_parallel#
Pipeline parallel utilities for Megatron models.
Module Contents#
Functions#
Broadcast an object across pipeline parallel ranks. |
|
Broadcast loss metrics from the last pipeline stage to all stages. |
|
Broadcast multiple tensors from the last pipeline stage to all stages. |
API#
- nemo_rl.models.megatron.pipeline_parallel.broadcast_obj_from_pp_rank(obj: Any) Any#
Broadcast an object across pipeline parallel ranks.
This utility function handles broadcasting an object from the rank that owns it to all other pipeline parallel ranks. If only one rank has the object (non-None), it will be broadcast to all other ranks.
- Parameters:
obj – The object to broadcast. Can be None on ranks that don’t own it.
- Returns:
The object on all ranks (either the original or the broadcast copy).
- Raises:
ValueError – If the object doesn’t exist on any pipeline parallel rank.
- nemo_rl.models.megatron.pipeline_parallel.broadcast_loss_metrics_from_last_stage(
- loss_metrics: Optional[list] = None,
Broadcast loss metrics from the last pipeline stage to all stages.
This utility handles the common pattern where loss computation happens on the last pipeline stage and needs to be broadcast to all other stages.
- Parameters:
loss_metrics – List of loss metrics if on last stage, None otherwise
- Returns:
List of loss metrics on all ranks
- nemo_rl.models.megatron.pipeline_parallel.broadcast_tensors_from_last_stage(
- tensors: dict[str, Optional[torch.Tensor]],
Broadcast multiple tensors from the last pipeline stage to all stages.
- Parameters:
tensors – Dictionary mapping tensor names to tensors (None on non-last stages)
pp_group – Pipeline parallel group (auto-detected if None)
- Returns:
Dictionary of broadcasted tensors on all ranks