core.parallel_state#

Model and data parallel groups.

Module Contents#

Classes#

RankGenerator

A class for generating rank groups for different modes of parallelism.

Functions#

get_nccl_options

Set the NCCL process group options.

update_pg_timeout

Update the timeout for all process groups or a specific process group. Synchronize the process groups before updating the timeout.

create_group

Creates a ProcessGroup.

generate_masked_orthogonal_rank_groups

Generate orthogonal parallel groups based on the parallel size and mask.

create_hierarchical_groups

Create hierarchical groups for a set of ranks. Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 … g15. If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level of sub-groups, and 4 GPUs in the last level of sub groups. The present function will create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as: 8 level-1 sub-groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 8 level-2 sub-groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 4 level-3 sub-groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]

default_embedding_ranks

Return the default ranks that constitute the stages on which the word embeddings live. For most models, these are the first and last pipeline stages.

default_position_embedding_ranks

Return the default ranks that constitute the stages on which the position embeddings live. For most models, this is only the first pipeline stage.

overwrite_nccl_comm_cfgs

Overwrite the nccl_comm_cfgs for the given pg_name with the given key_value_pair.

initialize_model_parallel

Initialize model data parallel groups.

is_initialized

Useful for code segments that may be accessed with or without mpu initialization

is_unitialized

Check if parallel state has been initialized

model_parallel_is_initialized

Check if model- and data-parallel groups are initialized.

get_model_parallel_group

Get the model-parallel group the caller rank belongs to.

get_tensor_model_parallel_group

Get the tensor-model-parallel group the caller rank belongs to.

get_pipeline_model_parallel_group

Get the pipeline-model-parallel group the caller rank belongs to.

get_data_parallel_group

Get the data-parallel group the caller rank belongs to.

get_data_parallel_group_gloo

Get the Gloo data-parallel group the caller rank belongs to.

get_context_parallel_group

Get the context-parallel group the caller rank belongs to.

get_context_parallel_global_ranks

Get all global ranks of the context-parallel group that the caller rank belongs to.

get_hierarchical_context_parallel_groups

Get the inner ring of context parallel group the caller rank belongs to.

get_embedding_group

Get the embedding group the caller rank belongs to.

get_position_embedding_group

Get the position embedding group the caller rank belongs to.

get_amax_reduction_group

Get the FP8 amax reduction group the caller rank belongs to.

get_tensor_and_data_parallel_group

Get the tensor- and data-parallel group the caller rank belongs to.

get_tensor_and_context_parallel_group

Get the tensor- and context-parallel group the caller rank belongs to.

set_tensor_model_parallel_world_size

Set the tensor-model-parallel size

set_pipeline_model_parallel_world_size

Set the pipeline-model-parallel size

set_virtual_pipeline_model_parallel_world_size

Set the pipeline-model-parallel size

get_tensor_model_parallel_world_size

Return world size for the tensor-model-parallel group.

get_pipeline_model_parallel_world_size

Return world size for the pipeline-model-parallel group.

set_tensor_model_parallel_rank

Set tensor-model-parallel rank.

set_pipeline_model_parallel_rank

Set pipeline-model-parallel rank.

get_tensor_model_parallel_rank

Return caller’s rank for the tensor-model-parallel group.

get_pipeline_model_parallel_rank

Return caller’s rank for the pipeline-model-parallel group.

is_pipeline_first_stage

Return True if in the first pipeline model-parallel stage, False otherwise.

is_pipeline_last_stage

Return True if in the last pipeline-model-parallel stage, False otherwise.

is_rank_in_embedding_group

Return true if current rank is in embedding group, False otherwise.

is_rank_in_position_embedding_group

Return true if current rank is in position embedding group, False otherwise.

get_virtual_pipeline_model_parallel_rank

Return the virtual pipeline-parallel rank.

set_virtual_pipeline_model_parallel_rank

Set the virtual pipeline-parallel rank.

get_virtual_pipeline_model_parallel_world_size

Return the virtual pipeline-parallel world size.

