diff --git a/impl/matmul/tiling/matmul_tiling_algorithm.cpp b/impl/matmul/tiling/matmul_tiling_algorithm.cpp index c00919461081615599ca8f807d5fc95bdda28c7a..7d1fa79f4868241409239844f85b1673ae6abcb2 100644 --- a/impl/matmul/tiling/matmul_tiling_algorithm.cpp +++ b/impl/matmul/tiling/matmul_tiling_algorithm.cpp @@ -2283,7 +2283,7 @@ bool MatmulTilingAlgorithm::CheckFinaleParams(const CoreStatusPack& coreStatus) } int dateDtypeSize = DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType); - if (tilingIns_->tiling_.get_BatchNum() > 0 && + if (!tilingIns_->isBMNKBmm && tilingIns_->tiling_.get_BatchNum() > 0 && ((tilingIns_->tiling_.get_singleCoreM() * tilingIns_->tiling_.get_singleCoreK() + tilingIns_->tiling_.get_singleCoreN() * tilingIns_->tiling_.get_singleCoreK()) * tilingIns_->tiling_.get_BatchNum() * dateDtypeSize / BITS_PER_BYTE + diff --git a/impl/matmul/tiling/matmul_tiling_base.cpp b/impl/matmul/tiling/matmul_tiling_base.cpp index c4011e1ab00725c44271fae3583b245d967cb587..40dadd8388732cc5ff006932d2664f01e11102d6 100644 --- a/impl/matmul/tiling/matmul_tiling_base.cpp +++ b/impl/matmul/tiling/matmul_tiling_base.cpp @@ -401,6 +401,7 @@ int32_t MatmulApiTilingBase::SetBatchInfoForNormal(int32_t batchA, int32_t batch this->cLayoutInfoN = 1; this->cLayoutInfoG = 1; this->cLayoutInfoS2 = n; + this->isBMNKBmm = true; return 0; } diff --git a/lib/matmul/matmul_tiling_base.h b/lib/matmul/matmul_tiling_base.h index b6708ce5ee31cc1af0bdb5de6e443f9f44166940..ccca8905dce43145dabc96fdf21c5cf69cf99136 100644 --- a/lib/matmul/matmul_tiling_base.h +++ b/lib/matmul/matmul_tiling_base.h @@ -335,6 +335,7 @@ public: int32_t mmConfigType = 1; // 0: Norm; 1: MDL bool enableL1CacheUB = false; bool enVecND2NZ = false; + bool isBMNKBmm = false; protected: virtual int64_t Compute() = 0;