GPT-OSS#
GPT-OSS is an open-weight model released by OpenAI, providing transparent and accessible large language models. GPT-OSS models are built on the Mixture-of-Experts (MoE) transformer decoder architecture with Sink Attention and alternating Sliding-Window Attention (SWA). The model family includes two variants: GPT-OSS 20B and GPT-OSS 120B, designed to serve different computational requirements while maintaining high-quality text generation capabilities. The models are designed to be used within agentic workflows with strong instruction following, tool use like web search and Python code execution, and reasoning capabilities—including the ability to adjust the reasoning effort for tasks that don’t require complex reasoning.
We provide pre-defined recipes for finetuning GPT-OSS models in two sizes: 20B and 120B using NeMo 2.0 and NeMo-Run.
These recipes configure a run.Partial
for one of the nemo.collections.llm api functions introduced in NeMo 2.0.
Note
Please use the custom container nvcr.io/nvidia/nemo:25.07.gpt_oss
when working with GPT-OSS. Please make sure you update to the latest version of transformers
.
NeMo 2.0 Finetuning Recipes#
Note
The finetuning recipes use the SquadDataModule
for the data
argument. You can replace the SquadDataModule
with your custom dataset.
Note that this model is a reasoning model with a specific chat template, so it is best to use a chat dataset with the use_hf_tokenizer_chat_template=True
argument when finetuning.
To import the HF model and convert to NeMo 2.0 format, run the following command (this only needs to be done once):
cd <HF_MODEL_DIR>
apt-get update && apt-get install git-lfs
git lfs install
git clone https://huggingface.co/openai/gpt-oss-20b
git clone https://huggingface.co/openai/gpt-oss-120b
from nemo.collections import llm
# For GPT-OSS 20B
llm.import_ckpt(model=llm.GPTOSSModel(llm.GPTOSSConfig20B()), source='hf:///<HF_MODEL_DIR>/gpt-oss-20b')
# For GPT-OSS 120B
# llm.import_ckpt(model=llm.GPTOSSModel(llm.GPTOSSConfig120B()), source='hf:///<HF_MODEL_DIR>/gpt-oss-120b')
To import the original OpenAI checkpoint and convert to NeMo 2.0 format, run the following command (this only needs to be done once):
from nemo.collections import llm
# For GPT-OSS 20B
llm.import_ckpt(model=llm.GPTOSSModel(llm.GPTOSSConfig20B()), source='openai:///path/to/gpt-oss-20b')
# For GPT-OSS 120B
# llm.import_ckpt(model=llm.GPTOSSModel(llm.GPTOSSConfig120B()), source='openai:///path/to/gpt-oss-120b')
We provide an example below on how to invoke the default recipe and override the data argument:
from nemo.collections import llm
# For GPT-OSS 20B
recipe = llm.gpt_oss_20b.finetune_recipe(
name="gpt_oss_20b_finetuning",
dir=f"/path/to/checkpoints",
num_nodes=1,
num_gpus_per_node=8,
peft_scheme='lora', # 'lora', 'none'
)
# For GPT-OSS 120B
# recipe = llm.gpt_oss_120b.finetune_recipe(
# name="gpt_oss_120b_finetuning",
# dir=f"/path/to/checkpoints",
# num_nodes=4,
# num_gpus_per_node=8,
# peft_scheme='lora', # 'lora', 'none'
# )
# # To override the data argument
# dataloader = a_function_that_configures_your_custom_dataset(
# gbs=gbs,
# mbs=mbs,
# seq_length=recipe.model.config.seq_length,
# use_hf_tokenizer_chat_template=True,
# )
# recipe.data = dataloader
By default, the finetuning recipe will run LoRA finetuning with LoRA applied to linear layers in the attention block in the language model.
To finetune the entire model without LoRA, set peft_scheme='none'
in the recipe argument.
Note
The configuration in the recipes is done using the NeMo-Run run.Config
and run.Partial
configuration objects. Please review the NeMo-Run documentation to learn more about its configuration and execution system.
Once you have your final configuration ready, you can execute it on any of the NeMo-Run supported executors. The simplest is the local executor, which just runs the pretraining locally in a separate process. You can use it as follows:
import nemo_run as run
run.run(recipe, executor=run.LocalExecutor())
Additionally, you can also run it directly in the same Python process as follows:
run.run(recipe, direct=True)
Inference#
To run inference with GPT-OSS models, you can use the following command:
# For GPT-OSS 20B
torchrun --nproc-per-node=1 /opt/NeMo/scripts/llm/generate.py \
--model_path=<PATH_TO_NEMO2_MODEL> \
--devices=1 \
--num_tokens_to_generate=512 \
--temperature=0.0 \
--top_p=0.0 \
--top_k=1 \
--disable_flash_decode
# For GPT-OSS 120B
# torchrun --nproc-per-node=8 /opt/NeMo/scripts/llm/generate.py \
# --model_path=<PATH_TO_NEMO2_MODEL> \
# --ep=4 \
# --pp=2 \
# --devices=8 \
# --num_tokens_to_generate=512 \
# --temperature=0.0 \
# --top_p=0.0 \
# --top_k=1 \
# --disable_flash_decode
Export to HF#
After training or finetuning your GPT-OSS model, you can export it to Hugging Face format for easy sharing and deployment:
from nemo.collections import llm
# Export NeMo checkpoint to Hugging Face format
llm.export_ckpt(
target="hf",
path=Path("<path_to_nemo_checkpoint>"),
output_path=Path("<path_to_output_hf_model>"),
)
Note
Ensure you have sufficient disk space and appropriate permissions when exporting large models. The export process may take some time depending on the model size and your storage setup.