Metrics

class nemo.collections.common.metrics.Perplexity(*args: Any, **kwargs: Any)

Bases: torchmetrics.Metric

This class computes mean perplexity of distributions in the last dimension of inputs. It is a wrapper around torch.distributions.Categorical.perplexity method. You have to provide either probs or logits to the update() method. The class computes perplexities for distributions passed to update() method in probs or logits arguments and averages the perplexities. Reducing results between all workers is done via SUM operations. See PyTorch Lightning Metrics for the metric usage instructions.

Parameters
  • dist_sync_on_step – Synchronize metric state across processes at each forward() before returning the value at the step.

  • process_group

    Specify the process group on which synchronization is called. default: None (which selects the entire

    world)

  • validate_args – If True values of update() method parameters are checked. logits has to not contain NaNs and probs last dim has to be valid probability distribution.

compute()

Returns perplexity across all workers and resets to 0 perplexities_sum and num_distributions.

full_state_update = True

update(probs=None, logits=None)

Updates perplexities_sum and num_distributions. :param probs: A torch.Tensor which innermost dimension is valid probability distribution. :param logits: A torch.Tensor without NaNs.

Previous Losses
Next Tokenizers
© Copyright 2023-2024, NVIDIA. Last updated on Apr 12, 2024.