get_tensor_model_parallel_src_rank

Calculate the global rank corresponding to the first local rank in the tensor model parallel group.

get_model_parallel_src_rank

Calculate the global rank corresponding to the first local rank in the model parallel group.

get_data_parallel_src_rank

Calculate the global rank corresponding to the first local rank in the data parallel group.

get_pipeline_model_parallel_first_rank

Return the global rank of the first stage in the current rank’s pipeline.

get_pipeline_model_parallel_last_rank

Return the global rank of the last stage in the current rank’s pipeline.

get_pipeline_model_parallel_next_rank

Return the global rank that follows the caller in the pipeline.

get_pipeline_model_parallel_prev_rank

Return the global rank that precedes the caller in the pipeline.

get_data_parallel_world_size

Return world size for the data parallel group.

set_data_parallel_rank

Return world size for the data parallel group.

get_data_parallel_rank

Return caller’s rank in the data-parallel group.

get_context_parallel_world_size

Return world size for the context parallel group.

get_context_parallel_rank

Return caller’s rank in the context-parallel group.

get_tensor_and_context_parallel_world_size

Return world size for the tensor and context-parallel group.

get_tensor_and_context_parallel_rank

Return caller’s rank in the joint tensor-model-parallel and context-parallel group.

get_expert_model_parallel_group

Get the expert-model-parallel group the caller rank belongs to.

get_expert_model_parallel_src_rank

Calculate the global rank corresponding to the first local rank in the expert model parallel group.

get_expert_model_parallel_world_size

Return world size for the expert-model-parallel group.

set_expert_model_parallel_world_size

Sets the expert-model-parallel world size.

get_expert_model_parallel_rank

Return caller’s rank in the expert-model-parallel group.

set_expert_model_parallel_rank

Set expert-model-parallel rank.

get_expert_tensor_parallel_group

Get the expert-tensor-parallel group the caller rank belongs to.

get_expert_tensor_parallel_world_size

Return world size for the expert tensor parallel group.

set_expert_tensor_parallel_world_size

Set expert tensor model parallel size

get_expert_tensor_parallel_rank

Return my rank for the expert tensor parallel group.

set_expert_tensor_parallel_rank

Set expert tensor model parallel rank

get_expert_tensor_and_model_parallel_group

Get the expert-tensor and expert-model group the caller rank belongs to.

get_expert_tensor_and_model_parallel_world_size

Return world size for the expert model parallel group times expert tensor parallel group.

get_expert_tensor_and_model_parallel_rank

Return caller’s rank in the joint tensor- and expert-model-parallel group.

get_expert_tensor_model_pipeline_parallel_group

Get expert tensor-model-pipeline parallel group.

get_expert_data_parallel_group

Get expert data parallel group.

get_data_modulo_expert_parallel_group

[Deprecated] Get expert data parallel group.

get_expert_data_parallel_group_gloo

Get expert data parallel group-gloo.

get_expert_data_parallel_rank

Return caller’s rank in the expert data parallel group.

get_expert_data_parallel_world_size

Return world size for the expert data parallel group.

get_intra_distributed_optimizer_instance_group

Get the group of all GPUs in a distributed optimizer instance.

get_inter_distributed_optimizer_instance_group

Get the group spanning the different distributed optimizer instances. Attention and MLP/Expert share same inter-instance group, so only built inter_partial_expert_data_parallel_group, and return it at here.

_set_global_memory_buffer

Initialize global buffer.

_set_global_symmetric_memory_buffer

Initialize global buffer.

get_global_memory_buffer

Return the global GlobalMemoryBuffer object

get_global_symmetric_memory_buffer

Return the global GlobalSymmetricMemoryBuffer object

destroy_global_memory_buffer

Sets the global memory buffer to None

destroy_global_symmetric_memory_buffer

Sets the global symmetric memory buffer to None

get_all_ranks

Get caller’s rank in tensor-model-parallel, data-parallel, context-parallel, pipeline-model-parallel and expert-model-parallel groups.

destroy_model_parallel

Set the groups to none.

