diff --git a/include/atb/infer_op_params.h b/include/atb/infer_op_params.h index 635ad86c06e51b03a706f278c13295ec79ffb8e1..82d380801a9860f320864b690be2c5305288d21c 100644 --- a/include/atb/infer_op_params.h +++ b/include/atb/infer_op_params.h @@ -3233,9 +3233,10 @@ struct RingMLAParam { //! \brief 计算类型 //! enum CalcType : int { - CALC_TYPE_DEFAULT = 0, // 默认,非首末卡场景,有prev_lse, prev_o传入,生成softmaxLse输出 - CALC_TYPE_FISRT_RING, // 首卡场景,无prev_lse, prev_o传入,生成softmaxLse输出 - CALC_TYPE_MAX + CALC_TYPE_DEFAULT = 0, // 默认,非首末卡场景,有prev_lse, prev_o传入,生成softmaxLse输出 + CALC_TYPE_FIRST_RING = 1, // 首卡场景,无prev_lse, prev_o传入,生成softmaxLse输出 + CALC_TYPE_FISRT_RING = 1, // 兼容 + CALC_TYPE_MAX = 2 }; //! //! \enum KernelType diff --git a/src/ops_infer/ring_mla/ring_mla_operation.cpp b/src/ops_infer/ring_mla/ring_mla_operation.cpp index 31bec89e7e09286dd5ab09c9fb806f3ffbb6f9bd..3f0d1062d7fc31f29d8b21a49ac00534a144caf0 100644 --- a/src/ops_infer/ring_mla/ring_mla_operation.cpp +++ b/src/ops_infer/ring_mla/ring_mla_operation.cpp @@ -53,8 +53,8 @@ namespace atb { bool ParamCheck(const infer::RingMLAParam &opParam, ExternalError &extError) { if (opParam.calcType != infer::RingMLAParam::CalcType::CALC_TYPE_DEFAULT && - opParam.calcType != infer::RingMLAParam::CalcType::CALC_TYPE_FISRT_RING) { - extError.errorDesc = "Ring MLA expects calcType to be one of CALC_TYPE_DEFAULT, CALC_TYPE_FISRT_RING."; + opParam.calcType != infer::RingMLAParam::CalcType::CALC_TYPE_FIRST_RING) { + extError.errorDesc = "Ring MLA expects calcType to be one of CALC_TYPE_DEFAULT, CALC_TYPE_FIRST_RING."; extError.errorData = OperationUtil::ConcatInfo("But got param.calcType = ", opParam.calcType); ATB_LOG(ERROR) << extError; return false; diff --git a/tests/cinterface/ring_mla_c_interface_test.cpp b/tests/cinterface/ring_mla_c_interface_test.cpp index 70ae54c62dbbbc226c2dc2ec59a57fe2b46719df..2676eff33f0d46bf557407db8f61bd6d67d44378 100644 --- a/tests/cinterface/ring_mla_c_interface_test.cpp +++ b/tests/cinterface/ring_mla_c_interface_test.cpp @@ -103,7 +103,7 @@ void TestRingMLA(const int64_t headNum, const int64_t kvHeadNum, const int64_t h uint64_t workspaceSize = 0; atb::Operation *op = nullptr; - if (calcType == 1) { // CALC_TYPE_FISRT_RING case: Don't take in prevOut, prevLse + if (calcType == 1) { // CALC_TYPE_FIRST_RING case: Don't take in prevOut, prevLse aclrtFreeHost(inoutDevice[PREF_OUT_INDEX]); // preOut aclrtFreeHost(inoutDevice[PREF_OUT_INDEX + 1]); // preLse inoutHost[PREF_OUT_INDEX] = nullptr;