diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index fb455434cf5464d416b16cb6a903aa9d29c2236b..7b23ea55793022183ea5aec788f7a2f71693d742 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -4,11 +4,9 @@ import sys import torch_npu import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, run_backward, get_api_info -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ - print_error_log - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, check_file_suffix, check_link +from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info +from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, print_error_log +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import check_link def check_tensor_overflow(x): @@ -46,16 +44,12 @@ def check_data_overflow(x): return check_tensor_overflow(x) -def run_overflow_check(forward_file, backward_file): +def run_overflow_check(forward_file): print_info_log("start UT test") forward_content = get_json_contents(forward_file) - backward_content = {} - if backward_file: - backward_content = get_json_contents(backward_file) - api_setting_dict = get_json_contents("torch_ut_setting.json") for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) + run_torch_api(api_full_name, api_info_dict) except Exception as err: api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] if "not implemented for 'Half'" in str(err): @@ -68,44 +62,25 @@ def run_overflow_check(forward_file, backward_file): print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) -def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): +def run_torch_api(api_full_name, api_info_dict): torch.npu.clear_npu_overflow_flag() api_type = api_full_name.split("_")[0] api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] args, kwargs, need_grad = get_api_info(api_info_dict, api_name) - need_backward = api_full_name.replace("forward", "backward") in backward_content - need_backward = need_backward and need_grad if not need_grad: print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) - npu_args, npu_kwargs = generate_device_params(args, kwargs, need_backward) + npu_args, npu_kwargs = generate_device_params(args, kwargs, False) if kwargs.get("device"): del kwargs["device"] out = exec_api(api_type, api_name, args, kwargs) npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) - if not need_backward: - cpu_overflow = check_data_overflow(out) - npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out) - if cpu_overflow == npu_overflow: - print_warn_log("The %s overflow is a normal overflow." % api_full_name) - else: - print_warn_log("The %s overflow is an abnormal overflow." % api_full_name) - return + cpu_overflow = check_data_overflow(out) + npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out) + if cpu_overflow == npu_overflow: + print_warn_log("The %s overflow is a normal overflow." % api_full_name) else: - api_full_name = api_full_name.replace("forward", "backward") - grad_input_index = api_setting_dict.get(api_name) - grad_index = None - if grad_input_index is not None: - grad_index = grad_input_index.get('grad_index') - - grad_out, npu_grad_out = run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out) - - cpu_overflow = check_data_overflow(grad_out) - npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_grad_out) - if cpu_overflow == npu_overflow: - print_warn_log("The %s overflow is a normal overflow." % api_full_name) - else: - print_warn_log("The %s overflow is an abnormal overflow." % api_full_name) - return + print_warn_log("The %s overflow is an abnormal overflow." % api_full_name) + return def _run_ut_parser(parser): @@ -113,10 +88,6 @@ def _run_ut_parser(parser): help=" The api param tool forward result file: generate from api param tool, " "a json file.", required=True) - parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", - help=" The api param tool backward result file: generate from api param tool, " - "a json file.", - required=False) parser.add_argument("-j", "--jit_compile", dest="jit_compile", help=" whether to turn on jit compile", default=False, required=False) parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set NPU device id to run ut", @@ -131,18 +102,12 @@ def _run_overflow_check(): npu_device = "npu:" + str(args.device_id) check_link(args.forward_input_file) forward_file = os.path.realpath(args.forward_input_file) - backward_file = "" - if args.backward_input_file: - check_link(args.backward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) - check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) try: torch.npu.set_device(npu_device) except Exception as error: print_error_log(f"Set NPU device id failed. device id is: {args.device_id}") raise NotImplementedError from error - run_overflow_check(forward_file, backward_file) + run_overflow_check(forward_file) if __name__ == '__main__': diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fb885168d662006b3fa639c2322113d6f1f6b386..8986849da6815b146005d4ecf7caa5837dc5ede3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -296,7 +296,7 @@ def _run_ut_parser(parser): parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", type=str, help=" The api param tool backward result file: generate from api param tool, " "a json file.", - required=True) + required=False) parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) @@ -328,17 +328,19 @@ def _run_ut(): print_error_log(f"Set device id failed. device id is: {args.device_id}") raise NotImplementedError from error check_link(args.forward_input_file) - check_link(args.backward_input_file) forward_file = os.path.realpath(args.forward_input_file) - backward_file = os.path.realpath(args.backward_input_file) check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) out_path = os.path.realpath(args.out_path) if args.out_path else "./" out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() save_error_data = args.save_error_data forward_content = get_json_contents(forward_file) - backward_content = get_json_contents(backward_file) + backward_content = {} + if args.backward_input_file: + check_link(args.backward_input_file) + backward_file = os.path.realpath(args.backward_input_file) + check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) + backward_content = get_json_contents(backward_file) result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) test_result_cnt = None