diff --git a/torch_npu/csrc/InitNpuBindings.cpp b/torch_npu/csrc/InitNpuBindings.cpp index 44df05c2f7d5acb8acd5f42005ee86811bed7fc3..79a8620758b0312d87421df0d98fdc9e774ac530 100644 --- a/torch_npu/csrc/InitNpuBindings.cpp +++ b/torch_npu/csrc/InitNpuBindings.cpp @@ -19,6 +19,7 @@ #include #include "torch_npu/csrc/npu/Event.h" +#include "torch_npu/csrc/npu/ReplayFunctions.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/framework/graph/execute/GraphExecutor.h" #include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h" @@ -77,6 +78,7 @@ static PyMethodDef TorchNpuMethods[] = { void THNPStream_init(PyObject *module); void THNPEvent_init(PyObject *module); +void THNPReplayGraph_init(PyObject *module); bool THPGenerator_init(PyObject *module); PyMethodDef* THNPModule_get_methods(); @@ -112,6 +114,7 @@ PyObject* initModule(){ // C, so these lines have to execute first).. THNPStream_init(module); THNPEvent_init(module); + THNPReplayGraph_init(module); THPGenerator_init(module); torch_npu::autograd::initTorchFunctions(module); diff --git a/torch_npu/csrc/framework/graph/ReplayGraph.cpp b/torch_npu/csrc/framework/graph/ReplayGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8fbb78c0e6ada5819775e5cb443862d9332ea970 --- /dev/null +++ b/torch_npu/csrc/framework/graph/ReplayGraph.cpp @@ -0,0 +1,280 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "ReplayGraph.h" +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at_npu { +namespace native { +bool ReplayGraphImpl::ReplayCacheHit(const at::TensorList& inputs) { + auto input_tensor_shape = inputs[0].sizes(); + return replay_graph_cache_.FindGraphCache(input_tensor_shape); +} + +TensorVec ReplayGraphImpl::Replay(const at::TensorList& inputs, at::TensorList assigned_outputs) { + int64_t info_list_index = 0; + auto input_tensor_shape = inputs[0].sizes(); + info_list_index = replay_graph_cache_.ReturnGraphCache(input_tensor_shape); + auto &graphinfo = graph_infolist_[info_list_index]; + TORCH_CHECK(inputs.size() == graphinfo.inputs.python_inputs_info.size(), + "Replay must have same num of inputs with generate graph"); + TORCH_CHECK(assigned_outputs.size() == graphinfo.outputs.python_assigned_outputs_info.size(), + "Replay must have same num of assigned outputs with generate graph"); + + for (size_t i = 0; i < inputs.size(); i++) { + const int64_t& idx = graphinfo.inputs.inputs_mapping[i]; + if (idx >= 0) { + if (NpuUtils::check_match(&inputs[i])) { + auto data_ptr = inputs[i].data_ptr(); + TORCH_CHECK(data_ptr != nullptr, "Input for replay graph must have data ptr"); + size_t numel = NPUNativeFunctions::get_storage_size(inputs[i]); + graphinfo.inputs.graph_inputs_ge_tensors[idx].SetData(reinterpret_cast(data_ptr), + numel * inputs[i].itemsize(), [](uint8_t* device_ptr) {return;}); + } else { + auto contiguous_input = NpuUtils::format_contiguous(inputs[i]); + auto data_ptr = contiguous_input.data_ptr(); + TORCH_CHECK(data_ptr != nullptr, "Input for replay graph must have data ptr"); + size_t numel = NPUNativeFunctions::get_storage_size(contiguous_input); + graphinfo.inputs.graph_inputs_ge_tensors[idx].SetData(reinterpret_cast(data_ptr), + numel * contiguous_input.itemsize(), [](uint8_t* device_ptr) {return;}); + } + } + } + + for (size_t i = 0; i < assigned_outputs.size(); i++) { + const int64_t& idx = graphinfo.outputs.assigned_outputs_mapping[i]; + size_t numel = NPUNativeFunctions::get_storage_size(assigned_outputs[i]); + auto data_ptr = assigned_outputs[i].data_ptr(); + graphinfo.outputs.graph_outputs_ge_tensors[idx].SetData(reinterpret_cast(data_ptr), + numel * assigned_outputs[i].itemsize(), [](uint8_t* device_ptr) {return;}); + } + + std::vector returnable_outputs = {}; + const std::vector& returnable_info = graphinfo.outputs.python_returnable_outputs_info; + for (size_t i = 0; i < graphinfo.outputs.python_returnable_outputs_info.size(); i++) { + auto options = at::TensorOptions().dtype(returnable_info[i].dtype). + device(at_npu::key::NativeDeviceType); + auto tensor = NPUNativeFunctions::empty_with_format(returnable_info[i].sizes, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), options.device_opt(), + options.pinned_memory_opt(), + returnable_info[i].tensor_desc.npu_format_); + const int64_t& idx = graphinfo.outputs.returnable_outputs_mapping[i]; + size_t numel = NPUNativeFunctions::get_storage_size(tensor); + graphinfo.outputs.graph_outputs_ge_tensors[idx].SetData(reinterpret_cast(tensor.data_ptr()), + numel * tensor.itemsize(), [](uint8_t* device_ptr) {return;}); + returnable_outputs.emplace_back(tensor); + } + + if (this->retain_inner_output_) { + graphinfo.outputs.inner_outputs_tensors.clear(); + const std::vector& inner_outputs_info = graphinfo.outputs.inner_outputs_info; + for (size_t i = 0; i < graphinfo.outputs.inner_outputs_info.size(); i++) { + auto options = at::TensorOptions().dtype(inner_outputs_info[i].dtype). + device(at_npu::key::NativeDeviceType); + auto tensor = NPUNativeFunctions::empty_with_format(inner_outputs_info[i].tensor_desc.base_sizes_, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), options.device_opt(), + options.pinned_memory_opt(), + inner_outputs_info[i].tensor_desc.npu_format_); + const int64_t& idx = graphinfo.outputs.inner_outputs_mapping[i]; + size_t numel = NPUNativeFunctions::get_storage_size(tensor); + graphinfo.outputs.graph_outputs_ge_tensors[idx]. + SetData(reinterpret_cast(tensor.data_ptr()), + numel * tensor.itemsize(), [](uint8_t* device_ptr) {return;}); + graphinfo.outputs.inner_outputs_tensors.emplace_back(tensor); + } + } + + GraphExecutor::GetInstance().RunGraph(graphinfo.graph_id_, graphinfo.inputs.graph_inputs_ge_tensors, + graphinfo.outputs.graph_outputs_ge_tensors); + + return returnable_outputs; +} + +int64_t ReplayGraphImpl::FindMapping(const std::vector& graph_uid, + const torch_npu::NpuGraphDesc& desc) { + int64_t uid = desc.unique_id; + for (size_t i = 0L; i < graph_uid.size(); i++) { + if (uid == graph_uid[i]) { + return i; + } + } + return -1; +} + +void ReplayGraphImpl::BuildReplayGraphInfo(const at::TensorList& tensors, std::vector& tensorinfo, + CombinedInfo& combinedinfo, std::vector& map) { + const auto& unique_ids = combinedinfo.unique_ids; + for (const auto& tensor : tensors) { + torch_npu::NPUStorageDesc& storage_desc = torch_npu::NPUBridge::GetNpuStorageImpl(tensor)->npu_desc_; + tensorinfo.emplace_back(TensorInfo(tensor.sizes().vec(), + tensor.strides().vec(), tensor.storage_offset(), + tensor.dtype(), storage_desc)); + torch_npu::NpuGraphDesc& graph_desc = torch_npu::NPUBridge:: + GetNpuStorageImpl(tensor)->get_mutable_npu_graph_desc(); + map.emplace_back(FindMapping(unique_ids, graph_desc)); + } +} + +void ReplayGraphImpl::SetInnerOutput(CombinedInfo& outputcombinedinfo, ReplayGraphInfo& graphinfo) { + std::vector id_mask = {0}; + const auto& out_ids = outputcombinedinfo.unique_ids; + id_mask.resize(out_ids.size()); + for (const auto& map_id : graphinfo.outputs.returnable_outputs_mapping) { + id_mask[map_id] = 1; + } + + auto full_output_storages = NpuGraphContextManager::GetInstance(). + GetAllStorageOfLiveTensors(c10_npu::current_device()); + std::vector output_storages; + for (const auto& s : full_output_storages) { + if (!(GraphUtils::IsTensorWithoutNode(s) || GraphUtils::IsDataTensor(s))) { + output_storages.emplace_back(s); + } + } + + for (size_t i = 0UL; i < id_mask.size(); i++) { + if (id_mask[i] == 1) { + continue; + } + const auto inner_id = out_ids[i]; + for (size_t storage_idx = 0; storage_idx < output_storages.size(); storage_idx++) { + auto& graph_desc = torch_npu::NPUBridge:: + GetNpuStorageImpl(output_storages[storage_idx])->get_mutable_npu_graph_desc(); + if (GraphUtils::IsTensorWithoutNode(output_storages[storage_idx])) { + continue; + } + if (graph_desc.unique_id == inner_id) { + graphinfo.outputs.inner_outputs_mapping.emplace_back(i); + const auto& storage_desc = torch_npu::NPUBridge:: + GetNpuStorageImpl(output_storages[storage_idx])->get_npu_desc(); + std::vector sizes; + for (const auto& size : storage_desc.base_sizes_) { + sizes.emplace_back(size); + } + std::vector strides; + for (const auto& stride : storage_desc.base_strides_) { + strides.emplace_back(stride); + } + graphinfo.outputs.inner_outputs_info.emplace_back(TensorInfo(sizes, strides, + storage_desc.base_offset_, storage_desc.data_type_, storage_desc)); + break; + } + } + } + + graphinfo.outputs.inner_outputs_tensors.clear(); + for (size_t i = 0; i < graphinfo.outputs.inner_outputs_info.size(); i++) { + const int64_t& idx = graphinfo.outputs.inner_outputs_mapping[i]; + c10::intrusive_ptr storage_impl = c10::intrusive_ptr:: + unsafe_reclaim_from_nonowning(output_storages[idx]); + auto tensor = at::detail::make_tensor(storage_impl, storage_impl, + graphinfo.outputs.inner_outputs_info[i].dtype); + graphinfo.outputs.inner_outputs_tensors.emplace_back(tensor); + } +} + +void ReplayGraphImpl::GenerateGraph(const at::TensorList& inputs, at::TensorList assigned_outputs, + at::TensorList returnable_outputs, bool retain_inner_output) { + auto input_tensor_shape = inputs[0].sizes(); + int64_t info_list_index = replay_graph_cache_.AddGraphCache(input_tensor_shape, graph_infolist_); + auto& graphinfo = graph_infolist_[info_list_index]; + this->retain_inner_output_ = retain_inner_output; + GraphExecutor::GetInstance().CheckDeviceIdAndInit(); + ScalarMemContext::GetContext().ExecuteH2D(c10_npu::getCurrentNPUStream()); + auto input_info = GraphExecutor::GetInstance().GetInputCombinedInfo(); + auto output_info = GraphExecutor::GetInstance().GetOutputCombinedInfo(); + const auto& out_ids = output_info.unique_ids; + const auto& in_ids = input_info.unique_ids; + + for (const auto& ge_input_tensor : input_info.tensors) { + graphinfo.inputs.graph_inputs_ge_tensors.emplace_back(ge_input_tensor); + } + + for (const auto& ge_output_tensor : output_info.tensors) { + graphinfo.outputs.graph_outputs_ge_tensors.emplace_back(ge_output_tensor); + } + + BuildReplayGraphInfo(inputs, graphinfo.inputs.python_inputs_info, input_info, + graphinfo.inputs.inputs_mapping); + BuildReplayGraphInfo(assigned_outputs, graphinfo.outputs.python_assigned_outputs_info, output_info, + graphinfo.outputs.assigned_outputs_mapping); + BuildReplayGraphInfo(returnable_outputs, graphinfo.outputs.python_returnable_outputs_info, output_info, + graphinfo.outputs.returnable_outputs_mapping); + + bool is_cache_hit = false; + graphinfo.graph_id_ = GraphExecutor::GetInstance().GetGraphIdDependOnCompileTypeAndCache( + input_info, output_info, is_cache_hit); + if (this->retain_inner_output_) { + SetInnerOutput(output_info, graphinfo); + } + + ScalarMemContext::GetContext().Reset(); + GraphExecutor::GetInstance().ResetGraphOutputs(); + GraphExecutor::GetInstance().RefreshGraphInputs(); + GraphExecutor::GetInstance().ClearDataStore(); + return; +} + +TensorVec ReplayGraphImpl::GetInnerOutputs(const at::TensorList& inputs) { + auto input_tensor_shape = inputs[0].sizes(); + auto info_list_index = replay_graph_cache_.ReturnGraphCache(input_tensor_shape); + auto& graphinfo = graph_infolist_[info_list_index]; + if (this->retain_inner_output_) { + return graphinfo.outputs.inner_outputs_tensors; + } + AT_ERROR("Get inner outputs should set retain_inner_output as true"); +} + +TensorVec ReplayGraph::Replay(const at::TensorList& inputs, at::TensorList assigned_outputs) { + if (this->replay_graph_ == nullptr) { + AT_ERROR("replay_graph_ == nullptr !"); + } + return this->replay_graph_->Replay(inputs, assigned_outputs); +} + +void ReplayGraph::GenerateGraph(const at::TensorList& inputs, at::TensorList assigned_outputs, + at::TensorList returnable_outputs, bool retain_inner_output) { + if (this->replay_graph_ == nullptr) { + AT_ERROR("replay_graph_ == nullptr !"); + } + this->replay_graph_->GenerateGraph(inputs, assigned_outputs, returnable_outputs, retain_inner_output); + return; +} + +TensorVec ReplayGraph::GetInnerOutputs(const at::TensorList& inputs) { + if (this->replay_graph_ == nullptr) { + AT_ERROR("replay_graph_ == nullptr !"); + } + return this->replay_graph_->GetInnerOutputs(inputs); +} + +bool ReplayGraph::ReplayCacheHit(const at::TensorList& inputs) { + if (this->replay_graph_ == nullptr) { + AT_ERROR("replay_graph_ == nullptr !"); + } + return this->replay_graph_->ReplayCacheHit(inputs); +} +} +} \ No newline at end of file diff --git a/torch_npu/csrc/framework/graph/ReplayGraph.h b/torch_npu/csrc/framework/graph/ReplayGraph.h new file mode 100644 index 0000000000000000000000000000000000000000..3b7c4f81d4f544e112745a2b7431d520957f724f --- /dev/null +++ b/torch_npu/csrc/framework/graph/ReplayGraph.h @@ -0,0 +1,144 @@ +#include +#include +#include + +namespace at_npu { +namespace native { +using TensorVec = std::vector; + +struct TensorInfo { + TensorInfo(std::vector size, std::vector stride, int64_t offset, + caffe2::TypeMeta type, torch_npu::NPUStorageDesc desc): + sizes(size), strides(stride), storage_offset(offset), dtype(type), + tensor_desc(desc) {} + + std::vector sizes; + std::vector strides; + int64_t storage_offset; + caffe2::TypeMeta dtype; + torch_npu::NPUStorageDesc tensor_desc; +}; + +struct ReplayGraphInputs { + std::vector python_inputs_info; + std::vector graph_inputs_ge_tensors; + std::vector inputs_mapping; +}; + +struct ReplayGraphOutputs { + std::vector python_assigned_outputs_info; + std::vector python_returnable_outputs_info; + std::vector inner_outputs_info; + std::vector graph_outputs_ge_tensors; + std::vector inner_outputs_tensors; + std::vector assigned_outputs_mapping; + std::vector returnable_outputs_mapping; + std::vector inner_outputs_mapping; +}; + +struct ReplayGraphInfo { + uint32_t graph_id_; + ReplayGraphInputs inputs; + ReplayGraphOutputs outputs; +}; + +class ReplayGraphCache { +public: + ReplayGraphCache() = default; + ReplayGraphCache(const ReplayGraphCache& other) = delete; + ReplayGraphCache& operator=(const ReplayGraphCache& other) = delete; + ReplayGraphCache(ReplayGraphCache&& other) = delete; + ReplayGraphCache& operator=(ReplayGraphCache&& other) = delete; + + template + int64_t AddGraphCache(T hash_key_input, L& graph_list) { + if (!replay_cache_flag_) { + return 0; + } + int64_t index = 0; + auto key = multi_hash(hash_key_input); + auto cache_index = graph_cache_.find(key); + if (cache_index == graph_cache_.end()) { + index = graph_list.size(); + graph_cache_.emplace(key, index); + ReplayGraphInfo info; + graph_list.emplace_back(info); + } else { + index = cache_index->second; + ReplayGraphInfo info; + graph_list[index] = info; + } + return index; + } + + template + int64_t ReturnGraphCache(T hash_key_input) { + if (!replay_cache_flag_) { + return 0; + } + auto cache_index = graph_cache_.find(multi_hash(hash_key_input)); + TORCH_CHECK(cache_index != graph_cache_.end(), "The graph is not generated when replay."); + return cache_index->second; + } + + template + bool FindGraphCache(T hash_key_input) { + if (!replay_cache_flag_) { + return true; + } + if (graph_cache_.empty() || (graph_cache_.find(multi_hash(hash_key_input)) == graph_cache_.end())) { + return false; + } + return true; + } +private: + std::unordered_map graph_cache_; + bool replay_cache_flag_ = true; +}; + +class ReplayGraphImpl { +public: + ReplayGraphImpl() = default; + ReplayGraphImpl(const ReplayGraphImpl& other) = delete; + ReplayGraphImpl& operator=(const ReplayGraphImpl& other) = delete; + ReplayGraphImpl(ReplayGraphImpl&& other) = delete; + ReplayGraphImpl& operator=(ReplayGraphImpl&& other) = delete; + + void GenerateGraph(const at::TensorList& inputs, at::TensorList assigned_outputs, + at::TensorList returnable_outputs, bool retain_inner_output); + TensorVec Replay(const at::TensorList& inputs, at::TensorList assigned_outputs); + TensorVec GetInnerOutputs(const at::TensorList& inputs); + bool ReplayCacheHit(const at::TensorList& inputs); + +private: + bool retain_inner_output_ = false; + std::vector graph_infolist_; + ReplayGraphCache replay_graph_cache_; + int64_t FindMapping(const std::vector& graph_uid, const torch_npu::NpuGraphDesc& desc); + void BuildReplayGraphInfo(const at::TensorList& tensors, std::vector& tensorinfo, + CombinedInfo& combinedinfo, std::vector& map); + void SetInnerOutput(CombinedInfo& outputcombinedinfo, ReplayGraphInfo& graphinfo); +}; + +class ReplayGraph { +public: + ReplayGraph() : replay_graph_(std::make_shared()) {}; + ~ReplayGraph() { + replay_graph_ = nullptr; + } + ReplayGraph(const ReplayGraph& other) = default; + ReplayGraph& operator=(const ReplayGraph& other) = default; + ReplayGraph(ReplayGraph&& other) = default; + ReplayGraph& operator=(ReplayGraph&& other) = default; + + void GenerateGraph(const at::TensorList& inputs, at::TensorList assigned_outputs, + at::TensorList returnable_outputs, bool retain_inner_output = false); + TensorVec Replay(const at::TensorList& inputs, at::TensorList assigned_outputs); + TensorVec GetInnerOutputs(const at::TensorList& inputs); + bool ReplayCacheHit(const at::TensorList& inputs); + +private: + std::shared_ptr replay_graph_; +}; +} +} \ No newline at end of file diff --git a/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp b/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp index 8c27be10f206d9968935cd056942f3f914be46b5..4edfd205fd35f0ddc00ed398efefbac2c1cc159b 100644 --- a/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp +++ b/torch_npu/csrc/framework/graph/execute/GraphExecutor.cpp @@ -77,6 +77,28 @@ void GraphExecutor::RunGraph( } } +void GraphExecutor::RunGraph( + uint32_t graph_id, + const std::vector& inputs, + std::vector& outputs) { + RECORD_HOST_FUNCTION("RunGraph", std::vector({})); + aclrtStream cal_stream = + const_cast(c10_npu::getCurrentNPUStream().stream()); + + auto start_time = std::chrono::steady_clock::now(); + C10_NPU_CHECK(session_->RunGraphWithStreamAsync(graph_id, + cal_stream, + inputs, + outputs)); + auto duration = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time); + if (verbose_) { + NPU_LOGI("RunGraph Time: duration = %.3f ms",static_cast(duration.count()) * + std::chrono::microseconds::period::num / + std::chrono::milliseconds::period::den); + } +} + void GraphExecutor::ConstructAndExecuteGraph() { RECORD_HOST_FUNCTION("ConstructAndExecuteGraph", std::vector({})); auto ret = CheckDeviceIdAndInit(); @@ -92,30 +114,13 @@ void GraphExecutor::ConstructAndExecuteGraph() { return; } - uint32_t cur_graph_id = graph_id + 1; - auto cached_graph_id = cacher_.GetCacheGraphId( - inputs.hash_of_topo_and_attr, - inputs.hash_of_shape, - outputs.hash_of_topo_and_attr, - outputs.hash_of_shape, - cur_graph_id); - - if (!cached_graph_id.has_value()) { - RECORD_HOST_FUNCTION("ConstructGraph", std::vector({})); - ConstructOps(outputs); - ge::Graph graph(kPytorchGraphName); - graph.SetInputs(GetInputOps()).SetOutputs(GetOutputOps()); - - C10_NPU_CHECK(session_->AddGraph(cur_graph_id, graph)); - graph_id = cur_graph_id; - } else { - cur_graph_id = cached_graph_id.value(); - } + bool is_cache_hit = false; + auto cur_graph_id = GetGraphIdDependOnCompileTypeAndCache(inputs, outputs, is_cache_hit); size_t input_number = inputs.tensors.size(); size_t output_number = outputs.tensors.size(); if (verbose_) { - string is_cache = cached_graph_id.has_value() ? "true" : "false"; + string is_cache = is_cache_hit ? "true" : "false"; NPU_LOGI("Using Graph Mode: current graph id = %u, cache hit = %s, input number = %zu, output number = %zu", cur_graph_id, is_cache.c_str(), input_number, output_number); } @@ -131,7 +136,7 @@ void GraphExecutor::ConstructAndExecuteGraph() { ScalarMemContext::GetContext().Reset(); ResetGraphOutputs(); - if (!cached_graph_id.has_value()) { + if (!is_cache_hit) { // Data of new graph maybe inputs of old graphs, // GE will change its attr // so we need to refresh it @@ -141,6 +146,32 @@ void GraphExecutor::ConstructAndExecuteGraph() { return; } +uint32_t GraphExecutor::GetGraphIdDependOnCompileTypeAndCache(const CombinedInfo& inputs, + CombinedInfo& outputs, + bool& is_cache_hit) { + uint32_t cur_graph_id = graph_id + 1; + auto cached_graph_id = cacher_.GetCacheGraphId( + inputs.hash_of_topo_and_attr, + inputs.hash_of_shape, + outputs.hash_of_topo_and_attr, + outputs.hash_of_shape, + cur_graph_id); + + if (!cached_graph_id.has_value()) { + RECORD_HOST_FUNCTION("ConstructGraph", std::vector({})); + ConstructOps(outputs); + ge::Graph graph(kPytorchGraphName); + graph.SetInputs(GetInputOps()).SetOutputs(GetOutputOps()); + + C10_NPU_CHECK(session_->AddGraph(cur_graph_id, graph)); + graph_id = cur_graph_id; + } else { + cur_graph_id = cached_graph_id.value(); + } + is_cache_hit = cached_graph_id.has_value(); + return cur_graph_id; +} + void GraphExecutor::Init() { auto device_id = std::to_string(init_device_id_); std::map config = { @@ -287,6 +318,17 @@ CombinedInfo GraphExecutor::GetInputCombinedInfo() { hash_t shape_hash = GraphCache::GetTensorShapeHash(topo_hash, tensor_desc); input_infos.hash_of_shape.push_back(shape_hash); } + + if (replay_graph_mode) { + for (size_t index = 0; index < input_storages.size(); ++index) { + torch_npu::NpuGraphDesc& graph_desc = + torch_npu::NPUBridge::GetNpuStorageImpl(input_storages[index])->get_mutable_npu_graph_desc(); + auto data_node = graph_desc.graph_value.GetDataNode(); + if (data_node.value()->GetOpType() == kDataNodeType) { + input_infos.unique_ids.emplace_back(graph_desc.unique_id); + } + } + } return input_infos; } @@ -323,6 +365,18 @@ CombinedInfo GraphExecutor::GetOutputCombinedInfo() { hash_t shape_hash = GraphCache::GetTensorShapeHash(topo_hash, tensor_desc); output_infos.hash_of_shape.push_back(shape_hash); } + + if (replay_graph_mode) { + for (size_t index = 0; index < output_storages.size(); ++index) { + const auto& output_storage = output_storages[index]; + if (!(GraphUtils::IsTensorWithoutNode(output_storage) || + GraphUtils::IsDataTensor(output_storage))) { + torch_npu::NpuGraphDesc& graph_desc = + torch_npu::NPUBridge::GetNpuStorageImpl(output_storage)->get_mutable_npu_graph_desc(); + output_infos.unique_ids.emplace_back(graph_desc.unique_id); + } + } + } return output_infos; } diff --git a/torch_npu/csrc/framework/graph/execute/GraphExecutor.h b/torch_npu/csrc/framework/graph/execute/GraphExecutor.h index f6553204f2cb72da8349c2666873dbb3465fd05f..cd36d6cde549a44d8dcf0bb3a539c6adabd0abe7 100644 --- a/torch_npu/csrc/framework/graph/execute/GraphExecutor.h +++ b/torch_npu/csrc/framework/graph/execute/GraphExecutor.h @@ -43,6 +43,7 @@ struct CombinedInfo { std::vector tensors; std::vector hash_of_topo_and_attr; std::vector hash_of_shape; + std::vector unique_ids; }; class GraphExecutor { @@ -60,10 +61,35 @@ public: return instance; } + void RunGraph( + uint32_t graph_id, + const std::vector& inputs, + std::vector& outputs); + + uint32_t GetGraphIdDependOnCompileTypeAndCache(const CombinedInfo& inputs, + CombinedInfo& outputs, + bool& is_cache_hit); + + bool CheckDeviceIdAndInit(); + + CombinedInfo GetInputCombinedInfo(); + + CombinedInfo GetOutputCombinedInfo(); + + void ResetGraphOutputs(); + + void RefreshGraphInputs(); + + void ClearDataStore(); + void SetVerbose(bool verbose) { verbose_ = verbose; } + void SetReplayGraphMode(bool is_replay_graph_mode) { + replay_graph_mode = is_replay_graph_mode; + } + void Finalize(); private: @@ -80,8 +106,6 @@ private: * * 2, you can not construct graph in two different device. */ - bool CheckDeviceIdAndInit(); - void RunGraph( uint32_t graph_id, CombinedInfo& inputs, @@ -93,10 +117,6 @@ private: GeOutPutOpType GetOutputOps(); - CombinedInfo GetInputCombinedInfo(); - - CombinedInfo GetOutputCombinedInfo(); - static ge::Tensor PrepareInputTensor( const c10::StorageImpl* const storage, const ge::TensorDesc& desc, @@ -106,18 +126,14 @@ private: c10::StorageImpl* storage, const ge::TensorDesc& desc); - void ResetGraphOutputs(); - - void RefreshGraphInputs(); - - void ClearDataStore(); - static uint32_t graph_id; c10::DeviceIndex init_device_id_ = -1; bool verbose_ = false; + bool replay_graph_mode = false; + std::unique_ptr session_ = nullptr; GraphCache cacher_; diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index 273298233739c8d4801563087e893c8c546dd122..8ef8c3265d93280f1f49dd39783214e6f778df01 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -228,6 +228,27 @@ PyObject* THNPModule_disable_graph_mode_wrap(PyObject* self, PyObject* noargs) { END_HANDLE_TH_ERRORS } +PyObject* THNPModule_enable_replay_graph_mode_wrap(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + pybind11::gil_scoped_release no_gil; + bool verbose = THPUtils_unpackBool(arg); + at_npu::native::GraphExecutor::GetInstance().SetVerbose(verbose); + at_npu::native::GraphExecutor::GetInstance().SetReplayGraphMode(true); + c10_npu::NpuRunMode::SetNpuRunMode(c10_npu::ModeKind::GRAPH_MODE); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THNPModule_disable_replay_graph_mode_wrap(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + pybind11::gil_scoped_release no_gil; + at_npu::native::GraphExecutor::GetInstance().ConstructAndExecuteGraph(); + at_npu::native::GraphExecutor::GetInstance().SetReplayGraphMode(false); + c10_npu::NpuRunMode::SetNpuRunMode(c10_npu::ModeKind::SINGLE_OP_MODE); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + PyObject* THNPModule_launch_graph_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS pybind11::gil_scoped_release no_gil; @@ -559,6 +580,8 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_setStream", (PyCFunction)THNPModule_setStream_wrap, METH_O, nullptr}, {"_npu_enable_graph_mode", (PyCFunction)THNPModule_enable_graph_mode_wrap, METH_O, nullptr}, {"_npu_disable_graph_mode", (PyCFunction)THNPModule_disable_graph_mode_wrap, METH_NOARGS, nullptr}, + {"_npu_enable_replay_graph_mode", (PyCFunction)THNPModule_enable_replay_graph_mode_wrap, METH_O, nullptr}, + {"_npu_disable_replay_graph_mode", (PyCFunction)THNPModule_disable_replay_graph_mode_wrap, METH_NOARGS, nullptr}, {"_npu_launch_graph", (PyCFunction)THNPModule_launch_graph_wrap, METH_NOARGS, nullptr}, {"_npu_is_graph_mode", (PyCFunction)THNPModule_is_graph_mode_wrap, METH_NOARGS, nullptr}, {"_npu_emptyCache", (PyCFunction) THNPModule_emptyCache, METH_NOARGS, nullptr}, diff --git a/torch_npu/csrc/npu/ReplayFunctions.cpp b/torch_npu/csrc/npu/ReplayFunctions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..69367ee619fb2ce69095ca9e4b00b44dd0bec94f --- /dev/null +++ b/torch_npu/csrc/npu/ReplayFunctions.cpp @@ -0,0 +1,213 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ReplayFunctions.h" +#include +#include +#include +#include +#include +#include "torch/csrc/autograd/utils/wrap_outputs.h" + +PyObject *THNPReplayGraphClass = nullptr; +using at::TensorList; +static PyObject* THNPReplayGraph_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + HANDLE_TH_ERRORS + THPObjectPtr ptr(type->tp_alloc(type, 0)); + if (!ptr) { + return nullptr; + } + + THNPReplayGraph* self = (THNPReplayGraph*)ptr.get(); + new (&self->replay_graph) at_npu::native::ReplayGraph(); + + return (PyObject*)ptr.release(); + END_HANDLE_TH_ERRORS +} + +static void THNPReplayGraph_dealloc(THNPReplayGraph* self) { + self->replay_graph.~ReplayGraph(); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +PyObject* THNPReplayGraph_generate_replay_graph(THNPReplayGraph* self, PyObject* args) { + HANDLE_TH_ERRORS + PyObject* inputs = nullptr; + PyObject* assigned_outputs = nullptr; + PyObject* returnable_outputs = nullptr; + PyObject* retain_inner_outputs = nullptr; + if (!PyArg_ParseTuple(args, "OOOO", &inputs, &assigned_outputs, &returnable_outputs, &retain_inner_outputs)) { + THPUtils_invalidArguments( + args, + nullptr, + "generate_replay_graph", + 1, + "(TensorList inputs, TensorList assigned_outputs, TensorList returnable_outputs, bool retain_inner_outputs);"); + return nullptr; + } + + static torch::PythonArgParser parser({ + "generate_replay_graph(TensorList inputs, TensorList assigned_outputs, TensorList returnable_outputs, bool retain_inner_outputs)", + }, true); + torch::ParsedArgs<4> parsed_args; + auto _r = parser.parse(args, nullptr, parsed_args); + pybind11::gil_scoped_release no_gil; + self->replay_graph.GenerateGraph(_r.tensorlist(0), _r.tensorlist(1), _r.tensorlist(2), _r.toBool(3)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THNPReplayGraph_replay(THNPReplayGraph* self, PyObject* args) { + HANDLE_TH_ERRORS + PyObject* inputs = nullptr; + PyObject* assigned_outputs = nullptr; + if (!PyArg_ParseTuple(args, "OO", &inputs, &assigned_outputs)) { + THPUtils_invalidArguments( + args, + nullptr, + "replay", + 1, + "(TensorList inputs, TensorList assigned_outputs);"); + return nullptr; + } + + static torch::PythonArgParser parser({ + "replay(TensorList inputs, TensorList assigned_outputs)", + }, true); + torch::ParsedArgs<2> parsed_args; + auto _r = parser.parse(args, nullptr, parsed_args); + auto call_replay = [&](const at::TensorList& inputs, at::TensorList assigned_outputs) -> std::vector { + pybind11::gil_scoped_release no_gil; + return self->replay_graph.Replay(inputs, assigned_outputs); + }; + return torch::autograd::utils::wrap(call_replay(_r.tensorlist(0), _r.tensorlist(1))); + END_HANDLE_TH_ERRORS +} + +PyObject* THNPReplayGraph_get_inner_outputs(THNPReplayGraph* self, PyObject* args) { + HANDLE_TH_ERRORS + PyObject* inputs = nullptr; + if (!PyArg_ParseTuple(args, "O", &inputs)) { + THPUtils_invalidArguments( + args, + nullptr, + "get_inner_outputs", + 1, + "(TensorList inputs);"); + return nullptr; + } + + static torch::PythonArgParser parser({ + "get_inner_outputs(TensorList inputs)", + }, true); + torch::ParsedArgs<1> parsed_args; + auto _r = parser.parse(args, nullptr, parsed_args); + auto call_get_inner_outputs = [&](const at::TensorList& inputs) -> std::vector { + pybind11::gil_scoped_release no_gil; + return self->replay_graph.GetInnerOutputs(inputs); + }; + return torch::autograd::utils::wrap(call_get_inner_outputs(_r.tensorlist(0))); + END_HANDLE_TH_ERRORS +} + +PyObject* THNPReplayGraph_is_replay_cache_hit(THNPReplayGraph* self, PyObject* args) { + HANDLE_TH_ERRORS + PyObject* inputs = nullptr; + if (!PyArg_ParseTuple(args, "O", &inputs)) { + THPUtils_invalidArguments( + args, + nullptr, + "replay_cache_hit", + 1, + "(TensorList inputs);"); + return nullptr; + } + + static torch::PythonArgParser parser({ + "is_replay_cache_hit(TensorList inputs)", + }, true); + torch::ParsedArgs<1> parsed_args; + auto _r = parser.parse(args, nullptr, parsed_args); + auto call_is_replay_cache_hit = [&](const at::TensorList& inputs) -> bool { + pybind11::gil_scoped_release no_gil; + return self->replay_graph.ReplayCacheHit(inputs); + }; + return torch::autograd::utils::wrap(call_is_replay_cache_hit(_r.tensorlist(0))); + END_HANDLE_TH_ERRORS +} + +static struct PyGetSetDef THNPReplayGraph_properties[] ={ + {nullptr} +}; + +static PyMethodDef THNPReplayGraph_methods[] = { + {"generate_replay_graph", (PyCFunction)THNPReplayGraph_generate_replay_graph, METH_VARARGS, nullptr}, + {"replay", (PyCFunction)THNPReplayGraph_replay, METH_VARARGS, nullptr}, + {"get_inner_outputs", (PyCFunction)THNPReplayGraph_get_inner_outputs, METH_VARARGS, nullptr}, + {"is_replay_cache_hit", (PyCFunction)THNPReplayGraph_is_replay_cache_hit, METH_VARARGS, nullptr}, + {nullptr} +}; + +PyTypeObject THNPReplayGraphType = { + PyVarObject_HEAD_INIT(nullptr, 0) + "torch_npu._C._NPUReplayGraphBase", /* tp_name */ + sizeof(THNPReplayGraph), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)THNPReplayGraph_dealloc, /* tp_dealloc */ + 0, /* tp_vectorcall_offset */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */ + nullptr, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + THNPReplayGraph_methods, /* tp_methods */ + 0, /* tp_members */ + THNPReplayGraph_properties, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + 0, /* tp_init */ + 0, /* tp_alloc */ + THNPReplayGraph_pynew, /* tp_new */ +}; + +void THNPReplayGraph_init(PyObject* module) { + THNPReplayGraphClass = (PyObject*)&THNPReplayGraphType; + if (PyType_Ready(&THNPReplayGraphType) < 0) { + throw python_error(); + } + Py_INCREF(&THNPReplayGraphType); + if (PyModule_AddObject(module, "_NPUReplayGraphBase", (PyObject*)&THNPReplayGraphType) < 0) { + throw python_error(); + } +} \ No newline at end of file diff --git a/torch_npu/csrc/npu/ReplayFunctions.h b/torch_npu/csrc/npu/ReplayFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..982b8165db1bad2d5a44a167e77ad546c6f9b3e6 --- /dev/null +++ b/torch_npu/csrc/npu/ReplayFunctions.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include + +struct THNPReplayGraph { + PyObject_HEAD + at_npu::native::ReplayGraph replay_graph; +}; + +extern PyObject* THNPReplayGraphClass; + +void THNPReplayGraph_init(PyObject *module); + +inline bool THNPReplayGraph_Check(PyObject* obj) { + return THNPReplayGraphClass && PyObject_IsInstance(obj, THNPReplayGraphClass); +} \ No newline at end of file diff --git a/torch_npu/npu/__init__.py b/torch_npu/npu/__init__.py index 3bf44f5a198309328427202181aed0e51a77a6e9..cc3815a76f8b8fa6ef13797682488cdfa16f5c19 100644 --- a/torch_npu/npu/__init__.py +++ b/torch_npu/npu/__init__.py @@ -27,7 +27,7 @@ __all__ = [ "memory_allocated", "max_memory_allocated", "memory_reserved", "max_memory_reserved", "memory_cached", "max_memory_cached", "memory_snapshot", "memory_summary", "Stream", "Event", "profiler", "set_option", "set_aoe", "profile", "prof_init", - "prof_start", "prof_stop", "prof_finalize", "iteration_start", "iteration_end", + "prof_start", "prof_stop", "prof_finalize", "iteration_start", "iteration_end", "profileConfig", "_in_bad_fork", "set_compile_mode", "FloatTensor", "IntTensor", "DoubleTensor", "LongTensor", "ShortTensor", "CharTensor", "ByteTensor", "HalfTensor", "set_mm_bmm_format_nd", "get_mm_bmm_format_nd", @@ -52,7 +52,9 @@ from .memory import (_free_mutex, caching_allocator_alloc, caching_allocator_del max_memory_allocated, memory_reserved, max_memory_reserved, memory_cached, max_memory_cached, memory_snapshot, memory_summary) from .streams import Stream, Event -from .graph import is_graph_mode, disable_graph_mode, enable_graph_mode, launch_graph +from .graph import (is_graph_mode, disable_graph_mode, enable_graph_mode, + launch_graph, enable_replay_graph_mode, disable_replay_graph_mode) +from .replay_graph import make_replay_graph from . import profiler from .npu_frontend_enhance import (set_option, set_aoe, profile, prof_init, prof_start, prof_stop, prof_finalize, iteration_start, iteration_end, diff --git a/torch_npu/npu/graph.py b/torch_npu/npu/graph.py index d6719314acf8af63ab3f7c7111f755500918d25e..8c2d79f4cb5ab0fbed9fe85f2f15b1fc2c98f61b 100644 --- a/torch_npu/npu/graph.py +++ b/torch_npu/npu/graph.py @@ -26,6 +26,15 @@ def disable_graph_mode(): torch_npu._C._npu_disable_graph_mode() +def enable_replay_graph_mode(verbose=False): + torch_npu._C._npu_enable_replay_graph_mode(verbose) + + +def disable_replay_graph_mode(): + _lazy_init() + torch_npu._C._npu_disable_replay_graph_mode() + + def is_graph_mode() -> bool: return torch_npu._C._npu_is_graph_mode() diff --git a/torch_npu/npu/replay_graph.py b/torch_npu/npu/replay_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..768302c889c510aa865f6c495815b8509bdadcec --- /dev/null +++ b/torch_npu/npu/replay_graph.py @@ -0,0 +1,192 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +class ReplayGraph(torch_npu._C._NPUReplayGraphBase): + def __new__(cls, **kwargs): + return super(ReplayGraph, cls).__new__(cls, **kwargs) + + def generate_replay_graph(self, inputs, assigned_outputs, + returnable_outputs, retain_inner_outputs=False): + super(ReplayGraph, self).generate_replay_graph(inputs, assigned_outputs, + returnable_outputs, retain_inner_outputs) + + def replay(self, inputs, assigned_outputs): + return super(ReplayGraph, self).replay(inputs, assigned_outputs) + + def get_inner_outputs(self, inputs): + return super(ReplayGraph, self).get_inner_outputs(inputs) + + def is_replay_cache_hit(self, inputs): + return super(ReplayGraph, self).is_replay_cache_hit(inputs) + + +class WrapModule(object): + def __init__(self, module, func, warm_up_step=3, verbose=False): + self.module = module + self.func = func + self.warm_up_step = warm_up_step + self.cur_step = 0 + self.fwd_graph = None + self.bwd_graph = None + self.call_func = None + self.param_grad = [] + self.verbose = verbose + + def wrap_forward(self, *args, **kwargs): + origin_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + arg.requires_grad_(True) + origin_inputs.append(arg) + + replay_cache = False + if (self.fwd_graph is not None): + replay_cache = self.fwd_graph.is_replay_cache_hit(origin_inputs) + + if (self.cur_step < self.warm_up_step) or not (replay_cache): + for p in self.module.parameters(): + p.grad = torch.zeros_like(p) + shallow_args = () + fwd_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + shallow_input = torch.empty_like(arg) + shallow_input.requires_grad_(True) + tu = (shallow_input,) + fwd_inputs.append(shallow_input) + shallow_args = shallow_args + tu + else: + tu = (arg,) + shallow_args = shallow_args + tu + + torch_npu.npu.enable_replay_graph_mode(self.verbose) + + shallow_fwd_output = self.func(*shallow_args, **kwargs) + fwd_graph_inputs = [] + fwd_graph_inputs.extend(fwd_inputs) + fwd_graph_inputs.extend(self.module.parameters()) + fwd_graph_inputs.extend(self.module.buffers()) + fwd_assigned_outputs = [] + if (self.fwd_graph is None): + self.fwd_graph = generate_replay_graph(inputs=fwd_graph_inputs, + assigned_outputs=fwd_assigned_outputs, + returnable_outputs=[shallow_fwd_output], + retain_inner_outputs=True) + else: + self.fwd_graph.generate_replay_graph(inputs=fwd_graph_inputs, + assigned_outputs=fwd_assigned_outputs, + returnable_outputs=[shallow_fwd_output], + retain_inner_outputs=True) + + saved_var = self.fwd_graph.get_inner_outputs(inputs=origin_inputs) + grad_input = torch.empty_like(shallow_fwd_output) + torch.autograd.backward(shallow_fwd_output, grad_input) + + self.param_grad = [] + for p in self.module.parameters(): + if p.grad is not None: + self.param_grad.append(p.grad) + + grad_output = [] + for fwd_input in fwd_inputs: + grad_output.append(fwd_input.grad) + + bwd_graph_inputs = [] + bwd_graph_inputs.extend(fwd_graph_inputs) + bwd_graph_inputs.extend(saved_var) + bwd_graph_inputs.append(grad_input) + bwd_graph_inputs.extend(self.param_grad) + bwd_graph_inputs.extend([shallow_fwd_output]) + if (self.bwd_graph is None): + self.bwd_graph = generate_replay_graph(inputs=bwd_graph_inputs, + assigned_outputs=self.param_grad, + returnable_outputs=grad_output) + else: + self.bwd_graph.generate_replay_graph(inputs=bwd_graph_inputs, + assigned_outputs=self.param_grad, + returnable_outputs=grad_output) + + torch_npu.npu.disable_replay_graph_mode() + + self.cur_step = self.cur_step + 1 + + class ReplayFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, *args, **kwargs): + fwd_inputs = [] + for arg in args: + if isinstance(arg, torch.Tensor): + fwd_inputs.append(arg) + + fwd_inputs_full = [] + fwd_inputs_full.extend(fwd_inputs) + fwd_inputs_full.extend(self.module.parameters()) + fwd_inputs_full.extend(self.module.buffers()) + fwd_assigned_outputs = [] + fwd_output = self.fwd_graph.replay(inputs=fwd_inputs_full, assigned_outputs=fwd_assigned_outputs) + save_var = self.fwd_graph.get_inner_outputs(inputs=origin_inputs) + ctx.fwd_input = fwd_inputs + ctx.saved_var = save_var + ctx.output = fwd_output[0] + fwd_output[0].requires_grad_(True) + return fwd_output + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, *grad_outputs): + need_init_grad = False + for p in self.module.parameters(): + if p.grad is None: + need_init_grad = True + break + + if need_init_grad: + self.param_grad = [] + for p in self.module.parameters(): + if p.grad is None: + p.grad = torch.zeros_like(p) + self.param_grad.append(p.grad) + + bwd_inputs_full = [] + bwd_inputs_full.extend(ctx.fwd_input) + bwd_inputs_full.extend(self.module.parameters()) + bwd_inputs_full.extend(self.module.buffers()) + bwd_inputs_full.extend(ctx.saved_var) + bwd_inputs_full.extend(grad_outputs) + bwd_inputs_full.extend(self.param_grad) + bwd_inputs_full.extend([ctx.output]) + bwd_output = self.bwd_graph.replay(inputs=bwd_inputs_full, assigned_outputs=self.param_grad) + ctx.saved_var = [] + ctx.output = [] + return bwd_output + + ret = ReplayFunction.apply(*args, **kwargs) + return ret[0] + + +def make_replay_graph(module, verbose_=False): + wrap_module = WrapModule(module, module.forward, verbose=verbose_) + module.forward = wrap_module.wrap_forward + module.is_replay_graph = True + return module + + +def generate_replay_graph(inputs, assigned_outputs, returnable_outputs, retain_inner_outputs=False): + replay_graph = ReplayGraph() + replay_graph.generate_replay_graph(inputs, assigned_outputs, returnable_outputs, retain_inner_outputs) + return replay_graph \ No newline at end of file diff --git a/torch_npu/npu/utils.py b/torch_npu/npu/utils.py index 54c8b98b2bbf055a0b811454c6a091bddf546ffb..9c3d744a12f1150a277c04f33b008cc916c661ec 100644 --- a/torch_npu/npu/utils.py +++ b/torch_npu/npu/utils.py @@ -331,6 +331,10 @@ if not hasattr(torch_npu._C, '_NPUStreamBase'): torch_npu._C.__dict__['_NPUStreamBase'] = _dummy_type('NPUStreamBase') torch_npu._C.__dict__['_NPUEventBase'] = _dummy_type('NPUEventBase') +if not hasattr(torch_npu._C, '_NPUReplayGraphBase'): + # Define dummy base classes + torch_npu._C.__dict__['_NPUReplayGraphBase'] = _dummy_type('NPUReplayGraphBase') + def init_dump(): torch_npu.npu._lazy_init()