diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index a584405c007ec02ca56a92d8cc008eea35dd29db..a15cbe08117219e0d80974baf2bb5a7bc5ff6a46 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,6 +1,7 @@ # 进行比对及结果展示 import os -from prettytable import PrettyTable +from rich.table import Table +from rich.console import Console from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, get_max_abs_err, \ compare_builtin_type, get_rel_err_ratio_thousandth, get_rel_err_ratio_ten_thousandth from api_accuracy_checker.common.utils import get_json_contents, print_info_log, write_csv @@ -33,25 +34,40 @@ class Comparator: self.register_compare_algorithm("Default: isEqual", compare_builtin_type, None) self.test_result_cnt = { - "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0 + "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0, + "total_num": 0 } self.result_save_path = result_save_path self.write_csv_title() def print_pretest_result(self): - res_dict = { - "forward_not_pass": self.test_result_cnt['forward_fail_num'], - "backward_not_pass": self.test_result_cnt['backward_fail_num'], - "forward_and_backward_not_pass": self.test_result_cnt['forward_and_backward_fail_num'], - "pass": self.test_result_cnt['success_num'] - } - tb = PrettyTable() - tb.add_column("Category", list(res_dict.keys())) - tb.add_column("statistics", list(res_dict.values())) - info_tb = str(tb) - print_info_log(info_tb) - - + if self.test_result_cnt.get("total_num") != 0: + passing_rate = str(self.test_result_cnt.get("success_num") / self.test_result_cnt.get("total_num")) + else: + passing_rate = "0" + + console = Console() + table_total = Table( + show_header=True, title="Overall Statistics", show_lines=True, width=75 + ) + table_total.add_column("Result") + table_total.add_column("Statistics") + table_total.add_row("[green]Total Pass[/green]", str(self.test_result_cnt.get("success_num"))) + table_total.add_row("[red]Total Fail[/red]", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) + table_total.add_row("Passing Rate", passing_rate) + + table_detail = Table( + show_header=True, title="Detail Statistics", show_lines=True, width=75 + ) + table_detail.add_column("Result") + table_detail.add_column("Statistics") + table_detail.add_row("Only Forward Fail", str(self.test_result_cnt.get("forward_fail_num"))) + table_detail.add_row("Only Backward Fail", str(self.test_result_cnt.get("backward_fail_num"))) + table_detail.add_row( + "Both Forward & Backward Fail", str(self.test_result_cnt.get("forward_and_backward_fail_num"))) + + console.print(table_total) + console.print(table_detail) def write_csv_title(self): summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS]] @@ -108,6 +124,7 @@ class Comparator: self.compare_alg.update({name: (compare_func, standard)}) def compare_output(self, api_name, bench_out, npu_out, bench_grad=None, npu_grad=None): + self.test_result_cnt["total_num"] += 1 if "dropout" in api_name: is_fwd_success, fwd_compare_alg_results = self._compare_dropout(bench_out, npu_out) else: