nemo_rl.environments.metrics
#
Module Contents#
Functions#
Function to compute fraction of prompts that have at least one correct answer (reward > 0). |
API#
- nemo_rl.environments.metrics.calculate_pass_rate_per_prompt(
- prompts: torch.Tensor,
- is_correct: torch.Tensor,
Function to compute fraction of prompts that have at least one correct answer (reward > 0).
prompts: tensor (b, s) Tensor of prompts the model used. May be on any device is_correct: tensor (b,) bool-valued label. May be on any device
Returns: pass rate: float