Modulus Metrics

Modulus provides several general and domain-specific metric calculations you can leverage in your custom training and inference workflows. These metrics are optimized to operate on PyTorch tensors.

General Metrics and Statistical Methods

Below is a summary of general purpose statistical methods and metrics that are available:

Metric

Description

modulus.metrics.general.mse.mse Mean Squared error between two tensors
modulus.metrics.general.mse.rmse Root Mean Squared error between two tensors
modulus.metrics.general.histogram.histogram Histogram of a set of tensors over the leading dimension
modulus.metrics.general.histogram.cdf Cumulative density function of a set of tensors over the leading dimension
modulus.metrics.general.histogram.normal_cdf Cumulative density function of a normal variable with given mean and standard deviation
modulus.metrics.general.histogram.normal_pdf Probability density function of a normal variable with given mean and standard deviation
modulus.metrics.general.calibration.find_rank Find the rank of the observation with respect to the given counts and bins
modulus.metrics.general.calibration.rank_probability_score Rank Probability Score for the passed ranks
modulus.metrics.general.entropy.entropy_from_counts Computes the statistical entropy of a random variable using a histogram.
modulus.metrics.general.entropy.relative_entropy_from_counts Computes the relative statistical entropy, or KL Divergence of two random variables using their histograms.
modulus.metrics.general.crps.crps Local Continuous Ranked Probability Score (CRPS) by computing a histogram and CDF of the predictions
modulus.metrics.general.wasserstein.wasserstein 1-Wasserstein distance between two discrete CDF functions
modulus.metrics.general.reduction.WeightedMean Weighted Mean
modulus.metrics.general.reduction.WeightedStatistic Weighted Statistic
modulus.metrics.general.reduction.WeightedVariance Weighted Variance

Below shows some examples of how to use these metrics in your own workflows.

To compute RMSE metric:

Copy
Copied!
            

>>> import torch >>> from modulus.metrics.general.mse import rmse >>> pred_tensor = torch.randn(16, 32) >>> targ_tensor = torch.randn(16, 32) >>> rmse(pred_tensor, targ_tensor) tensor(1.4781)

To compute the histogram of samples:

Copy
Copied!
            

>>> import torch >>> from modulus.metrics.general import histogram >>> x = torch.randn(1_000) >>> bins, counts = histogram.histogram(x, bins = 10) >>> bins tensor([-3.7709, -3.0633, -2.3556, -1.6479, -0.9403, -0.2326, 0.4751, 1.1827, 1.8904, 2.5980, 3.3057]) >>> counts tensor([ 3, 9, 43, 150, 227, 254, 206, 81, 24, 3])

To use compute the continuous density function (CDF):

Copy
Copied!
            

>>> bins, cdf = histogram.cdf(x, bins = 10) >>> bins tensor([-3.7709, -3.0633, -2.3556, -1.6479, -0.9403, -0.2326, 0.4751, 1.1827, 1.8904, 2.5980, 3.3057]) >>> cdf tensor([0.0030, 0.0120, 0.0550, 0.2050, 0.4320, 0.6860, 0.8920, 0.9730, 0.9970, 1.0000])

To use the histogram for statistical entropy calculations:

Copy
Copied!
            

>> from modulus.metrics.general import entropy >>> entropy.entropy_from_counts(counts, bins) tensor(0.4146)

Many of the functions operate over batches. For example, if one has a collection of two dimensional data, then we can compute the histogram over the collection:

Copy
Copied!
            

>>> import torch >>> from modulus.metrics.general import histogram, entropy >>> x = torch.randn((1_000, 3, 3)) >>> bins, counts = histogram.histogram(x, bins = 10) >>> bins.shape, counts.shape (torch.Size([11, 3, 3]), torch.Size([10, 3, 3])) >>> entropy.entropy_from_counts(counts, bins) tensor([[0.5162, 0.4821, 0.3976], [0.5099, 0.5309, 0.4519], [0.4580, 0.4290, 0.5121]])

There are additional metrics to compute differences between distributions: Ranks, Continuous Rank Probability Skill, and Wasserstein metric.

CRPS:

Copy
Copied!
            

>>> from modulus.metrics.general import crps >>> x = torch.randn((1_000,1)) >>> y = torch.randn((1,)) >>> crps.crps(x, y) tensor([0.8023])

