Fine-Tuning NemotronOmni on CORD-v2 Receipts β€” End-to-End Guide#

A step-by-step guide for fine-tuning NemotronOmni (33B MoE) to extract structured receipt data from scanned images using NeMo Automodel. Covers both full SFT and LoRA PEFT.


What is NemotronOmni?#

NemotronOmni (NemotronH_Nano_Omni_Reasoning_V3) is a ~33B multimodal MoE model supporting image, video, and audio inputs.

Key architectural details:

  • LLM backbone: NemotronV3 hybrid Mamba2 + Attention + MoE, 52 layers, hidden dim 2688

  • Vision encoder: RADIO v2.5-H (ViT-Huge), 256 vision tokens per tile

  • Audio encoder: Parakeet FastConformer (1024-dim)

  • MoE: 128 experts per MoE layer, top-6 routing with sigmoid gating

  • Total parameters: 33B (31.5B trainable with frozen vision/audio towers)

Fine-Tune for Receipt Field Extraction#

We fine-tune NemotronOmni on the CORD-v2 (Consolidated Receipt Dataset) to extract structured fields from scanned receipts:

Field

Example

menu

Item names, quantities, prices

sub_total

Subtotal, tax, discount

total

Total price, cash paid, change

The base model produces free-form descriptions. After fine-tuning, it outputs structured XML-like token sequences matching the receipt fields.

Guide Overview#

Step

Description

Step 0

Environment setup

Step 1

Explore the CORD-v2 dataset

Step 2

Training configuration (SFT and LoRA)

Step 3

Launch fine-tuning

Step 4

Run inference on the fine-tuned model

Step 5

Compare SFT vs LoRA results

Hardware Requirements#

  • 8x H100 80 GB GPUs required (MoE with EP=8)

  • SFT memory: ~49 GiB per GPU

  • LoRA memory: ~30 GiB per GPU

  • Estimated training time: ~10 min on 8x H100 (400 steps, 800 training samples)


Step 0 β€” Set Up the Environment#

# Inside the NeMo AutoModel container (26.04+):
cd /opt/Automodel

# Or from a source checkout:
git clone -b nemotron-omni ssh://git@gitlab-master.nvidia.com:12051/huiyingl/automodel-omni.git
cd automodel-omni

Note

NemotronOmni requires mamba_ssm, causal_conv1d, and decord packages, which are included in the NeMo AutoModel container.


Step 1 β€” Explore the CORD-v2 Dataset#

CORD-v2 contains scanned receipts with structured ground-truth JSON labels.

import json
from datasets import load_dataset

dataset = load_dataset("naver-clova-ix/cord-v2")

print(f"Train      : {len(dataset['train'])} samples")
print(f"Validation : {len(dataset['validation'])} samples")
print(f"Test       : {len(dataset['test'])} samples")

# Inspect a sample
ex = dataset["train"][0]
gt = json.loads(ex["ground_truth"])["gt_parse"]
print(f"\nGround-truth keys: {list(gt.keys())}")

Expected output:

Train      : 800 samples
Validation : 100 samples
Test       : 100 samples

Ground-truth keys: ['menu', 'sub_total', 'total', 'void_menu']

Target Format: JSON-to-Token Conversion#

NeMo Automodel converts structured JSON into an XML-like token sequence using the json2token() function. This is the format the model is trained to produce:

<s_total><s_total_price>45,500</s_total_price><s_changeprice>4,500</s_changeprice>
<s_cashprice>50,000</s_cashprice></s_total><s_menu><s_price>16,500</s_price>
<s_nm>REAL GANACHE</s_nm><s_cnt>1</s_cnt><sep/><s_price>13,000</s_price>
<s_nm>EGG TART</s_nm><s_cnt>1</s_cnt></s_menu>

Step 2 β€” Training Configuration#

Full SFT Config#

Config file: examples/vlm_finetune/nemotron_omni/nemotron_omni_cord_v2.yaml

recipe: FinetuneRecipeForVLM

step_scheduler:
  global_batch_size: 8
  local_batch_size: 1
  ckpt_every_steps: 100
  val_every_steps: 200
  max_steps: 400

model:
  _target_: nemo_automodel.NeMoAutoModelForImageTextToText.from_pretrained
  pretrained_model_name_or_path: <path_to_nemotron_omni_v2.0>
  trust_remote_code: true
  torch_dtype: torch.bfloat16
  backend:
    _target_: nemo_automodel.components.models.common.BackendConfig
    attn: sdpa
    linear: torch
    rms_norm: torch_fp32
    rope_fusion: false
    enable_deepep: false
    fake_balanced_gate: false
    enable_hf_state_dict_adapter: true

distributed:
  strategy: fsdp2
  ep_size: 8            # 128 MoE experts across 8 GPUs

freeze_config:
  freeze_embeddings: true
  freeze_vision_tower: true
  freeze_audio_tower: true
  freeze_language_model: false

dataset:
  _target_: nemo_automodel.components.datasets.vlm.datasets.make_cord_v2_dataset
  path_or_dataset: naver-clova-ix/cord-v2
  split: train

