bridge.training.comm_overlap#

Module Contents#

Classes#

TPOverlapCfg

Dataclass for linear layer TP overlap config.

PipelineOverlapCfg

Dataclass for pipeline TP overlap config.

RingExchangeOverlapCfg

Dataclass for ring exchange TP overlap config.

BulkOverlapCfg

Dataclass for bulk TP overlap config.

TransformerLayerTPOverlapCfg

Dataclass for transformer layer TP overlap config.

_CommOverlapConfig

CommOverlapConfig

Configuration for communication overlap optimizations in distributed training.

Data#

API#

class bridge.training.comm_overlap.TPOverlapCfg#

Dataclass for linear layer TP overlap config.

class bridge.training.comm_overlap.PipelineOverlapCfg#

Bases: bridge.training.comm_overlap.TPOverlapCfg

Dataclass for pipeline TP overlap config.

num_sm: int#

None

cga_size: int#

None

num_splits: int#

None

set_sm_margin: bool#

None

fp8_buf: bool#

(False,)

atomic_gemm: bool#

False

method: str#

‘pipeline’

class bridge.training.comm_overlap.RingExchangeOverlapCfg#

Bases: bridge.training.comm_overlap.TPOverlapCfg

Dataclass for ring exchange TP overlap config.

aggregate: bool#

False

method: str#

‘ring_exchange’

num_sm: int#

1

cga_size: int#

1

set_sm_margin: bool#

False

fp8_buf: bool#

False

atomic_gemm: bool#

False

class bridge.training.comm_overlap.BulkOverlapCfg#

Bases: bridge.training.comm_overlap.TPOverlapCfg

Dataclass for bulk TP overlap config.

num_sm: int#

None

cga_size: int#

None

set_sm_margin: bool#

None

method: str#

‘bulk’

class bridge.training.comm_overlap.TransformerLayerTPOverlapCfg#

Dataclass for transformer layer TP overlap config.

qkv_dgrad: bridge.training.comm_overlap.TPOverlapCfg#

None

qkv_wgrad: bridge.training.comm_overlap.TPOverlapCfg#

None

fc1_dgrad: bridge.training.comm_overlap.TPOverlapCfg#

None

fc1_wgrad: bridge.training.comm_overlap.TPOverlapCfg#

None

qkv_fprop: bridge.training.comm_overlap.TPOverlapCfg#

None

proj_dgrad: bridge.training.comm_overlap.TPOverlapCfg#

None

fc1_fprop: bridge.training.comm_overlap.TPOverlapCfg#

None

fc2_dgrad: bridge.training.comm_overlap.TPOverlapCfg#

None

proj_fprop: bridge.training.comm_overlap.TPOverlapCfg#

None

fc2_fprop: bridge.training.comm_overlap.TPOverlapCfg#

None

bridge.training.comm_overlap.userbuffers_bf16_h100_h8192_tp4_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_h100_h8192_tp4_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_b200_h8192_tp2_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_b200_h8192_tp2_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_h100_h16384_tp8_cp2_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_h100_h16384_tp8_cp2_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_b200_h16384_tp4_cp2_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_b200_h16384_tp4_cp2_mbs1_seqlen8192#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_h100_h8192_tp2_mbs1_seqlen4096_lora#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_h100_h16384_tp4_mbs1_seqlen2048_lora#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_h100_h6144_tp2_mbs2_seqlen2048#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_h100_h6144_tp2_mbs2_seqlen2048#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_h100_h12288_tp4_mbs1_seqlen2048#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_h100_h12288_tp4_mbs1_seqlen2048#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_b200_h12288_tp4_mbs1_seqlen2048#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_b200_h12288_tp4_mbs1_seqlen2048#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_b200_h6144_tp2_mbs1_seqlen4096#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_bf16_b200_h18432_tp8_mbs1_seqlen4096#

‘TransformerLayerTPOverlapCfg(…)’

bridge.training.comm_overlap.userbuffers_fp8_b200_h18432_tp8_mbs1_seqlen4096#

‘TransformerLayerTPOverlapCfg(…)’

class bridge.training.comm_overlap._CommOverlapConfig#
tp_comm_overlap: bool#

None

tp_comm_overlap_cfg: dict#

None

tp_comm_bootstrap_backend: str#

None

overlap_p2p_comm: bool#

None

batch_p2p_comm: bool#

None

overlap_grad_reduce: bool#

None

overlap_param_gather: bool#

None

overlap_param_gather_with_optimizer_step: bool#

None

align_param_gather: bool#

None

bucket_size: int#

None

defer_embedding_wgrad_compute: bool#

None

wgrad_deferral_limit: int#

None

class bridge.training.comm_overlap.CommOverlapConfig#

Configuration for communication overlap optimizations in distributed training.

This class manages tensor parallel, pipeline parallel, and data parallel communication overlap settings to improve training performance.

tp_comm_overlap: bool#

None

tp_comm_overlap_cfg: Optional[bridge.training.comm_overlap.TransformerLayerTPOverlapCfg]#

None

tp_comm_bootstrap_backend: Optional[str]#

None

overlap_p2p_comm: Optional[bool]#

None

batch_p2p_comm: Optional[bool]#

None

overlap_grad_reduce: Optional[bool]#

None

overlap_param_gather: Optional[bool]#

None

overlap_param_gather_with_optimizer_step: Optional[bool]#

None

align_param_gather: Optional[bool]#

None

bucket_size: Optional[int]#

None

defer_embedding_wgrad_compute: Optional[bool]#

None

wgrad_deferral_limit: Optional[int]#

None

data_parallel_size: Optional[int]#

None

__post_init__()#
_get_model_comm_overlap_cfgs(
model_cfg: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
) bridge.training.comm_overlap._CommOverlapConfig#
_get_optimizer_overlap_cfgs(
model_cfg: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
) bridge.training.comm_overlap._CommOverlapConfig#
_apply_cfgs(src_cfg, dest_cfg)#
_override_user_cfgs(comm_overlap_cfg)#
_set_num_cuda_device_max_connections(
model_cfg: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
)#
setup(
model_config: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider,
optimizer_config: megatron.core.optimizer.OptimizerConfig,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
) None#

Set up communication overlap configurations for the model, optimizer, and DDP.

Parameters:
  • model_config – Model configuration containing parallelism settings

  • optimizer_config – Optimizer configuration for gradient overlap settings

  • ddp_config – Distributed data parallel configuration