nemo_rl.experience.rollouts#

Module Contents#

Functions#

generate_responses

Generate responses from policy.

calculate_rewards

Calculate rewards for generated responses and get environment feedback.

run_multi_turn_rollout

Runs a multi-turn rollout loop, interacting with the environment.

API#

nemo_rl.experience.rollouts.generate_responses(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
tokenizer: transformers.AutoTokenizer,
input_lengths: torch.Tensor,
include_logprobs: bool = True,
greedy: bool = False,
) Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], List[torch.Tensor], dict][source]#

Generate responses from policy.

nemo_rl.experience.rollouts.calculate_rewards(
batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
) nemo_rl.environments.interfaces.EnvironmentReturn[source]#

Calculate rewards for generated responses and get environment feedback.

Parameters:
  • batch – Batch containing message_log (LLMMessageLogType) with generated responses

  • task_to_env – Dictionary mapping task names to their corresponding environments

Returns:

  • observations: List of observations from the environment for the next turn.

  • metadata: List of extracted metadata from the environment.

  • next_stop_strings: List of stop strings for the next generation step.

  • rewards: Tensor of rewards for the last turn.

  • terminateds: Tensor of booleans indicating if an episode ended naturally.

Return type:

EnvironmentReturn namedtuple containing

nemo_rl.experience.rollouts.run_multi_turn_rollout(
policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
tokenizer: transformers.AutoTokenizer,
task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
max_seq_len: int,
max_rollout_turns: int = 999999,
greedy: bool = False,
) Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec], Dict[str, Any]][source]#

Runs a multi-turn rollout loop, interacting with the environment.

Parameters:
  • policy_generation – The generation interface (policy).

  • input_batch – The starting batch containing initial message logs.

  • tokenizer – The tokenizer.

  • task_to_env – Dictionary mapping task names to environment instances.

  • max_rollout_turns – Maximum number of agent-environment interaction turns.

  • max_seq_len – Maximum sequence length allowed.

  • greedy – Whether to use greedy decoding.

Returns:

  • BatchedDataDict with the full interaction history and accumulated rewards

  • Dictionary of rollout metrics

Return type:

Tuple containing