Ranks:

Copy
Copied!
            

>>> from modulus.metrics.general import histogram, calibration >>> x = torch.randn((1_000,1)) >>> y = torch.randn((1,)) >>> bins, counts = histogram.histogram(x, bins = 10) >>> ranks = calibration.find_rank(bins, counts, y) tensor([0.1920])

Wasserstein Metric:

Copy
Copied!
            

>>> from modulus.metrics.general import wasserstein, histogram >>> x = torch.randn((1_000,1)) >>> y = torch.randn((1_000,1)) >>> bins, cdf_x = histogram.cdf(x) >>> bins, cdf_y = histogram.cdf(y, bins = bins) >>> wasserstein(bins, cdf_x, cdf_y) >>> wasserstein.wasserstein(bins, cdf_x, cdf_y) tensor([0.0459])

Weighted Reductions

Modulus currently offers classes for weighted mean and variance reductions.

Copy
Copied!
            

>>> from modulus.metrics.general import reduction >>> x = torch.randn((1_000,)) >>> weights = torch.cos(torch.linspace(-torch.pi/4, torch.pi/4, 1_000)) >>> wm = reduction.WeightedMean(weights) >>> wm(x, dim = 0) tensor(0.0365) >>> wv = reduction.WeightedVariance(weights) >>> wv(x, dim = 0) tensor(1.0148)

Online Statistical Methods

Modulus current offers routines for computing online, or out-of-memory, means, variances, and histograms.

Copy
Copied!
            

>>> import torch >>> from modulus.metrics.general import ensemble_metrics as em >>> x = torch.randn((1_000, 2)) # Interpret as 1_000 members of size (2,). >>> torch.mean(x, dim = 0) # Compute mean of entire data. tensor([-0.0545, 0.0267]) >>> x0, x1 = x[:500], x[500:] # Split data into two. >>> M = em.Mean(input_shape = (2,)) # Must pass shape of data >>> M(x0) # Compute mean of initial batch. tensor([-0.0722, 0.0414]) >>> M.update(x1) # Update with second batch. tensor([-0.0545, 0.0267])

To compute the Anomaly Correlation Coefficient, a metric widely used in weather and climate sciences:

Copy
Copied!
            

>>> import torch >>> import numpy as np >>> from modulus.metrics.climate.acc import acc >>> channels = 1 >>> img_shape = (32, 64) >>> time_means = np.pi / 2 * np.ones((channels, img_shape[0], img_shape[1]), dtype=np.float32) >>> x = np.linspace(-180, 180, img_shape[1], dtype=np.float32) >>> y = np.linspace(-90, 90, img_shape[0], dtype=np.float32) >>> xv, yv = np.meshgrid(x, y) >>> pred_tensor_np = np.cos(2 * np.pi * yv / (180)) >>> targ_tensor_np = np.cos(np.pi * yv / (180)) >>> pred_tensor = torch.from_numpy(pred_tensor_np).expand(channels, -1, -1) >>> targ_tensor = torch.from_numpy(targ_tensor_np).expand(channels, -1, -1) >>> means_tensor = torch.from_numpy(time_means) >>> lat = torch.from_numpy(y) >>> acc(pred_tensor, targ_tensor, means_tensor, lat) tensor([0.9841])

modulus.metrics.general.mse.mse(pred: Tensor, target: Tensor, dim: Optional[int] = None) → Union[Tensor, float][source]

Calculates Mean Squared error between two tensors

Parameters
  • pred (Tensor) – Input prediction tensor

  • target (Tensor) – Target tensor

  • dim (int, optional) – Reduction dimension. When None the losses are averaged or summed over all observations, by default None

Returns

Mean squared error value(s)

Return type

Union[Tensor, float]

modulus.metrics.general.mse.rmse(pred: Tensor, target: Tensor, dim: Optional[int] = None) → Union[Tensor, float][source]

Calculates Root mean Squared error between two tensors

Parameters
  • pred (Tensor) – Input prediction tensor

  • target (Tensor) – Target tensor

  • dim (int, optional) – Reduction dimension. When None the losses are averaged or summed over all observations, by default None

Returns

Root mean squared error value(s)

Return type

Union[Tensor, float]

