diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 33e2e5adf5e46e9c82a7ed334c29a72814b3bc0e..30dace5e3891109032961bd977467cd917b2f0cd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -66,13 +66,17 @@ def generate_npu_params(input_args, input_kwargs, need_backward): return npu_args, npu_kwargs def generate_cpu_params(input_args, input_kwargs, need_backward): + first_dtype = None def recursive_arg_to_cpu(arg_in): + nonlocal first_dtype if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys(): + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach().requires_grad_() + if first_dtype is None: + first_dtype = arg_in.dtype else: arg_in = arg_in.clone().detach().requires_grad_() temp_arg_in = arg_in * 1 @@ -80,8 +84,11 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): arg_in.retain_grad() return arg_in else: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys(): - return arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: + arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() + if first_dtype is None: + first_dtype = arg_in.dtype + return arg_in return arg_in.clone().detach() else: return arg_in