diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h index eb99a7ae990f1368b3ffa1bc216f3c83afa6f4d9..e84f56f51d026d2dd426af7119a3dfe75942a585 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h @@ -74,7 +74,7 @@ public: if (newProcess_ && outerIdx_ == batchOuter_ - 1) { const int32_t tail = inputBatchNum_ % batchA_; batchA_ = tail == 0 ? mainBatchInner_ : tail; - batchB_ = tail == 0 ? mainBatchInner_ : tail; + batchB_ = batchA_; batchNum_ = batchA_; batchCalcSize_ = batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreM() * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); @@ -120,7 +120,7 @@ public: __aicore__ inline int32_t GetBiasBatchSrcOffset() const { - return outerIdx_ * batchNum_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); + return outerIdx_ * mainBatchInner_ * MATMUL_MODULE(MatmulShapeInfo)->GetSingleCoreN(); } // Double Buffer Loop @@ -237,6 +237,7 @@ private: (batchNumA % batchNumB == 0 || batchNumB % batchNumA == 0)); batchA_ = batchNumA; batchB_ = batchNumB; + mainBatchInner_ = batchA_; return; } @@ -248,6 +249,7 @@ private: batchOuter_ = 1; batchA_ = layoutBatchNumA; batchB_ = layoutBatchNumB; + mainBatchInner_ = batchA_; return; } if (layoutBatchNumA >= layoutBatchNumB) { @@ -300,7 +302,7 @@ private: inputBatchNum_ = batchNumLarge; ASSERT(batchInner > 0); - newProcess_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0) && (inputBatchNum_ >= batchInner); + newProcess_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0); if (newProcess_) { mainBatchInner_ = batchInner; batchOuter_ = CeilT(batchNumLess, batchInner); @@ -320,12 +322,12 @@ private: __aicore__ inline void UpdateBatchNumParams() { batchNum_ = batchA_ > batchB_ ? batchA_ : batchB_; - if (!newProcess_ || batchA_ != batchB_) { - splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && - (batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; + if (batchOuter_ > 1 && batchA_ == batchB_) { + splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; splitBatchNum_ = batchNum_ / splitSize_; } else { - splitSize_ = (batchA_ >= DB_FACTOR) ? DB_FACTOR : 1; + splitSize_ = (batchNum_ >= DB_FACTOR) && (batchA_ % DB_FACTOR == 0) && + (batchB_ % DB_FACTOR == 0) ? DB_FACTOR : 1; splitBatchNum_ = batchNum_ / splitSize_; } } diff --git a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h index c9e86c4975afa8020f10c41c5c8a5cb61e0f537c..d4047712a53c3e13f680f36628d8d4450530abf6 100644 --- a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h +++ b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h @@ -255,19 +255,21 @@ private: auto alignWidth = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), c0Size_); // 2. Calculate src and dst stride of one step - auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() / splitSize; + auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); + int32_t tmpBatchNum = batchNum / splitSize; int64_t srcStride = alignWidth * alignHeight; int64_t dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = batchNum * splitIdx * srcStride; - int64_t dstOffset = batchNum * splitIdx * dstStride; + int64_t srcOffset = tmpBatchNum * splitIdx * srcStride; + int64_t dstOffset = tmpBatchNum * splitIdx * dstStride; + auto paramsBatchNum = splitIdx == 0 ? tmpBatchNum : batchNum - tmpBatchNum; // 3. loop copy NZ data by batch bool iskRowDirec = IS_KROW && IsSupportB8(); - auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() * srcStride; + auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchMainBlock() * srcStride; GlobalTensor srcGlobal; srcGlobal.SetGlobalBuffer(MATMUL_MODULE(MatmulTensorInfo)->GetGlobalTensor().address_); srcGlobal.SetAddr(batchOffset); - for (int i = 0; i < batchNum; ++i) { + for (int i = 0; i < paramsBatchNum; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyNZ2NZ( dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, alignHeight, alignWidth, alignHeight, iskRowDirec); diff --git a/tests/matmul/iterator/test_batch_loop.cpp b/tests/matmul/iterator/test_batch_loop.cpp index 7f7418055335952c8fccc8870aa9bb87632ca18d..4c5ae929572793ac55e41f53c923202ba778c97d 100644 --- a/tests/matmul/iterator/test_batch_loop.cpp +++ b/tests/matmul/iterator/test_batch_loop.cpp @@ -117,7 +117,7 @@ TEST_F(TestBatchLoop, batch_loop) { EXPECT_EQ(mm.GetOuterIndex(), 1); EXPECT_EQ(mm.GetDstOffset(), 33264); EXPECT_EQ(mm.GetBatchNum(), 3); - EXPECT_EQ(mm.GetBiasBatchSrcOffset(), 231); + EXPECT_EQ(mm.GetBiasBatchSrcOffset(), 0); EXPECT_EQ(mm.GetSplitIndex(), 1); EXPECT_EQ(mm.GetSplitSize(), 1); EXPECT_EQ(mm.GetSplitBatchNum(), 3); @@ -137,7 +137,7 @@ TEST_F(TestBatchLoop, batch_loop_db) { EXPECT_EQ(mm1.GetOuterIndex(), 1); EXPECT_EQ(mm1.GetDstOffset(), 49152); EXPECT_EQ(mm1.GetBatchNum(), 6); - EXPECT_EQ(mm1.GetBiasBatchSrcOffset(), 1536); + EXPECT_EQ(mm1.GetBiasBatchSrcOffset(), 0); EXPECT_EQ(mm1.GetSplitIndex(), 2); EXPECT_EQ(mm1.GetSplitSize(), 2); EXPECT_EQ(mm1.GetSplitBatchNum(), 3);