Data#

logger

_TENSOR_MODEL_PARALLEL_GROUP

_PIPELINE_MODEL_PARALLEL_GROUP

_MODEL_PARALLEL_GROUP

_EMBEDDING_GROUP

_POSITION_EMBEDDING_GROUP

_DATA_PARALLEL_GROUP

_DATA_PARALLEL_GROUP_GLOO

_TENSOR_AND_DATA_PARALLEL_GROUP

_EXPERT_MODEL_PARALLEL_GROUP

_EXPERT_TENSOR_PARALLEL_GROUP

_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP

_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP

_EXPERT_DATA_PARALLEL_GROUP

_EXPERT_DATA_PARALLEL_GROUP_GLOO

_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP

_INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO

_INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP

_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE

_MPU_EXPERT_MODEL_PARALLEL_RANK

_MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE

_MPU_EXPERT_TENSOR_PARALLEL_RANK

_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK

_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE

_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE

_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE

_MPU_DATA_PARALLEL_WORLD_SIZE

_MPU_DATA_PARALLEL_RANK

_MPU_TENSOR_MODEL_PARALLEL_RANK

_MPU_PIPELINE_MODEL_PARALLEL_RANK

_EMBEDDING_GLOBAL_RANKS

_POSITION_EMBEDDING_GLOBAL_RANKS

_PIPELINE_GLOBAL_RANKS

_DATA_PARALLEL_GLOBAL_RANKS

_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS

_EXPERT_MODEL_PARALLEL_RANKS

_MODEL_PARALLEL_GLOBAL_RANKS

_CONTEXT_PARALLEL_GROUP

_CONTEXT_PARALLEL_GLOBAL_RANKS

_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS

_DATA_PARALLEL_GROUP_WITH_CP

_DATA_PARALLEL_GROUP_WITH_CP_GLOO

_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP

_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP

_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO

_TENSOR_AND_CONTEXT_PARALLEL_GROUP

_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP

_INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP

_GLOBAL_MEMORY_BUFFER

_GLOBAL_SYMMETRIC_MEMORY_BUFFER

_global_process_group_list

API#

core.parallel_state.logger#

‘getLogger(…)’

core.parallel_state._TENSOR_MODEL_PARALLEL_GROUP#

None

core.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP#

None

core.parallel_state._MODEL_PARALLEL_GROUP#

None

core.parallel_state._EMBEDDING_GROUP#

None

core.parallel_state._POSITION_EMBEDDING_GROUP#

None

core.parallel_state._DATA_PARALLEL_GROUP#

None

core.parallel_state._DATA_PARALLEL_GROUP_GLOO#

None

core.parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP#

None

core.parallel_state._EXPERT_MODEL_PARALLEL_GROUP#

None

core.parallel_state._EXPERT_TENSOR_PARALLEL_GROUP#

None

core.parallel_state._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP#

None

core.parallel_state._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP#

None

core.parallel_state._EXPERT_DATA_PARALLEL_GROUP#

None

core.parallel_state._EXPERT_DATA_PARALLEL_GROUP_GLOO#

None

core.parallel_state._INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP#

None

core.parallel_state._INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO#

None

core.parallel_state._INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP#

None

core.parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE#

None

core.parallel_state._MPU_EXPERT_MODEL_PARALLEL_RANK#

None

core.parallel_state._MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE#

None

core.parallel_state._MPU_EXPERT_TENSOR_PARALLEL_RANK#

None

core.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK#

None

core.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE#

None

core.parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE#

None

core.parallel_state._MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE#

None

core.parallel_state._MPU_DATA_PARALLEL_WORLD_SIZE#

None

core.parallel_state._MPU_DATA_PARALLEL_RANK#

None

core.parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK#

None

core.parallel_state._MPU_PIPELINE_MODEL_PARALLEL_RANK#

None

core.parallel_state._EMBEDDING_GLOBAL_RANKS#

None

core.parallel_state._POSITION_EMBEDDING_GLOBAL_RANKS#

None

