Source code for polygraphy.comparator.compare

#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import functools
from collections import OrderedDict

from polygraphy import mod, util
from polygraphy.comparator import util as comp_util
from polygraphy.datatype import DataType
from polygraphy.logger import G_LOGGER, LogMode

np = mod.lazy_import("numpy")


[docs] @mod.export() class OutputCompareResult: """ Represents the result of comparing a single output of a single iteration between two runners. """ def __init__( self, passed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff, quantile_absdiff, quantile_reldiff, ): """ Records the required tolerances and other statistics gathered during comparison. Args: passed (bool): Whether the error was within acceptable limits. max_absdiff (float): The minimum required absolute tolerance to consider the outputs equivalent. max_reldiff (float): The minimum required relative tolerance to consider the outputs equivalent. mean_absdiff (float): The mean absolute error between the outputs. mean_reldiff (float): The mean relative error between the outputs. median_absdiff (float): The median absolute error between the outputs. median_reldiff (float): The median relative error between the outputs. quantile_absdiff (float): The q-th quantile absolute error between the outputs. quantile_reldiff (float): The q-th quantile relative error between the outputs. """ self.passed = passed self.max_absdiff = max_absdiff self.max_reldiff = max_reldiff self.mean_absdiff = mean_absdiff self.mean_reldiff = mean_reldiff self.median_absdiff = median_absdiff self.median_reldiff = median_reldiff self.quantile_absdiff = quantile_absdiff self.quantile_reldiff = quantile_reldiff
[docs] def __bool__(self): """ Whether the output matched. Returns: bool """ return self.passed
def __str__(self): return f"(atol={self.max_absdiff}, rtol={self.max_reldiff})"
def default_find_output_func(output_name, index, iter_result, base_iter_result): found_name = util.find_str_in_iterable(output_name, iter_result.keys(), index) if found_name is None: return None elif found_name != output_name: exact_match = util.find_str_in_iterable(found_name, base_iter_result.keys()) if exact_match == found_name: G_LOGGER.verbose( f"Will not compare {found_name} with {output_name}, since the former already has an exact match: {exact_match}" ) return None # If the found output is being compared against another output already, skip this non-exact match G_LOGGER.warning( f"Output names did not match exactly. Assuming {iter_result.runner_name} output: {found_name} corresponds to output: {output_name}" ) return [found_name] def run_comparison(func, fail_fast, iter_result0, iter_result1, find_output_func): """ Iterates over all the generated outputs and runs `func` to compare them. """ output_status = ( OrderedDict() ) # OrderedDict[str, bool] Maps output names to whether they matched. for index, (out0_name, output0) in enumerate(iter_result0.items()): out1_names = util.default(find_output_func(out0_name, index, iter_result1), []) if len(out1_names) > 1: G_LOGGER.info( f"Will attempt to compare output: '{out0_name}' [{iter_result0.runner_name}] with multiple outputs: '{list(out1_names)}' [{iter_result1.runner_name}]" ) for out1_name in out1_names: if out1_name is None or out1_name not in iter_result1: G_LOGGER.warning( f"For output: '{out0_name}' [{iter_result0.runner_name}], skipping corresponding output: '{out1_name}' [{iter_result1.runner_name}], since the output was not found" ) continue output1 = iter_result1[out1_name] G_LOGGER.start( f"Comparing Output: '{out0_name}' (dtype={util.array.dtype(output0)}, shape={util.array.shape(output0)}) with '{out1_name}' (dtype={util.array.dtype(output1)}, shape={util.array.shape(output1)})" ) with G_LOGGER.indent(): output_status[out0_name] = func(out0_name, output0, out1_name, output1) if fail_fast and not output_status[out0_name]: return output_status mismatched_output_names = [ name for name, matched in output_status.items() if not matched ] if mismatched_output_names: G_LOGGER.error(f"FAILED | Mismatched outputs: {mismatched_output_names}") else: G_LOGGER.finish( f"PASSED | All outputs matched | Outputs: {list(output_status.keys())}" ) # This is useful for catching cases were Polygraphy does something wrong with the runner output buffers if not output_status and (bool(iter_result0.keys()) or bool(iter_result1.keys())): r0_name = iter_result0.runner_name r0_outs = list(iter_result0.keys()) r1_name = iter_result1.runner_name r1_outs = list(iter_result1.keys()) G_LOGGER.critical( f"All outputs were skipped, no common outputs found! Note:\n{r0_name} outputs: {r0_outs}\n{r1_name} outputs: {r1_outs}" ) return output_status # Provides functions to compare two IterationResults
[docs] @mod.export() class CompareFunc: """ Provides functions that can be used to compare two `IterationResult` s. """
[docs] @staticmethod def simple( check_shapes=None, rtol=None, atol=None, fail_fast=None, find_output_func=None, check_error_stat=None, infinities_compare_equal=None, save_heatmaps=None, show_heatmaps=None, save_error_metrics_plot=None, show_error_metrics_plot=None, error_quantile=None, ): """ Creates a function that compares two IterationResults, and can be used as the `compare_func` argument in ``Comparator.compare_accuracy``. Args: check_shapes (bool): Whether shapes must match exactly. If this is False, this function may permute or reshape outputs before comparison. Defaults to True. rtol (Union[float, Dict[str, float]]): The relative tolerance to use when checking accuracy. This is expressed as a percentage of the second set of output values. For example, a value of 0.01 would check that the first set of outputs is within 1% of the second. This can be provided on a per-output basis using a dictionary. In that case, use an empty string ("") as the key to specify default tolerance for outputs not explicitly listed. Defaults to 1e-5. atol (Union[float, Dict[str, float]]): The absolute tolerance to use when checking accuracy. This can be provided on a per-output basis using a dictionary. In that case, use an empty string ("") as the key to specify default tolerance for outputs not explicitly listed. Defaults to 1e-5. fail_fast (bool): Whether the function should exit immediately after the first failure. Defaults to False. find_output_func (Callable(str, int, IterationResult) -> List[str]): A callback that returns a list of output names to compare against from the provided IterationResult, given an output name and index from another IterationResult. The comparison function will always iterate over the output names of the first IterationResult, expecting names from the second. A return value of `[]` or `None` indicates that the output should be skipped. check_error_stat (Union[str, Dict[str, str]]): The error statistic to check. Possible values are: - "elemwise": Checks each element in the output to determine if it exceeds both tolerances specified. The minimum required tolerances displayed in this mode are only applicable when just one type of tolerance is set. Because of the nature of the check, when both absolute/relative tolerance are specified, the required minimum tolerances may be lower. - "max": Checks the maximum absolute/relative errors against the respective tolerances. This is the strictest possible check. - "mean" Checks the mean absolute/relative errors against the respective tolerances. - "median": Checks the median absolute/relative errors against the respective tolerances. - "quantile": Checks the quantile absolute/relative errors against the respective tolerances. This can be provided on a per-output basis using a dictionary. In that case, use an empty string ("") as the key to specify default error stat for outputs not explicitly listed. Defaults to "elemwise". infinities_compare_equal (bool): If True, then matching +-inf values in the output have an absdiff of 0. If False, then matching +-inf values in the output have an absdiff of NaN. Defaults to False. save_heatmaps (str): [EXPERIMENTAL] Path to a directory in which to save figures of heatmaps of the absolute and relative error. Defaults to None. show_heatmaps (bool): [EXPERIMENTAL] Whether to display heatmaps of the absolute and relative error. Defaults to False. save_error_metrics_plot (str): [EXPERIMENTAL] Path to a directory in which to save the error metrics plots. Defaults to None. show_error_metrics_plot (bool): [EXPERIMENTAL] Whether to display the error metrics plot. error_quantile (Union[float, Dict[str, float]]): Quantile error to compute when checking accuracy. This is expressed as a float in range [0, 1]. For example, error_quantile=0.5 is the median. Defaults to 0.99. Returns: Callable(IterationResult, IterationResult) -> OrderedDict[str, OutputCompareResult]: A callable that returns a mapping of output names to `OutputCompareResult` s, indicating whether the corresponding output matched. """ check_shapes = util.default(check_shapes, True) default_rtol = 1e-5 default_atol = 1e-5 default_quantile = 0.99 rtol = util.default(rtol, default_rtol) atol = util.default(atol, default_atol) error_quantile = util.default(error_quantile, default_quantile) fail_fast = util.default(fail_fast, False) default_error_stat = "elemwise" check_error_stat = util.default(check_error_stat, default_error_stat) infinities_compare_equal = util.default(infinities_compare_equal, False) show_heatmaps = util.default(show_heatmaps, False) show_error_metrics_plot = util.default(show_error_metrics_plot, False) def check_outputs_match( out0, out0_name, out1, out1_name, per_out_rtol, per_out_atol, per_out_err_stat, runner0_name, runner1_name, per_out_quantile, ): """ Checks whether two outputs matched. Args: out0 (Union[np.array, torch.Tensor]): The first output. out0_name (str): The name of the first output. out1 (Union[np.array, torch.Tensor]): The second output. out1_name (str): The name of the second output. per_out_rtol (float): The relative tolerance to use for comparison. per_out_atol (float): The absolute tolerance to use for comparison. per_out_err_stat (str): The error statistic to check. See the docstring of ``simple`` for details. runner0_name (str): The name of the runner that generated the first output. runner1_name (str): The name of the runner that generated the second output. per_out_quantile (float): The qunatile value to use for quantile comparison. Returns: OutputCompareResult: Details on whether the outputs matched. """ VALID_CHECK_ERROR_STATS = ["max", "mean", "median", "elemwise", "quantile"] if per_out_err_stat not in VALID_CHECK_ERROR_STATS: G_LOGGER.critical( f"Invalid choice for check_error_stat: {per_out_err_stat}.\nNote: Valid choices are: {VALID_CHECK_ERROR_STATS}" ) G_LOGGER.super_verbose( f"{runner0_name:35} | Output: {out0_name} (dtype={util.array.dtype(out0)}, shape={util.array.shape(out0)}):\n{util.indent_block(out0)}" ) G_LOGGER.super_verbose( f"{runner1_name:35} | Output: {out1_name} (dtype={util.array.dtype(out1)}, shape={util.array.shape(out1)}):\n{util.indent_block(out1)}" ) # Check difference vs. tolerances if ( util.array.dtype(out0) == DataType.BOOL and util.array.dtype(out1) == DataType.BOOL ): absdiff = util.array.logical_xor(out0, out1) else: absdiff = util.array.abs( util.array.subtract( comp_util.cast_up(out0), comp_util.cast_up(out1) ) ) if infinities_compare_equal: out0_infinite = util.array.isinf(out0) cond = util.array.logical_and(out0_infinite, out0 == out1) absdiff = util.array.where(cond, 0, absdiff) # Add a small epsilon (2e-16) to zero values in the array to prevent NaN in relative error. out1_with_eps = copy.copy(comp_util.cast_up(out1)) if util.array.dtype(out1_with_eps).is_floating: if util.array.any(out1_with_eps == 0): G_LOGGER.warning( f"{runner1_name:35} | Output: {out1_name}: Some values are 0. " f"Will add a small epsilon quantity to these when computing relative difference. " f"Note that this may cause some relative differences to be extremely high. ", mode=LogMode.ONCE, ) EPSILON = 2.220446049250313e-16 out1_with_eps[out1_with_eps == 0] += EPSILON # TODO: Only evaluate this if actually needed like we do for quantile_*. reldiff = util.array.divide(absdiff, util.array.abs(out1_with_eps)) min_reldiff = comp_util.compute_min(reldiff) max_reldiff = comp_util.compute_max(reldiff) mean_reldiff = comp_util.compute_mean(reldiff) median_reldiff = comp_util.compute_median(reldiff) quantile_reldiff = None min_absdiff = comp_util.compute_min(absdiff) max_absdiff = comp_util.compute_max(absdiff) mean_absdiff = comp_util.compute_mean(absdiff) median_absdiff = comp_util.compute_median(absdiff) quantile_absdiff = None def stat_failed(diff, tol): return util.array.isnan(diff) or diff > tol if per_out_err_stat == "mean": failed = stat_failed(mean_absdiff, per_out_atol) and stat_failed( mean_reldiff, per_out_rtol ) elif per_out_err_stat == "median": failed = stat_failed(median_absdiff, per_out_atol) and stat_failed( median_reldiff, per_out_rtol ) elif per_out_err_stat == "max": failed = stat_failed(max_absdiff, per_out_atol) and stat_failed( max_reldiff, per_out_rtol ) elif per_out_err_stat == "quantile": quantile_reldiff = comp_util.compute_quantile(reldiff, per_out_quantile) quantile_absdiff = comp_util.compute_quantile(absdiff, per_out_quantile) failed = stat_failed(quantile_absdiff, per_out_atol) and stat_failed( quantile_reldiff, per_out_rtol ) else: assert ( per_out_err_stat == "elemwise" ), "This branch should be unreachable unless per_out_err_stat is 'elemwise'" mismatches = ( util.array.greater(absdiff, per_out_atol) | util.array.isnan(absdiff) ) & ( util.array.greater(reldiff, per_out_rtol) | util.array.isnan(reldiff) ) failed = util.array.any(mismatches) try: with G_LOGGER.indent(): G_LOGGER.super_verbose( lambda: f"Mismatched indices:\n{util.array.argwhere(mismatches)}" ) G_LOGGER.extra_verbose( lambda: f"{runner0_name:35} | Mismatched values:\n{out0[mismatches]}" ) G_LOGGER.extra_verbose( lambda: f"{runner1_name:35} | Mismatched values:\n{out1[mismatches]}" ) except Exception as err: G_LOGGER.warning( f"Failing to log mismatches.\nNote: Error was: {err}" ) # Log information about the outputs hist_bin_range = ( min(comp_util.compute_min(out0), comp_util.compute_min(out1)), max(comp_util.compute_max(out0), comp_util.compute_max(out1)), ) comp_util.log_output_stats( out0, failed, f"{runner0_name}: {out0_name}", hist_range=hist_bin_range ) comp_util.log_output_stats( out1, failed, f"{runner1_name}: {out1_name}", hist_range=hist_bin_range ) G_LOGGER.info(f"Error Metrics: {out0_name}") with G_LOGGER.indent(): def req_tol(mean_diff, median_diff, max_diff, quantile_diff): return { "mean": mean_diff, "median": median_diff, "max": max_diff, "elemwise": max_diff, "quantile": quantile_diff, }[per_out_err_stat] msg = f"Minimum Required Tolerance: {per_out_err_stat} error | [abs={req_tol(mean_absdiff, median_absdiff, max_absdiff, quantile_absdiff):.5g}] OR [rel={req_tol(mean_reldiff, median_reldiff, max_reldiff, quantile_reldiff):.5g}]" if per_out_err_stat == "elemwise": msg += " (requirements may be lower if both abs/rel tolerances are set)" G_LOGGER.info(msg) if save_error_metrics_plot or show_error_metrics_plot: with G_LOGGER.indent(): comp_util.scatter_plot_error_magnitude( absdiff, reldiff, comp_util.cast_up(out1), min_reldiff, max_reldiff, runner0_name, runner1_name, out0_name, out1_name, save_dir=save_error_metrics_plot, show=show_error_metrics_plot, ) def build_heatmaps(diff, min_diff, max_diff, prefix, use_lognorm=None): if save_heatmaps or show_heatmaps: with G_LOGGER.indent(): comp_util.build_heatmaps( diff, min_diff, max_diff, prefix=f"{prefix} Error | {out0_name}", save_dir=save_heatmaps, show=show_heatmaps, use_lognorm=use_lognorm, ) comp_util.log_output_stats(absdiff, failed, "Absolute Difference") build_heatmaps(absdiff, min_absdiff, max_absdiff, "Absolute") comp_util.log_output_stats(reldiff, failed, "Relative Difference") build_heatmaps( reldiff, min_reldiff, max_reldiff, "Relative", use_lognorm=True ) G_LOGGER.extra_verbose( lambda: f"Finished comparing: '{out0_name}' (dtype={util.array.dtype(out0)}, shape={util.array.shape(out0)}) [{runner0_name}] and '{out1_name}' (dtype={util.array.dtype(out1)}, shape={util.array.shape(out1)}) [{runner1_name}]" ) return OutputCompareResult( not failed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff, quantile_absdiff, quantile_reldiff, ) def compare_output(iter_result0, iter_result1): """ Compare the outputs of two runners from a single iteration. This function will always iterate over the output names of the first IterationResult, and attempt to find corresponding output names in the second. If no corresponding output name is found, the output is skipped. If all output names are skipped, then this function raises an error. Args: iter_result0 (IterationResult): The result of the first runner. iter_result1 (IterationResult): The result of the second runner. Returns: OrderedDict[str, OutputCompareResult]: The name of the outputs compared, derived from the first IterationResult, and whether they matched. If an output name is not found, it is omitted from this dictionary. Raises: PolygraphyException: If all output names are skipped, and thus no outputs are compared. """ def check_dict(dct, dict_name): if isinstance(dct, dict): util.check_sequence_contains( dct.keys(), set(iter_result0.keys()) | set(iter_result1.keys()) | {""}, name=dict_name, log_func=G_LOGGER.warning, check_missing=False, ) check_dict(rtol, "the rtol dictionary") check_dict(atol, "the atol dictionary") check_dict(check_error_stat, "the check_error_stat dictionary") check_dict(error_quantile, "the quantile dictionary") if not check_shapes: G_LOGGER.info( "Strict shape checking disabled. Will attempt to match output shapes before comparisons" ) def match(out0_name, output0, out1_name, output1): per_out_atol = util.value_or_from_dict(atol, out0_name, default_atol) per_out_rtol = util.value_or_from_dict(rtol, out0_name, default_rtol) per_out_err_stat = util.value_or_from_dict( check_error_stat, out0_name, default_error_stat ) per_out_quantile = util.value_or_from_dict( error_quantile, out0_name, default_quantile ) G_LOGGER.info( f"Tolerance: [abs={per_out_atol:.5g}, rel={per_out_rtol:.5g}] | Checking {per_out_err_stat} error" ) G_LOGGER.extra_verbose( f"Note: Comparing {iter_result0.runner_name} vs. {iter_result1.runner_name}" ) if check_shapes and util.array.shape(output0) != util.array.shape( output1 ): G_LOGGER.error( f"FAILED | Output: `{out0_name}` | Will not compare outputs of different shapes.\n" f"Note: Output shapes are {util.array.shape(output0)} and {util.array.shape(output1)}." ) G_LOGGER.error( "Note: Use --no-shape-check or set check_shapes=False to " "attempt to compare values anyway.", mode=LogMode.ONCE, ) return False output1 = util.try_match_shape(output1, util.array.shape(output0)) output0 = util.array.view( output0, DataType.from_dtype(util.array.dtype(output0)), util.array.shape(output1), ) outputs_matched = check_outputs_match( output0, out0_name, output1, out1_name, per_out_rtol=per_out_rtol, per_out_atol=per_out_atol, per_out_err_stat=per_out_err_stat, runner0_name=iter_result0.runner_name, runner1_name=iter_result1.runner_name, per_out_quantile=per_out_quantile, ) # Finally show summary. if not outputs_matched: G_LOGGER.error( f"FAILED | Output: '{out0_name}' | Difference exceeds tolerance (rel={per_out_rtol}, abs={per_out_atol})" ) else: G_LOGGER.finish( f"PASSED | Output: '{out0_name}' | Difference is within tolerance (rel={per_out_rtol}, abs={per_out_atol})" ) return outputs_matched nonlocal find_output_func find_output_func = util.default( find_output_func, functools.partial( default_find_output_func, base_iter_result=iter_result0 ), ) return run_comparison( match, fail_fast, iter_result0, iter_result1, find_output_func ) return compare_output
[docs] @staticmethod def indices(index_tolerance=None, fail_fast=None): """ Creates a function that compares two IterationResults containing indices, and can be used as the `compare_func` argument in ``Comparator.compare_accuracy``. This can be useful to compare, for example, the outputs of a Top-K operation. Outputs with more than one dimension are treated like multiple batches of values. For example, an output of shape (3, 4, 5, 10) would be treated like 60 batches (3 x 4 x 5) of 10 values each. Args: index_tolerance (Union[int, Dict[str, int]]): The tolerance to use when comparing indices. This is an integer indicating the maximum distance between values before it is considered a mismatch. For example, consider two outputs: :: output0 = [0, 1, 2] output1 = [1, 0, 2] With an index tolerance of 0, this would be considered a mismatch, since the positions of `0` and `1` are flipped between the two outputs. However, with an index tolerance of 1, it would pass since the mismatched values are only 1 spot apart. If instead the outputs were: :: output0 = [0, 1, 2] output1 = [1, 2, 0] Then we would require an index tolerance of 2, since the `0` value in the two outputs is 2 spots apart. When this value is set, the final 'index_tolerance' number of values are ignored for each batch. For example, with an index tolerance of 1, mismatches in the final element are not considered. If used with a Top-K output, you can compensate for this by instead using a Top-(K + index_tolerance). This can be provided on a per-output basis using a dictionary. In that case, use an empty string ("") as the key to specify default tolerance for outputs not explicitly listed. fail_fast (bool): Whether the function should exit immediately after the first failure. Defaults to False. Returns: Callable(IterationResult, IterationResult) -> OrderedDict[str, bool]: A callable that returns a mapping of output names to `bool` s, indicating whether the corresponding output matched. """ index_tolerance = util.default(index_tolerance, 0) fail_fast = util.default(fail_fast, False) def compare_output(iter_result0, iter_result1): """ Compare the outputs of two runners from a single iteration. This function will always iterate over the output names of the first IterationResult, and attempt to find corresponding output names in the second. If no corresponding output name is found, the output is skipped. If all output names are skipped, then this function raises an error. Args: iter_result0 (IterationResult): The result of the first runner. iter_result1 (IterationResult): The result of the second runner. Returns: OrderedDict[str, bool]: The name of the outputs compared, derived from the first IterationResult, and whether they matched. If an output name is not found, it is omitted from this dictionary. Raises: PolygraphyException: If all output names are skipped, and thus no outputs are compared. """ def match(out0_name, output0, out1_name, output1): per_out_index_tol = util.value_or_from_dict( index_tolerance, out0_name, 0 ) if util.array.shape(output0) != util.array.shape(output1): G_LOGGER.error("Cannot compare outputs of different shapes.") return False passed = True for batch in np.ndindex(util.array.shape(output0)[:-1]): out0_vals = output0[batch] if per_out_index_tol > 0: out0_vals = out0_vals[:-per_out_index_tol] out1_vals = output1[batch] for index0, val0 in enumerate(out0_vals): if val0 == out1_vals[index0]: continue index1 = util.array.ravel( util.array.argwhere(out1_vals == val0) ) if util.array.size(index1) < 1: G_LOGGER.error( f"FAILED | Value: {val0} not found in output" ) passed = False if fail_fast: return False continue index1 = index1[0] if abs(index1 - index0) > per_out_index_tol: G_LOGGER.error( f"FAILED | Difference exceeds index tolerance ({per_out_index_tol})" ) passed = False if fail_fast: return False continue # Log information about the outputs hist_bin_range = ( min(comp_util.compute_min(output0), comp_util.compute_min(output1)), max(comp_util.compute_max(output0), comp_util.compute_max(output1)), ) comp_util.log_output_stats( output0, not passed, f"{iter_result0.runner_name}: {out0_name}", hist_range=hist_bin_range, ) comp_util.log_output_stats( output1, not passed, f"{iter_result1.runner_name}: {out1_name}", hist_range=hist_bin_range, ) if passed: G_LOGGER.finish( f"PASSED | Difference is within index tolerance ({per_out_index_tol})" ) return passed return run_comparison( match, fail_fast, iter_result0, iter_result1, functools.partial( default_find_output_func, base_iter_result=iter_result0 ), ) return compare_output