NeMo Framework End-to-End Workflow Example

This workflow provides a full end-to-end example of preparing a dataset and training a foundation model based on Llama-3.1-8B using the re-designed NeMo 2.0 for NeMo Framework. This guide will be split into sub-sections to describe each part in detail.

NeMo 2.0 now uses a Pythonic API which allows it to be integrated with IDEs such as Visual Studio Code (VS Code) and supports type checking.

While this guide demonstrates pre-training a Llama-3.1-8B model from scratch, it can be modified to train any supported model with NeMo 2.0. For more information about NeMo 2.0, including the latest list of supported models, see the NVIDIA NeMo Framework User Guide.

Requirements

The following is a list of requirements to follow this complete workflow:

  • An NVIDIA DGX Cloud Lepton cluster with at least 2x A100 or newer GPU nodes with 8 GPUs each.
  • VS Code installed on a local machine. Download instructions here.
  • Python 3.10 or newer with PIP installed on a local machine.
  • A shared filesystem with read/write access which is mountable in jobs.
  • A Hugging Face account with an API token (setup steps in the following section).
  • A Weights & Biases account with an API token (setup steps in the following section).

Initial Setup

This guide uses two external services to simplify the LLM development process: Hugging Face and Weights & Biases.

Hugging Face contains resources for many of the most popular language models and datasets in the community. We can leverage these resources while training the model to minimize deployment steps and be consistent with community model assumptions.

This workflow walks through training a Llama-3.1-8B model from scratch. The dataset we use needs to be tokenized using a custom tokenizer. Luckily, Meta, the company that produced the Llama models, published their tokenizer for the Llama models on Hugging Face. In order to use the tokenizer, we need to create a Hugging Face account and accept the Llama-3.1-8B license on their model repository page. The following steps walk you through that process.

Create a Hugging Face Account

If you don't have a Hugging Face account already, create one by going to https://huggingface.co/join and signing up with your corporate email account.

Once your account is set up, go to https://huggingface.co/settings/tokens while logged in to create a personal access token. Create a new token with Read access and give it a memorable name. Save the generated token in a safe place, as it won't be viewable again for security reasons.

Accept the Llama-3.1-8B License

As mentioned earlier, this example uses the official Llama-3.1-8B tokenizer available on Hugging Face, which requires agreeing to their license on their model page. To do so, navigate to https://huggingface.co/meta-llama/Llama-3.1-8B while logged in. Read the privacy policy at the top of the model card page, then click the Agree and access repository button towards the top of the page to accept the license. Now, you can download resources from this repository using your personal access token.

Create a Weights & Biases Account

Weights & Biases is a tool that allows developers to easily track experiments for AI applications. NeMo Framework natively supports logging many values such as training loss, learning rate, and gradient norm as well as resource utilization with Weights & Biases. Weights & Biases is highly recommended for tracking NeMo Framework jobs.

To get started with Weights & Biases, navigate to https://wandb.ai in a web browser and click the Sign Up button in the top right to create a free account. Once logged in, go to https://wandb.ai/settings and go to the bottom to create a new API key. This API key will be used while launching workflows to automatically log to Weights & Biases.

Set Up VS Code Locally

With VS Code installed on your local machine, run the application and open a new directory to save the scripts you'll use for launching jobs on the DGX Cloud cluster.

In VS Code, open a terminal window by clicking the Terminal > New Terminal button in the menu. Next, create a Python virtual environment and install the dependencies required for running NeMo 2.0 and Lepton using the following commands in the new terminal:

python3 -m venv env
source env/bin/activate
pip3 install nemo_toolkit[nlp] git+https://github.com/NVIDIA/nemo-run megatron-core opencc==1.1.6

Once dependencies are installed, the data preparation and training scripts can be defined using VS Code.

Note

The source env/bin/activate command above activates a Python virtual environment with the dependencies installed. If you need to leave the virtual environment, you can run deactivate. To activate it again, navigate back to the directory where the virtual environment named env was saved and run source env/bin/activate again. If you run into ModuleNotFound errors, it is likely the environment needs to be re-activated.

Authenticate with DGX Cloud Lepton

