nemo_rl.models.megatron.common#

Module Contents#

Functions#

_round_up_to_multiple

broadcast_tensor

Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.

get_moe_metrics

Returns Mixture of Experts (MoE) auxiliary-loss metrics.

API#

nemo_rl.models.megatron.common._round_up_to_multiple(value: int, multiple: int) int#
nemo_rl.models.megatron.common.broadcast_tensor(
tensor: torch.Tensor | None,
src_rank: int,
group: torch.distributed.ProcessGroup,
) torch.Tensor#

Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.

Handles the case where the input tensor might be None on non-source ranks. If the input tensor is provided on non-source ranks, it must have the correct shape and dtype matching the tensor on the source rank.

Parameters:
  • tensor – The tensor to broadcast on the source rank. Can be None on non-source ranks (will be created with correct shape/dtype). If not None on non-source ranks, it’s used as the buffer for the broadcast and must match the source tensor’s metadata.

  • src_rank (int) – The global rank of the source process.

  • group – The process group for communication.

Returns:

The broadcasted tensor. On non-source ranks, this will be the tensor received from the source.

Return type:

torch.Tensor

Raises:
  • ValueError – If the tensor is None on the source rank, or if a tensor provided on a non-source rank has mismatched shape/dtype/device.

  • TypeError – If broadcasting metadata fails (e.g., due to pickling issues).

nemo_rl.models.megatron.common.get_moe_metrics(
loss_scale: float,
total_loss_dict: Optional[dict] = None,
per_layer_logging: bool = False,
) dict[str, Any]#

Returns Mixture of Experts (MoE) auxiliary-loss metrics.

This function reduces MoE auxiliary losses across ranks, aggregates them, and returns a dictionary of metrics.

Parameters:
  • loss_scale – Scale factor to apply to each auxiliary loss (e.g., 1/num_microbatches).

  • total_loss_dict – If provided, accumulate means into this dict (by name).

  • per_layer_logging – If True, include per-layer values in the returned dict.

Returns:

A flat dict of aggregated metrics. For each aux loss name, the mean value is returned under the same key (e.g., “load_balancing_loss”). If per_layer_logging is True, per-layer values are returned under keys of the form “moe/{name}layer{i}”.

Return type:

dict[str, Any]