Important
You are viewing the NeMo 2.0 documentation. This release introduces significant changes to the API and a new library, NeMo Run. We are currently porting all features from NeMo 1.0 to 2.0. For documentation on previous versions or features not yet available in 2.0, please refer to the NeMo 24.07 documentation.
Metrics#
- class nemo.collections.common.metrics.Perplexity(*args: Any, **kwargs: Any)#
Bases:
Metric
This class computes mean perplexity of distributions in the last dimension of inputs. It is a wrapper around torch.distributions.Categorical.perplexity method. You have to provide either
probs
orlogits
to theupdate()
method. The class computes perplexities for distributions passed toupdate()
method inprobs
orlogits
arguments and averages the perplexities. Reducing results between all workers is done via SUM operations. See PyTorch Lightning Metrics for the metric usage instructions.- Parameters:
dist_sync_on_step – Synchronize metric state across processes at each
forward()
before returning the value at the step.process_group –
- Specify the process group on which synchronization is called. default:
None
(which selects the entire world)
- Specify the process group on which synchronization is called. default:
validate_args – If
True
values ofupdate()
method parameters are checked.logits
has to not contain NaNs andprobs
last dim has to be valid probability distribution.
- compute()#
Returns perplexity across all workers and resets to 0
perplexities_sum
andnum_distributions
.
- full_state_update = True#
- update(probs=None, logits=None)#
Updates
perplexities_sum
andnum_distributions
. :param probs: Atorch.Tensor
which innermost dimension is valid probability distribution. :param logits: Atorch.Tensor
without NaNs.