From 536104b45c39c48ad521d752bad02cf15af7eed3 Mon Sep 17 00:00:00 2001 From: louyujing Date: Mon, 13 Nov 2023 03:16:07 +0000 Subject: [PATCH 1/4] update debug/accuracy_tools/api_accuracy_checker/compare/compare.py. Signed-off-by: louyujing --- .../api_accuracy_checker/compare/compare.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 525efd8553..1ff8bd9726 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -148,13 +148,13 @@ class Comparator: else: is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) else: - is_bwd_success, bwd_compare_alg_results = True, None + is_bwd_success, bwd_compare_alg_results = CompareConst.NA, None self.record_results(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, bwd_compare_alg_results) - if is_fwd_success and is_bwd_success: + if is_fwd_success != CompareConst.FALSE and is_bwd_success != CompareConst.FALSE: self.test_result_cnt['success_num'] += 1 - elif not is_fwd_success and not is_bwd_success: + elif is_fwd_success == CompareConst.FALSE and is_bwd_success == CompareConst.FALSE: self.test_result_cnt['forward_and_backward_fail_num'] += 1 - elif not is_fwd_success: + elif is_fwd_success == CompareConst.FALSE: self.test_result_cnt['forward_fail_num'] += 1 self.test_result_cnt['forward_or_backward_fail_num'] += 1 else: @@ -212,7 +212,8 @@ class Comparator: except IndexError as error: print_error_log(f"There is index error.\n{str(error)}") raise CompareException(CompareException.INVALID_DATA_ERROR) from error - test_final_success = False if CompareConst.ERROR in test_all_result or CompareConst.WARNING in test_all_result else True + test_final_success = CompareConst.FALSE \ + if CompareConst.ERROR in test_all_result or CompareConst.WARNING in test_all_result else CompareConst.TRUE return test_final_success, detailed_result_total @staticmethod @@ -220,8 +221,8 @@ class Comparator: tensor_num = bench_out.numel() if tensor_num >= 100: if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < 0.1: - return True, 1 + return CompareConst.TRUE, 1 else: - return False, 0 + return CompareConst.FALSE, 0 else: - return True, 1 + return CompareConst.TRUE, 1 -- Gitee From 2a91b54709355e3fe6a7534469df3451910da310 Mon Sep 17 00:00:00 2001 From: louyujing Date: Mon, 13 Nov 2023 03:18:15 +0000 Subject: [PATCH 2/4] update debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py. Signed-off-by: louyujing --- .../api_accuracy_checker/compare/compare_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py index 0bb80fbce9..93001414e3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py @@ -8,6 +8,8 @@ class CompareConst: PASS = 'pass' WARNING = 'warning' ERROR = 'error' + TRUE = 'True' + FALSE = 'False' def check_dtype_comparable(x, y): -- Gitee From f364880673b3bbcc2a070f1458b4d57fa5e37ece Mon Sep 17 00:00:00 2001 From: louyujing Date: Mon, 13 Nov 2023 09:06:07 +0000 Subject: [PATCH 3/4] update debug/accuracy_tools/api_accuracy_checker/compare/compare.py. Signed-off-by: louyujing --- .../api_accuracy_checker/compare/compare.py | 25 +++++++++++-------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 1ff8bd9726..d5de8ce021 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -148,13 +148,18 @@ class Comparator: else: is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) else: - is_bwd_success, bwd_compare_alg_results = CompareConst.NA, None - self.record_results(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, bwd_compare_alg_results) - if is_fwd_success != CompareConst.FALSE and is_bwd_success != CompareConst.FALSE: + is_bwd_success, bwd_compare_alg_results = True, None + if is_bwd_success and bwd_compare_alg_results is None: + self.record_results(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results, + bwd_compare_alg_results) + else: + self.record_results(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, + bwd_compare_alg_results) + if is_fwd_success and is_bwd_success: self.test_result_cnt['success_num'] += 1 - elif is_fwd_success == CompareConst.FALSE and is_bwd_success == CompareConst.FALSE: + elif not is_fwd_success and not is_bwd_success: self.test_result_cnt['forward_and_backward_fail_num'] += 1 - elif is_fwd_success == CompareConst.FALSE: + elif not is_fwd_success: self.test_result_cnt['forward_fail_num'] += 1 self.test_result_cnt['forward_or_backward_fail_num'] += 1 else: @@ -212,8 +217,8 @@ class Comparator: except IndexError as error: print_error_log(f"There is index error.\n{str(error)}") raise CompareException(CompareException.INVALID_DATA_ERROR) from error - test_final_success = CompareConst.FALSE \ - if CompareConst.ERROR in test_all_result or CompareConst.WARNING in test_all_result else CompareConst.TRUE + test_final_success = False if CompareConst.ERROR in test_all_result or CompareConst.WARNING in test_all_result \ + else True return test_final_success, detailed_result_total @staticmethod @@ -221,8 +226,8 @@ class Comparator: tensor_num = bench_out.numel() if tensor_num >= 100: if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < 0.1: - return CompareConst.TRUE, 1 + return True, 1 else: - return CompareConst.FALSE, 0 + return False, 0 else: - return CompareConst.TRUE, 1 + return True, 1 -- Gitee From 24137776f212da51c569b5cbc48eea09b6bc4e50 Mon Sep 17 00:00:00 2001 From: louyujing Date: Mon, 13 Nov 2023 09:07:01 +0000 Subject: [PATCH 4/4] update debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py. Signed-off-by: louyujing --- .../api_accuracy_checker/compare/compare_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py index 93001414e3..0bb80fbce9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare_utils.py @@ -8,8 +8,6 @@ class CompareConst: PASS = 'pass' WARNING = 'warning' ERROR = 'error' - TRUE = 'True' - FALSE = 'False' def check_dtype_comparable(x, y): -- Gitee