diff --git a/test/test_replay.py b/test/test_replay.py index b06c7c72bf2136a7ecb314160d6688c4bc942e61..43b7fc9b1d449d332aebe72f32e218bb456a6208 100644 --- a/test/test_replay.py +++ b/test/test_replay.py @@ -33,7 +33,7 @@ class Net(torch.nn.Module): class TestReplay(TestCase): def test_replay_graph(self): - if support_replay_model == False: + if support_replay_model is False: self.assertNotEqual(support_replay_model, True) return def train(): diff --git a/torch_npu/npu/replay_graph.py b/torch_npu/npu/replay_graph.py index 9616a2255a72da121f6f2464c072d52041265ab9..4a116855436f2f116a73affc19537a79b2e4f474 100644 --- a/torch_npu/npu/replay_graph.py +++ b/torch_npu/npu/replay_graph.py @@ -83,7 +83,9 @@ class WrapModule(object): with enable_replay_graph_mode(self.verbose): shallow_fwd_output = self.fwd_func(*shallow_args, **kwargs) - + if not isinstance(shallow_fwd_output, torch.Tensor): + raise TypeError("shallow_fwd_output shoule be one tensor.") + fwd_graph_info = [fwd_inputs, self.module.parameters(), self.module.buffers()] fwd_graph_inputs = [] for fwd_info in fwd_graph_info: @@ -130,7 +132,7 @@ class WrapModule(object): save_var = self.fwd_graph._ReplayGraph__get_inner_outputs(inputs=fwd_inputs) ctx.fwd_input = fwd_inputs ctx.saved_var = save_var - if fwd_output[0] is not None: + if fwd_output is not None and fwd_output[0] is not None: ctx.output = fwd_output[0] fwd_output[0].requires_grad_(True) else: