diff --git a/src/kernels/kernels/matmul/tiling/pp_matmul_info.h b/src/kernels/kernels/matmul/tiling/pp_matmul_info.h index dd6c36f15d40dd6b4fcc14dc124625fd80c24150..099a83589ff0a5270180c9ad57624ae11980b005 100644 --- a/src/kernels/kernels/matmul/tiling/pp_matmul_info.h +++ b/src/kernels/kernels/matmul/tiling/pp_matmul_info.h @@ -181,10 +181,12 @@ struct PpTilingData { uint32_t splitk{0}; uint32_t enShuffleK{0}; uint32_t quantMode{0}; + uint32_t enUnitFlag{0}; void SetBaseShape(uint32_t batchSize, uint32_t m, uint32_t k, uint32_t n); void SetBaseOp(uint32_t coreNum, uint32_t mBase, uint32_t nBase, const MatMulInfo &mmInfo); void SetTilingKey(const MatMulInfo &mmInfo, uint32_t swizzleDirect, uint32_t enSplitK); + void SetUnitFlagSwitch(); uint32_t End(const MatMulInfo &mmInfo); }; } // namespace AsdOps diff --git a/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp b/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp index 6ee0dc68984f8c9f3377275ce81634bfa62f1cb7..99d140773130e443559fb7f1ea3c789cb855f121 100644 --- a/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp +++ b/src/kernels/kernels/matmul/tiling/pp_matmul_tiling.cpp @@ -25,6 +25,7 @@ constexpr uint32_t CONST_16 = 16; constexpr uint32_t CONST_32 = 32; constexpr uint32_t CONST_256 = 256; constexpr uint32_t CONST_512 = 512; +constexpr float MIN_COMPUTE_INTENSITY = 2000.0f; const std::map G_DTYPE_MAP = { {TENSOR_DTYPE_INT8, 0u}, {TENSOR_DTYPE_FLOAT16, 1u}, {TENSOR_DTYPE_BF16, 2u}, {TENSOR_DTYPE_FLOAT, 3u}}; @@ -128,6 +129,13 @@ uint32_t PpTilingData::End(const MatMulInfo &mmInfo) return blockDim; } +void PpTilingData::SetUnitFlagSwitch() +{ + float computeIntensity = static_cast(opShape.m * opShape.n) / static_cast(opShape.m + opShape.n); + enUnitFlag = static_cast(computeIntensity > MIN_COMPUTE_INTENSITY); + return; +} + void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uint32_t &blockDim, PpTilingData &tilingData) { @@ -147,6 +155,7 @@ void GetPpMatmulTiling(const MatMulInfo &mmInfo, const HardwareInfo &hwInfo, uin uint32_t direct = Swizzl(tilingData); blockDim = tilingData.End(mmInfo); tilingData.SetTilingKey(mmInfo, direct, 0); + tilingData.SetUnitFlagSwitch(); } void PrintPpMatmulTiling(const KernelInfo &kernelInfo) @@ -166,6 +175,7 @@ void PrintPpMatmulTiling(const KernelInfo &kernelInfo) MKI_LOG(INFO) << "swizzlDirect = " << tilingData->swizzlDirect; MKI_LOG(INFO) << "enShuffleK = " << tilingData->enShuffleK; MKI_LOG(INFO) << "quantMode = " << tilingData->quantMode; + MKI_LOG(INFO) << "enUnitFlag = " << tilingData->enUnitFlag; return; } diff --git a/src/kernels/kernels/matmul/tiling/tiling_data.h b/src/kernels/kernels/matmul/tiling/tiling_data.h index 57fbb2785510089517e562041c29d43109df90ff..885f02c09cde701f710cfe8f9efdaadc2bb65186 100644 --- a/src/kernels/kernels/matmul/tiling/tiling_data.h +++ b/src/kernels/kernels/matmul/tiling/tiling_data.h @@ -32,6 +32,7 @@ struct PpMatmulTilingData { uint32_t splitk{0}; uint32_t enShuffleK{0}; uint32_t quantMode{0}; + uint32_t enUnitFlag{0}; }; struct GemvMatmulTilingData {