diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h index 1a8379e17fe6649c726c52bc0d56ae326e12d9fe..48f391b0ebd676aa983523ed88e23561779fd2ed 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h @@ -71,6 +71,16 @@ public: { outerIdx_++; dstOffset_ += batchCalcSize_; + if (newProcess_ && outerIdx_ == batchOuter_ - 1) { + const int32_t tail = inputBatchNum_ % batchA_; + batchA_ = tail == 0 ? mainBatchInner_ : tail; + batchB_ = batchA_ + batchNum_ = batchA_; + batchCalcSize_ = batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreM() * + MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; + splitBatchNum_ = batchNum_ / splitSize_; + } } __aicore__ inline bool OuterEnd() @@ -78,6 +88,11 @@ public: return outerIdx_ >= batchOuter_; } + __aicore__ inline int32_t GetMainBatchBlock() const + { + return mainBatchInner_; // batchNum main block in outLoop + } + __aicore__ inline uint32_t GetOuterIndex() const { return outerIdx_; @@ -105,7 +120,7 @@ public: __aicore__ inline int32_t GetBiasBatchSrcOffset() const { - return outerIdx_ * batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + return outerIdx_ * mainBatchInner_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); } // Double Buffer Loop @@ -161,7 +176,15 @@ public: __aicore__ inline bool InnerEnd() { - return innerIdx_ >= splitBatchNum_ || splitOuterIdx_ * splitBatchNum_ >= batchNum_; + if ((!newProcess_) || (batchNum_ % DB_FACTOR == 0) || (splitSize_ < DB_FACTOR)) { + return (innerIdx_ >= splitBatchNum_) || (splitOuterIdx_ * splitBatchNum_ >= batchNum_); + } + const auto firstBatchNum = batchNum_ / splitSize_; + if (splitOuterIdx_ < 1) { + return innerIdx_ >= firstBatchNum; + } else { + return innerIdx_ >= batchNum_ - firstBatchNum; + } } __aicore__ inline uint32_t GetInnerIndex() const @@ -274,26 +297,42 @@ private: int32_t multiples = batchNumLarge / batchNumLess; int32_t singleBatchSize = multiples * largeMatrixSingleBatchSize + lessMatrixSingleBatchSize; int32_t batchInner = TOTAL_L1_SIZE / singleBatchSize; + inputBatchNum_ = batchNumLarge; + ASSERT(batchInner > 0); - while (batchNumLess % batchInner != 0 && batchInner > 0) { - --batchInner; + newProcess_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0); + if (newProcess_) { + mainBatchInner_ = batchInner; + batchOuter_ = CeilT(batchNumLess, batchInner); + batchA_ = batchInner; + batchB_ = batchInner; + } else { + while (batchNumLess % batchInner != 0 && batchInner > 0) { + --batchInner; + } + mainBatchInner_ = batchInner; + batchOuter_ = batchNumLess / batchInner; + batchA_ = multiples * batchInner; + batchB_ = batchInner; } - batchOuter_ = batchNumLess / batchInner; - batchA_ = multiples * batchInner; - batchB_ = batchInner; } __aicore__ inline void UpdateBatchNumParams() { batchNum_ = batchA_ > batchB_ ? batchA_ : batchB_; - splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && - (batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; - splitBatchNum_ = batchNum_ / splitSize_; + if (batchOuter_ > 1 && batchA_ == batchB_) { + splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; + splitBatchNum_ = batchNum_ / splitSize_; + } else { + splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && + (batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; + splitBatchNum_ = batchNum_ / splitSize_; + } } __aicore__ inline void UpdateSplitParams() { - splitBatchIdx_ += batchNum_ / splitSize_; + splitBatchIdx_ += splitBatchNum_; } __aicore__ inline void UpdateInnerParams() @@ -301,9 +340,9 @@ private: innerBatchIdx_ = innerIdx_ + splitBatchIdx_; } - int32_t batchA_; - int32_t batchB_; - int32_t batchNum_; + int32_t batchA_; // outerLoop main/tail block + int32_t batchB_; // outerLoop main/tail block + int32_t batchNum_; // outerLoop main/tail block int32_t batchOuter_ = 1; constexpr static int32_t c0Size_ = AuxGetC0Size(); @@ -327,6 +366,10 @@ private: int32_t nBatchOutNum_ = 1; int32_t batchOutCacheNum_ = 0; int32_t batchOutOffsetNum_ = 0; + + int32_t inputBatchNum_ = 0; + bool newProcess_ = false; // new logical judgment condition for handling odd batchNum + int32_t mainBatchInner_ = 0; // outerLoop main block }; } // namespace Detail } // namespace Impl diff --git a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h index 809aab80f7a98e406d164163f4f05a8cc3864b72..31edec87ff4b9d429dc911f56efb766f6b4df778 100644 --- a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h +++ b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h @@ -75,13 +75,13 @@ private: // Calculate batch outer loop offset // the parameter false means don't need to use constant parameters int64_t batchOffset = outerIdx * GetSingleSize() * - MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); + MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchMainBlock(); // Calculate iter numbers by line of BSNGD layout int32_t batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); // batchA_ or batchB_ int32_t iterNum = 1; + int32_t tmpBatchNum = batchNum / splitSize; UpdataBatchNum(batchNum, iterNum); - batchNum /= splitSize; // Calculate srcDValue for ND copy auto srcDValue = MATMUL_MODULE(BatchCopyCubeInParams)->template GetBatchOrgWidth(); @@ -90,12 +90,13 @@ private: // if user input matrixStride, use matrixStride as srcStride auto srcStride = matrixStride != 0 ? matrixStride : GetSrcStride(); auto dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = batchNum * splitIdx * srcStride; - int64_t dstOffset = batchNum * splitIdx * dstStride; + int64_t srcOffset = tmpBatchNum * splitIdx * srcStride; + int64_t dstOffset = tmpBatchNum * splitIdx * dstStride; + auto paramsBatchNum = splitIdx == 0 ? tmpBatchNum : batchNum - tmpBatchNum; // Calculate src and dst stride of one line - auto iterSrcStride = batchNum * GetSingleSize(); - auto iterDstStride = batchNum * GetSingleSize(); + auto iterSrcStride = paramsBatchNum * GetSingleSize(); + auto iterDstStride = paramsBatchNum * GetSingleSize(); // Complete datacopy by line GlobalTensor srcGlobal; @@ -103,7 +104,7 @@ private: srcGlobal.SetAddr(batchOffset); for (int32_t idx = 0; idx < iterNum; ++idx) { if (srcStride >= UINT16_MAX) { - for (int i = 0; i < batchNum; ++i) { + for (int i = 0; i < paramsBatchNum; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyND2NZ( dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), @@ -116,7 +117,7 @@ private: dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), - srcDValue, batchNum, srcStride, dstStride); + srcDValue, paramsBatchNum, srcStride, dstStride); } dstOffset += iterDstStride; srcOffset += iterSrcStride; @@ -254,19 +255,22 @@ private: auto alignWidth = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), c0Size_); // 2. Calculate src and dst stride of one step - auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() / splitSize; + auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); + int32_t tmpBatchNum = batchNum / splitSize; + int64_t srcStride = alignWidth * alignHeight; int64_t dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = batchNum * splitIdx * srcStride; - int64_t dstOffset = batchNum * splitIdx * dstStride; + int64_t srcOffset = tmpBatchNum * splitIdx * srcStride; + int64_t dstOffset = tmpBatchNum * splitIdx * dstStride; + auto paramsBatchNum = splitIdx == 0 ? tmpBatchNum : batchNum - tmpBatchNum; // 3. loop copy NZ data by batch bool iskRowDirec = IS_KROW && IsSupportB8(); - auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() * srcStride; + auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchMainBlock() * srcStride; GlobalTensor srcGlobal; srcGlobal.SetGlobalBuffer(MATMUL_MODULE(MatmulTensorInfo)->GetGlobalTensor().address_); srcGlobal.SetAddr(batchOffset); - for (int i = 0; i < batchNum; ++i) { + for (int i = 0; i < paramsBatchNum; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyNZ2NZ( dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, alignHeight, alignWidth, alignHeight, iskRowDirec); diff --git a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_params.h b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_params.h index fb2cfe19c6ecc0f8e4af10b21643a1aa60198ad9..6e3a04a93a165ec5d8f0856344b6d71a9f76409a 100644 --- a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_params.h +++ b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_params.h @@ -36,6 +36,11 @@ public: } } + __aicore__ inline uint32_t GetBatchMainBlock() + { + return MATMUL_MODULE(BatchLoop)->GetMainBatchBlock(); + } + template __aicore__ inline int32_t GetBatchOrgWidth() {