NeMo Command Line Interface#
Introduction#
NeMo provides a powerful command line interface (CLI) that makes it easy to train, fine-tune, evaluate, and deploy models. The CLI follows a consistent pattern that makes it straightforward to use once you understand the basic structure:
nemo [collection] [command] [options]
Where:
collection
is the model collection (e.g.,llm
)command
is the action to perform (e.g.,pretrain
,finetune
,generate
)options
are additional parameters to customize the command
Note
Currently, the NeMo 2.0 CLI is only available for the LLM collection, more collections will be supported in future releases.
Basic Usage#
To see available commands within a collection, use the help flag:
$ nemo llm --help
Usage: nemo llm [OPTIONS] COMMAND [ARGS]...
[Module] llm
╭─ Options ────────────────────────────────────────────────────────────────╮
│ --help Show this message and exit. │
╰──────────────────────────────────────────────────────────────────────────╯
╭─ Commands ───────────────────────────────────────────────────────────────╮
│ train [Entrypoint] train │
│ pretrain [Entrypoint] pretrain │
│ finetune [Entrypoint] finetune │
│ validate [Entrypoint] validate │
│ prune [Entrypoint] prune │
│ distill [Entrypoint] distill │
│ ptq [Entrypoint] ptq │
│ deploy [Entrypoint] deploy │
│ import [Entrypoint] import │
│ export [Entrypoint] export │
│ generate [Entrypoint] generate │
╰──────────────────────────────────────────────────────────────────────────╯
Each command represents a different task you can perform with NeMo models.
Pre-training Models#
The pretrain
command allows you to train language models from scratch using pre-configured recipes.
Listing Available Recipes#
To see all available pre-training recipes:
$ nemo llm pretrain --help
Usage: nemo llm pretrain [OPTIONS] [ARGUMENTS]
[Entrypoint] pretrain
Pretrains a model using the specified data and trainer, with optional logging, resuming, and optimization.
This function is a wrapper around the `train` function, specifically configured for pretraining tasks.
Note, by default it will use the tokenizer from the model.
╭─ Pre-loaded entrypoint factories, run with --factory ──────────────────────────────────────╮
│ baichuan2_7b nemo.collections.llm.recipes.baichuan2_7b.pr… line 142 │
│ baichuan2_7b_optimized nemo.collections.llm.recipes.baichuan2_7b.pr… line 190 │
│ bert_110m nemo.collections.llm.recipes.bert_110m.pretr… line 50 │
│ bert_340m nemo.collections.llm.recipes.bert_340m.pretr… line 50 │
│ chatglm3_6b nemo.collections.llm.recipes.chatglm3_6b.pre… line 142 │
│ chatglm3_6b_optimized nemo.collections.llm.recipes.chatglm3_6b.pre… line 190 │
│ deepseek_v2 nemo.collections.llm.recipes.deepseek_v2.pre… line 54 │
│ deepseek_v2_lite nemo.collections.llm.recipes.deepseek_v2_lit… line 54 │
│ gemma2_2b nemo.collections.llm.recipes.gemma2_2b.pretr… line 53 │
│ gemma2_9b nemo.collections.llm.recipes.gemma2_9b.pretr… line 53 │
│ llama3_8b nemo.collections.llm.recipes.llama3_8b.pretr… line 145 │
│ llama3_70b nemo.collections.llm.recipes.llama3_70b.pret… line 145 │
│ mixtral_8x7b nemo.collections.llm.recipes.mixtral_8x7b.pr… line 143 │
│ nemotron3_8b nemo.collections.llm.recipes.nemotron3_8b.pr… line 56 │
│ nemotron4_15b nemo.collections.llm.recipes.nemotron4_15b.p… line 55 │
│ ... (output truncated) │
╰────────────────────────────────────────────────────────────────────────────────────────────╯
Running Pre-training with Default Recipes#
To start pre-training with a default recipe:
$ nemo llm pretrain --factory llama3_8b
This command will configure and start pre-training a Llama 3 8B model using the default settings. The output will show a preview of the resolved configuration values before starting.
When run with the --dryrun
flag, you can preview the configuration without starting the training:
$ nemo llm pretrain --factory llama3_8b --dryrun
Configuring global options
Dry run for task nemo.collections.llm.api:pretrain
Resolved Arguments
┏━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Argument Name ┃ Resolved Value ┃
┡━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ data │ MockDataModule(seq_length=8192, micro_batch_size=1, │
│ │ global_batch_size=512) │
│ model │ LlamaModel(config=Llama3Config8B()) │
│ trainer │ Trainer( │
│ │ accelerator='gpu', │
│ │ strategy=MegatronStrategy( │
│ │ tensor_model_parallel_size=1, │
│ │ pipeline_model_parallel_size=1, │
│ │ context_parallel_size=2, │
│ │ sequence_parallel=False, │
│ │ ), │
│ │ devices=8, │
│ │ num_nodes=1, │
│ │ max_steps=1168251, │
│ │ ) │
│ ... (output truncated for brevity) │
└──────────────────────┴──────────────────────────────────────────────────────────────┘
Customizing Factory Parameters#
You can pass parameters directly to the factory function:
$ nemo llm pretrain --factory "llama3_70b(num_nodes=128)"
This example configures the Llama 3 70B model to use 128 nodes for distributed training.
Overriding Configuration Parameters#
The CLI supports overriding any configuration parameter using Hydra-style dot notation:
$ nemo llm pretrain --factory llama3_70b trainer.max_steps=2000
This syntax follows the pattern component.parameter=value
, allowing you to navigate nested configurations. You can override multiple parameters at once by adding more space-separated overrides:
$ nemo llm pretrain --factory llama3_70b trainer.max_steps=2000 optim.config.lr=5e-5 data.global_batch_size=256
Interactive Configuration with REPL Mode#
For interactive recipe modification, you can use the --repl
flag:
$ nemo llm pretrain --factory llama3_70b --repl
This command opens an interactive Python REPL (Read-Eval-Print Loop) where you can:
Inspect the entire configuration
Modify parameters interactively
Test different settings before launching the training job
Execute custom Python code to set up your configuration
Inside the REPL, you’ll have access to all the components of the configuration (model, data, trainer, etc.) and can modify them directly:
# Example of what you might do in the REPL
# View the model config
print(model.config)
# Modify learning rate
optim.config.lr = 2e-5
# Change the number of training steps
trainer.max_steps = 5000
# Start the training when ready
run()
Fine-tuning Models#
Similar to pre-training, NeMo provides recipes for fine-tuning models:
$ nemo llm finetune --help
Usage: nemo llm finetune [OPTIONS] [ARGUMENTS]
[Entrypoint] finetune
Finetunes a model using the specified data and trainer, with optional logging, resuming, and PEFT.
Note, by default it will use the tokenizer from the model.
╭─ Pre-loaded entrypoint factories, run with --factory ──────────────────────────────────────╮
│ baichuan2_7b nemo.collections.llm.recipes.baichuan2_7b.fi… line 236 │
│ chatglm3_6b nemo.collections.llm.recipes.chatglm3_6b.fin… line 236 │
│ deepseek_v2 nemo.collections.llm.recipes.deepseek_v2.fin… line 108 │
│ deepseek_v2_lite nemo.collections.llm.recipes.deepseek_v2_lit… line 107 │
│ gemma2_2b nemo.collections.llm.recipes.gemma2_2b.finet… line 173 │
│ gemma2_9b nemo.collections.llm.recipes.gemma2_9b.finet… line 173 │
│ llama2_7b nemo.collections.llm.recipes.llama2_7b.finet… line 230 │
│ llama3_8b nemo.collections.llm.recipes.llama3_8b.finet… line 245 │
│ llama3_70b nemo.collections.llm.recipes.llama3_70b.fine… line 251 │
│ mixtral_8x7b nemo.collections.llm.recipes.mixtral_8x7b.fi… line 240 │
│ nemotron3_8b nemo.collections.llm.recipes.nemotron3_8b.fi… line 253 │
│ nemotron4_15b nemo.collections.llm.recipes.nemotron4_15b.f… line 227 │
│ ... (output truncated) │
╰────────────────────────────────────────────────────────────────────────────────────────────╯
The available models for fine-tuning include a wide range of architectures:
Llama 2 and Llama 3 family
Nemotron 3 and Nemotron 4 family
Mixtral and other mixture-of-experts models
Mamba2 models including SSM and hybrid architectures
Encoder-decoder models like T5
And many more
Fine-tuning recipes include support for Parameter-Efficient Fine-Tuning (PEFT) methods. Notice that the finetune
command has an additional peft
argument compared to the pretrain
command.
To fine-tune a model:
$ nemo llm finetune --factory llama3_8b
Creating and Running Custom Recipes#
You can create custom recipes in Python scripts that use the same CLI interface. Here’s how a custom recipe might look:
# custom_recipe.py
import nemo_run as run
from nemo.collections import llm
from nemo.collections.llm.recipes import llama3_8b, llama3_70b
def custom_llama3_8b():
pretrain = llama3_8b.pretrain_recipe(num_nodes=1, num_gpus_per_node=8)
pretrain.trainer.val_check_interval = 400
pretrain.log.ckpt.save_top_k = -1
pretrain.log.ckpt.every_n_train_steps = 400
pretrain.trainer.max_steps = 1000
return pretrain
def custom_llama3_70b():
pretrain = llama3_70b.pretrain_recipe(num_nodes=1, num_gpus_per_node=8)
pretrain.trainer.val_check_interval = 400
pretrain.log.ckpt.save_top_k = -1
pretrain.log.ckpt.every_n_train_steps = 400
pretrain.trainer.max_steps = 1000
return pretrain
if __name__ == "__main__":
# When running this file, it will run the `custom_llama3_8b` recipe
# To select the `custom_llama3_70b` recipe, use the following command:
# python custom_recipe.py --factory custom_llama3_70b
# This will automatically call the custom_llama3_70b that's defined above
# Note that any parameter can be overwritten by using the following syntax:
# python custom_recipe.py trainer.max_steps=2000
# You can even apply transformations when triggering the CLI as if it's Python code
# python custom_recipe.py "trainer.max_steps*=2"
run.cli.main(llm.pretrain, default_factory=custom_llama3_8b)
When running the custom_recipe.py file, it will execute the custom_llama3_8b
recipe by default. However, you can select different recipes or modify parameters:
To select the
custom_llama3_70b
recipe:python custom_recipe.py --factory custom_llama3_70b
To overwrite any parameter:
python custom_recipe.py trainer.max_steps=2000
You can even apply transformations:
python custom_recipe.py "trainer.max_steps*=2" # Doubles the max_steps value
Text Generation#
NeMo provides a generate command for inference with trained models:
$ nemo llm generate
This command is used for text generation with trained NeMo LLM models. It takes a checkpoint path and a list of prompts, generates text based on the loaded model and parameters, and returns the generated text.
The command supports parameters like:
path
: Path to the model checkpointtrainer
: NeMo trainer configurationprompts
: List of input prompts for generationinference_params
: Generation parameters like temperature, top_k, and number of tokens to generatetext_only
: Whether to return only text or also metadata
Advanced Features#
Model Import and Export#
NeMo CLI provides commands for importing external models (like Hugging Face models) and exporting NeMo models:
$ nemo llm import --help # Import models from other frameworks
$ nemo llm export --help # Export NeMo models
Quantization and Pruning#
For model optimization, NeMo offers post-training quantization (PTQ) and pruning:
$ nemo llm ptq --help # Post-training quantization
$ nemo llm prune --help # Model pruning
Model Distillation#
For creating smaller, more efficient models:
$ nemo llm distill --help # Knowledge distillation
Integration with NeMo-Run#
NeMo seamlessly supports scaling to thousands of GPUs using NeMo-Run. For examples of launching large-scale experiments using NeMo-Run, refer to Quickstart with NeMo-Run.
The CLI allows you to specify custom execution environments by passing in run.executor=...
which can be a factory of any of the supported executors from NeMo-Run. This powerful feature enables you to run your jobs in various environments like Docker containers or Slurm clusters without modifying your recipe code.
You can see what executors are available in your environment by using the --help
flag with any command. The help output will show a section called “Registered executors” at the bottom:
$ nemo llm finetune --help
# ... other help output ...
╭─ Registered executors ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ torchrun nemo.collections.llm.recipes.run.executor.to… line 20 │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
This shows that in this case, there’s a torchrun
executor registered by default. You can reference this in your command line with run.executor=torchrun
.
Here are some example executor factories you can define in your custom recipe:
@run.cli.factory
@run.autoconvert
def docker() -> run.Executor:
return run.DockerExecutor(
container_image="nvcr.io/nvidia/nemo:dev",
volumes=[
f"{BASE_DIR}/opt/NeMo-Run:/opt/NeMo-Run",
f"{BASE_DIR}/opt/NeMo:/opt/NeMo",
f"{BASE_DIR}/opt/megatron-lm:/opt/Megatron-LM",
],
env_vars={
"HF_HOME": "/workspaces/models/hf",
"NEMO_HOME": "/workspaces/models/nemo",
}
)
@run.cli.factory
@run.autoconvert
def slurm_cluster() -> run.Executor:
return run.SlurmExecutor(
account=ACCOUNT,
partition=SLURM_PARTITION,
job_name_prefix=f"{ACCOUNT}-nemo-ux:",
job_dir=BASE_DIR,
container_image="nvcr.io/nvidia/nemo:dev",
container_mounts=[
f"/home/{USER}:/home/{USER}",
"/lustre:/lustre",
],
time="4:00:00",
gpus_per_node=8,
tunnel=run.SSHTunnel(host=SLURM_LOGIN_NODE, user=USER, job_dir=BASE_DIR)
)
With these executor factories defined, you can easily select which execution environment to use via the command line:
# Run in a Docker container
$ python custom_recipe.py run.executor=docker
# Run on a Slurm cluster
$ python custom_recipe.py run.executor=slurm_cluster
This approach provides tremendous flexibility, allowing you to develop recipes locally and then seamlessly deploy them to different computing environments without changing your code.
Learning More About the CLI#
If you’re interested in understanding the internals of the NeMo CLI and NeMo-Run CLI system, or want to create your own CLI entrypoints and experiments, you can find detailed examples and tutorials in the NeMo-Run entrypoint examples.
This repository includes:
Detailed explanations of CLI concepts like entrypoints, factories, and partials
Examples of creating single task entrypoints
Examples of creating experiment entrypoints for sequential and parallel execution
Advanced CLI features like Pythonic argument parsing and interactive configuration
Best practices for creating effective CLI interfaces
These examples provide deeper insight into how the NeMo CLI works and how you can leverage its features for your own custom workflows.
Summary#
The NeMo CLI provides a comprehensive interface for working with large language models:
Model Architecture Support: Supports a wide range of architectures including LLaMA, Mixtral, Nemotron, Mamba, T5, and many others
Training Options:
Pre-training from scratch
Fine-tuning existing models
Parameter-efficient fine-tuning (PEFT)
Configuration Flexibility:
Override any parameter using dot notation
Interactive configuration with REPL mode
Create custom recipes in Python
Scalability:
Supports training on thousands of GPUs
Integrates with NeMo-Run for cluster management
Deployment and Optimization:
Model export and import
Quantization
Pruning
Distillation
This design allows you to quickly experiment with different models and configurations without having to write custom training scripts from scratch.