From d0b71697b995689cc94987bcb2f3565edf37c3b6 Mon Sep 17 00:00:00 2001 From: peng-peng-zhang Date: Tue, 28 Feb 2023 17:06:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9EmbeddingTableFind=20infersha?= =?UTF-8?q?pe?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/ops/aicpu/npu_cpu_ops.cc | 9 +++++-- .../st/kernels/testcase/npu_cpu_ops_test.cc | 25 ++++++++++++++++--- .../ut/kernels/testcase/npu_cpu_ops_test.cc | 22 ++++++++++++++-- 3 files changed, 49 insertions(+), 7 deletions(-) diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index ad5432292..f0f621fb9 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -254,8 +254,13 @@ REGISTER_OP("EmbeddingTableFind") .Output("values: float32") .Attr("embedding_dim: int = 0") .SetShapeFn([](shape_inference::InferenceContext *c) { - auto data_shape = c->input(0); - c->set_output(0, data_shape); + ShapeHandle keys_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys_shape)); + int embedding_dim; + if (!c->GetAttr("embedding_dim", &embedding_dim).ok()) { + return errors::InvalidArgument("Invalid embedding_dim"); + } + c->set_output(0, c->Matrix(c->Dim(keys_shape, 0), embedding_dim)); return Status::OK(); }); diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index a6fa3eaf9..af9ea8b0c 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -86,8 +86,8 @@ TEST(EmbeddingOpsTest, TestInitEmbeddingHashmap) { delete context; } -TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { - DataTypeSlice input_types({DT_UINT32}); +TEST(EmbeddingOpsTest, TestEmbeddingTableFind01) { + DataTypeSlice input_types({DT_INT32, DT_INT64}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({DT_FLOAT}); MemoryTypeSlice output_memory_types; @@ -100,12 +100,31 @@ TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { EmbeddingTableFindOp cache(context); OpKernelContext *ctx = nullptr; cache.Compute(ctx); + delete device; delete node_def; delete op_def; delete context; } +TEST(EmbeddingOpsTest, TestEmbeddingTableFind02) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingTableFind", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("embedding_dim", 4) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT64)) + .Finalize(&def)); + + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({1}), TShape({16})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + TEST(EmbeddingOpsTest, TestEmbeddingTableImport) { DataTypeSlice input_types({DT_UINT32}); MemoryTypeSlice input_memory_types; @@ -291,7 +310,7 @@ TEST(EmbeddingOpsTest, TestEmbeddingFeatureMappingShapeInfer) { TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingFeatureMapping", ®)); OpDef op_def = reg->op_def; NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Input(FakeInputStub(DT_INT64)) .Finalize(&def)); shape_inference::InferenceContext c(0, &def, op_def,{TShape({2, 2, 3, 4})}, {}, {}, {}); diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index 850d2b738..6f4ae8cbc 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -106,7 +106,7 @@ TEST(EmbeddingOpsTest, TestInitEmbeddingHashmap) { delete context; } -TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { +TEST(EmbeddingOpsTest, TestEmbeddingTableFind01) { DataTypeSlice input_types({DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({DT_FLOAT}); @@ -126,6 +126,24 @@ TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { delete context; } +TEST(EmbeddingOpsTest, TestEmbeddingTableFind02) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingTableFind", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("embedding_dim", 4) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT64)) + .Finalize(&def)); + + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({1}), TShape({16})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + TEST(EmbeddingOpsTest, TestEmbeddingTableImport) { DataTypeSlice input_types({DT_INT32}); MemoryTypeSlice input_memory_types; @@ -310,7 +328,7 @@ TEST(EmbeddingOpsTest, TestEmbeddingFeatureMappingShapeInfer) { TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingFeatureMapping", ®)); OpDef op_def = reg->op_def; NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Input(FakeInputStub(DT_INT64)) .Finalize(&def)); shape_inference::InferenceContext c(0, &def, op_def,{TShape({2, 2, 3, 4})}, {}, {}, {}); -- Gitee