diff --git a/impl/matmul/scheduler/batch/batch_scheduler.h b/impl/matmul/scheduler/batch/batch_scheduler.h index add2660f324c192a9a910ae05c7f5ca88d8274b6..32822d167b3cd564694e8902fa3d144c364363c5 100644 --- a/impl/matmul/scheduler/batch/batch_scheduler.h +++ b/impl/matmul/scheduler/batch/batch_scheduler.h @@ -61,89 +61,6 @@ public: __aicore__ inline BatchScheduler() = default; __aicore__ inline ~BatchScheduler() = default; - template - __aicore__ inline void ComputeInner(const T& dst, LocalTensor& a1, LocalTensor& b1, - LocalTensor& bias, bool enPartialSum, uint8_t enAtomic, - bool enSequentialWrite, BatchOffsetInfo& batchOffsetInfo, - BatchSchedulerContext& ctx, event_t eventIDMte2ToMte1, event_t eventIDMToMte1) - { - auto batchLoop = MATMUL_MODULE(BatchLoop); - for (batchLoop->InnerStart(); !batchLoop->InnerEnd(); batchLoop->InnerNext()) { - BASE_MODULE::isFirstIter_ = true; - if (batchOffsetInfo.setBiasFlag && (batchLoop->GetBatchIndex() % batchOffsetInfo.divisorBias == 1)) { - MATMUL_MODULE(BiasScheduler)->StopBias(bias); - } - UpdateOffset(batchOffsetInfo, ctx); - while (BASE_MODULE::MoveNext()) { // iterate - MATMUL_MODULE(CubeOutBuffer)->AllocTensor(); - ComputeBatch(a1, b1, bias, enPartialSum, ctx); - BatchScheduler::GetBatchResultImpl(dst, ctx, enAtomic, enSequentialWrite); - SetFlag(eventIDMToMte1); - WaitFlag(eventIDMToMte1); - } - EndIterate(); - } - } - - template - __aicore__ inline enable_if_t - ComputeSplit(const T& dst, LocalTensor& bias, bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite, - uint32_t matrixStrideA, uint32_t matrixStrideB, BatchOffsetInfo& batchOffsetInfo, - BatchSchedulerContext& ctx) - { - auto a1 = MATMUL_MODULE(BatchCopyCubeInA)->AllocTensor(); - auto b1 = MATMUL_MODULE(BatchCopyCubeInB)->AllocTensor(); - event_t eventIDMte2ToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1)); - event_t eventIDMToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1)); - auto batchLoop = MATMUL_MODULE(BatchLoop); - for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) { - MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, batchLoop->GetOuterIndex(), - batchLoop->GetSplitIndex(), batchLoop->GetSplitSize()); - MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, batchLoop->GetOuterIndex(), - batchLoop->GetSplitIndex(), batchLoop->GetSplitSize()); - SetFlag(eventIDMte2ToMte1); - WaitFlag(eventIDMte2ToMte1); - ComputeInner(dst, a1, b1, bias, enPartialSum, enAtomic, enSequentialWrite, batchOffsetInfo, ctx, - eventIDMte2ToMte1, eventIDMToMte1); - BASE_MODULE::End(); - } - MATMUL_MODULE(BatchCopyCubeInA)->BatchDestroy(a1); - MATMUL_MODULE(BatchCopyCubeInB)->BatchDestroy(b1); - } - - template - __aicore__ inline enable_if_t - ComputeSplit(const T& dst, LocalTensor& bias, bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite, - uint32_t matrixStrideA, uint32_t matrixStrideB, BatchOffsetInfo& batchOffsetInfo, - BatchSchedulerContext& ctx) - { - auto batchLoop = MATMUL_MODULE(BatchLoop); - for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) { - LocalTensor a1; - LocalTensor b1; - event_t eventIDMte2ToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1)); - event_t eventIDMToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1)); - auto splitIdxA = batchLoop->template GetSplitIndex(); - auto splitIdxB = batchLoop->template GetSplitIndex(); - MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, batchLoop->GetOuterIndex(), - splitIdxA, batchLoop->GetSplitSize()); - MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, batchLoop->GetOuterIndex(), - splitIdxB, batchLoop->GetSplitSize()); - SetFlag(eventIDMte2ToMte1); - WaitFlag(eventIDMte2ToMte1); - ComputeInner(dst, a1, b1, bias, enPartialSum, enAtomic, enSequentialWrite, batchOffsetInfo, ctx, - eventIDMte2ToMte1, eventIDMToMte1); - - MATMUL_MODULE(BatchCopyCubeInA)->BatchDestroy(a1); - MATMUL_MODULE(BatchCopyCubeInB)->BatchDestroy(b1); - MATMUL_MODULE(BiasScheduler)->End(); - MATMUL_MODULE(CubeOutBuffer)->Destroy(); - } - - MATMUL_MODULE(BatchCopyCubeInA)->Reset(); - MATMUL_MODULE(BatchCopyCubeInB)->Reset(); - } - template __aicore__ inline void Schedule(const T& dst, bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite, const uint32_t matrixStrideA, const uint32_t matrixStrideB, const uint32_t matrixStrideC) @@ -169,8 +86,80 @@ public: batchLoop->GetBatchNum(), batchLoop->GetBiasBatchSrcOffset()); } - ComputeSplit(dst, bias, enPartialSum, enAtomic, enSequentialWrite, matrixStrideA, matrixStrideB, - batchOffsetInfo, ctx); + if constexpr (ToMatmulConfig(MM_CFG).batchMode != BatchMode::BATCH_LESS_THAN_L1) { + auto a1 = MATMUL_MODULE(BatchCopyCubeInA)->AllocTensor(); + auto b1 = MATMUL_MODULE(BatchCopyCubeInB)->AllocTensor(); + event_t eventIDMte2ToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1)); + event_t eventIDMToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1)); + auto batchLoop = MATMUL_MODULE(BatchLoop); + for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) { + MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, batchLoop->GetOuterIndex(), + batchLoop->GetSplitIndex(), batchLoop->GetSplitSize()); + MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, batchLoop->GetOuterIndex(), + batchLoop->GetSplitIndex(), batchLoop->GetSplitSize()); + SetFlag(eventIDMte2ToMte1); + WaitFlag(eventIDMte2ToMte1); + auto batchLoop = MATMUL_MODULE(BatchLoop); + for (batchLoop->InnerStart(); !batchLoop->InnerEnd(); batchLoop->InnerNext()) { + BASE_MODULE::isFirstIter_ = true; + if (batchOffsetInfo.setBiasFlag && (batchLoop->GetBatchIndex() % batchOffsetInfo.divisorBias == 1)) { + MATMUL_MODULE(BiasScheduler)->StopBias(bias); + } + UpdateOffset(batchOffsetInfo, ctx); + while (BASE_MODULE::MoveNext()) { // iterate + MATMUL_MODULE(CubeOutBuffer)->AllocTensor(); + ComputeBatch(a1, b1, bias, enPartialSum, ctx); + BatchScheduler::GetBatchResultImpl(dst, ctx, enAtomic, enSequentialWrite); + SetFlag(eventIDMToMte1); + WaitFlag(eventIDMToMte1); + } + EndIterate(); + } + BASE_MODULE::End(); + } + MATMUL_MODULE(BatchCopyCubeInA)->BatchDestroy(a1); + MATMUL_MODULE(BatchCopyCubeInB)->BatchDestroy(b1); + } else { + auto batchLoop = MATMUL_MODULE(BatchLoop); + for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) { + LocalTensor a1; + LocalTensor b1; + event_t eventIDMte2ToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1)); + event_t eventIDMToMte1 = static_cast(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1)); + auto splitIdxA = batchLoop->template GetSplitIndex(); + auto splitIdxB = batchLoop->template GetSplitIndex(); + MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, batchLoop->GetOuterIndex(), + splitIdxA, batchLoop->GetSplitSize()); + MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, batchLoop->GetOuterIndex(), + splitIdxB, batchLoop->GetSplitSize()); + SetFlag(eventIDMte2ToMte1); + WaitFlag(eventIDMte2ToMte1); + auto batchLoop = MATMUL_MODULE(BatchLoop); + for (batchLoop->InnerStart(); !batchLoop->InnerEnd(); batchLoop->InnerNext()) { + BASE_MODULE::isFirstIter_ = true; + if (batchOffsetInfo.setBiasFlag && (batchLoop->GetBatchIndex() % batchOffsetInfo.divisorBias == 1)) { + MATMUL_MODULE(BiasScheduler)->StopBias(bias); + } + UpdateOffset(batchOffsetInfo, ctx); + while (BASE_MODULE::MoveNext()) { // iterate + MATMUL_MODULE(CubeOutBuffer)->AllocTensor(); + ComputeBatch(a1, b1, bias, enPartialSum, ctx); + BatchScheduler::GetBatchResultImpl(dst, ctx, enAtomic, enSequentialWrite); + SetFlag(eventIDMToMte1); + WaitFlag(eventIDMToMte1); + } + EndIterate(); + } + + MATMUL_MODULE(BatchCopyCubeInA)->BatchDestroy(a1); + MATMUL_MODULE(BatchCopyCubeInB)->BatchDestroy(b1); + MATMUL_MODULE(BiasScheduler)->End(); + MATMUL_MODULE(CubeOutBuffer)->Destroy(); + } + + MATMUL_MODULE(BatchCopyCubeInA)->Reset(); + MATMUL_MODULE(BatchCopyCubeInB)->Reset(); + } if constexpr (ToMatmulConfig(MM_CFG).isBiasBatch) { MATMUL_MODULE(BiasScheduler)->Destroy(bias); @@ -193,44 +182,39 @@ private: return batchOffsetInfo; } - template - __aicore__ inline enable_if_t - UpdateOffset(BatchOffsetInfo& batchOffsetInfo, BatchSchedulerContext& ctx) + __aicore__ inline void UpdateOffset(BatchOffsetInfo& batchOffsetInfo, BatchSchedulerContext& ctx) { - auto batchIndex = MATMUL_MODULE(BatchLoop)->GetBatchIndex(); - ctx.offsetA = batchOffsetInfo.alignA * - (batchIndex % batchOffsetInfo.modA + batchIndex / batchOffsetInfo.divisorA); - ctx.offsetB = batchOffsetInfo.alignB * - (batchIndex % batchOffsetInfo.modB + batchIndex / batchOffsetInfo.divisorB); - ctx.offsetBias = batchOffsetInfo.alignBias * - (batchIndex % batchOffsetInfo.modBias + batchIndex / batchOffsetInfo.divisorBias); - if constexpr (ToMatmulConfig(MM_CFG).isNBatchOut) { - MATMUL_MODULE(BatchLoop)->SetBatchOutCacheNum(MATMUL_MODULE(BatchLoop)->GetBatchOutCacheNum() + 1); - } - } - - template - __aicore__ inline enable_if_t - UpdateOffset(BatchOffsetInfo& batchOffsetInfo, BatchSchedulerContext& ctx) - { - auto batchAIndex = 0, batchBIndex = 0; - auto biasIndex = MATMUL_MODULE(BatchLoop)->GetBatchIndex(); - const auto& bL = MATMUL_MODULE(BatchLoop); - batchAIndex = bL->GetBatchA() <= bL->GetSplitBatchNum() ? bL->GetBatchIndex() - : bL->GetBatchIndex() % bL->GetSplitBatchNum(); - ctx.offsetA = batchOffsetInfo.alignA * - (batchAIndex % batchOffsetInfo.modA + batchAIndex / batchOffsetInfo.divisorA); - - batchBIndex = bL->GetBatchB() <= bL->GetSplitBatchNum() ? bL->GetBatchIndex() - : bL->GetBatchIndex() % bL->GetSplitBatchNum(); - ctx.offsetB = batchOffsetInfo.alignB * - (batchBIndex % batchOffsetInfo.modB + batchBIndex / batchOffsetInfo.divisorB); - - ctx.offsetBias = batchOffsetInfo.alignBias * - (biasIndex % batchOffsetInfo.modBias + biasIndex / batchOffsetInfo.divisorBias); - if constexpr (ToMatmulConfig(MM_CFG).isNBatchOut) { - bL->SetBatchOutCacheNum(bL->GetBatchOutCacheNum() + 1); + if constexpr (ToMatmulConfig(MM_CFG).batchMode != BatchMode::BATCH_LESS_THAN_L1) { + auto batchIndex = bL->GetBatchIndex(); + ctx.offsetA = batchOffsetInfo.alignA * + (batchIndex % batchOffsetInfo.modA + batchIndex / batchOffsetInfo.divisorA); + ctx.offsetB = batchOffsetInfo.alignB * + (batchIndex % batchOffsetInfo.modB + batchIndex / batchOffsetInfo.divisorB); + ctx.offsetBias = batchOffsetInfo.alignBias * + (batchIndex % batchOffsetInfo.modBias + batchIndex / batchOffsetInfo.divisorBias); + if constexpr (ToMatmulConfig(MM_CFG).isNBatchOut) { + bL->SetBatchOutCacheNum(bL->GetBatchOutCacheNum() + 1); + } + } else { + auto batchAIndex = 0, batchBIndex = 0; + auto biasIndex = bL->GetBatchIndex(); + + batchAIndex = bL->GetBatchA() <= bL->GetSplitBatchNum() ? bL->GetBatchIndex() + : bL->GetBatchIndex() % bL->GetSplitBatchNum(); + ctx.offsetA = batchOffsetInfo.alignA * + (batchAIndex % batchOffsetInfo.modA + batchAIndex / batchOffsetInfo.divisorA); + + batchBIndex = bL->GetBatchB() <= bL->GetSplitBatchNum() ? bL->GetBatchIndex() + : bL->GetBatchIndex() % bL->GetSplitBatchNum(); + ctx.offsetB = batchOffsetInfo.alignB * + (batchBIndex % batchOffsetInfo.modB + batchBIndex / batchOffsetInfo.divisorB); + + ctx.offsetBias = batchOffsetInfo.alignBias * + (biasIndex % batchOffsetInfo.modBias + biasIndex / batchOffsetInfo.divisorBias); + if constexpr (ToMatmulConfig(MM_CFG).isNBatchOut) { + bL->SetBatchOutCacheNum(bL->GetBatchOutCacheNum() + 1); + } } }