diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index f78ead500bac32e2d66cf41ead5da3491ee7fa2c..3bd5ce25c5976a0e1159b439b0ede23250f3b70c 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -693,7 +693,7 @@ class HostQueueDatasetOp : public DatasetOpKernel { total_bytes_ -= tensor.TotalBytes(); } } - ADP_LOG(INFO) << "Host queue " << dataset()->channel_name_ + ADP_LOG(INFO) << "Host queue " << dataset()->channel_name_ << " buffer_size: " << buffer_.size() << ", data_type:" << data_type; } Status status; diff --git a/tf_adapter/optimizers/dp_tf_to_ge_conversion_pass.cc b/tf_adapter/optimizers/dp_tf_to_ge_conversion_pass.cc index 17f2686408a5958dfe35f6260a45047e407aa7ff..9c9ec3b9fbabd149da9d3f1cb5023e0b0cb8d93d 100644 --- a/tf_adapter/optimizers/dp_tf_to_ge_conversion_pass.cc +++ b/tf_adapter/optimizers/dp_tf_to_ge_conversion_pass.cc @@ -45,7 +45,10 @@ static std::atomic g_channel_index(1); const static std::vector GE_OPS_WHITELIST = { "MapDataset", "ParallelMapDataset", "BatchDataset", "MapAndBatchDataset", "DeviceQueueDataset", "BatchDatasetV2", "MapAndBatchDatasetV2", "ModelDataset", "OptimizeDataset"}; - +// Indicates the chips that allows datesets to be executed on the npu +const static std::set kToNpuDatasetChips = { + "Ascend910", "Ascend910A", "Ascend910B", "Ascend910ProA", "Ascend910ProB", "Ascend910PremiumA" +}; // Used for 0-input NodeDefBuilder const static std::vector EMPTY_DEF_INPUT; // Used for 0-input NodeBuilder @@ -72,6 +75,8 @@ class DpTfToGEConversionPassImpl { inline bool IsIteratorNode(const Node &n) const; inline bool IsSkipDataset(const Node &n) const; inline bool IsGeSupportDataset(const Node &n) const; + inline bool NeedDeviceDataset(const Node &n, const std::string &socVersion) const; + inline bool IsSupportNpuDatasetChip(const string &socVersion) const; inline std::string GetEdgeName(const Edge *e) const; inline std::string GetRandomName(const std::string &prefix) const; std::string GetRandomName() const; @@ -79,9 +84,12 @@ class DpTfToGEConversionPassImpl { inline bool CheckNode(const std::string &op) const; inline bool IsDeviceSupportedOp(const NodeDef &n) const; inline bool IsDeviceSupportedFunc(const std::string &fn) const; - inline Status GetSplitEdges(const Node &n, std::vector &split_edges, const Edge *last_edge); + inline Status GetSplitEdges(const Node &n, std::vector &split_edges, + const Edge *last_edge, const std::string &socVersion); inline void RemoveSplitEdges(Node *topo_end); - Status InsertChannelQueue(Node *topo_end, std::string &host_queue_name, std::string &device_queue_name, + Status InsertChannelQueue(Node *topo_end, std::string &host_queue_name, + std::string &device_queue_name, + const std::string &socVersion, const std::map &all_options) const; bool GetNodeFuncs(const FunctionLibraryDefinition &flib_def, const Node &node, std::vector &node_funcs) const; bool RemoveIsolatedNode(Graph &g, std::unordered_set visited) const; @@ -98,7 +106,7 @@ class DpTfToGEConversionPassImpl { const std::string &fn_geop_dataset, const string &default_device, const std::map &all_options) const; Status BuildGeOpDatasetFunction(FunctionDefLibrary &fdeflib, const Graph &device_graph, - const std::string &fn_geop_dataset, const string &default_device, + const std::string &fn_geop_dataset, const string &default_device, const std::map &all_options) const; Status AddGeOpDatasetFunctionLibrary(FunctionLibraryDefinition *flib, const Node &topo_end, const std::string &device_channel_name, const std::string &fn_geop_dataset, @@ -108,7 +116,7 @@ class DpTfToGEConversionPassImpl { const std::string &device_channel_name) const; void AddOptionAttr(std::vector nodes, const std::map &all_options) const; bool GetSkipOptimizeFlag(const std::map &pass_options, - const OptimizationPassRegistry::Grouping pass_group_value) const; + const OptimizationPassRegistry::Grouping pass_group_value) const; // graph num int graph_run_num_; // All split edges, split edges means edges that combine A and B in this case @@ -151,6 +159,10 @@ inline bool DpTfToGEConversionPassImpl::IsGeSupportDataset(const Node &n) const return std::find(GE_OPS_WHITELIST.begin(), GE_OPS_WHITELIST.end(), n.type_string()) != GE_OPS_WHITELIST.end(); } +inline bool DpTfToGEConversionPassImpl::IsSupportNpuDatasetChip(const string &socVersion) const { + return (kToNpuDatasetChips.find(socVersion) != kToNpuDatasetChips.end()); +} + inline std::string DpTfToGEConversionPassImpl::GetEdgeName(const Edge *e) const { if (e == nullptr || e->src() == nullptr || e->dst() == nullptr) { return "invalid_edge"; @@ -318,13 +330,13 @@ inline bool DpTfToGEConversionPassImpl::IsDeviceSupportedFunc(const std::string } inline Status DpTfToGEConversionPassImpl::GetSplitEdges(const Node &n, std::vector &split_edges, - const Edge *last_edge) { + const Edge *last_edge, const std::string &socVersion) { if (IsMakeIteratorNode(n)) { for (const Edge *e : n.in_edges()) { REQUIRES_NOT_NULL(e); if (!IsIteratorNode(*(e->src()))) { last_edge = e; - ADP_LOG(INFO) << "last edge" << GetEdgeName(last_edge); + ADP_LOG(INFO) << "Last edge is " << GetEdgeName(last_edge); } } } @@ -339,18 +351,17 @@ inline Status DpTfToGEConversionPassImpl::GetSplitEdges(const Node &n, std::vect "optimize"); } // GE supported node, continue find - if (kIsHeterogeneous) { + if ((kIsHeterogeneous) || (!IsSupportNpuDatasetChip(socVersion))) { if (!IsIteratorNode(*(e->src()))) { split_edges.push_back(last_edge); } } else if (IsDeviceSupportedOp(e->src()->def())) { - Status s = GetSplitEdges(*(e->src()), split_edges, last_edge); + Status s = GetSplitEdges(*(e->src()), split_edges, last_edge, socVersion); if (!s.ok()) { return s; } } else { // GE unsupported node, this is a split edge - ADP_LOG(INFO) << "Split_" << GetEdgeName(e); - ADP_LOG(INFO) << "Begin check split edge."; + ADP_LOG(INFO) << "Begin check split edge [Split_" + GetEdgeName(e) + "]."; if (IsSkipDataset(*(e->dst()))) { ADP_LOG(INFO) << "ADD last edge " << GetEdgeName(last_edge); split_edges.push_back(last_edge); @@ -364,8 +375,20 @@ inline Status DpTfToGEConversionPassImpl::GetSplitEdges(const Node &n, std::vect return Status::OK(); } +inline bool DpTfToGEConversionPassImpl::NeedDeviceDataset(const Node &n, + const std::string &socVersion) const { + if (kIsHeterogeneous) { + return false; + } + if (!NpuAttrs::GetNewDataTransferFlag()) { + return true; + } + return (IsSupportNpuDatasetChip(socVersion) ? IsGeSupportDataset(n) : false); +} + Status DpTfToGEConversionPassImpl::InsertChannelQueue(Node *topo_end, std::string &host_queue_name, std::string &device_queue_name, + const std::string &socVersion, const std::map &all_options) const { ADP_LOG(INFO) << "Start to insert HostQueueDataset and DeviceQueueDataset."; REQUIRES_NOT_NULL(topo_end); @@ -378,18 +401,10 @@ Status DpTfToGEConversionPassImpl::InsertChannelQueue(Node *topo_end, std::strin REQUIRES_NOT_NULL(e); REQUIRES_NOT_NULL(e->src()); REQUIRES_NOT_NULL(e->dst()); - bool need_add_device_dataset = false; - if (kIsHeterogeneous) { - need_add_device_dataset = false; - } else if ((!NpuAttrs::GetNewDataTransferFlag()) || (IsGeSupportDataset(*(e->dst())))) { - need_add_device_dataset = true; - } else { - need_add_device_dataset = false; - } - std::string local_rank_id = all_options.at("local_rank_id"); std::string local_device_list = all_options.at("local_device_list"); std::string channel_name; + bool need_add_device_dataset = NeedDeviceDataset(*(e->dst()), socVersion); if (local_rank_id == "-1") { REQUIRES_NOT_NULL(iterator_node); if (!need_add_device_dataset) { @@ -554,15 +569,16 @@ Status DpTfToGEConversionPassImpl::AddDataTransDatasets(Node *topo_end, std::str std::string &device_channel_name, const std::map &all_options) { const Edge *tmp_edge = nullptr; - Status ret = GetSplitEdges(*topo_end, split_edges_[topo_end], tmp_edge); + const char *soc_name = aclrtGetSocName(); + const std::string socVersion = (soc_name == nullptr) ? "" : soc_name; + Status ret = GetSplitEdges(*topo_end, split_edges_[topo_end], tmp_edge, socVersion); if (!ret.ok()) { return ret; } - // Start optimize graph // Insert Host and Device queue - ADP_LOG(INFO) << "Start to add host and device queue on split edges"; - ret = InsertChannelQueue(topo_end, host_channel_name, device_channel_name, all_options); + ADP_LOG(INFO) << "Start to add host and device queue on split edges, SOC_VERSION [" + socVersion +"]."; + ret = InsertChannelQueue(topo_end, host_channel_name, device_channel_name, socVersion, all_options); if (!ret.ok()) { return ret; } diff --git a/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc b/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc index a118d583faa48dccdce73dc0c8e894398c7d3382..277d68aecd22a777b1bdc04cc54e4eb50893d5cc 100644 --- a/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc +++ b/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc @@ -329,6 +329,10 @@ aclError aclrtDestroyEvent(aclrtEvent event) { return ACL_ERROR_NONE; } +const char *aclrtGetSocName() { + return "Ascend910B"; +} + // for GE RunGraph api #if 0 aclError aclrtSynchronizeStream(aclrtStream stream) {