From 2d7212af43deb163ad869a309be6cb9f05c23add Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=9C=8B=E6=9C=8B?= Date: Sat, 28 Jan 2023 03:40:27 +0000 Subject: [PATCH 1/2] =?UTF-8?q?!1964=20EmbeddingService=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E7=AE=97=E5=AD=90=20Merge=20pull=20request?= =?UTF-8?q?=20!1964=20from=20=E5=BC=A0=E6=9C=8B=E6=9C=8B/EmbeddingService?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/aicpu/npu_cpu_ops.cc | 40 +++++++ tf_adapter/ops/aicpu/npu_cpu_ops.cc | 77 ++++++++++++++ .../st/kernels/testcase/npu_cpu_ops_test.cc | 100 ++++++++++++++++++ .../ut/kernels/testcase/npu_cpu_ops_test.cc | 99 +++++++++++++++++ 4 files changed, 316 insertions(+) diff --git a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc index 1c00ce3b6..14093961e 100644 --- a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc @@ -242,6 +242,41 @@ public: void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "UninitEmbeddingHashmapOp Compute"; } }; +class TableToResourceOp : public OpKernel { +public: + explicit TableToResourceOp(OpKernelConstruction *context) : OpKernel(context) {} + ~TableToResourceOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "TableToResourceOp Compute"; } +}; + +class EmbeddingTableFindAndInitOp : public OpKernel { +public: + explicit EmbeddingTableFindAndInitOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingTableFindAndInitOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableFindAndInitOp Compute"; } +}; + +class EmbeddingApplyAdamOp : public OpKernel { +public: + explicit EmbeddingApplyAdamOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingApplyAdamOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdamOp Compute"; } +}; + +class EmbeddingApplyAdaGradOp : public OpKernel { +public: + explicit EmbeddingApplyAdaGradOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingApplyAdaGradOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdaGradOp Compute"; } +}; + +class EmbeddingTableExportOp : public OpKernel { +public: + explicit EmbeddingTableExportOp(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingTableExportOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableExportOp Compute"; } +}; + REGISTER_KERNEL_BUILDER(Name("ScatterElementsV2").Device(DEVICE_CPU), ScatterElementsV2Op); REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); @@ -266,6 +301,11 @@ REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFind").Device(DEVICE_CPU), Embedding REGISTER_KERNEL_BUILDER(Name("EmbeddingTableImport").Device(DEVICE_CPU), EmbeddingTableImportOp); REGISTER_KERNEL_BUILDER(Name("UninitPartitionMap").Device(DEVICE_CPU), UninitPartitionMapOp); REGISTER_KERNEL_BUILDER(Name("UninitEmbeddingHashmap").Device(DEVICE_CPU), UninitEmbeddingHashmapOp); +REGISTER_KERNEL_BUILDER(Name("TableToResource").Device(DEVICE_CPU), TableToResourceOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFindAndInit").Device(DEVICE_CPU), EmbeddingTableFindAndInitOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdam").Device(DEVICE_CPU), EmbeddingApplyAdamOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdaGrad").Device(DEVICE_CPU), EmbeddingApplyAdaGradOp); +REGISTER_KERNEL_BUILDER(Name("EmbeddingTableExport").Device(DEVICE_CPU), EmbeddingTableExportOp); class DecodeImageV3Op : public OpKernel { public: diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 4d5875817..f246da40a 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -266,6 +266,83 @@ REGISTER_OP("UninitEmbeddingHashmap") .Input("table_id: int32") .SetShapeFn(shape_inference::NoOutputs); +REGISTER_OP("TableToResource") + .Input("table_id: int32") + .Output("table_handle: resource") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingTableFindAndInit") + .Input("table_id: int32") + .Input("keys: int64") + .Output("values: float32") + .Attr("embedding_dim: int = 0") + .Attr("value_total_len: int = 0") + .Attr("random_alg: string = 'random_uniform'") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetShapeFn([](shape_inference::InferenceContext *c) { + ShapeHandle keys_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys_shape)); + int embedding_dim; + if (!c->GetAttr("embedding_dim", &embedding_dim).ok()) { + return errors::InvalidArgument("Invalid embedding_dim"); + } + c->set_output(0, c->Matrix(c->Dim(keys_shape, 0), embedding_dim)); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingApplyAdam") + .Input("var_handle: resource") + .Input("beta1_power: T") + .Input("beta2_power: T") + .Input("lr: T") + .Input("beta1: T") + .Input("beta2: T") + .Input("epsilon: T") + .Input("grad: T") + .Input("keys: int64") + .Input("global_step: Tstep") + .Output("var_handle_output: resource") + .Attr("embedding_dim: int = 0") + .Attr("T: {float32, float16}") + .Attr("Tstep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingApplyAdaGrad") + .Input("var_handle: resource") + .Input("lr: T") + .Input("grad: T") + .Input("keys: int64") + .Input("global_step: Tstep") + .Output("var_handle_output: resource") + .Attr("embedding_dim: int = 0") + .Attr("T: {float32, float16}") + .Attr("Tstep: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + +REGISTER_OP("EmbeddingTableExport") + .Input("file_path: string") + .Input("file_name: string") + .Input("ps_id: int32") + .Input("table_id: int32") + .Attr("embedding_dim: int = 0") + .Attr("value_total_len: int = 0") + .Attr("only_var_flag: bool = false") + .Attr("file_type: string = 'bin' ") + .SetShapeFn(shape_inference::NoOutputs); + // regist dense image warp op REGISTER_OP("DenseImageWarp") .Input("image: T") diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index 827b714dc..96ffc9f95 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -146,4 +146,104 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } + +TEST(EmbeddingOpsTest, TestTableToResource) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + TableToResourceOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableFindAndInitOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_STRING}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_STRING}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableExportOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdamOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdaGrad) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdaGradOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index 5e9f047f3..b0c61a51b 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -166,4 +166,103 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } +TEST(EmbeddingOpsTest, TestTableToResource) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + TableToResourceOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { + DataTypeSlice input_types({DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableFindAndInitOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { + DataTypeSlice input_types({DT_STRING}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_STRING}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingTableExportOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdamOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(EmbeddingOpsTest, TestEmbeddingApplyAdaGrad) { + DataTypeSlice input_types({DT_RESOURCE}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_RESOURCE}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + EmbeddingApplyAdaGradOp cache(context); + OpKernelContext *ctx = nullptr; + cache.Compute(ctx); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file -- Gitee From 886f811f059181d7ecadef928b904aedfb2174b4 Mon Sep 17 00:00:00 2001 From: peng-peng-zhang Date: Sat, 28 Jan 2023 17:45:44 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=9B=9E=E9=80=80ls=20EmbeddingService?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=96=B0=E5=A2=9E=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tf_adapter/kernels/aicpu/npu_cpu_ops.cc | 40 ------- tf_adapter/ops/aicpu/npu_cpu_ops.cc | 77 -------------- .../st/kernels/testcase/npu_cpu_ops_test.cc | 100 ------------------ .../ut/kernels/testcase/npu_cpu_ops_test.cc | 99 ----------------- 4 files changed, 316 deletions(-) diff --git a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc index 14093961e..1c00ce3b6 100644 --- a/tf_adapter/kernels/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/kernels/aicpu/npu_cpu_ops.cc @@ -242,41 +242,6 @@ public: void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "UninitEmbeddingHashmapOp Compute"; } }; -class TableToResourceOp : public OpKernel { -public: - explicit TableToResourceOp(OpKernelConstruction *context) : OpKernel(context) {} - ~TableToResourceOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "TableToResourceOp Compute"; } -}; - -class EmbeddingTableFindAndInitOp : public OpKernel { -public: - explicit EmbeddingTableFindAndInitOp(OpKernelConstruction *context) : OpKernel(context) {} - ~EmbeddingTableFindAndInitOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableFindAndInitOp Compute"; } -}; - -class EmbeddingApplyAdamOp : public OpKernel { -public: - explicit EmbeddingApplyAdamOp(OpKernelConstruction *context) : OpKernel(context) {} - ~EmbeddingApplyAdamOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdamOp Compute"; } -}; - -class EmbeddingApplyAdaGradOp : public OpKernel { -public: - explicit EmbeddingApplyAdaGradOp(OpKernelConstruction *context) : OpKernel(context) {} - ~EmbeddingApplyAdaGradOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingApplyAdaGradOp Compute"; } -}; - -class EmbeddingTableExportOp : public OpKernel { -public: - explicit EmbeddingTableExportOp(OpKernelConstruction *context) : OpKernel(context) {} - ~EmbeddingTableExportOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingTableExportOp Compute"; } -}; - REGISTER_KERNEL_BUILDER(Name("ScatterElementsV2").Device(DEVICE_CPU), ScatterElementsV2Op); REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); @@ -301,11 +266,6 @@ REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFind").Device(DEVICE_CPU), Embedding REGISTER_KERNEL_BUILDER(Name("EmbeddingTableImport").Device(DEVICE_CPU), EmbeddingTableImportOp); REGISTER_KERNEL_BUILDER(Name("UninitPartitionMap").Device(DEVICE_CPU), UninitPartitionMapOp); REGISTER_KERNEL_BUILDER(Name("UninitEmbeddingHashmap").Device(DEVICE_CPU), UninitEmbeddingHashmapOp); -REGISTER_KERNEL_BUILDER(Name("TableToResource").Device(DEVICE_CPU), TableToResourceOp); -REGISTER_KERNEL_BUILDER(Name("EmbeddingTableFindAndInit").Device(DEVICE_CPU), EmbeddingTableFindAndInitOp); -REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdam").Device(DEVICE_CPU), EmbeddingApplyAdamOp); -REGISTER_KERNEL_BUILDER(Name("EmbeddingApplyAdaGrad").Device(DEVICE_CPU), EmbeddingApplyAdaGradOp); -REGISTER_KERNEL_BUILDER(Name("EmbeddingTableExport").Device(DEVICE_CPU), EmbeddingTableExportOp); class DecodeImageV3Op : public OpKernel { public: diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index f246da40a..4d5875817 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -266,83 +266,6 @@ REGISTER_OP("UninitEmbeddingHashmap") .Input("table_id: int32") .SetShapeFn(shape_inference::NoOutputs); -REGISTER_OP("TableToResource") - .Input("table_id: int32") - .Output("table_handle: resource") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto data_shape = c->input(0); - c->set_output(0, data_shape); - return Status::OK(); - }); - -REGISTER_OP("EmbeddingTableFindAndInit") - .Input("table_id: int32") - .Input("keys: int64") - .Output("values: float32") - .Attr("embedding_dim: int = 0") - .Attr("value_total_len: int = 0") - .Attr("random_alg: string = 'random_uniform'") - .Attr("seed: int = 0") - .Attr("seed2: int = 0") - .SetShapeFn([](shape_inference::InferenceContext *c) { - ShapeHandle keys_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys_shape)); - int embedding_dim; - if (!c->GetAttr("embedding_dim", &embedding_dim).ok()) { - return errors::InvalidArgument("Invalid embedding_dim"); - } - c->set_output(0, c->Matrix(c->Dim(keys_shape, 0), embedding_dim)); - return Status::OK(); - }); - -REGISTER_OP("EmbeddingApplyAdam") - .Input("var_handle: resource") - .Input("beta1_power: T") - .Input("beta2_power: T") - .Input("lr: T") - .Input("beta1: T") - .Input("beta2: T") - .Input("epsilon: T") - .Input("grad: T") - .Input("keys: int64") - .Input("global_step: Tstep") - .Output("var_handle_output: resource") - .Attr("embedding_dim: int = 0") - .Attr("T: {float32, float16}") - .Attr("Tstep: {int32, int64}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto data_shape = c->input(0); - c->set_output(0, data_shape); - return Status::OK(); - }); - -REGISTER_OP("EmbeddingApplyAdaGrad") - .Input("var_handle: resource") - .Input("lr: T") - .Input("grad: T") - .Input("keys: int64") - .Input("global_step: Tstep") - .Output("var_handle_output: resource") - .Attr("embedding_dim: int = 0") - .Attr("T: {float32, float16}") - .Attr("Tstep: {int32, int64}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto data_shape = c->input(0); - c->set_output(0, data_shape); - return Status::OK(); - }); - -REGISTER_OP("EmbeddingTableExport") - .Input("file_path: string") - .Input("file_name: string") - .Input("ps_id: int32") - .Input("table_id: int32") - .Attr("embedding_dim: int = 0") - .Attr("value_total_len: int = 0") - .Attr("only_var_flag: bool = false") - .Attr("file_type: string = 'bin' ") - .SetShapeFn(shape_inference::NoOutputs); - // regist dense image warp op REGISTER_OP("DenseImageWarp") .Input("image: T") diff --git a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc index 96ffc9f95..827b714dc 100644 --- a/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/npu_cpu_ops_test.cc @@ -146,104 +146,4 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } - -TEST(EmbeddingOpsTest, TestTableToResource) { - DataTypeSlice input_types({DT_INT32}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_RESOURCE}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - TableToResourceOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { - DataTypeSlice input_types({DT_INT32}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_INT32}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingTableFindAndInitOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { - DataTypeSlice input_types({DT_STRING}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_STRING}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingTableExportOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { - DataTypeSlice input_types({DT_RESOURCE}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_RESOURCE}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingApplyAdamOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingApplyAdaGrad) { - DataTypeSlice input_types({DT_RESOURCE}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_RESOURCE}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingApplyAdaGradOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} } \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index b0c61a51b..5e9f047f3 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -166,103 +166,4 @@ TEST(EmbeddingOpsTest, TestUninitEmbeddingHashmap) { delete op_def; delete context; } -TEST(EmbeddingOpsTest, TestTableToResource) { - DataTypeSlice input_types({DT_INT32}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_RESOURCE}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - TableToResourceOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingTableFindAndInit) { - DataTypeSlice input_types({DT_INT32}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_INT32}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingTableFindAndInitOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingTableExport) { - DataTypeSlice input_types({DT_STRING}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_STRING}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingTableExportOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingApplyAdam) { - DataTypeSlice input_types({DT_RESOURCE}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_RESOURCE}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingApplyAdamOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(EmbeddingOpsTest, TestEmbeddingApplyAdaGrad) { - DataTypeSlice input_types({DT_RESOURCE}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_RESOURCE}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - EmbeddingApplyAdaGradOp cache(context); - OpKernelContext *ctx = nullptr; - cache.Compute(ctx); - delete device; - delete node_def; - delete op_def; - delete context; -} } \ No newline at end of file -- Gitee