diff --git a/src/kernels/mixkernels/mla_preprocess/mla_preprocess_operation.cpp b/src/kernels/mixkernels/mla_preprocess/mla_preprocess_operation.cpp index 77cfb71c406f03ae1b137a4747164c8c7cb69777..30c1967821645c6a0c8cff7306dd3621706ab499 100644 --- a/src/kernels/mixkernels/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/kernels/mixkernels/mla_preprocess/mla_preprocess_operation.cpp @@ -8,6 +8,7 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include +#include #include #include #include @@ -36,7 +37,8 @@ public: if (deqOnTheFly) { return GetKernelByName("MLAPreprocessKernel"); } - return GetKernelByName("MLAPreprocessBF16Kernel"); + //return GetKernelByName("MLAPreprocessBF16Kernel"); + return GetKernelByName("MLAPreprocessKernel"); } int64_t GetInputNum(const Any &specificParam) const override diff --git a/src/kernels/mixkernels/mla_preprocess/op_kernel/mla_preprocess_mix.cce b/src/kernels/mixkernels/mla_preprocess/op_kernel/mla_preprocess_mix.cce index ac21dbe8ac5f1abf865a2dde8f0980a0032f10ed..4ab7c4e2d64b2e4fef6c046bd578ff4d49f95f62 100644 --- a/src/kernels/mixkernels/mla_preprocess/op_kernel/mla_preprocess_mix.cce +++ b/src/kernels/mixkernels/mla_preprocess/op_kernel/mla_preprocess_mix.cce @@ -65,7 +65,7 @@ constexpr uint8_t CACHE_MODE_NZCACHE = 3; // pp matmul namespace { -constexpr uint32_t HIDDTEN_STATE = 7168; +constexpr uint32_t HIDDTEN_STATE = 6144; constexpr uint32_t FLOAT_BLOCK_SIZE = 64; constexpr uint32_t HALF_BLOCK_SIZE = 64; constexpr uint32_t HALF_VECTOR_SIZE = 64; @@ -377,11 +377,11 @@ public: __aicore__ inline RmsNormQuant() {} __aicore__ inline void Init(AscendC::GlobalTensor gammaGmTensor, - AscendC::GlobalTensor betaGmTensor, - AscendC::GlobalTensor quantScaleGmTensor, + AscendC::GlobalTensor betaGmTensor, + AscendC::GlobalTensor quantScaleGmTensor, AscendC::GlobalTensor quantOffsetGmTensor, AscendC::GlobalTensor inputGmTensor, - AscendC::GlobalTensor outputGmTensor, + AscendC::GlobalTensor<__bf16> outputGmTensor, uint32_t stride, uint32_t num_col, float avg_factor, @@ -421,8 +421,8 @@ public: __aicore__ inline void Launch(const AscendC::LocalTensor &dstTensor, const AscendC::LocalTensor &srcTensor, const AscendC::LocalTensor &gammaTensor, - const AscendC::LocalTensor &betaTensor, - const AscendC::LocalTensor &quantScaleTensor, + const AscendC::LocalTensor &betaTensor, + const AscendC::LocalTensor &quantScaleTensor, const AscendC::LocalTensor &quantOffsetTensor, const AscendC::LocalTensor &res1Tensor, const AscendC::LocalTensor &res3Tensor) @@ -519,6 +519,8 @@ public: Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); + + /* if constexpr (WITH_BETA) { // quant的beta是fp16加的 AscendC::LocalTensor b = this->betaTensor; Cast(work, @@ -555,17 +557,18 @@ public: num_col_align_withStride_fp32 / REPEAT_TIME_64, {1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE}); AscendC::PipeBarrier(); + */ - AscendC::LocalTensor tmpfp16 = - buf.ReinterpretCast()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; + AscendC::LocalTensor<__bf16> tmpfp16 = + buf.ReinterpretCast<__bf16>()[OFFSET_SUM * num_col_align_withStride_fp32 * 2]; CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32); AscendC::PipeBarrier(); - CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); + // CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16); SET_FLAG(V, MTE3, EVENT_ID0); WAIT_FLAG(V, MTE3, EVENT_ID0); AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], - dstTensor, - AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0)); + tmpfp16, + AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 16, 0, 0)); SET_FLAG(MTE3, V, EVENT_ID0); WAIT_FLAG(MTE3, V, EVENT_ID0); SET_FLAG(MTE3, MTE2, EVENT_ID0); @@ -580,16 +583,16 @@ private: AscendC::LocalTensor dstTensor; AscendC::LocalTensor srcTensor; AscendC::LocalTensor gammaTensor; - AscendC::LocalTensor betaTensor; + AscendC::LocalTensor betaTensor; AscendC::LocalTensor fp32_xy; AscendC::LocalTensor buf; AscendC::GlobalTensor gammaGmTensor; - AscendC::GlobalTensor betaGmTensor; - AscendC::GlobalTensor quantScaleGmTensor; + AscendC::GlobalTensor betaGmTensor; + AscendC::GlobalTensor quantScaleGmTensor; AscendC::GlobalTensor quantOffsetGmTensor; AscendC::GlobalTensor inputGmTensor; - AscendC::GlobalTensor outputGmTensor; + AscendC::GlobalTensor<__bf16> outputGmTensor; uint32_t num_col_{0}; // 输入的列数 uint32_t row_work{0}; // 需要计算多少行 @@ -850,10 +853,10 @@ struct MatCoord { uint64_t n{0}; }; -template +template class PpMatmulEinSum { - using InDtype = half; - using OutDtype = half; + using InDtype = __bf16; + using OutDtype = __bf16; using AccumDtype = float; template @@ -872,7 +875,7 @@ public: __aicore__ explicit PpMatmulEinSum(){}; __aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, - const AtbOps::MlaTilingData &mlaParams); + const AtbOps::MlaTilingData &mlaParams, uint32_t mode); __aicore__ __force_inline__ void Process(); __aicore__ __force_inline__ void PreloadB(); @@ -918,26 +921,58 @@ private: uint32_t core_idx{0}; uint32_t en_shuffle_k = 0; uint32_t ping_flag{0}; + uint32_t MM_mode{0}; }; -template -__aicore__ __force_inline__ void PpMatmulEinSum::Init( - GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const AtbOps::MlaTilingData &mlaParams) +template +__aicore__ __force_inline__ void PpMatmulEinSum::Init( + GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const AtbOps::MlaTilingData &mlaParams, uint32_t mode) { #ifdef __DAV_C220_CUBE__ - batch_size = mlaParams.mm3.numBatch; - m = mlaParams.mm3.m; - k = mlaParams.mm3.k; - n = mlaParams.mm3.n; - m0 = mlaParams.mm3.m0; - k0 = mlaParams.mm3.k0; - n0 = mlaParams.mm3.n0; - tdim.m = mlaParams.mm3.mLoop; - tdim.k = mlaParams.mm3.kLoop; - tdim.n = mlaParams.mm3.nLoop; - core_loop = mlaParams.mm3.coreLoop; - swizzle_cnt = mlaParams.mm3.swizzleCount; - num_core = mlaParams.mm3.blockDim; + MM_mode = mode; + if (mode == 0) { + batch_size = mlaParams.mm1.numBatch; + m = mlaParams.mm1.m; + k = mlaParams.mm1.k; + n = mlaParams.mm1.n; + m0 = mlaParams.mm1.m0; + k0 = mlaParams.mm1.k0; + n0 = mlaParams.mm1.n0; + tdim.m = mlaParams.mm1.mLoop; + tdim.k = mlaParams.mm1.kLoop; + tdim.n = mlaParams.mm1.nLoop; + core_loop = mlaParams.mm1.coreLoop; + swizzle_cnt = mlaParams.mm1.swizzleCount; + num_core = mlaParams.mm1.blockDim; + } else if (mode == 1) { + batch_size = mlaParams.mm2.numBatch; + m = mlaParams.mm2.m; + k = mlaParams.mm2.k; + n = mlaParams.mm2.n; + m0 = mlaParams.mm2.m0; + k0 = mlaParams.mm2.k0; + n0 = mlaParams.mm2.n0; + tdim.m = mlaParams.mm2.mLoop; + tdim.k = mlaParams.mm2.kLoop; + tdim.n = mlaParams.mm2.nLoop; + core_loop = mlaParams.mm2.coreLoop; + swizzle_cnt = mlaParams.mm2.swizzleCount; + num_core = mlaParams.mm2.blockDim; + } else { // mode == 3 + batch_size = mlaParams.mm3.numBatch; + m = mlaParams.mm3.m; + k = mlaParams.mm3.k; + n = mlaParams.mm3.n; + m0 = mlaParams.mm3.m0; + k0 = mlaParams.mm3.k0; + n0 = mlaParams.mm3.n0; + tdim.m = mlaParams.mm3.mLoop; + tdim.k = mlaParams.mm3.kLoop; + tdim.n = mlaParams.mm3.nLoop; + core_loop = mlaParams.mm3.coreLoop; + swizzle_cnt = mlaParams.mm3.swizzleCount; + num_core = mlaParams.mm3.blockDim; + } core_idx = AscendC::GetBlockIdx(); ping_flag = 1; @@ -946,16 +981,16 @@ __aicore__ __force_inline__ void PpMatmulEinSum(gmC)); AsdopsBuffer buf; - l1_base_a = buf.template GetBuffer(0); - l1_base_b = buf.template GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); - l0a_base = buf.template GetBuffer(0); - l0b_base = buf.template GetBuffer(0); + l1_base_a = buf.template GetBuffer(0); + l1_base_b = buf.template GetBuffer(RoundUp(m0 * k0 * sizeof(InDtype))); + l0a_base = buf.template GetBuffer(0); + l0b_base = buf.template GetBuffer(0); #endif return; } -template -__aicore__ __force_inline__ void PpMatmulEinSum::GetBaseBlockIdx(uint64_t index, MatCoord &tidx) +template +__aicore__ __force_inline__ void PpMatmulEinSum::GetBaseBlockIdx(uint64_t index, MatCoord &tidx) { uint64_t in_batch_idx = index % (tdim.m * tdim.n); if constexpr (swizzleDirect == 0) { // Zn @@ -990,8 +1025,8 @@ __aicore__ __force_inline__ void PpMatmulEinSum -__aicore__ __force_inline__ void PpMatmulEinSum::PreloadB() +template +__aicore__ __force_inline__ void PpMatmulEinSum::PreloadB() { #ifdef __DAV_C220_CUBE__ uint64_t batch_idx = core_idx / tdim.n / tdim.m; @@ -1009,8 +1044,8 @@ __aicore__ __force_inline__ void PpMatmulEinSum -__aicore__ __force_inline__ uint64_t PpMatmulEinSum::GetOffsetB( +template +__aicore__ __force_inline__ uint64_t PpMatmulEinSum::GetOffsetB( const uint64_t batchIdx, const uint64_t kIdx, const uint64_t nIdx) { if constexpr (formatB == DataFormat::ND) { @@ -1030,8 +1065,8 @@ __aicore__ __force_inline__ uint64_t PpMatmulEinSum -__aicore__ __force_inline__ void PpMatmulEinSum::CopyTileA( +template +__aicore__ __force_inline__ void PpMatmulEinSum::CopyTileA( AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t m_actual, @@ -1049,19 +1084,30 @@ __aicore__ __force_inline__ void PpMatmulEinSum(dstTensor, // dst - srcTensor, // src - m_actual, // nTileActual - m_round, // nTileCeil - m, // nVal - k_actual, // dTileActual - k_round, // dTileCeil - (k + splitGapA) * batch_size); // dVal + if constexpr (ein) { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + (k + splitGapA) * batch_size); // dVal + } else { + CopyGmToCbuf(dstTensor, // dst + srcTensor, // src + m_actual, // nTileActual + m_round, // nTileCeil + m, // nVal + k_actual, // dTileActual + k_round, // dTileCeil + k); // dVal + } } } -template -__aicore__ __force_inline__ void PpMatmulEinSum::CopyTileB( +template +__aicore__ __force_inline__ void PpMatmulEinSum::CopyTileB( AscendC::LocalTensor &dstTensor, const AscendC::GlobalTensor &srcTensor, const uint64_t k_actual, @@ -1112,13 +1158,11 @@ __aicore__ __force_inline__ void PpMatmulEinSum -__aicore__ __force_inline__ void PpMatmulEinSum::Process() +template +__aicore__ __force_inline__ void PpMatmulEinSum::Process() { #ifdef __DAV_C220_CUBE__ if (block_idx >= num_core) { - WaitFlagDev(MM2OUT); - AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(BMM3SPLIT); return; } using LocalTensor = AscendC::LocalTensor; @@ -1136,7 +1180,12 @@ __aicore__ __force_inline__ void PpMatmulEinSum(m_actual); @@ -1144,7 +1193,11 @@ __aicore__ __force_inline__ void PpMatmulEinSum n_round ? m_round : n_round; uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16; uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0; - offset_a = tidx.m * m0 * batch_size * (k + splitGapA) + batch_idx * (k + splitGapA) + shuffle_k * k0; + if constexpr (ein) { + offset_a = tidx.m * m0 * batch_size * (k + splitGapA) + batch_idx * (k + splitGapA) + shuffle_k * k0; + } else { + offset_a = batch_idx * m * k + shuffle_k * k0 + tidx.m * m0 * k; + } uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0; uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16; @@ -1153,9 +1206,6 @@ __aicore__ __force_inline__ void PpMatmulEinSum(BMM3SPLIT); - // Copy A from gm to l1 buffer WAIT_FLAG(MTE1, MTE2, event_id); CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round); @@ -1177,8 +1227,13 @@ __aicore__ __force_inline__ void PpMatmulEinSum; if (core_idx >= core_num) { - if (MM1_MM2_mode == 0) { - WaitFlagDev(MM1); - } else if (MM1_MM2_mode == 1) { - WaitFlagDev(MM2QUANT); - } return; } SET_FLAG(MTE1, MTE2, EVENT_ID0); @@ -1789,15 +1852,6 @@ __aicore__ __force_inline__ void PpMatmulW8a8 class MLAOperation { - using qOutDtype = typename std::conditional_t; - using kNopeDtype = typename std::conditional_t; + using qOutDtype = typename std::conditional_t; + using kNopeDtype = typename std::conditional_t; public: __aicore__ inline MLAOperation(const AtbOps::MlaTilingData &mlaParams_, GM_ADDR tilingGm) { @@ -2088,49 +2142,49 @@ public: GM_ADDR s2Gm, GM_ADDR s3Gm) { - s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(s1Gm)); - wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm)); + s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(s1Gm)); + wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(wdqkvGm)); bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm)); descale1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t *>(descale1Gm)); - s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(s3Gm)); + s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(s3Gm)); #ifdef __DAV_C220_CUBE__ - mm_w8a8_1.Init(s1GmTensor, wdqkvGmTensor, bias1gmTensor, descale1gmTensor, s3GmTensor, mlaParams, 0); - mm_w8a8_1.PreloadDoubleWeight(); + mm_1.Init(hiddenStateGm, wdqkvGm, s3Gm, mlaParams, 0); + mm_1.PreloadB(); #endif hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(hiddenStateGm)); gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma1Gm)); quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale1Gm)); quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm)); - gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma2Gm)); + gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(gamma2Gm)); quantScale2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(quantScale2Gm)); quantScale3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gmCtkvScale)); quantOffset2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset2Gm)); - gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(gamma3Gm)); - sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(sin1Gm)); - cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cos1Gm)); - sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(sin2Gm)); - cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(cos2Gm)); + gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(gamma3Gm)); + sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(sin1Gm)); + cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(cos1Gm)); + sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(sin2Gm)); + cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(cos2Gm)); keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ kNopeDtype *>(keycacheOutGm)); - keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(keycacheOutGm2)); + keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(keycacheOutGm2)); slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm)); - wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wuqGm)); - wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(wukGm)); + wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(wuqGm)); + wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(wukGm)); descale2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t *>(descale2Gm)); - s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(s2Gm)); + s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(s2Gm)); qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ qOutDtype *>(qGm)); - qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(qGm2)); + qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ __bf16 *>(qGm2)); bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm)); beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta1Gm)); beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ half *>(beta2Gm)); #ifdef __DAV_C220_CUBE__ - mm_w8a8_2.Init(s1GmTensor, wuqGmTensor, bias2gmTensor, descale2gmTensor, s2GmTensor, mlaParams, 1); + mm_2.Init(s1Gm, wuqGm, s2Gm, mlaParams, 1); if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE){ - mm_ein_sum.Init(s2Gm, wukGm, s1Gm, mlaParams); + mm_ein_sum.Init(s2Gm, wukGm, s1Gm, mlaParams, 2); }else{ - mm_ein_sum.Init(s2Gm, wukGm, qGm, mlaParams); + mm_ein_sum.Init(s2Gm, wukGm, qGm, mlaParams, 2); } #endif @@ -2147,19 +2201,19 @@ public: row_work_ = 0; } this->splitN = mlaParams.perTaskNum; - rmsNormQuant1.Init(gamma1GmTensor, - beta1GmTensor, - quantScale1GmTensor, - quantOffset1GmTensor, - hiddenStateGmTensor, - s1GmTensor, - 0, - num_col_1, - 0.0001395089285, - vectorBlockIdx * static_cast(row_work) * num_col_1, - vectorBlockIdx * static_cast(row_work) * num_col_1, - row_work_, - mlaParams); + // rmsNormQuant1.Init(gamma1GmTensor, + // beta1GmTensor, + // quantScale1GmTensor, + // quantOffset1GmTensor, + // hiddenStateGmTensor, + // s1GmTensor, + // 0, + // num_col_1, + // 0.0001395089285, + // vectorBlockIdx * static_cast(row_work) * num_col_1, + // vectorBlockIdx * static_cast(row_work) * num_col_1, + // row_work_, + // mlaParams); rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, @@ -2175,7 +2229,7 @@ public: row_work_, mlaParams); ropeFp16.RopeInit(s2GmTensor, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams); - einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); + // einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams); ubTensor = buf.GetBuffer(0); ub8Tensor = buf.GetBuffer(0); ub32Tensor = buf.GetBuffer(0); @@ -2315,7 +2369,7 @@ private: AscendC::PipeBarrier(); Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, - AscendC::RoundMode::CAST_NONE, + AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_TWO); /* Rope K end */ // reshapeAndcache @@ -2401,27 +2455,27 @@ private: AscendC::GlobalTensor quantScale1GmTensor; AscendC::GlobalTensor quantOffset1GmTensor; - AscendC::GlobalTensor wdqkvGmTensor; - AscendC::GlobalTensor gamma2GmTensor; + AscendC::GlobalTensor<__bf16> wdqkvGmTensor; + AscendC::GlobalTensor<__bf16> gamma2GmTensor; AscendC::GlobalTensor quantScale2GmTensor; AscendC::GlobalTensor quantScale3GmTensor; AscendC::GlobalTensor quantOffset2GmTensor; - AscendC::GlobalTensor gamma3GmTensor; - AscendC::GlobalTensor sin1GmTensor; - AscendC::GlobalTensor cos1GmTensor; - AscendC::GlobalTensor sin2GmTensor; - AscendC::GlobalTensor cos2GmTensor; + AscendC::GlobalTensor<__bf16> gamma3GmTensor; + AscendC::GlobalTensor<__bf16> sin1GmTensor; + AscendC::GlobalTensor<__bf16> cos1GmTensor; + AscendC::GlobalTensor<__bf16> sin2GmTensor; + AscendC::GlobalTensor<__bf16> cos2GmTensor; AscendC::GlobalTensor keycacheGmTensor1; - AscendC::GlobalTensor keycacheGmTensor2; + AscendC::GlobalTensor<__bf16> keycacheGmTensor2; AscendC::GlobalTensor slotMappingGmTensor; - AscendC::GlobalTensor wuqGmTensor; - AscendC::GlobalTensor wukGmTensor; + AscendC::GlobalTensor<__bf16> wuqGmTensor; + AscendC::GlobalTensor<__bf16> wukGmTensor; AscendC::GlobalTensor qGmTensor; - AscendC::GlobalTensor qGmTensor2; - AscendC::GlobalTensor s1GmTensor; - AscendC::GlobalTensor s2GmTensor; - AscendC::GlobalTensor s3GmTensor; + AscendC::GlobalTensor<__bf16> qGmTensor2; + AscendC::GlobalTensor<__bf16> s1GmTensor; + AscendC::GlobalTensor<__bf16> s2GmTensor; + AscendC::GlobalTensor<__bf16> s3GmTensor; AscendC::GlobalTensor descale1gmTensor; AscendC::GlobalTensor descale2gmTensor; AscendC::GlobalTensor beta1GmTensor; @@ -2431,17 +2485,19 @@ private: AscendC::GlobalTensor bias2gmTensor; #ifdef __DAV_C220_CUBE__ - PpMatmulW8a8 mm_w8a8_1; - PpMatmulW8a8 mm_w8a8_2; + // PpMatmulW8a8 mm_w8a8_1; + // PpMatmulW8a8 mm_w8a8_2; + PpMatmulEinSum mm_1; + PpMatmulEinSum mm_2; static constexpr uint64_t splitGapC = cacheMode == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0; - PpMatmulEinSum mm_ein_sum; + PpMatmulEinSum mm_ein_sum; #endif #ifdef __DAV_C220_VEC__ - RmsNormQuant rmsNormQuant1; - RmsNormQuant rmsNormQuant2; - RopeFp16 ropeFp16; - EinSumQuant einSumQuant; + // RmsNormQuant<__bf16, true, false> rmsNormQuant1; + RmsNormQuant<__bf16, true, false> rmsNormQuant2; + RopeFp16<__bf16, __bf16, qOutDtype, cacheMode> ropeFp16; + // EinSumQuant einSumQuant; #endif }; @@ -2449,15 +2505,22 @@ template::ProcessCube() { #ifdef __DAV_C220_CUBE__ - mm_w8a8_1.Process(); + // mm_w8a8_1.Process(); + // WaitFlagDev(MM1); + mm_1.Process(); FftsCrossCoreSync(RMSNORMQUANT2); WaitFlagDev(RMSNORMQUANT2); AscendC::CrossCoreSetFlag<0x2, PIPE_FIX>(MM1QUANT); - mm_w8a8_2.PreloadDoubleWeight(); - mm_w8a8_2.Process(); + // mm_w8a8_2.PreloadDoubleWeight(); + // mm_w8a8_2.Process(); + mm_2.PreloadB(); + WaitFlagDev(MM2QUANT); + mm_2.Process(); FftsCrossCoreSync(MM2OUT); mm_ein_sum.PreloadB(); + WaitFlagDev(MM2OUT); + FftsCrossCoreSync(BMM3SPLIT); mm_ein_sum.Process(); if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { FftsCrossCoreSync(EINSUMOUT); @@ -2471,6 +2534,7 @@ template::ProcessVector() { #ifdef __DAV_C220_VEC__ + /* if (row_work_ != 0) { uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; @@ -2502,14 +2566,15 @@ __aicore__ inline void MLAOperation(RMSNORMQUANT1); WaitFlagDev(RMSNORMQUANT1); AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(MM1); + */ WaitFlagDev(MM1QUANT); if (row_work_ != 0) { uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256; uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128; uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64; - AscendC::LocalTensor input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor<__bf16> input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor<__bf16> gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); AscendC::LocalTensor beta_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2); AscendC::LocalTensor scale_tensor = @@ -2537,11 +2602,11 @@ __aicore__ inline void MLAOperation(MM2QUANT); if (row_work_ != 0) { - AscendC::LocalTensor input_tensor = buf.GetBuffer(0); - AscendC::LocalTensor gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); - AscendC::LocalTensor sin_tensor = - buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); - AscendC::LocalTensor cos_tensor = buf.GetBuffer( + AscendC::LocalTensor<__bf16> input_tensor = buf.GetBuffer(0); + AscendC::LocalTensor<__bf16> gamma_tensor = buf.GetBuffer(MM1_OUT_SIZE * 2); + AscendC::LocalTensor<__bf16> sin_tensor = + buf.GetBuffer(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2); + AscendC::LocalTensor<__bf16> cos_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2); AscendC::LocalTensor slotMapping_tensor = buf.GetBuffer( MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4); @@ -2550,28 +2615,28 @@ __aicore__ inline void MLAOperation temp_tensor = buf.GetBuffer(out_ub_offset); + AscendC::LocalTensor<__bf16> temp_tensor = buf.GetBuffer(out_ub_offset); AscendC::LocalTensor tmpfp16; AscendC::LocalTensor int8OutTensor; float scale3 = 0; - if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { - //quantScale3 - AscendC::LocalTensor quantScaleTensor = buf.GetBuffer(rms3_ub_offset); - AscendC::LocalTensor floatQuantScaleTensor = buf.GetBuffer(rms3_ub_offset + 32); - //int8out - tmpfp16 = buf.GetBuffer(rms3_ub_offset + SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); - int8OutTensor = buf.GetBuffer(out_ub_offset); - AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); - SET_FLAG(MTE2, V, EVENT_ID1); - WAIT_FLAG(MTE2, V, EVENT_ID1); - Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); - AscendC::SetFlag(EVENT_ID1); - AscendC::WaitFlag(EVENT_ID1); - scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0)); - } - - RmsNormAndRopeConvergence1( + // if constexpr (cacheMode == CACHE_MODE_INT8_NZCACHE) { + // //quantScale3 + // AscendC::LocalTensor quantScaleTensor = buf.GetBuffer(rms3_ub_offset); + // AscendC::LocalTensor floatQuantScaleTensor = buf.GetBuffer(rms3_ub_offset + 32); + // //int8out + // tmpfp16 = buf.GetBuffer(rms3_ub_offset + SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2); + // int8OutTensor = buf.GetBuffer(out_ub_offset); + // AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0)); + // SET_FLAG(MTE2, V, EVENT_ID1); + // WAIT_FLAG(MTE2, V, EVENT_ID1); + // Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1); + // AscendC::SetFlag(EVENT_ID1); + // AscendC::WaitFlag(EVENT_ID1); + // scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0)); + // } + + RmsNormAndRopeConvergence1<__bf16>( input_tensor, // n * 576 gamma_tensor, // gamma sin_tensor, // sin @@ -2592,11 +2657,11 @@ __aicore__ inline void MLAOperation op(mlaTilingData, tilingGm); + op.Init(hiddenStateGm, gamma1Gm, beta1Gm, quantScale1Gm, quantOffset1Gm, wdqkvGm, bias1Gm, gamma2Gm, beta2Gm, + quantScale2Gm, quantOffset2Gm, gamma3Gm, sin1Gm, cos1Gm, sin2Gm, cos2Gm, keycacheGm, slotMappingGm, + wuqGm, bias2Gm, wukGm, descale1Gm, descale2Gm, gmCtkvScale, gmQnopeScale, + qGm, keycacheOutGm, qGm2, keycacheOutGm2, s1Gm, s2Gm, s3Gm); + op.ProcessCube(); + op.ProcessVector(); + } else if (TILING_KEY_IS(152)) { + MLAOperation op(mlaTilingData, tilingGm); + op.Init(hiddenStateGm, gamma1Gm, beta1Gm, quantScale1Gm, quantOffset1Gm, wdqkvGm, bias1Gm, gamma2Gm, beta2Gm, + quantScale2Gm, quantOffset2Gm, gamma3Gm, sin1Gm, cos1Gm, sin2Gm, cos2Gm, keycacheGm, slotMappingGm, + wuqGm, bias2Gm, wukGm, descale1Gm, descale2Gm, gmCtkvScale, gmQnopeScale, + qGm, keycacheOutGm, qGm2, keycacheOutGm2, s1Gm, s2Gm, s3Gm); + op.ProcessCube(); + op.ProcessVector(); } -} \ No newline at end of file +} diff --git a/src/kernels/mixkernels/mla_preprocess/tiling/mla_preprocess_tiling.cpp b/src/kernels/mixkernels/mla_preprocess/tiling/mla_preprocess_tiling.cpp index 9fe26640a189f3dac3169b9df9a863fbc214789c..8f7e6d19359d39170dd5ec1d2826f9c761cf9710 100644 --- a/src/kernels/mixkernels/mla_preprocess/tiling/mla_preprocess_tiling.cpp +++ b/src/kernels/mixkernels/mla_preprocess/tiling/mla_preprocess_tiling.cpp @@ -8,6 +8,7 @@ * See LICENSE in the root of the software repository for the full text of the License. */ +#include #include "mla_preprocess_tiling.h" #include "atbops/params/params.h" #include "mki/utils/assert/assert.h" @@ -33,7 +34,7 @@ constexpr uint32_t L1_SCALE_SIZE = 4096; constexpr uint32_t L1_BIAS_SIZE = 2048; constexpr uint32_t L0C_SIZE = 128 * 1024; constexpr uint32_t CONCAT_SIZE = 512; -constexpr uint32_t HIDDEN_STRATE = 7168; +constexpr uint32_t HIDDEN_STRATE = 6144; constexpr uint32_t HIDDEN_STRATE_ROPE = 192; constexpr uint32_t HIDDEN_STRATE_MM = 2112; constexpr uint32_t HIDDEN_STRATE_RMS = 1536; @@ -147,6 +148,7 @@ void PpMatmulTilingApi::GetTilingData(AtbOps::PpMatmulTilingData &tiling) tiling.swizzleDirect = swizzleDirect_; tiling.enShuffleK = static_cast(enShuffleK_); tiling.blockDim = blockDim_; + std::cout << "12345678: " << blockDim_ << std::endl; tiling.enLoadAllAmat = static_cast(enLoadAllAmat_); tiling.b0matPingPongBufferLen = b0matPingPongBufferLen_; } @@ -403,13 +405,12 @@ void MlaPreprocessTiling::EinSumQuantTiling(const OpParam::MlaPreprocess ¶m, uint32_t repeatMask = 0; if (inDtype == TENSOR_DTYPE_BF16 || param.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { - // 将scale一次性搬入、广播、缓存 H * 32bytes - uint32_t scaleUb = RoundUp(esqHeadNum) * CONST_32; - // bf16 input [H', colNum](f16 + fp32 + int8), ub reuse - splitFactor = esqColNum * (sizeof(uint16_t) + sizeof(float) + sizeof(uint8_t)); - splitFactor *= NUM2; - esqHeadPerLoop = (ubSize - scaleUb) / splitFactor; // 26 - repeatMask = FP32_REPEAT_MASK; + // fp16 input [H', cloNum](fp16*2 + int8) + [H', 1](fp16) + [H', 16](fp16) + splitFactor = + esqColNum * (NUM2 * sizeof(uint16_t) + sizeof(uint8_t)) + sizeof(uint16_t) + (CONST_16 * sizeof(uint16_t)); + esqHeadPerLoop = ubSize / splitFactor; + repeatMask = FP16_REPEAT_MASK; + esqHeadPerLoop = RoundDown(esqHeadPerLoop); // 向下16对齐 } else { // fp16 input [H', cloNum](fp16*2 + int8) + [H', 1](fp16) + [H', 16](fp16) splitFactor = @@ -566,7 +567,7 @@ void MlaPreprocessTiling::SetMlapoWorkSpace(const TensorDType inDtype, const OpP uint64_t workSizeS3 = static_cast(tilingData.n) * HIDDEN_STRATE_MM * sizeof(uint16_t); uint64_t workSizeS4 = static_cast(tilingData.n) * std::max(param.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t); - uint64_t pertokenWorkspace = static_cast(tilingData.n) * sizeof(float) * 2; + // uint64_t pertokenWorkspace = static_cast(tilingData.n) * sizeof(float) * 2; uint64_t maxWorkspaceSize = 0; maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS1); maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2); @@ -574,7 +575,7 @@ void MlaPreprocessTiling::SetMlapoWorkSpace(const TensorDType inDtype, const OpP maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS4); if (inDtype == TENSOR_DTYPE_BF16 || param.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { kernelInfo.GetScratchSizes() = { - maxWorkspaceSize, maxWorkspaceSize, maxWorkspaceSize, maxWorkspaceSize, pertokenWorkspace}; + maxWorkspaceSize, maxWorkspaceSize, maxWorkspaceSize}; } else { kernelInfo.GetScratchSizes() = { maxWorkspaceSize, maxWorkspaceSize, maxWorkspaceSize}; @@ -589,9 +590,9 @@ Mki::Status MlaPreprocessTiling::Init(const Mki::LaunchParam &launchParam, Mki:: tilingData.n = param.N; tilingData.numCore = aicNum; bool deqOnTheFly = false; - if (inDtype == TENSOR_DTYPE_BF16 || param.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { - deqOnTheFly = true; - } + // if (inDtype == TENSOR_DTYPE_BF16 || param.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) { + // deqOnTheFly = true; + // } MKI_LOG(INFO) << "tilingSize: " << kernelInfo.GetTilingSize(); auto tilingParam = reinterpret_cast(kernelInfo.GetTilingHostAddr()); RmsNormQuantTiling(param.N); @@ -603,7 +604,7 @@ Mki::Status MlaPreprocessTiling::Init(const Mki::LaunchParam &launchParam, Mki:: HIDDEN_STRATE_MM, // n false, // transA true, // transB - true, // enDequant + false, // enDequant deqOnTheFly); // in bf16.cce? mm1TilingApi.GetTilingData(tilingParam->mm1); PpMatmulTilingApi mm2TilingApi(1, // numBatch @@ -612,7 +613,7 @@ Mki::Status MlaPreprocessTiling::Init(const Mki::LaunchParam &launchParam, Mki:: param.headNum * HIDDEN_STRATE_ROPE, // n false, // transA true, // transB - true, // enDequant + false, // enDequant deqOnTheFly); // in bf16.cce? mm2TilingApi.GetTilingData(tilingParam->mm2); PpMatmulTilingApi mm3TilingApi(param.headNum, // numBatch diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index 090d740b31d09d9b7dc81e3cf80485e7ff73650d..d454b18c9d6db0aad9d3e5295e2dd876c318aca1 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -30,7 +30,7 @@ static const uint32_t KVCACHE_INDEX = 19; static const uint32_t KVCACHE_ROPE_INDEX = 20; static const uint32_t QOUT1_INDEX = 2; static const uint32_t KVCACHEOUT1_INDEX = 3; -static const uint32_t INNER_DIM_7168 = 7168; +static const uint32_t INNER_DIM_7168 = 6144; static const uint32_t INNER_DIM_2112 = 2112; static const uint32_t INNER_DIM_1536 = 1536; static const uint32_t INNER_DIM_576 = 576; diff --git a/tests/apitest/kernelstest/mix/test_mla_preprocess.py b/tests/apitest/kernelstest/mix/test_mla_preprocess.py index ea5a5ef5efd09639290e267112257f80771c1933..0165ef0067cc923fb7ad9269bfbd70ec7da4c2dc 100644 --- a/tests/apitest/kernelstest/mix/test_mla_preprocess.py +++ b/tests/apitest/kernelstest/mix/test_mla_preprocess.py @@ -12,8 +12,7 @@ import numpy as np import torch import torch_npu import torch.nn.functional as F -import sys,os -sys.path.append(os.path.join(os.path.dirname(__file__), "../")) +import sys, os import op_test import random import logging @@ -22,6 +21,7 @@ sys.path.append("../") sys.path.append("../..") from precision_calcu import * +torch.set_printoptions(threshold=float('inf')) OP_NAME = "MlaPreprocessOperation" QUANTMAX = 127 QUANTMIN = -128 @@ -171,7 +171,7 @@ class TestMLAPrepross(op_test.OpTest): deScale = torch.ones((dim1), dtype=torch.float32) if self.dtype == torch.float16: deScale = process_deq_scale(deScale) - in_tensors = [intensors[0], intensors[1], bias, deScale, torch.Tensor()] + in_tensors = [intensors[0], intensors[1], bias, deScale] self.set_param( "MatMulOperation", { @@ -184,7 +184,9 @@ class TestMLAPrepross(op_test.OpTest): ) in_tensors_npu = [tensor.npu() for tensor in in_tensors] out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in outtensors] + self.__set_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) self.mki.execute(in_tensors_npu, out_tensors_npu) + self.__unset_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) mm1Out1 = out_tensors_npu[0].to(torch.int32).to(torch.float32) # Mul perChannelDescale self.mm1_mul_descale_out = torch.zeros((self.input_token_num, dim1), dtype=torch.float32).npu() @@ -240,18 +242,21 @@ class TestMLAPrepross(op_test.OpTest): beta: torch.Tensor, quantScale: torch.Tensor, quantOffset: torch.Tensor, + normNeeded: bool ): + if not normNeeded: + return input.to(torch.bfloat16) out_shape = input.shape scale = 1.0 / quantScale.float().item() offset = quantOffset.float().item() square_sum = torch.sum(torch.square(input.float()), axis=-1, keepdims=True) factor = 1.0 / torch.sqrt(square_sum / out_shape[-1] + self.epsilon) output = input.float() * factor * gamma.float() - output = (output + beta.float()) * scale + offset - output = torch.round(output).half() - output = torch.min(output, torch.tensor(QUANTMAX, dtype=torch.half)) - output = torch.max(output, torch.tensor(QUANTMIN, dtype=torch.half)).to(torch.int8) - return output + # output = (output + beta.float()) * scale + offset + # output = torch.round(output).half() + # output = torch.min(output, torch.tensor(QUANTMAX, dtype=torch.half)) + # output = torch.max(output, torch.tensor(QUANTMIN, dtype=torch.half)).to(torch.int8) + return output.to(torch.bfloat16) def ein_sum_out_quant_golden(self, input, scale): if (scale.dtype == torch.bfloat16): @@ -264,7 +269,7 @@ class TestMLAPrepross(op_test.OpTest): return output.to(torch.int8) def calc_vec_mm_atb_data(self, N, headNum, data_type, cacheMode, quant_mode=0): - hiddenStrate = 7168 + hiddenStrate = 6144 blockNum = 192 blockSize = 128 headdim = 576 @@ -275,19 +280,17 @@ class TestMLAPrepross(op_test.OpTest): self.epsilon = 1e-6 self.dtype = data_type - self.input1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(N, 7168))).to(data_type) # + self.input1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(N, 6144))).to(data_type) # self.gamma1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(hiddenStrate))).to(data_type) self.quantScale1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) self.quantOffset1 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) - self.wdqkv = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(2112, 7168))).to(torch.int8) # - + self.wdqkv = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(2112, 6144))).to(data_type) # torch.int8 self.deScale1 = torch.rand((2112), dtype=torch.float32) / 1000 self.gamma2 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(1536))).to(data_type) self.quantScale2 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) self.quantOffset2 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) - self.wuq = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum * 192, 1536))).to(torch.int8) # - + self.wuq = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum * 192, 1536))).to(data_type) #torch.int8 self.deScale2 = torch.rand((headNum * 192), dtype=torch.float32) / 1000 self.gamma3 = torch.rand(size=(512,)).to(data_type) @@ -318,43 +321,47 @@ class TestMLAPrepross(op_test.OpTest): self.qNopeScale = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(1, headNum, 1))).to(data_type) self.quantScale3 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) - self.calc_vec_mm_data(N, headNum, data_type, quant_mode) + self.calc_vec_mm_data(N, headNum, data_type, quant_mode) #cpu 结果? +#--------------小算子融合---------------------- # RmsNorm - self.rms1Out1_npu = torch.zeros((N, 7168), dtype=torch.int8) - npu_device = self.__get_npu_device() - torch_npu.npu.set_device(npu_device) - in_tensors = [self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1] - if quant_mode == 0: - out_tensors = [self.rms1Out1_npu] - OP_NAME = "NormOperation" - OP_PARAM0 = {"normType": 2, "inGamma": True, "inBeta": True, "epsilon": 1e-6} - self.set_param(OP_NAME, OP_PARAM0) - in_tensors_npu = [tensor.npu() for tensor in in_tensors] - out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] - self.mki.execute(in_tensors_npu, out_tensors_npu) - else: - out_tensors_npu = self.rmsNormPerTokenNpuGolden(in_tensors) - pertoken_descale = out_tensors_npu[1].unsqueeze(1).expand(-1, 2112).npu() - + # self.rms1Out1_npu = torch.zeros((N, 6144), dtype=torch.int8) + # npu_device = self.__get_npu_device() + # torch_npu.npu.set_device(npu_device) + # in_tensors = [self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1] + # if quant_mode == 0: + # out_tensors = [self.rms1Out1_npu] + # OP_NAME = "NormOperation" + # OP_PARAM0 = {"normType": 2, "inGamma": True, "inBeta": True, "epsilon": 1e-6} + # self.set_param(OP_NAME, OP_PARAM0) + # in_tensors_npu = [tensor.npu() for tensor in in_tensors] + # out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] + # self.mki.execute(in_tensors_npu, out_tensors_npu) + # else: + # out_tensors_npu = self.rmsNormPerTokenNpuGolden(in_tensors) + # pertoken_descale = out_tensors_npu[1].unsqueeze(1).expand(-1, 2112).npu() + + # Ppmatmul if quant_mode == 0: self.mm1Out1_npu = torch.zeros((N, 2112), dtype=data_type) - in_tensors = [out_tensors_npu[0], self.wdqkv, self.bias1, self.deScale1, torch.Tensor()] + in_tensors = [self.input1, self.wdqkv] out_tensors = [self.mm1Out1_npu] self.set_param( "MatMulOperation", { "transposeA": False, "transposeB": True, - "withBias": True, - "enDequant": True, + "withBias": False, + "enDequant": False, "outDtype": 27 if data_type == torch.bfloat16 else 1, }, ) in_tensors_npu = [tensor.npu() for tensor in in_tensors] out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] + self.__set_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) self.mki.execute(in_tensors_npu, out_tensors_npu) + self.__unset_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) self.mm1Out1_npu = out_tensors_npu[0].cpu().clone() else: in_tensors = [out_tensors_npu[0], self.wdqkv, self.bias1, self.deScale1, pertoken_descale] @@ -377,17 +384,20 @@ class TestMLAPrepross(op_test.OpTest): self.mki.execute(in_tensors_npu, Split1_out_tensors_npu) # RmsNorm - in_tensors = [Split1_out_tensors_npu[2].cpu(), self.gamma2, self.beta2, self.quantScale2, self.quantOffset2] + in_tensors = [Split1_out_tensors_npu[2].cpu(), self.gamma2, self.quantScale2, self.quantOffset2] if quant_mode == 0: - self.rms2Out_npu = torch.zeros((N, 1536), dtype=torch.int8) - in_tensors = [Split1_out_tensors_npu[2], self.gamma2, self.beta2, self.quantScale2, self.quantOffset2] + #self.rms2Out_npu = torch.zeros((N, 1536), dtype=torch.int8) + self.rms2Out_npu = torch.zeros((N, 1536), dtype=data_type) + in_tensors = [Split1_out_tensors_npu[2], self.gamma2] out_tensors = [self.rms2Out_npu] OP_NAME = "NormOperation" - OP_PARAM0 = {"normType": 2, "inGamma": True, "inBeta": True, "epsilon": 1e-6} + OP_PARAM0 = {"normType": 2, "inGamma": True, "inBeta": False, "epsilon": 1e-6} #, "gemmaMode": 1 + #OP_PARAM0 = {"normType": 2, "inGamma": True, "epsilon": 1e-6} self.set_param(OP_NAME, OP_PARAM0) in_tensors_npu = [tensor.npu() for tensor in in_tensors] out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] self.mki.execute(in_tensors_npu, out_tensors_npu) + else: out_tensors_npu = self.rmsNormPerTokenNpuGolden(in_tensors) pertoken_descale = out_tensors_npu[1].unsqueeze(1).expand(-1, headNum * 192).npu() @@ -395,21 +405,23 @@ class TestMLAPrepross(op_test.OpTest): # Ppmatmul if quant_mode == 0: self.mm2Out_npu = torch.zeros((N, headNum * 192), dtype=data_type) - in_tensors = [out_tensors_npu[0], self.wuq, self.bias2, self.deScale2, torch.Tensor()] + in_tensors = [out_tensors_npu[0], self.wuq] out_tensors = [self.mm2Out_npu] self.set_param( "MatMulOperation", { "transposeA": False, "transposeB": True, - "withBias": True, - "enDequant": True, + "withBias": False, + "enDequant": False, "outDtype": 27 if data_type == torch.bfloat16 else 1, }, ) in_tensors_npu = [tensor.npu() for tensor in in_tensors] mm2out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() for i in out_tensors] + self.__set_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) self.mki.execute(in_tensors_npu, mm2out_tensors_npu) + self.__unset_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) self.mm2Out_npu = mm2out_tensors_npu[0].cpu().clone() else: in_tensors = [out_tensors_npu[0], self.wuq, self.bias2, self.deScale2, pertoken_descale] @@ -450,7 +462,9 @@ class TestMLAPrepross(op_test.OpTest): Einsumout_tensors = [torch.zeros((N, headNum, 512), dtype=data_type)] in_tensors_npu = [tensor.npu() for tensor in in_tensors] Einsumout_tensors_npu = [tensor.npu() for tensor in Einsumout_tensors] + self.__set_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) self.mki.execute(in_tensors_npu, Einsumout_tensors_npu) + self.__unset_envs({"ASDOPS_MATMUL_PP_FLAG": "1"}) # Einsumout_tensors_npu = [torch.transpose(Einsumout_tensors_npu[0], 0, 1)] self.einsumOut = Einsumout_tensors_npu[0].cpu().clone() @@ -544,7 +558,6 @@ class TestMLAPrepross(op_test.OpTest): self.set_param(OP_NAME, OP_PARAM) self.mki.execute(in_tensors_npu, out_tensors_npu) self.keyout_npu = out_tensors_npu[0].cpu().clone() - def reshapeAndCache(self, input, keycache, slotMapping,num): keycache = keycache.reshape(-1, num) input = input.reshape(-1,num) @@ -583,16 +596,19 @@ class TestMLAPrepross(op_test.OpTest): return s8_res_cal def calc_vec_mm_data(self, N, headNum, data_type, quant_mode): - if quant_mode == 0: - mm1In = self.rms_norm_quant_calc(self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1) - else: - [mm1In, perTokenDescale1] = self.rmsNormPerTokenGolden([self.input1, self.gamma1, self.beta1]) + # if quant_mode == 0: + # mm1In = self.rms_norm_quant_calc(self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1, False) + # else: + # [mm1In, perTokenDescale1] = self.rmsNormPerTokenGolden([self.input1, self.gamma1, self.beta1]) + + # self.rmsquantOut1 = mm1In.clone() + mm1Out = torch.matmul(self.input1.to(torch.float32), self.wdqkv.transpose(0, 1).to(torch.float32)) + - self.rmsquantOut1 = mm1In.clone() - mm1Out = torch.matmul(mm1In.to(torch.float32), self.wdqkv.transpose(0, 1).to(torch.float32)) if quant_mode == 0: - mm1Out = mm1Out.to(torch.int32) + self.bias1 - mm1Out = (mm1Out.to(torch.float32) * self.deScale1).to(data_type) + # mm1Out = mm1Out.to(torch.int32) + self.bias1 + # mm1Out = (mm1Out.to(torch.float32) * self.deScale1).to(data_type) + mm1Out.to(data_type) else: perTokenDescale1 = perTokenDescale1.unsqueeze(1).expand(-1, 2112) mm1Out = (mm1Out.to(torch.float32) * self.deScale1 * perTokenDescale1).to(data_type) @@ -602,15 +618,17 @@ class TestMLAPrepross(op_test.OpTest): mm1OutSplit1, mm1OutSplit2 = torch.split(mm1Out, [576, 1536], dim=1) if quant_mode == 0: rms2Out = self.rms_norm_quant_calc( - mm1OutSplit2, self.gamma2, self.beta2, self.quantScale2, self.quantOffset2 + mm1OutSplit2, self.gamma2, self.beta2, self.quantScale2, self.quantOffset2, True ) else: [rms2Out, perTokenDescale2] = self.rmsNormPerTokenGolden([mm1OutSplit2, self.gamma2, self.beta2]) self.rmsquantOut2 = rms2Out.clone() mm2Out = torch.matmul(rms2Out.to(torch.float32), self.wuq.transpose(0, 1).to(torch.float32)) if quant_mode == 0: - mm2Out = mm2Out.to(torch.int32) + self.bias2 - mm2Out = (mm2Out.to(torch.float32) * self.deScale2).to(data_type) + # mm2Out = mm2Out.to(torch.int32) + self.bias2 + # mm2Out = (mm2Out.to(torch.float32) * self.deScale2).to(data_type) + mm2Out.to(data_type) + else: perTokenDescale2 = perTokenDescale2.unsqueeze(1).expand(-1, headNum * 192) mm2Out = (mm2Out.to(torch.float32) * self.deScale2 * perTokenDescale2).to(data_type) @@ -633,8 +651,7 @@ class TestMLAPrepross(op_test.OpTest): ) self.gg = mm2OutSplit2.clone() qOut = self.RopeConcatGolden(mm2OutSplit2.reshape(N, headNum * 64), self.sin2, self.cos2, self.bmmOut) - self.qOut = qOut.to(data_type) - + self.qOut = qOut.to(data_type) #cpu的结果 def golden_calc(self, in_tensors): return [self.qOut, self.keyOut1] @@ -642,7 +659,7 @@ class TestMLAPrepross(op_test.OpTest): out = tensor1.flatten() out_len = out.shape[0] golden = tensor2.flatten() - diff = torch.abs(golden - out) + diff = torch.abs(golden - out) max_diff = diff.max().item() logging.info(f"maxDiff {max_diff}") golden = golden.to(torch.float32) @@ -663,10 +680,11 @@ class TestMLAPrepross(op_test.OpTest): self.compare_data(self.qOut_npu.flatten(), out_tensors[0].flatten()) result_double2 = compare_cv(self.keyout_npu, self.keyOut1, out_tensors[1]) or\ self.compare_data(self.keyout_npu.flatten(), out_tensors[1].flatten()) - return result_double1 and result_double2 + return result_double1 and result_double2 elif self.op_desc["specificParam"]["cacheMode"] == 1: result_double1 = compare_cv(self.qOut_npu[..., 0:512], self.qOut[..., 0:512], out_tensors[0]) or\ self.compare_data(self.qOut_npu[..., 0:512].flatten(), out_tensors[0].flatten()) + result_double2 = compare_cv(self.keyout_npu[..., 0:512], self.keyOut1[..., 0:512], out_tensors[1]) or\ self.compare_data(self.keyout_npu[..., 0:512].flatten(), out_tensors[1].flatten()) result_double3 = compare_cv(self.qOut_npu[..., 512:576], self.qOut[..., 512:576], out_tensors[2]) or\ @@ -675,15 +693,14 @@ class TestMLAPrepross(op_test.OpTest): self.compare_data(self.keyout_npu[..., 512:576].flatten(), out_tensors[3].flatten()) return result_double1 and result_double2 and result_double3 and result_double4 elif self.op_desc["specificParam"]["cacheMode"] == 2: - max_diff = torch.max(torch.abs(out_tensors[0].flatten() - self.qOutNopeQuant.flatten())) + diff = out_tensors[0].flatten() - self.qOutNopeQuant.flatten() + max_diff = torch.max(torch.abs(diff)) result_double1 = max_diff <= 1 - - max_diff = torch.max(torch.abs(self.keyCache1_out.flatten() - out_tensors[1].flatten())) - result_double2 = max_diff <= 1 - result_double3 = compare_cv(self.qOut_npu[..., 512:576], self.qOut[..., 512:576], out_tensors[2]) or\ self.compare_data(self.qOut_npu[..., 512:576].flatten(), out_tensors[2].flatten()) - + diff = self.keyCache1_out.flatten() - out_tensors[1].flatten() + max_diff = torch.max(torch.abs(diff)) + result_double2 = max_diff <= 1 result_double4 = self.compare_data(self.keyCache2_out.flatten(), out_tensors[3].flatten()) return result_double1 and result_double2 and result_double3 and result_double4 elif self.op_desc["specificParam"]["cacheMode"] == 3: @@ -703,7 +720,7 @@ class TestMLAPrepross(op_test.OpTest): num_heads: int, cache_mode: int, quant_mode: int, - weight_format: int, + weight_format: torch.dtype, #int 改成torch.dtype ) -> None: self.calc_vec_mm_atb_data(num_tokens, num_heads, data_type, cache_mode, quant_mode) self.set_param( @@ -745,7 +762,7 @@ class TestMLAPrepross(op_test.OpTest): self.beta1, self.quantScale1, self.quantOffset1, - transdata(self.wdqkv, (16, 32)), # self.wdqkv, + transdata(self.wdqkv, (16, 16)), # self.wdqkv, self.bias1, self.gamma2, self.beta2, @@ -758,7 +775,7 @@ class TestMLAPrepross(op_test.OpTest): self.cos2, self.keyCache, self.slotMapping, - transdata(self.wuq, (16, 32)), # self.wuq, + transdata(self.wuq, (16, 16)), # self.wuq, self.bias2, self.wuk if weight_format == self.format_nd else transdata_3d(self.wuk), self.deScale1, @@ -775,10 +792,20 @@ class TestMLAPrepross(op_test.OpTest): ] elif cache_mode == 1: out_tensors = [ + torch.zeros_like(self.qOut[..., :512], dtype=data_type), self.keyOutTensor[..., :512], torch.zeros_like(self.qOut[..., 512:], dtype=data_type), self.keyOutTensor[..., 512:], + #torch.zeros_like(self.qOut[..., :512], dtype=data_type), + #self.keyOutTensor[..., :512], + + #torch.zeros_like(self.keyOutTensor[..., :512], dtype=data_type), + #torch.zeros_like(self.qOut[..., 512:], dtype=data_type), + #self.keyOutTensor[..., 512:], + # torch.zeros((num_tokens, 2112)).to(data_type) + #torch.zeros_like(self.keyOutTensor[..., 512:], dtype=data_type), + #torch.ones_like(self.keyOutTensor[..., 512:], dtype=data_type), ] elif cache_mode == 2: out_tensors = [ @@ -808,13 +835,13 @@ class TestMLAPrepross(op_test.OpTest): self.__test_mlapo_impl(torch.float16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b - def test_mla_preprocess_fp16_cm1_qm0_nd_case1(self): + def test_mla_preprocess_fp16_cm1_qm0_nd_case1(self): #cx num_tokens = 32 num_heads = 32 cache_mode = 1 quant_mode = 0 weight_format = self.format_nd - self.__test_mlapo_impl(torch.float16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) + self.__test_mlapo_impl(torch.bfloat16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b def test_mla_preprocess_fp16_cm2_qm0_nd_case1(self): @@ -877,10 +904,10 @@ class TestMLAPrepross(op_test.OpTest): cache_mode = 1 quant_mode = 0 weight_format = self.format_nd - self.__test_mlapo_impl(torch.bfloat16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) + self.__test_mlapo_impl(torch.float16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b - def test_mla_preprocess_bf16_cm2_qm0_nd_case1(self): + def test_mla_preprocess_bf16_cm2_qm0_nd_case1(self): #wxh num_tokens = 32 num_heads = 32 cache_mode = 2 @@ -898,16 +925,16 @@ class TestMLAPrepross(op_test.OpTest): self.__test_mlapo_impl(torch.bfloat16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b - def test_mla_preprocess_bf16_cm1_qm0_nd_case2(self): - num_tokens = 1024 + def test_mla_preprocess_bf16_cm1_qm0_nd_case2(self): #test + num_tokens = 32 num_heads = 32 - cache_mode = 1 + cache_mode = 0 quant_mode = 0 weight_format = self.format_nd self.__test_mlapo_impl(torch.bfloat16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b - def test_mla_preprocess_bf16_cm2_qm0_nd_case2(self): + def test_mla_preprocess_bf16_cm2_qm0_nd_case2(self): #wxh cache_mode =2 和3的应该不测 num_tokens = 1024 num_heads = 32 cache_mode = 2 @@ -916,7 +943,7 @@ class TestMLAPrepross(op_test.OpTest): self.__test_mlapo_impl(torch.bfloat16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b - def test_mla_preprocess_bf16_cm2_qm0_nd_case3(self): + def test_mla_preprocess_bf16_cm2_qm0_nd_case3(self): #wxh num_tokens = 1 num_heads = 64 cache_mode = 2 @@ -997,7 +1024,7 @@ class TestMLAPrepross(op_test.OpTest): self.__test_mlapo_impl(torch.bfloat16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) @op_test.only_910b - def test_mla_preprocess_bf16_cm3_qm0_nd_case1(self): + def test_mla_preprocess_bf16_cm3_qm0_nd_case1(self): #wxh num_tokens = 1024 num_heads = 32 cache_mode = 3 @@ -1013,9 +1040,9 @@ class TestMLAPrepross(op_test.OpTest): quant_mode = 0 weight_format = self.format_nz self.__test_mlapo_impl(torch.float16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) - + ''' @op_test.only_910b - def test_mla_preprocess_bf16_cm0_qm1_nd_case1(self): + def test_mla_preprocess_bf16_cm0_qm1_nd_case1(self): #quant = 1的肯定不测 num_tokens = 32 num_heads = 32 cache_mode = 0 @@ -1085,6 +1112,6 @@ class TestMLAPrepross(op_test.OpTest): quant_mode = 1 weight_format = self.format_nd self.__test_mlapo_impl(torch.float16, num_tokens, num_heads, cache_mode, quant_mode, weight_format) - + ''' if __name__ == "__main__": unittest.main()