bridge.models.model_provider
#
Module Contents#
Classes#
A mixin that implements the ModelProvider pattern for Megatron-Hub. |
|
Keyword arguments for the |
Functions#
Create and configure a model for distributed training. |
|
Create model instances with appropriate pipeline parallel configuration. |
|
Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP). |
|
Print the number of parameters in the model on rank 0. |
Data#
API#
- bridge.models.model_provider.ModelT#
βTypeVar(β¦)β
- class bridge.models.model_provider.ModelProviderMixin#
Bases:
abc.ABC
,typing.Generic
[bridge.models.model_provider.ModelT
]A mixin that implements the ModelProvider pattern for Megatron-Hub.
The ModelProvider pattern solves ecosystem fragmentation by providing a standardized way to instantiate models. This mixin provides a consistent
provide_distributed_model()
method that handles the complexity of distributed training setup, along with HuggingFace-inspired.from_hf_pretrained()
and.save_hf_pretrained()
for interoperability.For advanced customization, multiple hooks can be registered via
register_pre_wrap_hook
andregister_post_wrap_hook
. These hooks allow modifying the model before and after itβs wrapped for distributed training (e.g., freezing layers, logging). The composed hooks can be accessed via thepre_wrap_hook
andpost_wrap_hook
properties.Subclasses must implement the
provide
method to define the model architecture.- CONFIG_NAME#
βmhub_model.jsonβ
- DEFAULT_CONFIG_FORMAT#
βjsonβ
- abstractmethod provide(
- pre_process=None,
- post_process=None,
- vp_stage=None,
Abstract method to provide the model instance.
Subclasses must implement this method to return the specific Megatron model (e.g.,
GPTModel
) with its configuration. This method is called byget_model
to obtain the base model before it is wrapped for distributed training.- Parameters:
pre_process (callable, optional) β A function to be called before model instantiation.
post_process (callable, optional) β A function to be called after model instantiation.
vp_stage (int, optional) β The virtual pipeline stage of the model.
- Returns:
The Megatron model instance.
- Return type:
ModelT
- provide_distributed_model(
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
- model_type=ModelType.encoder_or_decoder,
- overlap_param_gather_with_optimizer_step: bool = False,
- fp16: bool | None = None,
- bf16: bool | None = None,
- use_torch_fsdp2: bool = False,
- wrap_with_ddp: bool = True,
- data_parallel_random_init: bool = True,
- use_cpu_initialization: None | bool = False,
- init_model_with_meta_device: bool | None = None,
- pre_wrap_hook: Union[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]], list[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]]]] | None = None,
- post_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None = None,
Instantiate and wrap the model for distributed training.
This method retrieves the model from
provide
and sets up the distributed environment, including data-parallel and model-parallel configurations. Itβs the primary entry point for creating a model thatβs ready for use in the Megatron ecosystem.- Parameters:
ddp_config β Configuration for distributed data parallel.
model_type β Type of model (encoder, decoder, or both).
overlap_param_gather_with_optimizer_step β Whether to overlap param gathering.
fp16 β Override FP16 setting.
bf16 β Override BF16 setting.
use_torch_fsdp2 β Use PyTorch FSDP2 instead of custom DDP.
wrap_with_ddp β Whether to wrap model with DDP.
data_parallel_random_init β Initialize parameters randomly across data parallel ranks.
use_cpu_initialization β Initialize model on CPU.
init_model_with_meta_device β Initialize model on meta device.
pre_wrap_hook β A single callable or list of callables to modify the model before itβs wrapped. If provided, this will override all hooks registered via
register_pre_wrap_hook
. If a list is provided, hooks will be executed in order.post_wrap_hook β A single callable to modify the model after itβs wrapped. If provided, this will override all hooks registered via
register_post_wrap_hook
.
- Returns:
A list containing the wrapped model instance.
- initialize_model_parallel(
- seed: int | None = None,
- seed_kwargs: dict | None = None,
- **model_parallel_kwargs,
Initializes model parallelism and sets the random seed.
This is a convenience method that sets up tensor, pipeline, and other forms of model parallelism based on the attributes of the provider instance.
- Parameters:
seed β The random seed for model parallel RNG.
seed_kwargs β Additional arguments for
model_parallel_cuda_manual_seed
.**model_parallel_kwargs β Additional arguments for
parallel_state.initialize_model_parallel
.
- property meta_model: list[bridge.models.model_provider.ModelT]#
Returns the model instantiated on the meta device for inspection.
This is useful for examining the model architecture without allocating GPU memory.
- property pre_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None#
A composed callable of all registered pre-wrap hooks.
This read-only property returns a single function that executes all registered pre-wrap hooks in order. The hook is applied before the model is passed to the DDP wrapper and can be used for tasks like freezing layers or altering model structure.
Use
register_pre_wrap_hook
to add a hook to the execution chain.- Returns:
A callable that executes all registered pre-wrap hooks in order, or None if no hooks are registered.
- register_pre_wrap_hook(
- hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]],
- prepend: bool = False,
Registers a hook to be executed before the model is wrapped.
The hook should be a callable that accepts a list of
MegatronModule
instances and returns a (potentially modified) list ofMegatronModule
instances.- Parameters:
hook β The hook to register.
prepend β If True, the hook is inserted at the beginning of the execution chain. Otherwise, it is appended to the end.
- property post_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None#
A composed callable of all registered post-wrap hooks.
This read-only property returns a single function that executes all registered post-wrap hooks in order. The hook is applied after the model has been wrapped by DDP and is useful for tasks like logging or attaching custom attributes.
Use
register_post_wrap_hook
to add a hook to the execution chain.- Returns:
A callable that executes all registered post-wrap hooks in order, or None if no hooks are registered.
- register_post_wrap_hook(
- hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]],
- prepend: bool = False,
Registers a hook to be executed after the model is wrapped.
The hook should be a callable that accepts a list of
MegatronModule
instances and returns a (potentially modified) list ofMegatronModule
instances.- Parameters:
hook β The hook to register.
prepend β If True, the hook is inserted at the beginning of the execution chain. Otherwise, it is appended to the end.
- classmethod from_hf_pretrained(
- pretrained_model_name_or_path: str | pathlib.Path,
- trust_remote_code: bool = False,
- mode: megatron.bridge.utils.instantiate_utils.InstantiationMode | None = None,
- config_name: str | None = None,
- **kwargs,
Load a pretrained model configuration from a directory or HuggingFace Hub.
This method provides a HuggingFace-inspired interface for loading model configurations, enabling interoperability.
- Parameters:
pretrained_model_name_or_path β The path to a local directory or a HuggingFace model identifier.
trust_remote_code β Whether to trust remote code when loading.
mode β The instantiation mode (e.g.,
LENIENT
).config_name β The name of the configuration file (without extension).
**kwargs β Additional keyword arguments for
from_hf_pretrained
.
- Returns:
An instance of the model provider with the loaded configuration.
- save_hf_pretrained(
- save_directory: str | pathlib.Path,
- config_format: str | None = None,
- config_name: str | None = None,
- **kwargs,
Save the model configuration to a directory.
This method provides a HuggingFace-inspired interface for saving model configurations, enabling interoperability.
- Parameters:
save_directory β The directory where the configuration will be saved.
config_format β The format for the configuration file (e.g.,
json
oryaml
).config_name β The name of the configuration file (without extension).
**kwargs β Additional keyword arguments for
save_hf_pretrained
.
- class bridge.models.model_provider.GetModelKwargs#
Bases:
typing.TypedDict
Keyword arguments for the
provide_distributed_model
method... attribute:: ddp_config
Configuration for distributed data parallel.
.. attribute:: model_type
Type of model (encoder, decoder, or both).
.. attribute:: overlap_param_gather_with_optimizer_step
Whether to overlap param gathering.
.. attribute:: fp16
Override FP16 setting.
.. attribute:: bf16
Override BF16 setting.
.. attribute:: use_torch_fsdp2
Use PyTorch FSDP2 instead of custom DDP.
.. attribute:: wrap_with_ddp
Whether to wrap model with DDP.
.. attribute:: data_parallel_random_init
Initialize parameters randomly across data parallel ranks.
.. attribute:: use_cpu_initialization
Initialize model on CPU.
.. attribute:: init_model_with_meta_device
Initialize model on meta device.
.. attribute:: pre_wrap_hook
A single callable or list of callables that overrides all registered pre-wrap hooks.
.. attribute:: post_wrap_hook
A single callable that overrides all registered post-wrap hooks.
Initialization
Initialize self. See help(type(self)) for accurate signature.
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None#
None
- model_type: megatron.core.enums.ModelType#
None
- overlap_param_gather_with_optimizer_step: bool#
None
- fp16: bool | None#
None
- bf16: bool | None#
None
- use_torch_fsdp2: bool#
None
- wrap_with_ddp: bool#
None
- data_parallel_random_init: bool#
None
- use_cpu_initialization: bool | None#
None
- init_model_with_meta_device: bool | None#
None
- pre_wrap_hook: Union[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]], list[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]]]] | None#
None
- post_wrap_hook: Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]] | None#
None
- bridge.models.model_provider.get_model(
- model_provider: bridge.models.model_provider.ModelProviderMixin,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- model_type=ModelType.encoder_or_decoder,
- overlap_param_gather_with_optimizer_step: bool = False,
- fp16: bool | None = None,
- bf16: bool | None = None,
- use_torch_fsdp2: bool = False,
- wrap_with_ddp: bool = True,
- data_parallel_random_init: bool = True,
- use_cpu_initialization: None | bool = False,
- init_model_with_meta_device: bool | None = None,
- pre_wrap_hook: Union[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]], list[Callable[[list[megatron.core.transformer.module.MegatronModule]], list[megatron.core.transformer.module.MegatronModule]]]] | None = None,
Create and configure a model for distributed training.
This function handles the complete model creation pipeline including:
Model instantiation with proper pipeline parallel configuration
GPU memory allocation
Mixed precision (FP16/BF16) wrapping
Float8 tensor correction
Distributed Data Parallel (DDP) wrapping
- Parameters:
model_provider β ModelProviderMixin instance that creates the model. Uses the provide() method with optional pre_process(bool), post_process(bool), vp_stage(int) arguments for pipeline parallelism
ddp_config β Configuration for distributed data parallel training
model_type β Type of model (encoder, decoder, or encoder_and_decoder)
overlap_param_gather_with_optimizer_step β Whether to overlap parameter gathering with optimizer step for performance optimization
fp16 β Enable FP16 mixed precision training. If None, uses model config
bf16 β Enable BF16 mixed precision training. If None, uses model config
use_torch_fsdp2 β Use PyTorchβs Fully Sharded Data Parallel v2
wrap_with_ddp β Whether to wrap the model with DDP
data_parallel_random_init β Whether to use random initialization for data parallel ranks (vs broadcasting from rank 0)
use_cpu_initialization β Whether to initialize model on CPU to save GPU memory
init_model_with_meta_device β Whether to initialize the model on the meta device
pre_wrap_hook β A callable or list of callables that takes a list of
MegatronModule
and returns a modified list, orNone
to clear the hook. If a list is provided, hooks will be executed in order.
- Returns:
List of model modules. Contains multiple modules when using virtual pipeline parallelism, otherwise a single module
- Return type:
list[MegatronModule]
- bridge.models.model_provider._create_model(
- model_provider: bridge.models.model_provider.ModelProviderMixin,
- model_type: megatron.core.enums.ModelType,
Create model instances with appropriate pipeline parallel configuration.
Handles virtual pipeline parallelism (VPP) by creating multiple model instances when needed. Sets pre_process and post_process flags based on pipeline parallel rank.
- Parameters:
model_provider β ModelProviderMixin instance that creates the model
model_type β ModelType enum indicating encoder, decoder, or both
- Returns:
List of model instances. Multiple instances for VPP, otherwise single
- Return type:
list
- bridge.models.model_provider._ddp_wrap(
- model: list[megatron.core.transformer.module.MegatronModule],
- use_torch_fsdp2: bool,
- data_parallel_random_init: bool,
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- overlap_param_gather_with_optimizer_step: bool,
Wrap model with Distributed Data Parallel (DDP) or Fully Sharded Data Parallel (FSDP).
- Parameters:
model β List of model modules to wrap
use_torch_fsdp2 β Whether to use PyTorch FSDP v2 instead of DDP
data_parallel_random_init β Whether to broadcast parameters from rank 0
ddp_config β Configuration for distributed data parallel
overlap_param_gather_with_optimizer_step β Whether to disable bucketing for overlapping parameter gathering with optimizer step
- Returns:
List of DDP/FSDP wrapped model modules
- Return type:
list[MegatronModule]
- bridge.models.model_provider._print_num_params(
- model: list[megatron.core.transformer.module.MegatronModule],
Print the number of parameters in the model on rank 0.
Only prints on data parallel rank 0 to avoid duplicate output. Shows parameter count per (tensor parallel, pipeline parallel) rank.
- Parameters:
model β List of model modules to count parameters from