nemo_rl.experience.rollouts
#
Module Contents#
Functions#
Generate responses from policy using synchronous generation. |
|
Async version of generate_responses that properly calls generate_async. |
|
Calculate rewards for generated responses and get environment feedback. |
|
Runs a multi-turn rollout loop, interacting with the environment. |
|
Generate a response for a single sampleβs turn using async generation. |
|
Run a multi-turn rollout for a single sample. |
|
Run multi-turn rollouts with sample-level processing. |
Data#
API#
- nemo_rl.experience.rollouts.TokenizerType#
None
- nemo_rl.experience.rollouts.generate_responses(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- tokenizer: nemo_rl.experience.rollouts.TokenizerType,
- input_lengths: torch.Tensor,
- include_logprobs: bool = True,
- greedy: bool = False,
Generate responses from policy using synchronous generation.
- async nemo_rl.experience.rollouts.generate_responses_async(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- tokenizer: nemo_rl.experience.rollouts.TokenizerType,
- input_lengths: torch.Tensor,
- include_logprobs: bool = True,
- greedy: bool = False,
Async version of generate_responses that properly calls generate_async.
- nemo_rl.experience.rollouts.calculate_rewards(
- batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
Calculate rewards for generated responses and get environment feedback.
- Parameters:
batch β Batch containing message_log (LLMMessageLogType) with generated responses
task_to_env β Dictionary mapping task names to their corresponding environments
- Returns:
observations: List of observations from the environment for the next turn.
metadata: List of extracted metadata from the environment.
next_stop_strings: List of stop strings for the next generation step.
rewards: Tensor of rewards for the last turn.
terminateds: Tensor of booleans indicating if an episode ended naturally.
- Return type:
EnvironmentReturn namedtuple containing
- nemo_rl.experience.rollouts.run_multi_turn_rollout(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- tokenizer: nemo_rl.experience.rollouts.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- max_seq_len: int,
- max_rollout_turns: int = 999999,
- greedy: bool = False,
Runs a multi-turn rollout loop, interacting with the environment.
- Parameters:
policy_generation β The generation interface (policy).
input_batch β The starting batch containing initial message logs.
tokenizer β The tokenizer.
task_to_env β Dictionary mapping task names to environment instances.
max_rollout_turns β Maximum number of agent-environment interaction turns.
max_seq_len β Maximum sequence length allowed.
greedy β Whether to use greedy decoding.
- Returns:
BatchedDataDict with the full interaction history and accumulated rewards
Dictionary of rollout metrics
- Return type:
Tuple containing
- async nemo_rl.experience.rollouts.async_generate_response_for_sample_turn(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- sample_message_log: list[dict],
- sample_stop_strings: list[str] | None,
- tokenizer: nemo_rl.experience.rollouts.TokenizerType,
- max_seq_len: int,
- greedy: bool = False,
Generate a response for a single sampleβs turn using async generation.
- Parameters:
policy_generation β The generation interface to use
sample_message_log β Message log for a single sample
sample_stop_strings β Stop strings for this sample
tokenizer β Tokenizer to use
max_seq_len β Maximum sequence length
greedy β Whether to use greedy decoding
- Returns:
Tuple of (updated_message_log, generated_tokens, input_lengths, generation_metrics)
- async nemo_rl.experience.rollouts.run_sample_multi_turn_rollout(
- sample_idx: int,
- initial_sample_state: dict,
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- tokenizer: nemo_rl.experience.rollouts.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- max_seq_len: int,
- max_rollout_turns: int = 999999,
- greedy: bool = False,
Run a multi-turn rollout for a single sample.
This function manages the complete lifecycle of one sampleβs interaction. Async generation is used internally when available.
- Parameters:
sample_idx β Index of this sample in the original batch
initial_sample_state β Initial state containing message_log, extra_env_info, etc.
policy_generation β The generation interface
tokenizer β Tokenizer to use
task_to_env β Environment mapping
max_seq_len β Maximum sequence length
max_rollout_turns β Maximum number of turns
greedy β Whether to use greedy decoding
- Returns:
Tuple of (final_sample_state, sample_metrics)
- nemo_rl.experience.rollouts.run_async_multi_turn_rollout(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- tokenizer: nemo_rl.experience.rollouts.TokenizerType,
- task_to_env: dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- max_seq_len: int,
- max_rollout_turns: int = 999999,
- greedy: bool = False,
Run multi-turn rollouts with sample-level processing.
Each sample in the batch proceeds through its interaction independently. Async generation is used internally when available but the function is synchronous.
- Parameters:
policy_generation β The generation interface (policy)
input_batch β The starting batch containing initial message logs
tokenizer β The tokenizer
task_to_env β Dictionary mapping task names to environment instances
max_seq_len β Maximum sequence length allowed
max_rollout_turns β Maximum number of agent-environment interaction turns
greedy β Whether to use greedy decoding
- Returns:
BatchedDataDict with the full interaction history and accumulated rewards
Dictionary of rollout metrics
- Return type:
Tuple containing