diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index cf9bfe29bdfd054db96ae878813c099892430bbb..c5d4daada698acedc4d3854a5b6f415c0ee4fb75 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -233,28 +233,30 @@ def run_backward(api_full_name, args, backward_content, grad_index, device_args, backward_args = backward_content[api_full_name] grad = gen_args(backward_args)[0] cpu_grad, _ = generate_cpu_params(grad, {}, False) - if grad_index is not None: - out[grad_index].backward(cpu_grad) - elif isinstance(out, (list, tuple)): - raise NotImplementedError("Multiple backward is not supported.") - else: - out.backward(cpu_grad) - args_grad = [] - for arg in args: - if isinstance(arg, torch.Tensor): - args_grad.append(arg.grad) - grad_out = args_grad - device_grad = grad.clone().detach().to(current_device) - if grad_index is not None: - device_out[grad_index].backward(device_grad) + if isinstance(out, (list, tuple)): + grad_out = [None] * len(out) + device_grad_out = [None] * len(device_out) else: - device_out.backward(device_grad) - device_args_grad = [] - for arg in device_args: - if isinstance(arg, torch.Tensor): - device_args_grad.append(arg.grad) - device_grad_out = device_args_grad - return grad_out, device_grad_out, grad, device_grad + if grad_index is not None: + out[grad_index].backward(cpu_grad) + device_out[grad_index].backward(grad.clone().detach().to(current_device)) + else: + out.backward(cpu_grad) + device_out.backward(grad.clone().detach().to(current_device)) + + args_grad = [] + for arg in args: + if isinstance(arg, torch.Tensor): + args_grad.append(arg.grad) + grad_out = args_grad + + device_args_grad = [] + for arg in device_args: + if isinstance(arg, torch.Tensor): + device_args_grad.append(arg.grad) + device_grad_out = device_args_grad + + return grad_out, device_grad_out, grad, cpu_grad def initialize_save_error_data():