From 1e2587912e8082a802d265ec640d3a5aa75e2886 Mon Sep 17 00:00:00 2001 From: lovline Date: Wed, 21 Feb 2024 11:01:47 +0800 Subject: [PATCH] support host feature mapping export value --- .../kernels/aicpu/host_feature_mapping.h | 48 ++++ .../aicpu/host_feature_mapping_export.cc | 267 +++++++++++++++++- .../aicpu/host_feature_mapping_import.cc | 88 +++++- tf_adapter/ops/aicpu/npu_cpu_ops.cc | 4 + .../python/npu_bridge/npu_cpu/npu_cpu_ops.py | 12 +- 5 files changed, 401 insertions(+), 18 deletions(-) diff --git a/tf_adapter/kernels/aicpu/host_feature_mapping.h b/tf_adapter/kernels/aicpu/host_feature_mapping.h index 963cf8f09..4e18755dd 100644 --- a/tf_adapter/kernels/aicpu/host_feature_mapping.h +++ b/tf_adapter/kernels/aicpu/host_feature_mapping.h @@ -13,7 +13,10 @@ #ifndef TENSORFLOW_TF_ADAPTER_KERNELS_HOST_FEATURE_MAPPING_OP_H #define TENSORFLOW_TF_ADAPTER_KERNELS_HOST_FEATURE_MAPPING_OP_H +#include #include +#include +#include #include #include @@ -25,6 +28,15 @@ namespace tensorflow { namespace featuremapping { +#define MEM_NOK(mem_ret, logText) \ + do { \ + if (mem_ret != EOK) { \ + ADP_LOG(ERROR) << logText; \ + return false; \ + } \ + } while (0) + + using HashmapType = std::unordered_map>; struct FeatureMappingTable { explicit FeatureMappingTable(int32_t input_buckets_num, int32_t input_threshold) @@ -45,6 +57,42 @@ struct FeatureMappingTable { std::vector feature_mappings_ptr; // buckets_num分桶 }; extern std::unordered_map feature_mapping_table; + +struct MappingExportLineInfo { + uint32_t feature_id_size = sizeof(int64_t); + uint32_t counts_size = sizeof(int32_t); + uint32_t offset_id_size = sizeof(int64_t); + uint32_t pair_size_per_line = 0; + uint32_t value_export_size = 0; + int64_t feature_id = 0; + int64_t counts = 0; + int64_t offset_id = 0; + uint8_t *value = nullptr; +}; + +class ScopeGuard { + public: + explicit ScopeGuard(const std::function exitScope) : exitScope_(exitScope) {} + ~ScopeGuard() { + if (exitScope_ == nullptr) { + return; + } + + try { + exitScope_(); + } catch (...) { + // pass + } + } + + private: + ScopeGuard(const ScopeGuard&) = delete; + ScopeGuard(ScopeGuard&&) = delete; + ScopeGuard& operator=(const ScopeGuard&) = delete; + ScopeGuard& operator=(ScopeGuard&&) = delete; + + std::function exitScope_; +}; } // namespace featuremapping } // namespace tensorflow diff --git a/tf_adapter/kernels/aicpu/host_feature_mapping_export.cc b/tf_adapter/kernels/aicpu/host_feature_mapping_export.cc index d29cc6ceb..fc1547074 100644 --- a/tf_adapter/kernels/aicpu/host_feature_mapping_export.cc +++ b/tf_adapter/kernels/aicpu/host_feature_mapping_export.cc @@ -21,6 +21,8 @@ namespace tensorflow { namespace featuremapping { +const std::string kTxtFileSuffix = ".txt"; +const std::string kMetaFileSuffix = ".meta"; const std::string kBinFileSuffix = ".bin"; class FeatureMappingExportOp : public OpKernel { @@ -28,12 +30,39 @@ class FeatureMappingExportOp : public OpKernel { explicit FeatureMappingExportOp(OpKernelConstruction *ctx) : OpKernel(ctx) { ADP_LOG(DEBUG) << "FeatureMappingExport built"; OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name_list", &table_name_list_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("embedding_dims", &embedding_dims_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("variables", &value_vairables_)); } ~FeatureMappingExportOp() override { ADP_LOG(DEBUG) << "FeatureMappingExport has been destructed"; } - void WriteMappingContens2File(std::string &table_name, std::string &dst_path) const { + bool CopyOneLineData(char* dst_buf, MappingExportLineInfo &mapping_info, uint32_t &offset, uint32_t total_size) const { + ADP_LOG(DEBUG) << "CopyOneLine offset " << offset << " total_size " << total_size; + auto mem_ret = memcpy_s(dst_buf + offset, total_size - offset, &(mapping_info.feature_id), mapping_info.feature_id_size); + MEM_NOK(mem_ret, "host export cpy feature id failed!"); + offset += mapping_info.feature_id_size; + + ADP_LOG(DEBUG) << "CopyOneLine offset " << offset; + mem_ret = memcpy_s(dst_buf + offset, total_size - offset, &(mapping_info.counts), mapping_info.counts_size); + MEM_NOK(mem_ret, "host export cpy counts failed!"); + offset += mapping_info.counts_size; + + ADP_LOG(DEBUG) << "CopyOneLine offset " << offset; + mem_ret = memcpy_s(dst_buf + offset, total_size - offset, &(mapping_info.offset_id), mapping_info.offset_id_size); + MEM_NOK(mem_ret, "host export cpy offset id failed!"); + offset += mapping_info.offset_id_size; + + ADP_LOG(DEBUG) << "CopyOneLine offset " << offset; + mem_ret = memcpy_s(dst_buf + offset, total_size - offset, mapping_info.value, + mapping_info.value_export_size); + MEM_NOK(mem_ret, "host export cpy value failed!"); + offset += mapping_info.value_export_size; + ADP_LOG(DEBUG) << "CopyOneLine offset " << offset; + return true; + } + + void WriteMappingContens2FileTxt(OpKernelContext &ctx, std::string &table_name, std::string &dst_path) const { auto it = feature_mapping_table.find(table_name); if (it == feature_mapping_table.end()) { ADP_LOG(WARNING) << "this table " << table_name << " is not in mapping, just skip"; @@ -46,32 +75,202 @@ class FeatureMappingExportOp : public OpKernel { return; } + // write contents to meta file not inluded value try { std::ofstream out_stream(dst_path); + ScopeGuard file_guard([&out_stream]() { out_stream.close(); }); // current use only one bucket refer to host feature mapping op int32_t bucket_index = 0; const auto mapping_map = table->feature_mappings_ptr[bucket_index]; + std::unordered_map>::iterator map_iter; for (map_iter = mapping_map->begin(); map_iter != mapping_map->end(); ++map_iter) { - const int64_t feature_id = map_iter->first; + int64_t feature_id = map_iter->first; std::pair &count_and_offset = map_iter->second; - const int64_t counts = count_and_offset.first; - const int64_t offset_id = count_and_offset.second; - // feature_id: 3 | counts: 1 | offset_id: 7 + int64_t counts = count_and_offset.first; + int64_t offset_id = count_and_offset.second; + // feature_id: 3 | counts: 1 | offset_id: 7 | value: [x, x, x] std::string content = "feature_id: " + std::to_string(feature_id) + " | " + "counts: " + std::to_string(counts) + " | " + "offset_id: " + std::to_string(offset_id); + ADP_LOG(DEBUG) << "TXT content: " << content; + out_stream << content << std::endl; + } + } catch (std::exception &e) { + ADP_LOG(ERROR) << "write to file " << dst_path << " failed, err: " << e.what(); + return; + } + } + + void WriteMappingContens2FileMeta(OpKernelContext &ctx, std::string &table_name, std::string &dst_path) const { + auto it = feature_mapping_table.find(table_name); + if (it == feature_mapping_table.end()) { + ADP_LOG(WARNING) << "this table " << table_name << " is not in mapping, just skip"; + return; + } + + FeatureMappingTable *table = it->second; + if (table == nullptr) { + ADP_LOG(ERROR) << "table map find but table is nullptr"; + return; + } + + // write contents to meta file inluded value + try { + std::ofstream out_stream(dst_path); + ScopeGuard file_guard([&out_stream]() { out_stream.close(); }); + // current use only one bucket refer to host feature mapping op + int32_t bucket_index = 0; + const auto mapping_map = table->feature_mappings_ptr[bucket_index]; + + auto iter = std::find(table_name_list_.begin(), table_name_list_.end(), table_name); + if (iter == table_name_list_.end()) { + ADP_LOG(WARNING) << "this table " << table_name << " is not in table_name_list_, error"; + return; + } + size_t index = std::distance(table_name_list_.begin(), iter); + if (index >= embedding_dims_.size()) { + ADP_LOG(ERROR) << "index " << index << " over embedding dims size " << embedding_dims_.size(); + return; + } + uint32_t embedding_dim = embedding_dims_[index]; + if (embedding_dims_.size() != value_vairables_.size()) { + ADP_LOG(ERROR) << "embedding dims size " << embedding_dims_.size() << " value_vairables size " << value_vairables_.size(); + return; + } + ADP_LOG(DEBUG) << "index " << index << " embedding_dim " << embedding_dim; + + // const Tensor &value_vairable_tensor = value_vairables_[index]; + const Tensor &value_vairable_tensor = ctx.input(1)[index]; + auto value_varaible_flat = value_vairable_tensor.flat(); + ADP_LOG(DEBUG) << "value_varaible_flat(0) " << value_varaible_flat(0); + size_t value_size = value_varaible_flat.size(); + ADP_LOG(DEBUG) << "value_varaible_flat size " << value_size; + + std::unordered_map>::iterator map_iter; + for (map_iter = mapping_map->begin(); map_iter != mapping_map->end(); ++map_iter) { + int64_t feature_id = map_iter->first; + std::pair &count_and_offset = map_iter->second; + int64_t counts = count_and_offset.first; + int64_t offset_id = count_and_offset.second; + ADP_LOG(DEBUG) << " feature_id " << feature_id << " counts " << counts << " offset_id " << offset_id; + + std::string value_content = "["; + std::ostringstream ss; + float value_tmp = 0.0; + uint32_t offset = 0; + for (offset = 0; offset < embedding_dim - 1; ++offset) { + value_tmp = value_varaible_flat(offset + offset_id * embedding_dim); + ss.str(""); + ss << value_tmp; + value_content += ss.str() + ", "; + ADP_LOG(DEBUG) << "for offset " << offset << " value_content " << value_content; + } + value_tmp = value_varaible_flat(offset + offset_id * embedding_dim); + ss.str(""); + ss << value_tmp; + value_content += ss.str() + "]"; + ADP_LOG(DEBUG) << "for offset " << offset << " value_content " << value_content; + + // feature_id: 3 | counts: 1 | offset_id: 7 | value: [x, x, x] + std::string content = "feature_id: " + std::to_string(feature_id) + " | " + + "counts: " + std::to_string(counts) + " | " + + "offset_id: " + std::to_string(offset_id) + " | " + + "value: " + value_content; ADP_LOG(DEBUG) << "content: " << content; out_stream << content << std::endl; } - out_stream.close(); } catch (std::exception &e) { ADP_LOG(ERROR) << "write to file " << dst_path << " failed, err: " << e.what(); return; } } - void SaveFeatureMapping2File(const std::string &path) const { + void WriteMappingContens2File(OpKernelContext &ctx, std::string &table_name, std::string &dst_path) const { + auto it = feature_mapping_table.find(table_name); + if (it == feature_mapping_table.end()) { + ADP_LOG(WARNING) << "this table " << table_name << " is not in mapping, just skip"; + return; + } + + FeatureMappingTable *table = it->second; + if (table == nullptr) { + ADP_LOG(ERROR) << "table map find but table is nullptr"; + return; + } + + try { + std::ofstream out_stream(dst_path); + ScopeGuard file_guard([&out_stream]() { out_stream.close(); }); + // current use only one bucket refer to host feature mapping op + int32_t bucket_index = 0; + const auto mapping_map = table->feature_mappings_ptr[bucket_index]; + + auto iter = std::find(table_name_list_.begin(), table_name_list_.end(), table_name); + if (iter == table_name_list_.end()) { + ADP_LOG(WARNING) << "this table " << table_name << " is not in table_name_list_, error"; + return; + } + size_t index = std::distance(table_name_list_.begin(), iter); + if (index >= embedding_dims_.size()) { + ADP_LOG(ERROR) << "index " << index << " over embedding dims size " << embedding_dims_.size(); + return; + } + int embedding_dim = embedding_dims_[index]; + ADP_LOG(DEBUG) << "index " << index << " embedding_dim " << embedding_dim; + if (embedding_dims_.size() != value_vairables_.size()) { + ADP_LOG(ERROR) << "embedding dims size " << embedding_dims_.size() << " value_vairables size " << value_vairables_.size(); + return; + } + + MappingExportLineInfo mapping_info; + mapping_info.value_export_size = embedding_dim * sizeof(float); + mapping_info.pair_size_per_line = mapping_info.feature_id_size + mapping_info.counts_size + + mapping_info.offset_id_size + mapping_info.value_export_size; + uint32_t chunk_size = mapping_info.pair_size_per_line * mapping_map->size(); + ADP_LOG(DEBUG) << "table_name: " << table_name; + ADP_LOG(DEBUG) << "pair_size_per_line: " << mapping_info.pair_size_per_line; + ADP_LOG(DEBUG) << "value_export_size: " << mapping_info.value_export_size; + ADP_LOG(DEBUG) << "chunk_size: " << chunk_size; + + char *chunk_read_buf = new (std::nothrow) char[chunk_size]; + if (chunk_read_buf == nullptr) { + ADP_LOG(ERROR) << "host export new char failed "; + return; + } + + // const Tensor &value_vairable_tensor = value_vairables_[index]; + const Tensor &value_vairable_tensor = ctx.input(1)[index]; + auto value_varaible = (float *)(value_vairable_tensor.tensor_data().data()); + ADP_LOG(DEBUG) << "value_variable_addr " << value_varaible; + + auto value_varaible_flat = value_vairable_tensor.flat(); + ADP_LOG(DEBUG) << "value_varaible_flat(0) " << value_varaible_flat(0); + size_t value_size = value_varaible_flat.size(); + ADP_LOG(DEBUG) << "value_varaible_flat size " << value_size; + + uint32_t offset = 0; + std::unordered_map>::iterator map_iter; + for (map_iter = mapping_map->begin(); map_iter != mapping_map->end(); ++map_iter) { + mapping_info.feature_id = map_iter->first; + std::pair &count_and_offset = map_iter->second; + mapping_info.counts = count_and_offset.first; + mapping_info.offset_id = count_and_offset.second; + mapping_info.value = reinterpret_cast(value_varaible + embedding_dim * mapping_info.offset_id); + if (CopyOneLineData(chunk_read_buf, mapping_info, offset, chunk_size)) { + ADP_LOG(ERROR) << "host export copy one line data failed "; + break; + } + } + out_stream.write(chunk_read_buf, chunk_size); + delete[] chunk_read_buf; + } catch (std::exception &e) { + ADP_LOG(ERROR) << "write to file " << dst_path << " failed, err: " << e.what(); + return; + } + } + + void SaveFeatureMapping2File(OpKernelContext &ctx, const std::string &path) const { const size_t path_length = path.size(); std::string dst_path_way = path; if (path[path_length - 1] != '/') { @@ -91,23 +290,64 @@ class FeatureMappingExportOp : public OpKernel { const size_t name_size = table_name_list_.size(); ADP_LOG(DEBUG) << "dst_path_way " << dst_path_way << " name_size " << name_size; + + ADP_LOG(DEBUG) << "111 export TXT file"; + if (name_size == 0) { + ADP_LOG(DEBUG) << "default export all feature mapping"; + for (const auto &map_pair : feature_mapping_table) { + std::string table_name = map_pair.first; + std::string dst_path_file = dst_path_way + table_name + kTxtFileSuffix; + ADP_LOG(DEBUG) << "TXT table_name " << table_name << " dst_path_file " << dst_path_file; + WriteMappingContens2FileTxt(ctx, table_name, dst_path_file); + } + } else { + ADP_LOG(DEBUG) << "export attr name of user specified"; + for (size_t index = 0; index < name_size; ++index) { + std::string attr_table_name = std::string(table_name_list_[index]); + std::string dst_file_path = dst_path_way + attr_table_name + kTxtFileSuffix; + ADP_LOG(DEBUG) << "TXT attr_table_name " << attr_table_name << " dst_file_path " << dst_file_path; + WriteMappingContens2FileTxt(ctx, attr_table_name, dst_file_path); + } + } + + ADP_LOG(DEBUG) << "222 export META file"; + if (name_size == 0) { + ADP_LOG(DEBUG) << "default export all feature mapping"; + for (const auto &map_pair : feature_mapping_table) { + std::string table_name = map_pair.first; + std::string dst_path_file = dst_path_way + table_name + kMetaFileSuffix; + ADP_LOG(DEBUG) << "META table_name " << table_name << " dst_path_file " << dst_path_file; + WriteMappingContens2FileMeta(ctx, table_name, dst_path_file); + } + } else { + ADP_LOG(DEBUG) << "export attr name of user specified"; + for (size_t index = 0; index < name_size; ++index) { + std::string attr_table_name = std::string(table_name_list_[index]); + std::string dst_file_path = dst_path_way + attr_table_name + kMetaFileSuffix; + ADP_LOG(DEBUG) << "META attr_table_name " << attr_table_name << " dst_file_path " << dst_file_path; + WriteMappingContens2FileMeta(ctx, attr_table_name, dst_file_path); + } + } + + ADP_LOG(DEBUG) << "333 export BIN file"; if (name_size == 0) { ADP_LOG(DEBUG) << "default export all feature mapping"; for (const auto &map_pair : feature_mapping_table) { std::string table_name = map_pair.first; std::string dst_path_file = dst_path_way + table_name + kBinFileSuffix; - ADP_LOG(DEBUG) << "table_name " << table_name << " dst_path_file " << dst_path_file; - WriteMappingContens2File(table_name, dst_path_file); + ADP_LOG(DEBUG) << "BIN table_name " << table_name << " dst_path_file " << dst_path_file; + WriteMappingContens2File(ctx, table_name, dst_path_file); } } else { ADP_LOG(DEBUG) << "export attr name of user specified"; for (size_t index = 0; index < name_size; ++index) { std::string attr_table_name = std::string(table_name_list_[index]); std::string dst_file_path = dst_path_way + attr_table_name + kBinFileSuffix; - ADP_LOG(DEBUG) << "attr_table_name " << attr_table_name << " dst_file_path " << dst_file_path; - WriteMappingContens2File(attr_table_name, dst_file_path); + ADP_LOG(DEBUG) << "BIN attr_table_name " << attr_table_name << " dst_file_path " << dst_file_path; + WriteMappingContens2File(ctx, attr_table_name, dst_file_path); } } + return; } @@ -124,12 +364,15 @@ class FeatureMappingExportOp : public OpKernel { errors::InvalidArgument("path should be a valid string.")); Tensor *output_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, save_path_tensor.shape(), &output_tensor)); - SaveFeatureMapping2File(std::string(save_path)); + + SaveFeatureMapping2File(*ctx, std::string(save_path)); ADP_LOG(INFO) << "Host FeatureMappingExport compute end"; } private: std::vector table_name_list_{}; + std::vector embedding_dims_{}; + std::vector value_vairables_{}; }; REGISTER_KERNEL_BUILDER(Name("FeatureMappingExport").Device(DEVICE_CPU), FeatureMappingExportOp); diff --git a/tf_adapter/kernels/aicpu/host_feature_mapping_import.cc b/tf_adapter/kernels/aicpu/host_feature_mapping_import.cc index ed3a2d2c6..b5947c30b 100644 --- a/tf_adapter/kernels/aicpu/host_feature_mapping_import.cc +++ b/tf_adapter/kernels/aicpu/host_feature_mapping_import.cc @@ -28,6 +28,8 @@ class FeatureMappingImportOp : public OpKernel { public: explicit FeatureMappingImportOp(OpKernelConstruction *ctx) : OpKernel(ctx) { ADP_LOG(DEBUG) << "Host FeatureMappingImport built"; + OP_REQUIRES_OK(ctx, ctx->GetAttr("table_name_list", &table_name_list_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("embedding_dims", &embedding_dims_)); } ~FeatureMappingImportOp() override { ADP_LOG(DEBUG) << "Host FeatureMappingImport has been destructed"; @@ -103,8 +105,85 @@ class FeatureMappingImportOp : public OpKernel { return; } - void FindTableDoImport(std::string &dst_path_way, std::string &file_name) const { + void FindTableDoImportBin(std::string &dst_path_way, std::string &file_name) const { + if (file_name.back() == 't' || file_name.back() == 'a') { + ADP_LOG(INFO) << "BIN skip txt and meta file " << file_name; + return; + } + + std::string table_name = ""; + size_t pos_period = file_name.find_last_of("."); + if (pos_period != std::string::npos) { + table_name = file_name.substr(0, pos_period); + } else { + ADP_LOG(ERROR) << "parse file " << file_name << " error"; + return; + } + ADP_LOG(DEBUG) << "BIN parse bin file file_name " << file_name << " table_name " << table_name; + + auto it = feature_mapping_table.find(table_name); + if (it == feature_mapping_table.end()) { + ADP_LOG(WARNING) << "this table " << table_name << " is not in mapping, just skip"; + return; + } + + FeatureMappingTable *table = it->second; + if (table == nullptr) { + ADP_LOG(ERROR) << "table map find but table is nullptr"; + return; + } + + auto iter = std::find(table_name_list_.begin(), table_name_list_.end(), table_name); + if (iter == table_name_list_.end()) { + ADP_LOG(WARNING) << "this table " << table_name << " is not in table_name_list_, error"; + return; + } + size_t index = std::distance(table_name_list_.begin(), iter); + if (index >= embedding_dims_.size()) { + ADP_LOG(ERROR) << "index " << index << " over embedding dims size " << embedding_dims_.size(); + return; + } + int embedding_dim = embedding_dims_[index]; + ADP_LOG(DEBUG) << "index " << index << " BIN embedding_dim " << embedding_dim; + std::string src_file_name = dst_path_way + file_name; + std::ifstream weight_stream(src_file_name, std::ios::binary | std::ios::ate); + if (!weight_stream.is_open()) { + ADP_LOG(ERROR) << "oepn file failed " << src_file_name; + return; + } + ScopeGuard file_guard([&weight_stream]() { weight_stream.close(); }); + weight_stream.seekg(0, std::ios::end); + uint64_t file_total_size = weight_stream.tellg(); + weight_stream.seekg(0, std::ios::beg); + + char *chunk_read_buf = new (std::nothrow) char[file_total_size]; + if (chunk_read_buf == nullptr) { + ADP_LOG(ERROR) << "host export new char failed "; + return; + } + + weight_stream.read(chunk_read_buf, file_total_size); + for (uint64_t offset = 0; offset < file_total_size;) { + int64_t feature_id = *((int64_t *)(chunk_read_buf + offset)); + offset += sizeof(int64_t); + int64_t counts = *((int64_t *)(chunk_read_buf + offset)); + offset += sizeof(int32_t); + int64_t offset_id = *((int64_t *)(chunk_read_buf + offset)); + offset += sizeof(int64_t); + offset += embedding_dim * sizeof(float); + ADP_LOG(DEBUG) << "import BIN offset " << offset << " feature_id " << feature_id << " counts " \ + << counts << " offset_id " << offset_id; + } + ADP_LOG(DEBUG) << "Bin file " << src_file_name << " read succ."; + } + + void FindTableDoImportTxt(std::string &dst_path_way, std::string &file_name) const { + std::string src_file_name = dst_path_way + file_name; + if (file_name.back() == 'n' || file_name.back() == 'a') { + ADP_LOG(INFO) << "skip bin and meta file " << file_name; + return; + } try { std::ifstream in_stream(src_file_name); if (!in_stream.is_open()) { @@ -155,7 +234,8 @@ class FeatureMappingImportOp : public OpKernel { continue; } ADP_LOG(DEBUG) << "file_name: " << ent->d_name; - FindTableDoImport(dst_path_way, file_name); + FindTableDoImportTxt(dst_path_way, file_name); + FindTableDoImportBin(dst_path_way, file_name); } closedir(dir); } else { @@ -179,6 +259,10 @@ class FeatureMappingImportOp : public OpKernel { TraverseAndParse(std::string(restore_path)); ADP_LOG(INFO) << "Host FeatureMappingImport compute end"; } + + private: + std::vector table_name_list_{}; + std::vector embedding_dims_{}; }; REGISTER_KERNEL_BUILDER(Name("FeatureMappingImport").Device(DEVICE_CPU), FeatureMappingImportOp); diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 489805ba8..e91d0ae34 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -515,6 +515,8 @@ REGISTER_OP("HostFeatureMapping") REGISTER_OP("FeatureMappingExport") .Input("path: string") .Attr("table_name_list: list(string)") + .Attr("embedding_dims: list(int)") + .Input("variable: list(float)") .Output("export_fake_output: string") .SetShapeFn([](shape_inference::InferenceContext *c) { c->set_output(0, c->input(0)); @@ -523,6 +525,8 @@ REGISTER_OP("FeatureMappingExport") REGISTER_OP("FeatureMappingImport") .Input("path: string") + .Attr("table_name_list: list(string)") + .Attr("embedding_dims: list(int)") .Output("import_fake_output: string") .SetShapeFn([](shape_inference::InferenceContext *c) { c->set_output(0, c->input(0)); diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 480582121..ada1fe01d 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -346,10 +346,13 @@ def device_feature_mapping(feature_id): ## 提供host侧FeatureMapping Import功能 # @param path string 类型 # @param table_name string 类型 +# @param embedding_dims int 类型 +# @param variable float 类型 # @return fake int32 类型 -def host_feature_mapping_export(path, table_name_list): +def host_feature_mapping_export(path, table_name_list, embedding_dims, variable): """ host feature mapping export. """ - result = gen_npu_cpu_ops.FeatureMappingExport(path=path, table_name_list=table_name_list) + result = gen_npu_cpu_ops.FeatureMappingExport(path=path, table_name_list=table_name_list, + embedding_dims=embedding_dims, variable=variable) return result @@ -357,7 +360,8 @@ def host_feature_mapping_export(path, table_name_list): # @param path string 类型 # @param table_name string 类型 # @return fake int32 类型 -def host_feature_mapping_import(path): +def host_feature_mapping_import(path, table_name_list, embedding_dims): """ host feature mapping export. """ - result = gen_npu_cpu_ops.FeatureMappingImport(path=path) + result = gen_npu_cpu_ops.FeatureMappingImport(path=path, table_name_list=table_name_list, + embedding_dims=embedding_dims) return result -- Gitee