From 2164758e24b7f0092c5755d45685b5da7d175c34 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 27 Sep 2023 16:23:16 +0800 Subject: [PATCH] fix --- .../api_accuracy_checker/run_ut/run_ut.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 33e2e5adf..30dace5e3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -66,13 +66,17 @@ def generate_npu_params(input_args, input_kwargs, need_backward): return npu_args, npu_kwargs def generate_cpu_params(input_args, input_kwargs, need_backward): + first_dtype = None def recursive_arg_to_cpu(arg_in): + nonlocal first_dtype if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys(): + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach().requires_grad_() + if first_dtype is None: + first_dtype = arg_in.dtype else: arg_in = arg_in.clone().detach().requires_grad_() temp_arg_in = arg_in * 1 @@ -80,8 +84,11 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): arg_in.retain_grad() return arg_in else: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys(): - return arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: + arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() + if first_dtype is None: + first_dtype = arg_in.dtype + return arg_in return arg_in.clone().detach() else: return arg_in -- Gitee