Add a New AutoModel#
NeMo already ships with an AutoModel
for text–generation (HFAutoModelForCausalLM
). The instructions below walk you through adding support for a new task (e.g., sequence classification) by creating your own AutoModel class and all required plumbing.
Where To Put the Class#
Create the file:
nemo/collections/<domain>/model/your_auto_model.py
Use HFAutoModelForCausalLM
as a template.
Create Your AutoModel Class#
class HFAutoModelForSequenceClassification(
pl.LightningModule, io.IOMixin, fn.FNMixin
):
def __init__(
self,
model_name: str,
tokenizer_name: Optional[str] = None,
loss_fn: Optional[Callable] = None,
**hf_kwargs,
):
super().__init__()
self.save_hyperparameters()
# Configure the HF model
self.model, self.tokenizer = self.configure_model(
model_name, tokenizer_name, **hf_kwargs
)
# Default loss (if user did not supply one)
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
``configure_model`` is described below.
Configure the Underlying HF Model#
def configure_model(
self, model_name, tokenizer_name=None, **hf_kwargs
):
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
torch_dtype=torch.float16, # choose dtype
device_map="auto", # choose device mapping
**hf_kwargs,
)
return model, tokenizer
Add or modify default arguments that make sense for your task (e.g., num_labels
, problem_type
).
Training and Validation Loops#
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = self.loss_fn(outputs.logits, batch["labels"])
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
outputs = self(**batch)
loss = self.loss_fn(outputs.logits, batch["labels"])
self.log("val_loss", loss, prog_bar=True)
return loss
Checkpointing Utilities#
Override the HF–style save/load helpers so that both Lightning and HF can re-instantiate your model correctly.
def save_pretrained(self, save_dir: str, **kwargs):
self.model.save_pretrained(save_dir, **kwargs)
self.tokenizer.save_pretrained(save_dir)
@classmethod
def load_pretrained(cls, load_dir: str, **kwargs):
inst = cls.__new__(cls)
inst.model = AutoModelForSequenceClassification.from_pretrained(
load_dir, **kwargs
)
inst.tokenizer = AutoTokenizer.from_pretrained(load_dir)
return inst
Checkpoint I/O Wrapper#
def make_checkpoint_io(self):
return HFCheckpointIO(
save_function=self.save_pretrained,
load_function=self.load_pretrained,
)
Create a LightningDataModule#
Duplicate HFAutoModelForCausalLMDataModule
(or the PEFT demo) and adjust:
Dataset loading (e.g., Hugging Face
datasets.load_dataset
).map
preprocessing – ensure it yieldsinput_ids
,attention_mask
andlabels
expected by your new model.
Quick Checklist#
[ ] New file placed under
nemo/collections/<domain>/model
[ ] Class inherits
pl.LightningModule, io.IOMixin, fn.FNMixin
[ ]
configure_model
builds correct HF model & tokenizer[ ]
training_step
/validation_step
compute task-specific loss[ ]
save_pretrained
/load_pretrained
work[ ]
make_checkpoint_io
returnsHFCheckpointIO
[ ] Matching
LightningDataModule
exists
Once these boxes are ticked, your task-specific AutoModel
is ready for use exactly like the existing HFAutoModelForCausalLM
:
from nemo.collections.<domain>.model import HFAutoModelForSequenceClassification
model = HFAutoModelForSequenceClassification(
model_name="bert-base-uncased", num_labels=2
)