diff --git a/impl/matmul/kernel_kfc.h b/impl/matmul/kernel_kfc.h index 04a50d6eb1891cc946810a1bd2e861657b1b71a8..c6e06859895d61b2ef43f4160cc26bdbaa3df3a8 100644 --- a/impl/matmul/kernel_kfc.h +++ b/impl/matmul/kernel_kfc.h @@ -328,7 +328,7 @@ public: __aicore__ inline void SetUserDefInfo(const uint64_t tilingPtr) {} __aicore__ inline void SetQuantScalar(const uint64_t quantScalar) {} __aicore__ inline void SetQuantVector(const GlobalTensor& quantTensor) {} - template __aicore__ inline void SetWorkspace(__gm__ T* addr, int len) {}; + template __aicore__ inline void SetWorkspace(__gm__ T* addr, int size) {}; template __aicore__ inline void SetWorkspace(GlobalTensor& addr){}; __aicore__ inline void End(){}; __aicore__ inline void SetHF32(bool enableHF32 = false, int32_t transMode = 0){}; @@ -351,7 +351,7 @@ public: __aicore__ inline void GetTensorC(const GlobalTensor& gm, uint8_t enAtomic = 0, bool enSequentialWrite = false){}; template - __aicore__ inline void GetTensorC(const GlobalTensor &c, const LocalTensor &cLocal, + __aicore__ inline void GetTensorC(const GlobalTensor &gm, const LocalTensor &cLocal, uint8_t enAtomic = 0, bool enSequentialWrite = false) {}; template __aicore__ inline GlobalTensor GetTensorC(uint8_t enAtomic = 0, bool enSequentialWrite = false) diff --git a/impl/matmul/matmul_impl.h b/impl/matmul/matmul_impl.h index 78db7293716c941633750d64aaca0064fbb10f64..7423ff6d40976a8596c1ab04198b7df66a8fca0f 100644 --- a/impl/matmul/matmul_impl.h +++ b/impl/matmul/matmul_impl.h @@ -20,10 +20,6 @@ namespace matmul { constexpr int32_t MAX_BLOCK_COUNT_SIZE = 4095; constexpr int32_t DOUBLE_SIZE = 2; -#ifdef ASCENDC_CPU_DEBUG -#define REGIST_MATMUL_OBJ_REMOTE(tpipe, workspace, maxTimes, ...) -#endif - template __aicore__ inline void GlobalCache::Init(const TCubeTiling* __restrict cubeTiling, TPipe* tpipe) { @@ -164,9 +160,9 @@ __aicore__ inline void GlobalCache::ReduceCacheSize() --cacheSize_; } -template -__aicore__ inline void SetTPipe(MatmulImpl &mm, +template +__aicore__ inline void SetTPipe(MatmulImpl &mm, TPipe* tpipe) { mm.var.tpipe_ = tpipe; @@ -6184,21 +6180,21 @@ __aicore__ inline void MatmulImplbaseN]; MatmulInstr::biasType_ = IsSameType::value ? 2 : 1; // 2:f32, 1:f16 MatmulInstr::sL1BiasOffset_ = 0; - MatmulInstr::template Compute(a1, b1, - var.cMatrix_, bias, 0, 0, var.sMadMStep_, var.sMadNStep_); + MatmulInstr::template Compute(a1, b1, var.cMatrix_, bias, + 0, 0, var.sMadMStep_, var.sMadNStep_); if constexpr (A_TYPE::layout == LayoutMode::NONE || MM_CFG.batchMode == BatchMode::SINGLE_LARGE_THAN_L1) { var.qidBias_.FreeTensor(bias); } } else { MatmulInstr::biasType_ = 0; - MatmulInstr::template Compute(a1, b1, - var.cMatrix_, bias, 0, 0, var.sMadMStep_, var.sMadNStep_); + MatmulInstr::template Compute(a1, b1, var.cMatrix_, bias, + 0, 0, var.sMadMStep_, var.sMadNStep_); } } else { MatmulInstr::biasType_ = 0; - MatmulInstr::template Compute(a1, b1, - var.cMatrix_, bias, 0, 0, var.sMadMStep_, var.sMadNStep_); + MatmulInstr::template Compute(a1, b1, var.cMatrix_, bias, + 0, 0, var.sMadMStep_, var.sMadNStep_); } } @@ -6394,20 +6390,20 @@ __aicore__ inline void MatmulImpl::value ? 2 : 1; // 2:f32, 1:f16 MatmulInstr::sL1BiasOffset_ = 0; - MatmulInstr::template Compute(a1, b1, - var.cMatrix_, bias, 0, 0, var.sMadMStep_, var.sMadNStep_); + MatmulInstr::template Compute(a1, b1, var.cMatrix_, bias, + 0, 0, var.sMadMStep_, var.sMadNStep_); if constexpr (A_TYPE::layout == LayoutMode::NONE || MM_CFG.batchMode == BatchMode::SINGLE_LARGE_THAN_L1) { var.qidBias_.FreeTensor(bias); } } else { MatmulInstr::biasType_ = 0; - MatmulInstr::template Compute(a1, b1, - var.cMatrix_, bias, 0, 0, var.sMadMStep_, var.sMadNStep_); + MatmulInstr::template Compute(a1, b1, var.cMatrix_, bias, + 0, 0, var.sMadMStep_, var.sMadNStep_); } } else { MatmulInstr::biasType_ = 0; - MatmulInstr::template Compute(a1, b1, - var.cMatrix_, bias, 0, 0, var.sMadMStep_, var.sMadNStep_); + MatmulInstr::template Compute(a1, b1, var.cMatrix_, bias, + 0, 0, var.sMadMStep_, var.sMadNStep_); } #elif __CCE_AICORE__ == 200 if (var.enableBias_) { @@ -9722,7 +9718,8 @@ __aicore__ inline void MatmulImpl -__aicore__ inline void MatmulImpl::UpdateDataCopyParamForQuant( +__aicore__ inline +void MatmulImpl::UpdateDataCopyParamForQuant( DataCopyEnhancedParams& enhancedParams) { if constexpr (IsSameType::value) { @@ -10127,10 +10124,11 @@ template ::CopyCo22GMNZ2NDOnTheFly( const GlobalTensor& gmC, const LocalTensor& src, bool enSequentialWrite) { + uint32_t dimN = (Kc_ != 0) ? Kc_ : N_; const int blockCount = sizeof(DstT) == B32_BYTE_SIZE ? BLOCK_CUBE : ONE_BLK_SIZE / sizeof(DstT); const int oneBlockCount = ONE_BLK_SIZE / sizeof(DstT); int calcWidth = var.baseUseN_ / blockCount; - int dstOffset = var.curM_ * var.tiling_->baseM * N_ + var.curN_ * var.tiling_->baseN; + int dstOffset = var.curM_ * var.tiling_->baseM * dimN + var.curN_ * var.tiling_->baseN; int blockLen = blockCount * sizeof(DstT) / ONE_BLK_SIZE; int srcRepeatGap = (var.blockUseM_ * BLOCK_CUBE * blockCount - blockCount) * sizeof(DstT) / ONE_BLK_SIZE; int tail = var.baseUseN_ % blockCount; @@ -10142,7 +10140,7 @@ __aicore__ inline void MatmulImpl::CopyCo22GMNZ2ND( const GlobalTensor& gmC, LocalTensor& src, bool enSequentialWrite) { + uint32_t dimN = (Kc_ != 0) ? Kc_ : N_; const int blockCount = sizeof(DstT) == B32_BYTE_SIZE ? BLOCK_CUBE : ONE_BLK_SIZE / sizeof(DstT); int width = var.blockUseN_ * blockCount; if constexpr (IsSameType::value || IsSameType::value) { @@ -10427,15 +10426,15 @@ __aicore__ inline void MatmulImpl= width), - { KERNEL_LOG(KERNEL_ERROR, "N_ is %d, width is %d, N_ should be no less than width", N_, width); }); - int dstStride = (N_ - width) * sizeof(DstT) / ONE_BLK_SIZE; - int dstOffset = var.curM_ * var.tiling_->baseM * N_ + var.curN_ * var.tiling_->baseN; - int offset = N_; + ASCENDC_ASSERT((dimN >= width), + { KERNEL_LOG(KERNEL_ERROR, "dimN is %d, width is %d, dimN should be no less than width", dimN, width); }); + int dstStride = (dimN - width) * sizeof(DstT) / ONE_BLK_SIZE; + int dstOffset = var.curM_ * var.tiling_->baseM * dimN + var.curN_ * var.tiling_->baseN; + int offset = dimN; if (enSequentialWrite) { isGmAligned = (var.baseUseN_ % blockCount) == 0; dstStride = 0; @@ -10451,7 +10450,7 @@ __aicore__ inline void MatmulImpl::value) { CopyToGMForNotAligned(gmC, trans, blocklen, enSequentialWrite, isTragetAligned); } else { @@ -10522,9 +10521,10 @@ template ::CopyCo22UBNZ2ND( const LocalTensor& dst, const LocalTensor& src, bool enSequentialWrite) { + uint32_t dimN = (Kc_ != 0) ? Kc_ : N_; const int blockCount = sizeof(DstT) == B32_BYTE_SIZE ? BLOCK_CUBE : ONE_BLK_SIZE / sizeof(DstT); - int dstOffset = var.curM_ * var.tiling_->baseM * N_ + var.curN_ * var.tiling_->baseN; - int offset = Ceil(N_, blockCount) * blockCount; + int dstOffset = var.curM_ * var.tiling_->baseM * dimN + var.curN_ * var.tiling_->baseN; + int offset = Ceil(dimN, blockCount) * blockCount; if (enSequentialWrite) { dstOffset = 0; offset = var.tiling_->baseN; @@ -10924,8 +10924,8 @@ __aicore__ inline MatmulImpl 16 or m,n<16 - const int32_t m0 = min(minMNSize, min(coreStatus.m, minTotalSize / n0)); - const int32_t k0 = min(min(minKSize / m0, minKSize / n0), coreStatus.k); + const int32_t m0 = min(minMNSize, ((n0 == 0) ? 0 : min(coreStatus.m, minTotalSize / n0))); + const int32_t k0 = (m0 != 0 && n0 != 0) ? + min(min(minKSize / m0, minKSize / n0), coreStatus.k) : coreStatus.k; const int32_t dbBuffer = 2; // A/B fullload or A fullload + B Kdim fullload or B fullload + A Kdim fullload(1/2/4) @@ -1628,10 +1629,11 @@ int32_t MatmulTilingAlgorithm::LoopNumFromSingleCoreToL0(const CoreStatusPack& c constexpr int32_t minSize = 64; constexpr int32_t minN0Size = 16; int32_t n0 = min(min(minN0Size, coreStatus.n), minSize); - int32_t m0 = min(min(coreStatus.m, minTotalSize / n0), minSize); + int32_t m0 = (n0 == 0) ? 0 : min(min(coreStatus.m, minTotalSize / n0), minSize); n0 = (m0 == 0) ? 0 : min(min(coreStatus.n, minTotalSize / m0), minSize); m0 = (n0 == 0) ? 0 : min(min(coreStatus.m, minTotalSize / n0), minSize); - const int32_t k0 = (m0 != 0 && n0 != 0) ? min(min(minSize / m0, minSize / n0), coreStatus.k) : coreStatus.k; + const int32_t k0 = (m0 != 0 && n0 != 0) ? + min(min(minSize / m0, minSize / n0), coreStatus.k) : coreStatus.k; const int32_t loopNum = MathUtil::CeilDivision(coreStatus.m, m0) * MathUtil::CeilDivision(coreStatus.n, n0) * MathUtil::CeilDivision(coreStatus.k, k0); return loopNum; @@ -1670,12 +1672,12 @@ int MatmulTilingAlgorithm::GetBigPackageCondition(CoreStatusPack &coreStatus, void MatmulTilingAlgorithm::GetBlockDimHelper(const DimFactor& blockDim, CoreStatusPack& coreStatus, BlockDimCalculator& blockDimRes, const MatmulRunParas& params) { - blockDimRes.kNum = params.k32 / blockDim.k * C0_SIZE * REDUCE_BLOCK_SIZE; // contain k * 16 + blockDimRes.kNum = (blockDim.k == 0) ? 0 : params.k32 / blockDim.k * C0_SIZE * REDUCE_BLOCK_SIZE; // contain k * 16 blockDimRes.kBytes = blockDimRes.kNum * INPUTDTYPE_BYTES; // contain k * 16 * 2 coreStatus.batch = MathUtil::CeilDivision(params.batch32, blockDim.batch); coreStatus.m = MathUtil::CeilDivision(params.m32, blockDim.m); coreStatus.n = MathUtil::CeilDivision(params.n32, blockDim.n); - coreStatus.k = params.k32 / blockDim.k; + coreStatus.k = (blockDim.k == 0) ? 0 : params.k32 / blockDim.k; if (tilingIns_->enableSplitK_) { if (params.kMapped != params.k32) { // need check--splitK blockDimRes.kNum = params.kMapped / blockDim.k * NUM_TWO * C0_SIZE * REDUCE_BLOCK_SIZE; @@ -1805,7 +1807,7 @@ bool MatmulTilingAlgorithm::PreProcessMiniShape(const std::string& opType, CoreS coreStatus.n = coreStatus.nDim == 1 ? params.n32 : MathUtil::CeilDivision(params.nMapped, coreStatus.nDim); coreStatus.m = coreStatus.mDim == 1 ? params.m32 : MathUtil::CeilDivision(params.mMapped, coreStatus.mDim); coreStatus.k = coreStatus.kDim == 1 ? params.k32 : MathUtil::CeilDivision(params.kMapped, coreStatus.kDim); - params.nonFactorK = params.k32 % coreStatus.kDim == 0 ? false : true; + params.nonFactorK = (coreStatus.kDim == 0) ? false : (params.k32 % coreStatus.kDim == 0 ? false : true); return true; } return false; @@ -1859,7 +1861,7 @@ void MatmulTilingAlgorithm::AddOptimalFactors(const std::string& opType, const M // A/B fullload or A fullload + B Kdim fullload or B fullload + A Kdim fullload(1/2/4) const int32_t mnCore = MathUtil::CeilDivision(coreNum, params.batch32); if (mnCore > 1) { - const float optPoint = sqrt((params.m32 + 0.0) / params.n32 * mnCore); + const float optPoint = static_cast(sqrt((params.m32 + 0.0) / params.n32 * mnCore)); const int32_t mdim = static_cast(ceil(optPoint)); const int32_t ndim = static_cast(ceil(mnCore / optPoint)); MathUtil::AddFactor(blockDimRes.mDimFactors, mdim); @@ -2275,6 +2277,27 @@ int64_t MatmulTilingAlgorithm::Process() tilingIns_->tiling_.set_baseN(singleCoreStatus.l0Status.nL0 * C0_SIZE); const int32_t reduceSize = C0_BYTE_SIZE / DTYPE_BIT_TAB.at(tilingIns_->aType_.dataType) * BITS_PER_BYTE; tilingIns_->tiling_.set_baseK(singleCoreStatus.l0Status.kL0 * reduceSize); + // check whether OUTER_PRODUCT is supported + if ((tilingIns_->scheduleType == ScheduleType::OUTER_PRODUCT) && + (tilingIns_->tiling_.get_baseK() < tilingIns_->tiling_.get_singleCoreK())) { + TILING_LOG_WARNING("Unsupported scheduleType is OUTER_PRODUCT"); + return -1; + } + tilingIns_->tiling_.set_iterateOrder(GetIteratorOrder(singleCoreStatus, singleCoreM, singleCoreN, singleCoreK)); + int32_t newBaseM = singleCoreStatus.l0Status.mL0 * C0_SIZE; + int32_t newBaseN = singleCoreStatus.l0Status.nL0 * C0_SIZE; + if (tilingIns_->scheduleType == ScheduleType::OUTER_PRODUCT) { + // when scheduleType is OUT_PRODUCT, each iteration computes 2 * basicBlock size of data + bool isL0CFullUsed = (newBaseM * newBaseN * NUM_TWO) * static_cast(DTYPE_BYTE_TAB.at(tilingIns_->cType_.dataType) > + tilingIns_->bufferPool_.l0CSize) ? 1 : 0; + if (isL0CFullUsed && (tilingIns_->tiling_.get_iterateOrder() == 0)) { + newBaseN = MathUtil::CeilDivision(newBaseN / NUM_TWO, C0_SIZE); + } else if (isL0CFullUsed && (tilingIns_->tiling_.get_iterateOrder() == 1)) { + newBaseM = MathUtil::CeilDivision(newBaseM / NUM_TWO, C0_SIZE); + } + tilingIns_->tiling_.set_baseM(newBaseM); + tilingIns_->tiling_.set_baseN(newBaseN); + } tilingIns_->baseM = tilingIns_->tiling_.get_baseM(); tilingIns_->baseN = tilingIns_->tiling_.get_baseN(); tilingIns_->baseK = tilingIns_->tiling_.get_baseK(); @@ -2299,7 +2322,6 @@ int64_t MatmulTilingAlgorithm::Process() int32_t b1LengthCache = 0; SetDepthL1CacheUBParams(a1LengthCache, b1LengthCache); tilingIns_->tiling_.set_transLength(transLength); // a1 b1 c1 reuse on ub - tilingIns_->tiling_.set_iterateOrder(GetIteratorOrder(singleCoreStatus, singleCoreM, singleCoreN, singleCoreK)); tilingIns_->tiling_.set_shareMode(0); tilingIns_->tiling_.set_dbL0A(singleCoreStatus.l0Status.dbL0A); tilingIns_->tiling_.set_dbL0B(singleCoreStatus.l0Status.dbL0B); diff --git a/impl/matmul/matmul_tiling_base.cpp b/impl/matmul/matmul_tiling_base.cpp index 3a9d37d65e80f0359e9dc3c54b553199fee5f7db..07d197bbd4ab6c7c7de8243d9ab794c4ff5b1efe 100644 --- a/impl/matmul/matmul_tiling_base.cpp +++ b/impl/matmul/matmul_tiling_base.cpp @@ -581,6 +581,18 @@ void MatmulApiTilingBase::SetMatmulConfigParams(int32_t mmConfigTypeIn, bool ena this->enableL1CacheUB = enableL1CacheUBIn; } +void MatmulApiTilingBase::SetMatmulConfigParams(const MatmulConfigParams& configParams) +{ + TILING_LOG_DEBUG("Set MatmulConfigType: %d", static_cast(configParams.mmConfigType)); + TILING_LOG_DEBUG("Set EnableL1CacheUB: %d", static_cast(configParams.enableL1CacheUB)); + TILING_LOG_DEBUG("Set ScheduleType: %d", static_cast(configParams.scheduleType)); + TILING_LOG_DEBUG("Set Traverse: %d", static_cast(configParams.traverse)); + this->mmConfigType = configParams.mmConfigType; + this->enableL1CacheUB = configParams.enableL1CacheUB; + this->scheduleType = configParams.scheduleType; + this->traverse_ = configParams.traverse; +} + bool MatmulApiTilingBase::CheckSetParam() { if (socVersion == platform_ascendc::SocVersion::ASCEND910 || diff --git a/impl/matmul/modules/dfx/dfx_config.h b/impl/matmul/modules/dfx/dfx_config.h new file mode 100644 index 0000000000000000000000000000000000000000..83d0470e3474b21c18bee1e0157c9e83a569e935 --- /dev/null +++ b/impl/matmul/modules/dfx/dfx_config.h @@ -0,0 +1,28 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + + /*! + * \file dfx_config.h + * \brief + */ + +#ifndef MATMUL_DFX_CONFIG_H +#define MATMUL_DFX_CONFIG_H + +#include "handlers/dfx_chain_handler.h" +#include "dfx_func_info.h" + +namespace matmul { +struct DfxConfig { + static constexpr bool ENABLE = false; + using EnabledHandlers = DfxChainHandler <>; +}; +} +#endif \ No newline at end of file diff --git a/impl/matmul/modules/dfx/dfx_func_info.h b/impl/matmul/modules/dfx/dfx_func_info.h new file mode 100644 index 0000000000000000000000000000000000000000..02e4fdd1be05d5f6eed610750982cc56e200b7a4 --- /dev/null +++ b/impl/matmul/modules/dfx/dfx_func_info.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_func_info.h + * \brief + */ + + #ifndef MATMUL_DFX_FUNC_INFO_H +#define MATMUL_DFX_FUNC_INFO_H + +namespace matmul { +struct DfxFuncInfo { + __aicore__ inline DfxFuncInfo(__gm__ const char* module, __gm__ const char* func, uint32_t funcId) + :module(module), func(func), funcId(funcId) { + } + __gm__ const char* module; + __gm__ const char* func; + uint32_t funcId; +}; +} +#endif \ No newline at end of file diff --git a/impl/matmul/modules/dfx/dfx_handler.h b/impl/matmul/modules/dfx/dfx_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..ce6c0fc31d6cd825e01c61cd78a4de958bf2dda9 --- /dev/null +++ b/impl/matmul/modules/dfx/dfx_handler.h @@ -0,0 +1,41 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_handler.h + * \brief + */ + +#ifndef MATMUL_DFX_HANDLER_H +#define MATMUL_DFX_HANDLER_H + +#include "dfx_config.h" + +namespace matmul { + +struct DfxHandler { + template + __aicore__ inline static void PreCall(const DfxFuncInfo& info, Args&&... args) { + DfxConfig::EnabledHandlers::PreCall(info, std::forward(args)...); + } + + template + __aicore__ inline static void PostCall(const DfxFuncInfo& info, const RT& ret) { + DfxConfig::EnabledHandlers::PostCall(info, ret); + } + + __aicore__ inline static void PostCall(const DfxFuncInfo& info) { + DfxConfig::EnabledHandlers::PostCall(info); + } +}; + +} + +#endif \ No newline at end of file diff --git a/impl/matmul/modules/dfx/dfx_proxy.h b/impl/matmul/modules/dfx/dfx_proxy.h new file mode 100644 index 0000000000000000000000000000000000000000..e6ad92c17b73bc13c3a317db64e57cb2145ab0d5 --- /dev/null +++ b/impl/matmul/modules/dfx/dfx_proxy.h @@ -0,0 +1,172 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_proxy.h + * \brief + */ + +#ifndef MATMUL_DFX_PROXY_H +#define MATMUL_DFX_PROXY_H + +#include +#include "dfx_handler.h" + +namespace matmul { + +template +using enable_if_t = typename std::enable_if::type; + +template +constexpr bool is_void_v = std::is_void::value; + +/////////////////////////////////////////////////////////////////////////////// +template +struct DfxProxy : MODULE { + __aicore__ inline auto operator->() { return this; } + __aicore__ inline operator MODULE*() { return this; } +}; + +/////////////////////////////////////////////////////////////////////////////// +#define MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC) \ +template \ +__aicore__ inline auto FUNC(Args&&... args) -> enable_if_t(args)...))>, \ +decltype(MODULE().MODULE::FUNC(std::forward(args)...))>{ \ + DfxFuncInfo info{#MODULE, #FUNC, __COUNTER__}; \ + DfxHandler::PreCall(info, std::forward(args)...); \ + auto ret = M_.MODULE::FUNC(std::forward(args)...); \ + DfxHandler::PostCall(info, ret); \ + return ret; \ +} \ +template \ +__aicore__ inline auto FUNC(Args&&... args) -> enable_if_t(args)...))>> { \ + DfxFuncInfo info{#MODULE, #FUNC, __COUNTER__}; \ + DfxHandler::PreCall(info, std::forward(args)...); \ + M_.MODULE::FUNC(std::forward(args)...); \ + DfxHandler::PostCall(info); \ +} + +/////////////////////////////////////////////////////////////////////////////// +#define MATMUL_COUNT_ARGS_IMPL(_1, _2, _3, _4, _5, _6, _7, _8, _9, N, ...) N +#define MATMUL_COUNT_ARGS(...) \ +MATMUL_COUNT_ARGS_IMPL(__VA_ARGS__, 9, 8, 7, 6, 5, 4, 3, 2, 1) + +#define MATMUL_DEF_PROXY_FUNC_1(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) + +#define MATMUL_DEF_PROXY_FUNC_2(M_, MODULE, FUNC1, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) + +#define MATMUL_DEF_PROXY_FUNC_3(M_, MODULE, FUNC1, FUNC2, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) + +#define MATMUL_DEF_PROXY_FUNC_4(M_, MODULE, FUNC1, FUNC2, FUNC3, FUNC4) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC4) + +#define MATMUL_DEF_PROXY_FUNC_5(M_, MODULE, FUNC1, FUNC2, FUNC3, FUNC4, FUNC5) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC4) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC5) + +#define MATMUL_DEF_PROXY_FUNC_6(M_, MODULE, FUNC1, FUNC2, FUNC3, FUNC4, FUNC5, FUNC6) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC4) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC5) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC6) + +#define MATMUL_DEF_PROXY_FUNC_7(M_, MODULE, FUNC1, FUNC2, FUNC3, FUNC4, FUNC5, FUNC6, FUNC7) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC4) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC5) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC6) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC7) + +#define MATMUL_DEF_PROXY_FUNC_8(M_, MODULE, FUNC1, FUNC2, FUNC3, FUNC4, FUNC5, FUNC6, FUNC7, FUNC8) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC4) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC5) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC6) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC7) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC8) + +#define MATMUL_DEF_PROXY_FUNC_9(M_, MODULE, FUNC1, FUNC2, FUNC3, FUNC4, FUNC5, FUNC6, FUNC7, FUNC8, FUNC9) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC1) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC2) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC3) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC4) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC5) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC6) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC7) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC8) \ +MATMUL_DEF_DFX_PROXY_FUNC(M_, MODULE, FUNC9) + +#define MATMUL_DEF_DFX_FUNCS_IMPL2(N, M_, MODULE, ...) \ +MATMUL_DEF_PROXY_FUNC_##N(M_, MODULE, __VA_ARGS__) + +#define MATMUL_DEF_DFX_FUNCS_IMPL(N, M_, MODULE, ...) \ +MATMUL_DEF_DFX_FUNCS_IMPL2(N, M_, MODULE, __VA_ARGS__) + +#define MATMUL_DEF_DFX_FUNCS(M_, MODULE, ...) \ +MATMUL_DEF_DFX_FUNCS_IMPL(MATMUL_COUNT_ARGS(__VA_ARGS__), M_, MODULE, __VA_ARGS__) + +/////////////////////////////////////////////////////////////////////////////// +#define MATMUL_DFX_PROXY_REGISTER(MODULE, ...) \ +template \ +struct DfxProxy { \ + using MODULE = typename IMPL::MODULE; \ + __aicore__ inline DfxProxy(MODULE& module) : proxy{module} {} \ + struct FuncProxy { \ + __aicore__ inline FuncProxy(MODULE& module) : m_{module} {} \ + __aicore__ inline auto& operator*() { return m_; } \ + MATMUL_DEF_DFX_FUNCS(m_, MODULE, __VA_ARGS__) \ + private: \ + MODULE& m_; \ + }; \ + __aicore__ inline auto operator->() { return &proxy; } \ + __aicore__ inline operator MODULE*() { return &(*proxy); } \ +private: \ + FuncProxy proxy; \ +}; \ +template \ +struct DfxProxy { \ + using MODULE = typename IMPL::MODULE; \ + __aicore__ inline DfxProxy(const MODULE& module) : proxy{module} {} \ + struct FuncProxy { \ + __aicore__ inline FuncProxy(const MODULE& module) : m_{module} {} \ + __aicore__ inline const auto& operator*() { return m_; } \ + MATMUL_DEF_DFX_FUNCS(m_, MODULE, __VA_ARGS__) \ + private: \ + const MODULE& m_; \ + }; \ + __aicore__ inline auto operator->() { return &proxy; } \ + __aicore__ inline operator MODULE*() { return &(*proxy); } \ +private: \ + FuncProxy proxy; \ +} + +} + +#endif \ No newline at end of file diff --git a/impl/matmul/modules/dfx/dfx_registry.h b/impl/matmul/modules/dfx/dfx_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..400891abe3b2a209e20d7c706aff7f5f6a035b35 --- /dev/null +++ b/impl/matmul/modules/dfx/dfx_registry.h @@ -0,0 +1,26 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_registry.h + * \brief + */ + + +#ifndef MATMUL_DFX_REGISTRY_H +#define MATMUL_DFX_REGISTRY_H + +#include "dfx_proxy.h" + +namespace matmul { + MATMUL_DFX_PROXY_REGISTER(InputL1Cache, ClearAL1Cache, ClearBL1Cache); +} + +#endif \ No newline at end of file diff --git a/impl/matmul/modules/dfx/handlers/dfx_chain_handler.h b/impl/matmul/modules/dfx/handlers/dfx_chain_handler.h new file mode 100644 index 0000000000000000000000000000000000000000..4112043814e4f452a724af93744227d9bfd16c43 --- /dev/null +++ b/impl/matmul/modules/dfx/handlers/dfx_chain_handler.h @@ -0,0 +1,42 @@ +/** + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +/*! + * \file dfx_chain_handler.h + * \brief + */ + +#ifndef MATMUL_DFX_CHAIN_HANDLER_H +#define MATMUL_DFX_CHAIN_HANDLER_H + +namespace matmul { + +struct DfxFuncInfo; + +template +struct DfxChainHandler { + template + __aicore__ inline static void PreCall(const DfxFuncInfo& info, Args&&... args) { + (HANDLERS::PreCall(info, std::forward(args)...), ...); + } + + template + __aicore__ inline static void PostCall(const DfxFuncInfo& info, const RT& ret) { + (HANDLERS::PostCall(info, ret), ...); + } + + __aicore__ inline static void PostCall(const DfxFuncInfo& info) { + (HANDLERS::PostCall(info), ...); + } +}; + +} + +#endif \ No newline at end of file diff --git a/impl/matmul/modules/matmul_module.h b/impl/matmul/modules/matmul_module.h index eed5071e9b3a690e33d4bd367c60c510d32b77b6..884dba270b00ea52e0a0d5bbd0045e0349d105fb 100644 --- a/impl/matmul/modules/matmul_module.h +++ b/impl/matmul/modules/matmul_module.h @@ -15,6 +15,9 @@ #ifndef IMPL_MATMUL_MODULES_MATMUL_MODULE_H #define IMPL_MATMUL_MODULES_MATMUL_MODULE_H +#include "dfx/dfx_registry.h" +#include "dfx/dfx_config.h" + /* MatmulModuleBase */ namespace matmul { template @@ -38,6 +41,7 @@ struct MatmulModuleBase> { /* MatmulImpl */ #define MATMUL_IMPL__ IMPL +#define MATMUL_POLICY__ POLICY #define MATMUL_CAST_TO_IMPL() static_cast(this) #define MATMUL_CAST_TO_CONST_IMPL() static_cast(this) @@ -48,33 +52,61 @@ struct MatmulModuleBase> { #define MATMUL_CAST_TO_CONST_IMPL_OF(...) \ (static_cast(MATMUL_CAST_TO_CONST_IMPL())) +#define MATMUL_CAST_TO_PROXY_OF(NAME) \ +typename matmul::DfxProxy (*MATMUL_CAST_TO_IMPL_OF(NAME)) + +#define MATMUL_CAST_TO_CONST_PROXY_OF(NAME) \ +typename matmul::DfxProxy (*MATMUL_CAST_TO_CONST_IMPL_OF(NAME)) + #define MATMUL_MODULE(NAME) cast_to_##NAME() -#define MATMUL_USE_MODULE(NAME) \ -__aicore__ inline constexpr decltype(auto) MATMUL_MODULE(NAME) { \ - return MATMUL_CAST_TO_IMPL_OF(NAME); \ +#define MATMUL_USE_MODULE(NAME) \ +__aicore__ inline constexpr decltype(auto) MATMUL_MODULE(NAME) { \ + if constexpr (DfxConfig::ENABLE) { \ + return MATMUL_CAST_TO_PROXY_OF(NAME); \ + } else { \ + return MATMUL_CAST_TO_IMPL_OF(NAME); \ + } \ +} \ +__aicore__ inline constexpr decltype(auto) MATMUL_MODULE(NAME) const { \ + if constexpr (DfxConfig::ENABLE) { \ + return MATMUL_CAST_TO_CONST_PROXY_OF(NAME); \ + } else { \ + return MATMUL_CAST_TO_CONST_IMPL_OF(NAME); \ + } \ } -#define MATMUL_USE_MODULE_ON(NAME, ...) \ -__aicore__ inline constexpr decltype(auto) MATMUL_MODULE(NAME) { \ - return MATMUL_CAST_TO_IMPL_OF(template NAME<__VA_ARGS__>); \ +#define MATMUL_USE_MODULE_ON(NAME, ...) \ +__aicore__ inline constexpr decltype(auto) MATMUL_MODULE(NAME) { \ + if constexpr (DfxConfig::ENABLE) { \ + return MATMUL_CAST_TO_PROXY_OF(template NAME<__VA_ARGS__>); \ + } else { \ + return MATMUL_CAST_TO_IMPL_OF(template NAME<__VA_ARGS__>); \ + } \ +} \ +__aicore__ inline constexpr decltype(auto) MATMUL_MODULE(NAME) const { \ + if constexpr (DfxConfig::ENABLE) { \ + return MATMUL_CAST_TO_CONST_PROXY_OF(template NAME<__VA_ARGS__>);\ + } else { \ + return MATMUL_CAST_TO_CONST_IMPL_OF(template NAME<__VA_ARGS__>); \ + } \ } /* MatmulPolicy */ #define MATMUL_POLICY_TEMPLATE MATMUL_POLICY #define MATMUL_POLICY_DEFAULT_OF(DEFAULT) \ -template \ +template \ class MATMUL_POLICY = DEFAULT #define MATMUL_POLICY_TEMPLATE_OF(NAME) \ -template class NAME +template class NAME #define MATMUL_IMPL_TYPE \ MatmulImpl #define MATMUL_MODULE_IN_POLICY(...) \ -MATMUL_POLICY_TEMPLATE::__VA_ARGS__ +MATMUL_POLICY_TEMPLATE::__VA_ARGS__ #define MATMUL_IMPORT_MODULE(...) private MATMUL_MODULE_IN_POLICY(__VA_ARGS__) diff --git a/impl/matmul/modules/matmul_policy.h b/impl/matmul/modules/matmul_policy.h index 0184d5ad4ea4dc00af5ebbcd5d820487ce80fa68..1db8dca529dcf0d1457d303f971cb283344c07ca 100644 --- a/impl/matmul/modules/matmul_policy.h +++ b/impl/matmul/modules/matmul_policy.h @@ -21,7 +21,7 @@ namespace matmul { -template struct MatmulPolicy { diff --git a/lib/matmul/matmul.h b/lib/matmul/matmul.h index fcfc558ab21b53efffffbb68a0643830250a09d8..ce0be5fd825ef50e200fbb541b9a7cf69b2c1abc 100644 --- a/lib/matmul/matmul.h +++ b/lib/matmul/matmul.h @@ -212,6 +212,8 @@ public: MATMUL_ALLOW_USING(InputL1Cache); MATMUL_ALLOW_USING(Co1Buffer); +private: + template friend struct DfxProxy; private: template diff --git a/lib/matmul/matmul_client.h b/lib/matmul/matmul_client.h index 1b1ebb9e722f321678b1d2afc50ff5eb38041fbe..d0766cfdc344b04592dea4517bc8741163605573 100644 --- a/lib/matmul/matmul_client.h +++ b/lib/matmul/matmul_client.h @@ -129,7 +129,7 @@ public: ASSERT(addr.GetSize() > 0); SetWorkspace(addr.GetPhyAddr(), addr.GetSize() * sizeof(T)); } - template __aicore__ inline void SetWorkspace(__gm__ const T* addr, int size) + template __aicore__ inline void SetWorkspace(__gm__ const T* addr, int len) { ASSERT(addr != nullptr); ASSERT(this->cubeTiling != nullptr); @@ -155,6 +155,11 @@ public: PostMessage(); } + __aicore__ inline void SetSingleShape(int singleM, int singleN, int singleK) + { + SetTail(singleM, singleN, singleK); + } + __aicore__ inline void SetTail(int tailM = -1, int tailN = -1, int tailK = -1) { if (tailM != -1) { @@ -337,6 +342,7 @@ public: template __aicore__ inline bool Iterate(bool enPartialSum = false) { + ASSERT(!(A_TYPE::ibShare && B_TYPE::ibShare) && "Iterate not support when samebab is enabled"); TRACE_START(TraceId::KFC_CLIENT_POST_MSG); if (unlikely(kfcMsg_.body.isFirstIter)) { cntIter_ = 0; @@ -369,6 +375,7 @@ public: kfcMsg_.body.sync = sync; kfcMsg_.body.cAddr = reinterpret_cast(cacheWorkspaceAddr); PostMessage(); + SyncCubeWithVec(); TRACE_STOP(TraceId::KFC_CLIENT_POST_MSG); return true; } @@ -378,7 +385,13 @@ public: __aicore__ inline void WaitIterateAll() { ASSERT(!isSyncGetC); // Must be asynchronous mode - WaitEvent(this->devEvtID); + auto intraId = this->devEvtID; + if constexpr (A_TYPE::ibShare && B_TYPE::ibShare) { + if (GetSubBlockIdxImpl() == 1) { + intraId = this->devEvtID - 1; + } + } + WaitEvent(intraId); } // Only support the mode that the IterateAll is asynchronous and GM output is continuous. @@ -457,6 +470,7 @@ public: { TRACE_START(TraceId::KFC_CLIENT_POST_MSG); ASSERT(kfcMsg_.body.isFirstIter == 1); + ASSERT(!(A_TYPE::ibShare && B_TYPE::ibShare) && "IterateBatch not support when sameab is enabled"); kfcMsg_.body.cAddr = reinterpret_cast(gm.GetPhyAddr()); kfcMsg_.body.enSequentialWrite = enSequentialWrite; kfcMsg_.body.sync = sync; @@ -491,6 +505,7 @@ public: TRACE_START(TraceId::KFC_CLIENT_POST_MSG); ASSERT(sync == true); ASSERT(kfcMsg_.body.isFirstIter == 1); + ASSERT(!(A_TYPE::ibShare && B_TYPE::ibShare) && "IterateBatch not support when sameab is enabled"); if (ubCmatrix.GetPosition() == static_cast(TPosition::TSCM)) { kfcMsg_.body.cAddr = GetTscmAddr(ubCmatrix); kfcMsg_.body.cIsTscm = 1; @@ -536,7 +551,7 @@ public: curProcess = 0; ASSERT(kfcMsg_.body.isFirstIter == 1); ASSERT(cacheWorkspaceAddr); - + ASSERT(!(A_TYPE::ibShare && B_TYPE::ibShare) && "IterateBatch not support when sameab is enabled"); kfcMsg_.body.cAddr = reinterpret_cast(cacheWorkspaceAddr); kfcMsg_.body.enSequentialWrite = enSequentialWrite; kfcMsg_.body.sync = sync; @@ -785,6 +800,8 @@ public: BIAS_TYPE, MM_CFG, MM_CB>::MATMUL mm; + constexpr static bool aIbShare = A_TYPE::ibShare; + constexpr static bool bIbShare = B_TYPE::ibShare; #endif @@ -795,7 +812,7 @@ private: // Use shared memory to get the queue. KfcCommClient* client; TPipe* tpipe; - TCubeTiling* cubeTiling; + const TCubeTiling* cubeTiling; KfcMsg kfcMsg_; bool isSyncGetC; @@ -809,7 +826,7 @@ private: uint32_t mnIter_; uint64_t cOffset_; template - friend __aicore__ inline void InitKfcClient(T& mm, U* cubeTiling, TPipe* tpipe, KfcCommClient* client, int instIdx, + friend __aicore__ inline void InitKfcClient(T& mm, U* tiling, TPipe* tpipe, KfcCommClient* client, int instIdx, GM_ADDR workspace); private: @@ -819,7 +836,7 @@ private: ASSERT(cubeTiling != nullptr && "cubeTiling cannot be nullptr when init matmul client"); ASSERT(sizeof(TCubeTiling) % sizeof(uint64_t) == 0); - this->cubeTiling = const_cast(cubeTiling); + this->cubeTiling = cubeTiling; *((uint64_t*)&kfcMsg_) = 0; *((uint64_t*)&(kfcMsg_.body)) = 0; @@ -869,6 +886,14 @@ private: } template __aicore__ inline void PostMessage() { + if constexpr (A_TYPE::ibShare && B_TYPE::ibShare) { + ASSERT(DoMatmulNorm(MM_CFG) && "MM_CFG should use norm config when sameab is enabled"); + if (GetSubBlockIdxImpl() == 1) { + *((uint32_t *)&kfcMsg_.body) = 0; + kfcMsg_.ubAddr = -1; + return; + } + } kfcMsg_.head = KfcMsgMakeFlag(funID, this->instIdx); auto msg = client->AllocMessage(); diff --git a/lib/matmul/matmul_intf.h b/lib/matmul/matmul_intf.h index 6431b73ccc71123601b2f4ac173a86e69e3e2ea0..a5f2c127fa1669344814112b8548a2bbb61c4950 100644 --- a/lib/matmul/matmul_intf.h +++ b/lib/matmul/matmul_intf.h @@ -25,7 +25,7 @@ #if __CCE_AICORE__ == 220 #ifdef ASCENDC_CUBE_ONLY namespace matmul { -template , MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> using Matmul = matmul::MatmulImpl; } @@ -114,10 +114,11 @@ __aicore__ inline void InitCurObj(AscendC::TPipe* tpipe, T& a, Args&&... b) #define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ InitCurObj(tpipe, __VA_ARGS__) namespace matmul { -template , MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> using Matmul = matmul::MatmulImpl; } + #endif #else @@ -152,31 +153,54 @@ __aicore__ inline void InitCurObj(AscendC::TPipe* tpipe, T& a, Args&&... b) } } +#ifdef ASCENDC_TIME_STAMP_ON +#define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ + InitCurObj(tpipe, __VA_ARGS__); \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_SERVER_OBJ)) +#else #define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ InitCurObj(tpipe, __VA_ARGS__) +#endif #define REGIST_MATMUL_OBJ_REMOTE(tpipe, workspace, ...) namespace matmul { -template , MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> using Matmul = matmul::MatmulImpl; } #else +#ifdef ASCENDC_TIME_STAMP_ON #define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ if ASCEND_IS_AIC { \ AscendC::KfcServer server; \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_SERVER)); \ server.Init(workspace); \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_SERVER_INIT)); \ server.InitObj(tpipe, __VA_ARGS__); \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_SERVER_OBJ)); \ while (server.isRun()) { \ server.Run(__VA_ARGS__); \ }; \ server.Quit(); \ return; \ } +#else +#define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ + if ASCEND_IS_AIC { \ + AscendC::KfcServer server; \ + server.Init(workspace); \ + server.InitObj(tpipe, __VA_ARGS__); \ + while (server.isRun()) { \ + server.Run(__VA_ARGS__); \ + }; \ + server.Quit(); \ + return; \ + } +#endif namespace matmul { -template , MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> -using Matmul = MatmulServiceAux; +using Matmul = matmul::MatmulServiceAux; } #endif @@ -185,14 +209,25 @@ using Matmul = MatmulServiceAux(AscendC::TimeStampId::TIME_STAMP_MATMUL_CLIENT_KFC)); \ + AscendC::SetMatrixKfc(tpipe, &__kfcClient__, 0, workspace, __VA_ARGS__); \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_MATRIX_KFC)); \ + AscendC::WaitEvent(matmul::WORKSPACE_SYNC_ID); \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_WAIT_EVE)) +#else #define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ AscendC::KfcCommClient __kfcClient__(workspace, AscendC::GetSubBlockIdx()); \ AscendC::g_kfcClient = &__kfcClient__; \ AscendC::SetMatrixKfc(tpipe, &__kfcClient__, 0, workspace, __VA_ARGS__); \ AscendC::WaitEvent(matmul::WORKSPACE_SYNC_ID) #endif +#endif namespace matmul { -template , MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> using Matmul = matmul::MatmulClient; } @@ -225,16 +260,20 @@ __aicore__ inline void InitCurObj(AscendC::TPipe* tpipe, T& a, Args&&... b) } } +#ifdef ASCENDC_TIME_STAMP_ON +#define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ + InitCurObj(tpipe, __VA_ARGS__); \ + AscendC::AscendCTimeStamp(static_cast(AscendC::TimeStampId::TIME_STAMP_MATMUL_OBJ)) +#else #define REGIST_MATMUL_OBJ(tpipe, workspace, ...) \ InitCurObj(tpipe, __VA_ARGS__) - +#endif namespace matmul { -template , MATMUL_POLICY_DEFAULT_OF(MatmulPolicy)> using Matmul = matmul::MatmulImpl; -} +} //namespace matmul #endif #endif - -#endif \ No newline at end of file +#endif \ No newline at end of file diff --git a/lib/matmul/matmul_tiling_base.h b/lib/matmul/matmul_tiling_base.h index 33319d47e208eb6e71070e7bdfca06f8f7b0c946..a724fdad9a51facd681d1a25ddbadfe0b844815e 100644 --- a/lib/matmul/matmul_tiling_base.h +++ b/lib/matmul/matmul_tiling_base.h @@ -132,6 +132,11 @@ enum class DequantType : int32_t { TENSOR = 1, }; +enum class ScheduleType : int32_t { + INNER_PRODUCT = 0, + OUTER_PRODUCT = 1, +}; + struct SysTilingTempBufSize { int32_t ubSize = 0; int32_t l1Size = 0; @@ -170,6 +175,13 @@ struct PlatformInfo { uint64_t l0BSize = 0; }; +struct MatmulConfigParams { + int32_t mmConfigType = 1; + bool enableL1CacheUB = false; + ScheduleType scheduleType = ScheduleType::INNER_PRODUCT; + MatrixTraverse traverse = MatrixTraverse::NOSET; +}; + class MatmulApiTilingBase { public: MatmulApiTilingBase(); @@ -207,6 +219,7 @@ public: int32_t SetDoubleBuffer(bool a, bool b, bool c, bool bias, bool transND2NZ = true, bool transNZ2ND = true); void SetMatmulConfigParams(int32_t mmConfigTypeIn = 1, bool enableL1CacheUBIn = false); + void SetMatmulConfigParams(const MatmulConfigParams& configParams); int32_t GetBaseM() const { @@ -288,6 +301,7 @@ public: BufferPool bufferPool_; MatrixTraverse traverse_ = MatrixTraverse::FIRSTM; MatrixMadType madType_ = MatrixMadType::NORMAL; + ScheduleType scheduleType = ScheduleType::INNER_PRODUCT; bool transND2NZ_ = false; bool transNZ2ND_ = false; diff --git a/tests/tiling/test_tiling.cpp b/tests/tiling/test_tiling.cpp index 6d2358417596dd31b83a99f6fe1302285557ea3a..2a66ea90d089c9fa464ef348ccf256b54a99ee75 100644 --- a/tests/tiling/test_tiling.cpp +++ b/tests/tiling/test_tiling.cpp @@ -38,11 +38,12 @@ TEST_F(TestTiling, MultiCoreSmallMN) rnnMatmul3.SetBType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType ::DT_FLOAT); rnnMatmul3.SetCType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::NZ, matmul_tiling::DataType ::DT_FLOAT); rnnMatmul3.SetBiasType(matmul_tiling::TPosition::VECCALC, matmul_tiling::CubeFormat::ND, matmul_tiling::DataType ::DT_FLOAT); + rnnMatmul3.SetSingleRange(-1,-1,-1,-1,-1,-1); auto ret = rnnMatmul3.SetBias(true); ret = rnnMatmul3.SetDim(24); ret = rnnMatmul3.SetOrgShape(5, 40, 986); ret = rnnMatmul3.SetShape(5, 10, 986); - ret = rnnMatmul3.SetBufferSpace(); // will use all buffer space if not explicitly specified + ret = rnnMatmul3.SetBufferSpace(-1, -1, -1); // will use all buffer space if not explicitly specified optiling::TCubeTiling tilingData; ret = rnnMatmul3.GetTiling(tilingData); rnnMatmul3.PrintTilingData(); @@ -74,6 +75,7 @@ TEST_F(TestTiling, PlatformConstructor) optiling::TCubeTiling tilingData; int ret = tiling.GetTiling(tilingData); tiling.PrintTilingData(); + tiling.PrintTilingDataInfo(tilingData); EXPECT_EQ(ret, 0); } @@ -319,7 +321,7 @@ TEST_F(TestTiling, L1CacheUBCase02NeiABFullLoad) tiling.PrintTilingData(); EXPECT_EQ(ret, 0); EXPECT_EQ(tilingData.get_depthAL1CacheUB(), 0); - EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 3); + EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 0); } TEST_F(TestTiling, L1CacheUBCase03BothABFullLoad) @@ -339,7 +341,7 @@ TEST_F(TestTiling, L1CacheUBCase03BothABFullLoad) tiling.PrintTilingData(); EXPECT_EQ(ret, 0); EXPECT_EQ(tilingData.get_depthAL1CacheUB(), 1); - EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 1); + EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 0); } TEST_F(TestTiling, L1CacheUBCase04OnlyAFullLoad) @@ -358,7 +360,7 @@ TEST_F(TestTiling, L1CacheUBCase04OnlyAFullLoad) int ret = tiling.GetTiling(tilingData); tiling.PrintTilingData(); EXPECT_EQ(ret, 0); - EXPECT_EQ(tilingData.get_depthAL1CacheUB(), 1); + EXPECT_EQ(tilingData.get_depthAL1CacheUB(), 0); } TEST_F(TestTiling, L1CacheUBCase04OnlyBFullLoad) @@ -377,7 +379,7 @@ TEST_F(TestTiling, L1CacheUBCase04OnlyBFullLoad) int ret = tiling.GetTiling(tilingData); tiling.PrintTilingData(); EXPECT_EQ(ret, 0); - EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 1); + EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 0); } TEST_F(TestTiling, L1CacheUBCase05BothCache) @@ -398,7 +400,7 @@ TEST_F(TestTiling, L1CacheUBCase05BothCache) tiling.PrintTilingData(); EXPECT_EQ(ret, 0); EXPECT_EQ(tilingData.get_depthAL1CacheUB(), 1); - EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 2); + EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 20); } TEST_F(TestTiling, L1CacheUBCase06AMatIsTSCM) @@ -418,7 +420,7 @@ TEST_F(TestTiling, L1CacheUBCase06AMatIsTSCM) tiling.PrintTilingData(); EXPECT_EQ(ret, 0); EXPECT_EQ(tilingData.get_depthAL1CacheUB(), 0); - EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 1); + EXPECT_EQ(tilingData.get_depthBL1CacheUB(), 0); } TEST_F(TestTiling, L1CacheUBCase07BMatIsTSCM) @@ -1885,7 +1887,7 @@ TEST_P(RnnTilingbTestSuite, TestMatmulApiTilngRnnRealCase) rnnParams.baseK = rnnMatmul.GetBaseK(); // get output info after cut ret = rnnMatmul.GetSingleShape(rnnParams.singleM, rnnParams.singleN, rnnParams.singleK); // get single process info ret = rnnMatmul.GetCoreNum(dim, mDim, - nDim); // get used blockdim after multi-cores cut, carried by user to kernel, contrl Kernel + nDim); // get used blockdim after multi-cores cut, carried by user to kernel, contrl Kernel // input mm int32_t l1_left = 512 * 1024 - 64 - rnnParams.singleN * (input_align + hidden_align) * sizeof(float) * 2; ret = rnnMatmul1.SetBufferSpace(l1_left, rnnParams.maxUbSize, rnnParams.maxUbSize); @@ -1922,7 +1924,7 @@ TEST_P(RnnTilingbTestSuite, TestMatmulApiTilngRnnRealCase) sizeof(float); EXPECT_LT(l1UsedSize, 512 * 1024 - 64); rnnParams.usedCoreNum = dim * 4; - } else { // part of full load + } else { // part of full load // two matmul time sharing auto ret = rnnMatmul.SetBias(true); ret = rnnMatmul.SetDim(rnnParams.sysAivCoreNum / 4); @@ -1943,7 +1945,7 @@ TEST_P(RnnTilingbTestSuite, TestMatmulApiTilngRnnRealCase) ret = rnnMatmul.GetSingleShape(rnnParams.singleM, rnnParams.singleN, rnnParams.singleK); // get single process info ret = rnnMatmul.GetCoreNum(dim, mDim, - nDim); // get used blockdim after multi-cores cut, carried by user to kernel, contrl Kernel business + nDim); // get used blockdim after multi-cores cut, carried by user to kernel, contrl Kernel business // input mm ret = rnnMatmul1.SetBufferSpace(-1, rnnParams.maxUbSize, rnnParams.maxUbSize); ret = rnnMatmul1.SetOrgShape(rnnParams.batch, rnnParams.hiddenSize * 4, rnnParams.inputSize); @@ -1977,7 +1979,7 @@ TEST_P(RnnTilingbTestSuite, TestMatmulApiTilngRnnRealCase) sizeof(float); EXPECT_LT(l1UsedSize, 512 * 1024 - 64); rnnParams.usedCoreNum = dim * 4; - } else { // no cache, reset AB,mm cache mechanism lose efficacy + } else { // no cache, reset AB, m cache mechanism lose efficacy std::cout << "can not load any weight" << std::endl; rnnMatmul.SetAType(matmul_tiling::TPosition::GM, matmul_tiling::CubeFormat::ND, (matmul_tiling::DataType)dataType); @@ -2020,7 +2022,7 @@ TEST_P(RnnTilingbTestSuite, TestMatmulApiTilngRnnRealCase) ret = rnnMatmul.GetSingleShape(rnnParams.singleM, rnnParams.singleN, rnnParams.singleK); // get single core data ret = rnnMatmul.GetCoreNum(dim, mDim, - nDim); // get used blockdim after multi-cores cut, carried by user to kernel, contrl Kernel business + nDim); // get used blockdim after multi-cores cut, carried by user to kernel, contrl Kernel business // input mm ret = rnnMatmul1.SetBufferSpace(-1, rnnParams.maxUbSize, rnnParams.maxUbSize); ret = rnnMatmul1.SetOrgShape(rnnParams.batch, rnnParams.hiddenSize * 4, rnnParams.inputSize); @@ -3005,4 +3007,4 @@ TEST_F(TestTiling, tiling_compute_error) bmm_tiling.biasType_.pos = TPosition::TSCM; ret = bmm_tiling.Compute(); EXPECT_EQ(ret, -1); -} +} \ No newline at end of file