diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py index 783334349245e8ed1e4062843a258b905c839296..dd88e37c49c8cea22344e9763e5d856827652a3f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/api_precision_compare.py @@ -446,8 +446,8 @@ def _api_precision_compare(parser=None): def _api_precision_compare_command(args): - npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail') - gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail') + npu_csv_path = get_validated_result_csv_path(args.npu_csv_path, 'detail', None) + gpu_csv_path = get_validated_result_csv_path(args.gpu_csv_path, 'detail', None) out_path = os.path.realpath(args.out_path) if args.out_path else "./" check_path_before_create(out_path) create_directory(out_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py index 760e088eb38a26ba01fd25ac579130240a76a9e8..316bcf7cde1ed87901d564a906932440c6c79bbe 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -163,7 +163,7 @@ def prepare_config(args): print_info_log(f"UT task result will be saved in {result_csv_path}") print_info_log(f"UT task details will be saved in {details_csv_path}") else: - result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') + result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result', 'accuracy_checking_result') details_csv_path = get_validated_details_csv_path(result_csv_path) print_info_log(f"UT task result will be saved in {result_csv_path}") print_info_log(f"UT task details will be saved in {details_csv_path}") diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index f47c4f4ea5a18988908b6434ed008be51769a7fe..a2fdec2811d44de94a84a128608d0f6e47f69921 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -45,7 +45,7 @@ UT_ERROR_DATA_DIR = 'ut_error_data' + current_time RESULT_FILE_NAME = f"accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = f"accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', - 'save_error_data', 'is_continue_run_ut', 'real_data_path']) + 'save_error_data', 'is_continue_run_ut', 'real_data_path', 'error_api']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} not_raise_dtype_set = {'type_as'} @@ -211,8 +211,8 @@ def run_api_offline(config, compare, api_name_set): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info) - if config.save_error_data: - do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) + if config.save_error_data and api_full_name in config.error_api: + do_save_error_data(api_full_name, data_info) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "expected scalar type Long" in str(err): @@ -258,16 +258,15 @@ def run_api_online(config, compare): dispatcher.update_consume_queue(api_data) -def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): - if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") - for element in data_info.in_fwd_data_list: - UtAPIInfo(api_full_name + '.forward.input', element) - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_output) - UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_output) - UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad) - UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad) +def do_save_error_data(api_full_name, data_info): + api_full_name = api_full_name.replace("*", ".") + for element in data_info.in_fwd_data_list: + UtAPIInfo(api_full_name + '.forward.input', element) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_output) + UtAPIInfo(api_full_name + '.forward.output.device', data_info.device_output) + UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad) + UtAPIInfo(api_full_name + '.backward.output.device', data_info.device_grad) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): @@ -373,7 +372,7 @@ def initialize_save_error_data(): initialize_save_path(error_data_path, UT_ERROR_DATA_DIR) -def get_validated_result_csv_path(result_csv_path, mode): +def get_validated_result_csv_path(result_csv_path, mode, prefix): if mode not in ['result', 'detail']: raise ValueError("The csv mode must be result or detail") result_csv_path_checker = FileChecker(result_csv_path, FileCheckConst.FILE, ability=FileCheckConst.READ_WRITE_ABLE, @@ -381,9 +380,12 @@ def get_validated_result_csv_path(result_csv_path, mode): validated_result_csv_path = result_csv_path_checker.common_check() if mode == 'result': result_csv_name = os.path.basename(validated_result_csv_path) - pattern = r"^accuracy_checking_result_\d{14}\.csv$" + pattern = rf"^{prefix}_\d{{14}}\.csv$" if not re.match(pattern, result_csv_name): - raise ValueError("When continue run ut, please do not modify the result csv name.") + if prefix == 'accuracy_checking_result': + raise ValueError("When continue run ut, please do not modify the result csv name.") + if prefix == 'api_precision_compare_result': + raise ValueError("When save error data, please do not modify the result csv name.") return validated_result_csv_path @@ -414,8 +416,10 @@ def _run_ut_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) - parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", - help=" Save compare failed api output.", required=False) + parser.add_argument('-save_error_data', dest="save_error_data", default="", type=str, + help=" The path of api_precision_compare_result_{timestamp}.csv, " + "when need to save error data, enter the path to save error data.", + required=False) parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", help=" whether to turn on jit compile", required=False) @@ -507,7 +511,7 @@ def run_ut_command(args): create_directory(out_path) out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() - save_error_data = args.save_error_data + save_error_data = True if args.save_error_data else False forward_content = {} if args.forward_input_file: check_link(args.forward_input_file) @@ -526,8 +530,9 @@ def run_ut_command(args): backward_content = get_json_contents(backward_file) result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) + error_api = [] if args.result_csv_path: - result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') + result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result', 'accuracy_checking_result') details_csv_path = get_validated_details_csv_path(result_csv_path) if save_error_data: if args.result_csv_path: @@ -535,8 +540,26 @@ def run_ut_command(args): global UT_ERROR_DATA_DIR UT_ERROR_DATA_DIR = 'ut_error_data' + time_info initialize_save_error_data() + error_data_csv_path = get_validated_result_csv_path(args.save_error_data, 'result', 'api_precision_compare_result') + with FileOpen(error_data_csv_path, 'r') as file: + reader = csv.reader(file) + result_csv_rows = [row for row in reader] + result_csv_name = os.path.basename(error_data_csv_path) + checklist = [CompareConst.PASS, CompareConst.ERROR, CompareConst.WARNING, CompareConst.SPACE, CompareConst.SKIP, "skip"] + for item in result_csv_rows[1:]: + if not isinstance(item, list) or len(item) < 3: + raise ValueError("The number of columns in %s is incorrect" % result_csv_name) + if not all(item[i] and item[i] in checklist for i in (1, 2)): + raise ValueError( + "The value in the 2nd or 3rd column of %s is wrong, it must be pass, error, warning, skip, or SPACE" + % result_csv_name) + api_name = item[0] + column1 = item[1] + column2 = item[2] + if column1 == CompareConst.ERROR or column2 == CompareConst.ERROR: + error_api.append(api_name) run_ut_config = RunUTConfig(forward_content, backward_content, result_csv_path, details_csv_path, save_error_data, - args.result_csv_path, args.real_data_path) + args.result_csv_path, args.real_data_path, error_api) run_ut(run_ut_config)