diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index ff0c4917c7ef2014dec3c113ec64fd7ebdd7716a..cecf81a0c75fbdc3b17253d4d3598d00a8453f78 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -5,6 +5,7 @@ import sys import torch_npu import yaml import torch +from multiprocessing import Pool, Manager, cpu_count from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ @@ -90,19 +91,27 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): cpu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs + def run_ut(forward_file, backward_file, out_path, save_error_data): print_info_log("start UT test") forward_content = get_json_contents(forward_file) backward_content = get_json_contents(backward_file) api_setting_dict = get_json_contents("torch_ut_setting.json") compare = Comparator(out_path) + processes = int((cpu_count() + 1) / 2) + pool = Pool(processes) + manager = Manager() + cpu_results_queue = manager.Queue() + for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) + pool.apply_async(run_torch_api_cpu, args=(api_full_name, api_setting_dict, backward_content, api_info_dict, cpu_results_queue)) + data_info = run_torch_api_npu(api_full_name, api_setting_dict, backward_content, api_info_dict) + cpu_data_info = cpu_results_queue.get() is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, - data_info.bench_out, + cpu_data_info.bench_out, data_info.npu_out, - data_info.bench_grad_out, + cpu_data_info.bench_grad_out, data_info.npu_grad_out) if save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) @@ -115,6 +124,8 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err))) compare.print_pretest_result() + pool.close() + pool.join() def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): @@ -130,7 +141,7 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) -def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): +def run_torch_api_cpu(api_full_name, api_setting_dict, backward_content, api_info_dict, cpu_results_queue): in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") args, kwargs, need_grad = get_api_info(api_info_dict, api_name) @@ -141,12 +152,10 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) - npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) grad_out, npu_grad_out = None, None if kwargs.get("device"): del kwargs["device"] out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) - npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) grad_input_index = api_setting_dict.get(api_name) grad_index = None grad = None @@ -154,12 +163,42 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di grad_index = grad_input_index.get('grad_index') if need_backward: - grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, cpu_args, backward_content, grad_index, npu_args, - npu_out, out) + grad_out, _, grad, _ = run_backward(api_full_name, cpu_args, backward_content, grad_index, cpu_args, out, out) if grad_index is not None: - return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], grad, in_fwd_data_list) - return UtDataInfo(grad_out, npu_grad_out, npu_out, out, grad, in_fwd_data_list) + cpu_data_info = UtDataInfo(grad_out, None, out[grad_index], out[grad_index], grad, in_fwd_data_list) + else: + cpu_data_info = UtDataInfo(grad_out, None, out, out, grad, in_fwd_data_list) + cpu_results_queue.put(cpu_data_info) + + +def run_torch_api_npu(api_full_name, api_setting_dict, backward_content, api_info_dict): + in_fwd_data_list = [] + [api_type, api_name, _] = api_full_name.split("*") + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) + need_backward = api_full_name in backward_content and api_name[-1] != "_" + need_backward = need_backward and need_grad + if not need_grad: + print_warn_log("%s involves in-place operations, skip backward" % api_full_name) + npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) + grad_out, npu_grad_out = None, None + if kwargs.get("device"): + del kwargs["device"] + npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) + grad_input_index = api_setting_dict.get(api_name) + grad_index = None + npu_grad = None + if grad_input_index is not None: + grad_index = grad_input_index.get('grad_index') + if need_backward: + _, npu_grad_out, _, npu_grad = run_backward(api_full_name, npu_args, backward_content, grad_index, npu_args, npu_out, npu_out) + if grad_index is not None: + npu_data_info = UtDataInfo(None, npu_grad_out, npu_out[grad_index], None, npu_grad, in_fwd_data_list) + else: + npu_data_info = UtDataInfo(None, npu_grad_out, npu_out, None, npu_grad, in_fwd_data_list) + return npu_data_info def get_api_info(api_info_dict, api_name): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict)