From d806308f2edf4ab7c7405489875abf70b6871457 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Sat, 5 Aug 2023 15:18:10 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E6=AF=94=E5=AF=B9=E7=BB=93=E6=9E=9Cdtype?= =?UTF-8?q?=E4=B8=8D=E7=9B=B8=E5=90=8C=EF=BC=8CUT=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=B8=8D=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index a79125d832b..2ab8735a312 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -30,7 +30,7 @@ def get_max_rel_err(n_value, b_value): print_warn_log("Max rel err only support numpy array!") raise ValueError("Max rel err only support numpy array!") if n_value.dtype != b_value.dtype: - raise ValueError("npu and bench value dtype is different.") + return CompareConst.NA, False if n_value.dtype in Const.FLOAT_TYPE: rel_err = np.abs((n_value - b_value) / (b_value + np.finfo(b_value.dtype).eps)).max() return rel_err, rel_err < 0.001 -- Gitee From a770a03d669782072fc5d6d91c9d979fdea2a318 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Sat, 5 Aug 2023 16:07:25 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E9=92=88=E5=AF=B9dropout?= =?UTF-8?q?=E7=B1=BB=E8=BF=90=E7=AE=97=E7=9A=84=E6=AF=94=E5=AF=B9=E6=96=B9?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/compare/compare.py | 28 +++++++++++++++---- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index b877bc50d4d..edf0dee8864 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -23,7 +23,8 @@ class Comparator: self.compare_alg_names = [] self.register_compare_algorithm("Cosine Similarity", cosine_sim, cosine_standard) self.test_results = [] - self.test_result_cnt = {"forward_fail_num":0, "backward_fail_num":0, "forward_and_backward_fail_num":0, "success_num":0} + self.test_result_cnt = {"forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, + "success_num": 0} def print_pretest_result(self): res_dict = { @@ -34,7 +35,7 @@ class Comparator: } tb = PrettyTable() tb.add_column("Category", list(res_dict.keys())) - tb.add_column("statistics",list(res_dict.values())) + tb.add_column("statistics", list(res_dict.values())) info_tb = str(tb) print_info_log(info_tb) @@ -62,9 +63,15 @@ class Comparator: self.compare_alg_names.append(name) def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None): - is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out) + if "dropout" in api_name: + is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out) + else: + is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out) if bench_grad and npu_grad: - is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) + if "dropout" in api_name: + is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad, npu_grad) + else: + is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) else: is_bwd_success, bwd_compare_alg_results = CompareConst.NA, None self.record_results(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, bwd_compare_alg_results) @@ -80,4 +87,15 @@ class Comparator: def _compare_core_wrapper(self, bench_out, npu_out): name = self.compare_alg_names[0] detailed_result, test_success = compare_core(bench_out, npu_out, self.compare_alg[name][0]) - return test_success, detailed_result \ No newline at end of file + return test_success, detailed_result + + @staticmethod + def _compare_dropout(bench_out, npu_out): + tensor_num = bench_out.numel() + if tensor_num.numel() >= 100: + if abs((bench_out == 0).sum() - (npu_out == 0).sum()) / tensor_num < 0.1: + return True, 1 + else: + return False, 0 + else: + return True, 1 -- Gitee From 4ce02b5ad8b57fe43ee421fb4233fffd76942da3 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Sat, 5 Aug 2023 16:52:35 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddropout=E7=B1=BB=E8=BF=90?= =?UTF-8?q?=E7=AE=97=E7=9A=84=E6=AF=94=E5=AF=B9=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index edf0dee8864..5cb0777e43f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -69,7 +69,7 @@ class Comparator: is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_out, npu_out) if bench_grad and npu_grad: if "dropout" in api_name: - is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad, npu_grad) + is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0]) else: is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) else: @@ -92,7 +92,7 @@ class Comparator: @staticmethod def _compare_dropout(bench_out, npu_out): tensor_num = bench_out.numel() - if tensor_num.numel() >= 100: + if tensor_num >= 100: if abs((bench_out == 0).sum() - (npu_out == 0).sum()) / tensor_num < 0.1: return True, 1 else: -- Gitee