From 5cbf294a82cfd3cac91fb20fe306d6de8965624b Mon Sep 17 00:00:00 2001 From: Liexss Date: Thu, 15 May 2025 19:05:29 +0800 Subject: [PATCH 1/8] [feature]: support mla flashdeoding for tp8 --- .../multi_latent_attention/mla_kernel.cpp | 19 ++++++--------- .../multi_latent_attention/op_kernel/mla.cce | 21 ++++++++-------- .../tiling/mla_tiling.cpp | 24 +++++++++---------- .../tiling/mla_tiling_dependency.h | 2 +- 4 files changed, 30 insertions(+), 36 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp index ee0d14b7..6bd3f70e 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp @@ -52,19 +52,14 @@ public: "paged attention: param type invalid", return false); auto param = AnyCast(launchParam.GetParam()); auto batch = param.kvSeqLen.size(); - if (param.headSize == M_LIMIT) { - uint64_t taskNum = param.qSeqLen.data() == nullptr ? batch : - std::accumulate(param.qSeqLen.data(), - param.qSeqLen.data() + batch, static_cast(0)); - uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); - uint64_t bufferSize = - Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE_TP1 * (taskNum - 1) * sizeof(uint32_t) + - TILING_PARA_SIZE_TP1 * blockDim * blockDim * sizeof(uint32_t) + blockDim * 2 * sizeof(uint32_t), - TILINGMIN); - return bufferSize; - } + uint64_t taskNum = param.qSeqLen.data() == nullptr ? batch : + std::accumulate(param.qSeqLen.data(), + param.qSeqLen.data() + batch, static_cast(0)); + uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); uint64_t bufferSize = - Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE * (batch - 1) * sizeof(uint32_t), TILINGMIN); + Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE_TP1 * (taskNum - 1) * sizeof(uint32_t) + + TILING_PARA_SIZE_TP1 * blockDim * blockDim * sizeof(uint32_t) + blockDim * 2 * sizeof(uint32_t), + TILINGMIN); return bufferSize; } }; diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 92fcddcf..5cae4b16 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -761,7 +761,6 @@ private: embed_split_size, // k embed_split_idx == 0 // cmatrixInitVal ); - PIPE_BARRIER(M); SET_FLAG(M, MTE1, embed_split_idx % 2); SET_FLAG(M, MTE1, embed_split_idx % 2 + 2); @@ -1034,7 +1033,7 @@ private: uint32_t qk_c_pingpong_flag = (now_idx * cur_q_seqlen + q_i) % 2; uint32_t qk_0_pingpong_flag = (embed_split_idx * cur_q_seqlen + q_i) % 2; WAIT_FLAG(M, MTE1, qk_0_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff / BLOCK_SIZE; ++loa_load_idx) { + for (uint64_t loa_load_idx = 0; loa_load_idx < m / BLOCK_SIZE; ++loa_load_idx) { l1_to_l0_a( l0a_buf_tensor[qk_0_pingpong_flag * l0a_pingpong_size + loa_load_idx * round_embed_split_size * BLOCK_SIZE], l1q_buf_addr_tensor[q_i * q_heads * BLOCK_SIZE + embed_split_idx * m * round_embed_split_size + loa_load_idx * T_CUBE_MATRIX_SIZE], @@ -1119,7 +1118,7 @@ private: uint32_t qk_c_pingpong_flag = (now_idx * cur_q_seqlen + q_i) % 2; uint32_t qk_0_pingpong_flag = (embed_split_idx * cur_q_seqlen + q_i) % 2; WAIT_FLAG(M, MTE1, qk_0_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff / BLOCK_SIZE; ++loa_load_idx) { + for (uint64_t loa_load_idx = 0; loa_load_idx < m / BLOCK_SIZE; ++loa_load_idx) { l1_to_l0_a( l0a_buf_tensor[qk_0_pingpong_flag * 16384 + loa_load_idx * round_embed_split_size * BLOCK_SIZE], l1q_buf_addr_tensor[q_i * q_heads * BLOCK_SIZE + embed_split_idx * m * 128 + loa_load_idx * CUBE_MATRIX_SIZE], @@ -1618,7 +1617,7 @@ public: WaitFlagDev(reduce_flag_id); CombineScaleBlock(num_batches, q_heads, kv_split_core_num, embedding_size); } - } + } private: @@ -2262,10 +2261,10 @@ private: uint32_t i) { constexpr uint32_t ubuf_offset = 3 * STAGE2_UB_UINT8_BLOCK_SIZE; - constexpr uint32_t lo_ubuf_size = ubuf_offset + STAGE2_UB_UINT8_BLOCK_SIZE * 4; // 8 * 512 * doublebuffer * sizeof(float) - constexpr uint32_t to_ubuf_size = lo_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 4; // 8 * 512 * doublebuffer * sizeof(float) - constexpr uint32_t go_ubuf_size = to_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 2; // 8 * 512 * sizeof(float) - constexpr uint32_t go16_ubuf_size = go_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE; // 8 * 512 * sizeof(float) / 2 + constexpr uint32_t lo_ubuf_size = ubuf_offset + STAGE2_UB_UINT8_BLOCK_SIZE * 4; //8 * 512 * doublebuffer * sizeof(float) + constexpr uint32_t to_ubuf_size = lo_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 4; //8 * 512 * doublebuffer * sizeof(float) + constexpr uint32_t go_ubuf_size = to_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 2; //8 * 512 * sizeof(float) + constexpr uint32_t go16_ubuf_size = go_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE; //8 * 512 * sizeof(float) / 2 constexpr uint32_t lBrcb_size = FLOAT_VECTOR_SIZE * 8; AscendC::LocalTensor loTensor = buf.GetBuffer(ubuf_offset); AscendC::LocalTensor toTensor = buf.GetBuffer(lo_ubuf_size); @@ -2422,7 +2421,7 @@ private: uint32_t batchTaskNum = (uint32_t)(*((__gm__ uint32_t *)tiling_gm + 17)); uint32_t offset_tiling = tiling_head_size + tiling_para_size * (totalTaskNum + flashDecodingTaskNum); uint32_t q_seqlen = (uint32_t)(*((__gm__ uint32_t *)tiling_gm + 1 + tiling_head_size + tiling_para_size * totalTaskNum)); - uint32_t decodingQSeqLens = batchTaskNum * q_seqlen; + uint32_t decodingQSeqLens = decodingQSeqLens = batchTaskNum * q_seqlen; if constexpr (tilingKeyType == TilingKeyType::TILING_INT8_DATA) { q_seqlen = 1; } @@ -3423,7 +3422,7 @@ private: round_sub_m / FLOAT_BLOCK_SIZE // repeat ); PIPE_BARRIER(V); - if (head_loop > 1) { + if (head_loop > 1 || q_seq_len > 1) { gm_to_ub( go32_ubuf_tensor[oUbOffset], go_gm_tensor, @@ -3657,7 +3656,7 @@ private: } } } - else if (head_loop > 1) { + else if (head_loop > 1 || q_seq_len > 1) { SET_FLAG(V, MTE3, EVENT_ID5); WAIT_FLAG(V, MTE3, EVENT_ID5); ub_to_gm( diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index e37116fa..2b714d79 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -31,7 +31,7 @@ const int32_t NUM512 = 512; const int32_t NUM576 = 576; const float SPLITKV_SEQLEN = 2048; -int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int32_t blockSize) +int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int32_t maxKVSeqlen, int32_t blockSize) { if (blockDim - mmInfo.flashDecodingTaskNum <= NUM4 || mmInfo.quantFlag) { return NUM1; @@ -39,14 +39,14 @@ int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int if (blockSize == 0 || blockDim == 0) { return NUM1; } - int32_t minKVBlocks = (minKVSeqlen + blockSize - 1) / blockSize; + int32_t maxKVBlocks = (maxKVSeqlen + blockSize - 1) / blockSize; for (int32_t splitNum = 2; splitNum <= NUM6; splitNum++) { if ((mmInfo.flashDecodingTaskNum * splitNum) % blockDim != 0) { continue; } int32_t repeatTimesPerBlock = mmInfo.flashDecodingTaskNum * splitNum / blockDim; if (minKVSeqlen / splitNum >= blockSize && - repeatTimesPerBlock + NUM6 <= minKVBlocks - (minKVBlocks + splitNum - 1) / splitNum * repeatTimesPerBlock) { + repeatTimesPerBlock + NUM6 <= maxKVBlocks - (maxKVBlocks + splitNum - 1) / splitNum * repeatTimesPerBlock) { return splitNum; } else { return NUM1; @@ -64,6 +64,7 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block mmInfo.tailTaskNum = mmInfo.totalTaskNum % blockDim; mmInfo.flashDecodingTaskNum = mmInfo.quantFlag ? mmInfo.tailTaskNum : mmInfo.tailBatch; auto minKVSeqlen = std::min_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); + auto maxKVSeqlen = std::max_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); auto minQSeqlen = mmInfo.qSeqLen != nullptr ? *std::min_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; auto maxQSeqlen = mmInfo.qSeqLen != nullptr ? *std::max_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; mmInfo.flashDecoding = !mmInfo.quantFlag && mmInfo.flashDecodingTaskNum != 0 && param.isRing == 0 && @@ -73,8 +74,8 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block if (!mmInfo.flashDecoding) { return Status::OkStatus(); } - mmInfo.splitKVNum = blockDim / mmInfo.flashDecodingTaskNum > 1 ? blockDim / mmInfo.flashDecodingTaskNum : - CalcSplitNum(mmInfo, blockDim, *minKVSeqlen, mmInfo.blockSize); + mmInfo.splitKVNum = blockDim / mmInfo.flashDecodingTaskNum > 1 ? blockDim / mmInfo.flashDecodingTaskNum : + CalcSplitNum(mmInfo, blockDim, *minKVSeqlen, *maxKVSeqlen, mmInfo.blockSize); mmInfo.flashDecoding = mmInfo.splitKVNum == 1 ? false : true; if (mmInfo.flashDecoding) { for (int32_t batchIdx = 0; batchIdx < mmInfo.batch; batchIdx++) { @@ -84,7 +85,7 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block int32_t taskNum = mmInfo.quantFlag ? mmInfo.totalTaskNum : mmInfo.batch; mmInfo.normalTaskNum = taskNum / blockDim * blockDim; } - MKI_LOG(INFO) << "flashDecoding is = " << mmInfo.flashDecoding; + MKI_LOG(INFO) << "mmInfo.flashDecoding is = " << mmInfo.flashDecoding; return Status::OkStatus(); } @@ -115,15 +116,14 @@ Status GetMLANdInfo(const LaunchParam &launchParam, MLAInfo &mmInfo, mmInfo.numHeads = static_cast(param.headSize); mmInfo.maskType = static_cast(param.maskType); mmInfo.quantFlag = (static_cast(mmInfo.type) < NUM2) ? 0 : 1; - mmInfo.mtpTp1Flag = (mmInfo.numHeads == M_LIMIT); - if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { - mmInfo.maskType = 0; - } mmInfo.totalTaskNum = mmInfo.qSeqLen != nullptr ? std::accumulate(mmInfo.qSeqLen, mmInfo.qSeqLen + mmInfo.batch, static_cast(0)) : mmInfo.batch; - if (mmInfo.mtpTp1Flag) { - OP_TILING_CHECK_STATUS_RETURN(GetFlashDecodingInfo(mmInfo, param, blockDim)); + OP_TILING_CHECK_STATUS_RETURN(GetFlashDecodingInfo(mmInfo, param, blockDim)); + mmInfo.mtpTp1Flag = (mmInfo.numHeads == M_LIMIT || + (mmInfo.flashDecoding && !mmInfo.quantFlag && mmInfo.numHeads >= NUM16)); + if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { + mmInfo.maskType = 0; } return Status::OkStatus(); } diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h index 1e20e646..bf467a0a 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h @@ -26,7 +26,7 @@ constexpr int32_t TILING_PARA_SIZE = 8; constexpr int32_t TILING_PARA_SIZE_PREFILL = 27; constexpr int32_t TILING_HEAD_SIZE_PREFILL = 18; constexpr int32_t MAX_EMBEDDING = 576; -constexpr int32_t TILING_PARA_SIZE_TP1 = 7; +constexpr int32_t TILING_PARA_SIZE_TP1 = 8; constexpr int32_t TILING_HEAD_SIZE = 18; constexpr int32_t M_LIMIT = 128; constexpr int32_t ND_BATCH_LIMIT = INT32_MAX; -- Gitee From 3dec73ef361682e5740d173ed479d955a4465742 Mon Sep 17 00:00:00 2001 From: Liexss Date: Fri, 16 May 2025 16:23:01 +0800 Subject: [PATCH 2/8] [fix]: fix q=2 error --- .../mixkernels/multi_latent_attention/op_kernel/mla.cce | 5 +++-- .../multi_latent_attention/tiling/mla_tiling_dependency.cpp | 5 +---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 5cae4b16..cfbe5887 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -948,6 +948,7 @@ private: SET_FLAG(MTE2, MTE1, EVENT_ID0); WAIT_FLAG(MTE2, MTE1, EVENT_ID0); uint32_t q_load_coeff = q_heads; + uint32_t q_load_coeff_round = RoundUp<16>(q_load_coeff); uint32_t l0a_pingpong_size = L0AB_UINT8_BUF_SIZE / sizeof(IN_DTYPE); uint32_t l0b_pingpong_size = L0AB_UINT8_BUF_SIZE / sizeof(IN_DTYPE); uint32_t l0c_pingpong_size = L0C_FLOAT_BUF_SIZE; @@ -1033,7 +1034,7 @@ private: uint32_t qk_c_pingpong_flag = (now_idx * cur_q_seqlen + q_i) % 2; uint32_t qk_0_pingpong_flag = (embed_split_idx * cur_q_seqlen + q_i) % 2; WAIT_FLAG(M, MTE1, qk_0_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < m / BLOCK_SIZE; ++loa_load_idx) { + for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff_round / BLOCK_SIZE; ++loa_load_idx) { l1_to_l0_a( l0a_buf_tensor[qk_0_pingpong_flag * l0a_pingpong_size + loa_load_idx * round_embed_split_size * BLOCK_SIZE], l1q_buf_addr_tensor[q_i * q_heads * BLOCK_SIZE + embed_split_idx * m * round_embed_split_size + loa_load_idx * T_CUBE_MATRIX_SIZE], @@ -1118,7 +1119,7 @@ private: uint32_t qk_c_pingpong_flag = (now_idx * cur_q_seqlen + q_i) % 2; uint32_t qk_0_pingpong_flag = (embed_split_idx * cur_q_seqlen + q_i) % 2; WAIT_FLAG(M, MTE1, qk_0_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < m / BLOCK_SIZE; ++loa_load_idx) { + for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff_round / BLOCK_SIZE; ++loa_load_idx) { l1_to_l0_a( l0a_buf_tensor[qk_0_pingpong_flag * 16384 + loa_load_idx * round_embed_split_size * BLOCK_SIZE], l1q_buf_addr_tensor[q_i * q_heads * BLOCK_SIZE + embed_split_idx * m * 128 + loa_load_idx * CUBE_MATRIX_SIZE], diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index eacfe351..a98a3720 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -361,10 +361,7 @@ Status GetMLATilingParam(const LaunchParam &launchParam, MLAInfo &mmInfo, auto param = AnyCast(launchParam.GetParam()); float tor = mmInfo.tor; uint32_t *torPtr = reinterpret_cast(&tor); - uint64_t curTilingParamSize = mmInfo.mtpTp1Flag ? - (TILING_HEAD_SIZE + TILING_PARA_SIZE_TP1 * mmInfo.totalTaskNum) * sizeof(uint32_t) : - (TILING_HEAD_SIZE + TILING_PARA_SIZE * mmInfo.batch) * sizeof(uint32_t); - MKI_CHECK(memset_s(tilingParam, tilingParamSize, 0, curTilingParamSize) == EOK, "init tiling failed", + MKI_CHECK(memset_s(tilingParam, tilingParamSize, 0, tilingParamSize) == EOK, "init tiling failed", return Status::FailStatus(ERROR_INVALID_VALUE)); if (mmInfo.mtpTp1Flag) { if (mmInfo.flashDecoding) { -- Gitee From 241b658d5d26959997e1cad358edb846fbdd910b Mon Sep 17 00:00:00 2001 From: Liexss Date: Fri, 16 May 2025 17:15:39 +0800 Subject: [PATCH 3/8] [fix]: code check fix --- .../multi_latent_attention/op_kernel/mla.cce | 10 +++++----- .../multi_latent_attention/tiling/mla_tiling.cpp | 11 +++++------ 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index cfbe5887..55cdffbb 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -2262,10 +2262,10 @@ private: uint32_t i) { constexpr uint32_t ubuf_offset = 3 * STAGE2_UB_UINT8_BLOCK_SIZE; - constexpr uint32_t lo_ubuf_size = ubuf_offset + STAGE2_UB_UINT8_BLOCK_SIZE * 4; //8 * 512 * doublebuffer * sizeof(float) - constexpr uint32_t to_ubuf_size = lo_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 4; //8 * 512 * doublebuffer * sizeof(float) - constexpr uint32_t go_ubuf_size = to_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 2; //8 * 512 * sizeof(float) - constexpr uint32_t go16_ubuf_size = go_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE; //8 * 512 * sizeof(float) / 2 + constexpr uint32_t lo_ubuf_size = ubuf_offset + STAGE2_UB_UINT8_BLOCK_SIZE * 4; // 8 * 512 * doublebuffer * sizeof(float) + constexpr uint32_t to_ubuf_size = lo_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 4; // 8 * 512 * doublebuffer * sizeof(float) + constexpr uint32_t go_ubuf_size = to_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE * 2; // 8 * 512 * sizeof(float) + constexpr uint32_t go16_ubuf_size = go_ubuf_size + STAGE2_UB_UINT8_BLOCK_SIZE; // 8 * 512 * sizeof(float) / 2 constexpr uint32_t lBrcb_size = FLOAT_VECTOR_SIZE * 8; AscendC::LocalTensor loTensor = buf.GetBuffer(ubuf_offset); AscendC::LocalTensor toTensor = buf.GetBuffer(lo_ubuf_size); @@ -2422,7 +2422,7 @@ private: uint32_t batchTaskNum = (uint32_t)(*((__gm__ uint32_t *)tiling_gm + 17)); uint32_t offset_tiling = tiling_head_size + tiling_para_size * (totalTaskNum + flashDecodingTaskNum); uint32_t q_seqlen = (uint32_t)(*((__gm__ uint32_t *)tiling_gm + 1 + tiling_head_size + tiling_para_size * totalTaskNum)); - uint32_t decodingQSeqLens = decodingQSeqLens = batchTaskNum * q_seqlen; + uint32_t decodingQSeqLens = batchTaskNum * q_seqlen; if constexpr (tilingKeyType == TilingKeyType::TILING_INT8_DATA) { q_seqlen = 1; } diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index 2b714d79..0eb1977d 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -31,7 +31,7 @@ const int32_t NUM512 = 512; const int32_t NUM576 = 576; const float SPLITKV_SEQLEN = 2048; -int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int32_t maxKVSeqlen, int32_t blockSize) +int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int32_t blockSize) { if (blockDim - mmInfo.flashDecodingTaskNum <= NUM4 || mmInfo.quantFlag) { return NUM1; @@ -39,14 +39,14 @@ int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int if (blockSize == 0 || blockDim == 0) { return NUM1; } - int32_t maxKVBlocks = (maxKVSeqlen + blockSize - 1) / blockSize; + int32_t minKVBlocks = (minKVSeqlen + blockSize - 1) / blockSize; for (int32_t splitNum = 2; splitNum <= NUM6; splitNum++) { if ((mmInfo.flashDecodingTaskNum * splitNum) % blockDim != 0) { continue; } int32_t repeatTimesPerBlock = mmInfo.flashDecodingTaskNum * splitNum / blockDim; if (minKVSeqlen / splitNum >= blockSize && - repeatTimesPerBlock + NUM6 <= maxKVBlocks - (maxKVBlocks + splitNum - 1) / splitNum * repeatTimesPerBlock) { + repeatTimesPerBlock + NUM6 <= minKVBlocks - (minKVBlocks + splitNum - 1) / splitNum * repeatTimesPerBlock) { return splitNum; } else { return NUM1; @@ -64,7 +64,6 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block mmInfo.tailTaskNum = mmInfo.totalTaskNum % blockDim; mmInfo.flashDecodingTaskNum = mmInfo.quantFlag ? mmInfo.tailTaskNum : mmInfo.tailBatch; auto minKVSeqlen = std::min_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); - auto maxKVSeqlen = std::max_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); auto minQSeqlen = mmInfo.qSeqLen != nullptr ? *std::min_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; auto maxQSeqlen = mmInfo.qSeqLen != nullptr ? *std::max_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; mmInfo.flashDecoding = !mmInfo.quantFlag && mmInfo.flashDecodingTaskNum != 0 && param.isRing == 0 && @@ -75,7 +74,7 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block return Status::OkStatus(); } mmInfo.splitKVNum = blockDim / mmInfo.flashDecodingTaskNum > 1 ? blockDim / mmInfo.flashDecodingTaskNum : - CalcSplitNum(mmInfo, blockDim, *minKVSeqlen, *maxKVSeqlen, mmInfo.blockSize); + CalcSplitNum(mmInfo, blockDim, *minKVSeqlen, mmInfo.blockSize); mmInfo.flashDecoding = mmInfo.splitKVNum == 1 ? false : true; if (mmInfo.flashDecoding) { for (int32_t batchIdx = 0; batchIdx < mmInfo.batch; batchIdx++) { @@ -85,7 +84,7 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block int32_t taskNum = mmInfo.quantFlag ? mmInfo.totalTaskNum : mmInfo.batch; mmInfo.normalTaskNum = taskNum / blockDim * blockDim; } - MKI_LOG(INFO) << "mmInfo.flashDecoding is = " << mmInfo.flashDecoding; + MKI_LOG(INFO) << "flashDecoding is = " << mmInfo.flashDecoding; return Status::OkStatus(); } -- Gitee From aacd5af062faa88f27774ec41af68d42cadd0551 Mon Sep 17 00:00:00 2001 From: Liexss Date: Fri, 16 May 2025 17:31:11 +0800 Subject: [PATCH 4/8] [fix]: code check fix --- src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 55cdffbb..2faaba42 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -1618,7 +1618,7 @@ public: WaitFlagDev(reduce_flag_id); CombineScaleBlock(num_batches, q_heads, kv_split_core_num, embedding_size); } - } + } private: -- Gitee From 95a17ea8a006ff4d9dd754f599964767a073414e Mon Sep 17 00:00:00 2001 From: Liexss Date: Mon, 19 May 2025 16:30:27 +0800 Subject: [PATCH 5/8] [fix]: bugfix for Tp1Flag when flashDecoding --- .../mixkernels/multi_latent_attention/tiling/mla_tiling.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index 0eb1977d..29c64b03 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -120,7 +120,8 @@ Status GetMLANdInfo(const LaunchParam &launchParam, MLAInfo &mmInfo, mmInfo.batch; OP_TILING_CHECK_STATUS_RETURN(GetFlashDecodingInfo(mmInfo, param, blockDim)); mmInfo.mtpTp1Flag = (mmInfo.numHeads == M_LIMIT || - (mmInfo.flashDecoding && !mmInfo.quantFlag && mmInfo.numHeads >= NUM16)); + (mmInfo.flashDecoding && !mmInfo.quantFlag && mmInfo.numHeads % NUM8 == 0)); + mmInfo.flashDecoding = !mmInfo.mtpTp1Flag ? false : mmInfo.flashDecoding; if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { mmInfo.maskType = 0; } -- Gitee From 0edbf6b6b04aec1fc3ab29433f3f03162373f97c Mon Sep 17 00:00:00 2001 From: Liexss Date: Tue, 20 May 2025 10:25:53 +0800 Subject: [PATCH 6/8] [fix]: code check fix --- .../mixkernels/multi_latent_attention/op_kernel/mla.cce | 6 ------ .../mixkernels/multi_latent_attention/tiling/mla_tiling.cpp | 2 +- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 2faaba42..44787551 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -2405,12 +2405,6 @@ private: src_gap, // srcGap (split_num - 1) * __k * 4 // dstGap ); - - WAIT_FLAG(MTE3, V, EVENT_ID3); - SET_FLAG(MTE3, V, EVENT_ID3); - - WAIT_FLAG(MTE3, MTE2, EVENT_ID1); - SET_FLAG(MTE3, MTE2, EVENT_ID1); } __aicore__ __attribute__((always_inline)) inline void CombineScaleBlock(uint32_t num_tokens, uint32_t q_heads, uint32_t kv_split_core_num, uint32_t embedding_size) diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index 29c64b03..580ef077 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -84,7 +84,6 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block int32_t taskNum = mmInfo.quantFlag ? mmInfo.totalTaskNum : mmInfo.batch; mmInfo.normalTaskNum = taskNum / blockDim * blockDim; } - MKI_LOG(INFO) << "flashDecoding is = " << mmInfo.flashDecoding; return Status::OkStatus(); } @@ -125,6 +124,7 @@ Status GetMLANdInfo(const LaunchParam &launchParam, MLAInfo &mmInfo, if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { mmInfo.maskType = 0; } + MKI_LOG(INFO) << "flashDecoding is = " << mmInfo.flashDecoding; return Status::OkStatus(); } -- Gitee From 2bcc84697a41db25785a52a6d544b82548870383 Mon Sep 17 00:00:00 2001 From: Liexss Date: Tue, 20 May 2025 15:35:43 +0800 Subject: [PATCH 7/8] [fix]: code check fix --- .../mixkernels/multi_latent_attention/op_kernel/mla.cce | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 44787551..fa5ed55b 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -1611,7 +1611,7 @@ public: WAIT_FLAG(V, MTE2, EVENT_ID4); WAIT_FLAG(V, MTE2, EVENT_ID7); WAIT_FLAG(MTE3, V, EVENT_ID2); - + PIPE_BARRIER(ALL); if constexpr (flashDecoding) { int reduce_flag_id = 3; FftsCrossCoreSync(reduce_flag_id); @@ -2405,6 +2405,9 @@ private: src_gap, // srcGap (split_num - 1) * __k * 4 // dstGap ); + + SET_FLAG(MTE3, V, EVENT_ID3); + WAIT_FLAG(MTE3, V, EVENT_ID3); } __aicore__ __attribute__((always_inline)) inline void CombineScaleBlock(uint32_t num_tokens, uint32_t q_heads, uint32_t kv_split_core_num, uint32_t embedding_size) -- Gitee From da74c15383e7bf975ec860a3bdf19ce05a1ff093 Mon Sep 17 00:00:00 2001 From: Liexss Date: Fri, 23 May 2025 17:45:30 +0800 Subject: [PATCH 8/8] [fix]: merge fix --- .../mixkernels/multi_latent_attention/op_kernel/mla.cce | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index fa5ed55b..5b722821 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -3388,6 +3388,9 @@ private: if constexpr (std::is_same_v) { oPingPongFlag = 0; } + if (head_loop <= 1 && q_seq_len <= 1) { + oPingPongFlag = 0; + } uint32_t oUbOffset = oPingPongFlag * 4096; // oPingPongFlag * 8 * 512 WAIT_FLAG(V, MTE2, oPingPongFlag); if (n_idx != 4) { -- Gitee