nemo_rl.environments.interfaces#

Module Contents#

Classes#

EnvironmentReturn

Standard batched return type for environment step methods.

EnvironmentInterface

API#

class nemo_rl.environments.interfaces.EnvironmentReturn[source]#

Bases: typing.NamedTuple

Standard batched return type for environment step methods.

All elements are batched. observations: New observation from the environment. It’s a (batched) ‘message’ type, which is a dict with keys ‘role’ and ‘content’. metadata: Updated metadata from the environment. next_stop_strings: The stop strings for the next turn. If your environment is a game or similar, you may want to return a list of stop strings that are valid actions for the next turn or similar. This field lets you control this per turn. rewards: the rewards for this turn. terminateds: whether the episode ended this turn.

observations: List[Dict[str, str]]#

None

metadata: List[Optional[dict]]#

None

next_stop_strings: List[Optional[List[str]]]#

None

rewards: torch.Tensor#

None

terminateds: torch.Tensor#

None

class nemo_rl.environments.interfaces.EnvironmentInterface[source]#

Bases: abc.ABC

abstractmethod step(
message_log_batch: List[nemo_rl.data.interfaces.LLMMessageLogType],
metadata: List[Optional[dict]],
*args,
**kwargs,
) nemo_rl.environments.interfaces.EnvironmentReturn[source]#

Runs a step in the environment. Allows for asynchrony with remote servers, but it’s not required (this function is a ray remote).

message_log_batch: batch of OpenAI-API-like message logs that represent interactions with the LLM. Each element is a List[Dict[str, Union[str, torch.Tensor]]]. For example, if this were a Math Environment, then the message log would be [ {“role”: “user”, “content”: “problem”}, {“role”: “assistant”, “content”: “response”}, ] but if this were a code environment with feedback, it would be: [ {“role”: “user”, “content”: “problem”}, {“role”: “assistant”, “content”: “response”}, {“role”: “user”, “content”: “code result”}, {“role”: “assistant”, “content”: “model response”}, ] metadata: batch of whatever the environment needs to keep track of. I.e. math solutions, code unit tests, or agent states. Can be None if episode terminated.

Returns:

  • EnvironmentReturn NamedTuple containing observations, metadata, next_stop_strings, rewards, and terminateds flags.

abstractmethod global_post_process_and_metrics(
batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict,
) Tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict, dict][source]#

Post processing function after all rollouts are done for the batch and returns metrics.