nemo_rl.distributed.collectives#

Module Contents#

Functions#

rebalance_nd_tensor

Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor.

gather_jagged_object_lists

Gathers jagged lists of picklable objects from all ranks and flattens them into a single list.

API#

nemo_rl.distributed.collectives.rebalance_nd_tensor(
tensor: torch.Tensor,
group: Optional[torch.distributed.ProcessGroup] = None,
)[source]#

Takes tensors with variable leading sizes (at dim=0) and stacks them into a single tensor.

This function handles the case where different GPUs have tensors with different batch sizes and combines them into a single balanced tensor across all ranks.

For example, with 3 GPUs: GPU0: tensor of shape [3, D] GPU1: tensor of shape [5, D] GPU2: tensor of shape [2, D]

After rebalancing: All GPUs will have the same tensor of shape [10, D] (3+5+2=10)

NOTE: assumes all other (i.e., non-zero) dimensions are equal.

nemo_rl.distributed.collectives.gather_jagged_object_lists(
local_objects: list,
group: Optional[torch.distributed.ProcessGroup] = None,
)[source]#

Gathers jagged lists of picklable objects from all ranks and flattens them into a single list.

This function handles the case where different GPUs have lists of different lengths and combines them into a single list containing all objects from all ranks.

For example, with 3 GPUs: GPU0: [obj0, obj1] GPU1: [obj2, obj3, obj4] GPU2: [obj5]

After gathering: All GPUs will have: [obj0, obj1, obj2, obj3, obj4, obj5]

WARNING: synchronous

Parameters:
  • local_objects – List of objects to gather from current rank

  • group – Optional process group

Returns:

Flattened list of all objects from all ranks in order [rank0, rank1, …]