From 989ffc89ae0f30a76e0cb54b5d8293c207c0fb17 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 22 Nov 2024 17:08:41 +0800 Subject: [PATCH 1/8] standard --- .../msprobe/core/common/file_utils.py | 25 ++--- .../msprobe/pytorch/__init__.py | 12 +-- .../api_accuracy_checker/compare/compare.py | 64 ++++++++----- .../standard/absolute_thd.py | 68 +++++++++++++ .../standard/basecompare.py | 46 +++++++++ .../standard/benchmark.py | 95 +++++++++++++++++++ .../api_accuracy_checker/standard/result.py | 35 +++++++ 7 files changed, 306 insertions(+), 39 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py diff --git a/debug/accuracy_tools/msprobe/core/common/file_utils.py b/debug/accuracy_tools/msprobe/core/common/file_utils.py index 9f02d93b97..a5b5781022 100644 --- a/debug/accuracy_tools/msprobe/core/common/file_utils.py +++ b/debug/accuracy_tools/msprobe/core/common/file_utils.py @@ -194,12 +194,13 @@ def check_other_user_writable(path): def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid - if file_owner != os.getuid() and os.getuid() != 0: - logger.error('The file path %s may be insecure because is does not belong to you.' % path) - raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) + # if file_owner != os.getuid() and os.getuid() != 0: + # logger.error('The file path %s may be insecure because is does not belong to you.' % path) + # raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_valid(path): + return if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): logger.error('The file path %s contains special characters.' % (path)) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) @@ -357,14 +358,16 @@ def check_file_type(path): def load_yaml(yaml_path): - path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) - checked_path = path_checker.common_check() - try: - with FileOpen(checked_path, "r") as f: - yaml_data = yaml.safe_load(f) - except Exception as e: - logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") - raise RuntimeError(f"Load yaml file {checked_path} failed.") from e + # path_checker = FileChecker(yaml_path, FileCheckConst.FILE, FileCheckConst.READ_ABLE, FileCheckConst.YAML_SUFFIX) + # checked_path = path_checker.common_check() + with open(yaml_path, "r") as f: + yaml_data = yaml.safe_load(f) + # try: + # with FileOpen(checked_path, "r") as f: + # yaml_data = yaml.safe_load(f) + # except Exception as e: + # logger.error(f"The yaml file failed to load. Please check the path: {checked_path}.") + # raise RuntimeError(f"Load yaml file {checked_path} failed.") from e return yaml_data diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index 6d7b2dcfc6..a31c4418df 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -16,9 +16,9 @@ # limitations under the License. -from .debugger.precision_debugger import PrecisionDebugger -from .common.utils import seed_all -from .compare.distributed_compare import compare_distributed -from .compare.pt_compare import compare -from .functional.module_dump import module_dump, module_dump_end -from msprobe.pytorch.monitor.module_hook import TrainerMon +# from .debugger.precision_debugger import PrecisionDebugger +# from .common.utils import seed_all +# from .compare.distributed_compare import compare_distributed +# from .compare.pt_compare import compare +# from .functional.module_dump import module_dump, module_dump_end +# from msprobe.pytorch.monitor.module_hook import TrainerMon diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index c40a43a511..df37d50b2f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -24,6 +24,7 @@ from msprobe.core.common.utils import CompareException from msprobe.core.common.file_utils import get_json_contents, write_csv import torch from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.standard.absolute_thd import AbsolutethdCompare from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ @@ -59,19 +60,19 @@ class Comparator: self.save_path_list = [result_csv_path] self.detail_save_path_list = [details_csv_path] - if config and config.online_config.is_online: - self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") - self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") - self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] - self.detail_save_path_list = \ - [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] - - if not is_continue_run_ut: - self.write_csv_title() - if stack_info_json_path: - self.stack_info = get_json_contents(stack_info_json_path) - else: - self.stack_info = None + # if config and config.online_config.is_online: + # self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") + # self.detail_save_path_str = details_csv_path.replace(".csv", "_rank{}.csv") + # self.save_path_list = [self.save_path_str.format(rank) for rank in config.online_config.rank_list] + # self.detail_save_path_list = \ + # [self.detail_save_path_str.format(rank) for rank in config.online_config.rank_list] + + # if not is_continue_run_ut: + # self.write_csv_title() + # if stack_info_json_path: + # self.stack_info = get_json_contents(stack_info_json_path) + # else: + # self.stack_info = None @staticmethod def get_path_from_rank(rank, path_list, path_pattern): @@ -335,15 +336,23 @@ class Comparator: err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output) compare_column.error_rate = err_rate elif api_name in absolute_standard_api: - small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( - api_name, str(dtype)) - rel_err = abs_err / abs_bench_with_eps - small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold) - normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) - compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output, - dtype, rtol) - compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) - compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol) + # small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( + # api_name, str(dtype)) + # rel_err = abs_err / abs_bench_with_eps + # small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold) + # normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) + # compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output, + # dtype, rtol) + # compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) + # compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol) + absolute_compare = AbsolutethdCompare(bench_output, device_output, dtype) + acc_result = absolute_compare.compute_with_golden() + compare_column.inf_nan_error_ratio = acc_result.inf_nan_error_ratio + compare_column.rel_err_ratio = acc_result.rel_err_ratio + compare_column.abs_err_ratio = acc_result.abs_err_ratio + print(compare_column.inf_nan_error_ratio) + print(compare_column.rel_err_ratio) + print(compare_column.abs_err_ratio) elif api_name in ulp_standard_api: if bench_output.size == 0: compare_column.max_ulp_error = 0 @@ -414,3 +423,14 @@ class Comparator: return CompareConst.WARNING, compare_column, message message += "Relative error is less than 0.0001, consider as pass.\n" return CompareConst.PASS, compare_column, message + + +compare = Comparator(result_csv_path="./result.csv", details_csv_path="./details.csv", is_continue_run_ut=False) +api_name = "mul" +bench_output = torch.rand(1,2) +device_output = torch.rand(1,2) +bench_output = bench_output.cpu().numpy() +device_output = device_output.cpu().numpy() +dtype = torch.float16 +compare_column = CompareColumn() +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py new file mode 100644 index 0000000000..64b075b6e5 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -0,0 +1,68 @@ +import torch +import numpy as np + +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \ + check_small_value +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseConfig, BaseCompare +from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult + +class AbolutethdConfig(BaseConfig): + rtol = { + torch.float16: 1e-3, + torch.bfloat16: 4e-3, + torch.float32: 1e-6 + } + + +class AbsolutethdCompare(BaseCompare): + def __init__(self, bench_output, device_output, dtype=None): + self.bench_output = bench_output + self.device_output = device_output + self.dtype = dtype + + self.standard = AbolutethdConfig() + + def get_threshold(self): + small_value, small_value_atol = super().get_threshold() + + rtol = self.standard.rtol.get(self.dtype, 1e-6) + + return small_value, small_value_atol, rtol + + def stat_inf_nan_value(self, inf_nan_mask, rtol): + return check_inf_nan_value(inf_nan_mask, self.bench_output, self.device_output, self.dtype, rtol) + + def stat_norm_value(self, normal_value_mask, rel_err, rtol): + return check_norm_value(normal_value_mask, rel_err, rtol) + + def stat_small_value(self, abs_bench, both_finite_mask, small_value): + return check_small_value(abs_bench, both_finite_mask, small_value) + + def compute_with_golden(self): + acc_result = AccResult() + + abs_bench, abs_bench_with_eps = self.stat_abs_bench_with_eps() + abs_err = self.stat_abs_error() + rel_err = abs_err / abs_bench_with_eps + both_finite_mask, inf_nan_mask = self.stat_finite_and_infinite_mask() + small_value, small_value_atol, rtol = self.get_threshold() + # print(small_value, small_value_atol, rotl) + small_value_mask = self.stat_small_value_mask(abs_bench, both_finite_mask, small_value) + normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) + inf_nan_error_ratio = self.stat_inf_nan_value(inf_nan_mask, rtol) + rel_err_ratio = self.stat_norm_value(normal_value_mask, rel_err, rtol) + abs_err_ratio = self.stat_small_value(abs_err, small_value_mask, small_value_atol) + acc_result.update( + inf_nan_error_ratio=inf_nan_error_ratio, + rel_err_ratio=rel_err_ratio, + abs_err_ratio=abs_err_ratio + ) + return acc_result + +bench_output = torch.rand(1,2) +device_output = torch.rand(1,2) +bench_output = bench_output.cpu().numpy() +device_output = device_output.cpu().numpy() +dtype = torch.float16 +bench = AbsolutethdCompare(bench_output, device_output, dtype) +bench.compute_with_golden() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py new file mode 100644 index 0000000000..5629bd1afd --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py @@ -0,0 +1,46 @@ +import torch +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ + get_finite_and_infinite_mask, get_small_value_mask + + +class BaseConfig: + small_value = { + torch.float16: 1e-3, + torch.bfloat16: 1e-3, + torch.float32: 1e-6, + } + small_value_atol = { + torch.float16: 1e-5, + torch.bfloat16: 1e-5, + torch.float32: 1e-9 + } + +class BaseCompare: + def __init__(self, bench_output, device_output, dtype=None): + self.bench_output = bench_output + self.device_output = device_output + self.dtype = dtype + self.standard = BaseConfig() + + def get_threshold(self): + small_value = self.standard.small_value.get(self.dtype, 1e-6) + small_value_atol = self.standard.small_value_atol.get(self.dtype, 1e-9) + return small_value, small_value_atol + + def stat_abs_bench_with_eps(self): + abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(self.bench_output, self.dtype) + return abs_bench, abs_bench_with_eps + + def stat_abs_error(self): + abs_err = get_abs_err(self.bench_output, self.device_output) + return abs_err + + def stat_finite_and_infinite_mask(self): + both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(self.bench_output, self.device_output) + return both_finite_mask, inf_nan_mask + + def stat_small_value_mask(self, abs_bench, both_finite_mask, small_value): + small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value) + return small_value_mask + + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py new file mode 100644 index 0000000000..ecc3776987 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -0,0 +1,95 @@ +import torch +import numpy as np + +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare, BaseConfig +from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_small_value_mask, \ + get_finite_and_infinite_mask, get_abs_err, get_small_value_err_ratio, get_rel_err, get_rmse, get_error_balance, \ + get_max_rel_err, get_mean_rel_err +class BenchmarkStandard(BaseConfig): + max_re_rtol = 10 + avg_re_rtol = 2 + rmse_rtol = 2 + small_ae_rtol = 2 + + def __init__(self): + pass + + def update_threshold(self, **kwargs): + pass + + +class BenchmarkCompare(BaseCompare): + + def __init__(self, bench_output, device_output, dtype=None): + self.bench_output = bench_output + self.device_output = device_output + self.dtype = dtype + + self.standard = BenchmarkStandard() + + def update(self, **kwargs): + + self.standard.update_threshold(**kwargs) + + def stat_small_value_err_ratio(self, small_value_mask, abs_err_greater_mask): + return get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) + + def stat_rel_er(self, abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask): + return get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) + + def stat_rmse(self, abs_err, small_value_mask): + return get_rmse(abs_err, small_value_mask) + + def stat_error_balance(self): + return get_error_balance(self.bench_output, self.device_output) + + def stat_max_rel_err(self, rel_err): + return get_max_rel_err(rel_err) + + def stat_mean_rel_err(self, rel_err): + return get_mean_rel_err(rel_err) + + def compute_with_golden(self): + acc_result = AccResult() + small_value, small_value_atol = self.get_threshold() + + abs_bench, abs_bench_with_eps = self.stat_abs_bench_with_eps() + + both_finite_mask, inf_nan_mask = self.stat_finite_and_infinite_mask() + small_value_mask = self.stat_small_value_mask(abs_bench, both_finite_mask, small_value) + + abs_err = self.stat_abs_error() + + rel_err = self.stat_rel_er(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) + + # if rel_err.size == 0: + # return CompareConst.ERROR, compare_column, "Relative error result list is empty." + if rel_err.size == 0: + acc_result.update(rel_err_size=rel_err.size) + return acc_result + abs_err_greater_mask = np.greater(abs_err, small_value_atol) + small_value_err_ratio = self.stat_small_value_err_ratio(small_value_mask, abs_err_greater_mask) + rmse = self.stat_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) + eb = self.stat_error_balance() + max_rel_error = self.stat_max_rel_err(rel_err) + mean_rel_error = self.stat_mean_rel_err(rel_err) + + acc_result.update( + small_value_err_ratio=small_value_err_ratio, + max_rel_error=max_rel_error, + mean_rel_error=mean_rel_error, + rmse=rmse, + error_balance=eb + ) + return acc_result + + +bench_output = torch.rand(1,2) +device_output = torch.rand(1,2) +bench_output = bench_output.cpu().numpy() +device_output = device_output.cpu().numpy() +dtype = torch.float32 +bench = BenchmarkCompare(bench_output, device_output, dtype) +acc_result = bench.compute_with_golden() +print(acc_result.max_rel_error) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py new file mode 100644 index 0000000000..dc693b129c --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py @@ -0,0 +1,35 @@ + + +class AccResult: + small_value_err_ratio = None + rmse = None + eb = None + max_rel_error = None + mean_rel_error = None + rel_error_size = None + inf_nan_error_ratio = None + rel_err_ratio = None + abs_err_ratio = None + rel_err_size = None + + + def update(self, **kwargs): + for key, value in kwargs.items(): + if value is None: + continue + setattr(self, key, value) + + @classmethod + def get_attribute(cls, instance, attribute_name): + """ + Class method to get the value of an attribute from an instance. + + Args: + instance (AccResult): The instance from which to get the attribute. + attribute_name (str): The name of the attribute to get. + + Returns: + The value of the attribute if it exists, otherwise None. + """ + return getattr(instance, attribute_name, None) + \ No newline at end of file -- Gitee From 149c2cc51801f46ee3eda77db430a169b0dbc257 Mon Sep 17 00:00:00 2001 From: gitee Date: Mon, 25 Nov 2024 16:06:26 +0800 Subject: [PATCH 2/8] fix --- .../api_accuracy_checker/compare/algorithm.py | 8 ++ .../api_accuracy_checker/compare/compare.py | 90 ++++++++++++------- .../standard/absolute_thd.py | 15 +--- .../standard/benchmark.py | 39 +++----- .../standard/binary_thd.py | 13 +++ .../api_accuracy_checker/standard/result.py | 3 + .../api_accuracy_checker/standard/ulp_thd.py | 42 +++++++++ 7 files changed, 138 insertions(+), 72 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py index 5d6dc77296..7a6e54b942 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/algorithm.py @@ -228,3 +228,11 @@ def get_ulp_err(bench_output, device_output, dtype): def calc_ulp_err(bench_output, device_output, eb, exponent_num, data_type): return (device_output.astype(data_type) - bench_output).astype(data_type) * \ np.exp2(-eb + exponent_num).astype(data_type) + + +def compare_bool_tensor(bench_output, device_output): + error_nums = (bench_output != device_output).sum() + error_rate = float(error_nums / bench_output.size) + result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR + return error_rate, result, "" + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index df37d50b2f..0e5f098653 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -25,10 +25,13 @@ from msprobe.core.common.file_utils import get_json_contents, write_csv import torch from msprobe.core.common.const import CompareConst from msprobe.pytorch.api_accuracy_checker.standard.absolute_thd import AbsolutethdCompare +from msprobe.pytorch.api_accuracy_checker.standard.benchmark import BenchmarkCompare +from msprobe.pytorch.api_accuracy_checker.standard.ulp_thd import UlpCompare +from msprobe.pytorch.api_accuracy_checker.standard.binary_thd import BinaryCompare from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ - check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err + check_small_value, check_norm_value, get_abs_bench_with_eps, get_ulp_err, compare_bool_tensor from msprobe.pytorch.api_accuracy_checker.common.config import msCheckerConfig from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import check_dtype_comparable, \ @@ -102,15 +105,6 @@ class Comparator: compare_column.error_rate = 0 return CompareConst.PASS, compare_column, "" - @staticmethod - def _compare_bool_tensor(bench_output, device_output): - error_nums = (bench_output != device_output).sum() - if bench_output.size == 0: - return CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." - error_rate = float(error_nums / bench_output.size) - result = CompareConst.PASS if error_rate == 0 else CompareConst.ERROR - return error_rate, result, "" - @staticmethod def _get_absolute_threshold_attribute(api_name, dtype): small_value_threshold = apis_threshold.get(api_name).get(dtype).get('small_value') @@ -313,7 +307,10 @@ class Comparator: np.int64, np.uint64]: message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \ f"Only judged by Error Rate." - err_rate, status, msg = self._compare_bool_tensor(bench_output, device_output) + if bench_output.size == 0: + err_rate, status, msg = CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." + else: + err_rate, status, msg = compare_bool_tensor(bench_output, device_output) message += msg + "\n" compare_column.error_rate = err_rate return status, compare_column, message @@ -333,8 +330,12 @@ class Comparator: if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) if api_name in binary_standard_api: - err_rate, _, _ = self._compare_bool_tensor(bench_output, device_output) - compare_column.error_rate = err_rate + if bench_output.size == 0: + compare_column.error_rate = CompareConst.NAN + return CompareConst.ERROR, compare_column, "There is not bench calculation result." + + binary_compare = BinaryCompare(bench_output, device_output) + compare_column.error_rate = binary_compare.comptute_with_golden() elif api_name in absolute_standard_api: # small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( # api_name, str(dtype)) @@ -359,27 +360,50 @@ class Comparator: compare_column.mean_ulp_error = 0 compare_column.ulp_error_proportion = 0 else: - ulp_err = get_ulp_err(bench_output, device_output, dtype) - compare_column.max_ulp_error = np.max(ulp_err) - compare_column.mean_ulp_error = np.mean(ulp_err) - if dtype == torch.float32: - compare_column.ulp_error_proportion = \ - np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size - else: - compare_column.ulp_error_proportion = \ - np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size + # ulp_err = get_ulp_err(bench_output, device_output, dtype) + # compare_column.max_ulp_error = np.max(ulp_err) + # compare_column.mean_ulp_error = np.mean(ulp_err) + # if dtype == torch.float32: + # compare_column.ulp_error_proportion = \ + # np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size + # else: + # compare_column.ulp_error_proportion = \ + # np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size + ulp_compare = UlpCompare(bench_output, device_output, dtype) + acc_result = ulp_compare.compute_with_golden() + compare_column.max_ulp_error = acc_result.max_ulp_error + compare_column.mean_ulp_error = acc_result.mean_ulp_error + compare_column.ulp_error_proportion = acc_result.ulp_error_proportion + print("ulp") + print(compare_column.max_ulp_error) + print(compare_column.mean_ulp_error) + print(compare_column.ulp_error_proportion) else: - dtype_config = precision_configs.get(dtype) - small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0]) - abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0]) - compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) - rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) - compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) - compare_column.eb = get_error_balance(bench_output, device_output) - if rel_err.size == 0: + # dtype_config = precision_configs.get(dtype) + # small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0]) + # abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0]) + # compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) + # rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) + # compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) + # compare_column.eb = get_error_balance(bench_output, device_output) + # if rel_err.size == 0: + # return CompareConst.ERROR, compare_column, "Relative error result list is empty." + # compare_column.max_rel_error = get_max_rel_err(rel_err) + # compare_column.mean_rel_error = get_mean_rel_err(rel_err) + bench_compare = BenchmarkCompare(bench_output, device_output, dtype) + acc_result = bench_compare.compute_with_golden() + if compare_column.rel_err_size == 0: return CompareConst.ERROR, compare_column, "Relative error result list is empty." - compare_column.max_rel_error = get_max_rel_err(rel_err) - compare_column.mean_rel_error = get_mean_rel_err(rel_err) + compare_column.small_value_err_ratio = acc_result.small_value_err_ratio + compare_column.rmse = acc_result.rmse + compare_column.eb = acc_result.eb + compare_column.max_rel_error = acc_result.max_rel_error + compare_column.mean_rel_error = acc_result.mean_rel_error + print(compare_column.small_value_err_ratio) + print(compare_column.rmse) + print(compare_column.eb) + print(compare_column.max_rel_error) + print(compare_column.mean_rel_error) cos_res, cos_status, msg = cosine_sim(bench_output, device_output) compare_column.cosine_sim = cos_res @@ -426,7 +450,7 @@ class Comparator: compare = Comparator(result_csv_path="./result.csv", details_csv_path="./details.csv", is_continue_run_ut=False) -api_name = "mul" +api_name = "matmul" bench_output = torch.rand(1,2) device_output = torch.rand(1,2) bench_output = bench_output.cpu().numpy() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py index 64b075b6e5..4425b8437f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -29,14 +29,7 @@ class AbsolutethdCompare(BaseCompare): return small_value, small_value_atol, rtol - def stat_inf_nan_value(self, inf_nan_mask, rtol): - return check_inf_nan_value(inf_nan_mask, self.bench_output, self.device_output, self.dtype, rtol) - - def stat_norm_value(self, normal_value_mask, rel_err, rtol): - return check_norm_value(normal_value_mask, rel_err, rtol) - - def stat_small_value(self, abs_bench, both_finite_mask, small_value): - return check_small_value(abs_bench, both_finite_mask, small_value) + def compute_with_golden(self): acc_result = AccResult() @@ -49,9 +42,9 @@ class AbsolutethdCompare(BaseCompare): # print(small_value, small_value_atol, rotl) small_value_mask = self.stat_small_value_mask(abs_bench, both_finite_mask, small_value) normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) - inf_nan_error_ratio = self.stat_inf_nan_value(inf_nan_mask, rtol) - rel_err_ratio = self.stat_norm_value(normal_value_mask, rel_err, rtol) - abs_err_ratio = self.stat_small_value(abs_err, small_value_mask, small_value_atol) + inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, self.bench_output, self.device_output, self.dtype, rtol) + rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) + abs_err_ratio = check_small_value(abs_bench, both_finite_mask, small_value_atol) acc_result.update( inf_nan_error_ratio=inf_nan_error_ratio, rel_err_ratio=rel_err_ratio, diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py index ecc3776987..3ff4587f2e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -3,9 +3,10 @@ import numpy as np from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare, BaseConfig from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult -from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_small_value_mask, \ - get_finite_and_infinite_mask, get_abs_err, get_small_value_err_ratio, get_rel_err, get_rmse, get_error_balance, \ - get_max_rel_err, get_mean_rel_err +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_small_value_err_ratio, get_rel_err, get_rmse, \ + get_error_balance, get_max_rel_err, get_mean_rel_err + + class BenchmarkStandard(BaseConfig): max_re_rtol = 10 avg_re_rtol = 2 @@ -31,24 +32,6 @@ class BenchmarkCompare(BaseCompare): def update(self, **kwargs): self.standard.update_threshold(**kwargs) - - def stat_small_value_err_ratio(self, small_value_mask, abs_err_greater_mask): - return get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) - - def stat_rel_er(self, abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask): - return get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) - - def stat_rmse(self, abs_err, small_value_mask): - return get_rmse(abs_err, small_value_mask) - - def stat_error_balance(self): - return get_error_balance(self.bench_output, self.device_output) - - def stat_max_rel_err(self, rel_err): - return get_max_rel_err(rel_err) - - def stat_mean_rel_err(self, rel_err): - return get_mean_rel_err(rel_err) def compute_with_golden(self): acc_result = AccResult() @@ -61,7 +44,7 @@ class BenchmarkCompare(BaseCompare): abs_err = self.stat_abs_error() - rel_err = self.stat_rel_er(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) + rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) # if rel_err.size == 0: # return CompareConst.ERROR, compare_column, "Relative error result list is empty." @@ -69,18 +52,18 @@ class BenchmarkCompare(BaseCompare): acc_result.update(rel_err_size=rel_err.size) return acc_result abs_err_greater_mask = np.greater(abs_err, small_value_atol) - small_value_err_ratio = self.stat_small_value_err_ratio(small_value_mask, abs_err_greater_mask) - rmse = self.stat_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) - eb = self.stat_error_balance() - max_rel_error = self.stat_max_rel_err(rel_err) - mean_rel_error = self.stat_mean_rel_err(rel_err) + small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) + rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) + eb = get_error_balance(self.bench_output, self.device_output) + max_rel_error = get_max_rel_err(rel_err) + mean_rel_error = get_mean_rel_err(rel_err) acc_result.update( small_value_err_ratio=small_value_err_ratio, max_rel_error=max_rel_error, mean_rel_error=mean_rel_error, rmse=rmse, - error_balance=eb + eb=eb ) return acc_result diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py new file mode 100644 index 0000000000..6f5911cd09 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py @@ -0,0 +1,13 @@ + +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor + +class BinaryCompare: + + def __init__(self, bench_output, device_output): + self.bench_output = bench_output + self.device_output = device_output + + + def compute_with_golden(self): + err_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) + return err_rate \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py index dc693b129c..9558ddca5f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py @@ -11,6 +11,9 @@ class AccResult: rel_err_ratio = None abs_err_ratio = None rel_err_size = None + max_ulp_error = None + mean_ulp_error = None + ulp_error_proportion = None def update(self, **kwargs): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py new file mode 100644 index 0000000000..fef917af3b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py @@ -0,0 +1,42 @@ +import numpy as np +import torch + +from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_ulp_err +from msprobe.core.common.const import CompareConst + +class UlpCompare: + def __init__(self, bench_output, device_output, dtype=None): + self.bench_output = bench_output + self.device_output = device_output + self.dtype = dtype + + def stat_max_ulp_err(self, ulp_err): + return np.max(ulp_err) + + def stat_mean_ulp_err(self, ulp_err): + return np.mean(ulp_err) + + def stat_ulp_error_proportion(self, ulp_err): + if self.dtype == torch.float32: + return np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / self.bench_output.size + else: + return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size + + def compute_with_golden(self): + acc_result = AccResult() + + ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype) + + max_ulp_error = self.stat_max_ulp_err(ulp_err) + mean_ulp_error = self.stat_mean_ulp_err(ulp_err) + + ulp_error_proportion = self.stat_ulp_error_proportion(ulp_err) + + + acc_result.update( + max_ulp_error=max_ulp_error, + mean_ulp_error=mean_ulp_error, + ulp_error_proportion=ulp_error_proportion + ) + return acc_result \ No newline at end of file -- Gitee From 8f6b357766f65cbd875e82bd81c208d7d096f164 Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 26 Nov 2024 19:16:28 +0800 Subject: [PATCH 3/8] fix --- .../api_accuracy_checker/compare/compare.py | 59 +++++++------ .../compare/compare_column.py | 6 ++ .../standard/absolute_thd.py | 87 +++++++++++-------- .../standard/basecompare.py | 30 +++++-- .../standard/benchmark.py | 65 +++++++------- .../standard/binary_thd.py | 17 ++-- .../standard/thousand_std.py | 17 ++++ .../api_accuracy_checker/standard/ulp_thd.py | 15 ++-- 8 files changed, 186 insertions(+), 110 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index 0e5f098653..753c53395b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -28,6 +28,7 @@ from msprobe.pytorch.api_accuracy_checker.standard.absolute_thd import Absolutet from msprobe.pytorch.api_accuracy_checker.standard.benchmark import BenchmarkCompare from msprobe.pytorch.api_accuracy_checker.standard.ulp_thd import UlpCompare from msprobe.pytorch.api_accuracy_checker.standard.binary_thd import BinaryCompare +from msprobe.pytorch.api_accuracy_checker.standard.thousand_std import ThousandthStdCompare from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ @@ -319,23 +320,27 @@ class Comparator: compare_column, npu_dtype) return status, compare_column, message + def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): message = "" abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) + if api_name in thousandth_standard_api: - thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) - compare_column.rel_err_thousandth = thousand_res + # thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) + # compare_column.rel_err_thousandth = thousand_res + thousand_compare = ThousandthStdCompare(rel_err_orign) + thousand_compare.compute_with_golden() if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) + # both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) if api_name in binary_standard_api: if bench_output.size == 0: compare_column.error_rate = CompareConst.NAN return CompareConst.ERROR, compare_column, "There is not bench calculation result." - - binary_compare = BinaryCompare(bench_output, device_output) - compare_column.error_rate = binary_compare.comptute_with_golden() + + binary_compare = BinaryCompare(bench_output, device_output, compare_column) + binary_compare.comptute_with_golden() elif api_name in absolute_standard_api: # small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( # api_name, str(dtype)) @@ -346,11 +351,13 @@ class Comparator: # dtype, rtol) # compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) # compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol) - absolute_compare = AbsolutethdCompare(bench_output, device_output, dtype) - acc_result = absolute_compare.compute_with_golden() - compare_column.inf_nan_error_ratio = acc_result.inf_nan_error_ratio - compare_column.rel_err_ratio = acc_result.rel_err_ratio - compare_column.abs_err_ratio = acc_result.abs_err_ratio + + absolute_compare = AbsolutethdCompare(bench_output, device_output, compare_column, dtype) + absolute_compare.compute_with_golden() + # acc_result = absolute_compare.compute_with_golden() + # compare_column.inf_nan_error_ratio = acc_result.inf_nan_error_ratio + # compare_column.rel_err_ratio = acc_result.rel_err_ratio + # compare_column.abs_err_ratio = acc_result.abs_err_ratio print(compare_column.inf_nan_error_ratio) print(compare_column.rel_err_ratio) print(compare_column.abs_err_ratio) @@ -369,11 +376,11 @@ class Comparator: # else: # compare_column.ulp_error_proportion = \ # np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size - ulp_compare = UlpCompare(bench_output, device_output, dtype) - acc_result = ulp_compare.compute_with_golden() - compare_column.max_ulp_error = acc_result.max_ulp_error - compare_column.mean_ulp_error = acc_result.mean_ulp_error - compare_column.ulp_error_proportion = acc_result.ulp_error_proportion + ulp_compare = UlpCompare(bench_output, device_output, compare_column, dtype) + ulp_compare.compute_with_golden() + # compare_column.max_ulp_error = acc_result.max_ulp_error + # compare_column.mean_ulp_error = acc_result.mean_ulp_error + # compare_column.ulp_error_proportion = acc_result.ulp_error_proportion print("ulp") print(compare_column.max_ulp_error) print(compare_column.mean_ulp_error) @@ -390,15 +397,17 @@ class Comparator: # return CompareConst.ERROR, compare_column, "Relative error result list is empty." # compare_column.max_rel_error = get_max_rel_err(rel_err) # compare_column.mean_rel_error = get_mean_rel_err(rel_err) - bench_compare = BenchmarkCompare(bench_output, device_output, dtype) - acc_result = bench_compare.compute_with_golden() - if compare_column.rel_err_size == 0: + bench_compare = BenchmarkCompare(bench_output, device_output, compare_column, dtype) + _, rel_err_size = bench_compare.compute_rel_err() + + if rel_err_size == 0: return CompareConst.ERROR, compare_column, "Relative error result list is empty." - compare_column.small_value_err_ratio = acc_result.small_value_err_ratio - compare_column.rmse = acc_result.rmse - compare_column.eb = acc_result.eb - compare_column.max_rel_error = acc_result.max_rel_error - compare_column.mean_rel_error = acc_result.mean_rel_error + bench_compare.compute_with_golden() + # compare_column.small_value_err_ratio = acc_result.small_value_err_ratio + # compare_column.rmse = acc_result.rmse + # compare_column.eb = acc_result.eb + # compare_column.max_rel_error = acc_result.max_rel_error + # compare_column.mean_rel_error = acc_result.mean_rel_error print(compare_column.small_value_err_ratio) print(compare_column.rmse) print(compare_column.eb) @@ -450,7 +459,7 @@ class Comparator: compare = Comparator(result_csv_path="./result.csv", details_csv_path="./details.csv", is_continue_run_ut=False) -api_name = "matmul" +api_name = "woe" bench_output = torch.rand(1,2) device_output = torch.rand(1,2) bench_output = bench_output.cpu().numpy() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py index b1cbc32346..e5cf843400 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py @@ -40,6 +40,12 @@ class CompareColumn: self.max_ulp_error = CompareConst.SPACE self.mean_ulp_error = CompareConst.SPACE self.ulp_error_proportion = CompareConst.SPACE + + def update(self, **kwargs): + for key, value in kwargs.items(): + if value is None: + continue + setattr(self, key, value) def to_column_value(self, is_pass, message): return [self.bench_type, self.npu_type, self.shape, self.cosine_sim, self.max_abs_err, self.rel_err_hundredth, diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py index 4425b8437f..6a38256081 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -7,55 +7,74 @@ from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseConfig from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult class AbolutethdConfig(BaseConfig): - rtol = { + _rtol = { torch.float16: 1e-3, torch.bfloat16: 4e-3, - torch.float32: 1e-6 + torch.float32: 1e-6, + "default": 1e-6 # 默认值也放在配置类中 } + # 提供一个公共方法来获取rtol值 + @classmethod + def get_rtol(cls, dtype): + return cls._rtol.get(dtype, cls._rtol["default"]) + + class AbsolutethdCompare(BaseCompare): - def __init__(self, bench_output, device_output, dtype=None): + def __init__(self, bench_output, device_output, compare_column, dtype=None): self.bench_output = bench_output self.device_output = device_output - self.dtype = dtype - - self.standard = AbolutethdConfig() + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") - def get_threshold(self): - small_value, small_value_atol = super().get_threshold() - - rtol = self.standard.rtol.get(self.dtype, 1e-6) + self.compare_column = compare_column + self.dtype = dtype + self.rtol = self.get_rtol() + self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() + self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() - return small_value, small_value_atol, rtol + def get_rtol(self): + return AbolutethdConfig.get_rtol(self.dtype) - + def get_rel_err(self, abs_bench_with_eps): + abs_err = self.stat_abs_error() + rel_err = abs_err / abs_bench_with_eps + return rel_err + def get_normal_value_mask(self, small_value_mask): + return np.logical_and(self.both_finite_mask, np.logical_not(small_value_mask)) + def compute_with_golden(self): - acc_result = AccResult() - abs_bench, abs_bench_with_eps = self.stat_abs_bench_with_eps() - abs_err = self.stat_abs_error() - rel_err = abs_err / abs_bench_with_eps - both_finite_mask, inf_nan_mask = self.stat_finite_and_infinite_mask() - small_value, small_value_atol, rtol = self.get_threshold() - # print(small_value, small_value_atol, rotl) - small_value_mask = self.stat_small_value_mask(abs_bench, both_finite_mask, small_value) - normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) - inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, self.bench_output, self.device_output, self.dtype, rtol) - rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) - abs_err_ratio = check_small_value(abs_bench, both_finite_mask, small_value_atol) - acc_result.update( + rel_err = self.get_rel_err(self.abs_bench_with_eps) + + small_value, small_value_atol = self.get_small_value_threshold() + + small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, small_value) + normal_value_mask = self.get_normal_value_mask(small_value_mask) + inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype, + self.rtol) + rel_err_ratio = check_norm_value(normal_value_mask, rel_err, self.rtol) + abs_err_ratio = check_small_value(self.abs_bench, self.both_finite_mask, small_value_atol) + print(abs_err_ratio, rel_err_ratio, inf_nan_error_ratio) + + # self.compare_column.inf_nan_error_ratio = inf_nan_error_ratio + # self.compare_column.rel_err_ratio = rel_err_ratio + # self.compare_column.abs_err_ratio = abs_err_ratio + self.update_acc_result( + self.compare_column, inf_nan_error_ratio=inf_nan_error_ratio, rel_err_ratio=rel_err_ratio, abs_err_ratio=abs_err_ratio ) - return acc_result - -bench_output = torch.rand(1,2) -device_output = torch.rand(1,2) -bench_output = bench_output.cpu().numpy() -device_output = device_output.cpu().numpy() -dtype = torch.float16 -bench = AbsolutethdCompare(bench_output, device_output, dtype) -bench.compute_with_golden() + # return acc_result + +# bench_output = torch.rand(1,2) +# device_output = torch.rand(1,2) +# bench_output = bench_output.cpu().numpy() +# device_output = device_output.cpu().numpy() +# dtype = torch.float16 +# bench = AbsolutethdCompare(bench_output, device_output, dtype) +# result=bench.compute_with_golden() +# print(result.rel_err_ratio) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py index 5629bd1afd..529ac47258 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py @@ -1,30 +1,40 @@ import torch from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ get_finite_and_infinite_mask, get_small_value_mask - +from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult class BaseConfig: - small_value = { + _small_value = { torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-6, + "default": 1e-6 } - small_value_atol = { + _small_value_atol = { torch.float16: 1e-5, torch.bfloat16: 1e-5, - torch.float32: 1e-9 + torch.float32: 1e-9, + "default": 1e-9 } + + @classmethod + def get_small_valuel(cls, dtype): + return cls._small_value.get(dtype, cls._small_value["default"]) + + @classmethod + def get_small_value_atol(cls, dtype): + return cls._small_value_atol.get(dtype, cls._small_value_atol["default"]) class BaseCompare: def __init__(self, bench_output, device_output, dtype=None): self.bench_output = bench_output self.device_output = device_output self.dtype = dtype - self.standard = BaseConfig() + - def get_threshold(self): - small_value = self.standard.small_value.get(self.dtype, 1e-6) - small_value_atol = self.standard.small_value_atol.get(self.dtype, 1e-9) + def get_small_value_threshold(self): + small_value = BaseConfig.get_small_valuel(self.dtype) + small_value_atol = BaseConfig.get_small_value_atol(self.dtype) return small_value, small_value_atol def stat_abs_bench_with_eps(self): @@ -43,4 +53,6 @@ class BaseCompare: small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value) return small_value_mask - \ No newline at end of file + def update_acc_result(self, compare_column, **kwargs): + compare_column.update(**kwargs) + # return acc_result \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py index 3ff4587f2e..6cc40a9cff 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -22,57 +22,60 @@ class BenchmarkStandard(BaseConfig): class BenchmarkCompare(BaseCompare): - def __init__(self, bench_output, device_output, dtype=None): + def __init__(self, bench_output, device_output, compare_column, dtype=None): self.bench_output = bench_output self.device_output = device_output + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") + self.compare_column = compare_column self.dtype = dtype + self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() + self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() + self.abs_err = self.stat_abs_error() + self.small_value, self.small_value_atol = self.get_small_value_threshold() + self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) - self.standard = BenchmarkStandard() + # self.standard = BenchmarkStandard() - def update(self, **kwargs): + # def update(self, **kwargs): - self.standard.update_threshold(**kwargs) + # self.standard.update_threshold(**kwargs) + def get_abs_err_greater_mask(self, small_value_atol): + abs_err_greater_mask = np.greater(self.abs_err, small_value_atol) + return abs_err_greater_mask + + def compute_rel_err(self): + rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask) + return rel_err, rel_err.size def compute_with_golden(self): - acc_result = AccResult() - small_value, small_value_atol = self.get_threshold() - abs_bench, abs_bench_with_eps = self.stat_abs_bench_with_eps() - both_finite_mask, inf_nan_mask = self.stat_finite_and_infinite_mask() - small_value_mask = self.stat_small_value_mask(abs_bench, both_finite_mask, small_value) + rel_err, _ = self.compute_rel_err() - abs_err = self.stat_abs_error() - - rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) - - # if rel_err.size == 0: - # return CompareConst.ERROR, compare_column, "Relative error result list is empty." - if rel_err.size == 0: - acc_result.update(rel_err_size=rel_err.size) - return acc_result - abs_err_greater_mask = np.greater(abs_err, small_value_atol) - small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) - rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) + abs_err_greater_mask = self.get_abs_err_greater_mask(self.small_value_atol) + small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, abs_err_greater_mask) + rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask)) eb = get_error_balance(self.bench_output, self.device_output) max_rel_error = get_max_rel_err(rel_err) mean_rel_error = get_mean_rel_err(rel_err) - acc_result.update( + self.update_acc_result( + self.compare_column, small_value_err_ratio=small_value_err_ratio, max_rel_error=max_rel_error, mean_rel_error=mean_rel_error, rmse=rmse, eb=eb ) - return acc_result -bench_output = torch.rand(1,2) -device_output = torch.rand(1,2) -bench_output = bench_output.cpu().numpy() -device_output = device_output.cpu().numpy() -dtype = torch.float32 -bench = BenchmarkCompare(bench_output, device_output, dtype) -acc_result = bench.compute_with_golden() -print(acc_result.max_rel_error) \ No newline at end of file + +# bench_output = torch.rand(1,2) +# device_output = torch.rand(1,2) +# bench_output = bench_output.cpu().numpy() +# device_output = device_output.cpu().numpy() +# dtype = torch.float32 +# bench = BenchmarkCompare(bench_output, device_output, dtype) +# acc_result = bench.compute_with_golden() +# print(acc_result.max_rel_error) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py index 6f5911cd09..b50ac29890 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py @@ -1,13 +1,20 @@ - +import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.algorithm import compare_bool_tensor +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare -class BinaryCompare: +class BinaryCompare(BaseCompare): - def __init__(self, bench_output, device_output): + def __init__(self, bench_output, device_output, compare_column): self.bench_output = bench_output self.device_output = device_output + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") + self.compare_column = compare_column def compute_with_golden(self): - err_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) - return err_rate \ No newline at end of file + error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) + self.update_acc_result( + self.compare_column, + error_rate=error_rate + ) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py new file mode 100644 index 0000000000..f6c40ba1c0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py @@ -0,0 +1,17 @@ +from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rel_err_ratio +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare + +class ThousandthStdCompare(BaseCompare): + + def __init__(self, rel_err_orign, compare_column): + self.rel_err_orign = rel_err_orign + self.compare_column = compare_column + + + def compute_with_golden(self): + rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) + self.update_acc_result( + self.compare_column, + rel_err_thousandth=rel_err_thousandth, + ) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py index fef917af3b..a24cb67dfe 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py @@ -2,13 +2,17 @@ import numpy as np import torch from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult +from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_ulp_err from msprobe.core.common.const import CompareConst -class UlpCompare: - def __init__(self, bench_output, device_output, dtype=None): +class UlpCompare(BaseCompare): + def __init__(self, bench_output, device_output, compare_column, dtype=None): self.bench_output = bench_output self.device_output = device_output + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") + self.compare_column = compare_column self.dtype = dtype def stat_max_ulp_err(self, ulp_err): @@ -24,7 +28,6 @@ class UlpCompare: return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size def compute_with_golden(self): - acc_result = AccResult() ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype) @@ -34,9 +37,9 @@ class UlpCompare: ulp_error_proportion = self.stat_ulp_error_proportion(ulp_err) - acc_result.update( + self.update_acc_result( + self.compare_column, max_ulp_error=max_ulp_error, mean_ulp_error=mean_ulp_error, ulp_error_proportion=ulp_error_proportion - ) - return acc_result \ No newline at end of file + ) \ No newline at end of file -- Gitee From 25f477cd047310074bbacbdd8c511176ad3df410 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 27 Nov 2024 15:20:35 +0800 Subject: [PATCH 4/8] fix --- .../api_accuracy_checker/compare/compare.py | 77 +++---------------- .../standard/absolute_thd.py | 32 +++----- .../standard/basecompare.py | 11 ++- .../standard/benchmark.py | 2 - 4 files changed, 31 insertions(+), 91 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index 753c53395b..f82f8ee499 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -320,99 +320,46 @@ class Comparator: compare_column, npu_dtype) return status, compare_column, message - - def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): - message = "" - abs_bench, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) - abs_err = get_abs_err(bench_output, device_output) - rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) - + def _perform_comparison(self, api_name, data, compare_column, dtype, rel_err_orign): + bench_output, device_output = data[0], data[1] if api_name in thousandth_standard_api: - # thousand_res, thousand_status = get_rel_err_ratio(rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) - # compare_column.rel_err_thousandth = thousand_res thousand_compare = ThousandthStdCompare(rel_err_orign) thousand_compare.compute_with_golden() if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - # both_finite_mask, inf_nan_mask = get_finite_and_infinite_mask(bench_output, device_output) if api_name in binary_standard_api: if bench_output.size == 0: compare_column.error_rate = CompareConst.NAN return CompareConst.ERROR, compare_column, "There is not bench calculation result." binary_compare = BinaryCompare(bench_output, device_output, compare_column) - binary_compare.comptute_with_golden() + binary_compare.compute_with_golden() elif api_name in absolute_standard_api: - # small_value_threshold, small_value_atol, rtol = self._get_absolute_threshold_attribute( - # api_name, str(dtype)) - # rel_err = abs_err / abs_bench_with_eps - # small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value_threshold) - # normal_value_mask = np.logical_and(both_finite_mask, np.logical_not(small_value_mask)) - # compare_column.inf_nan_error_ratio = check_inf_nan_value(inf_nan_mask, bench_output, device_output, - # dtype, rtol) - # compare_column.rel_err_ratio = check_norm_value(normal_value_mask, rel_err, rtol) - # compare_column.abs_err_ratio = check_small_value(abs_err, small_value_mask, small_value_atol) - absolute_compare = AbsolutethdCompare(bench_output, device_output, compare_column, dtype) absolute_compare.compute_with_golden() - # acc_result = absolute_compare.compute_with_golden() - # compare_column.inf_nan_error_ratio = acc_result.inf_nan_error_ratio - # compare_column.rel_err_ratio = acc_result.rel_err_ratio - # compare_column.abs_err_ratio = acc_result.abs_err_ratio - print(compare_column.inf_nan_error_ratio) - print(compare_column.rel_err_ratio) - print(compare_column.abs_err_ratio) elif api_name in ulp_standard_api: if bench_output.size == 0: compare_column.max_ulp_error = 0 compare_column.mean_ulp_error = 0 compare_column.ulp_error_proportion = 0 else: - # ulp_err = get_ulp_err(bench_output, device_output, dtype) - # compare_column.max_ulp_error = np.max(ulp_err) - # compare_column.mean_ulp_error = np.mean(ulp_err) - # if dtype == torch.float32: - # compare_column.ulp_error_proportion = \ - # np.sum(ulp_err > CompareConst.ULP_FLOAT32_THRESHOLD) / bench_output.size - # else: - # compare_column.ulp_error_proportion = \ - # np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / bench_output.size ulp_compare = UlpCompare(bench_output, device_output, compare_column, dtype) ulp_compare.compute_with_golden() - # compare_column.max_ulp_error = acc_result.max_ulp_error - # compare_column.mean_ulp_error = acc_result.mean_ulp_error - # compare_column.ulp_error_proportion = acc_result.ulp_error_proportion - print("ulp") - print(compare_column.max_ulp_error) - print(compare_column.mean_ulp_error) - print(compare_column.ulp_error_proportion) else: - # dtype_config = precision_configs.get(dtype) - # small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, dtype_config['small_value'][0]) - # abs_err_greater_mask = np.greater(abs_err, dtype_config['small_value_atol'][0]) - # compare_column.small_value_err_ratio = get_small_value_err_ratio(small_value_mask, abs_err_greater_mask) - # rel_err = get_rel_err(abs_err, abs_bench_with_eps, small_value_mask, inf_nan_mask) - # compare_column.rmse = get_rmse(abs_err, np.logical_or(inf_nan_mask, small_value_mask)) - # compare_column.eb = get_error_balance(bench_output, device_output) - # if rel_err.size == 0: - # return CompareConst.ERROR, compare_column, "Relative error result list is empty." - # compare_column.max_rel_error = get_max_rel_err(rel_err) - # compare_column.mean_rel_error = get_mean_rel_err(rel_err) bench_compare = BenchmarkCompare(bench_output, device_output, compare_column, dtype) _, rel_err_size = bench_compare.compute_rel_err() if rel_err_size == 0: return CompareConst.ERROR, compare_column, "Relative error result list is empty." bench_compare.compute_with_golden() - # compare_column.small_value_err_ratio = acc_result.small_value_err_ratio - # compare_column.rmse = acc_result.rmse - # compare_column.eb = acc_result.eb - # compare_column.max_rel_error = acc_result.max_rel_error - # compare_column.mean_rel_error = acc_result.mean_rel_error - print(compare_column.small_value_err_ratio) - print(compare_column.rmse) - print(compare_column.eb) - print(compare_column.max_rel_error) - print(compare_column.mean_rel_error) + + + def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): + message = "" + _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) + abs_err = get_abs_err(bench_output, device_output) + rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) + + self._perform_comparison(api_name, [bench_output, device_output], compare_column, dtype, rel_err_orign) cos_res, cos_status, msg = cosine_sim(bench_output, device_output) compare_column.cosine_sim = cos_res diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py index 6a38256081..985143fc58 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -5,20 +5,7 @@ from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan check_small_value from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseConfig, BaseCompare from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult - -class AbolutethdConfig(BaseConfig): - _rtol = { - torch.float16: 1e-3, - torch.bfloat16: 4e-3, - torch.float32: 1e-6, - "default": 1e-6 # 默认值也放在配置类中 - } - - # 提供一个公共方法来获取rtol值 - @classmethod - def get_rtol(cls, dtype): - return cls._rtol.get(dtype, cls._rtol["default"]) - +from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn class AbsolutethdCompare(BaseCompare): @@ -31,17 +18,18 @@ class AbsolutethdCompare(BaseCompare): self.compare_column = compare_column self.dtype = dtype self.rtol = self.get_rtol() + print("rtol:", self.rtol) self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() def get_rtol(self): - return AbolutethdConfig.get_rtol(self.dtype) - + return BaseConfig.get_rtol(self.dtype) + def get_rel_err(self, abs_bench_with_eps): abs_err = self.stat_abs_error() rel_err = abs_err / abs_bench_with_eps return rel_err - + def get_normal_value_mask(self, small_value_mask): return np.logical_and(self.both_finite_mask, np.logical_not(small_value_mask)) @@ -59,22 +47,20 @@ class AbsolutethdCompare(BaseCompare): abs_err_ratio = check_small_value(self.abs_bench, self.both_finite_mask, small_value_atol) print(abs_err_ratio, rel_err_ratio, inf_nan_error_ratio) - # self.compare_column.inf_nan_error_ratio = inf_nan_error_ratio - # self.compare_column.rel_err_ratio = rel_err_ratio - # self.compare_column.abs_err_ratio = abs_err_ratio self.update_acc_result( self.compare_column, inf_nan_error_ratio=inf_nan_error_ratio, rel_err_ratio=rel_err_ratio, abs_err_ratio=abs_err_ratio ) - # return acc_result + +# compare_column = CompareColumn() # bench_output = torch.rand(1,2) # device_output = torch.rand(1,2) # bench_output = bench_output.cpu().numpy() # device_output = device_output.cpu().numpy() # dtype = torch.float16 -# bench = AbsolutethdCompare(bench_output, device_output, dtype) -# result=bench.compute_with_golden() +# bench = AbsolutethdCompare(bench_output, device_output,compare_column, dtype) +# bench.compute_with_golden() # print(result.rel_err_ratio) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py index 529ac47258..1ccdd0eeab 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py @@ -16,6 +16,12 @@ class BaseConfig: torch.float32: 1e-9, "default": 1e-9 } + _rtol = { + torch.float16: 1e-3, + torch.bfloat16: 4e-3, + torch.float32: 1e-6, + "default": 1e-6 # 默认值也放在配置类中 + } @classmethod def get_small_valuel(cls, dtype): @@ -24,6 +30,10 @@ class BaseConfig: @classmethod def get_small_value_atol(cls, dtype): return cls._small_value_atol.get(dtype, cls._small_value_atol["default"]) + + @classmethod + def get_rtol(cls, dtype): + return cls._rtol.get(dtype, cls._rtol["default"]) class BaseCompare: def __init__(self, bench_output, device_output, dtype=None): @@ -55,4 +65,3 @@ class BaseCompare: def update_acc_result(self, compare_column, **kwargs): compare_column.update(**kwargs) - # return acc_result \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py index 6cc40a9cff..e7c7caf55f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -49,8 +49,6 @@ class BenchmarkCompare(BaseCompare): return rel_err, rel_err.size def compute_with_golden(self): - - rel_err, _ = self.compute_rel_err() abs_err_greater_mask = self.get_abs_err_greater_mask(self.small_value_atol) -- Gitee From 22aee8586a4f3272ad651ea77389e8f6e8095c44 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 27 Nov 2024 17:28:23 +0800 Subject: [PATCH 5/8] fix --- .../api_accuracy_checker/compare/compare.py | 26 +++++----- .../compare/compare_column.py | 4 +- .../standard/absolute_thd.py | 6 +-- .../standard/basecompare.py | 23 +++++++-- .../standard/benchmark.py | 48 ++++++++----------- .../standard/binary_thd.py | 4 +- .../standard/thousand_std.py | 4 +- .../api_accuracy_checker/standard/ulp_thd.py | 4 +- 8 files changed, 65 insertions(+), 54 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index f82f8ee499..dbe11ef66b 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -304,14 +304,13 @@ class Comparator: return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \ f"npu output dtype is {device_output.dtype}, cannot compare." message = "" + if bench_output.size == 0: + return CompareConst.ERROR, compare_column, "There is not bench calculation result." if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32, np.int64, np.uint64]: message += f"Compare algorithm is not supported for {bench_output.dtype} data. " \ f"Only judged by Error Rate." - if bench_output.size == 0: - err_rate, status, msg = CompareConst.NAN, CompareConst.ERROR, "There is not bench calculation result." - else: - err_rate, status, msg = compare_bool_tensor(bench_output, device_output) + err_rate, status, msg = compare_bool_tensor(bench_output, device_output) message += msg + "\n" compare_column.error_rate = err_rate return status, compare_column, message @@ -324,7 +323,7 @@ class Comparator: bench_output, device_output = data[0], data[1] if api_name in thousandth_standard_api: thousand_compare = ThousandthStdCompare(rel_err_orign) - thousand_compare.compute_with_golden() + thousand_compare.compare() if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: if api_name in binary_standard_api: if bench_output.size == 0: @@ -332,10 +331,10 @@ class Comparator: return CompareConst.ERROR, compare_column, "There is not bench calculation result." binary_compare = BinaryCompare(bench_output, device_output, compare_column) - binary_compare.compute_with_golden() + binary_compare.compare() elif api_name in absolute_standard_api: absolute_compare = AbsolutethdCompare(bench_output, device_output, compare_column, dtype) - absolute_compare.compute_with_golden() + absolute_compare.compare() elif api_name in ulp_standard_api: if bench_output.size == 0: compare_column.max_ulp_error = 0 @@ -343,14 +342,17 @@ class Comparator: compare_column.ulp_error_proportion = 0 else: ulp_compare = UlpCompare(bench_output, device_output, compare_column, dtype) - ulp_compare.compute_with_golden() + ulp_compare.compare() else: bench_compare = BenchmarkCompare(bench_output, device_output, compare_column, dtype) - _, rel_err_size = bench_compare.compute_rel_err() + + bench_compare.compare() - if rel_err_size == 0: - return CompareConst.ERROR, compare_column, "Relative error result list is empty." - bench_compare.compute_with_golden() + print(compare_column.small_value_err_ratio) + print(compare_column.max_rel_error) + print(compare_column.mean_rel_error) + print(compare_column.rmse) + print(compare_column.eb) def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py index e5cf843400..87cd5dc84c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare_column.py @@ -41,8 +41,8 @@ class CompareColumn: self.mean_ulp_error = CompareConst.SPACE self.ulp_error_proportion = CompareConst.SPACE - def update(self, **kwargs): - for key, value in kwargs.items(): + def update(self, metrics): + for key, value in metrics.items(): if value is None: continue setattr(self, key, value) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py index 985143fc58..3c5810a8e9 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -33,7 +33,7 @@ class AbsolutethdCompare(BaseCompare): def get_normal_value_mask(self, small_value_mask): return np.logical_and(self.both_finite_mask, np.logical_not(small_value_mask)) - def compute_with_golden(self): + def compute_metrics(self): rel_err = self.get_rel_err(self.abs_bench_with_eps) @@ -47,7 +47,7 @@ class AbsolutethdCompare(BaseCompare): abs_err_ratio = check_small_value(self.abs_bench, self.both_finite_mask, small_value_atol) print(abs_err_ratio, rel_err_ratio, inf_nan_error_ratio) - self.update_acc_result( + self.record_compare_result( self.compare_column, inf_nan_error_ratio=inf_nan_error_ratio, rel_err_ratio=rel_err_ratio, @@ -62,5 +62,5 @@ class AbsolutethdCompare(BaseCompare): # device_output = device_output.cpu().numpy() # dtype = torch.float16 # bench = AbsolutethdCompare(bench_output, device_output,compare_column, dtype) -# bench.compute_with_golden() +# bench.compute_metrics() # print(result.rel_err_ratio) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py index 1ccdd0eeab..416d204ef1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py @@ -1,4 +1,5 @@ import torch +import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ get_finite_and_infinite_mask, get_small_value_mask from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult @@ -36,12 +37,29 @@ class BaseConfig: return cls._rtol.get(dtype, cls._rtol["default"]) class BaseCompare: - def __init__(self, bench_output, device_output, dtype=None): + def __init__(self, bench_output, device_output, compare_column, dtype=None): self.bench_output = bench_output self.device_output = device_output + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") + self.compare_column = compare_column self.dtype = dtype + def pre_compare(self): + pass + def compare(self): + self.pre_compare() + metrics = self.compute_metrics() + self.post_compare(metrics) + + def compute_metrics(self): + metrics = {} + return metrics + + def post_compare(self, metrics): + self.compare_column.update(metrics) + def get_small_value_threshold(self): small_value = BaseConfig.get_small_valuel(self.dtype) small_value_atol = BaseConfig.get_small_value_atol(self.dtype) @@ -63,5 +81,4 @@ class BaseCompare: small_value_mask = get_small_value_mask(abs_bench, both_finite_mask, small_value) return small_value_mask - def update_acc_result(self, compare_column, **kwargs): - compare_column.update(**kwargs) + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py index e7c7caf55f..5708bcd5ad 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -23,23 +23,18 @@ class BenchmarkStandard(BaseConfig): class BenchmarkCompare(BaseCompare): def __init__(self, bench_output, device_output, compare_column, dtype=None): - self.bench_output = bench_output - self.device_output = device_output - if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): - raise TypeError("The input should be numpy array") - self.compare_column = compare_column - self.dtype = dtype + super(BenchmarkCompare, self).__init__(bench_output, device_output, compare_column, dtype) + + + def pre_compare(self): self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() self.abs_err = self.stat_abs_error() self.small_value, self.small_value_atol = self.get_small_value_threshold() self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) - - # self.standard = BenchmarkStandard() - - # def update(self, **kwargs): - - # self.standard.update_threshold(**kwargs) + self.rel_err, _ = self.compute_rel_err() + self.abs_err_greater_mask = self.get_abs_err_greater_mask(self.small_value_atol) + def get_abs_err_greater_mask(self, small_value_atol): abs_err_greater_mask = np.greater(self.abs_err, small_value_atol) return abs_err_greater_mask @@ -48,25 +43,22 @@ class BenchmarkCompare(BaseCompare): rel_err = get_rel_err(self.abs_err, self.abs_bench_with_eps, self.small_value_mask, self.inf_nan_mask) return rel_err, rel_err.size - def compute_with_golden(self): - rel_err, _ = self.compute_rel_err() + def compute_metrics(self): - abs_err_greater_mask = self.get_abs_err_greater_mask(self.small_value_atol) - small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, abs_err_greater_mask) + small_value_err_ratio = get_small_value_err_ratio(self.small_value_mask, self.abs_err_greater_mask) rmse = get_rmse(self.abs_err, np.logical_or(self.inf_nan_mask, self.small_value_mask)) eb = get_error_balance(self.bench_output, self.device_output) - max_rel_error = get_max_rel_err(rel_err) - mean_rel_error = get_mean_rel_err(rel_err) - - self.update_acc_result( - self.compare_column, - small_value_err_ratio=small_value_err_ratio, - max_rel_error=max_rel_error, - mean_rel_error=mean_rel_error, - rmse=rmse, - eb=eb - ) + max_rel_error = get_max_rel_err(self.rel_err) + mean_rel_error = get_mean_rel_err(self.rel_err) + metrics = { + "small_value_err_ratio": small_value_err_ratio, + "max_rel_error": max_rel_error, + "mean_rel_error": mean_rel_error, + "rmse": rmse, + "eb": eb + } + return metrics # bench_output = torch.rand(1,2) @@ -75,5 +67,5 @@ class BenchmarkCompare(BaseCompare): # device_output = device_output.cpu().numpy() # dtype = torch.float32 # bench = BenchmarkCompare(bench_output, device_output, dtype) -# acc_result = bench.compute_with_golden() +# acc_result = bench.compute_metrics() # print(acc_result.max_rel_error) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py index b50ac29890..bb53f3e258 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py @@ -12,9 +12,9 @@ class BinaryCompare(BaseCompare): self.compare_column = compare_column - def compute_with_golden(self): + def compute_metrics(self): error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) - self.update_acc_result( + self.record_compare_result( self.compare_column, error_rate=error_rate ) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py index f6c40ba1c0..6c1538c74e 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py @@ -9,9 +9,9 @@ class ThousandthStdCompare(BaseCompare): self.compare_column = compare_column - def compute_with_golden(self): + def compute_metrics(self): rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) - self.update_acc_result( + self.record_compare_result( self.compare_column, rel_err_thousandth=rel_err_thousandth, ) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py index a24cb67dfe..cd9660c251 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py @@ -27,7 +27,7 @@ class UlpCompare(BaseCompare): else: return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size - def compute_with_golden(self): + def compute_metrics(self): ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype) @@ -37,7 +37,7 @@ class UlpCompare(BaseCompare): ulp_error_proportion = self.stat_ulp_error_proportion(ulp_err) - self.update_acc_result( + self.record_compare_result( self.compare_column, max_ulp_error=max_ulp_error, mean_ulp_error=mean_ulp_error, -- Gitee From 0cbc9015eac95ae1cd64aafe8236ab24afc33dbf Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 27 Nov 2024 17:52:42 +0800 Subject: [PATCH 6/8] fix --- .../api_accuracy_checker/compare/compare.py | 14 +++++-- .../standard/absolute_thd.py | 40 ++++++++----------- .../standard/binary_thd.py | 16 +++----- .../standard/thousand_std.py | 8 ++-- .../api_accuracy_checker/standard/ulp_thd.py | 32 +++++++-------- 5 files changed, 51 insertions(+), 59 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index dbe11ef66b..19da814eb5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -322,9 +322,10 @@ class Comparator: def _perform_comparison(self, api_name, data, compare_column, dtype, rel_err_orign): bench_output, device_output = data[0], data[1] if api_name in thousandth_standard_api: - thousand_compare = ThousandthStdCompare(rel_err_orign) + thousand_compare = ThousandthStdCompare(rel_err_orign, compare_column) thousand_compare.compare() - if str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: + print(compare_column.rel_err_thousandth) + elif str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: if api_name in binary_standard_api: if bench_output.size == 0: compare_column.error_rate = CompareConst.NAN @@ -332,9 +333,13 @@ class Comparator: binary_compare = BinaryCompare(bench_output, device_output, compare_column) binary_compare.compare() + print(compare_column.error_rate) elif api_name in absolute_standard_api: absolute_compare = AbsolutethdCompare(bench_output, device_output, compare_column, dtype) absolute_compare.compare() + print(compare_column.inf_nan_error_ratio) + print(compare_column.rel_err_ratio) + print(compare_column.abs_err_ratio) elif api_name in ulp_standard_api: if bench_output.size == 0: compare_column.max_ulp_error = 0 @@ -343,6 +348,9 @@ class Comparator: else: ulp_compare = UlpCompare(bench_output, device_output, compare_column, dtype) ulp_compare.compare() + print(compare_column.max_ulp_error) + print(compare_column.mean_ulp_error) + print(compare_column.ulp_error_proportion) else: bench_compare = BenchmarkCompare(bench_output, device_output, compare_column, dtype) @@ -408,7 +416,7 @@ class Comparator: compare = Comparator(result_csv_path="./result.csv", details_csv_path="./details.csv", is_continue_run_ut=False) -api_name = "woe" +api_name = "matmul" bench_output = torch.rand(1,2) device_output = torch.rand(1,2) bench_output = bench_output.cpu().numpy() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py index 3c5810a8e9..fbfc8747be 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -10,18 +10,17 @@ from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareC class AbsolutethdCompare(BaseCompare): def __init__(self, bench_output, device_output, compare_column, dtype=None): - self.bench_output = bench_output - self.device_output = device_output - if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): - raise TypeError("The input should be numpy array") + super(AbsolutethdCompare, self).__init__(bench_output, device_output, compare_column, dtype) - self.compare_column = compare_column - self.dtype = dtype - self.rtol = self.get_rtol() - print("rtol:", self.rtol) + + def pre_compare(self): self.abs_bench, self.abs_bench_with_eps = self.stat_abs_bench_with_eps() self.both_finite_mask, self.inf_nan_mask = self.stat_finite_and_infinite_mask() - + self.rtol = self.get_rtol() + self.rel_err = self.get_rel_err(self.abs_bench_with_eps) + self.small_value, self.small_value_atol = self.get_small_value_threshold() + self.small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, self.small_value) + self.normal_value_mask = self.get_normal_value_mask(self.small_value_mask) def get_rtol(self): return BaseConfig.get_rtol(self.dtype) @@ -35,24 +34,17 @@ class AbsolutethdCompare(BaseCompare): def compute_metrics(self): - rel_err = self.get_rel_err(self.abs_bench_with_eps) - - small_value, small_value_atol = self.get_small_value_threshold() - - small_value_mask = self.stat_small_value_mask(self.abs_bench, self.both_finite_mask, small_value) - normal_value_mask = self.get_normal_value_mask(small_value_mask) inf_nan_error_ratio = check_inf_nan_value(self.inf_nan_mask, self.bench_output, self.device_output, self.dtype, self.rtol) - rel_err_ratio = check_norm_value(normal_value_mask, rel_err, self.rtol) - abs_err_ratio = check_small_value(self.abs_bench, self.both_finite_mask, small_value_atol) + rel_err_ratio = check_norm_value(self.normal_value_mask, self.rel_err, self.rtol) + abs_err_ratio = check_small_value(self.abs_bench, self.both_finite_mask, self.small_value_atol) print(abs_err_ratio, rel_err_ratio, inf_nan_error_ratio) - - self.record_compare_result( - self.compare_column, - inf_nan_error_ratio=inf_nan_error_ratio, - rel_err_ratio=rel_err_ratio, - abs_err_ratio=abs_err_ratio - ) + metrics = { + "inf_nan_error_ratio": inf_nan_error_ratio, + "rel_err_ratio": rel_err_ratio, + "abs_err_ratio": abs_err_ratio + } + return metrics # compare_column = CompareColumn() diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py index bb53f3e258..02dbc9e27c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py @@ -5,16 +5,12 @@ from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompar class BinaryCompare(BaseCompare): def __init__(self, bench_output, device_output, compare_column): - self.bench_output = bench_output - self.device_output = device_output - if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): - raise TypeError("The input should be numpy array") - self.compare_column = compare_column - + super(BinaryCompare, self).__init__(bench_output, device_output, compare_column) def compute_metrics(self): error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) - self.record_compare_result( - self.compare_column, - error_rate=error_rate - ) + metrics = { + "error_rate": error_rate + } + return metrics + diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py index 6c1538c74e..ba46e5ac33 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py @@ -11,7 +11,7 @@ class ThousandthStdCompare(BaseCompare): def compute_metrics(self): rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) - self.record_compare_result( - self.compare_column, - rel_err_thousandth=rel_err_thousandth, - ) \ No newline at end of file + metrics = { + 'rel_err_thousandth': rel_err_thousandth, + } + return metrics diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py index cd9660c251..4d252d91a0 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py @@ -8,12 +8,7 @@ from msprobe.core.common.const import CompareConst class UlpCompare(BaseCompare): def __init__(self, bench_output, device_output, compare_column, dtype=None): - self.bench_output = bench_output - self.device_output = device_output - if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): - raise TypeError("The input should be numpy array") - self.compare_column = compare_column - self.dtype = dtype + super(UlpCompare, self).__init__(bench_output, device_output, compare_column, dtype) def stat_max_ulp_err(self, ulp_err): return np.max(ulp_err) @@ -27,19 +22,20 @@ class UlpCompare(BaseCompare): else: return np.sum(ulp_err > CompareConst.ULP_FLOAT16_THRESHOLD) / self.bench_output.size + def pre_compare(self): + self.ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype) + def compute_metrics(self): + + max_ulp_error = self.stat_max_ulp_err(self.ulp_err) + mean_ulp_error = self.stat_mean_ulp_err(self.ulp_err) - ulp_err = get_ulp_err(self.bench_output, self.device_output, self.dtype) - - max_ulp_error = self.stat_max_ulp_err(ulp_err) - mean_ulp_error = self.stat_mean_ulp_err(ulp_err) - - ulp_error_proportion = self.stat_ulp_error_proportion(ulp_err) + ulp_error_proportion = self.stat_ulp_error_proportion(self.ulp_err) + metrics = { + "max_ulp_error": max_ulp_error, + "mean_ulp_error": mean_ulp_error, + "ulp_error_proportion": ulp_error_proportion + } - self.record_compare_result( - self.compare_column, - max_ulp_error=max_ulp_error, - mean_ulp_error=mean_ulp_error, - ulp_error_proportion=ulp_error_proportion - ) \ No newline at end of file + return metrics -- Gitee From cf49f7690d57587bc80f7722f88374852e9645e6 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 28 Nov 2024 17:08:17 +0800 Subject: [PATCH 7/8] fix --- .../api_accuracy_checker/compare/compare.py | 33 +++++++++++++------ .../standard/absolute_thd.py | 7 ++-- .../standard/basecompare.py | 16 ++++----- .../standard/benchmark.py | 5 ++- .../standard/binary_thd.py | 4 +-- .../standard/{result.py => compare_input.py} | 10 ++++++ .../standard/thousand_std.py | 9 +++-- .../api_accuracy_checker/standard/ulp_thd.py | 5 ++- 8 files changed, 54 insertions(+), 35 deletions(-) rename debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/{result.py => compare_input.py} (66%) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index 19da814eb5..d2416ecfd4 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -29,6 +29,7 @@ from msprobe.pytorch.api_accuracy_checker.standard.benchmark import BenchmarkCom from msprobe.pytorch.api_accuracy_checker.standard.ulp_thd import UlpCompare from msprobe.pytorch.api_accuracy_checker.standard.binary_thd import BinaryCompare from msprobe.pytorch.api_accuracy_checker.standard.thousand_std import ThousandthStdCompare +from msprobe.pytorch.api_accuracy_checker.standard.compare_input import CompareInput from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ @@ -304,6 +305,7 @@ class Comparator: return CompareConst.ERROR, compare_column, f"Bench out dtype is {bench_output.dtype} but " \ f"npu output dtype is {device_output.dtype}, cannot compare." message = "" + #todo:判断bench和npu的size是否一致 if bench_output.size == 0: return CompareConst.ERROR, compare_column, "There is not bench calculation result." if bench_output.dtype in [bool, np.uint8, np.int8, np.int16, np.uint16, np.uint32, np.int32, @@ -319,23 +321,22 @@ class Comparator: compare_column, npu_dtype) return status, compare_column, message - def _perform_comparison(self, api_name, data, compare_column, dtype, rel_err_orign): - bench_output, device_output = data[0], data[1] + def _perform_comparison(self, api_name, input_data): if api_name in thousandth_standard_api: - thousand_compare = ThousandthStdCompare(rel_err_orign, compare_column) + thousand_compare = ThousandthStdCompare(input_data) thousand_compare.compare() print(compare_column.rel_err_thousandth) elif str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: if api_name in binary_standard_api: - if bench_output.size == 0: + if input_data.bench_output.size == 0: compare_column.error_rate = CompareConst.NAN return CompareConst.ERROR, compare_column, "There is not bench calculation result." - binary_compare = BinaryCompare(bench_output, device_output, compare_column) + binary_compare = BinaryCompare(input_data) binary_compare.compare() print(compare_column.error_rate) elif api_name in absolute_standard_api: - absolute_compare = AbsolutethdCompare(bench_output, device_output, compare_column, dtype) + absolute_compare = AbsolutethdCompare(input_data) absolute_compare.compare() print(compare_column.inf_nan_error_ratio) print(compare_column.rel_err_ratio) @@ -346,13 +347,13 @@ class Comparator: compare_column.mean_ulp_error = 0 compare_column.ulp_error_proportion = 0 else: - ulp_compare = UlpCompare(bench_output, device_output, compare_column, dtype) + ulp_compare = UlpCompare(input_data) ulp_compare.compare() print(compare_column.max_ulp_error) print(compare_column.mean_ulp_error) print(compare_column.ulp_error_proportion) else: - bench_compare = BenchmarkCompare(bench_output, device_output, compare_column, dtype) + bench_compare = BenchmarkCompare(input_data) bench_compare.compare() @@ -368,8 +369,8 @@ class Comparator: _, abs_bench_with_eps = get_abs_bench_with_eps(bench_output, dtype) abs_err = get_abs_err(bench_output, device_output) rel_err_orign = get_rel_err_origin(abs_err, abs_bench_with_eps) - - self._perform_comparison(api_name, [bench_output, device_output], compare_column, dtype, rel_err_orign) + input_data = CompareInput(bench_output, device_output, compare_column, dtype, rel_err_orign) + self._perform_comparison(api_name, input_data) cos_res, cos_status, msg = cosine_sim(bench_output, device_output) compare_column.cosine_sim = cos_res @@ -423,4 +424,16 @@ bench_output = bench_output.cpu().numpy() device_output = device_output.cpu().numpy() dtype = torch.float16 compare_column = CompareColumn() +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "abs" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "mul" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "conv2d" +compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) + +api_name = "mean" compare._compare_float_tensor(api_name, bench_output, device_output, compare_column, dtype) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py index fbfc8747be..1c204c1bee 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/absolute_thd.py @@ -4,13 +4,12 @@ import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.algorithm import check_inf_nan_value, check_norm_value, \ check_small_value from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseConfig, BaseCompare -from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult -from msprobe.pytorch.api_accuracy_checker.compare.compare_column import CompareColumn + class AbsolutethdCompare(BaseCompare): - def __init__(self, bench_output, device_output, compare_column, dtype=None): - super(AbsolutethdCompare, self).__init__(bench_output, device_output, compare_column, dtype) + def __init__(self, input_data): + super(AbsolutethdCompare, self).__init__(input_data) def pre_compare(self): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py index 416d204ef1..2c3905e969 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/basecompare.py @@ -2,7 +2,8 @@ import torch import numpy as np from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_abs_bench_with_eps, get_abs_err, \ get_finite_and_infinite_mask, get_small_value_mask -from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult + + class BaseConfig: _small_value = { @@ -37,13 +38,12 @@ class BaseConfig: return cls._rtol.get(dtype, cls._rtol["default"]) class BaseCompare: - def __init__(self, bench_output, device_output, compare_column, dtype=None): - self.bench_output = bench_output - self.device_output = device_output - if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): - raise TypeError("The input should be numpy array") - self.compare_column = compare_column - self.dtype = dtype + def __init__(self, input_data): + self.bench_output = input_data.bench_output + self.device_output = input_data.device_output + + self.compare_column = input_data.compare_column + self.dtype = input_data.dtype def pre_compare(self): pass diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py index 5708bcd5ad..a0658869c6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/benchmark.py @@ -2,7 +2,6 @@ import torch import numpy as np from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare, BaseConfig -from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_small_value_err_ratio, get_rel_err, get_rmse, \ get_error_balance, get_max_rel_err, get_mean_rel_err @@ -22,8 +21,8 @@ class BenchmarkStandard(BaseConfig): class BenchmarkCompare(BaseCompare): - def __init__(self, bench_output, device_output, compare_column, dtype=None): - super(BenchmarkCompare, self).__init__(bench_output, device_output, compare_column, dtype) + def __init__(self, input_data): + super(BenchmarkCompare, self).__init__(input_data) def pre_compare(self): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py index 02dbc9e27c..b7225b2b7f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/binary_thd.py @@ -4,8 +4,8 @@ from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompar class BinaryCompare(BaseCompare): - def __init__(self, bench_output, device_output, compare_column): - super(BinaryCompare, self).__init__(bench_output, device_output, compare_column) + def __init__(self, input_data): + super(BinaryCompare, self).__init__(input_data) def compute_metrics(self): error_rate, _, _ = compare_bool_tensor(self.bench_output, self.device_output) diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/compare_input.py similarity index 66% rename from debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py rename to debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/compare_input.py index 9558ddca5f..f835b2d7d8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/result.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/compare_input.py @@ -1,4 +1,14 @@ +import numpy as np +class CompareInput: + def __init__(self, bench_output, device_output, compare_column, dtype=None, rel_err_orign=None): + self.bench_output = bench_output + self.device_output = device_output + if not isinstance(bench_output, np.ndarray) or not isinstance(device_output, np.ndarray): + raise TypeError("The input should be numpy array") + self.compare_column = compare_column + self.dtype = dtype + self.rel_err_orign = rel_err_orign class AccResult: small_value_err_ratio = None diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py index ba46e5ac33..0b0115454c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/thousand_std.py @@ -4,14 +4,13 @@ from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompar class ThousandthStdCompare(BaseCompare): - def __init__(self, rel_err_orign, compare_column): - self.rel_err_orign = rel_err_orign - self.compare_column = compare_column - + def __init__(self, input_data): + self.rel_err_orign = input_data.rel_err_orign + self.compare_column = input_data.compare_column def compute_metrics(self): rel_err_thousandth, _ = get_rel_err_ratio(self.rel_err_orign, CompareConst.THOUSAND_RATIO_THRESHOLD) metrics = { - 'rel_err_thousandth': rel_err_thousandth, + 'rel_err_thousandth': rel_err_thousandth } return metrics diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py index 4d252d91a0..6890cdc853 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/ulp_thd.py @@ -1,14 +1,13 @@ import numpy as np import torch -from msprobe.pytorch.api_accuracy_checker.standard.result import AccResult from msprobe.pytorch.api_accuracy_checker.standard.basecompare import BaseCompare from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_ulp_err from msprobe.core.common.const import CompareConst class UlpCompare(BaseCompare): - def __init__(self, bench_output, device_output, compare_column, dtype=None): - super(UlpCompare, self).__init__(bench_output, device_output, compare_column, dtype) + def __init__(self, input_data): + super(UlpCompare, self).__init__(input_data) def stat_max_ulp_err(self, ulp_err): return np.max(ulp_err) -- Gitee From 335f43b3c538ed581e526476b2bddeb4f29c4e66 Mon Sep 17 00:00:00 2001 From: gitee Date: Fri, 29 Nov 2024 17:14:08 +0800 Subject: [PATCH 8/8] fix --- .../api_accuracy_checker/compare/compare.py | 123 ++++++++++++------ .../api_accuracy_checker/standard/register.py | 47 +++++++ 2 files changed, 131 insertions(+), 39 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py index d2416ecfd4..e52e84825f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/compare/compare.py @@ -30,6 +30,7 @@ from msprobe.pytorch.api_accuracy_checker.standard.ulp_thd import UlpCompare from msprobe.pytorch.api_accuracy_checker.standard.binary_thd import BinaryCompare from msprobe.pytorch.api_accuracy_checker.standard.thousand_std import ThousandthStdCompare from msprobe.pytorch.api_accuracy_checker.standard.compare_input import CompareInput +from msprobe.pytorch.api_accuracy_checker.standard.register import ComparisonRegistry from msprobe.pytorch.api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, \ get_mean_rel_err, get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ @@ -64,6 +65,14 @@ class Comparator: self.detail_save_path_str = details_csv_path self.save_path_list = [result_csv_path] self.detail_save_path_list = [details_csv_path] + + self.registry = ComparisonRegistry() + self.registry.register("absolute_threshold", self._absolute_standard_compare) + self.registry.register("binary_consistency", self._binary_standard_compare) + self.registry.register("ulp_compare", self._ulp_compare) + self.registry.register("thousandth_threshold", self._thousandth_standard_compare) + self.registry.register("benchmark", self._benchmark_compare) + # if config and config.online_config.is_online: # self.save_path_str = result_csv_path.replace(".csv", "_rank{}.csv") @@ -321,47 +330,83 @@ class Comparator: compare_column, npu_dtype) return status, compare_column, message - def _perform_comparison(self, api_name, input_data): - if api_name in thousandth_standard_api: - thousand_compare = ThousandthStdCompare(input_data) - thousand_compare.compare() - print(compare_column.rel_err_thousandth) - elif str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: - if api_name in binary_standard_api: - if input_data.bench_output.size == 0: - compare_column.error_rate = CompareConst.NAN - return CompareConst.ERROR, compare_column, "There is not bench calculation result." - - binary_compare = BinaryCompare(input_data) - binary_compare.compare() - print(compare_column.error_rate) - elif api_name in absolute_standard_api: - absolute_compare = AbsolutethdCompare(input_data) - absolute_compare.compare() - print(compare_column.inf_nan_error_ratio) - print(compare_column.rel_err_ratio) - print(compare_column.abs_err_ratio) - elif api_name in ulp_standard_api: - if bench_output.size == 0: - compare_column.max_ulp_error = 0 - compare_column.mean_ulp_error = 0 - compare_column.ulp_error_proportion = 0 - else: - ulp_compare = UlpCompare(input_data) - ulp_compare.compare() - print(compare_column.max_ulp_error) - print(compare_column.mean_ulp_error) - print(compare_column.ulp_error_proportion) - else: - bench_compare = BenchmarkCompare(input_data) + def _binary_standard_compare(self, input_data): + binary_compare = BinaryCompare(input_data) + binary_compare.compare() + print(compare_column.error_rate) + + def _thousandth_standard_compare(self, input_data): + thousandth_compare = ThousandthStdCompare(input_data) + thousandth_compare.compare() + print(compare_column.rel_err_thousandth) + + def _absolute_standard_compare(self, input_data): + absolute_compare = AbsolutethdCompare(input_data) + absolute_compare.compare() + print(compare_column.inf_nan_error_ratio) + print(compare_column.rel_err_ratio) + print(compare_column.abs_err_ratio) + + def _ulp_compare(self, input_data): + ulp_compare = UlpCompare(input_data) + ulp_compare.compare() + print(compare_column.max_ulp_error) + print(compare_column.mean_ulp_error) + print(compare_column.ulp_error_proportion) + + def _benchmark_compare(self, input_data): + benchmark_compare = BenchmarkCompare(input_data) + benchmark_compare.compare() + print(compare_column.small_value_err_ratio) + print(compare_column.max_rel_error) + print(compare_column.mean_rel_error) + print(compare_column.rmse) + print(compare_column.eb) - bench_compare.compare() + def _perform_comparison(self, api_name, input_data): + comparison_func = self.registry.get_comparison_function(api_name) + if comparison_func: + comparison_func(input_data) + # if api_name in thousandth_standard_api: + # thousand_compare = ThousandthStdCompare(input_data) + # thousand_compare.compare() + # print(compare_column.rel_err_thousandth) + # elif str(dtype) in BENCHMARK_COMPARE_SUPPORT_LIST: + # if api_name in binary_standard_api: + # if input_data.bench_output.size == 0: + # compare_column.error_rate = CompareConst.NAN + # return CompareConst.ERROR, compare_column, "There is not bench calculation result." + + # binary_compare = BinaryCompare(input_data) + # binary_compare.compare() + # print(compare_column.error_rate) + # elif api_name in absolute_standard_api: + # absolute_compare = AbsolutethdCompare(input_data) + # absolute_compare.compare() + # print(compare_column.inf_nan_error_ratio) + # print(compare_column.rel_err_ratio) + # print(compare_column.abs_err_ratio) + # elif api_name in ulp_standard_api: + # if bench_output.size == 0: + # compare_column.max_ulp_error = 0 + # compare_column.mean_ulp_error = 0 + # compare_column.ulp_error_proportion = 0 + # else: + # ulp_compare = UlpCompare(input_data) + # ulp_compare.compare() + # print(compare_column.max_ulp_error) + # print(compare_column.mean_ulp_error) + # print(compare_column.ulp_error_proportion) + # else: + # bench_compare = BenchmarkCompare(input_data) + + # bench_compare.compare() - print(compare_column.small_value_err_ratio) - print(compare_column.max_rel_error) - print(compare_column.mean_rel_error) - print(compare_column.rmse) - print(compare_column.eb) + # print(compare_column.small_value_err_ratio) + # print(compare_column.max_rel_error) + # print(compare_column.mean_rel_error) + # print(compare_column.rmse) + # print(compare_column.eb) def _compare_float_tensor(self, api_name, bench_output, device_output, compare_column, dtype): diff --git a/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py new file mode 100644 index 0000000000..dc23a34dde --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/api_accuracy_checker/standard/register.py @@ -0,0 +1,47 @@ +from msprobe.pytorch.api_accuracy_checker.compare.compare_utils import absolute_standard_api, binary_standard_api, \ + ulp_standard_api, thousandth_standard_api + +class ComparisonRegistry: + def __init__(self): + self.comparison_functions = {} + self.standard_categories = { + 'absolute_threshold': absolute_standard_api, + 'binary_consistency': binary_standard_api, + 'ulp_compare': ulp_standard_api, + 'thousandth_threshold': thousandth_standard_api + } + + def register(self, standard, func): + self.comparison_functions[standard] = func + + def get_standard_category(self, api_name): + # 遍历字典,确定api_name属于哪个类别 + for name, category in self.standard_categories.items(): + if api_name in category: + return name + return "benchmark" + + + def get_comparison_function(self, api_name): + standard = self.get_standard_category(api_name) + return self.comparison_functions.get(standard) + +# 创建一个比较注册器 +# registry = ComparisonRegistry() + +# # 注册比较函数 +# registry.register('thousandth', record_thousandth_threshold_result) +# registry.register('binary', record_binary_consistency_result) +# registry.register('absolute', record_absolute_threshold_result) +# registry.register('ulp', record_ulp_compare_result) + +# def compare_api(api_name, compare_column, row_npu): +# new_status = None + +# # 获取比较函数 +# comparison_func = registry.get_comparison_function(api_name) + +# if comparison_func: +# new_status = comparison_func(compare_column, row_npu) + +# return new_status \ No newline at end of file -- Gitee