bridge.training.model_load_save#

Module Contents#

Functions#

torch_dtype_from_mcore_config

Convert Megatron-Core config dtype settings to torch dtype.

megatron_cpu_init_context

Context manager to temporarily force CPU initialization for Megatron models.

temporary_distributed_context

Context manager to temporarily initialize a minimal distributed environment.

load_tokenizer

Create a tokenizer from a training checkpoint.

load_model_config

Returns the model config saved in the checkpoint.

build_and_load_model

Load a Megatron model from a distributed checkpoint.

load_megatron_model

Load a Megatron model from a distributed checkpoint.

save_megatron_model

Save a Megatron model in native Megatron checkpoint format without optimizer state.

dtype_from_str

Convert a string representation of a dtype to a torch.dtype.

dtype_from_hf

Extract the torch.dtype from a Hugging Face PretrainedConfig object.

Data#

API#

bridge.training.model_load_save.logger#

‘getLogger(…)’

bridge.training.model_load_save.torch_dtype_from_mcore_config(config: Any) torch.dtype#

Convert Megatron-Core config dtype settings to torch dtype.

Parameters:

config – Megatron-Core configuration object with bf16/fp16 flags.

Returns:

The corresponding torch dtype.

bridge.training.model_load_save.megatron_cpu_init_context(
config: Any,
) Generator[None, None, None]#

Context manager to temporarily force CPU initialization for Megatron models.

This is useful when initializing a model on a system without GPUs or when memory constraints prevent GPU initialization.

Parameters:

config – The Megatron model configuration object (e.g., GPTConfig). Must have a use_cpu_initialization attribute.

Yields:

None. The context modifies the config in place.

bridge.training.model_load_save.temporary_distributed_context(
backend: str = 'gloo',
) Generator[None, None, None]#

Context manager to temporarily initialize a minimal distributed environment.

Sets up a single-process distributed backend, initializes Megatron model parallel state, yields control, and then cleans up the distributed environment. Useful for operations that require Megatron’s parallel state but should run standalone (e.g., loading distributed checkpoints).

Parameters:

backend – The distributed backend to use (“gloo” for CPU, “nccl” for GPU).

Yields:

None.

bridge.training.model_load_save.load_tokenizer(
checkpoint_path: str,
) megatron.bridge.training.tokenizers.tokenizer.MegatronTokenizer#

Create a tokenizer from a training checkpoint.

Obtains tokenizer configuration from the checkpoint and builds the tokenizer. Checkpoint should be in MCore distributed checkpoint format.

Parameters:

checkpoint_path – path to an MCore distributed checkpoint directory (e.g., /path/to/model/checkpoints/iter_0000001).

bridge.training.model_load_save.load_model_config(
checkpoint_path: str,
) tuple[megatron.core.transformer.TransformerConfig, Optional[argparse.Namespace]]#

Returns the model config saved in the checkpoint.

Supports checkpoints saved with either Megatron Bridge or MegatronLM.

Parameters:

checkpoint_path – path to an MCore distributed checkpoint directory (e.g., /path/to/model/checkpoints/iter_0000001).

Returns:

  • The model config from the checkpoint. The object returned will be a model provider (e.g. GPTModelProvider) if using a Megatron Bridge checkpoint or a TransformerConfig if using a MegatronLM checkpoint.

  • If the checkpoint is from MegatronLM, returns the argparse.Namespace object. Otherwise None.

bridge.training.model_load_save.build_and_load_model(
checkpoint_path: str,
model_cfg: megatron.core.transformer.TransformerConfig,
model_type: Optional[Literal[gpt, mamba]] = None,
megatron_args: Optional[argparse.Namespace] = None,
return_state_dict: bool = False,
use_cpu_init: bool = False,
skip_temp_dist_context: Optional[bool] = None,
) Union[Any, dict[str, torch.Tensor]]#

Load a Megatron model from a distributed checkpoint.

Creates model instances and optionally a minimal distributed environment to load the model weights from checkpoint_path into the model. Automatically selects the appropriate distributed backend (Gloo for CPU, NCCL for GPU).

