diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp index ee0d14b719f988b7f881145a528c0655ebd63598..6bd3f70e518a98a11c90f89e58843b9c0c93dc93 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp @@ -52,19 +52,14 @@ public: "paged attention: param type invalid", return false); auto param = AnyCast(launchParam.GetParam()); auto batch = param.kvSeqLen.size(); - if (param.headSize == M_LIMIT) { - uint64_t taskNum = param.qSeqLen.data() == nullptr ? batch : - std::accumulate(param.qSeqLen.data(), - param.qSeqLen.data() + batch, static_cast(0)); - uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); - uint64_t bufferSize = - Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE_TP1 * (taskNum - 1) * sizeof(uint32_t) + - TILING_PARA_SIZE_TP1 * blockDim * blockDim * sizeof(uint32_t) + blockDim * 2 * sizeof(uint32_t), - TILINGMIN); - return bufferSize; - } + uint64_t taskNum = param.qSeqLen.data() == nullptr ? batch : + std::accumulate(param.qSeqLen.data(), + param.qSeqLen.data() + batch, static_cast(0)); + uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); uint64_t bufferSize = - Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE * (batch - 1) * sizeof(uint32_t), TILINGMIN); + Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE_TP1 * (taskNum - 1) * sizeof(uint32_t) + + TILING_PARA_SIZE_TP1 * blockDim * blockDim * sizeof(uint32_t) + blockDim * 2 * sizeof(uint32_t), + TILINGMIN); return bufferSize; } }; diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 92fcddcfc08c4429f908cfbc8c11c8a623946bd7..5b7228212d6f08981775282fbb7097830828c377 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -761,7 +761,6 @@ private: embed_split_size, // k embed_split_idx == 0 // cmatrixInitVal ); - PIPE_BARRIER(M); SET_FLAG(M, MTE1, embed_split_idx % 2); SET_FLAG(M, MTE1, embed_split_idx % 2 + 2); @@ -949,6 +948,7 @@ private: SET_FLAG(MTE2, MTE1, EVENT_ID0); WAIT_FLAG(MTE2, MTE1, EVENT_ID0); uint32_t q_load_coeff = q_heads; + uint32_t q_load_coeff_round = RoundUp<16>(q_load_coeff); uint32_t l0a_pingpong_size = L0AB_UINT8_BUF_SIZE / sizeof(IN_DTYPE); uint32_t l0b_pingpong_size = L0AB_UINT8_BUF_SIZE / sizeof(IN_DTYPE); uint32_t l0c_pingpong_size = L0C_FLOAT_BUF_SIZE; @@ -1034,7 +1034,7 @@ private: uint32_t qk_c_pingpong_flag = (now_idx * cur_q_seqlen + q_i) % 2; uint32_t qk_0_pingpong_flag = (embed_split_idx * cur_q_seqlen + q_i) % 2; WAIT_FLAG(M, MTE1, qk_0_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff / BLOCK_SIZE; ++loa_load_idx) { + for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff_round / BLOCK_SIZE; ++loa_load_idx) { l1_to_l0_a( l0a_buf_tensor[qk_0_pingpong_flag * l0a_pingpong_size + loa_load_idx * round_embed_split_size * BLOCK_SIZE], l1q_buf_addr_tensor[q_i * q_heads * BLOCK_SIZE + embed_split_idx * m * round_embed_split_size + loa_load_idx * T_CUBE_MATRIX_SIZE], @@ -1119,7 +1119,7 @@ private: uint32_t qk_c_pingpong_flag = (now_idx * cur_q_seqlen + q_i) % 2; uint32_t qk_0_pingpong_flag = (embed_split_idx * cur_q_seqlen + q_i) % 2; WAIT_FLAG(M, MTE1, qk_0_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff / BLOCK_SIZE; ++loa_load_idx) { + for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff_round / BLOCK_SIZE; ++loa_load_idx) { l1_to_l0_a( l0a_buf_tensor[qk_0_pingpong_flag * 16384 + loa_load_idx * round_embed_split_size * BLOCK_SIZE], l1q_buf_addr_tensor[q_i * q_heads * BLOCK_SIZE + embed_split_idx * m * 128 + loa_load_idx * CUBE_MATRIX_SIZE], @@ -1611,7 +1611,7 @@ public: WAIT_FLAG(V, MTE2, EVENT_ID4); WAIT_FLAG(V, MTE2, EVENT_ID7); WAIT_FLAG(MTE3, V, EVENT_ID2); - + PIPE_BARRIER(ALL); if constexpr (flashDecoding) { int reduce_flag_id = 3; FftsCrossCoreSync(reduce_flag_id); @@ -2406,11 +2406,8 @@ private: (split_num - 1) * __k * 4 // dstGap ); - WAIT_FLAG(MTE3, V, EVENT_ID3); SET_FLAG(MTE3, V, EVENT_ID3); - - WAIT_FLAG(MTE3, MTE2, EVENT_ID1); - SET_FLAG(MTE3, MTE2, EVENT_ID1); + WAIT_FLAG(MTE3, V, EVENT_ID3); } __aicore__ __attribute__((always_inline)) inline void CombineScaleBlock(uint32_t num_tokens, uint32_t q_heads, uint32_t kv_split_core_num, uint32_t embedding_size) @@ -3391,6 +3388,9 @@ private: if constexpr (std::is_same_v) { oPingPongFlag = 0; } + if (head_loop <= 1 && q_seq_len <= 1) { + oPingPongFlag = 0; + } uint32_t oUbOffset = oPingPongFlag * 4096; // oPingPongFlag * 8 * 512 WAIT_FLAG(V, MTE2, oPingPongFlag); if (n_idx != 4) { @@ -3423,7 +3423,7 @@ private: round_sub_m / FLOAT_BLOCK_SIZE // repeat ); PIPE_BARRIER(V); - if (head_loop > 1) { + if (head_loop > 1 || q_seq_len > 1) { gm_to_ub( go32_ubuf_tensor[oUbOffset], go_gm_tensor, @@ -3657,7 +3657,7 @@ private: } } } - else if (head_loop > 1) { + else if (head_loop > 1 || q_seq_len > 1) { SET_FLAG(V, MTE3, EVENT_ID5); WAIT_FLAG(V, MTE3, EVENT_ID5); ub_to_gm( diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index e37116fa25d108a0c1ffff7d377adab3b1ee07d4..580ef077787f1c818e46d059a27ded103f89f118 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -73,7 +73,7 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block if (!mmInfo.flashDecoding) { return Status::OkStatus(); } - mmInfo.splitKVNum = blockDim / mmInfo.flashDecodingTaskNum > 1 ? blockDim / mmInfo.flashDecodingTaskNum : + mmInfo.splitKVNum = blockDim / mmInfo.flashDecodingTaskNum > 1 ? blockDim / mmInfo.flashDecodingTaskNum : CalcSplitNum(mmInfo, blockDim, *minKVSeqlen, mmInfo.blockSize); mmInfo.flashDecoding = mmInfo.splitKVNum == 1 ? false : true; if (mmInfo.flashDecoding) { @@ -84,7 +84,6 @@ Status GetFlashDecodingInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, uint32_t block int32_t taskNum = mmInfo.quantFlag ? mmInfo.totalTaskNum : mmInfo.batch; mmInfo.normalTaskNum = taskNum / blockDim * blockDim; } - MKI_LOG(INFO) << "flashDecoding is = " << mmInfo.flashDecoding; return Status::OkStatus(); } @@ -115,16 +114,17 @@ Status GetMLANdInfo(const LaunchParam &launchParam, MLAInfo &mmInfo, mmInfo.numHeads = static_cast(param.headSize); mmInfo.maskType = static_cast(param.maskType); mmInfo.quantFlag = (static_cast(mmInfo.type) < NUM2) ? 0 : 1; - mmInfo.mtpTp1Flag = (mmInfo.numHeads == M_LIMIT); - if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { - mmInfo.maskType = 0; - } mmInfo.totalTaskNum = mmInfo.qSeqLen != nullptr ? std::accumulate(mmInfo.qSeqLen, mmInfo.qSeqLen + mmInfo.batch, static_cast(0)) : mmInfo.batch; - if (mmInfo.mtpTp1Flag) { - OP_TILING_CHECK_STATUS_RETURN(GetFlashDecodingInfo(mmInfo, param, blockDim)); + OP_TILING_CHECK_STATUS_RETURN(GetFlashDecodingInfo(mmInfo, param, blockDim)); + mmInfo.mtpTp1Flag = (mmInfo.numHeads == M_LIMIT || + (mmInfo.flashDecoding && !mmInfo.quantFlag && mmInfo.numHeads % NUM8 == 0)); + mmInfo.flashDecoding = !mmInfo.mtpTp1Flag ? false : mmInfo.flashDecoding; + if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { + mmInfo.maskType = 0; } + MKI_LOG(INFO) << "flashDecoding is = " << mmInfo.flashDecoding; return Status::OkStatus(); } diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index eacfe351021c4a6a45386adc722f2a0681552689..a98a3720367e316a1871e5f4c358f0dd8301d62e 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -361,10 +361,7 @@ Status GetMLATilingParam(const LaunchParam &launchParam, MLAInfo &mmInfo, auto param = AnyCast(launchParam.GetParam()); float tor = mmInfo.tor; uint32_t *torPtr = reinterpret_cast(&tor); - uint64_t curTilingParamSize = mmInfo.mtpTp1Flag ? - (TILING_HEAD_SIZE + TILING_PARA_SIZE_TP1 * mmInfo.totalTaskNum) * sizeof(uint32_t) : - (TILING_HEAD_SIZE + TILING_PARA_SIZE * mmInfo.batch) * sizeof(uint32_t); - MKI_CHECK(memset_s(tilingParam, tilingParamSize, 0, curTilingParamSize) == EOK, "init tiling failed", + MKI_CHECK(memset_s(tilingParam, tilingParamSize, 0, tilingParamSize) == EOK, "init tiling failed", return Status::FailStatus(ERROR_INVALID_VALUE)); if (mmInfo.mtpTp1Flag) { if (mmInfo.flashDecoding) { diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h index 1e20e6465aa37928b93833b006826e6c4e03abd7..bf467a0af4242f67ad1c3cf78fa1399a06dc4d48 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.h @@ -26,7 +26,7 @@ constexpr int32_t TILING_PARA_SIZE = 8; constexpr int32_t TILING_PARA_SIZE_PREFILL = 27; constexpr int32_t TILING_HEAD_SIZE_PREFILL = 18; constexpr int32_t MAX_EMBEDDING = 576; -constexpr int32_t TILING_PARA_SIZE_TP1 = 7; +constexpr int32_t TILING_PARA_SIZE_TP1 = 8; constexpr int32_t TILING_HEAD_SIZE = 18; constexpr int32_t M_LIMIT = 128; constexpr int32_t ND_BATCH_LIMIT = INT32_MAX;