bridge.training.utils.pg_utils#
Module Contents#
Classes#
Process group collection for dist train. |
Functions#
Return the ProcessGroupCollection from a model or list of model chunks. |
API#
- bridge.training.utils.pg_utils.get_pg_collection(
- model: Union[megatron.core.transformer.MegatronModule, list[megatron.core.transformer.MegatronModule]],
Return the ProcessGroupCollection from a model or list of model chunks.
This mirrors the style of utility accessors like
get_model_config, but for retrieving the communication process group collection from the model wrapper.- Parameters:
model β A MegatronModule or a list of MegatronModule chunks.
- Returns:
The modelβs process group collection.
- Return type:
ProcessGroupCollection
- class bridge.training.utils.pg_utils.DistTrainProcessGroupCollection(
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- language_model_module_name: Optional[str] = None,
Bases:
megatron.core.process_groups_config.ProcessGroupCollectionProcess group collection for dist train.
Initialization
Initialize the dist train process group collection.
- Parameters:
pg_collection β The process group collection.
language_model_module_name β The name of the language model module.
- get_language_model_cp_size() int#
Get context parallel size for the language model.
- Returns:
Context parallel size for the language model.
- Raises:
ValueError β If no language model is specified for this collection.
- has_language_model() bool#
Check if this rank has a language model.
- Returns:
True if this rank has a language model, False otherwise.