Important

NeMo 2.0 is an experimental feature and currently released in the dev container only: nvcr.io/nvidia/nemo:dev. Please refer to the Migration Guide for information on getting started.

Checkpoint Conversion

Obtain the Checkpoints from Hugging Face

To obtain the checkpoint you want from Hugging Face, go to:

  1. Repository for the Mamba2 and Mamba2-Hybrid models by NVIDIA. The checkpoint from this repository is located in the files tab under release/mp_rank_00/model_optim_rng.pt. The tokenizer is under the files tab and is named mt_nlg_plus_multilingual_ja_zh_the_stack_frac_015_256k.model. You need both of these for conversion to .nemo checkpoint.

  2. Repository for the Mamba2 models from the Transformers are SSMs paper.

    For checkpoints from this repository, run the following Python script to convert the pytorch checkpoint (pytorch_model.bin in the Hugging Face model card) to a format similar to the 8b models:

    import torch
    import os
    
    ckpt_path = "/path/to/pytorch_model.bin"
    pyt_checkpoint = torch.load(ckpt_path)
    new_ckpt_path = os.path.join(os.path.dirname(ckpt_path), f"wrapped_{os.path.basename(ckpt_path)}")
    
    # Save the new checkpoint which will be used as the input to the conversion script
    torch.save({"model": pyt_checkpoint}, new_ckpt_path)
    

    You will use this wrapped_pytorch_model.bin for the conversion to .nemo in the next step.

Convert the PyTorch Checkpoint to a NeMo Checkpoint

  1. Authenticate with NVIDIA NGC, generate API KEY from NGC, add the key to your credentials following instructions in this guide, and get into NVIDIA NeMo dev container nvcr.io/nvidia/nemo:dev.

  2. Get into the NVIDIA dev container from NGC, or the 24.07 container (once released).

  3. Run the conversion script located at. For this script, you need to provide the PyTorch state dictionary of the model as the input_name_or_path argument. Note that this argument only accepts a single state_dict.

CUDA_VISIBLE_DEVICES="0" python /opt/NeMo/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py \
                                --input_name_or_path <path to the source pytorch model> \
                                --output_path <path to target .nemo model> \
                                --mamba_ssm_ngroups 8 \
                                --precision bf16 \
                                --tokenizer_model_dir=<path to tokenizer.model> # Remove this line (or set it to None) for 130m, 370m, 780m, 1.3b, and 2.7b models.

Note

The mamba_ssm_ngroups parameter should be set to 1 for the Mamba2 models from the Transformers are SSMs paper (130m, 370m, 780m, 1.3b, and 2.7b) and to 8 for the Mamba2 and Mamba2-Hybrid models by NVIDIA (both 8b).

Run Tensor Parallelism (TP) for 8b Models

Note

Distributed checkpointing for the Mamba2 and Mamba2-Hybrid models will be implemented soon. In the meantime, use the method below to convert to Tensor Parallel (TP) of different sizes.

The Hugging Face checkpoint for the 8b model is configured for a TP size 1, as is the .nemo checkpoint obtained in the previous step. To share the model weights for a larger TP size, use this script located at the NeMo Repository.

python /opt/NeMo/examples/nlp/language_modeling/mamba_change_num_partition.py \
       --model_file=<path to source .nemo model> \
       --target_file=<path to target .nemo model> \
       --tensor_model_parallel_size=1 \
       --target_tensor_model_parallel_size=4 \
       --precision=bf16 \
       --tokenizer_path=<path to tokenizer.model>

After running this script, a .nemo model and the corresponding number of TP-size folders (4 in this example) will be generated in the target path. The folders for each rank will be displayed as mp_rank_00 to mp_rank_03 in this example.

Note

You can only use Tensor Parallelism for the 8b models by NVIDIA (Mamba2 8b and Mamba2-Hybrid 8b). This is because the mamba_ssm_ngroups parameter in the model architecture should be divisible by the TP size. The mamba_ssm_ngroups parameter is 8 for NVIDIA models and 1 for other models in the list.