NeVA (LLaVA)#

Originating from LLaVA v1.5 (Large Language and Vision Assistant), NeVA is a key addition to the NeMo Multimodal ecosystem. This model integrates large language-centric models (like Nemotron, Llama 3, Mistral, etc.) with a vision encoder and is trained using machine-generated multimodal language-image instruction-following data. Building on the foundation set by LLaVA, NeVA further enhances training by leveraging features of the NeMo LLM framework such as model parallelism, sequence parallelism, activation checkpointing, AMP O2, CuDNN/Flash Attention, and more.

Import from Hugging Face to NeMo 2.0#

To import the Hugging Face (HF) model and convert it to NeMo 2.0 format, run the following command. This step only needs to be performed once:

from nemo.collections.llm import import_ckpt
from nemo.collections import vlm

if __name__ == '__main__':
    # Specify the Hugging Face model ID
    hf_model_id = "llava-hf/llava-1.5-7b-hf"

    # Import the model and convert to NeMo 2.0 format
    import_ckpt(
        model=vlm.LlavaModel(vlm.Llava15Config7B()),  # Model configuration
        source=f"hf://{hf_model_id}",  # Hugging Face model source
    )

The command above saves the converted file in the NeMo cache folder, located at: -/.cache/nemo.

If needed, you can change the default cache directory by setting the NEMO_CACHE_DIR environment variable before running the script.

NeMo 2.0 Fine-Tuning Recipes#

We provide pre-defined recipes for fine-tuning LLaVA v1.5 using NeMo 2.0 and NeMo-Run. These recipes configure a run.Partial for one of the nemo.collections.llm api functions introduced in NeMo 2.0. The recipes are hosted in llava15_7b and llava15_11b files. The recipes use mock dataset for training.

Note

The recipes use the MockDataModule for the data argument. You are expected to replace the MockDataModule with your custom dataset.

By default, the non-instruct version of the model is loaded. To load a different model, set finetune.resume.restore_config.path=nemo://<hf_model_id> or finetune.resume.restore_config.path=<local_model_path>.

We provide an example below on how to invoke the default recipe and override the data argument:

from nemo.collections import vlm

finetune = vlm.llava15_7b.finetune_recipe(
    name="llava15_7b_finetune",
    dir=f"/path/to/checkpoints",
    num_nodes=1,
    num_gpus_per_node=8,
    peft_scheme='lora',  # 'lora', 'none'
)

By default, the fine-tuning recipe applies LoRA to all linear layers in the language model, including cross-attention layers, while keeping the vision model unfrozen.

  • To configure which layers to apply LoRA: Set finetune.peft.target_modules. For example, to apply LoRA only on the self-attention qkv projection layers, set finetune.peft.target_modules=["*.language_model.*.linear_qkv"].

  • To freeze the vision model: Set finetune.peft.freeze_vision_model=True.

  • To fine-tune the entire model without LoRA: Set peft_scheme='none' in the recipe argument.

Note

The configuration in the recipes is done using the NeMo-Run run.Config and run.Partial configuration objects. Please review the NeMo-Run documentation to learn more about its configuration and execution system.

Once you have your final configuration ready, you can execute it on any of the NeMo-Run supported executors. The simplest is the local executor, which just runs the pretraining locally in a separate process. You can use it as follows:

import nemo_run as run

run.run(finetune, executor=run.LocalExecutor())

Additionally, you can also run it directly in the same Python process as follows:

run.run(finetune, direct=True)

Bring Your Own Data#

Replace the MockDataModule in default recipes with your custom dataset. Below, we show an example with llava-like dataset:

from nemo.collections import vlm

# Define the fine-tuning recipe
finetune = vlm.llava15_7b.finetune_recipe(
    name="llava15_7b_finetune",
    dir="/path/to/checkpoints",
    num_nodes=1,
    num_gpus_per_node=8,
    peft_scheme='lora',  # 'lora', 'none'
)

# The following is an example of a custom dataset configuration.
data_config = vlm.ImageDataConfig(
    image_folder="/path/to/images",
    conv_template="v1",  # Customize based on your dataset needs
)

# Data module setup
custom_data = vlm.NevaPreloadedDataModule(
    paths="/path/to/dataset.json",  # Path to your dataset
    data_config=data_config,
    seq_length=2048,
    global_batch_size=16,  # Global batch size
    micro_batch_size=1,  # Micro batch size
    tokenizer=None,  # Define your tokenizer if needed
    image_processor=None,  # Add an image processor if required
    num_workers=8,  # Number of workers for data loading
)

# Assign custom data to the fine-tuning recipe
finetune.data = custom_data

Use the Energon Dataloader#

The Energon Dataloader can be used with NeVA to handle multimodal datasets for training. This section explains how to set up and customize the dataloader, highlighting key components such as the task_encoder and multimodal_sample_config. For details on preparing data to use with the data module, refer to data preparation section.

