From b471b2527b922245b2f2d8ad2fd179b81af1b449 Mon Sep 17 00:00:00 2001 From: xujingjing Date: Wed, 28 May 2025 13:55:00 +0100 Subject: [PATCH 1/5] atb mla update --- include/atb/infer_op_params.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/atb/infer_op_params.h b/include/atb/infer_op_params.h index da14251d..7acda597 100644 --- a/include/atb/infer_op_params.h +++ b/include/atb/infer_op_params.h @@ -2898,6 +2898,7 @@ struct MultiLatentAttentionParam { UNDEFINED = 0, //!< 默认值,全0的mask MASK_TYPE_SPEC, //!< qseqlen > 1时的mask MASK_TYPE_MASK_FREE, //!< mask free + MASK_TYPE_CAUSAL_MASK, //!< 内部生成mask }; //! //! \brief mask类型 -- Gitee From 1555aa506734fbc393ef84396c6ff8cf6a4bf02a Mon Sep 17 00:00:00 2001 From: xujingjing Date: Wed, 28 May 2025 13:57:09 +0100 Subject: [PATCH 2/5] atb mla update --- ops_configs/atb_ops_info.ini | 1 + .../multi_latent_attention_operation.cpp | 19 +- ...ti_latent_attention_ops_runner_prefill.cpp | 3 +- .../test_multi_latent_attention_prefill.py | 176 ++++++++++++++++++ 4 files changed, 191 insertions(+), 8 deletions(-) diff --git a/ops_configs/atb_ops_info.ini b/ops_configs/atb_ops_info.ini index 8d662f22..aeed5bb3 100644 --- a/ops_configs/atb_ops_info.ini +++ b/ops_configs/atb_ops_info.ini @@ -3410,6 +3410,7 @@ input3.name=blockTables input3.dtype=int32,int32,int32,int32 input3.format=nd,nd,nd,nd input4.name=mask +input4.optional=true input4.dtype=float16,bf16,float16,bf16 input4.format=nd,nd,nd,nd input5.name=seqLen diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp index 27ebcb95..ae93e72e 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp @@ -113,7 +113,7 @@ static bool ParamRangeCheck(const infer::MultiLatentAttentionParam &opParam) return false; } if (opParam.maskType < infer::MultiLatentAttentionParam::MaskType::UNDEFINED || - opParam.maskType > infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE) { + opParam.maskType > infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { ATB_LOG(ERROR) << "invalid maskType"; return false; } @@ -135,12 +135,15 @@ static bool ParamPrefillCheck(const infer::MultiLatentAttentionParam &opParam) ATB_LOG(ERROR) << "headNum should be >= 1 and <= 128"; return false; } - if (opParam.headNum != opParam.kvHeadNum) { - ATB_LOG(ERROR) << "Prefill, headNum should be equal to kvHeadNum"; - return false; + if (opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { + if (opParam.headNum != opParam.kvHeadNum) { + ATB_LOG(ERROR) << "Prefill, headNum should be equal to kvHeadNum"; + return false; + } } if (opParam.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED && - opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE) { + opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE && + opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { ATB_LOG(ERROR) << "Prefill, maskType support UNDEFINED and MASK_TYPE_MASK_FREE"; return false; } @@ -156,7 +159,8 @@ MultiLatentAttentionOperation::MultiLatentAttentionOperation(const infer::MultiL { std::string opIrKeyStr; opIrKeyStr += "MultiLatentAttentionOperation"; - if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { + if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED && + param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { opIrKeyStr += "Mask"; } if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC || param_. @@ -184,7 +188,8 @@ MultiLatentAttentionOperation::~MultiLatentAttentionOperation() {} uint32_t MultiLatentAttentionOperation::GetInputNum() const { uint32_t intensorNumBase = IN_TENSOR_NUM; - if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { + if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED && + param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { intensorNumBase++; } if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_PREFILL) { diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index 0c0e7aef..98194815 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -85,7 +85,8 @@ Status MultiLatentAttentionOpsRunnerPrefill::ModifyKernelGraph(const OpsTensorPa ATB_LOG(ERROR) << GetLogPrefix() << " build param from host tensor fail"; return ERROR_INVALID_PARAM; } - if (newParam_.contextLens != newParam_.qSeqlen) { + if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK && + newParam_.contextLens != newParam_.qSeqlen) { ATB_LOG(ERROR) << GetLogPrefix() << " qSeqLen and kvSeqLen should be same"; return ERROR_INVALID_PARAM; } diff --git a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py index 8a51c7e2..7095a1b6 100644 --- a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py +++ b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py @@ -793,6 +793,8 @@ class TestMLAPrefill(operation_test.OperationTest): result_double = compare_cv(self.golden_out_true.npu(), golden_tensors[0].npu(), out_tensors[0].npu()) return (result_double or result_single) else: + print(f"{golden_tensors[0]=}") + print(f"{out_tensors[0]=}") result_double = compare_cv(self.golden_out_true.npu(), golden_tensors[0].npu(), out_tensors[0].npu()) return (result_double or result_single) @@ -1033,5 +1035,179 @@ class TestMLAPrefill(operation_test.OperationTest): self.mask.to(data_type).npu() ]) + def test_flash_attention_mla_fp16_causal_mask(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + batch = 2 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + q_seqLen = kv_seqLen + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MultiLatentAttentionOperation" + OP_PARAM = { + "headNum": heads, + "qkScale": tor, + "kvHeadNum": kv_head, + "maskType": 3, + "calcType": 4, + "cacheMode": 1 + } + RUN_PARAM = json.dumps({"qSeqlen": q_seqLen, "contextLens": kv_seqLen}) + data_type = torch.float16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = False, + op_type = 1, mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + for i in range(1): + self.execute_with_param(OP_NAME, OP_PARAM, RUN_PARAM, [ + self.q_split1.npu(), + self.q_split2.npu(), + self.k_split1[0].npu(), + self.k_split2[0].npu(), + self.v[0].npu(), + torch.tensor(q_seqLen, dtype=torch.int32).npu(), + torch.tensor(kv_seqLen, dtype=torch.int32).npu() + ]) + + def test_flash_attention_mla_bf16_batch_causal_mask(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + batch = 2 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + q_seqLen = kv_seqLen + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MultiLatentAttentionOperation" + OP_PARAM = { + "headNum": heads, + "qkScale": tor, + "kvHeadNum": kv_head, + "maskType": 3, + "calcType": 4, + "cacheMode": 1 + } + RUN_PARAM = json.dumps({"qSeqlen": q_seqLen, "contextLens": kv_seqLen}) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = False, + op_type = 1, mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + for i in range(1): + self.execute_with_param(OP_NAME, OP_PARAM, RUN_PARAM, [ + self.q_split1.npu(), + self.q_split2.npu(), + self.k_split1[0].npu(), + self.k_split2[0].npu(), + self.v[0].npu(), + torch.tensor(q_seqLen, dtype=torch.int32).npu(), + torch.tensor(kv_seqLen, dtype=torch.int32).npu() + ]) + + def test_flash_attention_mla_bf16_head_batch_causal_mask(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + batch = 2 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 2 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + q_seqLen = kv_seqLen + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MultiLatentAttentionOperation" + OP_PARAM = { + "headNum": heads, + "qkScale": tor, + "kvHeadNum": kv_head, + "maskType": 3, + "calcType": 4, + "cacheMode": 1 + } + RUN_PARAM = json.dumps({"qSeqlen": q_seqLen, "contextLens": kv_seqLen}) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = False, + op_type = 1, mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + for i in range(1): + self.execute_with_param(OP_NAME, OP_PARAM, RUN_PARAM, [ + self.q_split1.npu(), + self.q_split2.npu(), + self.k_split1[0].npu(), + self.k_split2[0].npu(), + self.v[0].npu(), + torch.tensor(q_seqLen, dtype=torch.int32).npu(), + torch.tensor(kv_seqLen, dtype=torch.int32).npu() + ]) + if __name__ == '__main__': unittest.main() \ No newline at end of file -- Gitee From 1ff190e3cfd14ca011c76c21e5374d40880c89ff Mon Sep 17 00:00:00 2001 From: xujingjing Date: Thu, 29 May 2025 08:18:36 +0100 Subject: [PATCH 3/5] update causalmask --- .../multi_latent_attention_ops_runner_prefill.cpp | 2 +- .../test_multi_latent_attention_prefill.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index 98194815..dd9fd4c5 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -68,7 +68,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::SetupKernelGraph(const OpsTensorPac asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CASUAL_MASK; mlaNode.opDesc = {0, "MLAOperation", asdParam}; mlaNode.inTensors = {&query, &queryRope, &key, &keyRope, &value, &mask, &alibiCoeff, diff --git a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py index 7095a1b6..ce42bbc5 100644 --- a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py +++ b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py @@ -1039,7 +1039,7 @@ class TestMLAPrefill(operation_test.OperationTest): if not operation_test.get_soc_version() == 'Ascend910B': print("this testcase only supports Ascend910B") return - batch = 2 + batch = 1 kv_head = 1 # kv_head num isdecoder = 0 # prefill or decoder heads = 1 # llama7b hidden_size 4096 -- Gitee From 20c90468122cbc75c4ff0c726f4e0f364a38de0d Mon Sep 17 00:00:00 2001 From: xujingjing Date: Thu, 29 May 2025 09:28:28 +0100 Subject: [PATCH 4/5] causal mask param --- .../multi_latent_attention_ops_runner_prefill.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index dd9fd4c5..31b04d4b 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -68,7 +68,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::SetupKernelGraph(const OpsTensorPac asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CASUAL_MASK; + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_MASK; mlaNode.opDesc = {0, "MLAOperation", asdParam}; mlaNode.inTensors = {&query, &queryRope, &key, &keyRope, &value, &mask, &alibiCoeff, -- Gitee From 724f63d81f7074f78511b43ff504494ff76bf5e2 Mon Sep 17 00:00:00 2001 From: xujingjing Date: Tue, 3 Jun 2025 15:11:45 +0100 Subject: [PATCH 5/5] maskType update --- ops_configs/atb_ops_info.ini | 1 - .../multi_latent_attention_ops_runner_prefill.cpp | 12 ++++++++---- .../test_multi_latent_attention_prefill.py | 2 -- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/ops_configs/atb_ops_info.ini b/ops_configs/atb_ops_info.ini index aeed5bb3..8d662f22 100644 --- a/ops_configs/atb_ops_info.ini +++ b/ops_configs/atb_ops_info.ini @@ -3410,7 +3410,6 @@ input3.name=blockTables input3.dtype=int32,int32,int32,int32 input3.format=nd,nd,nd,nd input4.name=mask -input4.optional=true input4.dtype=float16,bf16,float16,bf16 input4.format=nd,nd,nd,nd input5.name=seqLen diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index 31b04d4b..b5076713 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -66,9 +66,13 @@ Status MultiLatentAttentionOpsRunnerPrefill::SetupKernelGraph(const OpsTensorPac asdParam.tor = param_.qkScale; asdParam.kvHead = param_.kvHeadNum; asdParam.isRing = 0; - asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_MASK; + if (param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE) { + asdParam.maskType = AtbOps::OpParam::MLA::MaskType::MASK_TYPE_MASK_FREE; + } else if (param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { + asdParam.maskType = AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_MASK; + } else { + asdParam.maskType = AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; + } mlaNode.opDesc = {0, "MLAOperation", asdParam}; mlaNode.inTensors = {&query, &queryRope, &key, &keyRope, &value, &mask, &alibiCoeff, @@ -99,7 +103,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::ModifyKernelGraph(const OpsTensorPa asdParam.qSeqLen = newParam_.qSeqlen; asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_MASK : AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; mlaNode.opDesc = {0, "MLAOperation", asdParam}; return NO_ERROR; diff --git a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py index ce42bbc5..d3faf015 100644 --- a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py +++ b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py @@ -793,8 +793,6 @@ class TestMLAPrefill(operation_test.OperationTest): result_double = compare_cv(self.golden_out_true.npu(), golden_tensors[0].npu(), out_tensors[0].npu()) return (result_double or result_single) else: - print(f"{golden_tensors[0]=}") - print(f"{out_tensors[0]=}") result_double = compare_cv(self.golden_out_true.npu(), golden_tensors[0].npu(), out_tensors[0].npu()) return (result_double or result_single) -- Gitee