diff --git a/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp b/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp index 1dd74dee33cfb4fc4709528828b272947a4e77f1..609a058414b32996c3490d71acc11692c57a808a 100644 --- a/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp +++ b/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp @@ -55,28 +55,6 @@ static ge::Tensor MakeGeTensor( uint32_t GraphExecutor::graph_id = 0; -void GraphExecutor::RunGraph( - uint32_t graph_id, - CombinedInfo& inputs, - CombinedInfo& outputs) { - RECORD_HOST_FUNCTION("RunGraph", std::vector({})); - aclrtStream cal_stream = - const_cast(c10_npu::getCurrentNPUStream().stream()); - - auto start_time = std::chrono::steady_clock::now(); - C10_NPU_CHECK(session_->RunGraphWithStreamAsync(graph_id, - cal_stream, - inputs.tensors, - outputs.tensors)); - auto duration = std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time); - if (verbose_) { - NPU_LOGI("RunGraph Time: duration = %.3f ms",static_cast(duration.count()) * - std::chrono::microseconds::period::num / - std::chrono::milliseconds::period::den); - } -} - void GraphExecutor::RunGraph( uint32_t graph_id, const std::vector& inputs, @@ -128,10 +106,10 @@ void GraphExecutor::ConstructAndExecuteGraph() { // Release GIL to avoid deadlocks. if (PyGILState_Check()) { Py_BEGIN_ALLOW_THREADS - RunGraph(cur_graph_id, inputs, outputs); + RunGraph(cur_graph_id, inputs.tensors, outputs.tensors); Py_END_ALLOW_THREADS } else { - RunGraph(cur_graph_id, inputs, outputs); + RunGraph(cur_graph_id, inputs.tensors, outputs.tensors); } ScalarMemContext::GetContext().Reset(); diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index e447648e91393ed264027f8a1d74a2b8da03b833..77e7f68c2298e27c12da5276763a87870a82f8a2 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -32,7 +32,7 @@ __all__ = [ "FloatTensor", "IntTensor", "DoubleTensor", "LongTensor", "ShortTensor", "CharTensor", "ByteTensor", "HalfTensor", "set_mm_bmm_format_nd", "get_mm_bmm_format_nd", "get_npu_overflow_flag", "clear_npu_overflow_flag", "get_rng_state", "set_rng_state", - "get_rng_state_all", "set_rng_state_all", + "get_rng_state_all", "set_rng_state_all", "make_replay_graph" ] import torch @@ -55,8 +55,7 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del max_memory_allocated, memory_reserved, max_memory_reserved, memory_cached, max_memory_cached, memory_snapshot, memory_summary) from .streams import Stream, Event -from .graph import (is_graph_mode, disable_graph_mode, enable_graph_mode, - launch_graph, enable_replay_graph_mode, disable_replay_graph_mode) +from .graph import is_graph_mode, disable_graph_mode, enable_graph_mode, launch_graph from .replay_graph import make_replay_graph from . import profiler from .npu_frontend_enhance import (set_option, set_aoe, profile, prof_init, diff --git a/torch_npu/npu/graph.py b/torch_npu/npu/graph.py index 8c2d79f4cb5ab0fbed9fe85f2f15b1fc2c98f61b..d6719314acf8af63ab3f7c7111f755500918d25e 100644 --- a/torch_npu/npu/graph.py +++ b/torch_npu/npu/graph.py @@ -26,15 +26,6 @@ def disable_graph_mode(): torch_npu._C._npu_disable_graph_mode() -def enable_replay_graph_mode(verbose=False): - torch_npu._C._npu_enable_replay_graph_mode(verbose) - - -def disable_replay_graph_mode(): - _lazy_init() - torch_npu._C._npu_disable_replay_graph_mode() - - def is_graph_mode() -> bool: return torch_npu._C._npu_is_graph_mode() diff --git a/torch_npu/npu/replay_graph.py b/torch_npu/npu/replay_graph.py index 3225d12ece4add73e977a9fc0432b1235f1b9f38..0cf20dce54a5e34dc325bc209df01afdeffad0f4 100644 --- a/torch_npu/npu/replay_graph.py +++ b/torch_npu/npu/replay_graph.py @@ -15,50 +15,57 @@ import torch import torch_npu +from contextlib import contextmanager class ReplayGraph(torch_npu._C._NPUReplayGraphBase): def __new__(cls, **kwargs): return super(ReplayGraph, cls).__new__(cls, **kwargs) - def generate_replay_graph(self, inputs, assigned_outputs, - returnable_outputs, retain_inner_outputs=False): + def __generate_replay_graph(self, inputs: list, assigned_outputs: list, + returnable_outputs: list, retain_inner_outputs: bool=False): super(ReplayGraph, self).generate_replay_graph(inputs, assigned_outputs, returnable_outputs, retain_inner_outputs) - def replay(self, inputs, assigned_outputs): + def __replay(self, inputs: list, assigned_outputs: list) -> tuple: return super(ReplayGraph, self).replay(inputs, assigned_outputs) - def get_inner_outputs(self, inputs): + def __get_inner_outputs(self, inputs: list) -> tuple: return super(ReplayGraph, self).get_inner_outputs(inputs) - def is_replay_cache_hit(self, inputs): + def __is_replay_cache_hit(self, inputs: list) -> bool: return super(ReplayGraph, self).is_replay_cache_hit(inputs) class WrapModule(object): - def __init__(self, module, func, warm_up_step=3, verbose=False): + def __init__(self, module, fwd_func, warm_up_step=3, verbose=False): self.module = module - self.func = func + self.fwd_func = fwd_func self.warm_up_step = warm_up_step self.cur_step = 0 self.fwd_graph = None self.bwd_graph = None - self.call_func = None self.param_grad = [] self.verbose = verbose - def wrap_forward(self, *args, **kwargs): - origin_inputs = [] + def __wrap_forward(self, *args, **kwargs): for arg in args: if isinstance(arg, torch.Tensor): arg.requires_grad_(True) - origin_inputs.append(arg) + else: + raise TypeError("All args should be tensor in replay graph mode") + + for arg in kwargs.values(): + if not isinstance(arg, torch.Tensor): + raise TypeError("All args should be tensor in replay graph mode") replay_cache = False - if (self.fwd_graph is not None): - replay_cache = self.fwd_graph.is_replay_cache_hit(origin_inputs) + if self.fwd_graph is not None: + origin_inputs = [arg for arg in args if isinstance(arg, torch.Tensor)] + replay_cache = self.fwd_graph._ReplayGraph__is_replay_cache_hit(origin_inputs) + del origin_inputs - if (self.cur_step < self.warm_up_step) or not (replay_cache): + if self.cur_step < self.warm_up_step or not replay_cache: + self.cur_step = self.cur_step + 1 for p in self.module.parameters(): p.grad = torch.zeros_like(p) shallow_args = () @@ -74,76 +81,58 @@ class WrapModule(object): tu = (arg,) shallow_args = shallow_args + tu - torch_npu.npu.enable_replay_graph_mode(self.verbose) + with enable_replay_graph_mode(self.verbose): + shallow_fwd_output = self.fwd_func(*shallow_args, **kwargs) - shallow_fwd_output = self.func(*shallow_args, **kwargs) - fwd_graph_inputs = [] - fwd_graph_inputs.extend(fwd_inputs) - fwd_graph_inputs.extend(self.module.parameters()) - fwd_graph_inputs.extend(self.module.buffers()) - fwd_assigned_outputs = [] - if (self.fwd_graph is None): - self.fwd_graph = generate_replay_graph(inputs=fwd_graph_inputs, + fwd_graph_info = [fwd_inputs, self.module.parameters(), self.module.buffers()] + fwd_graph_inputs = [] + for fwd_info in fwd_graph_info: + fwd_graph_inputs.extend(fwd_info) + fwd_assigned_outputs = [] + self.fwd_graph = generate_replay_graph(replay_graph=self.fwd_graph, + inputs=fwd_graph_inputs, assigned_outputs=fwd_assigned_outputs, returnable_outputs=[shallow_fwd_output], retain_inner_outputs=True) - else: - self.fwd_graph.generate_replay_graph(inputs=fwd_graph_inputs, - assigned_outputs=fwd_assigned_outputs, - returnable_outputs=[shallow_fwd_output], - retain_inner_outputs=True) - - saved_var = self.fwd_graph.get_inner_outputs(inputs=origin_inputs) - grad_input = torch.empty_like(shallow_fwd_output) - torch.autograd.backward(shallow_fwd_output, grad_input) - self.param_grad = [] - for p in self.module.parameters(): - if p.grad is not None: - self.param_grad.append(p.grad) - - grad_output = [] - for fwd_input in fwd_inputs: - grad_output.append(fwd_input.grad) - - bwd_graph_inputs = [] - bwd_graph_inputs.extend(fwd_graph_inputs) - bwd_graph_inputs.extend(saved_var) - bwd_graph_inputs.append(grad_input) - bwd_graph_inputs.extend(self.param_grad) - bwd_graph_inputs.extend([shallow_fwd_output]) - if (self.bwd_graph is None): - self.bwd_graph = generate_replay_graph(inputs=bwd_graph_inputs, + saved_var = self.fwd_graph._ReplayGraph__get_inner_outputs(inputs=fwd_inputs) + grad_input = torch.empty_like(shallow_fwd_output) + torch.autograd.backward(shallow_fwd_output, grad_input) + + self.param_grad = [p.grad for p in self.module.parameters() if p.grad is not None] + grad_output = [fwd_input.grad for fwd_input in fwd_inputs] + bwd_graph_info = [fwd_graph_inputs, saved_var, [grad_input], self.param_grad, [shallow_fwd_output]] + bwd_graph_inputs = [] + for bwd_info in bwd_graph_info: + bwd_graph_inputs.extend(bwd_info) + self.bwd_graph = generate_replay_graph(replay_graph=self.bwd_graph, + inputs=bwd_graph_inputs, assigned_outputs=self.param_grad, returnable_outputs=grad_output) - else: - self.bwd_graph.generate_replay_graph(inputs=bwd_graph_inputs, - assigned_outputs=self.param_grad, - returnable_outputs=grad_output) - - torch_npu.npu.disable_replay_graph_mode() - self.cur_step = self.cur_step + 1 + del saved_var, fwd_graph_inputs, bwd_graph_inputs, grad_input, grad_output, shallow_fwd_output,\ + fwd_inputs, shallow_input class ReplayFunction(torch.autograd.Function): @staticmethod def forward(ctx, *args, **kwargs): - fwd_inputs = [] - for arg in args: - if isinstance(arg, torch.Tensor): - fwd_inputs.append(arg) - + fwd_inputs = [arg for arg in args if isinstance(arg, torch.Tensor)] + fwd_inputs_full_info = [fwd_inputs, self.module.parameters(), self.module.buffers()] fwd_inputs_full = [] - fwd_inputs_full.extend(fwd_inputs) - fwd_inputs_full.extend(self.module.parameters()) - fwd_inputs_full.extend(self.module.buffers()) + for info in fwd_inputs_full_info: + fwd_inputs_full.extend(info) fwd_assigned_outputs = [] - fwd_output = self.fwd_graph.replay(inputs=fwd_inputs_full, assigned_outputs=fwd_assigned_outputs) - save_var = self.fwd_graph.get_inner_outputs(inputs=fwd_inputs) + + fwd_output = self.fwd_graph._ReplayGraph__replay(inputs=fwd_inputs_full, + assigned_outputs=fwd_assigned_outputs) + save_var = self.fwd_graph._ReplayGraph__get_inner_outputs(inputs=fwd_inputs) ctx.fwd_input = fwd_inputs ctx.saved_var = save_var - ctx.output = fwd_output[0] - fwd_output[0].requires_grad_(True) + if fwd_output[0] is not None: + ctx.output = fwd_output[0] + fwd_output[0].requires_grad_(True) + else: + raise ValueError("Forward output has no value") return fwd_output @staticmethod @@ -162,31 +151,41 @@ class WrapModule(object): p.grad = torch.zeros_like(p) self.param_grad.append(p.grad) + bwd_inputs_full_info = [ctx.fwd_input, self.module.parameters(), self.module.buffers(), + ctx.saved_var, grad_outputs, self.param_grad, [ctx.output]] bwd_inputs_full = [] - bwd_inputs_full.extend(ctx.fwd_input) - bwd_inputs_full.extend(self.module.parameters()) - bwd_inputs_full.extend(self.module.buffers()) - bwd_inputs_full.extend(ctx.saved_var) - bwd_inputs_full.extend(grad_outputs) - bwd_inputs_full.extend(self.param_grad) - bwd_inputs_full.extend([ctx.output]) - bwd_output = self.bwd_graph.replay(inputs=bwd_inputs_full, assigned_outputs=self.param_grad) - ctx.saved_var = [] - ctx.output = [] + for info in bwd_inputs_full_info: + bwd_inputs_full.extend(info) + + bwd_output = self.bwd_graph._ReplayGraph__replay(inputs=bwd_inputs_full, + assigned_outputs=self.param_grad) + if bwd_output is None: + raise ValueError("Backward output has no value") + ctx.saved_var, ctx.output = [], [] return bwd_output ret = ReplayFunction.apply(*args, **kwargs) + if ret[0] is None: + raise ValueError("ReplayFunction return has no value") return ret[0] -def make_replay_graph(module, verbose_=False): +def make_replay_graph(module: torch.nn.Module, verbose_: bool=False) -> torch.nn.Module: wrap_module = WrapModule(module, module.forward, verbose=verbose_) - module.forward = wrap_module.wrap_forward + module.forward = wrap_module._WrapModule__wrap_forward module.is_replay_graph = True return module -def generate_replay_graph(inputs, assigned_outputs, returnable_outputs, retain_inner_outputs=False): - replay_graph = ReplayGraph() - replay_graph.generate_replay_graph(inputs, assigned_outputs, returnable_outputs, retain_inner_outputs) - return replay_graph \ No newline at end of file +def generate_replay_graph(replay_graph: ReplayGraph, inputs: list, assigned_outputs: list, + returnable_outputs: list, retain_inner_outputs: bool=False) -> ReplayGraph: + if replay_graph is None: + replay_graph = ReplayGraph() + replay_graph._ReplayGraph__generate_replay_graph(inputs, assigned_outputs, returnable_outputs, retain_inner_outputs) + return replay_graph + +@contextmanager +def enable_replay_graph_mode(verbose: bool=False): + torch_npu.npu.enable_graph_mode(verbose) + yield 1 + torch_npu.npu.disable_graph_mode()