Parameters:
  • checkpoint_path – path to an MCore distributed checkpoint directory. Used only for model weights (e.g., /path/to/model/checkpoints/iter_0000001).

  • model_cfg – Model config from load_model_config(). Either a TransformerConfig or a model provider (e.g. GPTModelProvider) depending on source of checkpoint.

  • model_type – If the checkpoint is from MegatronLM, the model type is required. Currently, only GPT and Mamba models are supported.

  • megatron_args – If the checkpoint is from MegatronLM, this is required.

  • return_state_dict – If True, return the state dict instead of model instance. Default: False.

  • use_cpu_init – If True, use CPU initialization context for the model and Gloo backend. If False, use GPU initialization and NCCL backend. Default: False.

  • skip_temp_dist_context – If True, skip temporary distributed context setup. If None, automatically skip if distributed is already initialized. Default: None.

Returns:

The model instance with loaded weights if return_state_dict is False, otherwise returns a dictionary containing the full, unsharded model state_dict.

bridge.training.model_load_save.load_megatron_model(
checkpoint_path: str,
model_type: Optional[Literal[gpt, mamba]] = None,
return_state_dict: bool = False,
use_cpu_init: bool = False,
skip_temp_dist_context: Optional[bool] = None,
mp_overrides: Optional[megatron.bridge.models.model_provider.ModelParallelKwargs] = None,
) Union[Any, dict[str, torch.Tensor]]#

Load a Megatron model from a distributed checkpoint.

Wrapper around load_model_config() and build_and_load_model() for convenience.

Parameters:
  • checkpoint_path – path to an MCore distributed checkpoint directory (e.g., /path/to/model/checkpoints/iter_0000001).

  • model_type – If the checkpoint is from MegatronLM, the model type is required. Currently, only GPT and Mamba models are supported.

  • return_state_dict – If True, return the state dict instead of model instance. Default: False.

  • use_cpu_init – If True, use CPU initialization context for the model and Gloo backend. If False, use GPU initialization and NCCL backend. Default: False.

  • skip_temp_dist_context – If True, skip temporary distributed context setup. If None, automatically skip if distributed is already initialized. Default: None.

  • mp_overrides – Optional model-parallel overrides to apply to the loaded config. Only provided fields are overridden.

Returns:

The model instance with loaded weights if return_state_dict is False, otherwise returns a dictionary containing the full, unsharded model state_dict.

bridge.training.model_load_save.save_megatron_model(
model: list[megatron.core.transformer.MegatronModule],
path: Union[str, pathlib.Path],
ckpt_format: str = 'torch_dist',
hf_tokenizer_path: Optional[Union[str, pathlib.Path]] = None,
) None#

Save a Megatron model in native Megatron checkpoint format without optimizer state.

This method saves the model in Megatron’s native checkpoint format, which can be loaded directly by Megatron for training or inference. The checkpoint includes the model configuration and weights, NO optimizer state or other artifacts.

Parameters:
  • model – Megatron model instance or list of instances.

  • path – Directory path where the checkpoint will be saved.

  • ckpt_format – Checkpoint format to use (“torch_dist” or other supported formats).

  • hf_tokenizer_path – Optional HuggingFace model ID or path for tokenizer metadata. If provided, the tokenizer metadata will be included in the checkpoint.

.. rubric:: Example

Save model checkpoint

save_megatron_model(megatron_model, “./megatron_checkpoint”)

Save model checkpoint with tokenizer metadata

save_megatron_model( … megatron_model, … “./megatron_checkpoint”, … hf_tokenizer_path=”meta-llama/Meta-Llama-3-8B” … )

.. note::

  • This method is collective and must be called by all ranks

  • The saved checkpoint can be loaded with Megatron’s checkpoint loading utilities

  • The checkpoint format follows Megatron’s standard structure for compatibility

bridge.training.model_load_save.dtype_from_str(dtype: str) torch.dtype#

Convert a string representation of a dtype to a torch.dtype.

Handles common variations like ‘fp16’, ‘bf16-mixed’. Defaults to float32 for unrecognized strings.

Parameters:

dtype – The string representation (e.g., “bf16”, “fp16”, “float32”).

Returns:

The corresponding torch.dtype.

bridge.training.model_load_save.dtype_from_hf(config: Any) torch.dtype#

Extract the torch.dtype from a Hugging Face PretrainedConfig object.

Parameters:

config – A Hugging Face model config object (must have torch_dtype attribute).

Returns:

The corresponding torch.dtype.

Raises:
  • ValueError – If the torch_dtype attribute is not a recognized string or torch.dtype.

  • AttributeError – If the config object does not have a torch_dtype attribute.