NeMo Framework on DGX Cloud Lepton leverages the Lepton Python SDK to upload data to the cluster and schedule jobs. To use the Python SDK, users need to authenticate with the cluster using the Lepton CLI tool installed in the previous step. The authentication credentials can be grabbed from the DGX Cloud Lepton UI and opening the Settings > Tokens page. This will show a command to authenticate with your workspace that will look similar to the following:

lep login -c xxxxxx:************************

Copy the code shown in the UI and run it locally in your terminal in VS Code to authenticate with the cluster. Once authenticated, the Python SDK will be connected to your cluster for all future commands.

Prepare the Data

NeMo Framework supports processing custom text-based datasets for pre-training new models. The data preprocessor requires datasets to be cleansed, excluding any sensitive or improperly formatted data that is unsuitable for use during pre-training. Each file in the dataset must be in .json or, ideally, .jsonl format. Datasets can be downloaded from external sources or uploaded directly to the remote filesystem.

The following example walks through downloading, extracting, concatenating, and preprocessing the SlimPajama dataset which includes a large corpus of text from several domains and has been deduplicated and cleaned to make it a great candidate for pre-training LLMs. While the remainder of the document will be based on the SlimPajama dataset, this general process can be followed for most custom datasets and will provide guidance on how to adapt as needed.

Set Up the Script

We will leverage four different scripts to prepare the SlimPajama dataset for pre-training a Llama-3.1-8B-based LLM. These scripts will be automatically copied to the remote filesystem. First, create a new sub-directory locally to save all of the files using this command:

mkdir -p data_prep

The four scripts that need to be created are as follows:

Download

The first script downloads the entire SlimPajama-627B training dataset from Hugging Face to the remote filesystem. The dataset is spread across nearly 60,000 individual shards, all needing to be downloaded independently. To make the process faster, the job leverages PyTorch distributed communication to spread the download equally amongst all workers in the cluster. Using the local VS Code session created previously, save the following file in the local directory at data_prep/download.py.

Note

The dataset is evenly divided amongst ten chunks on Hugging Face, each being its own subdirectory of files. The download.py script below has a CHUNKS = 10 variable at the top of the file to download all ten chunks. If desired, this value can be reduced to only download the first N chunks of the dataset. This is useful for quick workload validations that don't rely on a complete dataset. The remainder of this document will assume all ten chunks are pulled from the dataset, but the steps will still work if using fewer chunks.

import os
import requests
import time

CHUNKS = 10
SHARDS = 6000

def download_shard(url, filename, retry=False):
    if os.path.exists(filename):
        return

    response = requests.get(url)

    # In case of getting rate-limited, wait 3 seconds and retry the
    # download once.
    if response.status_code == 429 and not retry:
        time.sleep(3)
        download_shard(url, filename, retry=True)

    if response.status_code != 200:
        return

    with open(filename, 'wb') as fn:
        fn.write(response.content)

def split_shards(wsize):
    shards = []
    shards_to_download = list(range(SHARDS))

    for shard in range(wsize):
        idx_start = (shard * SHARDS) // wsize
        idx_end = ((shard + 1) * SHARDS) // wsize
        shards.append(shards_to_download[idx_start:idx_end])
    return shards

def download(directory=""):
    wrank = int(os.environ.get('RANK', 0))
    wsize = int(os.environ.get('WORLD_SIZE', 0))

    if wrank == 0:
        os.makedirs(directory, exist_ok=True)

    for chunk in range(1, CHUNKS + 1):
        shards_to_download = split_shards(wsize)

        for shard in shards_to_download[wrank]:
            filename = f'example_train_chunk{chunk}_shard{shard}.jsonl.zst'
            filename = os.path.join(directory, filename)
            url = f'https://huggingface.co/datasets/cerebras/SlimPajama-627B/resolve/main/train/chunk{chunk}/example_train_{shard}.jsonl.zst'
            download_shard(url, filename)
Extract

The individual dataset shards are compressed in the Zstandard or .zst format and must be decompressed. The following script distributes the downloaded files across all ranks, decompresses the shards, and then removes the compressed downloads to keep the filesystem clean. Using the local VS Code session, save the script in the local directory as data_prep/extract.py.

