bridge.models.common.unimodal#

Module Contents#

Functions#

unimodal_build_distributed_models

Build model stages and wrap for distributed training.

_print_num_params

Print the number of parameters in the model on rank 0.

_wrap_with_mp_wrapper

_ddp_wrap

Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP).

build_virtual_pipeline_stages

Build virtual pipeline stages if using virtual pipeline parallelism.

to_empty_if_meta_device

Move tensors to device if not meta device; otherwise materialize with empty_like().

Data#

API#

bridge.models.common.unimodal.logger#

β€˜getLogger(…)’

bridge.models.common.unimodal.unimodal_build_distributed_models(
build_model_func: Callable,
transformer_config: megatron.core.transformer.TransformerConfig,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
overlap_param_gather_with_optimizer_step: bool = False,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
pre_wrap_hook: Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]] | None = None,
model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
) list[megatron.core.transformer.MegatronModule]#

Build model stages and wrap for distributed training.

Shared helper for unimodal models (GPT, Mamba, etc.) that share the same procedure for distributed model initialization. Performs the following steps in order:

  1. Build virtual pipeline stages (one per VP rank, or a single stage if no VP)

  2. Apply pre_wrap_hook

  3. Set tensor model parallel attributes on all parameters

  4. Move model to GPU (unless using FSDP2 or CPU/meta-device initialization)

  5. Apply mixed precision wrapper (e.g. Float16Module)

  6. Materialize meta-device tensors if init_model_with_meta_device is set

  7. Optionally wrap with DDP/FSDP

Parameters:
  • build_model_func – Callable that builds a single model stage (e.g. ModelBuilder.build_model).

  • transformer_config – TransformerConfig; used for VP size, precision, and device placement.

  • pg_collection – Model communication process groups.

  • ddp_config – DistributedDataParallel configuration. Required when wrap_with_ddp=True.

  • overlap_param_gather_with_optimizer_step – Whether to overlap parameter gather with optimizer step.

  • use_megatron_fsdp – Whether to use Megatron FSDP.

  • use_torch_fsdp2 – Whether to use Torch FSDP 2.0.

  • wrap_with_ddp – Set to False to skip the DDP/FSDP wrapper.

  • data_parallel_random_init – Whether to broadcast parameters from data-parallel rank 0.

  • mixed_precision_wrapper – Mixed precision wrapper applied per model stage, e.g. Float16Module. Pass None to skip.

  • pre_wrap_hook – Hook applied to the model stage list before any wrapping.

  • model_type – Deprecated flag, only used for backwards compatibility.

Returns:

List of model stages, wrapped and ready for distributed training.

bridge.models.common.unimodal._print_num_params(
model: list[megatron.core.transformer.MegatronModule],
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) None#

Print the number of parameters in the model on rank 0.

Only prints on data parallel rank 0 to avoid duplicate output. Shows parameter count per (tensor parallel, pipeline parallel) rank.

Parameters:
  • model – List of model modules to count parameters from

  • pg_collection – Model communication process groups.

bridge.models.common.unimodal._wrap_with_mp_wrapper(
model_list: list[megatron.core.transformer.MegatronModule],
transformer_config: megatron.core.transformer.TransformerConfig,
mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
) list[megatron.core.transformer.MegatronModule]#
bridge.models.common.unimodal._ddp_wrap(
model: list[megatron.core.transformer.MegatronModule],
data_parallel_random_init: bool,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
overlap_param_gather_with_optimizer_step: bool,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) list[megatron.core.transformer.MegatronModule]#

Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP).

Parameters:
  • model – List of model modules to wrap

  • data_parallel_random_init – Whether to broadcast parameters from rank 0

  • ddp_config – Configuration for distributed data parallel

  • overlap_param_gather_with_optimizer_step – Whether to disable bucketing for overlapping parameter gathering with optimizer step

  • use_megatron_fsdp – Whether to use Megatron FSDP.

  • use_torch_fsdp2 – Whether to use PyTorch FSDP v2 instead of DDP

  • pg_collection – Model communication process groups.

Returns:

List of DDP/FSDP wrapped model modules

Return type:

list[MegatronModule]

bridge.models.common.unimodal.build_virtual_pipeline_stages(
build_model_func: Callable,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
vp_size: int | None,
model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
) list[megatron.core.transformer.MegatronModule]#

Build virtual pipeline stages if using virtual pipeline parallelism.

Parameters:
  • build_model_func – Function from ModelBuilder that builds a single stage of the model.

  • pg_collection – Model communication process groups.

  • vp_size – Virtual pipeline parallel size. If None or PP size is 1, a single stage is built.

  • model_type – Deprecated flag, only used for backwards compatibility.

Returns:

List of model stages. Contains one entry per VP rank, or a single entry if VP is not enabled.

bridge.models.common.unimodal.to_empty_if_meta_device(
module: torch.nn.Module,
*,
device: torch.device,
recurse=True,
)#

Move tensors to device if not meta device; otherwise materialize with empty_like().

Officially, torch suggests to_empty() for meta device materialization. Under the hood, torch.empty_like() is applied to all parameters or buffers (see _apply). This may accidently overwrite buffers with precomputed values during construction. Given the goal is to only materialize those tensors on meta device, this function checks the device first and only move the tensor to the destination if it is not on meta device.

Parameters:
  • module – The target module to apply this transformation.

  • device – The desired device of the parameters and buffers in this module.

  • recurse – Whether parameters and buffers of submodules should be recursively moved to the specified device.