class modulus.metrics.general.histogram.Histogram(input_shape: Tuple[int], bins: Union[int, Tensor] = 10, tol: float = 0.01, **kwargs)[source]

Bases: EnsembleMetrics

Convenience class for computing histograms, possibly as a part of a distributed or iterative environment

Parameters
  • input_shape (Tuple[int]) – Input data shape

  • bins (Union[int, Tensor], optional) – Initial bin edges or number of bins to use, by default 10

  • tol (float, optional) – Bin edge tolerance, by default 1e-3

finalize(cdf: bool = False) → Tuple[Tensor, Tensor][source]

Finalize the histogram, which computes the pdf and cdf

Parameters

cdf (bool, optional) – Compute a cumulative density function; will calculate probability density function otherwise, by default False

Returns

The calculated (bin edges [N+1, …], PDF or CDF [N, …]) tensors

Return type

Tuple[Tensor, Tensor]

update(input: Tensor) → Tuple[Tensor, Tensor][source]

Update current histogram with new data

Parameters

inputs (Tensor) – Input data tensor [B, …]

Returns

The calculated (bin edges [N+1, …], counts [N, …]) tensors

Return type

Tuple[Tensor, Tensor]

modulus.metrics.general.histogram.cdf(*inputs: Tensor, bins: Union[int, Tensor] = 10, counts: Union[None, Tensor] = None, verbose: bool = False) → Tuple[Tensor, Tensor][source]

Computes the cumulative density function of a set of tensors over the leading dimension

This function will compute CDF bins of given input tensors. If existing bins or count tensors are supplied, this function will update these existing statistics with the new data.

Parameters
  • inputs ((Tensor ...)) – Input data tensor(s) [B, …]

  • bins (Union[int, Tensor], optional) – Either the number of bins, or a tensor of bin edges with dimension [N+1, …] where N is the number of bins. If counts is passed, then bins is interpreted to be the bin edges for the counts tensor, by default 10

  • counts (Union[None, Tensor], optional) – Existing count tensor to combine results with. Must have dimensions [N, …] where N is the number of bins. Passing a tensor may also require recomputing the passed bins to make sure inputs and bins are compatible, by default None

  • verbose (bool, optional) – Verbose printing, by default False

Returns

The calculated (bin edges [N+1, …], cdf [N, …]) tensors

Return type

Tuple[Tensor, Tensor]

modulus.metrics.general.histogram.histogram(*inputs: Tensor, bins: Union[int, Tensor] = 10, counts: Union[None, Tensor] = None, verbose: bool = False) → Tuple[Tensor, Tensor][source]

Computes the histogram of a set of tensors over the leading dimension

This function will compute bin edges and bin counts of given input tensors. If existing bin edges or count tensors are supplied, this function will update these existing statistics with the new data.

Parameters
  • inputs ((Tensor ...)) – Input data tensor(s) [B, …]

  • bins (Union[int, Tensor], optional) – Either the number of bins, or a tensor of bin edges with dimension [N+1, …] where N is the number of bins. If counts is passed, then bins is interpreted to be the bin edges for the counts tensor, by default 10

  • counts (Union[None, Tensor], optional) – Existing count tensor to combine results with. Must have dimensions [N, …] where N is the number of bins. Passing a tensor may also require recomputing the passed bins to make sure inputs and bins are compatible, by default None

  • verbose (bool, optional) – Verbose printing, by default False

Returns

The calculated (bin edges [N+1, …], count [N, …]) tensors

Return type

Tuple[Tensor, Tensor]

modulus.metrics.general.histogram.normal_cdf(mean: Tensor, std: Tensor, bin_edges: Tensor, grid: str = 'midpoint') → Tensor[source]

Computes the cumulative density function of a normal variable with given mean and standard deviation. This CDF is given at the locations given by the midpoint of the bin_edges.

This function uses the standard formula:

(1)\[\frac{1}{2} ( 1 + erf( \frac{x-mean}{std \sqrt{2}}) ) )\]

where erf is the error function.

Parameters
  • mean (Tensor) – mean tensor

  • std (Tensor) – standard deviation tensor

  • bins (Tensor) – The tensor of bin edges with dimension [N+1, …] where N is the number of bins.

  • grid (str) – A string that indicates where in the bins should the cdf be defined. Can be one of {“mids”, “left”, “right”}.

