nemo_rl.environments.math_environment
#
Module Contents#
Classes#
Functions#
API#
- class nemo_rl.environments.math_environment.MathEnvConfig[source]#
Bases:
typing.TypedDict
- num_workers: int#
None
- stop_strings: Optional[list[str]]#
None
- verifier_type: Optional[str]#
None
- class nemo_rl.environments.math_environment.HFVerifyWorker[source]#
Initialization
- verify(
- pred_responses: list[str],
- ground_truths: list[str],
Verify the correctness of the predicted responses against the ground truth.
- Parameters:
pred_responses โ list[str]. The predicted responses from the LLM.
ground_truths โ list[str]. The ground truth responses.
- Returns:
list[float]. The rewards for each predicted response.
- class nemo_rl.environments.math_environment.MultilingualMultichoiceVerifyWorker[source]#
- verify(
- pred_responses: list[str],
- ground_truths: list[str],
Verify the correctness of the predicted responses against the ground truth.
- Parameters:
pred_responses โ list[str]. The predicted responses from the LLM.
ground_truths โ list[str]. The ground truth responses.
- Returns:
list[float]. The rewards for each predicted response.
- class nemo_rl.environments.math_environment.EnglishMultichoiceVerifyWorker[source]#
- verify(
- pred_responses: list[str],
- ground_truths: list[str],
Verify the correctness of the predicted responses against the ground truth.
- Parameters:
pred_responses โ list[str]. The predicted responses from the LLM.
ground_truths โ list[str]. The ground truth responses.
- Returns:
list[float]. The rewards for each predicted response.
- class nemo_rl.environments.math_environment.MathEnvironmentMetadata[source]#
Bases:
typing.TypedDict
- ground_truth: str#
None
- class nemo_rl.environments.math_environment.MathEnvironment( )[source]#
Bases:
nemo_rl.environments.interfaces.EnvironmentInterface
[nemo_rl.environments.math_environment.MathEnvironmentMetadata
]- step(
- message_log_batch: list[nemo_rl.data.interfaces.LLMMessageLogType],
- metadata: list[nemo_rl.environments.math_environment.MathEnvironmentMetadata],
Runs a step in the math environment.
- Parameters:
message_log โ list[list[dict[str, str]]]. A batch of OpenAI-API-like message logs that represent interactions with the LLM.
metadata โ list[MathEnvironmentMetadata]. The grader will use the โground_truthโ key to evaluate correctness.
- Returns:
A tuple containing: - list[dict[str, str]]: Observations/responses batch - list[dict]: Updated metadata - list[str]: Next stop strings for the next turn - Tensor: Rewards tensor - Tensor: Done flags tensor
- Return type:
- global_post_process_and_metrics( ) tuple[nemo_rl.distributed.batched_data_dict.BatchedDataDict[Any], dict[str, float | int]] [source]#
Computes metrics for this environment given a global rollout batch.
Every rank will run this function, so youโre free to use distributed calculations if youโd prefer for heavy metrics.