# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, TypedDict
import numpy as np
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.loss_functions import (
ClippedPGLossConfig,
ClippedPGLossDataDict,
ClippedPGLossFn,
)
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt
from nemo_rl.data import DataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_rl.data.interfaces import (
DatumSpec,
)
from nemo_rl.data.llm_message_utils import (
batched_message_log_to_flat_message,
)
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.environments.interfaces import (
EnvironmentInterface,
)
from nemo_rl.experience.rollouts import run_multi_turn_rollout
from nemo_rl.models.generation.interfaces import (
GenerationInterface,
)
from nemo_rl.models.generation.vllm import VllmGeneration
from nemo_rl.models.interfaces import PolicyInterface
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import (
Logger,
LoggerConfig,
print_message_log_samples,
)
from nemo_rl.utils.timer import Timer
# ===============================================================================
# Configuration
# ===============================================================================
[docs]
class GRPOConfig(TypedDict):
num_prompts_per_step: int
num_generations_per_prompt: int
max_num_steps: int
normalize_rewards: bool
use_leave_one_out_baseline: bool
val_period: int
val_batch_size: int
val_at_start: bool
checkpoint_dir: str
[docs]
class GRPOSaveState(TypedDict):
step: int
val_reward: float
consumed_samples: int
[docs]
def _default_grpo_save_state() -> GRPOSaveState:
return {
"step": 0,
"val_reward": -99999999.0,
"consumed_samples": 0,
}
[docs]
class MasterConfig(TypedDict):
policy: PolicyConfig
loss_fn: ClippedPGLossConfig
env_configs: Dict[str, Any]
data: DataConfig
grpo: GRPOConfig
logger: LoggerConfig
cluster: ClusterConfig
checkpointing: CheckpointingConfig
# ===============================================================================
# Setup & Initialization
# ===============================================================================
[docs]
def setup(
master_config: MasterConfig,
tokenizer: AutoTokenizer,
dataset: AllTaskProcessedDataset,
val_dataset: Optional[AllTaskProcessedDataset],
) -> Tuple[
PolicyInterface,
GenerationInterface,
RayVirtualCluster,
StatefulDataLoader,
Optional[StatefulDataLoader],
ClippedPGLossFn,
Logger,
CheckpointManager,
GRPOSaveState,
MasterConfig,
]:
"""Main entry point for running GRPO algorithm.
Returns:
Tuple of policy, cluster, dataloader, tokenizer, loss_fn, math_env, logger, master_config, val_dataloader
"""
# Extract individual configs for easier access
policy_config = master_config["policy"]
generation_config = master_config["policy"]["generation"]
loss_config = master_config["loss_fn"]
data_config = master_config["data"]
grpo_config = master_config["grpo"]
logger_config = master_config["logger"]
cluster_config = master_config["cluster"]
# ==========================
# Logger
# ==========================
logger = Logger(logger_config)
logger.log_hyperparams(master_config)
# ==========================
# Checkpointing
# ==========================
checkpointer = CheckpointManager(master_config["checkpointing"])
last_checkpoint_path = checkpointer.get_latest_checkpoint_path()
grpo_save_state: Optional[GRPOSaveState] = checkpointer.load_training_info(
last_checkpoint_path
)
if grpo_save_state is None:
grpo_save_state = _default_grpo_save_state()
# config validation checks
if master_config["checkpointing"]["enabled"]:
assert master_config["checkpointing"]["save_period"] > 0
assert (
master_config["checkpointing"]["save_period"]
% master_config["grpo"]["val_period"]
== 0
), (
f"Checkpointing save period {master_config['checkpointing']['save_period']} "
f"must be a multiple of validation period {master_config['grpo']['val_period']}"
f", or we won't know what metric to save!"
)
# ==========================
# Data
# ==========================
dataloader = StatefulDataLoader(
dataset,
batch_size=grpo_config["num_prompts_per_step"],
shuffle=False,
collate_fn=rl_collate_fn,
)
if last_checkpoint_path is not None:
dataloader_state_dict = torch.load(
os.path.join(last_checkpoint_path, "train_dataloader.pt")
)
dataloader.load_state_dict(dataloader_state_dict)
print(f" ✓ Training dataloader loaded with {len(dataset)} samples")
# Load validation dataset if provided
val_dataloader = None
# If validation is enabled, load the validation dataloader
if grpo_config["val_period"] > 0 or grpo_config["val_at_start"]:
val_dataloader = StatefulDataLoader(
val_dataset,
batch_size=grpo_config["val_batch_size"],
shuffle=False,
collate_fn=rl_collate_fn,
)
print(f" ✓ Validation dataloader loaded with {len(val_dataset)} samples")
# ==========================
# Cluster
# ==========================
print("\n▶ Setting up compute cluster...")
colocated_inference = generation_config["backend"] != "hf"
cluster = RayVirtualCluster(
name="grpo_policy_cluster",
bundle_ct_per_node_list=[cluster_config["gpus_per_node"]]
* cluster_config["num_nodes"],
use_gpus=True,
num_gpus_per_node=cluster_config["gpus_per_node"],
max_colocated_worker_groups=2 if colocated_inference else 1,
)
print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes")
# ==========================
# Training and Inference
# ==========================
print("\n▶ Setting up model and training...")
# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
backend = generation_config["backend"]
generation_config["model_name"] = policy_config["model_name"] # Needed for vLLM
if backend == "hf":
policy_generation = None
print(f" ✓ Using HF backend for generation with {policy_config['model_name']}")
elif backend == "vllm":
policy_generation = VllmGeneration(cluster=cluster, config=generation_config)
# Worker groups are not initialized until the first call to run something on workergroups.
# vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory).
policy_generation.finish_generation()
print(
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}"
)
policy = HfPolicy(
cluster=cluster,
config=policy_config,
tokenizer=tokenizer,
weights_path=Path(last_checkpoint_path) / "policy" / "weights"
if last_checkpoint_path
else None,
optimizer_path=Path(last_checkpoint_path) / "policy" / "optimizer"
if last_checkpoint_path
else None,
init_optimizer=True,
)
loss_fn = ClippedPGLossFn(loss_config)
print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
print("=" * 60 + "\n")
return (
policy,
policy_generation,
cluster,
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_save_state,
master_config,
)
# ===============================================================================
# Core Algorithm Functions
# ===============================================================================
[docs]
def refit_policy_generation(
policy: PolicyInterface,
policy_generation: GenerationInterface,
refit_buffer_size_gb: int, # GB
):
"""Refit the policy generation interface with the latest policy weights."""
policy.offload_before_refit()
policy_generation.prepare_for_generation(tags=["weights"])
# Streaming update weights to save memory
state_dict_info = policy.prepare_weights_for_ipc()
# group keys to save time
available_bytes = refit_buffer_size_gb * (1024**3)
split_keys, keys = [], []
for key, size_in_bytes in state_dict_info:
if size_in_bytes > available_bytes:
if keys:
split_keys.append(keys)
keys = []
available_bytes = refit_buffer_size_gb * (1024**3)
keys.append(key)
available_bytes -= size_in_bytes
if len(keys) > 0:
split_keys.append(keys)
# do update
for keys in split_keys:
ipc_handles = policy.get_weights_ipc_handles(keys)
if not policy_generation.update_weights(ipc_handles):
error_message = (
"❌ Error: Updating weights for the generation policy failed during refit.\n"
"This often indicates an issue with cuda-ipc or "
"a problem within the generation backend (e.g., vLLM worker).\n"
)
raise RuntimeError(error_message)
policy.offload_after_refit()
policy_generation.prepare_for_generation(tags=["kv_cache"])
# ===============================================================================
# Training & Validation
# ===============================================================================
[docs]
def grpo_train(
policy: PolicyInterface,
policy_generation: Optional[GenerationInterface],
dataloader: StatefulDataLoader,
val_dataloader: Optional[StatefulDataLoader],
tokenizer,
loss_fn: LossFunction,
task_to_env: Dict[str, EnvironmentInterface],
val_task_to_env: Optional[Dict[str, EnvironmentInterface]],
logger: Logger,
checkpointer: CheckpointManager,
grpo_save_state: Optional[GRPOSaveState],
master_config: MasterConfig,
):
"""Run GRPO training algorithm."""
timer = Timer()
NEED_REFIT = True
# If policy_generation is None, use the policy as the generation interface (hf framework backend)
if policy_generation is None:
policy_generation = policy
NEED_REFIT = False
POLICY_GENERATION_STALE = True # tracks if generation needs a refit before running
# common config/state itmes
step = grpo_save_state["step"]
consumed_samples = grpo_save_state["consumed_samples"]
val_period = master_config["grpo"]["val_period"]
val_at_start = master_config["grpo"]["val_at_start"]
refit_buffer_size_gb = master_config["policy"]["refit_buffer_size_gb"]
# Run validation at the start if configured
if val_at_start and step == 0:
print("\n🔍 Running initial validation...")
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(policy, policy_generation, refit_buffer_size_gb)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
val_metrics, validation_timings = validate(
policy_generation,
val_dataloader,
tokenizer,
val_task_to_env,
step=0,
master_config=master_config,
)
policy_generation.finish_generation()
logger.log_metrics(val_metrics, step, prefix="validation")
logger.log_metrics(validation_timings, step, prefix="timing/validation")
# Run grpo training (single-turn)
batch: BatchedDataDict[DatumSpec]
for batch in dataloader:
print(
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}"
)
val_metrics, validation_timings = None, None
with timer.time("total_step_time"):
# Prepare batch
print("▶ Preparing batch...")
with timer.time("data_processing"):
# Repeat batch items
repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave(
master_config["grpo"]["num_generations_per_prompt"]
)
# Convert LLMMessageLogType to FlatMessagesType for generation
batched_flat, input_lengths = batched_message_log_to_flat_message(
repeated_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
input_ids = batched_flat["token_ids"]
# Generate responses - this updates the LLMMessageLogType in repeated_batch
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy,
policy_generation,
refit_buffer_size_gb,
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
with timer.time("generation"):
repeated_batch, rollout_metrics = run_multi_turn_rollout(
policy_generation=policy_generation,
input_batch=repeated_batch,
tokenizer=tokenizer,
task_to_env=task_to_env,
max_seq_len=master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=master_config["grpo"]["max_rollout_turns"],
greedy=False,
)
policy_generation.finish_generation()
# Calculate rewards & advantages
print("▶ Processing rewards...")
with timer.time("reward_calculation"):
# Extract rewards from final_batch
rewards = repeated_batch["total_reward"]
print("▶ Computing advantages...")
baseline, std = calculate_baseline_and_std_per_prompt(
input_ids,
rewards,
torch.ones_like(rewards),
leave_one_out_baseline=master_config["grpo"][
"use_leave_one_out_baseline"
],
)
advantages = (rewards - baseline).unsqueeze(-1)
if master_config["grpo"]["normalize_rewards"]:
# don't sharpen the ones with no variation
zero_std_mask = std > 0
advantages[zero_std_mask] = (
advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask]
)
with timer.time("data_processing"):
# Add loss mask and advantages to each message in LLMMessageLogType
for i, message_log in enumerate(repeated_batch["message_log"]):
for j, message in enumerate(message_log):
if message["role"] == "assistant":
message["token_loss_mask"] = torch.ones_like(
message["token_ids"]
)
else:
message["token_loss_mask"] = torch.zeros_like(
message["token_ids"]
)
if "generation_logprobs" not in message:
message["generation_logprobs"] = torch.zeros_like(
message["token_ids"], dtype=torch.float32
)
message["advantages"] = advantages[i].expand(
message["token_ids"].shape
)
# Convert updated LLMMessageLogType to FlatMessagesType for training
flat_messages, input_lengths = batched_message_log_to_flat_message(
repeated_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
make_sequence_length_divisible_by=master_config["policy"][
"make_sequence_length_divisible_by"
],
)
# Create training data from flattened messages
train_data = BatchedDataDict[ClippedPGLossDataDict](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"advantages": flat_messages["advantages"],
"generation_logprobs": flat_messages["generation_logprobs"],
"token_mask": flat_messages["token_loss_mask"],
"sample_mask": repeated_batch["loss_multiplier"],
}
)
train_data.to("cpu")
print("▶ Preparing for logprob inference...")
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()
print("▶ Computing logprobs...")
with timer.time("policy_and_reference_logprobs"):
fprop_logprobs = policy.get_logprobs(train_data)["logprobs"]
reference_logprobs = policy.get_reference_policy_logprobs(train_data)[
"reference_logprobs"
]
train_data["prev_logprobs"] = fprop_logprobs
train_data["reference_policy_logprobs"] = reference_logprobs
print("▶ Preparing for training...")
with timer.time("training_prep"):
policy.prepare_for_training() # set model train and reload optim to GPU
POLICY_GENERATION_STALE = True
print("▶ Training policy...")
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)
is_last_step = step + 1 == min(
master_config["grpo"]["max_num_steps"], len(dataloader)
)
# Run validation if it's a validation step
if is_last_step or (val_period > 0 and (step + 1) % val_period == 0):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
policy,
policy_generation,
refit_buffer_size_gb,
)
POLICY_GENERATION_STALE = False
else:
policy_generation.prepare_for_generation()
val_metrics, validation_timings = validate(
policy_generation,
val_dataloader,
tokenizer,
val_task_to_env,
step=step + 1,
master_config=master_config,
)
policy_generation.finish_generation()
logger.log_metrics(
validation_timings, step + 1, prefix="timing/validation"
)
logger.log_metrics(val_metrics, step + 1, prefix="validation")
## Checkpointing
consumed_samples += master_config["grpo"]["num_prompts_per_step"]
if master_config["checkpointing"]["enabled"] and (
is_last_step
or (step + 1) % master_config["checkpointing"]["save_period"] == 0
): # +1 because step is 0-indexed
policy.prepare_for_training()
grpo_save_state["step"] = step + 1
grpo_save_state["val_reward"] = val_metrics["accuracy"]
grpo_save_state["consumed_samples"] = consumed_samples
with timer.time("checkpointing"):
print(f"Saving checkpoint for step {step + 1}...")
checkpoint_path = checkpointer.init_tmp_checkpoint(
step + 1, grpo_save_state, master_config
)
policy.save_checkpoint(
weights_path=os.path.join(checkpoint_path, "policy", "weights"),
optimizer_path=os.path.join(
checkpoint_path, "policy", "optimizer"
),
tokenizer_path=os.path.join(
checkpoint_path, "policy", "tokenizer"
),
)
torch.save(
dataloader.state_dict(),
os.path.join(checkpoint_path, "train_dataloader.pt"),
)
checkpointer.finalize_checkpoint(checkpoint_path)
policy.offload_after_refit()
# Logging
# Log training data
log_data = {"content": flat_messages["content"]}
log_data["rewards"] = rewards.tolist()
log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist()
log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist()
log_data["input_lengths"] = input_lengths.tolist()
logger.log_batched_dict_as_jsonl(log_data, f"train_data_step{step}.jsonl")
print("\n📊 Training Results:")
metrics = {
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
if k in {"lr", "reward", "global_valid_seqs", "global_valid_toks"}:
metrics[k] = np.mean(v).item()
else:
metrics[k] = np.sum(v).item()
metrics.update(rollout_metrics)
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
print(f" • Loss: {metrics['loss']:.4f}")
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
print(
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}"
)
print("\n⏱️ Timing:")
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)
print(f" • Total step time: {total_time:.2f}s")
# Display all other timing metrics
for k, v in sorted(
timing_metrics.items(), key=lambda item: item[1], reverse=True
):
if k != "total_step_time":
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")
logger.log_metrics(metrics, step + 1, prefix="train")
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")
timer.reset()
step += 1
if step >= master_config["grpo"]["max_num_steps"]:
break
[docs]
def validate(
policy_generation: GenerationInterface,
val_dataloader: StatefulDataLoader,
tokenizer,
val_task_to_env: Dict[str, EnvironmentInterface],
step: int,
master_config: MasterConfig,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Run validation on the validation dataset."""
if val_dataloader is None:
print(" ⚠️ No validation dataloader provided, skipping validation")
return
timer = Timer()
with timer.time("total_validation_time"):
print(f"▶ Starting validation at step {step}...")
total_rewards = []
total_lengths = []
all_message_logs = [] # Collect all message logs
max_batches = (
master_config["grpo"]["max_val_samples"]
// master_config["grpo"]["val_batch_size"]
)
for batch_idx, val_batch in enumerate(val_dataloader):
if batch_idx >= max_batches:
break
# Generate responses (updates the LLMMessageLogType in batch_with_msg_logs)
val_batch, gen_metrics = run_multi_turn_rollout(
policy_generation,
val_batch,
tokenizer,
val_task_to_env,
max_seq_len=master_config["policy"]["max_total_sequence_length"],
max_rollout_turns=master_config["grpo"]["max_rollout_turns"],
greedy=False,
)
rewards = val_batch["total_reward"]
total_rewards.extend(rewards.tolist())
total_lengths.append(gen_metrics["mean_gen_tokens_per_sample"])
all_message_logs.extend(val_batch["message_log"])
# Calculate validation metrics
accuracy = sum(total_rewards) / len(total_rewards)
avg_length = sum(total_lengths) / len(total_lengths)
val_metrics = {
"accuracy": accuracy,
"avg_length": avg_length,
}
# Print sample conversations only once at the end of validation
try:
print_message_log_samples(
all_message_logs,
total_rewards,
num_samples=min(
master_config["logger"]["num_val_samples_to_print"],
len(all_message_logs),
),
step=step,
)
except Exception as e:
print(f"\n ⚠️ Error displaying message samples: {str(e)}")
print(" ⚠️ Continuing validation without displaying samples...")
# Get timing metrics
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
validation_time = timing_metrics.get("total_validation_time", 0)
# Print summary of validation results
print("\n📊 Validation Results:")
print(f" • Accuracy: {accuracy:.4f}")
print(f" • Average response length: {avg_length:.1f} tokens")
print(f" • Samples processed: {len(total_rewards)}")
# Print timing information
print("\n ⏱️ Validation Timing:")
validation_time = timing_metrics.get("total_validation_time", 0)
print(f" • Total validation time: {validation_time:.2f}s")
# Make sure to reset the timer after validation
timer.reset()
return val_metrics, timing_metrics