nemo_automodel.components.models.common.hf_checkpointing_mixin#

HuggingFace-compatible checkpointing mixin for NeMo Automodel.

This module provides a mixin class that gives models HuggingFace-compatible save_pretrained() and from_pretrained() methods while using NeMo’s checkpointing infrastructure internally.

Key design principle: We do NOT override state_dict() or load_state_dict(). PyTorch’s DCP expects these to behave like standard nn.Module methods. HF format conversions happen only in save_pretrained() and from_pretrained() via Checkpointer.

Checkpointer is passed explicitly (dependency injection) - no global state.

Module Contents#

Classes#

HFCheckpointingMixin

Mixin providing HF-compatible API using NeMo’s checkpointing infrastructure.

Data#

API#

nemo_automodel.components.models.common.hf_checkpointing_mixin.logger#

‘getLogger(…)’

class nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin#

Mixin providing HF-compatible API using NeMo’s checkpointing infrastructure.

Provides save_pretrained() and from_pretrained() methods that use Checkpointer for unified distributed/async support with HF format conversion.

Key design: We do NOT override state_dict() or load_state_dict() because PyTorch’s DCP expects these to behave like standard nn.Module methods.

For PreTrainedModel subclasses:

  • super().from_pretrained() handles: downloads, quantization config, meta device init

  • Checkpointer.load_base_model() handles: actual weight loading with format conversion

For nn.Module subclasses (no parent from_pretrained):

  • Falls back to manual config loading + Checkpointer

save_pretrained(
save_directory: str,
checkpointer: Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
tokenizer: Optional[transformers.tokenization_utils.PreTrainedTokenizerBase] = None,
**kwargs,
) None#

Save model in HF-compatible format using Checkpointer infrastructure.

Supports distributed saving, sharding, and async checkpointing.

Parameters:
  • save_directory – Output path

  • checkpointer – Checkpointer instance. Uses self._checkpointer if not provided.

  • tokenizer – Optional tokenizer to save alongside model

  • **kwargs – Additional arguments