diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 4d9db1aeccd6875472b14ec5e1fa4cd4a3488530..a6241ea9bad3cd4a1c73b054b49ebe83f8e16e0b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -187,8 +187,9 @@ def read_json(file): return obj def write_csv(data, filepath): - data_frame = pd.DataFrame(columns=data) - data_frame.to_csv(filepath, index=False) + with open(filepath, 'a') as f: + writer = csv.writer(f) + writer.writerows(data) def _print_log(level, msg): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index c5ae3729947ed75eaefe439b3d9a2fd6c243ef1b..f57a8a96c1a475c308031bf276579dc60747774c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -10,6 +10,7 @@ from api_accuracy_checker.compare.compare_utils import CompareConst class Comparator: TEST_FILE_NAME = "pretest_result.csv" DETAIL_TEST_FILE_NAME = "pretest_details.csv" + # consts for result csv COLUMN_API_NAME = "API name" COLUMN_FORWARD_SUCCESS = "Forward Test Success" @@ -29,11 +30,12 @@ class Comparator: self.register_compare_algorithm("Thousandth Relative Error Ratio", get_rel_err_ratio_thousandth, None) self.register_compare_algorithm("Ten Thousandth Relative Error Ratio", get_rel_err_ratio_ten_thousandth, None) self.register_compare_algorithm("Default: isEqual", compare_builtin_type, None) - self.test_results = [] + self.test_result_cnt = { "forward_fail_num": 0, "backward_fail_num": 0, "forward_and_backward_fail_num": 0, "success_num": 0 } self.result_save_path = result_save_path + self.write_csv_title() def print_pretest_result(self): res_dict = { @@ -48,25 +50,13 @@ class Comparator: info_tb = str(tb) print_info_log(info_tb) - def write_compare_csv(self): - self.write_summary_csv() - self.write_detail_csv() + - def write_summary_csv(self): - test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS]] - if self.stack_info: - test_rows[0].append(self.COLUMN_STACK_INFO) - for result in self.test_results: - name = result[0] - df_row = list(result[:3]) - if self.stack_info: - stack_info = "\n".join(self.stack_info[name]) - df_row.append(stack_info) - test_rows.append(df_row) - write_csv(test_rows, self.save_path) + def write_csv_title(self): + summary_test_rows = [[self.COLUMN_API_NAME, self.COLUMN_FORWARD_SUCCESS, self.COLUMN_BACKWARD_SUCCESS]] + write_csv(summary_test_rows, self.save_path) - def write_detail_csv(self): - test_rows = [[ + detail_test_rows = [[ "Subject", "Bench Dtype", "NPU Dtype", "Cosine Similarity", "Cosine Similarity Message", "Max Rel Error", "Max Rel Err Message", @@ -75,23 +65,42 @@ class Comparator: "Compare Builtin Type", "Builtin Type Message", "Pass" ]] - for test_result in self.test_results: - subject_prefix = test_result[0] - fwd_result = test_result[3] - bwd_result = test_result[4] - if isinstance(fwd_result, list): - for i, test_subject in enumerate(fwd_result): - subject = subject_prefix + ".forward.output." + str(i) - test_rows.append([subject] + list(test_subject)) - if isinstance(bwd_result, list): - for i, test_subject in enumerate(bwd_result): - subject = subject_prefix + ".backward.output." + str(i) - test_rows.append([subject] + list(test_subject)) + write_csv(detail_test_rows, self.detail_save_path) + + def write_summary_csv(self, test_result): + test_rows = [] + if self.stack_info: + test_rows[0].append(self.COLUMN_STACK_INFO) + + name = test_result[0] + df_row = list(test_result[:3]) + if self.stack_info: + stack_info = "\n".join(self.stack_info[name]) + df_row.append(stack_info) + test_rows.append(df_row) + write_csv(test_rows, self.save_path) + + def write_detail_csv(self, test_result): + test_rows = [] + + subject_prefix = test_result[0] + fwd_result = test_result[3] + bwd_result = test_result[4] + if isinstance(fwd_result, list): + for i, test_subject in enumerate(fwd_result): + subject = subject_prefix + ".forward.output." + str(i) + test_rows.append([subject] + list(test_subject)) + if isinstance(bwd_result, list): + for i, test_subject in enumerate(bwd_result): + subject = subject_prefix + ".backward.output." + str(i) + test_rows.append([subject] + list(test_subject)) write_csv(test_rows, self.detail_save_path) def record_results(self, *args): - self.test_results.append(args) + self.write_summary_csv(args) + self.write_detail_csv(args) + def register_compare_algorithm(self, name, compare_func, standard): self.compare_alg.update({name: (compare_func, standard)}) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e5737b511068814d85ffb3d3eb5654288a46c93f..7fcb0dbd96fa0c82f46f6c28b123092a750e8f77 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -86,7 +86,6 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) compare.print_pretest_result() - compare.write_compare_csv() def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict):