nemo_automodel.components.models.common.hf_checkpointing_mixin

View as Markdown

HuggingFace-compatible checkpointing mixin for NeMo Automodel.

This module provides a mixin class that gives models HuggingFace-compatible save_pretrained() and from_pretrained() methods while using NeMo’s checkpointing infrastructure internally.

Key design principle: We do NOT override state_dict() or load_state_dict(). PyTorch’s DCP expects these to behave like standard nn.Module methods. HF format conversions happen only in save_pretrained() and from_pretrained() via Checkpointer.

Checkpointer is passed explicitly (dependency injection) - no global state.

Module Contents

Classes

NameDescription
HFCheckpointingMixinMixin providing HF-compatible API using NeMo’s checkpointing infrastructure.

Data

logger

API

class nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin()

Mixin providing HF-compatible API using NeMo’s checkpointing infrastructure.

Provides save_pretrained() and from_pretrained() methods that use Checkpointer for unified distributed/async support with HF format conversion.

Key design: We do NOT override state_dict() or load_state_dict() because PyTorch’s DCP expects these to behave like standard nn.Module methods.

For PreTrainedModel subclasses:

  • super().from_pretrained() handles: downloads, quantization config, meta device init
  • Checkpointer.load_base_model() handles: actual weight loading with format conversion

For nn.Module subclasses (no parent from_pretrained):

  • Falls back to manual config loading + Checkpointer
nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin.save_pretrained(
save_directory: str,
checkpointer: typing.Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
tokenizer: typing.Optional[transformers.tokenization_utils.PreTrainedTokenizerBase] = None,
kwargs = {}
) -> None

Save model in HF-compatible format using Checkpointer infrastructure.

Supports distributed saving, sharding, and async checkpointing.

Parameters:

save_directory
str

Output path

checkpointer
Optional[Checkpointer]Defaults to None

Checkpointer instance. Uses self._checkpointer if not provided.

tokenizer
Optional[PreTrainedTokenizerBase]Defaults to None

Optional tokenizer to save alongside model

**kwargs
Defaults to {}

Additional arguments, including peft_config and is_final_checkpoint. Direct callers that do not have recipe step-scheduler context default is_final_checkpoint to False.

nemo_automodel.components.models.common.hf_checkpointing_mixin.logger = logging.getLogger(__name__)