nemo_automodel.components.models.common.hf_checkpointing_mixin
nemo_automodel.components.models.common.hf_checkpointing_mixin
HuggingFace-compatible checkpointing mixin for NeMo Automodel.
This module provides a mixin class that gives models HuggingFace-compatible save_pretrained() and from_pretrained() methods while using NeMo’s checkpointing infrastructure internally.
Key design principle: We do NOT override state_dict() or load_state_dict(). PyTorch’s DCP expects these to behave like standard nn.Module methods. HF format conversions happen only in save_pretrained() and from_pretrained() via Checkpointer.
Checkpointer is passed explicitly (dependency injection) - no global state.
Module Contents
Classes
Data
API
Mixin providing HF-compatible API using NeMo’s checkpointing infrastructure.
Provides save_pretrained() and from_pretrained() methods that use Checkpointer for unified distributed/async support with HF format conversion.
Key design: We do NOT override state_dict() or load_state_dict() because PyTorch’s DCP expects these to behave like standard nn.Module methods.
For PreTrainedModel subclasses:
- super().from_pretrained() handles: downloads, quantization config, meta device init
- Checkpointer.load_base_model() handles: actual weight loading with format conversion
For nn.Module subclasses (no parent from_pretrained):
- Falls back to manual config loading + Checkpointer
Save model in HF-compatible format using Checkpointer infrastructure.
Supports distributed saving, sharding, and async checkpointing.
Parameters:
Output path
Checkpointer instance. Uses self._checkpointer if not provided.
Optional tokenizer to save alongside model
Additional arguments, including peft_config and
is_final_checkpoint. Direct callers that do not have recipe
step-scheduler context default is_final_checkpoint to
False.