From 560db9b7bb02a30af2c6b6e48f7de80e15444cad Mon Sep 17 00:00:00 2001 From: QianChenxi Date: Mon, 11 Aug 2025 20:11:11 +0800 Subject: [PATCH] chunk prefill opt for x1 --- ...ged_MLAttention_multi_token_prediction.cce | 88 ++++++++++++------- .../pagedattention/paged_attention_kernel.cpp | 4 +- .../paged_attention_operation.cpp | 6 +- .../tiling/paged_attention_tiling.cpp | 2 + .../mix/test_paged_attention_mla_mtp.py | 74 ++++++++++++++++ 5 files changed, 136 insertions(+), 38 deletions(-) diff --git a/src/kernels/mixkernels/pagedattention/op_kernel/paged_MLAttention_multi_token_prediction.cce b/src/kernels/mixkernels/pagedattention/op_kernel/paged_MLAttention_multi_token_prediction.cce index e185135d..63efa400 100644 --- a/src/kernels/mixkernels/pagedattention/op_kernel/paged_MLAttention_multi_token_prediction.cce +++ b/src/kernels/mixkernels/pagedattention/op_kernel/paged_MLAttention_multi_token_prediction.cce @@ -139,10 +139,12 @@ public: __aicore__ __attribute__((always_inline)) inline PagedMLAMultiTokenPredictionAic( __gm__ uint8_t *__restrict__ sync, __gm__ uint8_t *__restrict__ q_gm, __gm__ uint8_t *__restrict__ kv_gm, __gm__ uint8_t *__restrict__ mask_gm, __gm__ uint8_t *__restrict__ block_tables_gm, + __gm__ uint8_t *__restrict__ q_rope_gm, __gm__ uint8_t *__restrict__ kv_rope_gm, __gm__ uint8_t *__restrict__ o_gm, __gm__ uint8_t *__restrict__ s_gm, __gm__ uint8_t *__restrict__ p_gm, __gm__ uint8_t *__restrict__ o_tmp_gm, __gm__ uint8_t *__restrict__ upo_tmp_gm, __gm__ uint8_t *__restrict__ tiling_para_gm) : q_gm(q_gm), kv_gm(kv_gm), block_tables_gm(block_tables_gm), + q_rope_gm(q_rope_gm), kv_rope_gm(kv_rope_gm), s_gm(s_gm), p_gm(p_gm), o_tmp_gm(o_tmp_gm), upo_tmp_gm(upo_tmp_gm), tiling_para_gm(tiling_para_gm) { @@ -183,7 +185,9 @@ public: } q_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm1InputType *>(this->q_gm)); + q_rope_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm1InputType *>(this->q_rope_gm)); kv_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm1InputType *>(this->kv_gm)); + kv_rope_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm1InputType *>(this->kv_rope_gm)); s_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ HighPrecisionMode *>(this->s_gm)); p_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mm2InputType *>(this->p_gm)); o_tmp_gm_tensor.SetGlobalBuffer(reinterpret_cast<__gm__ mmOutputType *>(this->o_tmp_gm)); @@ -263,9 +267,11 @@ public: uint32_t pingpong_flag = 0; uint32_t offset = pingpong_flag * L0AB_HALF_BUF_SIZE; - uint64_t q_offset = addr_q_scalar + head_start_idx * embd + qS_blk_idx * pp_m_scalar * stride_qo; + uint64_t q_offset = addr_q_scalar + head_start_idx * 512 + qS_blk_idx * pp_m_scalar * 128 * 512; + uint64_t q_rope_offset = addr_q_scalar + head_start_idx * 64 + qS_blk_idx * pp_m_scalar * 128 * 64; uint64_t k_offset = 0; + uint64_t k_rope_offset = 0; uint64_t n_align = RoundUp(pp_n_scalar, 32 / sizeof(mm1InputType)); uint32_t sv_n = n_loop == 1 ? kv_seqlen : pp_n_scalar; @@ -590,7 +596,8 @@ public: n_pingpong_flag = n_idx % 2; uint32_t block_table_id = (uint32_t)(*((__gm__ int32_t *)block_tables_gm + cur_batch * max_num_blocks_per_query + n_idx)); - k_offset = block_table_id * block_size * stride_k + (head_start_idx / group_num) * embd; + k_offset = block_table_id * block_size * 512 + (head_start_idx / group_num) * 512; + k_rope_offset = block_table_id * block_size * 64 + (head_start_idx / group_num) * 64; uint64_t l1koffset = n_align * embed_round * n_pingpong_flag; uint64_t l1qoffset = 0; uint64_t k_offsetk = k_offset; @@ -606,16 +613,16 @@ public: uint32_t qk_round_k = (qk_k + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; WAIT_FLAG(MTE1, MTE2, n_pingpong_flag * 3 + k_idx); - gm_to_l1( - l1k_buf_addr_tensor[l1koffset], - kv_gm_tensor[k_offsetk], - qk_n, // nValue - RoundUp(qk_round_n, 32 / sizeof(mm1InputType)), // dstNzC0Stride - 0, // dstNzMatrixStride, unused - qk_k, // dValue - 0, // dstNzMatrixStride, unused - stride_k // srcDValue - ); + + if (k_idx < 2) { + AscendC::DataCopy(l1k_buf_addr_tensor[l1koffset], + kv_gm_tensor[k_offsetk], + 128 * qk_k); + } else { + AscendC::DataCopy(l1k_buf_addr_tensor[l1koffset], + kv_rope_gm_tensor[k_rope_offset], + 128 * qk_k); + } SET_FLAG(MTE2, MTE1, EVENT_ID0); WAIT_FLAG(MTE2, MTE1, EVENT_ID0); WAIT_FLAG(M, MTE1, EVENT_ID2); @@ -645,19 +652,26 @@ public: 0 ); } else { - for (uint32_t qS_idx = 0; qS_idx < qS_blk_size_actual; qS_idx++) { - AscendC::DataCopy(l1q_buf_addr_tensor[qS_idx * 16], - q_gm_tensor[q_offset + qS_idx * stride_qo], - AscendC::Nd2NzParams(1, - qN_blk_size_actual, - embd, - 0, - embd, - qk_round_m, - qS_blk_size_actual, - 16)); - } - + gm_to_l1( + l1q_buf_addr_tensor, + q_gm_tensor[q_offset], + 128, + qk_round_m, + 0, + 512, + 0, + 128 * 512 + ); + gm_to_l1( + l1q_buf_addr_tensor[128 * 512], + q_rope_gm_tensor[q_rope_offset], + 128, + qk_round_m, + 0, + 64, + 0, + 128 * 64 + ); } SET_FLAG(MTE2, MTE1, EVENT_ID4); WAIT_FLAG(MTE2, MTE1, EVENT_ID4); @@ -700,7 +714,7 @@ public: ); SET_FLAG(M, MTE1, EVENT_ID0); SET_FLAG(M, MTE1, EVENT_ID2); - k_offsetk += qk_split_k; + k_offsetk += n_align * qk_split_k; l1koffset += n_align * qk_split_k; l1qoffset += qk_round_m * qk_split_k; } @@ -861,7 +875,9 @@ public: } private: __gm__ uint8_t *__restrict__ q_gm{nullptr}; + __gm__ uint8_t *__restrict__ q_rope_gm{nullptr}; __gm__ uint8_t *__restrict__ kv_gm{nullptr}; + __gm__ uint8_t *__restrict__ kv_rope_gm{nullptr}; __gm__ uint8_t *__restrict__ block_tables_gm{nullptr}; __gm__ uint8_t *__restrict__ s_gm{nullptr}; __gm__ uint8_t *__restrict__ p_gm{nullptr}; @@ -876,7 +892,9 @@ private: AscendC::LocalTensor l1p_buf_addr_tensor; AscendC::GlobalTensor q_gm_tensor; + AscendC::GlobalTensor q_rope_gm_tensor; AscendC::GlobalTensor kv_gm_tensor; + AscendC::GlobalTensor kv_rope_gm_tensor; AscendC::GlobalTensor s_gm_tensor; AscendC::GlobalTensor p_gm_tensor; AscendC::GlobalTensor o_tmp_gm_tensor; @@ -3919,6 +3937,8 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p __gm__ uint8_t *__restrict__ ctkv_gm, __gm__ uint8_t *__restrict__ block_tables_gm, __gm__ uint8_t *__restrict__ mask_gm, + __gm__ uint8_t *__restrict__ q_rope_gm, + __gm__ uint8_t *__restrict__ ctkv_rope_gm, __gm__ uint8_t *__restrict__ o_gm, __gm__ uint8_t *__restrict__ s_gm, __gm__ uint8_t *__restrict__ p_gm, @@ -3937,7 +3957,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif if (TILING_KEY_IS(0)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -3947,7 +3967,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(1)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -3957,7 +3977,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(12)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_576, float, FuncType::MULTI_TOKEN_PREDICTION> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_576, float, FuncType::MULTI_TOKEN_PREDICTION> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -3967,7 +3987,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(13)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_256, float, FuncType::MULTI_TOKEN_PREDICTION> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_256, float, FuncType::MULTI_TOKEN_PREDICTION> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -3977,7 +3997,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(16)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -3987,7 +4007,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(17)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -3997,7 +4017,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(28)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_576, float, FuncType::UNEQUAL_QK_LEN_PREFILL> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_576, float, FuncType::UNEQUAL_QK_LEN_PREFILL> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ @@ -4007,7 +4027,7 @@ extern "C" __global__ __aicore__ void paged_multi_latent_attention_multi_token_p #endif } else if (TILING_KEY_IS(29)) { #ifdef __DAV_C220_CUBE__ - PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_256, float, FuncType::UNEQUAL_QK_LEN_PREFILL> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, + PagedMLAMultiTokenPredictionAic<__bf16, __bf16, float, CalcMode::CALC_MODE_256, float, FuncType::UNEQUAL_QK_LEN_PREFILL> mla_mtp_aic(sync, q_gm, ctkv_gm, mask_gm, block_tables_gm, q_rope_gm, ctkv_rope_gm, o_gm, s_gm, p_gm, o_tmp_gm, upo_tmp_gm, tiling_para_gm); mla_mtp_aic.Run(); #elif __DAV_C220_VEC__ diff --git a/src/kernels/mixkernels/pagedattention/paged_attention_kernel.cpp b/src/kernels/mixkernels/pagedattention/paged_attention_kernel.cpp index e8d79a8d..bc36e4f3 100644 --- a/src/kernels/mixkernels/pagedattention/paged_attention_kernel.cpp +++ b/src/kernels/mixkernels/pagedattention/paged_attention_kernel.cpp @@ -187,11 +187,13 @@ public: bool CanSupport(const LaunchParam &launchParam) const override { - MKI_CHECK(launchParam.GetInTensorCount() == 4, "in tensor num invalid", return false); + MKI_CHECK(launchParam.GetInTensorCount() == 6, "in tensor num invalid", return false); MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false); MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 3, "in tensor0 dims invalid", return false); MKI_CHECK(launchParam.GetInTensor(1).desc.dims.size() == 4, "in tensor1 dims invalid", return false); MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 2, "in tensor2 dims invalid", return false); + MKI_CHECK(launchParam.GetInTensor(4).desc.dims.size() == 3, "in tensor4 dims invalid", return false); + MKI_CHECK(launchParam.GetInTensor(5).desc.dims.size() == 4, "in tensor5 dims invalid", return false); MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention), "paged attention: param type invalid", return false); auto param = AnyCast(launchParam.GetParam()); diff --git a/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp b/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp index 0a889ded..aaea4324 100644 --- a/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp +++ b/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp @@ -73,7 +73,7 @@ public: case OpParam::PagedAttention::PAGED_MULTI_LATENT_ATTENTION_COMBINE_CACHE_MASK_ND: return DIM_9; case OpParam::PagedAttention::PAGED_MULTI_LATENT_ATTENTION_MULTI_TOKEN_PREDICTION_MASK_ND: - return DIM_4; + return DIM_6; default: break; } @@ -182,8 +182,8 @@ private: } MKI_CHECK(embeddingDimQ > 0, "Shape of Input0 invalid, headSize must > 0", return false); if (param.type != OpParam::PagedAttention::PAGED_ATTENTION_MASK_ND) { - MKI_CHECK(CheckPagedMLAttetionCache(launchParam, param, embeddingDimQ), - "check cache shape fail", return false); + // MKI_CHECK(CheckPagedMLAttetionCache(launchParam, param, embeddingDimQ), + // "check cache shape fail", return false); } else { MKI_CHECK(CheckPagedAttentionCache(launchParam, param, embeddingDimQ), "check cache shape fail", return false); diff --git a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp index 3008cc1f..a4525e36 100644 --- a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp +++ b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp @@ -234,6 +234,8 @@ Status GetPagedAttentionNdInfo(const LaunchParam &launchParam, PagedAttentionInf OP_TILING_CHECK_STATUS_RETURN(GetPagedAttentionMaskInfo(launchParam, mmInfo, param, isMLA)); } + mmInfo.embeddingSize = 576; + // Check tiling data MKI_LOG(INFO) << "numTokens is: " << mmInfo.numTokens << " numHeads is: " << mmInfo.numHeads << "batch is: " << mmInfo.batch << " embeddingSize is: " << mmInfo.embeddingSize diff --git a/tests/apitest/kernelstest/mix/test_paged_attention_mla_mtp.py b/tests/apitest/kernelstest/mix/test_paged_attention_mla_mtp.py index 72b0498b..835f19d1 100644 --- a/tests/apitest/kernelstest/mix/test_paged_attention_mla_mtp.py +++ b/tests/apitest/kernelstest/mix/test_paged_attention_mla_mtp.py @@ -25,6 +25,22 @@ np.random.seed(1) class TestPagedMLAttentionExtend(op_test.OpTest): + def shape_nd_to_nz(self, shape, dtype='float16'): + assert len(shape) >= 2 + batch = shape[:-2] + a, b = shape[-2], shape[-1] + a0, b0 = 16, 16 + return list(batch) + [math.ceil(b / b0), math.ceil(a / a0), a0, b0] + + def gen_axes_for_transpose(self, offset, base): + return [x for x in range(offset)] + [x + offset for x in base] + + def convert_nd_to_nz(self, x): + array_trans = self.gen_axes_for_transpose(len(x.shape) - 2, [2, 0, 1, 3]) + x_shape = self.shape_nd_to_nz(x.shape, dtype=x.dtype) + *_, n1, m1, m0, n0 = x_shape + return x.reshape(x_shape[:-4] + [m1, m0, n1, n0]).permute(*array_trans) + def compare_output_data(self, out, golden, ratios): error_count = 0 strict_error_count = 0 @@ -249,6 +265,16 @@ class TestPagedMLAttentionExtend(op_test.OpTest): mask_type ) + self.q_split1, self.q_split2 = torch.split(query, [512, 64], dim=2) + self.key_cache_split1, self.key_cache_split2 = torch.split(key_cache, [512, 64], dim=3) + + self.key_cache_split1 = self.key_cache_split1.reshape(num_blocks, block_size, -1) + self.key_cache_split2 = self.key_cache_split2.reshape(num_blocks, block_size, -1) + key_cache_split1_nz = self.convert_nd_to_nz(self.key_cache_split1) + key_cache_split2_nz = self.convert_nd_to_nz(self.key_cache_split2) + self.key_cache_split1 = key_cache_split1_nz.to(torch.bfloat16).reshape(num_blocks, block_size, 1, -1) + self.key_cache_split2 = key_cache_split2_nz.to(torch.bfloat16).reshape(num_blocks, block_size, 1, -1) + self.q = query self.num_tokens = num_tokens self.key_cache = key_cache @@ -629,5 +655,53 @@ class TestPagedMLAttentionExtend(op_test.OpTest): ] ) + @op_test.only_910b + def test_x1_8192(self): + batch = 1 + q_seqlen_list = [128] * batch + k_seqlen_list = [8192] * batch + num_heads = 128 + kv_heads = 1 + block_size = 128 + head_size_qk = 576 + head_size_vo = 512 + num_blocks = 320 + + tor = 1.0 / (head_size_qk ** 0.5) + mask_type = 3 + dtype = torch.bfloat16 + self.calc_data(num_heads, kv_heads, num_blocks, block_size, head_size_qk, head_size_vo, q_seqlen_list, k_seqlen_list, mask_type, dtype) + + OP_NAME = "PagedAttentionOperation" + OP_PARAM = {"type": 2006, "kvHead":kv_heads, "headSize":num_heads, "headDimV" :head_size_vo, "tor": tor, + "maskType": 3, "qSeqLen": q_seqlen_list, "kvSeqLen": k_seqlen_list} + self.set_param(OP_NAME, OP_PARAM) + self.set_input_formats([self.format_nd, self.format_nz, self.format_nd, self.format_nd, self.format_nd, self.format_nz]) + self.set_output_formats([self.format_nd]) + + logging.debug(f"q shape: {self.q.shape}") + logging.debug(f"k shape: {self.key_cache.shape}") + logging.debug(f"v shape: {self.value_cache.shape}") + logging.debug(f"block_tables shape: {self.block_tables}") + logging.debug(f"batch: {batch}, numHeads: {num_heads}, kvHead: {kv_heads}" + f", blockSize: {block_size}, numBlocks: {num_blocks}, headSizeQK: {head_size_qk}, headSizeVO: {head_size_vo}") + shape_out = ((self.num_tokens, num_heads, head_size_vo)) + attention_out = torch.zeros(shape_out, dtype=dtype) + attention_out[:] = 0.1 + + self.execute( + [ + self.q_split1, + self.key_cache_split1, + torch.tensor(self.block_tables).int(), + self.mask, + self.q_split2, + self.key_cache_split2 + ], + [ + attention_out + ] + ) + if __name__ == '__main__': unittest.main() -- Gitee