# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generate rollouts for arbitrary environments
# Supports multi-turn rollouts and many simultaneous environments (E.g. you can train on math, code, multi-turn games and more at once)
import asyncio
import copy
from typing import Any
import ray
import torch
from transformers import PreTrainedTokenizerBase
from nemo_rl.data.interfaces import (
DatumSpec,
FlatMessagesType,
LLMMessageLogType,
)
from nemo_rl.data.llm_message_utils import (
batched_message_log_to_flat_message,
get_keys_from_message_log,
)
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.environments.interfaces import (
EnvironmentInterface,
EnvironmentReturn,
)
from nemo_rl.models.generation.interfaces import (
GenerationDatumSpec,
GenerationInterface,
GenerationOutputSpec,
)
TokenizerType = PreTrainedTokenizerBase
[docs]
def generate_responses(
policy_generation: GenerationInterface,
generation_input_data: BatchedDataDict[GenerationDatumSpec],
batch: BatchedDataDict[DatumSpec],
tokenizer: TokenizerType,
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]:
"""Generate responses from policy using synchronous generation."""
# Add stop_strings to generation_input_data if present in the batch
if "stop_strings" in batch:
generation_input_data["stop_strings"] = batch["stop_strings"]
else:
# Ensure the key exists even if it's None, matching GenerationDatumSpec
generation_input_data["stop_strings"] = [None] * len(input_lengths)
# Always use synchronous generation
generation_outputs = policy_generation.generate(
generation_input_data, greedy=greedy
)
# Extract everything we need from the generation outputs
output_ids = generation_outputs["output_ids"]
generation_lengths = generation_outputs["generation_lengths"]
unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"]
# Extract generated parts
generated_ids = []
for i in range(len(input_lengths)):
input_len = input_lengths[i].item()
total_length = unpadded_sequence_lengths[i].item()
full_output = output_ids[i]
generated_part = full_output[input_len:total_length]
generated_ids.append(generated_part)
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Append to message log
for i, (text, input_length, total_length) in enumerate(
zip(generated_texts, input_lengths, unpadded_sequence_lengths)
):
assistant_message = {
"role": "assistant",
"content": text,
"token_ids": output_ids[i, input_length:total_length],
}
if include_logprobs and "logprobs" in generation_outputs:
assistant_message["generation_logprobs"] = generation_outputs["logprobs"][
i, input_length:total_length
]
batch["message_log"][i].append(assistant_message)
# Generation metrics
gen_metrics = {
"mean_generation_length": generation_lengths.float().mean().item(),
"total_generated_tokens": generation_lengths.sum().item(),
}
return batch, generated_ids, gen_metrics
[docs]
async def generate_responses_async(
policy_generation: GenerationInterface,
generation_input_data: BatchedDataDict[GenerationDatumSpec],
batch: BatchedDataDict[DatumSpec],
tokenizer: TokenizerType,
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
) -> tuple[BatchedDataDict[DatumSpec], list[torch.Tensor], dict[str, float | int]]:
"""Async version of generate_responses that properly calls generate_async."""
# Add stop_strings to generation_input_data if present in the batch
if "stop_strings" in batch:
generation_input_data["stop_strings"] = batch["stop_strings"]
else:
# Ensure the key exists even if it's None, matching GenerationDatumSpec
generation_input_data["stop_strings"] = [None] * len(input_lengths)
# Check if this is vLLM with async_engine enabled
use_async_generation = (
hasattr(policy_generation, "cfg")
and "vllm_cfg" in policy_generation.cfg
and policy_generation.cfg["vllm_cfg"]["async_engine"]
and hasattr(policy_generation, "generate_async")
)
assert use_async_generation, (
"Async generation is not enabled. Please enable async generation by setting async_engine=True in the vllm_cfg section of the policy config."
)
# Use async generation with per-sample streaming
collected_indexed_outputs: list[
tuple[int, BatchedDataDict[GenerationOutputSpec]]
] = []
async for original_idx, single_item_output in policy_generation.generate_async(
generation_input_data, greedy=greedy
):
collected_indexed_outputs.append((original_idx, single_item_output))
# Sort by original_idx to ensure order matches generation_input_data
collected_indexed_outputs.sort(key=lambda x: x[0])
# Extract in correct order
ordered_batched_data_dicts = [item for _, item in collected_indexed_outputs]
assert ordered_batched_data_dicts, (
"Generation returned no outputs for a non-empty batch."
)
pad_token_id = policy_generation.cfg.get("pad_token_id", tokenizer.pad_token_id)
generation_outputs = BatchedDataDict.from_batches(
ordered_batched_data_dicts,
pad_value_dict={"output_ids": pad_token_id, "logprobs": 0.0},
)
# Extract everything we need from the generation outputs
output_ids = generation_outputs["output_ids"]
generation_lengths = generation_outputs["generation_lengths"]
unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"]
# Extract generated parts
generated_ids = []
for i in range(len(input_lengths)):
input_len = input_lengths[i].item()
total_length = unpadded_sequence_lengths[i].item()
full_output = output_ids[i]
generated_part = full_output[input_len:total_length]
generated_ids.append(generated_part)
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Append to message log
for i, (text, input_length, total_length) in enumerate(
zip(generated_texts, input_lengths, unpadded_sequence_lengths)
):
assistant_message = {
"role": "assistant",
"content": text,
"token_ids": output_ids[i, input_length:total_length],
}
if include_logprobs and "logprobs" in generation_outputs:
assistant_message["generation_logprobs"] = generation_outputs["logprobs"][
i, input_length:total_length
]
batch["message_log"][i].append(assistant_message)
# Generation metrics
gen_metrics = {
"mean_generation_length": generation_lengths.float().mean().item(),
"total_generated_tokens": generation_lengths.sum().item(),
}
return batch, generated_ids, gen_metrics
[docs]
def calculate_rewards(
batch: BatchedDataDict[DatumSpec],
task_to_env: dict[str, EnvironmentInterface],
) -> EnvironmentReturn:
"""Calculate rewards for generated responses and get environment feedback.
Args:
batch: Batch containing message_log (LLMMessageLogType) with generated responses
task_to_env: Dictionary mapping task names to their corresponding environments
Returns:
EnvironmentReturn namedtuple containing:
- observations: List of observations from the environment for the next turn.
- metadata: List of extracted metadata from the environment.
- next_stop_strings: List of stop strings for the next generation step.
- rewards: Tensor of rewards for the last turn.
- terminateds: Tensor of booleans indicating if an episode ended naturally.
"""
# Extract message logs for environment (most recent interaction)
to_env = [
get_keys_from_message_log(batch["message_log"][i], ["role", "content"])
for i in range(len(batch["message_log"]))
]
task_names = batch["task_name"]
# Group messages by task type
task_groups: dict[str, list[tuple[int, LLMMessageLogType]]] = {}
for i, task_name in enumerate(task_names):
if task_name not in task_groups:
task_groups[task_name] = []
task_groups[task_name].append((i, to_env[i]))
# Calculate rewards for each task group concurrently
futures = []
future_to_indices = {} # Map future to its corresponding indices
for task_name, group in task_groups.items():
if task_name not in task_to_env:
raise ValueError(f"No environment found for task type: {task_name}")
# Extract indices and messages for this group
indices = [idx for idx, _ in group]
messages = [msg for _, msg in group]
# Get corresponding environment info
env_info = [batch["extra_env_info"][i] for i in indices]
# Submit task to environment and store future
future = task_to_env[task_name].step.remote(messages, env_info) # type: ignore # ray actor call
futures.append(future)
future_to_indices[future] = indices
results = ray.get(futures)
all_rewards = []
all_env_observations = []
all_terminateds = []
all_next_stop_strings = []
all_metadata = [] # Store extracted metadata
all_indices_order = []
for future, result in zip(futures, results):
indices = future_to_indices[future]
# Environment step returns: EnvironmentReturn
env_observations, metadata, next_stop_strings, task_rewards, terminateds = (
result
)
if next_stop_strings is None:
next_stop_strings = [None] * len(task_rewards)
# Store results with their original indices
for i, idx in enumerate(indices):
all_indices_order.append(idx)
all_rewards.append(task_rewards[i])
all_env_observations.append(env_observations[i])
all_terminateds.append(terminateds[i])
all_next_stop_strings.append(next_stop_strings[i])
all_metadata.append(metadata[i])
# Sort results by original index to maintain order
sorted_indices = sorted(
range(len(all_indices_order)), key=lambda k: all_indices_order[k]
)
rewards = torch.tensor([all_rewards[i] for i in sorted_indices])
env_observations = [all_env_observations[i] for i in sorted_indices]
terminateds = torch.tensor([all_terminateds[i] for i in sorted_indices])
next_stop_strings = [all_next_stop_strings[i] for i in sorted_indices]
metadata = [all_metadata[i] for i in sorted_indices] # Sort metadata
return EnvironmentReturn(
observations=env_observations,
metadata=metadata,
next_stop_strings=next_stop_strings,
rewards=rewards,
terminateds=terminateds,
)
[docs]
def run_multi_turn_rollout(
policy_generation: GenerationInterface,
input_batch: BatchedDataDict[DatumSpec],
tokenizer: TokenizerType,
task_to_env: dict[str, EnvironmentInterface],
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Runs a multi-turn rollout loop, interacting with the environment.
Args:
policy_generation: The generation interface (policy).
input_batch: The starting batch containing initial message logs.
tokenizer: The tokenizer.
task_to_env: Dictionary mapping task names to environment instances.
max_rollout_turns: Maximum number of agent-environment interaction turns.
max_seq_len: Maximum sequence length allowed.
greedy: Whether to use greedy decoding.
Returns:
Tuple containing:
- BatchedDataDict with the full interaction history and accumulated rewards
- Dictionary of rollout metrics
"""
current_batch = input_batch.copy() # Work on a copy
batch_size = len(current_batch["message_log"])
active_indices = torch.arange(batch_size)
total_rewards = torch.zeros(batch_size, dtype=torch.float32)
# Initialize stop_strings from the initial batch if present
current_stop_strings = current_batch.get("stop_strings", [None] * batch_size)
# Tracking metrics for each sample
sample_turn_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_assistant_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_env_token_counts = torch.zeros(batch_size, dtype=torch.int32)
sample_terminated = torch.zeros(batch_size, dtype=torch.bool)
sample_truncated = torch.zeros(batch_size, dtype=torch.bool)
sample_max_turns_reached = torch.zeros(batch_size, dtype=torch.bool)
# Tracking per-turn metrics
total_gen_tokens_per_turn = []
active_samples_per_turn = []
for turn in range(max_rollout_turns):
if len(active_indices) == 0:
break
active_samples_per_turn.append(len(active_indices))
# Convert LLMMessageLogType to FlatMessagesType for generation
active_batch = current_batch.select_indices(active_indices)
active_stop_strings = [current_stop_strings[i] for i in active_indices.tolist()]
active_flat_messages: BatchedDataDict[FlatMessagesType]
active_flat_messages, active_input_lengths = (
batched_message_log_to_flat_message(
active_batch["message_log"],
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)
# Extract input_ids and lengths from the flat messages
active_input_ids = active_flat_messages["token_ids"]
generation_input_data = BatchedDataDict[GenerationDatumSpec](
{
"input_ids": active_input_ids,
"input_lengths": active_input_lengths,
"stop_strings": active_stop_strings,
}
)
# generate_responses updates active_batch["message_log"] in-place
active_batch, generated_ids, gen_metrics = generate_responses(
policy_generation,
generation_input_data,
active_batch,
tokenizer,
input_lengths=active_input_lengths,
greedy=greedy,
)
# Record token usage - assistant
for i, global_idx in enumerate(active_indices.tolist()):
sample_assistant_token_counts[global_idx] += len(generated_ids[i])
sample_token_counts[global_idx] += len(generated_ids[i])
# Track total generated tokens this turn
total_gen_tokens_per_turn.append(sum(len(ids) for ids in generated_ids))
# Calculate rewards and get environment feedback
env_output: EnvironmentReturn = calculate_rewards(active_batch, task_to_env)
total_rewards[active_indices] += env_output.rewards
# Update message log for ALL active samples with env observation
# This must happen BEFORE filtering based on done flags
truncation_mask = torch.zeros_like(env_output.terminateds, dtype=torch.bool)
for i, global_idx in enumerate(active_indices.tolist()):
env_obs_content = env_output.observations[i]["content"]
# Tokenize the raw content from the environment
# TODO @sahilj: handle if we want these subsequent messages to have a chat template
tokenized_obs = tokenizer(
env_obs_content, return_tensors="pt", add_special_tokens=False
).input_ids[0]
# check if new message overflows max_seq_len
if (
len(tokenized_obs) + len(generated_ids[i]) + active_input_lengths[i]
>= max_seq_len
):
tokens_left_for_obs = max_seq_len - (
len(generated_ids[i]) + active_input_lengths[i]
)
assert tokens_left_for_obs >= 0, (
f"tokens_left_for_obs={tokens_left_for_obs} should not be negative. This should not happen if the inference engine respects the max sequence length."
)
# truncate
tokenized_obs = tokenized_obs[:tokens_left_for_obs]
truncation_mask[i] = True
# Record truncation
sample_truncated[active_indices[i]] = True
tokenized_env_obs_message = {
"role": env_output.observations[i]["role"],
"content": env_obs_content,
"token_ids": tokenized_obs,
}
current_batch["message_log"][global_idx].append(tokenized_env_obs_message)
# Record token usage - environment
sample_env_token_counts[global_idx] += len(tokenized_obs)
sample_token_counts[global_idx] += len(tokenized_obs)
# Increment turn count
sample_turn_counts[global_idx] += 1
# Determine done samples and update active set
terminateds = env_output.terminateds.bool()
done = truncation_mask | terminateds
sample_terminated[active_indices] |= done
# Update active indices for the next iteration
active_indices_local_next = torch.where(~done)[0]
active_indices = active_indices[active_indices_local_next]
continuing_indices_global = active_indices # Indices relative to original batch
# Get next stop strings and infos corresponding to the indices that are *continuing*
continuing_next_stops = [
env_output.next_stop_strings[i] for i in active_indices_local_next.tolist()
]
# Get metadata corresponding to continuing indices, using the correct field name
continuing_metadata = [
env_output.metadata[i] for i in active_indices_local_next.tolist()
]
for i, global_idx in enumerate(continuing_indices_global.tolist()):
# Update stop strings for the next turn
current_stop_strings[global_idx] = continuing_next_stops[i]
# Update metadata (extra_env_info) using info from environment
if continuing_metadata[i] is not None:
current_batch["extra_env_info"][global_idx] = continuing_metadata[i]
# Record samples that reached max turns
sample_max_turns_reached[active_indices] = True
# Add total rewards to the final batch
current_batch["total_reward"] = total_rewards
# Calculate aggregate metrics
rollout_metrics = {
# Overall metrics
"total_turns": int(sample_turn_counts.sum().item()),
"avg_turns_per_sample": float(sample_turn_counts.float().mean().item()),
"max_turns_per_sample": int(sample_turn_counts.max().item()),
"natural_termination_rate": float(sample_terminated.float().mean().item()),
"truncation_rate": float(sample_truncated.float().mean().item()),
"max_turns_reached_rate": float(sample_max_turns_reached.float().mean().item()),
# Token usage metrics
"mean_total_tokens_per_sample": float(
sample_token_counts.float().mean().item()
),
"mean_gen_tokens_per_sample": float(
sample_assistant_token_counts.float().mean().item()
),
"mean_env_tokens_per_sample": float(
sample_env_token_counts.float().mean().item()
),
}
return current_batch, rollout_metrics
[docs]
async def async_generate_response_for_sample_turn(
policy_generation: GenerationInterface,
sample_message_log: list[dict],
sample_stop_strings: list[str] | None,
tokenizer: TokenizerType,
max_seq_len: int,
greedy: bool = False,
) -> tuple[list[dict], torch.Tensor, torch.Tensor, dict[str, float]]:
"""Generate a response for a single sample's turn using async generation.
Args:
policy_generation: The generation interface to use
sample_message_log: Message log for a single sample
sample_stop_strings: Stop strings for this sample
tokenizer: Tokenizer to use
max_seq_len: Maximum sequence length
greedy: Whether to use greedy decoding
Returns:
Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics)
"""
from nemo_rl.data.llm_message_utils import batched_message_log_to_flat_message
# Convert single sample to batch format
batch_message_logs = [sample_message_log]
# Convert to flat format for generation
flat_messages, input_lengths = batched_message_log_to_flat_message(
batch_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
# Create generation input
generation_input_data = BatchedDataDict[GenerationDatumSpec](
{
"input_ids": flat_messages["token_ids"],
"input_lengths": input_lengths,
"stop_strings": [sample_stop_strings],
}
)
# Create a dummy batch for generate_responses_async
dummy_batch = BatchedDataDict[DatumSpec](
{
"message_log": batch_message_logs,
"stop_strings": [sample_stop_strings],
}
)
# Generate response using the async version
updated_batch, generated_ids, gen_metrics = await generate_responses_async(
policy_generation,
generation_input_data,
dummy_batch,
tokenizer,
input_lengths=input_lengths,
include_logprobs=True,
greedy=greedy,
)
# Extract results for the single sample
updated_message_log = updated_batch["message_log"][0]
generated_tokens = generated_ids[0] if generated_ids else torch.empty(0)
return updated_message_log, generated_tokens, input_lengths, gen_metrics
[docs]
async def run_sample_multi_turn_rollout(
sample_idx: int,
initial_sample_state: dict,
policy_generation: GenerationInterface,
tokenizer: TokenizerType,
task_to_env: dict[str, EnvironmentInterface],
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
) -> tuple[dict, dict[str, Any]]:
"""Run a multi-turn rollout for a single sample.
This function manages the complete lifecycle of one sample's interaction.
Async generation is used internally when available.
Args:
sample_idx: Index of this sample in the original batch
initial_sample_state: Initial state containing message_log, extra_env_info, etc.
policy_generation: The generation interface
tokenizer: Tokenizer to use
task_to_env: Environment mapping
max_seq_len: Maximum sequence length
max_rollout_turns: Maximum number of turns
greedy: Whether to use greedy decoding
Returns:
Tuple of (final_sample_state, sample_metrics)
"""
# Initialize sample state
current_message_log = copy.deepcopy(initial_sample_state["message_log"])
current_extra_env_info = copy.deepcopy(initial_sample_state["extra_env_info"])
current_stop_strings = initial_sample_state.get("stop_strings", None)
task_name = initial_sample_state["task_name"]
# Sample-level metrics
total_reward = 0.0
turn_count = 0
token_count = 0
assistant_token_count = 0
env_token_count = 0
terminated = False
truncated = False
max_turns_reached = False
# Track per-turn metrics
turn_gen_tokens = []
for turn in range(max_rollout_turns):
if terminated or truncated:
break
turn_count += 1
# Generate response for this sample using async generation
try:
(
updated_message_log,
generated_tokens,
input_lengths,
gen_metrics,
) = await async_generate_response_for_sample_turn(
policy_generation,
current_message_log,
current_stop_strings,
tokenizer,
max_seq_len,
greedy=greedy,
)
current_message_log = updated_message_log
# Update token counts
gen_token_count = len(generated_tokens)
assistant_token_count += gen_token_count
token_count += gen_token_count
turn_gen_tokens.append(gen_token_count)
except Exception as e:
print(f"Error generating response for sample {sample_idx}: {e}")
break
# Create single-sample batch for environment interaction
sample_batch = BatchedDataDict[DatumSpec](
{
"message_log": [current_message_log],
"extra_env_info": [current_extra_env_info],
"task_name": [task_name],
}
)
# Get environment feedback
env_output = calculate_rewards(sample_batch, task_to_env)
# Update total reward
total_reward += env_output.rewards[0].item()
# Check termination
terminated = env_output.terminateds[0].item()
env_obs_content = env_output.observations[0]["content"]
# Tokenize environment response
tokenized_obs = tokenizer(
env_obs_content, return_tensors="pt", add_special_tokens=False
).input_ids[0]
# Check for sequence length overflow
if input_lengths + gen_token_count + len(tokenized_obs) >= max_seq_len:
# Truncate environment observation
max_env_tokens = max_seq_len - input_lengths - gen_token_count
if max_env_tokens > 0:
tokenized_obs = tokenized_obs[:max_env_tokens]
else:
tokenized_obs = torch.empty(0, dtype=tokenized_obs.dtype)
truncated = True
env_message = {
"role": env_output.observations[0]["role"],
"content": env_obs_content,
"token_ids": tokenized_obs,
}
current_message_log.append(env_message)
# Update token counts
env_token_count += len(tokenized_obs)
token_count += len(tokenized_obs)
# Update sample state for next turn
if not terminated and not truncated:
if env_output.next_stop_strings[0] is not None:
current_stop_strings = env_output.next_stop_strings[0]
if env_output.metadata[0] is not None:
current_extra_env_info = env_output.metadata[0]
# Check if max turns reached
if turn_count >= max_rollout_turns:
max_turns_reached = True
# Prepare final sample state
final_sample_state = {
"message_log": current_message_log,
"extra_env_info": current_extra_env_info,
"task_name": task_name,
"total_reward": torch.tensor(total_reward),
"stop_strings": current_stop_strings,
"idx": sample_idx,
}
# Sample metrics
sample_metrics = {
"turn_count": turn_count,
"total_tokens": token_count,
"assistant_tokens": assistant_token_count,
"env_tokens": env_token_count,
"terminated": terminated,
"truncated": truncated,
"max_turns_reached": max_turns_reached,
"total_reward": total_reward,
"turn_gen_tokens": turn_gen_tokens,
}
return final_sample_state, sample_metrics
[docs]
def run_async_multi_turn_rollout(
policy_generation: GenerationInterface,
input_batch: BatchedDataDict[DatumSpec],
tokenizer: TokenizerType,
task_to_env: dict[str, EnvironmentInterface],
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
) -> tuple[BatchedDataDict[DatumSpec], dict[str, Any]]:
"""Run multi-turn rollouts with sample-level processing.
Each sample in the batch proceeds through its interaction independently.
Async generation is used internally when available but the function is synchronous.
Args:
policy_generation: The generation interface (policy)
input_batch: The starting batch containing initial message logs
tokenizer: The tokenizer
task_to_env: Dictionary mapping task names to environment instances
max_seq_len: Maximum sequence length allowed
max_rollout_turns: Maximum number of agent-environment interaction turns
greedy: Whether to use greedy decoding
Returns:
Tuple containing:
- BatchedDataDict with the full interaction history and accumulated rewards
- Dictionary of rollout metrics
"""
async def _async_rollout_implementation():
"""Internal async implementation."""
batch_size = len(input_batch["message_log"])
# Prepare initial states for each sample
sample_initial_states = []
for i in range(batch_size):
sample_state = {
"message_log": input_batch["message_log"][i],
"extra_env_info": input_batch["extra_env_info"][i],
"task_name": input_batch["task_name"][i],
"stop_strings": input_batch.get("stop_strings", [None] * batch_size)[i],
"idx": input_batch.get("idx", list(range(batch_size)))[i],
}
sample_initial_states.append(sample_state)
# Run all samples concurrently
async def run_single_sample_with_error_handling(i, sample_state):
"""Wrapper to handle errors for individual sample rollouts."""
try:
result = await run_sample_multi_turn_rollout(
sample_idx=i,
initial_sample_state=sample_state,
policy_generation=policy_generation,
tokenizer=tokenizer,
task_to_env=task_to_env,
max_seq_len=max_seq_len,
max_rollout_turns=max_rollout_turns,
greedy=greedy,
)
return result
except Exception as e:
raise RuntimeError(f"Error in sample {i} rollout: {e}") from e
# Create tasks for all samples and run them concurrently
sample_tasks = [
run_single_sample_with_error_handling(i, sample_state)
for i, sample_state in enumerate(sample_initial_states)
]
# Execute all sample rollouts concurrently
sample_results = await asyncio.gather(*sample_tasks, return_exceptions=False)
# Process results
final_sample_states = []
all_sample_metrics = []
for final_state, sample_metrics in sample_results:
final_sample_states.append(final_state)
all_sample_metrics.append(sample_metrics)
# Reconstruct batch from sample results
batch_size = len(final_sample_states)
final_batch = BatchedDataDict[DatumSpec](
{
"message_log": [state["message_log"] for state in final_sample_states],
"extra_env_info": [
state["extra_env_info"] for state in final_sample_states
],
"task_name": [state["task_name"] for state in final_sample_states],
"total_reward": torch.stack(
[state["total_reward"] for state in final_sample_states]
),
"idx": [
state.get("idx", i) for i, state in enumerate(final_sample_states)
],
}
)
# Preserve additional fields from the original input_batch
for key in input_batch.keys():
if key not in final_batch:
final_batch[key] = input_batch[key]
# Aggregate metrics across all samples
rollout_metrics = {
# Overall metrics
"total_turns": sum(m["turn_count"] for m in all_sample_metrics),
"avg_turns_per_sample": sum(m["turn_count"] for m in all_sample_metrics)
/ batch_size,
"max_turns_per_sample": max(m["turn_count"] for m in all_sample_metrics),
"natural_termination_rate": sum(m["terminated"] for m in all_sample_metrics)
/ batch_size,
"truncation_rate": sum(m["truncated"] for m in all_sample_metrics)
/ batch_size,
"max_turns_reached_rate": sum(
m["max_turns_reached"] for m in all_sample_metrics
)
/ batch_size,
# Token usage metrics
"mean_total_tokens_per_sample": sum(
m["total_tokens"] for m in all_sample_metrics
)
/ batch_size,
"mean_gen_tokens_per_sample": sum(
m["assistant_tokens"] for m in all_sample_metrics
)
/ batch_size,
"mean_env_tokens_per_sample": sum(
m["env_tokens"] for m in all_sample_metrics
)
/ batch_size,
# Reward metrics
"mean_total_reward": sum(m["total_reward"] for m in all_sample_metrics)
/ batch_size,
"max_total_reward": max(m["total_reward"] for m in all_sample_metrics),
"min_total_reward": min(m["total_reward"] for m in all_sample_metrics),
}
return final_batch, rollout_metrics
return asyncio.run(_async_rollout_implementation())