nemo_rl.models.megatron.common#
Module Contents#
Functions#
Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata. |
|
Returns Mixture of Experts (MoE) auxiliary-loss metrics. |
API#
- nemo_rl.models.megatron.common._round_up_to_multiple(value: int, multiple: int) int#
- nemo_rl.models.megatron.common.broadcast_tensor(
- tensor: torch.Tensor | None,
- src_rank: int,
- group: torch.distributed.ProcessGroup,
Broadcasts a tensor from src_rank to all ranks in the group using broadcast_object_list for metadata.
Handles the case where the input tensor might be None on non-source ranks. If the input tensor is provided on non-source ranks, it must have the correct shape and dtype matching the tensor on the source rank.
- Parameters:
tensor – The tensor to broadcast on the source rank. Can be None on non-source ranks (will be created with correct shape/dtype). If not None on non-source ranks, it’s used as the buffer for the broadcast and must match the source tensor’s metadata.
src_rank (int) – The global rank of the source process.
group – The process group for communication.
- Returns:
The broadcasted tensor. On non-source ranks, this will be the tensor received from the source.
- Return type:
torch.Tensor
- Raises:
ValueError – If the tensor is None on the source rank, or if a tensor provided on a non-source rank has mismatched shape/dtype/device.
TypeError – If broadcasting metadata fails (e.g., due to pickling issues).
- nemo_rl.models.megatron.common.get_moe_metrics(
- loss_scale: float,
- total_loss_dict: Optional[dict] = None,
- per_layer_logging: bool = False,
Returns Mixture of Experts (MoE) auxiliary-loss metrics.
This function reduces MoE auxiliary losses across ranks, aggregates them, and returns a dictionary of metrics.
- Parameters:
loss_scale – Scale factor to apply to each auxiliary loss (e.g., 1/num_microbatches).
total_loss_dict – If provided, accumulate means into this dict (by name).
per_layer_logging – If True, include per-layer values in the returned dict.
- Returns:
A flat dict of aggregated metrics. For each aux loss name, the mean value is returned under the same key (e.g., “load_balancing_loss”). If per_layer_logging is True, per-layer values are returned under keys of the form “moe/{name}layer{i}”.
- Return type:
dict[str, Any]