bridge.models.model_provider#

Module Contents#

Classes#

ModelProviderMixin

A mixin that implements the ModelProvider pattern for Megatron-Hub.

GetModelKwargs

Keyword arguments for the provide_distributed_model method.

Functions#

get_model

Create and configure a model for distributed training.

_create_model

Create model instances with appropriate pipeline parallel configuration.

_ddp_wrap

Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP).

_print_num_params

Print the number of parameters in the model on rank 0.

Data#

API#

bridge.models.model_provider.ModelT#

β€˜TypeVar(…)’

class bridge.models.model_provider.ModelProviderMixin#

Bases: abc.ABC, typing.Generic[bridge.models.model_provider.ModelT]

A mixin that implements the ModelProvider pattern for Megatron-Hub.

The ModelProvider pattern solves ecosystem fragmentation by providing a standardized way to instantiate models. This mixin provides a consistent provide_distributed_model() method that handles the complexity of distributed training setup, along with HuggingFace-inspired .from_hf_pretrained() and .save_hf_pretrained() for interoperability.

For advanced customization, multiple hooks can be registered via register_pre_wrap_hook and register_post_wrap_hook. These hooks allow modifying the model before and after it’s wrapped for distributed training (e.g., freezing layers, logging). The composed hooks can be accessed via the pre_wrap_hook and post_wrap_hook properties.

Subclasses must implement the provide method to define the model architecture.

CONFIG_NAME#

β€˜mhub_model.json’

DEFAULT_CONFIG_FORMAT#

β€˜json’

abstractmethod provide(
pre_process=None,
post_process=None,
vp_stage=None,
) bridge.models.model_provider.ModelT#

Abstract method to provide the model instance.

Subclasses must implement this method to return the specific Megatron model (e.g., GPTModel) with its configuration. This method is called by get_model to obtain the base model before it is wrapped for distributed training.

Parameters:
  • pre_process (callable, optional) – A function to be called before model instantiation.

  • post_process (callable, optional) – A function to be called after model instantiation.

  • vp_stage (int, optional) – The virtual pipeline stage of the model.

Returns:

The Megatron model instance.

Return type:

ModelT

provide_distributed_model(
ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
model_type=ModelType.encoder_or_decoder,
overlap_param_gather_with_optimizer_step: bool = False,
fp16: bool | None = None,
bf16: bool | None = None,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
use_cpu_initialization: None | bool = False,
init_model_with_meta_device: bool | None = None,
pre_wrap_hook: Union[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]], list[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]]]] | None = None,
post_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None = None,
) list[bridge.models.model_provider.ModelT]#

Instantiate and wrap the model for distributed training.

This method retrieves the model from provide and sets up the distributed environment, including data-parallel and model-parallel configurations. It’s the primary entry point for creating a model that’s ready for use in the Megatron ecosystem.

Parameters:
  • ddp_config – Configuration for distributed data parallel.

  • model_type – Type of model (encoder, decoder, or both).

  • overlap_param_gather_with_optimizer_step – Whether to overlap param gathering.

  • fp16 – Override FP16 setting.

  • bf16 – Override BF16 setting.

  • use_torch_fsdp2 – Use PyTorch FSDP2 instead of custom DDP.

  • wrap_with_ddp – Whether to wrap model with DDP.

  • data_parallel_random_init – Initialize parameters randomly across data parallel ranks.

  • use_cpu_initialization – Initialize model on CPU.

  • init_model_with_meta_device – Initialize model on meta device.

  • pre_wrap_hook – A single callable or list of callables to modify the model before it’s wrapped. If provided, this will override all hooks registered via register_pre_wrap_hook. If a list is provided, hooks will be executed in order.

  • post_wrap_hook – A single callable to modify the model after it’s wrapped. If provided, this will override all hooks registered via register_post_wrap_hook.

Returns:

A list containing the wrapped model instance.

initialize_model_parallel(
seed: int | None = None,
seed_kwargs: dict | None = None,
**model_parallel_kwargs,
) None#

Initializes model parallelism and sets the random seed.

This is a convenience method that sets up tensor, pipeline, and other forms of model parallelism based on the attributes of the provider instance.

Parameters:
  • seed – The random seed for model parallel RNG.

  • seed_kwargs – Additional arguments for model_parallel_cuda_manual_seed.

  • **model_parallel_kwargs – Additional arguments for parallel_state.initialize_model_parallel.

