From 060ddb4e187cc88f7d2710383d5eae4ef4bcdc60 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Tue, 12 Sep 2023 14:27:13 +0800 Subject: [PATCH 1/4] Optimization result printing --- .../api_accuracy_checker/compare/compare.py | 47 +++++++++++++------ 1 file changed, 32 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index a584405c00..b59901b997 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,6 +1,7 @@ # 进行比对及结果展示 import os -from prettytable import PrettyTable +from rich.table import Column, Table +from rich.console import Console from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, get_max_abs_err, \ compare_builtin_type, get_rel_err_ratio_thousandth, get_rel_err_ratio_ten_thousandth from api_accuracy_checker.common.utils import get_json_contents, print_info_log, write_csv @@ -33,25 +34,40 @@ class Comparator: self.register_compare_algorithm("Default: isEqual", compare_builtin_type, None) self.test_result_cnt = { - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0 + "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0, + "total_num": 0 } self.result_save_path = result_save_path self.write_csv_title() def print_pretest_result(self): - res_dict = { - "forward_not_pass": self.test_result_cnt['forward_fail_num'], - "backward_not_pass": self.test_result_cnt['backward_fail_num'], - "forward_and_backward_not_pass": self.test_result_cnt['forward_and_backward_fail_num'], - "pass": self.test_result_cnt['success_num'] - } - tb = PrettyTable() - tb.add_column("Category", list(res_dict.keys())) - tb.add_column("statistics", list(res_dict.values())) - info_tb = str(tb) - print_info_log(info_tb) - - + if self.test_result_cnt.get("total_num") != 0: + passing_rate = str(self.test_result_cnt.get("success_num") / self.test_result_cnt.get("total_num")) + else: + passing_rate = "0" + + console = Console() + table_total = Table( + show_header=True, header_style="bold magenta", title_style="Overall Statistics", show_lines=True, width=75 + ) + table_total.add_column("Result") + table_total.add_column("Statistics") + table_total.add_row("[green]Total Pass[/green]", str(self.test_result_cnt.get("success_num"))) + table_total.add_row("[red]Total Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) + table_total.add_row("Passing Rate", passing_rate) + + table_detail = Table( + show_header=True, header_style="bold magenta", title_style="Detail Statistics", show_lines=True, width=75 + ) + table_detail.add_column("Result") + table_detail.add_column("Statistics") + table_detail.add_column("Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) + table_detail.add_column("Backward Fail", str(self.test_result_cnt.get("backward_not_pass"))) + table_detail.add_column( + "Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_not_pass"))) + + console.print(table_total) + console.print(table_detail) def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS]] @@ -108,6 +124,7 @@ class Comparator: self.compare_alg.update({name: (compare_func, standard)}) def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None): + self.test_result_cnt["total_num"] += 1 if "dropout" in api_name: is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out) else: -- Gitee From 9e1bf49d889e4905b12811d182427a3665581f87 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Tue, 12 Sep 2023 15:21:03 +0800 Subject: [PATCH 2/4] Optimization result printing --- .../api_accuracy_checker/compare/compare.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index b59901b997..378657710a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -48,7 +48,7 @@ class Comparator: console = Console() table_total = Table( - show_header=True, header_style="bold magenta", title_style="Overall Statistics", show_lines=True, width=75 + show_header=True, title="Overall Statistics", show_lines=True, width=75 ) table_total.add_column("Result") table_total.add_column("Statistics") @@ -57,14 +57,14 @@ class Comparator: table_total.add_row("Passing Rate", passing_rate) table_detail = Table( - show_header=True, header_style="bold magenta", title_style="Detail Statistics", show_lines=True, width=75 + show_header=True, title="Detail Statistics", show_lines=True, width=75 ) table_detail.add_column("Result") table_detail.add_column("Statistics") - table_detail.add_column("Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) - table_detail.add_column("Backward Fail", str(self.test_result_cnt.get("backward_not_pass"))) - table_detail.add_column( - "Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_not_pass"))) + table_detail.add_row("Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) + table_detail.add_row("Backward Fail", str(self.test_result_cnt.get("backward_fail_num"))) + table_detail.add_row( + "Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) console.print(table_total) console.print(table_detail) -- Gitee From 56e3c2d36029150e66530f4d49cdd4aaad37a19d Mon Sep 17 00:00:00 2001 From: l30036321 Date: Tue, 12 Sep 2023 16:27:09 +0800 Subject: [PATCH 3/4] Optimization result printing --- .../accuracy_tools/api_accuracy_checker/compare/compare.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 378657710a..97f1005fe7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -61,10 +61,10 @@ class Comparator: ) table_detail.add_column("Result") table_detail.add_column("Statistics") - table_detail.add_row("Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) - table_detail.add_row("Backward Fail", str(self.test_result_cnt.get("backward_fail_num"))) + table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) + table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num"))) table_detail.add_row( - "Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) + "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) console.print(table_total) console.print(table_detail) -- Gitee From 1a9afe516f2955b55fbb950383d70313dffb0082 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Tue, 12 Sep 2023 20:36:29 +0800 Subject: [PATCH 4/4] Optimization result printing --- debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 97f1005fe7..a15cbe0811 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,6 +1,6 @@ # 进行比对及结果展示 import os -from rich.table import Column, Table +from rich.table import Table from rich.console import Console from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, get_max_abs_err, \ compare_builtin_type, get_rel_err_ratio_thousandth, get_rel_err_ratio_ten_thousandth -- Gitee