From 3853bac4ca0db3961f4ee2b82000fe5ba65c96fa Mon Sep 17 00:00:00 2001 From: Codersheepchen Date: Thu, 31 Jul 2025 08:29:00 -0400 Subject: [PATCH] add sparse matmul --- third_party/KDNN/BUILD | 4 +- third_party/KDNN/kdnn_adapter.h | 75 ++++++++++++++++++--------------- 2 files changed, 43 insertions(+), 36 deletions(-) diff --git a/third_party/KDNN/BUILD b/third_party/KDNN/BUILD index 0652fbbca..e70246b1c 100644 --- a/third_party/KDNN/BUILD +++ b/third_party/KDNN/BUILD @@ -33,6 +33,7 @@ cc_library( "src/dnn/src/sparse/kernel/gemm_csr_row_k_small.c", "src/dnn/src/sparse/**/*.S", ]), + visibility = ["//visibility:public"], ) cc_library( @@ -41,14 +42,15 @@ cc_library( includes = ["src/dnn/include"], srcs = glob(["src/dnn/src/*.cpp", "src/dnn/src/*.hpp"]), deps = [":ksparse"], + visibility = ["//visibility:public"], ) cc_library( name = "kdnn_adapter", hdrs = ["kdnn_adapter.h", "kdnn_threadpool.h"], strip_include_prefix = "/third_party/KDNN", - visibility = ["//visibility:public"], deps = [":kdnn"], + visibility = ["//visibility:public"], ) bzl_library( diff --git a/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h index 35248e1d5..c3bbf71fd 100644 --- a/third_party/KDNN/kdnn_adapter.h +++ b/third_party/KDNN/kdnn_adapter.h @@ -24,51 +24,56 @@ inline void kdnnGemm(OpKernelContext* ctx, const Tensor& a, const Tensor& b, Ten gemm.Run(A, B, C); } -struct NonZeroElement { - int row; - int col; - float val; -}; - -static bool compareByRow(const NonZeroElement& a, const NonZeroElement& b) { - return a.row < b.row; -} - template -void kdnnSparseMatmul(const std::size_t nnz, - const std::size_t rhs_right, const std::size_t lhs_right, - const int lhs_index_a, const int rhs_index_a, - typename TTypes::Matrix out, - typename TTypes::ConstMatrix a_indices, - typename TTypes::ConstVec a_values, - typename TTypes::ConstMatrix b) { - KDNN_INT idx[nnz]; - float val[nnz]; - std::vector elements; - for (size_t i = 0; i < nnz; ++i) { - elements.emplace_back(NonZeroElement{a_indices(i, lhs_index_a), a_indices(i, rhs_index_a), a_values(i)}); - } - std::sort(elements.begin(), elements.end(), compareByRow); - for (size_t i = 0; i < nnz; ++i) { - idx[i] = elements[i].col; - val[i] = elements[i].val; - } +inline void kdnnSparseMatmulCSR(const std::size_t nnz, + const std::size_t rhs_right, const std::size_t lhs_right, + const int lhs_index_a, const int rhs_index_a, + typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b) { + std::vector idx(nnz); int m = out.dimension(0); - KDNN_INT pntrb[m] = {0}; - KDNN_INT pntre[m] = {0}; - std::vector row_counts(m, 0); - for (const auto& t : elements) { - row_counts[t.row]++; + std::vector pntrb(m); + std::vector pntre(m); + std::vector row_counts(m); + for (size_t i = 0; i < nnz; ++i) { + idx[i] = a_indices(i, rhs_index_a); + ++row_counts[a_indices(i, lhs_index_a)]; } + int current_pos = 0; for (size_t i = 0; i < m; ++i) { pntrb[i] = current_pos; current_pos += row_counts[i]; pntre[i] = current_pos; } + const KDNN::CsrSparseTensorInfo aInfo = {{m, lhs_right}, + KDNN::Element::TypeT::F32, KDNN::Layout::AB, pntrb, pntre, idx, nnz}; + const KDNN::TensorInfo bInfo = {{lhs_right, rhs_right}, + KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + const KDNN::TensorInfo dstInfo = {{m, rhs_right}, + KDNN::Element::TypeT::F32, KDNN::Layout::AB}; + KDNN::SparseGemm sparse_csr(aInfo, bInfo, dstInfo); + sparse_csr.Run(a_values.data(), b.data(), out.data()); +} + +template +void kdnnSparseMatmul(const std::size_t nnz, + const std::size_t rhs_right, const std::size_t lhs_right, + const int lhs_index_a, const int rhs_index_a, + typename TTypes::Matrix out, + typename TTypes::ConstMatrix a_indices, + typename TTypes::ConstVec a_values, + typename TTypes::ConstMatrix b) { + static const std::size_t kNumCSR = 720; + int m = out.dimension(0); VLOG(1) << "kdnnSparseMatmul, M: " << m << " N:" << rhs_right << " K:" << lhs_right << " nnz:" << nnz; - KDNN::SparseCsrmm(KDNN_SPARSE_OPERATION_NON_TRANSPOSE, m, rhs_right, lhs_right, - 1.0, "G00C", val, idx, pntrb, pntre, b.data(), rhs_right, 0.0, out.data(), rhs_right); + if ((m > kint32max) || (rhs_right > kint32max) || (lhs_right > kint32max)) { + LOG(WARNING) << "too large m/n/k in KDNN sparse matmul, max allowed is " << kint32max; + return; + } + kdnnSparseMatmulCSR(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b); } }// namespace tensorflow \ No newline at end of file -- Gitee