bridge.models.common.unimodal#
Module Contents#
Functions#
Build model stages and wrap for distributed training. |
|
Print the number of parameters in the model on rank 0. |
|
Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP). |
|
Build virtual pipeline stages if using virtual pipeline parallelism. |
|
Move tensors to device if not meta device; otherwise materialize with empty_like(). |
Data#
API#
- bridge.models.common.unimodal.logger#
βgetLogger(β¦)β
- bridge.models.common.unimodal.unimodal_build_distributed_models(
- build_model_func: Callable,
- transformer_config: megatron.core.transformer.TransformerConfig,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
- overlap_param_gather_with_optimizer_step: bool = False,
- use_megatron_fsdp: bool = False,
- use_torch_fsdp2: bool = False,
- wrap_with_ddp: bool = True,
- data_parallel_random_init: bool = True,
- mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
- pre_wrap_hook: Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]] | None = None,
- model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
Build model stages and wrap for distributed training.
Shared helper for unimodal models (GPT, Mamba, etc.) that share the same procedure for distributed model initialization. Performs the following steps in order:
Build virtual pipeline stages (one per VP rank, or a single stage if no VP)
Apply
pre_wrap_hookSet tensor model parallel attributes on all parameters
Move model to GPU (unless using FSDP2 or CPU/meta-device initialization)
Apply mixed precision wrapper (e.g.
Float16Module)Materialize meta-device tensors if
init_model_with_meta_deviceis setOptionally wrap with DDP/FSDP
- Parameters:
build_model_func β Callable that builds a single model stage (e.g.
ModelBuilder.build_model).transformer_config β TransformerConfig; used for VP size, precision, and device placement.
pg_collection β Model communication process groups.
ddp_config β DistributedDataParallel configuration. Required when
wrap_with_ddp=True.overlap_param_gather_with_optimizer_step β Whether to overlap parameter gather with optimizer step.
use_megatron_fsdp β Whether to use Megatron FSDP.
use_torch_fsdp2 β Whether to use Torch FSDP 2.0.
wrap_with_ddp β Set to False to skip the DDP/FSDP wrapper.
data_parallel_random_init β Whether to broadcast parameters from data-parallel rank 0.
mixed_precision_wrapper β Mixed precision wrapper applied per model stage, e.g.
Float16Module. PassNoneto skip.pre_wrap_hook β Hook applied to the model stage list before any wrapping.
model_type β Deprecated flag, only used for backwards compatibility.
- Returns:
List of model stages, wrapped and ready for distributed training.
- bridge.models.common.unimodal._print_num_params(
- model: list[megatron.core.transformer.MegatronModule],
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Print the number of parameters in the model on rank 0.
Only prints on data parallel rank 0 to avoid duplicate output. Shows parameter count per (tensor parallel, pipeline parallel) rank.
- Parameters:
model β List of model modules to count parameters from
pg_collection β Model communication process groups.
- bridge.models.common.unimodal._wrap_with_mp_wrapper(
- model_list: list[megatron.core.transformer.MegatronModule],
- transformer_config: megatron.core.transformer.TransformerConfig,
- mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
- bridge.models.common.unimodal._ddp_wrap(
- model: list[megatron.core.transformer.MegatronModule],
- data_parallel_random_init: bool,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- overlap_param_gather_with_optimizer_step: bool,
- use_megatron_fsdp: bool = False,
- use_torch_fsdp2: bool = False,
- *,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP).
- Parameters:
model β List of model modules to wrap
data_parallel_random_init β Whether to broadcast parameters from rank 0
ddp_config β Configuration for distributed data parallel
overlap_param_gather_with_optimizer_step β Whether to disable bucketing for overlapping parameter gathering with optimizer step
use_megatron_fsdp β Whether to use Megatron FSDP.
use_torch_fsdp2 β Whether to use PyTorch FSDP v2 instead of DDP
pg_collection β Model communication process groups.
- Returns:
List of DDP/FSDP wrapped model modules
- Return type:
list[MegatronModule]
- bridge.models.common.unimodal.build_virtual_pipeline_stages(
- build_model_func: Callable,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- vp_size: int | None,
- model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
Build virtual pipeline stages if using virtual pipeline parallelism.
- Parameters:
build_model_func β Function from
ModelBuilderthat builds a single stage of the model.pg_collection β Model communication process groups.
vp_size β Virtual pipeline parallel size. If
Noneor PP size is 1, a single stage is built.model_type β Deprecated flag, only used for backwards compatibility.
- Returns:
List of model stages. Contains one entry per VP rank, or a single entry if VP is not enabled.
- bridge.models.common.unimodal.to_empty_if_meta_device(
- module: torch.nn.Module,
- *,
- device: torch.device,
- recurse=True,
Move tensors to device if not meta device; otherwise materialize with empty_like().
Officially, torch suggests to_empty() for meta device materialization. Under the hood, torch.empty_like() is applied to all parameters or buffers (see _apply). This may accidently overwrite buffers with precomputed values during construction. Given the goal is to only materialize those tensors on meta device, this function checks the device first and only move the tensor to the destination if it is not on meta device.
- Parameters:
module β The target module to apply this transformation.
device β The desired device of the parameters and buffers in this module.
recurse β Whether parameters and buffers of submodules should be recursively moved to the specified device.