Returns

The calculated cdf tensor with dimension [N, …]

Return type

Tensor

modulus.metrics.general.histogram.normal_pdf(mean: Tensor, std: Tensor, bin_edges: Tensor, grid: str = 'midpoint') → Tensor[source]

Computes the probability density function of a normal variable with given mean and standard deviation. This PDF is given at the locations given by the midpoint of the bin_edges.

This function uses the standard formula:

(2)\[\frac{1}{\sqrt{2*\pi} std } \exp( -\frac{1}{2} (\frac{x-mean}{std})^2 )\]

where erf is the error function.

Parameters
  • mean (Tensor) – mean tensor

  • std (Tensor) – standard deviation tensor

  • bins (Tensor) – The tensor of bin edges with dimension [N+1, …] where N is the number of bins.

  • grid (str) – A string that indicates where in the bins should the cdf be defined. Can be one of {“midpoint”, “left”, “right”}.

Returns

The calculated cdf tensor with dimension [N, …]

Return type

Tensor

modulus.metrics.general.entropy.entropy_from_counts(p: Tensor, bin_edges: Tensor, normalized=True) → Tensor[source]

Computes the Statistical Entropy of a random variable using a histogram.

Uses the formula:

(3)\[Entropy(X) = \int p(x) * \log( p(x) ) dx\]

Parameters
  • p (Tensor) – Tensor [N, …] containing counts/pdf, defined over bins. The non-zeroth dimensions of bin_edges and p must be compatible.

  • bins_edges (Tensor) – Tensor [N+1, …] containing bin edges. The leading dimension must represent the N+1 bin edges.

  • normalized (Bool, Optional) – Boolean flag determining whether the returned statistical entropy is normalized. Normally the entropy for a compact bounded probability distribution is bounded between a pseudo-dirac distribution, ent_min, and a uniform distribution, ent_max. This normalization transforms the entropy from [ent_min, ent_max] to [0, 1]

Returns

Tensor containing the Information/Statistical Entropy

Return type

Tensor

modulus.metrics.general.entropy.relative_entropy_from_counts(p: Tensor, q: Tensor, bin_edges: Tensor) → Tensor[source]

Computes the Relative Statistical Entropy, or KL Divergence of two random variables using their histograms.

Uses the formula:

(4)\[Entropy(X) = \int p(x) * \log( p(x)/q(x) ) dx\]

Parameters
  • p (Tensor) – Tensor [N, …] containing counts/pdf, defined over bins. The non-zeroth dimensions of bin_edges and p must be compatible.

  • q (Tensor) – Tensor [N, …] containing counts/pdf, defined over bins. The non-zeroth dimensions of bin_edges and q must be compatible.

  • bins_edges (Tensor) – Tensor [N+1, …] containing bin edges. The leading dimension must represent the N+1 bin edges.

Returns

Map of Statistical Entropy

Return type

Tensor

modulus.metrics.general.calibration.find_rank(bin_edges: Tensor, counts: Tensor, obs: Union[Tensor, ndarray]) → Tensor[source]

Finds the rank of the observation with respect to the given counts and bins.

Parameters
  • bins_edges (Tensor) – Tensor [N+1, …] containing bin edges. The leading dimension must represent the N+1 bin edges.

  • counts (Tensor) – Tensor [N, …] containing counts, defined over bins. The non-zeroth dimensions of bins and counts must be compatible.

  • obs (Union[Tensor, np.ndarray]) – Tensor or array containing an observation over which the ranks is computed with respect to.

Returns

Tensor of rank for eac of the batched dimensions […]

Return type

Tensor

modulus.metrics.general.calibration.rank_probability_score(ranks: Tensor) → Tensor[source]

Computes the Rank Probability Score for the passed ranks. Internally, this creates a histogram for the ranks and computes the Rank Probability Score (RPS) using the histogram.

With the histogram the RPS is computed as

(5)\[\int_0^1 (F_X(x) - F_U(x))^2 dx\]

where F represents a cumulative distribution function, X represents the rank distribution and U represents a Uniform distribution.

For computation of the ranks, use _find_rank.

Parameters

ranks (Tensor) – Tensor [B, …] containing ranks, where the leading dimension represents the batch, or ensemble, dimension. The non-zeroth dimensions are batched over.