property meta_model: list[bridge.models.model_provider.ModelT]#

Returns the model instantiated on the meta device for inspection.

This is useful for examining the model architecture without allocating GPU memory.

property pre_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None#

A composed callable of all registered pre-wrap hooks.

This read-only property returns a single function that executes all registered pre-wrap hooks in order. The hook is applied before the model is passed to the DDP wrapper and can be used for tasks like freezing layers or altering model structure.

Use register_pre_wrap_hook to add a hook to the execution chain.

Returns:

A callable that executes all registered pre-wrap hooks in order, or None if no hooks are registered.

register_pre_wrap_hook(
hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]],
prepend: bool = False,
) None#

Registers a hook to be executed before the model is wrapped.

The hook should be a callable that accepts a list of MegatronModule instances and returns a (potentially modified) list of MegatronModule instances.

Parameters:
  • hook – The hook to register.

  • prepend – If True, the hook is inserted at the beginning of the execution chain. Otherwise, it is appended to the end.

property post_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None#

A composed callable of all registered post-wrap hooks.

This read-only property returns a single function that executes all registered post-wrap hooks in order. The hook is applied after the model has been wrapped by DDP and is useful for tasks like logging or attaching custom attributes.

Use register_post_wrap_hook to add a hook to the execution chain.

Returns:

A callable that executes all registered post-wrap hooks in order, or None if no hooks are registered.

register_post_wrap_hook(
hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]],
prepend: bool = False,
) None#

Registers a hook to be executed after the model is wrapped.

The hook should be a callable that accepts a list of MegatronModule instances and returns a (potentially modified) list of MegatronModule instances.

Parameters:
  • hook – The hook to register.

  • prepend – If True, the hook is inserted at the beginning of the execution chain. Otherwise, it is appended to the end.

classmethod from_hf_pretrained(
pretrained_model_name_or_path: str | pathlib.Path,
trust_remote_code: bool = False,
mode: megatron.bridge.utils.instantiate_utils.InstantiationMode | None = None,
config_name: str | None = None,
**kwargs,
)#

Load a pretrained model configuration from a directory or HuggingFace Hub.

This method provides a HuggingFace-inspired interface for loading model configurations, enabling interoperability.

Parameters:
  • pretrained_model_name_or_path – The path to a local directory or a HuggingFace model identifier.

  • trust_remote_code – Whether to trust remote code when loading.

  • mode – The instantiation mode (e.g., LENIENT).

  • config_name – The name of the configuration file (without extension).

  • **kwargs – Additional keyword arguments for from_hf_pretrained.

Returns:

An instance of the model provider with the loaded configuration.

save_hf_pretrained(
save_directory: str | pathlib.Path,
config_format: str | None = None,
config_name: str | None = None,
**kwargs,
)#

Save the model configuration to a directory.

This method provides a HuggingFace-inspired interface for saving model configurations, enabling interoperability.

Parameters:
  • save_directory – The directory where the configuration will be saved.

  • config_format – The format for the configuration file (e.g., json or yaml).

  • config_name – The name of the configuration file (without extension).

  • **kwargs – Additional keyword arguments for save_hf_pretrained.

class bridge.models.model_provider.GetModelKwargs#

Bases: typing.TypedDict

Keyword arguments for the provide_distributed_model method.

.. attribute:: ddp_config

Configuration for distributed data parallel.

.. attribute:: model_type

Type of model (encoder, decoder, or both).

.. attribute:: overlap_param_gather_with_optimizer_step

Whether to overlap param gathering.

.. attribute:: fp16

Override FP16 setting.

.. attribute:: bf16

Override BF16 setting.

.. attribute:: use_torch_fsdp2

Use PyTorch FSDP2 instead of custom DDP.

.. attribute:: wrap_with_ddp

Whether to wrap model with DDP.

.. attribute:: data_parallel_random_init

Initialize parameters randomly across data parallel ranks.

.. attribute:: use_cpu_initialization

Initialize model on CPU.

.. attribute:: init_model_with_meta_device

Initialize model on meta device.

.. attribute:: pre_wrap_hook

A single callable or list of callables that overrides all registered pre-wrap hooks.

.. attribute:: post_wrap_hook

A single callable that overrides all registered post-wrap hooks.

Initialization

Initialize self. See help(type(self)) for accurate signature.

ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None#

None

model_type: megatron.core.enums.ModelType#

