From 84836b5731d7dca092d6dfcc0ce700380c748d08 Mon Sep 17 00:00:00 2001 From: weidandan 00687068 Date: Mon, 20 Feb 2023 11:40:09 +0800 Subject: [PATCH] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E9=A2=84=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=8A=A8=E6=80=81=E6=8E=A7=E5=88=B6device=E5=86=85=E5=AD=98?= =?UTF-8?q?=E5=8D=A0=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../kernels/aicpu/host_queue_dataset_op.cc | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index f032e347f..7b96b25f4 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -197,6 +197,7 @@ class HostQueueDatasetOp : public DatasetOpKernel { for (size_t i = 0UL; i < output_shape_size; i++) { DataType tensor_data_type = output_types_.at(i); if (tensor_data_type == DT_STRING) { + string_flag = True; ADP_LOG(INFO) << "Current tensor type is DT_STRING."; return kStringTypeDepth; } @@ -643,6 +644,30 @@ class HostQueueDatasetOp : public DatasetOpKernel { return true; } + Status SendDataByControl(const vector &args, const acltdtTensorType &data_type) { + Status status; + aclError acl_status; + size_t size; + while(True) { + acl_status = acltdtQueryChannelSize(acl_handle_, size); + if (acl_status != ACL_SUCESS) { + ADP_LOG(ERROR) << "acltdtQueryChannelSize failed: acl_status = " << acl_status; + return notok; + } + if (size <= 1) { + status = SendDataByAclQueue(args, data_type); + break; + } + else if (Queue_.ComputeSumBeforeRear(size) <= kMaxBytes) { + status = SendDataByAclQueue(args, data_type); + break; + } else { + sched_yield(); + } + } + return status; + } + void SendDataByQueueThread(const std::shared_ptr &ctx) { ADP_LOG(INFO) << "Begin to send data to the NPU. "; rtError_t rt = rtSetDevice(dataset()->device_id_); @@ -686,18 +711,25 @@ class HostQueueDatasetOp : public DatasetOpKernel { if (buffer_.front().host_thread_finished) { data_type = buffer_.front().status.ok() ? ACL_TENSOR_DATA_END_OF_SEQUENCE : ACL_TENSOR_DATA_ABNORMAL; } else { + uint64_t bytes_sum = 0ULL; args = buffer_.front().value; buffer_.pop_front(); for (auto &tensor : args) { total_bytes_ -= tensor.TotalBytes(); + bytes_sum += tensor.TotalBytes(); } + Queue_.EnQueue(bytes_sum); } ADP_LOG(INFO) << "Host queue " << dataset()->channel_name_ << ", buffer_size: " << buffer_.size() << ", data_type:" << data_type; } Status status; if (dataset()->channel_type_ == ChannelType::ACL_QUEUE) { - status = SendDataByAclQueue(args, data_type); + if (string_flag) { + status = SendDataByControl(args, data_type); + } else { + status = SendDataByAclQueue(args, data_type); + } } else { status = SendDataByHostQueue(args, data_type); } @@ -976,6 +1008,37 @@ class HostQueueDatasetOp : public DatasetOpKernel { acltdtChannelHandle *acl_handle_; uint32_t queue_id_; int active_thread_num = 0; + class SqQueue { + public: + SqQueue() { + rear = 0; + } + + ~SqQueue() { + } + + void EnQueue(uint64_t data) { + args_bytes[rear] = data; + rear = (rear + 1) % kStringTypeDepth; + } + + uint64_t ComputeSumBeforeRear(int size) { + uint64_sum = 0; + for (int i = 2; i < size + 2; i++) { + int index = (rear - i + kStringTypeDepth) % kStringTypeDepth; + sum += args_bytes[index]; + } + return sum; + } + + private: + static uint64_t args_bytes[kStringTypeDepth]; + int rear; + + + }; + + SqQueue Queue_; }; const std::vector inputs_; std::string channel_name_; @@ -999,6 +1062,7 @@ class HostQueueDatasetOp : public DatasetOpKernel { ChannelType channel_type_; uint32_t queue_id_; std::string queue_name_; + bool string_flag = false; }; REGISTER_KERNEL_BUILDER(Name("HostQueueDataset").Device(DEVICE_CPU), HostQueueDatasetOp); -- Gitee