diff --git a/src/kernels/include/atbops/params/mla.h b/src/kernels/include/atbops/params/mla.h index 7a55e959664ad2cc5fab76336e54ad76f48166d8..dc1159a16e4318345672a38e60830d6c1606203b 100644 --- a/src/kernels/include/atbops/params/mla.h +++ b/src/kernels/include/atbops/params/mla.h @@ -37,7 +37,7 @@ struct MLA { MASK_TYPE_ALIBI = 2, MASK_TYPE_LOOK_AHEAD = 3, MASK_TYPE_MASK_FREE = 4, - MASK_TYPE_CAUSAL_COMPRESS = 5 + MASK_TYPE_CAUSAL_MASK = 5 }; MaskType maskType = MASK_TYPE_NONE; diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp index 14a41b4d86f1a398ca67621e0dcb2722911204a9..9ac7addcdda24425acbcaa75805fe96e49e02003 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp @@ -35,8 +35,12 @@ public: case OpParam::MLA::SPLIT_CACHE: return GetKernelByName("MLAKernel"); case OpParam::MLA::PREFILL_SPLIT_CACHE: - return inDtype == TENSOR_DTYPE_BF16 ? GetKernelByName("MLAPrefillBF16Kernel") : - GetKernelByName("MLAPrefillKernel"); + if (param.maskType == OpParam::MLA::MASK_TYPE_CAUSAL_MASK) { + return GetKernelByName("MLAPrefillBF16Kernel"); + } else { + return inDtype == TENSOR_DTYPE_BF16 ? GetKernelByName("MLAPrefillBF16Kernel") : + GetKernelByName("MLAPrefillKernel"); + } default: break; } @@ -187,7 +191,7 @@ private: auto currentShape = tensorMask.desc.dims; auto maskShape = currentShape.size(); MKI_CHECK(maskShape == DIM_2, "mask invalid, please check.", return false); - auto normCompress = param.maskType == OpParam::MLA::MASK_TYPE_CAUSAL_COMPRESS; + auto normCompress = param.maskType == OpParam::MLA::MASK_TYPE_MASK_FREE; // 全量mask当前仅支持512,512的压缩mask,其余不支持,需配合isTriuMask开启 std::vector, bool>> supports = { {{LONG_SEQ_LEN_COMPRESS, LONG_SEQ_LEN_COMPRESS}, normCompress}, @@ -206,7 +210,7 @@ private: uint32_t batch = kvSeqLen.size(); auto maxQSeqlenIter = std::max_element(qSeqLen.begin(), qSeqLen.end()); auto maxQ = maxQSeqlenIter != qSeqLen.end() ? *maxQSeqlenIter : 1; - if (param.maskType == OpParam::MLA::MASK_TYPE_NONE) { + if (param.maskType == OpParam::MLA::MASK_TYPE_NONE || param.maskType == OpParam::MLA::MASK_TYPE_CAUSAL_MASK) { MKI_CHECK(CheckEmptyTensor(mask), "mask type inconsistent", return false); } else { MKI_CHECK(!CheckEmptyTensor(mask), "mask type inconsistent", return false); diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce index 851d2156a3623720216a655602234ec0f4471e3f..ab6d60e8b2e42e0815432efffb38c894149aebd8 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_prefill_high_precision.cce @@ -220,7 +220,7 @@ public: uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; - + uint32_t prefixIdx = (kv_seqlen - q_seqlen) / pp_n_scalar; uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; uint32_t qk_round_m = (qk_m + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; @@ -300,7 +300,7 @@ public: uint32_t n_end = n_loop; if (is_triu_mask) { - n_end = m_idx + 1; + n_end = m_idx + prefixIdx + 1; } uint32_t launch_delay = S_BLOCK_STACK * 2; uint32_t vect_mod = 2 * launch_delay; @@ -622,40 +622,6 @@ private: uint32_t tiling_head_size{0}; uint32_t tiling_para_size{0}; }; - -extern "C" __global__ __aicore__ void mla_prefill_bf16( - __gm__ uint8_t *__restrict__ sync, - __gm__ uint8_t *__restrict__ q_gm, - __gm__ uint8_t *__restrict__ q_rope_gm, - __gm__ uint8_t *__restrict__ k_gm, - __gm__ uint8_t *__restrict__ k_rope_gm, - __gm__ uint8_t *__restrict__ v_gm, - __gm__ uint8_t *__restrict__ mask_gm, - __gm__ uint8_t *__restrict__ alibi_coeff_gm, - __gm__ uint8_t *__restrict__ deq_qk_gm, - __gm__ uint8_t *__restrict__ off_qk_gm, - __gm__ uint8_t *__restrict__ deq_pv_gm, - __gm__ uint8_t *__restrict__ off_pv_gm, - __gm__ uint8_t *__restrict__ quant_p_gm, - __gm__ uint8_t *__restrict__ logN_gm, - __gm__ uint8_t *__restrict__ o_gm, - __gm__ uint8_t *__restrict__ s_gm, - __gm__ uint8_t *__restrict__ p_gm, - __gm__ uint8_t *__restrict__ o_tmp_gm, - __gm__ uint8_t *__restrict__ upo_tmp_gm, - __gm__ uint8_t *__restrict__ tiling_para_gm) -{ - uint32_t tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 12)); - // 仅在tilingKey为1的时候是bf,其余场景均当做half处理 - if (tilingKey == 1) { - FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 0> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); - fa_cube.Run(); - } else if (tilingKey == 2){ - FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); - fa_cube.Run(); - } -} - #endif #ifdef __DAV_C220_VEC__ @@ -685,7 +651,7 @@ __aicore__ __attribute__((always_inline)) inline void __set_mask(int32_t len) } } -template +template __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( const AscendC::LocalTensor &s_ub, const AscendC::LocalTensor &mask_orig_ub, @@ -706,7 +672,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( bool move_flag) { uint32_t round_m = (m + FLOAT_BLOCK_SIZE - 1) / FLOAT_BLOCK_SIZE * FLOAT_BLOCK_SIZE; - if constexpr(MASK_TYPE != 0) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { if (move_flag) { WAIT_FLAG(MTE2, V, EVENT_ID6); conv_v( @@ -758,7 +724,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( 8 // srcRepeatStride ); PIPE_BARRIER(V); - if constexpr(MASK_TYPE != 0) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { if (move_flag) { add_v( s_ub, @@ -1012,7 +978,7 @@ __aicore__ __attribute__((always_inline)) void OnlineSoftmaxStage1BF16( } } -template +template class FlashAttentionEncoderHighPrecisionVec { public: __aicore__ __attribute__((always_inline)) inline FlashAttentionEncoderHighPrecisionVec( @@ -1048,7 +1014,7 @@ public: if (this->data_shape_type == 1) { this->stride_qo = embdv; } - if constexpr (MASK_TYPE != 0) { + if constexpr (MASK_TYPE != MaskType::MASK_TYPE_NONE) { this->is_triu_mask = 1; } this->__k = embd; @@ -1106,6 +1072,8 @@ public: o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(o_tmp_gm)); logN_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ IN_DATA_TYPE*>(logN_gm)); logN_float_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(logN_gm)); + float mask_coff = std::is_same_v ? HALF_MASK_COFF : BF16_MASK_COFF; + GenMaskMatrix(mask16_ubuf_tensor.template ReinterpretCast(), mask_coff); uint32_t next_process = 0; for (uint32_t process = block_idx; process < process_num; process = next_process) { while (process >= cur_total_q_blk_num * q_heads) { @@ -1141,6 +1109,7 @@ public: uint32_t m_loop = (q_seqlen + pp_m_scalar - 1) / pp_m_scalar; uint32_t n_loop = (kv_seqlen + pp_n_scalar - 1) / pp_n_scalar; + uint32_t prefixIdx = (kv_seqlen - q_seqlen) / pp_n_scalar; uint32_t qk_m = (m_idx == (m_loop - 1)) ? (q_seqlen - m_idx * pp_m_scalar) : pp_m_scalar; uint32_t sub_m = (sub_block_idx == 1) ? (qk_m - qk_m / 2) : qk_m / 2; @@ -1166,7 +1135,7 @@ public: uint32_t n_end = n_loop; if (is_triu_mask) { - n_end = m_idx + 1; + n_end = m_idx + prefixIdx + 1; } uint32_t qk_n_triu = n_end * pp_n_scalar; uint32_t pv_stage = 3; @@ -1187,7 +1156,8 @@ public: continue; } uint32_t p_scale_offset = n_idx / S_BLOCK_STACK % pv_stage * pp_m_scalar; - if (n_idx + S_BLOCK_STACK > n_end - 1) { + bool last_n_loop = n_idx + S_BLOCK_STACK > n_end - 1; + if (last_n_loop) { qk_n = qk_n_triu > kv_seqlen ? kv_seqlen - n_idx * pp_n_scalar : qk_n_triu - n_idx * pp_n_scalar; } else { qk_n = pp_n_scalar * S_BLOCK_STACK; @@ -1204,7 +1174,7 @@ public: uint32_t curr_m = m_ind == m_end - 1 ? sub_m - m_ind * m_slice : m_slice; uint32_t s_ub_offset = pingpong_flag * S_DB_SIZE; uint32_t row_offset = m_ind * m_slice; - if constexpr(MASK_TYPE != 0) { + if constexpr(MASK_TYPE == MaskType::MASK_TYPE_MASK_FREE) { if (move_flag) { uint64_t cmp_mask_offset = (m_idx - n_idx) * 512 * 128; WAIT_FLAG(V, MTE2, EVENT_ID6); @@ -1229,26 +1199,48 @@ public: if (curr_m == 0) { continue; } - OnlineSoftmaxStage1BF16 ( - ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], - mask16_ubuf_tensor, - mask_ubuf_tensor, - lm_ubuf_tensor[m_ind * m_slice], - hm_ubuf_tensor[m_ind * m_slice], - gm_ubuf_tensor[m_ind * m_slice], - dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + m_ind * m_slice], - ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], - ll_ubuf_tensor[m_ind * m_slice], - gl_ubuf_tensor[m_ind * m_slice], - lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), - tv_ubuf_tensor, - s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + - (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_ind * m_slice * qk_round_n], - p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + - ((uint64_t)sub_block_idx * qk_m / 2 + m_ind * m_slice) * qk_round_n) * 2 / sizeof(QKV_DT)], - n_idx == 0, this->tor, - curr_m, qk_n, qk_round_n, pingpong_flag, move_flag - ); + if constexpr (MASK_TYPE == MaskType::MASK_TYPE_CAUSAL_MASK) { + OnlineSoftmaxStage1 ( + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + mask16_ubuf_tensor.template ReinterpretCast(), + lm_ubuf_tensor[row_offset], + hm_ubuf_tensor[row_offset], + gm_ubuf_tensor[row_offset], + dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + row_offset], + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + ll_ubuf_tensor[row_offset], + gl_ubuf_tensor[row_offset], + lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), + tv_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_ind * m_slice * qk_round_n], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + m_ind * m_slice) * qk_round_n) * 2 / sizeof(QKV_DT)], + n_idx == 0, this->tor, mask_coff, curr_m, qk_n, qk_round_n, pp_n_scalar, + last_n_loop, m_ind * m_slice, qk_m / 2, pingpong_flag + ); + } else { + OnlineSoftmaxStage1BF16 ( + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + mask16_ubuf_tensor, + mask_ubuf_tensor, + lm_ubuf_tensor[row_offset], + hm_ubuf_tensor[row_offset], + gm_ubuf_tensor[row_offset], + dm_ubuf_tensor[((n_idx / S_BLOCK_STACK) % 4) * UB_FLOAT_LINE_SIZE + row_offset], + ls_ubuf_tensor[pingpong_flag * S_DB_SIZE], + ll_ubuf_tensor[row_offset], + gl_ubuf_tensor[row_offset], + lp_ubuf_tensor[pingpong_flag * S_DB_SIZE].ReinterpretCast(), + tv_ubuf_tensor, + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + (uint64_t)sub_block_idx * qk_m / 2 * qk_round_n + m_ind * m_slice * qk_round_n], + p_gm_tensor[((uint64_t)block_idx * TMP_SIZE + n_idx % vect_mod * TMP_SIZE / vect_mod + + ((uint64_t)sub_block_idx * qk_m / 2 + m_ind * m_slice) * qk_round_n) * 2 / sizeof(QKV_DT)], + n_idx == 0, this->tor, + curr_m, qk_n, qk_round_n, pingpong_flag, move_flag + ); + } pingpong_flag = 1 - pingpong_flag; } FftsCrossCoreSync(SOFTMAX_READY); @@ -1523,6 +1515,7 @@ private: uint32_t data_shape_type{0}; uint32_t quantType{0}; }; +#endif extern "C" __global__ __aicore__ void mla_prefill_bf16( __gm__ uint8_t *__restrict__ sync, @@ -1546,17 +1539,45 @@ extern "C" __global__ __aicore__ void mla_prefill_bf16( __gm__ uint8_t *__restrict__ upo_tmp_gm, __gm__ uint8_t *__restrict__ tiling_para_gm) { - uint32_t tilingKey = (uint32_t)(*((__gm__ int32_t *)tiling_para_gm + 12)); - if (tilingKey == 1) { - FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, 0> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, - tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, - deq_pv_gm, off_pv_gm, logN_gm); + if (TILING_KEY_IS(1)) { +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 0> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_NONE> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); + fa_vec.Run(); +#endif + } else if (TILING_KEY_IS(9)) { +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_MASK_FREE> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); - } else if (tilingKey == 2) { - FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, 1> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, - tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, - deq_pv_gm, off_pv_gm, logN_gm); +#endif + } else if (TILING_KEY_IS(10)) { +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); fa_vec.Run(); - } -} #endif + } else if (TILING_KEY_IS(11)) { +#ifdef __DAV_C220_CUBE__ + FlashAttentionEncoderHighPrecision<__bf16, __bf16, __bf16, 1> fa_cube(sync, q_gm, q_rope_gm, k_gm, k_rope_gm, v_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); + fa_cube.Run(); +#elif __DAV_C220_VEC__ + FlashAttentionEncoderHighPrecisionVec<__bf16, __bf16, __bf16, MaskType::MASK_TYPE_CAUSAL_MASK> fa_vec(sync, mask_gm, alibi_coeff_gm, o_gm, s_gm, p_gm, o_tmp_gm, + tiling_para_gm, deq_qk_gm, off_qk_gm, quant_p_gm, + deq_pv_gm, off_pv_gm, logN_gm); + fa_vec.Run(); +#endif +} +} \ No newline at end of file diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index e37116fa25d108a0c1ffff7d377adab3b1ee07d4..b5ea63c133131ee44353a30a124b6383b5d4fd81 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -222,23 +222,24 @@ inline Status PrefillPreCheck(OpParam::MLA ¶m) inline Status GetPrefiillMaskInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, const Tensor &tensorMask) { - auto maskShape = tensorMask.desc.dims; auto maskType = param.maskType; + if (maskType == OpParam::MLA::MASK_TYPE_NONE || maskType == OpParam::MLA::MASK_TYPE_CAUSAL_MASK) { + MKI_LOG(INFO) << "No Check for nomask or causal mask"; + return Status::OkStatus(); + } + auto maskShape = tensorMask.desc.dims; auto maxKvSeq = std::max_element(param.kvSeqLen.begin(), param.kvSeqLen.end()); MKI_LOG(INFO) << "max kv seq" << *maxKvSeq; - if (maskType != OpParam::MLA::MASK_TYPE_NONE) { - auto maskDim = maskShape.size(); - int32_t maxSeq = maskShape.at(maskDim - 1); - mmInfo.maxSeqLen = maxSeq; - MKI_CHECK(maskType == OpParam::MLA::MASK_TYPE_CAUSAL_COMPRESS, - "mask type invalid", - return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(maskDim == DIM_2, "maskdim invalid", - return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(maskShape.at(1) == NORM_CMP_MASK_LEN, "compress mask shape should be 512, 512", - return Status::FailStatus(ERROR_INVALID_VALUE)); - } - + auto maskDim = maskShape.size(); + int32_t maxSeq = maskShape.at(maskDim - 1); + mmInfo.maxSeqLen = maxSeq; + MKI_CHECK(maskType == OpParam::MLA::MASK_TYPE_MASK_FREE, + "mask type invalid", + return Status::FailStatus(ERROR_INVALID_VALUE)); + MKI_CHECK(maskDim == DIM_2, "maskdim invalid", + return Status::FailStatus(ERROR_INVALID_VALUE)); + MKI_CHECK(maskShape.at(1) == NORM_CMP_MASK_LEN, "compress mask shape should be 512, 512", + return Status::FailStatus(ERROR_INVALID_VALUE)); return Status::OkStatus(); } @@ -309,6 +310,16 @@ Status InitInfo(MLAInfo &mmInfo, OpParam::MLA ¶m) return Status::OkStatus(); } +Status GenMlaPrefillTilingKey(MLAInfo &mmInfo, KernelInfo &kernelInfo, OpParam::MLA ¶m) +{ + // currently only support fp16/bf16 maskfree prefill or bf16 prefill kernel + uint32_t dataType = static_cast(mmInfo.type); + uint32_t tilingKey = dataType + (mmInfo.maskType << NUM1); + kernelInfo.SetTilingId(tilingKey); + MKI_LOG(INFO) << "MLA Prefill TILING KEY IS = " << tilingKey; + return Status::OkStatus(); +} + Status MLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) { MLAInfo mmInfo; @@ -333,6 +344,9 @@ Status MLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) uint32_t *tilingParam = reinterpret_cast(tilingHost); ret = GetMLAPrefillTilingParam(mmInfo, blockDim, tilingParam, tilingSize); OP_TILING_CHECK_STATUS_RETURN(ret); + GetTilingKeyTypeBase(mmInfo, mmInfo.tensors.query, mmInfo.tensors.queryRope); + ret = GenMlaPrefillTilingKey(mmInfo, kernelInfo, param); + OP_TILING_CHECK_STATUS_RETURN(ret); uint64_t dataLenFloat = sizeof(float); uint64_t sSize = static_cast(blockDim) * static_cast(32768) * 16 * dataLenFloat; kernelInfo.GetScratchSizes() = {sSize, sSize, sSize, sSize * 2}; // oTmp/S/P diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index eacfe351021c4a6a45386adc722f2a0681552689..92abb29829d92c6109e6ed56b14147205af62c58 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -397,8 +397,6 @@ Status CheckSeqlen(const MLAInfo &mmInfo, const AddrOffsets &addrOffsets, int32_ UNUSED_VALUE(seqIdx); MKI_CHECK(qSeqlen >= 0, "qSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(kvSeqlen >= 0, "kvSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(kvSeqlen == qSeqlen, "kvSeqlen should equal qSeqlen", - return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(addrOffsets.totalQBlkNum >= 0, "totalQBlkNum overflow", return Status::FailStatus(ERROR_INVALID_VALUE)); return Status::OkStatus(); diff --git a/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce b/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce index 91630817ab8cda8419418bf0a9a2fafc5a7257d4..92930bb0ad356f3c957e53eed87a4b0adfba098b 100644 --- a/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce +++ b/src/kernels/mixkernels/unpad_flash_attention/op_kernel/fa_common.cce @@ -50,7 +50,8 @@ enum class MaskType { MASK_TYPE_ALIBI_COMPRESS_SQRT = 7, MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN = 8, MASK_TYPE_ALIBI_COMPRESS_128 = 9, - MASK_TYPE_CAUSAL_MASK = 10 // mask内部生成 + MASK_TYPE_CAUSAL_MASK = 10, + MASK_TYPE_MASK_FREE = 11 }; __aicore__ __attribute__((always_inline)) inline void SetVecMask(int32_t len) diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index 0c0e7aef8a678977340faa635ce8bae3ed221fe2..c0be80c8ea47a37431c668cf67bd20c4e38d889d 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -67,7 +67,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::SetupKernelGraph(const OpsTensorPac asdParam.kvHead = param_.kvHeadNum; asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_MASK_FREE : AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; mlaNode.opDesc = {0, "MLAOperation", asdParam}; @@ -98,7 +98,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::ModifyKernelGraph(const OpsTensorPa asdParam.qSeqLen = newParam_.qSeqlen; asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_MASK_FREE : AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; mlaNode.opDesc = {0, "MLAOperation", asdParam}; return NO_ERROR; diff --git a/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py b/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py index 82ab88a4c15dd1854894853e416d72184ac74b16..ada3810ebd07dbe50dbafa853dcef9a057501b76 100644 --- a/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py +++ b/tests/apitest/kernelstest/mix/test_mla_prefill_rope.py @@ -865,7 +865,7 @@ class TestMLAPrefill(op_test.OpTest): clamp_min = 0 clamp_max = 0 OP_NAME = "MLAOperation" - OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5} + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 13) self.set_output_formats([self.format_nd]) @@ -963,7 +963,7 @@ class TestMLAPrefill(op_test.OpTest): clamp_min = 0 clamp_max = 0 OP_NAME = "MLAOperation" - OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5, "kvHead": heads} + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 4, "kvHead": heads} self.set_param(OP_NAME, OP_PARAM) self.set_input_formats([self.format_nd] * 13) self.set_output_formats([self.format_nd]) @@ -999,5 +999,113 @@ class TestMLAPrefill(op_test.OpTest): [torch.tensor(attention_out, dtype=data_type)]) + @op_test.only_910b + def test_flash_attention_mla_bf16_causal_mask(self): + batch = 1 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 1, "qSeqLen":kv_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5, "kvHead": heads} + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd] * 13) + self.set_output_formats([self.format_nd]) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = True, + op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + attention_out = np.zeros_like(self.golden_out.to(torch.float16)) + + for i in range(1): + self.execute([self.q_split1, self.q_split2, self.k_split1[0], self.k_split2[0], self.v[0], + torch.tensor([], dtype=data_type), + torch.tensor([], dtype=torch.float), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], + [torch.tensor(attention_out, dtype=data_type)]) + + + @op_test.only_910b + def test_flash_attention_mla_bf16_prefix_with_causal_mask(self): + batch = 1 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 1280 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + prefix = 256 + content = 129 + q_seqLen = [content] * batch + kv_seqLen = [prefix + content] * batch + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 1, "qSeqLen":q_seqLen, "kvSeqLen": kv_seqLen, "headDimV": embeddimV,"headSize": heads, "tor": tor, "maskType": 5, "kvHead": heads} + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd] * 13) + self.set_output_formats([self.format_nd]) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, q_seqlens=q_seqLen, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = True, + op_type = OP_PARAM["type"], mask_type = MASK_TYPE_NO_BATCH_WITH_PREFIX, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + attention_out = np.zeros_like(self.golden_out.to(torch.float16)) + + for i in range(1): + self.execute([self.q_split1, self.q_split2, self.k_split1[0], self.k_split2[0], self.v[0], + torch.tensor([], dtype=data_type), + torch.tensor([], dtype=torch.float), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.int32), + torch.tensor([], dtype=torch.float), torch.tensor([], dtype=torch.float)], + [torch.tensor(attention_out, dtype=data_type)]) + + + if __name__ == '__main__': unittest.main() \ No newline at end of file