# Source code for modulus.metrics.general.reduction

from abc import ABC

import torch
from torch import Tensor

[docs]class WeightedStatistic(ABC):
"""A convenience class for computing weighted statistics of some input

Parameters
----------
weights : Tensor
Weight tensor
"""

def __init__(self, weights: Tensor):
if not torch.all(weights > 0.0).item():
raise ValueError("Expected all weights to be positive.")
self.weights = self._normalize(weights)

def __call__(self, x: Tensor, dim: int):
"""
Convenience method to make sure weights have appropriate shapes.
"""
w = self.weights
if w.ndim == 1:
if x.shape[dim] != len(w):
raise ValueError(
"Expected inputs and weights to have the same size along the reduction dimension but have dimensions"
+ str(len(x[dim]))
+ " and "
+ str(len(w))
+ "."
)
if dim < 0:
dim = x.ndim + dim
for i in range(x.ndim):
if i < dim:
w = w.unsqueeze(0)
elif i > dim:
w = w.unsqueeze(-1)
else:
if not ((x.ndim == w.ndim) and (x.shape[dim] == w.shape[dim])):
raise ValueError(
"Expected inputs and weights to have compatible shapes."
)
return w

def _normalize(self, weights: Tensor) -> Tensor:
"""Normalize unnormalized weights, for convenience

Parameters
----------
weights : Tensor
Unnormalized weights

Returns
-------
Tensor
Normalized weights
"""
return weights / torch.sum(weights)

[docs]class WeightedMean(WeightedStatistic):
"""
Compute weighted mean of some input.

Parameters
----------
weights : Tensor
Weight tensor
"""

def __init__(self, weights: Tensor):
super().__init__(weights)

def __call__(self, x: Tensor, dim: int, keepdims: bool = False) -> Tensor:
"""Compute weighted mean

Parameters
----------
x : Tensor
Input data
dim : int
Dimension to take aggregate
keepdims : bool, optional
Keep aggregated dimension, by default False

Returns
-------
Tensor
Weighted mean
"""
w = super().__call__(x, dim)

[docs]class WeightedVariance(WeightedStatistic):
"""
Compute weighted variance of some input.

Parameters
----------
weights : Tensor
Weight tensor
"""

def __init__(self, weights: Tensor):
super().__init__(weights)
self.wm = WeightedMean(self.weights)

def __call__(self, x: Tensor, dim: int, keepdims: bool = False):
"""Compute weighted variance

Parameters
----------
x : Tensor
Input data
dim : int
Dimension to take aggregate
keepdims : bool, optional
Keep aggregated dimension, by default False

Returns
-------
Tensor
Weighted variance
"""
w = super().__call__(x, dim)

# Compute weighted mean
wm = self.wm(x, dim, keepdims=True)

# Computing scaling for standard deviation
number_of_non_zero_weights = torch.sum(w > 0.0)
scale = (number_of_non_zero_weights - 1.0) / number_of_non_zero_weights
return torch.sum(w * (x - wm) ** 2, dim=dim, keepdims=keepdims) / scale

