diff --git a/include/atb/infer_op_params.h b/include/atb/infer_op_params.h index da14251d4ef01df5c5f67a165023f98ec3a8385d..7acda597ce708fb9d8223afeec6535a909c777e1 100644 --- a/include/atb/infer_op_params.h +++ b/include/atb/infer_op_params.h @@ -2898,6 +2898,7 @@ struct MultiLatentAttentionParam { UNDEFINED = 0, //!< 默认值,全0的mask MASK_TYPE_SPEC, //!< qseqlen > 1时的mask MASK_TYPE_MASK_FREE, //!< mask free + MASK_TYPE_CAUSAL_MASK, //!< 内部生成mask }; //! //! \brief mask类型 diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp index 27ebcb95fd660c5f0cec20d0c44797a4acbe5150..ae93e72e65f466e39ad8cd1dc90a230eacddbfde 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_operation.cpp @@ -113,7 +113,7 @@ static bool ParamRangeCheck(const infer::MultiLatentAttentionParam &opParam) return false; } if (opParam.maskType < infer::MultiLatentAttentionParam::MaskType::UNDEFINED || - opParam.maskType > infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE) { + opParam.maskType > infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { ATB_LOG(ERROR) << "invalid maskType"; return false; } @@ -135,12 +135,15 @@ static bool ParamPrefillCheck(const infer::MultiLatentAttentionParam &opParam) ATB_LOG(ERROR) << "headNum should be >= 1 and <= 128"; return false; } - if (opParam.headNum != opParam.kvHeadNum) { - ATB_LOG(ERROR) << "Prefill, headNum should be equal to kvHeadNum"; - return false; + if (opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { + if (opParam.headNum != opParam.kvHeadNum) { + ATB_LOG(ERROR) << "Prefill, headNum should be equal to kvHeadNum"; + return false; + } } if (opParam.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED && - opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE) { + opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE && + opParam.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { ATB_LOG(ERROR) << "Prefill, maskType support UNDEFINED and MASK_TYPE_MASK_FREE"; return false; } @@ -156,7 +159,8 @@ MultiLatentAttentionOperation::MultiLatentAttentionOperation(const infer::MultiL { std::string opIrKeyStr; opIrKeyStr += "MultiLatentAttentionOperation"; - if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { + if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED && + param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { opIrKeyStr += "Mask"; } if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_SPEC || param_. @@ -184,7 +188,8 @@ MultiLatentAttentionOperation::~MultiLatentAttentionOperation() {} uint32_t MultiLatentAttentionOperation::GetInputNum() const { uint32_t intensorNumBase = IN_TENSOR_NUM; - if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED) { + if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::UNDEFINED && + param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { intensorNumBase++; } if (param_.calcType == infer::MultiLatentAttentionParam::CalcType::CALC_TYPE_PREFILL) { diff --git a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp index 0c0e7aef8a678977340faa635ce8bae3ed221fe2..b50767133aa7161e00d6b257d1d788a3742fdd36 100644 --- a/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp +++ b/src/ops_infer/multi_latent_attention/multi_latent_attention_ops_runner_prefill.cpp @@ -66,9 +66,13 @@ Status MultiLatentAttentionOpsRunnerPrefill::SetupKernelGraph(const OpsTensorPac asdParam.tor = param_.qkScale; asdParam.kvHead = param_.kvHeadNum; asdParam.isRing = 0; - asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; + if (param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE) { + asdParam.maskType = AtbOps::OpParam::MLA::MaskType::MASK_TYPE_MASK_FREE; + } else if (param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK) { + asdParam.maskType = AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_MASK; + } else { + asdParam.maskType = AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; + } mlaNode.opDesc = {0, "MLAOperation", asdParam}; mlaNode.inTensors = {&query, &queryRope, &key, &keyRope, &value, &mask, &alibiCoeff, @@ -85,7 +89,8 @@ Status MultiLatentAttentionOpsRunnerPrefill::ModifyKernelGraph(const OpsTensorPa ATB_LOG(ERROR) << GetLogPrefix() << " build param from host tensor fail"; return ERROR_INVALID_PARAM; } - if (newParam_.contextLens != newParam_.qSeqlen) { + if (param_.maskType != infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_CAUSAL_MASK && + newParam_.contextLens != newParam_.qSeqlen) { ATB_LOG(ERROR) << GetLogPrefix() << " qSeqLen and kvSeqLen should be same"; return ERROR_INVALID_PARAM; } @@ -98,7 +103,7 @@ Status MultiLatentAttentionOpsRunnerPrefill::ModifyKernelGraph(const OpsTensorPa asdParam.qSeqLen = newParam_.qSeqlen; asdParam.isRing = 0; asdParam.maskType = param_.maskType == infer::MultiLatentAttentionParam::MaskType::MASK_TYPE_MASK_FREE ? - AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_COMPRESS : + AtbOps::OpParam::MLA::MaskType::MASK_TYPE_CAUSAL_MASK : AtbOps::OpParam::MLA::MaskType::MASK_TYPE_NONE; mlaNode.opDesc = {0, "MLAOperation", asdParam}; return NO_ERROR; diff --git a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py index 8a51c7e26eb90b6f480fe0e55276f221be6ea7cf..d3faf0157e064dfb534707ae97ced1ec9af76d64 100644 --- a/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py +++ b/tests/apitest/opstest/python/operations/multi_latent_attention/test_multi_latent_attention_prefill.py @@ -1033,5 +1033,179 @@ class TestMLAPrefill(operation_test.OperationTest): self.mask.to(data_type).npu() ]) + def test_flash_attention_mla_fp16_causal_mask(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + batch = 1 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + q_seqLen = kv_seqLen + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MultiLatentAttentionOperation" + OP_PARAM = { + "headNum": heads, + "qkScale": tor, + "kvHeadNum": kv_head, + "maskType": 3, + "calcType": 4, + "cacheMode": 1 + } + RUN_PARAM = json.dumps({"qSeqlen": q_seqLen, "contextLens": kv_seqLen}) + data_type = torch.float16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = False, + op_type = 1, mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + for i in range(1): + self.execute_with_param(OP_NAME, OP_PARAM, RUN_PARAM, [ + self.q_split1.npu(), + self.q_split2.npu(), + self.k_split1[0].npu(), + self.k_split2[0].npu(), + self.v[0].npu(), + torch.tensor(q_seqLen, dtype=torch.int32).npu(), + torch.tensor(kv_seqLen, dtype=torch.int32).npu() + ]) + + def test_flash_attention_mla_bf16_batch_causal_mask(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + batch = 2 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 1 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + q_seqLen = kv_seqLen + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MultiLatentAttentionOperation" + OP_PARAM = { + "headNum": heads, + "qkScale": tor, + "kvHeadNum": kv_head, + "maskType": 3, + "calcType": 4, + "cacheMode": 1 + } + RUN_PARAM = json.dumps({"qSeqlen": q_seqLen, "contextLens": kv_seqLen}) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = False, + op_type = 1, mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + for i in range(1): + self.execute_with_param(OP_NAME, OP_PARAM, RUN_PARAM, [ + self.q_split1.npu(), + self.q_split2.npu(), + self.k_split1[0].npu(), + self.k_split2[0].npu(), + self.v[0].npu(), + torch.tensor(q_seqLen, dtype=torch.int32).npu(), + torch.tensor(kv_seqLen, dtype=torch.int32).npu() + ]) + + def test_flash_attention_mla_bf16_head_batch_causal_mask(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + batch = 2 + kv_head = 1 # kv_head num + isdecoder = 0 # prefill or decoder + heads = 2 # llama7b hidden_size 4096 + embeddim = 192 + embeddimV = 128 + max_seq = 128 + tor = 1.0 / math.sqrt(1.0 * embeddim) + dynamic_batch = False + kv_seqLen = [128] * batch + q_seqLen = kv_seqLen + is_clamp = 0 + clamp_min = 0 + clamp_max = 0 + OP_NAME = "MultiLatentAttentionOperation" + OP_PARAM = { + "headNum": heads, + "qkScale": tor, + "kvHeadNum": kv_head, + "maskType": 3, + "calcType": 4, + "cacheMode": 1 + } + RUN_PARAM = json.dumps({"qSeqlen": q_seqLen, "contextLens": kv_seqLen}) + data_type = torch.bfloat16 + + self.set_data_params(dynamic_batch = dynamic_batch, + is_decoder = isdecoder, batch = batch, kv_head = kv_head, heads = heads, + embeddim = embeddim,embeddimv = embeddimV, max_seq = max_seq, kv_seqLen = kv_seqLen, + is_clamp = is_clamp, clamp_max = clamp_max, clamp_min = clamp_min, + data_type = data_type, is_alibi = False, is_triu_mask = False, + op_type = 1, mask_type = MASK_TYPE_NO_BATCH, tor = tor, long_seq = True) + self.gen_out_tensor() + self.mask = self.mask.view(512, 512).to(data_type) + + logging.debug("**********input shape***********") + logging.info(f"q shape: {self.q.shape}") + logging.info(f"k shape: {self.k.shape}") + logging.info(f"v shape: {self.v.shape}") + logging.info(f"layer_id shape: {self.layer_id.shape}") + logging.info(f"mask shape: {self.mask.shape}") + logging.info(f"golden shape: {self.golden_out.shape}") + + for i in range(1): + self.execute_with_param(OP_NAME, OP_PARAM, RUN_PARAM, [ + self.q_split1.npu(), + self.q_split2.npu(), + self.k_split1[0].npu(), + self.k_split2[0].npu(), + self.v[0].npu(), + torch.tensor(q_seqLen, dtype=torch.int32).npu(), + torch.tensor(kv_seqLen, dtype=torch.int32).npu() + ]) + if __name__ == '__main__': unittest.main() \ No newline at end of file