From bc2330ad5e6ebd71650b469aedb7c0f7ae30fff6 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Mon, 25 Aug 2025 12:03:12 +0800 Subject: [PATCH] ringmla typo fix --- include/atb/infer_op_params.h | 7 ++++--- src/ops_infer/ring_mla/ring_mla_operation.cpp | 4 ++-- tests/cinterface/ring_mla_c_interface_test.cpp | 2 +- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/include/atb/infer_op_params.h b/include/atb/infer_op_params.h index 7896da47..579e3eff 100644 --- a/include/atb/infer_op_params.h +++ b/include/atb/infer_op_params.h @@ -3241,9 +3241,10 @@ struct RingMLAParam { //! \brief 计算类型 //! enum CalcType : int { - CALC_TYPE_DEFAULT = 0, // 默认,非首末卡场景,有prev_lse, prev_o传入,生成softmaxLse输出 - CALC_TYPE_FISRT_RING, // 首卡场景,无prev_lse, prev_o传入,生成softmaxLse输出 - CALC_TYPE_MAX + CALC_TYPE_DEFAULT = 0, // 默认,非首末卡场景,有prev_lse, prev_o传入,生成softmaxLse输出 + CALC_TYPE_FIRST_RING = 1, // 首卡场景,无prev_lse, prev_o传入,生成softmaxLse输出 + CALC_TYPE_FISRT_RING = 1, // 兼容 + CALC_TYPE_MAX = 2 }; //! //! \enum KernelType diff --git a/src/ops_infer/ring_mla/ring_mla_operation.cpp b/src/ops_infer/ring_mla/ring_mla_operation.cpp index ffbc3233..a23acdbc 100644 --- a/src/ops_infer/ring_mla/ring_mla_operation.cpp +++ b/src/ops_infer/ring_mla/ring_mla_operation.cpp @@ -53,8 +53,8 @@ namespace atb { bool ParamCheck(const infer::RingMLAParam &opParam, ExternalError &extError) { if (opParam.calcType != infer::RingMLAParam::CalcType::CALC_TYPE_DEFAULT && - opParam.calcType != infer::RingMLAParam::CalcType::CALC_TYPE_FISRT_RING) { - extError.errorDesc = "Ring MLA expects calcType to be one of CALC_TYPE_DEFAULT, CALC_TYPE_FISRT_RING."; + opParam.calcType != infer::RingMLAParam::CalcType::CALC_TYPE_FIRST_RING) { + extError.errorDesc = "Ring MLA expects calcType to be one of CALC_TYPE_DEFAULT, CALC_TYPE_FIRST_RING."; extError.errorData = OperationUtil::ConcatInfo("But got param.calcType = ", opParam.calcType); ATB_LOG(ERROR) << extError; return false; diff --git a/tests/cinterface/ring_mla_c_interface_test.cpp b/tests/cinterface/ring_mla_c_interface_test.cpp index e23db906..b34ef3cd 100644 --- a/tests/cinterface/ring_mla_c_interface_test.cpp +++ b/tests/cinterface/ring_mla_c_interface_test.cpp @@ -104,7 +104,7 @@ void TestRingMLA(const int64_t headNum, const int64_t kvHeadNum, const int64_t h uint64_t workspaceSize = 0; atb::Operation *op = nullptr; - if (calcType == 1) { // CALC_TYPE_FISRT_RING case: Don't take in prevOut, prevLse + if (calcType == 1) { // CALC_TYPE_FIRST_RING case: Don't take in prevOut, prevLse aclrtFreeHost(inoutDevice[PREF_OUT_INDEX]); // preOut aclrtFreeHost(inoutDevice[PREF_OUT_INDEX + 1]); // preLse inoutHost[PREF_OUT_INDEX] = nullptr; -- Gitee