nemo_export.utils.model_loader#

Module Contents#

Classes#

TarFileSystemReader

Reader that accepts both Path and TarPath checkpoint directory.

Functions#

nemo_to_path

Creates Path / TarPath object suitable for navigating inside the nemo checkpoint.

load_sharded_metadata_torch_dist

Loads model state dictionary from torch_dist checkpoint.

load_sharded_pickle_extra_state_scale

Loads model extra states from the .pt shards.

contains_extra_states

Checks if zarr directory contains extra states.

load_sharded_metadata_zarr

Loads model dictionary from the zarr format.

nemo_weights_directory

Returns a Path pointing to the weights directory inside the NeMo checkpoint.

load_model_weights

Loads NeMo state dictionary.

Data#

API#

nemo_export.utils.model_loader.LOGGER = 'getLogger(...)'#
nemo_export.utils.model_loader.nemo_to_path(
nemo_checkpoint: Union[pathlib.Path, str],
) Union[pathlib.Path, nemo_export.tarutils.TarPath]#

Creates Path / TarPath object suitable for navigating inside the nemo checkpoint.

Parameters:

nemo_checkpoint (Path, str) – Path to the NeMo checkpoint.

Returns:

Suitable Path object for navigating through the checkpoint.

Return type:

Path | TarPath

class nemo_export.utils.model_loader.TarFileSystemReader(
path: Union[pathlib.Path, nemo_export.tarutils.TarPath],
)#

Bases: torch.distributed.checkpoint.FileSystemReader

Reader that accepts both Path and TarPath checkpoint directory.

The FileSystemReader works with TarPath, but expects a pure Path. It’s enough to skip the Path check in init.

Initialization

Makes sure that super().init gets a pure path as expected.

nemo_export.utils.model_loader.load_sharded_metadata_torch_dist(
checkpoint_dir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
load_extra_states: bool = False,
) Dict[str, Any]#

Loads model state dictionary from torch_dist checkpoint.

Parameters:
  • checkpoint_dir (Path | TarPath) – Path to the model weights directory.

  • load_extra_states (bool) – If set to true, loads BytesIO objects, related to the extra states.

Returns:

Loaded model state dictionary (weights are stored in torch tensors).

Return type:

dict

nemo_export.utils.model_loader.load_sharded_pickle_extra_state_scale(
dir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
) Dict[str, io.BytesIO]#

Loads model extra states from the .pt shards.

Parameters:

dir (Path | TarPath) – Path to the directory with sharded extra states.

Returns:

State dictionary corresponding to the loaded extra states.

Return type:

dict

nemo_export.utils.model_loader.contains_extra_states(
subdir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
) bool#

Checks if zarr directory contains extra states.

Parameters:

subdir (Path | TarPath) – Directory inside the zarr checkpoint.

Returns:

Is a directory with extra states

Return type:

bool

nemo_export.utils.model_loader.load_sharded_metadata_zarr(
checkpoint_dir: Union[pathlib.Path, nemo_export.tarutils.TarPath],
load_extra_states: bool = False,
) Dict[str, Any]#

Loads model dictionary from the zarr format.

Parameters:
  • checkpoint_dir (Path | TarPath) – Path to the NeMo checkpoint.

  • load_extra_states (bool) – If set to True, the function will load BufferIO objects with extra states.

Returns:

Model state dictionary.

Return type:

dict

nemo_export.utils.model_loader.nemo_weights_directory(
nemo_path: Union[pathlib.Path, nemo_export.tarutils.TarPath],
) Union[pathlib.Path, nemo_export.tarutils.TarPath]#

Returns a Path pointing to the weights directory inside the NeMo checkpoint.

Parameters:

nemo_path (Path | TarPath) – Path to the nemo checkpoint.

Returns:

Path to the weights directory inside the model checkpoint.

Return type:

Path | TarPath

nemo_export.utils.model_loader.load_model_weights(
checkpoint_path: Union[str, pathlib.Path],
load_extra_states: bool = False,
) Dict[str, Any]#

Loads NeMo state dictionary.

Weights are stored in torch.Tensor

Parameters:
  • checkpoint_path (str | Path) – Path to the NeMo checkpoint.

  • load_extra_states (bool) – If True, loads BytesIO objects, corresponding to the extra states.

Returns:

Model state dictionary.

Return type:

dict