Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Hugging Face Integration#
The io.ConnectorMixin
class can be used to make a NeMo model compatible with Hugging Face. io.ConnectorMixin
makes it possible to load Hugging Face models into NeMo and save NeMo models in Hugging Face format. The GPTModel
class below shows how to implement this (we can ignore the other mixins here):
class GPTModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin):
...
The generic base class is extended by several models to provide two-way integration with Hugging Face. These models include:
GemmaModel
LLamaModel
MistralModel
MixtralModel
Fine-Tune a Model using a Hugging Face Checkpoint#
To fine-tune a model, use the following script:
import nemo_run as run
from nemo.collections import llm
from nemo import lightning as nl
@run.factory
def mistral():
return llm.MistralModel()
@run.factory
def trainer(devices=2) -> nl.Trainer:
strategy = nl.MegatronStrategy(tensor_model_parallel_size=devices)
return nl.Trainer(
devices=devices,
max_steps=100,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
)
resume = nl.AutoResume(import_path="hf://mistralai/Mistral-7B-v0.1")
sft = run.Partial(llm.finetune, model=mistral, data=llm.squad, trainer=trainer, resume=resume)
The script will try to load a model_importer
with the name “hf” on the MistralModel
.
It then loads the following class inside nemo/collections/llm/gpt/model/mistral.py
to perform the conversion:
@io.model_importer(MistralModel, "hf")
class HFMistralImporter(io.ModelConnector["MistralForCausalLM", MistralModel]):
...
Note that this conversion only occurs once. Afterwards, the converted checkpoint will be loaded from the $NEMO_HOME
dir.
Create a Model Importer#
To implement a custom model importer, you can follow the structure of the HFMistralImporter class. Here’s a step-by-step explanation of how to create a custom model importer.
Define a new class that inherits from
io.ModelConnector
:@io.model_importer(YourModel, "source_format") class CustomImporter(io.ModelConnector["SourceModel", YourModel]): # Implementation here
Replace
YourModel
with your target model class,"source_format"
with the format you’re importing from, and"SourceModel"
with the source model type. You can choose"source_format"
to be any string. In the Mistral example, we use the “hf” string to demonstrate that we are importing from Hugging Face.Implement the required methods:
class CustomImporter(io.ModelConnector["SourceModel", YourModel]): def init(self) -> YourModel: # Initialize and return your target model return YourModel(self.config, tokenizer=self.tokenizer) def apply(self, output_path: Path) -> Path: # Load source model, convert state, and save target model source = SourceModel.from_pretrained(str(self)) target = self.init() trainer = self.nemo_setup(target) self.convert_state(source, target) self.nemo_save(output_path, trainer) # Clean up and return output path teardown(trainer, target) return output_path def convert_state(self, source, target): # Define mapping between source and target model states mapping = { "source_key1": "target_key1", "source_key2": "target_key2", # ... more mappings ... } return io.apply_transforms(source, target, mapping=mapping, transforms=[]) @property def tokenizer(self) -> "YourTokenizer": # Return the appropriate tokenizer for your model return YourTokenizer(str(self)) @property def config(self) -> YourModelConfig: # Load source config and convert to target config source_config = SourceConfig.from_pretrained(str(self)) return YourModelConfig( # Set appropriate parameters based on source_config )
Implement custom state transforms:
The
@io.state_transform
decorator is a powerful tool for defining custom transformations between source and target model states. It allows you to specify complex mappings that go beyond simple key renaming.@io.state_transform( source_key=("source.key1", "source.key2"), target_key="target.key" ) def _custom_transform(ctx: io.TransformCTX, source1, source2): # Implement custom transformation logic return transformed_data
The following list describes the key aspects of the
state_transform
decorator:Source and Target Keys:
source_key
: Specifies the key(s) in the source model state. Can be a single string or a tuple of strings.target_key
: Specifies the key in the target model state where the transformed data will be stored.Wildcard
*
: Used to apply the transform across multiple layers or components.
Transform Function:
The decorated function receives the source tensor(s) as arguments.
It should return the transformed tensor(s) for the target model.
Context Object:
The first argument,
ctx
, is aTransformCTX
object. It provides access to both source and target models and their configs.
Multiple Source Keys:
When multiple source keys are specified, the transform function receives multiple tensors as arguments.
Flexible Transformations:
You can perform arbitrary operations on the tensors, including reshaping, concatenating, splitting, or applying mathematical operations.
The following example shows a more complex transform using wildcards:
@io.state_transform( source_key=("model.layers.*.self_attn.q_proj.weight", "model.layers.*.self_attn.k_proj.weight", "model.layers.*.self_attn.v_proj.weight"), target_key="decoder.layers.*.self_attention.qkv.weight" ) def _combine_qkv_weights(ctx: io.TransformCTX, q, k, v): # Combine separate Q, K, V weights into a single QKV tensor return torch.cat([q, k, v], dim=0)
This transform combines separate Q, K, and V weight matrices from the source model into a single QKV weight matrix for the target model. The use of
*
in the keys is crucial:In
source_key
,model.layers.*
matches all layers in the source model.In
target_key
,decoder.layers.*
corresponds to all layers in the target model.
The wildcard ensures that this transform is applied to each layer of the model automatically. Without it, you’d need to write separate transforms for each layer manually. This makes the code more concise and easier to maintain, especially for models with many layers.
The transform function itself (
_combine_qkv_weights
) will be called once for each layer, withq
,k
, andv
containing the weights for that specific layer.Add these transforms to the
convert_state
method:def convert_state(self, source, target): mapping = { # ... existing mappings ... } return io.apply_transforms(source, target, mapping=mapping, transforms=[_custom_transform])
By following this structure, you can create a custom model importer that converts models from a source format to your target format, handling state mapping and any necessary transformations.