bridge.models.common.base#

Module Contents#

Classes#

Serializable

Protocol for serializable configurations.

ModelConfig

Base class for model configurations.

ModelBuilder

Abstract base class for model builders.

Functions#

compose_hooks

Utility to compose pre/post-wrap hooks into a single function, preserving order.

Data#

API#

class bridge.models.common.base.Serializable#

Bases: typing.Protocol

Protocol for serializable configurations.

as_dict() dict[str, Any]#

Serialize to dictionary with target for class identification.

classmethod from_dict(
data: dict[str, Any],
) bridge.models.common.base.Serializable#

Deserialize from dictionary using target to identify class.

class bridge.models.common.base.ModelConfig#

Base class for model configurations.

Each model type (GPT, T5, Mamba, etc.) defines a concrete subclass with its own model-specific parameters. This class is a pure data container - all model construction logic lives in the corresponding ModelBuilder subclass.

Subclasses must define: - builder: a ClassVar[str] with the full import path to the associated ModelBuilder (e.g. 'megatron.bridge.models.mamba.MambaModelBuilder').

Subclasses may also embed nested configs (e.g. TransformerConfig) and proxy attribute access to them via __getattr__/__setattr__ overrides.

Serialization: Use as_dict() to serialize to a plain dict (includes a _target_ key for class resolution and a _builder_ key for builder resolution). Use from_dict() to reconstruct an instance from such a dict.

Builder resolution: Call get_builder_cls() to dynamically import and return the builder class identified by the builder ClassVar.

builder: ClassVar[str]#

None

Class variable with full path to builder class (e.g., ‘megatron.bridge.builders.GPTModelBuilder’).

restore_modelopt_state: bool#

False

Restore ModelOpt quantization/sparsity state.

hf_model_id: str | None#

None

HuggingFace model identifier.

generation_config: Any | None#

None

Generation configuration.

pre_wrap_hooks: list[Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]]#

‘field(…)’

List of functions that are executed before the model is wrapped with DDP/FSDP. Should take the model as the only argument and return a new model as the only return value.

post_wrap_hooks: list[Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]]#

‘field(…)’

List of functions that are executed after model initialization is complete. Should take the model as the only argument and return a new model as the only return value.

get_builder_cls() type#

Get the appropriate builder type for this config. Dynamically imports the builder from the string path.

as_dict() dict[str, Any]#

Serialize config to dictionary for saving.

Includes:

  • target: Full class path for deserialization

  • builder: Full builder class path (serialized from ClassVar)

  • All dataclass fields, including nested dataclasses

classmethod from_dict(
data: dict[str, Any],
) bridge.models.common.base.ModelConfig#

Deserialize config from dictionary.

Uses target to determine the correct class to instantiate. The builder is restored from builder or from the class’s ClassVar.

Parameters:

data – Dictionary with target and config fields

Returns:

Instance of the appropriate ModelConfig subclass

bridge.models.common.base.ModelT#

‘TypeVar(…)’

bridge.models.common.base.BuildConfigT#

‘TypeVar(…)’

class bridge.models.common.base.ModelBuilder(model_config: bridge.models.common.base.ModelConfig)#

Bases: abc.ABC, typing.Generic[bridge.models.common.base.ModelT, bridge.models.common.base.BuildConfigT]

Abstract base class for model builders.

A builder takes a ModelConfig and produces distributed model instances - either a single pipeline stage via build_model(), or a list of stages wrapped for distributed training via build_distributed_models().

Each builder subclass should:

  1. Implement build_model() for the specific model type

  2. Implement build_distributed_models() to handle virtual pipeline parallelism, DDP/FSDP wrapping, and pre/post-wrap hook execution

  3. Be linked to its corresponding ModelConfig via the builder ClassVar

Builders are factory objects, therefore any state saved in init should not be modified and only used to build the model.

Type Parameters: ModelT: The type of model this builder produces (e.g., MCoreGPTModel) BuildConfigT: The type of build config this builder accepts (e.g., GPTModelBuildConfig)

Initialization

abstractmethod build_model(
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
pre_process: bool | None = None,
post_process: bool | None = None,
vp_stage: int | None = None,
) bridge.models.common.base.ModelT#

Build a model from the provided configurations.

Parameters:
  • pg_collection – Process groups for distributed training

  • pre_process – Include embedding layer

  • post_process – Include output layer

  • vp_stage – Virtual pipeline stage

Returns:

The constructed model

abstractmethod build_distributed_models(
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
ddp_config: megatron.core.distributed.DistributedDataParallelConfig | None = None,
overlap_param_gather_with_optimizer_step: bool = False,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
mixed_precision_wrapper: Callable[[Any, megatron.core.transformer.MegatronModule], megatron.core.transformer.MegatronModule] | None = Float16Module,
model_type: megatron.core.enums.ModelType = ModelType.encoder_or_decoder,
) list[bridge.models.common.base.ModelT]#

Build model stages and wrap for distributed training.

Parameters:
  • pg_collection – Model communication process groups.

  • ddp_config – DistributedDataParallel configuration

  • overlap_param_gather_with_optimizer_step – Whether to overlap parameter gather with optimizer step

  • use_megatron_fsdp – Whether to use Megatron FSDP

  • use_torch_fsdp2 – Whether to use Torch FSDP 2.0

  • wrap_with_ddp – Set to False to skip DDP wrapper

  • data_parallel_random_init – Whether to use data parallel random initialization

  • mixed_precision_wrapper – Mixed precision wrapper, e.g. Float16Module

  • model_type – Deprecated flag, only used for backwards compatibility.

Returns:

List of model stages. If the model does not support virtual pipeline parallelism, this function should still return a single-item list.

bridge.models.common.base.compose_hooks(
hooks: list[Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]],
) Callable[[list[megatron.core.transformer.MegatronModule]], list[megatron.core.transformer.MegatronModule]]#

Utility to compose pre/post-wrap hooks into a single function, preserving order.

If hooks is empty, the returned function is an identity operation.

Parameters:

hooks – the list of hooks.

Returns:

A single function that executes all functions in hooks.