bridge.training.utils.pg_utils#

Module Contents#

Classes#

DistTrainProcessGroupCollection

Process group collection for dist train.

Functions#

get_pg_collection

Return the ProcessGroupCollection from a model or list of model chunks.

API#

bridge.training.utils.pg_utils.get_pg_collection(
model: Union[megatron.core.transformer.MegatronModule, list[megatron.core.transformer.MegatronModule]],
) megatron.core.process_groups_config.ProcessGroupCollection#

Return the ProcessGroupCollection from a model or list of model chunks.

This mirrors the style of utility accessors like get_model_config, but for retrieving the communication process group collection from the model wrapper.

Parameters:

model – A MegatronModule or a list of MegatronModule chunks.

Returns:

The model’s process group collection.

Return type:

ProcessGroupCollection

class bridge.training.utils.pg_utils.DistTrainProcessGroupCollection(
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
language_model_module_name: Optional[str] = None,
)#

Bases: megatron.core.process_groups_config.ProcessGroupCollection

Process group collection for dist train.

Initialization

Initialize the dist train process group collection.

Parameters:
  • pg_collection – The process group collection.

  • language_model_module_name – The name of the language model module.

get_language_model_cp_size() int#

Get context parallel size for the language model.

Returns:

Context parallel size for the language model.

Raises:

ValueError – If no language model is specified for this collection.

has_language_model() bool#

Check if this rank has a language model.

Returns:

True if this rank has a language model, False otherwise.