nemo_export.utils.model_loader
#
Module Contents#
Classes#
Reader that accepts both Path and TarPath checkpoint directory. |
Functions#
Creates Path / TarPath object suitable for navigating inside the nemo checkpoint. |
|
Loads model state dictionary from torch_dist checkpoint. |
|
Loads model extra states from the .pt shards. |
|
Checks if zarr directory contains extra states. |
|
Loads model dictionary from the zarr format. |
|
Returns a Path pointing to the weights directory inside the NeMo checkpoint. |
|
Loads NeMo state dictionary. |
Data#
API#
- nemo_export.utils.model_loader.LOGGER = 'getLogger(...)'#
- nemo_export.utils.model_loader.nemo_to_path(
- nemo_checkpoint: Union[pathlib.Path, str],
Creates Path / TarPath object suitable for navigating inside the nemo checkpoint.
- Parameters:
nemo_checkpoint (Path, str) β Path to the NeMo checkpoint.
- Returns:
Suitable Path object for navigating through the checkpoint.
- Return type:
Path | TarPath
- class nemo_export.utils.model_loader.TarFileSystemReader(
- path: Union[pathlib.Path, nemo_export.tarutils.TarPath],
Bases:
torch.distributed.checkpoint.FileSystemReader
Reader that accepts both Path and TarPath checkpoint directory.
The FileSystemReader works with TarPath, but expects a pure Path. Itβs enough to skip the Path check in init.
Initialization
Makes sure that super().init gets a pure path as expected.
- nemo_export.utils.model_loader.load_sharded_metadata_torch_dist(
- checkpoint_dir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
- load_extra_states: bool = False,
Loads model state dictionary from torch_dist checkpoint.
- Parameters:
checkpoint_dir (Path | TarPath) β Path to the model weights directory.
load_extra_states (bool) β If set to true, loads BytesIO objects, related to the extra states.
- Returns:
Loaded model state dictionary (weights are stored in torch tensors).
- Return type:
dict
- nemo_export.utils.model_loader.load_sharded_pickle_extra_state_scale(
- dir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
Loads model extra states from the .pt shards.
- Parameters:
dir (Path | TarPath) β Path to the directory with sharded extra states.
- Returns:
State dictionary corresponding to the loaded extra states.
- Return type:
dict
- nemo_export.utils.model_loader.contains_extra_states(
- subdir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
Checks if zarr directory contains extra states.
- Parameters:
subdir (Path | TarPath) β Directory inside the zarr checkpoint.
- Returns:
Is a directory with extra states
- Return type:
bool
- nemo_export.utils.model_loader.load_sharded_metadata_zarr(
- checkpoint_dir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
- load_extra_states: bool = False,
Loads model dictionary from the zarr format.
- Parameters:
checkpoint_dir (Path | TarPath) β Path to the NeMo checkpoint.
load_extra_states (bool) β If set to True, the function will load BufferIO objects with extra states.
- Returns:
Model state dictionary.
- Return type:
dict
- nemo_export.utils.model_loader.nemo_weights_directory(
- nemo_path: Union[pathlib.Path, nemo_export.tarutils.TarPath],
Returns a Path pointing to the weights directory inside the NeMo checkpoint.
- nemo_export.utils.model_loader.load_model_weights(
- checkpoint_path: Union[str, pathlib.Path],
- load_extra_states: bool = False,
Loads NeMo state dictionary.
Weights are stored in torch.Tensor
- Parameters:
checkpoint_path (str | Path) β Path to the NeMo checkpoint.
load_extra_states (bool) β If True, loads BytesIO objects, corresponding to the extra states.
- Returns:
Model state dictionary.
- Return type:
dict