From 456eca3afdff99cc41f64549047246b33d60f936 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=B7=E6=AC=A2?= Date: Wed, 1 Feb 2023 21:29:00 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=A4=9A=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E5=90=AF=E5=8A=A8=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/aicpu/host_queue_dataset_op.cc | 2 +- tf_adapter/util/host_thread_pool.cc | 15 +++++++++++++-- tf_adapter/util/host_thread_pool.h | 5 ++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index f032e347f..d7248ac1f 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -914,7 +914,7 @@ class HostQueueDatasetOp : public DatasetOpKernel { } } if (dataset()->channel_type_ == ChannelType::ACL_QUEUE) { - TF_RETURN_IF_ERROR(thread_pool_.Init(dataset()->device_id_)); + TF_RETURN_IF_ERROR(thread_pool_.Init(ctx, dataset()->device_id_)); } { mutex_lock lck(mu_); diff --git a/tf_adapter/util/host_thread_pool.cc b/tf_adapter/util/host_thread_pool.cc index 7710277a7..d716789ab 100644 --- a/tf_adapter/util/host_thread_pool.cc +++ b/tf_adapter/util/host_thread_pool.cc @@ -22,11 +22,12 @@ namespace { const uint32_t MAX_THREAD_NUM = 4U; } namespace tensorflow { +namespace data { HostThreadPool::HostThreadPool() : thread_stop_flag_(false), device_id_(0U) {} HostThreadPool::~HostThreadPool() {} - Status HostThreadPool::Init(uint32_t device_id) { + Status HostThreadPool::Init(IteratorContext *ctx, uint32_t device_id) { ADP_LOG(INFO) << "Start to start thread pool."; device_id_ = device_id; copy_thread_pool_.resize(MAX_THREAD_NUM); @@ -37,8 +38,17 @@ namespace tensorflow { for (size_t idx = 0UL; idx < copy_thread_pool_.size(); idx++) { if (copy_thread_pool_[idx] == nullptr) { std::string thread_name = "thread_pool" + std::to_string(idx); - copy_thread_pool_[idx].reset( + if (ctx == nullptr) { + copy_thread_pool_[idx].reset( Env::Default()->StartThread({}, thread_name, [this]() { ParallelForCopyThread(); })); + } else { + copy_thread_pool_[idx].reset( + ctx->env()->StartThread({}, thread_name, [this]() { ParallelForCopyThread(); })); + std::cout << "1111 " << Env::Default() << ";" << std::endl; + std::cout << "2222 " << ctx->env() << std::endl; + ADP_LOG(ERROR) << "2222->" << ctx->env(); + ADP_LOG(ERROR) << "11111->" << Env::Default(); + } } } return Status::OK(); @@ -73,4 +83,5 @@ namespace tensorflow { thread_stop_flag_.store(true); queue_var_.notify_all(); } +} } \ No newline at end of file diff --git a/tf_adapter/util/host_thread_pool.h b/tf_adapter/util/host_thread_pool.h index 9f5f5cfc0..3f76733c0 100644 --- a/tf_adapter/util/host_thread_pool.h +++ b/tf_adapter/util/host_thread_pool.h @@ -24,13 +24,15 @@ #include #include #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { +namespace data { class HostThreadPool { public: HostThreadPool(); - Status Init(uint32_t device_id); + Status Init(IteratorContext *ctx, uint32_t device_id); void PushTask(const std::function &closure); void StopThreadPool(); ~HostThreadPool(); @@ -44,4 +46,5 @@ class HostThreadPool { uint32_t device_id_; }; } +} #endif \ No newline at end of file -- Gitee From baa02e1ee72a62befb4ecfdfc82acdf62ca9bcec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=B7=E6=AC=A2?= Date: Thu, 2 Feb 2023 10:31:08 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Revert=20"=E4=BF=AE=E6=94=B9=E5=A4=9A?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=E5=90=AF=E5=8A=A8=E6=96=B9=E5=BC=8F"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit 456eca3afdff99cc41f64549047246b33d60f936. --- tf_adapter/kernels/aicpu/host_queue_dataset_op.cc | 2 +- tf_adapter/util/host_thread_pool.cc | 15 ++------------- tf_adapter/util/host_thread_pool.h | 5 +---- 3 files changed, 4 insertions(+), 18 deletions(-) diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index d7248ac1f..f032e347f 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -914,7 +914,7 @@ class HostQueueDatasetOp : public DatasetOpKernel { } } if (dataset()->channel_type_ == ChannelType::ACL_QUEUE) { - TF_RETURN_IF_ERROR(thread_pool_.Init(ctx, dataset()->device_id_)); + TF_RETURN_IF_ERROR(thread_pool_.Init(dataset()->device_id_)); } { mutex_lock lck(mu_); diff --git a/tf_adapter/util/host_thread_pool.cc b/tf_adapter/util/host_thread_pool.cc index d716789ab..7710277a7 100644 --- a/tf_adapter/util/host_thread_pool.cc +++ b/tf_adapter/util/host_thread_pool.cc @@ -22,12 +22,11 @@ namespace { const uint32_t MAX_THREAD_NUM = 4U; } namespace tensorflow { -namespace data { HostThreadPool::HostThreadPool() : thread_stop_flag_(false), device_id_(0U) {} HostThreadPool::~HostThreadPool() {} - Status HostThreadPool::Init(IteratorContext *ctx, uint32_t device_id) { + Status HostThreadPool::Init(uint32_t device_id) { ADP_LOG(INFO) << "Start to start thread pool."; device_id_ = device_id; copy_thread_pool_.resize(MAX_THREAD_NUM); @@ -38,17 +37,8 @@ namespace data { for (size_t idx = 0UL; idx < copy_thread_pool_.size(); idx++) { if (copy_thread_pool_[idx] == nullptr) { std::string thread_name = "thread_pool" + std::to_string(idx); - if (ctx == nullptr) { - copy_thread_pool_[idx].reset( + copy_thread_pool_[idx].reset( Env::Default()->StartThread({}, thread_name, [this]() { ParallelForCopyThread(); })); - } else { - copy_thread_pool_[idx].reset( - ctx->env()->StartThread({}, thread_name, [this]() { ParallelForCopyThread(); })); - std::cout << "1111 " << Env::Default() << ";" << std::endl; - std::cout << "2222 " << ctx->env() << std::endl; - ADP_LOG(ERROR) << "2222->" << ctx->env(); - ADP_LOG(ERROR) << "11111->" << Env::Default(); - } } } return Status::OK(); @@ -83,5 +73,4 @@ namespace data { thread_stop_flag_.store(true); queue_var_.notify_all(); } -} } \ No newline at end of file diff --git a/tf_adapter/util/host_thread_pool.h b/tf_adapter/util/host_thread_pool.h index 3f76733c0..9f5f5cfc0 100644 --- a/tf_adapter/util/host_thread_pool.h +++ b/tf_adapter/util/host_thread_pool.h @@ -24,15 +24,13 @@ #include #include #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { -namespace data { class HostThreadPool { public: HostThreadPool(); - Status Init(IteratorContext *ctx, uint32_t device_id); + Status Init(uint32_t device_id); void PushTask(const std::function &closure); void StopThreadPool(); ~HostThreadPool(); @@ -46,5 +44,4 @@ class HostThreadPool { uint32_t device_id_; }; } -} #endif \ No newline at end of file -- Gitee