core.parallel_state._PIPELINE_GLOBAL_RANKS#

None

core.parallel_state._DATA_PARALLEL_GLOBAL_RANKS#

None

core.parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS#

None

core.parallel_state._EXPERT_MODEL_PARALLEL_RANKS#

None

core.parallel_state._MODEL_PARALLEL_GLOBAL_RANKS#

None

core.parallel_state._CONTEXT_PARALLEL_GROUP#

None

core.parallel_state._CONTEXT_PARALLEL_GLOBAL_RANKS#

None

core.parallel_state._HIERARCHICAL_CONTEXT_PARALLEL_GROUPS#

None

core.parallel_state._DATA_PARALLEL_GROUP_WITH_CP#

None

core.parallel_state._DATA_PARALLEL_GROUP_WITH_CP_GLOO#

None

core.parallel_state._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP#

None

core.parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP#

None

core.parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO#

None

core.parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP#

None

core.parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP#

None

core.parallel_state._INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP#

None

core.parallel_state._GLOBAL_MEMORY_BUFFER#

None

core.parallel_state._GLOBAL_SYMMETRIC_MEMORY_BUFFER#

None

core.parallel_state._global_process_group_list#

None

core.parallel_state.get_nccl_options(pg_name, nccl_comm_cfgs)#

Set the NCCL process group options.

Parameters:
  • pg_name (str) – process group name

  • nccl_comm_cfgs (dict) – nccl communicator configurations

When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.

core.parallel_state.update_pg_timeout(
timeout: datetime.timedelta,
pg: Optional[torch._C._distributed_c10d.ProcessGroup] = None,
)#

Update the timeout for all process groups or a specific process group. Synchronize the process groups before updating the timeout.

Parameters:
  • timeout (datetime.timedelta) – The timeout to set for the process group(s)

  • pg (Optional[torch._C._distributed_c10d.ProcessGroup], default=None) – The process group to update the timeout for. If None, all process groups are updated.

core.parallel_state.create_group(
ranks=None,
timeout=None,
backend=None,
pg_options=None,
use_local_synchronization=False,
group_desc=None,
)#

Creates a ProcessGroup.

core.parallel_state.generate_masked_orthogonal_rank_groups(
world_size: int,
parallel_size: List[int],
mask: List[bool],
) List[List[int]]#

Generate orthogonal parallel groups based on the parallel size and mask.

Parameters:
  • world_size (int) – world size

  • parallel_size (List[int]) – The parallel size of each orthogonal parallel type. For example, if tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].

  • mask (List[bool]) – The mask controls which parallel methods the generated groups represent. If mask[i] is True, it means the generated group contains the i-th parallelism method. For example, if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then the generated group is the tp-dp group, if the mask = [False, True, False], then the generated group is the pp group.

Algorithm: For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and local_rank satisfy the following equation: global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) tp_rank \in [0, tp_size) dp_rank \in [0, dp_size) pp_rank \in [0, pp_size)

If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
For example,  if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
The tp_rank and pp_rank will be combined to form the `dp_group_index`.
    dp_group_index = tp_rank + pp_rank * tp_size (2)

So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
equation (1).

This function solve this math problem.

For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], and the mask = [False, True, False]. Then, dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 … dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2

dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
...
dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
core.parallel_state.create_hierarchical_groups(
rank,
ranks,
hierarchical_group_sizes,
create_gloo_process_groups=False,
pg_options=None,
timeout=None,
group_desc=None,
)#

Create hierarchical groups for a set of ranks. Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 … g15. If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level of sub-groups, and 4 GPUs in the last level of sub groups. The present function will create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as: 8 level-1 sub-groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 8 level-2 sub-groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 4 level-3 sub-groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]

class core.parallel_state.RankGenerator(
tp: int,
ep: int,
dp: int,
pp: int,
cp: int,
order: str,
rank_offset: int = 0,
)#

Bases: object

A class for generating rank groups for different modes of parallelism.

Initialization

get_mask(order: str, token: str)#

Create a mask for the specified tokens based on the given order.

