Add New Models#
This guide outlines how to integrate and validate a new model within NeMo RL. Each new model must pass a standard set of compatibility tests before being considered ready to be used in RL pipelines. The guide also details diagnostic scripts to help identify and resolve common issues during model integration.
Importance of Log Probability Consistency in Training and Inference#
In on-policy RL, we sample tokens (actions) from the latest version of the policy. This means the sampling distribution of token probabilities produced by the inference framework must closely match those from the training framework. If the inference framework produces significantly different probabilities, we effectively sample from a different distribution, leading to errors in the loss estimation.
As an example, we would see errors in naive KL estimation:
When summed/integrated, replacing the \(x \sim \pi\) with \(x \sim \pi_{\text{wrong}}\) leads to an error of:
So, to verify correctness, we calculate:
as a measure of multiplicative probability error for sampled tokens, where samples are drawn as \(x \sim \pi_{\text{inference-framework}}\).
Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn’t catch it here, as \(x \sim \pi_{\text{inference-framework}}\)). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample \(x \sim \pi_{\text{training-framework}}\). In practice, we use just the former in our tests and find it sufficient.
Understand Discrepancies Between Backends#
When validating models across different backends, you may encounter discrepancies in log probabilities. These differences can stem from various sources with effects ranging from negligible to significant:
Numerical precision differences: Training and inference backends may differ in precision formats (FP32, FP16, BF16, FP8).
Training may use mixed precision, while the inference backend may not.
High-precision training with FP8 inference may not be numerically stable for certain models.
Differences can occur at the layer level, with some layers in FP32, while others use lower precision.
Implementation variations: Subtle differences in how layer implementations like softmax, layer normalization, or attention mechanisms are implemented.
Attention/Norm layers (which could be fused) in TransformerEngine may not be bit-wise identical to implementations in inference backends.
Inference backends may re-implement kernels (e.g., for SSM layers) leading to differences.
Softmax in training frameworks may be calculated differently than in inference backends for numerical stability.
KV/Prefill cache handling: Differences in how key-value/prefill caches are managed during autoregressive generation.
In some cases, disabling the inference backend cache can resolve discrepancies.
Parallelism effects: Parallelisms like Tensor parallelism may introduce small variations.
Inherent non-determinism: Some neural network operations are inherently non-deterministic (e.g.,
torch.cumsum).Prefill/Decoding kernel mismatch: Different kernels for prefill and decoding phases may produce different log probabilities.
Training frameworks typically use prefill kernels, while inference backends may use both prefill kernels and specialized decoding kernels.
Imperfect Refit: Weight conversion from the training framework to the inference backend may be incomplete or data formats may be incorrect.
If weights are reshaped or reordered incorrectly, generations tend to be very wrong.
In some cases, if some weights in the inference backend are not refit after each training step, the error between training and inference log probabilities can diverge as training progresses.
Batch size: In some cases,
batch_size>1may produce larger errors thanbatch_size=1
When investigating discrepancies beyond the acceptable threshold, focus on these areas and determine whether the differences appear systematically or only in specific contexts.
1. Hugging Face–Based Models#
Validation Workflow#
When validating Hugging Face-based models, perform the following checks:
Compare log probabilities Ensure the generation log probabilities from inference backends like vLLM match those computed by Hugging Face. This comparison helps diagnose potential mismatches.
Test parallelism Verify consistency with other parallelism settings.
Variance Repeat tests multiple times (e.g., 10 runs) to confirm that behavior is deterministic or within acceptable variance.
Check sequence lengths Perform inference on sequence lengths of 100, 1,000, and 10,000 tokens. Ensure the model behaves consistently at each length.
Use real and dummy data
Real data: Tokenize and generate from actual text samples.
Dummy data: Simple numeric sequences to test basic generation.
Vary sampling parameters Test both greedy and sampling generation modes. Adjust temperature and top-p to confirm output consistency across backends.
Test different batch sizes Try with batch sizes of 1, 8, and 32 to ensure consistent behavior across different batch configurations.
2. Megatron Models#
Additional Validation#
Compare Megatron outputs Ensure the Megatron forward pass aligns with Hugging Face and the generation log probabilities from inference backends like vLLM.
Parallel settings Match the same parallelism configurations used for the HuggingFace-based tests. Confirm outputs remain consistent across repeated runs.
3. Expected Error Threshold#
When comparing log probabilities between training and inference backends, we use an error threshold of 1.05 to determine acceptable variance (for equal precision). An error of 1.0 indicates a perfect match, and values exceeding 1.05 require further investigation.
When validating your model, you should analyze the results across different configurations. Your analysis should include:
Sequence Length |
Data Type |
Generation Method |
Batch Size |
HF vs VLLM |
Megatron vs VLLM |
|---|---|---|---|---|---|
100 |
Real |
Greedy |
1 |
1.02 |
1.01 |
100 |
Real |
Sampling |
8 |
1.03 |
1.02 |
100 |
Synthetic |
Greedy |
1 |
1.01 |
1.02 |
1,000 |
Real |
Greedy |
32 |
1.04 |
1.03 |
… |
… |
… |
… |
… |
… |
By following these validation steps and ensuring your model’s outputs remain consistent across backends, you can confirm that your new model meets the requirements of NeMo RL.
Model Diagnostics#
We also maintain a set of standalone scripts that can be used to diagnose issues related to correctness that we have encountered before.
1.max_model_len_respected.py#
Test if a new model respects the max_model_len passed to vllm:
# Run that is expected to pass
uv run --extra vllm tools/model_diagnostics/1.max_model_len_respected.py Qwen/Qwen2.5-1.5B
# ...
# Prompt tokens: 8
# Generated tokens: 12
# Total tokens: 20
# [Qwen/Qwen2.5-1.5B] ALL GOOD!
2.long_generation_decode_vs_prefill#
Test that vLLM yields near-identical token log-probabilities when comparing decoding with a single prefill pass across multiple prompts.
# Run that is expected to pass
uv run --extra vllm tools/model_diagnostics/2.long_generation_decode_vs_prefill.py --model Qwen/Qwen2.5-1.5B
# ...
# [Qwen/Qwen2.5-1.5B] ALL GOOD!
3.check_hf_model_embeddings_untrained.py#
Detects untrained or improperly initialized Hugging Face model embeddings by scanning for near-zero rows and rows with near-identical values in both input and output embeddings. The script also reports whether word embeddings are tied and summarizes basic statistics.
# Example run
uv run --extra mcore tools/model_diagnostics/3.check_hf_model_embeddings_untrained.py --model nvidia/Nemotron-H-8B-Base-8K
# ....
#================================================================================
#EMBEDDING SUMMARIES
#================================================================================
#
#--- Input Embeddings Summary ---
#Shape: torch.Size([131072, 4096]), Dtype: torch.bfloat16
#Near-zero embeddings (abs < 1.00e-10): 1039/131072 (0.8%)
# Indices: 0-1,3-999,1192-1193,1245-1255,55014,77579,81772,81819,82312,82500,82725,82737,82977,84020,84121,84521,84794,85015,86409,87411,89412,90320,91368,94485,96385,104097,108262,112147,112327,112497,114755
#Identical embeddings (std < 1.00e-08): 1041/131072 (0.8%)
# Indices: 0-1,3-999,1192-1193,1245-1255,55014,77579,81772,81819,82312,82500,82725,82737,82977,83855,84020,84121,84521,84794,85015,86409,87411,89412,90320,91368,94485,96385,101707,104097,108262,112147,112327,112497,114755
#Statistics: mean_abs=0.007874, max_abs=0.196289, std_range=[0.000000, 0.015442]
#⚠️ POTENTIAL ISSUES: 1039 near-zero embeddings, 1041 identical embeddings
#
#--- Output Embeddings Summary (Tied: False) ---
#Shape: torch.Size([131072, 4096]), Dtype: torch.bfloat16
#Near-zero embeddings (abs < 1.00e-10): 0/131072 (0.0%)
#Identical embeddings (std < 1.00e-08): 0/131072 (0.0%)
#Statistics: mean_abs=0.006775, max_abs=0.200195, std_range=[0.004089, 0.021240]
#✅ No obvious untrained patterns detected
#
#=== Final Summary ===
#Model: nvidia/Nemotron-H-8B-Base-8K
#Analysis complete.
Thresholds can be adjusted via flags:
--near-zero-threshold(default:1e-10)--identical-threshold(default:1e-8)
If any near-zero or identical rows are reported, the model may have issues of numerical instability (e.g., inf grad norms) during post-training if any of these problematic tokens are encountered. We have observed this happening when special tokens are reserved in the tokenizer and embedding, but none are encountered during pre-training. It may help to initialize these embeddings similar to how they were initialize during pre-training.
4.vllm_precision_compilation_test.py#
Tests vLLM precision compilation by comparing log probabilities across different compilation modes and configurations. This script helps diagnose numerical precision issues that commonly arise when using different vLLM compilation settings. Note that this is not a strict pass/fail test - it’s designed to help you understand and investigate numerical discrepancies.
# Example run
uv run --extra vllm tools/model_diagnostics/4.vllm_precision_compilation_test.py --model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
# Typical output shows mixed results:
# Eager and cuda graph mode lps: FAILED - Arrays are different
...
# Eager and cuda graph mode lps with torch inductor precision flag: FAILED - Arrays are different
...
# Eager and cuda graph mode lps with backend=eager: PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
See example for model deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
====================================================================================================
Eager and cuda graph mode lps (prompt lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001
Mismatched elements: 96 / 515 (18.6%)
Max absolute difference among violations: 0.3885002
Max relative difference among violations: 0.20179409
ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01,
-3.443364e-04, nan, nan, nan,
nan, nan, nan, nan,...
DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01,
-3.047717e-04, nan, nan, nan,
nan, nan, nan, nan,...
====================================================================================================
====================================================================================================
Eager and cuda graph mode lps (generation lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001
nan location mismatch:
ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan,
nan, nan],
[-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,...
DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan,
nan, nan],
[-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,...
====================================================================================================
...
====================================================================================================
Eager and cuda graph mode lps with torch inductor precision flag (prompt lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001
Mismatched elements: 96 / 515 (18.6%)
Max absolute difference among violations: 0.3885002
Max relative difference among violations: 0.20179409
ACTUAL: array([[-1.424489e+01, -3.924684e-01, -3.135911e+00, -4.258007e-01,
-3.443364e-04, nan, nan, nan,
nan, nan, nan, nan,...
DESIRED: array([[-1.420929e+01, -3.619126e-01, -3.241854e+00, -4.308376e-01,
-3.047717e-04, nan, nan, nan,
nan, nan, nan, nan,...
====================================================================================================
====================================================================================================
Eager and cuda graph mode lps with torch inductor precision flag (generation lps): FAILED - Arrays are different
Detailed error:
Not equal to tolerance rtol=0.001, atol=0.001
nan location mismatch:
ACTUAL: array([[-1.231834e+01, -1.411233e-01, -3.764260e-01, ..., nan,
nan, nan],
[-8.567932e+00, -1.066314e+01, -4.463661e-01, ..., nan,...
DESIRED: array([[-1.226752e+01, -1.508305e-01, -4.024158e-01, ..., nan,
nan, nan],
[-8.610202e+00, -1.067061e+01, -4.593382e-01, ..., -1.060957e-05,...
====================================================================================================
...
Eager and cuda graph mode lps with backend=eager (prompt lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
Eager and cuda graph mode lps with backend=eager (generation lps): PASSED - Arrays are close within tolerance (atol=0.001, rtol=0.001)
What this script tests:
The script is to compare both prompt and generation logprobs under the following setups:
Eager vs CUDA Graph Mode: Compares log probabilities between eager execution (ground truth) and CUDA graph compilation mode
⚠️ Commonly fails: This comparison often shows discrepancies due to compilation optimizations
Torch Inductor Precision: Tests with
TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1environment variable⚠️ May help: This flag may help but typically doesn’t resolve all the numerical differences
Inductor Disabled: Verifies that disabling Torch Inductor compilation (
backend="eager") maintains output consistency✅ Usually works well: This configuration often produces results very close to eager mode
Note:
backend="eager"disables Inductor compilation but keeps CUDA graph capture active for compatible operations
Performance vs Accuracy Trade-offs:
The different compilation modes offer distinct trade-offs between accuracy and performance:
Eager Mode (
enforce_eager=True): Highest accuracy (ground truth) but slowest executionCUDA Graph Mode with Inductor Disabled (
enforce_eager=Falseandcompilation_config={"backend": "eager"}): Near-eager accuracy with significant speedup from CUDA graph optimizationCUDA Graph Mode with Inductor Enabled (
enforce_eager=Falseandcompilation_config={"backend": "inductor"}or default): Potentially fastest execution with custom Triton kernels (since Triton is the current backend of Inductor), but may introduce numerical differences. For accuracy improvement, try the torch inductor precision flag:export TORCHINDUCTOR_EMULATE_PRECISION_CASTS=1
Note: Performance characteristics vary by model. For example, deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B shows similar speed performance between backend="inductor" and backend="eager", making the accuracy-preserving option preferable.
Why this matters:
Debugging: Helps identify which compilation settings cause numerical differences
Configuration: Shows which settings work best for your model
Understanding: Reveals how compilation affects model outputs
When to use:
Model integration - understand numerical behavior across vLLM configurations
Debugging - investigate differences between development and production
Research - study compilation strategy impacts on precision
Interpreting results:
Eager vs CUDA Graph failures are normal - don’t panic if this fails
Focus on patterns - some models are more sensitive than others
Use as guidance - helps choose reliable compilation settings
Balance precision vs performance - choose what works for your use case
5.prefix_caching_nan.py#
Tests that prefix caching doesn’t produce NaN logprobs when prior generation is rolled back into the prompt (the standard RL / multi-turn pattern). In vLLM >= 0.14, the second request can return all-NaN logprobs with token_id=0 (<unk>) for every token after the first.
# Single version (requires 2+ GPUs for TP=2)
uv run --no-project --with "vllm==0.14.0" tools/model_diagnostics/5.prefix_caching_nan.py
# Test across multiple vLLM versions:
for ver in 0.11.2 0.13.0 0.14.0 0.15.0 0.15.1; do
uv run --no-project --with "vllm==$ver" tools/model_diagnostics/5.prefix_caching_nan.py 2>&1 | tee "prefix_caching_nan_vllm_${ver}.log"
done
Expected pass output (vLLM 0.13.0):
Iteration 1 — prompt length: 13990 chars
tokens: 2048, finish_reason: length
text (first 100): '3001 3002 3003 3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020 '
Iteration 2 — prompt length: 16038 chars
tokens: 2048, finish_reason: length
text (first 100): '1 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 343'
[nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16] ALL GOOD!
Expected fail output (vLLM 0.14.0):
Iteration 1 — prompt length: 13990 chars
tokens: 2048, finish_reason: length
text (first 100): '3000\n\nAssistant: 2600 2601 2602 2603 2604 2605 2606 2607 2609 2610 ...'
Iteration 2 — prompt length: 16047 chars
tokens: 2048, finish_reason: length
text (first 100): '3'
Sample logprobs from iteration 2:
token[0] id=1051: logprob=-0.0005862186080776155 decoded='3'
token[1] id=0: logprob=nan decoded='<unk>'
token[2] id=0: logprob=nan decoded='<unk>'
token[2047] id=0: logprob=nan decoded='<unk>'
AssertionError: FAIL: 2047/2048 logprobs are NaN on iteration 2 (prefix caching is broken in vLLM 0.14.0)
Note: the ERROR ... Engine core proc EngineCore_DP0 died unexpectedly message that may appear after the assertion is just vLLM’s engine shutting down ungracefully after the process exits — it is not a separate issue.
The script generates from a counting prompt, appends the output back into the prompt, and generates again. On the second generation, prefix caching reuses the KV cache from the first request’s prefix. The bug causes the cached prefix to produce corrupted activations, resulting in token_id=0 (<unk>) with logprob=nan for all tokens after the first.