diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index ad5432292611dd0332ff1cc6d5edee64f2df815b..f0f621fb9376673548957e14a550f665079b9f3b 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -254,8 +254,13 @@ REGISTER_OP("EmbeddingTableFind") .Output("values: float32") .Attr("embedding_dim: int = 0") .SetShapeFn([](shape_inference::InferenceContext *c) { - auto data_shape = c->input(0); - c->set_output(0, data_shape); + ShapeHandle keys_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys_shape)); + int embedding_dim; + if (!c->GetAttr("embedding_dim", &embedding_dim).ok()) { + return errors::InvalidArgument("Invalid embedding_dim"); + } + c->set_output(0, c->Matrix(c->Dim(keys_shape, 0), embedding_dim)); return Status::OK(); }); diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index a6fa3eaf93c97e940af98f9be73dd9eff1f58e5f..af9ea8b0c1e5e92ee271cc846c60f910b48408e9 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -86,8 +86,8 @@ TEST(EmbeddingOpsTest, TestInitEmbeddingHashmap) { delete context; } -TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { - DataTypeSlice input_types({DT_UINT32}); +TEST(EmbeddingOpsTest, TestEmbeddingTableFind01) { + DataTypeSlice input_types({DT_INT32, DT_INT64}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({DT_FLOAT}); MemoryTypeSlice output_memory_types; @@ -100,12 +100,31 @@ TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { EmbeddingTableFindOp cache(context); OpKernelContext *ctx = nullptr; cache.Compute(ctx); + delete device; delete node_def; delete op_def; delete context; } +TEST(EmbeddingOpsTest, TestEmbeddingTableFind02) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingTableFind", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("embedding_dim", 4) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT64)) + .Finalize(&def)); + + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({1}), TShape({16})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + TEST(EmbeddingOpsTest, TestEmbeddingTableImport) { DataTypeSlice input_types({DT_UINT32}); MemoryTypeSlice input_memory_types; @@ -291,7 +310,7 @@ TEST(EmbeddingOpsTest, TestEmbeddingFeatureMappingShapeInfer) { TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingFeatureMapping", ®)); OpDef op_def = reg->op_def; NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Input(FakeInputStub(DT_INT64)) .Finalize(&def)); shape_inference::InferenceContext c(0, &def, op_def,{TShape({2, 2, 3, 4})}, {}, {}, {}); diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index 850d2b738812a2042597617b4973cf09905e7ed0..6f4ae8cbcf95fdf3526e23f87363f5813729ac2e 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -106,7 +106,7 @@ TEST(EmbeddingOpsTest, TestInitEmbeddingHashmap) { delete context; } -TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { +TEST(EmbeddingOpsTest, TestEmbeddingTableFind01) { DataTypeSlice input_types({DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({DT_FLOAT}); @@ -126,6 +126,24 @@ TEST(EmbeddingOpsTest, TestEmbeddingTableFind) { delete context; } +TEST(EmbeddingOpsTest, TestEmbeddingTableFind02) { + const OpRegistrationData *reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingTableFind", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("embedding_dim", 4) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT64)) + .Finalize(&def)); + + shape_inference::InferenceContext c( + 0, &def, op_def, + {TShape({1}), TShape({16})}, + {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + TEST(EmbeddingOpsTest, TestEmbeddingTableImport) { DataTypeSlice input_types({DT_INT32}); MemoryTypeSlice input_memory_types; @@ -310,7 +328,7 @@ TEST(EmbeddingOpsTest, TestEmbeddingFeatureMappingShapeInfer) { TF_CHECK_OK(OpRegistry::Global()->LookUp("EmbeddingFeatureMapping", ®)); OpDef op_def = reg->op_def; NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Input(FakeInputStub(DT_INT64)) .Finalize(&def)); shape_inference::InferenceContext c(0, &def, op_def,{TShape({2, 2, 3, 4})}, {}, {}, {});