dataloader:
  collate_fn:
    _target_: nemo_automodel.components.datasets.vlm.collate_fns.nemotron_omni_collate_fn
    max_length: 4096

optimizer:
  _target_: torch.optim.AdamW
  lr: 1e-4
  weight_decay: 0.01
  betas: [0.9, 0.95]

LoRA PEFT Config#

Config file: examples/vlm_finetune/nemotron_omni/nemotron_omni_cord_v2_peft.yaml

Adds a peft: block to apply LoRA to language model linear layers only:

peft:
  _target_: nemo_automodel.components._peft.lora.PeftConfig
  match_all_linear: false
  exclude_modules:
    - "*vision_tower*"
    - "*vision_model*"
    - "*audio*"
    - "*sound*"
    - "*lm_head*"
    - "*mlp1*"
  dim: 64
  alpha: 128
  use_triton: true

optimizer:
  _target_: torch.optim.AdamW
  lr: 1e-3

Collate function#

NemotronOmni uses InternVL-style image handling where each <image> token in the input is replaced by 256 vision embeddings during the model’s forward pass. The collate function:

  1. Extracts images from the conversation

  2. Applies the chat template (which adds <think></think> prefix for the assistant turn)

  3. Processes images through the NemotronOmni processor

  4. Builds image_flags tensors and creates training labels


Step 3 β€” Launch Fine-Tuning#

Full SFT#

torchrun --nproc-per-node=8 \
    examples/vlm_finetune/finetune.py \
    -c examples/vlm_finetune/nemotron_omni/nemotron_omni_cord_v2.yaml

LoRA PEFT#

torchrun --nproc-per-node=8 \
    examples/vlm_finetune/finetune.py \
    -c examples/vlm_finetune/nemotron_omni/nemotron_omni_cord_v2_peft.yaml

Training log β€” Full SFT#

Trainable parameters: 31,570,023,872
Trainable parameters percentage: 95.63%

step    0 | loss 0.6866 | grad_norm  7.57 | lr 1.00e-04 | mem 37.29 GiB | tps/gpu   33
step   10 | loss 0.0705 | grad_norm  1.00 | lr 1.00e-04 | mem 48.95 GiB | tps/gpu 2419
step   50 | loss 0.0173 | grad_norm  0.43 | lr 1.00e-04 | mem 48.72 GiB | tps/gpu 2615
step  100 | loss 0.0115 | grad_norm  0.37 | lr 1.00e-04 | mem 48.84 GiB | tps/gpu 2642
step  200 | loss 0.0099 | grad_norm  0.20 | lr 1.00e-04 | mem 48.76 GiB | tps/gpu 2616
step  300 | loss 0.0056 | grad_norm  0.15 | lr 1.00e-04 | mem 48.72 GiB | tps/gpu 2087
step  399 | loss 0.0039 | grad_norm  0.17 | lr 1.00e-04 | mem 48.79 GiB | tps/gpu 2616

Validation:
  step  99 | val_loss 0.0363
  step 199 | val_loss 0.0342  <-- LOWEST_VAL
  step 299 | val_loss 0.0414
  step 399 | val_loss 0.0425

Training log β€” LoRA PEFT#

Trainable parameters: 55,422,976
Trainable parameters percentage: 0.17%

step    0 | loss 0.6866 | grad_norm  1.92 | lr 1.00e-03 | mem 30.26 GiB | tps/gpu   34
step   10 | loss 0.0557 | grad_norm  0.30 | lr 1.00e-03 | mem 30.16 GiB | tps/gpu 2455
step   50 | loss 0.0392 | grad_norm  0.32 | lr 1.00e-03 | mem 30.16 GiB | tps/gpu 3352
step  100 | loss 0.0309 | grad_norm  0.27 | lr 1.00e-03 | mem 30.20 GiB | tps/gpu 2456
step  200 | loss 0.0280 | grad_norm  0.23 | lr 1.00e-03 | mem 30.34 GiB | tps/gpu 2477
step  300 | loss 0.0326 | grad_norm  0.31 | lr 1.00e-03 | mem 30.52 GiB | tps/gpu 2737
step  399 | loss 0.0171 | grad_norm  0.24 | lr 1.00e-03 | mem 30.33 GiB | tps/gpu 3258

Validation:
  step  99 | val_loss 0.0449  <-- LOWEST_VAL
  step 199 | val_loss 0.0524
  step 299 | val_loss 0.0482
  step 399 | val_loss 0.0566

Checkpoints saved#

checkpoint_dir/
  epoch_0_step_99/
  epoch_1_step_199/
  epoch_2_step_299/
  epoch_3_step_399/
    model/
      consolidated/          <-- HF-compatible checkpoint for inference
        config.json
        model.safetensors.index.json
        model-00001-of-00017.safetensors
        ...
    optim/
    rng/
    dataloader/
  LATEST -> epoch_3_step_399
  LOWEST_VAL -> epoch_1_step_199
  training.jsonl
  validation.jsonl

