Add a New AutoModel#

NeMo already ships with an AutoModel for text–generation (HFAutoModelForCausalLM). The instructions below walk you through adding support for a new task (e.g., sequence classification) by creating your own AutoModel class and all required plumbing.

Where To Put the Class#

Create the file:

nemo/collections/<domain>/model/your_auto_model.py

Use HFAutoModelForCausalLM as a template.

Create Your AutoModel Class#

class HFAutoModelForSequenceClassification(
    pl.LightningModule, io.IOMixin, fn.FNMixin
):
    def __init__(
        self,
        model_name: str,
        tokenizer_name: Optional[str] = None,
        loss_fn: Optional[Callable] = None,
        **hf_kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()

        # Configure the HF model
        self.model, self.tokenizer = self.configure_model(
            model_name, tokenizer_name, **hf_kwargs
        )

        # Default loss (if user did not supply one)
        self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()

``configure_model`` is described below.

Configure the Underlying HF Model#

def configure_model(
    self, model_name, tokenizer_name=None, **hf_kwargs
):
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        torch_dtype=torch.float16,           # choose dtype
        device_map="auto",                  # choose device mapping
        **hf_kwargs,
    )
    return model, tokenizer

Add or modify default arguments that make sense for your task (e.g., num_labels, problem_type).

Training and Validation Loops#

def forward(self, **inputs):
    return self.model(**inputs)

def training_step(self, batch, batch_idx):
    outputs = self(**batch)
    loss = self.loss_fn(outputs.logits, batch["labels"])
    self.log("train_loss", loss)
    return loss

def validation_step(self, batch, batch_idx):
    outputs = self(**batch)
    loss = self.loss_fn(outputs.logits, batch["labels"])
    self.log("val_loss", loss, prog_bar=True)
    return loss

Checkpointing Utilities#

Override the HF–style save/load helpers so that both Lightning and HF can re-instantiate your model correctly.

def save_pretrained(self, save_dir: str, **kwargs):
    self.model.save_pretrained(save_dir, **kwargs)
    self.tokenizer.save_pretrained(save_dir)

@classmethod
def load_pretrained(cls, load_dir: str, **kwargs):
    inst = cls.__new__(cls)
    inst.model = AutoModelForSequenceClassification.from_pretrained(
        load_dir, **kwargs
    )
    inst.tokenizer = AutoTokenizer.from_pretrained(load_dir)
    return inst

Checkpoint I/O Wrapper#

def make_checkpoint_io(self):
    return HFCheckpointIO(
        save_function=self.save_pretrained,
        load_function=self.load_pretrained,
    )

Create a LightningDataModule#

Duplicate HFAutoModelForCausalLMDataModule (or the PEFT demo) and adjust:

  • Dataset loading (e.g., Hugging Face datasets.load_dataset).

  • map preprocessing – ensure it yields input_ids, attention_mask and labels expected by your new model.

Quick Checklist#

  • [ ] New file placed under nemo/collections/<domain>/model

  • [ ] Class inherits pl.LightningModule, io.IOMixin, fn.FNMixin

  • [ ] configure_model builds correct HF model & tokenizer

  • [ ] training_step / validation_step compute task-specific loss

  • [ ] save_pretrained / load_pretrained work

  • [ ] make_checkpoint_io returns HFCheckpointIO

  • [ ] Matching LightningDataModule exists

Once these boxes are ticked, your task-specific AutoModel is ready for use exactly like the existing HFAutoModelForCausalLM:

from nemo.collections.<domain>.model import HFAutoModelForSequenceClassification

model = HFAutoModelForSequenceClassification(
    model_name="bert-base-uncased", num_labels=2
)