# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function to get absolute maximum of a tensor
Follow numpy fashion, which is more generic as pytorch's
[docs]def reduce_amax(input, axis=None, keepdims=True):
"""Compute the absolute maximum value of a tensor.
Reduces input_tensor along the dimensions given in axis. Unless keepdims is true,
the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true,
the reduced dimensions are retained with length 1.
Gradient computeation is disabled as this function is never meant learning reduces amax
input: Input tensor
axis: The dimensions to reduce. None or int or tuple of ints. If None (the default),
reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).
keepdims: A boolean. If true, retains reduced dimensions with length 1. Default True
granularity: DEPRECTED. specifies if the statistic has to be calculated at tensor or channel granularity
The reduced tensor.
ValueError: Any axis which doesn't make sense or is not supported
ValueError: If unknown granularity is passed in.
output = input.abs()
if axis is None:
output = torch.max(output)
if isinstance(axis, int):
output, _ = torch.max(output, dim=axis, keepdim=keepdims)
if isinstance(axis, tuple) and len(axis) > input.dim():
raise ValueError("Cannot reduce more axes than tensor's dim.")
for i in axis:
output, _ = torch.max(output, dim=i, keepdim=True)
if not keepdims or output.numel() == 1: