diff --git a/src/kernels/mixkernels/genattentionmask/genattentionmask_operation.cpp b/src/kernels/mixkernels/genattentionmask/genattentionmask_operation.cpp index 5b4edec75f2ad41f5297d6b030246fe632f1c255..0d17e1ddcd2ff8995b228d3a611c6b98f4fab335 100644 --- a/src/kernels/mixkernels/genattentionmask/genattentionmask_operation.cpp +++ b/src/kernels/mixkernels/genattentionmask/genattentionmask_operation.cpp @@ -45,11 +45,10 @@ public: MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::GenAttentionMask), "OpParam is invalid", return Status::FailStatus(ERROR_INFERSHAPE_ERROR, "OpParam is invalid")); - auto param = AnyCast(launchParam.GetParam()); + auto opParam = AnyCast(launchParam.GetParam()); MKI_LOG(INFO) << "infer shape param: " << param.headNum; MKI_CHECK(CheckGenAttentionMask(launchParam), "Failed to check launch param", return Status::FailStatus(ERROR_INFERSHAPE_ERROR, "Failed to check launch param")); - auto opParam = AnyCast(launchParam.GetParam()); int64_t outdims = 0; for (int i = 0; i < launchParam.GetInTensor(0).desc.dims[0]; i++) { outdims += static_cast(opParam.headNum) * opParam.qSeqLen[i] * opParam.qSeqLen[i]; diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp index 143d2a2d70f82c72dfb614d826fb164f10913d87..a15204e9b5ed645d155af1246c29151668795f45 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_kernel.cpp @@ -49,7 +49,7 @@ public: uint64_t GetTilingSize(const LaunchParam &launchParam) const override { MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::MLA), - "paged attention: param type invalid", return false); + "multi latent attention: param type invalid", return false); auto param = AnyCast(launchParam.GetParam()); auto batch = param.kvSeqLen.size(); auto maxQSeqlen = param.qSeqLen.data() != nullptr ? *std::max_element(param.qSeqLen.begin(), param.qSeqLen.end()) : 1; diff --git a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp index f7929d3f04618f3297765d31d02f66bc152a9973..dfdd6bfb17982286d7e72bdea60b1adaa6af459e 100644 --- a/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/mla_operation.cpp @@ -141,11 +141,11 @@ private: return false); MKI_CHECK(tensorKcache.desc.dtype == TENSOR_DTYPE_FLOAT16 || tensorKcache.desc.dtype == TENSOR_DTYPE_BF16, "Input1 dtype " << GetStrWithDType(tensorKcache.desc.dtype) - << " invalid, should be float16 or bfloat16 or int8", + << " invalid, should be float16 or bfloat16", return false); MKI_CHECK(tensorVcache.desc.dtype == TENSOR_DTYPE_FLOAT16 || tensorVcache.desc.dtype == TENSOR_DTYPE_BF16, "Input2 dtype " << GetStrWithDType(tensorVcache.desc.dtype) - << " invalid, should be float16 or bfloat16 or int8", + << " invalid, should be float16 or bfloat16", return false); MKI_CHECK(tensorQ.desc.dtype == tensorKcache.desc.dtype && tensorKcache.desc.dtype == tensorVcache.desc.dtype, "tensorQ K V must be the same dtype,here Q dtype is " diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp index 802ebee0097c1f28864bec5fa69e0d5670df0e68..18ce20dfba99b217f6fe15b129e61cd12e10bf6f 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling.cpp @@ -26,11 +26,12 @@ const int32_t NUM6 = 6; const int32_t NUM8 = 8; const int32_t NUM16 = 16; const int32_t NUM32 = 32; -const int32_t NUM64 = 64; +const int32_t ROPE_DIM = 64; const int32_t NUM256 = 256; const int32_t NUM512 = 512; const int32_t NUM576 = 576; const float SPLITKV_SEQLEN = 2048; +const uint32_t BLOCK_WORKSPACE_SIZE = 32768; int32_t CalcSplitNum(MLAInfo &mmInfo, int32_t blockDim, int32_t minKVSeqlen, int32_t blockSize) { @@ -193,6 +194,7 @@ Status MLATiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) GetTilingKeyTypeBase(mmInfo, qTensor, qRopeTensor); uint32_t blockDim = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_CUBE); Status ret1 = GetMLAInfo(launchParam, mmInfo, param, blockDim); + OP_TILING_CHECK_STATUS_RETURN(ret1); uint32_t *tilingParam = reinterpret_cast(kernelInfo.GetTilingHostAddr()); uint64_t tilingSize = kernelInfo.GetTilingSize(); Status ret = GetMLATilingParam(launchParam, mmInfo, blockDim, tilingParam, tilingSize); @@ -319,7 +321,7 @@ void MLAPrefillFillInfo(MLAInfo &mmInfo, OpParam::MLA ¶m, int32_t batch, int mmInfo.batch = batch; mmInfo.numHeads = param.headSize; mmInfo.embeddingSize = embed; - mmInfo.embeddingSizeV = embed - NUM64; + mmInfo.embeddingSizeV = embed - ROPE_DIM; mmInfo.qSeqLen = param.qSeqLen.data(); mmInfo.kvSeqLen = param.kvSeqLen.data(); mmInfo.tor = param.tor; @@ -398,7 +400,7 @@ Status MLAPrefillTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) ret = GenMlaPrefillTilingKey(mmInfo, kernelInfo, param); OP_TILING_CHECK_STATUS_RETURN(ret); uint64_t dataLenFloat = sizeof(float); - uint64_t sSize = static_cast(blockDim) * static_cast(32768) * 16 * dataLenFloat; + uint64_t sSize = static_cast(blockDim) * static_cast(32768) * NUM16 * dataLenFloat; kernelInfo.GetScratchSizes() = {sSize, sSize, sSize, sSize * 2}; // oTmp/S/P kernelInfo.SetBlockDim(blockDim); return Status::OkStatus(); diff --git a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp index 41d48eafa09094647c5bce93172527eac24759f4..f90a72dd4e350bf2c4db296d51af40d7fbc8e0f5 100644 --- a/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/multi_latent_attention/tiling/mla_tiling_dependency.cpp @@ -90,7 +90,7 @@ constexpr std::array QN_TILE_LIST = { 128, 64, 32, 16, 8, 1 }; using IndexArr = std::array; inline uint32_t GetHigh32Bit(uint64_t v) { return static_cast(v >> HIGH_32BIT); } -inline uint32_t GetLoww32Bit(uint64_t v) { return static_cast(v); } +inline uint32_t GetLow32Bit(uint64_t v) { return static_cast(v); } inline int32_t ConvertValueToIndexMM(int32_t val, int32_t idxBound) // 16, 7 { @@ -110,13 +110,13 @@ const int32_t NUM_TWO = 2; void GetAddrOffsetMLA(uint32_t *tilingParam, const AddrOffsets addrOffsets, const int32_t tilingOffset) { tilingParam[tilingOffset + NUM2] = GetHigh32Bit(addrOffsets.addrQSeqOffset); - tilingParam[tilingOffset + NUM3] = GetLoww32Bit(addrOffsets.addrQSeqOffset); + tilingParam[tilingOffset + NUM3] = GetLow32Bit(addrOffsets.addrQSeqOffset); tilingParam[tilingOffset + NUM4] = GetHigh32Bit(addrOffsets.addrOSeqOffset); - tilingParam[tilingOffset + NUM5] = GetLoww32Bit(addrOffsets.addrOSeqOffset); + tilingParam[tilingOffset + NUM5] = GetLow32Bit(addrOffsets.addrOSeqOffset); // mask offset tilingParam[tilingOffset + NUM6] = GetHigh32Bit(addrOffsets.addrMaskOffset); - tilingParam[tilingOffset + NUM7] = GetLoww32Bit(addrOffsets.addrMaskOffset); + tilingParam[tilingOffset + NUM7] = GetLow32Bit(addrOffsets.addrMaskOffset); } int32_t GetQNBlockTile(const MLAInfo &mmInfo, int32_t qSeqLen) @@ -408,19 +408,19 @@ void FillPrefillTilingOffsetParam(int32_t seqIdx, AddrOffsets &addrOffsets, uint tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM4] = GetHigh32Bit(addrOffsets.addrQSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM5] = - GetLoww32Bit(addrOffsets.addrQSeqOffset); + GetLow32Bit(addrOffsets.addrQSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM6] = GetHigh32Bit(addrOffsets.addrKSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM7] = - GetLoww32Bit(addrOffsets.addrKSeqOffset); + GetLow32Bit(addrOffsets.addrKSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM8] = GetHigh32Bit(addrOffsets.addrVSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM9] = - GetLoww32Bit(addrOffsets.addrVSeqOffset); + GetLow32Bit(addrOffsets.addrVSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM10] = GetHigh32Bit(addrOffsets.addrOSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM11] = - GetLoww32Bit(addrOffsets.addrOSeqOffset); + GetLow32Bit(addrOffsets.addrOSeqOffset); } uint32_t GetPrefillTilingKey(const MLAInfo &mmInfo) @@ -537,7 +537,7 @@ Status GetMLAPrefillTilingParam(const MLAInfo &mmInfo, uint32_t &blockDim, } MKI_LOG(INFO) << i << " index: " << tilingParam[i] << ","; } - MKI_CHECK(addrOffsets.block > 0 && static_cast(addrOffsets.block) <= UINT32_MAX, "blockDim overflow", + MKI_CHECK(addrOffsets.block > 0 && static_cast(addrOffsets.block) <= UINT32_MAX, "blockDim overflow", return Status::FailStatus(ERROR_INVALID_VALUE)); uint32_t logicalblock = static_cast(addrOffsets.block); MKI_LOG(INFO) << "totalQBlkNum is " << addrOffsets.totalQBlkNum; diff --git a/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp b/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp index 0a889ded4ec3f0330054f257d35d3a25b8c38be1..cdbc99af9dd6d31e46da7b791ab1ddcbd5d49427 100644 --- a/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp +++ b/src/kernels/mixkernels/pagedattention/paged_attention_operation.cpp @@ -182,7 +182,7 @@ private: } MKI_CHECK(embeddingDimQ > 0, "Shape of Input0 invalid, headSize must > 0", return false); if (param.type != OpParam::PagedAttention::PAGED_ATTENTION_MASK_ND) { - MKI_CHECK(CheckPagedMLAttetionCache(launchParam, param, embeddingDimQ), + MKI_CHECK(CheckPagedMLAttentionCache(launchParam, param, embeddingDimQ), "check cache shape fail", return false); } else { MKI_CHECK(CheckPagedAttentionCache(launchParam, param, embeddingDimQ), @@ -235,14 +235,14 @@ private: return true; } - bool CheckPagedMLAttetionCache(const LaunchParam &launchParam, const OpParam::PagedAttention ¶m, + bool CheckPagedMLAttentionCache(const LaunchParam &launchParam, const OpParam::PagedAttention ¶m, int64_t embeddingDimQ) const { auto &tensorKcache = launchParam.GetInTensor(DIM_1); // K.shape = [num_blocks, block_size, num_heads, head_size] MKI_CHECK(tensorKcache.desc.dtype == TENSOR_DTYPE_FLOAT16 || tensorKcache.desc.dtype == TENSOR_DTYPE_BF16 || tensorKcache.desc.dtype == TENSOR_DTYPE_INT8, "Input1 dtype " << GetStrWithDType(tensorKcache.desc.dtype) - << " invalid, should be float16 or BF16", + << " invalid, should be float16 or bfloat16 or int8", return false); static const size_t KV_CACHE_DIM_NUM = 4; MKI_CHECK(tensorKcache.desc.dims.size() == KV_CACHE_DIM_NUM, @@ -255,7 +255,7 @@ private: MKI_CHECK(embeddingDimV <= embeddingDimK, "embeddingDimV should not exceed embeddingDimK for combined cache paged-MLA", return false); MKI_CHECK(numBlocks > 0 && blockSize > 0 && kvHead > 0 && embeddingDimK == embeddingDimQ, - "Shape of key_cache is vaild, must be [numBlocks, blockSize, numHeads, headSize]", return false); + "Shape of key_cache is valid, must be [numBlocks, blockSize, numHeads, headSize]", return false); if (tensorKcache.desc.dtype == TENSOR_DTYPE_FLOAT16 || tensorKcache.desc.dtype == TENSOR_DTYPE_BF16) { MKI_CHECK(kvHead == 1, "MLA operation ONLY supports MQA with 16bit input ", return false); } @@ -287,13 +287,13 @@ private: MKI_CHECK(tensorKcache.desc.dtype == TENSOR_DTYPE_FLOAT16 || tensorKcache.desc.dtype == TENSOR_DTYPE_BF16 || tensorKcache.desc.dtype == TENSOR_DTYPE_INT8, "Input1 dtype " << GetStrWithDType(tensorKcache.desc.dtype) - << " invalid, should be float16 or BF16", + << " invalid, should be float16 or bfloat16 or int8", return false); auto &tensorVcache = launchParam.GetInTensor(DIM_2); // V.shape = [num_blocks, block_size, num_heads, head_size] MKI_CHECK(tensorVcache.desc.dtype == TENSOR_DTYPE_FLOAT16 || tensorVcache.desc.dtype == TENSOR_DTYPE_BF16 || tensorVcache.desc.dtype == TENSOR_DTYPE_INT8, "Input2 dtype " << GetStrWithDType(tensorVcache.desc.dtype) - << " invalid, should be float16 or BF16", + << " invalid, should be float16 or bfloat16 or int8", return false); static const size_t KV_CACHE_DIM_NUM = 4; MKI_CHECK(tensorKcache.desc.dims.size() == KV_CACHE_DIM_NUM, @@ -310,7 +310,7 @@ private: auto embeddingDimK = tensorKcache.desc.dims[DIM_3]; MKI_CHECK(numBlocks > 0 && blockSize > 0 && kvHead > 0 && embeddingDimK == embeddingDimQ, - "Shape of key_cache is vaild, must be [numBlocks, blockSize, numHeads, headSize]", return false); + "Shape of key_cache is valid, must be [numBlocks, blockSize, numHeads, headSize]", return false); if (!param.compressHead) { MKI_CHECK(kvHead == static_cast(param.kvHead > 0 ? param.kvHead : param.headSize), @@ -658,10 +658,10 @@ private: MKI_CHECK(tensorVcache.desc.dims[DIM_3] == 16, "V_cache Shape should be in nz format", return false); auto &blockTables = launchParam.GetInTensor(DIM_3); // K.shape = [num_blocks, hidden_size/16, block_size, 16] MKI_CHECK(blockTables.desc.dtype == TENSOR_DTYPE_INT32, - "Input3 dtype " << GetStrWithDType(blockTables.desc.dtype) << " invalid, should be float16", return false); + "Input3 dtype " << GetStrWithDType(blockTables.desc.dtype) << " invalid, should be int32", return false); auto &blockContexLen = launchParam.GetInTensor(DIM_4); // V.shape = [num_blocks, hidden_size/16, block_size, 16] MKI_CHECK(blockContexLen.desc.dtype == TENSOR_DTYPE_INT32, "Input4 dtype " << - GetStrWithDType(blockContexLen.desc.dtype) << " invalid, should be float16", return false); + GetStrWithDType(blockContexLen.desc.dtype) << " invalid, should be int32", return false); static const size_t BLOCK_TABLE_DIM_NUM = 2; MKI_CHECK(blockTables.desc.dims.size() == BLOCK_TABLE_DIM_NUM, "Input3 dim num " << blockTables.desc.dims.size() << " invalid, should be " << BLOCK_TABLE_DIM_NUM, return false); @@ -719,7 +719,7 @@ private: {{batch, maxQ, maxKv}, norm}, {{headSize, maxQ, maxKv}, alibi}, {{static_cast(batch) / kvHead, maxQ, maxKv}, norm && param.compressHead}, - {{headSize, maxQ, maxKv}, alibi && alibi && param.compressHead}, + {{headSize, maxQ, maxKv}, alibi && param.compressHead}, {{headSize, maxQ, LONG_SEQ_LEN}, isAlibiCompress}, {{batch, headSize, maxQ, maxKv}, true}, {{static_cast(batch) / kvHead, headSize, maxQ, maxKv}, alibi && param.compressHead}, diff --git a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp index 3008cc1fad5c30e7a84abc481ae9338bbb97b229..d4323abd713762506d70893df5ddcf8d2186fde0 100644 --- a/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp +++ b/src/kernels/mixkernels/pagedattention/tiling/paged_attention_tiling.cpp @@ -277,7 +277,7 @@ Status CheckQSeqlenValid(const std::vector& qSeqlen) int64_t numTokens = 0; for (size_t idx = 0; idx < qSeqlen.size(); ++idx) { int64_t curQSeqlen = static_cast(qSeqlen.at(idx)); - MKI_CHECK(curQSeqlen <= INT32_MAX && curQSeqlen > 0, "seqlen is invalid", + MKI_CHECK(curQSeqlen > 0, "seqlen is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); numTokens += curQSeqlen; MKI_CHECK(numTokens <= INT32_MAX, "seqlen is invalid", @@ -427,7 +427,7 @@ Status PagedAttentionTiling(const LaunchParam &launchParam, KernelInfo &kernelIn uint64_t basicWorkSpaceFloat = blockDim * WORKSPACE_BLOCK_SIZE_DB * dataLenFloat; if (param.type == OpParam::PagedAttention::PAGED_ATTENTION_MASK_ND) { oCoreTempSize = blockDim * SPLITKV_RATION * static_cast(mmInfo.numHeads) * - blockDim * static_cast(mmInfo.embeddingSize * sizeof(float)); + blockDim * static_cast(mmInfo.embeddingSize * sizeof(float)); // (num_tokens, num_heads, kvsplit) lCoreTempSize = blockDim * SPLITKV_RATION * static_cast(mmInfo.numHeads) * blockDim * sizeof(float); diff --git a/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp b/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp index 9979f12c4801c637219b3e5937302e0dd8a73298..852daebfc0e33d3491b2a871e1a8c4656e6168c2 100644 --- a/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp +++ b/src/kernels/mixkernels/ring_mla/ring_mla_operation.cpp @@ -60,8 +60,6 @@ public: MKI_CHECK(specificParam.Type() == typeid(OpParam::RINGMLA), "OpParam is invalid", return 0); auto param = AnyCast(specificParam); switch (param.type) { - case OpParam::RINGMLA::SPLIT_CACHE: - return DIM_2; case OpParam::RINGMLA::PREFILL_SPLIT_CACHE: return DIM_2; default: diff --git a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp index 7178dedfe9ecf2769d62d88dc75d59b6934a31f0..0cea30f4e4df51e4dae37620b10cf67eb368b14f 100644 --- a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp +++ b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling.cpp @@ -28,6 +28,7 @@ const int32_t NUM256 = 256; const int32_t NUM512 = 512; const int32_t NUM576 = 576; const float SPLITKV_RATION = 0.8; +const int32_t DEFAULT_HEADDIM = 192; Status GetMLANdInfo(const LaunchParam &launchParam, RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m) @@ -275,7 +276,7 @@ Status InitInfo(RINGMLAInfo &mmInfo, OpParam::RINGMLA ¶m) SVector kcacheShape; SVector vcacheShape; kcacheShape = mmInfo.tensors.kCache.desc.dims; - int32_t embed = 192; + int32_t embed = DEFAULT_HEADDIM; int32_t maxKvSeqLen = 0; if (kcacheShape.size() == DIM_3) { diff --git a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp index 1c0e6dfdb8f111a6b978d7ed7079d26defdfca21..8327547e15cfedd4cf54de50242abf7b96045bcc 100644 --- a/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp +++ b/src/kernels/mixkernels/ring_mla/tiling/ring_mla_tiling_dependency.cpp @@ -106,7 +106,7 @@ constexpr std::array QN_TILE_LIST = { 128, 64, 32, 16, 8, 1 }; using IndexArr = std::array; inline uint32_t GetHigh32Bit(uint64_t v) { return static_cast(v >> HIGH_32BIT); } -inline uint32_t GetLoww32Bit(uint64_t v) { return static_cast(v); } +inline uint32_t GetLow32Bit(uint64_t v) { return static_cast(v); } inline int32_t ConvertValueToIndexMM(int32_t val, int32_t idxBound) // 16, 7 { @@ -126,13 +126,13 @@ const int32_t NUM_TWO = 2; void GetAddrOffsetRINGMLA(uint32_t *tilingParam, const AddrOffsets addrOffsets, const int32_t tilingOffset) { tilingParam[tilingOffset + NUM2] = GetHigh32Bit(addrOffsets.addrQSeqOffset); - tilingParam[tilingOffset + NUM3] = GetLoww32Bit(addrOffsets.addrQSeqOffset); + tilingParam[tilingOffset + NUM3] = GetLow32Bit(addrOffsets.addrQSeqOffset); tilingParam[tilingOffset + NUM4] = GetHigh32Bit(addrOffsets.addrOSeqOffset); - tilingParam[tilingOffset + NUM5] = GetLoww32Bit(addrOffsets.addrOSeqOffset); + tilingParam[tilingOffset + NUM5] = GetLow32Bit(addrOffsets.addrOSeqOffset); // mask offset tilingParam[tilingOffset + NUM6] = GetHigh32Bit(addrOffsets.addrMaskOffset); - tilingParam[tilingOffset + NUM7] = GetLoww32Bit(addrOffsets.addrMaskOffset); + tilingParam[tilingOffset + NUM7] = GetLow32Bit(addrOffsets.addrMaskOffset); } int32_t GetQNBlockTile(const RINGMLAInfo &mmInfo, int32_t qSeqLen) @@ -319,19 +319,19 @@ void RingFillPrefillTilingOffsetParam(int32_t seqIdx, AddrOffsets &addrOffsets, tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM4] = GetHigh32Bit(addrOffsets.addrQSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM5] = - GetLoww32Bit(addrOffsets.addrQSeqOffset); + GetLow32Bit(addrOffsets.addrQSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM6] = GetHigh32Bit(addrOffsets.addrKSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM7] = - GetLoww32Bit(addrOffsets.addrKSeqOffset); + GetLow32Bit(addrOffsets.addrKSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM8] = GetHigh32Bit(addrOffsets.addrVSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM9] = - GetLoww32Bit(addrOffsets.addrVSeqOffset); + GetLow32Bit(addrOffsets.addrVSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM10] = GetHigh32Bit(addrOffsets.addrOSeqOffset); tilingParam[TILING_HEAD_SIZE_PREFILL + seqIdx * TILING_PARA_SIZE_PREFILL + NUM11] = - GetLoww32Bit(addrOffsets.addrOSeqOffset); + GetLow32Bit(addrOffsets.addrOSeqOffset); } void PrefillTilingHead(const RINGMLAInfo &mmInfo, const uint32_t &torUptr, AddrOffsets &addrOffsets, diff --git a/src/kernels/mixkernels/unpad/tiling/unpad_tiling.cpp b/src/kernels/mixkernels/unpad/tiling/unpad_tiling.cpp index bb0cf714a842c621b600afe430708943a7583ad8..f997c9b89f3530e0887aad8a963f4bc80a2d4f79 100644 --- a/src/kernels/mixkernels/unpad/tiling/unpad_tiling.cpp +++ b/src/kernels/mixkernels/unpad/tiling/unpad_tiling.cpp @@ -13,6 +13,7 @@ namespace AtbOps { using namespace Mki; +const uint64_t SYSTEM_WORKSPACE_SIZE = 16; void FillTilingParam(const LaunchParam &launchParam, UnpadTilingData *tilingDataPtr) { tilingDataPtr->padLength = launchParam.GetInTensor(0).desc.dims[1]; @@ -25,7 +26,7 @@ Status UnpadTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo) FillTilingParam(launchParam, tilingDataPtr); kernelInfo.SetBlockDim(1); - uint64_t sysWorkspaceSize = 16; + uint64_t sysWorkspaceSize = SYSTEM_WORKSPACE_SIZE; kernelInfo.GetScratchSizes() = {sysWorkspaceSize}; return Status::OkStatus(); }