Parameters:
  • order (str) – The order of parallelism types (e.g., ‘tp-dp-pp’).

  • token (str) – The specific parallelism types to include in the mask, separated by hyphens (e.g., ‘tp-dp’).

get_ranks(token)#

Get rank group by input token.

Parameters:

token (str) – Specify the ranks type that want to get. If we want to obtain multiple parallel types, we can use a hyphen ‘-’ to separate them. For example, if we want to obtain the TP_DP group, the token should be ‘tp-dp’.

core.parallel_state.default_embedding_ranks(pp_ranks)#

Return the default ranks that constitute the stages on which the word embeddings live. For most models, these are the first and last pipeline stages.

core.parallel_state.default_position_embedding_ranks(pp_ranks)#

Return the default ranks that constitute the stages on which the position embeddings live. For most models, this is only the first pipeline stage.

core.parallel_state.overwrite_nccl_comm_cfgs(nccl_comm_cfgs, pg_name, key_value_pair)#

Overwrite the nccl_comm_cfgs for the given pg_name with the given key_value_pair.

core.parallel_state.initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_comm_backend: Optional[str] = None,
use_sharp: bool = False,
context_parallel_size: int = 1,
hierarchical_context_parallel_sizes: Optional[List[int]] = None,
expert_model_parallel_size: int = 1,
num_distributed_optimizer_instances: int = 1,
expert_tensor_parallel_size: Optional[int] = None,
nccl_communicator_config_path: Optional[str] = None,
distributed_timeout_minutes: int = 30,
order: str = 'tp-cp-ep-dp-pp',
get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
create_gloo_process_groups: bool = True,
high_priority_stream_groups: Optional[List[str]] = None,
sharp_enabled_group: Optional[str] = None,
) None#

Initialize model data parallel groups.

Parameters:
  • tensor_model_parallel_size (int, default = 1) – The number of GPUs to split individual tensors across.

  • pipeline_model_parallel_size (int, default = 1) – The number of tensor parallel GPU groups to split the Transformer layers across. For example, if tensor_model_parallel_size is 4 and pipeline_model_parallel_size is 2, the model will be split into 2 groups of 4 GPUs.

  • virtual_pipeline_model_parallel_size (int, optional) –

    The number of stages that each pipeline group will have, interleaving as necessary. If None, no interleaving is performed. For example, if tensor_model_parallel_size is 1, pipeline_model_parallel_size is 4, virtual_pipeline_model_parallel_size is 2, and there are 16 transformer layers in the model, the model will be split into 8 stages with two layers each and each GPU would get 2 stages as such (layer number starting with 1):

    GPU 0: [1, 2] [9, 10] GPU 1: [3, 4] [11, 12] GPU 2: [5, 6] [13, 14] GPU 3: [7, 8] [15, 16]

  • pipeline_model_parallel_comm_backend (str, optional) – The backend to use for pipeline parallel communication. If None, the default backend will be used.

  • use_sharp (bool, default = False) – Set the use of SHARP for the collective communications of data-parallel process groups. When True, run barrier within each data-parallel process group, which specifies the SHARP application target groups.

  • context_parallel_size (int, default = 1) –

    The number of tensor parallel GPU groups to split the network input sequence length across. Compute of attention module requires tokens of full sequence length, so GPUs in a context parallel group need to communicate with each other to exchange information of other sequence chunks. Each GPU and its counterparts in other tensor parallel groups compose a context parallel group.

    For example, assume we have 8 GPUs, if tensor model parallel size is 4 and context parallel size is 2, the network input will be split into two sequence chunks, which are processed by 2 different groups of 4 GPUs. One chunk is processed by GPU0-3, the other chunk is processed by GPU4-7. Four groups are build to do context parallel communications: [GPU0, GPU4], [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].

    Context parallelism partitions sequence length, so it has no impact on weights, which means weights are duplicated among GPUs in a context parallel group. Hence, weight gradients all-reduce is required in backward. For simplicity, we piggyback GPUs of context parallelism on data parallel group for weight gradient all-reduce.

  • expert_model_parallel_size (int, default = 1) – The number of Mixture of Experts parallel GPUs in each expert parallel group.

  • num_distributed_optimizer_instances (int, default = 1) – The number of distributed optimizer replicas across the data- parallel domain.

  • expert_tensor_parallel_size (int, default = tp_size) – The number of GPUs to split individual tensors of expert.

  • nccl_communicator_config_path (str, default = None) – Path to the yaml file of NCCL communicator configurations. min_ctas, max_ctas, and cga_cluster_size can be set for each communicator.

  • distributed_timeout_minutes (int, default = 30) – Timeout, in minutes,for operations executed against distributed process groups. See PyTorch documentation at https://pytorch.org/docs/stable/distributed.html for caveats.

  • order (str, default=tp-dp-pp) – The rank initialization order of parallelism. Now we support tp-dp-pp and tp-pp-dp orders.

  • get_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None) – A function that takes in a list of ranks for a pipeline group and returns those ranks that should have embeddings.

  • get_position_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None) – A function that takes in a list of ranks for a pipeline group, and returns those ranks that should have position embeddings.

  • create_gloo_process_groups (bool, default = True) – Create Gloo process groups if set to True. If set to False, Gloo process groups are not created and calls to get Gloo process groups will result in assertion errors.

  • high_priority_stream_groups (List[str], default = None) – Specify which communicator groups should use high priority streams during creation. Assigning high priority to communication streams ensures that communication kernels are scheduled with higher priority, minimizing the exposed communication when it is overlapped with other computation kernels. Example: initialize_parallel_groups(…, high_priority_stream_groups=[‘dp_cp’,’ep_dp’])

  • sharp_enabled_group (str, default = None) – Specify which communicator group should use SHARP communication. This option is only valid when use_sharp is True. By default (None), it is enabled from dp group. Available options (choose one): [dp, dp_replica]

