Add a New AutoModel#
NeMo already ships with an AutoModel for text–generation (HFAutoModelForCausalLM). The instructions below walk you through adding support for a new task (e.g., sequence classification) by creating your own AutoModel class and all required plumbing.
Where To Put the Class#
Create the file:
nemo/collections/<domain>/model/your_auto_model.py
Use HFAutoModelForCausalLM as a template.
Create Your AutoModel Class#
class HFAutoModelForSequenceClassification(
pl.LightningModule, io.IOMixin, fn.FNMixin
):
def __init__(
self,
model_name: str,
tokenizer_name: Optional[str] = None,
loss_fn: Optional[Callable] = None,
**hf_kwargs,
):
super().__init__()
self.save_hyperparameters()
# Configure the HF model
self.model, self.tokenizer = self.configure_model(
model_name, tokenizer_name, **hf_kwargs
)
# Default loss (if user did not supply one)
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
``configure_model`` is described below.
Configure the Underlying HF Model#
def configure_model(
self, model_name, tokenizer_name=None, **hf_kwargs
):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float16, # choose dtype
device_map="auto", # choose device mapping
**hf_kwargs,
)
return model, tokenizer
Add or modify default arguments that make sense for your task (e.g., num_labels, problem_type).
Training and Validation Loops#
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = self.loss_fn(outputs.logits, batch["labels"])
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
outputs = self(**batch)
loss = self.loss_fn(outputs.logits, batch["labels"])
self.log("val_loss", loss, prog_bar=True)
return loss
Checkpointing Utilities#
Override the HF–style save/load helpers so that both Lightning and HF can re-instantiate your model correctly.
def save_pretrained(self, save_dir: str, **kwargs):
self.model.save_pretrained(save_dir, **kwargs)
self.tokenizer.save_pretrained(save_dir)
@classmethod
def load_pretrained(cls, load_dir: str, **kwargs):
inst = cls.__new__(cls)
inst.model = AutoModelForSequenceClassification.from_pretrained(
load_dir, **kwargs
)
inst.tokenizer = AutoTokenizer.from_pretrained(load_dir)
return inst
Checkpoint I/O Wrapper#
def make_checkpoint_io(self):
return HFCheckpointIO(
save_function=self.save_pretrained,
load_function=self.load_pretrained,
)
Create a LightningDataModule#
Duplicate HFAutoModelForCausalLMDataModule (or the PEFT demo) and adjust:
Dataset loading (e.g., Hugging Face
datasets.load_dataset).mappreprocessing – ensure it yieldsinput_ids,attention_maskandlabelsexpected by your new model.
Quick Checklist#
[ ] New file placed under
nemo/collections/<domain>/model[ ] Class inherits
pl.LightningModule, io.IOMixin, fn.FNMixin[ ]
configure_modelbuilds correct HF model & tokenizer[ ]
training_step/validation_stepcompute task-specific loss[ ]
save_pretrained/load_pretrainedwork[ ]
make_checkpoint_ioreturnsHFCheckpointIO[ ] Matching
LightningDataModuleexists
Once these boxes are ticked, your task-specific AutoModel is ready for use exactly like the existing HFAutoModelForCausalLM:
from nemo.collections.<domain>.model import HFAutoModelForSequenceClassification
model = HFAutoModelForSequenceClassification(
model_name="bert-base-uncased", num_labels=2
)