nemo_rl.distributed.collectives
#
Module Contents#
Functions#
Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor. |
|
Gathers jagged lists of picklable objects from all ranks and flattens them into a single list. |
API#
- nemo_rl.distributed.collectives.rebalance_nd_tensor(
- tensor: torch.Tensor,
- group: Optional[torch.distributed.ProcessGroup] = None,
Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor.
This function handles the case where different GPUs have tensors with different batch sizes and combines them into a single balanced tensor across all ranks.
For example, with 3 GPUs: GPU0: tensor of shape [3, D] GPU1: tensor of shape [5, D] GPU2: tensor of shape [2, D]
After rebalancing: All GPUs will have the same tensor of shape [10, D] (3+5+2=10)
NOTE: assumes all other (i.e., non-zero) dimensions are equal.
- nemo_rl.distributed.collectives.gather_jagged_object_lists(
- local_objects: list,
- group: Optional[torch.distributed.ProcessGroup] = None,
Gathers jagged lists of picklable objects from all ranks and flattens them into a single list.
This function handles the case where different GPUs have lists of different lengths and combines them into a single list containing all objects from all ranks.
For example, with 3 GPUs: GPU0: [obj0, obj1] GPU1: [obj2, obj3, obj4] GPU2: [obj5]
After gathering: All GPUs will have: [obj0, obj1, obj2, obj3, obj4, obj5]
WARNING: synchronous
- Parameters:
local_objects – List of objects to gather from current rank
group – Optional process group
- Returns:
Flattened list of all objects from all ranks in order [rank0, rank1, …]