core.parallel_state#
Model and data parallel groups.
Module Contents#
Classes#
A class for generating rank groups for different modes of parallelism. |
Functions#
Set the NCCL process group options. |
|
Update the timeout for all process groups or a specific process group. Synchronize the process groups before updating the timeout. |
|
Creates a ProcessGroup. |
|
Generate orthogonal parallel groups based on the parallel size and mask. |
|
Create hierarchical groups for a set of ranks. Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 … g15. If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level of sub-groups, and 4 GPUs in the last level of sub groups. The present function will create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as: 8 level-1 sub-groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 8 level-2 sub-groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 4 level-3 sub-groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] |
|
Return the default ranks that constitute the stages on which the word embeddings live. For most models, these are the first and last pipeline stages. |
|
Return the default ranks that constitute the stages on which the position embeddings live. For most models, this is only the first pipeline stage. |
|
Overwrite the nccl_comm_cfgs for the given pg_name with the given key_value_pair. |
|
Initialize model data parallel groups. |
|
Useful for code segments that may be accessed with or without mpu initialization |
|
Check if parallel state has been initialized |
|
Check if model- and data-parallel groups are initialized. |
|
Get the model-parallel group the caller rank belongs to. |
|
Get the tensor-model-parallel group the caller rank belongs to. |
|
Get the pipeline-model-parallel group the caller rank belongs to. |
|
Get the data-parallel group the caller rank belongs to. |
|
Get the Gloo data-parallel group the caller rank belongs to. |
|
Get the context-parallel group the caller rank belongs to. |
|
Get all global ranks of the context-parallel group that the caller rank belongs to. |
|
Get the inner ring of context parallel group the caller rank belongs to. |
|
Get the embedding group the caller rank belongs to. |
|
Get the position embedding group the caller rank belongs to. |
|
Get the FP8 amax reduction group the caller rank belongs to. |
|
Get the tensor- and data-parallel group the caller rank belongs to. |
|
Get the tensor- and context-parallel group the caller rank belongs to. |
|
Set the tensor-model-parallel size |
|
Set the pipeline-model-parallel size |
|
Set the pipeline-model-parallel size |
|
Return world size for the tensor-model-parallel group. |
|
Return world size for the pipeline-model-parallel group. |
|
Set tensor-model-parallel rank. |
|
Set pipeline-model-parallel rank. |
|
Return caller’s rank for the tensor-model-parallel group. |
|
Return caller’s rank for the pipeline-model-parallel group. |
|
Return True if in the first pipeline model-parallel stage, False otherwise. |
|
Return True if in the last pipeline-model-parallel stage, False otherwise. |
|
Return true if current rank is in embedding group, False otherwise. |
|
Return true if current rank is in position embedding group, False otherwise. |
|
Return the virtual pipeline-parallel rank. |
|
Set the virtual pipeline-parallel rank. |
|
Return the virtual pipeline-parallel world size. |
|
Calculate the global rank corresponding to the first local rank in the tensor model parallel group. |
|
Calculate the global rank corresponding to the first local rank in the model parallel group. |
|
Calculate the global rank corresponding to the first local rank in the data parallel group. |
|
Return the global rank of the first stage in the current rank’s pipeline. |
|
Return the global rank of the last stage in the current rank’s pipeline. |
|
Return the global rank that follows the caller in the pipeline. |
|
Return the global rank that precedes the caller in the pipeline. |
|
Return world size for the data parallel group. |
|
Return world size for the data parallel group. |
|
Return caller’s rank in the data-parallel group. |
|
Return world size for the context parallel group. |
|
Return caller’s rank in the context-parallel group. |
|
Return world size for the tensor and context-parallel group. |
|
Return caller’s rank in the joint tensor-model-parallel and context-parallel group. |
|
Get the expert-model-parallel group the caller rank belongs to. |
|
Calculate the global rank corresponding to the first local rank in the expert model parallel group. |
|
Return world size for the expert-model-parallel group. |
|
Sets the expert-model-parallel world size. |
|
Return caller’s rank in the expert-model-parallel group. |
|
Set expert-model-parallel rank. |
|
Get the expert-tensor-parallel group the caller rank belongs to. |
|
Return world size for the expert tensor parallel group. |
|
Set expert tensor model parallel size |
|
Return my rank for the expert tensor parallel group. |
|
Set expert tensor model parallel rank |
|
Get the expert-tensor and expert-model group the caller rank belongs to. |
|
Return world size for the expert model parallel group times expert tensor parallel group. |
|
Return caller’s rank in the joint tensor- and expert-model-parallel group. |
|
Get expert tensor-model-pipeline parallel group. |
|
Get expert data parallel group. |
|
[Deprecated] Get expert data parallel group. |
|
Get expert data parallel group-gloo. |
|
Return caller’s rank in the expert data parallel group. |
|
Return world size for the expert data parallel group. |
|
Get the group of all GPUs in a distributed optimizer instance. |
|
Get the group spanning the different distributed optimizer instances. Attention and MLP/Expert share same inter-instance group, so only built inter_partial_expert_data_parallel_group, and return it at here. |
|
Initialize global buffer. |
|
Initialize global buffer. |
|
Return the global GlobalMemoryBuffer object |
|
Return the global GlobalSymmetricMemoryBuffer object |
|
Sets the global memory buffer to None |
|
Sets the global symmetric memory buffer to None |
|
Get caller’s rank in tensor-model-parallel, data-parallel, context-parallel, pipeline-model-parallel and expert-model-parallel groups. |
|
Set the groups to none. |
Data#
API#
- core.parallel_state.logger#
‘getLogger(…)’
- core.parallel_state._TENSOR_MODEL_PARALLEL_GROUP#
None
- core.parallel_state._PIPELINE_MODEL_PARALLEL_GROUP#
None
- core.parallel_state._MODEL_PARALLEL_GROUP#
None
- core.parallel_state._EMBEDDING_GROUP#
None
- core.parallel_state._POSITION_EMBEDDING_GROUP#
None
- core.parallel_state._DATA_PARALLEL_GROUP#
None
- core.parallel_state._DATA_PARALLEL_GROUP_GLOO#
None
- core.parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP#
None
- core.parallel_state._EXPERT_MODEL_PARALLEL_GROUP#
None
- core.parallel_state._EXPERT_TENSOR_PARALLEL_GROUP#
None
- core.parallel_state._EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP#
None
- core.parallel_state._EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP#
None
- core.parallel_state._EXPERT_DATA_PARALLEL_GROUP#
None
- core.parallel_state._EXPERT_DATA_PARALLEL_GROUP_GLOO#
None
- core.parallel_state._INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP#
None
- core.parallel_state._INTRA_PARTIAL_EXPERT_DATA_PARALLEL_GROUP_GLOO#
None
- core.parallel_state._INTER_PARTIAL_EXPERT_DATA_PARALLEL_GROUP#
None
- core.parallel_state._MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE#
None
- core.parallel_state._MPU_EXPERT_MODEL_PARALLEL_RANK#
None
- core.parallel_state._MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE#
None
- core.parallel_state._MPU_EXPERT_TENSOR_PARALLEL_RANK#
None
- core.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK#
None
- core.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE#
None
- core.parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE#
None
- core.parallel_state._MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE#
None
- core.parallel_state._MPU_DATA_PARALLEL_WORLD_SIZE#
None
- core.parallel_state._MPU_DATA_PARALLEL_RANK#
None
- core.parallel_state._MPU_TENSOR_MODEL_PARALLEL_RANK#
None
- core.parallel_state._MPU_PIPELINE_MODEL_PARALLEL_RANK#
None
- core.parallel_state._EMBEDDING_GLOBAL_RANKS#
None
- core.parallel_state._POSITION_EMBEDDING_GLOBAL_RANKS#
None
- core.parallel_state._PIPELINE_GLOBAL_RANKS#
None
- core.parallel_state._DATA_PARALLEL_GLOBAL_RANKS#
None
- core.parallel_state._TENSOR_MODEL_PARALLEL_GLOBAL_RANKS#
None
- core.parallel_state._EXPERT_MODEL_PARALLEL_RANKS#
None
- core.parallel_state._MODEL_PARALLEL_GLOBAL_RANKS#
None
- core.parallel_state._CONTEXT_PARALLEL_GROUP#
None
- core.parallel_state._CONTEXT_PARALLEL_GLOBAL_RANKS#
None
- core.parallel_state._HIERARCHICAL_CONTEXT_PARALLEL_GROUPS#
None
- core.parallel_state._DATA_PARALLEL_GROUP_WITH_CP#
None
- core.parallel_state._DATA_PARALLEL_GROUP_WITH_CP_GLOO#
None
- core.parallel_state._DATA_PARALLEL_GLOBAL_RANKS_WITH_CP#
None
- core.parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP#
None
- core.parallel_state._INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO#
None
- core.parallel_state._TENSOR_AND_CONTEXT_PARALLEL_GROUP#
None
- core.parallel_state._TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP#
None
- core.parallel_state._INTRA_DISTRIBUTED_OPTIMIZER_INSTANCE_GROUP#
None
- core.parallel_state._GLOBAL_MEMORY_BUFFER#
None
- core.parallel_state._GLOBAL_SYMMETRIC_MEMORY_BUFFER#
None
- core.parallel_state._global_process_group_list#
None
- core.parallel_state.get_nccl_options(pg_name, nccl_comm_cfgs)#
Set the NCCL process group options.
- Parameters:
pg_name (str) – process group name
nccl_comm_cfgs (dict) – nccl communicator configurations
When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
- core.parallel_state.update_pg_timeout(
- timeout: datetime.timedelta,
- pg: Optional[torch._C._distributed_c10d.ProcessGroup] = None,
Update the timeout for all process groups or a specific process group. Synchronize the process groups before updating the timeout.
- Parameters:
timeout (datetime.timedelta) – The timeout to set for the process group(s)
pg (Optional[torch._C._distributed_c10d.ProcessGroup], default=None) – The process group to update the timeout for. If None, all process groups are updated.
- core.parallel_state.create_group(
- ranks=None,
- timeout=None,
- backend=None,
- pg_options=None,
- use_local_synchronization=False,
- group_desc=None,
Creates a ProcessGroup.
- core.parallel_state.generate_masked_orthogonal_rank_groups(
- world_size: int,
- parallel_size: List[int],
- mask: List[bool],
Generate orthogonal parallel groups based on the parallel size and mask.
- Parameters:
world_size (int) – world size
parallel_size (List[int]) – The parallel size of each orthogonal parallel type. For example, if tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4, and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].
mask (List[bool]) – The mask controls which parallel methods the generated groups represent. If mask[i] is True, it means the generated group contains the i-th parallelism method. For example, if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then the generated group is the
tp-dpgroup, if the mask = [False, True, False], then the generated group is theppgroup.
Algorithm: For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and local_rank satisfy the following equation: global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1) tp_rank \in [0, tp_size) dp_rank \in [0, dp_size) pp_rank \in [0, pp_size)
If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each. For example, if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].) The tp_rank and pp_rank will be combined to form the `dp_group_index`. dp_group_index = tp_rank + pp_rank * tp_size (2) So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the equation (1). This function solve this math problem.For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4], and the mask = [False, True, False]. Then, dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2 dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2 … dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2
dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4] dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5] ... dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
- core.parallel_state.create_hierarchical_groups(
- rank,
- ranks,
- hierarchical_group_sizes,
- create_gloo_process_groups=False,
- pg_options=None,
- timeout=None,
- group_desc=None,
Create hierarchical groups for a set of ranks. Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 … g15. If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level of sub-groups, and 4 GPUs in the last level of sub groups. The present function will create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as: 8 level-1 sub-groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 8 level-2 sub-groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 4 level-3 sub-groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
- class core.parallel_state.RankGenerator(
- tp: int,
- ep: int,
- dp: int,
- pp: int,
- cp: int,
- order: str,
- rank_offset: int = 0,
Bases:
objectA class for generating rank groups for different modes of parallelism.
Initialization
- get_mask(order: str, token: str)#
Create a mask for the specified tokens based on the given order.
- Parameters:
order (str) – The order of parallelism types (e.g., ‘tp-dp-pp’).
token (str) – The specific parallelism types to include in the mask, separated by hyphens (e.g., ‘tp-dp’).
- get_ranks(token)#
Get rank group by input token.
- Parameters:
token (str) – Specify the ranks type that want to get. If we want to obtain multiple parallel types, we can use a hyphen ‘-’ to separate them. For example, if we want to obtain the TP_DP group, the token should be ‘tp-dp’.
- core.parallel_state.default_embedding_ranks(pp_ranks)#
Return the default ranks that constitute the stages on which the word embeddings live. For most models, these are the first and last pipeline stages.
- core.parallel_state.default_position_embedding_ranks(pp_ranks)#
Return the default ranks that constitute the stages on which the position embeddings live. For most models, this is only the first pipeline stage.
- core.parallel_state.overwrite_nccl_comm_cfgs(nccl_comm_cfgs, pg_name, key_value_pair)#
Overwrite the nccl_comm_cfgs for the given pg_name with the given key_value_pair.
- core.parallel_state.initialize_model_parallel(
- tensor_model_parallel_size: int = 1,
- pipeline_model_parallel_size: int = 1,
- virtual_pipeline_model_parallel_size: Optional[int] = None,
- pipeline_model_parallel_comm_backend: Optional[str] = None,
- use_sharp: bool = False,
- context_parallel_size: int = 1,
- hierarchical_context_parallel_sizes: Optional[List[int]] = None,
- expert_model_parallel_size: int = 1,
- num_distributed_optimizer_instances: int = 1,
- expert_tensor_parallel_size: Optional[int] = None,
- nccl_communicator_config_path: Optional[str] = None,
- distributed_timeout_minutes: int = 30,
- order: str = 'tp-cp-ep-dp-pp',
- get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
- get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
- create_gloo_process_groups: bool = True,
- high_priority_stream_groups: Optional[List[str]] = None,
- sharp_enabled_group: Optional[str] = None,
Initialize model data parallel groups.
- Parameters:
tensor_model_parallel_size (int, default = 1) – The number of GPUs to split individual tensors across.
pipeline_model_parallel_size (int, default = 1) – The number of tensor parallel GPU groups to split the Transformer layers across. For example, if tensor_model_parallel_size is 4 and pipeline_model_parallel_size is 2, the model will be split into 2 groups of 4 GPUs.
virtual_pipeline_model_parallel_size (int, optional) –
The number of stages that each pipeline group will have, interleaving as necessary. If None, no interleaving is performed. For example, if tensor_model_parallel_size is 1, pipeline_model_parallel_size is 4, virtual_pipeline_model_parallel_size is 2, and there are 16 transformer layers in the model, the model will be split into 8 stages with two layers each and each GPU would get 2 stages as such (layer number starting with 1):
GPU 0: [1, 2] [9, 10] GPU 1: [3, 4] [11, 12] GPU 2: [5, 6] [13, 14] GPU 3: [7, 8] [15, 16]
pipeline_model_parallel_comm_backend (str, optional) – The backend to use for pipeline parallel communication. If None, the default backend will be used.
use_sharp (bool, default = False) – Set the use of SHARP for the collective communications of data-parallel process groups. When
True, run barrier within each data-parallel process group, which specifies the SHARP application target groups.context_parallel_size (int, default = 1) –
The number of tensor parallel GPU groups to split the network input sequence length across. Compute of attention module requires tokens of full sequence length, so GPUs in a context parallel group need to communicate with each other to exchange information of other sequence chunks. Each GPU and its counterparts in other tensor parallel groups compose a context parallel group.
For example, assume we have 8 GPUs, if tensor model parallel size is 4 and context parallel size is 2, the network input will be split into two sequence chunks, which are processed by 2 different groups of 4 GPUs. One chunk is processed by GPU0-3, the other chunk is processed by GPU4-7. Four groups are build to do context parallel communications: [GPU0, GPU4], [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].
Context parallelism partitions sequence length, so it has no impact on weights, which means weights are duplicated among GPUs in a context parallel group. Hence, weight gradients all-reduce is required in backward. For simplicity, we piggyback GPUs of context parallelism on data parallel group for weight gradient all-reduce.
expert_model_parallel_size (int, default = 1) – The number of Mixture of Experts parallel GPUs in each expert parallel group.
num_distributed_optimizer_instances (int, default = 1) – The number of distributed optimizer replicas across the data- parallel domain.
expert_tensor_parallel_size (int, default = tp_size) – The number of GPUs to split individual tensors of expert.
nccl_communicator_config_path (str, default = None) – Path to the yaml file of NCCL communicator configurations.
min_ctas,max_ctas, andcga_cluster_sizecan be set for each communicator.distributed_timeout_minutes (int, default = 30) – Timeout, in minutes,for operations executed against distributed process groups. See PyTorch documentation at https://pytorch.org/docs/stable/distributed.html for caveats.
order (str, default=tp-dp-pp) – The rank initialization order of parallelism. Now we support tp-dp-pp and tp-pp-dp orders.
get_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None) – A function that takes in a list of ranks for a pipeline group and returns those ranks that should have embeddings.
get_position_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None) – A function that takes in a list of ranks for a pipeline group, and returns those ranks that should have position embeddings.
create_gloo_process_groups (bool, default = True) – Create Gloo process groups if set to True. If set to False, Gloo process groups are not created and calls to get Gloo process groups will result in assertion errors.
high_priority_stream_groups (List[str], default = None) – Specify which communicator groups should use high priority streams during creation. Assigning high priority to communication streams ensures that communication kernels are scheduled with higher priority, minimizing the exposed communication when it is overlapped with other computation kernels. Example: initialize_parallel_groups(…, high_priority_stream_groups=[‘dp_cp’,’ep_dp’])
sharp_enabled_group (str, default = None) – Specify which communicator group should use SHARP communication. This option is only valid when use_sharp is True. By default (None), it is enabled from dp group. Available options (choose one): [dp, dp_replica]
Let’s say we have a total of 16 GPUs denoted by g0 … g15 and we use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize the model pipeline. The present function will create 8 tensor model-parallel groups, 4 pipeline model-parallel groups and 8 data-parallel groups as: 8 data_parallel groups: [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15] 8 tensor model-parallel groups: [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15] 4 pipeline model-parallel groups: [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box.
- core.parallel_state.is_initialized()#
Useful for code segments that may be accessed with or without mpu initialization
- core.parallel_state.is_unitialized() bool#
Check if parallel state has been initialized
Deprecated. Use is_initialized instead.
- core.parallel_state.model_parallel_is_initialized()#
Check if model- and data-parallel groups are initialized.
- core.parallel_state.get_model_parallel_group(check_initialized=True)#
Get the model-parallel group the caller rank belongs to.
- core.parallel_state.get_tensor_model_parallel_group(check_initialized=True)#
Get the tensor-model-parallel group the caller rank belongs to.
- core.parallel_state.get_pipeline_model_parallel_group(check_initialized=True)#
Get the pipeline-model-parallel group the caller rank belongs to.
- core.parallel_state.get_data_parallel_group(
- with_context_parallel=False,
- partial_data_parallel=False,
Get the data-parallel group the caller rank belongs to.
- core.parallel_state.get_data_parallel_group_gloo(
- with_context_parallel=False,
- partial_data_parallel=False,
Get the Gloo data-parallel group the caller rank belongs to.
- core.parallel_state.get_context_parallel_group(check_initialized=True)#
Get the context-parallel group the caller rank belongs to.
- core.parallel_state.get_context_parallel_global_ranks(check_initialized=True)#
Get all global ranks of the context-parallel group that the caller rank belongs to.
- core.parallel_state.get_hierarchical_context_parallel_groups(check_initialized=True)#
Get the inner ring of context parallel group the caller rank belongs to.
- core.parallel_state.get_embedding_group(check_initialized=True)#
Get the embedding group the caller rank belongs to.
- core.parallel_state.get_position_embedding_group(check_initialized=True)#
Get the position embedding group the caller rank belongs to.
- core.parallel_state.get_amax_reduction_group(
- with_context_parallel=False,
- tp_only_amax_red=False,
Get the FP8 amax reduction group the caller rank belongs to.
- core.parallel_state.get_tensor_and_data_parallel_group(
- check_initialized=True,
- with_context_parallel=False,
Get the tensor- and data-parallel group the caller rank belongs to.
- core.parallel_state.get_tensor_and_context_parallel_group(check_initialized=True)#
Get the tensor- and context-parallel group the caller rank belongs to.
- core.parallel_state.set_tensor_model_parallel_world_size(world_size)#
Set the tensor-model-parallel size
- core.parallel_state.set_pipeline_model_parallel_world_size(world_size)#
Set the pipeline-model-parallel size
- core.parallel_state.set_virtual_pipeline_model_parallel_world_size(world_size)#
Set the pipeline-model-parallel size
- core.parallel_state.get_tensor_model_parallel_world_size()#
Return world size for the tensor-model-parallel group.
- core.parallel_state.get_pipeline_model_parallel_world_size()#
Return world size for the pipeline-model-parallel group.
- core.parallel_state.set_tensor_model_parallel_rank(rank)#
Set tensor-model-parallel rank.
- core.parallel_state.set_pipeline_model_parallel_rank(rank)#
Set pipeline-model-parallel rank.
- core.parallel_state.get_tensor_model_parallel_rank()#
Return caller’s rank for the tensor-model-parallel group.
- core.parallel_state.get_pipeline_model_parallel_rank()#
Return caller’s rank for the pipeline-model-parallel group.
- core.parallel_state.is_pipeline_first_stage(ignore_virtual=True, vp_stage=None)#
Return True if in the first pipeline model-parallel stage, False otherwise.
- core.parallel_state.is_pipeline_last_stage(ignore_virtual=True, vp_stage=None)#
Return True if in the last pipeline-model-parallel stage, False otherwise.
- core.parallel_state.is_rank_in_embedding_group(ignore_virtual=True, vp_stage=None)#
Return true if current rank is in embedding group, False otherwise.
- core.parallel_state.is_rank_in_position_embedding_group()#
Return true if current rank is in position embedding group, False otherwise.
- core.parallel_state.get_virtual_pipeline_model_parallel_rank()#
Return the virtual pipeline-parallel rank.
- core.parallel_state.set_virtual_pipeline_model_parallel_rank(rank)#
Set the virtual pipeline-parallel rank.
- core.parallel_state.get_virtual_pipeline_model_parallel_world_size()#
Return the virtual pipeline-parallel world size.
- core.parallel_state.get_tensor_model_parallel_src_rank()#
Calculate the global rank corresponding to the first local rank in the tensor model parallel group.
- core.parallel_state.get_model_parallel_src_rank()#
Calculate the global rank corresponding to the first local rank in the model parallel group.
- core.parallel_state.get_data_parallel_src_rank(with_context_parallel=False)#
Calculate the global rank corresponding to the first local rank in the data parallel group.
- core.parallel_state.get_pipeline_model_parallel_first_rank()#
Return the global rank of the first stage in the current rank’s pipeline.
- core.parallel_state.get_pipeline_model_parallel_last_rank()#
Return the global rank of the last stage in the current rank’s pipeline.
- core.parallel_state.get_pipeline_model_parallel_next_rank()#
Return the global rank that follows the caller in the pipeline.
- core.parallel_state.get_pipeline_model_parallel_prev_rank()#
Return the global rank that precedes the caller in the pipeline.
- core.parallel_state.get_data_parallel_world_size(
- with_context_parallel=False,
- partial_data_parallel=False,
Return world size for the data parallel group.
- core.parallel_state.set_data_parallel_rank(rank)#
Return world size for the data parallel group.
- core.parallel_state.get_data_parallel_rank(
- with_context_parallel=False,
- partial_data_parallel=False,
Return caller’s rank in the data-parallel group.
- core.parallel_state.get_context_parallel_world_size()#
Return world size for the context parallel group.
- core.parallel_state.get_context_parallel_rank()#
Return caller’s rank in the context-parallel group.
- core.parallel_state.get_tensor_and_context_parallel_world_size()#
Return world size for the tensor and context-parallel group.
- core.parallel_state.get_tensor_and_context_parallel_rank()#
Return caller’s rank in the joint tensor-model-parallel and context-parallel group.
- core.parallel_state.get_expert_model_parallel_group(check_initialized=True)#
Get the expert-model-parallel group the caller rank belongs to.
- core.parallel_state.get_expert_model_parallel_src_rank()#
Calculate the global rank corresponding to the first local rank in the expert model parallel group.
- core.parallel_state.get_expert_model_parallel_world_size()#
Return world size for the expert-model-parallel group.
- core.parallel_state.set_expert_model_parallel_world_size(world_size)#
Sets the expert-model-parallel world size.
- core.parallel_state.get_expert_model_parallel_rank()#
Return caller’s rank in the expert-model-parallel group.
- core.parallel_state.set_expert_model_parallel_rank(rank)#
Set expert-model-parallel rank.
- core.parallel_state.get_expert_tensor_parallel_group(check_initialized=True)#
Get the expert-tensor-parallel group the caller rank belongs to.
- core.parallel_state.get_expert_tensor_parallel_world_size()#
Return world size for the expert tensor parallel group.
- core.parallel_state.set_expert_tensor_parallel_world_size(world_size)#
Set expert tensor model parallel size
- core.parallel_state.get_expert_tensor_parallel_rank()#
Return my rank for the expert tensor parallel group.
- core.parallel_state.set_expert_tensor_parallel_rank(rank)#
Set expert tensor model parallel rank
- core.parallel_state.get_expert_tensor_and_model_parallel_group(check_initialized=True)#
Get the expert-tensor and expert-model group the caller rank belongs to.
- core.parallel_state.get_expert_tensor_and_model_parallel_world_size()#
Return world size for the expert model parallel group times expert tensor parallel group.
- core.parallel_state.get_expert_tensor_and_model_parallel_rank()#
Return caller’s rank in the joint tensor- and expert-model-parallel group.
- core.parallel_state.get_expert_tensor_model_pipeline_parallel_group(
- check_initialized=True,
Get expert tensor-model-pipeline parallel group.
- core.parallel_state.get_expert_data_parallel_group(
- check_initialized=True,
- partial_expert_data_parallel=False,
Get expert data parallel group.
- core.parallel_state.get_data_modulo_expert_parallel_group(
- partial_expert_data_parallel=False,
[Deprecated] Get expert data parallel group.
- core.parallel_state.get_expert_data_parallel_group_gloo(
- partial_expert_data_parallel=False,
Get expert data parallel group-gloo.
- core.parallel_state.get_expert_data_parallel_rank(partial_expert_data_parallel=False)#
Return caller’s rank in the expert data parallel group.
- core.parallel_state.get_expert_data_parallel_world_size(
- partial_expert_data_parallel=False,
Return world size for the expert data parallel group.
- core.parallel_state.get_intra_distributed_optimizer_instance_group(check_initialized=True)#
Get the group of all GPUs in a distributed optimizer instance.
- core.parallel_state.get_inter_distributed_optimizer_instance_group(check_initialized=True)#
Get the group spanning the different distributed optimizer instances. Attention and MLP/Expert share same inter-instance group, so only built inter_partial_expert_data_parallel_group, and return it at here.
- core.parallel_state._set_global_memory_buffer()#
Initialize global buffer.
- core.parallel_state._set_global_symmetric_memory_buffer()#
Initialize global buffer.
- core.parallel_state.get_global_memory_buffer()#
Return the global GlobalMemoryBuffer object
- core.parallel_state.get_global_symmetric_memory_buffer()#
Return the global GlobalSymmetricMemoryBuffer object
- core.parallel_state.destroy_global_memory_buffer()#
Sets the global memory buffer to None
- core.parallel_state.destroy_global_symmetric_memory_buffer()#
Sets the global symmetric memory buffer to None
- core.parallel_state.get_all_ranks()#
Get caller’s rank in tensor-model-parallel, data-parallel, context-parallel, pipeline-model-parallel and expert-model-parallel groups.
- core.parallel_state.destroy_model_parallel()#
Set the groups to none.