core.process_groups_config#
Dataclasses for organizing model parallelism and gradient communication process groups.
Module Contents#
Classes#
Metaclass to protect virtual_pipeline_model_parallel_size from direct assignment. |
|
Unified process group collection for transformer model parallelism, gradient communication, and finalization. |
|
Process group collection for multi-module pipelines. |
API#
- class core.process_groups_config.ProcessGroupHelperMeta#
Bases:
typeMetaclass to protect virtual_pipeline_model_parallel_size from direct assignment.
- __setattr__(name, value)#
- class core.process_groups_config.ProcessGroupCollection(**kwargs)#
Unified process group collection for transformer model parallelism, gradient communication, and finalization.
Fields use init=False and must be set after instance creation.
- Parameters:
Groups (# Data Parallelism)
tp – Tensor parallel process group
pp – Pipeline parallel process group
mp – Model parallel group (tensor + pipeline)
embd – Embedding process group
pos_embd – Position embedding process group
cp – Context parallel process group
tp_cp – Tensor and context parallel group
hcp – Hierarchical context parallel groups
ep – Expert model parallel group
expt_tp – Expert tensor parallel group
tp_ep – Tensor and expert parallel group
tp_ep_pp – Tensor, expert, and pipeline parallel group
Groups
dp – Data parallel process group
dp_cp – Data and context parallel group
expt_dp – Expert data parallel group
intra_dp_cp – Intra partial data parallel group
intra_expt_dp – Intra partial expert data parallel group
inter_dist_opt – Inter distributed optimizer instance group
.. rubric:: Example
Create instance and set needed process groups#
pgs = ProcessGroupCollection() pgs.tp = tp_group pgs.pp = pp_group pgs.dp = dp_group
Pass to model components#
model = TransformerModel(…, pg_collection=pgs) ddp_model = DistributedDataParallel(…, pg_collection=pgs) finalize_model_grads(…, pg_collection=pgs)
Initialization
- tp: torch.distributed.ProcessGroup#
‘field(…)’
- pp: torch.distributed.ProcessGroup#
‘field(…)’
- mp: torch.distributed.ProcessGroup#
‘field(…)’
- embd: torch.distributed.ProcessGroup#
‘field(…)’
- pos_embd: torch.distributed.ProcessGroup#
‘field(…)’
- cp: torch.distributed.ProcessGroup#
‘field(…)’
- tp_cp: torch.distributed.ProcessGroup#
‘field(…)’
- hcp: List[torch.distributed.ProcessGroup]#
‘field(…)’
- ep: torch.distributed.ProcessGroup#
‘field(…)’
- expt_tp: torch.distributed.ProcessGroup#
‘field(…)’
- tp_ep: torch.distributed.ProcessGroup#
‘field(…)’
- tp_ep_pp: torch.distributed.ProcessGroup#
‘field(…)’
- tp_dp_cp: torch.distributed.ProcessGroup#
‘field(…)’
- dp: torch.distributed.ProcessGroup#
‘field(…)’
- dp_cp: torch.distributed.ProcessGroup#
‘field(…)’
- expt_dp: torch.distributed.ProcessGroup#
‘field(…)’
- intra_dp_cp: torch.distributed.ProcessGroup#
‘field(…)’
- intra_expt_dp: torch.distributed.ProcessGroup#
‘field(…)’
- inter_dist_opt: torch.distributed.ProcessGroup#
‘field(…)’
- intra_dist_opt: torch.distributed.ProcessGroup#
‘field(…)’
- __repr__()#
Return a concise representation showing which process groups exist and their sizes.
- classmethod use_mpu_process_groups(
- required_pgs: Optional[List[str]] = None,
Use the default process groups from parallel_state.
- Parameters:
required_pgs (List[str], optional) – List of process group names to initialize. If None, pull all default process groups. Each string should correspond to one of the dataclass process group attributes.
- static setup_process_groups_for_optimizer(
- pg_collection: Optional[core.process_groups_config.ProcessGroupCollection],
- model_chunks: List,
- use_gloo_process_groups: bool = True,
Helper method to set up process groups for optimizer and DDP with proper validation and fallbacks.
- Parameters:
pg_collection – Optional process group collection. If None, uses parallel_state groups.
model_chunks – List of model chunks to extract configuration from.
use_gloo_process_groups – Whether to set up gloo process groups.
- Returns:
dp_group: Data parallel group
dp_cp_group: Data parallel with context parallel group
intra_dp_cp_group: Intra data parallel with context parallel group
expt_dp_group: Expert data parallel group
intra_expt_dp_group: Intra expert data parallel group
mp_group: Model parallel group
expt_tp_pp_group: Expert tensor-model-pipeline parallel group
inter_dist_opt_group: Inter distributed optimizer group (may be None)
intra_dist_opt_group: Intra distributed optimizer group (may be None)
intra_dp_cp_group_gloo: Gloo version of intra_dp_cp_group (may be None)
intra_expt_dp_group_gloo: Gloo version of intra_expt_dp_group (may be None)
- Return type:
Dictionary containing all required process groups
- static setup_process_groups_for_ddp(
- pg_collection: Optional[core.process_groups_config.ProcessGroupCollection],
- config,
- ddp_config,
Helper method to set up process groups for DDP with proper validation and fallbacks.
- Parameters:
pg_collection – Optional process group collection. If None, uses parallel_state groups.
config – Model config to extract context_parallel_size from.
ddp_config – DDP config to extract num_distributed_optimizer_instances from.
- Returns:
Dictionary containing all required process groups for DDP.
- class core.process_groups_config.MultiModuleProcessGroupCollection#
Process group collection for multi-module pipelines.
Used when a rank participates in multiple modules (e.g., colocated encoder + LLM). The language_model_module_name identifies which module is the language model (used for CP size extraction, loss computation, and other LLM-specific operations).
.. attribute:: module_pgs
Dict mapping module names to ProcessGroupCollection objects
.. attribute:: language_model_module_name
Key identifying the language model module (None if no LLM on this rank)
.. rubric:: Example
Colocated rank with encoder and LLM#
pg_collection = MultiModuleProcessGroupCollection( module_pgs={“encoder”: encoder_pg, “llm”: llm_pg}, language_model_module_name=”llm” )
Rank with dual encoders (no LLM)#
pg_collection = MultiModuleProcessGroupCollection( module_pgs={“encoder_1”: encoder_1_pg, “encoder_2”: encoder_2_pg}, language_model_module_name=None )
Single module (can also use ProcessGroupCollection directly)#
pg_collection = MultiModuleProcessGroupCollection( module_pgs={“llm”: llm_pg}, language_model_module_name=”llm” )
Usage#
cp_size = pg_collection.get_language_model_cp_size() encoder_pg = pg_collection[“encoder_1”] # Dict-like access has_llm = pg_collection.has_language_model()
- module_pgs: Dict[str, core.process_groups_config.ProcessGroupCollection]#
None
- language_model_module_name: Optional[str]#
None
- __post_init__()#
- get_language_model_collection() core.process_groups_config.ProcessGroupCollection#
Get the language model’s process group collection.
- Returns:
ProcessGroupCollection for the language model.
- Raises:
ValueError – If no language model is specified for this collection.
- get_language_model_cp_size() int#
Get context parallel size for the language model.
- Returns:
Context parallel size for the language model.
- Raises:
ValueError – If no language model is specified for this collection.
- has_language_model() bool#
Check if this rank has a language model.
- Returns:
True if this rank has a language model, False otherwise.
- get_module_collection(
- module_name: str,
Get process group collection for a specific module.
- Parameters:
module_name – Name of the module.
- Returns:
ProcessGroupCollection for the specified module.
- Raises:
ValueError – If module_name is not found in collections.
- __len__()#
Return the number of modules in this wrapper.
- __getitem__(module_name: str)#
Get process group collection for a module using dict-like access.
- __iter__()#
Iterate over all process group collections.
- keys()#
Return module names.
- values()#
Return process group collections.
- items()#
Return (module_name, collection) pairs.
- __repr__()#
Return a concise representation showing modules and their language model status.