import os
from glob import glob
import zstandard as zstd


def split_shards(wsize, dataset):
    shards = []

    for shard in range(wsize):
        idx_start = (shard * len(dataset)) // wsize
        idx_end = ((shard + 1) * len(dataset)) // wsize
        shards.append(dataset[idx_start:idx_end])
    return shards

def extract_shard(shard):
    extracted_filename = shard.replace(".zst", "")

    # Very rare scenario where another rank has already processed a shard
    if not os.path.exists(shard):
        return

    with open(shard, "rb") as in_file, open(extracted_filename, "wb") as out_file:
        dctx = zstd.ZstdDecompressor(max_window_size=2**27)
        reader = dctx.stream_reader(in_file)

        while True:
            chunk = reader.read(4096)
            if not chunk:
                break
            out_file.write(chunk)

    os.remove(shard)

def extract(directory=""):
    wrank = int(os.environ.get("RANK", 0))
    wsize = int(os.environ.get("WORLD_SIZE", 0))

    dataset = sorted(glob(os.path.join(directory, "example_train*zst")))
    shards_to_extract = split_shards(wsize, dataset)

    for shard in shards_to_extract[wrank]:
        extract_shard(shard)
Concatenate

Given the SlimPajama dataset contains nearly 60,000 files, it is helpful to concatenate them into fewer, larger files. Processing a smaller number of large files is much faster than handling a large number of small files, which eliminates potential data bottlenecks during the pre-training stage.

The following script takes 1,200 individual shards at a time and combines them into one large file, repeating for the entire dataset. Each rank concatenates a unique subsection of the dataset and deletes the individual shards in the end. Using the local VS Code session, save the script in the local directory as data_prep/concat.sh.

Note

The script combines 1,200 individual shards by default into a single file. For the complete dataset, this will yield 50 larger combined files representing the data, each being approximately 51 GB in size. To change how many shards are used in each file, increase or decrease the shards_per_file variable below. A larger number will result in fewer files that are larger in size. A smaller number will result in more files that are smaller in size.

#!/bin/bash
directory=$1
shards_per_file=1200
num_files=`find ${directory} -name 'example_train_chunk*.jsonl' | wc -l`
files=(${directory}/example_train_chunk*.jsonl)
rank=$RANK
world_size=$WORLD_SIZE

# Find the ceiling of the result
shards=$(((num_files+shards_per_file-1)/shards_per_file ))

echo "Creating ${shards} combined chunk(s) comprising ${shards_per_file} files each"

