From 239d98df3f75b53ae9c10a107174aa201b452976 Mon Sep 17 00:00:00 2001 From: x30073543 Date: Wed, 11 Jun 2025 21:24:41 +0800 Subject: [PATCH] auto_fusion_v1 --- docs/cpp_api.rst | 6 + include/atb/auto_fusion.h | 74 ++ include/atb/context.h | 7 + include/atb/infer_op_params.h | 32 + scripts/filelist.csv | 2 + src/atb/core/auto_fusion_tool.cpp | 751 ++++++++++++++++++ src/atb/core/context_base.cpp | 9 +- src/atb/core/node_impl/mki_node_implement.cpp | 18 + src/atb/operation/operation_base.cpp | 2 +- src/atb/runner/graph_runner.cpp | 48 ++ src/atb/runner/ops_runner.cpp | 9 + src/atb/utils/param_to_json.cpp | 7 + src/include/atb/core/auto_fusion_tool.h | 201 +++++ src/include/atb/core/context_base.h | 3 + src/include/atb/core/runner_type.h | 1 + src/ops_infer/fusion/fusion_operation.cpp | 117 +++ src/ops_infer/fusion/fusion_operation.h | 34 + src/ops_infer/fusion/fusion_ops_runner.cpp | 121 +++ src/ops_infer/fusion/fusion_ops_runner.h | 30 + src/torch_atb/bindings.cpp | 4 +- src/torch_atb/graph_operation_builder.cpp | 25 +- src/torch_atb/graph_operation_builder.h | 3 +- src/torch_atb/operation_wrapper.cpp | 20 +- src/torch_atb/operation_wrapper.h | 5 +- .../graph_auto_fusion_linear_bias_test.py | 73 ++ ...auto_fusion_matmul_add_gelu_signal_test.py | 113 +++ .../graph_auto_fusion_matmul_add_gelu_test.py | 103 +++ ...raph_auto_fusion_matmul_add_muliti_test.py | 141 ++++ .../graph_auto_fusion_matmul_add_test.py | 88 ++ .../graph_auto_fusion_matmul_gelu_test.py | 91 +++ tests/unittest/normal/auto_fusion_graph.cpp | 211 +++++ 31 files changed, 2341 insertions(+), 8 deletions(-) create mode 100644 include/atb/auto_fusion.h create mode 100644 src/atb/core/auto_fusion_tool.cpp create mode 100644 src/include/atb/core/auto_fusion_tool.h create mode 100644 src/ops_infer/fusion/fusion_operation.cpp create mode 100644 src/ops_infer/fusion/fusion_operation.h create mode 100644 src/ops_infer/fusion/fusion_ops_runner.cpp create mode 100644 src/ops_infer/fusion/fusion_ops_runner.h create mode 100644 tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_linear_bias_test.py create mode 100644 tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_signal_test.py create mode 100644 tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_test.py create mode 100644 tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_muliti_test.py create mode 100644 tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_test.py create mode 100644 tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_gelu_test.py create mode 100644 tests/unittest/normal/auto_fusion_graph.cpp diff --git a/docs/cpp_api.rst b/docs/cpp_api.rst index bd52c68c..6cf90b0e 100644 --- a/docs/cpp_api.rst +++ b/docs/cpp_api.rst @@ -71,4 +71,10 @@ atb_acl.h ----------------------- .. doxygenfile:: atb_acl.h + :project: ATB_CPP_API + +auto_fusion.h +----------------------- + +.. doxygenfile:: auto_fusion.h :project: ATB_CPP_API \ No newline at end of file diff --git a/include/atb/auto_fusion.h b/include/atb/auto_fusion.h new file mode 100644 index 00000000..2b5ad36b --- /dev/null +++ b/include/atb/auto_fusion.h @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef AUTO_FUSION_H +#define AUTO_FUSION_H +#include +#include +#include +#include +#include "atb/context.h" +#include "atb/graph_op_builder.h" +#include "atb/infer_op_params.h" +#include "atb/train_op_params.h" +#include "atb/operation.h" +#include "atb/svector.h" +#include "atb/types.h" +#include "atb/utils.h" + +//! +//! \file auto_fusion.h +//! +//! \brief 定义算子自动融合AutoFusion类 +//! +namespace atb { +//! +//! \class AutoFusion +//! +//! \brief 定义算子自动融合AutoFusion类 +//! +//! 该接口类定义了算子自动融合工具类,提供了自动融合的函数方法。 +//! + +class AutoFusion { +public: + //! \brief 默认构造函数 + //! + //! + AutoFusion() = default; + //! \brief 构造函数 + //! + //! \param graph AutoFusion类被创建的时候需要绑定的图结构 + //! + AutoFusion(atb::GraphParam &graph) + { + (void)(graph); + } + + //! \brief 析构函数。 + virtual ~AutoFusion() = default; + + //! \brief 执行auto fusion的入口函数 + //! + //! \param fusionClassArray 用户指定的融合类型,默认是空,表示不触发算子自动融合 + //! + virtual void DoAutoFusion(const std::set &fusionClassArray = {}) = 0; +}; + +//! \brief auto fusion 工具的创建接口 +//! +//! \param graph AutoFusion类被创建的时候需要绑定的图结构 +//! +//! \param autoFusionTool auto fusion工具指针的handler +//! +//! \return Status类型的状态,表示创建auto fusion 工具是否成功 +//! +Status CreateAutoFusionTool(atb::GraphParam &graph, atb::AutoFusion **autoFusionTool); +} // namespace atb +#endif \ No newline at end of file diff --git a/include/atb/context.h b/include/atb/context.h index 9f9d8150..0fc18931 100644 --- a/include/atb/context.h +++ b/include/atb/context.h @@ -139,6 +139,13 @@ public: //! //! \return 当前的算子下发模式 virtual LaunchMode GetLaunchMode() = 0; + + //! + //! \brief 设置当前的图是否是自动融合状态的 + //! + //! \param flag 算子是否是自动融合状态 + virtual void SetAutoFusionFlag(bool flag = false) = 0; + }; //! diff --git a/include/atb/infer_op_params.h b/include/atb/infer_op_params.h index e1425e98..69e7f52f 100644 --- a/include/atb/infer_op_params.h +++ b/include/atb/infer_op_params.h @@ -420,6 +420,38 @@ struct ElewiseParam { uint8_t rsv[8] = {0}; }; +//! +//! \struct FusionParam +//! +//! \brief 常用的算子自动融合参数 +//! +//! 目前支持的算子自动融合类型包括MATMUL_ADD、MATMUL_GELU +//! +//! +struct FusionParam { + //! + //! \enum FusionType + //! + //! \brief 融合类型 + //! + enum FusionType : int { + NON_FUSION = 0, + MATMUL_ADD = 1, + MATMUL_GELU = 2, + MATMUL_SIGMOID = 3, + MATMUL_SWIGLU = 4, + }; + + //! 融合方式 + FusionType fusionType = NON_FUSION; + //! 指定的数据类型转换的输出数据类型 + aclDataType outTensorType = ACL_DT_UNDEFINED; + //! + //! \brief 预留参数 + //! + uint8_t rsv[8] = {0}; +}; + //! //! \struct KvCacheParam //! diff --git a/scripts/filelist.csv b/scripts/filelist.csv index d03d9c7d..4a0d40cc 100644 --- a/scripts/filelist.csv +++ b/scripts/filelist.csv @@ -21,6 +21,7 @@ atb/cxx_abi_0/include/atb/types.h atb/cxx_abi_0/include/atb/utils.h atb/cxx_abi_0/include/atb/comm.h atb/cxx_abi_0/include/atb/operation_infra.h +atb/cxx_abi_0/include/atb/auto_fusion.h atb/cxx_abi_1/include/atb/atb_infer.h atb/cxx_abi_1/include/atb/context.h atb/cxx_abi_1/include/atb/graph_op_builder.h @@ -34,6 +35,7 @@ atb/cxx_abi_1/include/atb/types.h atb/cxx_abi_1/include/atb/utils.h atb/cxx_abi_1/include/atb/comm.h atb/cxx_abi_1/include/atb/operation_infra.h +atb/cxx_abi_1/include/atb/auto_fusion.h atb/cxx_abi_0/lib/libasdops.so atb/cxx_abi_0/lib/libasdops_static.a atb/cxx_abi_0/lib/libatb.so diff --git a/src/atb/core/auto_fusion_tool.cpp b/src/atb/core/auto_fusion_tool.cpp new file mode 100644 index 00000000..454dff2c --- /dev/null +++ b/src/atb/core/auto_fusion_tool.cpp @@ -0,0 +1,751 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "atb/core/auto_fusion_tool.h" +#include +#include +#include +#include +#include +#include "atb/operation/operation_base.h" +#include "atb/utils/log.h" +#include "mki/utils/file_system/file_system.h" +namespace atb { +Status CreateAutoFusionTool(atb::GraphParam &graph, atb::AutoFusion **autoFusion) +{ + *autoFusion = new (std::nothrow) AutoFusionTool(graph); + if (*autoFusion == nullptr) { + ATB_LOG(ERROR) << "failed to new operation"; + return ERROR_OUT_OF_HOST_MEMORY; + } + return NO_ERROR; +} + +//! +//! \brief auto fusion存储BiShengIR二进制的相关目录 +//! +#define AUTOFUSIONSTRPATH(path) std::string(path) + +//! +//! \brief BiShengIR工具的安装目录 +//! +#define AUTOFUSIONGETBISHENGPATH() std::getenv("BISHENG_INSTALL_PATH") + +//! +//! \brief BiShengIR工具的安装目录 +//! +typedef void (*TILING_FUNC_GET)(void *); + + +//! +//! \struct AutoFusionTilingData +//! +//! \brief BishengIR Host 二进制中tiling数据结构 +//! +struct AutoFusionTilingData { + int64_t key{0}; + int64_t mTile{0}; + int64_t nTile{0}; + int64_t kTile{0}; + int64_t processM{0}; + int64_t processN{0}; + int64_t processK{0}; + int64_t splitKSlices{1}; + + int64_t swizzleDir{0}; + int64_t swizzleCnt{1}; + int64_t shuffleKType{0}; + int64_t workspaceReuse{0}; + int64_t workspaceBufferNum{2}; + int64_t pTiles{0}; + int64_t ubMaxBitSize{0}; +}; + +//! +//! \struct AutoFusionKernelArgs +//! +//! \brief BishengIR Device二进制中tiling func传入参数的数据结构 +//! +struct AutoFusionKernelArgs { + void *xDevice; + void *xDeviceDup; + int64_t offsetX = 0; + int64_t sizeX0 = -1; + int64_t sizeX1 = -1; + int64_t strideX0 = -1; + int64_t strideX1 = -1; + + void *yDevice; + void *yDeviceDup; + int64_t offsetY = 0; + int64_t sizeY0 = -1; + int64_t sizeY1 = -1; + int64_t strideY0 = -1; + int64_t strideY1 = -1; + + void *vDevice; + void *vDeviceDup; + int64_t offsetV = 0; + int64_t sizeV0 = -1; + int64_t sizeV1 = -1; + int64_t strideV0 = -1; + int64_t strideV1 = -1; + + void *oDevice; + void *oDeviceDup; + int64_t offsetO = -1; + int64_t sizeO0 = -1; + int64_t sizeO1 = -1; + int64_t strideO0 = -1; + int64_t strideO1 = -1; + + void *tilingDevice; + void *tilingDeviceDup; + int64_t offsetTiling = 0; + int64_t sizeTiling = sizeof(AutoFusionTilingData); + int64_t strideTiling = 1; +}; + +AutoFusionTool::AutoFusionTool(atb::GraphParam &graph) : graph_(graph) +{ + homePath_ = std::string(std::getenv("HOME")); + homePath_ += AUTOFUSIONSTRPATH("/.atb_auto_fusion/bishengir_bin/"); +} + +void AutoFusionTool::SetFusionClass(const std::set &fusionClassArray) +{ + fusionClassMap_ = fusionClassArray; + return; +} + +void AutoFusionTool::genAllTensorIDs(std::set &allTensorIds, std::set &allOutTensorIds) +{ + size_t nodeSize = graph_.nodes.size(); + for (size_t i = 0; i < nodeSize; i++) { + const auto &node = graph_.nodes.at(i); + const auto &inTensorIds = node.inTensorIds; + const auto &outTensorIds = node.outTensorIds; + for (auto id : inTensorIds) { + allTensorIds.insert(id); + } + for (auto id : outTensorIds) { + allTensorIds.insert(id); + allOutTensorIds.insert(id); + } + } +} + +void AutoFusionTool::DoAutoFusion(const std::set &fusionClassArray) +{ + std::set allTensorIds; + std::set outTensorIds; + genAllTensorIDs(allTensorIds, outTensorIds); + if (fusionClassArray.empty()) { + return ; + } + SetFusionClass(fusionClassArray); + GetFusionBinAndUpdateFusedGraph(); + std::vector nodesNew; + size_t nodeSize = graph_.nodes.size(); + for (size_t i = 0; i < nodeSize; i++) { + const auto &node = graph_.nodes.at(i); + const auto &inTensorIds = node.inTensorIds; + const auto &outTensorIds = node.outTensorIds; + if (!(inTensorIds.size() == 0 && outTensorIds.size() == 0)) { + nodesNew.push_back(node); + } + } + graph_.nodes = nodesNew; + std::set allTensorIdsNew; + std::set allOutTensorIdsNew; + genAllTensorIDs(allTensorIdsNew, allOutTensorIdsNew); + updateAllTensorIDs(outTensorIds, allOutTensorIdsNew); +} + +void AutoFusionTool::updateAllTensorIDs(const std::set &outTensorIds, const std::set &allOutTensorIdsNew) +{ + auto it = outTensorIds.begin(); + auto itNew = allOutTensorIdsNew.begin(); + while (itNew != outTensorIds.end() && *it == *itNew) { + itNew++; + it++; + } + std::unordered_map chageMap; + int interErsed = outTensorIds.size() - allOutTensorIdsNew.size(); + graph_.internalTensorNum -= interErsed; + while (itNew != allOutTensorIdsNew.end()) { + chageMap[*itNew] = *it; + itNew++; + it++; + } + size_t nodeSize = graph_.nodes.size(); + for (size_t i = 0; i < nodeSize; i++) { + auto &node = graph_.nodes.at(i); + auto &inTensorIds = node.inTensorIds; + auto &outTensorIds = node.outTensorIds; + for (auto &id : inTensorIds) { + if (chageMap.find(id) != chageMap.end()) { + id = chageMap[id]; + } + } + for (auto &id : outTensorIds) { + if (chageMap.find(id) != chageMap.end()) { + id = chageMap[id]; + } + } + } +} + +void AutoFusionTool::findNodesWhichInputsIs(const uint32_t id, std::vector &re) +{ + size_t counter = graph_.nodes.size(); + for (size_t i = 0; i < counter; i++) { + auto inTensorIds = graph_.nodes[i].inTensorIds; + for (auto inId : inTensorIds) { + if (inId == id) { + re.push_back(i); + break; + } + } + } + return ; +} + +void AutoFusionTool::parseCollectNodes(std::vector>> &fusionclassAndIndex) +{ + for (auto i : linearNodes_) { + const auto &outTensorIds = graph_.nodes[i].outTensorIds; + std::vector re; + findNodesWhichInputsIs(outTensorIds[0], re); + if (re.size() != 1) { + linearNodes_.erase(linearNodes_.find(i)); + continue; + } else { + // here + uint32_t key = re[0]; + bool flagAdd = (eleAddNodes_.find(key) != eleAddNodes_.end()); + bool flagGelu = (actGeluNodes_.find(key) != actGeluNodes_.end()); + bool flagSigmoid = (actSigmoidNodes_.find(key) != actSigmoidNodes_.end()); + bool flagSwiglu = (actSwiGluNodes_.find(key) != actSwiGluNodes_.end()); + if (true == flagAdd) { + fusionclassAndIndex.push_back( + std::make_pair>("matmul_add", {i, key})); + } else if (true == flagGelu) { + fusionclassAndIndex.push_back( + std::make_pair>("matmul_gelu", {i, key})); + } else if (true == flagSigmoid) { + fusionclassAndIndex.push_back( + std::make_pair>("matmul_sigmoid", {i, key})); + } else if (true == flagSwiglu) { + fusionclassAndIndex.push_back( + std::make_pair>("matmul_swiglu", {i, key})); + } + } + } + for (auto i : fusionclassAndIndex) { + ATB_LOG(INFO) << "fusion type = " << i.first; + ATB_LOG(INFO) << "{"; + for (auto j : i.second) { + ATB_LOG(INFO) << j << ","; + } + ATB_LOG(INFO) << "}"; + } + return ; +} + +void AutoFusionTool::CollectNodes() +{ + const auto &nodes = graph_.nodes; + bool ismatmulAdd = (fusionClassMap_.find("matmul_add") != fusionClassMap_.end()); + bool ismatmulGelu = (fusionClassMap_.find("matmul_gelu") != fusionClassMap_.end()); + if (!ismatmulAdd && !ismatmulGelu) { + return; + } + // here + size_t counter = nodes.size(); + for (size_t i = 0; i < counter; i++) { + OperationBase *opBase = dynamic_cast(nodes[i].operation); + nlohmann::json paramJson = opBase->GetParamJson(); + size_t inTensorSize = nodes[i].inTensorIds.size(); + size_t outTensorSize = nodes[i].outTensorIds.size(); + if (opBase->GetName() == "LinearOperation" && inTensorSize == 2 && + outTensorSize == 1) { + if (paramJson["matmulType"] == 0) { + linearNodes_.insert(i); + } + } + if (ismatmulAdd == true && opBase->GetName() == "ElewiseOperation" && inTensorSize == 2 && + outTensorSize == 1) { + if (paramJson["elewiseType"] == atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD) { + eleAddNodes_.insert(i); + } + } + if (ismatmulGelu == true && opBase->GetName() == "ActivationOperation" && inTensorSize == 1 && + outTensorSize == 1) { + if (paramJson["activationType"] == atb::infer::ActivationType::ACTIVATION_GELU) { + actGeluNodes_.insert(i); + } + } + } + return; +} + +void AutoFusionTool::ParseFusion(std::vector>> &fusionclassAndIndex) +{ + CollectNodes(); + return parseCollectNodes(fusionclassAndIndex); +} + +void AutoFusionTool::GetFusionBinAndUpdateFusedGraph() +{ + std::vector>> fusionclassAndIndex; + ParseFusion(fusionclassAndIndex); + std::vector> fusionclassAndBin; + // 保留 + for (auto i : fusionclassAndIndex) { + auto path = callBiShengIR(i) ? homePath_ : ""; + fusionclassAndBin.push_back(std::make_pair(i.first, path)); + } + + UpdateFusedGraph(fusionclassAndIndex, fusionclassAndBin); + return; +} + +bool AutoFusionTool::callBiShengIR(const std::pair> &subFusion) +{ + // 以mlir为入手点 + std::string fusionClass = subFusion.first; + std::string hostBinPath = homePath_ + "lib" + fusionClass + ".so"; + // here + std::string preFill = homePath_ + fusionClass; + std::string deviceBinPath = preFill + ".o"; + std::string jsonPath = preFill + ".json"; + std::string cppPath = preFill + ".cpp"; + if (Mki::FileSystem::Exists(hostBinPath) && Mki::FileSystem::Exists(deviceBinPath) && + Mki::FileSystem::Exists(jsonPath) && Mki::FileSystem::Exists(cppPath)) { + return true; + } + if (!AUTOFUSIONGETBISHENGPATH()) { + return false; + } + if (!Mki::FileSystem::Exists(homePath_)) { + std::string cmd = "mkdir -p " + homePath_; + auto ret = system(cmd.c_str()); + if (ret != 0) { + ATB_LOG(ERROR) << "BishengIR folder create failed!"; + return false; + } + } + genMlirAndBin(fusionClass); + genFusionKernelDef(fusionClass); + return true; +} + +void AutoFusionTool::UpdateFusedGraph(std::vector>> &fusionclassAndIndex, + const std::vector> &fusionclassAndBin) +{ + size_t counter = fusionclassAndBin.size(); + for (size_t i = 0; i < counter; i++) { + if (fusionclassAndBin[i].second == "") { + continue; + } + SortNodeByTensorIdsAndChangeTopology(fusionclassAndIndex[i].second, fusionclassAndIndex[i].first); + } + return ; +} + +void AutoFusionTool::SortNodeByTensorIdsAndChangeTopology(std::vector &fusedNodes, const std::string &fusionClass) +{ + std::sort(fusedNodes.begin(), fusedNodes.end(), [this](int index1, int index2) { + bool flag = false; + for (auto &subIndex : graph_.nodes[index2].inTensorIds) { + for (auto &subIndex1 : graph_.nodes[index1].outTensorIds) { + if (subIndex1 == subIndex && uint32_t(-1) != subIndex1) { + subIndex1 = -1; + subIndex = -1; + flag = true; + } + } + } + if (flag) { + return flag; + } + for (auto &subIndex : graph_.nodes[index1].inTensorIds) { + for (auto &subIndex1 : graph_.nodes[index2].outTensorIds) { + if (subIndex1 == subIndex && uint32_t(-1) != subIndex1) { + subIndex1 = -1; + subIndex = -1; + flag = false; + } + } + } + return flag; + }); + UpdateGraphStruct(fusedNodes, fusionClass); + return; +} + +void AutoFusionTool::SetFusionParam(atb::infer::FusionParam ¶m, const std::string &fusionClass) +{ + if (fusionClass == "matmul_add") { + param.fusionType = atb::infer::FusionParam::MATMUL_ADD; + } else if (fusionClass == "matmul_gelu") { + param.fusionType = atb::infer::FusionParam::MATMUL_GELU; + } else if (fusionClass == "matmul_sigmoid") { + param.fusionType = atb::infer::FusionParam::MATMUL_SIGMOID; + } else if (fusionClass == "matmul_swiglu") { + param.fusionType = atb::infer::FusionParam::MATMUL_SWIGLU; + } else { + param.fusionType = atb::infer::FusionParam::NON_FUSION; + } + return; +} + +void AutoFusionTool::UpdateGraphStruct(const std::vector &fusedNodes, const std::string &fusionClass) +{ + auto &preInTensorIds = graph_.nodes[fusedNodes[0]].inTensorIds; + auto &preOutTensorIds = graph_.nodes[fusedNodes[0]].outTensorIds; + size_t counter = fusedNodes.size(); + for (size_t i = 1; i < counter; i++) { + auto ¤tInTensorIds = graph_.nodes[fusedNodes[i]].inTensorIds; + for (auto subIndex : currentInTensorIds) { + if (subIndex != uint32_t(-1)) { + preInTensorIds.push_back(subIndex); + } + } + preOutTensorIds = (graph_.nodes[fusedNodes[i]].outTensorIds); + } + atb::infer::FusionParam param; + SetFusionParam(param, fusionClass); + // here + UpdateReshapeFunc(fusedNodes); + atb::Operation *rawPtr = nullptr; + CreateOperation(param, &rawPtr); + if (rawPtr == nullptr) { + throw std::runtime_error("add operation is null."); + } + if (graph_.nodes[fusedNodes[0]].operation) { + delete graph_.nodes[fusedNodes[0]].operation; + graph_.nodes[fusedNodes[0]].operation = nullptr; + } + graph_.nodes[fusedNodes[0]].operation = rawPtr; + counter = fusedNodes.size(); + for (size_t i = 1; i < counter; i++) { + ATB_LOG(INFO) << "null fusion operation created"; + atb::infer::FusionParam param; + atb::Operation *rawPtr = nullptr; + CreateOperation(param, &rawPtr); + if (rawPtr == nullptr) { + throw std::runtime_error("add operation is null."); + } + if (graph_.nodes[fusedNodes[i]].operation) { + delete graph_.nodes[fusedNodes[i]].operation; + graph_.nodes[fusedNodes[i]].operation = nullptr; + } + graph_.nodes[fusedNodes[i]].operation = rawPtr; + graph_.nodes[fusedNodes[i]].inTensorIds.clear(); + graph_.nodes[fusedNodes[i]].outTensorIds.clear(); + } + ATB_LOG(INFO) << " UpdateGraphStruct Done!"; +} + +void AutoFusionTool::UpdateReshapeFunc(const std::vector &fusedNodes) +{ + auto &inTensorReshapeFuncs = graph_.nodes[fusedNodes[0]].inTensorReshapeFuncs; + size_t counter = fusedNodes.size(); + for (size_t i = 1; i < counter; i++) { + size_t funcsCounter = graph_.nodes[fusedNodes[i]].inTensorReshapeFuncs.size(); + for (size_t j = 1; j < funcsCounter; j++) { + inTensorReshapeFuncs.push_back(graph_.nodes[fusedNodes[i]].inTensorReshapeFuncs[j]); + } + } + return ; +} + +std::string AutoFusionTool::getTilingKey(const std::string &fusionClass) +{ + std::string path = homePath_; + if ("matmul_add" == fusionClass) { + path += "libmatmul_add.so"; + } else if ("matmul_gelu" == fusionClass) { + path += "libmatmul_gelu.so"; + } else if ("matmul_sigmoid" == fusionClass) { + path += "libmatmul_sigmoid.so"; + } else if ("matmul_swiglu" == fusionClass) { + path += "libmatmul_swiglu.so"; + } + AutoFusionTilingData *tilingDataPtr = new AutoFusionTilingData; + void *handle = dlopen(path.c_str(), RTLD_LAZY); + if (!handle) { + ATB_LOG(ERROR) << "host tiling load error!"; + } + TILING_FUNC_GET tiling_func = nullptr; + std::string tiling_func_name = fusionClass + "_tiling_func"; + *(void **)(&tiling_func) = dlsym(handle, tiling_func_name.c_str()); + AutoFusionKernelArgs *autoFusionKernelArgs = new AutoFusionKernelArgs; + autoFusionKernelArgs->tilingDevice = static_cast(tilingDataPtr); + autoFusionKernelArgs->tilingDeviceDup = autoFusionKernelArgs->tilingDevice; + + tiling_func((void *)autoFusionKernelArgs); + std::string key = std::to_string(tilingDataPtr->key); + delete tilingDataPtr; + delete autoFusionKernelArgs; + return key; +} +bool AutoFusionTool::genFusionKernelDef(const std::string &fusionClass) +{ + nlohmann::json fusionJson{{"binFileName", "matmul_add"}, + {"binFileSuffix", ".o"}, + {"blockDim", 40}, + {"coreType", "MIX"}, + {"core_type", "MIX"}, + {"intercoreSync", 0}, + {"magic", "RT_DEV_BINARY_MAGIC_ELF"}, + {"memoryStamping", {}}, + {"opParaSize", 0}, + {"parameters", {}}, + {"sha256", ""}, + {"kernelList", {}}, + {"compileInfo", nlohmann::json::object()}}; + fusionJson["binFileName"] = nlohmann::json::array({}); + fusionJson["parameters"] = nlohmann::json::array({}); + fusionJson["memoryStamping"] = nlohmann::json::array({}); + fusionJson["kernelList"] = nlohmann::json::array({}); + std::string key = getTilingKey(fusionClass); + if (key != "") { + key = "_" + key; + } + if ("matmul_add" == fusionClass) { + fusionJson["binFileName"] = "matmul_add"; + nlohmann::json JObject = {{"kernelName", "matmul_add" + key}}; + fusionJson["kernelList"].push_back(JObject); + } else if ("matmul_gelu" == fusionClass) { + fusionJson["binFileName"] = "matmul_gelu"; + nlohmann::json JObject = {{"kernelName", "matmul_gelu" + key}}; + fusionJson["kernelList"].push_back(JObject); + } else if ("matmul_sigmoid" == fusionClass) { + fusionJson["binFileName"] = "matmul_sigmoid"; + nlohmann::json JObject = {{"kernelName", "matmul_sigmoid" + key}}; + fusionJson["kernelList"].push_back(JObject); + } else if ("matmul_swiglu" == fusionClass) { + fusionJson["binFileName"] = "matmul_swiglu"; + nlohmann::json JObject = {{"kernelName", "matmul_swiglu" + key}}; + fusionJson["kernelList"].push_back(JObject); + } + std::string jsonPath = homePath_ + fusionClass + ".json"; + std::ofstream file(jsonPath); + file << fusionJson.dump(4); + file.close(); + std::string pyImport = "import os\nimport configparser\nimport json\nimport struct\nimport logging\n"; + std::string aligned = "def aligned_string(s:str, align:int) -> str:\n\ + width = (len(s) // align + 1) * align\n\ + return s.ljust(width, '\\0')\n"; + std::string pyGetHeader = + "def get_header_from_file(file_path):\n result = True\n magic_dict = {\"RT_DEV_BINARY_MAGIC_ELF\": 0x43554245,\"RT_DEV_BINARY_MAGIC_ELF_AIVEC\": 0x41415246,\"RT_DEV_BINARY_MAGIC_ELF_AICUBE\": 0x41494343}\n core_type_dict = {\"AiCore\": 0, \"VectorCore\": 2, \"MIX\": 4}\n aling_bytes = struct.calcsize('I')\n fixed_header_len = 128\n header = b''\n try:\n with open(file_path) as f:\n text = json.load(f) \n\ + version = 0 \n\ + crc = 0 \n\ + compile_info_str = aligned_string(json.dumps(text[\"compileInfo\"]), aling_bytes) \n\ + op_para_size = text[\"opParaSize\"] \n\ + core_type = core_type_dict.get(text[\"coreType\"], 0) \n\ + magic_type = text[\"magic\"] \n\ + if magic_type not in magic_dict: \n\ + logging.error(\"magic %s is invalid\", magic_type) \n\ + result = False \n\ + else: \n\ + magic = magic_dict[magic_type] \n\ + kernel_list = [] \n\ + if \"kernelList\" in text: \n\ + for kernel_item in text[\"kernelList\"]: \n\ + kernel_list.append(aligned_string(kernel_item[\"kernelName\"], aling_bytes)) \n\ + else: \n\ + kernel_list.append(aligned_string(text[\"kernelName\"], aling_bytes)) \n\ + kernel_num = len(kernel_list) \n\ + if kernel_num == 0: \n\ + result = False \n\ + header = struct.pack('I', version) + struct.pack('I', magic) + struct.pack('I', op_para_size) + struct.pack('I', core_type) + struct.pack('I', kernel_num) \n\ + offset = 0 \n\ + kernel_name_offset = offset \n\ + for kernel_name in kernel_list: \n\ + offset += (aling_bytes + len(kernel_name)) \n\ + compile_info_offset = offset \n\ + offset += (aling_bytes + len(compile_info_str)) \n\ + binary_offset = offset \n\ + header = header + struct.pack('I', kernel_name_offset) + struct.pack('I', compile_info_offset) + struct.pack('I', binary_offset) \n\ + intercore_sync = text.get(\"intercoreSync\", 0) \n\ + task_ration_type = text.get(\"taskRation\", \"tilingKey\") \n\ + if task_ration_type == \"tilingKey\": \n\ + task_ration = 0 \n\ + else: \n\ + ration = [int(r) for r in task_ration_type.split(\":\")] \n\ + if len(ration) != 2: \n\ + logging.error(f\"ration is invalid: {task_ration_type}\") \n\ + result = False \n\ + task_ration = (ration[0] << 16) + ration[1] \n\ + header = header + struct.pack('I', intercore_sync) + struct.pack('I', task_ration) \n\ + header = header.ljust(fixed_header_len - aling_bytes, b'\\x00') \n\ + header += struct.pack('I', crc) \n\ + for kernel_name in kernel_list: \n\ + header += struct.pack('I', len(kernel_name)) \n\ + header += kernel_name.encode('utf-8') \n\ + header += struct.pack('I', len(compile_info_str)) \n\ + header += compile_info_str.encode('utf-8') \n\ + except FileNotFoundError: \n\ + logging.error(\"file %s is not found!\", file_path) \n\ + result = False \n\ + except json.decoder.JSONDecodeError: \n\ + logging.error(\"file %s is not json file!\", file_path) \n\ + result = False \n\ + except KeyError: \n\ + logging.error(\"keyerror in file %s!\", file_path) \n\ + result = False \n\ + return header, result\n"; + // here + std::string pyWriteCpp = + "def write_to_cpp(binary_path, header, dst_cpp_path, kernel, target_version, is_const=True):\n try:\n " + " with open(binary_path, 'rb') as f:\n data = f.read()\n binary_size = len(data)\n " + " header += struct.pack('I', binary_size)\n data = header + data\n except " + "FileNotFoundError:\n logging.error(\"file %s is not found!\", binary_path)\n return False\n " + "name = f'KERNELBIN_{kernel.upper()}_{target_version.upper()}'\n data_type = 'const uint8_t' if is_const " + "else 'uint8_t'\n with open(dst_cpp_path, 'w') as f:\n for i in range(0, len(data), 1):\n " + "f.write(''.join('{:02x}'.format(b) for b in data[i:i+1]))\n f.write('\\n')\n return True\n"; + + std::string finalPy = "def compile_ascendc_code(obj_path, dst_cpp_path, is_const=True):\n\ + if not obj_path.endswith('.o'): \n\ + logging.error(\"%s is not an obj file.\", obj_path) \n\ + exit(1) \n\ + json_file = obj_path.rsplit('.', 1)[0] + '.json' \n\ + header, result = get_header_from_file(json_file) \n\ + if not result: \n\ + logging.error(\"failed to parse file %s.\", json_file) \n\ + exit(1) \n\ + obj_realpath = os.path.realpath(obj_path) \n\ + kernel = obj_path.split('/')[-2] \n\ + target_version = obj_realpath.split('/')[-4] \n\ + output_dir = os.path.dirname(dst_cpp_path) \n\ + if not os.path.exists(output_dir): \n\ + os.makedirs(output_dir, exist_ok=True) \n\ + result = write_to_cpp(obj_path, header, dst_cpp_path, kernel, target_version, is_const) \n\ + if not result: \n\ + logging.error(\"failed to write into file %s.\", dst_cpp_path) \n\ + exit(1) \n"; + std::string cmd1 = "compile_ascendc_code(\"" + homePath_ + fusionClass + + ".o\"," + "\"" + + homePath_ + fusionClass + ".cpp" + "\")"; + std::string cmd2 = "python3 " + homePath_ + "genCpp.py"; + std::string pypath = homePath_ + "genCpp.py"; + std::ofstream outfile; + outfile.open(pypath.c_str(), std::ios::out | std::ios::trunc); + outfile << pyImport << aligned << pyGetHeader << pyWriteCpp << finalPy << cmd1; + outfile.close(); + auto ret = system(cmd2.c_str()); + if (ret != 0) { + return false; + } + return true; +} + + +bool AutoFusionTool::genMlirAndBin(const std::string &fusionClass) +{ + int res = -1; + if ("matmul_add" == fusionClass) { + std::string matmul_add_mlir = "module { \ + func.func @matmul_add(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor \ + attributes {hacc.entry, hacc.function_kind = #hacc.function_kind, hfusion.fusion_kind = #hfusion.fusion_kind} { \ + %c0 = arith.constant 0 : index \ + %c1 = arith.constant 0 : index \ + %dim = tensor.dim %arg0, %c0 : tensor \ + %dim2 = tensor.dim %arg1, %c1 : tensor \ + %0 = tensor.empty(%dim, %dim2) : tensor \ + %1 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) -> tensor \ + %2 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1, %arg2 : tensor, tensor) outs(%arg3 : tensor) -> tensor \ + return %2 : tensor \ + } \ + }"; + std::ofstream outfile; + auto mlirPath = homePath_ + "matmul_add.mlir"; + outfile.open(mlirPath.c_str(), std::ios::out | std::ios::trunc); + outfile << matmul_add_mlir; + outfile.close(); + std::string matmul_add_bash = + "#!/bin/bash\npath=$HOME'/.atb_auto_fusion/bishengir_bin/'\narch=$(uname -i)\nif [ \"$arch\" != " + "\"aarch64\" ]; then\n arch=\'x86\'\nfi\necho " + "\'Arch is \'$arch\', BishengIR \'$arch\' should installed ...\'\nBISHENG_PATH=$BISHENG_PATH\necho " + "$BISHENG_PATH\nbishengir_opt=$BISHENG_PATH$\'bishengir_\'$arch\'/bin/" + "bishengir-opt\'\nbishengir_compile=$BISHENG_PATH$\'bishengir_\'$arch\'/bin/bishengir-compile\'\nexport " + "LD_LIBRARY_PATH=$BISHENG_PATH:$LD_LIBRARY_PATH\nexport " + "PATH=$BISHENG_PATH/ccec_compiler_$arch/bin:$PATH\nexport " + "BISHENG_INSTALL_PATH=$BISHENG_PATH/ccec_compiler_$arch/bin\necho $bishengir_opt\necho " + "$bishengir_compile\n$bishengir_opt -lower-hfusion-pipeline=\"block-dim=40 enable-workspace-reuse=true\" " + "-convert-hfusion-to-hivm -cse $path/matmul_add.mlir -o $path/matmul_add_lower.mlir\n$bishengir_compile " + "-enable-lir-compile=true -enable-hfusion-compile=false -enable-hivm-compile=true " + "-enable-multi-kernel-compile $path/matmul_add_lower.mlir -o $path/matmul_add"; + mlirPath = homePath_ + "gen_matmul_add_bin.sh"; + outfile.open(mlirPath.c_str(), std::ios::out | std::ios::trunc); + outfile << matmul_add_bash; + outfile.close(); + res = system(("bash " + homePath_ + AUTOFUSIONSTRPATH("gen_matmul_add_bin.sh")).c_str()); + } else if ("matmul_gelu" == fusionClass) { + std::string matmul_gelu_mlir = "module { \ + func.func @matmul_gelu(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {hacc.entry, hacc.function_kind = #hacc.function_kind, hfusion.fusion_kind = #hfusion.fusion_kind} {\ + %c0 = arith.constant 0 : index \ + %c1 = arith.constant 0 : index \ + %c2 = arith.constant 1 : index \ + %cst = arith.constant 4.470830e-02 : f16 \ + %cst_0 = arith.constant -1.595700e+00 : f16 \ + %cst_1 = arith.constant 1.000000e+00 : f16 \ + %dim = tensor.dim %arg0, %c0 : tensor \ + %dim2 = tensor.dim %arg1, %c1 : tensor \ + %0 = tensor.empty(%dim, %dim2) : tensor \ + %1 = linalg.matmul_transpose_b ins(%arg0, %arg1 : tensor, tensor) outs(%0 : tensor) -> tensor \ + %dim_2 = tensor.dim %1, %c0 : tensor \ + %dim_3 = tensor.dim %1, %c2 : tensor \ + %2 = tensor.empty(%dim_2, %dim_3) : tensor \ + %3 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%1, %1 : tensor, tensor) outs(%2 : tensor) -> tensor \ + %4 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%3, %1 : tensor, tensor) outs(%2 : tensor) -> tensor \ + %5 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%4, %cst : tensor, f16) outs(%2 : tensor) -> tensor \ + %6 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%5, %1 : tensor, tensor) outs(%2 : tensor) -> tensor \ + %7 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%6, %cst_0 : tensor, f16) outs(%2 : tensor) -> tensor \ + %8 = linalg.elemwise_unary {fun = #linalg.unary_fn} ins(%7 : tensor) outs(%2 : tensor) -> tensor \ + %9 = linalg.elemwise_binary {fun = #linalg.binary_fn} ins(%8, %cst_1 : tensor, f16) outs(%2 : tensor) -> tensor \ + %10 = linalg.elemwise_binary {fun = #linalg.binary_fn
} ins(%1, %9 : tensor, tensor) outs(%2 : tensor) -> tensor \ + return %10 : tensor \ + } \ + }"; + std::ofstream outfile; + auto mlirPath = homePath_ + "matmul_gelu.mlir"; + outfile.open(mlirPath.c_str(), std::ios::out | std::ios::trunc); + outfile << matmul_gelu_mlir; + outfile.close(); + std::string matmul_gelu_bash = + "#!/bin/bash\npath=$HOME'/.atb_auto_fusion/bishengir_bin/'\narch=$(uname -i)\nif [ \"$arch\" != " + "\"aarch64\" ]; then\n arch=\'x86\'\nfi\necho " + "\'Arch is \'$arch\', BishengIR \'$arch\' should installed ...\'\nBISHENG_PATH=$BISHENG_PATH\necho " + "$BISHENG_PATH\nbishengir_opt=$BISHENG_PATH$\'bishengir_\'$arch\'/bin/" + "bishengir-opt\'\nbishengir_compile=$BISHENG_PATH$\'bishengir_\'$arch\'/bin/bishengir-compile\'\nexport " + "LD_LIBRARY_PATH=$BISHENG_PATH:$LD_LIBRARY_PATH\nexport " + "PATH=$BISHENG_PATH/ccec_compiler_$arch/bin:$PATH\nexport " + "BISHENG_INSTALL_PATH=$BISHENG_PATH/ccec_compiler_$arch/bin\necho $bishengir_opt\necho " + "$bishengir_compile\n$bishengir_opt -lower-hfusion-pipeline=\"block-dim=40 enable-workspace-reuse=true\" " + "-convert-hfusion-to-hivm -cse $path/matmul_gelu.mlir -o $path/matmul_gelu_lower.mlir\n$bishengir_compile " + "-enable-lir-compile=true -enable-hfusion-compile=false -enable-hivm-compile=true " + "-enable-multi-kernel-compile $path/matmul_gelu_lower.mlir -o $path/matmul_gelu"; + mlirPath = homePath_ + "gen_matmul_gelu_bin.sh"; + outfile.open(mlirPath.c_str(), std::ios::out | std::ios::trunc); + outfile << matmul_gelu_bash; + outfile.close(); + res = system(("bash " + homePath_ + AUTOFUSIONSTRPATH("gen_matmul_gelu_bin.sh")).c_str()); + } + if (res != 0) { + return false; + } + return true; +} +} // namespace atb \ No newline at end of file diff --git a/src/atb/core/context_base.cpp b/src/atb/core/context_base.cpp index ae0812d8..00583982 100644 --- a/src/atb/core/context_base.cpp +++ b/src/atb/core/context_base.cpp @@ -353,5 +353,12 @@ bool ContextBase::GetLaunchWithTilingStatus() { return mode_ != GRAPH_LAUNCH_MODE; } - +void ContextBase::SetAutoFusionFlag(bool autoFusionFlag) +{ + autoFusionFlag_ = autoFusionFlag; +} +bool ContextBase::GetAutoFusionFlag() const +{ + return autoFusionFlag_; +} } // namespace atb diff --git a/src/atb/core/node_impl/mki_node_implement.cpp b/src/atb/core/node_impl/mki_node_implement.cpp index dcdc5462..6e7a48d0 100644 --- a/src/atb/core/node_impl/mki_node_implement.cpp +++ b/src/atb/core/node_impl/mki_node_implement.cpp @@ -10,6 +10,9 @@ #include "atb/core/node_impl/mki_node_implement.h" #include #include +#include +#include +#include #include "atb/utils/log.h" #include "atb/utils/config.h" #include "atb/utils/tensor_util.h" @@ -87,6 +90,21 @@ bool MkiNodeImplement::BuildLaunchParam(const SVector &inTensors, bool MkiNodeImplement::OperationGetBestKernel() { Mki::Kernel *kernel = operation_->GetBestKernel(launchParam_); + if (kernel == nullptr) { + if (operation_->GetName().find("Fusion") != std::string::npos || operation_->GetName().find("fusion") != std::string::npos) { + auto &fusionType = launchParam_.GetParam(); + if (AtbOps::OpParam::Fusion::NON_FUSION == fusionType.fusionType) { + // 这个是算子自动融合之后,非主要的融合算子,也就是theMain为false的算子, + // 针对这种算子不进行动态注册,不做处理直接越过 + return true; + } + // 如果不是上面类型的融合算子,调度注册融合算子二进制的接口进行融合算子的注册 + operation_->DynamicRegisterKernelByName(launchParam_, operation_->GetName()); + AtbOps::Ops::Instance().UpdateSchedule(); + // AsdOps::Ops::Instance().UpdateSchedule(); + kernel = operation_->GetBestKernel(launchParam_); + } + } if (kernel == nullptr) { ATB_LOG(ERROR) << GetLogPrefix() << " " << operation_->GetName() << " get best kernel fail, kernel count:" << operation_->GetKernelList().size(); diff --git a/src/atb/operation/operation_base.cpp b/src/atb/operation/operation_base.cpp index aa92c305..42a1cca9 100644 --- a/src/atb/operation/operation_base.cpp +++ b/src/atb/operation/operation_base.cpp @@ -849,7 +849,7 @@ Status OperationBase::PreExecuteThrow(const VariantPack &variantPack, uint8_t *w UpdateTensorData(variantPack, workspace); Status st = NO_ERROR; - if (!(runnerVariantPack_.context->GetLaunchWithTilingStatus())) { + if (!(runnerVariantPack_.context->GetLaunchWithTilingStatus()) || runnerVariantPack_.context->GetAutoFusionFlag()) { st = CopyTilingToDevice(); if (st != 0) { return st; diff --git a/src/atb/runner/graph_runner.cpp b/src/atb/runner/graph_runner.cpp index 62df27e6..039de05f 100644 --- a/src/atb/runner/graph_runner.cpp +++ b/src/atb/runner/graph_runner.cpp @@ -280,6 +280,9 @@ Status GraphRunner::SetupNodes(const RunnerVariantPack &runnerVariantPack) Status st = NO_ERROR; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } st = PreparseNodeVariantPack(nodeId, node, runnerVariantPack, nodeHostTilingBuffer, maxTilingSize); if (st != 0) { @@ -310,6 +313,9 @@ Status GraphRunner::SetupImpl(RunnerVariantPack &runnerVariantPack) } for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } ReserveSvector(node); } if (runnerGraph_.inTensors.size() != runnerVariantPack.inTensors.size() || @@ -359,6 +365,9 @@ Status GraphRunner::FillHostTilingBufferImpl(uint8_t *hostTilingBuffer, uint64_t uint64_t tilingOffset = 0; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } Status ret = node.runner->FillHostTilingBuffer(hostTilingBuffer + tilingOffset, tilingBufferSizes_.at(nodeId), context); if (ret != NO_ERROR) { @@ -375,6 +384,9 @@ std::vector &GraphRunner::GetWorkspaceBufferSize() { for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } const std::vector &runnerWorkspaceBufferSize = node.runner->GetWorkspaceBufferSize(); for (size_t i = 0; i < runnerWorkspaceBufferSize.size(); ++i) { multiStreamWorkspaceSizes_.at(i) = @@ -435,6 +447,9 @@ void GraphRunner::SetSaveTensorDir(const std::string &tensorDir) tensorDir_ = tensorDir; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); nodeId++) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } node.runner->SetSaveTensorDir(tensorDir + "/" + std::to_string(nodeId) + "_" + node.runner->operationName_); } } @@ -809,6 +824,9 @@ void GraphRunner::CalcTilingBufferSize() tilingBufferSizes_.resize(runnerGraph_.nodes.size()); for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } uint64_t runnerTilingBufferSize = node.runner->GetTilingBufferSize(); ATB_LOG(INFO) << GetLogPrefix() << " node[" << nodeId << "] tiling buffer size:" << runnerTilingBufferSize; totalTilingBufferSize_ += runnerTilingBufferSize; @@ -824,12 +842,18 @@ void GraphRunner::CalcIntermediateBufferSize() if (GetSingleton().IsworkspaceMemAllocGlobal()) { // 全局mem alloc时,所有runner共用一份内存 maxIntermediateBufferSize_ = memAllocationSolver_->GetSize(); for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { + if (runnerGraph_.nodes.at(nodeId).runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } intermediateBufferSizes_.at(nodeId) = maxIntermediateBufferSize_; } return; } for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } uint64_t runnerIntermediateBufferSize = node.runner->GetIntermediateBufferSize(); intermediateBufferSizes_.at(nodeId) = runnerIntermediateBufferSize; maxIntermediateBufferSize_ = std::max(maxIntermediateBufferSize_, runnerIntermediateBufferSize); @@ -846,6 +870,9 @@ void GraphRunner::UpdateVariantPackBuffer(RunnerVariantPack &runnerVariantPack) uint64_t offset = 0; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } node.runnerVariantPack.tilingBuffer = runnerVariantPack.tilingBuffer + offset; node.runnerVariantPack.tilingBufferSize = tilingBufferSizes_.at(nodeId); offset += tilingBufferSizes_.at(nodeId); @@ -856,11 +883,17 @@ void GraphRunner::UpdateVariantPackBuffer(RunnerVariantPack &runnerVariantPack) for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } node.runnerVariantPack.workspaceBuffer = runnerVariantPack.workspaceBuffer; } for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } node.runnerVariantPack.intermediateBuffer = runnerVariantPack.intermediateBuffer + selfIntermediateBufferSize_; node.runnerVariantPack.intermediateBufferSize = intermediateBufferSizes_.at(nodeId); } @@ -869,6 +902,9 @@ void GraphRunner::UpdateVariantPackBuffer(RunnerVariantPack &runnerVariantPack) uint64_t offset = 0; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } node.runnerVariantPack.argsDeviceBuffer = runnerVariantPack.argsDeviceBuffer + offset; offset += node.runner->GetArgsSize(); ATB_LOG(DEBUG) << GetLogPrefix() << "Graph node " << nodeId << " argsDeviceAddr is " @@ -880,6 +916,9 @@ void GraphRunner::UpdateVariantPackBuffer(RunnerVariantPack &runnerVariantPack) uint64_t offset = 0; for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } node.runnerVariantPack.argsHostBuffer = runnerVariantPack.argsHostBuffer + offset; offset += node.runner->GetArgsSize(); ATB_LOG(DEBUG) << GetLogPrefix() << "Graph node " << nodeId << " argsHostAddr is " @@ -908,6 +947,9 @@ void GraphRunner::UpdateVariantPackTensorData(RunnerVariantPack &runnerVariantPa for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } ATB_LOG(INFO) << GetLogPrefix() << " update tensor.data node[" << nodeId << "]"; for (size_t i = 0; i < node.runnerVariantPack.inTensors.size(); ++i) { auto &tensor = node.runnerVariantPack.inTensors.at(i); @@ -943,6 +985,9 @@ Status GraphRunner::ExecuteAllRunner(RunnerVariantPack &runnerVariantPack) { for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } ATB_LOG(INFO) << GetLogPrefix() << " mstx registe tensor.data node[" << nodeId << "]" << "graphrunner start"; if (runnerVariantPack.mstxMemRegister != nullptr && !(dynamic_cast(node.runner.get()))) { runnerVariantPack.mstxMemRegister->ClearMstxMemRegions(); @@ -985,6 +1030,9 @@ Status GraphRunner::PreExecuteAllRunner(RunnerVariantPack &runnerVariantPack) { for (size_t nodeId = 0; nodeId < runnerGraph_.nodes.size(); ++nodeId) { auto &node = runnerGraph_.nodes.at(nodeId); + if (node.runner->GetName() == "NOT_MAIN_FUSION") { + continue; + } ATB_LOG(INFO) << GetLogPrefix() << " node[" << nodeId << "] PreExecute start, runner:" << node.runner->GetName(); node.runnerVariantPack.context = runnerVariantPack.context; diff --git a/src/atb/runner/ops_runner.cpp b/src/atb/runner/ops_runner.cpp index eb754df9..b4c44397 100644 --- a/src/atb/runner/ops_runner.cpp +++ b/src/atb/runner/ops_runner.cpp @@ -169,6 +169,9 @@ bool OpsRunner::SetupCanReuse(RunnerVariantPack &runnerVariantPack, bool &kernel } if (!needKernelGraphModify_) { bool launchWithTiling = runnerVariantPack.context->GetLaunchWithTilingStatus(); + if (runnerVariantPack.context->GetAutoFusionFlag()) { + launchWithTiling = false; + } SetupCacheGetCachedTiling(runnerVariantPack.hostTilingBuffer, runnerVariantPack.tilingBufferSize, launchWithTiling); return true; // 组图不改,参数不改,直接返回 @@ -290,6 +293,9 @@ Status OpsRunner::FillSingleKernelHostTilingBuffer(KernelGraphNode &node, size_t GetOpSetupStatistic().tilingCacheMissCount += 1; Mki::Timer fillTimer; bool launchWithTiling = context->GetLaunchWithTilingStatus(); + if (context->GetAutoFusionFlag()) { + launchWithTiling = false; + } Status status = node.impl->InitKernelInfo(kernelHostTilingBuffer, tilingSize, launchWithTiling); if (status != NO_ERROR) { ATB_LOG(ERROR) << GetLogPrefix() << " node[" << nodeId << "] InitRunInfo failed!"; @@ -357,6 +363,9 @@ Status OpsRunner::UpdateDeviceRealAddr(const RunnerVariantPack &runnerVariantPac uint8_t *deviceIntermediateBuffer = runnerVariantPack.intermediateBuffer; bool isLaunchKernelWithTiling = runnerVariantPack.context->GetLaunchWithTilingStatus(); bool needSetTiling = !(isLaunchKernelWithTiling || (totalTilingSize_ == 0)); + if (runnerVariantPack.context->GetAutoFusionFlag()) { + needSetTiling = true; + } bool needSetworkspace = (workspaceSize_ != 0); uint64_t tilingOffset = 0; uint64_t deviceArgsSizeOffset = 0; diff --git a/src/atb/utils/param_to_json.cpp b/src/atb/utils/param_to_json.cpp index 51a42763..bd08f828 100644 --- a/src/atb/utils/param_to_json.cpp +++ b/src/atb/utils/param_to_json.cpp @@ -764,6 +764,13 @@ template <> nlohmann::json OpParamToJson(const infer::CohereLayerNormParam &opPa return cohereLayerNormParamsJson; } +template <> nlohmann::json OpParamToJson(const infer::FusionParam &opParam) +{ + nlohmann::json fusionParamsJson; + fusionParamsJson["type"] = opParam.fusionType; + return fusionParamsJson; +} + template <> nlohmann::json OpParamToJson(const infer::GatherPreRmsNormParam &opParam) { nlohmann::json gatherPreRmsNormParamJson; diff --git a/src/include/atb/core/auto_fusion_tool.h b/src/include/atb/core/auto_fusion_tool.h new file mode 100644 index 00000000..21cdd6dd --- /dev/null +++ b/src/include/atb/core/auto_fusion_tool.h @@ -0,0 +1,201 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef AUTO_FUSION_TOOL_H +#define AUTO_FUSION_TOOL_H +#include +#include +#include +#include +#include "atb/auto_fusion.h" +#include "atb/context.h" +#include "atb/graph_op_builder.h" +#include "atb/infer_op_params.h" +#include "atb/train_op_params.h" +#include "atb/operation.h" +#include "atb/svector.h" +#include "atb/types.h" +#include "atb/utils.h" + +//! +//! \file auto_fusion_tool.h +//! +//! \brief 定义算子自动融合AutoFusion类 +//! +namespace atb { +//! +//! \class AutoFusionTool +//! +//! \brief 定义算子自动融合AutoFusion类 +//! +//! 该接口类定义了算子自动融合工具类,提供了自动融合的函数方法。 +//! + +class AutoFusionTool : public AutoFusion { +public: + //! \brief 构造函数 + //! + //! \param graph AutoFusion类被创建的时候需要绑定的图结构 + //! + AutoFusionTool(atb::GraphParam &graph); + + //! \brief 析构函数。 + ~AutoFusionTool() = default; + + //! \brief 执行auto fusion的入口函数 + //! + //! \param fusionClassArray 用户指定的融合类型,默认是空,表示不触发算子自动融合 + //! + void DoAutoFusion(const std::set &fusionClassArray = {}) override; +private: + //! \brief 判断id为传入参数的tensor是不是其他node的输入 + //! + //! \param id 被判断的tensor id + //! \param re 返回的结果 + void findNodesWhichInputsIs(const uint32_t id, std::vector &re); + + //! \brief 针对手机到的node进行分析,筛选出可以被融合的nodes + //! \param fusionclassAndIndex 存放结果的容器 + void parseCollectNodes(std::vector>> &fusionclassAndIndex); + + //! \brief 收集符合融合条件的nodes + //! + void CollectNodes(); + + //! \brief 使用dialect对图进行描述,根据描述所得的各个node的名称+输入输出的 + //! 拓扑逻辑给出能否进行融合的判断,输出std::pair>数组 + //! \param fusionclassAndIndex 需要更新的图结构中的node pair所对应的融合之后算子的存储路径 + //! + void ParseFusion(std::vector>> &fusionclassAndIndex); + + + //! \brief 根据上述得到的vector>>, 对每个pair获取BiShengIR的所编译出来的算子二进制 + //! + void GetFusionBinAndUpdateFusedGraph(); + + //! \brief 更新自动融合之后的算子,并且针对node struct增加融合标签fused,在后续的推理获取kernel的过程中,实现动静分流 + //! + //! \param fusionclassAndIndex 需要更新的图结构中的融合类型+node pair + //! \param fusionclassAndBin 需要更新的图结构中的融合类型+二进制路径pair + //! + void UpdateFusedGraph(std::vector>> &fusionclassAndIndex, + const std::vector> &fusionclassAndBin); + + //! \brief 针对输入输出的tensor的ids对对nodes进行排序,并且直接输出连接node的tensor的id vector,合并2个node的输入输出 + //! + //! \param fusedNodes 自动融合node数组 + //! \param fusionClass 自动融合类型 + //! + void SortNodeByTensorIdsAndChangeTopology(std::vector &fusedNodes, + const std::string &fusionClass); + + //! \brief 根据排序结构更新图结构 + //! + //! \param fusedNodes 自动融合node数组 + //! \param fusionClass 自动融合类型 + //! + void UpdateGraphStruct(const std::vector &fusedNodes, + const std::string &fusionClass); + + //! \brief 更新融合之后的reshape func + //! + //! \param fusedNodes 自动融合node数组 + //! + void UpdateReshapeFunc(const std::vector &fusedNodes); + + //! \brief 调用BiShengIR, 生成二进制 + //! + //! \param subFusion 自动融合类型-nodes pair + //! \return 是否正常生成的二进制的存储路径 + //! + bool callBiShengIR(const std::pair> &subFusion); + + //! \brief 生成二进制 + //! + //! \param fusionClass 融合类型 + //! \return 生成是否成功 + //! + bool genMlirAndBin(const std::string &fusionClass); + + //! \brief 生成融合算子定义 + //! + //! \param fusionClass 融合类型 + //! \return 生成是否成功 + //! + bool genFusionKernelDef(const std::string &fusionClass); + + //! \brief 获得当前图中所有的tensor id以及所有的输出tensor id + //! + //! \param allTensorIds 所有张量 + //! \param allOutTensorIds 所有输出张量 + //! + void genAllTensorIDs(std::set &allTensorIds, std::set &allOutTensorIds); + + //! \brief 获得当前图中所有的tensor id以及所有的输出tensor id + //! + //! \param outTensorIds 更新之前的所有输出张量id + //! \param allOutTensorIdsNew 更新之后的所有输出张量id + //! + void updateAllTensorIDs(const std::set &outTensorIds, const std::set &allOutTensorIdsNew); + //! \brief 获得算子的tiling key + //! + //! \param fusionClass 融合类型 + //! \return 字符串类型的tiling key + //! + std::string getTilingKey(const std::string &fusionClass); + + //! \brief 配置待创建的operation类型 + //! + //! \param param 待创建的operation参数 + //! \param fusionClass 融合类型 + //! + void SetFusionParam(atb::infer::FusionParam ¶m, const std::string &fusionClass); + + //! \brief 通过用户的指定来打开特定的融合类型 + //! + //! \param fusionClassArray 用户指定的融合类型 + //! + void SetFusionClass(const std::set &fusionClassArray); + + //! \brief 存储BiShengIR的home路径构造的隐藏文件件 + //! + //! 用来存储BiShnegIR产生的二进制 + //! + std::string homePath_; + + //! \brief 待被分析的ATB图 + //! + atb::GraphParam &graph_; + + //! \brief 存储图中可以被融合的Linear matmul nodes + //! + std::set linearNodes_; + + //! \brief 存储图中可以被融合的add nodes + //! + std::set eleAddNodes_; + + //! \brief 存储图中可以被融合的激活且类型为gelu的matmul-gelu nodes + //! + std::set actGeluNodes_; + + //! \brief 存储图中可以被融合的激活且类型为sigmoid的matmul-sigmoid nodes + //! + std::set actSigmoidNodes_; + + //! \brief 存储图中可以被融合的激活且类型为swiglu的matmul-swiglu nodes + //! + std::set actSwiGluNodes_; + + //! \brief 用户指定的融合类型 + //! + std::set fusionClassMap_; +}; +} // namespace atb +#endif \ No newline at end of file diff --git a/src/include/atb/core/context_base.h b/src/include/atb/core/context_base.h index baf07ace..8c76b1f1 100644 --- a/src/include/atb/core/context_base.h +++ b/src/include/atb/core/context_base.h @@ -50,6 +50,8 @@ public: Status FreeArgsDeviceBuffer(void *addr); Status FreeArgsHostBuffer(void *addr); bool GetLaunchWithTilingStatus(); + void SetAutoFusionFlag(bool flag = false) override; + bool GetAutoFusionFlag() const; private: Status CreateCopyStreamAndEvents(); @@ -70,6 +72,7 @@ private: Tensor overflowOutTensor_; static thread_local ExecuteType executeType_; LaunchMode mode_ = KERNEL_LAUNCH_MODE; + bool autoFusionFlag_{false}; std::unique_ptr deviceAllocator_; // 一开始就赋值为defaultDeviceAllocator std::unique_ptr hostAllocator_; // 一开始就赋值为defaultHostAllocator std::function allocateFunc_; // 默认使用defaultDeviceAllocator中的Allocate方法 diff --git a/src/include/atb/core/runner_type.h b/src/include/atb/core/runner_type.h index b226615d..581f383d 100644 --- a/src/include/atb/core/runner_type.h +++ b/src/include/atb/core/runner_type.h @@ -20,6 +20,7 @@ enum RunnerType : int { RUNNER_TYPE_CUMSUM, RUNNER_TYPE_DYNAMICNTK, RUNNER_TYPE_ELEWISE, + RUNNER_TYPE_FUSION, RUNNER_TYPE_GATHER, RUNNER_TYPE_LINEAR, RUNNER_TYPE_MATMUL, diff --git a/src/ops_infer/fusion/fusion_operation.cpp b/src/ops_infer/fusion/fusion_operation.cpp new file mode 100644 index 00000000..715ed2dd --- /dev/null +++ b/src/ops_infer/fusion/fusion_operation.cpp @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "fusion_operation.h" +#include "fusion_ops_runner.h" +#include "atb/utils/tensor_check.h" +#include "atb/utils/config.h" +#include "atb/utils/param_to_json.h" +#include "atb/utils/singleton.h" +#include "atb/core/atb_operation_ir_cfg.h" +#include "atb/core/op_param_funcs.h" + +namespace atb { +const uint32_t TENSOR_NUM_ONE = 1; +const uint32_t TENSOR_NUM_TWO = 2; +const uint32_t TENSOR_NUM_THREE = 3; +const uint32_t TENSOR_IDX_ZERO = 0; +const uint32_t TENSOR_IDX_ONE = 1; +const uint32_t TENSOR_IDX_TWO = 2; +template <> Status CreateOperation(const infer::FusionParam &opParam, Operation **operation) +{ + if (operation == nullptr) { + return ERROR_INVALID_PARAM; + } + OP_PARAM_RSV_CHECK(opParam); + + *operation = new (std::nothrow) FusionOperation(opParam); + + if (*operation == nullptr) { + ATB_LOG(ERROR) << "failed to new operation"; + return ERROR_OUT_OF_HOST_MEMORY; + } + return NO_ERROR; +} + +FusionOperation::FusionOperation(const infer::FusionParam ¶m) : OperationBase("FusionOperation"), param_(param) +{ + static std::map opIniTable = { + {infer::FusionParam::FusionType::MATMUL_ADD, "FusionOperationMatmulAdd"}, + {infer::FusionParam::FusionType::MATMUL_GELU, "FusionOperationMatmulGelu"}, + {infer::FusionParam::FusionType::MATMUL_SIGMOID, "FusionOperationMatmulSigmoid"}, + {infer::FusionParam::FusionType::MATMUL_SWIGLU, "FusionOperationMatmulSwiGlu"}, + }; +} + +FusionOperation::~FusionOperation() {} + +uint32_t FusionOperation::GetInputNum() const +{ + static std::map inTensorNumTable = { + {infer::FusionParam::FusionType::MATMUL_ADD, TENSOR_NUM_THREE}, + {infer::FusionParam::FusionType::MATMUL_GELU, TENSOR_NUM_TWO}, + {infer::FusionParam::FusionType::MATMUL_SIGMOID, TENSOR_NUM_TWO}, + {infer::FusionParam::FusionType::MATMUL_SWIGLU, TENSOR_NUM_TWO}, + }; + std::map::const_iterator it = inTensorNumTable.find(param_.fusionType); + if (it != inTensorNumTable.end()) { + return it->second; + } + ATB_LOG(ERROR) << "param_.fusionType is invalid, type:" << param_.fusionType; + return NO_ERROR; +} + +uint32_t FusionOperation::GetOutputNum() const +{ + return TENSOR_NUM_ONE; +} + +Status FusionOperation::InferShapeImpl(const SVector &inTensorDescs, + SVector &outTensorDescs) const +{ + if (infer::FusionParam::FusionType::NON_FUSION == param_.fusionType) { + return NO_ERROR; + } + if (infer::FusionParam::FusionType::MATMUL_ADD == param_.fusionType) { + outTensorDescs.at(TENSOR_IDX_ZERO) = inTensorDescs.at(TENSOR_IDX_TWO); + } else { + outTensorDescs.at(TENSOR_IDX_ZERO) = inTensorDescs.at(TENSOR_IDX_ONE); + outTensorDescs.at(TENSOR_IDX_ZERO).shape.dims[0] = inTensorDescs.at(TENSOR_IDX_ZERO).shape.dims[0]; + outTensorDescs.at(TENSOR_IDX_ZERO).shape.dims[1] = inTensorDescs.at(TENSOR_IDX_ONE).shape.dims[0]; + } + return NO_ERROR; +} + +SVector FusionOperation::GetEmptyInTensorPermissions() const +{ + SVector v; + if (GetInputNum() == TENSOR_NUM_THREE) { + SVector emptyTensorPerms(GetInputNum(), false); + emptyTensorPerms.at(TENSOR_NUM_THREE - 1) = true; + return emptyTensorPerms; + } + return v; +} + +std::shared_ptr FusionOperation::CreateRunner(Context &context) const +{ + (void)context; + if (param_.fusionType == infer::FusionParam::FusionType::NON_FUSION) { + return std::make_shared("NOT_MAIN_FUSION"); + } + return std::make_shared(param_); +} + +nlohmann::json FusionOperation::GetParamJson() const +{ + return OpParamToJson(param_); +} +} // namespace atb + + diff --git a/src/ops_infer/fusion/fusion_operation.h b/src/ops_infer/fusion/fusion_operation.h new file mode 100644 index 00000000..b7e6c3e9 --- /dev/null +++ b/src/ops_infer/fusion/fusion_operation.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ATB_FUSION_OPERATION_H +#define ATB_FUSION_OPERATION_H +#include +#include "atb/operation/operation_base.h" +#include "atb/infer_op_params.h" +namespace atb { +class FusionOperation : public OperationBase { +public: + explicit FusionOperation(const infer::FusionParam ¶m); + ~FusionOperation() override; + uint32_t GetInputNum() const override; + uint32_t GetOutputNum() const override; + +protected: + Status InferShapeImpl(const SVector &inTensorDescs, SVector &outTensorDescs) const override; + std::shared_ptr CreateRunner(Context &context) const override; + SVector GetEmptyInTensorPermissions() const override; + nlohmann::json GetParamJson() const override; + +private: + infer::FusionParam param_; +// Status InferShapeCommon(const SVector &inTensorDescs, SVector &outTensorDescs) const; +}; +} +#endif \ No newline at end of file diff --git a/src/ops_infer/fusion/fusion_ops_runner.cpp b/src/ops_infer/fusion/fusion_ops_runner.cpp new file mode 100644 index 00000000..e6d5345d --- /dev/null +++ b/src/ops_infer/fusion/fusion_ops_runner.cpp @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#include "fusion_ops_runner.h" +#include "atb/utils/log.h" + +namespace atb { +static const uint32_t NUMONE = 1; +static const uint32_t NUMTWO = 2; +static const uint32_t NUMTHREE = 3; +static const uint32_t INDEX_ZERO = 0; +static const uint32_t INDEX_ONE = 1; +static const uint32_t INDEX_TWO = 2; + +FusionOpsRunner::FusionOpsRunner(const infer::FusionParam ¶m) + : OpsRunner("FusionOpsRunner", RUNNER_TYPE_FUSION), param_(param) +{ + ATB_LOG(INFO) << "FusionOpsRunner::FusionOpsRunner called"; + + kernelGraph_.nodes.resize(NUMONE); + auto &fusionNode = kernelGraph_.nodes.at(INDEX_ZERO); + if (!SetIntensor(fusionNode)) { + return; + } + SetOuttensor(fusionNode); + ATB_LOG(INFO) << "FusionOpsRunner::FusionOpsRunner end"; +} + +FusionOpsRunner::~FusionOpsRunner() {} + +bool FusionOpsRunner::SetIntensor(KernelGraphNode &fusionNode) +{ + uint32_t inTensorNum = GetIntensorSize(); + if (inTensorNum == 0) { + return false; + } + kernelGraph_.inTensors.resize(inTensorNum); + if (inTensorNum == NUMONE) { + Mki::Tensor &aTensor = kernelGraph_.inTensors.at(INDEX_ZERO); + fusionNode.inTensors = {&aTensor}; + } else if (inTensorNum == NUMTWO) { + Mki::Tensor &aTensor = kernelGraph_.inTensors.at(INDEX_ZERO); + Mki::Tensor &bTensor = kernelGraph_.inTensors.at(INDEX_ONE); + fusionNode.inTensors = {&aTensor, &bTensor}; + } else if (inTensorNum == NUMTHREE) { + Mki::Tensor &aTensor = kernelGraph_.inTensors.at(INDEX_ZERO); + Mki::Tensor &bTensor = kernelGraph_.inTensors.at(INDEX_ONE); + Mki::Tensor &cTensor = kernelGraph_.inTensors.at(INDEX_TWO); + fusionNode.inTensors = {&aTensor, &bTensor, &cTensor}; + } else { + ATB_LOG(WARN) << "FusionOpsRunner::FusionOpsRunner inTensorNum: " << inTensorNum; + } + return true; +} + +void FusionOpsRunner::SetOuttensor(KernelGraphNode &fusionNode) +{ + AtbOps::OpParam::Fusion::FusionType opFusionType = GetOpFusionType(); + AtbOps::OpParam::Fusion fusionParam = {opFusionType}; + kernelGraph_.outTensors.resize(1); + Mki::Tensor &operationOutTensor0 = kernelGraph_.outTensors.at(INDEX_ZERO); + fusionNode.outTensors = {&operationOutTensor0}; + fusionParam.outTensorType = GetOutTensorType(param_.outTensorType); + fusionNode.opDesc = {0, "FusionOperation", fusionParam}; + if (fusionParam.fusionType == AtbOps::OpParam::Fusion::MATMUL_ADD) { + fusionNode.opDesc = {1, "FusionOperation", fusionParam}; + } else if (fusionParam.fusionType == AtbOps::OpParam::Fusion::MATMUL_GELU) { + fusionNode.opDesc = {2, "FusionOperation", fusionParam}; + } else if (fusionParam.fusionType == AtbOps::OpParam::Fusion::MATMUL_SIGMOID) { + fusionNode.opDesc = {3, "FusionOperation", fusionParam}; + } else if (fusionParam.fusionType == AtbOps::OpParam::Fusion::MATMUL_SWIGLU) { + fusionNode.opDesc = {4, "FusionOperation", fusionParam}; + } + return ; +} + +uint32_t FusionOpsRunner::GetIntensorSize() const +{ + static std::map inTensorNumTable = { + {infer::FusionParam::FusionType::MATMUL_ADD, NUMTHREE}, + {infer::FusionParam::FusionType::MATMUL_GELU, NUMTWO}, + {infer::FusionParam::FusionType::MATMUL_SIGMOID, NUMTWO}, + {infer::FusionParam::FusionType::MATMUL_SWIGLU, NUMTWO}, + }; + std::map::const_iterator it = inTensorNumTable.find(param_.fusionType); + return it == inTensorNumTable.end() ? 0 : it->second; +} + +AtbOps::OpParam::Fusion::FusionType FusionOpsRunner::GetOpFusionType() const +{ + static std::map typeTable = { + {infer::FusionParam::FusionType::MATMUL_ADD, AtbOps::OpParam::Fusion::MATMUL_ADD}, + {infer::FusionParam::FusionType::MATMUL_GELU, AtbOps::OpParam::Fusion::MATMUL_GELU}, + {infer::FusionParam::FusionType::MATMUL_SIGMOID, AtbOps::OpParam::Fusion::MATMUL_SIGMOID}, + {infer::FusionParam::FusionType::MATMUL_SWIGLU, AtbOps::OpParam::Fusion::MATMUL_SWIGLU}, + }; + std::map::const_iterator it = + typeTable.find(param_.fusionType); + return it == typeTable.end() ? AtbOps::OpParam::Fusion::MATMUL_ADD : it->second; // NON_FUSION +} + +Mki::TensorDType FusionOpsRunner::GetOutTensorType(const aclDataType outType) const +{ + static std::map typeTable = { + {aclDataType::ACL_INT8, Mki::TensorDType::TENSOR_DTYPE_INT8}, + {aclDataType::ACL_FLOAT, Mki::TensorDType::TENSOR_DTYPE_FLOAT}, + {aclDataType::ACL_FLOAT16, Mki::TensorDType::TENSOR_DTYPE_FLOAT16}, + {aclDataType::ACL_INT32, Mki::TensorDType::TENSOR_DTYPE_INT32}, + {aclDataType::ACL_INT64, Mki::TensorDType::TENSOR_DTYPE_INT64}, + {aclDataType::ACL_BF16, Mki::TensorDType::TENSOR_DTYPE_BF16}, + }; + std::map::const_iterator it = typeTable.find(outType); + return it == typeTable.end() ? Mki::TensorDType::TENSOR_DTYPE_UNDEFINED : it->second; +} +} // namespace atb \ No newline at end of file diff --git a/src/ops_infer/fusion/fusion_ops_runner.h b/src/ops_infer/fusion/fusion_ops_runner.h new file mode 100644 index 00000000..7359fd5c --- /dev/null +++ b/src/ops_infer/fusion/fusion_ops_runner.h @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ATB_FUSION_OPS_RUNNER_H +#define ATB_FUSION_OPS_RUNNER_H +#include +#include "atb/runner/ops_runner.h" +#include "atb/infer_op_params.h" +namespace atb { +class FusionOpsRunner : public OpsRunner { +public: + explicit FusionOpsRunner(const infer::FusionParam ¶m); + ~FusionOpsRunner() override; +private: + infer::FusionParam param_; + Mki::Tensor nullTensor_ = {}; + bool SetIntensor(KernelGraphNode &fusionNode); + void SetOuttensor(KernelGraphNode &fusionNode); + uint32_t GetIntensorSize() const; + AtbOps::OpParam::Fusion::FusionType GetOpFusionType() const; + Mki::TensorDType GetOutTensorType(const aclDataType outType) const; +}; +} // namespace atb +#endif \ No newline at end of file diff --git a/src/torch_atb/bindings.cpp b/src/torch_atb/bindings.cpp index 17ff9f35..67b9f145 100644 --- a/src/torch_atb/bindings.cpp +++ b/src/torch_atb/bindings.cpp @@ -81,7 +81,7 @@ PYBIND11_MODULE(_C, m) .def_property_readonly("name", &TorchAtb::OperationWrapper::GetName) .def_property_readonly("input_num", &TorchAtb::OperationWrapper::GetInputNum) .def_property_readonly("output_num", &TorchAtb::OperationWrapper::GetOutputNum) - .def("forward", &TorchAtb::OperationWrapper::Forward) + .def("forward", &TorchAtb::OperationWrapper::Forward, py::arg("inTensors"), py::arg("autoFusionFlag") = bool{false}) .def("__repr__", [](const TorchAtb::OperationWrapper &opWrapper) { std::stringstream ss; ss << "op name: " << opWrapper.GetName() << ", input_num: " << opWrapper.GetInputNum() @@ -178,7 +178,7 @@ PYBIND11_MODULE(_C, m) .def("set_input_output", &TorchAtb::GraphOperationBuilder::SetInputOutput) .def("reshape", &TorchAtb::GraphOperationBuilder::Reshape) .def("add_operation", &TorchAtb::GraphOperationBuilder::AddOperation) - .def("build", &TorchAtb::GraphOperationBuilder::Build); + .def("build", &TorchAtb::GraphOperationBuilder::Build, py::arg("autoFusionClassArray") = std::set{}); py::enum_(m, "AclDataType") .value("ACL_DT_UNDEFINED", aclDataType::ACL_DT_UNDEFINED) diff --git a/src/torch_atb/graph_operation_builder.cpp b/src/torch_atb/graph_operation_builder.cpp index 5c089037..6666216a 100644 --- a/src/torch_atb/graph_operation_builder.cpp +++ b/src/torch_atb/graph_operation_builder.cpp @@ -101,12 +101,35 @@ GraphOperationBuilder &GraphOperationBuilder::Reshape(const std::string &srcTens return *this; } -OperationWrapper GraphOperationBuilder::Build() +OperationWrapper GraphOperationBuilder::Build(const std::set &autoFusionClassArray) { graphParam_.internalTensorNum = internalTensorNum_; + if (!autoFusionClassArray.empty()) { + OperationWrapper operationWrapper = OperationWrapper(graphParam_, autoFusionClassArray); + return operationWrapper; + } return OperationWrapper(graphParam_); } +void GraphOperationBuilder::UpdateReshapeFunc(const std::vector &fusedNodes) +{ + const size_t sizeOfInterIds = graphParam_.nodes[fusedNodes[0]].inTensorIds.size(); + for (size_t i = 0; i < sizeOfInterIds; i++) { + bool flagHaveReshapeFunc = false; + for (auto reshapeFunc : reshapedTensorIds_) { + if (reshapeFunc.second.first == graphParam_.nodes[fusedNodes[0]].inTensorIds.at(i)) { + graphParam_.nodes[fusedNodes[0]].inTensorReshapeFuncs.at(i) = reshapeFunc.second.second; + flagHaveReshapeFunc = true; + break; + } + } + if (!flagHaveReshapeFunc) { + graphParam_.nodes[fusedNodes[0]].inTensorReshapeFuncs.at(i) = nullptr; + } + } + return ; +} + uint32_t GraphOperationBuilder::GetTensorId(const std::string &tensorName) { if (inTensorIds_.find(tensorName) != inTensorIds_.end()) { diff --git a/src/torch_atb/graph_operation_builder.h b/src/torch_atb/graph_operation_builder.h index cc6fc5bf..c32d7755 100644 --- a/src/torch_atb/graph_operation_builder.h +++ b/src/torch_atb/graph_operation_builder.h @@ -24,7 +24,7 @@ public: const std::vector &outTensorNames); GraphOperationBuilder &Reshape(const std::string &srcTensorName, const ReshapeHandler &reshapeHandler, const std::string &reshapedTensorName); - OperationWrapper Build(); + OperationWrapper Build(const std::set& autoFusionClassArray = {}); private: uint32_t GetTensorId(const std::string &tensorName); @@ -36,6 +36,7 @@ private: std::map outTensorIds_; std::map internalTensorIds_; std::map> reshapedTensorIds_; + void UpdateReshapeFunc(const std::vector& fusedNodes); }; } // namespace TorchAtb #endif // TORCH_ATB_GRAPH_OPERATION_WRAPPER_H \ No newline at end of file diff --git a/src/torch_atb/operation_wrapper.cpp b/src/torch_atb/operation_wrapper.cpp index 7f3a1602..a4f8979f 100644 --- a/src/torch_atb/operation_wrapper.cpp +++ b/src/torch_atb/operation_wrapper.cpp @@ -11,6 +11,7 @@ #include #include #include "atb/utils/log.h" +#include "atb/core/auto_fusion_tool.h" #include "resource/utils.h" #include "resource/memory_manager.h" #include "prof/prof_stats.h" @@ -198,6 +199,11 @@ OperationWrapper::OperationWrapper(const RelayAttentionParam ¶m) CreateOpUniquePtr(param); } +OperationWrapper::OperationWrapper(const FusionParam ¶m) +{ + CreateOpUniquePtr(param); +} + OperationWrapper::OperationWrapper(const TopkToppSamplingParam ¶m) { CreateOpUniquePtr(param); @@ -213,6 +219,16 @@ OperationWrapper::OperationWrapper(const GraphParam ¶m) CreateOpUniquePtr(param); } +OperationWrapper::OperationWrapper(GraphParam ¶m, const std::set &fusionClassArray) +{ + if (!fusionClassArray.empty()) { + AutoFusion *autoFusionTool = nullptr; + atb::CreateAutoFusionTool(param, &autoFusionTool); + autoFusionTool->DoAutoFusion(fusionClassArray); + } + CreateOpUniquePtr(param); +} + std::string OperationWrapper::GetName() const { return operation_->GetName(); @@ -228,12 +244,13 @@ uint32_t OperationWrapper::GetOutputNum() const return operation_->GetOutputNum(); } -std::vector OperationWrapper::Forward(std::vector &inTensors) +std::vector OperationWrapper::Forward(std::vector &inTensors, bool autoFusionFlag) { Mki::Timer runTimer; if (!operation_) { throw std::runtime_error("call Forward fail, operation is nullptr"); } + autoFusionFlag_ = autoFusionFlag; std::vector outTensors; Setup(inTensors, outTensors); Execute(); @@ -275,6 +292,7 @@ void OperationWrapper::Setup(std::vector &inTensors, std::vector< variantPack_.outTensors.at(i) = Utils::ConvertToAtbTensor(outTensors.at(i)); } atb::Context *context = Utils::GetAtbContext(); + context->SetAutoFusionFlag(autoFusionFlag_); atb::Status st = operation_->Setup(variantPack_, workspaceSize_, context); if (st != NO_ERROR) { throw std::runtime_error("call operation_->Setup fail"); diff --git a/src/torch_atb/operation_wrapper.h b/src/torch_atb/operation_wrapper.h index 4291c5fa..0c91b098 100644 --- a/src/torch_atb/operation_wrapper.h +++ b/src/torch_atb/operation_wrapper.h @@ -64,11 +64,13 @@ public: explicit OperationWrapper(const atb::infer::TopkToppSamplingParam ¶m); explicit OperationWrapper(const atb::infer::AllToAllParam ¶m); explicit OperationWrapper(const atb::GraphParam ¶m); + explicit OperationWrapper(atb::GraphParam ¶m, const std::set &fusionClassArray = {}); + explicit OperationWrapper(const atb::infer::FusionParam ¶m); atb::Operation *ReleaseOperation(); std::string GetName() const; uint32_t GetInputNum() const; uint32_t GetOutputNum() const; - std::vector Forward(std::vector &inTensors); + std::vector Forward(std::vector &inTensors, bool autoFusionFlag = false); private: template void CreateOpUniquePtr(const OpParam ¶m); @@ -82,6 +84,7 @@ private: std::unique_ptr operation_; atb::VariantPack variantPack_; uint64_t workspaceSize_{0}; + bool autoFusionFlag_{false}; }; } // namespace TorchAtb #endif // TORCH_ATB_OPERATION_WRAPPER_H \ No newline at end of file diff --git a/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_linear_bias_test.py b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_linear_bias_test.py new file mode 100644 index 00000000..5ca2cb96 --- /dev/null +++ b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_linear_bias_test.py @@ -0,0 +1,73 @@ +import os +import torch +import torch.nn as nn +import torch_atb +import numpy as np + +os.environ['ATB_LAUNCH_KERNEL_WITH_TILING']="0" +ATB_LAUNCH_KERNEL_WITH_TILING = os.environ.get("ATB_LAUNCH_KERNEL_WITH_TILING") + +def run_test(): + print("----------- graph test begin ------------") + m, n, k = 1024, 1024, 1024 + linear_param = torch_atb.LinearParam() + linear_param.has_bias = True + linear_param.transpose_b = True + linear = torch_atb.Operation(linear_param) + + linear_param_1 = torch_atb.LinearParam() + linear_param_1.has_bias = True + linear_param_1.transpose_b = True + linear_1 = torch_atb.Operation(linear_param_1) + + elewise_param_2 = torch_atb.ElewiseParam() + elewise_param_2.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_2 = torch_atb.Operation(elewise_param_2) + + graph = torch_atb.GraphBuilder("matmul_add_fuse_test") \ + .set_input_output(["a", "b", "d", "f", "g", "i"], ["k"]) \ + .add_operation(linear, ["a", "b", "d"], ["c"]) \ + .add_operation(linear_1, ["f", "g", "i"], ["h"]) \ + .add_operation(elewise_add_2, ["c", "h"], ["k"]) \ + .build({"matmul_add"}) + a = torch.ones(m, n, dtype=torch.float16) + b = torch.ones(k, n, dtype=torch.float16) + d = torch.ones(1, n, dtype=torch.float16) + + f = torch.ones(m, n, dtype=torch.float16) + g = torch.ones(k, n, dtype=torch.float16) + i = torch.ones(1, n, dtype=torch.float16) + + tensors_npu = [tensor.npu() for tensor in [a, b, d, f, g, i]] + + def graph_run(): + return graph.forward(tensors_npu, True) + + def golden(): + result_1 = torch.matmul(a, b.transpose(0, 1)) + result_1 = result_1 + d + + result_2 = torch.matmul(f, g.transpose(0, 1)) + result_2 = result_2 + i + + result = result_1 + result_2 + return [result] + + cpu_goldens = golden() + print("cpu_goldens", cpu_goldens) + + npu_outputs = graph_run() + print("cpu_goldens: ", cpu_goldens[0]) + print("npu_outputs: ", npu_outputs[0].cpu()) + print("cpu_goldens: ", npu_outputs[0].shape) + print("npu_outputs: ", npu_outputs[0].cpu().shape) + difference = (cpu_goldens[0] - npu_outputs[0].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) +if __name__ == "__main__": + run_test() + + diff --git a/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_signal_test.py b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_signal_test.py new file mode 100644 index 00000000..8babe54a --- /dev/null +++ b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_signal_test.py @@ -0,0 +1,113 @@ + +import os +import torch +import torch.nn as nn +import torch_atb +import numpy as np +import torch.nn.functional as F + +os.environ['ATB_LAUNCH_KERNEL_WITH_TILING']="0" +ATB_LAUNCH_KERNEL_WITH_TILING = os.environ.get("ATB_LAUNCH_KERNEL_WITH_TILING") + +def run_test(): + print("----------- graph test begin ------------") + m, n, k = 1024, 1024, 1024 + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear = torch_atb.Operation(linear_param) + + elewise_param = torch_atb.ElewiseParam() + elewise_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add = torch_atb.Operation(elewise_param) + + linear_param_1 = torch_atb.LinearParam() + linear_param_1.has_bias = False + linear_param_1.transpose_b = True + linear_1 = torch_atb.Operation(linear_param_1) + + elewise_param_1 = torch_atb.ElewiseParam() + elewise_param_1.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_1 = torch_atb.Operation(elewise_param_1) + + elewise_param_2 = torch_atb.ElewiseParam() + elewise_param_2.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_2 = torch_atb.Operation(elewise_param_2) + + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear_gelu = torch_atb.Operation(linear_param) + + + activate_param = torch_atb.ActivationParam() + activate_param.activation_type = torch_atb.ActivationType.ACTIVATION_GELU + activate_gelu = torch_atb.Operation(activate_param) + + + + graph = torch_atb.GraphBuilder("matmul_add_fuse_test") \ + .set_input_output(["a", "b", "d", "f", "g", "i", "l", "m"], ["k", "o"]) \ + .add_operation(linear, ["a", "b"], ["c"]) \ + .add_operation(elewise_add, ["c", "d"], ["e"]) \ + .add_operation(linear_1, ["f", "g"], ["h"]) \ + .add_operation(elewise_add_1, ["h", "i"], ["j"]) \ + .add_operation(elewise_add_2, ["e", "j"], ["k"]) \ + .add_operation(linear_gelu, ["l", "m"], ["n"]) \ + .add_operation(activate_gelu, ["n"], ["o"]) \ + .build({"matmul_add", "matmul_gelu"}) + a = torch.randn(m, n, dtype=torch.float16) + b = torch.randn(k, n, dtype=torch.float16) + d = torch.randn(m, k, dtype=torch.float16) + + f = torch.randn(m, n, dtype=torch.float16) + g = torch.randn(k, n, dtype=torch.float16) + i = torch.randn(m, k, dtype=torch.float16) + + j = torch.randn(m, n, dtype=torch.float16) + k = torch.randn(k, n, dtype=torch.float16) + + tensors_npu = [tensor.npu() for tensor in [a, b, d, f, g, i, j, k]] + + def graph_run(): + return graph.forward(tensors_npu, True) + + def golden(): + result_1 = torch.matmul(a, b.transpose(0, 1)) + result_1 = result_1 + d + result_2 = torch.matmul(f, g.transpose(0, 1)) + result_2 = result_2 + i + result = result_1 + result_2 + result_gelu = torch.matmul(j, k.transpose(0, 1)) + result_gelu_1 = F.gelu(result_gelu) + return [result, result_gelu_1] + + cpu_goldens = golden() + print("cpu_goldens", cpu_goldens) + + npu_outputs = graph_run() + print("cpu_goldens: ", cpu_goldens[0]) + print("npu_outputs: ", npu_outputs[0].cpu()) + print("cpu_goldens: ", cpu_goldens[1]) + print("npu_outputs: ", npu_outputs[1].cpu()) + print("cpu_goldens: ", npu_outputs[0].shape) + print("npu_outputs: ", npu_outputs[0].cpu().shape) + + + difference = (cpu_goldens[0] - npu_outputs[0].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) + + difference = (cpu_goldens[1] - npu_outputs[1].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) +if __name__ == "__main__": + run_test() + + diff --git a/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_test.py b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_test.py new file mode 100644 index 00000000..85285cc3 --- /dev/null +++ b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_gelu_test.py @@ -0,0 +1,103 @@ + +import os +import torch +import torch.nn as nn +import torch_atb +import numpy as np +import torch.nn.functional as F + +os.environ['ATB_LAUNCH_KERNEL_WITH_TILING']="0" +ATB_LAUNCH_KERNEL_WITH_TILING = os.environ.get("ATB_LAUNCH_KERNEL_WITH_TILING") + +def run_test(): + print("----------- graph test begin ------------") + m, n, k = 1024, 1024, 1024 + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear = torch_atb.Operation(linear_param) + + elewise_param = torch_atb.ElewiseParam() + elewise_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add = torch_atb.Operation(elewise_param) + + linear_param_1 = torch_atb.LinearParam() + linear_param_1.has_bias = False + linear_param_1.transpose_b = True + linear_1 = torch_atb.Operation(linear_param_1) + + elewise_param_1 = torch_atb.ElewiseParam() + elewise_param_1.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_1 = torch_atb.Operation(elewise_param_1) + + elewise_param_2 = torch_atb.ElewiseParam() + elewise_param_2.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_2 = torch_atb.Operation(elewise_param_2) + + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear_gelu = torch_atb.Operation(linear_param) + + + activate_param = torch_atb.ActivationParam() + activate_param.activation_type = torch_atb.ActivationType.ACTIVATION_GELU + activate_gelu = torch_atb.Operation(activate_param) + + + + graph = torch_atb.GraphBuilder("matmul_add_fuse_test") \ + .set_input_output(["a", "b", "d", "f", "g", "i", "l", "m"], [ "o"]) \ + .add_operation(linear, ["a", "b"], ["c"]) \ + .add_operation(elewise_add, ["c", "d"], ["e"]) \ + .add_operation(linear_1, ["f", "g"], ["h"]) \ + .add_operation(elewise_add_1, ["h", "i"], ["j"]) \ + .add_operation(elewise_add_2, ["e", "j"], ["k"]) \ + .add_operation(linear_gelu, ["k", "m"], ["n"]) \ + .add_operation(activate_gelu, ["n"], ["o"]) \ + .build({"matmul_add", "matmul_gelu"}) + a = torch.randn(m, n, dtype=torch.float16) + b = torch.randn(k, n, dtype=torch.float16) + d = torch.randn(m, k, dtype=torch.float16) + + f = torch.randn(m, n, dtype=torch.float16) + g = torch.randn(k, n, dtype=torch.float16) + i = torch.randn(m, k, dtype=torch.float16) + + j = torch.randn(m, k, dtype=torch.float16) + k = torch.randn(n, k, dtype=torch.float16) + + tensors_npu = [tensor.npu() for tensor in [a, b, d, f, g, i, j, k]] + + def graph_run(): + return graph.forward(tensors_npu, True) + + def golden(): + result_1 = torch.matmul(a, b.transpose(0, 1)) + result_1 = result_1 + d + result_2 = torch.matmul(f, g.transpose(0, 1)) + result_2 = result_2 + i + result = result_1 + result_2 + result_gelu = torch.matmul(result, k.transpose(0, 1)) + result_gelu_1 = F.gelu(result_gelu) + return [result_gelu_1] + + cpu_goldens = golden() + print("cpu_goldens", cpu_goldens) + + npu_outputs = graph_run() + print("cpu_goldens: ", cpu_goldens[0]) + print("npu_outputs: ", npu_outputs[0].cpu()) + print("cpu_goldens: ", npu_outputs[0].shape) + print("npu_outputs: ", npu_outputs[0].cpu().shape) + + + difference = (cpu_goldens[0] - npu_outputs[0].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) +if __name__ == "__main__": + run_test() + diff --git a/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_muliti_test.py b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_muliti_test.py new file mode 100644 index 00000000..e68cde12 --- /dev/null +++ b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_muliti_test.py @@ -0,0 +1,141 @@ + +import os +import torch +import torch.nn as nn +import torch_atb +import numpy as np + +os.environ['ATB_LAUNCH_KERNEL_WITH_TILING']="0" +ATB_LAUNCH_KERNEL_WITH_TILING = os.environ.get("ATB_LAUNCH_KERNEL_WITH_TILING") + +def run_test(): + print("----------- graph test begin ------------") + m, n, k = 512, 512, 512 + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear = torch_atb.Operation(linear_param) + + elewise_param = torch_atb.ElewiseParam() + elewise_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add = torch_atb.Operation(elewise_param) + + linear_param_1 = torch_atb.LinearParam() + linear_param_1.has_bias = False + linear_param_1.transpose_b = True + linear_1 = torch_atb.Operation(linear_param_1) + + elewise_param_1 = torch_atb.ElewiseParam() + elewise_param_1.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_1 = torch_atb.Operation(elewise_param_1) + + linear_param_2 = torch_atb.LinearParam() + linear_param_2.has_bias = False + linear_param_2.transpose_b = True + linear_2 = torch_atb.Operation(linear_param_2) + + elewise_param_2 = torch_atb.ElewiseParam() + elewise_param_2.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_2 = torch_atb.Operation(elewise_param_2) + + + linear_param_3_ = torch_atb.LinearParam() + linear_param_3_.has_bias = False + linear_param_3_.transpose_b = True + linear_3_ = torch_atb.Operation(linear_param_3_) + + elewise_param_3_ = torch_atb.ElewiseParam() + elewise_param_3_.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_3_ = torch_atb.Operation(elewise_param_3_) + + + elewise_param_3 = torch_atb.ElewiseParam() + elewise_param_3.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_3 = torch_atb.Operation(elewise_param_3) + + + elewise_param_4 = torch_atb.ElewiseParam() + elewise_param_4.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_4 = torch_atb.Operation(elewise_param_4) + + + elewise_param_5 = torch_atb.ElewiseParam() + elewise_param_5.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_5 = torch_atb.Operation(elewise_param_5) + + + + + graph = torch_atb.GraphBuilder("matmul_add_fuse_test") \ + .set_input_output(["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"], ["k"]) \ + .add_operation(linear, ["0", "1"], ["a"]) \ + .add_operation(elewise_add, ["2", "a"], ["b"]) \ + .add_operation(linear_1, ["3", "4"], ["c"]) \ + .add_operation(elewise_add_1, ["5", "c"], ["d"]) \ + .add_operation(linear_2, ["6", "7"], ["e"]) \ + .add_operation(elewise_add_2, ["8", "e"], ["f"]) \ + .add_operation(linear_3_, ["9", "10"], ["g"]) \ + .add_operation(elewise_add_3_, ["11", "g"], ["h"]) \ + .add_operation(elewise_add_3, ["b", "d"], ["i"]) \ + .add_operation(elewise_add_4, ["f", "h"], ["j"]) \ + .add_operation(elewise_add_5, ["i", "j"], ["k"]) \ + .build({"matmul_add"}) + a = torch.randn(m, n, dtype=torch.float16) + b = torch.randn(k, n, dtype=torch.float16) + d = torch.randn(m, k, dtype=torch.float16) + + f = torch.randn(m, n, dtype=torch.float16) + g = torch.randn(k, n, dtype=torch.float16) + i = torch.randn(m, k, dtype=torch.float16) + + a1 = torch.randn(m, n, dtype=torch.float16) + b1 = torch.randn(k, n, dtype=torch.float16) + d1 = torch.randn(m, k, dtype=torch.float16) + + f1 = torch.randn(m, n, dtype=torch.float16) + g1 = torch.randn(k, n, dtype=torch.float16) + i1 = torch.randn(m, k, dtype=torch.float16) + + tensors_npu = [tensor.npu() for tensor in [a, b, d, f, g, i, a1, b1, d1, f1, g1, i1]] + + def graph_run(): + return graph.forward(tensors_npu, True) + + def golden(): + result_1 = torch.matmul(a, b.transpose(1, 0)) + result_1 = torch.add(result_1, d) + result_2 = torch.matmul(f, g.transpose(1, 0)) + result_2 = torch.add(result_2, i) + + result_3 = torch.matmul(a1, b1.transpose(1, 0)) + result_3 = torch.add(result_3, d1) + result_3_ = torch.matmul(f1, g1.transpose(1, 0)) + result_3_ = torch.add(result_3_, i1) + + + result_4 = result_1 + result_2 + result_5 = result_3_ + result_3 + result_6 = result_4 + result_5 + return [result_6] + + cpu_goldens = golden() + print("cpu_goldens", cpu_goldens) + + npu_outputs = graph_run() + print("cpu_goldens: ", cpu_goldens[0]) + print("npu_outputs: ", npu_outputs[0].cpu()) + print("cpu_goldens: ", npu_outputs[0].shape) + print("npu_outputs: ", npu_outputs[0].cpu().shape) + + + difference = (cpu_goldens[0] - npu_outputs[0].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) + fake = (cpu_goldens[0] == npu_outputs[0].cpu()) + print("fake = ", fake) +if __name__ == "__main__": + run_test() + diff --git a/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_test.py b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_test.py new file mode 100644 index 00000000..510b6f91 --- /dev/null +++ b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_add_test.py @@ -0,0 +1,88 @@ + +import os +import torch +import torch.nn as nn +import torch_atb +import numpy as np + +os.environ['ATB_LAUNCH_KERNEL_WITH_TILING']="0" +ATB_LAUNCH_KERNEL_WITH_TILING = os.environ.get("ATB_LAUNCH_KERNEL_WITH_TILING") + +def run_test(): + print("----------- graph test begin ------------") + m, n, k = 512, 512, 512 + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear = torch_atb.Operation(linear_param) + + elewise_param = torch_atb.ElewiseParam() + elewise_param.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add = torch_atb.Operation(elewise_param) + + linear_param_1 = torch_atb.LinearParam() + linear_param_1.has_bias = False + linear_param_1.transpose_b = True + linear_1 = torch_atb.Operation(linear_param_1) + + elewise_param_1 = torch_atb.ElewiseParam() + elewise_param_1.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_1 = torch_atb.Operation(elewise_param_1) + + elewise_param_2 = torch_atb.ElewiseParam() + elewise_param_2.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_2 = torch_atb.Operation(elewise_param_2) + + + graph = torch_atb.GraphBuilder("matmul_add_fuse_test") \ + .set_input_output(["a", "b", "d", "f", "g", "i"], ["k"]) \ + .add_operation(linear, ["a", "b"], ["c"]) \ + .add_operation(elewise_add, ["c", "d"], ["e"]) \ + .add_operation(linear_1, ["f", "g"], ["h"]) \ + .add_operation(elewise_add_1, ["h", "i"], ["j"]) \ + .add_operation(elewise_add_2, ["e", "j"], ["k"]) \ + .build({"matmul_add"}) + a = torch.randn(m, n, dtype=torch.float16) + b = torch.randn(k, n, dtype=torch.float16) + d = torch.randn(m, k, dtype=torch.float16) + + f = torch.randn(m, n, dtype=torch.float16) + g = torch.randn(k, n, dtype=torch.float16) + i = torch.randn(m, k, dtype=torch.float16) + + + + tensors_npu = [tensor.npu() for tensor in [a, b, d, f, g, i]] + + def graph_run(): + return graph.forward(tensors_npu, True) + + def golden(): + result_1 = torch.matmul(a, b.transpose(0, 1)) + result_1 = result_1 + d + result_2 = torch.matmul(f, g.transpose(0, 1)) + result_2 = result_2 + i + result = result_1 + result_2 + return [result] + + cpu_goldens = golden() + print("cpu_goldens", cpu_goldens) + + npu_outputs = graph_run() + print("cpu_goldens: ", cpu_goldens[0]) + print("npu_outputs: ", npu_outputs[0].cpu()) + print("cpu_goldens: ", npu_outputs[0].shape) + print("npu_outputs: ", npu_outputs[0].cpu().shape) + + + difference = (cpu_goldens[0] - npu_outputs[0].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) + fake = (cpu_goldens[0] == npu_outputs[0].cpu()) + print("fake = ", fake) +if __name__ == "__main__": + run_test() + diff --git a/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_gelu_test.py b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_gelu_test.py new file mode 100644 index 00000000..4a409a49 --- /dev/null +++ b/tests/apitest/torch_atb_test/graph_test/graph_auto_fusion_matmul_gelu_test.py @@ -0,0 +1,91 @@ + +import os +import torch +import torch.nn as nn +import torch_atb +import numpy as np +import torch.nn.functional as F + + +os.environ['ATB_LAUNCH_KERNEL_WITH_TILING']="0" +ATB_LAUNCH_KERNEL_WITH_TILING = os.environ.get("ATB_LAUNCH_KERNEL_WITH_TILING") + +def run_test(): + print("----------- graph test begin ------------") + m, n, k = 512, 512, 512 + linear_param = torch_atb.LinearParam() + linear_param.has_bias = False + linear_param.transpose_b = True + linear = torch_atb.Operation(linear_param) + + activate_param = torch_atb.ActivationParam() + activate_param.activation_type = torch_atb.ActivationType.ACTIVATION_GELU + activate_gelu = torch_atb.Operation(activate_param) + + + linear_param_1 = torch_atb.LinearParam() + linear_param_1.has_bias = False + linear_param_1.transpose_b = True + linear_1 = torch_atb.Operation(linear_param_1) + + activate_param_1 = torch_atb.ActivationParam() + activate_param_1.activation_type = torch_atb.ActivationType.ACTIVATION_GELU + activate_gelu_1 = torch_atb.Operation(activate_param_1) + + elewise_param_2 = torch_atb.ElewiseParam() + elewise_param_2.elewise_type = torch_atb.ElewiseParam.ElewiseType.ELEWISE_ADD + elewise_add_2 = torch_atb.Operation(elewise_param_2) + + + graph = torch_atb.GraphBuilder("matmul_add_fuse_test") \ + .set_input_output(["a", "b", "f", "g"], ["k"]) \ + .add_operation(linear, ["a", "b"], ["c"]) \ + .add_operation(activate_gelu, ["c"], ["e"]) \ + .add_operation(linear_1, ["f", "g"], ["h"]) \ + .add_operation(activate_gelu_1, ["h"], ["j"]) \ + .add_operation(elewise_add_2, ["e", "j"], ["k"]) \ + .build({"matmul_gelu"}) + a = torch.randn(m, n, dtype=torch.float16) + b = torch.ones(k, n, dtype=torch.float16) + # d = torch.randn(m, k, dtype=torch.float16) + + f = torch.ones(m, n, dtype=torch.float16) + g = torch.randn(k, n, dtype=torch.float16) + # i = torch.randn(m, k, dtype=torch.float16) + + + + tensors_npu = [tensor.npu() for tensor in [a, b, f, g]] + + def graph_run(): + return graph.forward(tensors_npu, True) + + def golden(): + result_1 = torch.matmul(a, b.transpose(0, 1)) + result_1 = F.gelu(result_1) + result_2 = torch.matmul(f, g.transpose(0, 1)) + result_2 = F.gelu(result_2) + result = result_1 + result_2 + return [result] + + cpu_goldens = golden() + print("cpu_goldens", cpu_goldens) + + npu_outputs = graph_run() + print("cpu_goldens: ", cpu_goldens[0]) + print("npu_outputs: ", npu_outputs[0].cpu()) + print("cpu_goldens: ", npu_outputs[0].shape) + print("npu_outputs: ", npu_outputs[0].cpu().shape) + + + difference = (cpu_goldens[0] - npu_outputs[0].cpu()).numpy() + print("difference = ", difference) + mean = np.mean(difference) + print("差异均值", mean) + std = np.std(difference) + print("差异标准差", std) + fake = (cpu_goldens[0] == npu_outputs[0].cpu()) + print("fake = ", fake) +if __name__ == "__main__": + run_test() + diff --git a/tests/unittest/normal/auto_fusion_graph.cpp b/tests/unittest/normal/auto_fusion_graph.cpp new file mode 100644 index 00000000..fb941c47 --- /dev/null +++ b/tests/unittest/normal/auto_fusion_graph.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ + +#include +// #include +#include +#include +#include "test_utils/test_common.h" +#include "atb/operation.h" +#include "atb/utils/tensor_util.h" +#include "test_utils/operation_test.h" +#include "atb/utils/operation_util.h" +#include "atb/operation_infra.h" +#include "atb/operation/operation_base.h" +#include "atb/utils/config.h" +#include "atb/utils/singleton.h" +#include "atb/auto_fusion.h" +#include "atb/utils/singleton.h" + +using namespace atb; + +static void CreateMatMulAddGraphOperation(atb::GraphParam &opGraph, atb::Operation **operation, + bool autoFusion = false) +{ + // 构子图流程 + opGraph.inTensorNum = 6; + opGraph.outTensorNum = 1; + opGraph.internalTensorNum = 4; + opGraph.nodes.resize(5); + + size_t nodeId = 0; + atb::Node &linNode1 = opGraph.nodes.at(nodeId++); + atb::Node &addNode1 = opGraph.nodes.at(nodeId++); + atb::Node &linNode2 = opGraph.nodes.at(nodeId++); + atb::Node &addNode2 = opGraph.nodes.at(nodeId++); + atb::Node &addNode = opGraph.nodes.at(nodeId++); + + atb::infer::LinearParam linearParam1; + linearParam1.hasBias = false; + linearParam1.transposeB = true; + atb::CreateOperation(linearParam1, &linNode1.operation); + linNode1.inTensorIds = {0, 1}; + linNode1.outTensorIds = {7}; + + atb::infer::ElewiseParam addParam1; + addParam1.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CreateOperation(addParam1, &addNode1.operation); + addNode1.inTensorIds = {2, 7}; + addNode1.outTensorIds = {9}; + + + atb::infer::LinearParam linearParam2; + linearParam2.hasBias = false; + linearParam2.transposeB = true; + atb::CreateOperation(linearParam2, &linNode2.operation); + linNode2.inTensorIds = {3, 4}; + linNode2.outTensorIds = {8}; + + atb::infer::ElewiseParam addParam2; + addParam2.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CreateOperation(addParam2, &addNode2.operation); + addNode2.inTensorIds = {5, 8}; + addNode2.outTensorIds = {10}; + + atb::infer::ElewiseParam addParam; + addParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + CreateOperation(addParam, &addNode.operation); + addNode.inTensorIds = {9, 10}; + addNode.outTensorIds = {6}; + + if (true == autoFusion) { + atb::AutoFusion *autoFusionTool = nullptr; + atb::CreateAutoFusionTool(opGraph, &autoFusionTool); + autoFusionTool->DoAutoFusion(); + } + + atb::CreateOperation(opGraph, operation); +} + +static void CreateInTensorDescs(atb::SVector &intensorDescs) +{ + for (size_t i = 0; i < intensorDescs.size(); i++) { + intensorDescs.at(i).dtype = ACL_FLOAT16; + intensorDescs.at(i).format = ACL_FORMAT_ND; + intensorDescs.at(i).shape.dimNum = 2; + intensorDescs.at(i).shape.dims[0] = 2; + intensorDescs.at(i).shape.dims[1] = 2; + } +} + +static void CreateInTensors(atb::SVector &inTensors, atb::SVector &intensorDescs) +{ + std::vector zeroData(24, 1); // 一段全0的hostBuffer + for (size_t i = 0; i < inTensors.size(); i++) { + inTensors.at(i).desc = intensorDescs.at(i); + inTensors.at(i).dataSize = atb::Utils::GetTensorSize(inTensors.at(i)); + int ret = aclrtMalloc(&inTensors.at(i).deviceData, inTensors.at(i).dataSize, ACL_MEM_MALLOC_HUGE_FIRST); // 分配NPU内存 + if (ret != 0) { + std::cout << "alloc error!"; + exit(0); + } + ret = aclrtMemcpy(inTensors.at(i).deviceData, inTensors.at(i).dataSize, zeroData.data(), zeroData.size(), ACL_MEMCPY_HOST_TO_DEVICE); //拷贝CPU内存到NPU侧 + } +} + +static void CreateOutTensors(atb::SVector &outTensors, atb::SVector &outtensorDescs) +{ + for (size_t i = 0; i < outTensors.size(); i++) { + outTensors.at(i).desc = outtensorDescs.at(i); + outTensors.at(i).dataSize = atb::Utils::GetTensorSize(outTensors.at(i)); + int ret = aclrtMalloc(&outTensors.at(i).deviceData, outTensors.at(i).dataSize, ACL_MEM_MALLOC_HUGE_FIRST); + if (ret != 0) { + std::cout << "alloc error!"; + exit(0); + } + } +} + + +TEST(CreateMatMulAddGraphOperation, TestMatMmulAdd) +{ + if (!GetSingleton().Is910B()) { + exit(0); + } + uint32_t deviceId = 0; + aclrtSetDevice(deviceId); + atb::Context *context = nullptr; + Status st = atb::CreateContext(&context); + context->SetAutoFusionFlag(true); + ATB_LOG_IF(st != 0, ERROR) << "CreateContext fail"; + + aclrtStream stream = nullptr; + st = aclrtCreateStream(&stream); + ATB_LOG_IF(st != 0, ERROR) << "aclrtCreateStream fail"; + context->SetExecuteStream(stream); + + atb::Operation *graphOp = nullptr; + atb::GraphParam graphParam1; + CreateMatMulAddGraphOperation(graphParam1, &graphOp, true); + + // 准备输入输出tensor + atb::VariantPack pack; + atb::SVector intensorDescs1; + atb::SVector outtensorDescs1; + + uint32_t inTensorNum = 6; + uint32_t outTensorNum = 1; + pack.inTensors.resize(inTensorNum); + pack.outTensors.resize(outTensorNum); + intensorDescs1.resize(inTensorNum); + + CreateInTensorDescs(intensorDescs1); + CreateInTensors(pack.inTensors, intensorDescs1); + + outtensorDescs1.resize(outTensorNum); + outtensorDescs1.at(0).dtype = ACL_FLOAT16; + outtensorDescs1.at(0).format = ACL_FORMAT_ND; + outtensorDescs1.at(0).shape.dimNum = 2; + outtensorDescs1.at(0).shape.dims[0] = 2; + outtensorDescs1.at(0).shape.dims[1] = 2; + CreateOutTensors(pack.outTensors, outtensorDescs1); + + // Setup + uint64_t workspaceSize = 0; + graphOp->Setup(pack, workspaceSize, context); + void *workSpace = nullptr; + int ret1 = 0; + if (workspaceSize != 0) { + ret1 = aclrtMalloc(&workSpace, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST); + if (ret1 != 0) { + std::cout << "alloc error!"; + exit(0); + } + } + + //Execute + atb::Status st1 = graphOp->Execute(pack, (uint8_t *)workSpace, workspaceSize, context); + + st = (st1 == atb::NO_ERROR) ? atb::NO_ERROR : atb::ERROR_INVALID_GRAPH; + + //流同步 + ret1 = aclrtSynchronizeStream(stream); + EXPECT_EQ(ret1, atb::NO_ERROR); + if (ret1 != 0) { + std::cout << "sync error!"; + exit(0); + } + + // 资源释放 + + atb::DestroyOperation(graphOp); + atb::DestroyContext(context); + for (size_t i = 0; i < pack.inTensors.size(); i++) { + aclrtFree(pack.inTensors.at(i).deviceData); + } + for (size_t i = 0; i < pack.outTensors.size(); i++) { + aclrtFree(pack.outTensors.at(i).deviceData); + } + aclrtFree(workSpace); + aclrtDestroyStream(stream); + aclrtResetDevice(deviceId); + ASSERT_EQ(st, atb::NO_ERROR); + // aclFinalize(); +} -- Gitee