# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generate rollouts for arbitrary environments
# Supports multi-turn rollouts and many simultaneous environments (E.g. you can train on math, code, multi-turn games and more at once)
from typing import Any, Dict, List, Tuple
import ray
import torch
from transformers import AutoTokenizer
from nemo_rl.data.interfaces import (
DatumSpec,
FlatMessagesType,
)
from nemo_rl.data.llm_message_utils import (
batched_message_log_to_flat_message,
get_keys_from_message_log,
)
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.environments.interfaces import (
EnvironmentInterface,
EnvironmentReturn,
)
from nemo_rl.models.generation.interfaces import (
GenerationDatumSpec,
GenerationInterface,
)
[docs]
def generate_responses(
policy_generation: GenerationInterface,
generation_input_data: BatchedDataDict[GenerationDatumSpec],
batch: BatchedDataDict[DatumSpec],
tokenizer: AutoTokenizer,
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
) -> Tuple[BatchedDataDict[DatumSpec], List[torch.Tensor], dict]:
"""Generate responses from policy."""
# Add stop_strings to generation_input_data if present in the batch
if "stop_strings" in batch:
generation_input_data["stop_strings"] = batch["stop_strings"]
else:
# Ensure the key exists even if it's None, matching GenerationDatumSpec
generation_input_data["stop_strings"] = [None] * len(input_lengths)
# Generate responses
generation_outputs = policy_generation.generate(
generation_input_data, greedy=greedy
)
# Extract generated tokens
generated_ids = []
unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"]
for output_ids, input_length, total_length in zip(
generation_outputs["output_ids"], input_lengths, unpadded_sequence_lengths
):
generated_ids.append(output_ids[input_length:total_length])
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Append to message log
for i, (text, input_length, total_length) in enumerate(
zip(generated_texts, input_lengths, unpadded_sequence_lengths)
):
message = {
"role": "assistant",
"content": text,
"token_ids": generation_outputs["output_ids"][i, input_length:total_length],
}
if include_logprobs and "logprobs" in generation_outputs:
message["generation_logprobs"] = generation_outputs["logprobs"][
i, input_length:total_length
]
batch["message_log"][i].append(message)
metrics = {
"mean_generation_length": (
torch.sum(unpadded_sequence_lengths) - torch.sum(input_lengths)
).item()
/ len(unpadded_sequence_lengths),
"max_seqlen": torch.max(unpadded_sequence_lengths).item(),
}
return batch, generated_ids, metrics
[docs]
def calculate_rewards(
batch: BatchedDataDict[DatumSpec],
task_to_env: Dict[str, EnvironmentInterface],
) -> EnvironmentReturn:
"""Calculate rewards for generated responses and get environment feedback.
Args:
batch: Batch containing message_log (LLMMessageLogType) with generated responses
task_to_env: Dictionary mapping task names to their corresponding environments
Returns:
EnvironmentReturn namedtuple containing:
- observations: List of observations from the environment for the next turn.
- metadata: List of extracted metadata from the environment.
- next_stop_strings: List of stop strings for the next generation step.
- rewards: Tensor of rewards for the last turn.
- terminateds: Tensor of booleans indicating if an episode ended naturally.
"""
# Extract message logs for environment (most recent interaction)
to_env = [
get_keys_from_message_log(batch["message_log"][i], ["role", "content"])
for i in range(len(batch["message_log"]))
]
task_names = batch["task_name"]
# Group messages by task type
task_groups = {}
for i, task_name in enumerate(task_names):
if task_name not in task_groups:
task_groups[task_name] = []
task_groups[task_name].append((i, to_env[i]))
# Calculate rewards for each task group concurrently
futures = []
future_to_indices = {} # Map future to its corresponding indices
for task_name, group in task_groups.items():
if task_name not in task_to_env:
raise ValueError(f"No environment found for task type: {task_name}")
# Extract indices and messages for this group
indices = [idx for idx, _ in group]
messages = [msg for _, msg in group]
# Get corresponding environment info
env_info = [batch["extra_env_info"][i] for i in indices]
# Submit task to environment and store future
future = task_to_env[task_name].step.remote(messages, env_info)
futures.append(future)
future_to_indices[future] = indices
results = ray.get(futures)
all_rewards = []
all_env_observations = []
all_terminateds = []
all_next_stop_strings = []
all_metadata = [] # Store extracted metadata
all_indices_order = []
for future, result in zip(futures, results):
indices = future_to_indices[future]
# Environment step returns: EnvironmentReturn
env_observations, metadata, next_stop_strings, task_rewards, terminateds = (
result
)
if next_stop_strings is None:
next_stop_strings = [None] * len(task_rewards)
# Store results with their original indices
for i, idx in enumerate(indices):
all_indices_order.append(idx)
all_rewards.append(task_rewards[i])
all_env_observations.append(env_observations[i])
all_terminateds.append(terminateds[i])
all_next_stop_strings.append(next_stop_strings[i])
all_metadata.append(metadata[i])
# Sort results by original index to maintain order
sorted_indices = sorted(
range(len(all_indices_order)), key=lambda k: all_indices_order[k]
)
rewards = torch.tensor([all_rewards[i] for i in sorted_indices])
env_observations = [all_env_observations[i] for i in sorted_indices]
terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices])
next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices]
metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata
return EnvironmentReturn(
observations=env_observations,
metadata=metadata,
next_stop_strings=next_stop_strings,
rewards=rewards,
terminateds=terminateds,
)
[docs]
def run_multi_turn_rollout(
policy_generation: GenerationInterface,
input_batch: BatchedDataDict[DatumSpec],
tokenizer: AutoTokenizer,
task_to_env: Dict[str, EnvironmentInterface],
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
) -> Tuple[BatchedDataDict[DatumSpec], Dict[str, Any]]:
"""Runs a multi-turn rollout loop, interacting with the environment.
Args:
policy_generation: The generation interface (policy).
input_batch: The starting batch containing initial message logs.
tokenizer: The tokenizer.
task_to_env: Dictionary mapping task names to environment instances.
max_rollout_turns: Maximum number of agent-environment interaction turns.
max_seq_len: Maximum sequence length allowed.
greedy: Whether to use greedy decoding.
Returns:
Tuple containing:
- BatchedDataDict with the full interaction history and accumulated rewards
- Dictionary of rollout metrics
"""
current_batch = input_batch.copy() # Work on a copy
batch_size = len(current_batch["message_log"])
active_indices = torch.arange(batch_size)
total_rewards = torch.zeros(batch_size, dtype=torch.float32)
# Initialize stop_strings from the initial batch if present
current_stop_strings = current_batch.get("stop_strings", [None] * batch_size)
# Tracking metrics for each sample
sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_assistant_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_env_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_terminated = torch.zeros(batch_size, dtype=torch.bool)
sample_truncated = torch.zeros(batch_size, dtype=torch.bool)
sample_max_turns_reached = torch.zeros(batch_size, dtype=torch.bool)
# Tracking per-turn metrics
total_gen_tokens_per_turn = []
active_samples_per_turn = []
for turn in range(max_rollout_turns):
if len(active_indices) == 0:
break
active_samples_per_turn.append(len(active_indices))
# Convert LLMMessageLogType to FlatMessagesType for generation
active_batch = current_batch.select_indices(active_indices)
active_stop_strings = [current_stop_strings[i] for i in active_indices.tolist()]
active_flat_messages: FlatMessagesType
active_flat_messages, active_input_lengths = (
batched_message_log_to_flat_message(
active_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)
# Extract input_ids and lengths from the flat messages
active_input_ids = active_flat_messages["token_ids"]
generation_input_data = BatchedDataDict[GenerationDatumSpec](
{
"input_ids": active_input_ids,
"input_lengths": active_input_lengths,
"stop_strings": active_stop_strings,
}
)
# generate_responses updates active_batch["message_log"] in-place
active_batch, generated_ids, gen_metrics = generate_responses(
policy_generation,
generation_input_data,
active_batch,
tokenizer,
active_input_lengths,
greedy=greedy,
)
# Record token usage - assistant
for i, global_idx in enumerate(active_indices.tolist()):
sample_assistant_token_counts[global_idx] += len(generated_ids[i])
sample_token_counts[global_idx] += len(generated_ids[i])
# Track total generated tokens this turn
total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids))
# Calculate rewards and get environment feedback
env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env)
total_rewards[active_indices] += env_output.rewards
# Update message log for ALL active samples with env observation
# This must happen BEFORE filtering based on done flags
truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool)
for i, global_idx in enumerate(active_indices.tolist()):
env_obs_content = env_output.observations[i]["content"]
# Tokenize the raw content from the environment
# TODO @sahilj: handle if we want these subsequent messages to have a chat template
tokenized_obs = tokenizer(
env_obs_content, return_tensors="pt", add_special_tokens=False
)["input_ids"][0]
# check if new message overflows max_seq_len
if (
len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i]
>= max_seq_len
):
# truncate
tokenized_obs = tokenized_obs[
: max_seq_len - (len(generated_ids[i]) + active_input_lengths[i])
]
truncation_mask[i] = True
# Record truncation
sample_truncated[active_indices[i]] = True
tokenized_env_obs_message = {
"role": env_output.observations[i]["role"],
"content": env_obs_content,
"token_ids": tokenized_obs,
}
current_batch["message_log"][global_idx].append(tokenized_env_obs_message)
# Record token usage - environment
sample_env_token_counts[global_idx] += len(tokenized_obs)
sample_token_counts[global_idx] += len(tokenized_obs)
# Increment turn count
sample_turn_counts[global_idx] += 1
# Determine done samples and update active set
terminateds = env_output.terminateds.bool()
done = truncation_mask | terminateds
sample_terminated[active_indices] |= done
# Update active indices for the next iteration
active_indices_local_next = torch.where(~done)[0]
active_indices = active_indices[active_indices_local_next]
continuing_indices_global = active_indices # Indices relative to original batch
# Get next stop strings and infos corresponding to the indices that are *continuing*
continuing_next_stops = [
env_output.next_stop_strings[i] for i in active_indices_local_next.tolist()
]
# Get metadata corresponding to continuing indices, using the correct field name
continuing_metadata = [
env_output.metadata[i] for i in active_indices_local_next.tolist()
]
for i, global_idx in enumerate(continuing_indices_global.tolist()):
# Update stop strings for the next turn
current_stop_strings[global_idx] = continuing_next_stops[i]
# Update metadata (extra_env_info) using info from environment
if continuing_metadata[i] is not None:
current_batch["extra_env_info"][global_idx] = continuing_metadata[i]
# Record samples that reached max turns
sample_max_turns_reached[active_indices] = True
# Add total rewards to the final batch
current_batch["total_reward"] = total_rewards
# Calculate aggregate metrics
rollout_metrics = {
# Overall metrics
"total_turns": int(sample_turn_counts.sum().item()),
"avg_turns_per_sample": float(sample_turn_counts.float().mean().item()),
"max_turns_per_sample": int(sample_turn_counts.max().item()),
"natural_termination_rate": float(sample_terminated.float().mean().item()),
"truncation_rate": float(sample_truncated.float().mean().item()),
"max_turns_reached_rate": float(sample_max_turns_reached.float().mean().item()),
# Token usage metrics
"mean_gen_tokens_per_sample": float(sample_token_counts.float().mean().item()),
"mean_env_tokens_per_sample": float(
sample_env_token_counts.float().mean().item()
),
}
return current_batch, rollout_metrics