deeplearning/modulus/modulus-core/_modules/modulus/metrics/general/wasserstein.html

Source code for modulus.metrics.general.wasserstein

# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from warnings import warn

import torch

from modulus.metrics.general import histogram

Tensor = torch.Tensor


[docs]def wasserstein_from_normal( mu0: Tensor, sigma0: Tensor, mu1: Tensor, sigma1: Tensor ) -> Tensor: """Compute the wasserstein distances between two (possibly multivariate) normal distributions. Parameters ---------- mu0 : Tensor [B (optional), d1] The mean of distribution 0. Can optionally have a batched first dimension. sigma0 : Tensor [B (optional), d1, d2 (optional)] The variance or covariance of distribution 0. If mu0 has a batched dimension, then so must sigma0. If sigma0 is 2 dimension, it is assumed to be a covariance matrix and must be symmetric positive definite. mu1 : Tensor [B (optional), d1] The mean of distribution 1. Can optionally have a batched first dimension. sigma1 : Tensor [B (optional), d1, d2 (optional)] The variance or covariance of distribution 1. If mu1 has a batched dimension, then so must sigma1. If sigma1 is 2 dimension, it is assumed to be a covariance matrix and must be symmetric positive definite. Returns ------- Tensor [B] The wasserstein distance between N(mu0, sigma0) and N(mu1, sigma1) """ mu_ndim = mu0.ndim sigma_ndim = sigma0.ndim if sigma_ndim == mu_ndim: # Univariate normal distribution return (mu0 - mu1) ** 2 + (sigma0 + sigma1 - 2 * torch.sqrt(sigma0 * sigma1)) else: # Multivariate normal distribution # Compute trace(sig0 + sig1 - 2*(sig0^1/2 * sig1 * sig0^1/2)^1/2) first # Compute sig0^1/2 first using eigen decomposition. vals0, vecs0 = torch.linalg.eigh(sigma0) if torch.any(vals0 < 0.0): warn( "Warning! Some eigenvalues are less than zero and matrix is not positive definite." ) vals0 = torch.nn.functional.relu(vals0) sqrt_sig0 = torch.matmul( torch.matmul(vecs0, torch.diag_embed(torch.sqrt(vals0))), vecs0.transpose(-2, -1), ) # Compute C = (sig0^1/2 * sig1 * sig0^1/2) C = torch.matmul(torch.matmul(sqrt_sig0, sigma1), sqrt_sig0) # Compute Csqrt = sqrt( C ) vals0, vecs0 = torch.linalg.eigh(C) if torch.any(vals0 < 0.0): warn( "Warning! Some eigenvalues are less than zero and matrix is not positive definite." ) vals0 = torch.nn.functional.relu(vals0) sqrtC = torch.matmul( torch.matmul(vecs0, torch.diag_embed(torch.sqrt(vals0))), vecs0.transpose(-2, -1), ) # Compute T = tr(sig0 + sig1 - 2* sqrtC) if sigma_ndim > 2: T = torch.vmap(torch.trace)(sigma0 + sigma1 - 2 * sqrtC) else: T = torch.trace(sigma0 + sigma1 - 2 * sqrtC) return torch.norm((mu0 - mu1), p=2, dim=-1) ** 2 + T
[docs]def wasserstein_from_samples(x: Tensor, y: Tensor, bins: int = 10): """1-Wasserstein distances between two sets of samples, computed using the discrete CDF. Parameters ---------- x : Tensor [S, ...] Tensor containing one set of samples. The wasserstein metric will be computed over the first dimension of the data. y : Tensor[S, ...] Tensor containing the second set of samples. The wasserstein metric will be computed over the first dimension of the data. The shapes of x and y must be compatible. bins : int, Optional. Optional number of bins to use in the empirical CDF. Defaults to 10. Returns ------- Tensor The 1-Wasserstein distance between the samples x and y. """ bin_edges, cdf_x = histogram.cdf(x, bins=bins) _, cdf_y = histogram.cdf(y, bins=bin_edges) return wasserstein_from_cdf(bin_edges, cdf_x, cdf_y)
[docs]def wasserstein_from_cdf(bin_edges: Tensor, cdf_x: Tensor, cdf_y: Tensor) -> Tensor: """1-Wasserstein distance between two discrete CDF functions This norm is typically used to compare two different forecast ensembles (for X and Y). Creates a map of distance and does not accumulate over lat/lon regions. Computes .. math:: W(F_X, F_Y) = int[ |F_X(x) - F_Y(x)| ] dx where F_X is the empirical cdf of X and F_Y is the empirical cdf of Y. Parameters ---------- bin_edges : Tensor Tensor containing bin edges. The leading dimension must represent the N+1 bin edges. cdf_x : Tensor Tensor containing a CDF one, defined over bins. The non-zeroth dimensions of bins and cdf must be compatible. cdf_y : Tensor Tensor containing a CDF two, defined over bins. Must be compatible with cdf_x in terms of bins and shape. Returns ------- Tensor The 1-Wasserstein distance between cdf_x and cdf_y """ return torch.sum( torch.abs(cdf_x - cdf_y) * (bin_edges[1, ...] - bin_edges[0, ...]), dim=0 )
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jan 25, 2024.