bridge.training.mlm_compat.model#

Module Contents#

Functions#

_get_transformer_layer_spec

Get transformer layer specification based on configuration.

_gpt_provider

Provide the GPTModel exactly as done by MLM using an argparse args object.

_mamba_provider

Provide the MambaModel exactly as done by MLM using an argparse args object.

API#

bridge.training.mlm_compat.model._get_transformer_layer_spec(
args: argparse.Namespace,
use_te: bool,
use_kitchen: bool,
) megatron.core.transformer.ModuleSpec#

Get transformer layer specification based on configuration.

Parameters:
  • args – Training arguments

  • use_te – Whether to use Transformer Engine

  • use_kitchen – Whether to use kitchen extension

Returns:

The transformer layer specification

Return type:

transformer_layer_spec

bridge.training.mlm_compat.model._gpt_provider(
args: argparse.Namespace,
config: Optional[megatron.core.transformer.TransformerConfig] = None,
pre_process: bool = True,
post_process: bool = True,
vp_stage: Optional[int] = None,
) megatron.core.models.gpt.GPTModel#

Provide the GPTModel exactly as done by MLM using an argparse args object.

May need to set args and config with functools.partial.

bridge.training.mlm_compat.model._mamba_provider(
args: argparse.Namespace,
config: Optional[megatron.core.transformer.TransformerConfig] = None,
pre_process: bool = True,
post_process: bool = True,
) megatron.core.models.mamba.MambaModel#

Provide the MambaModel exactly as done by MLM using an argparse args object.

May need to set args and config with functools.partial.