diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 957de66ac475a28028f7a3f1348e57d848c6ac5e..d9c94b78404c28a51f4a4c81aa1632b25aecbaef 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -40,10 +40,11 @@ class Comparator: def print_pretest_result(self): self.get_statistics_from_result_csv() - if self.test_result_cnt.get("total_num") != 0: - passing_rate = str(self.test_result_cnt.get("success_num") / self.test_result_cnt.get("total_num")) + total_tests = self.test_result_cnt.get("total_num", 0) + if total_tests != 0: + passing_rate = "{:.2%}".format(self.test_result_cnt.get("success_num", 0) / total_tests) else: - passing_rate = "0" + passing_rate = "0%" console = Console() table_total = Table( @@ -51,26 +52,31 @@ class Comparator: ) table_total.add_column("Result") table_total.add_column("Statistics") - table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num"))) - table_total.add_row("[red]Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num") + - self.test_result_cnt.get("forward_or_backward_fail_num"))) + table_total.add_row("[green]Pass[/green]", str(self.test_result_cnt.get("success_num", 0))) + table_total.add_row("[yellow]Warning[/yellow]", str(self.test_result_cnt.get("warning_num", 0))) + table_total.add_row("[red]Error[/red]", str(self.test_result_cnt.get("error_num", 0))) table_total.add_row("Passing Rate", passing_rate) + table_total.add_row("Skip Tests", str(self.test_result_cnt.get("total_skip_num", 0))) table_detail = Table( show_header=True, title="Detail Statistics", show_lines=True, width=75 ) table_detail.add_column("Result") table_detail.add_column("Statistics") - table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) - table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num"))) - table_detail.add_row( - "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) + table_detail.add_row("Forward Error", str(self.test_result_cnt.get("forward_fail_num", 0))) + table_detail.add_row("Backward Error", str(self.test_result_cnt.get("backward_fail_num", 0))) + table_detail.add_row("Both Forward & Backward Error", str(self.test_result_cnt.get("forward_and_backward_fail_num", 0))) console.print(table_total) console.print(table_detail) def get_statistics_from_result_csv(self): - checklist = [CompareConst.TRUE, CompareConst.FALSE, CompareConst.NA, CompareConst.SKIP] + checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.NA, CompareConst.SKIP, "skip"] + self.test_result_cnt = { + "success_num": 0, "warning_num": 0, "error_num": 0, + "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, + "total_num": 0, "total_skip_num": 0 + } with FileOpen(self.save_path, 'r') as file: reader = csv.reader(file) result_csv_rows = [row for row in reader] @@ -78,25 +84,29 @@ class Comparator: for item in result_csv_rows[1:]: if not isinstance(item, list) or len(item) < 3: raise ValueError("The number of columns in %s is incorrect" % result_csv_name) - if not all(item[i] and item[i].upper() in checklist for i in (1, 2)): + if not all(item[i] and item[i] in checklist for i in (1, 2)): raise ValueError( - "The value in the 2nd or 3rd column of %s is wrong, it must be TRUE, FALSE, SKIP or N/A" + "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or N/A" % result_csv_name) - column1 = item[1].upper() - column2 = item[2].upper() - if column1 == CompareConst.SKIP: + column1 = item[1] + column2 = item[2] + if column1.upper() == CompareConst.SKIP: + self.test_result_cnt["total_skip_num"] += 1 continue self.test_result_cnt["total_num"] += 1 - if column1 == CompareConst.TRUE and column2 in [CompareConst.TRUE, 'N/A']: + if column1 == CompareConst.PASS and column2 in [CompareConst.PASS, CompareConst.NA]: self.test_result_cnt['success_num'] += 1 - elif column1 == CompareConst.FALSE and column2 == CompareConst.FALSE: + elif column1 == CompareConst.ERROR and column2 == CompareConst.ERROR: self.test_result_cnt['forward_and_backward_fail_num'] += 1 - elif column1 == CompareConst.FALSE: + self.test_result_cnt['error_num'] += 1 + elif column1 == CompareConst.ERROR: self.test_result_cnt['forward_fail_num'] += 1 - self.test_result_cnt['forward_or_backward_fail_num'] += 1 - else: + self.test_result_cnt['error_num'] += 1 + elif column2 == CompareConst.ERROR: self.test_result_cnt['backward_fail_num'] += 1 - self.test_result_cnt['forward_or_backward_fail_num'] += 1 + self.test_result_cnt['error_num'] += 1 + elif column1 == CompareConst.WARNING or column2 == CompareConst.WARNING: + self.test_result_cnt['warning_num'] += 1 def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, @@ -146,28 +156,15 @@ class Comparator: self.write_detail_csv(args) def compare_output(self, api_name, bench_output, device_output, bench_grad=None, npu_grad=None): - if "dropout" in api_name: - is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_output, device_output) - else: - is_fwd_success, fwd_compare_alg_results = self._compare_core_wrapper(bench_output, device_output) - if bench_grad and npu_grad: - if "dropout" in api_name: - is_bwd_success, bwd_compare_alg_results = self._compare_dropout(bench_grad[0], npu_grad[0]) - else: - is_bwd_success, bwd_compare_alg_results = self._compare_core_wrapper(bench_grad, npu_grad) - else: - is_bwd_success, bwd_compare_alg_results = True, None - if is_bwd_success and bwd_compare_alg_results is None: - self.record_results(api_name, is_fwd_success, CompareConst.NA, fwd_compare_alg_results, - bwd_compare_alg_results) - else: - self.record_results(api_name, is_fwd_success, is_bwd_success, fwd_compare_alg_results, - bwd_compare_alg_results) - return is_fwd_success, is_bwd_success + compare_func = self._compare_dropout if "dropout" in api_name else self._compare_core_wrapper + fwd_success_status, fwd_compare_alg_results = compare_func(bench_output, device_output) + bwd_success_status, bwd_compare_alg_results = (CompareConst.PASS, []) if not (bench_grad and npu_grad) else compare_func(bench_grad[0], npu_grad[0]) if "dropout" in api_name else compare_func(bench_grad, npu_grad) + self.record_results(api_name, fwd_success_status, bwd_success_status if bwd_compare_alg_results is not None else CompareConst.NA, fwd_compare_alg_results, bwd_compare_alg_results) + return fwd_success_status == CompareConst.PASS, bwd_success_status == CompareConst.PASS def _compare_core_wrapper(self, bench_output, device_output): detailed_result_total = [] - test_final_success = True + test_final_success = CompareConst.PASS if isinstance(bench_output, (list, tuple)): status, compare_result, message = [], [], [] if len(bench_output) != len(device_output): @@ -183,13 +180,17 @@ class Comparator: status, compare_result, message = self._compare_core(bench_output, device_output) if not isinstance(status, list): detailed_result_total.append(compare_result.to_column_value(status, message)) - if status in [CompareConst.ERROR, CompareConst.WARNING]: - test_final_success = False + if status == CompareConst.ERROR: + test_final_success = CompareConst.ERROR + elif status == CompareConst.WARNING: + test_final_success = CompareConst.WARNING else: for item, item_status in enumerate(status): detailed_result_total.append(compare_result[item].to_column_value(item_status, message[item])) - if item_status in [CompareConst.ERROR, CompareConst.WARNING]: - test_final_success = False + if item_status == CompareConst.ERROR: + test_final_success = CompareConst.ERROR + elif item_status == CompareConst.WARNING: + test_final_success = CompareConst.WARNING return test_final_success, detailed_result_total def _compare_core(self, bench_output, device_output): @@ -257,11 +258,11 @@ class Comparator: tensor_num = bench_output.numel() if tensor_num >= 100: if abs((bench_output == 0).sum() - (device_output == 0).cpu().sum()) / tensor_num < 0.1: - return True, 1 + return CompareConst.PASS, 1 else: - return False, 0 + return CompareConst.ERROR, 0 else: - return True, 1 + return CompareConst.PASS, 1 @staticmethod def _compare_builtin_type(bench_output, device_output, compare_column): diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py index 085f2f931677dbebce8c4800679e071c77881a60..038ed1e9f7ee693ecd611145628f786fcc5dac0a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -157,7 +157,7 @@ def prepare_config(args): print_info_log(f"UT task result will be saved in {result_csv_path}") print_info_log(f"UT task details will be saved in {details_csv_path}") else: - result_csv_path = get_validated_result_csv_path(args.result_csv_path) + result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') details_csv_path = get_validated_details_csv_path(result_csv_path) print_info_log(f"UT task result will be saved in {result_csv_path}") print_info_log(f"UT task details will be saved in {details_csv_path}") diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py index 95683cda5e448c2f65d15d711a271af07ef17d31..18293a4bc1fc899191bde35252034962f8312f3c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_multi_run_ut.py @@ -20,7 +20,7 @@ class TestMultiRunUT(unittest.TestCase): def test_split_json_file(self, mock_FileOpen): mock_FileOpen.return_value.__enter__.return_value = mock_open(read_data=self.test_json_content).return_value num_splits = 2 - split_files, total_items = split_json_file(self.test_json_file, num_splits) + split_files, total_items = split_json_file(self.test_json_file, num_splits, False) self.assertEqual(len(split_files), num_splits) self.assertEqual(total_items, len(self.test_data))