Source code for polygraphy.tools.args.comparator.postprocess
## SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.# SPDX-License-Identifier: Apache-2.0## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.#frompolygraphyimportmodfrompolygraphy.loggerimportG_LOGGERfrompolygraphy.tools.argsimportutilasargs_utilfrompolygraphy.tools.args.baseimportBaseArgsfrompolygraphy.tools.scriptimportinline,safe
[docs]@mod.export()classComparatorPostprocessArgs(BaseArgs):""" Comparator Postprocessing: applying postprocessing to outputs. """defadd_parser_args_impl(self):self.group.add_argument("--postprocess","--postprocess-func",help="Apply post-processing on the specified outputs prior to comparison. ""Format: --postprocess [<out_name>:]<func>. If no output name is provided, the function is applied to all outputs. ""For example: `--postprocess out0:top-5 out1:top-3` or `--postprocess top-5`. ""Available post-processing functions are: {{top-<K>[,axis=<axis>]: Takes the indices of the K highest values along ""the specified axis (defaulting to the last axis), where K is an integer. ""For example: `--postprocess top-5` or `--postprocess top-5,axis=1`}}",nargs="+",default=None,dest="postprocess",)
[docs]defparse_impl(self,args):""" Parses command-line arguments and populates the following attributes: Attributes: postprocess (Dict[str, Dict[str, Any]]): Maps postprocessing function names to dictionaries of output names mapped to parameters. For example, this could be something like: :: {"top_k": {"output1": 5, "output2": 6}} """self.postprocess=args_util.parse_arglist_to_dict(args_util.get(args,"postprocess"))postprocess={}topk_key=inline(safe("top_k"))ifself.postprocessisnotNone:postprocess[topk_key]={}forkey,valinself.postprocess.items():ifnotval.startswith("top-"):G_LOGGER.critical(f"Invalid post-processing function: {val}. Note: Valid choices are: ['top-<K>'].")k,_,axis=val.partition(",")k=int(k.lstrip("top-"))ifaxis:postprocess[topk_key][key]=(k,int(axis.lstrip("axis=")))else:postprocess[topk_key][key]=kself.postprocess=postprocess
[docs]defadd_to_script_impl(self,script,results_name):""" Args: results_name (str): The name of the variable containing results from ``Comparator.run()``. Returns: str: The name of the variable containing the post-processed results. This could be the same as the original name. """ifself.postprocess:script.add_import(imports=["PostprocessFunc"],frm="polygraphy.comparator")forfunc,arginself.postprocess.items():script.append_suffix(safe("\n# Postprocessing\n""{results} = Comparator.postprocess({results}, PostprocessFunc.{func}({arg}))",arg=arg,func=func,results=results_name,))returnresults_name