Returns

Tensor of RPS for each of the batched dimensions […]

Return type

Tensor

modulus.metrics.general.crps.crps(pred: Tensor, obs: Union[Tensor, ndarray], dim: int = 0, method: str = 'kernel') → Tensor[source]

Computes the local Continuous Ranked Probability Score (CRPS) by either computing a histogram and CDF of the predictions, or using the kernel definition.

Creates a map of CRPS and does not accumulate over lat/lon regions.

Computes:

(6)\[CRPS(x, y) = E[X-y] - 0.5*E[X-X'] if B < 100 CRPS(X, y) = int[ (F(x) - 1[x - y])^2 ] dx otherwise\]

where F is the empirical cdf of X.

Parameters
  • pred (Tensor) – Tensor containing the ensemble predictions.

  • obs (Union[Tensor, np.ndarray]) – Tensor or array containing an observation over which the CRPS is computed with respect to.

  • dim (int, Optional) – Dimension with which to calculate the CRPS over, the ensemble dimension. Assumed to be zero.

  • method (str, Optional) – The method to calculate the crps. Can either be “kernel” or “histogram”.

Returns

Map of CRPS

Return type

Tensor

class modulus.metrics.general.ensemble_metrics.EnsembleMetrics(input_shape: Union[Tuple[int, ...], List[int]], device: Union[str, device] = 'cpu', dtype: dtype = torch.float32)[source]

Bases: ABC

Abstract class for ensemble performance related metrics

Can be helpful for distributed and sequential computations of metrics.

Parameters
  • input_shape (Union[Tuple[int,...], List]) – Shape of input tensors without batched dimension.

  • device (torch.device, optional) – Pytorch device model is on, by default ‘cpu’

  • dtype (torch.dtype, optional) – Standard dtype to initialize any tensor with

finalize(*args)[source]

Marks the end of the sequential calculation, used to finalize any computations.

update(*args)[source]

Update initial or stored calculation with additional information.

class modulus.metrics.general.ensemble_metrics.Mean(input_shape: Union[Tuple, List], **kwargs)[source]

Bases: EnsembleMetrics

Utility class that computes the mean over a batched or ensemble dimension

This is particularly useful for distributed environments and sequential computation.

Parameters

input_shape (Union[Tuple, List]) – Shape of broadcasted dimensions

finalize() → Tensor[source]

Compute and store final mean

Returns

Final mean value

Return type

Tensor

update(inputs: Tensor, dim: int = 0) → Tensor[source]

Update current mean and essential statistics with new data

Parameters
  • inputs (Tensor) – Inputs tensor

  • dim (int) – Dimension of batched data

Returns

Current mean value

Return type

Tensor

class modulus.metrics.general.ensemble_metrics.Variance(input_shape: Union[Tuple, List], **kwargs)[source]

Bases: EnsembleMetrics

Utility class that computes the variance over a batched or ensemble dimension

This is particularly useful for distributed environments and sequential computation.

Parameters

input_shape (Union[Tuple, List]) – Shape of broadcasted dimensions

Note

See “Updating Formulae and a Pairwise Algorithm for Computing Sample Variances” by Chan et al. http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf for details.

finalize(std: bool = False) → Tuple[Tensor, Tensor][source]

Compute and store final mean and unbiased variance / std

Parameters

std (bool, optional) – Compute standard deviation, by default False

Returns

Final (mean, variance/std) value

Return type

Tensor

property mean: Tensor

Mean value

update(inputs: Tensor) → Tensor[source]

Update current variance value and essential statistics with new data

Parameters

inputs (Tensor) – Input data

Returns

Unbiased variance tensor

Return type

Tensor

class modulus.metrics.general.reduction.WeightedMean(weights: Tensor)[source]

Bases: WeightedStatistic

Compute weighted mean of some input.

Parameters

weights (Tensor) – Weight tensor

class modulus.metrics.general.reduction.WeightedStatistic(weights: Tensor)[source]

Bases: ABC

A convenience class for computing weighted statistics of some input

Parameters

weights (Tensor) – Weight tensor

class modulus.metrics.general.reduction.WeightedVariance(weights: Tensor)[source]

Bases: WeightedStatistic

Compute weighted variance of some input.

Parameters

weights (Tensor) – Weight tensor

modulus.metrics.general.wasserstein.wasserstein_from_cdf(bin_edges: Tensor, cdf_x: Tensor, cdf_y: Tensor) → Tensor[source]

1-Wasserstein distance between two discrete CDF functions

This norm is typically used to compare two different forecast ensembles (for X and Y). Creates a map of distance and does not accumulate over lat/lon regions. Computes

(7)\[W(F_X, F_Y) = int[ |F_X(x) - F_Y(x)| ] dx\]

where F_X is the empirical cdf of X and F_Y is the empirical cdf of Y.

Parameters
  • bin_edges (Tensor) – Tensor containing bin edges. The leading dimension must represent the N+1 bin edges.

  • cdf_x (Tensor) – Tensor containing a CDF one, defined over bins. The non-zeroth dimensions of bins and cdf must be compatible.

  • cdf_y (Tensor) – Tensor containing a CDF two, defined over bins. Must be compatible with cdf_x in terms of bins and shape.

Returns

The 1-Wasserstein distance between cdf_x and cdf_y

Return type

Tensor

modulus.metrics.general.wasserstein.wasserstein_from_normal(mu0: Tensor, sigma0: Tensor, mu1: Tensor, sigma1: Tensor) → Tensor[source]

Compute the wasserstein distances between two (possibly multivariate) normal distributions.

Parameters
  • mu0 (Tensor [B (optional), d1]) – The mean of distribution 0. Can optionally have a batched first dimension.

  • sigma0 (Tensor [B (optional), d1, d2 (optional)]) – The variance or covariance of distribution 0. If mu0 has a batched dimension, then so must sigma0. If sigma0 is 2 dimension, it is assumed to be a covariance matrix and must be symmetric positive definite.

  • mu1 (Tensor [B (optional), d1]) – The mean of distribution 1. Can optionally have a batched first dimension.

  • sigma1 (Tensor [B (optional), d1, d2 (optional)]) – The variance or covariance of distribution 1. If mu1 has a batched dimension, then so must sigma1. If sigma1 is 2 dimension, it is assumed to be a covariance matrix and must be symmetric positive definite.

Returns

The wasserstein distance between N(mu0, sigma0) and N(mu1, sigma1)

Return type

Tensor [B]

modulus.metrics.general.wasserstein.wasserstein_from_samples(x: Tensor, y: Tensor, bins: int = 10)[source]

1-Wasserstein distances between two sets of samples, computed using the discrete CDF.

Parameters
  • x (Tensor [S, ...]) – Tensor containing one set of samples. The wasserstein metric will be computed over the first dimension of the data.

  • y (Tensor[S, ...]) – Tensor containing the second set of samples. The wasserstein metric will be computed over the first dimension of the data. The shapes of x and y must be compatible.

  • bins (int, Optional.) – Optional number of bins to use in the empirical CDF. Defaults to 10.

Returns

The 1-Wasserstein distance between the samples x and y.

Return type

Tensor

modulus.metrics.climate.acc.acc(pred: Tensor, target: Tensor, climatology: Tensor, lat: Tensor) → Tensor[source]

Calculates the Anomaly Correlation Coefficient

Parameters
  • pred (Tensor) – […, H, W] Predicted tensor on a lat/long grid

  • target (Tensor) – […, H, W] Target tensor on a lat/long grid

  • climatology (Tensor) – […, H, W] climatology tensor

  • lat (Tensor) – [H] latitude tensor

Returns

ACC values for each field

Return type

Tensor

modulus.metrics.climate.efi.efi(bin_edges: Tensor, counts: Tensor, quantiles: Tensor) → Tensor[source]

Compute the Extreme Forecast Index for the given histogram.

The histogram is assumed to correspond with the given quantiles. That is, the bin midpoints must align with the quantiles.

Parameters
  • bin_edges (Tensor) – The bin edges of the histogram over which the data distribution is defined. Assumed to be monotonically increasing but not evenly spaced.

  • counts (Tensor) – The counts of the histogram over which the data distributed is defined. Not assumed to be normalized.

  • quantiles (Tensor) – The quantiles of the climatological or reference distribution. The quantiles must match the midpoints of the histogram bins.

  • details. (See modulus/metrics/climate/efi for more) –

modulus.metrics.climate.efi.efi_gaussian(pred_cdf: Tensor, bin_edges: Tensor, climatology_mean: Tensor, climatology_std: Tensor) → Tensor[source]

Calculates the Extreme Forecast Index (EFI) for an ensemble forecast against a climatological distribution.

Parameters
  • pred_cdf (Tensor) – Cumulative distribution function of predictions of shape [N, …] where N is the number of bins. This cdf must be defined over the passed bin_edges.

  • bin_edges (Tensor) – Tensor of bin edges with shape [N+1, …] where N is the number of bins.

  • climatology_mean (Tensor) – Tensor of climatological mean with shape […]

  • climatology_std (Tensor) – Tensor of climatological std with shape […]

Returns

EFI values of each of the batched dimensions.

Return type

Tensor

modulus.metrics.climate.efi.normalized_entropy(pred_pdf: Tensor, bin_edges: Tensor, climatology_pdf: Tensor) → Tensor[source]

Calculates the relative entropy, or surprise, of using the prediction distribution with respect to the climatology distribution.

Parameters
  • pred_cdf (Tensor) – Cumulative distribution function of predictions of shape [N, …] where N is the number of bins. This cdf must be defined over the passed bin_edges.

  • bin_edges (Tensor) – Tensor of bin edges with shape [N+1, …] where N is the number of bins.

  • climatology_pdf (Tensor) – Tensor of climatological probability function shape [N, …]

Returns

Relative Entropy values of each of the batched dimensions.

Return type

Tensor

modulus.metrics.climate.reduction.global_mean(x: Tensor, lat: Tensor, keepdims: bool = False) → Tensor[source]

Computes global mean

This function computs the global mean of a lat/lon grid by weighting over the latitude direction and then averaging over longitude

Parameters
  • x (Tensor) – The lat/lon tensor […, H, W] over which the mean will be computed

  • lat (Tensor) – A one-dimension tensor [H] representing the latitudes at which the function will return weights for

  • keepdims (bool, optional) – Keep aggregated dimension, by default False

Returns

Global mean tensor

Return type

Tensor

modulus.metrics.climate.reduction.global_var(x: Tensor, lat: Tensor, std: bool = False, keepdims: bool = False) → Tensor[source]

Computes global variance

This function computs the global variance of a lat/lon grid by weighting over the latitude direction and then averaging over longitude

Parameters
  • x (Tensor) – The lat/lon tensor […, H, W] over which the variance will be computed

  • lat (Tensor) – A one-dimension tensor [H] representing the latitudes at which the function will return weights for

  • std (bool, optional) – Return global standard deviation, by default False

  • keepdims (bool, optional) – Keep aggregated dimension, by default False

Returns

Global variance tensor

Return type

Tensor

modulus.metrics.climate.reduction.zonal_mean(x: Tensor, lat: Tensor, dim: int = -2, keepdims: bool = False) → Tensor[source]

Computes zonal mean, weighting over the latitude direction that is specified by dim

Parameters
  • x (Tensor) – The tensor […, H, W] over which the mean will be computed

  • lat (Tensor) – A one-dimension tensor representing the latitudes at which the function will return weights for

  • dim (int, optional) – The int specifying which dimension of x the reduction will occur, by default -2

  • keepdims (bool, optional) – Keep aggregated dimension, by default False

Returns

Zonal mean tensor of x over the latitude dimension

Return type

Tensor

modulus.metrics.climate.reduction.zonal_var(x: Tensor, lat: Tensor, std: bool = False, dim: int = -2, keepdims: bool = False) → Tensor[source]

Computes zonal variance, weighting over the latitude direction

Parameters
  • x (Tensor) – The tensor […, H, W] over which the variance will be computed

  • lat (Tensor) – A one-dimension tensor [H] representing the latitudes at which the function will return weights for

  • std (bool, optional) – Return zonal standard deviation, by default False

  • dim (int, optional) – The int specifying which dimension of x the reduction will occur, by default -2

  • keepdims (bool, optional) – Keep aggregated dimension, by default False

Returns

The variance (or standard deviation) of x over the latitude dimension

Return type

Tensor

Previous Modulus Datapipes
Next Modulus Deploy
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.