From 805daa5e2f60d06d7275a51ba79855c90df25b87 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Thu, 23 May 2024 17:39:37 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dapi=5Fprecision=5Fcompare?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../compare/api_precision_compare.py | 86 +++++++++---------- 1 file changed, 41 insertions(+), 45 deletions(-) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index f7f61a23e..6a544de21 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -9,42 +9,39 @@ from ..common.utils import print_info_log, print_warn_log, print_error_log, writ CompareException, create_directory from ..common.config import msCheckerConfig from ..compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ -API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ + API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \ convert_str_to_float, CompareMessage from ..compare.compare_column import ApiPrecisionOutputColumn from ..run_ut.run_ut import get_validated_result_csv_path from ...common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create - CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) unsupported_message = 'This data type does not support benchmark compare.' - benchmark_algorithms_thresholds = { - 'small_value' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 + 'small_value': { + 'error_threshold': 2, + 'warning_threshold': 1 }, - 'rmse' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 + 'rmse': { + 'error_threshold': 2, + 'warning_threshold': 1 }, - 'max_rel_err' : { - 'error_threshold' : 10, - 'warning_threshold' : 1 + 'max_rel_err': { + 'error_threshold': 10, + 'warning_threshold': 1 }, - 'mean_rel_err' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 + 'mean_rel_err': { + 'error_threshold': 2, + 'warning_threshold': 1 }, - 'eb' : { - 'error_threshold' : 2, - 'warning_threshold' : 1 + 'eb': { + 'error_threshold': 2, + 'warning_threshold': 1 } } - benchmark_message = { "small_value_err_status": { CompareConst.ERROR: "ERROR: 小值域错误比值超过阈值\n", @@ -107,18 +104,20 @@ class BenchmarkStandard: self.npu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), self.gpu_precision.get(ApiPrecisionCompareColumn.SMALL_VALUE_ERROR_RATE), 10000.0) self.rmse_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.RMSE), - self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0) + self.gpu_precision.get(ApiPrecisionCompareColumn.RMSE), 10000.0) self.max_rel_err_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), - self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), 10000.0) + self.gpu_precision.get(ApiPrecisionCompareColumn.MAX_REL_ERR), + 10000.0) self.mean_rel_err_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), - self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), 10000.0) + self.gpu_precision.get(ApiPrecisionCompareColumn.MEAN_REL_ERR), + 10000.0) self.eb_ratio = self._calc_ratio(self.npu_precision.get(ApiPrecisionCompareColumn.EB), - self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0) + self.gpu_precision.get(ApiPrecisionCompareColumn.EB), 10000.0) def to_column_value(self): - return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, - self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, - self.mean_rel_err_status, self.eb_ratio, self.eb_status] + return [self.small_value_err_ratio, self.small_value_err_status, self.rmse_ratio, + self.rmse_status, self.max_rel_err_ratio, self.max_rel_err_status, self.mean_rel_err_ratio, + self.mean_rel_err_status, self.eb_ratio, self.eb_status] @staticmethod def _get_status(ratio, algorithm): @@ -142,7 +141,7 @@ class BenchmarkStandard: def write_detail_csv(content, save_path): rows = [] content = ["{:.{}f}".format(item, msCheckerConfig.precision) \ - if isinstance(item, float) else item for item in content] + if isinstance(item, float) else item for item in content] rows.append(content) write_csv(rows, save_path) @@ -175,13 +174,13 @@ def api_precision_compare(config): def analyse_csv(npu_data, gpu_data, config): forward_status, backward_status = [], [] - full_last_api_name, last_api_dtype = None, None + last_api_name, last_api_dtype = None, None for _, row_npu in npu_data.iterrows(): message = '' compare_column = ApiPrecisionOutputColumn() full_api_name_with_direction_status = row_npu[ApiPrecisionCompareColumn.API_NAME] row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] - full_api_name, direction_status, _, _ = full_api_name_with_direction_status.split(".") + _, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".") if row_gpu.empty: print_warn_log(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.') continue @@ -189,14 +188,14 @@ def analyse_csv(npu_data, gpu_data, config): msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.' raise CompareException(CompareException.INVALID_DATA_ERROR, msg) row_gpu = row_gpu.iloc[0] - #当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 + # 当前API的输出为空(例如反向过程中requires_grad=False),跳过比对 if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE].isspace(): continue - _, api_name, _ = full_api_name.split("*") new_status = CompareConst.SPACE compare_column.api_name = full_api_name_with_direction_status - if row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or api_name in BinaryStandardApi: - new_status = record_binary_consistency_result(api_name, compare_column, row_npu) + if row_npu[ + ApiPrecisionCompareColumn.DEVICE_DTYPE] not in BINARY_COMPARE_UNSUPPORT_LIST or api_name in BinaryStandardApi: + new_status = record_binary_consistency_result(api_name, compare_column, row_npu) elif api_name in AbsoluteStandardApi: new_status = record_absolute_threshold_result(compare_column, row_npu) elif row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] in BENCHMARK_COMPARE_SUPPORT_LIST: @@ -204,24 +203,23 @@ def analyse_csv(npu_data, gpu_data, config): new_status = record_benchmark_compare_result(compare_column, bs) write_detail_csv(compare_column.to_column_value(), config.details_csv_path) - if full_last_api_name is not None and full_api_name != full_last_api_name: + if last_api_name is not None and api_name != last_api_name: if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: message = unsupported_message - write_csv([[full_last_api_name, "skip", "skip", message]], config.result_csv_path) + write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path) forward_status, backward_status = [], [] message = '' else: forward_result = get_api_checker_result(forward_status) backward_result = get_api_checker_result(backward_status) - _, last_api_name, _ = full_last_api_name.split("*") message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" - write_csv([[full_last_api_name, forward_result, backward_result, message]], config.result_csv_path) + write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path) forward_status, backward_status = [], [] message = '' - + is_supported = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] not in API_PRECISION_COMPARE_UNSUPPORT_LIST - full_last_api_name = full_api_name - + last_api_name = api_name + last_api_dtype = row_npu[ApiPrecisionCompareColumn.DEVICE_DTYPE] if not is_supported: continue @@ -233,16 +231,15 @@ def analyse_csv(npu_data, gpu_data, config): else: print_error_log(f"Invalid direction status: {direction_status}") - if full_last_api_name is not None: + if last_api_name is not None: if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: message = unsupported_message - write_csv([[full_last_api_name, "skip", "skip", message]], config.result_csv_path) + write_csv([[last_api_name, "skip", "skip", message]], config.result_csv_path) else: forward_result = get_api_checker_result(forward_status) backward_result = get_api_checker_result(backward_status) - _, last_api_name, _ = full_last_api_name.split("*") message += CompareMessage.get(last_api_name, "") if forward_result == CompareConst.ERROR else "" - write_csv([[full_last_api_name, forward_result, backward_result, message]], config.result_csv_path) + write_csv([[last_api_name, forward_result, backward_result, message]], config.result_csv_path) def check_error_rate(npu_error_rate): @@ -388,4 +385,3 @@ def _api_precision_compare_parser(parser): if __name__ == '__main__': _api_precision_compare() print_info_log("Compare task completed.") - \ No newline at end of file -- Gitee