From 1f9dd137740d43a8ac4d0daba22bd8924cd666ae Mon Sep 17 00:00:00 2001 From: gitee Date: Sun, 28 Apr 2024 16:55:16 +0800 Subject: [PATCH] fix --- .../compare/api_precision_compare.py | 4 +- .../run_ut/multi_run_ut.py | 2 +- .../api_accuracy_checker/run_ut/run_ut.py | 65 +++++++++++++------ 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py index a68cdce4a2..a2801b55a4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py @@ -427,8 +427,8 @@ def _api_precision_compare(parser=None): def _api_precision_compare_command(args): - npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail') - gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail') + npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail', None) + gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail', None) out_path = os.path.realpath(args.out_path) if args.out_path else "./" check_path_before_create(out_path) create_directory(out_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py index 760e088eb3..316bcf7cde 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -163,7 +163,7 @@ def prepare_config(args): print_info_log(f"UT task result will be saved in {result_csv_path}") print_info_log(f"UT task details will be saved in {details_csv_path}") else: - result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') + result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result', 'accuracy_checking_result') details_csv_path = get_validated_details_csv_path(result_csv_path) print_info_log(f"UT task result will be saved in {result_csv_path}") print_info_log(f"UT task details will be saved in {details_csv_path}") diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 05bd4305a3..e3df4cced2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -39,7 +39,7 @@ UT_ERROR_DATA_DIR = 'ut_error_data' + current_time RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path']) + 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'error_api']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} not_raise_dtype_set = {'type_as'} @@ -176,8 +176,8 @@ def run_ut(config): data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) - if config.save_error_data: - do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) + if config.save_error_data and api_full_name in config.error_api: + do_save_error_data(api_full_name, data_info) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "expected scalar type Long" in str(err): @@ -199,16 +199,15 @@ def run_ut(config): compare.print_pretest_result() -def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): - if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") - for element in data_info.in_fwd_data_list: - UtAPIInfo(api_full_name + '.forward.input', element) - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) - UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_out) - UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad_out) +def do_save_error_data(api_full_name, data_info): + api_full_name = api_full_name.replace("*", ".") + for element in data_info.in_fwd_data_list: + UtAPIInfo(api_full_name + '.forward.input', element) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_output) + UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_output) + UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad) + UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): @@ -297,7 +296,7 @@ def initialize_save_error_data(): initialize_save_path(error_data_path, UT_ERROR_DATA_DIR) -def get_validated_result_csv_path(result_csv_path, mode): +def get_validated_result_csv_path(result_csv_path, mode, prefix): if mode not in ['result', 'detail']: raise ValueError("The csv mode must be result or detail") result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE, @@ -305,9 +304,12 @@ def get_validated_result_csv_path(result_csv_path, mode): validated_result_csv_path = result_csv_path_checker.common_check() if mode == 'result': result_csv_name = os.path.basename(validated_result_csv_path) - pattern = r"^accuracy_checking_result_\d{14}\.csv$" + pattern = rf"^{prefix}_\d{{14}}\.csv$" if not re.match(pattern, result_csv_name): - raise ValueError("When continue run ut, please do not modify the result csv name.") + if prefix == 'accuracy_checking_result': + raise ValueError("When continue run ut, please do not modify the result csv name.") + if prefix == 'api_precision_compare_result': + raise ValueError("When save error data, please do not modify the result csv name.") return validated_result_csv_path @@ -333,8 +335,10 @@ def _run_ut_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) - parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", - help=" Save compare failed api output.", required=False) + parser.add_argument('-save_error_data', dest="save_error_data", default="", type=str, + help=" The path of api_precision_compare_result_{timestamp}.csv, " + "when need to save error data, enter the path to save error data.", + required=False) parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", help=" whether to turn on jit compile", required=False) @@ -420,7 +424,7 @@ def run_ut_command(args): create_directory(out_path) out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() - save_error_data = args.save_error_data + save_error_data = True if args.save_error_data else False forward_content = get_json_contents(forward_file) if args.filter_api: forward_content = preprocess_forward_content(forward_content) @@ -432,8 +436,9 @@ def run_ut_command(args): backward_content = get_json_contents(backward_file) result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) + error_api = [] if args.result_csv_path: - result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') + result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result', 'accuracy_checking_result') details_csv_path = get_validated_details_csv_path(result_csv_path) if save_error_data: if args.result_csv_path: @@ -441,8 +446,26 @@ def run_ut_command(args): global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info initialize_save_error_data() + error_data_csv_path = get_validated_result_csv_path(args.save_error_data, 'result', 'api_precision_compare_result') + with FileOpen(error_data_csv_path, 'r') as file: + reader = csv.reader(file) + result_csv_rows = [row for row in reader] + result_csv_name = os.path.basename(error_data_csv_path) + checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, "skip"] + for item in result_csv_rows[1:]: + if not isinstance(item, list) or len(item) < 3: + raise ValueError("The number of columns in %s is incorrect" % result_csv_name) + if not all(item[i] and item[i] in checklist for i in (1, 2)): + raise ValueError( + "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE" + % result_csv_name) + api_name = item[0] + column1 = item[1] + column2 = item[2] + if column1 == CompareConst.ERROR or column2 == CompareConst.ERROR: + error_api.append(api_name) run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, args.real_data_path) + args.result_csv_path, args.real_data_path, error_api) run_ut(run_ut_config) -- Gitee