From e27ab6a05bb715086eac8f45bc8073a0b2a8afdb Mon Sep 17 00:00:00 2001 From: gnodli Date: Mon, 25 Nov 2024 16:49:58 +0800 Subject: [PATCH] add embedding_hashtable_export --- tf_adapter/kernels/aicore/npu_aicore_ops.cc | 11 +++++++- tf_adapter/ops/aicore/npu_aicore_ops.cc | 25 +++++++++++++++++++ .../python/npu_bridge/npu_cpu/npu_cpu_ops.py | 16 ++++++++++++ 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index 3c6c5adb0..35c12ac69 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -53,6 +53,15 @@ public: REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport") .Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); +class EmbeddingHashTableExportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableExportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableExportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableExport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableExportOp); class EmbeddingHashTableLookupOrInsertOp : public tensorflow::OpKernel { public: @@ -141,7 +150,7 @@ public: }; REGISTER_KERNEL_BUILDER(Name("InitEmbeddingHashTable").Device(tensorflow::DEVICE_CPU), InitEmbeddingHashTableOp); - + class EmbeddingHashTableApplyAdamWOp : public tensorflow::OpKernel { public: explicit EmbeddingHashTableApplyAdamWOp(tensorflow::OpKernelConstruction *context) diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 166956fcb..1f68c787e 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -54,6 +54,31 @@ REGISTER_OP("EmbeddingHashTableImport") .Input("values: float") .SetShapeFn(tensorflow::shape_inference::NoOutputs); +REGISTER_OP("EmbeddingHashTableExport") + .Input("table_handles: int64") + .Input("table_sizes: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Output("keys: num * int64") + .Output("counters: num * uint64") + .Output("filter_flags: num * uint8") + .Output("values: num * float") + .Attr("export_mode: string = 'all'") + .Attr("filter_export_flag: bool = false") + .Attr("num: int >= 1") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + int64 num = 0; + c->GetAttr("num", &num); + for (int64_t i = 0; i < num; ++i) { + c->set_output(i, c->Vector(c->UnknownDim())); + c->set_output(i + num, c->Vector(c->UnknownDim())); + c->set_output(i + 2 * num, c->Vector(c->UnknownDim())); + c->set_output(i + 3 * num, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 41a509cd2..273c67241 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -613,6 +613,22 @@ def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, key return result +## 提供host侧hashTable导出功能 +# @param table_handles int64 类型 +# @param table_sizes int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param export_mode string 类型 +# @param filtered_export_flag bool 类型 +def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, bucket_sizes, export_mode='all', + filter_export_flag=False): + """ host embedding feature hash table export. """ + result = gen_npu_cpu_ops.EmbeddingHashTableExport( + table_handles=table_handles, table_sizes=table_sizes, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + export_mode=export_mode, filter_export_flag=filter_export_flag) + return result + + ## EmbeddingHashTableApplyAdamW AdamW 更新功能 # @param table_handle int64 类型 # @param keys int64 类型 -- Gitee