diff --git a/impl/matmul/policy/matmul_private_modules.h b/impl/matmul/policy/matmul_private_modules.h index 996f48cf7d4b9b2df656c4e152967ce4faf2ce4d..53f5f2892ed452cff7aa4e2a91531f7a1326715b 100644 --- a/impl/matmul/policy/matmul_private_modules.h +++ b/impl/matmul/policy/matmul_private_modules.h @@ -87,7 +87,7 @@ struct MatmulPrivateModules { using MatmulUserDefineInfo = AscendC::Impl::Detail::MatmulUserDefineInfo; using MatmulUnitFlag = AscendC::Impl::Detail::MatmulUnitFlag; - using BatchLoop = AscendC::Impl::Detail::BatchLoop, MM_CFG>; + using BatchLoop = AscendC::Impl::Detail::BatchLoop, BIAS_TYPE, MM_CFG>; using CopyCubeOutUtils = AscendC::Impl::Detail::CopyCubeOutWrapper; // using compute modules diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_intf.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_intf.h index 49b53b7856fb93f81980fa24b1459900303b2660..7dc0c164defb9041002547c32d48d0c5778571ae 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_intf.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_intf.h @@ -24,7 +24,7 @@ namespace Detail { We retain the freedom to make incompatible changes, but do not guarantee the stability. BatchLoop is only for internal usage, does not support extension or customized specialization! */ -template +template class BatchLoop { public: diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h index 18c98b414eeca1e930d05f754c1bc2edef371b22..38b0a7e8768d1f4e9f8d2fa4f846368f419fad06 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_multi.h @@ -27,14 +27,15 @@ namespace Detail { We retain the freedom to make incompatible changes, but do not guarantee the stability. BatchLoop is only for internal usage, does not support extension or customized specialization! */ -template -class BatchLoop +class BatchLoop() == Impl::Detail::CopyCubeInType::BMM) || (Impl::Detail::IsBMMFromL1())>> { MATMUL_USE_MODULE(MatmulShapeTiling); MATMUL_USE_MODULE(MatmulShapeInfo); using SrcT = typename INPUT_TYPE::T; + using BiasT = typename BIAS_TYPE::T; public: __aicore__ inline BatchLoop() = default; @@ -71,7 +72,7 @@ public: { outerIdx_++; dstOffset_ += batchCalcSize_; - if (newProcess_ && outerIdx_ == batchOuter_ - 1) { + if (oddAndLargeThanL1_ && outerIdx_ == batchOuter_ - 1) { const int32_t tail = inputBatchNum_ % batchA_; batchA_ = tail == 0 ? mainBatchInner_ : tail; batchB_ = batchA_; @@ -127,10 +128,11 @@ public: __aicore__ inline void SplitStart() { // Check that the total amount of data to be transferred is less than L1. - ASSERT((batchA_ * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreM() * - MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreK() + - batchB_ * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * - MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreK()) * sizeof(SrcT) <= TOTAL_L1_SIZE); + const auto &tiling = MATMUL_MODULE(MatmulShapeTiling)->GetTiling(); + ASSERT((batchA_ * tiling.GetSingleCoreM() * tiling.GetSingleCoreK() + batchB_ * tiling.GetSingleCoreN() * + tiling.GetSingleCoreK()) * sizeof(SrcT) + tiling.IsBias() * tiling.GetSingleCoreN() * + sizeof(BiasT) <= TOTAL_L1_SIZE); + splitOuterIdx_ = 0; splitBatchIdx_ = 0; } @@ -176,7 +178,7 @@ public: __aicore__ inline bool InnerEnd() { - if ((!newProcess_) || (batchNum_ % DB_FACTOR == 0) || (splitSize_ < DB_FACTOR)) { + if ((!oddAndLargeThanL1_) || (batchNum_ % DB_FACTOR == 0) || (splitSize_ < DB_FACTOR)) { return (innerIdx_ >= splitBatchNum_) || (splitOuterIdx_ * splitBatchNum_ >= batchNum_); } const auto firstBatchNum = batchNum_ / splitSize_; @@ -245,7 +247,9 @@ private: (layoutBatchNumA % layoutBatchNumB == 0 || layoutBatchNumB % layoutBatchNumA == 0)); int32_t aMatrixSingleBatchSize = GetSingleSizeAlignA(); int32_t bMatrixSingleBatchSize = GetSingleSizeAlignB(); - if ((layoutBatchNumA * aMatrixSingleBatchSize + layoutBatchNumB * bMatrixSingleBatchSize) <= TOTAL_L1_SIZE) { + if ((layoutBatchNumA * aMatrixSingleBatchSize + layoutBatchNumB * bMatrixSingleBatchSize + + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT)) <= TOTAL_L1_SIZE) { batchOuter_ = 1; batchA_ = layoutBatchNumA; batchB_ = layoutBatchNumB; @@ -296,13 +300,15 @@ private: int32_t largeMatrixSingleBatchSize, int32_t lessMatrixSingleBatchSize) { int32_t multiples = batchNumLarge / batchNumLess; - int32_t singleBatchSize = multiples * largeMatrixSingleBatchSize + lessMatrixSingleBatchSize; + int32_t singleBatchSize = multiples * largeMatrixSingleBatchSize + lessMatrixSingleBatchSize + + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().IsBias() * + MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN() * sizeof(BiasT); int32_t batchInner = TOTAL_L1_SIZE / singleBatchSize; inputBatchNum_ = batchNumLarge; ASSERT(batchInner > 0); - newProcess_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0); - if (newProcess_) { + oddAndLargeThanL1_ = (multiples == 1) && (inputBatchNum_ % DB_FACTOR != 0); + if (oddAndLargeThanL1_) { mainBatchInner_ = batchInner; batchOuter_ = CeilT(batchNumLess, batchInner); batchA_ = batchInner; @@ -369,7 +375,7 @@ private: int32_t batchOutOffsetNum_ = 0; int32_t inputBatchNum_ = 0; - bool newProcess_ = false; // new logical judgment condition for handling odd batchNum + bool oddAndLargeThanL1_ = false; // new logical judgment condition for handling odd batchNum && large than L1 int32_t mainBatchInner_ = 0; // outerLoop main block }; } // namespace Detail diff --git a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_single.h b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_single.h index 1b9cfce3a31a7f1960b183b4e0f745ba6bf0d3ff..1486c3e499007d89ac977f9cd95be1a26f84de74 100644 --- a/impl/matmul/scheduler/iterator/batch_loop/batch_loop_single.h +++ b/impl/matmul/scheduler/iterator/batch_loop/batch_loop_single.h @@ -27,8 +27,8 @@ namespace Detail { We retain the freedom to make incompatible changes, but do not guarantee the stability. BatchLoop is only for internal usage, does not support extension or customized specialization! */ -template -class BatchLoop +class BatchLoop> { diff --git a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h index 31edec87ff4b9d429dc911f56efb766f6b4df778..33994ef2645d389ad747e2a331dcd6ed7102f161 100644 --- a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h +++ b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in.h @@ -80,7 +80,7 @@ private: // Calculate iter numbers by line of BSNGD layout int32_t batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); // batchA_ or batchB_ int32_t iterNum = 1; - int32_t tmpBatchNum = batchNum / splitSize; + int32_t batchNumIdx = batchNum / splitSize; UpdataBatchNum(batchNum, iterNum); // Calculate srcDValue for ND copy @@ -90,13 +90,14 @@ private: // if user input matrixStride, use matrixStride as srcStride auto srcStride = matrixStride != 0 ? matrixStride : GetSrcStride(); auto dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = tmpBatchNum * splitIdx * srcStride; - int64_t dstOffset = tmpBatchNum * splitIdx * dstStride; - auto paramsBatchNum = splitIdx == 0 ? tmpBatchNum : batchNum - tmpBatchNum; + int64_t srcOffset = batchNumIdx * splitIdx * srcStride; + int64_t dstOffset = batchNumIdx * splitIdx * dstStride; + // if odd ground, the first block is unequal with the second block + auto batchBlock = splitIdx == 0 ? batchNumIdx : batchNum - batchNumIdx; // Calculate src and dst stride of one line - auto iterSrcStride = paramsBatchNum * GetSingleSize(); - auto iterDstStride = paramsBatchNum * GetSingleSize(); + auto iterSrcStride = batchBlock * GetSingleSize(); + auto iterDstStride = batchBlock * GetSingleSize(); // Complete datacopy by line GlobalTensor srcGlobal; @@ -104,7 +105,7 @@ private: srcGlobal.SetAddr(batchOffset); for (int32_t idx = 0; idx < iterNum; ++idx) { if (srcStride >= UINT16_MAX) { - for (int i = 0; i < paramsBatchNum; ++i) { + for (int i = 0; i < batchBlock; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyND2NZ( dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), @@ -117,7 +118,7 @@ private: dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), - srcDValue, paramsBatchNum, srcStride, dstStride); + srcDValue, batchBlock, srcStride, dstStride); } dstOffset += iterDstStride; srcOffset += iterSrcStride; @@ -256,13 +257,13 @@ private: // 2. Calculate src and dst stride of one step auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); - int32_t tmpBatchNum = batchNum / splitSize; + int32_t batchNumIdx = batchNum / splitSize; int64_t srcStride = alignWidth * alignHeight; int64_t dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = tmpBatchNum * splitIdx * srcStride; - int64_t dstOffset = tmpBatchNum * splitIdx * dstStride; - auto paramsBatchNum = splitIdx == 0 ? tmpBatchNum : batchNum - tmpBatchNum; + int64_t srcOffset = batchNumIdx * splitIdx * srcStride; + int64_t dstOffset = batchNumIdx * splitIdx * dstStride; + auto batchBlock = splitIdx == 0 ? batchNumIdx : batchNum - batchNumIdx; // 3. loop copy NZ data by batch bool iskRowDirec = IS_KROW && IsSupportB8(); @@ -270,7 +271,7 @@ private: GlobalTensor srcGlobal; srcGlobal.SetGlobalBuffer(MATMUL_MODULE(MatmulTensorInfo)->GetGlobalTensor().address_); srcGlobal.SetAddr(batchOffset); - for (int i = 0; i < paramsBatchNum; ++i) { + for (int i = 0; i < batchBlock; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyNZ2NZ( dstTensor[dstOffset], srcGlobal[srcOffset], 0, 0, alignHeight, alignWidth, alignHeight, iskRowDirec); diff --git a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_using_ub.h b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_using_ub.h index 17c68246d4cbcc3c59c9b7b34f2eb35a4cc3e4b3..2f8f5b5126f2f9f56f2d3a3448df8dd8ff99e9e4 100644 --- a/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_using_ub.h +++ b/impl/matmul/stage/copy_cube_in/batch/batch_copy_cube_in_using_ub.h @@ -79,13 +79,16 @@ private: const int32_t outerIdx, const int32_t splitIdx, const int32_t splitSize) { // 1. calculate src stride and dst stride by db split loop index - auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() / splitSize; + int32_t batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); + int32_t batchNumIdx = batchNum / splitSize; auto alignWidth = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), c0Size_); auto alignHeight = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), BLOCK_CUBE); auto srcStride = GetSingleSize(); auto dstStride = alignWidth * alignHeight; - uint64_t srcOffset = batchNum * splitIdx * srcStride; - uint64_t dstOffset = batchNum * splitIdx * dstStride; + uint64_t srcOffset = batchNumIdx * splitIdx * srcStride; + uint64_t dstOffset = batchNumIdx * splitIdx * dstStride; + // if odd ground, the first block is unequal with the second block + auto batchBlock = splitIdx == 0 ? batchNumIdx : batchNum - batchNumIdx; // 2. copy batch matrix in int64_t batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() * srcStride; @@ -94,16 +97,16 @@ private: srcGlobal.SetAddr(batchOffset + srcOffset); if constexpr (ToMatmulConfig(MM_CFG).enVecND2NZ) { CopyND2NZThroughVec( - dstTensor[dstOffset], srcGlobal, batchNum, outerIdx, splitIdx, alignHeight, alignWidth); + dstTensor[dstOffset], srcGlobal, batchBlock, outerIdx, splitIdx, alignHeight, alignWidth); } else { if constexpr (isKRow) { MATMUL_MODULE(DataCopyWrapper)->CopyND2NZOnTheFly( dstTensor[dstOffset], srcGlobal, 0, 0, MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), - batchNum * alignWidth, batchNum * alignWidth); + batchBlock * alignWidth, batchBlock * alignWidth); } else { MATMUL_MODULE(DataCopyWrapper)->CopyND2NZOnTheFly( - dstTensor[dstOffset], srcGlobal, 0, 0, batchNum * alignHeight, + dstTensor[dstOffset], srcGlobal, 0, 0, batchBlock * alignHeight, MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), MATMUL_MODULE(CopyCubeInParams)->template GetOrgWidth()); } @@ -303,14 +306,18 @@ private: // 1. Calculate batch outer loop offset auto alignHeight = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleHeight(), BLOCK_CUBE); auto alignWidth = CeilAlign(MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleWidth(), c0Size_); - auto batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() / splitSize; + int32_t batchNum = MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum(); + int32_t batchNumIdx = batchNum / splitSize; bool iskRowDirec = isKRow && IsSameTypeV; // 2. Calculate src and dst stride of one step auto srcStride = alignWidth * alignHeight; auto dstStride = MATMUL_MODULE(BatchCopyCubeInParams)->template GetSingleSizeAlign(); - int64_t srcOffset = batchNum * splitIdx * srcStride; - int64_t dstOffset = batchNum * splitIdx * dstStride; + int64_t srcOffset = batchNumIdx * splitIdx * srcStride; + int64_t dstOffset = batchNumIdx * splitIdx * dstStride; + // if odd ground, the first block is unequal with the second block + auto batchBlock = splitIdx == 0 ? batchNumIdx : batchNum - batchNumIdx; + // 3. set input srctensor addr auto batchOffset = outerIdx * MATMUL_MODULE(BatchCopyCubeInParams)->GetBatchNum() * srcStride; @@ -326,7 +333,7 @@ private: } // 4. loop copy NZ data by batch - for (auto i = 0; i < batchNum; ++i) { + for (auto i = 0; i < batchBlock; ++i) { MATMUL_MODULE(DataCopyWrapper)->CopyNZ2NZ(dstTensor[dstOffset], srcTensor[srcOffset], 0, 0, alignHeight, alignWidth, alignHeight, iskRowDirec); dstOffset += dstStride; diff --git a/tests/matmul/iterator/test_batch_loop.cpp b/tests/matmul/iterator/test_batch_loop.cpp index 4c5ae929572793ac55e41f53c923202ba778c97d..00f95f97e31179103adcb84e724e5a936ca3562d 100644 --- a/tests/matmul/iterator/test_batch_loop.cpp +++ b/tests/matmul/iterator/test_batch_loop.cpp @@ -33,7 +33,7 @@ template { public: - using BatchLoop = Impl::Detail::BatchLoop, MM_CFG>; + using BatchLoop = Impl::Detail::BatchLoop, BIAS_TYPE, MM_CFG>; }; template { public: - using BatchLoop = Impl::Detail::BatchLoop, MM_CFG>; + using BatchLoop = Impl::Detail::BatchLoop, BIAS_TYPE, MM_CFG>; }; template