diff --git a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc index f032e347fcb918dc9baf4d059d6b3dd4aa2b045d..6f56379182ab4d15cf5dd7bc6c29af16abab291f 100644 --- a/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc +++ b/tf_adapter/kernels/aicpu/host_queue_dataset_op.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -280,6 +281,14 @@ class HostQueueDatasetOp : public DatasetOpKernel { dataset()->device_id_, dataset()->local_device_list_, dataset()->channel_name_)) { + size_t output_shape_size = dataset()->output_types_.size(); + for (size_t i = 0UL; i < output_shape_size; i++) { + DataType tensor_data_type = dataset()->output_types_.at(i); + if (tensor_data_type == DT_STRING) { + is_hold_data_trans = true; + break; + } + } } ~Iterator() override { @@ -581,7 +590,10 @@ class HostQueueDatasetOp : public DatasetOpKernel { Status SendDataByAclQueue(const vector &args, const acltdtTensorType &data_type) { Status status; + aclError acl_status; + size_t size; bool is_need_resend = false; + bool is_need_recompute_mbuf = false; do { { mutex_lock lck(mu_); @@ -589,8 +601,29 @@ class HostQueueDatasetOp : public DatasetOpKernel { break; } } - status = SendTensorsByAcl(acl_handle_, data_type, args, is_need_resend); - } while (status.ok() && is_need_resend); + if (!is_hold_data_trans || is_need_resend) { + status = SendTensorsByAcl(acl_handle_, data_type, args, is_need_resend); + } else { + acl_status = acltdtQueryChannelSize(acl_handle_, &size); + if (acl_status != ACL_SUCCESS) { + return errors::InvalidArgument("Failed to get the mbuf size!"); + } + if (size <= 1 || GetMbufTotalBytes(size) <= kMaxBytes) { + status = SendTensorsByAcl(acl_handle_, data_type, args, is_need_resend); + is_need_recompute_mbuf = false; + } else { + is_need_recompute_mbuf = true; + sched_yield(); + } + } + } while ((status.ok() && is_need_resend) || is_need_recompute_mbuf); + + uint64_t bytes_sum = 0ULL; + for (auto &tensor : args) { + bytes_sum += tensor.TotalBytes(); + } + args_bytes[args_bytes_rear] = bytes_sum; + args_bytes_rear = (args_bytes_rear + 1) % kStringTypeDepth; return status; } @@ -643,6 +676,15 @@ class HostQueueDatasetOp : public DatasetOpKernel { return true; } + uint64_t GetMbufTotalBytes(size_t size) { + uint64_t sum = 0; + for (size_t i = 1; i <= size ; i++) { + size_t index = (args_bytes_rear - i + kStringTypeDepth) % kStringTypeDepth; + sum += args_bytes[index]; + } + return sum; + } + void SendDataByQueueThread(const std::shared_ptr &ctx) { ADP_LOG(INFO) << "Begin to send data to the NPU. "; rtError_t rt = rtSetDevice(dataset()->device_id_); @@ -976,6 +1018,9 @@ class HostQueueDatasetOp : public DatasetOpKernel { acltdtChannelHandle *acl_handle_; uint32_t queue_id_; int active_thread_num = 0; + uint64_t args_bytes[kStringTypeDepth]; + int args_bytes_rear = 0; + bool is_hold_data_trans = false; }; const std::vector inputs_; std::string channel_name_; diff --git a/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc b/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc index 52109405e46bad63ac6784e42f195a288f2ca69d..87862c1d09ac82dd67c0d88dad32495b81d90131 100644 --- a/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc +++ b/tf_adapter/tests/depends/ascendcl/src/ascendcl_stub.cc @@ -103,6 +103,10 @@ aclError aclrtResetDevice(int32_t deviceId) { return ACL_SUCCESS; } +aclError acltdtQueryChannelSize(const acltdtChannelHandle *handle, size_t *size) { + return ACL_SUCCESS; +} + acltdtChannelHandle *acltdtCreateChannelWithCapacity(uint32_t deviceId, const char *name, size_t capacity) {