diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 07dd4e6bfca90336b4d9c3c4f1deaed9aded17b6..2bf26d73552a84dd7b50e7dc62273f6b25079a4e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -19,12 +19,15 @@ class Config: 'dump_step': int, 'error_data_path': str, 'enable_dataloader': bool, - 'target_iter': int + 'target_iter': int, + 'precision': int } if not isinstance(value, validators[key]): raise ValueError(f"{key} must be {validators[key].__name__} type") if key == 'target_iter' and value < 0: raise ValueError("target_iter must be greater than 0") + if key == 'precision' and value < 0: + raise ValueError("precision must be greater than 0") return value def __getattr__(self, item): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 1dad82f9d6d0a9570df7cffa005759c982d65956..37c04ad1014eab1ec11bdd2bc0bf41cfa4c44db2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -5,7 +5,7 @@ from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cos compare_builtin_type, get_rel_err_ratio_thousandth, get_rel_err_ratio_ten_thousandth from api_accuracy_checker.common.utils import get_json_contents, print_info_log, write_csv from api_accuracy_checker.compare.compare_utils import CompareConst - +from api_accuracy_checker.common.config import msCheckerConfig class Comparator: TEST_FILE_NAME = "accuracy_checking_result.csv" @@ -95,10 +95,12 @@ class Comparator: if isinstance(fwd_result, list): for i, test_subject in enumerate(fwd_result): subject = subject_prefix + ".forward.output." + str(i) + test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision) if isinstance(item, float) else item for item in test_subject] test_rows.append([subject] + list(test_subject)) if isinstance(bwd_result, list): for i, test_subject in enumerate(bwd_result): subject = subject_prefix + ".backward.output." + str(i) + test_subject = ["{:.{}f}".format(item, msCheckerConfig.precision) if isinstance(item, float) else item for item in test_subject] test_rows.append([subject] + list(test_subject)) write_csv(test_rows, self.detail_save_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 46f0ed8d41af82c43154263d940678321f57b814..143629d92ca1e3f4b0324a37b1f48c91d085131a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -6,4 +6,6 @@ real_data: False dump_step: 1000 error_data_path: './' enable_dataloader: True -target_iter: 1 \ No newline at end of file +target_iter: 1 +precision: 14 + \ No newline at end of file