From 62d845fdf8d04bd284729229180a36cb814885d3 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Thu, 3 Jul 2025 10:55:15 +0800 Subject: [PATCH 1/4] cleancode --- .../paged_cache_load_operation.cpp | 15 ++-- .../tiling/paged_cache_load_tiling.cpp | 4 +- .../tiling/reshape_and_cache_tiling.cpp | 6 +- .../ring_mla/ring_mla_operation.cpp | 1 + .../ring_mla/tiling/ring_mla_tiling.cpp | 68 ++++++++-------- .../tiling/ring_mla_tiling_dependency.cpp | 14 ++-- ..._and_rope_and_reshape_and_cache_tiling.cpp | 1 + .../tiling/flash_attention_tiling.cpp | 4 +- .../flash_attention_tiling_dependency.cpp | 80 +++++++++++-------- .../unpad_flash_attention_kernel.cpp | 6 +- .../tiling/flash_attention_nz_tiling.cpp | 2 +- .../flash_attention_nz_tiling_dependency.cpp | 2 +- .../all_to_all/all_to_all_operation.cpp | 9 ++- .../linear_parallel_operation.cpp | 5 +- .../mm_deq_swiglu_quant_mm_deq_operation.cpp | 2 +- src/ops_infer/relay_attention/param.cpp | 2 +- src/ops_infer/ring_mla/ring_mla_operation.cpp | 4 +- 17 files changed, 121 insertions(+), 104 deletions(-) diff --git a/src/kernels/mixkernels/paged_cache_load/paged_cache_load_operation.cpp b/src/kernels/mixkernels/paged_cache_load/paged_cache_load_operation.cpp index 77490911..61121be0 100644 --- a/src/kernels/mixkernels/paged_cache_load/paged_cache_load_operation.cpp +++ b/src/kernels/mixkernels/paged_cache_load/paged_cache_load_operation.cpp @@ -136,13 +136,14 @@ public: auto headSizeK = keyCache.desc.dims[DIM_3]; auto headSizeV = valueCache.desc.dims[DIM_3]; constexpr int32_t baiscBlockSize = 32; - MKI_CHECK(numHeads * headSizeK * GetTensorElementSize(inDtype) % baiscBlockSize == 0, - "dim_2*dim_3 of KCache should be 32B aligned", return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); - MKI_CHECK(numHeads * headSizeV * GetTensorElementSize(inDtype) % baiscBlockSize == 0, - "dim_2*dim_3 of VCache should be 32B aligned", return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); - - MKI_CHECK(key.desc.dims.size() == DIM_3 && value.desc.dims.size() == DIM_3, - "K&V's dim num should be " << DIM_3, return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); + int32_t typeByte = static_cast(GetTensorElementSize(inDtype)); + MKI_CHECK(numHeads * headSizeK * typeByte % baiscBlockSize == 0, "dim_2*dim_3 of KCache should be 32B aligned", + return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); + MKI_CHECK(numHeads * headSizeV * typeByte % baiscBlockSize == 0, "dim_2*dim_3 of VCache should be 32B aligned", + return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); + + MKI_CHECK(key.desc.dims.size() == DIM_3 && value.desc.dims.size() == DIM_3, "K&V's dim num should be " << DIM_3, + return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); MKI_CHECK(key.desc.dtype == inDtype && value.desc.dtype == inDtype, "K&V's dtype must be same as KCache&VCache", return Status::FailStatus(ERROR_INFERSHAPE_ERROR)); MKI_CHECK(key.desc.dims[DIM_0] == value.desc.dims[DIM_0], diff --git a/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp b/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp index 6b20b79d..ddfe9816 100644 --- a/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp +++ b/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp @@ -51,7 +51,7 @@ bool CommonPagedCacheLoadTiling(const LaunchParam &launchParam, KernelInfo &kern int32_t numblkTabCol = static_cast(blockTablesShape.at(DIM_1)); // block tables column TensorDType inDtype = launchParam.GetInTensor(DIM_0).desc.dtype; - uint32_t typeByte = static_cast(GetTensorElementSize(inDtype)); + uint32_t typeByte = GetTensorElementSize(inDtype); MKI_CHECK(blockSize > 0 && blockSize <= INT_MAX, "blockSize is invalid", return false); MKI_CHECK(numTokens > 0 && numTokens <= INT_MAX, "numTokens is invalid", return false); @@ -68,7 +68,7 @@ bool CommonPagedCacheLoadTiling(const LaunchParam &launchParam, KernelInfo &kern tilingDataPtr->numblkTabCol = numblkTabCol; tilingDataPtr->tokenSizeK = tokenSizeK; tilingDataPtr->tokenSizeV = tokenSizeV; - tilingDataPtr->typeByte = typeByte; + tilingDataPtr->typeByte = static_cast(typeByte); tilingDataPtr->hasSeqStarts = param.hasSeqStarts; tilingDataPtr->cuSeqLens = param.cuSeqLens; diff --git a/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp b/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp index 1893ccba..e649de82 100644 --- a/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp +++ b/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp @@ -61,10 +61,10 @@ Status ReshapeAndCacheTilingNd(const LaunchParam &launchParam, KernelInfo &kerne auto &offsetV = value.desc.offset; tilingDataPtr->headSizeV = static_cast(headSizeV); - tilingDataPtr->offsetK = static_cast(offsetK); - tilingDataPtr->offsetV = static_cast(offsetV); + tilingDataPtr->offsetK = static_cast(offsetK); + tilingDataPtr->offsetV = static_cast(offsetV); for (uint32_t i = 0; i < SHAPE_DIMS; ++i) { - tilingDataPtr->stride[i] = static_cast(key.desc.strides[i]); + tilingDataPtr->stride[i] = static_cast(key.desc.strides[i]); } blockDim = tilingDataPtr->numTokens < blockDim ? tilingDataPtr->numTokens : blockDim; bool isAlign = ((tilingDataPtr->numHeads * tilingDataPtr->headSizeK * tilingDataPtr->typeByte) % ALIGN == 0 diff --git a/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp b/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp index 9979f12c..7106429b 100644 --- a/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp +++ b/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp @@ -284,6 +284,7 @@ private: bool CheckMLA(const LaunchParam &launchParam) const { + (void)launchParam; return true; } }; diff --git a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp index 7178dedf..7c128795 100644 --- a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp +++ b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp @@ -1,12 +1,12 @@ /* -* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. -* This file is a part of the CANN Open Software. -* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). -* Please refer to the License for details. You may not use this file except in compliance with the License. -* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, -* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. -* See LICENSE in the root of the software repository for the full text of the License. -*/ + * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ #include #include #include "mki/utils/assert/assert.h" @@ -29,16 +29,15 @@ const int32_t NUM512 = 512; const int32_t NUM576 = 576; const float SPLITKV_RATION = 0.8; -Status GetMLANdInfo(const LaunchParam &launchParam, RINGMLAInfo &mmInfo, - OpParam::RINGMLA ¶m) +Status GetMLANdInfo(const LaunchParam &launchParam, RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m) { auto kcacheShape = launchParam.GetInTensor(DIM_2).desc.dims; auto KDims = kcacheShape.size(); auto tableShape = launchParam.GetInTensor(DIM_4).desc.dims; mmInfo.kNz = (kcacheShape.at(KDims - 1) == NUM16 || kcacheShape.at(KDims - 1) == NUM32) ? 1 : 0; if (mmInfo.kNz) { - mmInfo.embeddingSize = static_cast(kcacheShape.at(DIM_3)) * - static_cast(kcacheShape.at(DIM_1)); + mmInfo.embeddingSize = + static_cast(kcacheShape.at(DIM_3)) * static_cast(kcacheShape.at(DIM_1)); mmInfo.blockSize = static_cast(kcacheShape.at(DIM_2)); } else { mmInfo.embeddingSize = static_cast(kcacheShape.at(DIM_3)); @@ -92,8 +91,9 @@ Status GetTilingKeyTypeBase(RINGMLAInfo &mmInfo, const Tensor &qTensor, const Te Status GenTilingKey(RINGMLAInfo &mmInfo, KernelInfo &kernelInfo, OpParam::RINGMLA ¶m) { - uint32_t dataType = static_cast(mmInfo.type); - uint32_t tilingKey = dataType + (mmInfo.kNz << NUM4) + (mmInfo.mtpTp1Flag << NUM2) + (param.isRing << NUM5); + uint32_t dataType = static_cast(mmInfo.type); + uint32_t tilingKey = + dataType + (mmInfo.kNz << NUM4) + (mmInfo.mtpTp1Flag << NUM2) + (static_cast(param.isRing) << NUM5); kernelInfo.SetTilingId(tilingKey); MKI_LOG(INFO) << "TILING KEY IS = " << tilingKey; return Status::OkStatus(); @@ -107,7 +107,7 @@ Status RINGMLATiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) RINGMLAInfo mmInfo; GetTilingKeyTypeBase(mmInfo, qTensor, qRopeTensor); - Status ret1 = GetRINGMLAInfo(launchParam, mmInfo, param); + Status ret1 = GetRINGMLAInfo(launchParam, mmInfo, param); uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); uint32_t *tilingParam = reinterpret_cast(kernelInfo.GetTilingHostAddr()); uint64_t tilingSize = kernelInfo.GetTilingSize(); @@ -124,15 +124,14 @@ Status RINGMLATiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) uint64_t sWorkSpaceSize = mmInfo.mtpTp1Flag ? basicWorkSpaceFloat * 2 : basicWorkSpaceFloat; uint64_t pWorkSpaceSize = basicWorkSpaceInt8; uint64_t oTempWorkSpcaceSize = basicWorkSpaceInt8 * 2; - kernelInfo.GetScratchSizes() = {sWorkSpaceSize, sWorkSpaceSize, pWorkSpaceSize, - oTempWorkSpcaceSize, basicWorkSpaceFloat}; + kernelInfo.GetScratchSizes() = {sWorkSpaceSize, sWorkSpaceSize, pWorkSpaceSize, oTempWorkSpcaceSize, + basicWorkSpaceFloat}; } else { uint64_t sWorkSpaceSize = mmInfo.mtpTp1Flag ? basicWorkSpaceFloat * 4 : basicWorkSpaceFloat * 2; uint64_t pWorkSpaceSize = mmInfo.mtpTp1Flag ? basicWorkSpaceHalf * 4 : basicWorkSpaceHalf * 2; uint64_t oTempWorkSpcaceSize = mmInfo.mtpTp1Flag ? basicWorkSpaceFloat * 4 : basicWorkSpaceFloat * 2; uint64_t goWorkSpcaceSize = mmInfo.mtpTp1Flag ? basicWorkSpaceFloat * 2 : basicWorkSpaceFloat; - kernelInfo.GetScratchSizes() = {sWorkSpaceSize, NUM512, pWorkSpaceSize, oTempWorkSpcaceSize, - goWorkSpcaceSize}; + kernelInfo.GetScratchSizes() = {sWorkSpaceSize, NUM512, pWorkSpaceSize, oTempWorkSpcaceSize, goWorkSpcaceSize}; } Status ret2 = GenTilingKey(mmInfo, kernelInfo, param); OP_TILING_CHECK_STATUS_RETURN(ret2); @@ -158,20 +157,20 @@ inline Status PrefillPreCheck(OpParam::RINGMLA ¶m) return Status::OkStatus(); } -Status GetAlibiMaskInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, const Tensor &tensorMask, - int32_t maxSeq) +Status GetAlibiMaskInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, const Tensor &tensorMask, int32_t maxSeq) { + UNUSED_VALUE(param); auto maskShape = tensorMask.desc.dims; auto maskDim = maskShape.size(); MKI_CHECK(maskDim <= DIM_4 && maskDim >= DIM_2, "maskdim invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); if (maskDim == DIM_3) { // [h, ms, ms] mmInfo.maskStride = 0; - mmInfo.headStride = static_cast(maxSeq); + mmInfo.headStride = maxSeq; } else if (maskDim == DIM_4) { // [bs,1,ms,ms] [bs,headnum,ms,ms] MKI_CHECK(maskShape.at(DIM_2) * maskShape.at(1) <= UINT32_MAX, "maxSeq * headnum can not large than UINT32_MAX", return Status::FailStatus(ERROR_INVALID_VALUE)); mmInfo.maskStride = maskShape.at(1) * maskShape.at(DIM_2); - mmInfo.headStride = static_cast(maxSeq); + mmInfo.headStride = maxSeq; } else if (maskDim == DIM_2) { MKI_CHECK(maxSeq == LONG_SEQ_ALIBI_LEN, "alibi mask shape must be [256, 256] for long seq opt", return Status::FailStatus(ERROR_INVALID_VALUE)); @@ -181,8 +180,7 @@ Status GetAlibiMaskInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, const Tens return Status::OkStatus(); } -inline Status GetPrefiillMaskInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, - const Tensor &tensorMask) +inline Status GetPrefiillMaskInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, const Tensor &tensorMask) { auto &tensorAlibiCoeff = mmInfo.tensors.alibiCoeff; auto maskShape = tensorMask.desc.dims; @@ -232,22 +230,22 @@ inline Status MLAPrefillPostCheck(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m) MKI_CHECK(mmInfo.isTriuMask == 0 || mmInfo.isTriuMask == 1, "param isTriuMask is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(mmInfo.embeddingSize <= MAX_EMBEDDING && mmInfo.embeddingSize > 0, "headdimQ is invalid", - return Status::FailStatus(ERROR_INVALID_VALUE)); + return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(mmInfo.embeddingSizeV <= MAX_EMBEDDING && mmInfo.embeddingSizeV > 0, "headdimV is invalid", - return Status::FailStatus(ERROR_INVALID_VALUE)); + return Status::FailStatus(ERROR_INVALID_VALUE)); return Status::OkStatus(); } inline void PrefillLog(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m) { + UNUSED_VALUE(param); MKI_LOG(INFO) << "batch is: " << mmInfo.batch << " kvMaxSeq is: " << mmInfo.maxKvSeqLen << " maskMaxSeq is: " << mmInfo.maxSeqLen << " head is: " << mmInfo.numHeads << " qSeq is: " << *(mmInfo.qSeqLen) << " kvSeq is: " << *(mmInfo.kvSeqLen) - << " tor is : " << mmInfo.tor - << " kv_head is: " << mmInfo.kvHeads << " embed is: " << mmInfo.embeddingSize - << " maskStrid is: " << mmInfo.maskStride << " headStride is: " << mmInfo.headStride << " isTriuMask " - << mmInfo.isTriuMask << " mask type is " << mmInfo.maskType << " longseq is " << mmInfo.isLongSeq - << "isRing is" << mmInfo.isRing; + << " tor is : " << mmInfo.tor << " kv_head is: " << mmInfo.kvHeads + << " embed is: " << mmInfo.embeddingSize << " maskStrid is: " << mmInfo.maskStride + << " headStride is: " << mmInfo.headStride << " isTriuMask " << mmInfo.isTriuMask << " mask type is " + << mmInfo.maskType << " longseq is " << mmInfo.isLongSeq << "isRing is" << mmInfo.isRing; } void MLAPrefillFillInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, size_t batch, int32_t embed) @@ -263,9 +261,9 @@ void MLAPrefillFillInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m, size_t bat mmInfo.isTriuMask = param.isTriuMask; mmInfo.kTensorList = param.kTensorList; mmInfo.vTensorList = param.vTensorList; - mmInfo.maskType = static_cast(param.maskType); + mmInfo.maskType = static_cast(param.maskType); mmInfo.quantType = static_cast(param.quantType); - mmInfo.isRing = param.isRing; + mmInfo.isRing = static_cast(param.isRing); } Status InitInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m) @@ -318,7 +316,7 @@ Status RINGMLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelIn OP_TILING_CHECK_STATUS_RETURN(ret); uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); MKI_CHECK(blockDim > 0, "blockDim cannot <= 0", return Status::FailStatus(ERROR_INVALID_VALUE)); - mmInfo.blockDim = blockDim; + mmInfo.blockDim = static_cast(blockDim); uint8_t *tilingHost = kernelInfo.GetTilingHostAddr(); uint64_t tilingSize = kernelInfo.GetTilingSize(); uint32_t *tilingParam = reinterpret_cast(tilingHost); diff --git a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp index 1c0e6dfd..4ff14008 100644 --- a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp @@ -167,9 +167,9 @@ int32_t GetMaxKVseqlen(const OpParam::RINGMLA ¶m) Status GetNdMLATiling(const RINGMLAInfo &mmInfo, uint32_t &blockDim, uint32_t *tilingParam, const OpParam::RINGMLA ¶m) { + UNUSED_VALUE(blockDim); AddrOffsets addrOffsets {}; - auto qSeqLen = param.qSeqLen; int32_t maxQseqlen = GetMaxQseqlen(param); MKI_CHECK(maxQseqlen > 0, "qSeqlen max value invalid, please check", return AtbOps::Status::FailStatus(ERROR_INFERSHAPE_ERROR, "OpParam is invalid")); @@ -208,6 +208,7 @@ Status GetNdMLATiling(const RINGMLAInfo &mmInfo, uint32_t &blockDim, uint32_t *t void GetNdMLAMtpTilingTP1(const RINGMLAInfo &mmInfo, uint32_t &blockDim, uint32_t *tilingParam, const OpParam::RINGMLA ¶m) { + UNUSED_VALUE(param); bool isFP16 = static_cast(mmInfo.type) < NUM2; int32_t maxQPerJob = isFP16 ? NUM2 : NUM1; int32_t prevTaskNum = 0; @@ -224,10 +225,10 @@ void GetNdMLAMtpTilingTP1(const RINGMLAInfo &mmInfo, uint32_t &blockDim, uint32_ int32_t tilingOffset = TILING_HEAD_SIZE + TILING_PARA_SIZE_TP1 * prevTaskNum; int32_t curQLen = ((qSeqLen - qSeq) > maxQPerJob) ? maxQPerJob : (qSeqLen - qSeq); curKvSeq += curQLen; - tilingParam[tilingOffset] = seqIdx; - tilingParam[tilingOffset + NUM1] = qRowIdx; - tilingParam[tilingOffset + NUM2] = curKvSeq; - tilingParam[tilingOffset + NUM3] = curQLen; + tilingParam[tilingOffset] = static_cast(seqIdx); + tilingParam[tilingOffset + NUM1] = static_cast(qRowIdx); + tilingParam[tilingOffset + NUM2] = static_cast(curKvSeq); + tilingParam[tilingOffset + NUM3] = static_cast(curQLen); prevTaskNum++; qRowIdx += curQLen; totalTaskNum -= curQLen; @@ -243,6 +244,7 @@ void GetNdMLAMtpTilingTP1(const RINGMLAInfo &mmInfo, uint32_t &blockDim, uint32_ void GetTilingHead(const RINGMLAInfo &mmInfo, const OpParam::RINGMLA ¶m, uint32_t *tilingParam, const uint32_t *torPtr) { + UNUSED_VALUE(param); tilingParam[TILING_BATCH] = static_cast(mmInfo.batch); tilingParam[TILING_HEADSIZE] = static_cast(TILING_HEAD_SIZE); tilingParam[TILING_PARASIZE] = mmInfo.mtpTp1Flag ? static_cast(TILING_PARA_SIZE_TP1) : @@ -295,6 +297,7 @@ Status GetRINGMLATilingParam(const LaunchParam &launchParam, const RINGMLAInfo & void InitTilingKWithN(const RINGMLAInfo &mmInfo, int32_t &embeddingSizeAligned, int32_t &nIbd, int32_t seqIdx, uint32_t *tilingParam) { + UNUSED_VALUE(embeddingSizeAligned); int32_t kvSeqlen = *(mmInfo.kvSeqLen + seqIdx); int32_t kvSeqlenAligned = (kvSeqlen + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; int32_t nUbd = std::min(LONG_SEQ_LEN, kvSeqlenAligned); @@ -307,6 +310,7 @@ Status CheckSeqlen(const RINGMLAInfo &mmInfo, const AddrOffsets &addrOffsets, in int32_t kvSeqlen, int32_t seqIdx) { UNUSED_VALUE(seqIdx); + UNUSED_VALUE(mmInfo); MKI_CHECK(qSeqlen >= 0, "qSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(kvSeqlen >= 0, "kvSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(addrOffsets.totalQBlkNum >= 0, "totalQBlkNum overflow", diff --git a/src/kernels/mixkernels/rms_norm_and_rope_and_reshape_and_cache/tiling/rms_norm_and_rope_and_reshape_and_cache_tiling.cpp b/src/kernels/mixkernels/rms_norm_and_rope_and_reshape_and_cache/tiling/rms_norm_and_rope_and_reshape_and_cache_tiling.cpp index 54b2357c..b03c9f85 100644 --- a/src/kernels/mixkernels/rms_norm_and_rope_and_reshape_and_cache/tiling/rms_norm_and_rope_and_reshape_and_cache_tiling.cpp +++ b/src/kernels/mixkernels/rms_norm_and_rope_and_reshape_and_cache/tiling/rms_norm_and_rope_and_reshape_and_cache_tiling.cpp @@ -48,6 +48,7 @@ T CeilDiv(T x, T y) namespace AtbOps { Status TilingKeyChose(const LaunchParam &launchParam, KernelInfo &kernelInfo, RmsNormAndRopeAndReshapeAndCacheTilingData *tilingDataPtr) { + (void)tilingDataPtr; auto platformType = PlatformInfo::Instance().GetPlatformType(); if (platformType != PlatformType::ASCEND_910B) { MKI_LOG(ERROR) << "BF16 only supports 800I A2"; diff --git a/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling.cpp b/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling.cpp index 9522ee80..c862ae2e 100644 --- a/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling.cpp +++ b/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling.cpp @@ -136,7 +136,7 @@ Status FlashAttentionTiling(const LaunchParam &launchParam, KernelInfo &kernelIn uint64_t glWorkSize = static_cast(mmInfo.batchSize) * static_cast(mmInfo.innerBatchSize) * 4 * 2; uint64_t goWorkSize = static_cast(mmInfo.batchSize) * static_cast(mmInfo.innerBatchSize) * - mmInfo.embeddingSize * 4 * 2; + static_cast(mmInfo.embeddingSize) * 4 * 2; kernelInfo.GetScratchSizes() = {workSize, workSize, workSize, goWorkSize, glWorkSize}; blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); } else { @@ -629,7 +629,7 @@ Status InitMaxKVSeqlen(size_t &batch, Mki::SVector &kcacheShape, UnpadF mmInfo.isNoCache = true; } else if (param.type == OpParam::UnpadFlashAttention::UNPAD_FLASH_ATTENTION_ENCODER_PREFIX_CACHE_ND) { mmInfo.isNoCache = true; - mmInfo.maxKvSeqLen = kcacheShape.at(0) * kcacheShape.at(0) / batch; + mmInfo.maxKvSeqLen = static_cast(kcacheShape.at(0) * kcacheShape.at(0)) / static_cast(batch); OP_TILING_CHECK_STATUS_RETURN(GetFlashAttentionNoCacheMaskInfo(mmInfo, param, mmInfo.tensors.mask)); } else { if (param.dataShapeType == 1) { diff --git a/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling_dependency.cpp b/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling_dependency.cpp index d25b5b21..5372ddc7 100644 --- a/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling_dependency.cpp +++ b/src/kernels/mixkernels/unpad_flash_attention/tiling/flash_attention_tiling_dependency.cpp @@ -247,7 +247,7 @@ Status EncoderFillTilingParam(const UnpadFlashAttentionInfo &mmInfo, const uint3 PP_MM[mIbd]) * mmInfo.tileQ : 0; uint32_t textBlKNum = (qSeqlen != 0 && kvSeqlen != 0) ? ((mmInfo.textQLen + PP_MM[mIbd] - 1) / PP_MM[mIbd]) : 0; - addrOffsets.totalQBlkNum += razorBlkNum + textBlKNum; + addrOffsets.totalQBlkNum += static_cast(razorBlkNum + textBlKNum); tilingParam[TILING_HEAD_SIZE + seqIdx * TILING_PARA_SIZE + INDEX3] = static_cast(PP_NN[nIbd]); } else { addrOffsets.totalQBlkNum += (qSeqlen != 0 && kvSeqlen != 0) ? ((qSeqlen + mUbd - 1) / mUbd) : 0; @@ -273,22 +273,28 @@ Status EncoderFillTilingParam(const UnpadFlashAttentionInfo &mmInfo, const uint3 void FillAddrOffsets(const AtbOps::UnpadFlashAttentionInfo &mmInfo, AtbOps::AddrOffsets &addrOffsets, int32_t kvRealHeads, int32_t qSeqlen, int32_t kvFactor) { + int64_t maxQSeqLen = static_cast(mmInfo.maxQSeqLen); + int64_t innerBatchSize = static_cast(mmInfo.innerBatchSize); + int64_t embeddingSize = static_cast(mmInfo.embeddingSize); + int64_t embeddingSizeV = static_cast(mmInfo.embeddingSizeV); + int64_t maxKvSeqLen = static_cast(mmInfo.maxKvSeqLen); + int64_t kvRealHeadsInt64 = static_cast(kvRealHeads); + int64_t qSeqlenInt64 = static_cast(qSeqlen); + int64_t kvFactorInt64 = static_cast(kvFactor); if (mmInfo.dataShapeType == 1) { - MKI_LOG(INFO) << "addrOffsets mmInfo.maxSeqLen" << mmInfo.maxSeqLen << "mmInfo.innerBatchSize" - << mmInfo.innerBatchSize << "mmInfo.embeddingSize" << mmInfo.embeddingSize; - MKI_LOG(INFO) << "addrOffsets mmInfo.maxKvSeqLen" << mmInfo.maxKvSeqLen << "kvRealHeads" - << mmInfo.innerBatchSize << "mmInfo.embeddingSize" << mmInfo.embeddingSize; - addrOffsets.addrQSeqOffset += - static_cast(mmInfo.maxQSeqLen) * mmInfo.innerBatchSize * mmInfo.embeddingSize; - addrOffsets.addrKSeqOffset += static_cast(mmInfo.maxKvSeqLen) * kvRealHeads * mmInfo.embeddingSize; - addrOffsets.addrVSeqOffset += static_cast(mmInfo.maxKvSeqLen) * kvRealHeads * mmInfo.embeddingSizeV; - addrOffsets.addrOSeqOffset += - static_cast(mmInfo.maxQSeqLen) * mmInfo.innerBatchSize * mmInfo.embeddingSizeV; + MKI_LOG(INFO) << "addrOffsets mmInfo.maxSeqLen" << mmInfo.maxSeqLen << "mmInfo.innerBatchSize" << innerBatchSize + << "mmInfo.embeddingSize" << embeddingSize; + MKI_LOG(INFO) << "addrOffsets mmInfo.maxKvSeqLen" << maxKvSeqLen << "kvRealHeads" << innerBatchSize + << "mmInfo.embeddingSize" << embeddingSize; + addrOffsets.addrQSeqOffset += static_cast(maxQSeqLen * innerBatchSize * embeddingSize); + addrOffsets.addrKSeqOffset += static_cast(maxKvSeqLen * kvRealHeadsInt64 * embeddingSize); + addrOffsets.addrVSeqOffset += static_cast(maxKvSeqLen * kvRealHeadsInt64 * embeddingSizeV); + addrOffsets.addrOSeqOffset += static_cast(maxQSeqLen * innerBatchSize * embeddingSizeV); } else { - addrOffsets.addrQSeqOffset += static_cast(qSeqlen) * mmInfo.innerBatchSize * mmInfo.embeddingSize; - addrOffsets.addrKSeqOffset += static_cast(kvFactor) * kvRealHeads * mmInfo.embeddingSize; - addrOffsets.addrVSeqOffset += static_cast(kvFactor) * kvRealHeads * mmInfo.embeddingSizeV; - addrOffsets.addrOSeqOffset += static_cast(qSeqlen) * mmInfo.innerBatchSize * mmInfo.embeddingSizeV; + addrOffsets.addrQSeqOffset += static_cast(qSeqlenInt64 * innerBatchSize * embeddingSize); + addrOffsets.addrKSeqOffset += static_cast(kvFactorInt64 * kvRealHeads * embeddingSize); + addrOffsets.addrVSeqOffset += static_cast(kvFactorInt64 * kvRealHeads * embeddingSizeV); + addrOffsets.addrOSeqOffset += static_cast(qSeqlenInt64 * innerBatchSize * embeddingSizeV); } } @@ -323,27 +329,30 @@ void FillSplitBatchPtr(const UnpadFlashAttentionInfo &mmInfo, uint32_t *tilingPa void SplitTaskRelay(const UnpadFlashAttentionInfo &mmInfo, uint32_t *tilingParam, int32_t groupNum, uint32_t blockIdx, uint32_t shareBlockTiling) { - int32_t taskHeadStart = tilingParam[TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling]; - int32_t taskHeadEnd = tilingParam[TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling + 1]; + int32_t taskHeadStart = static_cast(tilingParam[TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling]); + uint32_t taskHeadEnd = static_cast(tilingParam[TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling + 1]); int32_t curBatch = taskHeadStart % (mmInfo.batchSize * groupNum) / groupNum; int32_t curHeadId = taskHeadStart / (mmInfo.batchSize * groupNum) * groupNum; uint32_t shareIndex = TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling + INDEX3; uint32_t unshareIndex = TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling + shareBlockTiling / 2 + INDEX3; int totalShareNum = 0; + uint32_t batchSizeUint32 = static_cast(mmInfo.batchSize); + uint32_t groupNumUint32 = static_cast(groupNum); for (int32_t headId = curHeadId; headId < mmInfo.innerBatchSize; headId += groupNum) { - int32_t nowBatch = headId == curHeadId ? curBatch : 0; - if (mmInfo.batchSize * groupNum * (headId / groupNum) + nowBatch * groupNum >= taskHeadEnd) { + uint32_t nowBatch = static_cast(headId == curHeadId ? curBatch : 0); + uint32_t headIdUint32 = static_cast(headId); + if (batchSizeUint32 * groupNum * (headId / groupNum) + nowBatch * groupNum >= taskHeadEnd) { break; } - uint64_t offsetTiling = TILING_RELAY_HEAD_SIZE + - shareBlockTiling * mmInfo.blockDim + TILING_PARA_SIZE * nowBatch; + uint64_t offsetTiling = TILING_RELAY_HEAD_SIZE + shareBlockTiling * mmInfo.blockDim + + static_cast(TILING_PARA_SIZE * nowBatch); uint32_t nowShare = tilingParam[INDEX18 + offsetTiling]; uint32_t qLen = 0; - while (mmInfo.batchSize * groupNum * (headId / groupNum) + - nowBatch * groupNum < taskHeadEnd && nowBatch < mmInfo.batchSize) { + while (batchSizeUint32 * groupNum * (headId / groupNum) + + nowBatch * groupNum < taskHeadEnd && nowBatch < batchSizeUint32) { uint32_t shareIdx = tilingParam[INDEX18 + offsetTiling]; tilingParam[unshareIndex] = nowBatch; - tilingParam[unshareIndex + 1] = headId; + tilingParam[unshareIndex + 1] = headIdUint32; unshareIndex += INDEX4; if (shareIdx == static_cast(-1)) { nowBatch++; @@ -355,22 +364,22 @@ void SplitTaskRelay(const UnpadFlashAttentionInfo &mmInfo, uint32_t *tilingParam qLen = 0; } tilingParam[shareIndex] = nowBatch; - tilingParam[shareIndex + 1] = headId; + tilingParam[shareIndex + 1] = headIdUint32; tilingParam[shareIndex + INDEX2] = shareIdx; tilingParam[shareIndex + INDEX3] = 0; - if (qLen == 0 && (mmInfo.batchSize * groupNum * (headId / groupNum) + - nowBatch * groupNum + groupNum >= taskHeadEnd || nowBatch + 1 == mmInfo.batchSize || + if (qLen == 0 && (batchSizeUint32 * groupNum * (headId / groupNum) + + nowBatch * groupNum + groupNum >= taskHeadEnd || nowBatch + 1 == batchSizeUint32 || tilingParam[offsetTiling + TILING_PARA_SIZE + INDEX18] != nowShare) && groupNum == 1) { tilingParam[shareIndex + INDEX3] = 1; } shareIndex += INDEX4; totalShareNum++; - qLen = qLen + groupNum; + qLen += groupNumUint32; nowBatch++; offsetTiling += TILING_PARA_SIZE; } } - tilingParam[TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling + INDEX2] = totalShareNum; + tilingParam[TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling + INDEX2] = static_cast(totalShareNum); } void SplitCoreRelay(const UnpadFlashAttentionInfo &mmInfo, uint32_t *tilingParam, uint32_t shareBlockTiling) @@ -385,8 +394,8 @@ void SplitCoreRelay(const UnpadFlashAttentionInfo &mmInfo, uint32_t *tilingParam taskStart = taskEnd; taskEnd = taskEnd + (blockIdx < tailTaskNum ? taskNumPerCore + 1 : taskNumPerCore); uint32_t nowBlockTiling = TILING_RELAY_HEAD_SIZE + blockIdx * shareBlockTiling; - tilingParam[nowBlockTiling] = taskStart * groupNum; - tilingParam[nowBlockTiling + 1] = taskEnd * groupNum; + tilingParam[nowBlockTiling] = static_cast(taskStart * groupNum); + tilingParam[nowBlockTiling + 1] = static_cast(taskEnd * groupNum); if (taskStart >= taskEnd) { continue; } @@ -504,14 +513,15 @@ Status DecoderFillTilingParamRelay(const UnpadFlashAttentionInfo &mmInfo, const int32_t nUbd = std::min((DECODER_PP_BLOCK_BUFFER_SIZE * INDEX2 / embeddingSizeAligned / BLOCK_SIZE) * BLOCK_SIZE, kvSeqlenAligned); int32_t nIbd = ConvertValueToIndexNN(nUbd, PP_NN_NUM - 1); + int64_t seqIdxint64 = static_cast(seqIdx); OP_TILING_CHECK_STATUS_RETURN(CheckSeqlen(mmInfo, addrOffsets, qSeqlen, kvSeqlen, seqIdx)); addrOffsets.addrQSeqOffset = static_cast(seqIdx * qSeqlen * mmInfo.innerBatchSize * mmInfo.embeddingSize); addrOffsets.addrOSeqOffset = static_cast(seqIdx * qSeqlen * mmInfo.innerBatchSize * mmInfo.embeddingSize); - addrOffsets.addrLSeqOffset = seqIdx * mmInfo.innerBatchSize * INDEX2; - addrOffsets.addrOFdSeqOffset = seqIdx * static_cast(mmInfo.innerBatchSize * - mmInfo.embeddingSize * qSeqlen) * INDEX2; + addrOffsets.addrLSeqOffset = static_cast(seqIdxint64 * mmInfo.innerBatchSize * INDEX2); + addrOffsets.addrOFdSeqOffset = + static_cast(seqIdxint64 * mmInfo.innerBatchSize * mmInfo.embeddingSize * qSeqlen * INDEX2); uint32_t nowBatchTiling = TILING_RELAY_HEAD_SIZE + shareBlockTiling * mmInfo.blockDim + nowBatch * TILING_PARA_SIZE; tilingParam[nowBatchTiling] = static_cast(qSeqlen); @@ -523,7 +533,7 @@ Status DecoderFillTilingParamRelay(const UnpadFlashAttentionInfo &mmInfo, const tilingParam[nowBatchTiling + INDEX11] = GetLoww32Bit(addrOffsets.addrOSeqOffset); tilingParam[nowBatchTiling + INDEX13] = static_cast(shareLen); tilingParam[nowBatchTiling + INDEX18] = static_cast(it.first); - tilingParam[nowBatchTiling + INDEX19] = seqIdx; + tilingParam[nowBatchTiling + INDEX19] = static_cast(seqIdx); tilingParam[nowBatchTiling + INDEX20] = GetHigh32Bit(addrOffsets.addrLSeqOffset); tilingParam[nowBatchTiling + INDEX21] = GetLoww32Bit(addrOffsets.addrLSeqOffset); tilingParam[nowBatchTiling + INDEX22] = GetHigh32Bit(addrOffsets.addrOFdSeqOffset); diff --git a/src/kernels/mixkernels/unpad_flash_attention/unpad_flash_attention_kernel.cpp b/src/kernels/mixkernels/unpad_flash_attention/unpad_flash_attention_kernel.cpp index c69f3e7e..d8845387 100644 --- a/src/kernels/mixkernels/unpad_flash_attention/unpad_flash_attention_kernel.cpp +++ b/src/kernels/mixkernels/unpad_flash_attention/unpad_flash_attention_kernel.cpp @@ -265,11 +265,11 @@ public: MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::UnpadFlashAttention), "unpad_flash_attention: param type invalid", return false); auto ¶m = AnyCast(launchParam.GetParam()); - auto batch = param.qSeqLen.size(); - auto kvHead = param.kvHead == 0 ? param.headSize : param.kvHead; + const size_t batch = param.qSeqLen.size(); + const size_t kvHead = param.kvHead == 0 ? param.headSize : param.kvHead; MKI_CHECK(batch <= 60, "batch should not larger than 60", return false); MKI_CHECK(kvHead <= 8, "kvhead should not larger than 8", return false); - auto blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); + const size_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); MKI_CHECK(batch > 0 && batch <= ND_BATCH_LIMIT, "batch is invalid", return 0); auto shareBlockTiling = ((batch * kvHead + blockDim - 1) / blockDim + 1) * RELAY_BLOCK_TILING; uint64_t bufferSize = diff --git a/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling.cpp b/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling.cpp index 7cc742b4..c282c811 100644 --- a/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling.cpp +++ b/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling.cpp @@ -236,7 +236,7 @@ Status FillFlashAttentionNzInfo(UnpadFlashAttentionNzInfo &mmInfo, const LaunchP mmInfo.alibiLeftAlign = param.alibiLeftAlign; MKI_CHECK(param.windowSize >= 0, "invalid param.windowSize < 0, which is: " << param.windowSize, return Status::FailStatus(Mki::ERROR_INVALID_VALUE)); - mmInfo.windowLen = param.windowSize; + mmInfo.windowLen = static_cast(param.windowSize); mmInfo.cacheType = param.cacheType; if (mmInfo.maskType != OpParam::UnpadFlashAttentionNz::MASK_TYPE_NONE && !(param.type == OpParam::UnpadFlashAttentionNz::UNPAD_FLASH_ATTENTION_NZ_DECODER diff --git a/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling_dependency.cpp b/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling_dependency.cpp index 66622572..1ae5548e 100644 --- a/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling_dependency.cpp +++ b/src/kernels/mixkernels/unpad_flash_attention_nz/tiling/flash_attention_nz_tiling_dependency.cpp @@ -293,7 +293,7 @@ Status GetUnpadFlashAttentionTilingParam(const UnpadFlashAttentionNzInfo mmInfo, } const uint32_t nzRealCoreNum = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); uint32_t initSize = static_cast(mmInfo.batchSize * NZ_REAL_CORE_TILING_SIZE * sizeof(uint32_t) + - GetNzRealCoreTilingOffset()); + static_cast(GetNzRealCoreTilingOffset())); auto ret = memset_s(tilingParam, tilingParamSize, 0, initSize); MKI_CHECK(ret == EOK, "Failed to clear the array", return Status::FailStatus(-1)); AddrOffsetsNz addrOffsets; diff --git a/src/ops_infer/all_to_all/all_to_all_operation.cpp b/src/ops_infer/all_to_all/all_to_all_operation.cpp index 53ba5f85..03a5afe7 100644 --- a/src/ops_infer/all_to_all/all_to_all_operation.cpp +++ b/src/ops_infer/all_to_all/all_to_all_operation.cpp @@ -26,8 +26,8 @@ namespace atb { static const int32_t IN_TENSOR_NUM = 1; static const int32_t OUT_TENSOR_NUM = 1; static const int32_t TRANSPOSE_IN_TENSOR_DIM_NUM = 2; -static const int64_t MAX_W_SIZE = 90LL * 1024LL; // 90KB -static const int64_t MAX_TENSOR_SIZE = 190LL * 1024LL * 1024LL; // 190MB +static const uint64_t MAX_W_SIZE = 90LL * 1024LL; // 90KB +static const uint64_t MAX_TENSOR_SIZE = 190LL * 1024LL * 1024LL; // 190MB template <> Status CreateOperation(const infer::AllToAllParam &opParam, Operation **operation) { @@ -112,12 +112,13 @@ Status AllToAllOperation::InferShapeCheckImpl(const SVector &inTenso << inTensorDescs.at(0).shape.dims[1] << ", rankSize: " << param_.rankSize; return ERROR_INVALID_TENSOR_DIM; } - int64_t wSize = inTensorDescs.at(0).shape.dims[TRANSPOSE_IN_TENSOR_DIM_NUM - 1] * sizeof(inTensorDescs.at(0).dtype); + uint64_t wSize = + inTensorDescs.at(0).shape.dims[TRANSPOSE_IN_TENSOR_DIM_NUM - 1] * sizeof(inTensorDescs.at(0).dtype); if (wSize / param_.rankSize >= MAX_W_SIZE) { ATB_LOG(ERROR) << "intensors[0].dims[1] / rankSize must be no greater than 90K, but got bytes: " << wSize; return ERROR_INVALID_TENSOR_DIM; } - int64_t tensorSize = Utils::GetTensorSize(inTensorDescs.at(0)); + uint64_t tensorSize = Utils::GetTensorSize(inTensorDescs.at(0)); if (tensorSize > MAX_TENSOR_SIZE) { ATB_LOG(ERROR) << "intensors[0] total tensor size must be no greater than 190MB, but got bytes: " << tensorSize; return ERROR_INVALID_TENSOR_DIM; diff --git a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp index e9182ada..7013435b 100644 --- a/src/ops_infer/linear_parallel/linear_parallel_operation.cpp +++ b/src/ops_infer/linear_parallel/linear_parallel_operation.cpp @@ -107,8 +107,9 @@ template <> Status CreateOperation(const infer::LinearParallelParam &opParam, Op ATB_LOG(ERROR) << "LinearParallelOperation DistributedInitCheck failed."; return ERROR_INVALID_PARAM; } - if (opParam.rankSize <= 0 || (opParam.rankSize & (opParam.rankSize - 1)) != 0) { - ATB_LOG(ERROR) << "LinearParallel rankSize support power of 2 but got [" << opParam.rankSize << "]"; + int rankSize = opParam.rankSize; + if (opParam.rankSize <= 0 || (rankSize & (rankSize - 1)) != 0) { + ATB_LOG(ERROR) << "LinearParallel rankSize only support positive power of 2 but got [" << rankSize << "]"; return ERROR_INVALID_PARAM; } if (opParam.backend != "hccl" && opParam.backend != "lccl" && opParam.backend != "lcoc") { diff --git a/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp b/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp index 6ee4de82..c048a46d 100644 --- a/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp +++ b/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp @@ -304,7 +304,7 @@ std::shared_ptr MmDeqSwigluQuantMmDeqOperation::CreateRunner(Context &co ATB_LOG(DEBUG) << "MallocRunner from pool failed!"; return std::make_shared(param_); } - return std::shared_ptr(runner, [&pool](Runner *runner) { pool.FreeRunner(runner); }); + return std::shared_ptr(runner, [poolPtr = &pool](Runner *runner) { poolPtr->FreeRunner(runner) };); } nlohmann::json MmDeqSwigluQuantMmDeqOperation::GetParamJson() const diff --git a/src/ops_infer/relay_attention/param.cpp b/src/ops_infer/relay_attention/param.cpp index 335a8c38..14efa483 100644 --- a/src/ops_infer/relay_attention/param.cpp +++ b/src/ops_infer/relay_attention/param.cpp @@ -42,7 +42,7 @@ bool RelayAttentionVariantPackParam::BuildFromTensor(const SVector void RelayAttentionVariantPackParam::ReintCastShapeFix(const Mki::Tensor tensor, std::vector &tensorList) { if (tensor.desc.dims.size() - 1 != tensorList[0].desc.shape.dimNum) { - int diffDimNum = tensorList[0].desc.shape.dimNum; + std::size_t diffDimNum = static_cast(tensorList[0].desc.shape.dimNum); for (std::size_t i = 0; i < tensorList.size(); i++) { for (std::size_t j = diffDimNum; j < tensor.desc.dims.size() - 1; j++) { tensorList[i].desc.shape.dims[j] = 1; diff --git a/src/ops_infer/ring_mla/ring_mla_operation.cpp b/src/ops_infer/ring_mla/ring_mla_operation.cpp index 425e47a7..6f4bf236 100644 --- a/src/ops_infer/ring_mla/ring_mla_operation.cpp +++ b/src/ops_infer/ring_mla/ring_mla_operation.cpp @@ -452,13 +452,13 @@ Status RingMLAOperation::SetupCheckImpl(const SVector &inTensors, const if (!TensorUtil::TensorDescEqual(outTensorDescs.at(i), targetOutTensorDescs.at(i))) { extError.errorDesc = OperationUtil::ConcatInfo("Invalid outTensor shape at outTensors[", i, "]."); ss << "Target outTensor shape: ["; - int dimNum = (int)targetOutTensorDescs.at(i).shape.dimNum; + int dimNum = static_cast(targetOutTensorDescs.at(i).shape.dimNum); for (int j = 0; j < dimNum - 1; ++j) { ss << targetOutTensorDescs.at(i).shape.dims[j] << ", "; } ss << targetOutTensorDescs.at(i).shape.dims[dimNum - 1] << "], "; ss << "but got outTensor shape: ["; - dimNum = (int)outTensorDescs.at(i).shape.dimNum; + dimNum = static_cast(outTensorDescs.at(i).shape.dimNum); for (int j = 0; j < dimNum - 1; ++j) { ss << outTensorDescs.at(i).shape.dims[j] << ", "; } -- Gitee From 571db81b36394efe07130db0ab5471df2d5e708b Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Thu, 3 Jul 2025 10:59:16 +0800 Subject: [PATCH 2/4] cleancode --- .../mm_deq_swiglu_quant_mm_deq_operation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp b/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp index c048a46d..f7e5fa85 100644 --- a/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp +++ b/src/ops_infer/mm_deq_swiglu_quant_mm_deq/mm_deq_swiglu_quant_mm_deq_operation.cpp @@ -304,7 +304,7 @@ std::shared_ptr MmDeqSwigluQuantMmDeqOperation::CreateRunner(Context &co ATB_LOG(DEBUG) << "MallocRunner from pool failed!"; return std::make_shared(param_); } - return std::shared_ptr(runner, [poolPtr = &pool](Runner *runner) { poolPtr->FreeRunner(runner) };); + return std::shared_ptr(runner, [poolPtr = &pool](Runner *runner) { poolPtr->FreeRunner(runner); }); } nlohmann::json MmDeqSwigluQuantMmDeqOperation::GetParamJson() const -- Gitee From 7dd5fef6cfda80c5ce73c67a0c314557d58d9605 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Thu, 3 Jul 2025 14:35:26 +0800 Subject: [PATCH 3/4] cleancode kernels --- .../tiling/mla_tiling_dependency.cpp | 34 ++++++++++++------- .../tiling/paged_attention_tiling.cpp | 1 + .../paged_attention_tiling_dependency.cpp | 12 ++++--- 3 files changed, 29 insertions(+), 18 deletions(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index 53cab8c1..4b21105b 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -266,11 +266,11 @@ Status GetNdMLADecodingMtpTilingTP1(MLAInfo &mmInfo, uint32_t blockDim, uint32_t tilingParam[tilingOffset + 1] = qSeqLen; tilingParam[tilingOffset + NUM2] = kvBlockIdx == splitNum - 1 ? nowKVSeqlen - prevKVSeqlen : nowBlocks * mmInfo.blockSize; - tilingParam[tilingOffset + NUM3] = prevKVSeqlen; - tilingParam[tilingOffset + NUM4] = mmInfo.quantFlag ? batchIdx * qSeqLen + qSeq : batchIdx * qSeqLen; - tilingParam[tilingOffset + NUM5] = splitNum; - tilingParam[tilingOffset + NUM6] = mmInfo.prevSplitNumSum; - prevKVSeqlen += tilingParam[tilingOffset + NUM2]; + tilingParam[tilingOffset + NUM3] = static_cast(prevKVSeqlen); + tilingParam[tilingOffset + NUM4] = static_cast(mmInfo.quantFlag ? batchIdx * qSeqLen + qSeq : batchIdx * qSeqLen); + tilingParam[tilingOffset + NUM5] = static_cast(splitNum); + tilingParam[tilingOffset + NUM6] = static_cast(mmInfo.prevSplitNumSum); + prevKVSeqlen += static_cast(tilingParam[tilingOffset + NUM2]); tilingOffset += TILING_PARA_SIZE_TP1; prevTaskNum++; } @@ -301,10 +301,10 @@ void GetNdMLAMtpTilingTP1(const MLAInfo &mmInfo, uint32_t &blockDim, uint32_t *t int32_t tilingOffset = TILING_HEAD_SIZE + TILING_PARA_SIZE_TP1 * prevTaskNum; int32_t curQLen = ((qSeqLen - qSeq) > maxQPerJob) ? maxQPerJob : (qSeqLen - qSeq); curKvSeq += curQLen; - tilingParam[tilingOffset] = batchIdx; - tilingParam[tilingOffset + NUM1] = beginQ + qSeq; - tilingParam[tilingOffset + NUM2] = curKvSeq; - tilingParam[tilingOffset + NUM3] = curQLen; + tilingParam[tilingOffset] = static_cast(batchIdx); + tilingParam[tilingOffset + NUM1] = static_cast(beginQ + qSeq); + tilingParam[tilingOffset + NUM2] = static_cast(curKvSeq); + tilingParam[tilingOffset + NUM3] = static_cast(curQLen); prevTaskNum++; totalTaskNum -= curQLen; MKI_LOG(INFO) << "seqIdx = " << tilingParam[tilingOffset]; @@ -319,6 +319,8 @@ void GetNdMLAMtpTilingTP1(const MLAInfo &mmInfo, uint32_t &blockDim, uint32_t *t void GetTilingHead(const MLAInfo &mmInfo, const OpParam::MLA ¶m, uint32_t *tilingParam, const uint32_t *torPtr, uint32_t blockDim) { + UNUSED_VALUE(param); + UNUSED_VALUE(blockDim); tilingParam[TILING_BATCH] = static_cast(mmInfo.batch); tilingParam[TILING_HEADSIZE] = static_cast(TILING_HEAD_SIZE); tilingParam[TILING_PARASIZE] = mmInfo.mtpTp1Flag ? static_cast(TILING_PARA_SIZE_TP1) : @@ -383,6 +385,7 @@ Status GetMLATilingParam(const LaunchParam &launchParam, MLAInfo &mmInfo, void InitTilingKWithN(const MLAInfo &mmInfo, int32_t &embeddingSizeAligned, int32_t &nIbd, int32_t seqIdx, uint32_t *tilingParam) { + UNUSED_VALUE(embeddingSizeAligned); int32_t kvSeqlen = *(mmInfo.kvSeqLen + seqIdx); int32_t kvSeqlenAligned = (kvSeqlen + BLOCK_SIZE - 1) / BLOCK_SIZE * BLOCK_SIZE; int32_t nUbd = std::min(LONG_SEQ_LEN, kvSeqlenAligned); @@ -395,6 +398,7 @@ Status CheckSeqlen(const MLAInfo &mmInfo, const AddrOffsets &addrOffsets, int32_ int32_t kvSeqlen, int32_t seqIdx) { UNUSED_VALUE(seqIdx); + UNUSED_VALUE(mmInfo); MKI_CHECK(qSeqlen >= 0, "qSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(kvSeqlen >= 0, "kvSeqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); MKI_CHECK(addrOffsets.totalQBlkNum >= 0, "totalQBlkNum overflow", @@ -477,10 +481,14 @@ Status PrefillTilingParam(const MLAInfo &mmInfo, const uint32_t &torUptr, AddrOf static_cast(addrOffsets.totalQBlkNum); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM14] = 1; auto kvFactor = GetKvFactor(mmInfo, kvSeqlen); - addrOffsets.addrQSeqOffset += static_cast(qSeqlen) * mmInfo.numHeads * mmInfo.embeddingSizeV; - addrOffsets.addrKSeqOffset += static_cast(kvFactor) * kvRealHeads * mmInfo.embeddingSizeV; - addrOffsets.addrVSeqOffset += static_cast(kvFactor) * kvRealHeads * mmInfo.embeddingSizeV; - addrOffsets.addrOSeqOffset += static_cast(qSeqlen) * mmInfo.numHeads * mmInfo.embeddingSizeV; + int64_t qSeqlenInt64 = static_cast(qSeqlen); + int64_t numHeadsInt64 = static_cast(mmInfo.numHeads); + int64_t embeddingSizeVInt64 = static_cast(mmInfo.embeddingSizeV); + int64_t kvRealHeadsInt64 = static_cast(kvRealHeads); + addrOffsets.addrQSeqOffset += static_cast(qSeqlenInt64 * numHeadsInt64 * embeddingSizeVInt64); + addrOffsets.addrKSeqOffset += static_cast(kvFactor * kvRealHeadsInt64 * embeddingSizeVInt64); + addrOffsets.addrVSeqOffset += static_cast(kvFactor * kvRealHeadsInt64 * embeddingSizeVInt64); + addrOffsets.addrOSeqOffset += static_cast(qSeqlenInt64 * numHeadsInt64 * embeddingSizeVInt64); } PrefillTilingHead(mmInfo, torUptr, addrOffsets, tilingParam, kvRealHeads); addrOffsets.block = static_cast(mmInfo.numHeads) * addrOffsets.totalQBlkNum; diff --git a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp index 3008cc1f..5bc2935f 100644 --- a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp +++ b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp @@ -156,6 +156,7 @@ Status GetPagedAttentionMaskInfo(const LaunchParam &launchParam, PagedAttentionI Status CheckPagedAttentionNdEnhance(const LaunchParam &launchParam, PagedAttentionInfo &mmInfo, OpParam::PagedAttention ¶m) { + (void)param; auto qTensor = launchParam.GetInTensor(DIM_0); auto kTensor = launchParam.GetInTensor(DIM_1); if ((qTensor.desc.dtype == TENSOR_DTYPE_FLOAT16 || qTensor.desc.dtype == TENSOR_DTYPE_BF16) && diff --git a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling_dependency.cpp b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling_dependency.cpp index 44fec937..85bb0479 100644 --- a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling_dependency.cpp +++ b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling_dependency.cpp @@ -653,7 +653,7 @@ void GetMlaMtpBatchTiling(uint32_t *tilingParam, const OpParam::PagedAttention & int32_t kvSeqlen = *(mmInfo.kvSeqLen + seqIdx); maxKVseqLen = std::max(maxKVseqLen, kvSeqlen); int32_t mUbd = qSeqlenAligned; - int32_t mIbd = ConvertValueToIndexMM(mUbd, PP_MM_NUM - 1); + std::size_t mIbd = static_cast(ConvertValueToIndexMM(mUbd, PP_MM_NUM - 1)); int32_t tilingOffset = TILING_HEAD_SIZE; int32_t curQSBlockTile = PP_MM[mIbd]; int32_t curQSBlockNum = (qSeqLen + curQSBlockTile - 1) / curQSBlockTile; // 对齐处理 @@ -909,6 +909,7 @@ Status GetNzPagedAttentionTiling(const PagedAttentionInfo &mmInfo, uint32_t &blo return Status::FailStatus(ERROR_INVALID_VALUE)); uint32_t taskNum = static_cast(taskNumI64); blockDim = taskNum < blockDim ? taskNum : blockDim; + MKI_CHECK(blockDim > 0, "blockDim is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); if (tilingParam[tilingDecodeTypeOffset] == static_cast(CalcType::CALC_TYPE_PREFILL)) { GetPaBlockTilingParallel(tilingParam, mmInfo, taskNum, blockDim); } else { @@ -937,10 +938,11 @@ Status GetPagedAttentionTilingParam(const LaunchParam &launchParam, const PagedA return Status::FailStatus(ERROR_INVALID_VALUE)); } else { auto tilingHeadSize = is910A ? TILING_HEAD_SIZE_910A : TILING_HEAD_SIZE_NZ; - curTilingParamSize = (tilingHeadSize + TILING_PARA_SIZE_NZ * param.qSeqLen.size()) * sizeof(uint32_t); - MKI_CHECK(memset_s(tilingParam, tilingParamSize, 0, - curTilingParamSize) == EOK, - "init tiling failed", return Status::FailStatus(ERROR_INVALID_VALUE)); + curTilingParamSize = static_cast((static_cast(tilingHeadSize) + + static_cast(TILING_PARA_SIZE_NZ) * param.qSeqLen.size()) * + sizeof(uint32_t)); + MKI_CHECK(memset_s(tilingParam, tilingParamSize, 0, curTilingParamSize) == EOK, "init tiling failed", + return Status::FailStatus(ERROR_INVALID_VALUE)); } float tor = mmInfo.tor; uint32_t *torPtr = reinterpret_cast(&tor); -- Gitee From 6b2cadc5da84d35325bd02db1d51b54210247f05 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Thu, 3 Jul 2025 14:36:59 +0800 Subject: [PATCH 4/4] cleancode --- .../multi_latent_attention/tiling/mla_tiling_dependency.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index 4b21105b..854f8e1c 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -265,7 +265,7 @@ Status GetNdMLADecodingMtpTilingTP1(MLAInfo &mmInfo, uint32_t blockDim, uint32_t tilingParam[tilingOffset] = batchIdx; tilingParam[tilingOffset + 1] = qSeqLen; tilingParam[tilingOffset + NUM2] = kvBlockIdx == splitNum - 1 ? - nowKVSeqlen - prevKVSeqlen : nowBlocks * mmInfo.blockSize; + nowKVSeqlen - prevKVSeqlen : nowBlocks * static_cast(mmInfo.blockSize); tilingParam[tilingOffset + NUM3] = static_cast(prevKVSeqlen); tilingParam[tilingOffset + NUM4] = static_cast(mmInfo.quantFlag ? batchIdx * qSeqLen + qSeq : batchIdx * qSeqLen); tilingParam[tilingOffset + NUM5] = static_cast(splitNum); -- Gitee