diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index b1860bc18390c120cc96f9a98d11caabe31e1a13..d7565e3186557994ff2ba11b192f0f1500a121bc 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -48,58 +48,62 @@ def exec_api(api_type, api_name, args, kwargs): return out +def deal_detach(arg, to_detach=True): + return arg.detach() if to_detach else arg + + def generate_npu_params(input_args, input_kwargs, need_backward): - def recursive_arg_to_npu(arg_in): + def recursive_arg_to_npu(arg_in, to_detach=True): if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_npu(arg) for arg in arg_in) + return type(arg_in)(recursive_arg_to_npu(arg, to_detach) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - arg_in = arg_in.clone().detach().to("npu").requires_grad_() + arg_in = deal_detach(arg_in.clone(), to_detach).to("npu").requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() return arg_in else: - return arg_in.clone().detach().to("npu") + return deal_detach(arg_in.clone(), to_detach).to("npu") else: return arg_in npu_args = recursive_arg_to_npu(input_args) - npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in input_kwargs.items()} + npu_kwargs = {key: recursive_arg_to_npu(value, key != "out") for key, value in input_kwargs.items()} return npu_args, npu_kwargs def generate_cpu_params(input_args, input_kwargs, need_backward): first_dtype = None - def recursive_arg_to_cpu(arg_in): + def recursive_arg_to_cpu(arg_in, to_detach=True): nonlocal first_dtype if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) + return type(arg_in)(recursive_arg_to_cpu(arg, to_detach) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach().requires_grad_() + arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach).requires_grad_() if first_dtype is None: first_dtype = arg_in.dtype else: - arg_in = arg_in.clone().detach().requires_grad_() + arg_in = deal_detach(arg_in.clone(), to_detach).requires_grad_() temp_arg_in = arg_in * 1 arg_in = temp_arg_in.type_as(arg_in) arg_in.retain_grad() return arg_in else: if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() + arg_in = deal_detach(arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])), to_detach) if first_dtype is None: first_dtype = arg_in.dtype return arg_in - return arg_in.clone().detach() + return deal_detach(arg_in.clone(), to_detach) else: return arg_in cpu_args = recursive_arg_to_cpu(input_args) - cpu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} + cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out") for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs