From 0183db70115ea52fc2c60043abc50159b48be6cc Mon Sep 17 00:00:00 2001 From: Codersheepchen Date: Thu, 4 Sep 2025 02:57:52 -0400 Subject: [PATCH] fused matmul --- tensorflow/core/BUILD | 15 + tensorflow/core/kernels/BUILD | 34 + .../core/kernels/kdnn_matmul_op_fused.cc | 150 ++++ .../core/kernels/kdnn_matmul_op_test.cc | 693 ++++++++++++++++++ tensorflow/core/ops/kdnn_ops.cc | 53 ++ third_party/KDNN/kdnn_adapter.h | 57 ++ 6 files changed, 1002 insertions(+) create mode 100644 tensorflow/core/kernels/kdnn_matmul_op_fused.cc create mode 100644 tensorflow/core/kernels/kdnn_matmul_op_test.cc create mode 100644 tensorflow/core/ops/kdnn_ops.cc diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 98ece453..54905cc5 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -150,6 +150,10 @@ load( "//third_party/annc:build_defs.bzl", "if_enable_annc", ) +load( + "//third_party/KDNN:build_defs.bzl", + "if_enable_kdnn", +) # Placeholder for Google-internal load statements. package( @@ -677,6 +681,13 @@ tf_gen_op_libs( deps = [":protos_all_cc"], ) +tf_gen_op_libs( + op_lib_names = [ + "kdnn_ops", + ], + deps = [":protos_all_cc"], +) + tf_gen_op_libs( op_lib_names = [ "string_ops", @@ -921,6 +932,8 @@ cc_library( "//tensorflow/compiler/tf2tensorrt:trt_op_libs", ]) + if_enable_annc([ ":embedding_fused_ops_op_lib", + ]) + if_enable_kdnn([ + ":kdnn_ops_op_lib", ]), alwayslink = 1, ) @@ -1117,6 +1130,8 @@ cc_library( "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", ]) + if_enable_annc([ "//tensorflow/core/kernels:embedding_fused_ops", + ]) + if_enable_kdnn([ + "//tensorflow/core/kernels:kdnn_matmul_op", ]), ) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f74da7e8..1f09075f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4208,6 +4208,14 @@ tf_kernel_library( copts = if_enable_kdnn(["-DENABLE_KDNN=1"]), ) +tf_kernel_library( + name = "kdnn_matmul_op", + srcs = if_enable_kdnn([ + "kdnn_matmul_op_fused.cc", + ]), + deps = MATH_DEPS + kdnn_deps(), +) + tf_mkl_kernel_library( name = "mkl_matmul_op", srcs = [ @@ -4415,6 +4423,32 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "kdnn_matmul_op_test", + size = "small", + srcs = if_enable_kdnn([ + "kdnn_matmul_op_test.cc", + ]), + deps = [ + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:client_session", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@com_google_absl//absl/algorithm:container", + ] + if_enable_kdnn([ + ":kdnn_matmul_op", + ]), +) + tf_cuda_cc_test( name = "batch_matmul_op_test", size = "small", diff --git a/tensorflow/core/kernels/kdnn_matmul_op_fused.cc b/tensorflow/core/kernels/kdnn_matmul_op_fused.cc new file mode 100644 index 00000000..829ce194 --- /dev/null +++ b/tensorflow/core/kernels/kdnn_matmul_op_fused.cc @@ -0,0 +1,150 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements matmul operations with other kernels baked into the +// processing, to optimize latency and memory usage: +// - MatMul + BiasAdd + +// Activation: Relu +// +// Currently supported only on CPU device. + +#ifndef TENSORFLOW_CORE_KERNELS_KDNN_MATMUL_OP_FUSED_H_ +#define TENSORFLOW_CORE_KERNELS_KDNN_MATMUL_OP_FUSED_H_ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include +#include + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/util/tensor_format.h" + +#include "kdnn_adapter.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +template +class KdnnFusedMatMulOp : public OpKernel { + public: + explicit KdnnFusedMatMulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("fused_ops", &fused_ops_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + OP_REQUIRES(ctx, fused_ops_.size() <= 2, + errors::InvalidArgument( + "KdnnFusedMatMul must have 2 post-arguments at most.")); + OP_REQUIRES( + ctx, fused_ops_[0] == "BiasAdd", + errors::InvalidArgument( + "The 1st post-argument of KdnnFusedMatMul must be BiasAdd.")); + OP_REQUIRES( + ctx, transpose_a_ == false, + errors::InvalidArgument("In[0] of KdnnFusedMatMul can't be transposed.")); + OP_REQUIRES( + ctx, transpose_b_ == false, + errors::InvalidArgument("In[1] of KdnnFusedMatMul can't be transposed.")); + } + + void ExtendKdnnFusedMatMulParams(OpKernelContext* ctx, + kdnnFusedType& fused_type) { + if (fused_ops_.size() == 2) { + string post_op = fused_ops_[1]; + if (post_op == "Relu") { + fused_type = kdnnFusedType::FUSED_TYPE_BIASRELU; + } else { + OP_REQUIRES_OK( + ctx, errors::InvalidArgument( + "Unsupported post-argument in MklFusedMatMul: ", post_op)); + } + } + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& a = ctx->input(0); + const Tensor& b = ctx->input(1); + const Tensor& bias = ctx->input(2); + + // Check that the dimensions of the two matrices are valid. + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(a.shape()), + errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ", + a.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsMatrix(b.shape()), + errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ", + b.shape().DebugString())); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias.shape()), + errors::InvalidArgument("Bias must be 1D")); + + OP_REQUIRES( + ctx, a.dim_size(1) == b.dim_size(0), + errors::InvalidArgument( + "Matrix size-incompatible: In[0]: ", a.shape().DebugString(), + ", In[1]: ", b.shape().DebugString())); + TensorShape out_shape( + {a.dim_size(0), b.dim_size(1)}); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); + + if (out->NumElements() == 0) { + // If a has shape [0, x] or b has shape [x, 0], the output shape + // is a 0-element matrix, so there is nothing to do. + return; + } + + if (a.NumElements() == 0 && b.NumElements() == 0) { + // If a has shape [x, 0] and b has shape [0, y], the + // output shape is [x, y] where x and y are non-zero, so we fill + // the output with zeros. + functor::SetZeroFunctor f; + f(ctx->eigen_device(), out->flat()); + return; + } + + kdnnFusedType fused_type = kdnnFusedType::FUSED_TYPE_BIAS; + this->ExtendKdnnFusedMatMulParams(ctx, fused_type); + kdnnFusedGemm(ctx, a, b, bias, out, false, false, fused_type); + return; + } + + private: + bool transpose_a_; + bool transpose_b_; + + std::vector fused_ops_; + + TF_DISALLOW_COPY_AND_ASSIGN(KdnnFusedMatMulOp); +}; + +// Registration of the CPU implementations. +#define REGISTER_FUSED_CPU_MATMUL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("_KdnnFusedMatMul").Device(DEVICE_CPU).TypeConstraint("T"), \ + KdnnFusedMatMulOp); + +TF_CALL_float(REGISTER_FUSED_CPU_MATMUL); + +#undef REGISTER_FUSED_CPU_MATMUL + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_KDNN_MATMUL_OP_FUSED_H_ diff --git a/tensorflow/core/kernels/kdnn_matmul_op_test.cc b/tensorflow/core/kernels/kdnn_matmul_op_test.cc new file mode 100644 index 00000000..592bab1a --- /dev/null +++ b/tensorflow/core/kernels/kdnn_matmul_op_test.cc @@ -0,0 +1,693 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/algorithm/container.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +template +class FusedMatMulOpTest : public OpsTestBase { + protected: + using BiasAddGraphRunner = + std::function; + + // Runs a Tensorflow graph defined by the root scope, and fetches the result + // of 'fetch' node into the output Tensor. Optional `fetch_node` parameter + // allows to define a fetch node directly using a NodeDef for the ops that are + // not supported by the C++ Api. + void RunAndFetch(const tensorflow::Scope& root, const string& fetch, + Tensor* output, bool allow_gpu_device, + const NodeDef* fetch_node = nullptr) { + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + if (fetch_node) { + *graph.add_node() = *fetch_node; + } + + // We really want to make sure that graph executed exactly as we passed it + // to the session, so we disable various optimizations. + tensorflow::SessionOptions session_options; + + // Disable common runtime constant folding. + session_options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(OptimizerOptions::L0); + + // Disable Grappler optimizations for tests. + tensorflow::RewriterConfig* cfg = + session_options.config.mutable_graph_options() + ->mutable_rewrite_options(); + cfg->set_constant_folding(tensorflow::RewriterConfig::OFF); + cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF); + cfg->set_remapping(tensorflow::RewriterConfig::OFF); + + std::unique_ptr session( + tensorflow::NewSession(session_options)); + + std::vector available_devices; + TF_ASSERT_OK(session->ListDevices(&available_devices)) + << "Failed to get available session devices"; + + // Check if session has an available GPU device. + const bool has_gpu_device = + absl::c_any_of(available_devices, [](const DeviceAttributes& device) { + return device.device_type() == DEVICE_GPU; + }); + + // If fused computation implemented only for CPU, in this test we don't want + // to compare GPU vs CPU numbers, so place all nodes on CPU in this case. + const bool place_all_on_gpu = allow_gpu_device && has_gpu_device; + + const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0"; + for (NodeDef& mutable_node : *graph.mutable_node()) { + mutable_node.set_device(device); + } + + TF_ASSERT_OK(session->Create(graph)); + + std::vector unfused_tensors; + TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors)); + + *output = unfused_tensors[0]; + } + + void RunMatMulWithBias(const Tensor& lhs_data, const Tensor& rhs_data, + const Tensor& bias_data, bool transpose_a, + bool transpose_b, Tensor* output, + bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + ops::MatMul matmul = ops::MatMul( + root.WithOpName("matmul"), + ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)), + ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)), + ops::MatMul::Attrs().TransposeA(transpose_a).TransposeB(transpose_b)); + + ops::BiasAdd with_bias = ops::BiasAdd( + root.WithOpName("with_bias"), matmul, + ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data))); + + RunAndFetch(root, "with_bias", output, allow_gpu_device); + } + + void RunMatMulWithBiasAndActivation( + const Tensor& lhs_data, const Tensor& rhs_data, const Tensor& bias_data, + bool transpose_a, bool transpose_b, const string& activation_type, + Tensor* output, bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + ops::MatMul matmul = ops::MatMul( + root.WithOpName("matmul"), + ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)), + ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)), + ops::MatMul::Attrs().TransposeA(transpose_a).TransposeB(transpose_b)); + + ops::BiasAdd with_bias = ops::BiasAdd( + root.WithOpName("with_bias"), matmul, + ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data))); + + if (activation_type == "Relu") { + ops::Relu(root.WithOpName("with_activation"), with_bias); + } else { + ops::Identity(root.WithOpName("with_activation"), with_bias); + } + + RunAndFetch(root, "with_activation", output, allow_gpu_device); + } + + void RunFusedMatMulOp(const Tensor& lhs_data, const Tensor& rhs_data, + const std::vector& args_data, + const std::vector& fused_ops, bool transpose_a, + bool transpose_b, Tensor* output, + bool allow_gpu_device = false) { + Scope root = tensorflow::Scope::NewRootScope(); + + DataType dtype = DataTypeToEnum::v(); + int num_args = static_cast(args_data.size()); + + Output a = + ops::Const(root.WithOpName("a"), Input::Initializer(lhs_data)); + Output b = + ops::Const(root.WithOpName("b"), Input::Initializer(rhs_data)); + + std::vector args; + for (int i = 0; i < num_args; ++i) { + Output arg = ops::Const(root.WithOpName(absl::StrCat("arg", i)), + Input::Initializer(args_data[i])); + args.emplace_back(arg.name(), 0, dtype); + } + + NodeDef kdnn_fused_matmul; + TF_EXPECT_OK(NodeDefBuilder("kdnn_fused_matmul", "_KdnnFusedMatMul") + .Input({a.name(), 0, dtype}) + .Input({b.name(), 0, dtype}) + .Input(args) + .Attr("num_args", num_args) + .Attr("T", dtype) + .Attr("fused_ops", fused_ops) + .Attr("transpose_a", transpose_a) + .Attr("transpose_b", transpose_b) + .Finalize(&kdnn_fused_matmul)); + + RunAndFetch(root, kdnn_fused_matmul.name(), output, allow_gpu_device, + &kdnn_fused_matmul); + } + + void VerifyBiasAddTensorsNear(int m, int k, int n, + const BiasAddGraphRunner& run_default, + const BiasAddGraphRunner& run_fused) { + DataType dtype = DataTypeToEnum::v(); + + Tensor lhs(dtype, {m, k}); + lhs.flat() = lhs.flat().setRandom(); + + // Add some negative values to filter to properly test Relu. + Tensor rhs(dtype, {k, n}); + rhs.flat() = rhs.flat().setRandom(); + rhs.flat() -= rhs.flat().constant(static_cast(0.5f)); + + // Bias added to the inner dimension. + const int bias_size = n; + Tensor bias(dtype, {bias_size}); + bias.flat() = bias.flat().setRandom(); + bias.flat() += bias.flat().constant(static_cast(0.5f)); + + Tensor matmul; + Tensor fused_matmul; + + run_default(lhs, rhs, bias, &matmul); + run_fused(lhs, rhs, bias, &fused_matmul); + + ASSERT_EQ(matmul.dtype(), fused_matmul.dtype()); + ASSERT_EQ(matmul.shape(), fused_matmul.shape()); + + test::ExpectClose(matmul, fused_matmul, /*atol=*/1e-5); + } + + // Verifies that computing MatMul+BiasAdd in a graph is identical to + // FusedMatMul. + void VerifyMatMulWithBias(int m, int k, int n, bool transpose_a, + bool transpose_b) { + const BiasAddGraphRunner run_default = + [&](const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunMatMulWithBias(input_data, filter_data, bias_data, transpose_a, + transpose_b, out); + }; + + const BiasAddGraphRunner run_fused = + [&](const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunFusedMatMulOp(input_data, filter_data, {bias_data}, {"BiasAdd"}, + transpose_a, transpose_b, out); + }; + + VerifyBiasAddTensorsNear(m, k, n, run_default, run_fused); + } + + // Verifies that computing MatMul+BiasAdd+{Activation} in a graph is identical + // to FusedMatMul. + void VerifyMatmulWithBiasAndActivation(int m, int k, int n, bool transpose_a, + bool transpose_b, + const string& activation) { + const BiasAddGraphRunner run_default = [&](const Tensor& input_data, + const Tensor& filter_data, + const Tensor& bias_data, + Tensor* out) { + RunMatMulWithBiasAndActivation(input_data, filter_data, bias_data, + transpose_a, transpose_b, activation, out); + }; + + const BiasAddGraphRunner run_fused = [&](const Tensor& input_data, + const Tensor& filter_data, + const Tensor& bias_data, + Tensor* out) { + RunFusedMatMulOp(input_data, filter_data, {bias_data}, + {"BiasAdd", activation}, transpose_a, transpose_b, out); + }; + + VerifyBiasAddTensorsNear(m, k, n, run_default, run_fused); + } +}; + +// MatMul with BatchNorm can be tested only with `T=float`, because default +// `FusedBatchNorm` kernel supports only floats for scale, mean and variance. + +template +class FusedMatMulWithBiasOpTest : public FusedMatMulOpTest {}; + +TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest); + +// -------------------------------------------------------------------------- // +// MatMul + BiasAdd // +// -------------------------------------------------------------------------- // + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256) { + this->VerifyMatMulWithBias(256, 256, 256, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256) { + this->VerifyMatMulWithBias(1, 256, 256, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1) { + this->VerifyMatMulWithBias(256, 256, 1, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1) { + this->VerifyMatMulWithBias(1, 256, 1, false, false); +} + +// -------------------------------------------------------------------------- // +// MatMul Extended Random Test // +// Cover: jdtest // +// -------------------------------------------------------------------------- // + +// === 中等规模 & KDNN 常见场景 === +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_5530x104x32) { + this->VerifyMatMulWithBias(5530, 104, 32, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_5530x116x32) { + this->VerifyMatMulWithBias(5530, 116, 32, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_5530x32x16) { + this->VerifyMatMulWithBias(5530, 32, 16, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_5530x16x1) { + this->VerifyMatMulWithBias(5530, 16, 1, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_7000x104x32) { + this->VerifyMatMulWithBias(7000, 104, 32, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_7000x116x32) { + this->VerifyMatMulWithBias(7000, 116, 32, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_7000x32x16) { + this->VerifyMatMulWithBias(7000, 32, 16, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_7000x16x1) { + this->VerifyMatMulWithBias(7000, 16, 1, false, false); +} + +// 极小维度测试:0 维度(空矩阵) +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_0x256x256) { + this->VerifyMatMulWithBias(0, 256, 256, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_256x256x0) { + this->VerifyMatMulWithBias(256, 256, 0, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_0x0x0) { + this->VerifyMatMulWithBias(0, 0, 0, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_0x0x1) { + this->VerifyMatMulWithBias(0, 0, 1, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_0x1x0) { + this->VerifyMatMulWithBias(0, 1, 0, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_1x0x0) { + this->VerifyMatMulWithBias(1, 0, 0, false, false); +} + +// 非 2 的幂次维度(内存不对齐) +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_257x257x257) { + this->VerifyMatMulWithBias(257, 257, 257, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_250x240x230) { + this->VerifyMatMulWithBias(250, 240, 230, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_123x456x789) { + this->VerifyMatMulWithBias(123, 456, 789, false, false); +} + +// 大 k 值(高计算密度) +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_64x8192x64) { + this->VerifyMatMulWithBias(64, 8192, 64, false, false); +} + +// 小 k 值 +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_256x1x256) { + this->VerifyMatMulWithBias(256, 1, 256, false, false); +} + +// 大 k:如 Embedding 后接 FFN +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_64x4096x512) { + this->VerifyMatMulWithBias(64, 4096, 512, false, false); +} + +// 小 m, 大 n:如分类头 +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_1x512x1000) { + this->VerifyMatMulWithBias(1, 512, 1000, false, false); +} + +// 大 m, 小 n:如 Batch 大但输出小 +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_1024x256x1) { + this->VerifyMatMulWithBias(1024, 256, 1, false, false); +} + +// 超小维度组合(广播/Kernel 选择错误) +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_1x1x1) { + this->VerifyMatMulWithBias(1, 1, 1, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_1x1x64) { + this->VerifyMatMulWithBias(1, 1, 64, false, false); +} + +TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul_64x1x1) { + this->VerifyMatMulWithBias(64, 1, 1, false, false); +} + +// -------------------------------------------------------------------------- // +// MatMul + BiasAdd + {Activation} // +// -------------------------------------------------------------------------- // +template +class FusedMatMulWithBiasActivationOpTest : public FusedMatMulOpTest {}; +TYPED_TEST_SUITE_P(FusedMatMulWithBiasActivationOpTest); + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul256x256x256WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(256, 256, 256, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul1x256x256WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1, 256, 256, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul256x256x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(256, 256, 1, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul1x256x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1, 256, 1, false, false, + activation); + } +} + +// -------------------------------------------------------------------------- // +// MatMul Extended Random Test // +// Cover: jdtest // +// -------------------------------------------------------------------------- // + +// === 中等规模 & KDNN 常见场景 === +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_5530x104x32WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(5530, 104, 32, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_5530x116x32WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(5530, 116, 32, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_5530x32x16WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(5530, 32, 16, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_5530x16x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(5530, 16, 1, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_7000x104x32WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(7000, 104, 32, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_7000x116x32WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(7000, 116, 32, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_7000x32x16WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(7000, 32, 16, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_7000x16x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(7000, 16, 1, false, false, + activation); + } +} + +// 极小维度测试:0 维度(空矩阵) +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_0x256x256WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(0, 256, 256, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_256x256x0WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(256, 256, 0, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_0x0x0WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(0, 0, 0, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_0x0x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(0, 0, 1, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_0x1x0WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(0, 1, 0, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_1x0x0WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1, 0, 0, false, false, + activation); + } +} + +// 非 2 的幂次维度(内存不对齐) +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_257x257x257WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(257, 257, 257, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_250x240x230WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(250, 240, 230, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_123x456x789WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(123, 456, 789, false, false, + activation); + } +} + +// 大 k 值(高计算密度) +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_64x8192x64WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(64, 8192, 64, false, false, + activation); + } +} + +// 小 k 值 +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_256x1x256WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(256, 1, 256, false, false, + activation); + } +} + +// 大 k:如 Embedding 后接 FFN +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_64x4096x512WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(64, 4096, 512, false, false, + activation); + } +} + +// 小 m, 大 n:如分类头 +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_1x512x1000WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1, 512, 1000, false, false, + activation); + } +} + +// 大 m, 小 n:如 Batch 大但输出小 +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_1024x256x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1024, 256, 1, false, false, + activation); + } +} + +// 超小维度组合(广播/Kernel 选择错误) +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_1x1x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1, 1, 1, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_1x1x64WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(1, 1, 64, false, false, + activation); + } +} + +TYPED_TEST_P(FusedMatMulWithBiasActivationOpTest, MatMul_64x1x1WithActivation) { + for (const string& activation : {"Relu"}) { + this->VerifyMatmulWithBiasAndActivation(64, 1, 1, false, false, + activation); + } +} + +REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest, + MatMul256x256x256, + MatMul1x256x256, + MatMul256x256x1, + MatMul1x256x1, + MatMul_5530x104x32, + MatMul_5530x116x32, + MatMul_5530x32x16, + MatMul_5530x16x1, + MatMul_7000x104x32, + MatMul_7000x116x32, + MatMul_7000x32x16, + MatMul_7000x16x1, + MatMul_0x256x256, + MatMul_256x256x0, + MatMul_0x0x0, + MatMul_0x0x1, + MatMul_0x1x0, + MatMul_1x0x0, + MatMul_257x257x257, + MatMul_250x240x230, + MatMul_123x456x789, + MatMul_64x8192x64, + MatMul_256x1x256, + MatMul_64x4096x512, + MatMul_1x512x1000, + MatMul_1024x256x1, + MatMul_1x1x1, + MatMul_1x1x64, + MatMul_64x1x1); + +REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasActivationOpTest, + MatMul256x256x256WithActivation, + MatMul1x256x256WithActivation, + MatMul256x256x1WithActivation, + MatMul1x256x1WithActivation, + MatMul_5530x104x32WithActivation, + MatMul_5530x116x32WithActivation, + MatMul_5530x32x16WithActivation, + MatMul_5530x16x1WithActivation, + MatMul_7000x104x32WithActivation, + MatMul_7000x116x32WithActivation, + MatMul_7000x32x16WithActivation, + MatMul_7000x16x1WithActivation, + MatMul_0x256x256WithActivation, + MatMul_256x256x0WithActivation, + MatMul_0x0x0WithActivation, + MatMul_0x0x1WithActivation, + MatMul_0x1x0WithActivation, + MatMul_1x0x0WithActivation, + MatMul_257x257x257WithActivation, + MatMul_250x240x230WithActivation, + MatMul_123x456x789WithActivation, + MatMul_64x8192x64WithActivation, + MatMul_256x1x256WithActivation, + MatMul_64x4096x512WithActivation, + MatMul_1x512x1000WithActivation, + MatMul_1024x256x1WithActivation, + MatMul_1x1x1WithActivation, + MatMul_1x1x64WithActivation, + MatMul_64x1x1WithActivation); + +// TODO(ezhulenev): Add support for more data types. +using FusedBiasAddDataTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedMatMulWithBiasOpTest, + FusedBiasAddDataTypes); + +using FusedBiasAddActivationDataTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedMatMulWithBiasActivationOpTest, + FusedBiasAddActivationDataTypes); +} // end namespace tensorflow diff --git a/tensorflow/core/ops/kdnn_ops.cc b/tensorflow/core/ops/kdnn_ops.cc new file mode 100644 index 00000000..2b00ec1a --- /dev/null +++ b/tensorflow/core/ops/kdnn_ops.cc @@ -0,0 +1,53 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/util/mirror_pad_mode.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("_KdnnFusedMatMul") + .Input("a: T") + .Input("b: T") + .Input("args: num_args * T") + .Output("product: T") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("T: {float}") + .Attr("num_args: int >= 0") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ----------- // + .Attr("epsilon: float = 0.0001") + // --------------------------------------------- // + .SetShapeFn(shape_inference::MatMulShape) + .Doc(R"doc( +KDNN version of FusedMatMul operator. Uses KDNN APIs to implement MatMul +operator. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + + +} // namespace tensorflow diff --git a/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h index 8502bf7b..b067af2a 100644 --- a/third_party/KDNN/kdnn_adapter.h +++ b/third_party/KDNN/kdnn_adapter.h @@ -151,6 +151,63 @@ void kdnnSparseMatmul(const std::size_t nnz, kdnnSparseMatmulCSR(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b); } +typedef enum +{ + FUSED_TYPE_UNDEFINED = 0, + FUSED_TYPE_BIAS = 1, + FUSED_TYPE_BIASRELU = 2 +} kdnnFusedType; + +inline void kdnnFusedGemm(const OpKernelContext* ctx, const Tensor& a, const Tensor& b, const Tensor& bias, + Tensor* out, bool trans_a_, bool trans_b, int fused_type) { + const float* bias_data = bias.flat().data(); + + long unsigned int m = a.dim_size(0); + long unsigned int n = b.dim_size(1); + long unsigned int k = a.dim_size(1); + const float* A = a.flat().data(); + const float* B = b.flat().data(); + float* C = out->flat().data(); + // intra_op thread_pool + thread::ThreadPool* thread_pool = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers; + kdnn::KDNNThreadPool eigen_tp(thread_pool); + const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo, &eigen_tp); + auto relu = [](float x) { return std::max(0.0f, x); }; + auto identity = [](float x) { return x; }; + + auto matmul_bias_activation = [&](auto activation) { + gemm.Run(A, B, C); + int cost_per_unit = 2 * n; + thread_pool->ParallelFor(m, cost_per_unit, [&](int start, int end) { + for (int i = start; i < end; i++) { + for (int j = 0; j < n; j++) { + float val = C[i * n + j] + bias_data[j]; + C[i * n + j] = activation(val); + } + } + }); + }; + switch (fused_type) { + case kdnnFusedType::FUSED_TYPE_BIAS: + matmul_bias_activation(identity); + break; + case kdnnFusedType::FUSED_TYPE_BIASRELU: + matmul_bias_activation(relu); + break; + case kdnnFusedType::FUSED_TYPE_UNDEFINED: + std::invalid_argument("Fusion type is undefined"); + break; + default: + std::invalid_argument("Fusion type is not supported"); + } +} + }// namespace tensorflow #endif // THIRD_PARTY_KDNN_KDNN_ADAPTER_H_ \ No newline at end of file -- Gitee