diff --git a/tf_adapter/kernels/h_gemm.cc b/tf_adapter/kernels/h_gemm.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f153be4f7545add444544f94c0d8d2cbfeeac5e --- /dev/null +++ b/tf_adapter/kernels/h_gemm.cc @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2021. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tf_adapter/common/adp_logger.h" + +namespace tensorflow { +class HGEMMOp : public OpKernel { +public: + explicit HGEMMOp(OpKernelConstruction *context) : OpKernel(context) {} + ~HGEMMOp() override = default; + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "HGEMMOp Compute "; + } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("HGEMM").Device(DEVICE_CPU), HGEMMOp); +} // namespace tensorflow diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index 1da33c9914dc33d001512ed85eddb8bcb81bebcb..76d920d4ca4c572cba818d4985ce27f9dd050cfa 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -522,5 +522,18 @@ REGISTER_OP("KMeansCentroidsV2") c->set_output(2, c->MakeShape({1,})); return Status::OK(); }); + +REGISTER_OP("HGEMM") + .Input("panel_data: Ref(float16)") + .Input("l_offset: T") + .Input("u_offset: T") + .Input("b_offset: T") + .Output("output_panel_data: Ref(float16)") + .Attr("T: {int32, int64}") + .Attr("block_size: int = 2048") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); } // namespace } // namespace tensorflow diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index 43b32674283f72faea7b003de769aa66ad3d23bb..87b35dfdce70b8edc8c6651833908ea13324e346 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -313,3 +313,34 @@ def k_means_centroids(x, y, sum_square_y, sum_square_x, use_actual_distance=Fals def npu_onnx_graph_op(inputs, tout, model_path, name=None): output = gen_npu_ops.npu_onnx_graph_op(inputs, tout, model_path, name) return output + + +def h_gemm(panel_data, l_offset, u_offset, b_offset, block_size=2048, name=None): + """h_gemm. + + Args: + panel_data: a variable tensor with type float16. + l_offset: a tensor with type int32 or int64. + u_offset: a tensor with type int32 or int64. + b_offset: a tensor with type int32 or int64. + block_size: block size of lower or upper triangular matrix with type int. + name: operator name. + + Returns: + A tensor. + """ + if context.executing_eagerly(): + raise RuntimeError("npu_ops.h_gemm() is not compatible with " + "eager execution.") + + l_offset = ops.convert_to_tensor(l_offset, name="l_offset") + u_offset = ops.convert_to_tensor(u_offset, name="u_offset") + b_offset = ops.convert_to_tensor(b_offset, name="b_offset") + + if not panel_data.dtype._is_ref_dtype: + raise TypeError("'HGEMM' Op requires that input 'panel_data' " + "be a mutable tensor (e.g.: a tf.Variable)") + + result = gen_npu_ops.HGEMM(panel_data=panel_data, l_offset=l_offset, u_offset=u_offset, + b_offset=b_offset, block_size=block_size, name=name) + return result diff --git a/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9dc75dcab9212f7fd2ad0069e2f1a578c8c7b1bd --- /dev/null +++ b/tf_adapter/tests/st/kernels/testcase/h_gemm_test.cc @@ -0,0 +1,71 @@ +#include +#include "tf_adapter/kernels/h_gemm.cc" +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(HGEMMOpTest, TestHGEMM) { + DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({MakeRefType(DT_HALF)}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + HGEMMOp h_gemm(context); + OpKernelContext *ctx = nullptr; + h_gemm.Compute(ctx); + h_gemm.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(HGEMMOpTest, TestHGEMMShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_INT32) + .Attr("block_size", 2048) + .Input(FakeInputStub(MakeRefType(DT_HALF))) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, {TShape({4096, 4096}), TShape({1,}), TShape({1,}), TShape({1,})}, {}, {}, {}); + std::vector input_shapes; + TF_CHECK_OK(reg->shape_inference_fn(&c)); + ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); +} +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9dc75dcab9212f7fd2ad0069e2f1a578c8c7b1bd --- /dev/null +++ b/tf_adapter/tests/ut/kernels/testcase/h_gemm_test.cc @@ -0,0 +1,71 @@ +#include +#include "tf_adapter/kernels/h_gemm.cc" +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(HGEMMOpTest, TestHGEMM) { + DataTypeSlice input_types({MakeRefType(DT_HALF), DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({MakeRefType(DT_HALF)}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + HGEMMOp h_gemm(context); + OpKernelContext *ctx = nullptr; + h_gemm.Compute(ctx); + h_gemm.IsExpensive(); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(HGEMMOpTest, TestHGEMMShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("HGEMM", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_INT32) + .Attr("block_size", 2048) + .Input(FakeInputStub(MakeRefType(DT_HALF))) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, {TShape({4096, 4096}), TShape({1,}), TShape({1,}), TShape({1,})}, {}, {}, {}); + std::vector input_shapes; + TF_CHECK_OK(reg->shape_inference_fn(&c)); + ASSERT_EQ("[4096,4096]", c.DebugString(c.output(0))); +} +} // namespace +} // namespace tensorflow \ No newline at end of file