diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index 9f968c1d87064d81c0e6b723a1754bb6dbdcf9d0..b1c97abef2c8216d346ca0604a35feeec7c4bda7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -4,7 +4,7 @@ import sys import torch_npu import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info +from api_accuracy_checker.run_ut.run_ut import exec_api, get_api_info, GenerateParams from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, print_error_log from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import check_link @@ -70,7 +70,7 @@ def run_torch_api(api_full_name, api_info_dict): if not need_grad: print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) - npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name) + npu_args, npu_kwargs = GenerateParams(args, kwargs, False, api_name, False).generate_params() if kwargs.get("device"): del kwargs["device"] out = exec_api(api_type, api_name, args, kwargs) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 5101705434de0b0325817b05758922197f8bf381..a5e2160d2d7665c19e2cd3e7f0344bb642e5f725 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -36,7 +36,6 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] -not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 @@ -67,75 +66,6 @@ def exec_api(api_type, api_name, args, kwargs): return out -def deal_detach(arg, to_detach=True): - return arg.detach() if to_detach else arg - - -def deal_dtype(arg, raise_dtype=None): - if raise_dtype is None or arg.dtype not in Const.RAISE_PRECISION or raise_dtype == arg.dtype: - return arg - return arg.type(raise_dtype) - - -def generate_device_params(input_args, input_kwargs, need_backward, api_name): - def recursive_arg_to_device(arg_in, to_detach): - if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_device(arg, to_detach) for arg in arg_in) - elif isinstance(arg_in, torch.Tensor): - if need_backward and arg_in.requires_grad: - arg_in = deal_detach(arg_in.clone(), to_detach).to(current_device).requires_grad_() - temp_arg_in = arg_in * 1 - arg_in = temp_arg_in.type_as(arg_in) - arg_in.retain_grad() - return arg_in - else: - return deal_detach(arg_in.clone(), to_detach).to(current_device) - else: - return arg_in - - is_detach = api_name not in not_detach_set - device_args = recursive_arg_to_device(input_args, is_detach) - device_kwargs = \ - {key: recursive_arg_to_device(value, key != "out" and is_detach) for key, value in input_kwargs.items()} - return device_args, device_kwargs - - -def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): - def recursive_arg_to_cpu(arg_in, to_detach, raise_dtype=None): - if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_cpu(arg, to_detach, raise_dtype=raise_dtype) for arg in arg_in) - elif isinstance(arg_in, torch.Tensor): - if need_backward and arg_in.requires_grad: - arg_in = deal_detach(deal_dtype(arg_in.clone(), raise_dtype), to_detach).requires_grad_() - temp_arg_in = arg_in * 1 - arg_in = temp_arg_in.type_as(arg_in) - arg_in.retain_grad() - return arg_in - else: - return deal_detach(deal_dtype(arg_in.clone(), raise_dtype=raise_dtype), to_detach) - else: - return arg_in - - def recursive_find_dtypes(arg_in): - if isinstance(arg_in, (list, tuple)): - return set().union(*tuple(recursive_find_dtypes(arg) for arg in arg_in)) - elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: - return set([arg_in.dtype]) - return set() - - raise_dtype = None - need_raise_dtypes = recursive_find_dtypes(input_args) - if len(need_raise_dtypes) == 1: - raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) - elif len(need_raise_dtypes) >= 2: - raise_dtype = torch.float32 - - is_detach = api_name not in not_detach_set - cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype) - cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach) for key, value in input_kwargs.items()} - return cpu_args, cpu_kwargs - - def run_ut(config): print_info_log("start UT test") print_info_log(f"UT task result will be saved in {config.result_csv_path}") @@ -211,8 +141,8 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict need_backward = need_backward and need_grad if kwargs.get("device"): del kwargs["device"] - cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward, api_name) - device_args, device_kwargs = generate_device_params(args, kwargs, need_backward, api_name) + cpu_args, cpu_kwargs = GenerateParams(args, kwargs, need_backward, api_name).generate_params() + device_args, device_kwargs = GenerateParams(args, kwargs, need_backward, api_name, is_cpu=False).generate_params() bench_grad_out, device_grad_out = None, None out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) device_out = exec_api(api_type, api_name, device_args, device_kwargs) @@ -226,7 +156,7 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict if need_backward: backward_args = backward_content[api_full_name] grad = gen_args(backward_args, real_data_path=real_data_path)[0] - bench_grad, _ = generate_cpu_params(grad, {}, False, api_name) + bench_grad, _ = GenerateParams(grad, {}, False, api_name).generate_params() bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out) device_grad = grad.clone().detach().to(current_device) device_grad_out = run_backward(device_args, device_grad, grad_index, device_out) @@ -247,7 +177,6 @@ def get_api_info(api_info_dict, api_name, real_data_path): def run_backward(args, grad, grad_index, out): - if grad_index is not None: out[grad_index].backward(grad) elif isinstance(out, (list, tuple)): @@ -397,6 +326,62 @@ class UtAPIInfo(APIInfo): self.analyze_element(element) +class GenerateParams: + def __init__(self, input_args, input_kwargs, need_backward, api_name, is_cpu=True): + self.input_args = input_args + self.input_kwargs = input_kwargs + self.need_backward = need_backward + self.api_name = api_name + self.is_cpu = is_cpu + self.not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} + self.is_detach = self.api_name not in self.not_detach_set + self.raise_dtype = None + if self.is_cpu: + need_raise_dtypes = self._recursive_find_dtypes(input_args) + if len(need_raise_dtypes) == 1: + self.raise_dtype = Const.RAISE_PRECISION.get(need_raise_dtypes.pop()) + elif len(need_raise_dtypes) >= 2: + self.raise_dtype = torch.float32 + + def generate_params(self): + args = self._recursive_arg(self.input_args, self.is_detach) + kwargs = {key: self._recursive_arg(value, key != "out" and self.is_detach) + for key, value in self.input_kwargs.items()} + return args, kwargs + + def _deal_detach_to_device(self, arg, to_detach=True): + arg = self._deal_dtype(arg) + arg = arg.detach() if to_detach else arg + return arg if self.is_cpu else arg.to(current_device) + + def _recursive_find_dtypes(self, arg_in): + if isinstance(arg_in, (list, tuple)): + return set().union(*tuple(self._recursive_find_dtypes(arg) for arg in arg_in)) + elif isinstance(arg_in, torch.Tensor) and arg_in.dtype in Const.RAISE_PRECISION: + return {arg_in.dtype} + return set() + + def _deal_dtype(self, arg): + return arg \ + if not self.is_cpu or self.raise_dtype in (None, arg.dtype) or arg.dtype not in Const.RAISE_PRECISION \ + else arg.type(self.raise_dtype) + + def _recursive_arg(self, arg_in, to_detach): + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(self._recursive_arg(arg, to_detach) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if self.need_backward and arg_in.requires_grad: + arg_in = self._deal_detach_to_device(arg_in.clone(), to_detach).requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in + else: + return self._deal_detach_to_device(arg_in.clone(), to_detach) + else: + return arg_in + + if __name__ == '__main__': _run_ut() print_info_log("UT task completed.") diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index fdcc1cfddeb38d4fca0d2a67a09147b571b35def..7837edc43a897b1ee542459804a95edc8c48067e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -19,7 +19,7 @@ class TestRunUtMethods(unittest.TestCase): api_info = copy.deepcopy(api_info_dict) [api_type, api_name, _] = api_full_name.split("*") args, kwargs, need_grad = get_api_info(api_info, api_name, None) - cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '') + cpu_args, cpu_kwargs = GenerateParams(args, kwargs, True, '', True).generate_params() out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) self.assertEqual(out.dtype, torch.float64) self.assertTrue(out.requires_grad) @@ -42,7 +42,7 @@ class TestRunUtMethods(unittest.TestCase): mocks['retain_grad'].return_value = None mocks['to'].return_value = mock_tensor - device_args, device_kwargs = generate_device_params([mock_tensor], {'inplace': False}, True, '') + device_args, device_kwargs = GenerateParams([mock_tensor], {'inplace': False}, True, '', False).generate_params() self.assertEqual(len(device_args), 1) self.assertEqual(device_args[0].dtype, torch.float32) self.assertTrue(device_args[0].requires_grad) @@ -53,7 +53,7 @@ class TestRunUtMethods(unittest.TestCase): api_info = copy.deepcopy(api_info_dict) [api_type, api_name, _] = api_full_name.split("*") args, kwargs, need_grad = get_api_info(api_info, api_name, None) - cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, True, '') + cpu_args, cpu_kwargs = GenerateParams(args, kwargs, True, '', True).generate_params() self.assertEqual(len(cpu_args), 1) self.assertEqual(cpu_args[0].dtype, torch.float64) self.assertTrue(cpu_args[0].requires_grad)