From d9491eef2e4c8b01cdc3edb63d970d816769170a Mon Sep 17 00:00:00 2001 From: gnodli Date: Tue, 5 Nov 2024 13:35:26 +0800 Subject: [PATCH 1/4] add hashTableExporrt tf plugin --- tf_adapter/kernels/aicore/npu_aicore_ops.cc | 12 ++++- tf_adapter/ops/aicore/npu_aicore_ops.cc | 47 +++++++++++++++++++ .../python/npu_bridge/npu_cpu/npu_cpu_ops.py | 17 +++++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index 3c6c5adb0..49b126b43 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -53,6 +53,16 @@ public: REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport") .Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); +class EmbeddingHashTableExportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableExportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableExportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableExport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableExportOp); + class EmbeddingHashTableLookupOrInsertOp : public tensorflow::OpKernel { public: @@ -141,7 +151,7 @@ public: }; REGISTER_KERNEL_BUILDER(Name("InitEmbeddingHashTable").Device(tensorflow::DEVICE_CPU), InitEmbeddingHashTableOp); - + class EmbeddingHashTableApplyAdamWOp : public tensorflow::OpKernel { public: explicit EmbeddingHashTableApplyAdamWOp(tensorflow::OpKernelConstruction *context) diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 166956fcb..127ba7b47 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -54,6 +54,53 @@ REGISTER_OP("EmbeddingHashTableImport") .Input("values: float") .SetShapeFn(tensorflow::shape_inference::NoOutputs); +REGISTER_OP("EmbeddingHashTableExport") + .Input("table_handles: int64") + .Input("table_sizes: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Output("keys: int64") + .Output("counters: uint64") + .Output("filter_flags: uint8") + .Output("values: float") + .Attr("export_mode: string = all") + .Attr("filter_export_flag: bool = false") + .SetIsStateful() + .SetShapeFn([](InferenceContext *c) { + + int32_t real_dim_num = InferenceContext::Rank(c->input(0)); + int32_t begin_norm_axis = 0; + TF_RETURN_IF_ERROR(c->GetAttr("begin_norm_axis", &begin_norm_axis)); + if (begin_norm_axis < 0) { + begin_norm_axis += real_dim_num; + } + if (begin_norm_axis < 0 || begin_norm_axis >= real_dim_num) { + return errors::InvalidArgument("begin_norm_axis is invalid"); + } + ShapeHandle input_shape_handle; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), real_dim_num, &input_shape_handle)); + ShapeHandle out_shape_handle; + for (int32_t i = 0; i < real_dim_num; ++i) { + DimensionHandle tmp_dim_handle = c->Dim(input_shape_handle, i); + if (i >= begin_norm_axis) { + tmp_dim_handle = c->MakeDim(1); + TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape_handle, i, tmp_dim_handle, &out_shape_handle)); + } + } + c->set_output(0, c->input(0)); + c->set_output(1, out_shape_handle); + c->set_output(2, out_shape_handle); + + + auto num_table_shape = c->input(1); + auto output_h_shape = c->MakeShape({input_shape, 0}); + c->set_output(0, output_keys_shape); + c->set_output(1, output_counters_shape); + c->set_output(2, output_filter_flags_shape); + c->set_output(3, output_values_shape); + return Status::OK(); + }); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 301e64913..7b1a50666 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -612,6 +612,23 @@ def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, key return result +## 提供host侧hashTable导出功能 +# @param table_handles int64 类型 +# @param table_sizes int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param export_mode string 类型 +# @param filtered_export_flag bool 类型 +def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, bucket_sizes, export_mode='all', + filter_export_flag=False): + """ host embedding feature hash table export. """ + result = gen_npu_cpu_ops.EmbeddingHashTableExport( + table_handles=table_handles, table_sizes=table_sizes, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + export_mode=export_mode, filter_export_flag=filter_export_flag) + return result + + + ## EmbeddingHashTableApplyAdamW AdamW 更新功能 # @param table_handle int64 类型 # @param keys int64 类型 -- Gitee From 26aa112bc211e9b47762f3258c72065cdf42308b Mon Sep 17 00:00:00 2001 From: gnodli Date: Tue, 5 Nov 2024 13:47:16 +0800 Subject: [PATCH 2/4] update --- tf_adapter/ops/aicore/npu_aicore_ops.cc | 35 +++---------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 127ba7b47..f447ef019 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -67,37 +67,10 @@ REGISTER_OP("EmbeddingHashTableExport") .Attr("filter_export_flag: bool = false") .SetIsStateful() .SetShapeFn([](InferenceContext *c) { - - int32_t real_dim_num = InferenceContext::Rank(c->input(0)); - int32_t begin_norm_axis = 0; - TF_RETURN_IF_ERROR(c->GetAttr("begin_norm_axis", &begin_norm_axis)); - if (begin_norm_axis < 0) { - begin_norm_axis += real_dim_num; - } - if (begin_norm_axis < 0 || begin_norm_axis >= real_dim_num) { - return errors::InvalidArgument("begin_norm_axis is invalid"); - } - ShapeHandle input_shape_handle; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), real_dim_num, &input_shape_handle)); - ShapeHandle out_shape_handle; - for (int32_t i = 0; i < real_dim_num; ++i) { - DimensionHandle tmp_dim_handle = c->Dim(input_shape_handle, i); - if (i >= begin_norm_axis) { - tmp_dim_handle = c->MakeDim(1); - TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape_handle, i, tmp_dim_handle, &out_shape_handle)); - } - } - c->set_output(0, c->input(0)); - c->set_output(1, out_shape_handle); - c->set_output(2, out_shape_handle); - - - auto num_table_shape = c->input(1); - auto output_h_shape = c->MakeShape({input_shape, 0}); - c->set_output(0, output_keys_shape); - c->set_output(1, output_counters_shape); - c->set_output(2, output_filter_flags_shape); - c->set_output(3, output_values_shape); + c->set_output(0, c->UnknownShape()); + c->set_output(1, c->UnknownShape()); + c->set_output(2, c->UnknownShape()); + c->set_output(3, c->UnknownShape()); return Status::OK(); }); -- Gitee From 744ef17bf41261896ed09b0318ea2c65839b5a8c Mon Sep 17 00:00:00 2001 From: gnodli Date: Tue, 5 Nov 2024 14:06:17 +0800 Subject: [PATCH 3/4] fix codecheck --- tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 7b1a50666..caa573909 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -628,7 +628,6 @@ def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, buck return result - ## EmbeddingHashTableApplyAdamW AdamW 更新功能 # @param table_handle int64 类型 # @param keys int64 类型 -- Gitee From b875aa9a7644b85eb2fea54603ad63ecad2c272e Mon Sep 17 00:00:00 2001 From: gnodli Date: Tue, 5 Nov 2024 14:11:49 +0800 Subject: [PATCH 4/4] fix codecheck --- tf_adapter/ops/aicore/npu_aicore_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index f447ef019..b74a9680e 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -63,7 +63,7 @@ REGISTER_OP("EmbeddingHashTableExport") .Output("counters: uint64") .Output("filter_flags: uint8") .Output("values: float") - .Attr("export_mode: string = all") + .Attr("export_mode: string = 'all'") .Attr("filter_export_flag: bool = false") .SetIsStateful() .SetShapeFn([](InferenceContext *c) { -- Gitee