nemo_rl.environments.interfaces
#
Module Contents#
Classes#
Standard batched return type for environment step methods. |
|
API#
- class nemo_rl.environments.interfaces.EnvironmentReturn[source]#
Bases:
typing.NamedTuple
Standard batched return type for environment step methods.
All elements are batched. observations: New observation from the environment. It’s a (batched) ‘message’ type, which is a dict with keys ‘role’ and ‘content’. metadata: Updated metadata from the environment. next_stop_strings: The stop strings for the next turn. If your environment is a game or similar, you may want to return a list of stop strings that are valid actions for the next turn or similar. This field lets you control this per turn. rewards: the rewards for this turn. terminateds: whether the episode ended this turn.
- observations: List[Dict[str, str]]#
None
- metadata: List[Optional[dict]]#
None
- next_stop_strings: List[Optional[List[str]]]#
None
- rewards: torch.Tensor#
None
- terminateds: torch.Tensor#
None
- class nemo_rl.environments.interfaces.EnvironmentInterface[source]#
Bases:
abc.ABC
- abstractmethod step(
- message_log_batch: List[nemo_rl.data.interfaces.LLMMessageLogType],
- metadata: List[Optional[dict]],
- *args,
- **kwargs,
Runs a step in the environment. Allows for asynchrony with remote servers, but it’s not required (this function is a ray remote).
message_log_batch: batch of OpenAI-API-like message logs that represent interactions with the LLM. Each element is a List[Dict[str, Union[str, torch.Tensor]]]. For example, if this were a Math Environment, then the message log would be [ {“role”: “user”, “content”: “problem”}, {“role”: “assistant”, “content”: “response”}, ] but if this were a code environment with feedback, it would be: [ {“role”: “user”, “content”: “problem”}, {“role”: “assistant”, “content”: “response”}, {“role”: “user”, “content”: “code result”}, {“role”: “assistant”, “content”: “model response”}, ] metadata: batch of whatever the environment needs to keep track of. I.e. math solutions, code unit tests, or agent states. Can be None if episode terminated.
Returns:
EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags.
- abstractmethod global_post_process_and_metrics( ) Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict] [source]#
Post processing function after all rollouts are done for the batch and returns metrics.