diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index c71b32cefe02418bc4148545d862921c18a7c773..82c8a5e45d9792f0d3e74ac9c853100ebaf9a788 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -116,6 +116,8 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, value): if isinstance(arg, torch.Tensor): npu_args_grad.append(arg.grad) npu_grad_out = npu_args_grad + if grad_index is not None: + return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index] return grad_out, npu_grad_out, npu_out, out