Example Code#

Below is an example of how to use the Energon Dataloader with NeVA for training:

from nemo.collections.multimodal.data.energon import (
    ImageToken,
    MultiModalSampleConfig,
    EnergonMultiModalDataModule,
)
from transformers import AutoProcessor

# Load processor, tokenizer, and image processor from pre-trained model
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
tokenizer = processor.tokenizer
image_processor = processor.image_processor

# Define dataset path
dataset_path = "<path_to_dataset>"

# Configure multimodal samples
config = MultiModalSampleConfig(
    image_token=ImageToken(token_str="<image>", token_id=-200),
    ignore_place_holder=-100
)

# Initialize the data module
data_module = EnergonMultiModalDataModule(
    path=dataset_path,
    tokenizer=tokenizer,
    image_processor=image_processor,
    seq_length=2048,
    micro_batch_size=1,
    global_batch_size=16,
    num_workers=0,
    multimodal_sample_config=config,
)

# Note: `EnergonMultiModalDataModule` defaults to `MultiModalTaskEncoder` if no custom task encoder is provided.

Explanation of Parameters#

  • path: Path to the dataset.

  • tokenizer: Tokenizer used for processing text data.

  • `image_processor`: Image Processor to preprocess images and prepare them for input to the vision model. ImageProcessor.

  • seq_length: Maximum sequence length for tokenized text (default: 2048).

  • micro_batch_size`: Batch size for each GPU or process (default: 1).

  • global_batch_size: Total batch size across all GPUs or processes (default: 1).

  • num_workers: Number of workers for data loading (default: 1).

  • multimodal_sample_config: Configuration for multimodal samples, allowing customization of image tokens, placeholder values, and template configurations.

  • task_encoder: A custom task encoder can be provided if needed. By default, EnergonMultiModalDataModule uses MultiModalTaskEncoder.

Key Components#

  1. Task Encoder: The MultiModalTaskEncoder is a flexible encoder capable of handling various multimodal sample types such as VQA, captioning, interleaved and similarity interleaved samples. By default, the task_encoder defaults to MultiModalTaskEncoder if no custom task encoder is provided. This makes it easy to get started while still allowing customization for advanced use cases.

    It is flexible enough to register additional custom encoders for new sample types. The encoder processes the samples into batches and prepares them for input to the NeVA model. For more details on sample types, please refer to Energon documentation - Megatron-Energon sample types.

    This modular design ensures that it can be adapted to a wide variety of multimodal training scenarios.

  2. MultiModalSampleConfig: The MultiModalSampleConfig defines the configuration for multimodal samples. It includes the following default values, which can be customized as needed:

    • image_token: The default token configuration for images. By default, the placeholder string for images is the string '<image>' and the token ID is -200.

    • ignore_place_holder: The default value is -100, which is used to represent placeholder tokens to be ignored during loss computation.

    • conversation_template_config: The default value is LLaVATemplateConfig. This configuration is for multimodal conversation templates and is used to apply a prompt template to the input text before tokenization. If conversation_template_config is provided, it will be used to generate the conversation prompt. If not, and the tokenizer has a chat template defined (tokenizer.chat_template), the tokenizer’s chat template will be used. If neither the tokenizer nor the conversation_template_config has a chat template defined, a ValueError will be raised for VQA samples.

    • image_following_text: A boolean indicating if image tokens should follow text tokens. It defaults to True.

    Below is an example of conversation template configuration:

    class BaseConversationTemplateConfig:
        """Conversation template config related parameters"""
    
        system: Optional[str] = (
            "A chat between a curious user and artificial assistant agent. The assistant gives helpful, detailed, and polite answers to user's questions.".format()
        )  # fmt: off
        roles: List[str] = field(default_factory=lambda: ['user', 'assistant'])
        stop_string: str = "</s>"
        chat_template = """
        {%- for message in messages %}
            {%- if message['role'] == 'system' %}
                {{- message['content'].strip() + ' ' -}}
            {%- elif message['role'] == 'user' %}
                {{- 'USER: ' -}} {{- message['content'].strip() + ' ' -}}
            {%- elif message['role'] == 'assistant' %}
                {{- 'ASSISTANT: ' -}} {{- message['content'].strip() -}}
                {{- '</s>' -}}
            {%- endif %}
        {%- endfor -%}
        """
    

    This configuration includes the following parameters:

    • system: A string that defines the system’s description or purpose, such as “A chat between a curious user and artificial assistant agent.”

    • roles: A list of roles in the conversation (default: [‘user’, ‘assistant’]).

    • stop_string: A string to indicate the end of a conversation (default: </s>).

    • chat_template: A Jinja2 template used to format the conversation into a sequence of messages for input tokenization.

    1. If conversation_template_config is provided, it takes precedence and is used to format the conversation.

    2. If conversation_template_config is not provided, but tokenizer.chat_template exists, the tokenizer’s template will be used.

    3. If neither conversation_template_config nor tokenizer.chat_template exists, a ValueError will be raised for VQA samples.

Supported Features#

  1. Sequence Packing: Please refer to Run SFT/PEFT with Packed Sequences in NeVA for more information.

  2. Additional Model Parallelisms: Check out our example fine-tuning script for full examples on setting additional model parallelisms.

    • Sequence Parallel: Add sequence_parallel=True in MegatronStrategy:

      from nemo import lightning as nl
      strategy = nl.MegatronStrategy(
          tensor_model_parallel_size=tp_size,
          pipeline_model_parallel_size=pp_size,
          ...
          sequence_parallel=True,
      )
      
    • Context Parallel: Add context_parallel_size=cp_size in MegatronStrategy:

      from nemo import lightning as nl
      strategy = nl.MegatronStrategy(
          tensor_model_parallel_size=tp_size,
          pipeline_model_parallel_size=pp_size,
          ...
          context_parallel_size=cp_size,
      )
      

      Context Parallel currently applies only to the language model. The vision encoder will be duplicated on those ranks. We are working on a better strategy for this.

    • Virtual Pipeline Parallel: Add virtual_pipeline_model_parallel_size=vpp_size in MegatronStrategy:

      from nemo import lightning as nl
      strategy = nl.MegatronStrategy(
          tensor_model_parallel_size=tp_size,
          pipeline_model_parallel_size=pp_size,
          ...
          virtual_pipeline_model_parallel_size=vpp_size,
      )
      
  3. Supported Vision Encoders: NeMo supports the following Vision Encoders:

    • CLIPViT Vision Encoder (Hugging Face and Megatron Core backend)

    • SigLIPViT Vision Encoder (Megatron Core backend)

    • InternViT Vision Encoder (Megatron Core backend)

    You can define the corresponding vision transformer config and insert it into the definition of the Neva config. Check out our example fine-tuning script for full examples.

    For Megatron backend models, you must first convert the module weights into the NeMo format before loading them with vision_model_from_pretrained, otherwise it would be random weights inside vision encoders. Below is an example for converting InternViT; similar steps apply for CLIP and SigLIP.

    from nemo.collections import vlm
    from nemo.collections.llm import import_ckpt
    
    if __name__ == '__main__':
        model_id = "OpenGVLab/InternViT-300M-448px-V2_5"
        model = vlm.InternViTModel(vlm.InternViT_300M_448px_Config())
        import_ckpt(model=model,
                    source=f'hf://{model_id}',
        )
    

    The command above saves the converted file in the NeMo cache folder, located at: -/.cache/nemo.

    from nemo.collections import vlm
    vision_transformer_config = vlm.HFCLIPVisionConfig(
        pretrained_model_name_or_path="openai/clip-vit-large-patch14-336"  # Change model ID here
    )
    neva_config = vlm.NevaConfig(
        vision_transformer_config=vision_transformer_config,
        ...
    )
    
    from nemo.collections import vlm
    vision_transformer_config = vlm.CLIPViTL_14_336_Config()
    neva_config = vlm.NevaConfig(
        vision_transformer_config=vision_transformer_config,
        vision_model_from_pretrained="/path/to/converted/clip_vit_model",
        ...
    )
    
    from nemo.collections import vlm
    vision_transformer_config = vlm.SigLIPViT400M_14_384_Config()
    neva_config = vlm.NevaConfig(
        vision_transformer_config=vision_transformer_config,
        vision_model_from_pretrained="/path/to/converted/siglip_vit_model",
        ...
    )
    
    from nemo.collections import vlm
    vision_transformer_config = vlm.InternViT_300M_448px_Config()
    neva_config = vlm.NevaConfig(
        vision_transformer_config=vision_transformer_config,
        vision_model_from_pretrained="/path/to/converted/intern_vit_model",
        ...
    )
    
  4. FP8 Training: To enable FP8 training, you need to configure the MegatronMixedPrecision plugin and set appropriate FP8 arguments. For more details, check the Transformer Engine User Guide.

from nemo import lightning as nl
import torch

trainer = nl.Trainer(
    num_nodes=1,
    devices=8,
    ...
    plugins=nl.MegatronMixedPrecision(
        precision="bf16-mixed",
        params_dtype=torch.bfloat16,
        fp8='hybrid',
        fp8_amax_history_len=16,
        fp8_amax_compute_algo="max",
    ),
)

Below is a comprehensive list of pretraining recipes that we currently support or plan to support soon:

Recipe

Status

LLaVA 1.5 7B LoRA

Yes

LLaVA 1.5 7B Full fine-tuning

Yes

LLaVA 1.5 11B LoRA

Yes

LLaVA 1.5 11B Full fine-tuning

Yes