From f15388968e66dd45adaf80264f995c187edfa8fb Mon Sep 17 00:00:00 2001 From: Codersheepchen Date: Mon, 18 Aug 2025 23:37:38 -0400 Subject: [PATCH] add kdnn gflags insdead of env & add copyright --- tensorflow/core/kernels/matmul_op.cc | 11 +++--- .../kernels/sparse_tensor_dense_matmul_op.cc | 11 +++--- third_party/KDNN/BUILD | 6 +++- third_party/KDNN/kdnn_adapter.h | 22 +++++++++++- third_party/KDNN/kdnn_flags.cc | 19 +++++++++++ third_party/KDNN/kdnn_threadpool.h | 34 +++++++++++++------ 6 files changed, 76 insertions(+), 27 deletions(-) create mode 100644 third_party/KDNN/kdnn_flags.cc diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 0f422330d..ff766392a 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/math_ops.cc. +#include "gflags/gflags.h" #define EIGEN_USE_THREADS @@ -24,7 +25,6 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/util/matmul_autotune.h" -#include "tensorflow/core/util/env_var.h" #if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" #endif @@ -35,6 +35,7 @@ limitations under the License. #if defined(ENABLE_KDNN) #include "kdnn_adapter.h" +DECLARE_bool(enable_kdnn); #endif namespace tensorflow { @@ -447,16 +448,13 @@ template class MatMulOp : public OpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx) - : OpKernel(ctx), algorithms_set_already_(false), kdnn_enable(true) { + : OpKernel(ctx), algorithms_set_already_(false) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); LaunchMatMul::GetBlasGemmAlgorithm( ctx, &algorithms_, &algorithms_set_already_); use_autotune_ = MatmulAutotuneEnable(); -#if defined(ENABLE_KDNN) - OP_REQUIRES_OK(ctx, tensorflow::ReadBoolFromEnvVar("KDNN_ENABLE", true, &kdnn_enable)); -#endif } void Compute(OpKernelContext* ctx) override { @@ -526,7 +524,7 @@ class MatMulOp : public OpKernel { out->flat().data(), out->NumElements()); } #if defined(ENABLE_KDNN) - else if (kdnn_enable && std::is_same::value && + else if (FLAGS_enable_kdnn && std::is_same::value && !transpose_a_ && !transpose_b_) { kdnnGemm(ctx, a, b, out, transpose_a_, transpose_b_); } @@ -543,7 +541,6 @@ class MatMulOp : public OpKernel { bool use_autotune_; bool transpose_a_; bool transpose_b_; - bool kdnn_enable; }; namespace functor { diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc index f2dcab4e5..f4f4c6f98 100644 --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc @@ -24,10 +24,11 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" -#include "tensorflow/core/util/env_var.h" +#include "gflags/gflags.h" #if defined(ENABLE_KDNN) #include "kdnn_adapter.h" +DECLARE_bool(enable_kdnn); #endif #include @@ -42,12 +43,9 @@ template class SparseTensorDenseMatMulOp : public OpKernel { public: explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx) - : OpKernel(ctx), kdnn_enable(true) { + : OpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_a", &adjoint_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_b", &adjoint_b_)); -#if defined(ENABLE_KDNN) - OP_REQUIRES_OK(ctx, ReadBoolFromEnvVar("KDNN_ENABLE", true, &kdnn_enable)); -#endif } void Compute(OpKernelContext* ctx) override { @@ -155,7 +153,7 @@ class SparseTensorDenseMatMulOp : public OpKernel { b->matrix()); \ OP_REQUIRES_OK(ctx, functor_status); \ } - if (kdnn_enable && std::is_same::value && adjoint_b_ == false) { + if (FLAGS_enable_kdnn && std::is_same::value && adjoint_b_ == false) { KDNN_ADJOINT(false, false); KDNN_ADJOINT(true, false); return; @@ -184,7 +182,6 @@ class SparseTensorDenseMatMulOp : public OpKernel { private: bool adjoint_a_; bool adjoint_b_; - bool kdnn_enable; }; #define REGISTER_CPU(TypeT, TypeIndex) \ diff --git a/third_party/KDNN/BUILD b/third_party/KDNN/BUILD index e70246b1c..644314afc 100644 --- a/third_party/KDNN/BUILD +++ b/third_party/KDNN/BUILD @@ -47,9 +47,13 @@ cc_library( cc_library( name = "kdnn_adapter", + srcs = ["kdnn_flags.cc"], hdrs = ["kdnn_adapter.h", "kdnn_threadpool.h"], strip_include_prefix = "/third_party/KDNN", - deps = [":kdnn"], + deps = [ + "@com_github_gflags_gflags//:gflags", + ":kdnn", + ], visibility = ["//visibility:public"], ) diff --git a/third_party/KDNN/kdnn_adapter.h b/third_party/KDNN/kdnn_adapter.h index c3bbf71fd..0f8c4a6d9 100644 --- a/third_party/KDNN/kdnn_adapter.h +++ b/third_party/KDNN/kdnn_adapter.h @@ -1,3 +1,21 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_KDNN_KDNN_ADAPTER_H_ +#define THIRD_PARTY_KDNN_KDNN_ADAPTER_H_ + #include "kdnn.hpp" #include "kdnn_threadpool.h" @@ -76,4 +94,6 @@ void kdnnSparseMatmul(const std::size_t nnz, kdnnSparseMatmulCSR(nnz, rhs_right, lhs_right, lhs_index_a, rhs_index_a, out, a_indices, a_values, b); } -}// namespace tensorflow \ No newline at end of file +}// namespace tensorflow + +#endif // THIRD_PARTY_KDNN_KDNN_ADAPTER_H_ \ No newline at end of file diff --git a/third_party/KDNN/kdnn_flags.cc b/third_party/KDNN/kdnn_flags.cc new file mode 100644 index 000000000..e5882c914 --- /dev/null +++ b/third_party/KDNN/kdnn_flags.cc @@ -0,0 +1,19 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "gflags/gflags.h" + +DEFINE_bool(enable_kdnn, true, "Enable KDNN Operation"); +DEFINE_int64(kdnn_num_threads, -1, "Set KDNN num threads, default same as intra threadpool"); \ No newline at end of file diff --git a/third_party/KDNN/kdnn_threadpool.h b/third_party/KDNN/kdnn_threadpool.h index 83d2f1f97..4b71c1234 100644 --- a/third_party/KDNN/kdnn_threadpool.h +++ b/third_party/KDNN/kdnn_threadpool.h @@ -1,5 +1,20 @@ -#ifndef TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_ -#define TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_KDNN_KDNN_THREADPOOL_H_ +#define THIRD_PARTY_KDNN_KDNN_THREADPOOL_H_ #include #include @@ -7,6 +22,7 @@ #include #include #include +#include "gflags/gflags.h" #define EIGEN_USE_THREADS @@ -15,7 +31,8 @@ #include "tensorflow/core/platform/threadpool.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/env_var.h" + +DECLARE_int64(kdnn_num_threads); namespace kdnn { @@ -56,13 +73,8 @@ class KDNNThreadPool : public ThreadpoolWrapper { return; } - tensorflow::int64 env_threads = -1; - tensorflow::Status status = tensorflow::ReadInt64FromEnvVar("KDNN_NUM_THREADS", -1, &env_threads); - if (!status.ok()) { - LOG(WARNING) << "Parse env KDNN_NUM_THREADS failed, use default thread nums"; - } - if (env_threads > 0) { - num_threads_ = std::min(pool_threads, static_cast(env_threads)); + if (FLAGS_kdnn_num_threads > 0) { + num_threads_ = std::min(pool_threads, static_cast(FLAGS_kdnn_num_threads)); return; } @@ -72,4 +84,4 @@ class KDNNThreadPool : public ThreadpoolWrapper { } // namespace kdnn -#endif // TENSORFLOW_CORE_UTIL_KDNN_THREADPOOL_H_ +#endif // THIRD_PARTY_KDNN_KDNN_THREADPOOL_H_ -- Gitee