Adding Custom Models
Learn how to integrate custom models into NeMo Curator stages.
The NeMo Curator container includes a robust set of default models, but you can add your own for specialized tasks.
Before You Start
Before you begin adding a custom model, make sure that you have:
- Reviewed the pipeline concepts and diagrams.
- A working NeMo Curator development environment.
- Optionally prepared a container image that includes your model dependencies.
- Optionally created a custom environment to support your new custom model.
How to Add a Custom Model
Review Model Interface
In NeMo Curator, models inherit from nemo_curator.models.base.ModelInterface and must implement model_id_names and setup:
Create New Model
For this tutorial, we’ll sketch a minimal model for demonstration.
Let’s go through each part of the code piece by piece.
Define the PyTorch Model
Provide a model ID (for example, a HuggingFace ID) if you plan to cache or fetch weights. The pipeline can download weights prior to setup() via your model class method if you provide one.
Implement the Model Interface
Your model implements the interface. It defines methods to declare weight identifiers and to initialize the underlying core network.
The setup method initializes the underlying MyCore class that performs the model inference.
The model_id_names property returns a list of weight IDs. These typically correspond to model repository names but do not have to.
If your stage requires a specific environment, manage that in the stage’s resources (for example, gpu_memory_gb, entire_gpu, or gpus) and container image, rather than on the model. GPU allocation is managed at the stage level using Resources, not on the model.
Manage model weights
Provide your model with a model_dir where weights are stored. Your stage should ensure that any required weights are available at runtime (for example, by mounting them into the container or downloading them prior to execution).
Next Steps
Now that you have created a custom model, you can create a custom stage that uses your code.