From 3eabcb5044a5687fe4bb12773e4932988dacf6fd Mon Sep 17 00:00:00 2001 From: "liyefeng803@huawei.com" Date: Sat, 24 May 2025 15:34:16 +0800 Subject: [PATCH 1/2] fix hashtable ops --- tf_adapter/kernels/aicore/npu_aicore_ops.cc | 50 ++++++++++ tf_adapter/ops/aicore/npu_aicore_ops.cc | 87 +++++++++++++++++ .../python/npu_bridge/npu_cpu/npu_cpu_ops.py | 96 +++++++++++++++++++ 3 files changed, 233 insertions(+) diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index b66740143..cfb8253c8 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -43,6 +43,36 @@ class FastGeluOp : public tensorflow::OpKernel { } }; +class EmbeddingHashTableImportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableImportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableImportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); + +class EmbeddingHashTableExportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableExportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableExportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableExport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableExportOp); + +class EmbeddingHashTableLookupOrInsertOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableLookupOrInsertOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableLookupOrInsertOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableLookupOrInsert") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableLookupOrInsertOp); + REGISTER_KERNEL_BUILDER( Name("FastGelu") . @@ -121,4 +151,24 @@ REGISTER_KERNEL_BUILDER( Device(tensorflow::DEVICE_CPU) .TypeConstraint("T"), FastGeluGradOp); + +class InitEmbeddingHashTableOp : public tensorflow::OpKernel { +public: + explicit InitEmbeddingHashTableOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~InitEmbeddingHashTableOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("InitEmbeddingHashTable").Device(tensorflow::DEVICE_CPU), InitEmbeddingHashTableOp); + +class EmbeddingHashTableApplyAdamWOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableApplyAdamWOp(tensorflow::OpKernelConstruction *context) + : OpKernel(context) {} + ~EmbeddingHashTableApplyAdamWOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableApplyAdamW").Device(tensorflow::DEVICE_CPU), + EmbeddingHashTableApplyAdamWOp); } // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index dfaac8e8f..36bd35ebf 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -44,6 +44,63 @@ REGISTER_OP("FastGeluGrad") .Attr("T: realnumbertype") .SetShapeFn(tensorflow::shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("EmbeddingHashTableImport") + .Input("table_handles: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Input("keys: num * int64") + .Input("counters: num * uint64") + .Input("filter_flags: num * uint8") + .Input("values: num * float32") + .Attr("num: int >= 1") + .SetShapeFn(tensorflow::shape_inference::NoOutputs); + +REGISTER_OP("EmbeddingHashTableExport") + .Input("table_handles: int64") + .Input("table_sizes: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Output("keys: num * int64") + .Output("counters: num * uint64") + .Output("filter_flags: num * uint8") + .Output("values: num * float") + .Attr("export_mode: string = 'all'") + .Attr("filter_export_flag: bool = false") + .Attr("num: int >= 1") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + int64 num = 0; + c->GetAttr("num", &num); + for (int64_t i = 0; i < num; ++i) { + c->set_output(i, c->Vector(c->UnknownDim())); + c->set_output(i + num, c->Vector(c->UnknownDim())); + c->set_output(i + 2 * num, c->Vector(c->UnknownDim())); + c->set_output(i + 3 * num, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + +REGISTER_OP("EmbeddingHashTableApplyAdamW") + .Input("table_handle: int64") + .Input("keys: int64") + .Input("m: Ref(T)") + .Input("v: Ref(T)") + .Input("beta1_power: Ref(T)") + .Input("beta2_power: Ref(T)") + .Input("lr: T") + .Input("weight_decay: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("max_grad_norm: Ref(T)") + .Attr("embedding_dim: int") + .Attr("bucket_size: int") + .Attr("amsgrad: bool = false") + .Attr("maximize: bool = false") + .Attr("T: {float16, float32}") + .SetShapeFn(tensorflow::shape_inference::NoOutputs); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") @@ -459,6 +516,27 @@ REGISTER_OP("DynamicRnnGrad") return Status::OK(); }); +REGISTER_OP("EmbeddingHashTableLookupOrInsert") + .Input("table_handle: int64") + .Input("keys:int64") + .Output("values: float") + .Attr("bucket_size:int") + .Attr("embedding_dim:int") + .Attr("filter_mode:string='no_filter'") + .Attr("filter_freq:int=0") + .Attr("default_key_or_value:bool = false") + .Attr("default_key: int = 0") + .Attr("default_value: float = 0.0") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + int64 num = 0; + c->GetAttr("embedding_dim", &num); + auto key_num = c->input(1); + int64_t nsample = InferenceContext::Value(c->Dim(key_num, 0)); + c->set_output(0, c->MakeShape({c->MakeDim(nsample), c->MakeDim(num)})); + return Status::OK(); + }); + REGISTER_OP("LRUCacheV2") .Input("index_list: T") .Input("data: Ref(dtype)") @@ -792,5 +870,14 @@ REGISTER_OP("TabulateFusionGrad") c->set_output(1, c->input(3)); return Status::OK(); }); + +REGISTER_OP("InitEmbeddingHashTable") + .Input("table_handle: int64") + .Input("sampled_values: float") + .Attr("bucket_size : int") + .Attr("embedding_dim : int") + .Attr("initializer_mode : string='random'") + .Attr("constant_value : float=0.0") + .SetShapeFn(shape_inference::NoOutputs); } // namespace } // namespace tensorflow diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index ba6ed2b5d..b8c73495e 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -24,6 +24,25 @@ from npu_bridge.helper import helper gen_npu_cpu_ops = helper.get_gen_ops() +## 提供device侧FeatureMapping LookupOrInsert功能 +# @param table_handle int64 类型 +# @param keys int64 类型 +# @param bucket_size int 类型 +# @param embedding_dim int 类型 +# @param filter_mode string 类型 +# @param filter_freq int 类型 +# @param default_key_or_value bool 类型 +# @param default_key int 类型 +# @param default_value float 类型 +# @return values float 类型 +def embedding_hashtable_lookup_or_insert(table_handle, keys, bucket_size, embedding_dim, filter_mode, filter_freq, + default_key_or_value, default_key, default_value): + """ device embedding feature mapping lookup or insert. """ + result = gen_npu_cpu_ops.EmbeddingHashTableLookupOrInsert( + table_handle=table_handle, keys=keys, bucket_size=bucket_size, embedding_dim=embedding_dim, + filter_mode=filter_mode, filter_freq=filter_freq, default_key_or_value=default_key_or_value, + default_key=default_key, default_value=default_value) + return result ## 提供embeddingrankid功能 # @param addr_tensor tensorflow的tensor类型,embeddingrankid操作的输入; @@ -577,3 +596,80 @@ def embedding_hashmap_import_v2(file_path, table_ids, table_sizes, table_names, file_path=file_path, table_ids=table_ids, table_sizes=table_sizes, table_names=table_names, global_step=global_step, embedding_dims=embedding_dims, num=num) return result + +## EmbeddingHashTable Init功能 +# @param table_handle int64 类型 +# @param sampled_values float 类型 +# @param bucket_size int 类型 +# @param embedding_dim int 类型 +# @param initializer_mode string 类型 +# @param constant_value int 类型 +def init_embedding_hashtable(table_handle, sampled_values, bucket_size, embedding_dim, initializer_mode, + constant_value): + """ device init embedding hashtable. """ + result = gen_npu_cpu_ops.InitEmbeddingHashTable( + table_handle=table_handle, sampled_values=sampled_values, bucket_size=bucket_size, embedding_dim=embedding_dim, + initializer_mode=initializer_mode, constant_value=constant_value) + return result + + +## 提供host侧hashTable导入功能 +# @param table_handles int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param keys int64 类型 +# @param counters uint64 类型 +# @param filter_flags uint8 类型 +# @param values float 类型 +def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, keys, counters, filter_flags, values): + """ host embedding feature hash table import. """ + result = gen_npu_cpu_ops.EmbeddingHashTableImport( + table_handles=table_handles, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + keys=keys, counters=counters, filter_flags=filter_flags, values=values) + return result + + +## 提供host侧hashTable导出功能 +# @param table_handles int64 类型 +# @param table_sizes int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param export_mode string 类型 +# @param filtered_export_flag bool 类型 +def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, bucket_sizes, export_mode='all', + filter_export_flag=False): + """ host embedding feature hash table export. """ + result = gen_npu_cpu_ops.EmbeddingHashTableExport( + table_handles=table_handles, table_sizes=table_sizes, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + export_mode=export_mode, filter_export_flag=filter_export_flag) + return result + + +## EmbeddingHashTableApplyAdamW AdamW 更新功能 +# @param table_handle int64 类型 +# @param keys int64 类型 +# @param m float16, float32 类型 +# @param v float16, float32 类型 +# @param beta1_power float16, float32 类型 +# @param beta2_power float16, float32 类型 +# @param lr float16, float32 类型 +# @param weight_decay float16, float32 类型 +# @param beta1 float16, float32 类型 +# @param beta2 float16, float32 类型 +# @param epsilon float16, float32 类型 +# @param grad float16, float32 类型 +# @param max_grad_norm float16, float32 类型 +# @param embedding_dim int 类型 +# @param bucket_size int 类型 +# @param amsgrad bool 类型 +# @param maximize bool 类型 +def embedding_hashtable_apply_adam_w(table_handle, keys, m, v, beta1_power, beta2_power, lr, weight_decay, + beta1, beta2, epsilon, grad, max_grad_norm, embedding_dim, + bucket_size, amsgrad, maximize): + """ device update embedding hashtable using AdamW. """ + result = gen_npu_cpu_ops.EmbeddingHashTableApplyAdamW( + table_handle=table_handle, keys=keys, m=m, v=v, beta1_power=beta1_power, beta2_power=beta2_power, + lr=lr, weight_decay=weight_decay, beta1=beta1, beta2=beta2, epsilon=epsilon, grad=grad, + max_grad_norm=max_grad_norm, embedding_dim=embedding_dim, bucket_size=bucket_size, + amsgrad=amsgrad, maximize=maximize) + return result -- Gitee From 83644ac7178a23419fc681c1b8fbe2c1286e0112 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=9A=E4=B8=B0?= Date: Wed, 28 May 2025 02:58:44 +0000 Subject: [PATCH 2/2] update tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py. fix codecheck MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 李业丰 --- tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index b8c73495e..f21dc95a8 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -24,6 +24,7 @@ from npu_bridge.helper import helper gen_npu_cpu_ops = helper.get_gen_ops() + ## 提供device侧FeatureMapping LookupOrInsert功能 # @param table_handle int64 类型 # @param keys int64 类型 @@ -44,6 +45,7 @@ def embedding_hashtable_lookup_or_insert(table_handle, keys, bucket_size, embedd default_key=default_key, default_value=default_value) return result + ## 提供embeddingrankid功能 # @param addr_tensor tensorflow的tensor类型,embeddingrankid操作的输入; # @param index tensorflow的tensor类型,embeddingrankid操作的输入; @@ -597,6 +599,7 @@ def embedding_hashmap_import_v2(file_path, table_ids, table_sizes, table_names, table_names=table_names, global_step=global_step, embedding_dims=embedding_dims, num=num) return result + ## EmbeddingHashTable Init功能 # @param table_handle int64 类型 # @param sampled_values float 类型 -- Gitee