None

overlap_param_gather_with_optimizer_step: bool#

None

fp16: bool | None#

None

bf16: bool | None#

None

use_torch_fsdp2: bool#

None

wrap_with_ddp: bool#

None

data_parallel_random_init: bool#

None

use_cpu_initialization: bool | None#

None

init_model_with_meta_device: bool | None#

None

pre_wrap_hook: Union[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]], list[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]]]] | None#

None

post_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None#

None

bridge.models.model_provider.get_model(
model_provider: bridge.models.model_provider.ModelProviderMixin,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
model_type=ModelType.encoder_or_decoder,
overlap_param_gather_with_optimizer_step: bool = False,
fp16: bool | None = None,
bf16: bool | None = None,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
use_cpu_initialization: None | bool = False,
init_model_with_meta_device: bool | None = None,
pre_wrap_hook: Union[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]], list[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]]]] | None = None,
) list[megatron.core.transformer.module.MegatronModule]#

Create and configure a model for distributed training.

This function handles the complete model creation pipeline including:

  • Model instantiation with proper pipeline parallel configuration

  • GPU memory allocation

  • Mixed precision (FP16/BF16) wrapping

  • Float8 tensor correction

  • Distributed Data Parallel (DDP) wrapping

Parameters:
  • model_provider – ModelProviderMixin instance that creates the model. Uses the provide() method with optional pre_process(bool), post_process(bool), vp_stage(int) arguments for pipeline parallelism

  • ddp_config – Configuration for distributed data parallel training

  • model_type – Type of model (encoder, decoder, or encoder_and_decoder)

  • overlap_param_gather_with_optimizer_step – Whether to overlap parameter gathering with optimizer step for performance optimization

  • fp16 – Enable FP16 mixed precision training. If None, uses model config

  • bf16 – Enable BF16 mixed precision training. If None, uses model config

  • use_torch_fsdp2 – Use PyTorch’s Fully Sharded Data Parallel v2

  • wrap_with_ddp – Whether to wrap the model with DDP

  • data_parallel_random_init – Whether to use random initialization for data parallel ranks (vs broadcasting from rank 0)

  • use_cpu_initialization – Whether to initialize model on CPU to save GPU memory

  • init_model_with_meta_device – Whether to initialize the model on the meta device

  • pre_wrap_hook – A callable or list of callables that takes a list of MegatronModule and returns a modified list, or None to clear the hook. If a list is provided, hooks will be executed in order.

Returns:

List of model modules. Contains multiple modules when using virtual pipeline parallelism, otherwise a single module

Return type:

list[MegatronModule]

bridge.models.model_provider._create_model(
model_provider: bridge.models.model_provider.ModelProviderMixin,
model_type: megatron.core.enums.ModelType,
) list[megatron.core.transformer.module.MegatronModule]#

Create model instances with appropriate pipeline parallel configuration.

Handles virtual pipeline parallelism (VPP) by creating multiple model instances when needed. Sets pre_process and post_process flags based on pipeline parallel rank.

Parameters:
  • model_provider – ModelProviderMixin instance that creates the model

  • model_type – ModelType enum indicating encoder, decoder, or both

Returns:

List of model instances. Multiple instances for VPP, otherwise single

Return type:

list

bridge.models.model_provider._ddp_wrap(
model: list[megatron.core.transformer.module.MegatronModule],
use_torch_fsdp2: bool,
data_parallel_random_init: bool,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
overlap_param_gather_with_optimizer_step: bool,
) list[megatron.core.transformer.module.MegatronModule]#

Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP).

Parameters:
  • model – List of model modules to wrap

  • use_torch_fsdp2 – Whether to use PyTorch FSDP v2 instead of DDP

  • data_parallel_random_init – Whether to broadcast parameters from rank 0

  • ddp_config – Configuration for distributed data parallel

  • overlap_param_gather_with_optimizer_step – Whether to disable bucketing for overlapping parameter gathering with optimizer step

Returns:

List of DDP/FSDP wrapped model modules

Return type:

list[MegatronModule]

bridge.models.model_provider._print_num_params(
model: list[megatron.core.transformer.module.MegatronModule],
) None#

Print the number of parameters in the model on rank 0.

Only prints on data parallel rank 0 to avoid duplicate output. Shows parameter count per (tensor parallel, pipeline parallel) rank.

Parameters:

model – List of model modules to count parameters from