nemo_automodel.components.models.common.hf_checkpointing_mixin#
HuggingFace-compatible checkpointing mixin for NeMo Automodel.
This module provides a mixin class that gives models HuggingFace-compatible save_pretrained() and from_pretrained() methods while using NeMo’s checkpointing infrastructure internally.
Key design principle: We do NOT override state_dict() or load_state_dict(). PyTorch’s DCP expects these to behave like standard nn.Module methods. HF format conversions happen only in save_pretrained() and from_pretrained() via Checkpointer.
Checkpointer is passed explicitly (dependency injection) - no global state.
Module Contents#
Classes#
Mixin providing HF-compatible API using NeMo’s checkpointing infrastructure. |
Data#
API#
- nemo_automodel.components.models.common.hf_checkpointing_mixin.logger#
‘getLogger(…)’
- class nemo_automodel.components.models.common.hf_checkpointing_mixin.HFCheckpointingMixin#
Mixin providing HF-compatible API using NeMo’s checkpointing infrastructure.
Provides save_pretrained() and from_pretrained() methods that use Checkpointer for unified distributed/async support with HF format conversion.
Key design: We do NOT override state_dict() or load_state_dict() because PyTorch’s DCP expects these to behave like standard nn.Module methods.
For PreTrainedModel subclasses:
super().from_pretrained() handles: downloads, quantization config, meta device init
Checkpointer.load_base_model() handles: actual weight loading with format conversion
For nn.Module subclasses (no parent from_pretrained):
Falls back to manual config loading + Checkpointer
- save_pretrained(
- save_directory: str,
- checkpointer: Optional[nemo_automodel.components.checkpoint.checkpointing.Checkpointer] = None,
- tokenizer: Optional[transformers.tokenization_utils.PreTrainedTokenizerBase] = None,
- **kwargs,
Save model in HF-compatible format using Checkpointer infrastructure.
Supports distributed saving, sharding, and async checkpointing.
- Parameters:
save_directory – Output path
checkpointer – Checkpointer instance. Uses self._checkpointer if not provided.
tokenizer – Optional tokenizer to save alongside model
**kwargs – Additional arguments