Let’s say we have a total of 16 GPUs denoted by g0 … g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 8 tensor model-parallel groups, 4 pipeline model-parallel groups and 8 data-parallel groups as: 8 data_parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 4 pipeline model-parallel groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box.

core.parallel_state.is_initialized()#

Useful for code segments that may be accessed with or without mpu initialization

core.parallel_state.is_unitialized() bool#

Check if parallel state has been initialized

Deprecated. Use is_initialized instead.

core.parallel_state.model_parallel_is_initialized()#

Check if model- and data-parallel groups are initialized.

core.parallel_state.get_model_parallel_group(check_initialized=True)#

Get the model-parallel group the caller rank belongs to.

core.parallel_state.get_tensor_model_parallel_group(check_initialized=True)#

Get the tensor-model-parallel group the caller rank belongs to.

core.parallel_state.get_pipeline_model_parallel_group(check_initialized=True)#

Get the pipeline-model-parallel group the caller rank belongs to.

core.parallel_state.get_data_parallel_group(
with_context_parallel=False,
partial_data_parallel=False,
)#

Get the data-parallel group the caller rank belongs to.

core.parallel_state.get_data_parallel_group_gloo(
with_context_parallel=False,
partial_data_parallel=False,
)#

Get the Gloo data-parallel group the caller rank belongs to.

core.parallel_state.get_context_parallel_group(check_initialized=True)#

Get the context-parallel group the caller rank belongs to.

core.parallel_state.get_context_parallel_global_ranks(check_initialized=True)#

Get all global ranks of the context-parallel group that the caller rank belongs to.

core.parallel_state.get_hierarchical_context_parallel_groups(check_initialized=True)#

Get the inner ring of context parallel group the caller rank belongs to.

core.parallel_state.get_embedding_group(check_initialized=True)#

Get the embedding group the caller rank belongs to.

core.parallel_state.get_position_embedding_group(check_initialized=True)#

Get the position embedding group the caller rank belongs to.

core.parallel_state.get_amax_reduction_group(
with_context_parallel=False,
tp_only_amax_red=False,
)#

Get the FP8 amax reduction group the caller rank belongs to.

