diff --git a/tf_adapter/kernels/aicore/npu_aicore_ops.cc b/tf_adapter/kernels/aicore/npu_aicore_ops.cc index 83a027b6ea2e367f2d36cf86c431216bc6bcc8e5..74ec03259ba2ba957501718bb3637683114adbf5 100644 --- a/tf_adapter/kernels/aicore/npu_aicore_ops.cc +++ b/tf_adapter/kernels/aicore/npu_aicore_ops.cc @@ -43,6 +43,15 @@ class FastGeluOp : public tensorflow::OpKernel { } }; +class EmbeddingHashTableImportOp : public tensorflow::OpKernel { +public: + explicit EmbeddingHashTableImportOp(tensorflow::OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashTableImportOp() override {} + void Compute(tensorflow::OpKernelContext *context) override {} +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashTableImport").Device(tensorflow::DEVICE_CPU), EmbeddingHashTableImportOp); + REGISTER_KERNEL_BUILDER( Name("FastGelu") . diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index fbe6035bdb24c97067cf9517a9dd736a31070c3d..817370440832f1c217809a1aa9fe636203c91e77 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -44,6 +44,16 @@ REGISTER_OP("FastGeluGrad") .Attr("T: realnumbertype") .SetShapeFn(tensorflow::shape_inference::MergeBothInputsShapeFn); +REGISTER_OP("EmbeddingHashTableImport") + .Input("table_handles: int64") + .Input("embedding_dims: int64") + .Input("bucket_sizes: int64") + .Input("keys: int64") + .Input("counters: uint64") + .Input("filter_flags: uint8") + .Input("values: float") + .SetShapeFn(tensorflow::shape_inference::NoOutputs); + REGISTER_OP("DynamicGruV2") .Input("x: T") .Input("weight_input: T") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index a29a93c5e7a14100c15157a04c166594cd01763d..adc196e91dd17f9ce82d9a2624c711048b756090 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -557,3 +557,19 @@ def embedding_feature_mapping_import(file_path, table_ids, table_sizes, table_na file_path=file_path, table_ids=table_ids, table_sizes=table_sizes, table_names=table_names, global_step=global_step, embedding_dims=embedding_dims, num=num) return result + + +## 提供host侧hashmap导入功能 +# @param table_handles int64 类型 +# @param embedding_dims int64 类型 +# @param bucket_sizes int64 类型 +# @param keys int64 类型 +# @param counters uint64 类型 +# @param filter_flags uint8 类型 +# @param values float 类型 +def embedding_hash_table_import(table_handles, embedding_dims, bucket_sizes, keys, counters, filter_flags, values): + """ host embedding feature hash table import. """ + result = gen_npu_cpu_ops.EmbeddingHashTableImport( + table_handles=table_handles, embedding_dims=embedding_dims, bucket_sizes=bucket_sizes, + keys=keys, counters=counters, filter_flags=filter_flags, values=values) + return result \ No newline at end of file