From d832dd3211a0209454c1777b7bdf003994f755ae Mon Sep 17 00:00:00 2001 From: l00520113 Date: Tue, 29 Oct 2024 09:59:53 +0800 Subject: [PATCH] add hash_table_import --- tf_adapter/kernels/aicore/npu_aicore_ops.cc | 11 +++++++++++ tf_adapter/ops/aicore/npu_aicore_ops.cc | 10 ++++++++++ .../python/npu_bridge/npu_cpu/npu_cpu_ops.py | 16 ++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index c39396d1b..4fe4a347e 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -43,6 +43,17 @@ class FastGeluOp : public tensorflow::OpKernel { } }; +class EmbeddingHashTableImportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableImportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableImportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); + + class EmbeddingHashTableLookupOrInsertOp : public tensorflow::OpKernel { public: explicit EmbeddingHashTableLookupOrInsertOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 12f5176d9..8374c620c 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -44,6 +44,16 @@ REGISTER_OP("FastGeluGrad") .Attr("T: realnumbertype") .SetShapeFn(tensorflow::shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("EmbeddingHashTableImport") + .Input("table_handles: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Input("keys: int64") + .Input("counters: uint64") + .Input("filter_flags: uint8") + .Input("values: float") + .SetShapeFn(tensorflow::shape_inference::NoOutputs); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 805fb7479..09211c56f 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -593,4 +593,20 @@ def init_embedding_hashtable(table_handle, sampled_values, bucket_size, embeddin result = gen_npu_cpu_ops.InitEmbeddingHashTable( table_handle=table_handle, sampled_values=sampled_values, bucket_size=bucket_size, embedding_dim=embedding_dim, initializer_mode=initializer_mode, constant_value=constant_value) + return result + + +## 提供host侧hashTable导入功能 +# @param table_handles int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param keys int64 类型 +# @param counters uint64 类型 +# @param filter_flags uint8 类型 +# @param values float 类型 +def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, keys, counters, filter_flags, values): + """ host embedding feature hash table import. """ + result = gen_npu_cpu_ops.EmbeddingHashTableImport( + table_handles=table_handles, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + keys=keys, counters=counters, filter_flags=filter_flags, values=values) return result \ No newline at end of file -- Gitee