Metrics#
- class nemo.collections.common.metrics.Perplexity(*args: Any, **kwargs: Any)#
Bases:
MetricThis class computes mean perplexity of distributions in the last dimension of inputs. It is a wrapper around torch.distributions.Categorical.perplexity method. You have to provide either
probsorlogitsto theupdate()method. The class computes perplexities for distributions passed toupdate()method inprobsorlogitsarguments and averages the perplexities. Reducing results between all workers is done via SUM operations. See the TorchMetrics in PyTorch Lightning guide for the metric usage instructions.- Parameters:
dist_sync_on_step – Synchronize metric state across processes at each
forward()before returning the value at the step.process_group –
- Specify the process group on which synchronization is called. default:
None(which selects the entire world)
- Specify the process group on which synchronization is called. default:
validate_args – If
Truevalues ofupdate()method parameters are checked.logitshas to not contain NaNs andprobslast dim has to be valid probability distribution.
- compute()#
Returns perplexity across all workers and resets to 0
perplexities_sumandnum_distributions.
- full_state_update = True#
- update(probs=None, logits=None)#
Updates
perplexities_sumandnum_distributions. :param probs: Atorch.Tensorwhich innermost dimension is valid probability distribution. :param logits: Atorch.Tensorwithout NaNs.