Comparison Functions¶
Module: polygraphy.comparator
-
class
OutputCompareResult
(passed, max_absdiff, max_reldiff, mean_absdiff, mean_reldiff, median_absdiff, median_reldiff)[source]¶ Bases:
object
Represents the result of comparing a single output of a single iteration between two runners.
Records the required tolerances and other statistics gathered during comparison.
- Parameters
passed (bool) – Whether the error was within acceptable limits.
max_absdiff (float) – The minimum required absolute tolerance to consider the outputs equivalent.
max_reldiff (float) – The minimum required relative tolerance to consider the outputs equivalent.
mean_absdiff (float) – The mean absolute error between the outputs.
mean_reldiff (float) – The mean relative error between the outputs.
median_absdiff (float) – The median absolute error between the outputs.
median_reldiff (float) – The median relative error between the outputs.
-
class
CompareFunc
[source]¶ Bases:
object
Provides functions that can be used to compare two IterationResult s.
-
static
basic_compare_func
(check_shapes=None, rtol=None, atol=None, fail_fast=None, find_output_func=None, check_error_stat=None)[source]¶ Creates a function that compares two IterationResults, and can be used as the compare_func argument in
Comparator.compare_accuracy
.- Parameters
check_shapes (bool) – Whether shapes must match exactly. If this is False, this function may permute or reshape outputs before comparison. Defaults to True.
rtol (Union[float, Dict[str, float]]) – The relative tolerance to use when checking accuracy. This can be provided on a per-output basis using a dictionary. In that case, use an empty string (“”) as the key to specify default tolerance for outputs not explicitly listed. Defaults to 1e-5.
atol (Union[float, Dict[str, float]]) – The absolute tolerance to use when checking accuracy. This can be provided on a per-output basis using a dictionary. In that case, use an empty string (“”) as the key to specify default tolerance for outputs not explicitly listed. Defaults to 1e-5.
fail_fast (bool) – Whether the function should exit immediately after the first failure. Defaults to False.
find_output_func (Callable(str, int, IterationResult) -> List[str]) – A callback that returns a list of output names to compare against from the provided IterationResult, given an output name and index from another IterationResult. The comparison function will always iterate over the output names of the first IterationResult, expecting names from the second. A return value of [] or None indicates that the output should be skipped.
check_error_stat (Union[str, Dict[str, str]]) –
The error statistic to check. Possible values are:
”elemwise”: Checks each element in the output to determine if it exceeds both tolerances specified.
”max”: Checks the maximum absolute/relative errors against the respective tolerances. This is the strictest possible check.
”mean” Checks the mean absolute/relative errors against the respective tolerances.
”median”: Checks the median absolute/relative errors against the respective tolerances.
This can be provided on a per-output basis using a dictionary. In that case, use an empty string (“”) as the key to specify default error stat for outputs not explicitly listed. Defaults to “elemwise”.
- Returns
A callable that returns a mapping of output names to OutputCompareResult s, indicating whether the corresponding output matched.
- Return type
Callable(IterationResult, IterationResult) -> OrderedDict[str, OutputCompareResult]
-
static