From 0f400107e4a91d737df703d77ddce154699a0073 Mon Sep 17 00:00:00 2001 From: Liexss Date: Mon, 21 Jul 2025 11:20:59 +0800 Subject: [PATCH 1/4] [feature]: support mla C8 lse --- .../multi_latent_attention/mla_kernel.cpp | 5 +- .../multi_latent_attention/op_kernel/mla.cce | 22 + .../op_kernel/mla_int8.cce | 955 +++++++++--------- .../tiling/mla_tiling.cpp | 4 +- .../test_paged_attention_mla_split_quant.py | 161 ++- 5 files changed, 657 insertions(+), 490 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp index ee0d14b7..143d2a2d 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp @@ -52,7 +52,10 @@ public: "paged attention: param type invalid", return false); auto param = AnyCast(launchParam.GetParam()); auto batch = param.kvSeqLen.size(); - if (param.headSize == M_LIMIT) { + auto maxQSeqlen = param.qSeqLen.data() != nullptr ? *std::max_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; + auto mtpTp1Flag = ((param.headSize == M_LIMIT) || + (launchParam.GetInTensor(DIM_0).desc.dtype == TENSOR_DTYPE_INT8 && maxQSeqlen > 1)); + if (mtpTp1Flag) { uint64_t taskNum = param.qSeqLen.data() == nullptr ? batch : std::accumulate(param.qSeqLen.data(), param.qSeqLen.data() + batch, static_cast(0)); diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 92fcddcf..ee3d7c7b 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -4401,6 +4401,28 @@ extern "C" __global__ __aicore__ void mla( MLADecoderAiv pa_aiv {}; pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); pa_aiv.RunTP1(); +#endif + } else if (TILING_KEY_IS(54)) { // INT8 fp16 FOUR_FLOW RING +#ifdef __DAV_C220_CUBE__ + MLAttentionDecoderAic pa_aic {}; + pa_aic.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); + pa_aic.RunTP1(); +#elif __DAV_C220_VEC__ + MLADecoderAiv pa_aiv {}; + pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); + pa_aiv.SetArgs2(lse_gm); + pa_aiv.RunTP1(); +#endif + } else if (TILING_KEY_IS(55)) { // INT8 bf16 FOUR_FLOW RING +#ifdef __DAV_C220_CUBE__ + MLAttentionDecoderAic pa_aic {}; + pa_aic.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); + pa_aic.RunTP1(); +#elif __DAV_C220_VEC__ + MLADecoderAiv pa_aiv {}; + pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); + pa_aiv.SetArgs2(lse_gm); + pa_aiv.RunTP1(); #endif } else if (TILING_KEY_IS(68)) { // fp16 + nd + FOUR_FLOW + flashdecoding 0b1000100 #ifdef __DAV_C220_CUBE__ diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce index a21b696b..8d1d3ef9 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce @@ -9,8 +9,8 @@ */ constexpr int32_t CHUNK_BUFFER_SIZE = 2; -constexpr int32_t TMP_SIZE_DECODER_INT8 = SP_BLOCK_SIZE*K_BLOCK_NUM*CHUNK_BUFFER_SIZE; -constexpr int64_t P_TMP_SIZE_INT8 = 128*512*CHUNK_BUFFER_SIZE; +constexpr int32_t TMP_SIZE_DECODER_INT8 = SP_BLOCK_SIZE * K_BLOCK_NUM * CHUNK_BUFFER_SIZE; +constexpr int64_t P_TMP_SIZE_INT8 = 128 * 512 * CHUNK_BUFFER_SIZE; constexpr int32_t UB_FLOAT_LINE_SIZE_TP1 = 128; #ifdef __DAV_C220_CUBE__ @@ -22,11 +22,11 @@ constexpr uint32_t TP1_INT8_L1_K_ROPE_SIZE = 128 * 64 * 4; constexpr uint32_t TP1_INT8_L1_P_SIZE = 128 * 512 * 2; constexpr uint32_t TP1_INT8_L1_V_SIZE = 128 * 512 * 2; constexpr uint32_t TP1_INT8_Q_OFFSET = 0; -constexpr uint32_t TP1_INT8_Q_ROPE_OFFSET = TP1_INT8_Q_OFFSET+TP1_INT8_L1_Q_SIZE; -constexpr uint32_t TP1_INT8_K_OFFSET = TP1_INT8_Q_ROPE_OFFSET+TP1_INT8_L1_Q_ROPE_SIZE; -constexpr uint32_t TP1_INT8_K_ROPE_OFFSET = TP1_INT8_K_OFFSET+TP1_INT8_L1_K_SIZE; -constexpr uint32_t TP1_INT8_P_OFFSET = TP1_INT8_K_ROPE_OFFSET+TP1_INT8_L1_K_ROPE_SIZE; -constexpr uint32_t TP1_INT8_V_OFFSET = TP1_INT8_P_OFFSET+TP1_INT8_L1_P_SIZE; +constexpr uint32_t TP1_INT8_Q_ROPE_OFFSET = TP1_INT8_Q_OFFSET + TP1_INT8_L1_Q_SIZE; +constexpr uint32_t TP1_INT8_K_OFFSET = TP1_INT8_Q_ROPE_OFFSET + TP1_INT8_L1_Q_ROPE_SIZE; +constexpr uint32_t TP1_INT8_K_ROPE_OFFSET = TP1_INT8_K_OFFSET + TP1_INT8_L1_K_SIZE; +constexpr uint32_t TP1_INT8_P_OFFSET = TP1_INT8_K_ROPE_OFFSET + TP1_INT8_L1_K_ROPE_SIZE; +constexpr uint32_t TP1_INT8_V_OFFSET = TP1_INT8_P_OFFSET + TP1_INT8_L1_P_SIZE; constexpr int32_t FLOAT_VECTOR_SIZE = 64; template (cur_head_num); - uint32_t v_rows_mad = 256; - uint32_t v_rows_mad_round = RoundUp(v_rows_mad); - uint32_t v_row_loop_num = (512 + v_rows_mad - 1) / v_rows_mad; + uint32_t v_rows_mad = 256; + uint32_t v_rows_mad_round = RoundUp(v_rows_mad); + uint32_t v_row_loop_num = (512 + v_rows_mad - 1) / v_rows_mad; uint32_t l0_p_pingpong_offset = block_size * v_rows_mad; uint32_t l1_p_pingpong_offset = block_size * v_rows_mad; uint32_t p_cols_l1_2_l0 = v_rows_mad; uint32_t p_cols_l1_2_l0_round = RoundUp(p_cols_l1_2_l0); uint32_t ndNum_k_gm_2_l1 = 2; - uint32_t v_rows_one_loop = block_size * K_BLOCK_NUM; + uint32_t v_rows_one_loop = block_size * K_BLOCK_NUM; uint32_t pv_n_2 = pp_n_scalar * K_BLOCK_NUM; uint32_t pv_round_n_2 = RoundUp(pv_n_2); uint32_t pv_round_n_2_l1 = RoundUp(pv_n_2); @@ -901,383 +901,367 @@ private: uint32_t embed_split_size = 256; uint32_t rope_embed_split_size = 64; uint32_t round_embed_split_size = RoundUp(embed_split_size); - uint32_t embed_split_loop_v = (embed_size + round_embed_split_size-1)/round_embed_split_size; - uint32_t q_load_coeff = cur_head_num_round; - uint32_t l0a_data_size = round_embed_split_size*q_load_coeff*sizeof(int8_t); - uint32_t l0b_data_size = round_embed_split_size*pp_n_scalar*sizeof(int8_t); - uint32_t l0c_data_size = q_load_coeff*pp_n_scalar; - - for (uint32_t n_idx = 0; n_idx < n_loop + 1; n_idx+=1){ + uint32_t embed_split_loop_v = (embed_size + round_embed_split_size - 1) / round_embed_split_size; + uint32_t q_load_coeff = cur_head_num; + uint32_t q_load_coeff_round = cur_head_num_round; + uint32_t l0a_data_size = round_embed_split_size * 128 * sizeof(int8_t); + uint32_t l0b_data_size = round_embed_split_size * pp_n_scalar * sizeof(int8_t); + uint32_t l0c_data_size = 128 * pp_n_scalar; + + for (uint32_t n_idx = 0; n_idx < n_loop + 1; n_idx += 1) { if (n_idx != n_loop) { - uint32_t cur_k_block_num = (n_idx == (n_loop - 1)) ? n_block_num - n_idx*K_BLOCK_NUM: K_BLOCK_NUM; - uint32_t gm_pingpong_flag = n_idx%CHUNK_BUFFER_SIZE; - - uint32_t l0c_pingpong_flag = 0; - uint32_t l0ab_pingpong_flag = 0; - for (uint32_t sub_idx = 0; sub_idx < cur_k_block_num; sub_idx+=1) { - uint32_t l1_kv_rope_pingpong_flag = sub_idx % 2; - qk_n = left_kv_seqlen > pp_n_scalar ? pp_n_scalar : left_kv_seqlen; - qk_round_n = RoundUp(qk_n); - - uint64_t hiddenSize_offset = start_head * embedding_size; - - /* ************ CUBE1 stage1 ************* */ + uint32_t cur_k_block_num = (n_idx == (n_loop - 1)) ? n_block_num - n_idx*K_BLOCK_NUM: K_BLOCK_NUM; + uint32_t gm_pingpong_flag = n_idx % CHUNK_BUFFER_SIZE; + uint32_t l0c_pingpong_flag = 0; + uint32_t l0ab_pingpong_flag = 0; + for (uint32_t sub_idx = 0; sub_idx < cur_k_block_num; sub_idx += 1) { + uint32_t l1_kv_rope_pingpong_flag = sub_idx % 2; + qk_n = left_kv_seqlen > pp_n_scalar ? pp_n_scalar : left_kv_seqlen; + qk_round_n = RoundUp(qk_n); - uint32_t block_table_id = (uint32_t)(*(block_tables_gm + - cur_batch * max_num_blocks_per_query + start_kv / block_size + real_block_idx)); - int64_t kv_offset = (int64_t)block_table_id * block_size * stride_kv; - int64_t kv_offset_rope = (int64_t)block_table_id * block_size * stride_kv_rope; - - uint32_t embed_split_idx = 0; - for (embed_split_idx = 0; embed_split_idx < embed_split_loop_v; ++embed_split_idx) { - uint32_t l1_kv_pingpong_flag = embed_split_idx % 2; - l0ab_pingpong_flag = embed_split_idx % 2; - - WAIT_FLAG(MTE1, MTE2, l1_kv_pingpong_flag); - - gm_to_l1( - l1kv_buf_addr_tensor[l1_kv_pingpong_flag * 128 * 256], - k_gm_tensor[kv_offset + embed_split_idx * block_size * embed_split_size], - qk_round_n, // nValue - qk_round_n, // nValue - pp_n_scalar, // dstNzC0Stride - embed_split_size, // dValue // 256 - embed_split_size, // dValue // 256 - 0 // srcDValue - ); - SET_FLAG(MTE2, MTE1, l1_kv_pingpong_flag); - WAIT_FLAG(MTE2, MTE1, l1_kv_pingpong_flag); + uint64_t hiddenSize_offset = start_head * embedding_size; - WAIT_FLAG(M, MTE1, l0ab_pingpong_flag); - for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff / BLOCK_SIZE; ++loa_load_idx) { - l1_to_l0_a( - l0a_buf_tensor[l0ab_pingpong_flag * l0a_data_size + loa_load_idx * round_embed_split_size * BLOCK_SIZE], - l1q_buf_addr_tensor[embed_split_idx * cur_head_num_round * round_embed_split_size + loa_load_idx * T_CUBE_MATRIX_SIZE], - 0, - round_embed_split_size / T_BLOCK_SIZE, // repeat - 0, - q_load_coeff / BLOCK_SIZE, // srcStride - 0, - 0 // dstStride - ); - } + /* ************ CUBE1 stage1 ************* */ - SET_FLAG(MTE1, M, l0ab_pingpong_flag); - WAIT_FLAG(MTE1, M, l0ab_pingpong_flag); + uint32_t block_table_id = (uint32_t)(*(block_tables_gm + + cur_batch * max_num_blocks_per_query + start_kv / block_size + real_block_idx)); + int64_t kv_offset = (int64_t)block_table_id * block_size * stride_kv; + int64_t kv_offset_rope = (int64_t)block_table_id * block_size * stride_kv_rope; + + uint32_t embed_split_idx = 0; + for (embed_split_idx = 0; embed_split_idx < embed_split_loop_v; ++embed_split_idx) { + uint32_t l1_kv_pingpong_flag = embed_split_idx % 2; + l0ab_pingpong_flag = embed_split_idx % 2; + WAIT_FLAG(MTE1, MTE2, l1_kv_pingpong_flag); + gm_to_l1( + l1kv_buf_addr_tensor[l1_kv_pingpong_flag * 128 * 256], + k_gm_tensor[kv_offset + embed_split_idx * block_size * embed_split_size], + qk_round_n, // nValue + qk_round_n, // nValue + pp_n_scalar, // dstNzC0Stride + embed_split_size, // dValue // 256 + embed_split_size, // dValue // 256 + 0 // srcDValue + ); + SET_FLAG(MTE2, MTE1, l1_kv_pingpong_flag); + WAIT_FLAG(MTE2, MTE1, l1_kv_pingpong_flag); - WAIT_FLAG(M, MTE1, l0ab_pingpong_flag + 2); - l1_to_l0_b( - l0b_buf_tensor[l0ab_pingpong_flag * l0b_data_size], - l1kv_buf_addr_tensor[l1_kv_pingpong_flag * 128 * 256], + WAIT_FLAG(M, MTE1, l0ab_pingpong_flag); + for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff_round / BLOCK_SIZE; ++loa_load_idx) { + l1_to_l0_a( + l0a_buf_tensor[l0ab_pingpong_flag * l0a_data_size + loa_load_idx * round_embed_split_size * BLOCK_SIZE], + l1q_buf_addr_tensor[embed_split_idx * cur_head_num_round * round_embed_split_size + loa_load_idx * T_CUBE_MATRIX_SIZE], 0, - round_embed_split_size * qk_round_n / T_CUBE_MATRIX_SIZE, // repeat + round_embed_split_size / T_BLOCK_SIZE, // repeat 0, - 1, // srcStride + q_load_coeff_round / BLOCK_SIZE, // srcStride 0, - 0 // dstStride - ); - - SET_FLAG(MTE1, MTE2, l1_kv_pingpong_flag); - - SET_FLAG(MTE1, M, l0ab_pingpong_flag + 2); - WAIT_FLAG(MTE1, M, l0ab_pingpong_flag + 2); - - auto unit_flag = embed_split_idx == embed_split_loop_v-1 ? 0b11 : 0b10; - mmad( - mm1_l0c_buf_tensor[l0c_pingpong_flag * l0c_data_size], - l0a_buf_tensor[l0ab_pingpong_flag * l0a_data_size], - l0b_buf_tensor[l0ab_pingpong_flag * l0b_data_size], - m, // m - qk_n, // n - embed_split_size, // k - embed_split_idx == 0, // cmatrixInitVal - unit_flag + 0 // dstStride ); - - SET_FLAG(M, MTE1, l0ab_pingpong_flag); - SET_FLAG(M, MTE1, l0ab_pingpong_flag + 2); - // copy S to gm - - if (embed_split_idx == embed_split_loop_v-1) { - l0c_to_gm( - s_gm_tensor[(uint64_t)block_idx * TMP_SIZE_DECODER_INT8 + gm_pingpong_flag*SP_BLOCK_SIZE*K_BLOCK_NUM +sub_idx*pp_n_scalar], - mm1_l0c_buf_tensor[l0c_pingpong_flag * l0c_data_size], - m, // MSize - qk_round_n, // NSize - RoundUp<16>(m), // srcStride - pp_n_scalar*K_BLOCK_NUM, // dstStride_dst_D - 0b11 - ); - l0c_pingpong_flag = (l0c_pingpong_flag+1)%2; - } } - l0c_pingpong_flag = (l0c_pingpong_flag+1)%2; - left_kv_seqlen -= pp_n_scalar; - real_block_idx += 1; - } - - - WAIT_FLAG(M, MTE1, EVENT_ID0); - for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff / BLOCK_SIZE; ++loa_load_idx) { - l1_to_l0_a( - l0a_buf_tensor.template ReinterpretCast()[loa_load_idx * rope_embed_split_size * BLOCK_SIZE], - l1q_rope_buf_addr_tensor[loa_load_idx * CUBE_MATRIX_SIZE], - 0, - rope_embed_split_size / BLOCK_SIZE, // repeat - 0, - q_load_coeff / BLOCK_SIZE, // srcStride - 0, - 0 // dstStride - ); - } - SET_FLAG(MTE1, M, EVENT_ID0); - WAIT_FLAG(MTE1, M, EVENT_ID0); //L0A - - for(uint32_t sub_idx = 0; sub_idx < cur_k_block_num; sub_idx+=1){ - l0ab_pingpong_flag = sub_idx % 2; - uint32_t l1_kv_rope_pingpong_flag = sub_idx % 2; - - qk_n_rope = left_kv_seqlen_rope > pp_n_scalar ? pp_n_scalar : left_kv_seqlen_rope; - qk_round_n_rope = RoundUp(qk_n_rope); - - - uint32_t block_table_id = (uint32_t)(*(block_tables_gm + - cur_batch * max_num_blocks_per_query + start_kv / block_size + real_block_idx_rope)); - int64_t kv_offset_rope = (int64_t)block_table_id * block_size * stride_kv_rope; - - WAIT_FLAG(MTE1, MTE2, l1_kv_rope_pingpong_flag + 2); - gm_to_l1( - l1kv_rope_buf_addr_tensor[l1_kv_rope_pingpong_flag * 128 * 64], - k_rope_gm_tensor[kv_offset_rope], - qk_round_n_rope, // nValue - qk_round_n_rope, // nValue - pp_n_scalar, // dstNzC0Stride - rope_size, // dValue - rope_size, // dValue - 0 // srcDValue - ); - - SET_FLAG(MTE2, MTE1, l1_kv_rope_pingpong_flag + 2); - WAIT_FLAG(MTE2, MTE1, l1_kv_rope_pingpong_flag + 2); - + SET_FLAG(MTE1, M, l0ab_pingpong_flag); + WAIT_FLAG(MTE1, M, l0ab_pingpong_flag); WAIT_FLAG(M, MTE1, l0ab_pingpong_flag + 2); - l1_to_l0_b( - l0b_buf_tensor.template ReinterpretCast()[l0ab_pingpong_flag * (l0b_data_size / sizeof(IN_ROPE_DTYPE))], - l1kv_rope_buf_addr_tensor[l1_kv_rope_pingpong_flag * 128 * 64], + l1_to_l0_b( + l0b_buf_tensor[l0ab_pingpong_flag * l0b_data_size], + l1kv_buf_addr_tensor[l1_kv_pingpong_flag * 128 * 256], 0, - rope_embed_split_size * qk_round_n_rope / CUBE_MATRIX_SIZE, // repeat + round_embed_split_size * qk_round_n / T_CUBE_MATRIX_SIZE, // repeat 0, 1, // srcStride 0, 0 // dstStride ); - SET_FLAG(MTE1, MTE2, l1_kv_rope_pingpong_flag + 2); - SET_FLAG(MTE1, M, l0ab_pingpong_flag + 2); - WAIT_FLAG(MTE1, M, l0ab_pingpong_flag + 2);//L0B + + SET_FLAG(MTE1, MTE2, l1_kv_pingpong_flag); - mmad( - mm1_l0c_buf_tensor.template ReinterpretCast()[l0c_pingpong_flag * l0c_data_size], - l0a_buf_tensor.template ReinterpretCast()[0], - l0b_buf_tensor.template ReinterpretCast()[l0ab_pingpong_flag * (l0b_data_size / sizeof(IN_ROPE_DTYPE))], - m, // m - qk_n_rope, // n - rope_embed_split_size, // k - 1, // cmatrixInitVal - 0b11 + SET_FLAG(MTE1, M, l0ab_pingpong_flag + 2); + WAIT_FLAG(MTE1, M, l0ab_pingpong_flag + 2); + + auto unit_flag = embed_split_idx == embed_split_loop_v - 1 ? 0b11 : 0b10; + mmad( + mm1_l0c_buf_tensor[l0c_pingpong_flag * l0c_data_size], + l0a_buf_tensor[l0ab_pingpong_flag * l0a_data_size], + l0b_buf_tensor[l0ab_pingpong_flag * l0b_data_size], + q_load_coeff, // m + qk_n, // n + embed_split_size, // k + embed_split_idx == 0, // cmatrixInitVal + unit_flag ); - + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, l0ab_pingpong_flag); SET_FLAG(M, MTE1, l0ab_pingpong_flag + 2); - - l0c_to_gm( - s_rope_gm_tensor[(uint64_t)block_idx * TMP_SIZE_DECODER_INT8 + gm_pingpong_flag*SP_BLOCK_SIZE*K_BLOCK_NUM +sub_idx*pp_n_scalar], - mm1_l0c_buf_tensor.template ReinterpretCast()[l0c_pingpong_flag * l0c_data_size], - m, // MSize - qk_round_n_rope, // NSize - RoundUp<16>(m), // srcStride - pp_n_scalar*K_BLOCK_NUM, // dstStride_dst_D - 0b11 - ); - - real_block_idx_rope += 1; - left_kv_seqlen_rope -= pp_n_scalar; + // copy S to gm + if (embed_split_idx == embed_split_loop_v - 1) { + l0c_to_gm( + s_gm_tensor[(uint64_t)block_idx * TMP_SIZE_DECODER_INT8 + gm_pingpong_flag*SP_BLOCK_SIZE*K_BLOCK_NUM +sub_idx*pp_n_scalar], + mm1_l0c_buf_tensor[l0c_pingpong_flag * l0c_data_size], + q_load_coeff, // MSize + qk_round_n, // NSize + RoundUp<16>(q_load_coeff), // srcStride + pp_n_scalar*K_BLOCK_NUM, // dstStride_dst_D + 0b11 + ); + l0c_pingpong_flag = (l0c_pingpong_flag+1)%2; + } } - SET_FLAG(M, MTE1, EVENT_ID0); + l0c_pingpong_flag = (l0c_pingpong_flag+1) % 2; + left_kv_seqlen -= pp_n_scalar; + real_block_idx += 1; + } + WAIT_FLAG(M, MTE1, EVENT_ID0); + for (uint64_t loa_load_idx = 0; loa_load_idx < q_load_coeff_round / BLOCK_SIZE; ++loa_load_idx) { + l1_to_l0_a( + l0a_buf_tensor.template ReinterpretCast()[loa_load_idx * rope_embed_split_size * BLOCK_SIZE], + l1q_rope_buf_addr_tensor[loa_load_idx * CUBE_MATRIX_SIZE], + 0, + rope_embed_split_size / BLOCK_SIZE, // repeat + 0, + q_load_coeff_round / BLOCK_SIZE, // srcStride + 0, + 0 // dstStride + ); + } + SET_FLAG(MTE1, M, EVENT_ID0); + WAIT_FLAG(MTE1, M, EVENT_ID0); //L0A + + for(uint32_t sub_idx = 0; sub_idx < cur_k_block_num; sub_idx+=1){ + l0ab_pingpong_flag = sub_idx % 2; + uint32_t l1_kv_rope_pingpong_flag = sub_idx % 2; + qk_n_rope = left_kv_seqlen_rope > pp_n_scalar ? pp_n_scalar : left_kv_seqlen_rope; + qk_round_n_rope = RoundUp(qk_n_rope); + uint32_t block_table_id = (uint32_t)(*(block_tables_gm + + cur_batch * max_num_blocks_per_query + start_kv / block_size + real_block_idx_rope)); + int64_t kv_offset_rope = (int64_t)block_table_id * block_size * stride_kv_rope; + + WAIT_FLAG(MTE1, MTE2, l1_kv_rope_pingpong_flag + 2); + gm_to_l1( + l1kv_rope_buf_addr_tensor[l1_kv_rope_pingpong_flag * 128 * 64], + k_rope_gm_tensor[kv_offset_rope], + qk_round_n_rope, // nValue + qk_round_n_rope, // nValue + pp_n_scalar, // dstNzC0Stride + rope_size, // dValue + rope_size, // dValue + 0 // srcDValue + ); - FftsCrossCoreSync(QK_READY_DECODER); + SET_FLAG(MTE2, MTE1, l1_kv_rope_pingpong_flag + 2); + WAIT_FLAG(MTE2, MTE1, l1_kv_rope_pingpong_flag + 2); + + WAIT_FLAG(M, MTE1, l0ab_pingpong_flag + 2); + l1_to_l0_b( + l0b_buf_tensor.template ReinterpretCast()[l0ab_pingpong_flag * (l0b_data_size / sizeof(IN_ROPE_DTYPE))], + l1kv_rope_buf_addr_tensor[l1_kv_rope_pingpong_flag * 128 * 64], + 0, + rope_embed_split_size * qk_round_n_rope / CUBE_MATRIX_SIZE, // repeat + 0, + 1, // srcStride + 0, + 0 // dstStride + ); + SET_FLAG(MTE1, MTE2, l1_kv_rope_pingpong_flag + 2); + SET_FLAG(MTE1, M, l0ab_pingpong_flag + 2); + WAIT_FLAG(MTE1, M, l0ab_pingpong_flag + 2); + mmad( + mm1_l0c_buf_tensor.template ReinterpretCast()[l0c_pingpong_flag * l0c_data_size], + l0a_buf_tensor.template ReinterpretCast()[0], + l0b_buf_tensor.template ReinterpretCast()[l0ab_pingpong_flag * (l0b_data_size / sizeof(IN_ROPE_DTYPE))], + q_load_coeff, // m + qk_n_rope, // n + rope_embed_split_size, // k + 1, // cmatrixInitVal + 0b11 + ); + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, l0ab_pingpong_flag + 2); + l0c_to_gm( + s_rope_gm_tensor[(uint64_t)block_idx * TMP_SIZE_DECODER_INT8 + gm_pingpong_flag*SP_BLOCK_SIZE*K_BLOCK_NUM +sub_idx*pp_n_scalar], + mm1_l0c_buf_tensor.template ReinterpretCast()[l0c_pingpong_flag * l0c_data_size], + q_load_coeff, // MSize + qk_round_n_rope, // NSize + RoundUp<16>(q_load_coeff), // srcStride + pp_n_scalar*K_BLOCK_NUM, // dstStride_dst_D + 0b11 + ); + real_block_idx_rope += 1; + left_kv_seqlen_rope -= pp_n_scalar; } - /* ************ CUBE2 stage1 ************* */ + SET_FLAG(M, MTE1, EVENT_ID0); + + FftsCrossCoreSync(QK_READY_DECODER); + } + /* ************ CUBE2 stage1 ************* */ if (n_idx != 0) { - if (n_idx == n_loop) { - pv_n_2 = (cur_kv_seqlen - (n_idx - 1) * v_rows_one_loop); - pv_round_n_2 = RoundUp(pv_n_2); - pv_round_n_2_l1 = RoundUp(pv_n_2); - v_row_loop_num = (pv_n_2 + v_rows_mad - 1) / v_rows_mad; - } - l1p_pingpong_flag = (n_idx-1)%2; - - WaitFlagDev(SOFTMAX_READY_DECODER); - WAIT_FLAG(MTE1, MTE2, l1p_pingpong_flag+4); - gm_to_l1( - l1p_buf_addr_tensor[l1p_pingpong_flag * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE], - p_gm_tensor[(uint64_t)block_idx * P_TMP_SIZE_INT8 + ((n_idx - 1) % CHUNK_BUFFER_SIZE)* P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE], - cur_head_num, // nValue - RoundUp(cur_head_num),// dstNzC0Stride - 0, // dstNzMatrixStride, unused - pv_round_n_2, - 0, // dstNzMatrixStride, unused - pp_n_scalar*K_BLOCK_NUM // srcDValue - ); - SET_FLAG(MTE2, MTE1, l1p_pingpong_flag+4); + if (n_idx == n_loop) { + pv_n_2 = (cur_kv_seqlen - (n_idx - 1) * v_rows_one_loop); + pv_round_n_2 = RoundUp(pv_n_2); + pv_round_n_2_l1 = RoundUp(pv_n_2); + v_row_loop_num = (pv_n_2 + v_rows_mad - 1) / v_rows_mad; + } + l1p_pingpong_flag = (n_idx-1)%2; + + WaitFlagDev(SOFTMAX_READY_DECODER); + WAIT_FLAG(MTE1, MTE2, l1p_pingpong_flag + 4); + gm_to_l1( + l1p_buf_addr_tensor[l1p_pingpong_flag * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE], + p_gm_tensor[(uint64_t)block_idx * P_TMP_SIZE_INT8 + ((n_idx - 1) % CHUNK_BUFFER_SIZE) * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE], + cur_head_num, // nValue + RoundUp(cur_head_num),// dstNzC0Stride + 0, // dstNzMatrixStride, unused + pv_round_n_2, + 0, // dstNzMatrixStride, unused + pp_n_scalar * K_BLOCK_NUM // srcDValue + ); + SET_FLAG(MTE2, MTE1, l1p_pingpong_flag + 4); + uint32_t block_table_id = (uint32_t)(*(block_tables_gm + + cur_batch * max_num_blocks_per_query + start_kv / block_size + (n_idx - 1) * K_BLOCK_NUM)); + uint32_t embed_split_size = 128; + uint32_t embed_split_loop_v = 4; + uint32_t round_embed_split_size = RoundUp(embed_split_size); + for (uint32_t embed_split_idx = 0; embed_split_idx < embed_split_loop_v; ++embed_split_idx) { + uint32_t l0c_o_pingpong_flag = embed_split_idx % 2; + for(uint32_t v_row_loop_idx = 0; v_row_loop_idx < v_row_loop_num; v_row_loop_idx++){ + if (n_idx == n_loop) { + p_cols_l1_2_l0 = v_row_loop_idx == v_row_loop_num - 1 ? pv_n_2 - v_row_loop_idx * v_rows_mad : v_rows_mad; + p_cols_l1_2_l0_round = RoundUp(p_cols_l1_2_l0); + } - uint32_t block_table_id = (uint32_t)(*(block_tables_gm + - cur_batch * max_num_blocks_per_query + start_kv / block_size + (n_idx-1) * K_BLOCK_NUM)); - - uint32_t embed_split_size = 128; - uint32_t embed_split_loop_v = 4; - uint32_t round_embed_split_size = RoundUp(embed_split_size); - for (uint32_t embed_split_idx = 0; embed_split_idx < embed_split_loop_v; ++embed_split_idx) { - uint32_t l0c_o_pingpong_flag = embed_split_idx % 2; - for(uint32_t v_row_loop_idx = 0; v_row_loop_idx < v_row_loop_num; v_row_loop_idx++){ - if (n_idx == n_loop) { - p_cols_l1_2_l0 = v_row_loop_idx == v_row_loop_num - 1 ? pv_n_2 - v_row_loop_idx * v_rows_mad : v_rows_mad; - p_cols_l1_2_l0_round = RoundUp(p_cols_l1_2_l0); - } + uint32_t combine_loop_idx = embed_split_idx * v_row_loop_num + v_row_loop_idx; + uint32_t l1_right_v_pingpong_flag = combine_loop_idx % 2; + uint32_t l0_right_v_pingpong_flag = combine_loop_idx % 2; + uint32_t l0_left_p_pingpong_flag = v_row_loop_idx % 2; + uint32_t l1_left_p_pingpong_flag = v_row_loop_idx % 2; + + // move p from l1 to l0a + WAIT_FLAG(M, MTE1, l0_left_p_pingpong_flag); - uint32_t combine_loop_idx = embed_split_idx * v_row_loop_num + v_row_loop_idx; - uint32_t l1_right_v_pingpong_flag = combine_loop_idx % 2; - uint32_t l0_right_v_pingpong_flag = combine_loop_idx % 2; - uint32_t l0_left_p_pingpong_flag = v_row_loop_idx % 2; - uint32_t l1_left_p_pingpong_flag = v_row_loop_idx % 2; - - // move p from l1 to l0a - WAIT_FLAG(M, MTE1, l0_left_p_pingpong_flag); - - if(embed_split_idx == 0 || v_row_loop_idx >= 2){ - if(v_row_loop_idx == 0) { - WAIT_FLAG(MTE2, MTE1, l1p_pingpong_flag+4); - } - l1_to_l0_a( - l0a_buf_tensor[l0_left_p_pingpong_flag * l0_p_pingpong_offset], - l1p_buf_addr_tensor[l1p_pingpong_flag * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE + l1_left_p_pingpong_flag * l1_p_pingpong_offset], - RoundUp(cur_head_num), - p_cols_l1_2_l0_round, // repeat ori:128 -> new: 128*2 - 0, - 0, // srcStride - 0, - 0 // dstStride - ); - SET_FLAG(MTE1, M, l1p_pingpong_flag + 4); - WAIT_FLAG(MTE1, M, l1p_pingpong_flag + 4); - if(v_row_loop_idx == v_row_loop_num - 1) { - SET_FLAG(MTE1, MTE2, l1p_pingpong_flag + 4); - } - SET_FLAG(MTE1, MTE2, l0_left_p_pingpong_flag + 6); - WAIT_FLAG(MTE1, MTE2, l0_left_p_pingpong_flag + 6); + if(embed_split_idx == 0 || v_row_loop_idx >= 2){ + if(v_row_loop_idx == 0) { + WAIT_FLAG(MTE2, MTE1, l1p_pingpong_flag+4); } - - WAIT_FLAG(MTE1, MTE2, l1_right_v_pingpong_flag); - - block_table_id = (uint32_t)(*(block_tables_gm + - cur_batch * max_num_blocks_per_query + start_kv / block_size + (n_idx-1) * K_BLOCK_NUM + v_row_loop_idx * 2)); - int64_t kv_offset = (int64_t)block_table_id * block_size * stride_kv; + l1_to_l0_a( + l0a_buf_tensor[l0_left_p_pingpong_flag * l0_p_pingpong_offset], + l1p_buf_addr_tensor[l1p_pingpong_flag * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE + l1_left_p_pingpong_flag * q_load_coeff_round * 256], + RoundUp(cur_head_num), + p_cols_l1_2_l0_round, // repeat ori:128 -> new: 128*2 + 0, + 0, // srcStride + 0, + 0 // dstStride + ); + SET_FLAG(MTE1, M, l1p_pingpong_flag + 4); + WAIT_FLAG(MTE1, M, l1p_pingpong_flag + 4); + if(v_row_loop_idx == v_row_loop_num - 1) { + SET_FLAG(MTE1, MTE2, l1p_pingpong_flag + 4); + } + SET_FLAG(MTE1, MTE2, l0_left_p_pingpong_flag + 6); + WAIT_FLAG(MTE1, MTE2, l0_left_p_pingpong_flag + 6); + } + + WAIT_FLAG(MTE1, MTE2, l1_right_v_pingpong_flag); + + block_table_id = (uint32_t)(*(block_tables_gm + + cur_batch * max_num_blocks_per_query + start_kv / block_size + (n_idx-1) * K_BLOCK_NUM + v_row_loop_idx * 2)); + int64_t kv_offset = (int64_t)block_table_id * block_size * stride_kv; - uint32_t rows_v_gm_to_l1_one_loop = p_cols_l1_2_l0 > 128 ? 128 : p_cols_l1_2_l0; - uint32_t rows_v_gm_to_l1_one_loop_round = RoundUp(rows_v_gm_to_l1_one_loop); - uint32_t rows_v_gm_to_l1_one_loop2 = p_cols_l1_2_l0 - 128; - uint32_t rows_v_gm_to_l1_one_loop_round2 = RoundUp(rows_v_gm_to_l1_one_loop2); + uint32_t rows_v_gm_to_l1_one_loop = p_cols_l1_2_l0 > 128 ? 128 : p_cols_l1_2_l0; + uint32_t rows_v_gm_to_l1_one_loop_round = RoundUp(rows_v_gm_to_l1_one_loop); + uint32_t rows_v_gm_to_l1_one_loop2 = p_cols_l1_2_l0 - 128; + uint32_t rows_v_gm_to_l1_one_loop_round2 = RoundUp(rows_v_gm_to_l1_one_loop2); + gm_to_l1( + l1v_buf_addr_tensor[l1_right_v_pingpong_flag * block_size * v_rows_mad], + k_gm_tensor[kv_offset + embed_split_idx * block_size * block_size], + rows_v_gm_to_l1_one_loop_round, // nValue + rows_v_gm_to_l1_one_loop_round, // nValue + embed_split_size, // dstNzC0Stride + embed_split_size, // dValue + embed_split_size, // dValue + 0 // srcDValue + ); + if (p_cols_l1_2_l0 > 128) { + block_table_id = (uint32_t)(*(block_tables_gm + + cur_batch * max_num_blocks_per_query + start_kv / block_size + (n_idx - 1) * K_BLOCK_NUM + v_row_loop_idx * 2 + 1)); + int64_t kv_offset2 = (int64_t)block_table_id * block_size * stride_kv; gm_to_l1( - l1v_buf_addr_tensor[l1_right_v_pingpong_flag * block_size * v_rows_mad], - k_gm_tensor[kv_offset + embed_split_idx * block_size * block_size], - rows_v_gm_to_l1_one_loop_round, // nValue - rows_v_gm_to_l1_one_loop_round, // nValue - embed_split_size, // dstNzC0Stride - embed_split_size, // dValue - embed_split_size, // dValue - 0 // srcDValue - ); - if (p_cols_l1_2_l0 > 128) { - block_table_id = (uint32_t)(*(block_tables_gm + - cur_batch * max_num_blocks_per_query + start_kv / block_size + (n_idx-1) * K_BLOCK_NUM + v_row_loop_idx * 2 + 1)); - int64_t kv_offset2 = (int64_t)block_table_id * block_size * stride_kv; - gm_to_l1( l1v_buf_addr_tensor[l1_right_v_pingpong_flag * block_size * v_rows_mad + rows_v_gm_to_l1_one_loop_round * block_size], k_gm_tensor[kv_offset2 + embed_split_idx * block_size * block_size], rows_v_gm_to_l1_one_loop_round2, // nValue - rows_v_gm_to_l1_one_loop_round2, // nValue + rows_v_gm_to_l1_one_loop_round2, // nValue embed_split_size, // dstNzC0Stride embed_split_size, // dValue embed_split_size, // dValue 0 // srcDValue - ); - } - - SET_FLAG(MTE2, MTE1, l1_right_v_pingpong_flag + 6); - WAIT_FLAG(MTE2, MTE1, l1_right_v_pingpong_flag + 6); - - AscendC::LoadData2dTransposeParams loadDataParams; - loadDataParams.dstGap = 0; - loadDataParams.startIndex = 0; - loadDataParams.dstFracGap = 0; - if (block_size <= round_embed_split_size) { - loadDataParams.repeatTimes = round_embed_split_size / T_BLOCK_SIZE; - loadDataParams.srcStride = rows_v_gm_to_l1_one_loop_round / T_BLOCK_SIZE; - uint16_t dstGap = sizeof(int8_t) == 1 ? 1 : 0; - loadDataParams.dstGap = dstGap; - uint32_t copy_loops1 = (rows_v_gm_to_l1_one_loop + T_BLOCK_SIZE - 1) / T_BLOCK_SIZE; - WAIT_FLAG(M, MTE1, l0_right_v_pingpong_flag + 2); - for (uint32_t l0b_load_idx = 0; l0b_load_idx < copy_loops1; ++l0b_load_idx) { + ); + } + SET_FLAG(MTE2, MTE1, l1_right_v_pingpong_flag + 6); + WAIT_FLAG(MTE2, MTE1, l1_right_v_pingpong_flag + 6); + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.dstGap = 0; + loadDataParams.startIndex = 0; + loadDataParams.dstFracGap = 0; + if (block_size <= round_embed_split_size) { + loadDataParams.repeatTimes = round_embed_split_size / T_BLOCK_SIZE; + loadDataParams.srcStride = rows_v_gm_to_l1_one_loop_round / T_BLOCK_SIZE; + uint16_t dstGap = sizeof(int8_t) == 1 ? 1 : 0; + loadDataParams.dstGap = dstGap; + uint32_t copy_loops1 = (rows_v_gm_to_l1_one_loop + T_BLOCK_SIZE - 1) / T_BLOCK_SIZE; + WAIT_FLAG(M, MTE1, l0_right_v_pingpong_flag + 2); + for (uint32_t l0b_load_idx = 0; l0b_load_idx < copy_loops1; ++l0b_load_idx) { + AscendC::LoadDataWithTranspose( + l0b_buf_tensor[l0_right_v_pingpong_flag * 16384 * 2 + l0b_load_idx * RoundUp<16>(embed_split_size) * T_BLOCK_SIZE], + l1v_buf_addr_tensor[l1_right_v_pingpong_flag * 128 * 256 + l0b_load_idx * T_BLOCK_SIZE * T_BLOCK_SIZE], + loadDataParams); + } + uint32_t copy_loops2 = (rows_v_gm_to_l1_one_loop2 + T_BLOCK_SIZE - 1) / T_BLOCK_SIZE; + loadDataParams.srcStride = rows_v_gm_to_l1_one_loop_round2 / T_BLOCK_SIZE; + if (p_cols_l1_2_l0 > 128) { + for (uint32_t l0b_load_idx = 0; l0b_load_idx < copy_loops2; ++l0b_load_idx) { AscendC::LoadDataWithTranspose( - l0b_buf_tensor[l0_right_v_pingpong_flag * 16384 * 2 + l0b_load_idx * RoundUp<16>(embed_split_size) * T_BLOCK_SIZE], - l1v_buf_addr_tensor[l1_right_v_pingpong_flag * 128 * 256 + l0b_load_idx * T_BLOCK_SIZE * T_BLOCK_SIZE], + l0b_buf_tensor[l0_right_v_pingpong_flag * 16384 * 2 + l0b_load_idx * RoundUp<16>(embed_split_size) * T_BLOCK_SIZE + block_size * block_size], + l1v_buf_addr_tensor[l1_right_v_pingpong_flag * 128 * 256 + l0b_load_idx * T_BLOCK_SIZE * T_BLOCK_SIZE + block_size * block_size], loadDataParams); - } - uint32_t copy_loops2 = (rows_v_gm_to_l1_one_loop2 + T_BLOCK_SIZE - 1) / T_BLOCK_SIZE; - loadDataParams.srcStride = rows_v_gm_to_l1_one_loop_round2 / T_BLOCK_SIZE; - if (p_cols_l1_2_l0 > 128) { - for (uint32_t l0b_load_idx = 0; l0b_load_idx < copy_loops2; ++l0b_load_idx) { - AscendC::LoadDataWithTranspose( - l0b_buf_tensor[l0_right_v_pingpong_flag * 16384 * 2 + l0b_load_idx * RoundUp<16>(embed_split_size) * T_BLOCK_SIZE + block_size * block_size], - l1v_buf_addr_tensor[l1_right_v_pingpong_flag * 128 * 256 + l0b_load_idx * T_BLOCK_SIZE * T_BLOCK_SIZE + block_size * block_size], - loadDataParams); - } - } - - SET_FLAG(MTE1, MTE2, l1_right_v_pingpong_flag); - SET_FLAG(MTE1, M, l1_right_v_pingpong_flag); - WAIT_FLAG(MTE1, M, l1_right_v_pingpong_flag); - - bool init_c = v_row_loop_idx == 0 ? true : false; - auto unit_flag = v_row_loop_idx == v_row_loop_num - 1 ? 0b11 : 0b10; - mmad( + } + } + + SET_FLAG(MTE1, MTE2, l1_right_v_pingpong_flag); + SET_FLAG(MTE1, M, l1_right_v_pingpong_flag); + WAIT_FLAG(MTE1, M, l1_right_v_pingpong_flag); + + bool init_c = v_row_loop_idx == 0 ? true : false; + auto unit_flag = v_row_loop_idx == v_row_loop_num - 1 ? 0b11 : 0b10; + mmad( + mm2_l0c_buf_tensor[l0c_o_pingpong_flag * 16384], + l0a_buf_tensor[l0_left_p_pingpong_flag * 16384 * 2], + l0b_buf_tensor[l0_right_v_pingpong_flag * 16384 * 2], + q_load_coeff, // m + embed_split_size, // n + p_cols_l1_2_l0, // k + init_c, // cmatrixInitVal + unit_flag + ); + PIPE_BARRIER(M); + SET_FLAG(M, MTE1, l0_left_p_pingpong_flag); + SET_FLAG(M, MTE1, l0_right_v_pingpong_flag + 2); + + if(v_row_loop_idx == v_row_loop_num - 1){ + // copy O to gm + l0c_to_gm( + o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE * CHUNK_BUFFER_SIZE + embed_split_idx * round_embed_split_size + ((n_idx - 1)%CHUNK_BUFFER_SIZE) * TMP_SIZE], mm2_l0c_buf_tensor[l0c_o_pingpong_flag * 16384], - l0a_buf_tensor[l0_left_p_pingpong_flag * 16384 * 2], - l0b_buf_tensor[l0_right_v_pingpong_flag * 16384 * 2], - m, // m - embed_split_size, // n - p_cols_l1_2_l0, // k - init_c, // cmatrixInitVal - unit_flag - ); - SET_FLAG(M, MTE1, l0_left_p_pingpong_flag); - SET_FLAG(M, MTE1, l0_right_v_pingpong_flag + 2); - - if(v_row_loop_idx == v_row_loop_num - 1){ - // copy O to gm - l0c_to_gm( - o_tmp_gm_tensor[(uint64_t)block_idx * TMP_SIZE * CHUNK_BUFFER_SIZE + embed_split_idx * round_embed_split_size + ((n_idx - 1)%CHUNK_BUFFER_SIZE) * TMP_SIZE], - mm2_l0c_buf_tensor[l0c_o_pingpong_flag * 16384], - m, // MSize - RoundUp<16>(embed_split_size), // NSize - RoundUp<16>(m), // srcStride - round_v, // dstStride_dst_D - 0b11 - ); - } + q_load_coeff, // MSize + RoundUp<16>(embed_split_size), // NSize + RoundUp<16>(q_load_coeff), // srcStride + round_v, // dstStride_dst_D + 0b11 + ); } } } - FftsCrossCoreSync(UPDATE_READY_DECODER); } + FftsCrossCoreSync(UPDATE_READY_DECODER); + } } } @@ -1309,9 +1293,9 @@ private: AscendC::GlobalTensor block_tables_gm_tensor; const uint32_t l1q_buf_addr_offset = 0; - const uint32_t l1q_rope_buf_addr_offset = 128*512*2; - const uint32_t l1kv_buf_addr_offset = 128*576*2; - const uint32_t l1kv_rope_buf_addr_offset= 128 * 576 * 2 + 128 * 512 * 2; + const uint32_t l1q_rope_buf_addr_offset = 128 * 512 * 2; + const uint32_t l1kv_buf_addr_offset = 128 * 576 * 2; + const uint32_t l1kv_rope_buf_addr_offset = 128 * 576 * 2 + 128 * 512 * 2; const uint32_t l1p_buf_addr_offset = 128 * 576 * 6; AsdopsBuffer buf; @@ -2856,26 +2840,25 @@ __aicore__ __attribute__((always_inline)) inline void OnlineSoftmaxStage1Step1Qu 8 // srcRepeatStride ); PIPE_BARRIER(V); - if(n_real%64) - { + if (n_real % 64) { uint64_t n_stride_2 = n_real; - if(n_real%8) - n_stride_2= n_real- n_real % 8; + if(n_real % 8) + n_stride_2 = n_real- n_real % 8; uint64_t need_dup = n_stride - n_real; // n_stride = qk_round_n = RoundUp<64>(qk_n); float scalar = -1e30; uint64_t mask64 = 0; uint64_t range_mask = (1ULL << (need_dup )) - 1; - mask64 = range_mask << (RoundUp<8>(need_dup)-need_dup); + mask64 = range_mask << (RoundUp<8>(need_dup) - need_dup); uint64_t mask[2] = { mask64, 0 }; - AscendC::Duplicate(s_ub[n_stride_2], scalar, mask, \ + AscendC::Duplicate(s_ub[n_stride_2], scalar, mask, round_m, //repeat 1, //dstBlockStride - n_stride/8 ); //dstRepeatStride; + n_stride / 8); //dstRepeatStride; PIPE_BARRIER(V); SetVectorMask((uint64_t)-1, (uint64_t)-1); - } + } RowmaxQuant( s_ub, @@ -3095,7 +3078,7 @@ template s_exp_ub, local_rowsum_ub, tmp_ub, - first_n_iter,tor,m,n_real,n_stride + first_n_iter, tor, m, n_real, n_stride ); //quant @@ -3189,7 +3172,7 @@ template uint32_t head_res_row_num, uint32_t head_start_sblock_idx, uint32_t tail_res_row_num - ) + ) { uint32_t sub_m_d64 = (sub_m + 63) / 64; // up aligned to 64 uint32_t round_sub_m = (sub_m + 15) / 16 * 16; @@ -3550,8 +3533,9 @@ template uint32_t pingpong_flag, uint32_t cur_nIndx, uint32_t prev_split_num, - uint32_t split_num - ) + uint32_t split_num, + uint32_t all_sub_m + ) { uint32_t sub_m_d64 = (sub_m + 63) / 64; // up aligned to 64 uint32_t round_sub_m = (sub_m + 15) / 16 * 16; @@ -3580,7 +3564,6 @@ template SetVectorMask((uint64_t)-1, (uint64_t)-1); WAIT_FLAG(MTE3, MTE2, pingpong_flag); if (n_idx != 0) { - SetVectorMask((uint64_t)-1, (uint64_t)-1); brcb_v(tv32_ubuf_tensor.ReinterpretCast(), dm32_ubuf_tensor.ReinterpretCast(), @@ -3589,19 +3572,17 @@ template round_sub_m / FLOAT_BLOCK_SIZE // repeat ); PIPE_BARRIER(V); - if (head_loop > 1) { - gm_to_ub( - go32_ubuf_tensor_cur, - go_gm_tensor, - 0, - 1, - sub_m * round_v / FLOAT_BLOCK_SIZE, - 0, - 0 - ); - SET_FLAG(MTE2, V, pingpong_flag); - WAIT_FLAG(MTE2, V, pingpong_flag); - } + gm_to_ub( + go32_ubuf_tensor_cur, + go_gm_tensor, + 0, + 1, + sub_m * round_v / FLOAT_BLOCK_SIZE, + 0, + 0 + ); + SET_FLAG(MTE2, V, pingpong_flag); + WAIT_FLAG(MTE2, V, pingpong_flag); // *** go = go * dm_block SetVectorMask((uint64_t)-1, (uint64_t)-1); @@ -3765,10 +3746,62 @@ template 0, // srcGap 0 // dstGap ); - } + } + if constexpr (IS_RING) { + if (head_loop_idx == head_loop - 1) { + uint32_t outBytes = sizeof(OUT_DTYPE); + uint32_t cal_head_num = all_sub_m; + uint32_t cal_head_num_d64 = (cal_head_num + 63) / 64; + ln_v(lse32_ubuf_tensor, + gl32_ubuf_tensor, + cal_head_num_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 8, // dstRepeatStride + 8 // srcRepeatStride + ); + PIPE_BARRIER(V); + add_v(lse32_ubuf_tensor, + lse32_ubuf_tensor, + gm32_ubuf_tensor, + cal_head_num_d64, // repeat + 1, // dstBlockStride + 1, // src0BlockStride + 1, // src1BlockStride + 8, // dstRepeatStride + 8, // src0RepeatStride + 8 // src1RepeatStride + ); + PIPE_BARRIER(V); + conv_v(lse_conv_ubuf_tensor, + lse32_ubuf_tensor, + cal_head_num_d64, // repeat + 1, // dstBlockStride + 1, // srcBlockStride + 4, // dstRepeatStride + 8 // srcRepeatStride + ); + SET_FLAG(V, MTE3, EVENT_ID1); + WAIT_FLAG(V, MTE3, EVENT_ID1); + // 搬出lse + ub_to_gm_align( + lse_gm_tensor[(int64_t)(o_offset / __k)], + lse_conv_ubuf_tensor, + 0, // sid + 1, // nBurst + outBytes * all_sub_m,// lenBurst + 0, // leftPaddingNum + 0, // rightPaddingNum + 0, // srcGap + 0 // dstGap + ); + SET_FLAG(MTE3, V, EVENT_ID1); + WAIT_FLAG(MTE3, V, EVENT_ID1); + } + } } // ********************* move O to GM ************************ - else if (head_loop > 1) { + else { SET_FLAG(V, MTE3, pingpong_flag); WAIT_FLAG(V, MTE3, pingpong_flag); ub_to_gm( @@ -3936,39 +3969,25 @@ template uint32_t offset_tiling, uint32_t embed_split_size_v, uint32_t embed_split_loop_v, uint32_t prev_split_num = 0, uint32_t split_num = 1, uint32_t kvEndFlag = 1) { uint32_t cur_kv_seqlen = cur_kvs_seqlen + (- cur_q_seqlen + 1) * kvEndFlag; - uint64_t addr_o_scalar = 0; - uint32_t pp_n_scalar = block_size; uint32_t sub_n_loop = pp_n_scalar / block_size; uint32_t n_block_num = (cur_kv_seqlen + (cur_q_seqlen - 1) * kvEndFlag + pp_n_scalar - 1) / pp_n_scalar; uint32_t n_loop = (n_block_num + K_BLOCK_NUM - 1) / K_BLOCK_NUM; - - uint32_t qk_gm_stride = pp_n_scalar*K_BLOCK_NUM; - + uint32_t qk_gm_stride = pp_n_scalar * K_BLOCK_NUM; uint32_t qk_n = qk_gm_stride; - uint32_t qk_round_n = RoundUp<64>(qk_n); - uint32_t qk_n_2 = pp_n_scalar; uint32_t qk_round_n_2 = RoundUp(qk_n_2); - uint32_t sub_m = (sub_block_idx == 1) ? (cur_head_num - cur_head_num / 2) : cur_head_num / 2; uint32_t head_idx = (sub_block_idx == 0) ? 0 : 0 + cur_head_num / 2 * cur_q_seqlen; - o_offset = addr_o_scalar + start_head * embedding_size + sub_block_idx * cur_head_num / 2 * embedding_size; // for NSD -> SND - #ifdef VEC_NO_PP - uint32_t m_slice = FLOAT_VECTOR_SIZE / K_BLOCK_NUM; - #else + o_offset = addr_o_scalar + start_head * embedding_size + sub_block_idx * (cur_head_num / 2) * embedding_size; // for NSD -> SND uint32_t m_slice = 16; - #endif uint32_t m_end = (sub_m + m_slice - 1) / m_slice; - uint32_t sub_m_d128 = (sub_m + 127) / 128; // up aligned to 128 uint32_t sub_m_d64 = (sub_m + 63) / 64; // up aligned to 128 uint32_t round_sub_m = (sub_m + 15) / 16 * 16; - uint32_t start_kv = 0; - //wait last batch finish WAIT_FLAG(V, MTE2, EVENT_ID7); gm_to_ub( descale_q1_ubuf_tensor, @@ -3993,93 +4012,91 @@ template uint32_t pingpong_flag = 0; for (uint32_t n_idx = 0; n_idx < n_loop + 1; n_idx++) { if (n_idx != n_loop) { - if (n_idx == (n_loop - 1)) { - qk_n = (cur_kv_seqlen - n_idx * pp_n_scalar * K_BLOCK_NUM); - qk_round_n = RoundUp<64>(qk_n); - } - WaitFlagDev(QK_READY_DECODER); - /* ************ softmax1 stage1 ************* */ - for (uint32_t m_ind = 0; m_ind < m_end; m_ind++) { - uint32_t row_offset = m_ind * m_slice; - uint32_t curr_m = m_ind == m_end - 1 ? sub_m - row_offset : m_slice; // m_slice=8 - if constexpr (flashDecodingVec) { - if (n_idx == 0) { - if (gl_flag_scalar == 1) { - WAIT_FLAG(MTE3, V, EVENT_ID2); - gl_flag_scalar = 0; - } + if (n_idx == (n_loop - 1)) { + qk_n = (cur_kv_seqlen - n_idx * pp_n_scalar * K_BLOCK_NUM); + qk_round_n = RoundUp<64>(qk_n); + } + WaitFlagDev(QK_READY_DECODER); + /* ************ softmax1 stage1 ************* */ + for (uint32_t m_ind = 0; m_ind < m_end; m_ind++) { + uint32_t row_offset = m_ind * m_slice; + uint32_t curr_m = m_ind == m_end - 1 ? sub_m - row_offset : m_slice; // m_slice=8 + if constexpr (flashDecodingVec) { + if (n_idx == 0) { + if (gl_flag_scalar == 1) { + WAIT_FLAG(MTE3, V, EVENT_ID2); + gl_flag_scalar = 0; } } - if (sub_m > 0) { - uint32_t s_ub_offset = pingpong_flag * m_slice*512; - uint32_t s_rope_offset = m_slice*512*2; - uint32_t p_gm_offset = (uint64_t)block_idx * P_TMP_SIZE_INT8 + - (uint64_t)sub_block_idx * cur_head_num / 2 * qk_gm_stride + + } + if (sub_m > 0) { + uint32_t s_ub_offset = pingpong_flag * m_slice*512; + uint32_t s_rope_offset = m_slice*512*2; + uint32_t p_gm_offset = (uint64_t)block_idx * P_TMP_SIZE_INT8 + + (uint64_t)sub_block_idx * cur_head_num / 2 * qk_gm_stride + + row_offset * qk_gm_stride + + (uint64_t)((n_idx)%CHUNK_BUFFER_SIZE) * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE; + + uint32_t s_gm_offset = (int64_t)block_idx * TMP_SIZE_DECODER_INT8 + + (int64_t)sub_block_idx * cur_head_num / 2 * qk_gm_stride + row_offset * qk_gm_stride + - (uint64_t)((n_idx)%CHUNK_BUFFER_SIZE) * P_TMP_SIZE_INT8 / CHUNK_BUFFER_SIZE; - - uint32_t s_gm_offset = (int64_t)block_idx * TMP_SIZE_DECODER_INT8 + - (int64_t)sub_block_idx * cur_head_num / 2 * qk_gm_stride + - row_offset * qk_gm_stride + - (uint64_t)((n_idx)%CHUNK_BUFFER_SIZE) * TMP_SIZE_DECODER_INT8 / CHUNK_BUFFER_SIZE; - - OnlineSoftmaxStage1Quant( - ls32_ubuf_tensor[s_ub_offset].template ReinterpretCast(), // s - ls32_ubuf_tensor[s_ub_offset + s_rope_offset], //s rope - mask_ubuf_tensor, - mask_ubuf_tensor.template ReinterpretCast(), - lm32_ubuf_tensor[row_offset], - hm32_ubuf_tensor[row_offset], - gm32_ubuf_tensor[row_offset], - dm32_ubuf_tensor[(n_idx%CHUNK_BUFFER_SIZE) * UB_FLOAT_LINE_SIZE_TP1 + row_offset], - ls32_ubuf_tensor[s_ub_offset], - ll_ubuf_tensor[row_offset], - gl32_ubuf_tensor[row_offset], - descale_q1_ubuf_tensor[row_offset], - pm32_ubuf_tensor[(n_idx%CHUNK_BUFFER_SIZE) * UB_FLOAT_LINE_SIZE_TP1 + row_offset], - lp_ubuf_tensor[s_ub_offset * 4 / sizeof(int8_t)], - tv32_ubuf_tensor, - s_gm_tensor[s_gm_offset], - s_rope_gm_tensor[s_gm_offset], - p_gm_tensor[p_gm_offset], - n_idx == 0, this->tor, - curr_m, qk_n, qk_round_n, qk_gm_stride, pingpong_flag - ); - pingpong_flag = 1 - pingpong_flag; - } + (uint64_t)((n_idx)%CHUNK_BUFFER_SIZE) * TMP_SIZE_DECODER_INT8 / CHUNK_BUFFER_SIZE; + + OnlineSoftmaxStage1Quant( + ls32_ubuf_tensor[s_ub_offset].template ReinterpretCast(), // s + ls32_ubuf_tensor[s_ub_offset + s_rope_offset], //s rope + mask_ubuf_tensor, + mask_ubuf_tensor.template ReinterpretCast(), + lm32_ubuf_tensor[row_offset], + hm32_ubuf_tensor[row_offset], + gm32_ubuf_tensor[row_offset], + dm32_ubuf_tensor[(n_idx%CHUNK_BUFFER_SIZE) * UB_FLOAT_LINE_SIZE_TP1 + row_offset], + ls32_ubuf_tensor[s_ub_offset], + ll_ubuf_tensor[row_offset], + gl32_ubuf_tensor[row_offset], + descale_q1_ubuf_tensor[row_offset], + pm32_ubuf_tensor[(n_idx%CHUNK_BUFFER_SIZE) * UB_FLOAT_LINE_SIZE_TP1 + row_offset], + lp_ubuf_tensor[s_ub_offset * 4 / sizeof(int8_t)], + tv32_ubuf_tensor, + s_gm_tensor[s_gm_offset], + s_rope_gm_tensor[s_gm_offset], + p_gm_tensor[p_gm_offset], + n_idx == 0, this->tor, + curr_m, qk_n, qk_round_n, qk_gm_stride, pingpong_flag + ); + pingpong_flag = 1 - pingpong_flag; } - FftsCrossCoreSync(SOFTMAX_READY_DECODER); - } + FftsCrossCoreSync(SOFTMAX_READY_DECODER); + } pingpong_flag = 0; - /* ************ softmax2 stage1 ************* */ + /* ************ softmax2 stage1 ************* */ if (n_idx != 0) { - if (n_idx == n_loop) { - qk_n_2 = (cur_kv_seqlen - (n_idx - 1) * pp_n_scalar); - qk_round_n_2 = RoundUp(qk_n_2); + if (n_idx == n_loop) { + qk_n_2 = (cur_kv_seqlen - (n_idx - 1) * pp_n_scalar); + qk_round_n_2 = RoundUp(qk_n_2); + } + WaitFlagDev(UPDATE_READY_DECODER); + if (sub_m > 0) { + uint32_t head_loop = (sub_m + 15) / 16; + for (uint32_t head_loop_idx = 0; head_loop_idx < head_loop; ++head_loop_idx) { + uint32_t head_offset = head_loop_idx * 16 * round_v; + uint32_t cur_sub_m = head_loop_idx == (head_loop - 1) ? sub_m - head_loop_idx * 16 : 16; // 16 + + SoftmaxStage2MLAHeadLoopTP1( + o_tmp_gm_tensor[(uint64_t)(block_idx * TMP_SIZE * CHUNK_BUFFER_SIZE + sub_block_idx * cur_head_num / 2 * round_v + head_offset + ((n_idx - 1) % CHUNK_BUFFER_SIZE) * TMP_SIZE)], + go_gm_tensor[(uint64_t)(block_idx * TMP_SIZE + sub_block_idx * cur_head_num / 2 * round_v + head_offset)], + o_gm_tensor[(uint64_t)(o_offset + head_offset)], + dm32_ubuf_tensor[(uint64_t)(((n_idx - 1) % CHUNK_BUFFER_SIZE) * 128 + head_loop_idx * 16)], ll_ubuf_tensor[(uint64_t)(((n_idx - 1) % CHUNK_BUFFER_SIZE) * 256 + head_loop_idx * 16)], + pm32_ubuf_tensor[(uint64_t)(((n_idx - 1) % CHUNK_BUFFER_SIZE) * 128 + head_loop_idx * 16)], + descale_k1_ubuf_tensor[head_loop_idx * 16], + n_idx - 1, n_loop, qk_n_2, RoundUp(qk_round_n_2), cur_sub_m, o_offset, head_idx + head_loop_idx * 16, + (n_idx + 1) % 2, head_loop, head_loop_idx, cur_q_seqlen, head_loop_idx * 16, pingpong_flag, cur_nIndx, prev_split_num, split_num, sub_m); + pingpong_flag = 1 - pingpong_flag; } - WaitFlagDev(UPDATE_READY_DECODER); - if (sub_m > 0) { - uint32_t head_loop = (sub_m + 15) / 16; - for (uint32_t head_loop_idx = 0; head_loop_idx < head_loop; ++head_loop_idx) { - uint32_t head_offset = head_loop_idx * 16 * round_v; - uint32_t cur_sub_m = head_loop_idx == (head_loop - 1) ? sub_m - head_loop_idx * 16 : 16; // 16 - - SoftmaxStage2MLAHeadLoopTP1( - o_tmp_gm_tensor[(uint64_t)(block_idx * TMP_SIZE * CHUNK_BUFFER_SIZE + sub_block_idx * cur_head_num / 2 * round_v + head_offset + ((n_idx - 1)%CHUNK_BUFFER_SIZE) * TMP_SIZE)], - go_gm_tensor[(uint64_t)(block_idx * TMP_SIZE + sub_block_idx * cur_head_num / 2 * round_v + head_offset)], - o_gm_tensor[(uint64_t)(o_offset + head_offset)], - dm32_ubuf_tensor[(uint64_t)(((n_idx - 1)%CHUNK_BUFFER_SIZE) * 128 + head_loop_idx * 16)], ll_ubuf_tensor[(uint64_t)(((n_idx - 1)%CHUNK_BUFFER_SIZE) * 256 + head_loop_idx * 16)], - pm32_ubuf_tensor[(uint64_t)(((n_idx - 1)%CHUNK_BUFFER_SIZE) * 128 + head_loop_idx * 16)], - descale_k1_ubuf_tensor[head_loop_idx * 16], - n_idx - 1, n_loop, qk_n_2, RoundUp(qk_round_n_2), cur_sub_m, o_offset, head_idx + head_loop_idx * 16, - (n_idx + 1) % 2, head_loop, head_loop_idx, cur_q_seqlen, head_loop_idx * 16, pingpong_flag, cur_nIndx, prev_split_num, split_num); - pingpong_flag = 1 - pingpong_flag; - } - } } + } } - //current batch done SET_FLAG(V, MTE2, EVENT_ID7); } diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index 9b885438..802ebee0 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -133,7 +133,9 @@ Status GetMLANdInfo(const LaunchParam &launchParam, MLAInfo &mmInfo, mmInfo.numHeads = static_cast(param.headSize); mmInfo.maskType = static_cast(param.maskType); mmInfo.quantFlag = (static_cast(mmInfo.type) < NUM2) ? 0 : 1; - mmInfo.mtpTp1Flag = (mmInfo.numHeads == M_LIMIT); + auto maxQSeqlen = mmInfo.qSeqLen != nullptr ? *std::max_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; + mmInfo.mtpTp1Flag = ((param.headSize == M_LIMIT) || + (launchParam.GetInTensor(DIM_0).desc.dtype == TENSOR_DTYPE_INT8 && maxQSeqlen > 1)); if (mmInfo.mtpTp1Flag || static_cast(mmInfo.type) >= NUM2) { mmInfo.maskType = 0; } diff --git a/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py b/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py index 189053d0..8f8be60a 100644 --- a/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py +++ b/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py @@ -148,7 +148,7 @@ class TestPagedAttentionMLA(op_test.OpTest): sim_sub = np.exp(sim_sub) row_sum = np.sum(sim_sub, axis=-1, keepdims=True) soft_res = sim_sub / row_sum - return soft_res + return soft_res, row_sum + np.log(row_max) def softmax_quant_numpy(self, sim, is_first): lm = np.max(sim, axis=-1, keepdims=True) @@ -168,7 +168,7 @@ class TestPagedAttentionMLA(op_test.OpTest): soft_res = sim_int8.astype("float16") soft_res = np.rint(soft_res).astype("int8") de_scalev = self.de_scale2_fp32 * row_maxp[:,0,0] / 127 - return soft_res, row_sum, de_scalev, hm, self.dm + return soft_res, row_sum, de_scalev, hm, self.dm, row_sum + np.log(row_maxp) def softmax_quant_numpy_online(self, sim, heads, kv_head, value): group_head = heads // kv_head @@ -199,7 +199,7 @@ class TestPagedAttentionMLA(op_test.OpTest): qk_n = cur_kv_seqlen - n_idx * block_size_calc end_kv = end_kv + qk_n sim_block = sim[:, :, start_kv : end_kv] - p_block, ll, de_scalev, hm, dm = self.softmax_quant_numpy(sim_block, is_first) + p_block, ll, de_scalev, hm, dm, _ = self.softmax_quant_numpy(sim_block, is_first) self.de_scalev = de_scalev value_block = value[:, start_kv : end_kv, :] lo = self.group_mm_torch(heads, kv_head, torch.from_numpy(p_block), value_block, 0) @@ -230,7 +230,7 @@ class TestPagedAttentionMLA(op_test.OpTest): o = o * scale.transpose(1, 0)[:,:,np.newaxis,np.newaxis] self.go = np.sum(o, axis=0, keepdims=True) self.go = np.squeeze(self.go, axis=0) - return torch.from_numpy(self.go) + return torch.from_numpy(self.go), ls def ref_masked_attention(self, query, # (1, num_heads, head_size) @@ -245,6 +245,7 @@ class TestPagedAttentionMLA(op_test.OpTest): # Q * K.T query = query query = torch.permute(query, (1, 0, 2)) + lse = None if self.is_quant_flag: query_rope = torch.permute(query_rope, (1, 0, 2)) key_rope = torch.permute(key_rope, (1, 2, 0)) @@ -264,32 +265,36 @@ class TestPagedAttentionMLA(op_test.OpTest): sim_high = sim_high + alibi_bias.to(torch.float32) if self.is_quant_flag: self.gm = np.full([query.shape[0] , 1, 1], np.finfo(np.float32).min) - p_high, row_sum, de_scalev, _, _ = self.softmax_quant_numpy(sim_high.numpy(), 1) + p_high, row_sum, de_scalev, _, _, lse = self.softmax_quant_numpy(sim_high.numpy(), 1) + lse = torch.permute(torch.from_numpy(lse).to(mask_data_type), (1, 0, 2)) self.de_scalev = de_scalev value = torch.permute(value, (1, 0, 2)) out_high = self.group_mm_torch(query.shape[0], key.shape[0], torch.from_numpy(p_high), value, 0) out_high = out_high / row_sum out_high = torch.permute(out_high, (1, 0, 2)) s_qk = sim_high.numpy() - out = self.softmax_quant_numpy_online(s_qk, query.shape[0], key.shape[0], value) + out, lse_high = self.softmax_quant_numpy_online(s_qk, query.shape[0], key.shape[0], value) + lse_high = torch.permute(torch.from_numpy(lse_high).to(mask_data_type), (1, 0, 2)) else: # softmax - p_high = self.softmax_numpy(sim_high) + p_high, lse = self.softmax_numpy(sim_high) p = torch.from_numpy(p_high).to(mask_data_type) p_high = torch.from_numpy(p_high) - + lse = torch.permute(torch.from_numpy(lse).to(mask_data_type), (1, 0, 2)) # P * V + lse_high = lse value = torch.permute(value, (1, 0, 2)) out = self.group_mm_torch(query.shape[0], key.shape[0], p, value, 0) out_high = self.group_mm_torch(query.shape[0], key.shape[0], p_high, value, 0) out = torch.permute(out, (1, 0, 2)) out_high = torch.permute(out_high, (1, 0, 2)) sim_out = torch.permute(sim_out, (1, 0, 2)) - return out, out_high, sim_out + return out, out_high, sim_out, lse, lse_high def ref_single_query_cached_kv_attention(self, output, true_out, + lse, query, key_cache, # (num_blocks, block_size, num_heads, head_size) value_cache, # (num_blocks, block_size, num_heads, head_size) @@ -354,14 +359,16 @@ class TestPagedAttentionMLA(op_test.OpTest): values = torch.stack(values, axis=0) scale = np.float32(1.0 / (576 ** 0.5)) if mask_dim == 4: - out,out_high,sim_out = self.ref_masked_attention(q, keys, values, scale, mask[i, :, :, :context_len], mask_data_type, q_rope, keys_rope) + out, out_high, sim_out, lse_i, lse_high = self.ref_masked_attention(q, keys, values, scale, mask[i, :, :, :context_len], mask_data_type, q_rope, keys_rope) out = out.reshape(num_heads, head_size_vo) elif mask_dim == 3: - out,out_high,sim_out = self.ref_masked_attention(q, keys, values, scale, mask[i // mask_index_coff, :, :context_len], mask_data_type, q_rope, keys_rope) + out, out_high, sim_out, lse_i, lse_high = self.ref_masked_attention(q, keys, values, scale, mask[i // mask_index_coff, :, :context_len], mask_data_type, q_rope, keys_rope) out = out.reshape(num_heads, head_size_vo) else: - out,out_high,sim_out = self.ref_masked_attention(q, keys, values, scale, mask, mask_data_type, q_rope, keys_rope) + out, out_high, sim_out, lse_i, lse_high = self.ref_masked_attention(q, keys, values, scale, mask, mask_data_type, q_rope, keys_rope) out = out.reshape(num_heads, head_size_vo) + lse_high = lse_high.reshape(num_heads, 1) + lse[index] = lse_high.to(mask_data_type) out_high = out_high.reshape(num_heads, head_size_vo) sim_out = sim_out.reshape(1, num_heads * context_len) output[index] = out.to(mask_data_type) @@ -429,7 +436,6 @@ class TestPagedAttentionMLA(op_test.OpTest): self.is_quant_flag = is_quant_flag self.q_seqlen = q_seqlen self.fa_block_size = fa_block_size - logging.debug(f'input info: {num_tokens}, {num_heads}, {kv_heads}, {head_size_qk}, {head_size_vo}, {block_size}, {num_blocks}, {k_seqlen}, {dtype}') q_min_range = -1.0 q_max_range = 1.0 @@ -474,7 +480,6 @@ class TestPagedAttentionMLA(op_test.OpTest): else: value_cache = key_cache[:, :, :, :head_size_vo] self.data_type = dtype - if dynamic_batch: context_lens = dynamic_seqlen else: @@ -542,12 +547,14 @@ class TestPagedAttentionMLA(op_test.OpTest): self.block_size_calc = self.get_blockszie_calc(max_context_len, block_size, head_size_qk, head_size_vo) self.block_size = block_size - shape_out = (num_tokens*q_seqlen, num_heads, head_size_vo) - ref_output = torch.zeros(shape_out, dtype=mask_data_type) - true_out = torch.zeros(shape_out, dtype=torch.float32) + shape_out = (num_tokens * q_seqlen, num_heads, head_size_vo) + ref_output = torch.zeros(shape_out, dtype = mask_data_type) + true_out = torch.zeros(shape_out, dtype = torch.float32) + lse = torch.zeros((num_tokens * q_seqlen, num_heads, 1), dtype = mask_data_type) self.ref_single_query_cached_kv_attention( ref_output, true_out, + lse, query, key_cache, value_cache, @@ -586,15 +593,21 @@ class TestPagedAttentionMLA(op_test.OpTest): self.alib_mask = mask self.golden_out = ref_output self.true_out = true_out + self.lse = lse def golden_calc(self, in_tensors): golden_out = torch.tensor(self.golden_out) - return [golden_out] + return [golden_out, self.lse] def golden_compare(self, out_tensors, golden_tensors): result_double = compare_cv(self.true_out, golden_tensors[0], out_tensors[0]) result_old = self.compare_output_data(out_tensors[0], golden_tensors[0], [0.001, 0.001, 0.005, 0.005]) - return (result_double or result_old) + lse_double = True + lse_old = True + if self.is_ring: + lse_double = compare_cv(golden_tensors[1], golden_tensors[1], out_tensors[1]) + lse_old = self.compare_output_data(out_tensors[1], golden_tensors[1], [0.001, 0.001, 0.005, 0.005]) + return (result_double or result_old) and (lse_double or lse_old) @op_test.only_910b def test_paged_mla_split_cache_norm_128_quant_fp16(self): @@ -1226,5 +1239,115 @@ class TestPagedAttentionMLA(op_test.OpTest): ] ) + @op_test.only_910b + def test_paged_mla_split_cache_norm_fp16_q2_numheads128_quant(self): + self.set_support_910b_only() + num_tokens = 8 + q_seqlen = 2 + k_seqlen = 1025 + q_seqlen_list = [q_seqlen] * num_tokens + k_seqlen_list = [k_seqlen] * num_tokens + num_heads = 32 + kv_heads = 1 + block_size = 128 + head_size_qk = 576 + head_size_vo = 512 + num_blocks = 2048 + tor = 1.0 / (head_size_qk ** 0.5) + mask_dim = 0 + dtype = torch.float16 + is_kv_combined = True + is_quant_flag = True + self.is_ring = 0 + fa_block_size = 512 + + self.calc_data(num_tokens, q_seqlen, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen, dtype, mask_dim, torch.float16, + is_kv_combined = is_kv_combined, is_quant_flag = is_quant_flag, fa_block_size = fa_block_size) + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 0, "kvHead": kv_heads, "headSize": num_heads, "tor": tor, + "kvSeqLen": k_seqlen_list, "qSeqLen": q_seqlen_list, "maskType": 0, "isRing": 0} + + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nz, self.format_nz, self.format_nd, self.format_nd, self.format_nd, self.format_nd]) + self.set_output_formats([self.format_nd] * 2) + logging.debug(f"blcok_tables shape: {self.block_tables}") + logging.debug(f"contex_lens shape: {self.contex_lens}") + logging.debug(f"numTokens: {num_tokens}, numHeads: {num_heads}, kvHead: {kv_heads}" + f", blockSize: {block_size}, headSizeQK: {head_size_qk}, headSizeVO: {head_size_vo}, numBlocks: {num_blocks}") + shape_out = ((num_tokens * q_seqlen, num_heads, head_size_vo)) + attention_out = torch.zeros(shape_out, dtype=torch.float16) + for i in range (1): + self.execute( + [ + self.q_split1, + self.q_split2, + self.key_cache_split1, + self.key_cache_split2, + torch.tensor(self.block_tables).int(), + torch.tensor([], dtype=torch.float16), + self.de_scale1_fp32, + self.de_scale2_fp32 + ], + [ + attention_out, torch.tensor([]) + ] + ) + + @op_test.only_910b + def test_paged_mla_split_cache_quant_fp16_q4_numheads32_quant_ring(self): + self.set_support_910b_only() + num_tokens = 8 + q_seqlen = 4 + k_seqlen = 1500 + q_seqlen_list = [q_seqlen] * num_tokens + k_seqlen_list = [k_seqlen] * num_tokens + num_heads = 32 + kv_heads = 1 + block_size = 128 + head_size_qk = 576 + head_size_vo = 512 + num_blocks = 512 + tor = 1.0 / (head_size_qk ** 0.5) + mask_dim = 0 + dtype = torch.float16 + is_kv_combined = True + is_quant_flag = True + self.is_ring = 1 + fa_block_size = 512 + + self.calc_data(num_tokens, q_seqlen, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen, dtype, mask_dim, dtype, + is_kv_combined = is_kv_combined, is_quant_flag = is_quant_flag, fa_block_size = fa_block_size) + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 0, "kvHead": kv_heads, "headSize": num_heads, "tor": tor, + "kvSeqLen": k_seqlen_list, "qSeqLen": q_seqlen_list, "maskType": 0, "isRing": self.is_ring} + + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nz, self.format_nz, self.format_nd, self.format_nd, self.format_nd, self.format_nd]) + self.set_output_formats([self.format_nd] * 2) + logging.debug(f"blcok_tables shape: {self.block_tables}") + logging.debug(f"contex_lens shape: {self.contex_lens}") + logging.debug(f"numTokens: {num_tokens}, numHeads: {num_heads}, kvHead: {kv_heads}" + f", blockSize: {block_size}, headSizeQK: {head_size_qk}, headSizeVO: {head_size_vo}, numBlocks: {num_blocks}") + shape_out = ((num_tokens * q_seqlen, num_heads, head_size_vo)) + attention_out = torch.zeros(shape_out, dtype=dtype) + + shape_out_2 = ((num_tokens * q_seqlen, num_heads, 1)) + lse = torch.zeros(shape_out_2, dtype=dtype) + for i in range (1): + self.execute( + [ + self.q_split1, + self.q_split2, + self.key_cache_split1, + self.key_cache_split2, + torch.tensor(self.block_tables).int(), + torch.tensor([], dtype=dtype), + self.de_scale1_fp32, + self.de_scale2_fp32 + ], + [ + attention_out, lse + ] + ) if __name__ == '__main__': unittest.main() -- Gitee From a61602f06c036bbc06b6b79bf6028a20411f3777 Mon Sep 17 00:00:00 2001 From: Liexss Date: Mon, 21 Jul 2025 15:34:43 +0800 Subject: [PATCH 2/4] [fix]: fix int8 error --- .../op_kernel/mla_int8.cce | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce index 8d1d3ef9..5e959cce 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce @@ -3989,23 +3989,27 @@ template uint32_t round_sub_m = (sub_m + 15) / 16 * 16; uint32_t start_kv = 0; WAIT_FLAG(V, MTE2, EVENT_ID7); - gm_to_ub( + gm_to_ub_align( descale_q1_ubuf_tensor, deq_scale_gm_tensor_q1[head_idx], 0, // sid 1, // nBurst - cur_head_num / FLOAT_BLOCK_SIZE, // lenBurst - 0, // srcGap - 0 // dstGap + sub_m * sizeof(mmScaleType), // lenBurst + 0, + 0, + 0, + 0 ); - gm_to_ub( + gm_to_ub_align( descale_k1_ubuf_tensor, deq_scale_gm_tensor_k1[head_idx], 0, // sid 1, // nBurst - cur_head_num / FLOAT_BLOCK_SIZE, // lenBurst - 0, // srcGap - 0 // dstGap + sub_m * sizeof(mmScaleType), // lenBurst + 0, + 0, + 0, + 0 ); SET_FLAG(MTE2, V, EVENT_ID0); WAIT_FLAG(MTE2, V, EVENT_ID0); -- Gitee From a49d2e25b5a9edebd52c8dab60567f8b385e0602 Mon Sep 17 00:00:00 2001 From: Liexss Date: Tue, 22 Jul 2025 12:00:14 +0800 Subject: [PATCH 3/4] [feature]: support lse for float dtype --- .../multi_latent_attention/op_kernel/mla.cce | 20 ++++++ .../op_kernel/mla_int8.cce | 39 +++-------- .../test_paged_attention_mla_split_quant.py | 67 +++++++++++++++++-- 3 files changed, 92 insertions(+), 34 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index ee3d7c7b..9f2b8b9a 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -4483,6 +4483,26 @@ extern "C" __global__ __aicore__ void mla( MLADecoderAiv pa_aiv {}; pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm, o_core_tmp_gm, l_gm); pa_aiv.RunTP1(); +#endif + } else if (TILING_KEY_IS(50)) { // int8_t(IN) fp16(OUT) +#ifdef __DAV_C220_CUBE__ + MLAttentionDecoderAic pa_aic_fp16 {}; + pa_aic_fp16.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); + pa_aic_fp16.Run(); +#elif __DAV_C220_VEC__ + MLADecoderAiv pa_aiv {}; + pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); + pa_aiv.Run(); +#endif + } else if (TILING_KEY_IS(51)) { // int8_t(IN) bf16(OUT) +#ifdef __DAV_C220_CUBE__ + MLAttentionDecoderAic pa_aic_bf16 {}; + pa_aic_bf16.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); + pa_aic_bf16.Run(); +#elif __DAV_C220_VEC__ + MLADecoderAiv pa_aiv {}; + pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); + pa_aiv.Run(); #endif } } \ No newline at end of file diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce index 5e959cce..76e37132 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla_int8.cce @@ -1438,8 +1438,8 @@ public: __aicore__ __attribute__((always_inline)) inline void SetArgs2( __gm__ uint8_t *__restrict__ lse_out_gm) { - lse_gm = reinterpret_cast<__gm__ OUT_DTYPE *>(lse_out_gm); - lse_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ OUT_DTYPE *>(lse_gm)); + lse_gm = reinterpret_cast<__gm__ float *>(lse_out_gm); + lse_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(lse_gm)); } __aicore__ __attribute__((always_inline)) inline void Run() @@ -3445,7 +3445,7 @@ template } // ********************* move O to GM ************************ if constexpr (IS_RING) { - uint32_t lenBurst = sizeof(OUT_DTYPE); + uint32_t lenBurst = sizeof(float); ln_v(lse32_ubuf_tensor, gl32_ubuf_tensor, sub_m_d64, // repeat @@ -3466,21 +3466,12 @@ template 8, // src0RepeatStride 8 // src1RepeatStride ); - PIPE_BARRIER(V); - conv_v(lse_conv_ubuf_tensor, - lse32_ubuf_tensor, - sub_m_d64, // repeat - 1, // dstBlockStride - 1, // srcBlockStride - 4, // dstRepeatStride - 8 // srcRepeatStride - ); SET_FLAG(V, MTE3, EVENT_ID1); WAIT_FLAG(V, MTE3, EVENT_ID1); // 搬出lse - ub_to_gm_align( + ub_to_gm_align( lse_gm_tensor[(int64_t)(o_offset / __k)], - lse_conv_ubuf_tensor, + lse32_ubuf_tensor, 0, // sid 1, // nBurst lenBurst * sub_m * head_loop, // lenBurst @@ -3492,7 +3483,6 @@ template SET_FLAG(MTE3, V, EVENT_ID1); WAIT_FLAG(MTE3, V, EVENT_ID1); } - } else if (head_loop > 1) { SET_FLAG(V, MTE3, EVENT_ID5); WAIT_FLAG(V, MTE3, EVENT_ID5); @@ -3749,7 +3739,7 @@ template } if constexpr (IS_RING) { if (head_loop_idx == head_loop - 1) { - uint32_t outBytes = sizeof(OUT_DTYPE); + uint32_t outBytes = sizeof(float); uint32_t cal_head_num = all_sub_m; uint32_t cal_head_num_d64 = (cal_head_num + 63) / 64; ln_v(lse32_ubuf_tensor, @@ -3772,21 +3762,12 @@ template 8, // src0RepeatStride 8 // src1RepeatStride ); - PIPE_BARRIER(V); - conv_v(lse_conv_ubuf_tensor, - lse32_ubuf_tensor, - cal_head_num_d64, // repeat - 1, // dstBlockStride - 1, // srcBlockStride - 4, // dstRepeatStride - 8 // srcRepeatStride - ); SET_FLAG(V, MTE3, EVENT_ID1); WAIT_FLAG(V, MTE3, EVENT_ID1); // 搬出lse - ub_to_gm_align( + ub_to_gm_align( lse_gm_tensor[(int64_t)(o_offset / __k)], - lse_conv_ubuf_tensor, + lse32_ubuf_tensor, 0, // sid 1, // nBurst outBytes * all_sub_m,// lenBurst @@ -4111,7 +4092,7 @@ private: __gm__ mm2CopyType *__restrict__ o_tmp_gm{nullptr}; __gm__ float *__restrict__ go_gm{nullptr}; __gm__ int32_t* __restrict__ gm_block_tables_{nullptr}; - __gm__ OUT_DTYPE *__restrict__ lse_gm{nullptr}; + __gm__ float *__restrict__ lse_gm{nullptr}; __gm__ OUT_DTYPE *__restrict__ o_gm{nullptr}; __gm__ OUT_DTYPE *__restrict__ mask_gm{nullptr}; __gm__ uint8_t *__restrict__ tiling_gm{nullptr}; @@ -4157,7 +4138,7 @@ private: AscendC::GlobalTensor mask_gm_tensor; AscendC::GlobalTensor o_gm_tensor; - AscendC::GlobalTensor lse_gm_tensor; + AscendC::GlobalTensor lse_gm_tensor; AscendC::GlobalTensor s_gm_tensor; AscendC::GlobalTensor s_rope_gm_tensor; AscendC::GlobalTensor p_gm_tensor; diff --git a/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py b/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py index 8f8be60a..68184a56 100644 --- a/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py +++ b/tests/apitest/kernelstest/mix/test_paged_attention_mla_split_quant.py @@ -274,7 +274,7 @@ class TestPagedAttentionMLA(op_test.OpTest): out_high = torch.permute(out_high, (1, 0, 2)) s_qk = sim_high.numpy() out, lse_high = self.softmax_quant_numpy_online(s_qk, query.shape[0], key.shape[0], value) - lse_high = torch.permute(torch.from_numpy(lse_high).to(mask_data_type), (1, 0, 2)) + lse_high = torch.permute(torch.from_numpy(lse_high).to(torch.float32), (1, 0, 2)) else: # softmax p_high, lse = self.softmax_numpy(sim_high) @@ -368,7 +368,7 @@ class TestPagedAttentionMLA(op_test.OpTest): out, out_high, sim_out, lse_i, lse_high = self.ref_masked_attention(q, keys, values, scale, mask, mask_data_type, q_rope, keys_rope) out = out.reshape(num_heads, head_size_vo) lse_high = lse_high.reshape(num_heads, 1) - lse[index] = lse_high.to(mask_data_type) + lse[index] = lse_high.to(torch.float32) out_high = out_high.reshape(num_heads, head_size_vo) sim_out = sim_out.reshape(1, num_heads * context_len) output[index] = out.to(mask_data_type) @@ -550,7 +550,7 @@ class TestPagedAttentionMLA(op_test.OpTest): shape_out = (num_tokens * q_seqlen, num_heads, head_size_vo) ref_output = torch.zeros(shape_out, dtype = mask_data_type) true_out = torch.zeros(shape_out, dtype = torch.float32) - lse = torch.zeros((num_tokens * q_seqlen, num_heads, 1), dtype = mask_data_type) + lse = torch.zeros((num_tokens * q_seqlen, num_heads, 1), dtype = torch.float32) self.ref_single_query_cached_kv_attention( ref_output, true_out, @@ -1329,10 +1329,67 @@ class TestPagedAttentionMLA(op_test.OpTest): logging.debug(f"numTokens: {num_tokens}, numHeads: {num_heads}, kvHead: {kv_heads}" f", blockSize: {block_size}, headSizeQK: {head_size_qk}, headSizeVO: {head_size_vo}, numBlocks: {num_blocks}") shape_out = ((num_tokens * q_seqlen, num_heads, head_size_vo)) - attention_out = torch.zeros(shape_out, dtype=dtype) + attention_out = torch.zeros(shape_out, dtype = dtype) shape_out_2 = ((num_tokens * q_seqlen, num_heads, 1)) - lse = torch.zeros(shape_out_2, dtype=dtype) + lse = torch.zeros(shape_out_2, dtype = torch.float32) + for i in range (1): + self.execute( + [ + self.q_split1, + self.q_split2, + self.key_cache_split1, + self.key_cache_split2, + torch.tensor(self.block_tables).int(), + torch.tensor([], dtype=dtype), + self.de_scale1_fp32, + self.de_scale2_fp32 + ], + [ + attention_out, lse + ] + ) + + @op_test.only_910b + def test_paged_mla_split_cache_quant_fp16_q1_numheads8_quant_ring(self): + self.set_support_910b_only() + num_tokens = 32 + q_seqlen = 1 + k_seqlen = 1700 + q_seqlen_list = [q_seqlen] * num_tokens + k_seqlen_list = [k_seqlen] * num_tokens + num_heads = 8 + kv_heads = 1 + block_size = 128 + head_size_qk = 576 + head_size_vo = 512 + num_blocks = 512 + tor = 1.0 / (head_size_qk ** 0.5) + mask_dim = 0 + dtype = torch.float16 + is_kv_combined = True + is_quant_flag = True + self.is_ring = 1 + fa_block_size = 512 + + self.calc_data(num_tokens, q_seqlen, num_heads, kv_heads, head_size_qk, head_size_vo, block_size, num_blocks, k_seqlen, dtype, mask_dim, dtype, + is_kv_combined = is_kv_combined, is_quant_flag = is_quant_flag, fa_block_size = fa_block_size) + OP_NAME = "MLAOperation" + OP_PARAM = {"type": 0, "kvHead": kv_heads, "headSize": num_heads, "tor": tor, + "kvSeqLen": k_seqlen_list, "qSeqLen": q_seqlen_list, "maskType": 0, "isRing": self.is_ring} + + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd, self.format_nd, self.format_nz, self.format_nz, self.format_nd, self.format_nd, self.format_nd, self.format_nd]) + self.set_output_formats([self.format_nd] * 2) + logging.debug(f"blcok_tables shape: {self.block_tables}") + logging.debug(f"contex_lens shape: {self.contex_lens}") + logging.debug(f"numTokens: {num_tokens}, numHeads: {num_heads}, kvHead: {kv_heads}" + f", blockSize: {block_size}, headSizeQK: {head_size_qk}, headSizeVO: {head_size_vo}, numBlocks: {num_blocks}") + shape_out = ((num_tokens * q_seqlen, num_heads, head_size_vo)) + attention_out = torch.zeros(shape_out, dtype = dtype) + + shape_out_2 = ((num_tokens * q_seqlen, num_heads, 1)) + lse = torch.zeros(shape_out_2, dtype = torch.float32) for i in range (1): self.execute( [ -- Gitee From 9af16f30f9fdd658cc056bbbf8360bc7e7329400 Mon Sep 17 00:00:00 2001 From: Liexss Date: Tue, 22 Jul 2025 14:13:44 +0800 Subject: [PATCH 4/4] [fix]: fix gm error --- .../multi_latent_attention/op_kernel/mla.cce | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce index 9f2b8b9a..db5c4a35 100644 --- a/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce +++ b/src/kernels/mixkernels/multi_latent_attention/op_kernel/mla.cce @@ -4430,9 +4430,9 @@ extern "C" __global__ __aicore__ void mla( pa_aic_fp16.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); pa_aic_fp16.RunTP1(); #elif __DAV_C220_VEC__ - MLADecoderAiv pa_aiv {}; - pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm, o_core_tmp_gm, l_gm); - pa_aiv.RunTP1(); + MLADecoderAiv pa_aiv {}; + pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm, o_core_tmp_gm, l_gm); + pa_aiv.RunTP1(); #endif } else if (TILING_KEY_IS(69)) { // bf16 + nd + FOUR_FLOW + flashdecoding 0b1000101 #ifdef __DAV_C220_CUBE__ @@ -4484,7 +4484,7 @@ extern "C" __global__ __aicore__ void mla( pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm, o_core_tmp_gm, l_gm); pa_aiv.RunTP1(); #endif - } else if (TILING_KEY_IS(50)) { // int8_t(IN) fp16(OUT) + } else if (TILING_KEY_IS(50)) { // int8_t(IN) fp16(OUT) RING #ifdef __DAV_C220_CUBE__ MLAttentionDecoderAic pa_aic_fp16 {}; pa_aic_fp16.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); @@ -4492,9 +4492,10 @@ extern "C" __global__ __aicore__ void mla( #elif __DAV_C220_VEC__ MLADecoderAiv pa_aiv {}; pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); + pa_aiv.SetArgs2(lse_gm); pa_aiv.Run(); #endif - } else if (TILING_KEY_IS(51)) { // int8_t(IN) bf16(OUT) + } else if (TILING_KEY_IS(51)) { // int8_t(IN) bf16(OUT) RING #ifdef __DAV_C220_CUBE__ MLAttentionDecoderAic pa_aic_bf16 {}; pa_aic_bf16.SetArgs(sync, q_gm, q_rope_gm, ctkv_gm, ctkv_rope_gm, block_tables_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, tiling_para_gm); @@ -4502,6 +4503,7 @@ extern "C" __global__ __aicore__ void mla( #elif __DAV_C220_VEC__ MLADecoderAiv pa_aiv {}; pa_aiv.SetArgs(sync, block_tables_gm, deq_qk_gm, deq_pv_gm, o_gm, s_gm, s_rope_out_gm, p_gm, o_tmp_gm, go_gm, tiling_para_gm, mask_gm); + pa_aiv.SetArgs2(lse_gm); pa_aiv.Run(); #endif } -- Gitee