Source code for pytorch_quantization.utils.reduce_amax

# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""Function to get absolute maximum of a tensor
Follow numpy fashion, which is more generic as pytorch's

import torch

[docs]def reduce_amax(input, axis=None, keepdims=True): """Compute the absolute maximum value of a tensor. Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1. .. note:: Gradient computeation is disabled as this function is never meant learning reduces amax Args: input: Input tensor axis: The dimensions to reduce. None or int or tuple of ints. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)). keepdims: A boolean. If true, retains reduced dimensions with length 1. Default True granularity: DEPRECTED. specifies if the statistic has to be calculated at tensor or channel granularity Returns: The reduced tensor. Raises: ValueError: Any axis which doesn't make sense or is not supported ValueError: If unknown granularity is passed in. """ with torch.no_grad(): output = input.abs() if axis is None: output = torch.max(output) else: if isinstance(axis, int): output, _ = torch.max(output, dim=axis, keepdim=keepdims) else: if isinstance(axis, tuple) and len(axis) > input.dim(): raise ValueError("Cannot reduce more axes than tensor's dim.") for i in axis: output, _ = torch.max(output, dim=i, keepdim=True) if not keepdims or output.numel() == 1: output.squeeze_() return output