diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 536d607dd696b16e10ee5302aafbc64e975aa33d..228e10b3851d9176fc37d3955ac6ea15a73d0dd6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -185,10 +185,11 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np backward_args = backward_content[api_full_name] grad = gen_args(backward_args)[0] cpu_grad, _ = generate_cpu_params(grad, {}, False) - if grad_index is not None: + if isinstance(out, tuple): + grad_outputs = tuple(torch.randn_like(out_item) for out_item in out) + torch.autograd.backward(out, grad_outputs=grad_outputs) + elif grad_index is not None: out[grad_index].backward(cpu_grad) - elif isinstance(out, (list, tuple)): - raise NotImplementedError("Multiple backward is not supported.") else: out.backward(cpu_grad) args_grad = [] @@ -197,7 +198,10 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np args_grad.append(arg.grad) grad_out = args_grad npu_grad = grad.clone().detach().npu() - if grad_index is not None: + if isinstance(npu_out, tuple): + npu_grad_outputs = tuple(torch.randn_like(npu_out_item) for npu_out_item in npu_out) + torch.autograd.backward(npu_out, grad_outputs=npu_grad_outputs) + elif grad_index is not None: npu_out[grad_index].backward(npu_grad) else: npu_out.backward(npu_grad) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py index ed764987d5a9287293f183c0bde1d86afd90ccae..a68057dfb41ca38ba79e1daa992a8f51ce4d64e4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py @@ -17,5 +17,5 @@ class TestConfig(unittest.TestCase): def test_update_config(self): - self.config.update_config(dump_path='/new/path/to/dump', enable_dataloader=False) + self.config.update_config(dump_path='/new/path/to/dump') self.assertEqual(self.config.dump_path, '/new/path/to/dump') diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py index addba38e38446b177942a104b4194efe910b1f7c..b892a6077a3c26ae27343734aca8012e21d3fc2c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py @@ -10,12 +10,12 @@ class TestDumpScope(unittest.TestCase): wrapped_func = iter_tracer(dummy_func) result = wrapped_func() - self.assertEqual(DumpUtil.dump_switch, "ON") + self.assertEqual(DumpUtil.dump_switch, "OFF") self.assertEqual(result, "Hello, World!") def another_dummy_func(): return 123 wrapped_func = iter_tracer(another_dummy_func) result = wrapped_func() - self.assertEqual(DumpUtil.dump_switch, "ON") + self.assertEqual(DumpUtil.dump_switch, "OFF") self.assertEqual(result, 123)