For LoRA, the checkpoint saves adapter weights instead:

  model/
    adapter_model.safetensors   (~27 MB)
    adapter_config.json

Tip: LOWEST_VAL symlink points to the checkpoint with the best validation loss.


Step 4 β€” Run Inference on the Fine-Tuned Model#

Load the consolidated checkpoint and run inference on a handful of validation samples to spot-check structured output.

import torch
import json
from transformers import AutoConfig, AutoModel, AutoProcessor
from datasets import load_dataset
from nemo_automodel.components.datasets.vlm.utils import json2token

CKPT = "<checkpoint_dir>/LOWEST_VAL/model/consolidated"

# Load processor
processor = AutoProcessor.from_pretrained(CKPT, trust_remote_code=True)
tokenizer = processor.tokenizer

# Resolve the trust_remote_code model class via from_config, then load weights.
# Using AutoModel.from_pretrained directly can mis-route on v3 dumps.
config = AutoConfig.from_pretrained(CKPT, trust_remote_code=True)
model_class = type(AutoModel.from_config(config, trust_remote_code=True))
if not hasattr(model_class, "all_tied_weights_keys"):
    model_class.all_tied_weights_keys = {}
model = model_class.from_pretrained(CKPT, trust_remote_code=True, torch_dtype=torch.bfloat16)

# Reset RADIO's `summary_idxs` (non-persistent buffer; can be a meta tensor after load)
if hasattr(model, "vision_model") and hasattr(model.vision_model, "radio_model"):
    model.vision_model.radio_model.summary_idxs = None

model = model.cuda().eval()

# Load dataset
dataset = load_dataset("naver-clova-ix/cord-v2")

# v3 processor returns extra placeholder-expansion metadata that is NOT a generate() kwarg.
PROCESSOR_METADATA_KEYS = ("num_patches", "num_tokens", "imgs_sizes")

# Run inference on the first 5 validation samples
for i in range(5):
    sample = dataset["validation"][i]
    image = sample["image"].convert("RGB")
    gt = json.loads(sample["ground_truth"])["gt_parse"]
    gt_text = json2token(gt, sort_json_key=True)

    # Build prompt β€” enable_thinking=False for structured output
    messages = [{"role": "user", "content": "<image>\nDescribe this image."}]
    text = tokenizer.apply_chat_template(
        messages, tokenize=False,
        add_generation_prompt=True, enable_thinking=False,
    )
    inputs = processor(text=text, images=[image], return_tensors="pt")
    for k in PROCESSOR_METADATA_KEYS:
        inputs.pop(k, None)
    inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=1024, do_sample=False)

    generated = tokenizer.decode(
        output_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True,
    ).strip()

    print(f"\n=== Sample {i} ===")
    print(f"Ground truth: {gt_text}")
    print(f"Prediction:   {generated}")

Step 5 β€” Results Comparison#

Evaluation on 5 CORD-v2 Validation Samples#

Full SFT (lr=1e-4, 400 steps, epoch_3_step_399)#

Sample

Ground Truth

Prediction

Match

1

<s_total>...<s_nm>REAL GANACHE</s_nm>...<s_nm>EGG TART</s_nm>...<s_nm>PIZZA TOAST</s_nm>...

Exact match

100%

2

<s_total>...<s_nm>JAMUR</s_nm>...<s_nm>TAHU</s_nm>...

Exact match

100%

3

<s_total>...<s_nm>Gojek Chicken Chilli Sauce H</s_nm>...

Correct values, slight name segmentation diff

33%

4

<s_total>...<s_nm>VANILLA CHOCO HEART CAKE</s_nm>...

Exact match

100%

5

<s_total>...<s_nm>Sate Padang</s_nm>...

Correct, extra <s_unitprice> field

~0%

3/5 exact matches. All samples produce correct structured output.

LoRA PEFT (rank=64, lr=1e-3, 400 steps, epoch_0_step_99)#

Sample

Ground Truth

Prediction

Match

1

<s_total>...<s_nm>REAL GANACHE</s_nm>...

Exact match

100%

2

<s_total>...<s_nm>JAMUR</s_nm>...<s_nm>TAHU</s_nm>...

Exact match

100%

3

<s_total>...<s_nm>Gojek Chicken Chilli Sauce H</s_nm>...

Correct values, slight name segmentation diff

33%

4

<s_total>...<s_nm>VANILLA CHOCO HEART CAKE</s_nm>...

Exact match

100%

5

<s_total>...<s_nm>Sate Padang</s_nm>...

Exact match

100%

4/5 exact matches. All samples produce correct structured output.

Summary#

Full SFT

LoRA PEFT

Trainable params

31.5B (95.63%)

55M (0.17%)

Learning rate

1e-4

1e-3

GPU memory

~49 GiB

~30 GiB

Training time (8x H100)

~10 min

~6 min

Best val loss

0.034 (step 199)

0.045 (step 99)

Final train loss

0.004

0.017

Checkpoint size

~64 GB

~27 MB

Exact matches (5 val)

3/5

4/5