From ce337412cd6752a3e491df8a804533c6426358ea Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Mon, 22 Nov 2021 16:26:33 +0800 Subject: [PATCH 01/12] HGEMM operator --- tf_adapter/kernels/h_gemm.cc | 32 +++++++++ tf_adapter/ops/npu_ops.cc | 13 ++++ .../python/npu_bridge/estimator/npu_ops.py | 30 ++++++++ .../tests/ut/kernels/testcase/h_gemm_test.cc | 70 +++++++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 tf_adapter/kernels/h_gemm.cc create mode 100644 tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc diff --git a/tf_adapter/kernels/h_gemm.cc b/tf_adapter/kernels/h_gemm.cc new file mode 100644 index 000000000..ab28c08a8 --- /dev/null +++ b/tf_adapter/kernels/h_gemm.cc @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tf_adapter/common/adp_logger.h" + +namespace tensorflow { +class HGEMMOp : public OpKernel { + public: + explicit HGEMMOp(OpKernelConstruction *context) : OpKernel(context) {} + ~HGEMMOp() override = default; + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "HGEMMOp Compute "; + } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("HGEMM").Device(DEVICE_CPU), HGEMMOp); +} // namespace tensorflow diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index 1da33c991..76d920d4c 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -522,5 +522,18 @@ REGISTER_OP("KMeansCentroidsV2") c->set_output(2, c->MakeShape({1,})); return Status::OK(); }); + +REGISTER_OP("HGEMM") + .Input("panel_data: Ref(float16)") + .Input("l_offset: T") + .Input("u_offset: T") + .Input("b_offset: T") + .Output("output_panel_data: Ref(float16)") + .Attr("T: {int32, int64}") + .Attr("block_size: int = 2048") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); } // namespace } // namespace tensorflow diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index 43b326742..e8596817f 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -313,3 +313,33 @@ def k_means_centroids(x, y, sum_square_y, sum_square_x, use_actual_distance=Fals def npu_onnx_graph_op(inputs, tout, model_path, name=None): output = gen_npu_ops.npu_onnx_graph_op(inputs, tout, model_path, name) return output + +def h_gemm(panel_data, l_offset, u_offset, b_offset, block_size=2048, name=None): + """h_gemm. + + Args: + panel_data: a variable tensor with type float16. + l_offset: a tensor with type int32 or int64. + u_offset: a tensor with type int32 or int64. + b_offset: a tensor with type int32 or int64. + block_size: block size of lower or upper triangular matrix with type int. + name: operator name. + + Returns: + A tensor. + """ + if context.executing_eagerly(): + raise RuntimeError("npu_ops.h_gemm() is not compatible with " + "eager execution.") + + l_offset = ops.convert_to_tensor(l_offset, name="l_offset") + b_offset = ops.convert_to_tensor(b_offset, name="b_offset") + u_offset = ops.convert_to_tensor(u_offset, name="u_offset") + + if not panel_data.dtype._is_ref_dtype: + raise TypeError("'HGEMM' op requires that input 'panel_data' " + "be a mutable tensor (e.g.: a tf.Variable)") + + result = gen_npu_ops.HGEMM(panel_data=panel_data, l_offset=l_offset, u_offset=u_offset, + b_offset=b_offset, block_size=block_size, name=name) + return result diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc new file mode 100644 index 000000000..605358ef7 --- /dev/null +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -0,0 +1,70 @@ +#include +#include "tf_adapter/kernels/h_gemm.cc" +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(HGEMMOpTest, TestHGEMM) { + DataTypeSlice input_types({DT_FLOAT16, DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_FLOAT16}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + HGEMMOp h_gemm(context); + OpKernelContext *ctx = nullptr; + h_gemm.Compute(ctx); + h_gemm.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(HGEMMOpTest, TestHGEMMShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_INT32) + .Attr("block_size", 2048) + .Input(FakeInputStub(DT_FLOAT16)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({4096, 4096}), TShape({1,}), TShape({1,}), TShape({1,})}, {}); + std::vector input_shapes; + TF_CHECK_OK(reg->shape_inference_fn(&c)); + ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); +} +} // namespace +} // namespace tensorflow \ No newline at end of file -- Gitee From d4d01b553f1b613dae6ad1c3cdca93c956aec558 Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Mon, 22 Nov 2021 16:50:19 +0800 Subject: [PATCH 02/12] fix some bug --- tf_adapter/kernels/h_gemm.cc | 2 +- tf_adapter/python/npu_bridge/estimator/npu_ops.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_adapter/kernels/h_gemm.cc b/tf_adapter/kernels/h_gemm.cc index ab28c08a8..f15b6fa8a 100644 --- a/tf_adapter/kernels/h_gemm.cc +++ b/tf_adapter/kernels/h_gemm.cc @@ -1,5 +1,5 @@ /* - * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index e8596817f..d0e92255d 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -333,11 +333,11 @@ def h_gemm(panel_data, l_offset, u_offset, b_offset, block_size=2048, name=None) "eager execution.") l_offset = ops.convert_to_tensor(l_offset, name="l_offset") - b_offset = ops.convert_to_tensor(b_offset, name="b_offset") u_offset = ops.convert_to_tensor(u_offset, name="u_offset") + b_offset = ops.convert_to_tensor(b_offset, name="b_offset") if not panel_data.dtype._is_ref_dtype: - raise TypeError("'HGEMM' op requires that input 'panel_data' " + raise TypeError("'HGEMM' Op requires that input 'panel_data' " "be a mutable tensor (e.g.: a tf.Variable)") result = gen_npu_ops.HGEMM(panel_data=panel_data, l_offset=l_offset, u_offset=u_offset, -- Gitee From be2e5e3d6ce61bbe89455ab0269ab5dc8c22d677 Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Mon, 22 Nov 2021 17:21:00 +0800 Subject: [PATCH 03/12] fake code smell --- tf_adapter/kernels/h_gemm.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/kernels/h_gemm.cc b/tf_adapter/kernels/h_gemm.cc index f15b6fa8a..6f153be4f 100644 --- a/tf_adapter/kernels/h_gemm.cc +++ b/tf_adapter/kernels/h_gemm.cc @@ -19,7 +19,7 @@ namespace tensorflow { class HGEMMOp : public OpKernel { - public: +public: explicit HGEMMOp(OpKernelConstruction *context) : OpKernel(context) {} ~HGEMMOp() override = default; void Compute(OpKernelContext *context) override { -- Gitee From 6c7e3bc4d9f2a4ac352b0df47233dbca126628d4 Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Mon, 22 Nov 2021 22:03:19 +0800 Subject: [PATCH 04/12] fix some bug --- tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index 605358ef7..bbf44a578 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -28,9 +28,9 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { - DataTypeSlice input_types({DT_FLOAT16, DT_INT32, DT_INT32, DT_INT32}); + DataTypeSlice input_types({DT_HALF, DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_FLOAT16}); + DataTypeSlice output_types({DT_HALF}); MemoryTypeSlice output_memory_types; DeviceBase *device = new DeviceBase(Env::Default()); NodeDef *node_def = new NodeDef(); @@ -56,7 +56,7 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Attr("T", DT_INT32) .Attr("block_size", 2048) - .Input(FakeInputStub(DT_FLOAT16)) + .Input(FakeInputStub(DT_HALF)) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) -- Gitee From 3d9dc63ad794b80e5ada4f46485d9b768b6aae73 Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Tue, 23 Nov 2021 09:37:25 +0800 Subject: [PATCH 05/12] fix some bug --- tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index bbf44a578..51a073e8a 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -61,7 +61,7 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def,{TShape({4096, 4096}), TShape({1,}), TShape({1,}), TShape({1,})}, {}); + shape_inference::InferenceContext c(0, &def, op_def, {TShape({4096, 4096}), TShape({1,}), TShape({1,}), TShape({1,})}, {}, {}, {}); std::vector input_shapes; TF_CHECK_OK(reg->shape_inference_fn(&c)); ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); -- Gitee From b569262334a5f7f01d56d6313e6fefa8288247cf Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Tue, 23 Nov 2021 11:17:34 +0800 Subject: [PATCH 06/12] try to fix half_ref --- tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index 51a073e8a..f963b8e1a 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -6,6 +6,7 @@ #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/platform/test.h" @@ -28,9 +29,9 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { - DataTypeSlice input_types({DT_HALF, DT_INT32, DT_INT32, DT_INT32}); + DataTypeSlice input_types({Ref(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_HALF}); + DataTypeSlice output_types({Ref(DT_HALF)}); MemoryTypeSlice output_memory_types; DeviceBase *device = new DeviceBase(Env::Default()); NodeDef *node_def = new NodeDef(); @@ -56,7 +57,7 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Attr("T", DT_INT32) .Attr("block_size", 2048) - .Input(FakeInputStub(DT_HALF)) + .Input(FakeInputStub(Ref(DT_HALF))) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) -- Gitee From a6e2450e0b3d39e0217498e938fa8e5cb0964e21 Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Tue, 23 Nov 2021 11:29:37 +0800 Subject: [PATCH 07/12] try to fix half_ref --- tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index f963b8e1a..a72f09374 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -29,9 +29,9 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { - DataTypeSlice input_types({Ref(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); + DataTypeSlice input_types({ref(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({Ref(DT_HALF)}); + DataTypeSlice output_types({ref(DT_HALF)}); MemoryTypeSlice output_memory_types; DeviceBase *device = new DeviceBase(Env::Default()); NodeDef *node_def = new NodeDef(); @@ -57,7 +57,7 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Attr("T", DT_INT32) .Attr("block_size", 2048) - .Input(FakeInputStub(Ref(DT_HALF))) + .Input(FakeInputStub(ref(DT_HALF))) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) -- Gitee From e7e6dd8bb0cb3bdc277448cf98f5fc957f5eb84a Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Tue, 23 Nov 2021 14:06:51 +0800 Subject: [PATCH 08/12] fix bug half_ref --- tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index a72f09374..9dc75dcab 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -6,9 +6,9 @@ #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -29,9 +29,9 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { - DataTypeSlice input_types({ref(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); + DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({ref(DT_HALF)}); + DataTypeSlice output_types({MakeRefType(DT_HALF)}); MemoryTypeSlice output_memory_types; DeviceBase *device = new DeviceBase(Env::Default()); NodeDef *node_def = new NodeDef(); @@ -57,7 +57,7 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) .Attr("T", DT_INT32) .Attr("block_size", 2048) - .Input(FakeInputStub(ref(DT_HALF))) + .Input(FakeInputStub(MakeRefType(DT_HALF))) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) .Input(FakeInputStub(DT_INT32)) -- Gitee From d0114cff445045372e0d615b43e20f3d11eb23cb Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Tue, 23 Nov 2021 15:15:58 +0800 Subject: [PATCH 09/12] st coverage debug --- tf_adapter/python/npu_bridge/estimator/npu_ops.py | 1 + tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 2 ++ 2 files changed, 3 insertions(+) diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index d0e92255d..87b35dfdc 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -314,6 +314,7 @@ def npu_onnx_graph_op(inputs, tout, model_path, name=None): output = gen_npu_ops.npu_onnx_graph_op(inputs, tout, model_path, name) return output + def h_gemm(panel_data, l_offset, u_offset, b_offset, block_size=2048, name=None): """h_gemm. diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index 9dc75dcab..641583b39 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -29,6 +29,7 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { + std::cout << "[DEBUG dsh] TestHGEMM" << std::endl; DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({MakeRefType(DT_HALF)}); @@ -50,6 +51,7 @@ TEST(HGEMMOpTest, TestHGEMM) { } TEST(HGEMMOpTest, TestHGEMMShapeInference) { + std::cout << "[DEBUG dsh] TestHGEMMShapeInference" << std::endl; const OpRegistrationData* reg; TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); OpDef op_def = reg->op_def; -- Gitee From 90ec83678e82dbf17dded15e3cb3cd2ff42fcb9f Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Tue, 23 Nov 2021 20:31:38 +0800 Subject: [PATCH 10/12] add st testcase --- .../tests/st/kernels/testcase/h_gemm_test.cc | 75 +++++++++++++++++++ .../tests/ut/kernels/testcase/h_gemm_test.cc | 2 + 2 files changed, 77 insertions(+) create mode 100644 tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc diff --git a/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc new file mode 100644 index 000000000..056737c5d --- /dev/null +++ b/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc @@ -0,0 +1,75 @@ +#include +#include "tf_adapter/kernels/h_gemm.cc" +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(HGEMMOpTest, TestHGEMM) { + std::cout << "[DEBUG dsh] TestHGEMM" << std::endl; + DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({MakeRefType(DT_HALF)}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + HGEMMOp h_gemm(context); + OpKernelContext *ctx = nullptr; + h_gemm.Compute(ctx); + h_gemm.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; + std::cout << "[DEBUG dsh] TestHGEMM END" << std::endl; +} + +TEST(HGEMMOpTest, TestHGEMMShapeInference) { + std::cout << "[DEBUG dsh] TestHGEMMShapeInference" << std::endl; + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_INT32) + .Attr("block_size", 2048) + .Input(FakeInputStub(MakeRefType(DT_HALF))) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, {TShape({4096, 4096}), TShape({1,}), TShape({1,}), TShape({1,})}, {}, {}, {}); + std::vector input_shapes; + TF_CHECK_OK(reg->shape_inference_fn(&c)); + ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); + std::cout << "[DEBUG dsh] TestHGEMMShapeInference END" << std::endl; +} +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index 641583b39..056737c5d 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -48,6 +48,7 @@ TEST(HGEMMOpTest, TestHGEMM) { delete node_def; delete op_def; delete context; + std::cout << "[DEBUG dsh] TestHGEMM END" << std::endl; } TEST(HGEMMOpTest, TestHGEMMShapeInference) { @@ -68,6 +69,7 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { std::vector input_shapes; TF_CHECK_OK(reg->shape_inference_fn(&c)); ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); + std::cout << "[DEBUG dsh] TestHGEMMShapeInference END" << std::endl; } } // namespace } // namespace tensorflow \ No newline at end of file -- Gitee From 909b6daebb43cdcdc579b87ca170c314b28718bf Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Wed, 24 Nov 2021 10:12:12 +0800 Subject: [PATCH 11/12] delete std::cout --- tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc index 056737c5d..9dc75dcab 100644 --- a/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc @@ -29,7 +29,6 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { - std::cout << "[DEBUG dsh] TestHGEMM" << std::endl; DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({MakeRefType(DT_HALF)}); @@ -48,11 +47,9 @@ TEST(HGEMMOpTest, TestHGEMM) { delete node_def; delete op_def; delete context; - std::cout << "[DEBUG dsh] TestHGEMM END" << std::endl; } TEST(HGEMMOpTest, TestHGEMMShapeInference) { - std::cout << "[DEBUG dsh] TestHGEMMShapeInference" << std::endl; const OpRegistrationData* reg; TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); OpDef op_def = reg->op_def; @@ -69,7 +66,6 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { std::vector input_shapes; TF_CHECK_OK(reg->shape_inference_fn(&c)); ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); - std::cout << "[DEBUG dsh] TestHGEMMShapeInference END" << std::endl; } } // namespace } // namespace tensorflow \ No newline at end of file -- Gitee From 9368c3666cc9ac324fd12b71bcb454c328296c47 Mon Sep 17 00:00:00 2001 From: stormchasingg Date: Wed, 24 Nov 2021 10:17:04 +0800 Subject: [PATCH 12/12] delete std::cout --- tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc index 056737c5d..9dc75dcab 100644 --- a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -29,7 +29,6 @@ FakeInputFunctor FakeInputStub(DataType dt) { } TEST(HGEMMOpTest, TestHGEMM) { - std::cout << "[DEBUG dsh] TestHGEMM" << std::endl; DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); MemoryTypeSlice input_memory_types; DataTypeSlice output_types({MakeRefType(DT_HALF)}); @@ -48,11 +47,9 @@ TEST(HGEMMOpTest, TestHGEMM) { delete node_def; delete op_def; delete context; - std::cout << "[DEBUG dsh] TestHGEMM END" << std::endl; } TEST(HGEMMOpTest, TestHGEMMShapeInference) { - std::cout << "[DEBUG dsh] TestHGEMMShapeInference" << std::endl; const OpRegistrationData* reg; TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); OpDef op_def = reg->op_def; @@ -69,7 +66,6 @@ TEST(HGEMMOpTest, TestHGEMMShapeInference) { std::vector input_shapes; TF_CHECK_OK(reg->shape_inference_fn(&c)); ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); - std::cout << "[DEBUG dsh] TestHGEMMShapeInference END" << std::endl; } } // namespace } // namespace tensorflow \ No newline at end of file -- Gitee