core.parallel_state.get_tensor_and_data_parallel_group(
check_initialized=True,
with_context_parallel=False,
)#

Get the tensor- and data-parallel group the caller rank belongs to.

core.parallel_state.get_tensor_and_context_parallel_group(check_initialized=True)#

Get the tensor- and context-parallel group the caller rank belongs to.

core.parallel_state.set_tensor_model_parallel_world_size(world_size)#

Set the tensor-model-parallel size

core.parallel_state.set_pipeline_model_parallel_world_size(world_size)#

Set the pipeline-model-parallel size

core.parallel_state.set_virtual_pipeline_model_parallel_world_size(world_size)#

Set the pipeline-model-parallel size

core.parallel_state.get_tensor_model_parallel_world_size()#

Return world size for the tensor-model-parallel group.

core.parallel_state.get_pipeline_model_parallel_world_size()#

Return world size for the pipeline-model-parallel group.

core.parallel_state.set_tensor_model_parallel_rank(rank)#

Set tensor-model-parallel rank.

core.parallel_state.set_pipeline_model_parallel_rank(rank)#

Set pipeline-model-parallel rank.

core.parallel_state.get_tensor_model_parallel_rank()#

Return caller’s rank for the tensor-model-parallel group.

core.parallel_state.get_pipeline_model_parallel_rank()#

Return caller’s rank for the pipeline-model-parallel group.

core.parallel_state.is_pipeline_first_stage(ignore_virtual=True, vp_stage=None)#

Return True if in the first pipeline model-parallel stage, False otherwise.

core.parallel_state.is_pipeline_last_stage(ignore_virtual=True, vp_stage=None)#

Return True if in the last pipeline-model-parallel stage, False otherwise.

core.parallel_state.is_rank_in_embedding_group(ignore_virtual=True, vp_stage=None)#

Return true if current rank is in embedding group, False otherwise.

core.parallel_state.is_rank_in_position_embedding_group()#

Return true if current rank is in position embedding group, False otherwise.

core.parallel_state.get_virtual_pipeline_model_parallel_rank()#

Return the virtual pipeline-parallel rank.

core.parallel_state.set_virtual_pipeline_model_parallel_rank(rank)#

Set the virtual pipeline-parallel rank.

core.parallel_state.get_virtual_pipeline_model_parallel_world_size()#

Return the virtual pipeline-parallel world size.

core.parallel_state.get_tensor_model_parallel_src_rank()#

Calculate the global rank corresponding to the first local rank in the tensor model parallel group.

core.parallel_state.get_model_parallel_src_rank()#

Calculate the global rank corresponding to the first local rank in the model parallel group.

core.parallel_state.get_data_parallel_src_rank(with_context_parallel=False)#

Calculate the global rank corresponding to the first local rank in the data parallel group.

core.parallel_state.get_pipeline_model_parallel_first_rank()#

Return the global rank of the first stage in the current rank’s pipeline.

core.parallel_state.get_pipeline_model_parallel_last_rank()#

Return the global rank of the last stage in the current rank’s pipeline.

core.parallel_state.get_pipeline_model_parallel_next_rank()#

Return the global rank that follows the caller in the pipeline.

core.parallel_state.get_pipeline_model_parallel_prev_rank()#

Return the global rank that precedes the caller in the pipeline.

core.parallel_state.get_data_parallel_world_size(
with_context_parallel=False,
partial_data_parallel=False,
)#

Return world size for the data parallel group.

core.parallel_state.set_data_parallel_rank(rank)#

Return world size for the data parallel group.

core.parallel_state.get_data_parallel_rank(
with_context_parallel=False,
partial_data_parallel=False,
)#

Return caller’s rank in the data-parallel group.

core.parallel_state.get_context_parallel_world_size()#

Return world size for the context parallel group.

core.parallel_state.get_context_parallel_rank()#

Return caller’s rank in the context-parallel group.

core.parallel_state.get_tensor_and_context_parallel_world_size()#

Return world size for the tensor and context-parallel group.

