Source code for modulus.metrics.general.ensemble_metrics
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from modulus.distributed.manager import DistributedManager
Tensor = torch.Tensor
[docs]class EnsembleMetrics(ABC):
"""Abstract class for ensemble performance related metrics
Can be helpful for distributed and sequential computations of metrics.
input_shape : Union[Tuple[int,...], List]
Shape of input tensors without batched dimension.
device : torch.device, optional
Pytorch device model is on, by default 'cpu'
dtype : torch.dtype, optional
Standard dtype to initialize any tensor with
def __init__(
input_shape: Union[Tuple[int, ...], List[int]],
device: Union[str, torch.device] = "cpu",
dtype: torch.dtype = torch.float32,
self.input_shape = list(input_shape)
self.device = torch.device(device)
self.dtype = dtype
def _check_shape(self, inputs: Tensor) -> None:
Check input shapes for non-batched dimension.
if not all([i == s for (i, s) in zip(inputs.shape[1:], self.input_shape)]):
raise ValueError(
"Expected new input to have compatible shape with existing shapes but got"
+ str(inputs.shape)
+ "and"
+ str(self.input_shape)
+ "."
def __call__(self, *args):
Initial calculation for stored information. Will also overwrite previously stored data.
raise NotImplementedError("Class member must implement a __call__ method.")
[docs] def update(self, *args):
Update initial or stored calculation with additional information.
raise NotImplementedError("Class member must implement an update method.")
[docs] def finalize(self, *args):
Marks the end of the sequential calculation, used to finalize any computations.
raise NotImplementedError("Class member must implement a finalize method.")def _update_mean(
old_sum: Tensor,
old_n: Union[int, Tensor],
inputs: Tensor,
batch_dim: Union[int, None] = 0,
) -> Tuple[Tensor, Union[int, Tensor]]:
"""Updated mean sufficient statistics given new data
This method updates a running sum and number of samples with new (possibly batched)
old_sum : Tensor
Current, or old, running sum
old_n : Union[int, Tensor]
Current, or old, number of samples
input : Tensor
New input to add to current/old sum. May be batched, in which case the batched
dimension must be flagged by passing an int to batch_dim.
batch_dim : Union[int, None], optional
Whether the new inputs contain a batch of new inputs and what dimension they are
stored along. Will reduce all elements if None, by default 0.
Tuple[Tensor, Union[int, Tensor]]
Updated (rolling sum, number of samples)
if batch_dim is None:
inputs = torch.unsqueeze(inputs, 0)
batch_dim = 0
new_sum = old_sum + torch.sum(inputs, dim=batch_dim)
new_n = old_n + inputs.shape[batch_dim]
return new_sum, new_n
[docs]class Mean(EnsembleMetrics):
"""Utility class that computes the mean over a batched or ensemble dimension
This is particularly useful for distributed environments and sequential computation.
input_shape : Union[Tuple, List]
Shape of broadcasted dimensions
def __init__(self, input_shape: Union[Tuple, List], **kwargs):
super().__init__(input_shape, **kwargs)
self.sum = torch.zeros(self.input_shape, dtype=self.dtype, device=self.device)
self.n = torch.zeros([1], dtype=torch.int32, device=self.device)
def __call__(self, inputs: Tensor, dim: int = 0) -> Tensor:
"""Calculate an initial mean
inputs : Tensor
Input data
dim : Int
Dimension of batched data
Mean value
if inputs.device != self.device:
raise AssertionError(
f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."
self.sum = torch.sum(inputs, dim=dim)
self.n = torch.as_tensor([inputs.shape[dim]], device=self.device)
# TODO(Dallas) Move distributed calls into finalize.
if (
DistributedManager.is_initialized() and dist.is_initialized()
): # pragma: no cover
dist.all_reduce(self.sum, op=dist.ReduceOp.SUM)
dist.all_reduce(self.n, op=dist.ReduceOp.SUM)
return self.sum / self.n
[docs] def update(self, inputs: Tensor, dim: int = 0) -> Tensor:
"""Update current mean and essential statistics with new data
inputs : Tensor
Inputs tensor
dim : int
Dimension of batched data
Current mean value
if inputs.device != self.device:
raise AssertionError(
f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."
# TODO(Dallas) Move distributed calls into finalize.
if (
DistributedManager.is_initialized() and dist.is_initialized()
): # pragma: no cover
# Collect local sums, n
sums = torch.sum(inputs, dim=dim)
n = torch.as_tensor([inputs.shape[dim]], device=self.device)
# Reduce
dist.all_reduce(sums, op=dist.ReduceOp.SUM)
dist.all_reduce(n, op=dist.ReduceOp.SUM)
# Update
self.sum += sums
self.n += n
self.sum, self.n = _update_mean(self.sum, self.n, inputs, batch_dim=dim)
return self.sum / self.n
[docs] def finalize(
) -> Tensor:
"""Compute and store final mean
Final mean value
self.mean = self.sum / self.n
return self.meandef _update_var(
old_sum: Tensor,
old_sum2: Tensor,
old_n: Union[int, Tensor],
inputs: Tensor,
batch_dim: Union[int, None] = 0,
) -> Tuple[Tensor, Tensor, Union[int, Tensor]]:
"""Updated variance sufficient statistics given new data
This method updates a running running sum, squared sum, and number of samples with
new (possibly batched) inputs
old_sum : Tensor
Current, or old, running sum
old_sum2 : Tensor
Current, or old, running squared sum
old_n : Union[int, Tensor]
Current, or old, number of samples
inputs : Tensor
New input to add to current/old sum. May be batched, in which case the batched
dimension must be flagged by passing an int to batch_dim.
batch_dim : Union[int, None], optional
Whether the new inputs contain a batch of new inputs and what dimension they are
stored along. Will reduce all elements if None, by default 0.
Tuple[Tensor, Tensor, Union[int, Tensor]]
Updated (rolling sum, rolling squared sum, number of samples)
See "Updating Formulae and a Pairwise Algorithm for Computing Sample Variances"
by Chan et al.
for details.
if batch_dim is None:
inputs = torch.unsqueeze(inputs, 0)
batch_dim = 0
temp_n = inputs.shape[batch_dim]
temp_sum = torch.sum(inputs, dim=batch_dim)
temp_sum2 = torch.sum((inputs - temp_sum / temp_n) ** 2, dim=batch_dim)
delta = old_sum * temp_n / old_n - temp_sum
new_sum = old_sum + temp_sum
new_sum2 = old_sum2 + temp_sum2
new_sum2 += old_n / temp_n / (old_n + temp_n) * delta**2
new_n = old_n + temp_n
return new_sum, new_sum2, new_n
[docs]class Variance(EnsembleMetrics):
"""Utility class that computes the variance over a batched or ensemble dimension
This is particularly useful for distributed environments and sequential computation.
input_shape : Union[Tuple, List]
Shape of broadcasted dimensions
See "Updating Formulae and a Pairwise Algorithm for Computing Sample Variances"
by Chan et al.
for details.
def __init__(self, input_shape: Union[Tuple, List], **kwargs):
super().__init__(input_shape, **kwargs)
self.n = torch.zeros([1], dtype=torch.int32, device=self.device)
self.sum = torch.zeros(self.input_shape, dtype=self.dtype, device=self.device)
self.sum2 = torch.zeros(self.input_shape, dtype=self.dtype, device=self.device)
def __call__(self, inputs: Tensor, dim: int = 0) -> Tensor:
"""Calculate an initial variance
inputs : Tensor
Input data
dim : Int
Dimension of batched data
Unbiased variance values
if inputs.device != self.device:
raise AssertionError(
f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."
self.sum = torch.sum(inputs, dim=dim)
self.n = torch.as_tensor([inputs.shape[0]], device=self.device)
if (
DistributedManager.is_initialized() and dist.is_initialized()
): # pragma: no cover
# Compute mean and send around.
dist.all_reduce(self.sum, op=dist.ReduceOp.SUM)
dist.all_reduce(self.n, op=dist.ReduceOp.SUM)
self.sum2 = torch.sum((inputs - self.sum / self.n) ** 2, dim=dim)
dist.all_reduce(self.sum2, op=dist.ReduceOp.SUM)
self.sum2 = torch.sum((inputs - self.sum / self.n) ** 2, dim=dim)
if self.n < 2.0:
return self.sum2
return self.sum2 / (self.n - 1.0)
[docs] def update(self, inputs: Tensor) -> Tensor:
"""Update current variance value and essential statistics with new data
inputs : Tensor
Input data
Unbiased variance tensor
if inputs.device != self.device:
raise AssertionError(
f"Input device, {inputs.device}, and Module device, {self.device}, must be the same."
new_n = torch.as_tensor([inputs.shape[0]], device=self.device)
new_sum = torch.sum(inputs, dim=0)
# TODO(Dallas) Move distributed calls into finalize.
if (
DistributedManager.is_initialized() and dist.is_initialized()
): # pragma: no cover
dist.all_reduce(new_n, op=dist.ReduceOp.SUM)
dist.all_reduce(new_sum, op=dist.ReduceOp.SUM)
new_sum2 = torch.sum((inputs - new_sum / new_n) ** 2, dim=0)
dist.all_reduce(new_sum2, op=dist.ReduceOp.SUM)
# Calculate new statistics
new_sum2 = torch.sum((inputs - new_sum / new_n) ** 2, dim=0)
delta = self.sum * new_n / self.n - new_sum
# Update
self.sum += new_sum
self.sum2 += new_sum2
self.sum2 += self.n / new_n / (self.n + new_n) * (delta) ** 2
self.n += new_n
if self.n < 2.0:
return self.sum2
return self.sum2 / (self.n - 1.0)@property
def mean(self) -> Tensor:
"""Mean value"""
return self.sum / self.n
[docs] def finalize(self, std: bool = False) -> Tuple[Tensor, Tensor]:
"""Compute and store final mean and unbiased variance / std
std : bool, optional
Compute standard deviation, by default False
Final (mean, variance/std) value
if not (self.n > 1.0):
raise ValueError(
"Error! In order to finalize, there needs to be at least 2 samples."
self.var = self.sum2 / (self.n - 1.0)
if std:
self.std = torch.sqrt(self.var)
return self.std
return self.var