From 88771ca8da79e4d29eaec4b54d67e470315c5205 Mon Sep 17 00:00:00 2001 From: void-ptr Date: Wed, 25 Jun 2025 15:22:05 +0800 Subject: [PATCH] enable unitflag in mac-bound scenario. --- src/kernels/kernels/matmul/tiling/pp_matmul_info.h | 2 ++ src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp | 10 ++++++++++ src/kernels/kernels/matmul/tiling/tiling_data.h | 1 + 3 files changed, 13 insertions(+) diff --git a/src/kernels/kernels/matmul/tiling/pp_matmul_info.h b/src/kernels/kernels/matmul/tiling/pp_matmul_info.h index dd6c36f1..099a8358 100644 --- a/src/kernels/kernels/matmul/tiling/pp_matmul_info.h +++ b/src/kernels/kernels/matmul/tiling/pp_matmul_info.h @@ -181,10 +181,12 @@ struct PpTilingData { uint32_t splitk{0}; uint32_t enShuffleK{0}; uint32_t quantMode{0}; + uint32_t enUnitFlag{0}; void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n); void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo); void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK); + void SetUnitFlagSwitch(); uint32_t End(const MatMulInfo &mmInfo); }; } // namespace AsdOps diff --git a/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp b/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp index 6ee0dc68..99d14077 100644 --- a/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp +++ b/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp @@ -25,6 +25,7 @@ constexpr uint32_t CONST_16 = 16; constexpr uint32_t CONST_32 = 32; constexpr uint32_t CONST_256 = 256; constexpr uint32_t CONST_512 = 512; +constexpr float MIN_COMPUTE_INTENSITY = 2000.0f; const std::map G_DTYPE_MAP = { {TENSOR_DTYPE_INT8, 0u}, {TENSOR_DTYPE_FLOAT16, 1u}, {TENSOR_DTYPE_BF16, 2u}, {TENSOR_DTYPE_FLOAT, 3u}}; @@ -128,6 +129,13 @@ uint32_t PpTilingData::End(const MatMulInfo &mmInfo) return blockDim; } +void PpTilingData::SetUnitFlagSwitch() +{ + float computeIntensity = static_cast(opShape.m * opShape.n) / static_cast(opShape.m + opShape.n); + enUnitFlag = static_cast(computeIntensity > MIN_COMPUTE_INTENSITY); + return; +} + void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim, PpTilingData &tilingData) { @@ -147,6 +155,7 @@ void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uin uint32_t direct = Swizzl(tilingData); blockDim = tilingData.End(mmInfo); tilingData.SetTilingKey(mmInfo, direct, 0); + tilingData.SetUnitFlagSwitch(); } void PrintPpMatmulTiling(const KernelInfo &kernelInfo) @@ -166,6 +175,7 @@ void PrintPpMatmulTiling(const KernelInfo &kernelInfo) MKI_LOG(INFO) << "swizzlDirect = " << tilingData->swizzlDirect; MKI_LOG(INFO) << "enShuffleK = " << tilingData->enShuffleK; MKI_LOG(INFO) << "quantMode = " << tilingData->quantMode; + MKI_LOG(INFO) << "enUnitFlag = " << tilingData->enUnitFlag; return; } diff --git a/src/kernels/kernels/matmul/tiling/tiling_data.h b/src/kernels/kernels/matmul/tiling/tiling_data.h index 57fbb278..885f02c0 100644 --- a/src/kernels/kernels/matmul/tiling/tiling_data.h +++ b/src/kernels/kernels/matmul/tiling/tiling_data.h @@ -32,6 +32,7 @@ struct PpMatmulTilingData { uint32_t splitk{0}; uint32_t enShuffleK{0}; uint32_t quantMode{0}; + uint32_t enUnitFlag{0}; }; struct GemvMatmulTilingData { -- Gitee