From 2b1c20eba25a95ccc85d87d8910745b493fa3c30 Mon Sep 17 00:00:00 2001 From: yaolun Date: Mon, 20 Mar 2023 11:44:24 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A2=84=E5=A4=84=E7=90=86=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E5=91=8A=E8=AD=A6=E6=B8=85=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/aicpu/dataset_function.cc | 11 ++++------- tf_adapter/kernels/aicpu/dataset_function.h | 8 ++++---- tf_adapter/kernels/aicpu/stream_pool.h | 4 +++- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tf_adapter/kernels/aicpu/dataset_function.cc b/tf_adapter/kernels/aicpu/dataset_function.cc index 38811c113..4b39a2408 100644 --- a/tf_adapter/kernels/aicpu/dataset_function.cc +++ b/tf_adapter/kernels/aicpu/dataset_function.cc @@ -624,9 +624,8 @@ Status DatasetFunction::InitAccelateOpList(std::vector &acc_while_l return Status::OK(); } -void DatasetFunction::MarkDvppGraphNodes(Graph &sub_graph_tf, std::vector &acc_nodes, - std::vector &dvpp_graph_nodes, - const std::vector acc_while_list) { +void DatasetFunction::MarkDvppGraphNodes(Graph &sub_graph_tf, std::vector &dvpp_graph_nodes, + const std::vector acc_while_list) const { for (Node *node : sub_graph_tf.nodes()) { if (node->IsSource() || node->IsSink()) { continue; } @@ -639,7 +638,6 @@ void DatasetFunction::MarkDvppGraphNodes(Graph &sub_graph_tf, std::vector }; for (std::string item : acc_while_list) { if (node->type_string().find(item) != string::npos) { - acc_nodes.emplace_back(node); tensorflow::DFSFrom(sub_graph_tf, {node}, {}, leave, {}, {}); } } @@ -869,9 +867,8 @@ std::string DatasetFunction::SplitSubGraph(FunctionLibraryDefinition &flib_def, // mark all the nodes with white list in acc_nodes // dvpp_graph_nodes stores acc_nodes and nodes which depend on the output of nodes in acc_nodes - std::vector acc_nodes; std::vector dvpp_graph_nodes; - MarkDvppGraphNodes(sub_graph_tf, acc_nodes, dvpp_graph_nodes, acc_while_list); + MarkDvppGraphNodes(sub_graph_tf, dvpp_graph_nodes, acc_while_list); // handle const node, save the const node which connect host graph and dvpp graph, // we will copy this const node for dvpp graph MarkConstNodes(sub_graph_tf, dvpp_graph_nodes); @@ -952,7 +949,7 @@ Status DatasetFunction::CreateGeGraph(const std::shared_ptr & return Status::OK(); } -bool DatasetFunction::IsSplitGraph() { +bool DatasetFunction::IsSplitGraph() const { return run_split_graph_; } diff --git a/tf_adapter/kernels/aicpu/dataset_function.h b/tf_adapter/kernels/aicpu/dataset_function.h index 83fb15c84..661067dbb 100644 --- a/tf_adapter/kernels/aicpu/dataset_function.h +++ b/tf_adapter/kernels/aicpu/dataset_function.h @@ -117,7 +117,7 @@ class DatasetFunction { static void DestoryAclModelDesc(aclmdlDesc *model_desc); static Status GetAclTenorDescDims(aclTensorDesc *desc, std::vector &ret_dims); static void *ReAllocDeviceMem(void *addr, size_t len); - bool IsSplitGraph(); + bool IsSplitGraph() const; static inline bool CheckMultiplyOverflow(int64_t a, int64_t b) { const static int64_t max_int64 = INT64_MAX; @@ -173,8 +173,8 @@ class DatasetFunction { } private: - void MarkDvppGraphNodes(Graph &sub_graph_tf, std::vector &acc_nodes, - std::vector &dvpp_graph_nodes, const std::vector acc_while_list); + void MarkDvppGraphNodes(Graph &sub_graph_tf, std::vector &dvpp_graph_nodes, + const std::vector acc_while_list) const; void MarkConstNodes(const Graph &sub_graph_tf, std::vector &dvpp_graph_nodes) const; bool CheckCorrectness(const tensorflow::Graph &sub_graph_tf, const std::vector dvpp_graph_nodes, const std::vector host_graph_nodes) const; @@ -187,7 +187,7 @@ class DatasetFunction { const std::map dvpp_arg_indexs) const; std::string SplitSubGraph(FunctionLibraryDefinition &flib_def, const std::vector acc_while_list); Status InitAccelateOpList(std::vector &acc_while_list) const; - Status ReadJsonFile(const string &file_name, nlohmann::json &json_read) const; + Status ReadJsonFile(const string &json_file_path, nlohmann::json &json_read) const; tensorflow::DataType EdgeDataType(const tensorflow::Edge &edge) const; std::string GetSocVersion() const; diff --git a/tf_adapter/kernels/aicpu/stream_pool.h b/tf_adapter/kernels/aicpu/stream_pool.h index 4d5234a39..f454311a7 100644 --- a/tf_adapter/kernels/aicpu/stream_pool.h +++ b/tf_adapter/kernels/aicpu/stream_pool.h @@ -255,6 +255,8 @@ public: ADP_LOG(ERROR) << "[StreamPool] Failed to reset cur_event_num_ memory. size_count=" << size_count; } streams_.resize(pos_stream_num, nullptr); + // reset max_stream_ to be the same as the number of threads + max_stream_ = pos_stream_num; } ~StreamPool() { @@ -322,7 +324,7 @@ public: private: std::mutex mtx_; - const uint64_t max_stream_; + uint64_t max_stream_; const uint64_t max_task_; uint64_t cur_stream_num_ = 0; uint64_t *cur_event_num_ = nullptr; -- Gitee