diff --git a/tf_adapter_2.x/npu_device/core/npu_device.cpp b/tf_adapter_2.x/npu_device/core/npu_device.cpp index be0a990e9c578e4b8d6cb2ffcf3318694260bef8..64df7d1d745b4767d7efa253326761d0ed47372d 100644 --- a/tf_adapter_2.x/npu_device/core/npu_device.cpp +++ b/tf_adapter_2.x/npu_device/core/npu_device.cpp @@ -30,6 +30,7 @@ #include "npu_managed_buffer.h" #include "npu_tensor.h" #include "npu_utils.h" +#include "acl/acl_rt.h" #include "optimizers/npu_optimizer_manager.h" @@ -936,7 +937,20 @@ void NpuDevice::TransTfInputs2GeInputs(int num_inputs, TFE_TensorHandle **inputs dims.emplace_back(dim_size); } input.SetTensorDesc(ge::TensorDesc(ge::Shape(dims), ge::FORMAT_ND, ge_type)); - input.SetData(static_cast(tensor->data()), tensor->TotalBytes(), [](uint8_t *) {}); + void *host_memory; + auto ret = aclrtMallocHost(&host_memory, tensor->TotalBytes()); + if (ret != ACL_SUCCESS) { + status.status = tensorflow::errors::Internal("Call aclrtMallocHost failed"); + return; + } + ret = aclrtMemcpy(host_memory, tensor->TotalBytes(), tensor->data(), tensor->TotalBytes(), ACL_MEMCPY_HOST_TO_HOST); + if (ret != ACL_SUCCESS) { + status.status = tensorflow::errors::Internal("Call aclrtMemcpy failed"); + return; + } + input.SetData(static_cast(host_memory), tensor->TotalBytes(), [](uint8_t *ptr) {aclrtFreeHost(ptr);}); + // input.SetData(static_cast(tensor->data()), tensor->TotalBytes(), [](uint8_t *) {}); + // input.SetData(static_cast(tensor->data()), tensor->TotalBytes()); ge_inputs.emplace_back(input); DLOG() << " input " << i << " ge enum " << ge_type << " tf type " << tensorflow::DataTypeString(tensor->dtype()) << VecToString(dims); diff --git a/tf_adapter_2.x/tests/stub/include/ge/ge_api.h b/tf_adapter_2.x/tests/stub/include/ge/ge_api.h index 23a37692d5fe645b201a6107fdfed88c4f495381..ae7c5275d1afe24c9e4dd66f9bb12d4e94cad82b 100644 --- a/tf_adapter_2.x/tests/stub/include/ge/ge_api.h +++ b/tf_adapter_2.x/tests/stub/include/ge/ge_api.h @@ -87,7 +87,8 @@ class Tensor { std::unique_ptr ResetData() { return std::move(data_); } graphStatus SetData(const uint8_t *data, size_t size) { const static DeleteFunc deleter = [](uint8_t *p) { delete[] p; }; - data_ = std::unique_ptr(new uint8_t[size], deleter); + REQUIRES_ACL_STATUS_OK(aclrtMallocHost(data_, size), aclrtMallocHost); + // data_ = std::unique_ptr(new uint8_t[size], deleter); std::memcpy(data_.get(), data, size); size_ = size; return GRAPH_SUCCESS;