From 6e84c9dfbe80c35bf945ca5f89c22f657fdd1fbf Mon Sep 17 00:00:00 2001 From: sunyiming Date: Thu, 11 Jan 2024 01:04:11 +0000 Subject: [PATCH 1/3] bugfix Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index cf9bfe29bd..cb49c0e621 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -236,7 +236,7 @@ def run_backward(api_full_name, args, backward_content, grad_index, device_args, if grad_index is not None: out[grad_index].backward(cpu_grad) elif isinstance(out, (list, tuple)): - raise NotImplementedError("Multiple backward is not supported.") + return None, None, None, None else: out.backward(cpu_grad) args_grad = [] -- Gitee From ad84e4bb41b86fac4fba5d3fca6db0da1b6abae1 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Thu, 11 Jan 2024 07:44:31 +0000 Subject: [PATCH 2/3] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index cb49c0e621..23bca30a6e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -236,7 +236,7 @@ def run_backward(api_full_name, args, backward_content, grad_index, device_args, if grad_index is not None: out[grad_index].backward(cpu_grad) elif isinstance(out, (list, tuple)): - return None, None, None, None + return (None,) * 4 else: out.backward(cpu_grad) args_grad = [] -- Gitee From 2a20713d5ce4927d76825c60a3ecc62bdd3cbb74 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 12 Jan 2024 01:38:37 +0000 Subject: [PATCH 3/3] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: sunyiming --- .../api_accuracy_checker/run_ut/run_ut.py | 44 ++++++++++--------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 23bca30a6e..c5d4daada6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -233,28 +233,30 @@ def run_backward(api_full_name, args, backward_content, grad_index, device_args, backward_args = backward_content[api_full_name] grad = gen_args(backward_args)[0] cpu_grad, _ = generate_cpu_params(grad, {}, False) - if grad_index is not None: - out[grad_index].backward(cpu_grad) - elif isinstance(out, (list, tuple)): - return (None,) * 4 - else: - out.backward(cpu_grad) - args_grad = [] - for arg in args: - if isinstance(arg, torch.Tensor): - args_grad.append(arg.grad) - grad_out = args_grad - device_grad = grad.clone().detach().to(current_device) - if grad_index is not None: - device_out[grad_index].backward(device_grad) + if isinstance(out, (list, tuple)): + grad_out = [None] * len(out) + device_grad_out = [None] * len(device_out) else: - device_out.backward(device_grad) - device_args_grad = [] - for arg in device_args: - if isinstance(arg, torch.Tensor): - device_args_grad.append(arg.grad) - device_grad_out = device_args_grad - return grad_out, device_grad_out, grad, device_grad + if grad_index is not None: + out[grad_index].backward(cpu_grad) + device_out[grad_index].backward(grad.clone().detach().to(current_device)) + else: + out.backward(cpu_grad) + device_out.backward(grad.clone().detach().to(current_device)) + + args_grad = [] + for arg in args: + if isinstance(arg, torch.Tensor): + args_grad.append(arg.grad) + grad_out = args_grad + + device_args_grad = [] + for arg in device_args: + if isinstance(arg, torch.Tensor): + device_args_grad.append(arg.grad) + device_grad_out = device_args_grad + + return grad_out, device_grad_out, grad, cpu_grad def initialize_save_error_data(): -- Gitee