From a4282636e3de672eede98893cf3187b5b535ba68 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 31 Dec 2024 17:28:20 +0800 Subject: [PATCH] add accumulative error compare --- .../msprobe/core/common/const.py | 2 + .../api_accuracy_checker/compare/algorithm.py | 16 +-- .../compare/api_precision_compare.py | 53 ++++++++- .../compare/api_precision_standard.yaml | 4 +- .../api_accuracy_checker/compare/compare.py | 17 ++- .../compare/compare_utils.py | 1 + .../precision_standard/absolute_threshold.py | 9 +- .../accumulative_error_compare.py | 108 ++++++++++++++++++ .../precision_standard/base_standard.py | 12 +- .../precision_standard/standard_config.py | 30 ++++- .../precision_standard/standard_register.py | 5 +- .../test_standard_config.py | 8 +- 12 files changed, 227 insertions(+), 38 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index baf609377d..9bc40c978a 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -283,12 +283,14 @@ class CompareConst: BINARY_CONSISTENCY_ALGORITHM_NAME = "二进制一致法" ABSOLUTE_THRESHOLD_ALGORITHM_NAME = "绝对阈值法" THOUSANDTH_STANDARD_ALGORITHM_NAME = "双千指标法" + ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME = "累积误差比对法" ABSOLUTE_THRESHOLD = 'absolute_threshold' BINARY_CONSISTENCY = 'binary_consistency' ULP_COMPARE = 'ulp_compare' THOUSANDTH_STANDARD = 'thousandth_threshold' BENCHMARK = 'benchmark' + ACCUMULATIVE_ERROR_COMPARE = 'accumulative_error_compare' SMALL_VALUE_ERR_RATIO = "small_value_err_ratio" RMSE_RATIO = "rmse_ratio" diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index d3c35dd265..ddee254c2b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -181,13 +181,13 @@ def check_inf_nan_value(inf_nan_mask, bench_output, device_output, dtype, rtol): def check_small_value(abs_err, small_value_mask, small_value_atol): ''' - 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值 + 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值 输入: - rel_err:npu输出和golden输出的相对误差 + abs_err:npu输出和golden输出的绝对误差 normal_value_mask:npu输出和golden输出的正常值mask - rtol:相对误差的阈值 + atol:绝对误差的阈值 输出: - rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例 + abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例 ''' greater_mask = np.greater(abs_err, small_value_atol) err_mask = np.logical_and(greater_mask, small_value_mask) @@ -197,13 +197,13 @@ def check_small_value(abs_err, small_value_mask, small_value_atol): def check_norm_value(normal_value_mask, rel_err, rtol): ''' - 新精度标准的绝对阈值法中,检查npu和golden正常值输出的绝对误差是否满足阈值 + 新精度标准的相对阈值法中,检查npu和golden小值域输出的相对误差是否满足阈值 输入: - abs_err:npu输出和golden输出的绝对误差 + rel_err:npu输出和golden输出的相对误差 normal_value_mask:npu输出和golden输出的正常值mask - atol:绝对误差的阈值 + rtol:相对误差的阈值 输出: - abs_err_ratio:npu输出和golden输出的绝对误差不满足阈值的比例 + rel_err_ratio:npu输出和golden输出的相对误差不满足阈值的比例 ''' err_mask = np.greater(rel_err, rtol) err_mask = np.logical_and(err_mask, normal_value_mask) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py index d1c7806e28..0e94d62a67 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -34,6 +34,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_input import Precision from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_register import StandardRegistry from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpPrecisionCompare from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare import BenchmarkPrecisionCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig from msprobe.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn from msprobe.pytorch.api_accuracy_checker.run_ut.run_ut_utils import get_validated_result_csv_path from msprobe.pytorch.api_accuracy_checker.common.utils import extract_detailed_api_segments, extract_basic_api_segments @@ -81,11 +82,12 @@ def write_detail_csv(content, save_path): def register_compare_func(): registry = StandardRegistry() - registry.register("absolute_threshold", record_absolute_threshold_result) - registry.register("binary_consistency", record_binary_consistency_result) - registry.register("ulp_compare", record_ulp_compare_result) - registry.register("thousandth_threshold", record_thousandth_threshold_result) - registry.register("benchmark", record_benchmark_compare_result) + registry.register(CompareConst.ABSOLUTE_THRESHOLD, record_absolute_threshold_result) + registry.register(CompareConst.BINARY_CONSISTENCY, record_binary_consistency_result) + registry.register(CompareConst.ULP_COMPARE, record_ulp_compare_result) + registry.register(CompareConst.THOUSANDTH_STANDARD, record_thousandth_threshold_result) + registry.register(CompareConst.BENCHMARK, record_benchmark_compare_result) + registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, record_accumulative_error_compare_result) return registry @@ -361,6 +363,47 @@ def record_ulp_compare_result(input_data): return compare_result +def record_accumulative_error_compare_result(input_data): + row_npu = input_data.row_npu + compare_column = input_data.compare_column + absolute_threshold_result = get_absolute_threshold_result(row_npu) + threshold_result = absolute_threshold_result.get("absolute_threshold_result") + eb, eb_result = check_eb(row_npu) + accumulative_error_compare_result = CompareConst.PASS + if CompareConst.ERROR in [threshold_result, eb_result]: + accumulative_error_compare_result = CompareConst.ERROR + + compare_column.inf_nan_error_ratio = absolute_threshold_result.get("inf_nan_error_ratio") + compare_column.inf_nan_error_ratio_status = absolute_threshold_result.get("inf_nan_result") + compare_column.rel_err_ratio = absolute_threshold_result.get("rel_err_ratio") + compare_column.rel_err_ratio_status = absolute_threshold_result.get("rel_err_result") + compare_column.abs_err_ratio = absolute_threshold_result.get("abs_err_ratio") + compare_column.abs_err_ratio_status = absolute_threshold_result.get("abs_err_result") + compare_column.eb_ratio = eb + compare_column.eb_status = eb_result + compare_column.compare_result = accumulative_error_compare_result + compare_column.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE_ALGORITHM_NAME + message = '' + if compare_column.inf_nan_error_ratio_status == CompareConst.ERROR: + message += "ERROR: inf/nan错误率超过阈值\n" + if compare_column.rel_err_ratio_status == CompareConst.ERROR: + message += "ERROR: 相对误差错误率超过阈值\n" + if compare_column.abs_err_ratio_status == CompareConst.ERROR: + message += "ERROR: 绝对误差错误率超过阈值\n" + if compare_column.eb_status == CompareConst.ERROR: + message += "ERROR: 误差均衡性超过阈值\n" + compare_column.compare_message = message + return compare_column.compare_result + + +def check_eb(row_npu): + eb = convert_str_to_float(row_npu[ApiPrecisionCompareColumn.EB]) + dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] + eb_threshold = StandardConfig.get_accumulative_error_eb_threshold(dtype) + eb_result = CompareConst.PASS if eb <= eb_threshold else CompareConst.ERROR + return eb, eb_result + + def check_thousandth_rate(thousandth_rate): return CompareConst.PASS if convert_str_to_float(thousandth_rate) >= CompareConst.THOUSANDTH_PASS_VALUE \ else CompareConst.ERROR diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml index c627e84072..1175c1ed42 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml @@ -131,4 +131,6 @@ ULPStandard: ThousandthStandard: - conv1d - conv2d - \ No newline at end of file + +AccumulativeErrorStandard: + - test_api diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index 6357036c00..cf5928e509 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -30,6 +30,7 @@ from msprobe.pytorch.api_accuracy_checker.precision_standard.benchmark_compare i from msprobe.pytorch.api_accuracy_checker.precision_standard.ulp_compare import UlpCompare from msprobe.pytorch.api_accuracy_checker.precision_standard.binary_consistency import BinaryCompare from msprobe.pytorch.api_accuracy_checker.precision_standard.thousandth_standard import ThousandthStdCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.accumulative_error_compare import AccumulativeErrorCompare from msprobe.pytorch.api_accuracy_checker.compare.compare_input import CompareInput from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_err, get_max_abs_err, get_rel_err_ratio, \ cosine_sim, get_rel_err_origin, get_abs_bench_with_eps, compare_bool_tensor @@ -158,6 +159,11 @@ class Comparator: def _benchmark_compare(input_data): benchmark_compare = BenchmarkCompare(input_data) benchmark_compare.compare() + + @staticmethod + def _accumulative_error_compare(input_data): + accumulative_error_compare = AccumulativeErrorCompare(input_data) + accumulative_error_compare.compare() def write_csv_title(self): summary_test_rows = [ @@ -247,11 +253,12 @@ class Comparator: def _register_compare_func(self): registry = StandardRegistry() - registry.register("absolute_threshold", self._absolute_standard_compare) - registry.register("binary_consistency", self._binary_standard_compare) - registry.register("ulp_compare", self._ulp_compare) - registry.register("thousandth_threshold", self._thousandth_standard_compare) - registry.register("benchmark", self._benchmark_compare) + registry.register(CompareConst.ABSOLUTE_THRESHOLD, self._absolute_standard_compare) + registry.register(CompareConst.BINARY_CONSISTENCY, self._binary_standard_compare) + registry.register(CompareConst.ULP_COMPARE, self._ulp_compare) + registry.register(CompareConst.THOUSANDTH_STANDARD, self._thousandth_standard_compare) + registry.register(CompareConst.BENCHMARK, self._benchmark_compare) + registry.register(CompareConst.ACCUMULATIVE_ERROR_COMPARE, self._accumulative_error_compare) return registry def _compare_core_wrapper(self, api_name, bench_output, device_output): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py index 2f891c7603..549230d0a9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -43,6 +43,7 @@ absolute_standard_api = apis.get('AbsoluteThreshStandard') binary_standard_api = apis.get('BinaryCompareStandard') ulp_standard_api = apis.get('ULPStandard') thousandth_standard_api = apis.get('ThousandthStandard') +accumulative_error_standard_api = apis.get('AccumulativeErrorStandard') DETAIL_TEST_ROWS = [ diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py index 3e716cd035..fd6ac7dd41 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/absolute_threshold.py @@ -71,13 +71,6 @@ class AbsolutethdCompare(BaseCompare): def _get_rtol(self): return StandardConfig.get_rtol(self.dtype) - def _get_rel_err(self, abs_err, abs_bench_with_eps): - rel_err = abs_err / abs_bench_with_eps - return rel_err - - def _get_normal_value_mask(self, small_value_mask): - return np.logical_and(self.both_finite_mask, np.logical_not(small_value_mask)) - def _pre_compare(self): """ Prepares the comparison by calculating various metrics. @@ -99,7 +92,7 @@ class AbsolutethdCompare(BaseCompare): self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps) self.small_value, self.small_value_atol = self.get_small_value_threshold() self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) - self.normal_value_mask = self._get_normal_value_mask(self.small_value_mask) + self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask) def _compute_metrics(self): inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype, diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py new file mode 100644 index 0000000000..3fd3bd970d --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/accumulative_error_compare.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \ + check_small_value, get_error_balance +from msprobe.pytorch.api_accuracy_checker.precision_standard.base_standard import BaseCompare +from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config import StandardConfig +from msprobe.core.common.const import CompareConst + + + +class AccumulativeErrorCompare(BaseCompare): + """ + Absolute threshold compare class. + + This class is used to compare the absolute threshold of benchmark outputs and device outputs. + It calculates various metrics such as inf_nan_error_ratio, rel_err_ratio, and abs_err_ratio + to determine the accuracy of the device output compared to the benchmark output. + + Attributes: + bench_output (np.ndarray): The output from the benchmark. + device_output (np.ndarray): The output from the device. + dtype (torch.dtype): The data type of the outputs. + abs_bench (np.ndarray): The absolute value of the benchmark output. + abs_bench_with_eps (np.ndarray): The absolute value of the benchmark output with epsilon. + both_finite_mask (np.ndarray): A mask indicating where both outputs are finite. + inf_nan_mask (np.ndarray): A mask indicating where either output is infinite or NaN. + bound (float): The tolerance for comparison. + rel_err (np.ndarray): The relative error between the benchmark and device outputs. + small_value (float): The small value threshold for comparison. + small_value_atol (float): The absolute tolerance for small values. + small_value_mask (np.ndarray): A mask indicating where values are small. + normal_value_mask (np.ndarray): A mask indicating where values are normal. + + Methods: + _get_rtol(): Gets the relative tolerance based on the data type. + _get_rel_err(abs_bench_with_eps): Calculates the relative error. + _get_normal_value_mask(small_value_mask): Gets the mask for normal values. + _pre_compare(): Prepares the comparison by calculating various metrics. + _compute_metrics(): Computes the comparison metrics. + + Note: + This class assumes that the input data is a dictionary containing 'bench_output', 'device_output', + 'compare_column' and 'dtype'. + The 'dtype' should be a PyTorch data type. + + See Also: + BaseCompare: The base class for comparison classes. + StandardConfig: The class containing standard configuration values. + """ + def __init__(self, input_data): + super(AccumulativeErrorCompare, self).__init__(input_data) + self.compare_algorithm = CompareConst.ACCUMULATIVE_ERROR_COMPARE + + def _get_bound(self): + return StandardConfig.get_accumulative_error_bound(self.dtype) + + def _pre_compare(self): + """ + Prepares the comparison by calculating various metrics. + + This method performs the following steps: + 1. Calculates the absolute benchmark values and their epsilon-adjusted versions. + 2. Determines masks for finite and infinite/NaN values in the outputs. + 3. Computes the absolute error between benchmark and device outputs. + 4. Retrieves the tolerance based on the data type. + 5. Calculates the relative error using the absolute error and epsilon-adjusted benchmark values. + 6. Determines the small value threshold and its absolute tolerance. + 7. Creates a mask for small values based on the benchmark values and finite mask. + 8. Creates a mask for normal values by excluding small values from the finite mask. + """ + self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() + self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() + self.abs_err = self.stat_abs_error() + self.bound = self._get_bound() + self.rel_err = self._get_rel_err(self.abs_err, self.abs_bench_with_eps) + self.small_value, self.small_value_atol = self.get_small_value_threshold() + self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) + self.normal_value_mask = self._get_normal_value_mask(self.both_finite_mask, self.small_value_mask) + + def _compute_metrics(self): + inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype, + self.bound) + rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.bound) + abs_err_ratio = check_small_value(self.abs_err, self.small_value_mask, self.bound) + eb = get_error_balance(self.bench_output, self.device_output) + return { + "inf_nan_error_ratio": inf_nan_error_ratio, + "rel_err_ratio": rel_err_ratio, + "abs_err_ratio": abs_err_ratio, + "eb": eb + } diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py index 13d1eabf03..e3ff663758 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/base_standard.py @@ -16,6 +16,7 @@ # limitations under the License. from abc import ABC, abstractmethod +import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import convert_str_to_float from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ get_finite_and_infinite_mask, get_small_value_mask @@ -67,13 +68,22 @@ class BaseCompare(ABC): def stat_small_value_mask(abs_bench, both_finite_mask, small_value): small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value) return small_value_mask + + @staticmethod + def _get_rel_err(abs_err, abs_bench_with_eps): + rel_err = abs_err / abs_bench_with_eps + return rel_err + + @staticmethod + def _get_normal_value_mask(both_finite_mask, small_value_mask): + return np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) @abstractmethod def _pre_compare(self): raise NotImplementedError def get_small_value_threshold(self): - small_value = StandardConfig.get_small_valuel(self.dtype) + small_value = StandardConfig.get_small_value(self.dtype, self.compare_algorithm) small_value_atol = StandardConfig.get_small_value_atol(self.dtype, self.compare_algorithm) return small_value, small_value_atol diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py index afa6fda715..11a99e044a 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_config.py @@ -76,7 +76,12 @@ class StandardConfig: torch.float32: 2**-20, "default": 2**-20 } - + _accumulative_error_bound = { + torch.float16: 2**-8, + torch.bfloat16: 2**-7, + torch.float32: 2**-11, + "default": 2**-11 + } _small_value_threshold = { 'error_threshold': 2, 'warning_threshold': 1, @@ -102,20 +107,29 @@ class StandardConfig: 'warning_threshold': 1, "default": 1 } - minmum_err = { + _minmum_err = { 'torch.float16': 2**-11, 'torch.bfloat16': 2**-8, 'torch.float32': 2**-14, 'default': 2**-14 } + _accumulative_error_eb_threshold = { + 'torch.float16': 2**-20, + 'torch.bfloat16': 2**-7, + 'torch.float32': 2**-14, + 'default': 2**-14 + } _fp32_mean_ulp_err_threshold = 64 ulp_err_proportion_ratio = 1 _fp32_ulp_err_proportion = 0.05 _fp16_ulp_err_proportion = 0.001 + _special_samll_value = 1 @classmethod - def get_small_valuel(cls, dtype): + def get_small_value(cls, dtype, standard): + if standard == CompareConst.ACCUMULATIVE_ERROR_COMPARE: + return cls._special_samll_value return cls._small_value.get(dtype, cls._small_value["default"]) @classmethod @@ -193,4 +207,12 @@ class StandardConfig: @classmethod def get_minmum_err(cls, dtype): - return cls.minmum_err.get(dtype, cls.minmum_err["default"]) + return cls._minmum_err.get(dtype, cls._minmum_err["default"]) + + @classmethod + def get_accumulative_error_bound(cls, dtype): + return cls._accumulative_error_bound.get(dtype, cls._accumulative_error_bound["default"]) + + @classmethod + def get_accumulative_error_eb_threshold(cls, dtype): + return cls._accumulative_error_eb_threshold.get(dtype, cls._accumulative_error_eb_threshold["default"]) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py index e70fee8a20..c1d72e8d1f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/precision_standard/standard_register.py @@ -17,7 +17,7 @@ from typing import Callable from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \ - ulp_standard_api, thousandth_standard_api, BINARY_COMPARE_UNSUPPORT_LIST + ulp_standard_api, thousandth_standard_api, accumulative_error_standard_api, BINARY_COMPARE_UNSUPPORT_LIST from msprobe.core.common.const import CompareConst class StandardRegistry: @@ -50,7 +50,8 @@ class StandardRegistry: CompareConst.ABSOLUTE_THRESHOLD: absolute_standard_api, CompareConst.BINARY_CONSISTENCY: binary_standard_api, CompareConst.ULP_COMPARE: ulp_standard_api, - CompareConst.THOUSANDTH_STANDARD: thousandth_standard_api + CompareConst.THOUSANDTH_STANDARD: thousandth_standard_api, + CompareConst.ACCUMULATIVE_ERROR_COMPARE: accumulative_error_standard_api } def register(self, standard: str, func: Callable) -> None: diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_config.py index 6817919958..094994dbce 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_config.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/api_accuracy_checker/precision_standard/test_standard_config.py @@ -5,12 +5,12 @@ from msprobe.pytorch.api_accuracy_checker.precision_standard.standard_config imp class TestStandardConfig(unittest.TestCase): def test_get_small_value(self): # 测试已定义的数据类型 - self.assertEqual(StandardConfig.get_small_valuel(torch.float16), 2**-10) - self.assertEqual(StandardConfig.get_small_valuel(torch.bfloat16), 2**-10) - self.assertEqual(StandardConfig.get_small_valuel(torch.float32), 2**-20) + self.assertEqual(StandardConfig.get_small_value(torch.float16), 2**-10) + self.assertEqual(StandardConfig.get_small_value(torch.bfloat16), 2**-10) + self.assertEqual(StandardConfig.get_small_value(torch.float32), 2**-20) # 测试未定义的数据类型(应返回默认值) - self.assertEqual(StandardConfig.get_small_valuel(torch.int32), 2**-20) + self.assertEqual(StandardConfig.get_small_value(torch.int32), 2**-20) def test_get_small_value_atol(self): standard = 'absolute_threshold' -- Gitee