nemo_rl.environments.metrics#

Module Contents#

Functions#

calculate_pass_rate_per_prompt

Function to compute fraction of prompts that have at least one correct answer (reward > 0).

API#

nemo_rl.environments.metrics.calculate_pass_rate_per_prompt(prompts, is_correct)[source]#

Function to compute fraction of prompts that have at least one correct answer (reward > 0).

prompts: tensor (b, s) Tensor of prompts the model used. May be on any device is_correct: tensor (b,) bool-valued label. May be on any device

Returns: pass rate: float