diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 2bbae538e8bd0b2739f0095484176800cc74c5da..961468ea70165f9f9ceaf6fba9280a5880b0c7d3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -4002,7 +4002,8 @@ tf_kernel_library( prefix = "batch_matmul_op", deps = MATH_DEPS + [":eigen_contraction_kernel"] + if_mkl_ml([ "//third_party/mkl:intel_binary_blob", - ]), + ]) + kdnn_deps(), + defines = if_enable_kdnn(["ENABLE_KDNN=1"]), ) tf_mkl_kernel_library( diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h index 5649c06878076efee790cec0421cda7989b0a379..a8de9d1e586edfc4f8af41f19f5c0467b76bb58e 100644 --- a/tensorflow/core/kernels/batch_matmul_op_impl.h +++ b/tensorflow/core/kernels/batch_matmul_op_impl.h @@ -46,6 +46,11 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM +#if defined(ENABLE_KDNN) +#include "kdnn_adapter.h" +DECLARE_bool(enable_kdnn); +#endif + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -123,6 +128,12 @@ struct ParallelMatMulKernel { static void Run(const OpKernelContext* context, const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, const MatMulBCast& bcast, Tensor* out, int start, int limit) { +#if defined(ENABLE_KDNN) + if (FLAGS_enable_kdnn && std::is_same::value && !adj_x && !adj_y) { + kdnnParallelGemm(context, in_x, in_y, out, bcast, start, limit); + return; + } +#endif auto Tx = in_x.tensor(); auto Ty = in_y.tensor(); auto Tz = out->tensor(); @@ -171,6 +182,12 @@ struct SequentialMatMulKernel { static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x, bool adj_y, const MatMulBCast& bcast, Tensor* out, int start, int limit) { +#if defined(ENABLE_KDNN) + if (FLAGS_enable_kdnn && std::is_same::value && !adj_x && !adj_y) { + kdnnSeqGemm(in_x, in_y, out, bcast, start, limit); + return; + } +#endif const bool should_bcast = bcast.IsBroadcastingRequired(); const auto& x_batch_indices = bcast.x_batch_indices(); const auto& y_batch_indices = bcast.y_batch_indices(); diff --git a/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h index 0f8c4a6d93e8023ae09081d5493ebf5a337e2e4e..8502bf7b7b7264b63494da5b946ac90b3e8b4915 100644 --- a/third_party/KDNN/kdnn_adapter.h +++ b/third_party/KDNN/kdnn_adapter.h @@ -18,6 +18,7 @@ limitations under the License. #include "kdnn.hpp" #include "kdnn_threadpool.h" +#include "tensorflow/core/util/matmul_bcast.h" namespace tensorflow { @@ -34,14 +35,70 @@ inline void kdnnGemm(OpKernelContext* ctx, const Tensor& a, const Tensor& b, Ten ctx->device() ->tensorflow_cpu_worker_threads() ->workers; - kdnn::KDNNThreadPool eigen_tp(thread_pool); + kdnn::KDNNThreadPool kdnn_tp(thread_pool); const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; - KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo, &eigen_tp); + KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo, &kdnn_tp); gemm.Run(A, B, C); } +inline void kdnnParallelGemm(const OpKernelContext* ctx, const Tensor& a, const Tensor& b, Tensor* out, + const MatMulBCast& bcast, int start, int end) { + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + int m = a.dim_size(1); + int n = b.dim_size(2); + int k = a.dim_size(2); + int stride_a = m * k; + int stride_b = k * n; + int stride_c = m * n; + const float *A = a.flat().data(); + const float *B = b.flat().data(); + float *C = out->flat().data(); + // intra_op thread_pool + thread::ThreadPool* thread_pool = + ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers; + kdnn::KDNNThreadPool kdnn_tp(thread_pool); + const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo, &kdnn_tp); + for (int64_t i = start; i < end; ++i) { + const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; + gemm.Run(A + x_batch_index * stride_a, B + y_batch_index * stride_b, C + i * stride_c); + } +} + +inline void kdnnSeqGemm(const Tensor& a, const Tensor& b, Tensor* out, + const MatMulBCast& bcast, int start, int end) { + const bool should_bcast = bcast.IsBroadcastingRequired(); + const auto& x_batch_indices = bcast.x_batch_indices(); + const auto& y_batch_indices = bcast.y_batch_indices(); + int m = a.dim_size(1); + int n = b.dim_size(2); + int k = a.dim_size(2); + int stride_a = m * k; + int stride_b = k * n; + int stride_c = m * n; + const float *A = a.flat().data(); + const float *B = b.flat().data(); + float *C = out->flat().data(); + const KDNN::TensorInfo srcInfo = {{m, k}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo weightsInfo = {{k, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo dstInfo = {{m, n}, KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + KDNN::Gemm gemm(srcInfo, weightsInfo, dstInfo); + for (int64_t i = start; i < end; ++i) { + const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; + const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; + gemm.Run(A + x_batch_index * stride_a, B + y_batch_index * stride_b, C + i * stride_c); + } +} + template inline void kdnnSparseMatmulCSR(const std::size_t nnz, const std::size_t rhs_right, const std::size_t lhs_right,