Source code for

# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from polygraphy import mod
from polygraphy.logger import G_LOGGER
from import util as args_util
from import BaseArgs
from import inline, safe

[docs]@mod.export() class ComparatorPostprocessArgs(BaseArgs): """ Comparator Postprocessing: applying postprocessing to outputs. """ def add_parser_args_impl(self): "--postprocess", "--postprocess-func", help="Apply post-processing on the specified outputs prior to comparison. " "Format: --postprocess [<out_name>:]<func>. If no output name is provided, the function is applied to all outputs. " "For example: `--postprocess out0:top-5 out1:top-3` or `--postprocess top-5`. " "Available post-processing functions are: {{top-<K>[,axis=<axis>]: Takes the indices of the K highest values along " "the specified axis (defaulting to the last axis), where K is an integer, e.g. top-5 or top-5,axis=1}}", nargs="+", default=None, dest="postprocess", )
[docs] def parse_impl(self, args): """ Parses command-line arguments and populates the following attributes: Attributes: postprocess (Dict[str, Dict[str, Any]]): Maps postprocessing function names to dictionaries of output names mapped to parameters. For example, this could be something like: :: {"top_k": {"output1": 5, "output2": 6}} """ self.postprocess = args_util.parse_dict_with_default(args_util.get(args, "postprocess")) postprocess = {} topk_key = inline(safe("top_k")) if self.postprocess is not None: postprocess[topk_key] = {} for key, val in self.postprocess.items(): if not val.startswith("top-"): G_LOGGER.critical(f"Invalid post-processing function: {val}. Note: Valid choices are: ['top-<K>'].") k, _, axis = val.partition(",") k = int(k.lstrip("top-")) if axis: postprocess[topk_key][key] = (k, int(axis.lstrip("axis="))) else: postprocess[topk_key][key] = k self.postprocess = postprocess
[docs] def add_to_script_impl(self, script, results_name): """ Args: results_name (str): The name of the variable containing results from ````. Returns: str: The name of the variable containing the post-processed results. This could be the same as the original name. """ if self.postprocess: script.add_import(imports=["PostprocessFunc"], frm="polygraphy.comparator") for func, arg in self.postprocess.items(): script.append_suffix( safe( "\n# Postprocessing\n" "{results} = Comparator.postprocess({results}, PostprocessFunc.{func}({arg}))", arg=arg, func=func, results=results_name, ) ) return results_name