From d01e02a68dce4dfb0b24a6fcf28984f0102d3c87 Mon Sep 17 00:00:00 2001 From: Codersheepchen Date: Wed, 6 Aug 2025 23:51:17 -0400 Subject: [PATCH] add env KDNN_NUM_THREADS to set max threads can be used by KDNN --- .../kernels/sparse_tensor_dense_matmul_op.cc | 5 ++-- third_party/KDNN/kdnn_adapter.h | 3 ++- third_party/KDNN/kdnn_threadpool.h | 25 ++++++++++++++++--- 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index f43095775..96a7b34ad 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -44,7 +44,7 @@ class SparseTensorDenseMatMulOp : public OpKernel { : OpKernel(ctx), kdnn_enable(true) { OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_a", &adjoint_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_b", &adjoint_b_)); - TF_CHECK_OK(ReadBoolFromEnvVar("KDNN_ENABLE", true, &kdnn_enable)); + OP_REQUIRES_OK(ctx, ReadBoolFromEnvVar("KDNN_ENABLE", true, &kdnn_enable)); } void Compute(OpKernelContext* ctx) override { @@ -385,8 +385,7 @@ struct KDNNSparseMatMulFunctor { } } } else { - const int b_chip_index = 0; - kdnnSparseMatmul(nnz, rhs_right, lhs_right, out, a_indices, a_values, b); + kdnnSparseMatmul(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b); } return Status::OK(); } diff --git a/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h index e7836c892..35248e1d5 100644 --- a/third_party/KDNN/kdnn_adapter.h +++ b/third_party/KDNN/kdnn_adapter.h @@ -37,6 +37,7 @@ static bool compareByRow(const NonZeroElement& a, const NonZeroElement& b) { template void kdnnSparseMatmul(const std::size_t nnz, const std::size_t rhs_right, const std::size_t lhs_right, + const int lhs_index_a, const int rhs_index_a, typename TTypes::Matrix out, typename TTypes::ConstMatrix a_indices, typename TTypes::ConstVec a_values, @@ -45,7 +46,7 @@ void kdnnSparseMatmul(const std::size_t nnz, float val[nnz]; std::vector elements; for (size_t i = 0; i < nnz; ++i) { - elements.emplace_back(NonZeroElement{a_indices(i, 0), a_indices(i, 1), a_values(i)}); + elements.emplace_back(NonZeroElement{a_indices(i, lhs_index_a), a_indices(i, rhs_index_a), a_values(i)}); } std::sort(elements.begin(), elements.end(), compareByRow); for (size_t i = 0; i < nnz; ++i) { diff --git a/third_party/KDNN/kdnn_threadpool.h b/third_party/KDNN/kdnn_threadpool.h index 193a2e033..83d2f1f97 100644 --- a/third_party/KDNN/kdnn_threadpool.h +++ b/third_party/KDNN/kdnn_threadpool.h @@ -13,6 +13,9 @@ #include "kdnn.hpp" #include "tensorflow/core/platform/blocking_counter.h" #include "tensorflow/core/platform/threadpool.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/env_var.h" namespace kdnn { @@ -29,7 +32,7 @@ class KDNNThreadPool : public ThreadpoolWrapper { eigen_interface_(thread_pool->AsEigenThreadPool()) { set_num_and_max_threads(num_threads); } - + virtual int GetNumThreads() const override { return num_threads_; } virtual bool GetInParallel() const override { return (eigen_interface_->CurrentThreadId() != -1) ? true : false; @@ -46,8 +49,24 @@ class KDNNThreadPool : public ThreadpoolWrapper { Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; int num_threads_ = 1; inline void set_num_and_max_threads(int num_threads) { - num_threads_ = - num_threads == -1 ? eigen_interface_->NumThreads() : num_threads; + int pool_threads = eigen_interface_->NumThreads(); + + if (num_threads > 0) { + num_threads_ = std::min(pool_threads, num_threads); + return; + } + + tensorflow::int64 env_threads = -1; + tensorflow::Status status = tensorflow::ReadInt64FromEnvVar("KDNN_NUM_THREADS", -1, &env_threads); + if (!status.ok()) { + LOG(WARNING) << "Parse env KDNN_NUM_THREADS failed, use default thread nums"; + } + if (env_threads > 0) { + num_threads_ = std::min(pool_threads, static_cast(env_threads)); + return; + } + + num_threads_ = pool_threads; } }; -- Gitee