diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.cpp b/impl/matmul/tiling/matmul_tiling_algorithm.cpp index a4f9b40d84cda0c55af91ebd61262aa074484eec..208654488bcfdbaab94590ed7b9a416c741fed1d 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.cpp +++ b/impl/matmul/tiling/matmul_tiling_algorithm.cpp @@ -450,13 +450,15 @@ void MatmulTilingAlgorithm::GetL0FactorsCand(L0Factors& resFactors, const CoreSt if (k0 == 0 || minorDimFactor * k0 > minorDimK || majorDimFactor * k0 > majorDimK) { continue; } + // when k_axis <= 8, adjust k0_factor to 2, to generate tiling baseK align to 16 - if (((tilingIns_->aType_.dataType == DataType::DT_FLOAT && tilingIns_->aType_.type == CubeFormat::NZ && + if (((tilingIns_->aType_.dataType == DataType::DT_FLOAT && tilingIns_->aType_.isTrans) || - (tilingIns_->bType_.dataType == DataType::DT_FLOAT && tilingIns_->bType_.type == CubeFormat::NZ && + (tilingIns_->bType_.dataType == DataType::DT_FLOAT && !tilingIns_->bType_.isTrans)) && k0 == 1) { k0 = NUM_TWO; } + // Check if the buffer size allocated exceed the hardware buffer size in Float Mode if (tilingIns_->aType_.dataType == DataType::DT_FLOAT) { int32_t mL0 = majorDimFactor; @@ -2982,13 +2984,13 @@ void MatmulTilingAlgorithm::CheckL0DB(SingleCoreStatus& singleCoreStatus, const tilingIns_->bType_.scalePos == TPosition::TSCM) { baseN = MathUtil::Align(singleCoreStatus.l0Status.nL0, L0_FACTOR_NUM_LIMIT) * C0_SIZE; } - if (baseM * baseK > tilingIns_->bufferPool_.l0ASize / DB_ON) { + if ((baseM * baseK * DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) / BITS_PER_BYTE) > (tilingIns_->bufferPool_.l0ASize / DB_ON)) { singleCoreStatus.l0Status.dbL0A = DB_OFF; } - if (baseN * baseK > tilingIns_->bufferPool_.l0BSize / DB_ON) { + if ((baseN * baseK * DTYPE_BIT_TAB.at(tilingIns_->bType_.dataType) / BITS_PER_BYTE) > (tilingIns_->bufferPool_.l0BSize / DB_ON)) { singleCoreStatus.l0Status.dbL0B = DB_OFF; } - if (baseM * baseN > tilingIns_->bufferPool_.l0CSize / DB_ON) { + if ((baseM * baseN * FP32_BYTES) > (tilingIns_->bufferPool_.l0CSize / DB_ON)) { singleCoreStatus.l0Status.dbL0C = DB_OFF; } } @@ -3093,10 +3095,14 @@ void MatmulTilingAlgorithm::AdjustMxL0Factors(SingleCoreStatus& singleCoreStatus void MatmulTilingAlgorithm::AdjustB32L0Factors(SingleCoreStatus& singleCoreStatus) const { const int32_t reduceSize = C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE; - if ((tilingIns_->aType_.dataType == DataType::DT_FLOAT && tilingIns_->aType_.isTrans) || - (tilingIns_->bType_.dataType == DataType::DT_FLOAT && !tilingIns_->bType_.isTrans)) { - singleCoreStatus.l0Status.kL0 = MathUtil::Align(singleCoreStatus.l0Status.kL0 * reduceSize, C0_SIZE) / reduceSize; - } + if ((tilingIns_->aType_.dataType == DataType::DT_FLOAT && tilingIns_->aType_.isTrans) || + (tilingIns_->bType_.dataType == DataType::DT_FLOAT && !tilingIns_->bType_.isTrans)) { + singleCoreStatus.l0Status.kL0 = MathUtil::Align(singleCoreStatus.l0Status.kL0 * reduceSize, C0_SIZE) / reduceSize; + } + int32_t baseK = + singleCoreStatus.l0Status.kL0 * (C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE); + // check L0A/L0B/L0CSize for L0DB + CheckL0DB(singleCoreStatus, baseK); } void MatmulTilingAlgorithm::AdjustMxL1Factors(SingleCoreStatus& singleCoreStatus, const int32_t k0Size) const