From f229ba5ae49beae5add8acb533da9fcfbfa1df42 Mon Sep 17 00:00:00 2001 From: wangyang519 Date: Thu, 21 Aug 2025 19:40:43 +0800 Subject: [PATCH 1/2] fix code check of ATB for reshapeAndCache and pagedCacheLoad --- .../paged_cache_load/tiling/paged_cache_load_tiling.cpp | 2 ++ .../reshape_and_cache/tiling/reshape_and_cache_tiling.cpp | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp b/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp index 6b20b79d..cb4d071b 100644 --- a/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp +++ b/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp @@ -59,6 +59,8 @@ bool CommonPagedCacheLoadTiling(const LaunchParam &launchParam, KernelInfo &kern MKI_CHECK(tokenSizeK > 0 && tokenSizeK <= INT_MAX, "tokenSizeK is invalid", return false); MKI_CHECK(tokenSizeV > 0 && tokenSizeV <= INT_MAX, "tokenSizeV is invalid", return false); MKI_CHECK(typeByte > 0 && typeByte <= INT_MAX, "typeByte is invalid", return false); + MKI_CHECK(tokenSizeK <= INT_MAX / blockSize / typeByte, "tokenSizeK * blockSize is too large", return false); + MKI_CHECK(tokenSizeV <= INT_MAX / blockSize / typeByte, "tokenSizeV * blockSize is too large", return false); PagedCacheLoadTilingData *tilingDataPtr = reinterpret_cast(kernelInfo.GetTilingHostAddr()); diff --git a/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp b/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp index 1893ccba..f649e826 100644 --- a/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp +++ b/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp @@ -40,6 +40,8 @@ bool CommonReshapeAndCacheTiling(const LaunchParam &launchParam, KernelInfo &ker TensorDType inDtype = launchParam.GetInTensor(0).desc.dtype; uint32_t typeByte = static_cast(GetTensorElementSize(inDtype)); tilingDataPtr->typeByte = typeByte; + MKI_CHECK(typeByte > 0 && typeByte <= INT_MAX, "typeByte is invalid", return false); + MKI_CHECK(headSizeK <= UINT32_MAX / numHeads / typeByte, "headSizeK * numHeads is too large", return false); return true; } @@ -54,6 +56,8 @@ Status ReshapeAndCacheTilingNd(const LaunchParam &launchParam, KernelInfo &kerne uint32_t headSizeV = static_cast(vShape.at(DIM_2)); MKI_CHECK(headSizeV > 0 && headSizeV <= INT_MAX, "headSizeV is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); + MKI_CHECK(headSizeV <= UINT32_MAX / tilingDataPtr->numHeads / tilingDataPtr->typeByte, + "headSizeV * numHeads is too large", return false); if (!key.desc.IsContiguous()) { ReshapeAndCacheNctTilingData *tilingDataPtr = reinterpret_cast(kernelInfo.GetTilingHostAddr()); -- Gitee From 177d55668a1ffb1a3e13bcd1fd567f0ee7bcab8d Mon Sep 17 00:00:00 2001 From: wangyang519 Date: Fri, 22 Aug 2025 09:55:14 +0800 Subject: [PATCH 2/2] fix code check of ATB for reshapeAndCache and pagedCacheLoad --- .../paged_cache_load/op_kernel/paged_cache_load_nd.cpp | 4 ++-- .../paged_cache_load/tiling/paged_cache_load_tiling.cpp | 6 ++++-- .../reshape_and_cache/tiling/reshape_and_cache_tiling.cpp | 7 +++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/kernels/mixkernels/paged_cache_load/op_kernel/paged_cache_load_nd.cpp b/src/kernels/mixkernels/paged_cache_load/op_kernel/paged_cache_load_nd.cpp index c1c3648a..b7bc8cb5 100644 --- a/src/kernels/mixkernels/paged_cache_load/op_kernel/paged_cache_load_nd.cpp +++ b/src/kernels/mixkernels/paged_cache_load/op_kernel/paged_cache_load_nd.cpp @@ -12,7 +12,7 @@ namespace PagedCacheLoad { constexpr int32_t BLOCK_SIZE = 32; -constexpr int32_t TYPEBYPE_ID = 5; +constexpr int32_t TYPEBYTE_ID = 5; constexpr int32_t UB_BUF_SIZE = 192 * 1024; constexpr int32_t VEC_SUBDIM = 2; @@ -237,7 +237,7 @@ extern "C" __global__ __aicore__ void paged_cache_load_nd( __gm__ uint8_t * __restrict__ valueOutGm, __gm__ uint8_t * __restrict__ tilingParaGm) { - int32_t typeByte = ((__gm__ uint32_t *)tilingParaGm)[PagedCacheLoad::TYPEBYPE_ID]; + int32_t typeByte = ((__gm__ uint32_t *)tilingParaGm)[PagedCacheLoad::TYPEBYTE_ID]; if (typeByte == 1) { PagedCacheLoad::PagedCacheLoadNd op; op.Init(keyCacheInGm, valueCacheInGm, blockTablesInGm, diff --git a/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp b/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp index cb4d071b..54495d2f 100644 --- a/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp +++ b/src/kernels/mixkernels/paged_cache_load/tiling/paged_cache_load_tiling.cpp @@ -59,8 +59,10 @@ bool CommonPagedCacheLoadTiling(const LaunchParam &launchParam, KernelInfo &kern MKI_CHECK(tokenSizeK > 0 && tokenSizeK <= INT_MAX, "tokenSizeK is invalid", return false); MKI_CHECK(tokenSizeV > 0 && tokenSizeV <= INT_MAX, "tokenSizeV is invalid", return false); MKI_CHECK(typeByte > 0 && typeByte <= INT_MAX, "typeByte is invalid", return false); - MKI_CHECK(tokenSizeK <= INT_MAX / blockSize / typeByte, "tokenSizeK * blockSize is too large", return false); - MKI_CHECK(tokenSizeV <= INT_MAX / blockSize / typeByte, "tokenSizeV * blockSize is too large", return false); + MKI_CHECK(tokenSizeK <= INT_MAX / blockSize / static_cast(typeByte), + "tokenSizeK * blockSize is too large", return false); + MKI_CHECK(tokenSizeV <= INT_MAX / blockSize / static_cast(typeByte), + "tokenSizeV * blockSize is too large", return false); PagedCacheLoadTilingData *tilingDataPtr = reinterpret_cast(kernelInfo.GetTilingHostAddr()); diff --git a/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp b/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp index f649e826..9effb3a9 100644 --- a/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp +++ b/src/kernels/mixkernels/reshape_and_cache/tiling/reshape_and_cache_tiling.cpp @@ -56,11 +56,12 @@ Status ReshapeAndCacheTilingNd(const LaunchParam &launchParam, KernelInfo &kerne uint32_t headSizeV = static_cast(vShape.at(DIM_2)); MKI_CHECK(headSizeV > 0 && headSizeV <= INT_MAX, "headSizeV is invalid", return Status::FailStatus(ERROR_INVALID_VALUE)); - MKI_CHECK(headSizeV <= UINT32_MAX / tilingDataPtr->numHeads / tilingDataPtr->typeByte, - "headSizeV * numHeads is too large", return false); + if (!key.desc.IsContiguous()) { ReshapeAndCacheNctTilingData *tilingDataPtr = reinterpret_cast(kernelInfo.GetTilingHostAddr()); + MKI_CHECK(headSizeV <= UINT32_MAX / tilingDataPtr->numHeads / tilingDataPtr->typeByte, + "headSizeV * numHeads is too large", return Status::FailStatus(ERROR_INVALID_VALUE)); auto &offsetK = key.desc.offset; auto &offsetV = value.desc.offset; @@ -96,6 +97,8 @@ Status ReshapeAndCacheTilingNd(const LaunchParam &launchParam, KernelInfo &kerne } else { ReshapeAndCacheTilingData *tilingDataPtr = reinterpret_cast(kernelInfo.GetTilingHostAddr()); + MKI_CHECK(headSizeV <= UINT32_MAX / tilingDataPtr->numHeads / tilingDataPtr->typeByte, + "headSizeV * numHeads is too large", return Status::FailStatus(ERROR_INVALID_VALUE)); tilingDataPtr->headSizeV = static_cast(headSizeV); blockDim = tilingDataPtr->numTokens < blockDim ? tilingDataPtr->numTokens : blockDim; bool isAlign = ((tilingDataPtr->numHeads * tilingDataPtr->headSizeK * tilingDataPtr->typeByte) % ALIGN == 0 -- Gitee