From 14e3675625991debe92ecf242670c204e4e12b79 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Thu, 3 Aug 2023 20:19:02 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=89=8D=E5=90=91=E6=AF=94?= =?UTF-8?q?=E5=AF=B9=E9=9C=80=E8=A6=81=E4=BC=A0=E5=85=A5=E7=9A=84=E6=95=B0?= =?UTF-8?q?=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index c71b32cefe0..82c8a5e45d9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -116,6 +116,8 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, value): if isinstance(arg, torch.Tensor): npu_args_grad.append(arg.grad) npu_grad_out = npu_args_grad + if grad_index is not None: + return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index] return grad_out, npu_grad_out, npu_out, out -- Gitee