diff --git a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc index 9bb7a953c708abc9f3f8b55261e4a41ae7971c23..e5b29206d1914c9d09e7ecce868f3ebf54917045 100644 --- a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc @@ -410,6 +410,55 @@ public: void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingFeatureMappingInsertOp Compute"; } }; +class InitEmbeddingHashmapV2Op : public OpKernel { +public: + explicit InitEmbeddingHashmapV2Op(OpKernelConstruction *context) : OpKernel(context) {} + ~InitEmbeddingHashmapV2Op() override {} + void Compute(OpKernelContext *context) override {} +}; + +class DeinitEmbeddingHashmapV2Op : public OpKernel { +public: + explicit DeinitEmbeddingHashmapV2Op(OpKernelConstruction *context) : OpKernel(context) {} + ~DeinitEmbeddingHashmapV2Op() override {} + void Compute(OpKernelContext *context) override {} +}; + +class TableToResourceV2Op : public OpKernel { +public: + explicit TableToResourceV2Op(OpKernelConstruction *context) : OpKernel(context) {} + ~TableToResourceV2Op() override {} + void Compute(OpKernelContext *context) override {} +}; + +class EmbeddingHashmapExportOp : public OpKernel { +public: + explicit EmbeddingHashmapExportOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashmapExportOp() override {} + void Compute(OpKernelContext *context) override {} +}; + +class EmbeddingHashmapSizeOp : public OpKernel { +public: + explicit EmbeddingHashmapSizeOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashmapSizeOp() override {} + void Compute(OpKernelContext *context) override {} +}; + +class EmbeddingHashmapFileSizeOp : public OpKernel { +public: + explicit EmbeddingHashmapFileSizeOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashmapFileSizeOp() override {} + void Compute(OpKernelContext *context) override {} +}; + +class EmbeddingHashmapImportOp : public OpKernel { +public: + explicit EmbeddingHashmapImportOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingHashmapImportOp() override {} + void Compute(OpKernelContext *context) override {} +}; + REGISTER_KERNEL_BUILDER(Name("ScatterElementsV2").Device(DEVICE_CPU), ScatterElementsV2Op); REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); @@ -458,6 +507,13 @@ REGISTER_KERNEL_BUILDER(Name("EmbeddingFeatureMappingExport").Device(DEVICE_CPU) REGISTER_KERNEL_BUILDER(Name("EmbeddingFeatureMappingFileSize").Device(DEVICE_CPU), EmbeddingFeatureMappingFileSizeOp); REGISTER_KERNEL_BUILDER(Name("EmbeddingFeatureMappingImport").Device(DEVICE_CPU), EmbeddingFeatureMappingImportOp); REGISTER_KERNEL_BUILDER(Name("EmbeddingFeatureMappingInsert").Device(DEVICE_CPU), EmbeddingFeatureMappingInsertOp); +REGISTER_KERNEL_BUILDER(Name("InitEmbeddingHashmapV2").Device(DEVICE_CPU), InitEmbeddingHashmapV2Op); +REGISTER_KERNEL_BUILDER(Name("DeinitEmbeddingHashmapV2").Device(DEVICE_CPU), DeinitEmbeddingHashmapV2Op); +REGISTER_KERNEL_BUILDER(Name("TableToResourceV2").Device(DEVICE_CPU), TableToResourceV2Op); +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashmapExport").Device(DEVICE_CPU), EmbeddingHashmapExportOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashmapSize").Device(DEVICE_CPU), EmbeddingHashmapSizeOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashmapFileSize").Device(DEVICE_CPU), EmbeddingHashmapFileSizeOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingHashmapImport").Device(DEVICE_CPU), EmbeddingHashmapImportOp); class DecodeImageV3Op : public OpKernel { public: diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index bc017af0a06292670f2f5e4b768cc827a0340231..6fb60d2427f4335c211dd98fd430ed2351327b87 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -587,6 +587,91 @@ REGISTER_OP("EmbeddingFeatureMappingInsert") .Input("offset_id: int32") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("InitEmbeddingHashmapV2") + .Input("table_id: int32") + .Output("table_handle: int64") + .Attr("bucket_size: int") + .Attr("load_factor: int") + .Attr("embedding_dim: int") + .Attr("dtype: type = DT_FLOAT") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("DeinitEmbeddingHashmapV2") + .Input("table_id: int32") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("TableToResourceV2") + .Input("table_id: int32") + .Output("table_handle: int64") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingHashmapExport") + .Input("file_path: string") + .Input("table_ids: int32") + .Input("table_names: string") + .Input("global_step: TStep") + .Input("keys: num * int64") + .Input("counts: num * int64") + .Input("filter_flags: num * uint8") + .Input("values: num * float32") + .Attr("num: int >= 1") + .Attr("TStep: {int32, int64}") + .SetShapeFn(shape_inference::NoOutputs); + +REGISTER_OP("EmbeddingHashmapSize") + .Input("table_ids: int32") + .Output("table_sizes: int64") + .Attr("filter_export_flag: bool = false") + .Attr("export_mode: {'all', 'old', 'new', 'specifiednew'} = 'all'") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingHashmapFileSize") + .Input("file_path: string") + .Input("table_ids: int32") + .Input("table_names: string") + .Input("global_step: TStep") + .Output("table_sizes: int64") + .Attr("embedding_dims: list(int)") + .Attr("TStep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(1)); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingHashmapImport") + .Input("file_path: string") + .Input("table_ids: int32") + .Input("table_sizes: int64") + .Input("table_names: string") + .Input("global_step: TStep") + .Output("keys: num * int64") + .Output("counts: num * int64") + .Output("filter_flags: num * uint8") + .Output("values: num * float32") + .Attr("embedding_dims: list(int)") + .Attr("num: int >= 1") + .Attr("TStep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + int64 num = 0; + c->GetAttr("num", &num); + for (int64_t i = 0; i < num; ++i) { + c->set_output(i, c->Vector(c->UnknownDim())); + c->set_output(i + num, c->Vector(c->UnknownDim())); + c->set_output(i + 2 * num, c->Vector(c->UnknownDim())); + c->set_output(i + 3 * num, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + REGISTER_OP("HostFeatureMapping") .Input("feature_id: int64") .Output("offset_id: int64") diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index 9a02da6ab1b28156155b70985769954f08fc0f0b..4ac7ae1ab037eaf9a2eb88e253bfa6fc70231d7d 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -459,3 +459,95 @@ def host_feature_mapping_import(path): """ host feature mapping export. """ result = gen_npu_cpu_ops.FeatureMappingImport(path=path) return result + + +## 提供device侧初始化hashmap表功能 +# @param table_id int32 类型 +# @param bucket_size int64 类型 +# @param load_factor int64 类型 +# @param embedding_dim int64 类型 +# @param dtype type 类型 +# @return table_handle int64 类型 +def init_embedding_hashmap_v2(table_id, bucket_size, load_factor, embedding_dim, dtype): + """ device init embedding hashmap v2. """ + result = gen_npu_cpu_ops.InitEmbeddingHashmapV2( + table_id=table_id, bucket_size=bucket_size, + load_factor=load_factor, embedding_dim=embedding_dim, dtype=dtype) + return result + + +## 提供device侧去初始化hashmap表功能 +# @param table_id int32 类型 +def deinit_embedding_hashmap_v2(table_id): + """ device deinit embedding hashmap v2. """ + gen_npu_cpu_ops.DeinitEmbeddingHashmapV2(table_id=table_id) + + +## 提供device侧hashmap表映射功能 +# @param table_id int32 类型 +# @return table_handle int64 类型 +def table_to_resource_v2(table_id, bucket_size, load_factor, embedding_dim, dtype): + """ device embedding hashmap to handle. """ + result = gen_npu_cpu_ops.TableToResourceV2(table_id=table_id) + return result + + +## 提供device侧计算hashmap表大小功能 +# @param table_ids int32 类型 +# @param filter_export_flag bool 类型 +# @param export_mode string 类型 +# @return table_sizes int64 类型 +def embedding_hashmap_table_size_v2(table_ids, filter_export_flag, export_mode): + """ device embedding hashmap table size. """ + result = gen_npu_cpu_ops.EmbeddingHashmapSize( + table_ids=table_ids, filter_export_flag=filter_export_flag, export_mode=export_mode) + return result + + +## 提供host侧hashmap导出功能 +# @param file_path string 类型 +# @param table_ids int32 类型 +# @param table_names string 类型 +# @param global_step int32/int64 类型 +# @param keys int64 类型 +# @param counts int64 类型 +# @param filter_flag uint8 类型 +# @param values float32 类型 +# @param num int64 类型 +def embedding_hashmap_export_v2(file_path, table_ids, table_names, global_step, keys, counts, filter_flag, values, num): + """ host embedding hashmap export. """ + gen_npu_cpu_ops.EmbeddingHashmapExport( + file_path=file_path, table_ids=table_ids, table_names=table_names, global_step=global_step, + keys=keys, counts=counts, filter_flag=filter_flag, values=values, num=num) + + +## 提供host侧hashmap文件大小功能 +# @param file_path string 类型 +# @param table_ids int32 类型 +# @param table_names string 类型 +# @param global_step int32/int64 类型 +# @param embedding_dims int64 类型 +# @return table_sizes int64 类型 +def embedding_hashmap_file_size_v2(file_path, table_ids, table_names, global_step, embedding_dims): + """ host embedding hashmap file size. """ + result = gen_npu_cpu_ops.EmbeddingHashmapFileSize( + file_path=file_path, table_ids=table_ids, table_names=table_names, + global_step=global_step, embedding_dims=embedding_dims) + return result + + +## 提供host侧hashmap导入功能 +# @param file_path string 类型 +# @param table_ids int32 类型 +# @param table_names string 类型 +# @param table_sizes int64 类型 +# @param global_step int32/int64 类型 +# @param embedding_dims int 类型 +# @param num int64 类型 +# @return keys(int64)/counts(int64)/filter_flag(uint8)/values(float32) +def embedding_feature_mapping_import(file_path, table_ids, table_names, table_sizes, global_step, embedding_dims, num): + """ host embedding feature mapping import. """ + result = gen_npu_cpu_ops.EmbeddingHashmapImport( + file_path=file_path, table_ids=table_ids, table_names=table_names, + table_sizes=table_sizes, global_step=global_step, embedding_dims=embedding_dims, num=num) + return result diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index 90e47d5b3f87791e4c9ef473475b9f208a719528..7f994a56e7b5e3c5841ce1e89166af7b75b6961d 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -714,5 +714,75 @@ TEST(EmbeddingOpsTest, TestEmbeddingFeatureMappingFileSizeShapeInfer) { {}, {}, {}); TF_CHECK_OK(reg->shape_inference_fn(&c)); } + +TEST(EmbeddingOpsTest, InitEmbeddingHashmapV2ShapeInfer) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("InitEmbeddingHashmapV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("bucket_size", 10) + .Attr("load_factor", 80) + .Attr("embedding_dim", 2) + .Attr("dtype", DT_FLOAT) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(EmbeddingOpsTest, DeinitEmbeddingHashmapV2ShapeInfer) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("DeinitEmbeddingHashmapV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(EmbeddingOpsTest, TableToResourceV2ShapeInfer) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("TableToResourceV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(EmbeddingOpsTest, EmbeddingHashmapImportShapeInfer) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingHashmapImport", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("embedding_dim", 4) + .Attr("num", 1) + .Input(FakeInputStub(DT_STRING)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT64)) + .Input(FakeInputStub(DT_STRING)) + .Input(FakeInputStub(DT_INT64)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({}), TShape({}), TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} } } \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index 5206be4a9b46c7d181e4b8b02469e1b4e23e1c7d..cda33e074d90b7c246ad141f692a56a459680bc7 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -732,5 +732,75 @@ TEST(EmbeddingOpsTest, TestEmbeddingFeatureMappingFileSizeShapeInfer) { {}, {}, {}); TF_CHECK_OK(reg->shape_inference_fn(&c)); } + +TEST(EmbeddingOpsTest, InitEmbeddingHashmapV2ShapeInfer) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("InitEmbeddingHashmapV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("bucket_size", 10) + .Attr("load_factor", 80) + .Attr("embedding_dim", 2) + .Attr("dtype", DT_FLOAT) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(EmbeddingOpsTest, DeinitEmbeddingHashmapV2ShapeInfer) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("DeinitEmbeddingHashmapV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(EmbeddingOpsTest, TableToResourceV2ShapeInfer) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("TableToResourceV2", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(EmbeddingOpsTest, EmbeddingHashmapImportShapeInfer) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingHashmapImport", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("embedding_dim", 4) + .Attr("num", 1) + .Input(FakeInputStub(DT_STRING)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT64)) + .Input(FakeInputStub(DT_STRING)) + .Input(FakeInputStub(DT_INT64)) + .Finalize(&def)); + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({}), TShape({}), TShape({6})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} } } \ No newline at end of file