bridge.training.model_load_save
#
Module Contents#
Functions#
Convert Megatron-Core config dtype settings to torch dtype. |
|
Context manager to temporarily force CPU initialization for Megatron models. |
|
Context manager to temporarily initialize a minimal distributed environment. |
|
Create a tokenizer from a training checkpoint. |
|
Load a Megatron model from a distributed checkpoint. |
|
Save a Megatron model in native Megatron checkpoint format without optimizer state. |
|
Convert a string representation of a dtype to a torch.dtype. |
|
Extract the torch.dtype from a Hugging Face PretrainedConfig object. |
Data#
API#
- bridge.training.model_load_save.logger#
âgetLogger(âŠ)â
- bridge.training.model_load_save.torch_dtype_from_mcore_config(config: Any) torch.dtype #
Convert Megatron-Core config dtype settings to torch dtype.
- Parameters:
config â Megatron-Core configuration object with bf16/fp16 flags.
- Returns:
The corresponding torch dtype.
- bridge.training.model_load_save.megatron_cpu_init_context(
- config: Any,
Context manager to temporarily force CPU initialization for Megatron models.
This is useful when initializing a model on a system without GPUs or when memory constraints prevent GPU initialization.
- Parameters:
config â The Megatron model configuration object (e.g., GPTConfig). Must have a
use_cpu_initialization
attribute.- Yields:
None. The context modifies the config in place.
- bridge.training.model_load_save.temporary_distributed_context(
- backend: str = 'gloo',
Context manager to temporarily initialize a minimal distributed environment.
Sets up a single-process distributed backend, initializes Megatron model parallel state, yields control, and then cleans up the distributed environment. Useful for operations that require Megatronâs parallel state but should run standalone (e.g., loading distributed checkpoints).
- Parameters:
backend â The distributed backend to use (âglooâ for CPU, âncclâ for GPU).
- Yields:
None.
- bridge.training.model_load_save.load_tokenizer(
- checkpoint_path: str,
Create a tokenizer from a training checkpoint.
Obtains tokenizer configuration from the checkpoint and builds the tokenizer. Checkpoint should be in MCore distributed checkpoint format.
- Parameters:
checkpoint_path â path to an MCore distributed checkpoint directory (e.g., /path/to/model/checkpoints/iter_0000001).
- bridge.training.model_load_save.load_megatron_model(
- checkpoint_path: str,
- model_type: Optional[Literal[gpt, mamba]] = None,
- return_state_dict: bool = False,
- use_cpu_init: bool = True,
- skip_temp_dist_context: Optional[bool] = None,
Load a Megatron model from a distributed checkpoint.
Creates a model instance and optionally a minimal distributed environment to load the model weights from
checkpoint_path
into the model. Automatically selects the appropriate distributed backend (Gloo for CPU, NCCL for GPU).- Parameters:
checkpoint_path â path to an MCore distributed checkpoint directory (e.g., /path/to/model/checkpoints/iter_0000001).
model_type â If the checkpoint is from MegatronLM, the model type is required. Currently, only GPT and Mamba models are supported.
return_state_dict â If True, return the state dict instead of model instance. Default: False.
use_cpu_init â If True, use CPU initialization context for the model and Gloo backend. If False, use GPU initialization and NCCL backend. Default: True.
skip_temp_dist_context â If True, skip temporary distributed context setup. If None, automatically skip if distributed is already initialized. Default: None.
- Returns:
The model instance with loaded weights if return_state_dict is False, otherwise returns a dictionary containing the full, unsharded model state_dict.
- bridge.training.model_load_save.save_megatron_model(
- model: list[megatron.core.transformer.MegatronModule],
- path: Union[str, pathlib.Path],
- ckpt_format: str = 'torch_dist',
Save a Megatron model in native Megatron checkpoint format without optimizer state.
This method saves the model in Megatronâs native checkpoint format, which can be loaded directly by Megatron for training or inference. The checkpoint includes the model configuration and weights, NO optimizer state or other artifacts.
- Parameters:
model â Megatron model instance or list of instances.
path â Directory path where the checkpoint will be saved.
ckpt_format â Checkpoint format to use (âtorch_distâ or other supported formats).
.. rubric:: Example
Save model checkpoint
save_megatron_model(megatron_model, â./megatron_checkpointâ)
.. note::
This method is collective and must be called by all ranks
The saved checkpoint can be loaded with Megatronâs checkpoint loading utilities
The checkpoint format follows Megatronâs standard structure for compatibility
- bridge.training.model_load_save.dtype_from_str(dtype: str) torch.dtype #
Convert a string representation of a dtype to a torch.dtype.
Handles common variations like âfp16â, âbf16-mixedâ. Defaults to float32 for unrecognized strings.
- Parameters:
dtype â The string representation (e.g., âbf16â, âfp16â, âfloat32â).
- Returns:
The corresponding torch.dtype.
- bridge.training.model_load_save.dtype_from_hf(config: Any) torch.dtype #
Extract the torch.dtype from a Hugging Face PretrainedConfig object.
- Parameters:
config â A Hugging Face model config object (must have
torch_dtype
attribute).- Returns:
The corresponding torch.dtype.
- Raises:
ValueError â If the
torch_dtype
attribute is not a recognized string or torch.dtype.AttributeError â If the config object does not have a
torch_dtype
attribute.