diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h index 684823a87c8729b4aeba624074aa568adc220c63..5b204767d270029f7bc8059f62d7c37385cb9f08 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h @@ -81,7 +81,7 @@ public: if (newProcess_ && outerIdx_ == batchOuter_ - 1) { const int32_t tail = inputBatchNum_ % batchA_; batchA_ = tail == 0 ? mainBatchInner_ : tail; - batchB_ = tail == 0 ? mainBatchInner_ : tail; + batchB_ = batchA_; batchNum_ = batchA_; batchCalcSize_ = batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreM() * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); @@ -143,7 +143,7 @@ public: __aicore__ inline int32_t GetBiasBatchSrcOffset() const { - return outerIdx_ * batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + return outerIdx_ * mainBatchInner_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); } // Double Buffer Loop @@ -290,6 +290,7 @@ private: (batchNumA % batchNumB == 0 || batchNumB % batchNumA == 0)); batchA_ = batchNumA; batchB_ = batchNumB; + mainBatchInner_ = batchA_; return; } @@ -301,6 +302,7 @@ private: batchOuter_ = 1; batchA_ = layoutBatchNumA; batchB_ = layoutBatchNumB; + mainBatchInner_ = batchA_; return; } if (layoutBatchNumA >= layoutBatchNumB) { @@ -353,7 +355,7 @@ private: inputBatchNum_ = batchNumLarge; ASSERT(batchInner > 0); - newProcess_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0) && (inputBatchNum_ >= batchInner); + newProcess_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0); if (newProcess_) { mainBatchInner_ = batchInner; batchOuter_ = CeilT(batchNumLess, batchInner); @@ -374,12 +376,12 @@ private: { batchNum_ = batchA_ > batchB_ ? batchA_ : batchB_; if constexpr (!IsBmmDoubleBuffer()) { - if (!newProcess_ || batchA_ != batchB_) { - splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && - (batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; + if (batchOuter_ > 1 && batchA_ == batchB_) { + splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; splitBatchNum_ = batchNum_ / splitSize_; } else { - splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; + splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && + (batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; splitBatchNum_ = batchNum_ / splitSize_; } } diff --git a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h index fa2adcbd91ed331d097e946ecef80c9fd6ad40b5..2969c0abc33a93248707c117e7167f670999518b 100644 --- a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h +++ b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h @@ -362,19 +362,22 @@ private: auto alignWidth = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), c0Size_); // 2. Calculate src and dst stride of one step - auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() / splitSize; + auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); + int32_t tmpBatchNum = batchNum / splitSize; + int64_t srcStride = alignWidth * alignHeight; int64_t dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = batchNum * splitIdx * srcStride; - int64_t dstOffset = batchNum * splitIdx * dstStride; + int64_t srcOffset = tmpBatchNum * splitIdx * srcStride; + int64_t dstOffset = tmpBatchNum * splitIdx * dstStride; + auto paramsBatchNum = splitIdx == 0 ? tmpBatchNum : batchNum - tmpBatchNum; // 3. loop copy NZ data by batch bool iskRowDirec = IS_KROW && IsSupportB8(); - auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() * srcStride; + auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchMainBlock() * srcStride; GlobalTensor srcGlobal; srcGlobal.SetGlobalBuffer(MATMUL_MODULE(MatmulTensorInfo)->GetGlobalTensor().address_); srcGlobal.SetAddr(batchOffset); - for (int i = 0; i < batchNum; ++i) { + for (int i = 0; i < paramsBatchNum; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyNZ2NZ( dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, alignHeight, alignWidth, alignHeight, iskRowDirec);