Llama Nemotron Nano VL 8B#
Llama-Nemotron-Nano-VL-8B-V1 is a leading document intelligence vision language model (VLMs) that enables the ability to query and summarize images and video from the physical or virtual world. Llama-Nemotron-Nano-VL-8B-V1 is deployable in the data center, cloud and at the edge, including Jetson Orin and laptop by AWQ 4bit quantization through TinyChat framework. We find: (1) image-text pairs are not enough, interleaved image-text is essential; (2) unfreezing LLM during interleaved image-text pre-training enables in-context learning; (3) re-blending text-only instruction data is crucial to boost both VLM and text-only performance.
This model was trained on commercial images and videos for all three stages of training and supports single image and video inference.
Note
Please use the custom container nvcr.io/nvidia/nemo:25.04.01.llama_nemotron_nano_vl
when working with Llama
Nemotron Nano VL 8B.
Import from Hugging Face to NeMo 2.0#
To import a Hugging Face model checkpoint and convert it to NeMo 2.0 format, run the following command. This step only needs to be performed once:
Log in to Hugging Face:
huggingface-cli login
from nemo.collections.llm import import_ckpt
from nemo.collections import vlm
if __name__ == '__main__':
# Hugging Face model id or a local path
ckpt_path = "nvidia/Llama-3.1-Nemotron-Nano-VL-8B-V1"
# Import the model and convert it to NeMo 2.0 format
import_ckpt(
model=vlm.LlamaNemotronVLModel(vlm.LlamaNemotronNanoVLConfig8B()), # Model configuration
source=f"hf://{ckpt_path}",
)
The converted file will be saved in the NeMo cache directory (default: ~/.cache/nemo
).
To change this location, set the environment variable NEMO_MODELS_CACHE
before running the script.
Model Generation Example#
Load from a converted (or fine-tuned) local NeMo 2.0 checkpoint:
python scripts/vlm/llama_nemotron_nano_vl/llama_nemotron_nano_vl_generate.py \
--local_model_path=/path/to/nemo2_ckpt
You can pass in your own image and text prompts with the image_url
and prompt
commands.
A default image and prompt are supplied for you.
NeMo 2.0 Fine-Tuning Recipes#
Refer to example script scripts/vlm/llama_nemotron_nano_vl/llama_nemotron_nano_vl_8b_finetune.py
. With single node, you can run,
torchrun --nproc_per_node=8 scripts/vlm/llama_nemotron_nano_vl/llama_nemotron_nano_vl_8b_finetune.py \
--devices=8 --tp=2 --data_type=mock
By default, it uses the MockDataModule
for the data
argument. We also support:
llava
dataset typellava
dataset converted to Energon format
To load a local NeMo 2.0 model, set:
--restore_path=<local_model_path>
Here’s an example for full SFT (Supervised Fine-Tuning) using the LLaVA dataset:
torchrun --nproc_per_node=8 scripts/vlm/llama_nemotron_nano_vl/llama_nemotron_nano_vl_8b_finetune.py \
--data_path "/path/to/dataset/llava_v1_5_mix665k.json" \
--image_folder "/path/to/dataset/images" \
--data_type llava \
--num_nodes 1 \
--devices=8 \
--projector_type=mcore_mlp \
--tp_size 2 --pp_size 1 \
--gbs 128 --mbs 2 \
--wandb_project=llama_nemotron_vl_demo \
--name=llama_nemotron_vl_finetune \
--log_dir "/path/to/experiments/llama_nemotron_vl_finetune" \
--restore_path "/path/to/experiments/llama_nemotron_vl_pretrain_checkpoint"
By default, all components are unfrozen for SFT. You can modify freezing behavior in the script:
from nemo.collections import vlm
neva_config = vlm.LlamaNemotronVLConfig(
...
freeze_language_model=False,
freeze_vision_model=False,
freeze_vision_projection=False,
)
NeMo 2.0 PEFT Recipes#
For Parameter-Efficient Fine-Tuning (PEFT), use the same script with the additional flag --peft='lora'
.
torchrun --nproc_per_node=8 scripts/vlm/llama_nemotron_nano_vl/llama_nemotron_nano_vl_8b_finetune.py \
--data_path "/path/to/dataset/llava_v1_5_mix665k.json" \
--image_folder "/path/to/dataset/images" \
--data_type llava \
--num_nodes 1 \
--devices=8 \
--projector_type=mcore_mlp \
--tp_size 2 --pp_size 1 \
--gbs 128 --mbs 2 \
--wandb_project=llama_nemotron_vl_demo \
--name=llama_nemotron_vl_finetune \
--log_dir "/path/to/experiments/llama_nemotron_vl_finetune" \
--restore_path "/path/to/experiments/llama_nemotron_vl_pretrain_checkpoint" \
--peft='lora'
By default, LoRA is applied to all linear layers in the language model, while the vision model remains unfrozen.
To specify target modules for LoRA: Set
peft.target_modules
. For example:peft.target_modules=["*.language_model.*.linear_qkv"]
To freeze the vision model: Set
peft.freeze_vision_model=True
from nemo.collections import vlm
peft = vlm.peft.LoRA(
target_modules=[
"*.language_model.*.linear_qkv",
"*.language_model.*.linear_proj",
"*.language_model.*.linear_fc1",
"*.language_model.*.linear_fc2",
],
freeze_language_model=True,
freeze_vision_model=False,
freeze_vision_projection=False,
dim=32,
)
Merge LoRA Adapters to Base Model#
In order to perform inference on the LoRA checkpoint or export a LoRA checkpoint to Hugging Face format, it is necessary to merge it with the base model to create a full LoRA checkpoint.
from nemo.collections import vlm
vlm.peft.merge_lora(
lora_checkpoint_path=<lora_checkpoint_path>,
output_path=<output_path>,
)
Export to Hugging Face from NeMo 2.0#
To export a NeMo 2.0 model to Hugging Face format, run the following command.
from nemo.collections.llm import import_ckpt
from nemo.collections import llm
from pathlib import Path
if __name__ == '__main__':
ckpt = "/path/to/nemo2_ckpt/"
output_path = "/path/to/output_hf"
output = llm.export_ckpt(
path=Path(ckpt),
target="hf",
output_path=Path(output_path),
overwrite=True,
)
Deploy NeMo Models#
For scenarios requiring optimized performance, the model can leverage TensorRT.
You need to convert NeMo models into a format compatible with TensorRT using the nemo.export module:
from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter
if __name__ == '__main__':
exporter = TensorRTMMExporter(model_dir="/path/to/trt_engines")
exporter.export(
visual_checkpoint_path="/path/to/nemo_ckpt",
model_type="llama_nemotron",
vision_max_batch_size=13, # this equal to maxinum tiles of image
max_batch_size=1,
max_multimodal_len=4096,
max_input_len=4096,
)
After converting TensorRT, you can generate outputs by nemo.export:
from nemo.export.tensorrt_mm_exporter import TensorRTMMExporter
if __name__ == '__main__':
exporter = TensorRTMMExporter(model_dir="/path/to/trt_engines", load_model=True)
img_path = "/path/to/image.jpg"
prompt = "<img><image></img>\nPlease describe the image shortly."
output = exporter.forward(
prompt,
img_path,
batch_size=1,
max_output_len=128,
)
print(output)
You can also deploy the model on the Triton server:
python scripts/deploy/multimodal/deploy_triton.py \
--triton_model_repository /path/to/trt_engines \
--model_type llama_nemotron \
--llm_model_type llama \
--triton_model_name llama_nemotron
After deploying the model, you can send a query to the server:
python scripts/deploy/multimodal/query.py \
--model_name llama_nemotron \
--model_type llama_nemotron \
--input_media /path/to/image.jpg \
--input_text "<img><image></img>\nPlease describe the image shortly."
Supported Features#
Sequence Packing: Please refer to Run SFT/PEFT with Packed Sequences in NeVA for more information.
Additional Model Parallelisms: Check out our example fine-tuning script for full examples on setting additional model parallelisms.
Sequence Parallel: Add
sequence_parallel=True
inMegatronStrategy
:from nemo import lightning as nl strategy = nl.MegatronStrategy( tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, ... sequence_parallel=True, )
Context Parallel: Add
context_parallel_size=cp_size
inMegatronStrategy
:from nemo import lightning as nl strategy = nl.MegatronStrategy( tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, ... context_parallel_size=cp_size, )
Context Parallel currently applies only to the language model. The vision encoder will be duplicated on those ranks. We are working on a better strategy for this.
Virtual Pipeline Parallel: Add
virtual_pipeline_model_parallel_size=vpp_size
inMegatronStrategy
:from nemo import lightning as nl strategy = nl.MegatronStrategy( tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=pp_size, ... virtual_pipeline_model_parallel_size=vpp_size, )
FP8 Training: To enable FP8 training, you need to configure the
MegatronMixedPrecision
plugin and set appropriate FP8 arguments. For more details, check the Transformer Engine User Guide.
from nemo import lightning as nl
import torch
trainer = nl.Trainer(
num_nodes=1,
devices=8,
...
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
fp8='hybrid',
fp8_amax_history_len=16,
fp8_amax_compute_algo="max",
),
)
Below is a comprehensive list of pretraining recipes that we currently support or plan to support soon: