From 0f2688394bafebf6273330b5c1174f2084b011ee Mon Sep 17 00:00:00 2001 From: zhongsunjian1 Date: Fri, 23 May 2025 11:17:16 +0800 Subject: [PATCH 1/3] Init MLA MaskFree --- .../multi_latent_attention/mla_operation.cpp | 8 +- .../op_kernel/mla_prefill_high_precision.cce | 75 +++++++------------ .../tiling/mla_tiling.cpp | 11 +++ .../tiling/mla_tiling_dependency.cpp | 2 +- .../op_kernel/fa_common.cce | 3 +- 5 files changed, 45 insertions(+), 54 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp index 14a41b4d..e4e41b7f 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp @@ -35,8 +35,12 @@ public: case OpParam::MLA::SPLIT_CACHE: return GetKernelByName("MLAKernel"); case OpParam::MLA::PREFILL_SPLIT_CACHE: - return inDtype == TENSOR_DTYPE_BF16 ? GetKernelByName("MLAPrefillBF16Kernel") : - GetKernelByName("MLAPrefillKernel"); + if (param.maskType == OpParam::MLA::MASK_TYPE_MASK_FREE) { + return GetKernelByName("MLAPrefillBF16Kernel"); + } else { + return inDtype == TENSOR_DTYPE_BF16 ? GetKernelByName("MLAPrefillBF16Kernel") : + GetKernelByName("MLAPrefillKernel"); + } default: break; } diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce index 851d2156..6da1ff83 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce @@ -622,40 +622,6 @@ private: uint32_t tiling_head_size{0}; uint32_t tiling_para_size{0}; }; - -extern "C" __global__ __aicore__ void mla_prefill_bf16( - __gm__ uint8_t *__restrict__ sync, - __gm__ uint8_t *__restrict__ q_gm, - __gm__ uint8_t *__restrict__ q_rope_gm, - __gm__ uint8_t *__restrict__ k_gm, - __gm__ uint8_t *__restrict__ k_rope_gm, - __gm__ uint8_t *__restrict__ v_gm, - __gm__ uint8_t *__restrict__ mask_gm, - __gm__ uint8_t *__restrict__ alibi_coeff_gm, - __gm__ uint8_t *__restrict__ deq_qk_gm, - __gm__ uint8_t *__restrict__ off_qk_gm, - __gm__ uint8_t *__restrict__ deq_pv_gm, - __gm__ uint8_t *__restrict__ off_pv_gm, - __gm__ uint8_t *__restrict__ quant_p_gm, - __gm__ uint8_t *__restrict__ logN_gm, - __gm__ uint8_t *__restrict__ o_gm, - __gm__ uint8_t *__restrict__ s_gm, - __gm__ uint8_t *__restrict__ p_gm, - __gm__ uint8_t *__restrict__ o_tmp_gm, - __gm__ uint8_t *__restrict__ upo_tmp_gm, - __gm__ uint8_t *__restrict__ tiling_para_gm) -{ - uint32_t tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 12)); - // 仅在tilingKey为1的时候是bf,其余场景均当做half处理 - if (tilingKey == 1) { - FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 0> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); - fa_cube.Run(); - } else if (tilingKey == 2){ - FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); - fa_cube.Run(); - } -} - #endif #ifdef __DAV_C220_VEC__ @@ -685,7 +651,7 @@ __aicore__ __attribute__((always_inline)) inline void __set_mask(int32_t len) } } -template +template __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( const AscendC::LocalTensor &s_ub, const AscendC::LocalTensor &mask_orig_ub, @@ -706,7 +672,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( bool move_flag) { uint32_t round_m = (m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; - if constexpr(MASK_TYPE != 0) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_CAUSAL_COMPRESS) { if (move_flag) { WAIT_FLAG(MTE2, V, EVENT_ID6); conv_v( @@ -758,7 +724,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( 8 // srcRepeatStride ); PIPE_BARRIER(V); - if constexpr(MASK_TYPE != 0) { + if constexpr(MASK_TYPE != MaskType::MASK_TYPE_NONE) { if (move_flag) { add_v( s_ub, @@ -1012,7 +978,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( } } -template +template class FlashAttentionEncoderHighPrecisionVec { public: __aicore__ __attribute__((always_inline)) inline FlashAttentionEncoderHighPrecisionVec( @@ -1048,7 +1014,7 @@ public: if (this->data_shape_type == 1) { this->stride_qo = embdv; } - if constexpr (MASK_TYPE != 0) { + if constexpr (MASK_TYPE != MaskType::MASK_TYPE_NONE) { this->is_triu_mask = 1; } this->__k = embd; @@ -1204,7 +1170,7 @@ public: uint32_t curr_m = m_ind == m_end - 1 ? sub_m - m_ind * m_slice : m_slice; uint32_t s_ub_offset = pingpong_flag * S_DB_SIZE; uint32_t row_offset = m_ind * m_slice; - if constexpr(MASK_TYPE != 0) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_CAUSAL_COMPRESS) { if (move_flag) { uint64_t cmp_mask_offset = (m_idx - n_idx) * 512 * 128; WAIT_FLAG(V, MTE2, EVENT_ID6); @@ -1523,6 +1489,7 @@ private: uint32_t data_shape_type{0}; uint32_t quantType{0}; }; +#endif extern "C" __global__ __aicore__ void mla_prefill_bf16( __gm__ uint8_t *__restrict__ sync, @@ -1546,17 +1513,25 @@ extern "C" __global__ __aicore__ void mla_prefill_bf16( __gm__ uint8_t *__restrict__ upo_tmp_gm, __gm__ uint8_t *__restrict__ tiling_para_gm) { - uint32_t tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 12)); - if (tilingKey == 1) { - FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, 0> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, - tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, - deq_pv_gm, off_pv_gm, logN_gm); + if (TILING_KEY_IS(1)) +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 0> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_NONE> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); - } else if (tilingKey == 2) { +#endif + } else if (TILING_KEY_IS(11)) { +#ifdef __DAV_C220_CUBE__ FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, 1> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, - tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, - deq_pv_gm, off_pv_gm, logN_gm); + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); - } -} +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_CAUSAL_COMPRESS> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); #endif +} +} \ No newline at end of file diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index e37116fa..db492e30 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -309,6 +309,15 @@ Status InitInfo(MLAInfo &mmInfo, OpParam::MLA ¶m) return Status::OkStatus(); } +Status GenMlaPrefillTilingKey(MLAInfo &mmInfo, KernelInfo &kernelInfo, OpParam::MLA ¶m) +{ + uint32_t dataType = static_cast(mmInfo.type); + uint32_t tilingKey = dataType + (mmInfo.maskType << NUM1); + kernelInfo.SetTilingId(tilingKey); + MKI_LOG(INFO) << "MLA Prefill TILING KEY IS = " << tilingKey; + return Status::OkStatus(); +} + Status MLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) { MLAInfo mmInfo; @@ -333,6 +342,8 @@ Status MLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) uint32_t *tilingParam = reinterpret_cast(tilingHost); ret = GetMLAPrefillTilingParam(mmInfo, blockDim, tilingParam, tilingSize); OP_TILING_CHECK_STATUS_RETURN(ret); + ret = GenMlaPrefillTilingKey(mmInfo, kernelInfo, param); + OP_TILING_CHECK_STATUS_RETURN(ret); uint64_t dataLenFloat = sizeof(float); uint64_t sSize = static_cast(blockDim) * static_cast(32768) * 16 * dataLenFloat; kernelInfo.GetScratchSizes() = {sSize, sSize, sSize, sSize * 2}; // oTmp/S/P diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index eacfe351..bccbbe39 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -439,7 +439,7 @@ void PrefillTilingHead(const MLAInfo &mmInfo, const uint32_t &torUptr, AddrOffse tilingParam[NUM9] = static_cast(addrOffsets.totalQBlkNum); tilingParam[NUM10] = static_cast(TILING_HEAD_SIZE_PREFILL); tilingParam[NUM11] = static_cast(TILING_PARA_SIZE_PREFILL); - tilingParam[NUM12] = mmInfo.maskType == 0 ? 1 : 2; + tilingParam[NUM12] = mmInfo.maskType; tilingParam[NUM13] = 0; tilingParam[NUM14] = mmInfo.maxKvSeqLen; tilingParam[NUM15] = mmInfo.maskType; diff --git a/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce b/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce index 91630817..92930bb0 100644 --- a/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce +++ b/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce @@ -50,7 +50,8 @@ enum class MaskType { MASK_TYPE_ALIBI_COMPRESS_SQRT = 7, MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN = 8, MASK_TYPE_ALIBI_COMPRESS_128 = 9, - MASK_TYPE_CAUSAL_MASK = 10 // mask内部生成 + MASK_TYPE_CAUSAL_MASK = 10, + MASK_TYPE_MASK_FREE = 11 }; __aicore__ __attribute__((always_inline)) inline void SetVecMask(int32_t len) -- Gitee From debabe846f94bd9a9a2cc772e6607eebcab1060c Mon Sep 17 00:00:00 2001 From: zhongsunjian1 Date: Tue, 27 May 2025 16:40:54 +0800 Subject: [PATCH 2/3] mla_maskfree code update --- .../multi_latent_attention/mla_operation.cpp | 2 +- .../op_kernel/mla_prefill_high_precision.cce | 100 +++++++++++----- .../tiling/mla_tiling.cpp | 4 +- .../tiling/mla_tiling_dependency.cpp | 4 +- .../kernelstest/mix/test_mla_prefill_rope.py | 109 ++++++++++++++++++ 5 files changed, 187 insertions(+), 32 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp index e4e41b7f..bf5d7a7f 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp @@ -210,7 +210,7 @@ private: uint32_t batch = kvSeqLen.size(); auto maxQSeqlenIter = std::max_element(qSeqLen.begin(), qSeqLen.end()); auto maxQ = maxQSeqlenIter != qSeqLen.end() ? *maxQSeqlenIter : 1; - if (param.maskType == OpParam::MLA::MASK_TYPE_NONE) { + if (param.maskType == OpParam::MLA::MASK_TYPE_NONE || param.maskType == OpParam::MLA::MASK_TYPE_MASK_FREE) { MKI_CHECK(CheckEmptyTensor(mask), "mask type inconsistent", return false); } else { MKI_CHECK(!CheckEmptyTensor(mask), "mask type inconsistent", return false); diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce index 6da1ff83..d9adf7f1 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce @@ -220,7 +220,7 @@ public: uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; - + uint32_t prefixIdx = (kv_seqlen - q_seqlen) / pp_n_scalar; uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; uint32_t qk_round_m = (qk_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; @@ -300,7 +300,7 @@ public: uint32_t n_end = n_loop; if (is_triu_mask) { - n_end = m_idx + 1; + n_end = m_idx + prefixIdx + 1; } uint32_t launch_delay = S_BLOCK_STACK * 2; uint32_t vect_mod = 2 * launch_delay; @@ -1072,6 +1072,8 @@ public: o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(o_tmp_gm)); logN_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ IN_DATA_TYPE*>(logN_gm)); logN_float_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(logN_gm)); + float mask_coff = std::is_same_v ? HALF_MASK_COFF : BF16_MASK_COFF; + GenMaskMatrix(mask16_ubuf_tensor.template ReinterpretCast(), mask_coff); uint32_t next_process = 0; for (uint32_t process = block_idx; process < process_num; process = next_process) { while (process >= cur_total_q_blk_num * q_heads) { @@ -1107,6 +1109,7 @@ public: uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + uint32_t prefixIdx = (kv_seqlen - q_seqlen) / pp_n_scalar; uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; uint32_t sub_m = (sub_block_idx == 1) ? (qk_m - qk_m / 2) : qk_m / 2; @@ -1132,7 +1135,7 @@ public: uint32_t n_end = n_loop; if (is_triu_mask) { - n_end = m_idx + 1; + n_end = m_idx + prefixIdx + 1; } uint32_t qk_n_triu = n_end * pp_n_scalar; uint32_t pv_stage = 3; @@ -1153,7 +1156,8 @@ public: continue; } uint32_t p_scale_offset = n_idx / S_BLOCK_STACK % pv_stage * pp_m_scalar; - if (n_idx + S_BLOCK_STACK > n_end - 1) { + bool last_n_loop = n_idx + S_BLOCK_STACK > n_end - 1; + if (last_n_loop) { qk_n = qk_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : qk_n_triu - n_idx * pp_n_scalar; } else { qk_n = pp_n_scalar * S_BLOCK_STACK; @@ -1195,26 +1199,48 @@ public: if (curr_m == 0) { continue; } - OnlineSoftmaxStage1BF16 ( - ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], - mask16_ubuf_tensor, - mask_ubuf_tensor, - lm_ubuf_tensor[m_ind * m_slice], - hm_ubuf_tensor[m_ind * m_slice], - gm_ubuf_tensor[m_ind * m_slice], - dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + m_ind * m_slice], - ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], - ll_ubuf_tensor[m_ind * m_slice], - gl_ubuf_tensor[m_ind * m_slice], - lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), - tv_ubuf_tensor, - s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + - (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_ind * m_slice * qk_round_n], - p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + - ((uint64_t)sub_block_idx * qk_m / 2 + m_ind * m_slice) * qk_round_n) * 2 / sizeof(QKV_DT)], - n_idx == 0, this->tor, - curr_m, qk_n, qk_round_n, pingpong_flag, move_flag - ); + if constexpr (MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { + OnlineSoftmaxStage1 ( + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + mask16_ubuf_tensor.template ReinterpretCast(), + lm_ubuf_tensor[m_ind * m_slice], + hm_ubuf_tensor[m_ind * m_slice], + gm_ubuf_tensor[m_ind * m_slice], + dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + m_ind * m_slice], + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + ll_ubuf_tensor[m_ind * m_slice], + gl_ubuf_tensor[m_ind * m_slice], + lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), + tv_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_ind * m_slice * qk_round_n], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + m_ind * m_slice) * qk_round_n) * 2 / sizeof(QKV_DT)], + n_idx == 0, this->tor, mask_coff, curr_m, qk_n, qk_round_n, pp_n_scalar, + last_n_loop, m_ind * m_slice, qk_m / 2, pingpong_flag + ); + } else { + OnlineSoftmaxStage1BF16 ( + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + mask16_ubuf_tensor, + mask_ubuf_tensor, + lm_ubuf_tensor[m_ind * m_slice], + hm_ubuf_tensor[m_ind * m_slice], + gm_ubuf_tensor[m_ind * m_slice], + dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + m_ind * m_slice], + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + ll_ubuf_tensor[m_ind * m_slice], + gl_ubuf_tensor[m_ind * m_slice], + lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), + tv_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_ind * m_slice * qk_round_n], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + m_ind * m_slice) * qk_round_n) * 2 / sizeof(QKV_DT)], + n_idx == 0, this->tor, + curr_m, qk_n, qk_round_n, pingpong_flag, move_flag + ); + } pingpong_flag = 1 - pingpong_flag; } FftsCrossCoreSync(SOFTMAX_READY); @@ -1513,7 +1539,7 @@ extern "C" __global__ __aicore__ void mla_prefill_bf16( __gm__ uint8_t *__restrict__ upo_tmp_gm, __gm__ uint8_t *__restrict__ tiling_para_gm) { - if (TILING_KEY_IS(1)) + if (TILING_KEY_IS(1)) { #ifdef __DAV_C220_CUBE__ FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 0> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); fa_cube.Run(); @@ -1525,13 +1551,33 @@ extern "C" __global__ __aicore__ void mla_prefill_bf16( #endif } else if (TILING_KEY_IS(11)) { #ifdef __DAV_C220_CUBE__ - FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, 1> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_CAUSAL_COMPRESS> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); +#endif + } else if (TILING_KEY_IS(8)) { +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); #elif __DAV_C220_VEC__ - FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_CAUSAL_COMPRESS> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + FlashAttentionEncoderHighPrecisionVec fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); + fa_vec.Run(); +#endif + } else if (TILING_KEY_IS(9)) { +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_MASK_FREE> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); + fa_vec.Run(); #endif } } \ No newline at end of file diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index db492e30..a523aa79 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -226,7 +226,7 @@ inline Status GetPrefiillMaskInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, auto maskType = param.maskType; auto maxKvSeq = std::max_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); MKI_LOG(INFO) << "max kv seq" << *maxKvSeq; - if (maskType != OpParam::MLA::MASK_TYPE_NONE) { + if (maskType != OpParam::MLA::MASK_TYPE_NONE && maskType != OpParam::MLA::MASK_TYPE_MASK_FREE) { auto maskDim = maskShape.size(); int32_t maxSeq = maskShape.at(maskDim - 1); mmInfo.maxSeqLen = maxSeq; @@ -311,6 +311,7 @@ Status InitInfo(MLAInfo &mmInfo, OpParam::MLA ¶m) Status GenMlaPrefillTilingKey(MLAInfo &mmInfo, KernelInfo &kernelInfo, OpParam::MLA ¶m) { + // currently only support fp16/bf16 maskfree prefill or bf16 prefill kernel uint32_t dataType = static_cast(mmInfo.type); uint32_t tilingKey = dataType + (mmInfo.maskType << NUM1); kernelInfo.SetTilingId(tilingKey); @@ -342,6 +343,7 @@ Status MLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) uint32_t *tilingParam = reinterpret_cast(tilingHost); ret = GetMLAPrefillTilingParam(mmInfo, blockDim, tilingParam, tilingSize); OP_TILING_CHECK_STATUS_RETURN(ret); + GetTilingKeyTypeBase(mmInfo, mmInfo.tensors.query, mmInfo.tensors.queryRope); ret = GenMlaPrefillTilingKey(mmInfo, kernelInfo, param); OP_TILING_CHECK_STATUS_RETURN(ret); uint64_t dataLenFloat = sizeof(float); diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index bccbbe39..92abb298 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -397,8 +397,6 @@ Status CheckSeqlen(const MLAInfo &mmInfo, const AddrOffsets &addrOffsets, int32_ UNUSED_VALUE(seqIdx); MKI_CHECK(qSeqlen >= 0, "qSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(kvSeqlen >= 0, "kvSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(kvSeqlen == qSeqlen, "kvSeqlen should equal qSeqlen", - return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(addrOffsets.totalQBlkNum >= 0, "totalQBlkNum overflow", return Status::FailStatus(ERROR_INVALID_VALUE)); return Status::OkStatus(); @@ -439,7 +437,7 @@ void PrefillTilingHead(const MLAInfo &mmInfo, const uint32_t &torUptr, AddrOffse tilingParam[NUM9] = static_cast(addrOffsets.totalQBlkNum); tilingParam[NUM10] = static_cast(TILING_HEAD_SIZE_PREFILL); tilingParam[NUM11] = static_cast(TILING_PARA_SIZE_PREFILL); - tilingParam[NUM12] = mmInfo.maskType; + tilingParam[NUM12] = mmInfo.maskType == 0 ? 1 : 2; tilingParam[NUM13] = 0; tilingParam[NUM14] = mmInfo.maxKvSeqLen; tilingParam[NUM15] = mmInfo.maskType; diff --git a/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py b/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py index 82ab88a4..2dd4e260 100644 --- a/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py +++ b/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py @@ -14,6 +14,7 @@ import unittest import math import numpy as np sys.path.append('../') +sys.path.append('/opt/cl/ascend-transformer-boost/tests/apitest/kernelstest') sys.path.append('../..') import op_test import torch @@ -999,5 +1000,113 @@ class TestMLAPrefill(op_test.OpTest): [torch.tensor(attention_out, dtype=data_type)]) + @op_test.only_910b + def test_flash_attention_mla_bf16_mask_free(self): + batch = 1 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4, "kvHead": heads} + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd] * 13) + self.set_output_formats([self.format_nd]) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = True, + op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + attention_out = np.zeros_like(self.golden_out.to(torch.float16)) + + for i in range(1): + self.execute([self.q_split1, self.q_split2, self.k_split1[0], self.k_split2[0], self.v[0], + torch.tensor([], dtype=data_type), + torch.tensor([], dtype=torch.float), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], + [torch.tensor(attention_out, dtype=data_type)]) + + + @op_test.only_910b + def test_flash_attention_mla_bf16_prefix_with_mask_free(self): + batch = 1 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 1280 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + prefix = 256 + content = 129 + q_seqLen = [content] * batch + kv_seqLen = [prefix + content] * batch + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 1, "qSeqLen":q_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4, "kvHead": heads} + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd] * 13) + self.set_output_formats([self.format_nd]) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, q_seqlens=q_seqLen, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = True, + op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + attention_out = np.zeros_like(self.golden_out.to(torch.float16)) + + for i in range(1): + self.execute([self.q_split1, self.q_split2, self.k_split1[0], self.k_split2[0], self.v[0], + torch.tensor([], dtype=data_type), + torch.tensor([], dtype=torch.float), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], + [torch.tensor(attention_out, dtype=data_type)]) + + + if __name__ == '__main__': unittest.main() \ No newline at end of file -- Gitee From 8d5fb35ef4dfc0f22c82a7b81eabe7b4a3fa8cb5 Mon Sep 17 00:00:00 2001 From: zhongsunjian1 Date: Thu, 29 May 2025 16:55:55 +0800 Subject: [PATCH 3/3] update CAUSAL_MASK MaskType --- src/kernels/include/atbops/params/mla.h | 2 +- .../multi_latent_attention/mla_operation.cpp | 6 +-- .../op_kernel/mla_prefill_high_precision.cce | 44 +++++++++---------- .../tiling/mla_tiling.cpp | 29 ++++++------ ...ti_latent_attention_ops_runner_prefill.cpp | 4 +- .../kernelstest/mix/test_mla_prefill_rope.py | 13 +++--- 6 files changed, 49 insertions(+), 49 deletions(-) diff --git a/src/kernels/include/atbops/params/mla.h b/src/kernels/include/atbops/params/mla.h index 7a55e959..dc1159a1 100644 --- a/src/kernels/include/atbops/params/mla.h +++ b/src/kernels/include/atbops/params/mla.h @@ -37,7 +37,7 @@ struct MLA { MASK_TYPE_ALIBI = 2, MASK_TYPE_LOOK_AHEAD = 3, MASK_TYPE_MASK_FREE = 4, - MASK_TYPE_CAUSAL_COMPRESS = 5 + MASK_TYPE_CAUSAL_MASK = 5 }; MaskType maskType = MASK_TYPE_NONE; diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp index bf5d7a7f..9ac7addc 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp @@ -35,7 +35,7 @@ public: case OpParam::MLA::SPLIT_CACHE: return GetKernelByName("MLAKernel"); case OpParam::MLA::PREFILL_SPLIT_CACHE: - if (param.maskType == OpParam::MLA::MASK_TYPE_MASK_FREE) { + if (param.maskType == OpParam::MLA::MASK_TYPE_CAUSAL_MASK) { return GetKernelByName("MLAPrefillBF16Kernel"); } else { return inDtype == TENSOR_DTYPE_BF16 ? GetKernelByName("MLAPrefillBF16Kernel") : @@ -191,7 +191,7 @@ private: auto currentShape = tensorMask.desc.dims; auto maskShape = currentShape.size(); MKI_CHECK(maskShape == DIM_2, "mask invalid, please check.", return false); - auto normCompress = param.maskType == OpParam::MLA::MASK_TYPE_CAUSAL_COMPRESS; + auto normCompress = param.maskType == OpParam::MLA::MASK_TYPE_MASK_FREE; // 全量mask当前仅支持512,512的压缩mask,其余不支持,需配合isTriuMask开启 std::vector, bool>> supports = { {{LONG_SEQ_LEN_COMPRESS, LONG_SEQ_LEN_COMPRESS}, normCompress}, @@ -210,7 +210,7 @@ private: uint32_t batch = kvSeqLen.size(); auto maxQSeqlenIter = std::max_element(qSeqLen.begin(), qSeqLen.end()); auto maxQ = maxQSeqlenIter != qSeqLen.end() ? *maxQSeqlenIter : 1; - if (param.maskType == OpParam::MLA::MASK_TYPE_NONE || param.maskType == OpParam::MLA::MASK_TYPE_MASK_FREE) { + if (param.maskType == OpParam::MLA::MASK_TYPE_NONE || param.maskType == OpParam::MLA::MASK_TYPE_CAUSAL_MASK) { MKI_CHECK(CheckEmptyTensor(mask), "mask type inconsistent", return false); } else { MKI_CHECK(!CheckEmptyTensor(mask), "mask type inconsistent", return false); diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce index d9adf7f1..ab6d60e8 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce @@ -672,7 +672,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( bool move_flag) { uint32_t round_m = (m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; - if constexpr(MASK_TYPE == MaskType::MASK_TYPE_CAUSAL_COMPRESS) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { if (move_flag) { WAIT_FLAG(MTE2, V, EVENT_ID6); conv_v( @@ -724,7 +724,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( 8 // srcRepeatStride ); PIPE_BARRIER(V); - if constexpr(MASK_TYPE != MaskType::MASK_TYPE_NONE) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { if (move_flag) { add_v( s_ub, @@ -1174,7 +1174,7 @@ public: uint32_t curr_m = m_ind == m_end - 1 ? sub_m - m_ind * m_slice : m_slice; uint32_t s_ub_offset = pingpong_flag * S_DB_SIZE; uint32_t row_offset = m_ind * m_slice; - if constexpr(MASK_TYPE == MaskType::MASK_TYPE_CAUSAL_COMPRESS) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { if (move_flag) { uint64_t cmp_mask_offset = (m_idx - n_idx) * 512 * 128; WAIT_FLAG(V, MTE2, EVENT_ID6); @@ -1199,17 +1199,17 @@ public: if (curr_m == 0) { continue; } - if constexpr (MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { + if constexpr (MASK_TYPE == MaskType::MASK_TYPE_CAUSAL_MASK) { OnlineSoftmaxStage1 ( ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], mask16_ubuf_tensor.template ReinterpretCast(), - lm_ubuf_tensor[m_ind * m_slice], - hm_ubuf_tensor[m_ind * m_slice], - gm_ubuf_tensor[m_ind * m_slice], - dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + m_ind * m_slice], + lm_ubuf_tensor[row_offset], + hm_ubuf_tensor[row_offset], + gm_ubuf_tensor[row_offset], + dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + row_offset], ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], - ll_ubuf_tensor[m_ind * m_slice], - gl_ubuf_tensor[m_ind * m_slice], + ll_ubuf_tensor[row_offset], + gl_ubuf_tensor[row_offset], lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), tv_ubuf_tensor, s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + @@ -1224,13 +1224,13 @@ public: ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], mask16_ubuf_tensor, mask_ubuf_tensor, - lm_ubuf_tensor[m_ind * m_slice], - hm_ubuf_tensor[m_ind * m_slice], - gm_ubuf_tensor[m_ind * m_slice], - dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + m_ind * m_slice], + lm_ubuf_tensor[row_offset], + hm_ubuf_tensor[row_offset], + gm_ubuf_tensor[row_offset], + dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + row_offset], ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], - ll_ubuf_tensor[m_ind * m_slice], - gl_ubuf_tensor[m_ind * m_slice], + ll_ubuf_tensor[row_offset], + gl_ubuf_tensor[row_offset], lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), tv_ubuf_tensor, s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + @@ -1549,32 +1549,32 @@ extern "C" __global__ __aicore__ void mla_prefill_bf16( deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); #endif - } else if (TILING_KEY_IS(11)) { + } else if (TILING_KEY_IS(9)) { #ifdef __DAV_C220_CUBE__ FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); fa_cube.Run(); #elif __DAV_C220_VEC__ - FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_CAUSAL_COMPRESS> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_MASK_FREE> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); #endif - } else if (TILING_KEY_IS(8)) { + } else if (TILING_KEY_IS(10)) { #ifdef __DAV_C220_CUBE__ FlashAttentionEncoderHighPrecision fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); fa_cube.Run(); #elif __DAV_C220_VEC__ - FlashAttentionEncoderHighPrecisionVec fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + FlashAttentionEncoderHighPrecisionVec fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); #endif - } else if (TILING_KEY_IS(9)) { + } else if (TILING_KEY_IS(11)) { #ifdef __DAV_C220_CUBE__ FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); fa_cube.Run(); #elif __DAV_C220_VEC__ - FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_MASK_FREE> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_CAUSAL_MASK> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index a523aa79..b5ea63c1 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -222,23 +222,24 @@ inline Status PrefillPreCheck(OpParam::MLA ¶m) inline Status GetPrefiillMaskInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, const Tensor &tensorMask) { - auto maskShape = tensorMask.desc.dims; auto maskType = param.maskType; + if (maskType == OpParam::MLA::MASK_TYPE_NONE || maskType == OpParam::MLA::MASK_TYPE_CAUSAL_MASK) { + MKI_LOG(INFO) << "No Check for nomask or causal mask"; + return Status::OkStatus(); + } + auto maskShape = tensorMask.desc.dims; auto maxKvSeq = std::max_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); MKI_LOG(INFO) << "max kv seq" << *maxKvSeq; - if (maskType != OpParam::MLA::MASK_TYPE_NONE && maskType != OpParam::MLA::MASK_TYPE_MASK_FREE) { - auto maskDim = maskShape.size(); - int32_t maxSeq = maskShape.at(maskDim - 1); - mmInfo.maxSeqLen = maxSeq; - MKI_CHECK(maskType == OpParam::MLA::MASK_TYPE_CAUSAL_COMPRESS, - "mask type invalid", - return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(maskDim == DIM_2, "maskdim invalid", - return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(maskShape.at(1) == NORM_CMP_MASK_LEN, "compress mask shape should be 512, 512", - return Status::FailStatus(ERROR_INVALID_VALUE)); - } - + auto maskDim = maskShape.size(); + int32_t maxSeq = maskShape.at(maskDim - 1); + mmInfo.maxSeqLen = maxSeq; + MKI_CHECK(maskType == OpParam::MLA::MASK_TYPE_MASK_FREE, + "mask type invalid", + return Status::FailStatus(ERROR_INVALID_VALUE)); + MKI_CHECK(maskDim == DIM_2, "maskdim invalid", + return Status::FailStatus(ERROR_INVALID_VALUE)); + MKI_CHECK(maskShape.at(1) == NORM_CMP_MASK_LEN, "compress mask shape should be 512, 512", + return Status::FailStatus(ERROR_INVALID_VALUE)); return Status::OkStatus(); } diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index 0c0e7aef..c0be80c8 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -67,7 +67,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::SetupKernelGraph(const OpsTensorPac asdParam.kvHead = param_.kvHeadNum; asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_MASK_FREE : AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; mlaNode.opDesc = {0, "MLAOperation", asdParam}; @@ -98,7 +98,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::ModifyKernelGraph(const OpsTensorPa asdParam.qSeqLen = newParam_.qSeqlen; asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_MASK_FREE : AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; mlaNode.opDesc = {0, "MLAOperation", asdParam}; return NO_ERROR; diff --git a/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py b/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py index 2dd4e260..ada3810e 100644 --- a/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py +++ b/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py @@ -14,7 +14,6 @@ import unittest import math import numpy as np sys.path.append('../') -sys.path.append('/opt/cl/ascend-transformer-boost/tests/apitest/kernelstest') sys.path.append('../..') import op_test import torch @@ -866,7 +865,7 @@ class TestMLAPrefill(op_test.OpTest): clamp_min = 0 clamp_max = 0 OP_NAME = "MLAOperation" - OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5} + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 13) self.set_output_formats([self.format_nd]) @@ -964,7 +963,7 @@ class TestMLAPrefill(op_test.OpTest): clamp_min = 0 clamp_max = 0 OP_NAME = "MLAOperation" - OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5, "kvHead": heads} + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4, "kvHead": heads} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 13) self.set_output_formats([self.format_nd]) @@ -1001,7 +1000,7 @@ class TestMLAPrefill(op_test.OpTest): @op_test.only_910b - def test_flash_attention_mla_bf16_mask_free(self): + def test_flash_attention_mla_bf16_causal_mask(self): batch = 1 kv_head = 1 # kv_head num isdecoder = 0 # prefill or decoder @@ -1016,7 +1015,7 @@ class TestMLAPrefill(op_test.OpTest): clamp_min = 0 clamp_max = 0 OP_NAME = "MLAOperation" - OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4, "kvHead": heads} + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5, "kvHead": heads} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 13) self.set_output_formats([self.format_nd]) @@ -1053,7 +1052,7 @@ class TestMLAPrefill(op_test.OpTest): @op_test.only_910b - def test_flash_attention_mla_bf16_prefix_with_mask_free(self): + def test_flash_attention_mla_bf16_prefix_with_causal_mask(self): batch = 1 kv_head = 1 # kv_head num isdecoder = 0 # prefill or decoder @@ -1071,7 +1070,7 @@ class TestMLAPrefill(op_test.OpTest): clamp_min = 0 clamp_max = 0 OP_NAME = "MLAOperation" - OP_PARAM = {"type": 1, "qSeqLen":q_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4, "kvHead": heads} + OP_PARAM = {"type": 1, "qSeqLen":q_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5, "kvHead": heads} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 13) self.set_output_formats([self.format_nd]) -- Gitee