From 1562e155ee3b4a181491343070362c67f177a579 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Mon, 15 Sep 2025 19:18:44 +0800 Subject: [PATCH 1/8] feat aclnnrunner mlapo hiddenSize --- CMakeLists.txt | 1 + include/atb/types.h | 71 +- src/CMakeLists.txt | 2 +- src/atb/kernel_cache/aclnn_executor_cache.cpp | 94 ++ src/atb/kernel_cache/aclnn_executor_cache.h | 38 + src/atb/runner/aclnn_runner.cpp | 154 +++ src/atb/runner/aclnn_runner.h | 38 + src/atb/runner/runner_type.h | 1 + src/atb/utils/aclnn_util.cpp | 300 ++++++ src/atb/utils/aclnn_util.h | 88 ++ src/atb/utils/dl_manager.cpp | 47 + src/atb/utils/dl_manager.h | 25 + .../mla_preprocess_aclnn_runner.cpp | 225 +++++ .../mla_preprocess_aclnn_runner.h | 46 + .../mla_preprocess_operation.cpp | 160 +++- .../mla_preprocess/mla_preprocess_operation.h | 6 +- src/torch_atb/CMakeLists.txt | 2 + .../test_mla_preprocess_aclnn.py | 899 ++++++++++++++++++ tests/cinterface/CMakeLists.txt | 2 +- tests/unittest/CMakeLists.txt | 3 +- 20 files changed, 2169 insertions(+), 33 deletions(-) create mode 100644 src/atb/kernel_cache/aclnn_executor_cache.cpp create mode 100644 src/atb/kernel_cache/aclnn_executor_cache.h create mode 100644 src/atb/runner/aclnn_runner.cpp create mode 100644 src/atb/runner/aclnn_runner.h create mode 100644 src/atb/utils/aclnn_util.cpp create mode 100644 src/atb/utils/aclnn_util.h create mode 100644 src/atb/utils/dl_manager.cpp create mode 100644 src/atb/utils/dl_manager.h create mode 100644 src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp create mode 100644 src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.h create mode 100644 tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 60cc53e9..96c8c1a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,7 @@ include_directories( ${PROJECT_SOURCE_DIR}/3rdparty/mki/include ${PROJECT_SOURCE_DIR}/3rdparty/nlohmannJson/include $ENV{ASCEND_HOME_PATH}/include/aclnn + $ENV{ASCEND_HOME_PATH}/include $ENV{PYTHON_INCLUDE_PATH} $ENV{PYTORCH_INSTALL_PATH}/include $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include diff --git a/include/atb/types.h b/include/atb/types.h index 5ce82c92..de3f2814 100644 --- a/include/atb/types.h +++ b/include/atb/types.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * Copyright (c) 2025 Huawei Technologies Co., Ltd. * This file is a part of the CANN Open Software. * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). * Please refer to the License for details. You may not use this file except in compliance with the License. @@ -11,9 +11,11 @@ #define ATB_TYPES_H #include #include -#include +#include #include +#include #include +#include #include "atb/svector.h" //! @@ -189,5 +191,68 @@ struct GraphParam { //! \brief inferShape函数指针。 InferShapeFunc inferShapeFunc = nullptr; }; + +//! +//! \struct AclNNIntArray +//! +//! \brief aclnn host tensor +//! +struct AclNNIntArray { + /// This struct is created by calling `aclCreateIntArray` and should be destroyed by calling `aclDestroyIntArray`. + /// \brief It is used to create the `aclOpExecutor`. + aclIntArray *intArray = nullptr; + /// \brief Data used to create the `aclIntArray*`. It is copied from atb::Tensor's hostData. + std::vector data = {}; + /// \brief the origin of the data + std::vector dataOri = {}; + /// \brief The size of `data` in bytes. + uint64_t dataSize = 0; +}; + +//! +//! \struct AclNNTensor +//! +//! \brief 用于aclnn算子,包装Tensor和intArray等信息 +//! +class AclNNTensor { +public: + /// An const value to indicate that the `tensorListidx` is invalid. + static const int64_t notInTensorList = -1; + + /// \brief Tensor passed through the ATB framework. + atb::Tensor atbTensor; + /// \brief The stride of each dimension in the tensor's view shape. Used when creating `aclTensor`. + atb::SVector strides = {}; + /// \brief Tensor passed into the AclNN operation. + aclTensor *tensor = nullptr; + /// \brief An AclNNIntArray object contain tensor's host data in the int array format. + AclNNIntArray intArrayHostData; + /// \brief The index of the tensor in the tensor list. Used when `aclTensor` is passed into `aclTensorList`. + int tensorListidx = notInTensorList; + /// \brief The index of the tensor in `aclOpExecutor`'s parameter list. + int tensorIdx = -1; + /// An indicator that shows whether the tensor's device data needs to be updated in the execution. + bool needUpdateTensorDataPtr = false; +}; + +//! +//! \struct AclNNVariantPack +//! +//! \brief aclnn算子 +//! +struct AclNNVariantPack { + /// \brief A container stores an AclNN operation's in tensor in order. + /// Each `AclNNTensor` object contains one `aclTensor`. + atb::SVector> aclInTensors; + /// \brief A container stores an AclNN operation's out tensor in order. + /// Each `AclNNTensor` object contains one `aclTensor`. + atb::SVector> aclOutTensors; + /// \brief A container stores an AclNN operation's input `aclTensorList` in order. + /// Each `aclTensorList` object may contain multiple `aclTensor`. + atb::SVector aclInTensorList; + /// \brief A container stores an AclNN operation's output `aclTensorList` in order. + /// Each `aclTensorList` object may contain multiple `aclTensor`. + atb::SVector aclOutTensorList; +}; } // namespace atb -#endif \ No newline at end of file +#endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 208da858..db3efe90 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -38,7 +38,7 @@ if(USE_ASAN) target_link_libraries(atb PRIVATE asan) target_link_libraries(atb_train PRIVATE asan) endif() -target_link_libraries(atb PUBLIC dl mki asdops atb_mixops ascendcl profapi lcal hccl pthread acl_op_compiler nnopbase) +target_link_libraries(atb PUBLIC dl mki asdops atb_mixops ascendcl profapi lcal hccl pthread acl_op_compiler nnopbase opapi) target_link_libraries(atb_train PUBLIC atb) target_include_directories(atb PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) target_include_directories(atb_static PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) diff --git a/src/atb/kernel_cache/aclnn_executor_cache.cpp b/src/atb/kernel_cache/aclnn_executor_cache.cpp new file mode 100644 index 00000000..8bacdec9 --- /dev/null +++ b/src/atb/kernel_cache/aclnn_executor_cache.cpp @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "aclnn_executor_cache.h" +#include "atb/utils/log.h" +#include "atb/utils/aclnn_util.h" +#include "atb/utils/tensor_util.h" + +namespace atb { +AclnnExecutorCache::AclnnExecutorCache() +{ + ATB_LOG(INFO) << "ATB aclnn cache init"; + cachePool_ = {}; +} + +AclnnExecutorCache::~AclnnExecutorCache() {} + +Status AclnnExecutorCache::FetchCacheSlot(const std::string &opNameStr, const RunnerVariantPack &aclnnCacheKey, + AclnnCacheSlot &outAclnnCacheSlot) +{ +#ifdef _DEBUG + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr + << ", try to fetch cache slot with RunnerVariantPack: " << aclnnCacheKey.ToString(); +#else + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr << ", try to fetch cache slot"; +#endif + std::map>>::iterator it = + cachePool_.find(opNameStr); + if (it == cachePool_.end()) { + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr << " not found in executor cache, consider add the cache first"; + return ERROR_INVALID_PARAM; + } + const std::vector> &slotVec = it->second; + int slotVecSize = static_cast(slotVec.size()); + for (int i = 0; i < slotVecSize; ++i) { + if (slotVec[i].second.executor == nullptr) { + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr << " cache index [" << i << "]'s executor is nullptr"; + continue; + } + if (TensorUtil::IsRunnerVariantPackInputEqual(aclnnCacheKey, slotVec[i].first)) { + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr << ", fetched slot at index [" << i << "]"; + outAclnnCacheSlot = slotVec[i].second; + return NO_ERROR; + } + } + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr + << ", cache slot was not found with variantPack: " << aclnnCacheKey.ToString(); + return ERROR_INVALID_PARAM; +} + +Status AclnnExecutorCache::AddCacheSlot(const std::string &opNameStr, const RunnerVariantPack &aclnnCacheKey, + AclnnCacheSlot &inAclnnCacheSlot) +{ +#ifdef _DEBUG + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr + << ", try to add cache slot with RunnerVariantPack: " << aclnnCacheKey.ToString(); +#else + ATB_LOG(INFO) << "ATB aclnn op: " << opNameStr << ", add to fetch cache slot"; +#endif + if (inAclnnCacheSlot.executor == nullptr) { + ATB_LOG(ERROR) << "ATB aclnn AddCacheSlot with op: " << opNameStr << " got null aclOpExecutor"; + return ERROR_INVALID_PARAM; + } + std::map>>::iterator it = + cachePool_.find(opNameStr); + if (it == cachePool_.end()) { + ATB_LOG(INFO) << "ATB aclnn executor cache add op: " << opNameStr; + cachePool_[opNameStr].emplace_back(std::make_pair(aclnnCacheKey, inAclnnCacheSlot)); + return NO_ERROR; + } + + std::vector> slotVec = it->second; + size_t slotVecSize = slotVec.size(); + if (slotVecSize < cacheCapacity_) { + slotVec.emplace_back(aclnnCacheKey, inAclnnCacheSlot); + ATB_LOG(INFO) << "ATB aclnn executor cache add op: " << opNameStr << " at index[" << slotVecSize << "]"; + } + + // 淘汰方式:使用等长vector+FIFO + ATB_LOG(INFO) << "ATB aclnn executor cache full for op: " << opNameStr << "update index [" << nextUpdateIndex_ + << "]"; + cachePool_[opNameStr][nextUpdateIndex_] = std::make_pair(aclnnCacheKey, inAclnnCacheSlot); + nextUpdateIndex_ = (nextUpdateIndex_ + 1) % cacheCapacity_; + return NO_ERROR; +} + +} // namespace atb diff --git a/src/atb/kernel_cache/aclnn_executor_cache.h b/src/atb/kernel_cache/aclnn_executor_cache.h new file mode 100644 index 00000000..932313c7 --- /dev/null +++ b/src/atb/kernel_cache/aclnn_executor_cache.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ATB_ACLNN_EXECUTOR_CACHE_H +#define ATB_ACLNN_EXECUTOR_CACHE_H +#include +#include "atb/types.h" +#include "atb/utils/runner_variant_pack.h" +namespace atb { + +struct AclnnCacheSlot { + uint64_t workspaceSize; + std::shared_ptr executor; +}; + +class AclnnExecutorCache { +public: + AclnnExecutorCache(); + ~AclnnExecutorCache(); + Status FetchCacheSlot( + const std::string &opNameStr, const RunnerVariantPack &aclnnCacheKey, AclnnCacheSlot &outAclnnCacheSlot); + Status AddCacheSlot( + const std::string &opNameStr, const RunnerVariantPack &aclnnCacheKey, AclnnCacheSlot &inAclnnCacheSlot); + +private: + std::map>> cachePool_; + int nextUpdateIndex_ = 0; + uint32_t cacheCapacity_ = 16; +}; +} // namespace atb +#endif diff --git a/src/atb/runner/aclnn_runner.cpp b/src/atb/runner/aclnn_runner.cpp new file mode 100644 index 00000000..6904e285 --- /dev/null +++ b/src/atb/runner/aclnn_runner.cpp @@ -0,0 +1,154 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "atb/runner/aclnn_runner.h" +#include "atb/kernel_cache/aclnn_executor_cache.h" +#include "atb/utils/aclnn_util.h" +#include "atb/utils/log.h" +#include "atb/utils/singleton.h" + +namespace atb { + +AclnnRunner::AclnnRunner(const std::string &name, RunnerType runnerType) : Runner(name), runnerType_(runnerType) {} + +AclnnRunner::~AclnnRunner() {} + +Status AclnnRunner::SetupImpl(RunnerVariantPack &runnerVariantPack) +{ + ATB_LOG(INFO) << GetLogPrefix() << "aclnn runner setupImpl"; + if (!runnerVariantPack.context) { + ATB_LOG(ERROR) << GetLogPrefix() << "context is not ContextBase, setup fail"; + return ERROR_INVALID_CONTEXT_ADDR; + } + Status ret = NO_ERROR; + const std::string &opName = this->GetName(); + AclnnCacheSlot aclnnCacheSlot = {}; + // executorCache hit + if (executorRepeatable_ && + GetSingleton().FetchCacheSlot(opName, runnerVariantPack, aclnnCacheSlot) == NO_ERROR) { + if (!IsAclnnRunnerVariankPackEqual(this->aclnnVariantPack_, runnerVariantPack)) { + ATB_LOG(INFO) << GetLogPrefix() + << "fetched cached runnerVariantPack not same as aclnnVariantPack_, build again"; + ret = BuildAclnnVariantPack(runnerVariantPack); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() << "BuildAclnnVariantPack failed!"; + return ret; + } + } + this->atbVariantPack_.workspaceBufferSize = aclnnCacheSlot.workspaceSize; + this->aclnnExecutor_ = aclnnCacheSlot.executor; + return NO_ERROR; + } + // cache miss,创建新的executor + ATB_LOG(INFO) << GetLogPrefix() << "ExecutorCache miss, BuildAclnnVariantPack directly"; + ret = BuildAclnnVariantPack(runnerVariantPack); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() << "BuildAclnnVariantPack failed!"; + return ret; + } + aclnnStatus aclnnRet = ACL_SUCCESS; + aclnnRet = SetAclNNWorkspaceExecutor(); + if (aclnnRet != ACL_SUCCESS) { + ATB_LOG(ERROR) << GetLogPrefix() << "Atb aclnn op set workspace failed with return value: " << aclnnRet; + return ERROR_CANN_ERROR; + } + ATB_LOG(INFO) << GetLogPrefix() + << "getWorkspace success, workspaceSize: " << this->atbVariantPack_.workspaceBufferSize + << ", workspace addr: " << this->atbVariantPack_.workspaceBuffer; + aclnnRet = aclSetAclOpExecutorRepeatable(this->aclnnExecutor_.get()); + if (ret != 0) { + // 设置算子可复用失败,标记cache中executor不可复用 + ATB_LOG(INFO) << this->GetName() << " call aclSetAclOpExecutorRepeatable fail: " << ret; + this->executorRepeatable_ = false; + } else { + // 设置算子可复用成功,标记cache中executor可复用 + ATB_LOG(INFO) << this->GetName() << " call aclSetAclOpExecutorRepeatable success: "; + this->executorRepeatable_ = true; + } + aclnnCacheSlot = {this->atbVariantPack_.workspaceBufferSize, aclnnExecutor_}; + ret = GetSingleton().AddCacheSlot(opName, runnerVariantPack, aclnnCacheSlot); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() << "AclnnExecutorCache update cache failed!"; + } + ATB_LOG(INFO) << GetLogPrefix() << "AclnnExecutorCache AddCacheSlot success opName: " << opName + << ", runnerVariantPack: " << runnerVariantPack.ToString(); + return ret; +} + +uint64_t AclnnRunner::GetWorkspaceBufferSizeImpl() +{ + return this->atbVariantPack_.workspaceBufferSize; +} + +Status AclnnRunner::PreExecuteImpl(RunnerVariantPack &runnerVariantPack) +{ + ATB_LOG(INFO) << GetLogPrefix() << "AclNNOpCacheUpdateAclNNVariantPack"; + for (size_t i = 0; i < this->aclnnVariantPack_.aclInTensors.size(); ++i) { + int ret = -1; + if (!this->aclnnVariantPack_.aclInTensors[i]->needUpdateTensorDataPtr) { + continue; + } + this->aclnnVariantPack_.aclInTensors[i]->atbTensor = runnerVariantPack.inTensors.at(i); + if (this->aclnnVariantPack_.aclInTensors[i]->tensorListidx == AclNNTensor::notInTensorList) { + ret = aclSetInputTensorAddr(this->aclnnExecutor_.get(), this->aclnnVariantPack_.aclInTensors[i]->tensorIdx, + this->aclnnVariantPack_.aclInTensors[i]->tensor, + this->aclnnVariantPack_.aclInTensors[i]->atbTensor.deviceData); + } else { + ret = aclSetDynamicInputTensorAddr( + this->aclnnExecutor_.get(), this->aclnnVariantPack_.aclInTensors[i]->tensorListidx, + this->aclnnVariantPack_.aclInTensors[i]->tensorIdx, + this->aclnnVariantPack_.aclInTensorList[this->aclnnVariantPack_.aclInTensors[i]->tensorListidx], + this->aclnnVariantPack_.aclInTensors[i]->atbTensor.deviceData); + } + if (ret != 0) { + ATB_LOG(ERROR) << GetLogPrefix() << "inTensor " << i << " call UpdateAclTensorDataPtr fail, error: " << ret; + return atb::ERROR_CANN_ERROR; + } + } + + for (size_t i = 0; i < this->aclnnVariantPack_.aclOutTensors.size(); ++i) { + int ret = -1; + if (!this->aclnnVariantPack_.aclOutTensors[i]->needUpdateTensorDataPtr) { + continue; + } + this->aclnnVariantPack_.aclOutTensors[i]->atbTensor = runnerVariantPack.outTensors.at(i); + if (this->aclnnVariantPack_.aclOutTensors[i]->tensorListidx == AclNNTensor::notInTensorList) { + ret = + aclSetOutputTensorAddr(this->aclnnExecutor_.get(), this->aclnnVariantPack_.aclOutTensors[i]->tensorIdx, + this->aclnnVariantPack_.aclOutTensors[i]->tensor, + this->aclnnVariantPack_.aclOutTensors[i]->atbTensor.deviceData); + } else { + ret = aclSetDynamicOutputTensorAddr( + this->aclnnExecutor_.get(), this->aclnnVariantPack_.aclOutTensors[i]->tensorListidx, + this->aclnnVariantPack_.aclOutTensors[i]->tensorIdx, + this->aclnnVariantPack_.aclOutTensorList[this->aclnnVariantPack_.aclOutTensors[i]->tensorListidx], + this->aclnnVariantPack_.aclOutTensors[i]->atbTensor.deviceData); + } + if (ret != 0) { + ATB_LOG(ERROR) << GetLogPrefix() << "outTensor " << i + << " call UpdateAclTensorDataPtr fail, error: " << ret; + return atb::ERROR_CANN_ERROR; + } + } + + return atb::NO_ERROR; +} + +void AclnnRunner::UpdateWorkspace(const RunnerVariantPack &runnerVariantPack) +{ + this->atbVariantPack_.workspaceBufferSize = runnerVariantPack.workspaceBufferSize; +} + +Status AclnnRunner::ExecuteImpl(RunnerVariantPack &runnerVariantPack) +{ + UpdateWorkspace(runnerVariantPack); + return LaunchAclnnKernel(); +} + +} // namespace atb diff --git a/src/atb/runner/aclnn_runner.h b/src/atb/runner/aclnn_runner.h new file mode 100644 index 00000000..26933b79 --- /dev/null +++ b/src/atb/runner/aclnn_runner.h @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ATB_ACLNN_RUNNER_H +#define ATB_ACLNN_RUNNER_H +#include "runner_type.h" +#include "runner.h" + +namespace atb { +class AclnnRunner : public Runner { +public: + explicit AclnnRunner(const std::string &name, RunnerType runnerType = RUNNER_TYPE_UNDEFINED); + ~AclnnRunner() override; +protected: + Status SetupImpl(RunnerVariantPack &runnerVariantPack) override; + virtual Status BuildAclnnVariantPack(const RunnerVariantPack &runnerVariantPack) = 0; + virtual aclnnStatus SetAclNNWorkspaceExecutor() = 0; + + Status PreExecuteImpl(RunnerVariantPack &runnerVariantPack) override; + Status ExecuteImpl(RunnerVariantPack &runnerVariantPack) override; + virtual Status LaunchAclnnKernel() = 0; + uint64_t GetWorkspaceBufferSizeImpl() override; + void UpdateWorkspace(const RunnerVariantPack &runnerVariantPack); + RunnerType runnerType_ = RUNNER_TYPE_UNDEFINED; + bool executorRepeatable_ = false; + std::shared_ptr aclnnExecutor_ = nullptr; + AclNNVariantPack aclnnVariantPack_; + RunnerVariantPack atbVariantPack_; +}; +} // namespace atb +#endif diff --git a/src/atb/runner/runner_type.h b/src/atb/runner/runner_type.h index b226615d..fe92eac1 100644 --- a/src/atb/runner/runner_type.h +++ b/src/atb/runner/runner_type.h @@ -109,6 +109,7 @@ enum RunnerType : int { RUNNER_TYPE_GMM_DEQ_SWIGLU_QUANT_GMM_DEQ, RUNNER_TYPE_MM_DEQ_SWIGLU_QUANT_MM_DEQ, RUNNER_TYPE_RING_MLA, + RUNNER_TYPE_MLA_PREPROCESS_ACLNN, RUNNER_TYPE_MAX, }; } // namespace atb diff --git a/src/atb/utils/aclnn_util.cpp b/src/atb/utils/aclnn_util.cpp new file mode 100644 index 00000000..dcbd1ac7 --- /dev/null +++ b/src/atb/utils/aclnn_util.cpp @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include +#include + +#include "aclnn_util.h" +#include "log.h" + +namespace { +/// Constant to represent index +const int DIM0 = 0; +const int DIM1 = 1; +const int DIM2 = 2; +const int DIM3 = 3; +} // namespace + + +namespace atb { +template +typename std::common_type::type CheckIntMulOverFlow(const T a, const U b) +{ + if (std::is_signed::value != std::is_signed::value) { + throw std::runtime_error("Multiplication between signed and unsigned integer not supported, it's not safe"); + } + using PromotedType = typename std::common_type::type; + if (a == 0 || b == 0) { + return 0; + } + + PromotedType pa = static_cast(a); + PromotedType pb = static_cast(b); + + if constexpr (std::is_signed::value) { + const PromotedType maxVal = std::numeric_limits::max(); + const PromotedType minVal = std::numeric_limits::min(); + if (pa > 0 && pb > 0) { + if (pa > maxVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else if (pa < 0 && pb < 0) { + if (pa < maxVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else if (pa > 0 && pb < 0) { + if (pa > minVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else if (pa < minVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } else { + const PromotedType maxVal = std::numeric_limits::max(); + if (pa > maxVal / pb) { + throw std::overflow_error("Integer overflow detected."); + } + } + return pa * pb; +} + +atb::SVector GetCopyTensorStride(atb::Dims &tensorDims) +{ + atb::SVector tmpStrides(tensorDims.dimNum, 1); + if (tensorDims.dimNum > 8) { // 8: tensor最大维度数量 + ATB_LOG(ERROR) << "Tensor's dimNum is larger than 8, `GetCopyTensorStride` failed."; + return tmpStrides; + } + for (int64_t i = static_cast(tensorDims.dimNum) - 2; i >= 0; i--) { + tmpStrides[i] = CheckIntMulOverFlow(tensorDims.dims[i + 1], tmpStrides[i + 1]); + } + return tmpStrides; +} + +atb::SVector GetTransposeTensorStride(atb::Dims &tensorDims) +{ + atb::SVector tmptransposeStrides(tensorDims.dimNum, 1); + tmptransposeStrides[tensorDims.dimNum - 1] = tensorDims.dims[tensorDims.dimNum - 1]; + if (tensorDims.dimNum == 3) { // 3: 维度 + tmptransposeStrides[0] = CheckIntMulOverFlow( // 0: 第0维 + tensorDims.dims[1], + tensorDims.dims[2]); // 1, 2: 跳过第1维和第2维的大小 + } + return tmptransposeStrides; +} + +atb::Status CallAclCreateTensor( + atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, std::shared_ptr aclnnTensor) +{ + aclnnTensor->tensor = aclCreateTensor(viewDims.dims, + viewDims.dimNum, + atbTensor.desc.dtype, + aclnnTensor->strides.data(), + 0, + atbTensor.desc.format, + storageDims.dims, + storageDims.dimNum, + atbTensor.deviceData); + if (aclnnTensor->tensor == nullptr) { + return atb::ERROR_INTERNAL_ERROR; + } + return atb::NO_ERROR; +} + +atb::Tensor SqueezeBatchSeq(atb::Tensor atbTensor) +{ + if (atbTensor.desc.shape.dimNum == DIM3) { + atbTensor.desc.shape.dimNum = DIM2; + atbTensor.desc.shape.dims[DIM0] = + CheckIntMulOverFlow(atbTensor.desc.shape.dims[DIM0], atbTensor.desc.shape.dims[DIM1]); + atbTensor.desc.shape.dims[DIM1] = atbTensor.desc.shape.dims[DIM2]; + } + return atbTensor; +} + +std::string PrintAclNNVariankPack(const AclNNVariantPack &aclnnVariantPack) +{ + std::stringstream ss; + ss << "ATB aclnn Op Cache: AclNNVariantPack "; + for (size_t i = 0; i < aclnnVariantPack.aclInTensors.size(); i++) { + const atb::TensorDesc &tensorDesc = aclnnVariantPack.aclInTensors[i]->atbTensor.desc; + ss << "index " << i << " dtype " << tensorDesc.dtype << " format " << tensorDesc.format << " dimNum " + << tensorDesc.shape.dimNum; + for (uint64_t j = 0; j < std::min(tensorDesc.shape.dimNum, static_cast(8)); + j++) { // 8: tensor最大维度数量 + ss << "dim[" << j << "]=" << tensorDesc.shape.dims[j] << " "; + } + } + return ss.str(); +} + +std::string PrintATBVariankPack(const atb::VariantPack &atbVariantPack) +{ + std::stringstream ss; + ss << "ATB aclnn Op Cache: ATBVariantPack "; + for (size_t i = 0; i < atbVariantPack.inTensors.size(); i++) { + const atb::TensorDesc &tensorDesc = atbVariantPack.inTensors[i].desc; + ss << "index " << i << " dtype " << tensorDesc.dtype << " format " << tensorDesc.format << " dimNum " + << tensorDesc.shape.dimNum; + for (uint64_t j = 0; j < std::min(tensorDesc.shape.dimNum, static_cast(8)); + j++) { // 8: tensor最大维度数量 + ss << "dim[" << j << "]=" << tensorDesc.shape.dims[j] << " "; + } + } + return ss.str(); +} + +bool IsHostDataEqual(const std::shared_ptr tensorA, const atb::Tensor &tensorB, int tensorIdx) +{ + if (tensorA->intArrayHostData.intArray != nullptr && tensorB.hostData == nullptr) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx + << " aclnnVariantPack hostData is not null but atbVariantPack hostData is"; + return false; + } + if (tensorA->intArrayHostData.intArray == nullptr && tensorB.hostData != nullptr) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx + << " aclnnVariantPack hostData is null but atbVariantPack hostData is not"; + return false; + } + if (tensorA->intArrayHostData.intArray != nullptr && tensorB.hostData != nullptr) { + if (tensorA->intArrayHostData.dataOri.size() * 4 != tensorB.dataSize) { // 8: int64_t in bytes + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx << " dataSize not equal"; + return false; + } + if (memcmp(tensorA->intArrayHostData.dataOri.data(), tensorB.hostData, tensorB.dataSize) != 0) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx << " hostData not equal"; + return false; + } + } + return true; +} + +bool IsTensorDescEqual(const atb::TensorDesc &tensorDescA, const atb::TensorDesc &tensorDescB, int tensorIdx) +{ + if (tensorDescA.dtype != tensorDescB.dtype) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx + << " dtype not equal, aclnnVariantPack dtype " << tensorDescA.dtype << " atbVariantPack dtype " + << tensorDescB.dtype; + return false; + } + if (tensorDescA.format != tensorDescB.format) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx + << " format not equal, aclnnVariantPack format " << tensorDescA.format + << " atbVariantPack format " << tensorDescB.format; + return false; + } + if (tensorDescA.shape.dimNum != tensorDescB.shape.dimNum || tensorDescA.shape.dimNum > 8 || + tensorDescA.shape.dimNum <= 0) { // 8: tensor最大维度数量 + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: tensor index " << tensorIdx + << " dimNum not equal, aclnnVariantPack dimNum " << tensorDescA.shape.dimNum + << " atbVariantPack dimNum " << tensorDescB.shape.dimNum; + return false; + } + for (uint64_t j = 0; j < tensorDescA.shape.dimNum; j++) { + if (tensorDescA.shape.dims[j] != tensorDescB.shape.dims[j]) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: : tensor index " << tensorIdx << " shape.dims " << j + << " not equal, aclnnVariantPack value " << tensorDescA.shape.dims[j] + << " atbVariantPack value " << tensorDescB.shape.dims[j]; + return false; + } + } + return true; +} + +bool AreTensorVectorsEqual( + const atb::SVector> &aclnnTensors, const atb::SVector &atbTensors) +{ + // Check the size of two vectors + if (aclnnTensors.size() != atbTensors.size()) { + ATB_LOG(DEBUG) << "ATB aclnn Op Cache: size not equal, aclnnVariantPack size " << aclnnTensors.size() + << " atbVariantPack size " << atbTensors.size(); + return false; + } + + // Check if every tensor in each vector has consistent data type, format, shape and host data. + for (size_t i = 0; i < aclnnTensors.size(); i++) { + const std::shared_ptr tensorA = aclnnTensors[i]; + const atb::Tensor &tensorB = atbTensors[i]; + + if (!IsHostDataEqual(tensorA, tensorB, i)) { + return false; + } + + if (!IsTensorDescEqual(tensorA->atbTensor.desc, tensorB.desc, i)) { + return false; + } + } + + return true; +} + +bool IsAclnnRunnerVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const RunnerVariantPack &runnerVariantPack) +{ + ATB_LOG(INFO) << "Compare AclNNVariantPack with RunnerVariantPack:"; + ATB_LOG(INFO) << PrintAclNNVariankPack(aclnnVariantPack); + ATB_LOG(INFO) << runnerVariantPack.ToString(); + if (!AreTensorVectorsEqual(aclnnVariantPack.aclInTensors, runnerVariantPack.inTensors)) { + return false; + } + + if (!AreTensorVectorsEqual(aclnnVariantPack.aclOutTensors, runnerVariantPack.outTensors)) { + return false; + } + ATB_LOG(INFO) << "ATB aclnn Op Cache: TensorDesc match"; + return true; +} + +bool IsAclnnAtbVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const atb::VariantPack &atbVariantPack) +{ + ATB_LOG(INFO) << PrintAclNNVariankPack(aclnnVariantPack); + ATB_LOG(INFO) << PrintATBVariankPack(atbVariantPack); + + if (!AreTensorVectorsEqual(aclnnVariantPack.aclInTensors, atbVariantPack.inTensors)) { + return false; + } + + if (!AreTensorVectorsEqual(aclnnVariantPack.aclOutTensors, atbVariantPack.outTensors)) { + return false; + } + ATB_LOG(INFO) << "ATB aclnn Op Cache: TensorDesc match"; + return true; +} + +std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx) +{ + std::shared_ptr aclnnTensor = std::make_shared(); + aclnnTensor->needUpdateTensorDataPtr = true; + aclnnTensor->atbTensor = atbTensor; + aclnnTensor->tensorIdx = tensorIdx; + aclnnTensor->strides = GetCopyTensorStride(atbTensor.desc.shape); + CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensor); + return aclnnTensor; +} + +int ConvertTensorToSeqLengths(atb::Tensor &tensor, aclIntArray *&actualSeqLengths) +{ + static std::vector seqLenCache; + size_t dataSize = tensor.dataSize / 8; // 8: int64 size + if (seqLenCache.size() < dataSize) { + seqLenCache.resize(dataSize); + } + if (memcpy_s(seqLenCache.data(), dataSize * 8, tensor.hostData, dataSize * 8) != 0) { // 8: int64 size + ATB_LOG(ERROR) << "memcpy_s failed!"; + return atb::ERROR_INTERNAL_ERROR; + } + if (actualSeqLengths != nullptr) { + aclDestroyIntArray(actualSeqLengths); + actualSeqLengths = nullptr; + } + actualSeqLengths = aclCreateIntArray(static_cast(seqLenCache.data()), dataSize); + return atb::NO_ERROR; +} +} // namespace atb diff --git a/src/atb/utils/aclnn_util.h b/src/atb/utils/aclnn_util.h new file mode 100644 index 00000000..9806541e --- /dev/null +++ b/src/atb/utils/aclnn_util.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ATB_ACLNN_UTILS_H +#define ATB_ACLNN_UTILS_H +#include "atb/types.h" +#include "atb/utils/runner_variant_pack.h" + +namespace atb { + +/// Calculate stride along each dimension when copying a tensor. +/// +/// \param tensorDims The size of each axis in a tensor. +/// \return The number of steps needed in memory to move to the next element +/// along each dimension when copying a tensor. +atb::SVector GetCopyTensorStride(atb::Dims &tensorDims); + +/// Calculate stride along each dimension when transposing a tensor. +/// +/// \param tensorDims The size of each axis in a tensor. +/// \return The number of steps needed in memory to move to the next element +/// along each dimension when transposing a tensor. +atb::SVector GetTransposeTensorStride(atb::Dims &tensorDims); + +/// Call `aclCreateTensor` API to create `aclTensor`. +/// +/// \param viewDims The dimension of tensor's view shape. +/// \param storageDims The dimension of tensor's storage shape. +/// \param atbTensor The tensor passed through ATB framework. +/// \param aclnnTensor A pointer to an `AclNNTensor` object whose `tensor` attribute is updated +/// using the return value of `aclCreateTensor`. +/// \return A status code that indicates whether `aclTensor` has been created. +atb::Status CallAclCreateTensor(atb::Dims &viewDims, atb::Dims &storageDims, atb::Tensor &atbTensor, + std::shared_ptr aclnnTensor); + +/// Reshape a tensor by squeezing batch size axis and seq len axis if the tensor's shape has two dimensions. +/// +/// \param atbTensor An `atb::Tensor` object whose tensor shape requires reshaping. +/// \return The `atb::Tensor` after reshaping. +atb::Tensor SqueezeBatchSeq(atb::Tensor atbTensor); + +/// Check whether `aclnnVariantPack` and `atbVariantPack` are the same, except for tensors' device data. +/// +/// Two variant packs, `aclnnVariantPack` and `atbVariantPack`, are considered the same if they have the same +/// number of tensors, and each corresponding tensor in the variant packs +/// has identical data type,format,shape and host data. +/// \param aclnnVariantPack An `AclNNVariantPack` object containing tensor info of existing AclNN operation. +/// \param atbVariantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. +/// \return A boolean value that indicates whether `aclnnVariantPack` and `atbVariantPack` are the same, +/// except for tensors' device data. +bool IsAclnnAtbVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const VariantPack &atbVariantPack); + +/// Check whether `aclnnVariantPack` and `atbVariantPack` are the same, except for tensors' device data. +/// +/// Two variant packs, `aclnnVariantPack` and `atbVariantPack`, are considered the same if they have the same +/// number of tensors, and each corresponding tensor in the variant packs +/// has identical data type,format,shape and host data. +/// \param aclnnVariantPack An `AclNNVariantPack` object containing tensor info of existing AclNN operation. +/// \param atbVariantPack An `atb::VariantPack` object containing tensor info passed through ATB framework. +/// \return A boolean value that indicates whether `aclnnVariantPack` and `atbVariantPack` are the same, +/// except for tensors' device data. +bool IsAclnnRunnerVariankPackEqual(const AclNNVariantPack &aclnnVariantPack, const RunnerVariantPack &runnerVariantPack); + +/// Create a pointer to `AclNNTensor` by configuring it with tensor information extracted from `atbTensor`. +/// +/// `needUpdateTensorDataPtr` is set to true, `atbTensor` and `tensorIdx` are updated with input parameters. +/// `strides` is updated by `GetCopyTensorStride`. +/// \param atbTensor Tensor passed through the ATB framework. +/// \param tensorIdx The index of the tensor in `aclOpExecutor`'s parameter list. +/// \return A pointer to an `AclNNTensor` object whose attributes are updated. +std::shared_ptr CreateTensor(atb::Tensor atbTensor, int tensorIdx); + +int ConvertTensorToSeqLengths(atb::Tensor &tensor, aclIntArray *&actualSeqLengths); +} // namespace atb +#endif \ No newline at end of file diff --git a/src/atb/utils/dl_manager.cpp b/src/atb/utils/dl_manager.cpp new file mode 100644 index 00000000..5bc354a7 --- /dev/null +++ b/src/atb/utils/dl_manager.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include "atb/utils/dl_manager.h" +#include + +namespace atb { +DlManager::DlManager(std::string path) : path_(path) +{ + // 在构造函数中加载动态库 + handle_ = dlopen(path.c_str(), RTLD_LAZY); + if (handle_ == nullptr) { + ATB_LOG(ERROR) << "Dynamic library handle is null, please check the path: " << path_; + } +} + +DlManager::~DlManager() +{ + // 在析构函数中卸载动态库 + if (handle_ != nullptr) { + dlclose(handle_); + } +} + +Status DlManager::getSymbol(const std::string &symbol, void **symbolPtr) const +{ + if (handle_ == nullptr) { + ATB_LOG(ERROR) << "Dynamic library handle is null, please check the path: " << path_; + return ERROR_CANN_ERROR; + } + *symbolPtr = dlsym(handle_, symbol.c_str()); + char *errorInfo = dlerror(); + if (*symbolPtr == nullptr || errorInfo != nullptr) { + ATB_LOG(ERROR) << "Failed to find symbol " << symbol << " from path: " << path_ << ", error: " << errorInfo; + return ERROR_CANN_ERROR; + } + return NO_ERROR; +} + +} // namespace atb diff --git a/src/atb/utils/dl_manager.h b/src/atb/utils/dl_manager.h new file mode 100644 index 00000000..69568d6f --- /dev/null +++ b/src/atb/utils/dl_manager.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +#include "atb/utils/log.h" + +namespace atb { +class DlManager { +public: + DlManager(std::string path); + ~DlManager(); + Status getSymbol(const std::string &symbol, void **symbolPtr) const; + +private: + std::string path_; + void *handle_ = nullptr; +}; +} // namespace atb diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp new file mode 100644 index 00000000..fadf0457 --- /dev/null +++ b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "mla_preprocess_aclnn_runner.h" +#include "atb/utils/dl_manager.h" +#include "atb/utils/aclnn_util.h" +#include "atb/utils/log.h" +#include + +namespace { +static const uint32_t IN_TENSOR_NUM = 24; +static const uint32_t OUT_TENSOR_NUM = 4; +static const uint32_t Q_OUT1_INDEX = 2; +static const uint32_t KV_CACHE_OUT1_INDEX = 3; +static const uint32_t KV_CACHE_ROPE_INDEX = 20; +} // namespace + +namespace atb { +// 初始化类函数指针 +aclnnStatus (*MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_)( + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + int64_t, int64_t, int64_t, double, int64_t, int64_t, bool, bool, bool, int64_t, int64_t, bool, int64_t, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, uint64_t *, aclOpExecutor **) = nullptr; + +aclnnStatus (*MlaPreprocessAclnnRunner::aclnnExecuteFunc_)(void *, uint64_t, aclOpExecutor *, aclrtStream) = nullptr; + +MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m) + : AclnnRunner("MlaPreprocessOpsRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param) +{ + ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner called"; +} + +MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m, bool doRmsNorm) + : AclnnRunner("MlaPreprocessOpsRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param), doRmsNorm_(doRmsNorm) +{ + ATB_LOG(INFO) << GetLogPrefix() + << "MlaPreprocessAclnnRunnecr::MlaPreprocessAclnnRunner called, set doRmsNorm: " << doRmsNorm; +} + +MlaPreprocessAclnnRunner::~MlaPreprocessAclnnRunner() {} + +Status MlaPreprocessAclnnRunner::BuildAclnnVariantPack(const RunnerVariantPack &runnerVariantPack) +{ + ATB_LOG(INFO) << GetLogPrefix() << "BuildAclnnVariantPack"; + this->atbVariantPack_ = runnerVariantPack; + Status ret = NO_ERROR; + bool isRopeCache = param_.cacheMode != infer::MlaPreprocessParam::CacheMode::KVCACHE; + this->aclnnVariantPack_.aclInTensors.reserve(IN_TENSOR_NUM); + this->aclnnVariantPack_.aclInTensors.resize(IN_TENSOR_NUM); + for (size_t i = 0; i < this->aclnnVariantPack_.aclInTensors.size(); ++i) { + std::shared_ptr aclnnTensorPtr = std::make_shared(); + if (i == KV_CACHE_ROPE_INDEX && !isRopeCache) { + // kvCache不带rope转置时kvCacheRope为nullptr + this->aclnnVariantPack_.aclInTensors[i] = aclnnTensorPtr; + continue; + } + atb::Tensor atbTensor = runnerVariantPack.inTensors.at(i); + aclnnTensorPtr->atbTensor = atbTensor; + aclnnTensorPtr->strides = GetCopyTensorStride(atbTensor.desc.shape); + ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() << "create aclTensor by aclCreateTensor failed!"; + return ret; + } + aclnnTensorPtr->tensorIdx = i; + aclnnTensorPtr->needUpdateTensorDataPtr = true; + this->aclnnVariantPack_.aclInTensors[i] = aclnnTensorPtr; + } + + this->aclnnVariantPack_.aclOutTensors.reserve(OUT_TENSOR_NUM); + this->aclnnVariantPack_.aclOutTensors.resize(OUT_TENSOR_NUM); + for (size_t i = 0; i < this->aclnnVariantPack_.aclOutTensors.size(); ++i) { + std::shared_ptr aclnnTensorPtr = std::make_shared(); + if ((i == Q_OUT1_INDEX || i == KV_CACHE_OUT1_INDEX) && !isRopeCache) { + // kvCache不带rope转置时不生成2个rope分量 + this->aclnnVariantPack_.aclOutTensors[i] = aclnnTensorPtr; + continue; + } + atb::Tensor atbTensor = runnerVariantPack.outTensors.at(i); + aclnnTensorPtr->atbTensor = atbTensor; + aclnnTensorPtr->strides = GetCopyTensorStride(atbTensor.desc.shape); + ret = CallAclCreateTensor(atbTensor.desc.shape, atbTensor.desc.shape, atbTensor, aclnnTensorPtr); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() << "create aclTensor by aclCreateTensor failed!"; + return ret; + } + aclnnTensorPtr->tensorIdx = i; + aclnnTensorPtr->needUpdateTensorDataPtr = true; + this->aclnnVariantPack_.aclOutTensors[i] = aclnnTensorPtr; + } + return atb::NO_ERROR; +} + +aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() +{ + ATB_LOG(INFO) << GetLogPrefix() << "aclnn mlaPreprocess setup start."; + Status status = MlaPreprocessAclnnRunner::LoadMethod(); + if (status != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() + << "load getWorkspace function from aclnn failed! Consider upgrade CANN first!"; + return status; + } + size_t inTensorStart = 0; + bool isRopeCache = param_.cacheMode != infer::MlaPreprocessParam::CacheMode::KVCACHE; + ATB_LOG(INFO) << GetLogPrefix() << "aclnn mlaPreprocess isRopeCache: " << isRopeCache + << ", aclInTensors size: " << this->aclnnVariantPack_.aclInTensors.size() + << ", aclOutTensors size: " << this->aclnnVariantPack_.aclOutTensors.size(); + aclTensor *input = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *gamma0 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *beta0 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *quantScale0 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *quantOffset0 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *wdqkv = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *deScale0 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *bias0 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *gamma1 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *beta1 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *quantScale1 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *quantOffset1 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *wuq = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *deScale1 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *bias1 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *gamma2 = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *cos = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *sin = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *wuk = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *kvCache = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *kRope = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + if (!isRopeCache) { + kRope = nullptr; + } + aclTensor *slotmapping = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *ctkvScale = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + aclTensor *qNopeScale = this->aclnnVariantPack_.aclInTensors.at(inTensorStart++)->tensor; + size_t outTensorStart = 0; + aclTensor *qOut0 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; + aclTensor *kvCacheOut0 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; + aclTensor *qOut1 = nullptr; + aclTensor *kvCacheOut1 = nullptr; + if (isRopeCache) { + qOut1 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; + kvCacheOut1 = this->aclnnVariantPack_.aclOutTensors.at(outTensorStart++)->tensor; + } + + aclOpExecutor *raw_executor_ptr = this->aclnnExecutor_.get(); + ATB_LOG(INFO) << GetLogPrefix() << "&(this->aclnnExecutor_): " << &(this->aclnnExecutor_) + << ", addr of this->aclnnExecutor_: " << this->aclnnExecutor_ + << ", raw ptr from it: " << raw_executor_ptr + << ", then take the address of the raw ptr: " << &raw_executor_ptr; + + ATB_LOG(INFO) << GetLogPrefix() << "workspaceSize addr: " << &(this->atbVariantPack_.workspaceBufferSize); + + aclnnStatus ret = MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_( + input, gamma0, beta0, quantScale0, quantOffset0, wdqkv, deScale0, bias0, gamma1, beta1, quantScale1, + quantOffset1, wuq, deScale1, bias1, gamma2, cos, sin, wuk, kvCache, kRope, slotmapping, ctkvScale, qNopeScale, + param_.wdqDim, param_.qRopeDim, param_.kRopeDim, param_.epsilon, param_.qRotaryCoeff, param_.kRotaryCoeff, + param_.transposeWdq, param_.transposeWuq, param_.transposeWuk, param_.cacheMode, param_.quantMode, + doRmsNorm_, // doRmsNorm + 1, // wdkvSplitCount + qOut0, kvCacheOut0, qOut1, kvCacheOut1, &(this->atbVariantPack_.workspaceBufferSize), &raw_executor_ptr); + // TODO:上移 + aclError aclRet = aclSetAclOpExecutorRepeatable(this->aclnnExecutor_.get()); + bool repeatable = this->executorRepeatable_; + this->aclnnExecutor_ = std::shared_ptr(raw_executor_ptr, [repeatable](aclOpExecutor *ptr) { + if (ptr && repeatable) { // 可复用时才手动销毁aclOpExecutor + aclDestroyAclOpExecutor(ptr); + } + }); + ATB_LOG(INFO) << GetLogPrefix() << "workspaceSize: " << this->atbVariantPack_.workspaceBufferSize; + return ret; +} + +Status MlaPreprocessAclnnRunner::LaunchAclnnKernel() +{ + ATB_LOG(INFO) << GetLogPrefix() << " execute start."; + Status status = MlaPreprocessAclnnRunner::LoadMethod(); + if (status != NO_ERROR) { + ATB_LOG(ERROR) << GetLogPrefix() + << "load getWorkspace function from aclnn failed! Consider upgrade CANN first!"; + return status; + } + void *executeStream = GetExecuteStream(this->atbVariantPack_.context); + aclnnStatus ret = MlaPreprocessAclnnRunner::aclnnExecuteFunc_(this->atbVariantPack_.workspaceBuffer, + this->atbVariantPack_.workspaceBufferSize, + this->aclnnExecutor_.get(), executeStream); + if (ret != ACL_SUCCESS) { + ATB_LOG(ERROR) << GetLogPrefix() << "Atb aclnn op kernel launch failed with return value: " << ret; + return ERROR_CANN_ERROR; + } + ATB_LOG(INFO) << GetLogPrefix() << " execute success."; + return NO_ERROR; +} + +Status MlaPreprocessAclnnRunner::LoadMethod() +{ + ATB_LOG(INFO) << "MlaPreprocessAclnnRunner LoadMethod"; + if (MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_ != nullptr && + MlaPreprocessAclnnRunner::aclnnExecuteFunc_ != nullptr) { + return NO_ERROR; + } + DlManager dlManager = DlManager(std::string(std::getenv("ASCEND_HOME_PATH")) + "/lib64/libopapi.so"); + Status ret = dlManager.getSymbol("aclnnMlaPreprocessGetWorkspaceSize", + (void **)&MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << "load aclnnMlaPreprocessGetWorkspaceSize failed! Consider upgrade the CANN first!"; + return ret; + } + ret = dlManager.getSymbol("aclnnMlaPreprocess", (void **)&MlaPreprocessAclnnRunner::aclnnExecuteFunc_); + if (ret != NO_ERROR) { + ATB_LOG(ERROR) << "load aclnnMlaPreprocess failed! Consider upgrade the CANN first!"; + return ret; + } + ATB_LOG(INFO) << "load aclnnMlaPreprocess two-staged method success!"; + return NO_ERROR; +} +} // namespace atb diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.h b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.h new file mode 100644 index 00000000..2bb87c15 --- /dev/null +++ b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.h @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#ifndef ATB_MLAPREPROCESS_ACLNN_RUNNER_H +#define ATB_MLAPREPROCESS_ACLNN_RUNNER_H +#include "atb/infer_op_params.h" +#include "atb/runner/aclnn_runner.h" + +namespace atb { +class MlaPreprocessAclnnRunner : public AclnnRunner { +public: + explicit MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m); + explicit MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m, bool doRmsNorm = true); + ~MlaPreprocessAclnnRunner() override; + + static Status LoadMethod(); + +protected: + Status BuildAclnnVariantPack(const RunnerVariantPack &runnerVariantPack) override; + Status LaunchAclnnKernel() override; + aclnnStatus SetAclNNWorkspaceExecutor() override; + +private: + infer::MlaPreprocessParam param_; + bool doRmsNorm_ = true; + + // 对应aclnnop/aclnn_mla_preprocess.h中的俩段式接口 + static aclnnStatus (*aclnnGetWorkspaceSizeFunc_)( + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, const aclTensor *, const aclTensor *, int64_t, int64_t, int64_t, double, + int64_t, int64_t, bool, bool, bool, int64_t, int64_t, bool, int64_t, const aclTensor *, const aclTensor *, + const aclTensor *, const aclTensor *, uint64_t *, aclOpExecutor **); + static aclnnStatus (*aclnnExecuteFunc_)(void *, uint64_t, aclOpExecutor *, aclrtStream); +}; +} // namespace atb +#endif diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index 126634b0..75ad479f 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -8,6 +8,7 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "mla_preprocess_operation.h" +#include "mla_preprocess_aclnn_runner.h" #include "mla_preprocess_ops_runner.h" #include "mla_preprocess_ops_runner_split.h" #include "atb/utils/log.h" @@ -23,6 +24,10 @@ namespace { static const uint32_t IN_TENSOR_NUM = 24; static const uint32_t OUT_TENSOR_NUM = 2; static const uint32_t OUT_TENSOR_NUM_SPLIT = 4; +static const uint32_t INPUT_INDEX = 0; +static const uint32_t GAMMA0_INDEX = 1; +static const uint32_t BETA0_INDEX = 2; +static const uint32_t WDQKV_INDEX = 5; static const uint32_t BIAS0_INDEX = 7; static const uint32_t BIAS1_INDEX = 14; static const uint32_t WUK_INDEX = 18; @@ -39,6 +44,8 @@ static const uint32_t INNER_DIM_192 = 192; static const uint32_t INNER_DIM_128 = 128; static const uint32_t INNER_DIM_64 = 64; static const int64_t MAX_TOKEN_NUM = 1024; +static const int64_t MAX_HIDDEN_SIZE = 8192; +static const int64_t MIN_HIDDEN_SIZE = 2048; } // namespace namespace atb { @@ -48,17 +55,19 @@ template <> Status CreateOperation(const infer::MlaPreprocessParam &opParam, Ope if (operation == nullptr) { return ERROR_INVALID_PARAM; } + ATB_LOG(INFO) << "CreateOperation with MlaPreprocessParam: " << OpParamToJson(opParam); if (!GetSingleton().Is910B()) { ATB_LOG(ERROR) << "only support Atlas 800I A2/A3 Inference Product"; return ERROR_INVALID_PARAM; } + if (opParam.cacheMode > infer::MlaPreprocessParam::CacheMode::NZCACHE) { ATB_LOG(ERROR) << "invalid cacheMode"; return ERROR_INVALID_PARAM; } if (opParam.quantMode == infer::MlaPreprocessParam::QuantMode::PER_TOKEN_QUANT_ASYMM || opParam.quantMode == infer::MlaPreprocessParam::QuantMode::UNQUANT) { - ATB_LOG(ERROR) << "invalid quantMode,dont support PER_TOKEN_ASYMM_QUANT and UNQUANT yet"; + ATB_LOG(ERROR) << "invalid quantMode, doesn't support PER_TOKEN_ASYMM_QUANT and UNQUANT"; return ERROR_INVALID_PARAM; } OP_PARAM_RSV_CHECK(opParam); @@ -221,7 +230,7 @@ Status MlaPreprocessOperation::OutTensorCheckSplit(const SVector &inTens Status MlaPreprocessOperation::SetupCheckImpl(const SVector &inTensors, const SVector &outTensors) const { - (void)outTensors; + ATB_LOG(INFO) << GetLogPrefix() << "SetupCheckImpl"; SVector inTensorDesc = {}; OperationUtil::InTensorsToInTensorDescs(inTensors, inTensorDesc); Status st = DimCheck(inTensorDesc); @@ -246,11 +255,13 @@ Status MlaPreprocessOperation::SetupCheckImpl(const SVector &inTensors, Status MlaPreprocessOperation::TensorShapeCheck(const atb::SVector &inTensorDesc, const std::vector> &inTensorShapes) const { + ATB_LOG(INFO) << GetLogPrefix() << "TensorShapeCheck start"; if (inTensorDesc.size() != inTensorShapes.size()) { ATB_LOG(ERROR) << "inTensor num is not equal to the expect num : " << inTensorShapes.size(); return ERROR_INVALID_IN_TENSOR_NUM; } for (size_t i = 0; i < inTensorDesc.size(); i++) { + ATB_LOG(INFO) << GetLogPrefix() << "TensorShapeCheck index [" << i << "]"; if (inTensorShapes.at(i).size() == 0 || inTensorDesc.at(i).shape.dimNum == 0) { continue; } @@ -267,6 +278,7 @@ Status MlaPreprocessOperation::TensorShapeCheck(const atb::SVector &inTensorDesc) const { + ATB_LOG(INFO) << GetLogPrefix() << "DimCheck start"; int64_t tokenNum = inTensorDesc.at(0).shape.dims[0]; int64_t headNum = inTensorDesc.at(WUK_INDEX).shape.dims[0]; if (tokenNum <= 0 || headNum <= 0) { @@ -295,42 +308,52 @@ Status MlaPreprocessOperation::DimCheck(const SVector &inTensorDesc) if (st != NO_ERROR) { return st; } + st = CheckAclnnKernel(inTensorDesc); + if (st != NO_ERROR) { + return st; + } + st = HiddenSizeCheck(inTensorDesc); + if (st != NO_ERROR) { + return st; + } std::vector> inTensorShapes = { - {tokenNum, INNER_DIM_7168}, // input - {INNER_DIM_7168}, // gamma0 - {INNER_DIM_7168}, // beta0 - {1}, // quantScale0 - {1}, // quantOffset0 - {}, // wdqkv - {INNER_DIM_2112}, // deScale0 - {}, // bias0 - {INNER_DIM_1536}, // gamma1 - {INNER_DIM_1536}, // beta1 - {1}, // quantScale1 - {1}, // quantOffset1 - {}, // wuq - {headNum * INNER_DIM_192}, // deScale1 - {}, // bias1 - {INNER_DIM_512}, // gamma2 - {tokenNum, INNER_DIM_64}, // cos - {tokenNum, INNER_DIM_64}, // sin - {}, // wuk - {}, // ctkv - {}, // kRope - {tokenNum}, // slotmapping - {}, // ctkvScale - {}, // qNopeScale + {}, // input + {}, // gamma0 + {}, // beta0 + {1}, // quantScale0 + {1}, // quantOffset0 + {}, // wdqkv + {INNER_DIM_2112}, // deScale0 + {}, // bias0 + {INNER_DIM_1536}, // gamma1 + {INNER_DIM_1536}, // beta1 + {1}, // quantScale1 + {1}, // quantOffset1 + {}, // wuq + {headNum * INNER_DIM_192}, // deScale1 + {}, // bias1 + {INNER_DIM_512}, // gamma2 + {tokenNum, INNER_DIM_64}, // cos + {tokenNum, INNER_DIM_64}, // sin + {}, // wuk + {}, // ctkv + {}, // kRope + {tokenNum}, // slotmapping + {}, // ctkvScale + {}, // qNopeScale }; ModifyInTensorShapeTable(inTensorDesc, inTensorShapes, headNum); st = TensorShapeCheck(inTensorDesc, inTensorShapes); if (st != NO_ERROR) { return st; } + ATB_LOG(INFO) << GetLogPrefix() << "DimCheck finished"; return NO_ERROR; } Status MlaPreprocessOperation::BlockSizeCheck(const SVector &inTensorDesc) const { + ATB_LOG(INFO) << GetLogPrefix() << "BlockSizeCheck start"; bool isNz = (param_.cacheMode == infer::MlaPreprocessParam::CacheMode::INT8_NZCACHE || param_.cacheMode == infer::MlaPreprocessParam::CacheMode::NZCACHE); int64_t blockSize = isNz ? inTensorDesc.at(KVCACHE_INDEX).shape.dims[2] : // 2: blocksizeIdx @@ -343,12 +366,97 @@ Status MlaPreprocessOperation::BlockSizeCheck(const SVector &inTenso ATB_LOG(ERROR) << "blockSize should be 128 with INT8_NZCACHE or NZCACHE"; return ERROR_INVALID_TENSOR_DIM; } + ATB_LOG(INFO) << GetLogPrefix() << "BlockSizeCheck finished"; + return NO_ERROR; +} + +Status MlaPreprocessOperation::HiddenSizeCheck(const SVector &inTensorDesc) const +{ + if (inTensorDesc.at(INPUT_INDEX).shape.dimNum != 2) { // 2: input [tokenNum, hiddenSize] + ATB_LOG(ERROR) << GetLogPrefix() + << "dimNum of input should be 2, but got: " << inTensorDesc.at(INPUT_INDEX).shape.dimNum; + return ERROR_INVALID_TENSOR_DIM_NUM; + } + int64_t hiddenSize = inTensorDesc.at(INPUT_INDEX).shape.dims[1]; // 1: input hiddenSize index + if (useAclnnKernel_) { + if (hiddenSize < MIN_HIDDEN_SIZE || hiddenSize > MAX_HIDDEN_SIZE) { + ATB_LOG(ERROR) << GetLogPrefix() << "expect hiddenSize of input to be in range [" << MIN_HIDDEN_SIZE << "," + << MAX_HIDDEN_SIZE << "], but got: " << hiddenSize; + return ERROR_INVALID_TENSOR_DIM; + } + } else if (hiddenSize != INNER_DIM_7168) { + ATB_LOG(ERROR) << GetLogPrefix() << "expect hiddenSize of input to be" << INNER_DIM_7168 + << ", but got: " << hiddenSize; + return ERROR_INVALID_TENSOR_DIM; + } + if (inTensorDesc.at(GAMMA0_INDEX).shape.dimNum != 1) { // 1: input [hiddenSize] + ATB_LOG(ERROR) << GetLogPrefix() + << "dimNum of gamma0 should be 1, but got: " << inTensorDesc.at(GAMMA0_INDEX).shape.dimNum; + return ERROR_INVALID_TENSOR_DIM_NUM; + } + if (inTensorDesc.at(GAMMA0_INDEX).shape.dims[1] != hiddenSize) { + ATB_LOG(ERROR) << GetLogPrefix() << "hiddenSize of gamma0 should be the same as input's, " << hiddenSize + << ", but got: " << inTensorDesc.at(GAMMA0_INDEX).shape.dims[0]; // gamma0 hiddenSize index + return ERROR_INVALID_TENSOR_DIM; + } + if (inTensorDesc.at(BETA0_INDEX).shape.dimNum != 1) { // 1: input [hiddenSize] + ATB_LOG(ERROR) << GetLogPrefix() + << "dimNum of beta0 should be 1, but got: " << inTensorDesc.at(BETA0_INDEX).shape.dimNum; + return ERROR_INVALID_TENSOR_DIM_NUM; + } + if (inTensorDesc.at(BETA0_INDEX).shape.dims[1] != hiddenSize) { + ATB_LOG(ERROR) << GetLogPrefix() << "hiddenSize of beta0 should be the same as input's, " << hiddenSize + << ", but got: " << inTensorDesc.at(BETA0_INDEX).shape.dims[0]; // beta0 hiddenSize index + return ERROR_INVALID_TENSOR_DIM; + } + return NO_ERROR; +} + +Status MlaPreprocessOperation::CheckAclnnKernel(const SVector &inTensorDesc) const +{ + ATB_LOG(INFO) << GetLogPrefix() << "CheckAclnnKernel start"; + // input's hiddenSize != 7168, generalize hiddenSize + bool generalizedHiddenSize = inTensorDesc.at(INPUT_INDEX).shape.dims[1] != INNER_DIM_7168; // 1: hiddenSize + // if wdqkv's dtype is bf16, then do not rmsNorm + doRmsNorm_ = inTensorDesc.at(WDQKV_INDEX).dtype != ACL_BF16; + if (!generalizedHiddenSize && doRmsNorm_) { + ATB_LOG(INFO) << GetLogPrefix() + << "no need to use aclnn kernel for non-generalized hiddenSize and rmsNormQuant for input is on"; + useAclnnKernel_ = false; + return NO_ERROR; + } + useAclnnKernel_ = true; + Status ret = MlaPreprocessAclnnRunner::LoadMethod(); + ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::LoadMethod() ret: " << ret; + if (ret != NO_ERROR) { + if (generalizedHiddenSize) { + ATB_LOG(INFO) << GetLogPrefix() + << "Need to use aclnn kernel for generalized hiddenSize but load methods failed! Consider " + "change hiddenSize back to " + << INNER_DIM_7168 << " to use the atb kernel"; + return ERROR_INVALID_TENSOR_DIM_NUM; + } + if (!doRmsNorm_) { + ATB_LOG(INFO) << GetLogPrefix() + << "Need to use aclnn kernel while skipping rmsNormQuant for input but load methods failed!"; + return ret; + } + } return NO_ERROR; } std::shared_ptr MlaPreprocessOperation::CreateRunner(Context &context) const { (void)context; + if (useAclnnKernel_) { + ATB_LOG(INFO) << GetLogPrefix() << "create MlaPreprocess AclnnRunner"; + bool doRmsNormParam = doRmsNorm_; + useAclnnKernel_ = false; + doRmsNorm_ = true; + return std::make_shared(param_, doRmsNormParam); + } + + ATB_LOG(INFO) << GetLogPrefix() << "create MlaPreprocess OpsRunner"; if (param_.cacheMode == infer::MlaPreprocessParam::CacheMode::KVCACHE) { return std::make_shared(param_); } diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.h b/src/ops_infer/mla_preprocess/mla_preprocess_operation.h index 7d4166fe..18743941 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.h +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.h @@ -36,9 +36,13 @@ private: Status OutTensorCheck(const SVector &inTensors, const SVector &outTensors) const; Status OutTensorCheckSplit(const SVector &inTensors, const SVector &outTensors) const; Status BlockSizeCheck(const SVector &inTensorDesc) const; - + // check the hiddenSize of input, gamma0 and beta0 + Status HiddenSizeCheck(const SVector &inTensorDesc) const; + Status CheckAclnnKernel(const SVector &inTensorDesc) const; private: infer::MlaPreprocessParam param_; + mutable bool useAclnnKernel_ = false; + mutable bool doRmsNorm_ = true; }; } // namespace atb #endif \ No newline at end of file diff --git a/src/torch_atb/CMakeLists.txt b/src/torch_atb/CMakeLists.txt index f75429df..190e4123 100644 --- a/src/torch_atb/CMakeLists.txt +++ b/src/torch_atb/CMakeLists.txt @@ -8,6 +8,8 @@ # See LICENSE in the root of the software repository for the full text of the License. # +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -w") + file(GLOB_RECURSE pybind11_source_files "*.cpp") pybind11_add_module(_C ${pybind11_source_files}) set_target_properties(_C PROPERTIES OUTPUT_NAME "_C" SUFFIX ".so") diff --git a/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py b/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py new file mode 100644 index 00000000..5f4d66e2 --- /dev/null +++ b/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py @@ -0,0 +1,899 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. +# This file is a part of the CANN Open Software. +# Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. + +import logging +import sys +import os +import unittest +import math +import numpy as np +import torch +import random +import json +import torch.nn.functional as F +import torch_npu +sys.path.append(os.path.join(os.path.dirname(__file__), "../")) +import operation_test +from precision_calcu import * + +OP_NAME = "MlaPreprocessOperation" +QUANTMAX = 127 +QUANTMIN = -128 +block_size = 128 + +random.seed(12) +np.random.seed(12) +torch.manual_seed(12) + +def process_deq_scale(deq_scale: torch.Tensor) -> np.ndarray: + ret = torch.frombuffer(deq_scale.numpy().tobytes(), dtype=torch.int32).to(torch.int64) + return ret + + +def round_up(val: int, align: int) -> int: + if align == 0: + return 0 + return -(val // -align) * align + + +def transdata(nd_mat, block_size: tuple = (16, 16)): + r = round_up(nd_mat.shape[0], block_size[0]) + c = round_up(nd_mat.shape[1], block_size[1]) + r_pad = r - nd_mat.shape[0] + c_pad = c - nd_mat.shape[1] + nd_mat = F.pad(nd_mat, ((0, r_pad, 0, c_pad))) + nz_mat = torch.permute( + torch.reshape(nd_mat, (r // block_size[0], block_size[0], c // block_size[1], block_size[1])), [2, 0, 1, 3] + ) + nz_mat = torch.reshape(nz_mat, (nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3])) + return nz_mat + +def transdata_3d(nd_mat, block_size: tuple = (16, 16)): + if nd_mat.ndim != 3: + raise ValueError("Expected a 3-dimensional input array.") + B, K, N = nd_mat.shape + processed_slices = [] + + for batch_index in range(B): + current_slice = nd_mat[batch_index] + nz_mat = transdata(current_slice, block_size) + processed_slices.append(nz_mat) + + result = torch.stack(processed_slices, axis=0) + return result + +class TestMLAPrepross(operation_test.OperationTest): + def __set_envs(self, env: dict): + if env: + for key, value in env.items(): + os.environ[key] = value + + def __unset_envs(self, env: dict): + if env: + for key, _ in env.items(): + os.environ[key] = "" + + def __get_npu_device(self): + npu_device = os.environ.get("MKI_NPU_DEVICE") + if npu_device is None: + npu_device = "npu:0" + else: + npu_device = f"npu:{npu_device}" + return npu_device + + def padHeadNum(self, headdim_sin_cos): + return torch.tile(headdim_sin_cos, (1, self.headNum)) + + def rotateHalfX(self, q_temp): + # 拆分成 head_num 个 [n,head_dim] 的二维向量 + # print("============== q_temp", q_temp) + # import pdb;pdb.set_trace() + q_splits = torch.chunk(q_temp, self.headNum, dim=1) + # 对每个 [n,head_dim] 向量的第二维进行分割,并对第二块乘以 -1再拼回到第一块前面 + # import pdb;pdb.set_trace() + processed_q_splits = [] + for q_split in q_splits: + # print("============== q_split", q_split) + # 分割第二维 + first_half, second_half = torch.chunk(q_split, 2, dim=1) + # 拼接回 [n,head_dim] 的二维向量 + # import pdb;pdb.set_trace() + processed_q_split = torch.cat((-second_half, first_half), dim=1) + processed_q_splits.append(processed_q_split) + + # 将所有处理后的 [n,head_dim] 向量拼回 [n,head_num*head_dim] 的二维向量 + return torch.cat(processed_q_splits, dim=1) + + def rotateHalf2(self, q_temp): + # 拆分成 head_num 个 [n,head_dim] 的二维向量 + # print("============== q_temp", q_temp) + q_splits = np.split(q_temp, self.headNum, axis=1) + # 对每个 [n,head_dim] 向量的第二维进行分割,并对第二块乘以 -1再拼回到第一块前面 + processed_q_splits = [] + for q_split in q_splits: + # print("============== q_split", q_split) + # 分割第二维 + first_half, second_half = np.split(q_split, 2, axis=1) + # 拼接回 [n,head_dim] 的二维向量 + processed_q_split = np.concatenate((second_half, first_half), axis=1) + processed_q_splits.append(processed_q_split) + + # 将所有处理后的 [n,head_dim] 向量拼回 [n,head_num*head_dim] 的二维向量 + return np.concatenate(processed_q_splits, axis=1) + + def RopeConcatGolden(self, q, sin, cos, concatInput): + pad_sin = self.padHeadNum(sin) + pad_cos = self.padHeadNum(cos) + # import pdb;pdb.set_trace() + # self.qcosOut = self.rotateHalf(q) + rope_res = q * pad_cos + self.rotateHalfX(q) * pad_sin + rope_res = rope_res.reshape(self.input_token_num, self.headNum, self.rope_hidden_size) + rope_res = rope_res.to(self.dtype) + + return torch.cat((concatInput.to(self.dtype), rope_res), dim=2) + + def rmsNormGolden(self, x, gamma): + x_float32 = x.to(torch.float32) + square_sum = torch.sum(torch.square(x_float32), axis=-1, keepdims=True) + rms = 1.0 / torch.sqrt(square_sum /self.rms_hidden_size + self.epsilon) + gamma_float32 = gamma.to(torch.float32) + rmsNorm = rms * x_float32 * gamma_float32 + result = rmsNorm.to(self.dtype) + np.set_printoptions(suppress=True, formatter={'float_kind':'{:.15f}'.format}) + return result + + def rotateHalf(self, k_temp): + first_half, second_half = torch.chunk(k_temp, 2, dim=1) + processed_k_split = torch.cat((-second_half, first_half), dim=1) + return processed_k_split + + def RopeGolden(self, keyRope, sin, cos): + RopeGolden = keyRope * cos + self.rotateHalf(keyRope) * sin + return RopeGolden + + def RACGolden(self, keyRAC, slotMapping, keycacheout_golden): + for i, slot in enumerate(slotMapping): + if slot < 0: + continue + block_index = slot // block_size + block_offset = slot % block_size + token_key = keyRAC[i] + + keycacheout_golden[block_index][block_offset] = token_key + return keycacheout_golden + + def RmsNormAndRopeAndReshapeAndCacheGolden(self, x, gamma, keyRope, cos, sin, slotMapping, keycachein): + rmsNormOutput = self.rmsNormGolden(x, gamma) + ropeOutput = self.RopeGolden(keyRope, sin, cos) + ropeReshape = ropeOutput.reshape(self.input_token_num, 1, self.rope_hidden_size) + + keyRAC = torch.cat((rmsNormOutput, ropeReshape), axis=-1) + return self.RACGolden(keyRAC, slotMapping, keycachein) + + def rms_norm_quant_calc(self, input, gamma, beta, quantScale, quantOffset): + out_shape = input.shape + scale = 1.0 / quantScale.item() + offset = quantOffset.item() + inputScale = torch.tensor(scale, dtype=torch.float) + inputOffset = torch.tensor(offset, dtype=torch.float) + input0 = torch.tensor(input).float() + input1 = torch.tensor(gamma).float() + square_sum = torch.sum(torch.square(input0), axis=-1, keepdims=True) + factor = 1.0 / torch.sqrt(square_sum / out_shape[-1] + self.epsilon) + output = input0 * factor * input1 + output = (output + beta.float()) * inputScale + inputOffset + self.xxTest = output.clone() + output = torch.round(output) + output = torch.tensor(output).to(torch.float16) + output = torch.min(output, torch.tensor(QUANTMAX, dtype=torch.half)) + output = torch.max(output, torch.tensor(QUANTMIN, dtype=torch.half)) + return torch.tensor(output).to(torch.int8) + + def calc_vec_mm_atb_data(self, N, headNum, data_type, hidden_size): + hiddenStrate = hidden_size + blockNum = 192 + blockSize = 128 + headdim = 576 + self.input_token_num = N + self.rms_hidden_size = 512 + self.rope_hidden_size = 64 + self.headNum = headNum + self.epsilon = 1e-6 + self.dtype = data_type + + self.input1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(N, hiddenStrate))).to(data_type)# + self.gamma1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(hiddenStrate))).to(data_type) + self.quantScale1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) + self.quantOffset1 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) + self.wdqkv = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(2112, hiddenStrate))).to(torch.int8)# + + self.deScale1 = torch.rand((2112), dtype=torch.float32) / 1000 + self.gamma2 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(1536))).to(data_type) + self.quantScale2 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) + self.quantOffset2 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) + + self.wuq = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum * 192, 1536))).to(torch.int8)# + + self.deScale2 = torch.rand((headNum * 192), dtype=torch.float32) / 1000 + + self.gamma3 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(512))).to(data_type) + self.sin1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(N, 64))).to(data_type) + self.cos1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(N, 64))).to(data_type) + self.keyCache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(blockNum, blockSize, 1, headdim))).to(data_type) + + self.slotMapping = torch.from_numpy(np.random.choice(192 * 128, N, replace=False).astype(np.int32)).to(torch.int32) + + self.wuk = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum, 128, 512))).to(data_type)# + self.sin2 = self.sin1 + self.cos2 = self.cos1 + + self.bias1 = torch.from_numpy(np.random.randint(-10, 10, (1, 2112)).astype(np.int32)).to(torch.int32) + self.bias2 = torch.from_numpy(np.random.randint(-10, 10, (1, headNum * 192)).astype(np.int32)).to(torch.int32) + + self.beta1 = torch.from_numpy(np.random.randint(-2, 2, (hiddenStrate)).astype(np.float16)).to(data_type) + self.beta2 = torch.from_numpy(np.random.randint(-2, 2, (1536)).astype(np.float16)).to(data_type) + + self.calc_vec_mm_data(N, headNum, data_type) + + ## RmsNorm + print("====================RmsNorm0====================") + self.rms1Out1_npu = torch.zeros((N, hiddenStrate), dtype=torch.int8) + npu_device = self.__get_npu_device() + torch_npu.npu.set_device(npu_device) + in_tensors = [self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1] + out_tensors = [self.rms1Out1_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "RmsNormOperation" + op_param =json.dumps({"layerType":1, "normParam":{"quantType": 2, "epsilon": 1e-6}}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.rms1Out1_npu = out_tensors_npu[0].cpu().clone() + + ##Ppmatmul + print("====================Ppmatmul0====================") + self.mm1Out1_npu = torch.zeros((N, 2112), dtype=data_type) + in_tensors = [out_tensors_npu[0], self.wdqkv, self.bias1, self.deScale1] + out_tensors = [self.mm1Out1_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "LinearOperation" + op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 1}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.mm1Out1_npu = out_tensors_npu[0].cpu().clone() + print(self.mm1Out1_npu.size()) + + ##SplitV + print("====================SplitV0====================") + splitSize = [512, 64, 1536] + splitVDim = 1 + in_tensors_npu = [out_tensors_npu[0]] + Split1_out_tensors_npu = [] + shape = in_tensors_npu[0].shape + for size in splitSize: + slice_shape = list(shape) + slice_shape[splitVDim] = size + Split1_out_tensors_npu.append(torch.zeros(slice_shape, dtype=data_type).npu()) + op_name = "SplitOperation" + op_param =json.dumps({"splitNum": 3, "splitDim": splitVDim, "splitSizes": splitSize}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Split1_out_tensors_npu) + print(Split1_out_tensors_npu[0].size()) + ## RmsNorm + print("====================RmsNorm1====================") + self.rms2Out_npu = torch.zeros((N, 1536), dtype=torch.int8) + in_tensors = [Split1_out_tensors_npu[2], self.gamma2, self.beta2, self.quantScale2, self.quantOffset2] + out_tensors = [self.rms2Out_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "RmsNormOperation" + op_param =json.dumps({"layerType":1, "normParam":{"quantType": 2, "epsilon": 1e-6}}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.rms2Out_npu = out_tensors_npu[0].cpu().clone() + print(self.rms2Out_npu.size()) + + # ##Ppmatmul + print("====================Ppmatmul1====================") + self.mm2Out_npu = torch.zeros((N, headNum * 192), dtype=data_type) + in_tensors = [out_tensors_npu[0], self.wuq, self.bias2, self.deScale2] + out_tensors = [self.mm2Out_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + mm2out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "LinearOperation" + op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 1}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, mm2out_tensors_npu) + self.mm2Out_npu = mm2out_tensors_npu[0].cpu().clone() + print(self.mm2Out_npu.size()) + + # # ##SplitV + print("====================SplitV1====================") + splitSize = [128, 64] + splitVDim = 2 + in_tensors = mm2out_tensors_npu[0].reshape(N, headNum, 192) + in_tensors_npu = in_tensors.npu() + Split2_out_tensors_npu = [] + shape = in_tensors_npu.shape + in_tensors_npu = [in_tensors_npu] + for size in splitSize: + slice_shape = list(shape) + slice_shape[splitVDim] = size + Split2_out_tensors_npu.append(torch.zeros(slice_shape, dtype=data_type).npu()) + op_name = "SplitOperation" + op_param =json.dumps({"splitNum": 2, "splitDim": splitVDim, "splitSizes": splitSize}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Split2_out_tensors_npu) + print(Split2_out_tensors_npu[0].size()) + + #EinSum + print("====================EinSum====================") + # EinSumInput = torch.transpose(Split2_out_tensors_npu[0], 0, 1) + self.trans_A, self.trans_B = False, False + bsize, msize, ksize, nsize = headNum, N, 128, 512 + in_tensors = [Split2_out_tensors_npu[0], self.wuk] + Einsumout_tensors = [torch.zeros((N, headNum, 512), dtype=data_type)] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + Einsumout_tensors_npu = [tensor.npu() for tensor in Einsumout_tensors] + op_name = "LinearOperation" + op_param =json.dumps({"matmulType": 1, "transposeA": False, "transposeB": False, "hasBias": False, "outDataType": -1}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Einsumout_tensors_npu) + + # Einsumout_tensors_npu = [torch.transpose(Einsumout_tensors_npu[0], 0, 1)] + self.einsumOut = Einsumout_tensors_npu[0].cpu().clone() + print(self.einsumOut.size()) + # RopeQConcat + print("====================RopeQConcat====================") + self.qOut_npu = torch.zeros((N, headNum, 576), dtype=data_type) + Split2_out_tensors_npu[1] = Split2_out_tensors_npu[1].cpu().reshape(N, headNum * 64) + in_tensors = [Split2_out_tensors_npu[1], self.cos2, self.sin2, Einsumout_tensors_npu[0]] + qOut_tensors = [self.qOut_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + qout_tensors_npu = [qOut_tensors[i] if isinstance(i, int) else i.npu() + for i in qOut_tensors] + op_name = "RopeQConcatOperation" + op_param =json.dumps({}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, qout_tensors_npu) + qout_tensors_npu[0] = torch.cat([qout_tensors_npu[0].cpu()[:,:,64:],qout_tensors_npu[0].cpu()[:,:,:64]], dim = 2) + self.qOut_npu = qout_tensors_npu[0].cpu().clone() + print(self.qOut_npu.size()) + # #Rope + print("====================Rope====================") + rotaryCoeff = 2 + self.RopeOut0 = torch.zeros((N, 64), dtype=data_type) + self.RopeOut1 = torch.zeros((N, 64), dtype=data_type) + seqlen = torch.randint(1, 2, (N,), dtype=torch.int32) + in_tensors = [Split1_out_tensors_npu[1].reshape(N, 1 * 64), Split1_out_tensors_npu[1].reshape(N, 1 * 64), self.cos1, self.sin1, seqlen] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + Ropeout_tensors_npu = [self.RopeOut0.npu(), self.RopeOut1.npu()] + + op_name = "RopeOperation" + op_param =json.dumps({"rotaryCoeff": rotaryCoeff}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Ropeout_tensors_npu) + print(Ropeout_tensors_npu[0].size()) + + #RmsNorm + print("====================RmsNorm2====================") + in_tensors = [Split1_out_tensors_npu[0].reshape(N, 1, 512), self.gamma3] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + self.RmsNormOut = torch.zeros((N, 1, 512), dtype=data_type).npu() + RmsNormOut_tensors_npu = [self.RmsNormOut] + + op_name = "RmsNormOperation" + op_param =json.dumps({"layerType":1, "normParam":{"epsilon": 1e-6}}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, RmsNormOut_tensors_npu) + + out_tensors_npu = RmsNormOut_tensors_npu + #Concat + print("====================Concat====================") + in_tensors = [RmsNormOut_tensors_npu[0], Ropeout_tensors_npu[0].cpu().reshape(N, 1, 64)] + ConCat2out_tensors = [torch.zeros((N, 1, 576), dtype=data_type)] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + ConCat2out_tensors_npu = [tensor.npu() for tensor in ConCat2out_tensors] + + op_name = "ConcatOperation" + op_param =json.dumps({"concatDim": 2}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, ConCat2out_tensors_npu) + + #Reshape&Cache + print("====================Reshape&Cache====================") + in_tensors = [ConCat2out_tensors_npu[0], self.keyOutTensor, self.slotMapping] + out_tensors = [1] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [in_tensors_npu[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + + op_name = "ReshapeAndCacheOperation" + op_param =json.dumps({"kvCacheCfg": 1}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.keyout_npu = out_tensors_npu[0].cpu().clone() + + def calc_vec_mm_atb_data_bf16(self, N, headNum, data_type, hidden_size): + hiddenStrate = hidden_size + blockNum = 192 + blockSize = 128 + headdim = 576 + self.input_token_num = N + self.rms_hidden_size = 512 + self.rope_hidden_size = 64 + self.headNum = headNum + self.epsilon = 1e-6 + self.dtype = data_type + + self.input1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(N, hiddenStrate))).to(data_type)# + self.gamma1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(hiddenStrate))).to(data_type) + self.quantScale1 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) + self.quantOffset1 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) + self.wdqkv = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(2112, hiddenStrate))).to(torch.int8)# + + self.deScale1 = torch.rand((2112), dtype=torch.float32) / 1000 + self.gamma2 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(1536))).to(data_type) + self.quantScale2 = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(1))).to(data_type) + self.quantOffset2 = torch.from_numpy(np.random.uniform(-128.0, 127.0, size=(1))).to(torch.int8) + + self.wuq = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum * 192, 1536))).to(torch.int8)# + + self.deScale2 = torch.rand((headNum * 192), dtype=torch.float32) / 1000 + + self.gamma3 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(512))).to(data_type) + self.sin1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(N, 64))).to(data_type) + self.cos1 = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(N, 64))).to(data_type) + self.keyCache = torch.from_numpy(np.random.uniform(-1.0, 1.0, size=(blockNum, blockSize, 1, headdim))).to(data_type) + + self.slotMapping = torch.from_numpy(np.random.choice(192 * 128, N, replace=False).astype(np.int32)).to(torch.int32) + + self.wuk = torch.from_numpy(np.random.uniform(-2.0, 2.0, size=(headNum, 128, 512))).to(data_type)# + self.sin2 = self.sin1 + self.cos2 = self.cos1 + + self.bias1 = torch.from_numpy(np.random.randint(-10, 10, (1, 2112)).astype(np.int32)).to(torch.int32) + self.bias2 = torch.from_numpy(np.random.randint(-10, 10, (1, headNum * 192)).astype(np.int32)).to(torch.int32) + + self.beta1 = torch.from_numpy(np.random.randint(-2, 2, (hiddenStrate)).astype(np.float16)).to(data_type) + self.beta2 = torch.from_numpy(np.random.randint(-2, 2, (1536)).astype(np.float16)).to(data_type) + + self.calc_vec_mm_data(N, headNum, data_type) + + ## RmsNorm + print("====================RmsNorm0====================") + self.rms1Out1_npu = torch.zeros((N, hiddenStrate), dtype=torch.int8) + npu_device = self.__get_npu_device() + torch_npu.npu.set_device(npu_device) + in_tensors = [self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1] + out_tensors = [self.rms1Out1_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "RmsNormOperation" + op_param =json.dumps({"layerType":1, "normParam":{"quantType": 2, "epsilon": 1e-6}}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.rms1Out1_npu = out_tensors_npu[0].cpu().clone() + + ##Ppmatmul + print("====================Ppmatmul0====================") + self.mm1Out1_npu = torch.zeros((N, 2112), dtype=data_type) + in_tensors = [out_tensors_npu[0], self.wdqkv, self.bias1, self.deScale1] + out_tensors = [self.mm1Out1_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "LinearOperation" + op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 27}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.mm1Out1_npu = out_tensors_npu[0].cpu().clone() + print(self.mm1Out1_npu.size()) + + ##SplitV + print("====================SplitV0====================") + splitSize = [512, 64, 1536] + splitVDim = 1 + in_tensors_npu = [out_tensors_npu[0]] + Split1_out_tensors_npu = [] + shape = in_tensors_npu[0].shape + for size in splitSize: + slice_shape = list(shape) + slice_shape[splitVDim] = size + Split1_out_tensors_npu.append(torch.zeros(slice_shape, dtype=data_type).npu()) + op_name = "SplitOperation" + op_param =json.dumps({"splitNum": 3, "splitDim": splitVDim, "splitSizes": splitSize}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Split1_out_tensors_npu) + print(Split1_out_tensors_npu[0].size()) + ## RmsNorm + print("====================RmsNorm1====================") + self.rms2Out_npu = torch.zeros((N, 1536), dtype=torch.int8) + in_tensors = [Split1_out_tensors_npu[2], self.gamma2, self.beta2, self.quantScale2, self.quantOffset2] + out_tensors = [self.rms2Out_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "RmsNormOperation" + op_param =json.dumps({"layerType":1, "normParam":{"quantType": 2, "epsilon": 1e-6}}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.rms2Out_npu = out_tensors_npu[0].cpu().clone() + print(self.rms2Out_npu.size()) + + # ##Ppmatmul + print("====================Ppmatmul1====================") + self.mm2Out_npu = torch.zeros((N, headNum * 192), dtype=data_type) + in_tensors = [out_tensors_npu[0], self.wuq, self.bias2, self.deScale2] + out_tensors = [self.mm2Out_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + mm2out_tensors_npu = [out_tensors[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + op_name = "LinearOperation" + op_param =json.dumps({"transposeA": False, "transposeB": True, "hasBias": True, "outDataType": 27}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, mm2out_tensors_npu) + self.mm2Out_npu = mm2out_tensors_npu[0].cpu().clone() + print(self.mm2Out_npu.size()) + + # # ##SplitV + print("====================SplitV1====================") + splitSize = [128, 64] + splitVDim = 2 + in_tensors = mm2out_tensors_npu[0].reshape(N, headNum, 192) + in_tensors_npu = in_tensors.npu() + Split2_out_tensors_npu = [] + shape = in_tensors_npu.shape + in_tensors_npu = [in_tensors_npu] + for size in splitSize: + slice_shape = list(shape) + slice_shape[splitVDim] = size + Split2_out_tensors_npu.append(torch.zeros(slice_shape, dtype=data_type).npu()) + op_name = "SplitOperation" + op_param =json.dumps({"splitNum": 2, "splitDim": splitVDim, "splitSizes": splitSize}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Split2_out_tensors_npu) + print(Split2_out_tensors_npu[0].size()) + + #EinSum + print("====================EinSum====================") + # EinSumInput = torch.transpose(Split2_out_tensors_npu[0], 0, 1) + self.trans_A, self.trans_B = False, False + bsize, msize, ksize, nsize = headNum, N, 128, 512 + in_tensors = [Split2_out_tensors_npu[0], self.wuk] + Einsumout_tensors = [torch.zeros((N, headNum, 512), dtype=data_type)] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + Einsumout_tensors_npu = [tensor.npu() for tensor in Einsumout_tensors] + op_name = "LinearOperation" + op_param =json.dumps({"matmulType": 1, "transposeA": False, "transposeB": False, "hasBias": False, "outDataType": -1}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Einsumout_tensors_npu) + + # Einsumout_tensors_npu = [torch.transpose(Einsumout_tensors_npu[0], 0, 1)] + self.einsumOut = Einsumout_tensors_npu[0].cpu().clone() + print(self.einsumOut.size()) + # RopeQConcat + print("====================RopeQConcat====================") + self.qOut_npu = torch.zeros((N, headNum, 576), dtype=data_type) + Split2_out_tensors_npu[1] = Split2_out_tensors_npu[1].cpu().reshape(N, headNum * 64) + in_tensors = [Split2_out_tensors_npu[1], self.cos2, self.sin2, Einsumout_tensors_npu[0]] + qOut_tensors = [self.qOut_npu] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + qout_tensors_npu = [qOut_tensors[i] if isinstance(i, int) else i.npu() + for i in qOut_tensors] + op_name = "RopeQConcatOperation" + op_param =json.dumps({}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, qout_tensors_npu) + qout_tensors_npu[0] = torch.cat([qout_tensors_npu[0].cpu()[:,:,64:],qout_tensors_npu[0].cpu()[:,:,:64]], dim = 2) + self.qOut_npu = qout_tensors_npu[0].cpu().clone() + print(self.qOut_npu.size()) + # #Rope + print("====================Rope====================") + rotaryCoeff = 2 + self.RopeOut0 = torch.zeros((N, 64), dtype=data_type) + self.RopeOut1 = torch.zeros((N, 64), dtype=data_type) + seqlen = torch.randint(1, 2, (N,), dtype=torch.int32) + in_tensors = [Split1_out_tensors_npu[1].reshape(N, 1 * 64), Split1_out_tensors_npu[1].reshape(N, 1 * 64), self.cos1, self.sin1, seqlen] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + Ropeout_tensors_npu = [self.RopeOut0.npu(), self.RopeOut1.npu()] + + op_name = "RopeOperation" + op_param =json.dumps({"rotaryCoeff": rotaryCoeff}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, Ropeout_tensors_npu) + print(Ropeout_tensors_npu[0].size()) + + #RmsNorm + print("====================RmsNorm2====================") + in_tensors = [Split1_out_tensors_npu[0].reshape(N, 1, 512), self.gamma3] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + self.RmsNormOut = torch.zeros((N, 1, 512), dtype=data_type).npu() + RmsNormOut_tensors_npu = [self.RmsNormOut] + + op_name = "RmsNormOperation" + op_param =json.dumps({"layerType":1, "normParam":{"epsilon": 1e-6}}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, RmsNormOut_tensors_npu) + + out_tensors_npu = RmsNormOut_tensors_npu + #Concat + print("====================Concat====================") + in_tensors = [RmsNormOut_tensors_npu[0], Ropeout_tensors_npu[0].cpu().reshape(N, 1, 64)] + ConCat2out_tensors = [torch.zeros((N, 1, 576), dtype=data_type)] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + ConCat2out_tensors_npu = [tensor.npu() for tensor in ConCat2out_tensors] + + op_name = "ConcatOperation" + op_param =json.dumps({"concatDim": 2}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, ConCat2out_tensors_npu) + + #Reshape&Cache + print("====================Reshape&Cache====================") + in_tensors = [ConCat2out_tensors_npu[0], self.keyOutTensor, self.slotMapping] + out_tensors = [1] + in_tensors_npu = [tensor.npu() for tensor in in_tensors] + out_tensors_npu = [in_tensors_npu[i] if isinstance(i, int) else i.npu() + for i in out_tensors] + + op_name = "ReshapeAndCacheOperation" + op_param =json.dumps({"kvCacheCfg": 1}) + self.operation = torch.classes.OperationTorch.OperationTorch(op_name) + self.op_param = op_param + self.operation.set_param(op_param) + self.operation.execute_out(in_tensors_npu, out_tensors_npu) + self.keyout_npu = out_tensors_npu[0].cpu().clone() + + def calc_vec_mm_data(self, N, headNum, data_type): + mm1In = self.rms_norm_quant_calc(self.input1, self.gamma1, self.beta1, self.quantScale1, self.quantOffset1) + + self.rmsquantOut1 = mm1In.clone() + mm1Out = torch.matmul(mm1In.to(torch.float32), self.wdqkv.transpose(0,1).to(torch.float32)) + mm1Out = mm1Out.to(torch.int32) + self.bias1 + mm1Out = (mm1Out.to(torch.float32) * self.deScale1).to(data_type) + self.mm1Out1 = mm1Out + if data_type == torch.float16: + self.deScale1 = process_deq_scale(deq_scale=self.deScale1) + mm1OutSplit1, mm1OutSplit2 = torch.split(mm1Out, [576, 1536], dim=1) + rms2Out = self.rms_norm_quant_calc(mm1OutSplit2, self.gamma2, self.beta2, self.quantScale2, self.quantOffset2) + self.rmsquantOut2 = rms2Out.clone() + mm2Out = torch.matmul(rms2Out.to(torch.float32), self.wuq.transpose(0, 1).to(torch.float32)) + mm2Out = mm2Out.to(torch.int32) + self.bias2 + mm2Out = (mm2Out.to(torch.float32) * self.deScale2).to(data_type) + + self.mm2Out2 = mm2Out + if data_type == torch.float16: + self.deScale2 = process_deq_scale(deq_scale=self.deScale2) + + mm11OutSplit1, mm12OutSplit1 = torch.split(mm1OutSplit1, [512, 64], dim=1) + mm11OutSplit1 = mm11OutSplit1.reshape(N, 1, 512) + self.keyOutTensor = self.keyCache.clone() + self.keyOut1 = self.RmsNormAndRopeAndReshapeAndCacheGolden(mm11OutSplit1, self.gamma3, mm12OutSplit1, self.cos1, self.sin1, self.slotMapping, self.keyCache) + mm2Out = mm2Out.reshape(N, headNum, 192) + mm2OutSplit1, mm2OutSplit2 = torch.split(mm2Out, [128, 64], dim=2) + self.bmmOut= torch.permute( + torch.matmul(torch.permute(mm2Out[:, :, :128], (1, 0, 2)).float(), self.wuk.float()), + (1, 0, 2), + ) + self.gg = mm2OutSplit2.clone() + qOut = self.RopeConcatGolden(mm2OutSplit2.reshape(N, headNum * 64), self.sin2, self.cos2, self.bmmOut) + self.qOut = qOut.to(data_type) + + def golden_calc(self, in_tensors): + # out_tensors = torch.load("out_tensors.pt") + # return [out_tensors[-2], out_tensors[-1]] + if self.cacheMode == 1: + return [self.qOut_npu[..., 0:512], self.keyout_npu[..., 0:512], self.qOut_npu[..., 512:576], self.keyout_npu[..., 512:576]] + else: + return [self.qOut_npu, self.keyout_npu] + + def compare_data(self, tensor1, tensor2): + out = tensor1.flatten() + out_len = out.shape[0] + golden = tensor2.flatten() + diff = torch.abs(golden - out) + max_diff = diff.max().item() + print(f"maxDiff {max_diff}") + golden = golden.to(torch.float32) + out = out.to(torch.float32) + limit_error = torch.maximum( + torch.abs(golden * 0.001), torch.tensor(0.001)) + strict_limit_error = torch.maximum( + torch.abs(golden * 0.003), torch.tensor(0.003)) + error_count = torch.gt(diff, limit_error).sum().item() + strict_error_count = torch.gt( + diff, strict_limit_error).sum().item() + print("1/1000 Accuracy is %f", + 1 - float(error_count) / out_len) + print("3/1000 Accuracy is %f", 1 - + float(strict_error_count) / out_len) + print("accuracy is correct: %r", (float(strict_error_count) / out_len) <= 0.001) + + def golden_compare(self, out_tensors, golden_tensors): + # self.compare_data(out_tensors.npu(), golden_tensors.npu()) + if self.cacheMode == 1: + print(f"qOut_npu npu max {torch.max(self.qOut_npu.clone()[..., 0:512].cpu())} min {torch.min(self.qOut_npu.clone()[..., 0:512].cpu())}") + print(f"qout max {torch.max(self.qOut.clone()[..., 0:512].cpu())} min {torch.min(self.qOut.clone()[..., 0:512].cpu())}") + print(f"out_tensors max {torch.max(out_tensors.clone()[..., 0:512].cpu())} min {torch.min(out_tensors.clone()[..., 0:512].cpu())}") + if self.compare_count == 0: + self.compare_count += 1 + return compare_cv(self.qOut_npu[..., 0:512].npu(), self.qOut[..., 0:512].npu(), out_tensors.npu()) + elif self.compare_count == 1: + self.compare_count += 1 + return compare_cv(self.keyout_npu[..., 0:512].npu(), self.keyOut1[..., 0:512].npu(), out_tensors.npu()) + elif self.compare_count == 2: + self.compare_count += 1 + return compare_cv(self.qOut_npu[..., 512:576].npu(), self.qOut[..., 512:576].npu(), out_tensors.npu()) + elif self.compare_count == 3: + return compare_cv(self.keyout_npu[..., 512:576].npu(), self.keyOut1[..., 512:576].npu(), out_tensors.npu()) + else: + if self.compare_count == 0: + self.compare_count += 1 + return compare_cv(self.qOut_npu.npu(), self.qOut.npu(), out_tensors.npu()) + else: + return compare_cv(self.keyout_npu.npu(), self.keyOut1.npu(), out_tensors.npu()) + + def test_mla_preprocess(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + self.compare_count = 0 + self.cacheMode = 1 + N = 32 + headNum = 32 + data_type = torch.float16 + hidden_size = 8000 + N = 32 + headNum = 128 + data_type = torch.float16 + OP_NAME = "MlaPreprocessOperation" + PARAM = json.dumps({"cacheMode":self.cacheMode}) + print("test_mla_preprocess") + self.calc_vec_mm_atb_data(N,headNum,data_type, hidden_size) + self.keyCache = self.keyCache.npu() + qOut: torch.Tensor = torch.zeros((N, headNum, 512), dtype=data_type).npu() # float16 + kvCacheOut: torch.Tensor = self.keyCache[..., 0:512].npu() # float16 + qRopeOut: torch.Tensor = torch.zeros((N, headNum, 64), dtype=data_type).npu() # float16 + krCacheOut: torch.Tensor = self.keyCache[..., 512:576].npu() + print("type -------------- [qOut, kvCacheOut, qRopeOut, krCacheOut]", type([qOut, kvCacheOut, qRopeOut, krCacheOut])) + self.execute_out(OP_NAME, PARAM, + [self.input1.npu(), # float16 + self.gamma1.npu(), # float16 + self.beta1.npu(), # float16 + self.quantScale1.npu(), # float16 + self.quantOffset1.npu(), # int8 + torch_npu.npu_format_cast(transdata(self.wdqkv, (16, 32)).contiguous().npu(), 29), # int8,nz + self.deScale1.to(torch.int64).npu(), # int64 + self.bias1.npu(), # int32 + self.gamma2.npu(), # float16 + self.beta2.npu(), # float16 + self.quantScale2.npu(), # float16 + self.quantOffset2.npu(), # int8 + torch_npu.npu_format_cast(transdata(self.wuq, (16, 32)).contiguous().npu(), 29), # int8,nz + self.deScale2.to(torch.int64).npu(), # int64 + self.bias2.npu(), # int32 + self.gamma3.npu(), # float16 + self.cos1.npu(), # float16 + self.sin1.npu(), # float16 + self.wuk.npu(), # float16 + self.keyCache[..., 0:512].npu(), # float16 + self.keyCache[..., 512:576].npu(), # float16 + self.slotMapping.npu(), # int32 + torch.tensor([]).to(data_type).npu(), + torch.tensor([]).to(data_type).npu()], + [qOut, kvCacheOut, qRopeOut, krCacheOut]) # float16 + print("================================qOut=============================", qOut) + + + def test_mla_preprocess_no_rms_norm(self): + if not operation_test.get_soc_version() == 'Ascend910B': + print("this testcase only supports Ascend910B") + return + self.compare_count = 0 + self.cacheMode = 0 + N = 32 + headNum = 32 + data_type = torch.float16 + hidden_size = 8000 + OP_NAME = "MlaPreprocessOperation" + PARAM = json.dumps({"cacheMode":self.cacheMode}) + self.calc_vec_mm_atb_data(N,headNum,data_type, hidden_size) + self.keyCache = self.keyCache.npu() + print("test_mla_preprocess_no_rms_norm") + self.execute_out(OP_NAME, PARAM, + [self.input1.npu(), + self.gamma1.npu(), + self.beta1.npu(), + self.quantScale1.npu(), + self.quantOffset1.npu(), + torch_npu.npu_format_cast(transdata(self.wdqkv, (16, 32)).contiguous().to(torch.bfloat16).npu(), 29), + self.deScale1.to(torch.float32).npu(), + self.bias1.npu(), + self.gamma2.npu(), + self.beta2.npu(), + self.quantScale2.npu(), + self.quantOffset2.npu(), + torch_npu.npu_format_cast(transdata(self.wuq, (16, 32)).contiguous().npu(), 29), + self.deScale2.to(torch.float32).npu(), + self.bias2.npu(), + self.gamma3.npu(), + self.cos1.npu(), + self.sin1.npu(), + self.wuk.npu(), + self.keyCache[..., 0:512].npu(), + self.keyCache[..., 512:576].npu(), + self.slotMapping.npu(), + torch.tensor([]).npu(), + torch.tensor([]).npu()], + [torch.zeros((N, headNum, 512), dtype=data_type).npu(), + self.keyCache[..., 0:512].npu(), + ]) + +def suite(): + suite = unittest.TestSuite() + for _ in range(1): + suite.addTest(TestMLAPrepross('test_mla_preprocess')) + # suite.addTest(TestMLAPrepross('test_mla_preprocess_no_rms_norm')) + return suite + +if __name__ == '__main__': + runner = unittest.TextTestRunner() + runner.run(suite()) diff --git a/tests/cinterface/CMakeLists.txt b/tests/cinterface/CMakeLists.txt index be235428..4860bdd1 100644 --- a/tests/cinterface/CMakeLists.txt +++ b/tests/cinterface/CMakeLists.txt @@ -18,5 +18,5 @@ target_include_directories(atb_cinterface PRIVATE ${ASCEND_HOME_PATH}/include/) target_link_directories(atb_cinterface PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${ASCEND_HOME_PATH}/lib64/ ${ATB_HOME_PATH}/lib) target_compile_options(atb_cinterface PRIVATE "-Wno-write-strings") target_compile_options(atb_cinterface PRIVATE -Wno-sign-compare -Wno-narrowing) -target_link_libraries(atb_cinterface PRIVATE atb nnopbase tbe_adapter -lgtest -lgtest_main -lc_sec) +target_link_libraries(atb_cinterface PRIVATE atb nnopbase tbe_adapter -lgtest -lgtest_main -lc_sec -lopapi) install(TARGETS atb_cinterface DESTINATION bin) \ No newline at end of file diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index a265bc53..d9e9fb88 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -17,8 +17,9 @@ add_executable(atb_unittest ${CORE_SOURCE} ${NORMAL_SOURCE} ${OPS_SOURCE}) target_include_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty) target_include_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/include) target_include_directories(atb_unittest PRIVATE ${ASCEND_HOME_PATH}/include/) +target_include_directories(atb_unittest PRIVATE ${ASCEND_HOME_PATH}/include/aclnnop) target_link_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${ASCEND_HOME_PATH}/lib64/) target_compile_options(atb_unittest PRIVATE "-Wno-write-strings") target_compile_options(atb_unittest PRIVATE -Wno-sign-compare -Wno-narrowing) -target_link_libraries(atb_unittest PRIVATE atb_test_utils tbe_adapter -lgtest -lgtest_main -lc_sec) +target_link_libraries(atb_unittest PRIVATE atb_test_utils tbe_adapter -lgtest -lgtest_main -lc_sec -lopapi) install(TARGETS atb_unittest DESTINATION bin) -- Gitee From 903fcc15c72e3bd740b0191daa609b0f5484349d Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 09:25:38 +0800 Subject: [PATCH 2/8] task: add ini --- ops_configs/atb_ops_info.ini | 881 +++++++++--------- src/atb/runner/aclnn_runner.cpp | 3 +- .../mla_preprocess_aclnn_runner.cpp | 21 +- .../mla_preprocess_operation.cpp | 6 +- .../test_mla_preprocess_aclnn.py | 168 +++- 5 files changed, 591 insertions(+), 488 deletions(-) diff --git a/ops_configs/atb_ops_info.ini b/ops_configs/atb_ops_info.ini index 92edbfb7..c6f06b4a 100644 --- a/ops_configs/atb_ops_info.ini +++ b/ops_configs/atb_ops_info.ini @@ -5543,710 +5543,711 @@ output1.format=nd,nd,nd [MlaPreprocessOperation] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,bf16,float16,bf16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input4.name=quantOffset0 -input4.dtype=int8,int8,int8,int8 -input4.format=nd,nd,nd,nd +input4.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=int64,float,int64,float -input6.format=nd,nd,nd,nd +input6.dtype=int64,float,int64,float,int64,float,int64,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=int32,int32,int32,int32 -input7.format=nd,nd,nd,nd +input7.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,bf16,float16,bf16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input11.name=quantOffset1 -input11.dtype=int8,int8,int8,int8 -input11.format=nd,nd,nd,nd +input11.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=int64,float,int64,float -input13.format=nd,nd,nd,nd +input13.dtype=int64,float,int64,float,int64,float,int64,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=int32,int32,int32,int32 -input14.format=nd,nd,nd,nd +input14.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=float16,bf16,float16,bf16 -input19.format=nd,nd,nd,nd +input19.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input19.format=nd,nd,nd,nd,nd,nd,nd,nd input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=nd,nd,nd,nd +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=nd,nd,nd,nd,nd,nd,nd,nd input20.optional=true input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,bf16,float16,bf16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input22.optional=true input23.name=qNopeScale -input23.dtype=float16,bf16,float16,bf16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd input23.optional=true output0.name=qOut -output0.dtype=float16,bf16,float16,bf16 -output0.format=nd,nd,nd,nd +output0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut -output1.dtype=float16,bf16,float16,bf16 -output1.format=nd,nd,nd,nd +output1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output1.format=nd,nd,nd,nd,nd,nd,nd,nd [MlaPreprocessOperationSplit] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,bf16,float16,bf16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input4.name=quantOffset0 -input4.dtype=int8,int8,int8,int8 -input4.format=nd,nd,nd,nd +input4.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=int64,float,int64,float -input6.format=nd,nd,nd,nd +input6.dtype=int64,float,int64,float,int64,float,int64,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=int32,int32,int32,int32 -input7.format=nd,nd,nd,nd +input7.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,bf16,float16,bf16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input11.name=quantOffset1 -input11.dtype=int8,int8,int8,int8 -input11.format=nd,nd,nd,nd +input11.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=int64,float,int64,float -input13.format=nd,nd,nd,nd +input13.dtype=int64,float,int64,float,int64,float,int64,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=int32,int32,int32,int32 -input14.format=nd,nd,nd,nd +input14.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=float16,bf16,float16,bf16 -input19.format=nd,nd,nd,nd +input19.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input19.format=nd,nd,nd,nd,nd,nd,nd,nd input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=nd,nd,nd,nd +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=nd,nd,nd,nd,nd,nd,nd,nd input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,bf16,float16,bf16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input22.optional=true input23.name=qNopeScale -input23.dtype=float16,bf16,float16,bf16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd input23.optional=true output0.name=qOut0 -output0.dtype=float16,bf16,float16,bf16 -output0.format=nd,nd,nd,nd +output0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut0 -output1.dtype=float16,bf16,float16,bf16 -output1.format=nd,nd,nd,nd +output1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output1.format=nd,nd,nd,nd,nd,nd,nd,nd output2.name=qOut1 -output2.dtype=float16,bf16,float16,bf16 -output2.format=nd,nd,nd,nd +output2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output2.format=nd,nd,nd,nd,nd,nd,nd,nd output3.name=kvCacheOut1 -output3.dtype=float16,bf16,float16,bf16 -output3.format=nd,nd,nd,nd +output3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output3.format=nd,nd,nd,nd,nd,nd,nd,nd [MlaPreprocessOperationINT8NZ] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,bf16,float16,bf16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input4.name=quantOffset0 -input4.dtype=int8,int8,int8,int8 -input4.format=nd,nd,nd,nd +input4.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=int64,float,int64,float -input6.format=nd,nd,nd,nd +input6.dtype=int64,float,int64,float,int64,float,int64,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=int32,int32,int32,int32 -input7.format=nd,nd,nd,nd +input7.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,bf16,float16,bf16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input11.name=quantOffset1 -input11.dtype=int8,int8,int8,int8 -input11.format=nd,nd,nd,nd +input11.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=int64,float,int64,float -input13.format=nd,nd,nd,nd +input13.dtype=int64,float,int64,float,int64,float,int64,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=int32,int32,int32,int32 -input14.format=nd,nd,nd,nd +input14.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=int8,int8,int8,int8 -input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input19.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,bf16,float16,bf16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input23.name=qNopeScale -input23.dtype=float16,bf16,float16,bf16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd output0.name=qOut0 -output0.dtype=int8,int8,int8,int8 -output0.format=nd,nd,nd,nd +output0.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut0 -output1.dtype=int8,int8,int8,int8 -output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output1.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz output2.name=qOut1 -output2.dtype=float16,bf16,float16,bf16 -output2.format=nd,nd,nd,nd +output2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output2.format=nd,nd,nd,nd,nd,nd,nd,nd output3.name=kvCacheOut1 -output3.dtype=float16,bf16,float16,bf16 -output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz [MlaPreprocessOperationNZ] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,bf16,float16,bf16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input4.name=quantOffset0 -input4.dtype=int8,int8,int8,int8 -input4.format=nd,nd,nd,nd +input4.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=int64,float,int64,float -input6.format=nd,nd,nd,nd +input6.dtype=int64,float,int64,float,int64,float,int64,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=int32,int32,int32,int32 -input7.format=nd,nd,nd,nd +input7.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,bf16,float16,bf16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input11.name=quantOffset1 -input11.dtype=int8,int8,int8,int8 -input11.format=nd,nd,nd,nd +input11.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=int64,float,int64,float -input13.format=nd,nd,nd,nd +input13.dtype=int64,float,int64,float,int64,float,int64,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=int32,int32,int32,int32 -input14.format=nd,nd,nd,nd +input14.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=float16,bf16,float16,bf16 -input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input19.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,bf16,float16,bf16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input22.optional=true input23.name=qNopeScale -input23.dtype=float16,bf16,float16,bf16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd input23.optional=true output0.name=qOut0 -output0.dtype=float16,bf16,float16,bf16 -output0.format=nd,nd,nd,nd +output0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut0 -output1.dtype=float16,bf16,float16,bf16 -output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz output2.name=qOut1 -output2.dtype=float16,bf16,float16,bf16 -output2.format=nd,nd,nd,nd +output2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output2.format=nd,nd,nd,nd,nd,nd,nd,nd output3.name=kvCacheOut1 -output3.dtype=float16,bf16,float16,bf16 -output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz [MlaPreprocessOperationPertoken] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,float16,float16,float16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input3.optional=true input4.name=quantOffset0 -input4.dtype=float16,float16,float16,float16 -input4.format=nd,nd,nd,nd +input4.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input4.optional=true input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=float,float,float,float -input6.format=nd,nd,nd,nd +input6.dtype=float,float,float,float,float,float,float,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=float16,bf16,float16,bf16 -input7.format=nd,nd,nd,nd +input7.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input7.optional=true input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,float16,float16,float16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input10.optional=true input11.name=quantOffset1 -input11.dtype=float16,float16,float16,float16 -input11.format=nd,nd,nd,nd +input11.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input11.optional=true input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=float,float,float,float -input13.format=nd,nd,nd,nd +input13.dtype=float,float,float,float,float,float,float,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=float16,bf16,float16,bf16 -input14.format=nd,nd,nd,nd +input14.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input14.optional=true input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=float16,bf16,float16,bf16 -input19.format=nd,nd,nd,nd +input19.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input19.format=nd,nd,nd,nd,nd,nd,nd,nd input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=nd,nd,nd,nd +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=nd,nd,nd,nd,nd,nd,nd,nd input20.optional=true input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,float16,float16,float16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input22.optional=true input23.name=qNopeScale -input23.dtype=float16,float16,float16,float16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd input23.optional=true output0.name=qOut -output0.dtype=float16,bf16,float16,bf16 -output0.format=nd,nd,nd,nd +output0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut -output1.dtype=float16,bf16,float16,bf16 -output1.format=nd,nd,nd,nd +output1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output1.format=nd,nd,nd,nd,nd,nd,nd,nd [MlaPreprocessOperationSplitPertoken] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,float16,float16,float16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input3.optional=true input4.name=quantOffset0 -input4.dtype=float16,float16,float16,float16 -input4.format=nd,nd,nd,nd +input4.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input4.optional=true input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=float,float,float,float -input6.format=nd,nd,nd,nd +input6.dtype=float,float,float,float,float,float,float,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=float16,bf16,float16,bf16 -input7.format=nd,nd,nd,nd +input7.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input7.optional=true input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,float16,float16,float16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input10.optional=true input11.name=quantOffset1 -input11.dtype=float16,float16,float16,float16 -input11.format=nd,nd,nd,nd +input11.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input11.optional=true input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=float,float,float,float -input13.format=nd,nd,nd,nd +input13.dtype=float,float,float,float,float,float,float,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=float16,bf16,float16,bf16 -input14.format=nd,nd,nd,nd +input14.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input14.optional=true input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=float16,bf16,float16,bf16 -input19.format=nd,nd,nd,nd +input19.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input19.format=nd,nd,nd,nd,nd,nd,nd,nd input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=nd,nd,nd,nd +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=nd,nd,nd,nd,nd,nd,nd,nd +input20.optional=true input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,float16,float16,float16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input22.optional=true input23.name=qNopeScale -input23.dtype=float16,float16,float16,float16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd input23.optional=true output0.name=qOut0 -output0.dtype=float16,bf16,float16,bf16 -output0.format=nd,nd,nd,nd +output0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut0 -output1.dtype=float16,bf16,float16,bf16 -output1.format=nd,nd,nd,nd +output1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output1.format=nd,nd,nd,nd,nd,nd,nd,nd output2.name=qOut1 -output2.dtype=float16,bf16,float16,bf16 -output2.format=nd,nd,nd,nd +output2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output2.format=nd,nd,nd,nd,nd,nd,nd,nd output3.name=kvCacheOut1 -output3.dtype=float16,bf16,float16,bf16 -output3.format=nd,nd,nd,nd +output3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output3.format=nd,nd,nd,nd,nd,nd,nd,nd [MlaPreprocessOperationINT8NZPertoken] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,float16,float16,float16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input3.optional=true input4.name=quantOffset0 -input4.dtype=float16,float16,float16,float16 -input4.format=nd,nd,nd,nd +input4.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input4.optional=true input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=float,float,float,float -input6.format=nd,nd,nd,nd +input6.dtype=float,float,float,float,float,float,float,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=float16,bf16,float16,bf16 -input7.format=nd,nd,nd,nd +input7.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input7.optional=true input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,float16,float16,float16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input10.optional=true input11.name=quantOffset1 -input11.dtype=float16,float16,float16,float16 -input11.format=nd,nd,nd,nd +input11.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input11.optional=true input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=float,float,float,float -input13.format=nd,nd,nd,nd +input13.dtype=float,float,float,float,float,float,float,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=float16,bf16,float16,bf16 -input14.format=nd,nd,nd,nd +input14.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input14.optional=true input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=int8,int8,int8,int8 -input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input19.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,float16,float16,float16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input23.name=qNopeScale -input23.dtype=float16,float16,float16,float16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd output0.name=qOut0 -output0.dtype=int8,int8,int8,int8 -output0.format=nd,nd,nd,nd +output0.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut0 -output1.dtype=int8,int8,int8,int8 -output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output1.dtype=int8,int8,int8,int8,int8,int8,int8,int8 +output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz output2.name=qOut1 -output2.dtype=float16,bf16,float16,bf16 -output2.format=nd,nd,nd,nd +output2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output2.format=nd,nd,nd,nd,nd,nd,nd,nd output3.name=kvCacheOut1 -output3.dtype=float16,bf16,float16,bf16 -output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz [MlaPreprocessOperationNZPertoken] input0.name=input -input0.dtype=float16,bf16,float16,bf16 -input0.format=nd,nd,nd,nd +input0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input0.format=nd,nd,nd,nd,nd,nd,nd,nd input1.name=gamma0 -input1.dtype=float16,bf16,float16,bf16 -input1.format=nd,nd,nd,nd +input1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input1.format=nd,nd,nd,nd,nd,nd,nd,nd input2.name=beta0 -input2.dtype=float16,bf16,float16,bf16 -input2.format=nd,nd,nd,nd +input2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input2.format=nd,nd,nd,nd,nd,nd,nd,nd input3.name=quantScale0 -input3.dtype=float16,float16,float16,float16 -input3.format=nd,nd,nd,nd +input3.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input3.format=nd,nd,nd,nd,nd,nd,nd,nd input3.optional=true input4.name=quantOffset0 -input4.dtype=float16,float16,float16,float16 -input4.format=nd,nd,nd,nd +input4.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input4.format=nd,nd,nd,nd,nd,nd,nd,nd input4.optional=true input5.name=wdqkv -input5.dtype=int8,int8,int8,int8 -input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input5.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input5.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input6.name=deScale0 -input6.dtype=float,float,float,float -input6.format=nd,nd,nd,nd +input6.dtype=float,float,float,float,float,float,float,float +input6.format=nd,nd,nd,nd,nd,nd,nd,nd input7.name=bias0 -input7.dtype=float16,bf16,float16,bf16 -input7.format=nd,nd,nd,nd +input7.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input7.format=nd,nd,nd,nd,nd,nd,nd,nd input7.optional=true input8.name=gamma1 -input8.dtype=float16,bf16,float16,bf16 -input8.format=nd,nd,nd,nd +input8.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input8.format=nd,nd,nd,nd,nd,nd,nd,nd input9.name=beta1 -input9.dtype=float16,bf16,float16,bf16 -input9.format=nd,nd,nd,nd +input9.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input9.format=nd,nd,nd,nd,nd,nd,nd,nd input10.name=quantScale1 -input10.dtype=float16,float16,float16,float16 -input10.format=nd,nd,nd,nd +input10.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input10.format=nd,nd,nd,nd,nd,nd,nd,nd input10.optional=true input11.name=quantOffset1 -input11.dtype=float16,float16,float16,float16 -input11.format=nd,nd,nd,nd +input11.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input11.format=nd,nd,nd,nd,nd,nd,nd,nd input11.optional=true input12.name=wuq -input12.dtype=int8,int8,int8,int8 -input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input12.dtype=int8,int8,int8,int8,float16,bf16,float16,bf16 +input12.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input13.name=deScale1 -input13.dtype=float,float,float,float -input13.format=nd,nd,nd,nd +input13.dtype=float,float,float,float,float,float,float,float +input13.format=nd,nd,nd,nd,nd,nd,nd,nd input14.name=bias1 -input14.dtype=float16,bf16,float16,bf16 -input14.format=nd,nd,nd,nd +input14.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input14.format=nd,nd,nd,nd,nd,nd,nd,nd input14.optional=true input15.name=gamma2 -input15.dtype=float16,bf16,float16,bf16 -input15.format=nd,nd,nd,nd +input15.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input15.format=nd,nd,nd,nd,nd,nd,nd,nd input16.name=cos -input16.dtype=float16,bf16,float16,bf16 -input16.format=nd,nd,nd,nd +input16.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input16.format=nd,nd,nd,nd,nd,nd,nd,nd input17.name=sin -input17.dtype=float16,bf16,float16,bf16 -input17.format=nd,nd,nd,nd +input17.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input17.format=nd,nd,nd,nd,nd,nd,nd,nd input18.name=wuk -input18.dtype=float16,bf16,float16,bf16 -input18.format=fractal_nz,fractal_nz,nd,nd +input18.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input18.format=fractal_nz,fractal_nz,nd,nd,fractal_nz,fractal_nz,nd,nd input19.name=kvcache -input19.dtype=float16,bf16,float16,bf16 -input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input19.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input19.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input20.name=kvCacheRope -input20.dtype=float16,bf16,float16,bf16 -input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +input20.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +input20.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz input21.name=slotmapping -input21.dtype=int32,int32,int32,int32 -input21.format=nd,nd,nd,nd +input21.dtype=int32,int32,int32,int32,int32,int32,int32,int32 +input21.format=nd,nd,nd,nd,nd,nd,nd,nd input22.name=ctkvScale -input22.dtype=float16,float16,float16,float16 -input22.format=nd,nd,nd,nd +input22.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input22.format=nd,nd,nd,nd,nd,nd,nd,nd input22.optional=true input23.name=qNopeScale -input23.dtype=float16,float16,float16,float16 -input23.format=nd,nd,nd,nd +input23.dtype=float16,float16,float16,float16,float16,float16,float16,float16 +input23.format=nd,nd,nd,nd,nd,nd,nd,nd input23.optional=true output0.name=qOut0 -output0.dtype=float16,bf16,float16,bf16 -output0.format=nd,nd,nd,nd +output0.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output0.format=nd,nd,nd,nd,nd,nd,nd,nd output1.name=kvCacheOut0 -output1.dtype=float16,bf16,float16,bf16 -output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output1.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output1.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz output2.name=qOut1 -output2.dtype=float16,bf16,float16,bf16 -output2.format=nd,nd,nd,nd +output2.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output2.format=nd,nd,nd,nd,nd,nd,nd,nd output3.name=kvCacheOut1 -output3.dtype=float16,bf16,float16,bf16 -output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz +output3.dtype=float16,bf16,float16,bf16,float16,bf16,float16,bf16 +output3.format=fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz,fractal_nz [ReshapeAndCacheOmniOperation] input0.name=key diff --git a/src/atb/runner/aclnn_runner.cpp b/src/atb/runner/aclnn_runner.cpp index 6904e285..36e15b9c 100644 --- a/src/atb/runner/aclnn_runner.cpp +++ b/src/atb/runner/aclnn_runner.cpp @@ -62,7 +62,7 @@ Status AclnnRunner::SetupImpl(RunnerVariantPack &runnerVariantPack) << "getWorkspace success, workspaceSize: " << this->atbVariantPack_.workspaceBufferSize << ", workspace addr: " << this->atbVariantPack_.workspaceBuffer; aclnnRet = aclSetAclOpExecutorRepeatable(this->aclnnExecutor_.get()); - if (ret != 0) { + if (aclnnRet != 0) { // 设置算子可复用失败,标记cache中executor不可复用 ATB_LOG(INFO) << this->GetName() << " call aclSetAclOpExecutorRepeatable fail: " << ret; this->executorRepeatable_ = false; @@ -143,6 +143,7 @@ Status AclnnRunner::PreExecuteImpl(RunnerVariantPack &runnerVariantPack) void AclnnRunner::UpdateWorkspace(const RunnerVariantPack &runnerVariantPack) { this->atbVariantPack_.workspaceBufferSize = runnerVariantPack.workspaceBufferSize; + this->atbVariantPack_.workspaceBuffer = runnerVariantPack.workspaceBuffer; } Status AclnnRunner::ExecuteImpl(RunnerVariantPack &runnerVariantPack) diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp index fadf0457..a3195cf0 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp @@ -34,7 +34,7 @@ aclnnStatus (*MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_)( aclnnStatus (*MlaPreprocessAclnnRunner::aclnnExecuteFunc_)(void *, uint64_t, aclOpExecutor *, aclrtStream) = nullptr; MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m) - : AclnnRunner("MlaPreprocessOpsRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param) + : AclnnRunner("MlaPreprocessAclnnRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param) { ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner called"; } @@ -107,7 +107,7 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() if (status != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << "load getWorkspace function from aclnn failed! Consider upgrade CANN first!"; - return status; + return 561003; // ACLNN_ERR_INNER_FIND_KERNEL_ERROR } size_t inTensorStart = 0; bool isRopeCache = param_.cacheMode != infer::MlaPreprocessParam::CacheMode::KVCACHE; @@ -159,19 +159,16 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() ATB_LOG(INFO) << GetLogPrefix() << "workspaceSize addr: " << &(this->atbVariantPack_.workspaceBufferSize); + int64_t wdkvSplitCount = 1; aclnnStatus ret = MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_( input, gamma0, beta0, quantScale0, quantOffset0, wdqkv, deScale0, bias0, gamma1, beta1, quantScale1, quantOffset1, wuq, deScale1, bias1, gamma2, cos, sin, wuk, kvCache, kRope, slotmapping, ctkvScale, qNopeScale, param_.wdqDim, param_.qRopeDim, param_.kRopeDim, param_.epsilon, param_.qRotaryCoeff, param_.kRotaryCoeff, - param_.transposeWdq, param_.transposeWuq, param_.transposeWuk, param_.cacheMode, param_.quantMode, - doRmsNorm_, // doRmsNorm - 1, // wdkvSplitCount - qOut0, kvCacheOut0, qOut1, kvCacheOut1, &(this->atbVariantPack_.workspaceBufferSize), &raw_executor_ptr); - // TODO:上移 - aclError aclRet = aclSetAclOpExecutorRepeatable(this->aclnnExecutor_.get()); - bool repeatable = this->executorRepeatable_; - this->aclnnExecutor_ = std::shared_ptr(raw_executor_ptr, [repeatable](aclOpExecutor *ptr) { - if (ptr && repeatable) { // 可复用时才手动销毁aclOpExecutor + param_.transposeWdq, param_.transposeWuq, param_.transposeWuk, param_.cacheMode, param_.quantMode, doRmsNorm_, + wdkvSplitCount, qOut0, kvCacheOut0, qOut1, kvCacheOut1, &(this->atbVariantPack_.workspaceBufferSize), + &raw_executor_ptr); + this->aclnnExecutor_ = std::shared_ptr(raw_executor_ptr, [this](aclOpExecutor *ptr) { + if (ptr && this->executorRepeatable_) { // 可复用时才手动销毁aclOpExecutor aclDestroyAclOpExecutor(ptr); } }); @@ -207,7 +204,7 @@ Status MlaPreprocessAclnnRunner::LoadMethod() MlaPreprocessAclnnRunner::aclnnExecuteFunc_ != nullptr) { return NO_ERROR; } - DlManager dlManager = DlManager(std::string(std::getenv("ASCEND_HOME_PATH")) + "/lib64/libopapi.so"); + static DlManager dlManager = DlManager(std::string(std::getenv("ASCEND_HOME_PATH")) + "/lib64/libopapi.so"); Status ret = dlManager.getSymbol("aclnnMlaPreprocessGetWorkspaceSize", (void **)&MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_); if (ret != NO_ERROR) { diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index 75ad479f..e3d23119 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -417,8 +417,10 @@ Status MlaPreprocessOperation::CheckAclnnKernel(const SVector &inTen ATB_LOG(INFO) << GetLogPrefix() << "CheckAclnnKernel start"; // input's hiddenSize != 7168, generalize hiddenSize bool generalizedHiddenSize = inTensorDesc.at(INPUT_INDEX).shape.dims[1] != INNER_DIM_7168; // 1: hiddenSize - // if wdqkv's dtype is bf16, then do not rmsNorm - doRmsNorm_ = inTensorDesc.at(WDQKV_INDEX).dtype != ACL_BF16; + // if wdqkv's dtype is the same as the input and is either float16/bf16, then do not do rmsNormQuant + aclDataType inputDtype = inTensorDesc.at(INPUT_INDEX).dtype; + doRmsNorm_ = + inTensorDesc.at(WDQKV_INDEX).dtype == inputDtype && (inputDtype == ACL_FLOAT16 || inputDtype == ACL_BF16); if (!generalizedHiddenSize && doRmsNorm_) { ATB_LOG(INFO) << GetLogPrefix() << "no need to use aclnn kernel for non-generalized hiddenSize and rmsNormQuant for input is on"; diff --git a/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py b/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py index 5f4d66e2..75aab78e 100644 --- a/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py +++ b/tests/apitest/opstest/python/operations/mla_preprocess/test_mla_preprocess_aclnn.py @@ -744,6 +744,108 @@ class TestMLAPrepross(operation_test.OperationTest): return [self.qOut_npu[..., 0:512], self.keyout_npu[..., 0:512], self.qOut_npu[..., 512:576], self.keyout_npu[..., 512:576]] else: return [self.qOut_npu, self.keyout_npu] + + def __test_mlapo_impl( + self, + data_type: torch.dtype, + num_tokens: int, + num_heads: int, + cache_mode: int, + quant_mode: int, + weight_format: int, + ) -> None: + self.calc_vec_mm_atb_data(num_tokens, num_heads, data_type, cache_mode, quant_mode) + self.set_param( + "MlaPreprocessOperation", + {"N": num_tokens, "headNum": num_heads, "cacheMode": cache_mode, "quantMode": quant_mode}, + ) + self.set_input_formats( + [ + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nz, # self.wdqkv, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nz, # self.wuq, + self.format_nd, + weight_format, # self.wuk + self.format_nd, + self.format_nd, + self.format_nd, + self.format_nd, + ] + ) + in_tensors = [ + self.input1, + self.gamma1, + self.beta1, + self.quantScale1, + self.quantOffset1, + transdata(self.wdqkv, (16, 32)), # self.wdqkv, + self.bias1, + self.gamma2, + self.beta2, + self.quantScale2, + self.quantOffset2, + self.gamma3, + self.sin1, + self.cos1, + self.sin2, + self.cos2, + self.keyCache, + self.slotMapping, + transdata(self.wuq, (16, 32)), # self.wuq, + self.bias2, + self.wuk if weight_format == self.format_nd else transdata_3d(self.wuk), + self.deScale1, + self.deScale2, + self.quantScale3, + self.qNopeScale, + ] + if cache_mode == 0: + out_tensors = [ + torch.zeros_like(self.qOut, dtype=data_type), + self.keyOutTensor, + torch.tensor([]), + torch.tensor([]), + ] + elif cache_mode == 1: + out_tensors = [ + torch.zeros_like(self.qOut[..., :512], dtype=data_type), + self.keyOutTensor[..., :512], + torch.zeros_like(self.qOut[..., 512:], dtype=data_type), + self.keyOutTensor[..., 512:], + ] + elif cache_mode == 2: + out_tensors = [ + torch.zeros_like(self.qOut[..., :512], dtype=torch.int8), + self.I8keyCache1, + torch.zeros_like(self.qOut[..., 512:], dtype=data_type), + self.keyCache2, + ] + else: + out_tensors = [ + torch.zeros_like(self.qOut[..., :512], dtype=data_type), + self.FPkeyCache1, + torch.zeros_like(self.qOut[..., 512:], dtype=data_type), + self.keyCache2, + ] + + self.execute(in_tensors, out_tensors) + return def compare_data(self, tensor1, tensor2): out = tensor1.flatten() @@ -770,9 +872,9 @@ class TestMLAPrepross(operation_test.OperationTest): def golden_compare(self, out_tensors, golden_tensors): # self.compare_data(out_tensors.npu(), golden_tensors.npu()) if self.cacheMode == 1: - print(f"qOut_npu npu max {torch.max(self.qOut_npu.clone()[..., 0:512].cpu())} min {torch.min(self.qOut_npu.clone()[..., 0:512].cpu())}") + print(f"qOut_npu npu max {torch.max(self.qOut_npu.clone().detach()[..., 0:512].cpu())} min {torch.min(self.qOut_npu.clone().detach()[..., 0:512].cpu())}") print(f"qout max {torch.max(self.qOut.clone()[..., 0:512].cpu())} min {torch.min(self.qOut.clone()[..., 0:512].cpu())}") - print(f"out_tensors max {torch.max(out_tensors.clone()[..., 0:512].cpu())} min {torch.min(out_tensors.clone()[..., 0:512].cpu())}") + print(f"out_tensors max {torch.max(out_tensors[..., 0:512].cpu())} min {torch.min(out_tensors[..., 0:512].cpu())}") if self.compare_count == 0: self.compare_count += 1 return compare_cv(self.qOut_npu[..., 0:512].npu(), self.qOut[..., 0:512].npu(), out_tensors.npu()) @@ -797,7 +899,6 @@ class TestMLAPrepross(operation_test.OperationTest): return self.compare_count = 0 self.cacheMode = 1 - N = 32 headNum = 32 data_type = torch.float16 hidden_size = 8000 @@ -840,7 +941,6 @@ class TestMLAPrepross(operation_test.OperationTest): torch.tensor([]).to(data_type).npu(), torch.tensor([]).to(data_type).npu()], [qOut, kvCacheOut, qRopeOut, krCacheOut]) # float16 - print("================================qOut=============================", qOut) def test_mla_preprocess_no_rms_norm(self): @@ -848,7 +948,7 @@ class TestMLAPrepross(operation_test.OperationTest): print("this testcase only supports Ascend910B") return self.compare_count = 0 - self.cacheMode = 0 + self.cacheMode = 1 N = 32 headNum = 32 data_type = torch.float16 @@ -858,40 +958,42 @@ class TestMLAPrepross(operation_test.OperationTest): self.calc_vec_mm_atb_data(N,headNum,data_type, hidden_size) self.keyCache = self.keyCache.npu() print("test_mla_preprocess_no_rms_norm") + qOut: torch.Tensor = torch.zeros((N, headNum, 512), dtype=data_type).npu() # float16 + kvCacheOut: torch.Tensor = self.keyCache[..., 0:512].npu() # float16 + qRopeOut: torch.Tensor = torch.zeros((N, headNum, 64), dtype=data_type).npu() # float16 + krCacheOut: torch.Tensor = self.keyCache[..., 512:576].npu() + print("type -------------- [qOut, kvCacheOut, qRopeOut, krCacheOut]", type([qOut, kvCacheOut, qRopeOut, krCacheOut])) self.execute_out(OP_NAME, PARAM, - [self.input1.npu(), - self.gamma1.npu(), - self.beta1.npu(), - self.quantScale1.npu(), - self.quantOffset1.npu(), - torch_npu.npu_format_cast(transdata(self.wdqkv, (16, 32)).contiguous().to(torch.bfloat16).npu(), 29), - self.deScale1.to(torch.float32).npu(), - self.bias1.npu(), - self.gamma2.npu(), - self.beta2.npu(), - self.quantScale2.npu(), - self.quantOffset2.npu(), - torch_npu.npu_format_cast(transdata(self.wuq, (16, 32)).contiguous().npu(), 29), - self.deScale2.to(torch.float32).npu(), - self.bias2.npu(), - self.gamma3.npu(), - self.cos1.npu(), - self.sin1.npu(), - self.wuk.npu(), - self.keyCache[..., 0:512].npu(), - self.keyCache[..., 512:576].npu(), - self.slotMapping.npu(), - torch.tensor([]).npu(), - torch.tensor([]).npu()], - [torch.zeros((N, headNum, 512), dtype=data_type).npu(), - self.keyCache[..., 0:512].npu(), - ]) + [self.input1.npu(), # float16 + self.gamma1.npu(), # float16 + self.beta1.npu(), # float16 + self.quantScale1.npu(), # float16 + self.quantOffset1.npu(), # int8 + torch_npu.npu_format_cast(transdata(self.wdqkv.to(torch.float16), (16, 32)).contiguous().npu(), 29), # float16,nz + self.deScale1.to(torch.int64).npu(), # int64 + self.bias1.npu(), # int32 + self.gamma2.npu(), # float16 + self.beta2.npu(), # float16 + self.quantScale2.npu(), # float16 + self.quantOffset2.npu(), # int8 + torch_npu.npu_format_cast(transdata(self.wuq.to(torch.float16), (16, 32)).contiguous().npu(), 29), # float16,nz + self.deScale2.to(torch.int64).npu(), # int64 + self.bias2.npu(), # int32 + self.gamma3.npu(), # float16 + self.cos1.npu(), # float16 + self.sin1.npu(), # float16 + self.wuk.npu(), # float16 + self.keyCache[..., 0:512].npu(), # float16 + self.keyCache[..., 512:576].npu(), # float16 + self.slotMapping.npu(), # int32 + torch.tensor([]).to(data_type).npu(), + torch.tensor([]).to(data_type).npu()], + [qOut, kvCacheOut, qRopeOut, krCacheOut]) # float16 def suite(): suite = unittest.TestSuite() for _ in range(1): suite.addTest(TestMLAPrepross('test_mla_preprocess')) - # suite.addTest(TestMLAPrepross('test_mla_preprocess_no_rms_norm')) return suite if __name__ == '__main__': -- Gitee From 1a32d3ce794d8aaf3ba53909f06b63f871d0a2a9 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 10:08:42 +0800 Subject: [PATCH 3/8] fix branch --- src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index e3d23119..bd12c926 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -417,10 +417,10 @@ Status MlaPreprocessOperation::CheckAclnnKernel(const SVector &inTen ATB_LOG(INFO) << GetLogPrefix() << "CheckAclnnKernel start"; // input's hiddenSize != 7168, generalize hiddenSize bool generalizedHiddenSize = inTensorDesc.at(INPUT_INDEX).shape.dims[1] != INNER_DIM_7168; // 1: hiddenSize - // if wdqkv's dtype is the same as the input and is either float16/bf16, then do not do rmsNormQuant - aclDataType inputDtype = inTensorDesc.at(INPUT_INDEX).dtype; - doRmsNorm_ = - inTensorDesc.at(WDQKV_INDEX).dtype == inputDtype && (inputDtype == ACL_FLOAT16 || inputDtype == ACL_BF16); + // if wdqkv's dtype is the same as the input and is either float16/bf16, then do not do rmsNormQuant + aclDataType inputDtype = inTensorDesc.at(INPUT_INDEX).dtype; + doRmsNorm_ = + !(inTensorDesc.at(WDQKV_INDEX).dtype == inputDtype && (inputDtype == ACL_FLOAT16 || inputDtype == ACL_BF16)); if (!generalizedHiddenSize && doRmsNorm_) { ATB_LOG(INFO) << GetLogPrefix() << "no need to use aclnn kernel for non-generalized hiddenSize and rmsNormQuant for input is on"; -- Gitee From 0940a0b2d7366e86e828fd2ef5701c7891e8d784 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 13:37:24 +0800 Subject: [PATCH 4/8] fix check --- src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index bd12c926..fe92c954 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -394,7 +394,7 @@ Status MlaPreprocessOperation::HiddenSizeCheck(const SVector &inTens << "dimNum of gamma0 should be 1, but got: " << inTensorDesc.at(GAMMA0_INDEX).shape.dimNum; return ERROR_INVALID_TENSOR_DIM_NUM; } - if (inTensorDesc.at(GAMMA0_INDEX).shape.dims[1] != hiddenSize) { + if (inTensorDesc.at(GAMMA0_INDEX).shape.dims[0] != hiddenSize) { ATB_LOG(ERROR) << GetLogPrefix() << "hiddenSize of gamma0 should be the same as input's, " << hiddenSize << ", but got: " << inTensorDesc.at(GAMMA0_INDEX).shape.dims[0]; // gamma0 hiddenSize index return ERROR_INVALID_TENSOR_DIM; @@ -404,7 +404,7 @@ Status MlaPreprocessOperation::HiddenSizeCheck(const SVector &inTens << "dimNum of beta0 should be 1, but got: " << inTensorDesc.at(BETA0_INDEX).shape.dimNum; return ERROR_INVALID_TENSOR_DIM_NUM; } - if (inTensorDesc.at(BETA0_INDEX).shape.dims[1] != hiddenSize) { + if (inTensorDesc.at(BETA0_INDEX).shape.dims[0] != hiddenSize) { ATB_LOG(ERROR) << GetLogPrefix() << "hiddenSize of beta0 should be the same as input's, " << hiddenSize << ", but got: " << inTensorDesc.at(BETA0_INDEX).shape.dims[0]; // beta0 hiddenSize index return ERROR_INVALID_TENSOR_DIM; -- Gitee From f2f7d63ba2e08888c8e6664070d37acfe2c537b1 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 13:49:06 +0800 Subject: [PATCH 5/8] cleancode --- src/atb/runner/aclnn_runner.cpp | 6 +++--- src/atb/utils/dl_manager.h | 4 ++++ .../mla_preprocess_aclnn_runner.cpp | 20 +++++++++---------- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/atb/runner/aclnn_runner.cpp b/src/atb/runner/aclnn_runner.cpp index 36e15b9c..10c9e64e 100644 --- a/src/atb/runner/aclnn_runner.cpp +++ b/src/atb/runner/aclnn_runner.cpp @@ -62,9 +62,9 @@ Status AclnnRunner::SetupImpl(RunnerVariantPack &runnerVariantPack) << "getWorkspace success, workspaceSize: " << this->atbVariantPack_.workspaceBufferSize << ", workspace addr: " << this->atbVariantPack_.workspaceBuffer; aclnnRet = aclSetAclOpExecutorRepeatable(this->aclnnExecutor_.get()); - if (aclnnRet != 0) { + if (aclnnRet != 0) { // 设置算子可复用失败,标记cache中executor不可复用 - ATB_LOG(INFO) << this->GetName() << " call aclSetAclOpExecutorRepeatable fail: " << ret; + ATB_LOG(INFO) << this->GetName() << " call aclSetAclOpExecutorRepeatable fail with error code: " << aclnnRet; this->executorRepeatable_ = false; } else { // 设置算子可复用成功,标记cache中executor可复用 @@ -143,7 +143,7 @@ Status AclnnRunner::PreExecuteImpl(RunnerVariantPack &runnerVariantPack) void AclnnRunner::UpdateWorkspace(const RunnerVariantPack &runnerVariantPack) { this->atbVariantPack_.workspaceBufferSize = runnerVariantPack.workspaceBufferSize; - this->atbVariantPack_.workspaceBuffer = runnerVariantPack.workspaceBuffer; + this->atbVariantPack_.workspaceBuffer = runnerVariantPack.workspaceBuffer; } Status AclnnRunner::ExecuteImpl(RunnerVariantPack &runnerVariantPack) diff --git a/src/atb/utils/dl_manager.h b/src/atb/utils/dl_manager.h index 69568d6f..626d630e 100644 --- a/src/atb/utils/dl_manager.h +++ b/src/atb/utils/dl_manager.h @@ -8,6 +8,9 @@ * See LICENSE in the root of the software repository for the full text of the License. */ +#ifndef ATB_DISKUTIL_H +#define ATB_DISKUTIL_H + #include #include "atb/utils/log.h" @@ -23,3 +26,4 @@ private: void *handle_ = nullptr; }; } // namespace atb +#endif diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp index a3195cf0..4c5a102e 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_aclnn_runner.cpp @@ -34,13 +34,13 @@ aclnnStatus (*MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_)( aclnnStatus (*MlaPreprocessAclnnRunner::aclnnExecuteFunc_)(void *, uint64_t, aclOpExecutor *, aclrtStream) = nullptr; MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m) - : AclnnRunner("MlaPreprocessAclnnRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param) + : AclnnRunner("MlaPreprocessAclnnRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param) { ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner called"; } MlaPreprocessAclnnRunner::MlaPreprocessAclnnRunner(const infer::MlaPreprocessParam ¶m, bool doRmsNorm) - : AclnnRunner("MlaPreprocessOpsRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param), doRmsNorm_(doRmsNorm) + : AclnnRunner("MlaPreprocessAclnnRunner", RUNNER_TYPE_MLA_PREPROCESS_ACLNN), param_(param), doRmsNorm_(doRmsNorm) { ATB_LOG(INFO) << GetLogPrefix() << "MlaPreprocessAclnnRunnecr::MlaPreprocessAclnnRunner called, set doRmsNorm: " << doRmsNorm; @@ -107,7 +107,7 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() if (status != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << "load getWorkspace function from aclnn failed! Consider upgrade CANN first!"; - return 561003; // ACLNN_ERR_INNER_FIND_KERNEL_ERROR + return 561003; // ACLNN_ERR_INNER_FIND_KERNEL_ERROR } size_t inTensorStart = 0; bool isRopeCache = param_.cacheMode != infer::MlaPreprocessParam::CacheMode::KVCACHE; @@ -159,16 +159,16 @@ aclnnStatus MlaPreprocessAclnnRunner::SetAclNNWorkspaceExecutor() ATB_LOG(INFO) << GetLogPrefix() << "workspaceSize addr: " << &(this->atbVariantPack_.workspaceBufferSize); - int64_t wdkvSplitCount = 1; + int64_t wdkvSplitCount = 1; aclnnStatus ret = MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_( input, gamma0, beta0, quantScale0, quantOffset0, wdqkv, deScale0, bias0, gamma1, beta1, quantScale1, quantOffset1, wuq, deScale1, bias1, gamma2, cos, sin, wuk, kvCache, kRope, slotmapping, ctkvScale, qNopeScale, param_.wdqDim, param_.qRopeDim, param_.kRopeDim, param_.epsilon, param_.qRotaryCoeff, param_.kRotaryCoeff, - param_.transposeWdq, param_.transposeWuq, param_.transposeWuk, param_.cacheMode, param_.quantMode, doRmsNorm_, - wdkvSplitCount, qOut0, kvCacheOut0, qOut1, kvCacheOut1, &(this->atbVariantPack_.workspaceBufferSize), - &raw_executor_ptr); - this->aclnnExecutor_ = std::shared_ptr(raw_executor_ptr, [this](aclOpExecutor *ptr) { - if (ptr && this->executorRepeatable_) { // 可复用时才手动销毁aclOpExecutor + param_.transposeWdq, param_.transposeWuq, param_.transposeWuk, param_.cacheMode, param_.quantMode, doRmsNorm_, + wdkvSplitCount, qOut0, kvCacheOut0, qOut1, kvCacheOut1, &(this->atbVariantPack_.workspaceBufferSize), + &raw_executor_ptr); + this->aclnnExecutor_ = std::shared_ptr(raw_executor_ptr, [this](aclOpExecutor *ptr) { + if (ptr && this->executorRepeatable_) { // 可复用时才手动销毁aclOpExecutor aclDestroyAclOpExecutor(ptr); } }); @@ -204,7 +204,7 @@ Status MlaPreprocessAclnnRunner::LoadMethod() MlaPreprocessAclnnRunner::aclnnExecuteFunc_ != nullptr) { return NO_ERROR; } - static DlManager dlManager = DlManager(std::string(std::getenv("ASCEND_HOME_PATH")) + "/lib64/libopapi.so"); + static DlManager dlManager = DlManager(std::string(std::getenv("ASCEND_HOME_PATH")) + "/lib64/libopapi.so"); Status ret = dlManager.getSymbol("aclnnMlaPreprocessGetWorkspaceSize", (void **)&MlaPreprocessAclnnRunner::aclnnGetWorkspaceSizeFunc_); if (ret != NO_ERROR) { -- Gitee From cac8170769a3fe0d91553f4452a99b7a046c34e5 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 14:07:51 +0800 Subject: [PATCH 6/8] remove opapi in cmake --- src/CMakeLists.txt | 2 +- tests/cinterface/CMakeLists.txt | 2 +- tests/unittest/CMakeLists.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index db3efe90..208da858 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -38,7 +38,7 @@ if(USE_ASAN) target_link_libraries(atb PRIVATE asan) target_link_libraries(atb_train PRIVATE asan) endif() -target_link_libraries(atb PUBLIC dl mki asdops atb_mixops ascendcl profapi lcal hccl pthread acl_op_compiler nnopbase opapi) +target_link_libraries(atb PUBLIC dl mki asdops atb_mixops ascendcl profapi lcal hccl pthread acl_op_compiler nnopbase) target_link_libraries(atb_train PUBLIC atb) target_include_directories(atb PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) target_include_directories(atb_static PUBLIC ${MSTX_PATH} ${ATB_INCLUDE_DIR}) diff --git a/tests/cinterface/CMakeLists.txt b/tests/cinterface/CMakeLists.txt index 4860bdd1..be235428 100644 --- a/tests/cinterface/CMakeLists.txt +++ b/tests/cinterface/CMakeLists.txt @@ -18,5 +18,5 @@ target_include_directories(atb_cinterface PRIVATE ${ASCEND_HOME_PATH}/include/) target_link_directories(atb_cinterface PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${ASCEND_HOME_PATH}/lib64/ ${ATB_HOME_PATH}/lib) target_compile_options(atb_cinterface PRIVATE "-Wno-write-strings") target_compile_options(atb_cinterface PRIVATE -Wno-sign-compare -Wno-narrowing) -target_link_libraries(atb_cinterface PRIVATE atb nnopbase tbe_adapter -lgtest -lgtest_main -lc_sec -lopapi) +target_link_libraries(atb_cinterface PRIVATE atb nnopbase tbe_adapter -lgtest -lgtest_main -lc_sec) install(TARGETS atb_cinterface DESTINATION bin) \ No newline at end of file diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index d9e9fb88..5c9c5a4c 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -21,5 +21,5 @@ target_include_directories(atb_unittest PRIVATE ${ASCEND_HOME_PATH}/include/acln target_link_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${ASCEND_HOME_PATH}/lib64/) target_compile_options(atb_unittest PRIVATE "-Wno-write-strings") target_compile_options(atb_unittest PRIVATE -Wno-sign-compare -Wno-narrowing) -target_link_libraries(atb_unittest PRIVATE atb_test_utils tbe_adapter -lgtest -lgtest_main -lc_sec -lopapi) +target_link_libraries(atb_unittest PRIVATE atb_test_utils tbe_adapter -lgtest -lgtest_main -lc_sec) install(TARGETS atb_unittest DESTINATION bin) -- Gitee From 5f3548607aaa29b7c605db726c4dd9e0d448fcfd Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 15:37:22 +0800 Subject: [PATCH 7/8] remove redundant include --- CMakeLists.txt | 1 - tests/unittest/CMakeLists.txt | 1 - 2 files changed, 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 96c8c1a8..60cc53e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,7 +89,6 @@ include_directories( ${PROJECT_SOURCE_DIR}/3rdparty/mki/include ${PROJECT_SOURCE_DIR}/3rdparty/nlohmannJson/include $ENV{ASCEND_HOME_PATH}/include/aclnn - $ENV{ASCEND_HOME_PATH}/include $ENV{PYTHON_INCLUDE_PATH} $ENV{PYTORCH_INSTALL_PATH}/include $ENV{PYTORCH_INSTALL_PATH}/include/torch/csrc/api/include diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index 5c9c5a4c..a265bc53 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -17,7 +17,6 @@ add_executable(atb_unittest ${CORE_SOURCE} ${NORMAL_SOURCE} ${OPS_SOURCE}) target_include_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty) target_include_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/include) target_include_directories(atb_unittest PRIVATE ${ASCEND_HOME_PATH}/include/) -target_include_directories(atb_unittest PRIVATE ${ASCEND_HOME_PATH}/include/aclnnop) target_link_directories(atb_unittest PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/googletest/lib ${ASCEND_HOME_PATH}/lib64/) target_compile_options(atb_unittest PRIVATE "-Wno-write-strings") target_compile_options(atb_unittest PRIVATE -Wno-sign-compare -Wno-narrowing) -- Gitee From b39d3ea8cd3960f9e4b1af284e75fa277de729c7 Mon Sep 17 00:00:00 2001 From: ivanshan_8170 Date: Wed, 17 Sep 2025 16:17:24 +0800 Subject: [PATCH 8/8] add print --- src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp index fe92c954..84020070 100644 --- a/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp +++ b/src/ops_infer/mla_preprocess/mla_preprocess_operation.cpp @@ -444,6 +444,8 @@ Status MlaPreprocessOperation::CheckAclnnKernel(const SVector &inTen return ret; } } + ATB_LOG(INFO) << GetLogPrefix() << "aclnn kernel is required and usable, generalizedHiddenSize: " << generalizedHiddenSize + << ", doRmsNorm: " << doRmsNorm_; return NO_ERROR; } @@ -469,4 +471,4 @@ nlohmann::json MlaPreprocessOperation::GetParamJson() const { return OpParamToJson(param_); } -} // namespace atb \ No newline at end of file +} // namespace atb -- Gitee