core.parallel_state.get_tensor_and_context_parallel_rank()#

Return caller’s rank in the joint tensor-model-parallel and context-parallel group.

core.parallel_state.get_expert_model_parallel_group(check_initialized=True)#

Get the expert-model-parallel group the caller rank belongs to.

core.parallel_state.get_expert_model_parallel_src_rank()#

Calculate the global rank corresponding to the first local rank in the expert model parallel group.

core.parallel_state.get_expert_model_parallel_world_size()#

Return world size for the expert-model-parallel group.

core.parallel_state.set_expert_model_parallel_world_size(world_size)#

Sets the expert-model-parallel world size.

core.parallel_state.get_expert_model_parallel_rank()#

Return caller’s rank in the expert-model-parallel group.

core.parallel_state.set_expert_model_parallel_rank(rank)#

Set expert-model-parallel rank.

core.parallel_state.get_expert_tensor_parallel_group(check_initialized=True)#

Get the expert-tensor-parallel group the caller rank belongs to.

core.parallel_state.get_expert_tensor_parallel_world_size()#

Return world size for the expert tensor parallel group.

core.parallel_state.set_expert_tensor_parallel_world_size(world_size)#

Set expert tensor model parallel size

core.parallel_state.get_expert_tensor_parallel_rank()#

Return my rank for the expert tensor parallel group.

core.parallel_state.set_expert_tensor_parallel_rank(rank)#

Set expert tensor model parallel rank

core.parallel_state.get_expert_tensor_and_model_parallel_group(check_initialized=True)#

Get the expert-tensor and expert-model group the caller rank belongs to.

core.parallel_state.get_expert_tensor_and_model_parallel_world_size()#

Return world size for the expert model parallel group times expert tensor parallel group.

core.parallel_state.get_expert_tensor_and_model_parallel_rank()#

Return caller’s rank in the joint tensor- and expert-model-parallel group.

core.parallel_state.get_expert_tensor_model_pipeline_parallel_group(
check_initialized=True,
)#

Get expert tensor-model-pipeline parallel group.

core.parallel_state.get_expert_data_parallel_group(
check_initialized=True,
partial_expert_data_parallel=False,
)#

Get expert data parallel group.

core.parallel_state.get_data_modulo_expert_parallel_group(
partial_expert_data_parallel=False,
)#

[Deprecated] Get expert data parallel group.

core.parallel_state.get_expert_data_parallel_group_gloo(
partial_expert_data_parallel=False,
)#

Get expert data parallel group-gloo.

core.parallel_state.get_expert_data_parallel_rank(partial_expert_data_parallel=False)#

Return caller’s rank in the expert data parallel group.

core.parallel_state.get_expert_data_parallel_world_size(
partial_expert_data_parallel=False,
)#

Return world size for the expert data parallel group.

core.parallel_state.get_intra_distributed_optimizer_instance_group(check_initialized=True)#

Get the group of all GPUs in a distributed optimizer instance.

core.parallel_state.get_inter_distributed_optimizer_instance_group(check_initialized=True)#

Get the group spanning the different distributed optimizer instances. Attention and MLP/Expert share same inter-instance group, so only built inter_partial_expert_data_parallel_group, and return it at here.

core.parallel_state._set_global_memory_buffer()#

Initialize global buffer.

core.parallel_state._set_global_symmetric_memory_buffer()#

Initialize global buffer.

core.parallel_state.get_global_memory_buffer()#

Return the global GlobalMemoryBuffer object

core.parallel_state.get_global_symmetric_memory_buffer()#

Return the global GlobalSymmetricMemoryBuffer object

core.parallel_state.destroy_global_memory_buffer()#

Sets the global memory buffer to None

core.parallel_state.destroy_global_symmetric_memory_buffer()#

Sets the global symmetric memory buffer to None

core.parallel_state.get_all_ranks()#

Get caller’s rank in tensor-model-parallel, data-parallel, context-parallel, pipeline-model-parallel and expert-model-parallel groups.

core.parallel_state.destroy_model_parallel()#

Set the groups to none.