diff --git a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/CMakeLists.txt b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/CMakeLists.txt index fce09c38fae463a41a5934fecc97b0efcd92e007..dfbc7b76a656b28db89041ab1df4d01ea255a6db 100644 --- a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/CMakeLists.txt +++ b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/CMakeLists.txt @@ -16,6 +16,7 @@ endif() # ${KERNEL_FILES} are used to compile library, push files written by ascendc in ${KERNEL_FILES}. # ref to cmake/npu.cmake ascendc_library, cmake/cpu.cmake add_library file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/quant_group_matmul_custom.cpp) +file(GLOB KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/gemm_a8w4_combine.cpp) if("${RUN_MODE}" STREQUAL "cpu") include(cmake/cpu_lib.cmake) @@ -28,6 +29,7 @@ endif() add_executable(ascendc_kernels_bbit ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/quant_group_matmul_custom_tiling.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/gemm_a8w4_combine_tiling.cpp ) target_compile_options(ascendc_kernels_bbit PRIVATE diff --git a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine.cpp b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine.cpp new file mode 100755 index 0000000000000000000000000000000000000000..94d677f80c74a7339a90f4f559fef569a616544b --- /dev/null +++ b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine.cpp @@ -0,0 +1,687 @@ +/** + * @file quant_group_matmul_custom.cpp + * + * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include "kernel_operator.h" +#include "lib/matmul_intf.h" +#include "gemm_a8w4_combine_tiling.h" + +using namespace matmul; +using namespace AscendC; + +using aT = MatmulType; +using bT = MatmulType; +using BiasT = MatmulType; +using cT = MatmulType; + +using MT = matmul::MatmulImpl; + +constexpr uint32_t BROADCAST_DIM = 2; +constexpr uint64_t SYNC_AIV_TO_AIC = 3; +constexpr uint64_t SYNC_AIC_TO_AIV = 5; +constexpr uint32_t BUFFER_NUM = 2; +static constexpr uint32_t POST_SKIP_ITER_NUM = 1; + +struct MNConfig { + uint32_t m = 0; + uint32_t k = 0; + uint32_t n = 0; + uint32_t baseM = 0; + uint32_t baseN = 0; + uint32_t mIdx = 0; + uint32_t nIdx = 0; + uint32_t blockDimM = 0; + uint32_t blockDimN = 0; + uint32_t singleM = 0; + uint32_t singleN = 0; + uint32_t offsetM = 0; + uint64_t workSpaceOffset = 0; + uint32_t tailN = 0; + uint32_t groupIdx = 0; +}; + +template +__aicore__ inline void DataCopyPad2D(const LocalTensor dst, const GlobalTensor src, uint32_t dim1, uint32_t dim0, + uint32_t srcDim0) { + DataCopyExtParams params; + params.blockCount = dim1; + params.blockLen = dim0 * sizeof(T); + params.srcStride = (srcDim0 - dim0) * sizeof(T); + // 32: int32 -> float16, 为防止跨行数据进入同一32B block,提前每行按偶数block对齐 + params.dstStride = Ceil(dim0 * sizeof(T), 32) % 2; + + DataCopyPadExtParams padParams{true, 0, 0, 0}; + DataCopyPad(dst, src, params, padParams); +} + +template +__aicore__ inline void DataCopyPad2D(const GlobalTensor dst, const LocalTensor src, uint32_t dim1, uint32_t dim0, + uint32_t srcDim0, uint32_t dstDim0) { + DataCopyExtParams params; + params.blockCount = dim1; + params.blockLen = dim0 * sizeof(T); + // 32: ub访问粒度为32B + params.srcStride = (srcDim0 - dim0) * sizeof(T) / 32; + params.dstStride = (dstDim0 - dim0) * sizeof(T); + DataCopyPad(dst, src, params); +} + +class A8W4GroupMatmul { +public: + __aicore__ inline A8W4GroupMatmul(MT &matmul) : mm(matmul) {} + __aicore__ inline void Init(GM_ADDR inputs, GM_ADDR weights, GM_ADDR bias, + GM_ADDR logits, GM_ADDR group_tokens, GM_ADDR token_ranks, GM_ADDR residual, + GM_ADDR inputs_dq_channel_scale, GM_ADDR weight_group_scale, GM_ADDR inputs_dq_token_scale, + GM_ADDR outputs, GM_ADDR workspace, GemmA8W4CombineTilingData *tilingData, TPipe *tPipeIn); + __aicore__ inline void Process(); +private: + __aicore__ inline void InitUbBuffer(); + __aicore__ inline void CastWeightProcess(MNConfig& mnConfig, uint32_t secondHalfIterCount); + __aicore__ inline void MMCompute(MNConfig& mnConfig, uint32_t secondHalfIterCount); + __aicore__ inline void VectorCompute(MNConfig& mnConfig, bool isLastGroup, uint32_t secondHalfIterCount); + __aicore__ inline void PostCompute(); + __aicore__ inline void TailProcess(MNConfig &mnConfig, uint32_t secondHalfIterCount); + __aicore__ inline void ComputeDequantAndFinalizeRounting(MNConfig& mnConfig, uint32_t curVecBaseM, uint32_t alignBaseN, + uint32_t curVecBaseN, uint32_t offsetM); + __aicore__ inline void DataCopyScale(uint32_t curBaseN, uint32_t alignBaseN, uint64_t scaleOffset); + __aicore__ inline void DataCopyPerTokenScaleAndLogits(MNConfig& mnConfig, uint32_t curBaseM, uint32_t alignBaseN, + uint32_t offsetM); + __aicore__ inline void PrefillWorkspace(); + __aicore__ inline void DataCopyAntiQuantScale(uint32_t curBaseN, uint32_t alignBaseN, uint64_t realScaleOffset); + __aicore__ inline void ComputeUbBaseK(uint32_t curSingleK, uint32_t offsetK, uint32_t newBaseK, + uint32_t& curUsedGroupSize, uint32_t& curBaseK); + __aicore__ inline void FreeScaleAndOffset(bool& firstLoop); + __aicore__ inline void SetGmToUbDataCopyParams(const uint32_t curBaseN, const uint32_t curBaseK, + const MNConfig& mnConfig, DataCopyExtParams& intriParams); + __aicore__ inline void CastWeightCompute(uint32_t curCalcK, uint32_t curCalcAlignN); + __aicore__ inline void SetUbToGmDataCopyParams(const uint32_t curBaseN, const uint32_t alignRowLen, + const uint32_t curBaseK, const MNConfig& mnConfig, + DataCopyExtParams& intriParams); + __aicore__ inline void DataCopyLogitsAndBrcb(MNConfig& mnConfig, uint32_t curBaseM, uint32_t alignBaseN, uint32_t offsetM); + + + +private: + MT& mm; + GlobalTensor xGm; + GlobalTensor weightWorkspaceGm; + GlobalTensor weightGm; + GlobalTensor biasGm; + GlobalTensor logitsGm; + GlobalTensor tokenRanksGm; + GlobalTensor residualGm; + GlobalTensor mmOutGm; + GlobalTensor perChannelScaleGm; + GlobalTensor perGroupScaleGm; + GlobalTensor perTokenScaleGm; + GlobalTensor groupListGm; + GlobalTensor yGm; + // define the que + TQue vecInQueue; + TQue vecOutQueue; + TQue nAxisInQueue; + TQue perTokenScaleInQueue; + TQue logitsInQueue; + TQue tokenRanksInQueue; + TBuf tmpBuff; + LocalTensor perChannelScaleInUb; + LocalTensor perTokenScaleInUb; + LocalTensor perGroupScaleInUb; + LocalTensor logitsInUb; + LocalTensor logitsBrcbInUb; + LocalTensor residualLocal; + + LocalTensor dequantMiddleResult; + LocalTensor sharedTmpLocal; + LocalTensor mulsResultLocal; + LocalTensor pertokenBrcbLocal; + LocalTensor actResultLocal; + uint32_t subBlockIdx; + uint32_t coreIdx; + uint32_t taskRation; + uint32_t cubeCount = 0; + TPipe *pipe; + uint32_t perGroupSize; + GemmA8W4CombineTilingData *tiling; + MNConfig mnConfigs[POST_SKIP_ITER_NUM + 1]; +}; + +__aicore__ inline void A8W4GroupMatmul::Init(GM_ADDR inputs, GM_ADDR weights, GM_ADDR bias, + GM_ADDR logits, GM_ADDR group_tokens, GM_ADDR token_ranks, GM_ADDR residual, + GM_ADDR inputs_dq_channel_scale, GM_ADDR weight_group_scale, GM_ADDR inputs_dq_token_scale, + GM_ADDR outputs, GM_ADDR workspace, GemmA8W4CombineTilingData *tilingData, TPipe *tPipeIn) +{ + tiling = tilingData; + xGm.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(inputs)); + weightWorkspaceGm.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(workspace)); + weightGm.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(weights)); + biasGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias)); // unused + + logitsGm.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(logits)); + tokenRanksGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t *>(token_ranks)); + residualGm.SetGlobalBuffer(reinterpret_cast<__gm__ bfloat16_t *>(residual)); + mmOutGm.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(workspace + tiling->workspaceSize)); + perChannelScaleGm.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputs_dq_channel_scale)); + perGroupScaleGm.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(weight_group_scale)); + perTokenScaleGm.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(inputs_dq_token_scale)); + groupListGm.SetGlobalBuffer(reinterpret_cast<__gm__ int64_t *>(group_tokens)); + yGm.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(outputs)); + subBlockIdx = GetSubBlockIdx(); + coreIdx = GetBlockIdx(); + taskRation = GetTaskRation(); + if ASCEND_IS_AIV { + coreIdx /= taskRation; + } + pipe = tPipeIn; + perGroupSize = tiling->perGroupSize; + InitUbBuffer(); +} + +__aicore__ inline void A8W4GroupMatmul::InitUbBuffer() +{ + if ASCEND_IS_AIC { + return; + } + pipe->InitBuffer(nAxisInQueue, BUFFER_NUM, 1024 * sizeof(float)); + pipe->InitBuffer(perTokenScaleInQueue, BUFFER_NUM, 64 * sizeof(float)); + pipe->InitBuffer(logitsInQueue, BUFFER_NUM, 64 * sizeof(float)); + pipe->InitBuffer(vecInQueue, BUFFER_NUM, tiling->ubCalSize * sizeof(int32_t)); + pipe->InitBuffer(vecOutQueue, BUFFER_NUM, tiling->ubCalSize * sizeof(float)); + pipe->InitBuffer(tmpBuff, tiling->ubRestBytes); + uint32_t ubCalSizeFloat = tiling->ubCalSize * sizeof(float); + // ub分配,依次划分中间结果 + dequantMiddleResult = tmpBuff.GetWithOffset(tiling->ubCalSize, 0); + logitsBrcbInUb = tmpBuff.GetWithOffset(tiling->ubCalSize, 0); + + pertokenBrcbLocal = tmpBuff.GetWithOffset(tiling->ubCalSize, ubCalSizeFloat); + // 2: 偏移两份,前面为反量化输出和pertoken scale + mulsResultLocal = tmpBuff.GetWithOffset(tiling->ubCalSize, 3 * ubCalSizeFloat); + // api需要的临时空间,复用中间结果的空间 + // 2: ub临时空间总共4份,高层api分配两份 + sharedTmpLocal = tmpBuff.GetWithOffset(2 * ubCalSizeFloat, 2 * ubCalSizeFloat); +} + +__aicore__ inline void A8W4GroupMatmul::Process() +{ + //PrefillWorkspace(); + //SyncAll(); + + MNConfig mnConfig; + mnConfig.baseM = tiling->mmTilingData.baseM; + mnConfig.baseN = tiling->mmTilingData.baseN; + mnConfig.singleN = tiling->singleN; + mnConfig.singleM = mnConfig.baseM; + mnConfig.blockDimN = Ceil(tiling->n, mnConfig.singleN); + mnConfig.n = tiling->n; + mnConfig.k = tiling->k; + uint32_t secondHalfIterCount = 0; + for (uint32_t groupIdx = 0, preCount = 0; groupIdx < tiling->groupNum; ++groupIdx) { + uint32_t m = static_cast(groupListGm.GetValue(groupIdx)); + if (m <= 0) { + continue; + } + mnConfig.m = static_cast(m); + mnConfig.blockDimM = Ceil(mnConfig.m, mnConfig.singleM); + uint32_t curCount = preCount + mnConfig.blockDimN * mnConfig.blockDimM; + uint32_t curBlock = coreIdx >= preCount ? coreIdx : coreIdx + tiling->coreNum; + + while (curBlock < curCount) { + mnConfig.mIdx = (curBlock - preCount) / mnConfig.blockDimN; + mnConfig.nIdx = (curBlock - preCount) % mnConfig.blockDimN; + mnConfig.tailN = mnConfig.nIdx * mnConfig.singleN; + mnConfig.groupIdx = groupIdx; + if ASCEND_IS_AIV { + CastWeightProcess(mnConfig, secondHalfIterCount); + } + MMCompute(mnConfig, secondHalfIterCount); + VectorCompute(mnConfig, false, secondHalfIterCount); + secondHalfIterCount++; + + curBlock += tiling->coreNum; + } + + preCount = curCount % tiling->coreNum; + mnConfig.offsetM += mnConfig.m; + } + TailProcess(mnConfig, secondHalfIterCount); + //PostCompute(); +} + +__aicore__ inline void A8W4GroupMatmul::ComputeUbBaseK( + uint32_t curSingleK, uint32_t offsetK, uint32_t newBaseK, uint32_t& curUsedGroupSize, uint32_t& curBaseK) { + if (unlikely(offsetK + newBaseK >= curUsedGroupSize)) { + curBaseK = curUsedGroupSize - offsetK; + curUsedGroupSize += perGroupSize; + if (offsetK + curBaseK > curSingleK) { + curBaseK = curSingleK - offsetK; + } + } else if (unlikely(offsetK + newBaseK > curSingleK)) { + curBaseK = curSingleK - offsetK; + } else { + curBaseK = newBaseK; + } +} + +__aicore__ inline void A8W4GroupMatmul::DataCopyAntiQuantScale(uint32_t curBaseN, uint32_t alignBaseN, + uint64_t realScaleOffset) { + // copy scale and offset frome GM + DataCopyPadParams padParams; + DataCopyParams scaleParams; + scaleParams.blockLen = curBaseN * sizeof(half); + scaleParams.blockCount = 1; + scaleParams.srcStride = 0; + scaleParams.dstStride = 0; + + LocalTensor perGroupScaleLocal = nAxisInQueue.AllocTensor(); + + DataCopyPad(perGroupScaleLocal, perGroupScaleGm[realScaleOffset], scaleParams, padParams); + + nAxisInQueue.EnQue(perGroupScaleLocal); + perGroupScaleInUb = nAxisInQueue.DeQue(); + perGroupScaleInUb.SetSize(alignBaseN); +} + +__aicore__ inline void A8W4GroupMatmul::SetGmToUbDataCopyParams(const uint32_t curBaseN, + const uint32_t curBaseK, const MNConfig& mnConfig, DataCopyExtParams& intriParams) { + if (tiling->transposeW) { + intriParams.blockLen = curBaseK / 2; + intriParams.blockCount = curBaseN; + intriParams.srcStride = (mnConfig.k - curBaseK) / 2; + intriParams.dstStride = 0; + } else { + intriParams.blockLen = curBaseN / 2; + intriParams.blockCount = curBaseK; + intriParams.srcStride = (mnConfig.n - curBaseN) / 2; + intriParams.dstStride = 0; + } +} + +__aicore__ inline void A8W4GroupMatmul::FreeScaleAndOffset(bool& firstLoop) { + if (firstLoop) { + firstLoop = false; + } else { + nAxisInQueue.FreeTensor(perGroupScaleInUb); + } +} + +__aicore__ inline void A8W4GroupMatmul::CastWeightCompute(uint32_t curCalcK, uint32_t curCalcAlignN) { + LocalTensor wInUb = vecInQueue.DeQue(); + uint32_t computeSize = curCalcK * curCalcAlignN; + wInUb.SetSize(computeSize); + LocalTensor wMidFp16Ub = tmpBuff.GetWithOffset(tiling->ubCalSize * 2, 0); + Cast(wMidFp16Ub, wInUb, RoundMode::CAST_NONE, computeSize); + PipeBarrier(); + vecInQueue.FreeTensor(wInUb); + uint8_t repeatSride = curCalcAlignN / 16; + for (uint32_t idx = 0; idx < curCalcAlignN; idx = idx + 128) { + Mul(wMidFp16Ub[idx], wMidFp16Ub[idx], perGroupScaleInUb[idx], 128, curCalcK, + {1, 1, 1, repeatSride, repeatSride, 0}); + PipeBarrier(); + } + LocalTensor wResUb = vecOutQueue.AllocTensor(); + Cast(wResUb, wMidFp16Ub, RoundMode::CAST_RINT, computeSize); + PipeBarrier(); + vecOutQueue.EnQue(wResUb); +} + +__aicore__ inline void A8W4GroupMatmul::SetUbToGmDataCopyParams(const uint32_t curBaseN, + const uint32_t alignRowLen, const uint32_t curBaseK, const MNConfig& mnConfig, DataCopyExtParams& intriParams) { + if (tiling->transposeW) { + uint32_t alignBaseK = (Ceil(curBaseK, uint32_t(8)) * 8); + intriParams.blockLen = curBaseK * sizeof(int8_t); + intriParams.blockCount = curBaseN; + intriParams.srcStride = (alignRowLen - curBaseK) * sizeof(int8_t) / 32; + intriParams.dstStride = (mnConfig.k - curBaseK) * sizeof(int8_t); + } else { + intriParams.blockLen = curBaseN * sizeof(int8_t); + intriParams.blockCount = curBaseK; + intriParams.srcStride = 0; + intriParams.dstStride = (alignRowLen - curBaseN) * sizeof(int8_t); + } +} + +__aicore__ inline void A8W4GroupMatmul::CastWeightProcess(MNConfig& mnConfig, uint32_t secondHalfIterCount) { + uint32_t workspacePieceIdx = secondHalfIterCount % (POST_SKIP_ITER_NUM + 1); + uint32_t curSingleK = mnConfig.k; + uint32_t curSingleN = mnConfig.singleN; + if (mnConfig.nIdx == mnConfig.blockDimN - 1) { + curSingleN = tiling->n - mnConfig.tailN; + } + uint32_t newBaseN = tiling->ubBaseN; + uint32_t newBaseK = tiling->ubCalSize * 4 / newBaseN; + uint64_t weightWorkspaceOffset = (coreIdx + workspacePieceIdx * tiling->coreNum) * mnConfig.k * mnConfig.singleN; + uint64_t weightOffset = mnConfig.groupIdx * mnConfig.n * mnConfig.k; + uint64_t wInOffset = tiling->transposeW ? mnConfig.tailN * mnConfig.k : mnConfig.tailN; + if (tiling->isPerGroup) { + newBaseK = newBaseK > perGroupSize ? perGroupSize : newBaseK; + } + DataCopyPadExtParams padParams; + for (uint32_t offsetN(0), curBaseN(newBaseN), vecCount(0); offsetN < curSingleN; offsetN += newBaseN) { + if (unlikely(offsetN + newBaseN > curSingleN)) { + curBaseN = curSingleN - offsetN; + } + uint32_t alignBaseN = (Ceil(curBaseN, uint32_t(8)) * 8); + uint32_t curBaseK = newBaseK; + uint32_t curUsedGroupSize = perGroupSize; + bool firstKLoop = true; + int32_t prePergroupIdx = -1; + int32_t curPergroupIdx = 0; + for (uint32_t offsetK(0); offsetK < curSingleK; offsetK += curBaseK) { + vecCount++; + ComputeUbBaseK(curSingleK, offsetK, newBaseK, curUsedGroupSize, curBaseK); + if (taskRation != 0 && vecCount % taskRation != subBlockIdx) { + continue; + } + + if (tiling->isPerGroup) { + curPergroupIdx = offsetK / perGroupSize; + if (firstKLoop || curPergroupIdx > prePergroupIdx) { // load new group + FreeScaleAndOffset(firstKLoop); + + DataCopyAntiQuantScale(curBaseN, alignBaseN, mnConfig.tailN + offsetN + (mnConfig.groupIdx * \ + mnConfig.k / perGroupSize + curPergroupIdx) * mnConfig.n); + prePergroupIdx = curPergroupIdx; + } + } + LocalTensor inLocal = vecInQueue.AllocTensor(); + DataCopyExtParams gmToUbIntriParams; + SetGmToUbDataCopyParams(curBaseN, curBaseK, mnConfig, gmToUbIntriParams); + uint64_t weightInOffset = tiling->transposeW ? offsetK + static_cast(offsetN) * mnConfig.k : + static_cast(offsetK) * mnConfig.n + offsetN; + DataCopyPad(inLocal, weightGm[((weightOffset + weightInOffset + wInOffset)) / 2], + gmToUbIntriParams, padParams); + vecInQueue.EnQue(inLocal); + + DataCopyExtParams ubToGmIntriParams; + if (tiling->transposeW) { + uint32_t alignBaseK = (Ceil(curBaseK, uint32_t(8)) * 8); + CastWeightCompute(alignBaseK, alignBaseN); + SetUbToGmDataCopyParams(curBaseN, alignBaseK, curBaseK, mnConfig, ubToGmIntriParams); + } else { + CastWeightCompute(curBaseK, alignBaseN); + SetUbToGmDataCopyParams(curBaseN, curSingleN, curBaseK, mnConfig, ubToGmIntriParams); + } + LocalTensor wResUb = vecOutQueue.DeQue(); + uint64_t weightOutOffset = tiling->transposeW ? offsetK + offsetN * mnConfig.k : offsetK * curSingleN + offsetN; + DataCopyPad(weightWorkspaceGm[weightWorkspaceOffset + weightOutOffset], wResUb, ubToGmIntriParams); + vecOutQueue.FreeTensor(wResUb); + } + if (!firstKLoop) { + nAxisInQueue.FreeTensor(perGroupScaleInUb); + } + } + CrossCoreSetFlag<2, PIPE_MTE3>(SYNC_AIV_TO_AIC); +} + +__aicore__ inline void A8W4GroupMatmul::MMCompute(MNConfig& mnConfig, uint32_t secondHalfIterCount) +{ + uint32_t workspacePieceIdx = secondHalfIterCount % (POST_SKIP_ITER_NUM + 1); + uint32_t curSingleN = mnConfig.singleN; + if (mnConfig.nIdx == mnConfig.blockDimN - 1) { + curSingleN = tiling->n - mnConfig.tailN; + } + uint32_t curSingleM = mnConfig.singleM; + if (mnConfig.mIdx == mnConfig.blockDimM - 1) { + curSingleM = mnConfig.m - mnConfig.mIdx * mnConfig.singleM; + } + uint64_t xOffset = (mnConfig.offsetM + mnConfig.mIdx * mnConfig.singleM) * tiling->k; + uint64_t weightOffset = (coreIdx + workspacePieceIdx * tiling->coreNum) * mnConfig.k * curSingleN; + mnConfig.workSpaceOffset = mnConfig.singleN * mnConfig.singleM * (coreIdx + workspacePieceIdx * tiling->coreNum); + if ASCEND_IS_AIC { + CrossCoreWaitFlag(SYNC_AIV_TO_AIC); + + mm.SetOrgShape(mnConfig.m, curSingleN, tiling->k); + mm.SetSingleShape(curSingleM, curSingleN, tiling->k); + mm.SetTensorA(xGm[xOffset]); + auto weightSlice = weightWorkspaceGm[weightOffset]; + if (mnConfig.blockDimM == 1) { + weightSlice.SetL2CacheHint(CacheMode::CACHE_MODE_DISABLE); + } + mm.SetTensorB(weightSlice); + uint64_t worskspaceOffset = mnConfig.workSpaceOffset; + while (mm.Iterate()) { + mm.GetTensorC(mmOutGm[worskspaceOffset], 0, true); + worskspaceOffset += (mnConfig.baseM * mnConfig.baseN); + } + CrossCoreSetFlag<2, PIPE_FIX>(SYNC_AIC_TO_AIV); // 2: mode为2, group内同步 + } + cubeCount++; +} + +__aicore__ inline void A8W4GroupMatmul::TailProcess(MNConfig &mnConfig, uint32_t secondHalfIterCount) { + uint32_t resPostLoop = POST_SKIP_ITER_NUM; + if (secondHalfIterCount < POST_SKIP_ITER_NUM) { + resPostLoop = secondHalfIterCount; + secondHalfIterCount = POST_SKIP_ITER_NUM; + } + for (uint32_t resIdx = 0; resIdx < resPostLoop; ++resIdx) { + VectorCompute(mnConfig, true, secondHalfIterCount); + secondHalfIterCount++; + } +} + +__aicore__ inline void A8W4GroupMatmul::VectorCompute(MNConfig& mnConfig, bool isLastGroup, uint32_t secondHalfIterCount) +{ + if ASCEND_IS_AIC { + return; + } + if (!isLastGroup) { + mnConfigs[secondHalfIterCount % (POST_SKIP_ITER_NUM + 1)] = mnConfig; + if (secondHalfIterCount < POST_SKIP_ITER_NUM) { + return; + } + } + MNConfig postMNConfig = mnConfigs[(secondHalfIterCount - POST_SKIP_ITER_NUM) % (POST_SKIP_ITER_NUM + 1)]; + uint32_t curCubeSingleN = postMNConfig.singleN; + if (postMNConfig.nIdx == postMNConfig.blockDimN - 1) { + curCubeSingleN = tiling->n - postMNConfig.nIdx * postMNConfig.singleN; + } + uint32_t curCubeSingleM = postMNConfig.singleM; + if (postMNConfig.mIdx == postMNConfig.blockDimM - 1) { + curCubeSingleM = postMNConfig.m - postMNConfig.mIdx * postMNConfig.singleM; + } + uint32_t vecBaseM = tiling->ubCalSize / (Ceil(postMNConfig.baseN, uint32_t(8)) * 8); // 8: num int32_t in 32B ub block + vecBaseM = vecBaseM < curCubeSingleM ? vecBaseM : curCubeSingleM; + uint32_t curVecBaseN = postMNConfig.baseN; + uint64_t scaleOffset = postMNConfig.groupIdx * tiling->n + postMNConfig.nIdx * postMNConfig.singleN; + uint64_t outOffset = (postMNConfig.offsetM + postMNConfig.mIdx * postMNConfig.singleM) * tiling->n + \ + postMNConfig.nIdx * postMNConfig.singleN; + bool firstFlag = true; + for (uint32_t offsetN = 0, vecCount = 0; offsetN < curCubeSingleN; offsetN += postMNConfig.baseN) { + if (unlikely(offsetN + postMNConfig.baseN >= curCubeSingleN)) { + curVecBaseN = curCubeSingleN - offsetN; + } + uint32_t alignBaseN = Ceil(curVecBaseN, uint32_t(8)) * 8; // 8: num int32_t in 32B ub block + DataCopyScale(curVecBaseN, alignBaseN, scaleOffset + offsetN); + uint32_t curVecBaseM = vecBaseM; + uint64_t mmOutOffset = postMNConfig.workSpaceOffset + offsetN * mnConfig.baseM; + + for (uint32_t offsetM = 0; offsetM < curCubeSingleM; offsetM += vecBaseM) { + vecCount++; + if (taskRation == 0 || vecCount % taskRation == subBlockIdx) { + DataCopyPerTokenScaleAndLogits(postMNConfig, curVecBaseM, alignBaseN, offsetM); + } + if (firstFlag) { + CrossCoreWaitFlag(SYNC_AIC_TO_AIV); + firstFlag = false; + } + if (taskRation != 0 && vecCount % taskRation != subBlockIdx) { + continue; + } + if (unlikely(offsetM + vecBaseM >= curCubeSingleM)) { + curVecBaseM = curCubeSingleM - offsetM; + } + // 使用AscendDequant接口做perchannel反量化 + LocalTensor mmOutLocal = vecInQueue.AllocTensor(); + DataCopyPad2D(mmOutLocal, mmOutGm[mmOutOffset + offsetM * curVecBaseN], + curVecBaseM, curVecBaseN, curVecBaseN); + vecInQueue.EnQue(mmOutLocal); + ComputeDequantAndFinalizeRounting(postMNConfig, curVecBaseM, alignBaseN, curVecBaseN, offsetM); + LocalTensor yLocal = vecOutQueue.DeQue(); + //DataCopyPad2D(yGm[outOffset + offsetM * tiling->n + offsetN], yLocal, + // curVecBaseM, curVecBaseN, alignBaseN, tiling->n); + for (uint32_t outIdx = offsetM; outIdx < curVecBaseM; outIdx++) { + int32_t outputTokenId = tokenRanksGm.GetValue(postMNConfig.offsetM + outIdx); + int32_t outputTokenOffset = outputTokenId * tiling->n + offsetN + postMNConfig.tailN; + SetAtomicAdd(); + DataCopyPad2D(yGm[outputTokenOffset], yLocal[(outIdx - offsetM) * curVecBaseN], + 1, curVecBaseN, 0, 0); + SetAtomicNone(); + } + vecOutQueue.FreeTensor(yLocal); + } + nAxisInQueue.FreeTensor(perChannelScaleInUb); + } + //CrossCoreSetFlag<2, PIPE_MTE2>(SYNC_AIV_TO_AIC); // 2: mode为2, group内同步 +} + +__aicore__ inline void A8W4GroupMatmul::PostCompute() { + if ASCEND_IS_AIC { + if (likely(cubeCount > 0)) { // exits last cycle wait + uint32_t loop = cubeCount < tiling->parallNum ? cubeCount : tiling->parallNum; + for (int32_t idx = 0; idx < loop; ++idx) { + CrossCoreWaitFlag(SYNC_AIV_TO_AIC); + } + } + } +} + +__aicore__ inline void A8W4GroupMatmul::PrefillWorkspace() +{ + if ASCEND_IS_AIC { + return; + } + uint32_t prefillBaseN = 512; + uint32_t prefillBaseM = 10; + uint32_t prefillSingleN = 1024; + uint32_t prefillSingleM = 24; + uint32_t curPrefillBaseM = prefillBaseM; + uint32_t curPrefillBaseN = prefillBaseN; + + uint32_t prefillMIdx = coreIdx / 8; + uint32_t prefillNIdx = coreIdx % 8; + for (uint32_t offsetN = 0, vecCount = 0; offsetN < prefillSingleN; offsetN += prefillBaseN) { + if (unlikely(offsetN + prefillBaseN >= prefillSingleN)) { + curPrefillBaseN = prefillSingleN - offsetN; + } + uint32_t alignBaseN = Ceil(curPrefillBaseN, uint32_t(8)) * 8; // 8: num int32_t in 32B ub block + for (uint32_t offsetM = 0; offsetM < prefillSingleM; offsetM += prefillBaseM) { + vecCount++; + if (taskRation != 0 && vecCount % taskRation != subBlockIdx) { + continue; + } + if (unlikely(offsetM + prefillBaseM >= prefillSingleM)) { + curPrefillBaseM = prefillSingleM - offsetM; + } + LocalTensor residualOutLocal = vecInQueue.AllocTensor(); + DataCopyPad2D(residualOutLocal, residualGm[(prefillMIdx * prefillSingleM + offsetM) * tiling->n + prefillNIdx * prefillSingleN + offsetN], + curPrefillBaseM, curPrefillBaseN, tiling->n); + + vecInQueue.EnQue(residualOutLocal); + LocalTensor residualOutInUb = vecInQueue.DeQue(); + LocalTensor residualResInUb = vecOutQueue.AllocTensor(); + Cast(residualResInUb, residualOutInUb, RoundMode::CAST_NONE, curPrefillBaseM * curPrefillBaseN); + vecInQueue.FreeTensor(residualOutInUb); + Muls(residualResInUb, residualResInUb, tiling->residualSacle, curPrefillBaseM * curPrefillBaseN); + + vecOutQueue.EnQue(residualResInUb); + LocalTensor yLocal = vecOutQueue.DeQue(); + DataCopyPad2D(yGm[(prefillMIdx * prefillSingleM + offsetM) * tiling->n + prefillNIdx * prefillSingleN + offsetN], yLocal, + curPrefillBaseM, curPrefillBaseN, curPrefillBaseN, tiling->n); + vecOutQueue.FreeTensor(yLocal); + + } + } + +} + +__aicore__ inline void A8W4GroupMatmul::ComputeDequantAndFinalizeRounting(MNConfig& mnConfig, + uint32_t curVecBaseM, uint32_t alignBaseN, uint32_t curVecBaseN, uint32_t offsetM) +{ + LocalTensor mmOutInUb = vecInQueue.DeQue(); + AscendDequant(dequantMiddleResult, mmOutInUb, perChannelScaleInUb, sharedTmpLocal, {curVecBaseM, alignBaseN, curVecBaseN}); + PipeBarrier(); + vecInQueue.FreeTensor(mmOutInUb); + // pertoken反量化 + uint32_t computeSize = curVecBaseM * alignBaseN; + LocalTensor yLocalInUb = vecOutQueue.AllocTensor(); + Mul(yLocalInUb, dequantMiddleResult, pertokenBrcbLocal, computeSize); + PipeBarrier(); + vecOutQueue.EnQue(yLocalInUb); +} + +__aicore__ inline void A8W4GroupMatmul::DataCopyScale(uint32_t curBaseN, uint32_t alignBaseN, uint64_t scaleOffset) +{ + // GM拷贝scale + DataCopyPadExtParams padParams; + DataCopyExtParams scaleParams{1, static_cast(curBaseN * sizeof(float)), 1, 1, 0}; + LocalTensor scaleLocal = nAxisInQueue.AllocTensor(); + DataCopyPad(scaleLocal, perChannelScaleGm[scaleOffset], scaleParams, padParams); + nAxisInQueue.EnQue(scaleLocal); + + perChannelScaleInUb = nAxisInQueue.DeQue(); + perChannelScaleInUb.SetSize(alignBaseN); +} + +__aicore__ inline void A8W4GroupMatmul::DataCopyPerTokenScaleAndLogits(MNConfig& mnConfig, + uint32_t curBaseM, uint32_t alignBaseN, uint32_t offsetM) +{ + uint64_t mAxisOffset = mnConfig.offsetM + mnConfig.mIdx * mnConfig.singleM + offsetM; + // GM拷贝per token scale + DataCopyPadExtParams padParams; + DataCopyExtParams mAxisParams{1, static_cast(curBaseM * sizeof(float)), 0, 0, 0}; + LocalTensor perTokenScaleLocal = perTokenScaleInQueue.AllocTensor(); + DataCopyPad(perTokenScaleLocal, perTokenScaleGm[mAxisOffset], mAxisParams, padParams); + perTokenScaleInQueue.EnQue(perTokenScaleLocal); + + sharedTmpLocal = tmpBuff.GetWithOffset(2 * tiling->ubCalSize * sizeof(float), + 2 * tiling->ubCalSize * sizeof(float)); + perTokenScaleInUb = perTokenScaleInQueue.DeQue(); + LocalTensor logitsLocal = logitsInQueue.AllocTensor(); + DataCopyPad(logitsLocal, logitsGm[mAxisOffset], mAxisParams, padParams); + logitsInQueue.EnQue(logitsLocal); + logitsInUb = logitsInQueue.DeQue(); + Mul(perTokenScaleInUb, perTokenScaleInUb, logitsInUb, curBaseM); + PipeBarrier(); + const uint32_t broadCastDst[BROADCAST_DIM] = {curBaseM, alignBaseN}; + const uint32_t broadCastSrc[BROADCAST_DIM] = {curBaseM, 1}; + BroadCast(pertokenBrcbLocal, perTokenScaleInUb, broadCastDst, broadCastSrc, + sharedTmpLocal); + perTokenScaleInQueue.FreeTensor(perTokenScaleInUb); + logitsInQueue.FreeTensor(logitsInUb); +} + +__aicore__ inline void CopyTiling(GemmA8W4CombineTilingData *tiling, GM_ADDR tilingGM) +{ + uint32_t *ptr = reinterpret_cast(tiling); + auto tiling32 = reinterpret_cast<__gm__ uint32_t *>(tilingGM); + + for (uint32_t i = 0; i < sizeof(GemmA8W4CombineTilingData) / sizeof(uint32_t); i++, ptr++) { + *ptr = *(tiling32 + i); + } + return; +} + +extern "C" __global__ __aicore__ void gemm_a8w4_combine(GM_ADDR inputs, GM_ADDR weights, GM_ADDR bias, + GM_ADDR logits, GM_ADDR group_tokens, GM_ADDR token_ranks, GM_ADDR residual, + GM_ADDR inputs_dq_channel_scale, GM_ADDR weight_group_scale, GM_ADDR inputs_dq_token_scale, + GM_ADDR outputs, GM_ADDR workspace, GM_ADDR tilingGm) +{ + TPipe tPipe; + GemmA8W4CombineTilingData tilingData; + CopyTiling(&tilingData, tilingGm); + MT mm; + if ASCEND_IS_AIC { + mm.SetSubBlockIdx(0); + mm.Init(&tilingData.mmTilingData, &tPipe); + } + A8W4GroupMatmul op(mm); + op.Init(inputs, weights, bias, logits, group_tokens, token_ranks, residual, inputs_dq_channel_scale, weight_group_scale, + inputs_dq_token_scale, outputs, workspace, &tilingData, &tPipe); + op.Process(); +} \ No newline at end of file diff --git a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine_tiling.cpp b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine_tiling.cpp new file mode 100755 index 0000000000000000000000000000000000000000..a36fb0e1cc3aa79fbf6a47271985cdde34f88da6 --- /dev/null +++ b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine_tiling.cpp @@ -0,0 +1,61 @@ +/** + * @file quant_group_matmul_custom_tiling.cpp + * + * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#include +#include +#include +#include +#include + +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" +#include "gemm_a8w4_combine_tiling.h" + +using namespace matmul_tiling; +using namespace std; + +constexpr uint32_t BEST_BASE_M = 64; +constexpr uint32_t BEST_BASE_K = 64; +constexpr uint32_t BEST_BASE_N = 512; + +bool GenerateTiling(GemmA8W4CombineTilingData &gmmTiling) +{ + optiling::TCubeTiling tilingData; + auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(); + MultiCoreMatmulTiling tilingApi(*ascendcPlatform); + + tilingApi.SetDim(1); + tilingApi.SetAType(TPosition::GM, CubeFormat::ND, DataType::DT_INT8, false); + tilingApi.SetBType(TPosition::GM, CubeFormat::ND, DataType::DT_INT8, false); + tilingApi.SetCType(TPosition::GM, CubeFormat::ND, DataType::DT_INT32); + tilingApi.SetBias(false); + + tilingApi.SetOrgShape(BEST_BASE_M, gmmTiling.n, gmmTiling.k); + tilingApi.SetShape(BEST_BASE_M, gmmTiling.n, gmmTiling.k); + tilingApi.SetFixSplit(BEST_BASE_M, BEST_BASE_N, BEST_BASE_K); + + int64_t res = tilingApi.GetTiling(tilingData); + if (res == -1) { + std::cout << "gen tiling failed" << std::endl; + return false; + } + tilingData.set_dbL0C(1); + tilingData.set_stepKa(4); // 4: L1中左矩阵单次搬运基于baseK的4倍数据 + tilingData.set_stepKb(4); // 4: L1中右矩阵单次搬运基于baseK的4倍数据 + tilingData.set_depthA1(8); // 8: stepKa的两倍,开启double buffer + tilingData.set_depthB1(8); // 8: stepKb的两倍,开启double buffer + tilingData.set_stepM(1); + tilingData.set_stepN(1); + + uint32_t tilingSize = tilingData.GetDataSize(); + tilingData.SaveToBuffer(&gmmTiling.mmTilingData, tilingSize); + + return true; +} + diff --git a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine_tiling.h b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine_tiling.h new file mode 100755 index 0000000000000000000000000000000000000000..5976712a72174a3ea87da1fc7531c59a25560711 --- /dev/null +++ b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/gemm_a8w4_combine_tiling.h @@ -0,0 +1,35 @@ +/** + * @file quant_group_matmul_custom_tiling.h + * + * Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ +#ifndef GROUP_A8W4_MATMUL_TILING_H +#define GROUP_A8W4_MATMUL_TILING_H + +#include "kernel_tiling/kernel_tiling.h" + +struct GemmA8W4CombineTilingData +{ + uint32_t coreNum; + uint32_t groupNum; + uint32_t totalInGroup; + uint32_t k; + uint32_t n; + uint32_t ubCalSize; + uint32_t ubRestBytes; + uint32_t parallNum; + uint32_t perGroupSize; + uint32_t ubBaseN; + uint32_t singleN; + uint64_t workspaceSize; + float residualSacle; + bool transposeW; + bool isPerGroup; + TCubeTiling mmTilingData; +}; + +#endif \ No newline at end of file diff --git a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/main.cpp b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/main.cpp old mode 100644 new mode 100755 index 0ba82c66cb58ea9079afad28f3fba2ae0fccbc9d..a31dd972d9a410a5c6cb65bcc8a8be9c73b5088f --- a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/main.cpp +++ b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/main.cpp @@ -7,18 +7,22 @@ * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. */ +#include #include "data_utils.h" -#include "quant_group_matmul_custom_tiling.h" #include "tiling/platform/platform_ascendc.h" +#include "gemm_a8w4_combine_tiling.h" #ifndef ASCENDC_CPU_DEBUG #include "acl/acl.h" -#include "aclrtlaunch_quant_group_matmul_custom.h" +//#include "aclrtlaunch_quant_group_matmul_custom.h" +#include "aclrtlaunch_gemm_a8w4_combine.h" + #else #include "tikicpulib.h" -extern "C" void quant_group_matmul_custom(uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *, +extern "C" void gemm_a8w4_combine(uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *, + uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *, uint8_t *); #endif -extern bool GenerateTiling(QuantGroupMatmulCustomTilingData &gmmTiling); +extern bool GenerateTiling(GemmA8W4CombineTilingData &gmmTiling); struct Data { size_t size = 0; @@ -33,8 +37,12 @@ struct OpArgs { Data x; Data weight; Data bias; + Data logits; Data groupList; + Data token_ranks; + Data residual; Data scale; + Data pergroup_scale; Data perTokenScale; Data tiling; Data workspace; @@ -67,6 +75,10 @@ bool CreateInput(OpArgs &args) { AllocMem("x", args.x); AllocMem("weight", args.weight); AllocMem("bias", args.bias); + AllocMem("residual", args.residual); + AllocMem("perGroupScale", args.pergroup_scale); + AllocMem("tokenTanks", args.token_ranks); + AllocMem("logits", args.logits); AllocMem("groupList", args.groupList); AllocMem("scale", args.scale); AllocMem("perTokenScale", args.perTokenScale); @@ -98,6 +110,11 @@ bool FreeData(OpArgs &args) FreeMem(args.x); FreeMem(args.weight); FreeMem(args.bias); + FreeMem(args.residual); + FreeMem(args.logits); + FreeMem(args.pergroup_scale); + FreeMem(args.token_ranks); + FreeMem(args.groupList); FreeMem(args.scale); FreeMem(args.perTokenScale); @@ -110,26 +127,34 @@ bool FreeData(OpArgs &args) int32_t main(int32_t argc, char *argv[]) { - int m = 1024; - int k = 1024; - int n = 8192; - int groupNum = 8; - uint32_t blockDim = 8; + //int m = 432; + //int k = 1024; + //int n = 8192; + //int groupNum = 128; + int m = 3072; + int k = 1536; + int n = 4096; + int groupNum = 32; + uint32_t blockDim = 24; uint32_t cvParallNum = 4; - size_t userWorkspaceSize = cvParallNum * 256 * 128 * sizeof(int32_t) * blockDim; + size_t userWorkspaceSize = 1024*1024 * sizeof(int) * blockDim * 32; auto ascendcPlatform = platform_ascendc::PlatformAscendCManager::GetInstance(); size_t systemWorkspaceSize = static_cast(ascendcPlatform->GetLibApiWorkSpaceSize()); size_t workspaceSize = userWorkspaceSize + systemWorkspaceSize; OpArgs args { .x=Data(m * k * sizeof(int8_t)), - .weight=Data(groupNum * k * n * sizeof(int8_t)), + .weight=Data(groupNum * k * n * sizeof(int8_t) / 2), .bias=Data(0), // no bias + .logits=Data(m * sizeof(float)), .groupList=Data(groupNum * sizeof(int64_t)), + .token_ranks=Data(m * sizeof(int64_t)), + .residual=Data(72 * n * sizeof(float)/2), .scale=Data(groupNum * n * sizeof(float)), + .pergroup_scale=Data(groupNum * (k/512) * n * sizeof(float) / 2), .perTokenScale=Data(m * sizeof(float)), - .tiling=Data(sizeof(QuantGroupMatmulCustomTilingData), false), + .tiling=Data(sizeof(GemmA8W4CombineTilingData), false), .workspace=Data(workspaceSize, false), - .y=Data(m * n * sizeof(int16_t), false) // sizeof(half) + .y=Data(m * n * sizeof(float), false) // sizeof(half) }; #ifndef ASCENDC_CPU_DEBUG CHECK_ACL(aclInit(nullptr)); @@ -139,22 +164,30 @@ int32_t main(int32_t argc, char *argv[]) CHECK_ACL(aclrtCreateStream(&stream)); #endif CreateInput(args); - QuantGroupMatmulCustomTilingData &gmmTiling = *reinterpret_cast(args.tiling.host); + GemmA8W4CombineTilingData &gmmTiling = *reinterpret_cast(args.tiling.host); gmmTiling.coreNum = blockDim; gmmTiling.groupNum = groupNum; gmmTiling.totalInGroup = m; gmmTiling.k = k; gmmTiling.n = n; - gmmTiling.ubCalSize = 24 * 256; // 24: vector每次计算的行数,256: 每次计算的列数,与cube baseN保持一致 - gmmTiling.ubRestBytes = 118784; // 118784: ub剩余大小 + gmmTiling.ubCalSize = 5120; // 24: vector每次计算的行数,256: 每次计算的列数,与cube baseN保持一致 + gmmTiling.ubRestBytes = 81920; // 118784: ub剩余大小 gmmTiling.parallNum = cvParallNum; + gmmTiling.perGroupSize = 512; + gmmTiling.residualSacle = 0.8; + gmmTiling.transposeW = false; + gmmTiling.isPerGroup = true; + gmmTiling.workspaceSize = k*1024*48; + gmmTiling.singleN = 1024; + gmmTiling.ubBaseN = 512; GenerateTiling(gmmTiling); #ifdef ASCENDC_CPU_DEBUG // AscendC::SetKernelMode(KernelMode::AIV_MODE); - ICPU_RUN_KF(quant_group_matmul_custom, blockDim, args.x.host, args.weight.host, args.bias.host, - args.groupList.host, args.scale.host, args.perTokenScale.host, args.y.host, args.workspace.host, + ICPU_RUN_KF(gemm_a8w4_combine, blockDim, args.x.host, args.weight.host, args.bias.host, + args.logits.host, args.groupList.host, args.token_ranks.host, args.residual.host, args.scale.host, + args.pergroup_scale.host, args.perTokenScale.host, args.y.host, args.workspace.host, args.tiling.host); WriteFile("./output/output.bin", args.y.host, args.y.size); @@ -163,8 +196,9 @@ int32_t main(int32_t argc, char *argv[]) CHECK_ACL(aclrtMemcpy(args.tiling.device, args.tiling.size, args.tiling.host, args.tiling.size, ACL_MEMCPY_HOST_TO_DEVICE)); - ACLRT_LAUNCH_KERNEL(quant_group_matmul_custom) - (blockDim, stream, args.x.device, args.weight.device, args.bias.device, args.groupList.device, args.scale.device, + ACLRT_LAUNCH_KERNEL(gemm_a8w4_combine) + (blockDim, stream, args.x.device, args.weight.device, args.bias.device, args.logits.device, args.groupList.device, args.token_ranks.device, + args.residual.device, args.scale.device, args.pergroup_scale.device, args.perTokenScale.device, args.y.device, args.workspace.device, args.tiling.device); CHECK_ACL(aclrtSynchronizeStream(stream)); @@ -178,4 +212,4 @@ int32_t main(int32_t argc, char *argv[]) CHECK_ACL(aclFinalize()); #endif return 0; -} \ No newline at end of file +} diff --git a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/scripts/gen_data.py b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/scripts/gen_data.py old mode 100644 new mode 100755 index 216fbb626cd9ca97a3aa5fa5be65534cb3cb1327..2810371f7d1d62c6e36fc892f275fc9067e68c23 --- a/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/scripts/gen_data.py +++ b/operator/ascendc/4_best_practices/6_group_matmul/KernelLaunch/scripts/gen_data.py @@ -10,41 +10,69 @@ import numpy as np import os - +import tensorflow as tf +bfloat16 = tf.bfloat16.as_numpy_dtype def FastGelu(x): x2 = np.exp(0.851 * (x - np.abs(x))) x3 = 1 + np.exp(-1.702 * np.abs(x)) return x * x2 / x3 +def pack_4bit_to_bytes(data): + data = data.astype(np.int8) + data[data < 0] += 16 + packed = (data[..., 1::2] << 4) | (data[..., ::2] & 0x0F) + return packed def gen_golden_data(): - groupNum = 8 - m = 1024 + groupNum = 128 + m = 432 k = 1024 n = 8192 + quantGroupNum = k // 512 - groupList = np.array([128] * groupNum, dtype=np.int64) - + groupList = np.array([3] * 80 + [4] * 48, dtype=np.int64) + print(groupList) x = np.random.randint(-10, 10, [m, k]).astype(np.int8) - weight = np.random.randint(-10, 10, [groupNum, k, n]).astype(np.int8) + weight = np.random.randint(-8, 8, [groupNum, k, n], dtype=np.int8) + weightInt4 = pack_4bit_to_bytes(weight) + perGroupScale = np.random.normal(0, 1, (groupNum, quantGroupNum, n)).astype(np.float16) scale = np.random.normal(0, 0.01, (groupNum, n)).astype(np.float32) - perTokenScale = np.random.normal(0, 0.01, (m, 1)).astype(np.float32) + perTokenScale = np.random.normal(0, 1, (m, 1)).astype(np.float32) + logits = np.random.normal(0, 0.01, (m, 1)).astype(np.float32) index = np.cumsum(groupList) + #tokenRanks = np.random.randint(0, 71, [m, 1]).astype(np.int64) + tokenRanks = np.repeat(np.arange(72), 6)[:432].astype(np.int64) + print(tokenRanks) + residual = np.random.normal(-10, 10, (72, n)).astype(bfloat16) xSplit = np.split(x, index, axis=0) perTokenScaleSplit = np.split(perTokenScale, index, axis=0) + logitsSplit = np.split(logits, index, axis=0) + mmOuts = [] + weightHigh = weight.astype(np.float32).reshape([groupNum, quantGroupNum, -1, n]) * perGroupScale.reshape([groupNum, quantGroupNum, 1, n]) + weightHigh = weightHigh.reshape([groupNum, k, n]) + weightHigh = np.minimum(weightHigh, 127) + weightHigh = np.maximum(weightHigh, -128) + for i in range(groupNum): - mm = np.matmul(xSplit[i].astype(np.int32), weight[i].astype(np.int32)) - mm = mm.astype(np.float32) * scale[i].astype(np.float32) * perTokenScaleSplit[i] + ww = np.round(weightHigh[i]).astype(np.int8).astype(np.int32) + mm = np.matmul(xSplit[i].astype(np.int32), ww) + mm = mm.astype(np.float32) * scale[i].astype(np.float32) * perTokenScaleSplit[i] * logitsSplit[i] mmOuts.append(mm) - golden = np.concatenate(mmOuts, axis=0).astype(np.float16) - golden = FastGelu(golden) + residual = residual.astype(np.float32) + residual = residual * 0.8 + golden = np.concatenate(mmOuts, axis=0).astype(np.float32) + #for i in range(432): + # residual[tokenRanks[i]] = residual[tokenRanks[i]] + golden[i] os.makedirs("input", exist_ok=True) os.makedirs("output", exist_ok=True) x.tofile("./input/x.bin") - weightNz = weight.reshape([groupNum, k // 16, 16, n // 32, 32]).transpose([0, 3, 1, 2, 4]) - weightNz.tofile("./input/weight.bin") + weightInt4.tofile("./input/weight.bin") + residual.tofile("./input/residual.bin") + perGroupScale.tofile("./input/perGroupScale.bin") + tokenRanks.tofile("./input/tokenTanks.bin") + logits.tofile("./input/logits.bin") groupList.tofile("./input/groupList.bin") scale.tofile("./input/scale.bin") perTokenScale.tofile("./input/perTokenScale.bin")