diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index 3c6c5adb0bb1a9f2e0f6d8fbb739d4d9ccf443d1..49b126b43370ac4a90b1abc1ab7698309a19c51c 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -53,6 +53,16 @@ public: REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport") .Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); +class EmbeddingHashTableExportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableExportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableExportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableExport") +.Device(tensorflow::DEVICE_CPU), EmbeddingHashTableExportOp); + class EmbeddingHashTableLookupOrInsertOp : public tensorflow::OpKernel { public: @@ -141,7 +151,7 @@ public: }; REGISTER_KERNEL_BUILDER(Name("InitEmbeddingHashTable").Device(tensorflow::DEVICE_CPU), InitEmbeddingHashTableOp); - + class EmbeddingHashTableApplyAdamWOp : public tensorflow::OpKernel { public: explicit EmbeddingHashTableApplyAdamWOp(tensorflow::OpKernelConstruction *context) diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 166956fcbfcad85a0e319981250a071c43517d5d..b74a9680e84f4371ab71200a882057d2ef956a4e 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -54,6 +54,26 @@ REGISTER_OP("EmbeddingHashTableImport") .Input("values: float") .SetShapeFn(tensorflow::shape_inference::NoOutputs); +REGISTER_OP("EmbeddingHashTableExport") + .Input("table_handles: int64") + .Input("table_sizes: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Output("keys: int64") + .Output("counters: uint64") + .Output("filter_flags: uint8") + .Output("values: float") + .Attr("export_mode: string = 'all'") + .Attr("filter_export_flag: bool = false") + .SetIsStateful() + .SetShapeFn([](InferenceContext *c) { + c->set_output(0, c->UnknownShape()); + c->set_output(1, c->UnknownShape()); + c->set_output(2, c->UnknownShape()); + c->set_output(3, c->UnknownShape()); + return Status::OK(); + }); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 301e6491316d196683935853cc650c47483f66e0..caa5739097780cf634c749512e7881b9cfa968f7 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -612,6 +612,22 @@ def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, key return result +## 提供host侧hashTable导出功能 +# @param table_handles int64 类型 +# @param table_sizes int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param export_mode string 类型 +# @param filtered_export_flag bool 类型 +def embedding_hash_table_export(table_handles, table_sizes, embedding_dims, bucket_sizes, export_mode='all', + filter_export_flag=False): + """ host embedding feature hash table export. """ + result = gen_npu_cpu_ops.EmbeddingHashTableExport( + table_handles=table_handles, table_sizes=table_sizes, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + export_mode=export_mode, filter_export_flag=filter_export_flag) + return result + + ## EmbeddingHashTableApplyAdamW AdamW 更新功能 # @param table_handle int64 类型 # @param keys int64 类型