diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0c18da7d0e8f573260a68e94f89b25727a035b5a..350f436b380671d1875ab4a883000766a2019a9d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -31,34 +31,20 @@ def exec_api(api_type, api_name, args, kwargs): def generate_npu_params(cpu_args, cpu_kwargs, need_backward): - npu_args = [] - npu_kwargs = {} - if need_backward: - for arg_in in cpu_args: - arg_in = arg_to_npu(arg_in) - npu_args.append(arg_in) - for key, value in cpu_kwargs.items(): - value = arg_to_npu(value) - npu_kwargs[key] = value - else: - for arg_in in cpu_args: - if isinstance(arg_in, torch.Tensor): - arg_in = arg_in.clone().detach().to("npu") - npu_args.append(arg_in) - for key, value in cpu_kwargs.items(): - if isinstance(value, torch.Tensor): - value = value.clone().detach().to("npu") - npu_kwargs[key] = value - return npu_args, npu_kwargs - + def recursive_arg_to_npu(arg_in): + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(recursive_arg_to_npu(arg) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if need_backward and arg_in.requires_grad: + return arg_in.clone().detach().to("npu").requires_grad_() + else: + return arg_in.clone().detach().to("npu") + else: + return arg_in -def arg_to_npu(arg_in): - if isinstance(arg_in, torch.Tensor) and arg_in.dtype in [torch.float, torch.float16, - torch.float64] and arg_in.requires_grad: - arg_in = arg_in.clone().detach().to("npu").requires_grad_() - elif isinstance(arg_in, torch.Tensor): - arg_in = arg_in.clone().detach().to("npu") - return arg_in + npu_args = recursive_arg_to_npu(cpu_args) + npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in cpu_kwargs.items()} + return npu_args, npu_kwargs def run_ut(forward_file, backward_file, out_path, save_error_data):