nemo_rl.experience.rollouts
#
Module Contents#
Functions#
Generate responses from policy. |
|
Calculate rewards for generated responses and get environment feedback. |
|
Runs a multi-turn rollout loop, interacting with the environment. |
API#
- nemo_rl.experience.rollouts.generate_responses(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- generation_input_data: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.models.generation.interfaces.GenerationDatumSpec],
- batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- tokenizer: transformers.AutoTokenizer,
- input_lengths: torch.Tensor,
- include_logprobs: bool = True,
- greedy: bool = False,
Generate responses from policy.
- nemo_rl.experience.rollouts.calculate_rewards(
- batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
Calculate rewards for generated responses and get environment feedback.
- Parameters:
batch – Batch containing message_log (LLMMessageLogType) with generated responses
task_to_env – Dictionary mapping task names to their corresponding environments
- Returns:
observations: List of observations from the environment for the next turn.
metadata: List of extracted metadata from the environment.
next_stop_strings: List of stop strings for the next generation step.
rewards: Tensor of rewards for the last turn.
terminateds: Tensor of booleans indicating if an episode ended naturally.
- Return type:
EnvironmentReturn namedtuple containing
- nemo_rl.experience.rollouts.run_multi_turn_rollout(
- policy_generation: nemo_rl.models.generation.interfaces.GenerationInterface,
- input_batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[nemo_rl.data.interfaces.DatumSpec],
- tokenizer: transformers.AutoTokenizer,
- task_to_env: Dict[str, nemo_rl.environments.interfaces.EnvironmentInterface],
- max_seq_len: int,
- max_rollout_turns: int = 999999,
- greedy: bool = False,
Runs a multi-turn rollout loop, interacting with the environment.
- Parameters:
policy_generation – The generation interface (policy).
input_batch – The starting batch containing initial message logs.
tokenizer – The tokenizer.
task_to_env – Dictionary mapping task names to environment instances.
max_rollout_turns – Maximum number of agent-environment interaction turns.
max_seq_len – Maximum sequence length allowed.
greedy – Whether to use greedy decoding.
- Returns:
BatchedDataDict with the full interaction history and accumulated rewards
Dictionary of rollout metrics
- Return type:
Tuple containing