diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index aa48ce5edd2c7fbd7e21cb3f9ca9048116bc0abf..28ca86793f1db03647df525ea5037b64623dde35 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -47,6 +47,8 @@ def gen_data(info, need_grad, convert_type): data = gen_random_tensor(info, convert_type) if info.get('requires_grad') and need_grad: data.requires_grad_(True) + temp_data = data * 1 + data = temp_data.type_as(data) data.retain_grad() else: data = info.get('value') diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index bfb2296a42568a4f76c1732bee635360d0b874dd..52d5d928d045913112f03a5093d936dc1627428b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -44,7 +44,11 @@ def generate_npu_params(cpu_args, cpu_kwargs, need_backward): return type(arg_in)(recursive_arg_to_npu(arg) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - return arg_in.clone().detach().to("npu").requires_grad_() + arg_in = arg_in.clone().detach().to("npu").requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in else: return arg_in.clone().detach().to("npu") else: