nemo_automodel.components.checkpoint.state_dict_adapter#

Module Contents#

Classes#

StateDictAdapter

Abstract base class for state dict transformations.

API#

class nemo_automodel.components.checkpoint.state_dict_adapter.StateDictAdapter#

Bases: abc.ABC

Abstract base class for state dict transformations.

This class defines the interface for converting between native model state dict format and other model state dict formats.

abstractmethod to_hf(
state_dict: dict[str, Any],
**kwargs,
) dict[str, Any]#

Convert from native model state dict to HuggingFace format.

Parameters:

state_dict – The native model state dict

Returns:

The converted HuggingFace format state dict

abstractmethod from_hf(
hf_state_dict: dict[str, Any],
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
**kwargs,
) dict[str, Any]#

Obtain native model state dict from HuggingFace format.

Parameters:
  • hf_state_dict – The HuggingFace format state dict

  • device_mesh – Optional device mesh for DTensor expert parallelism. If provided, only loads experts needed for the current rank.

Returns:

The converted native model state dict