for ((i=0; i<$shards; i++)); do
  if (( (( $i - $rank )) % $world_size )) ; then
    continue
  fi
  file_start=$((i*shards_per_file))

  if [[ $(((i+1)*shards_per_file)) -ge ${#files[@]} ]]; then
    file_stop=$((${#files[@]}-1))
  else
    file_stop=$(((i+1)*shards_per_file))
  fi

  echo "  Building chunk $i with files $file_start to $file_stop"
  for file in "${files[@]:$file_start:$shards_per_file}"; do
    cat "$file" >> "${directory}/slim_pajama_${i}.jsonl"
  done
  rm ${files[@]:$file_start:$shards_per_file}
done
Preprocess

Once all of the files have been concatenated, it is time to preprocess the dataset. The preprocessing phase tokenizes each dataset file using the Llama-3.1-8B tokenizer, which is downloaded from Hugging Face and creates .bin and .idx files for each concatenated file. As with the other scripts, this one divides the work amongst all available workers to speed up preprocessing. Using the local VS Code session, save the following script in the local directory as data_prep/preprocess.py.

Note

As mentioned, this script uses the Llama-3.1-8B tokenizer because the intent is to use this data for pre-training a Llama-3.1-8B model. However, the tokenizer can be swapped out for a different one available on Hugging Face if pre-training a different model is desired.

For example, the Mixtral-8x7B tokenizer from MistralAI can be used instead. To do this, replace both references of meta-llama/Meta-Llama-3.1-8B in the script with the repo ID of the Mixtral-8x7B model, mistralai/Mixtral-8x7B-v0.1. Additionally, update the filename and path to the tokenizer in the model repo, which is filename=tokenizer.model.

Be sure to accept any applicable licenses on the model repository page.

import os
import subprocess
from glob import glob

from huggingface_hub import hf_hub_download


def split_shards(wsize, dataset):
    shards = []

    for shard in range(wsize):
        idx_start = (shard * len(dataset)) // wsize
        idx_end = ((shard + 1) * len(dataset)) // wsize
        shards.append(dataset[idx_start:idx_end])
    return shards

def preprocess(directory=""):
    wrank = int(os.environ.get("RANK", 0))
    wsize = int(os.environ.get("WORLD_SIZE", 1))

    dataset = sorted(glob(os.path.join(directory, "slim_pajama*jsonl")))
    shards_to_extract = split_shards(wsize, dataset)

    if wrank == 0:
        # Download a specific file from a repository
        hf_hub_download(
            repo_id="meta-llama/Meta-Llama-3.1-8B",
            filename="original/tokenizer.model",
            local_dir="/nemo-workspace/tokenizers/llama-3.1-8b"
        )

    for num, shard in enumerate(shards_to_extract[wrank]):
        shard_num = wrank + (num * wsize)  # Counter for which file is processed
        output_path = os.path.join(directory, f"llama-slim-pajama-{shard_num}")
        command = (
            "python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py "
            f"--input {shard} "
            f"--output-prefix {output_path} "
            f"--dataset-impl mmap "
            f"--tokenizer-type meta-llama/Meta-Llama-3.1-8B "
            f"--tokenizer-library huggingface "
            f"--tokenizer-model /nemo-workspace/tokenizers/llama-3.1-8b/original/tokenizer.model "
            f"--workers 80"
        )
        subprocess.run([command], shell=True)
Data Prep

A final script needs to be written to launch all of the data preparation jobs on the cluster. This uses NeMo-Run to authenticate with the DGX Cloud Lepton cluster and run distributed PyTorch jobs directly on the cluster. The jobs will be launched sequentially in the order they are called. Using the local VS Code session, save the following script locally as data-prep.py.

Several lines in the script below will need to be modified to reflect your cluster. The lines are as follows:

  • resource_shape="gpu.h100-80gb": Replace gpu.h100-80gb with the desired resource shape. This is the GPU type and configuration to use for the job, such as gpu.8xh100-80gb might refer to a pod with 8x H100 GPUs available in it.
  • node_group="xxxxx": Replace xxxxx with the node group to run in. The list of available node groups can be found in the Nodes tab in the UI.
  • "HF_TOKEN": "xxxxxxxxxxxxxxxxxx": Add your Hugging Face authentication token between the quotation marks.
  • executor = lepton_executor(nodes=8, devices=1): The example runs on 8 pods with 1 process per node. If more nodes/processes are required, specify the amount here.
  • "from": "local:nfs": If using remote shared storage, enter the name of the storage to mount in all jobs. This can be found in the UI while creating a job and selecting a storage option.
import nemo_run as run

from data_prep.download import download
from data_prep.extract import extract
from data_prep.preprocess import preprocess

def lepton_executor(nodes: int = 1, devices: int = 1) -> run.LeptonExecutor:
    mounts = [
        {
            "path": "/",  # Directory to mount from the remote filesystem
            "mount_path": "/nemo-workspace"  # Where to mount the directory in pods
            "from": "local:nfs"  # (Optional) Which remote storage resource to mount
        }
    ]

    return run.LeptonExecutor(
        resource_shape="gpu.a100-80gb",  # Replace with the resource shape for the node group
        container_image="nvcr.io/nvidia/nemo:25.02",  # Which container to deploy
        nemo_run_dir="/nemo-workspace/nemo-run",  # Specify the NeMo-Run directory to copy experiments to in the remote filesystem
        mounts=mounts,  # Which directories to mount from the remote filesystem
        node_group="xxxxx",  # Replace with the name of the node group available in the cluster
        nodes=nodes,  # Number of nodes to run on
        nprocs_per_node=devices,  # Number of processes per node to use
        env_vars={
            "HF_TOKEN": "xxxxxxxxxxxxxxxxxx",  # Add your Hugging Face API token here
            "TORCH_HOME": "/nemo-workspace/.cache"  # Save downloaded models and tokenizers to the remote storage cache
        },
        launcher="torchrun",  # Use torchrun to launch the processes
        packager=run.PatternPackager(  # Copy the data prep scripts to the filesystem for execution
            include_pattern="data_prep/*",
            relative_path=""
        )
    )

def prepare_slim_pajama():
    executor = lepton_executor(nodes=8, devices=1)

    # Create a NeMo-Run experiment which runs all sub-steps sequentially
    with run.Experiment("slim-pajama-data-prep") as exp:
        exp.add(run.Partial(download, "/nemo-workspace/data"), name="download", executor=executor)
        exp.add(run.Partial(extract, "/nemo-workspace/data"), name="extract", executor=executor)
        exp.add(run.Script("/nemo_run/code/data_prep/concat.sh", args=["/nemo-workspace/data"]), name="concat", executor=executor)
        exp.add(run.Partial(preprocess, "/nemo-workspace/data"), name="preprocess", executor=executor)

        # Launch the experiment on the cluster
        exp.run(sequential=True)

if __name__ == "__main__":
    prepare_slim_pajama()

Launch Data Preparation

Once all the scripts are saved in the specified location, it is time to launch the preprocessing job. NeMo-Run will launch the job automatically on the cluster, so starting data preparation is as simple as running a Python command. Launch data preparation with the following command in the terminal of your local VS Code session:

chmod +x data_prep/concat.sh
python3 data-prep.py

After creating the data preparation job, a pod for each worker and primary will be scheduled and started once resources become available on the cluster. The process can be monitored by viewing the logs in the DGX Cloud Lepton UI. The /nemo-workspace/data directory will evolve throughout the process with the following changes at the end of each stage:

  • After downloading, there will be 59,166 compressed data shards named example_train_chunkX_shardY.jsonl.zst where X is the chunk number from 1-10 and Y is the individual shard number within that chunk. Each file is approximately 15 MB in size.
  • After extraction, there will be 59,166 unzipped data shards named example_train_chunkX_shardY.jsonl and all of the compressed .zst files will be removed. Each file is approximately 44 MB in size.
  • After concatenation, there will be 50 large, combined files named slim_pajama_N.jsonl where N ranges from 0-49. Each file will be approximately 51 GB in size. It is normal for the last file to be smaller in size as it doesn't contain an even 1,200 shards. All of the individual example_train* files will be removed.
  • After preprocessing, there will be 50 .bin files and 50 .idx files named llama-slim-pajama-N_text_document, where N corresponds to the combined data file number. Each .bin file should be approximately 26 GB in size and .idx files should be 229 MB.

Once all 50 files have been preprocessed, it is time to begin pre-training the model.

Pre-Train the Model

NeMo Framework contains many predefined configurations for various models, including the Llama 3.1-8B model. This section will demonstrate how to initiate training a Llama 3.1-8B model on DGX Cloud Lepton using the preprocessed SlimPajama dataset.

Pre-training is the most compute-intensive phase of the LLM training process as the model is typically trained for hundreds of billions to trillions of tokens while it learns the vocabulary and word pairings of the underlying dataset. Depending on the size of the dataset and model as well as the amount of compute resources available to train the model, this process can take anywhere from several days to a few months to finish. Therefore it is strongly recommended to leverage as much of your available compute power as possible for pre-training the model.

Set Up the Environment

Now the training job can be defined. The following script is used to launch pre-training of a Llama 3.1-8B model for 627B tokens using the SlimPajama dataset that was prepared. Save the script to llama-pretrain.py locally using your VS Code session. Note, as with data preparation earlier, several lines will need to be modified to reflect your cluster. These lines are as follows:

  • resource_shape="gpu.8xh100-80gb": Replace gpu.8xh100-80gb with the desired resource shape. This is the GPU type and configuration to use for the job, such as gpu.8xh100-80gb might refer to a pod with 8x H100 GPUs available in it.
  • node_group="xxxxx": Replace xxxxx with the node group to run in. The list of available node groups can be found in the Nodes tab in the UI.
  • "HF_TOKEN": "xxxxxxxxxxxxxxxxxx": Add your Hugging Face authentication token between the quotation marks.
  • "WANDB_API_KEY": "xxxxxxxxxxxxxxxxxx": Add your Weights & Biases authentication token between the quotation marks.
  • "from": "local:nfs": If using remote shared storage, enter the name of the storage to mount in all jobs. This can be found in the UI while creating a job and selecting a storage option.
import os
import nemo_run as run

from nemo.collections import llm
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule
from nemo.collections.llm.recipes.log.default import default_log, wandb_logger
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.utils.exp_manager import TimingCallback


def configure_recipe(nodes: int = 1, gpus_per_node: int = 2, dir=None, name="nemo"):
    paths = [os.path.join(data_dir, f"llama-slim-pajama-{num}_text_document") for num in range(30)]
    tokenizer = run.Config(AutoTokenizer, pretrained_model_name="meta-llama/Llama-3.1-8B")

    data=run.Config(
        PreTrainingDataModule,
        paths=paths,
        seq_length=8192,  # Use a sequence length or context window of 8K tokens
        global_batch_size=512,  # Batch size of 512
        micro_batch_size=1,
        tokenizer=tokenizer
    )

    wandb = wandb_logger(
        project="llama-3.1",  # Specify the Weights & Biases project name
        name="llama-3.1-8b"  # Specify the name of the training run to be displayed on Weights & Biases
    )

    recipe = run.Partial(
        llm.pretrain,  # Specify that we want to use the Pre-train method
        model=llm.llama31_8b.model(),  # Use the existing Llama 3.1-8B model config default settings
        trainer=llm.llama31_8b.trainer(
            num_nodes=nodes,
            num_gpus_per_node=gpus_per_node,
            max_steps=150000,  # Train for 150,000 steps - equal to 150,000 * batch size (512) * sequence length (8192) = 629B tokens
            callbacks=[run.Config(TimingCallback)],
        ),
        data=data,
        optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
        log=default_log(dir=dir, name=name, wandb_logger=wandb),
    )

    recipe.trainer.val_check_interval = 2000  # Run evaluation and save a checkpoint every 2,000 steps
    recipe.trainer.strategy.tensor_model_parallel_size = 4  # Set the Tensor Parallelism size to 4
    return recipe

def lepton_executor(nodes: int = 1, devices: int = 1) -> run.LeptonExecutor:
    mounts = [
        {
            "path": "/nemo-workspace",  # Directory to mount from the remote filesystem
            "mount_path": "/nemo-workspace"  # Where to mount the directory in pods
            "from": "local:nfs"  # (Optional) Which remote storage resource to mount
        }
    ]

    return run.LeptonExecutor(
        resource_shape="gpu.8xh100-80gb",  # Replace with the resource shape for the node group
        container_image="nvcr.io/nvidia/nemo:25.02",  # Which container to deploy
        nemo_run_dir="/nemo-workspace/nemo-run",  # Specify the NeMo-Run directory to copy experiments to in the remote filesystem
        mounts=mounts,  # Which directories to mount from the remote filesystem
        node_group="xxxxx",  # Replace with the name of the node group available in the cluster
        nodes=nodes,  # Number of nodes to run on
        nprocs_per_node=devices,  # Number of processes per node to use
        env_vars={
            "PYTHONPATH": "/nemo-workspace/nemo-run:$PYTHONPATH",  # Add the NeMo-Run directory to the PYTHONPATH
            "TORCH_HOME": "/nemo-workspace/.cache",  # Save downloaded models and tokenizers to the remote storage cache
            "HF_TOKEN": "xxxxxxxxxxxxxxxxxx",  # Add your Hugging Face API token here
            "WANDB_API_KEY": "xxxxxxxxxxxxxxxxxx"  # Add your Weights & Biases API token here
        },
        launcher="torchrun",  # Use torchrun to launch the processes
        packager=run.PatternPackager(  # Copy the data prep scripts to the filesystem for execution
            include_pattern=["data_prep/*", "scripts/*"],
            relative_path=["", ""]
        )
    )

def run_pretraining():
    recipe = configure_recipe(nodes=8, gpus_per_node=8, dir="/nemo-workspace/llama-3.1-8b", name="llama-3.1-8b")
    executor = lepton_executor(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices)

    run.run(recipe, executor=executor)

    # Re-initialize the executor as only a single GPU is needed for conversion
    executor = lepton_executor(nodes=1, devices=1)
    export_ckpt = convert_checkpoint(dir="/nemo-workspace/llama-3.1-8b")

    run.run(run.Partial(convert_checkpoint, "/nemo-workspace/llama-3.1-8b"), name="convert-model", executor=executor)

if __name__ == "__main__":
    run_pretraining()

Depending on how many resources you have available, you can also change the number of nodes used for pre-training by modifying this line:

recipe = configure_recipe(nodes=8, gpus_per_node=8, dir="/nemo-workspace/llama-3.1-8b", name="llama-3.1-8b")

Update the nodes=8 line to the desired number of nodes to train with. Keep gpus_per_node at 8 as this allows optimal multi-node communication over NCCL.

Additionally, a Python script needs to be created which converts the model to Hugging Face format once training finishes. Create a new directory named scripts using:

mkdir -p scripts

Copy and save the following Python script to scripts/convert.py:

import os

from nemo.collections import llm


def last_checkpoint(directory=""):
    checkpoints = []

    for root, dirs, _ in os.walk(directory):
        for dir in dirs:
            if dir.endswith("-last"):
                checkpoints.append(os.path.join(root, dir))
    # Return the most recent checkpoint found
    return max(checkpoints, key=os.path.getmtime)

def convert_checkpoint(dir=""):
    checkpoint = last_checkpoint(dir)

    llm.export_ckpt(
        path=checkpoint,
        target="hf",
        overwrite=True,
        output_path=f"{dir}/huggingface"
    )

This script will run after the model completes pre-training and find the final checkpoint in the training directory and convert it to Hugging Face format where it can be used for downstream tasks.

Launch the Pre-Training Job

After modifying and saving the llama-pretrain.py script locally and saving the conversion script, launch the pre-training job from the terminal in your local VS Code session using the following command:

python3 llama-pretrain.py
Note

Make sure your Python virtual environment is activated before running this command.

The job will be scheduled with DGX Cloud Lepton and will launch once resources become available. After submission, the job will appear in the DGX Cloud Lepton Batch Jobs page.

NeMo Framework is fully integrated with Weights & Biases and logs multiple metrics that can be viewable on their website. If the W&B key was provided in the command, a new W&B project will automatically be created and metrics will be uploaded there. Viewing logs on W&B is recommended as the best path to monitor training progress.

View the Project Dashboard on Weights & Biases

To view your charts, navigate to https://wandb.ai. You should see a link to the newly created project on your home page. Clicking the link will take you to your project dashboard which should look similar to the following. Note that the figure below includes training results for two different runs where the second run is a continuation of the first.

Weights & Biases

Two of the most important charts to monitor during pre-training are the reduced_train_loss and val_loss charts which show how the model is learning over time. In general, these charts should have an exponential decay shape.

The job will take around four weeks to complete on 8 nodes. Since NeMo Framework pre-training scales linearly, doubling the number of nodes should halve the amount of time required to pre-train the model.

While the model trains, a checkpoint will be saved every 2,000 steps in the PVC. Per the command above, the checkpoints will be saved in the /nemo-workspace/llama-3.1-8b/llama-3.1-8b/<date>/checkpoints directory where <date> is a timestamp of when the job was launched. Only the 10 checkpoints with the best val_loss values as well as the latest checkpoint will be saved. These checkpoints will be used for future fine-tuning runs.

After pre-training finishes, another task will begin to convert the final pre-trained model checkpoint to the Hugging Face format. This spins up another pod with a single GPU which is required for conversion. The final Hugging Face model will be saved at /nemo-workspace/llama-3.1-8b/huggingface. The converted Hugging Face model can be deployed as a NIM for inference.

Copyright @ 2025, NVIDIA Corporation.