Note
Attention: Dedicated Container for Gemma
For Gemma models, please use the nvcr.io/nvidia/nemo:24.05
container. Also check our Gemma playbooks.
Note
Attention: Dedicated Container for CodeGemma
For CodeGemma models, please use the nvcr.io/nvidia/nemo:24.03.codegemma
container. Aside from file names, the same script and command is valid for both Gemma and CodeGemma.
Checkpoint Conversion
NVIDIA provides scripts to convert the external Gemma checkpoints from Jax, Pytorch, and HuggingFace format to .nemo
format. The .nemo
checkpoint will be used for SFT, PEFT, and inference.
Run the container using the following command:
docker run --gpus device=1 --shm-size=2g --net=host --ulimit memlock=-1 --rm -it -v ${PWD}:/workspace -w /workspace -v ${PWD}/results:/results nvcr.io/nvidia/nemo:24.05 bash
Option 1: Convert the Jax Gemma model to .nemo model (clone Google Gemma Jax repo to /path/to/google/gemma_jax
):
pip install orbax jax flax jaxlib; \
export PYTHONPATH=/path/to/google/gemma_jax:$PYTHONPATH; \
python3 /opt/NeMo/scripts/checkpoint_converters/convert_gemma_jax_to_nemo.py \
--input_name_or_path /path/to/gemma/checkpoints/jax/7b \
--output_path /path/to/gemma-7b.nemo \
--tokenizer_path /path/to/tokenizer.model
Option 2: Convert the Pytorch Gemma model to .nemo model (clone Google Gemma PyTorch repo to /path/to/google/gemma_pytorch
):
pip install fairscale==0.4.13 immutabledict==4.1.0 tensorstore==0.1.45; \
export PYTHONPATH=/path/to/google/gemma_pytorch:$PYTHONPATH; \
python3 /opt/NeMo/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py \
--input_name_or_path /path/to/gemma/checkpoints/pyt/7b.ckpt \
--output_path /path/to/gemma-7b.nemo \
--tokenizer_path /path/to/tokenizer.model
Option 3: Convert the HuggingFace Gemma model to .nemo model:
python3 /opt/NeMo/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py \
--input_name_or_path /path/to/gemma/checkpoints/hf/7b \
--output_path /path/to/gemma-7b.nemo \
--tokenizer_path /path/to/tokenizer.model
The generated gemma-7b.nemo file uses distributed checkpointing and can be loaded with any tensor parallel (tp) or pipeline parallel (pp) combination without reshaping/splitting.