core.process_groups_config#

Dataclasses for organizing model parallelism and gradient communication process groups.

Module Contents#

Classes#

ProcessGroupHelperMeta

Metaclass to protect virtual_pipeline_model_parallel_size from direct assignment.

ProcessGroupCollection

Unified process group collection for transformer model parallelism, gradient communication, and finalization.

MultiModuleProcessGroupCollection

Process group collection for multi-module pipelines.

API#

class core.process_groups_config.ProcessGroupHelperMeta#

Bases: type

Metaclass to protect virtual_pipeline_model_parallel_size from direct assignment.

__setattr__(name, value)#
class core.process_groups_config.ProcessGroupCollection(**kwargs)#

Unified process group collection for transformer model parallelism, gradient communication, and finalization.

Fields use init=False and must be set after instance creation.

Parameters:
  • Groups (# Data Parallelism)

  • tp – Tensor parallel process group

  • pp – Pipeline parallel process group

  • mp – Model parallel group (tensor + pipeline)

  • embd – Embedding process group

  • pos_embd – Position embedding process group

  • cp – Context parallel process group

  • tp_cp – Tensor and context parallel group

  • hcp – Hierarchical context parallel groups

  • ep – Expert model parallel group

  • expt_tp – Expert tensor parallel group

  • tp_ep – Tensor and expert parallel group

  • tp_ep_pp – Tensor, expert, and pipeline parallel group

  • Groups

  • dp – Data parallel process group

  • dp_cp – Data and context parallel group

  • expt_dp – Expert data parallel group

  • intra_dp_cp – Intra partial data parallel group

  • intra_expt_dp – Intra partial expert data parallel group

  • inter_dist_opt – Inter distributed optimizer instance group

.. rubric:: Example

Create instance and set needed process groups#

pgs = ProcessGroupCollection() pgs.tp = tp_group pgs.pp = pp_group pgs.dp = dp_group

Pass to model components#

model = TransformerModel(…, pg_collection=pgs) ddp_model = DistributedDataParallel(…, pg_collection=pgs) finalize_model_grads(…, pg_collection=pgs)

Initialization

tp: torch.distributed.ProcessGroup#

‘field(…)’

pp: torch.distributed.ProcessGroup#

‘field(…)’

mp: torch.distributed.ProcessGroup#

‘field(…)’

embd: torch.distributed.ProcessGroup#

‘field(…)’

pos_embd: torch.distributed.ProcessGroup#

‘field(…)’

cp: torch.distributed.ProcessGroup#

‘field(…)’

tp_cp: torch.distributed.ProcessGroup#

‘field(…)’

hcp: List[torch.distributed.ProcessGroup]#

‘field(…)’

ep: torch.distributed.ProcessGroup#

‘field(…)’

expt_tp: torch.distributed.ProcessGroup#

‘field(…)’

tp_ep: torch.distributed.ProcessGroup#

‘field(…)’

tp_ep_pp: torch.distributed.ProcessGroup#

‘field(…)’

tp_dp_cp: torch.distributed.ProcessGroup#

‘field(…)’

dp: torch.distributed.ProcessGroup#

‘field(…)’

dp_cp: torch.distributed.ProcessGroup#

‘field(…)’

expt_dp: torch.distributed.ProcessGroup#

‘field(…)’

intra_dp_cp: torch.distributed.ProcessGroup#

‘field(…)’

intra_expt_dp: torch.distributed.ProcessGroup#

‘field(…)’

inter_dist_opt: torch.distributed.ProcessGroup#

‘field(…)’

intra_dist_opt: torch.distributed.ProcessGroup#

‘field(…)’

__repr__()#

Return a concise representation showing which process groups exist and their sizes.

classmethod use_mpu_process_groups(
required_pgs: Optional[List[str]] = None,
)#

Use the default process groups from parallel_state.

Parameters:

required_pgs (List[str], optional) – List of process group names to initialize. If None, pull all default process groups. Each string should correspond to one of the dataclass process group attributes.

static setup_process_groups_for_optimizer(
pg_collection: Optional[core.process_groups_config.ProcessGroupCollection],
model_chunks: List,
use_gloo_process_groups: bool = True,
)#

Helper method to set up process groups for optimizer and DDP with proper validation and fallbacks.

Parameters:
  • pg_collection – Optional process group collection. If None, uses parallel_state groups.

  • model_chunks – List of model chunks to extract configuration from.

  • use_gloo_process_groups – Whether to set up gloo process groups.

Returns:

  • dp_group: Data parallel group

  • dp_cp_group: Data parallel with context parallel group

  • intra_dp_cp_group: Intra data parallel with context parallel group

  • expt_dp_group: Expert data parallel group

  • intra_expt_dp_group: Intra expert data parallel group

  • mp_group: Model parallel group

  • expt_tp_pp_group: Expert tensor-model-pipeline parallel group

  • inter_dist_opt_group: Inter distributed optimizer group (may be None)

  • intra_dist_opt_group: Intra distributed optimizer group (may be None)

  • intra_dp_cp_group_gloo: Gloo version of intra_dp_cp_group (may be None)

  • intra_expt_dp_group_gloo: Gloo version of intra_expt_dp_group (may be None)

Return type:

Dictionary containing all required process groups

static setup_process_groups_for_ddp(
pg_collection: Optional[core.process_groups_config.ProcessGroupCollection],
config,
ddp_config,
)#

Helper method to set up process groups for DDP with proper validation and fallbacks.

Parameters:
  • pg_collection – Optional process group collection. If None, uses parallel_state groups.

  • config – Model config to extract context_parallel_size from.

  • ddp_config – DDP config to extract num_distributed_optimizer_instances from.

Returns:

Dictionary containing all required process groups for DDP.

class core.process_groups_config.MultiModuleProcessGroupCollection#

Process group collection for multi-module pipelines.

Used when a rank participates in multiple modules (e.g., colocated encoder + LLM). The language_model_module_name identifies which module is the language model (used for CP size extraction, loss computation, and other LLM-specific operations).

.. attribute:: module_pgs

Dict mapping module names to ProcessGroupCollection objects

.. attribute:: language_model_module_name

Key identifying the language model module (None if no LLM on this rank)

.. rubric:: Example

Colocated rank with encoder and LLM#

pg_collection = MultiModuleProcessGroupCollection( module_pgs={“encoder”: encoder_pg, “llm”: llm_pg}, language_model_module_name=”llm” )

Rank with dual encoders (no LLM)#

pg_collection = MultiModuleProcessGroupCollection( module_pgs={“encoder_1”: encoder_1_pg, “encoder_2”: encoder_2_pg}, language_model_module_name=None )

Single module (can also use ProcessGroupCollection directly)#

pg_collection = MultiModuleProcessGroupCollection( module_pgs={“llm”: llm_pg}, language_model_module_name=”llm” )

Usage#

cp_size = pg_collection.get_language_model_cp_size() encoder_pg = pg_collection[“encoder_1”] # Dict-like access has_llm = pg_collection.has_language_model()

module_pgs: Dict[str, core.process_groups_config.ProcessGroupCollection]#

None

language_model_module_name: Optional[str]#

None

__post_init__()#
get_language_model_collection() core.process_groups_config.ProcessGroupCollection#

Get the language model’s process group collection.

Returns:

ProcessGroupCollection for the language model.

Raises:

ValueError – If no language model is specified for this collection.

get_language_model_cp_size() int#

Get context parallel size for the language model.

Returns:

Context parallel size for the language model.

Raises:

ValueError – If no language model is specified for this collection.

has_language_model() bool#

Check if this rank has a language model.

Returns:

True if this rank has a language model, False otherwise.

get_module_collection(
module_name: str,
) core.process_groups_config.ProcessGroupCollection#

Get process group collection for a specific module.

Parameters:

module_name – Name of the module.

Returns:

ProcessGroupCollection for the specified module.

Raises:

ValueError – If module_name is not found in collections.

__len__()#

Return the number of modules in this wrapper.

__getitem__(module_name: str)#

Get process group collection for a module using dict-like access.

__iter__()#

Iterate over all process group collections.

keys()#

Return module names.

values()#

Return process group collections.

items()#

Return (module_name, collection) pairs.

__repr__()#

Return a concise representation showing modules and their language model status.