nemo_rl.environments.math_environment#

Module Contents#

Classes#

Functions#

API#

class nemo_rl.environments.math_environment.MathEnvConfig[source]#

Bases: typing.TypedDict

num_workers: int#

None

stop_strings: Optional[list[str]]#

None

verifier_type: Optional[str]#

None

nemo_rl.environments.math_environment._mute_output()[source]#
class nemo_rl.environments.math_environment.HFVerifyWorker[source]#

Initialization

verify(
pred_responses: list[str],
ground_truths: list[str],
) list[float][source]#

Verify the correctness of the predicted responses against the ground truth.

Parameters:
  • pred_responses โ€“ list[str]. The predicted responses from the LLM.

  • ground_truths โ€“ list[str]. The ground truth responses.

Returns:

list[float]. The rewards for each predicted response.

class nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker[source]#
verify(
pred_responses: list[str],
ground_truths: list[str],
) list[float][source]#

Verify the correctness of the predicted responses against the ground truth.

Parameters:
  • pred_responses โ€“ list[str]. The predicted responses from the LLM.

  • ground_truths โ€“ list[str]. The ground truth responses.

Returns:

list[float]. The rewards for each predicted response.

class nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker[source]#
verify(
pred_responses: list[str],
ground_truths: list[str],
) list[float][source]#

Verify the correctness of the predicted responses against the ground truth.

Parameters:
  • pred_responses โ€“ list[str]. The predicted responses from the LLM.

  • ground_truths โ€“ list[str]. The ground truth responses.

Returns:

list[float]. The rewards for each predicted response.

class nemo_rl.environments.math_environment.MathEnvironmentMetadata[source]#

Bases: typing.TypedDict

ground_truth: str#

None

class nemo_rl.environments.math_environment.MathEnvironment(
cfg: nemo_rl.environments.math_environment.MathEnvConfig,
)[source]#

Bases: nemo_rl.environments.interfaces.EnvironmentInterface[nemo_rl.environments.math_environment.MathEnvironmentMetadata]

shutdown() None[source]#
step(
message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType],
metadata: list[nemo_rl.environments.math_environment.MathEnvironmentMetadata],
) nemo_rl.environments.interfaces.EnvironmentReturn[nemo_rl.environments.math_environment.MathEnvironmentMetadata][source]#

Runs a step in the math environment.

Parameters:
  • message_log โ€“ list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM.

  • metadata โ€“ list[MathEnvironmentMetadata]. The grader will use the โ€˜ground_truthโ€™ key to evaluate correctness.

Returns:

A tuple containing: - list[dict[str, str]]: Observations/responses batch - list[dict]: Updated metadata - list[str]: Next stop strings for the next turn - Tensor: Rewards tensor - Tensor: Done flags tensor

Return type:

EnvironmentReturn

global_post_process_and_metrics(
batch: nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any],
) tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any], dict[str, float | int]][source]#

Computes metrics for this environment given a global rollout batch.

Every rank will run this function, so youโ€™re free to use distributed calculations if youโ€™d prefer for heavy metrics.