diff --git a/CMakeLists.txt b/CMakeLists.txt index fc2cc7356905e9b75d84c893d1b05bacbef62eac..7fa1fe27dc53f3c6a6e17a2eac631fdca7431927 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,24 +20,25 @@ if (ENABLE_OPEN_SRC) if (NOT DEFINED ASCEND_OPENSDK_DIR) set(ASCEND_OPENSDK_DIR $ENV{ASCEND_CUSTOM_PATH}/opensdk/opensdk) + set(ASCEND_INSTALLED_PATH $ENV{ASCEND_CUSTOM_PATH}) endif() include(${CMAKE_CURRENT_LIST_DIR}/cmake/nlohmann_json.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/secure_c.cmake) include(${CMAKE_CURRENT_LIST_DIR}/cmake/tensorflow.cmake) include_directories(${CMAKE_CURRENT_LIST_DIR}) - include_directories(${ASCEND_OPENSDK_DIR}/include) - include_directories(${ASCEND_OPENSDK_DIR}/include/aoe) - include_directories(${ASCEND_OPENSDK_DIR}/include/slog) - include_directories(${ASCEND_OPENSDK_DIR}/include/runtime) - include_directories(${ASCEND_OPENSDK_DIR}/include/ascendcl/external) - include_directories(${ASCEND_OPENSDK_DIR}/include/msprof) - include_directories(${ASCEND_OPENSDK_DIR}/include/parser) - include_directories(${ASCEND_OPENSDK_DIR}/include/parser/external) - include_directories(${ASCEND_OPENSDK_DIR}/include/air) - include_directories(${ASCEND_OPENSDK_DIR}/include/air/external) - include_directories(${ASCEND_OPENSDK_DIR}/include/metadef) - include_directories(${ASCEND_OPENSDK_DIR}/include/metadef/external) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/aoe) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/slog) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/runtime) + include_directories(${ASCEND_INSTALLED_PATH}/include) + include_directories(${ASCEND_INSTALLED_PATH}/include/acl) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/msprof) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/parser) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/parser/external) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/air) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/air/external) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/metadef) + include_directories(${CMAKE_CURRENT_LIST_DIR}/inc/metadef/external) include_directories(${CMAKE_CURRENT_LIST_DIR}/inc) if (NOT EXISTS ${CMAKE_CURRENT_LIST_DIR}/tools/COMPILE_FLAGS OR NOT EXISTS diff --git a/build.sh b/build.sh index b557f98a935366dff2cb4fe1291ac2730c4adb04..4fef5d892e23c5af05140822c7988e30e0f7407b 100755 --- a/build.sh +++ b/build.sh @@ -25,7 +25,7 @@ RELEASE_TARGET="tfadapter.tar" # print usage message usage() { echo "Usage:" - echo " bash build.sh [-h] [-j[n]] [-v] [-g] [-u] [-s] [-c] [-x]" + echo " bash build.sh [-h] [-j[n]] [-v] [-g] [-u] [-s] [-c] [-x] [-p] [-q]" echo "" echo "Options:" echo " -h Print usage" @@ -36,6 +36,8 @@ usage() { echo " -s TF_adapter stest" echo " -c TF_adapter ci build" echo " -x TF_adapter2.x ci build" + echo " -p TF_adapter build" + echo " -q TF_adapter2.x build" echo "to be continued ..." } @@ -52,8 +54,10 @@ checkopts() { ENABLE_TFADAPTER_ST="off" ENABLE_CI_BUILD="off" ENABLE_2X_CI_BUILD="off" + ENABLE_1X_BUILD="off" + ENABLE_2X_BUILD="off" # Process the options - while getopts 'hj:vuscg:x' opt + while getopts 'hj:vuscg:xpq' opt do case "${opt}" in h) usage @@ -65,6 +69,8 @@ checkopts() { s) ENABLE_TFADAPTER_ST="on" ;; c) ENABLE_CI_BUILD="on" ;; x) ENABLE_2X_CI_BUILD="on" ;; + p) ENABLE_CI_BUILD="on" && ENABLE_1X_BUILD="on" ;; + q) ENABLE_2X_BUILD="on" ;; *) logging "Undefined option: ${opt}" usage exit 1 ;; @@ -114,14 +120,22 @@ build_tfadapter() { else make ${VERBOSE} -j${THREAD_NUM} logging "tfadapter build success!" - chmod +x "${BASE_PATH}/tf_adapter_2.x/CI_Build" - sh "${BASE_PATH}/tf_adapter_2.x/CI_Build" + if [[ "X$ENABLE_1X_BUILD" = "Xoff" ]]; then + chmod +x "${BASE_PATH}/tf_adapter_2.x/CI_Build" + sh "${BASE_PATH}/tf_adapter_2.x/CI_Build" + fi fi } release_tfadapter() { logging "Create output directory" - cd ${CMAKE_PATH}/dist/python/dist && mkdir -p fwkplugin/bin && mv npu_bridge-*.whl fwkplugin/bin && mv "${BASE_PATH}/tf_adapter_2.x/build/dist/python/dist/npu_device-0.1-py3-none-any.whl" fwkplugin/bin && tar cfz "${RELEASE_TARGET}" * && mv "${RELEASE_TARGET}" "${RELEASE_PATH}" + cd ${CMAKE_PATH}/dist/python/dist && mkdir -p fwkplugin/bin && { + mv npu_bridge-*.whl fwkplugin/bin + npu_device_file="${BASE_PATH}/tf_adapter_2.x/build/dist/python/dist/npu_device-0.1-py3-none-any.whl" + if [ -f "${npu_device_file}" ]; then + mv "${npu_device_file}" fwkplugin/bin + fi + } && tar cfz "${RELEASE_TARGET}" * && mv "${RELEASE_TARGET}" "${RELEASE_PATH}" } main() { @@ -133,6 +147,10 @@ main() { if [[ "X$ENABLE_2X_CI_BUILD" = "Xon" ]]; then chmod +x "${BASE_PATH}/tf_adapter_2.x/CI_Build" sh "${BASE_PATH}/tf_adapter_2.x/CI_Build" "python3.7" + elif [[ "X$ENABLE_2X_BUILD" = "Xon" ]]; then + logging "build for tf2.x" + chmod +x "${BASE_PATH}/tf_adapter_2.x/build_tf2.sh" + sh "${BASE_PATH}/tf_adapter_2.x/build_tf2.sh" "python3.7" else build_tfadapter fi diff --git a/cmake/nlohmann_json.cmake b/cmake/nlohmann_json.cmake index a059581f521bbf04a94a66a552d899c2599a64cb..13098408e6fdd5e59b6e43f8ef25893ca678646d 100755 --- a/cmake/nlohmann_json.cmake +++ b/cmake/nlohmann_json.cmake @@ -1,7 +1,5 @@ include(FetchContent) -set(_json_url "") if(TF_PKG_SERVER) - set(_json_url "${TF_PKG_SERVER}/libs/json/v3.6.1/include.zip") FetchContent_Declare( nlohmann_json URL https://ascend-cann.obs.myhuaweicloud.com/json/repository/archive/json-v3.10.1.zip diff --git a/cmake/secure_c.cmake b/cmake/secure_c.cmake index 7b00a4890f3cecdba5bd6496aba6eab1bf512c59..cfa9f6dc402a1ff069738182ca9d8e0b14308fbb 100755 --- a/cmake/secure_c.cmake +++ b/cmake/secure_c.cmake @@ -1,10 +1,8 @@ include(FetchContent) -set(_json_url "") if(TF_PKG_SERVER) - set(_json_url "${TF_PKG_SERVER}/libs/securec/v1.1.10.tar.gz") FetchContent_Declare( secure_c - URL ${_json_url} + URL "${TF_PKG_SERVER}/libs/securec/v1.1.10.tar.gz" ) else() FetchContent_Declare( diff --git a/cmake/tensorflow.cmake b/cmake/tensorflow.cmake index d7b87279f1853b883df58d2b251c75a1783a000f..5213b23a68d71aef3eee58400fe25d4c95042ae7 100755 --- a/cmake/tensorflow.cmake +++ b/cmake/tensorflow.cmake @@ -1,11 +1,8 @@ include(FetchContent) -set(_json_url "") if(TF_PKG_SERVER) - set(_json_url "${TF_PKG_SERVER}/libs/tensorflow/v1.15.0.zip") FetchContent_Declare( tensorflow - URL ${_json_url} - URL_HASH MD5=0ad811d8d59f56ecc1a6032af997ec1d + URL "${TF_PKG_SERVER}/libs/tensorflow/v1.15.0.zip" ) else() FetchContent_Declare( diff --git a/inc/air/external/ge/ge_api.h b/inc/air/external/ge/ge_api.h new file mode 100644 index 0000000000000000000000000000000000000000..5d42fca85d947413544109cee150dda81e1c714b --- /dev/null +++ b/inc/air/external/ge/ge_api.h @@ -0,0 +1,477 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_API_H_ +#define INC_EXTERNAL_GE_GE_API_H_ + +#include +#include +#include +#include + +#include "ge/ge_api_error_codes.h" +#include "ge/ge_api_types.h" +#include "ge/ge_data_flow_api.h" +#include "ge/ge_continuous_tensor_list_api.h" +#include "ge/ge_graph_compile_summary.h" +#include "graph/graph.h" +#include "graph/tensor.h" +#include "ge/ge_allocator.h" +#include "exe_graph/runtime/tensor.h" +namespace ge { +typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map ¶ms_list); + +namespace session { +typedef uint32_t (*pCallBackFunc)(uint32_t graph_id, const std::map ¶ms_list); +} + +// Initialize GE +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY Status GEInitialize(const std::map &)) +GE_FUNC_VISIBILITY Status GEInitialize(const std::map &options); + +GE_FUNC_VISIBILITY Status GEInitialize(const std::map &options); + +// Finalize GE, release all resources +GE_FUNC_VISIBILITY Status GEFinalize(); + +GE_FUNC_VISIBILITY std::string GEGetErrorMsg(); + +GE_FUNC_VISIBILITY ge::AscendString GEGetErrorMsgV2(); + +GE_FUNC_VISIBILITY std::string GEGetWarningMsg(); + +GE_FUNC_VISIBILITY ge::AscendString GEGetWarningMsgV2(); + +GE_FUNC_VISIBILITY Status GetModelDistributeDesc(const void *data, const uint64_t length, + ModelDistibuteDesc &model_dist_desc); + +class GE_FUNC_VISIBILITY Session { + public: + ATTRIBUTED_DEPRECATED(Session(const std::map &)) + explicit Session(const std::map &options); + + explicit Session(const std::map &options); + + ~Session(); + + /// + /// @ingroup client + /// @brief add a graph with a specific graph id + /// @param [in] graph_id graph id + /// @return Status result of function + /// + Status AddGraph(uint32_t graph_id, const Graph &graph); + + /// + /// @ingroup client + /// @brief add a graph with a specific graph id and graphOptions + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @param [in] options graph options + /// @return Status result of function + /// + ATTRIBUTED_DEPRECATED(Status AddGraph(uint32_t, const Graph &, const std::map &)) + Status AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options); + + /// + /// @ingroup client + /// @brief add a graph with a specific graphId and graphOptions + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @param [in] options graph options + /// @return Status result of function + /// + Status AddGraph(uint32_t graph_id, const Graph &graph, const std::map &options); + + /// + /// @ingroup client + /// @brief add a copy graph with a specific graphId + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @return Status result of function + /// + Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph); + + /// + /// @ingroup client + /// @brief add a copy graph with a specific graphId and graphOptions + /// @param [in] graphId graph id + /// @param [in] graph the graph + /// @param [in] options graph options + /// @return Status result of function + /// + Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map &options); + + /// + /// @ingroup ge_graph + /// @brief remove a graph of the session with specific session id + /// @param [in] graph_d graph id + /// @return Status result of function + /// + Status RemoveGraph(uint32_t graph_id); + + /// + /// @ingroup ge_graph + /// @brief run a graph of the session with specific session id + /// @param [in] graphId graph id + /// @param [in] inputs input data + /// @param [out] outputs output data + /// @return Status result of function + /// + Status RunGraph(uint32_t graph_id, const std::vector &inputs, std::vector &outputs); + + /// + /// @ingroup ge_graph + /// @brief run a graph of the session with specific session id + /// @param [in] graphId graph id + /// @param [in] inputs input data list + /// @param [out] outputs output data + /// @return Status result of function + /// + Status RunGraph(uint32_t graph_id, const ContinuousTensorList &inputs, std::vector &outputs); + + /// + /// @ingroup ge_graph + /// @brief run a distributed graph of the session with specific session id + /// @param [in] graphId graph id + /// @param [in] device id respond to inputs input data + /// @param [out] device id respond to outputs output data + /// @return Status result of function + /// + Status RunGraphDistribute(uint32_t graph_id, const std::map> &device_to_inputs, + std::map> &device_to_outputs); + + /// + /// @ingroup ge_graph + /// @brief Load graph from om + /// @param [in] graphId graph id + /// @param [in] options graph options + /// @param [in] om_file_path offline om path + /// @return Status result of function + /// + /* 规避方案:规避acl系列接口不支持拉远环境;aclModelLoad接口当前不支持分布式的模型加载,不支持session管理权重变量复用 + 根据20230615 SEG会议纪要可以在session中开接口作为临时方案 + 方案详述:acl接口不支持拉远形态,不支持外置权重。使用该接口实现在session加载离线om。 + 1.将om根据指定om_file_path直接加载到ModelManager中 + 2.modelManager判断是异构模型,调用异构部署函数,生成deloyplan + 3.ModelManger提供接口返回modelid和flowModelPtr GraphPtr,InnerSession根据返回信息注册在GraphManager中生成graphnode + 4.当前调用进程可以通过RunGraph接口去执行加载的离线模型 + 方案约束:只用于加载异构离线模型,即存储格式为flowmodel下包含一个或多个submodel和modelrelation的离线模型; + 不支持包含variable的离线模型,variable的值在device侧目前还没有保存到om中的方案 + */ + Status LoadGraph(const uint32_t graph_id, const std::map &options, + const std::string &om_file_path) const; + + /// + /// @ingroup ge_graph + /// @brief load a graph of the session with specific session id and specific stream, + /// @param [in] graph_id graph id + /// @param [in] stream specific stream + /// @param [in] options graph options + /// @return Status result of function + /// + /* + 方案约束:只用于加载已经完成CompileGraph的图,不支持重复加载图, + 使用方法:CompileGraph + LoadGraph + ExecuteGraphWithStreamAsync + */ + Status LoadGraph(const uint32_t graph_id, const std::map &options, + void *stream) const; + + /// + /// @ingroup ge_graph + /// @brief run a graph of the session with specific session id and specific stream asynchronously + /// @param [in] graph_id graph id + /// @param [in] stream specific stream + /// @param [in] inputs input data + /// @param [out] outputs output data + /// @return Status result of function + /// + Status RunGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector &inputs, + std::vector &outputs); + + Status ExecuteGraphWithStreamAsync(uint32_t graph_id, void *stream, const std::vector &inputs, + std::vector &outputs); + + /// + /// @ingroup ge_graph + /// @brief build graph in the session with specific session id + /// @param [in] graphId: graph id + /// @param [in] inputs: input data + /// @return Status result of function + /// + Status BuildGraph(uint32_t graph_id, const std::vector &inputs); + + Status BuildGraph(uint32_t graph_id, const std::vector &inputs); /*lint !e148*/ + + /// + /// @ingroup ge_graph + /// @brief run graph in the session with specific session id asynchronously + /// @param [in] graphId: graph id + /// @param [in] inputs: input data + /// @param [out] callback: callback while runing graph has been finished. + /// The callback function will not be checked. + /// Please ensure that the implementation of the function is trusted. + /// @return Status result of function + /// + Status RunGraphAsync(uint32_t graph_id, const std::vector &inputs, RunAsyncCallback callback); + + /// + /// @ingroup ge_graph + /// @brief run graph in the session with specific session id asynchronously + /// @param [in] graphId: graph id + /// @param [in] inputs: input data list + /// @param [out] callback: callback while runing graph has been finished. + /// The callback function will not be checked. + /// Please ensure that the implementation of the function is trusted. + /// @return Status result of function + /// + Status RunGraphAsync(uint32_t graph_id, const ContinuousTensorList &inputs, RunAsyncCallback callback); + + /// + /// @ingroup ge_graph + /// @brief get variables in the session with specific session id + /// @param [in] var_names: variable names + /// @param [out] var_values: variable values + /// @return Status result of function + /// + ATTRIBUTED_DEPRECATED(Status GetVariables(const std::vector &, std::vector &)) + Status GetVariables(const std::vector &var_names, std::vector &var_values); + + /// + /// @ingroup ge_graph + /// @brief get variables in the session with specific session id + /// @param [in] var_names: variable names + /// @param [out] var_values: variable values + /// @return Status result of function + /// + Status GetVariables(const std::vector &var_names, std::vector &var_values); + + /// + /// @ingroup ge_graph + /// @brief register callback func with specific summary or checkpoint by users + /// @param [in] key: func key + /// @param [in] callback: callback specific summary or checkpoint. + /// The callback function will not be checked. + /// Please ensure that the implementation of the function is trusted. + /// @return Status result of function + /// + ATTRIBUTED_DEPRECATED(Status RegisterCallBackFunc(const char *, const session::pCallBackFunc &)) + Status RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback); + + Status RegisterCallBackFunc(const char *key, const session::pCallBackFunc &callback); + + bool IsGraphNeedRebuild(uint32_t graph_id); + + uint64_t GetSessionId() const; + + /// @ingroup ge_graph + /// @brief Feed input data to graph. + /// @param [in] graph_id graph id + /// @param [in] inputs input data + /// @param [in] info intput data flow flag + /// @param [in] timeout data feed timeout(ms), -1 means never timeout + /// @return Status result of function + Status FeedDataFlowGraph(uint32_t graph_id, const std::vector &inputs, const DataFlowInfo &info, + int32_t timeout); + + /// @ingroup ge_graph + /// @brief Feed input data to graph. + /// @param [in] graph_id graph id + /// @param [in] indexes fetch output data order(index cannot be duplicated) + /// @param [in] inputs input data + /// @param [in] info intput data flow flag + /// @param [in] timeout data feed timeout(ms), -1 means never timeout + /// @return Status result of function + Status FeedDataFlowGraph(uint32_t graph_id, const std::vector &indexes, const std::vector &inputs, + const DataFlowInfo &info, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Feed input data to graph. + /// @param [in] graph_id graph id + /// @param [in] raw_data_list can be 1 or n, feed will be combine n raw data automatically + /// @param [in] indexes feed input index + /// @param [in] info intput data flow flag + /// @param [in] timeout data feed timeout(ms), -1 means never timeout + /// @return Status result of function + Status FeedRawData(uint32_t graph_id, const std::vector &raw_data_list, const uint32_t index, + const DataFlowInfo &info, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Fetch graph output data in order. + /// @param [in] graph_id graph id + /// @param [out] outputs output data + /// @param [out] info output data flow flag + /// @param [in] timeout data fetch timeout(ms), -1 means never timeout + /// @return Status result of function + Status FetchDataFlowGraph(uint32_t graph_id, std::vector &outputs, DataFlowInfo &info, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Fetch graph output data in order. + /// @param [in] graph_id graph id + /// @param [in] indexes fetch output data order(index cannot be duplicated) + /// @param [out] outputs output data + /// @param [out] info output data flow flag + /// @param [in] timeout data ftech timeout(ms), -1 means never timeout + /// @return Status result of function + Status FetchDataFlowGraph(uint32_t graph_id, const std::vector &indexes, std::vector &outputs, + DataFlowInfo &info, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Feed input data to graph. + /// @param [in] graph_id graph id + /// @param [in] inputs input data + /// @param [in] timeout data feed timeout(ms), -1 means never timeout + /// @return Status result of function + Status FeedDataFlowGraph(uint32_t graph_id, const std::vector &inputs, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Feed input data to graph. + /// @param [in] graph_id graph id + /// @param [in] indexes fetch output data order(index cannot be duplicated) + /// @param [in] inputs input data + /// @param [in] timeout data feed timeout(ms), -1 means never timeout + /// @return Status result of function + Status FeedDataFlowGraph(uint32_t graph_id, const std::vector &indexes, + const std::vector &inputs, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Fetch graph output data in order. + /// @param [in] graph_id graph id + /// @param [out] outputs output data + /// @param [in] timeout data fetch timeout(ms), -1 means never timeout + /// @return Status result of function + Status FetchDataFlowGraph(uint32_t graph_id, std::vector &outputs, int32_t timeout); + + /// @ingroup ge_graph + /// @brief Fetch graph output data in order. + /// @param [in] graph_id graph id + /// @param [in] indexes fetch output data order(index cannot be duplicated) + /// @param [out] outputs output data + /// @param [in] timeout data ftech timeout(ms), -1 means never timeout + /// @return Status result of function + Status FetchDataFlowGraph(uint32_t graph_id, const std::vector &indexes, + std::vector &outputs, int32_t timeout); + + /// @ingroup ge_graph + /// @brief compile graph in the session with specific session id + /// @param [in] graphId: graph id + /// @return Status result of function + Status CompileGraph(uint32_t graph_id); + + /// + /// @ingroup ge_graph + /// @brief get graph resource summary after compiled + /// @param [in] graphId: graph id + /// @return share_ptr of CompiledGraphSummary + /// + CompiledGraphSummaryPtr GetCompiledGraphSummary(uint32_t graph_id); + + /// + /// @ingroup ge_graph + /// @brief set const memory base after compiled and before loaded, only allows setting once + /// @param [in] graphId graph id + /// @param [in] memory const memory base + /// @param [out] size const memory size + /// @return Status result of function + /// + Status SetGraphConstMemoryBase(uint32_t graph_id, const void *const memory, size_t size); + + /// + /// @ingroup ge_graph + /// @brief set or update fearture memory base after compiled + /// @param [in] graphId graph id + /// @param [in] memory feature map memory base, without input and output mem + /// @param [out] size feature map memory size + /// @return Status result of function + /// + Status UpdateGraphFeatureMemoryBase(uint32_t graph_id, const void *const memory, size_t size); + + /// + /// @ingroup ge_graph + /// @brief set fix feature memory base after compiled and before loaded, only allows setting once + /// @param [in] graphId graph id + /// @param [in] memory const memory base + /// @param [out] size const memory size + /// @return Status result of function + /// + Status SetGraphFixedFeatureMemoryBase(uint32_t graph_id, const void *const memory, size_t size); + + /// + /// @ingroup ge_graph + /// @brief set fix feature memory base with type after compiled and before loaded, one type only allows setting once + /// @param [in] graph_id graph id + /// @param [in] type memory type + /// @param [in] memory const memory base + /// @param [out] size const memory size + /// @return Status result of function + /// + Status SetGraphFixedFeatureMemoryBaseWithType(uint32_t graph_id, MemoryType type, const void *const memory, + size_t size); + + /// + /// @ingroup ge_graph + /// @brief set or update tefreshable fearture memory base after compiled, not include fix memory + /// @param [in] graphId graph id + /// @param [in] memory feature map memory base, without input and output mem + /// @param [out] size feature map memory size + /// @return Status result of function + /// + Status UpdateGraphRefreshableFeatureMemoryBase(uint32_t graph_id, const void *const memory, size_t size); + + /// @ingroup ge_graph + /// @brief register external allocator to GE. + /// @param [in] stream stream handle + /// @param [in] allocator_obj allocator object handle + /// @return Status result of function + Status RegisterExternalAllocator(const void *const stream, AllocatorPtr allocator) const; + + /// @ingroup ge_graph + /// @brief unregister external allocator to GE. + /// @param [in] stream stream handle + /// @return Status result of function + Status UnregisterExternalAllocator(const void *const stream) const; + + /// @ingroup ge_graph + /// @brief shard graphs in the session according the add graph sequence + /// @return Status result of function + Status ShardGraphsToFile(const char_t *file_path = "./") const; + + /// @ingroup ge_graph + /// @brief shard graphs in the session according the add graph sequence + /// @return Status result of function + Status ShardGraphs() const; + + /// @ingroup ge_graph + /// @brief save graphs in the session with specific file path + /// @return Status result of function + Status SaveGraphsToPb(const char_t *file_path) const; + + /// @ingroup ge_graph + /// @brief virtual memory need to remap + /// @param [in] va virtual memory + /// @param [in] new_pa new physical memory + /// @param [in] len the lens of va to remap + /// @return Status result of function + Status PaRemapped(const uint64_t va, const uint64_t new_pa, const uint64_t len) const; + + private: + uint64_t sessionId_{0}; +}; +} // namespace ge +extern "C" { +ge::Status GeSessionLoadGraph(ge::Session &session, uint32_t graph_id, + const std::map &options, + void *stream); + +ge::Status GeSessionExecuteGraphWithStreamAsync(ge::Session &session, uint32_t graph_id, void *stream, + const std::vector &inputs, + std::vector &outputs); +} // extern "C" + +#endif // INC_EXTERNAL_GE_GE_API_H_ diff --git a/inc/air/external/ge/ge_api_error_codes.h b/inc/air/external/ge/ge_api_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..b9a1e7175409c196422cbe8e213d256972940c1a --- /dev/null +++ b/inc/air/external/ge/ge_api_error_codes.h @@ -0,0 +1,106 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ +#define INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ + +#include +#include +#include "ge_error_codes.h" +#include "ge_api_types.h" + +#ifdef __GNUC__ +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif +#ifndef GE_ERRORNO_EXTERNAL +#define GE_ERRORNO_EXTERNAL(name, desc) const ErrorNoRegisterar g_errorno_##name((name), (desc)) +#endif +#ifndef GE_ERRORNO +// Code compose(4 byte), runtime: 2 bit, type: 2 bit, level: 3 bit, sysid: 8 bit, modid: 5 bit, value: 12 bit +#define GE_ERRORNO(runtime, type, level, sysid, modid, name, value, desc) \ + constexpr ge::Status name = (static_cast(0xFFU & (static_cast(runtime))) << 30U) | \ + (static_cast(0xFFU & (static_cast(type))) << 28U) | \ + (static_cast(0xFFU & (static_cast(level))) << 25U) | \ + (static_cast(0xFFU & (static_cast(sysid))) << 17U) | \ + (static_cast(0xFFU & (static_cast(modid))) << 12U) | \ + (static_cast(0x0FFFU) & (static_cast(value))); \ + const ErrorNoRegisterar g_errorno_##name((name), (desc)) + + +namespace ge { +class GE_FUNC_VISIBILITY StatusFactory { + public: + static StatusFactory *Instance() { + static StatusFactory instance; + return &instance; + } + + void RegisterErrorNo(const uint32_t err, const std::string &desc) { + // Avoid repeated addition + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = desc; + } + + void RegisterErrorNo(const uint32_t err, const char *const desc) { + if (desc == nullptr) { + return; + } + const std::string error_desc = desc; + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = error_desc; + } + + std::string GetErrDesc(const uint32_t err) { + const auto iter_find = static_cast::const_iterator>(err_desc_.find(err)); + if (iter_find == err_desc_.cend()) { + return ""; + } + return iter_find->second; + } + + AscendString GetErrDescV2(const uint32_t err) { + const std::map::const_iterator iter_find = err_desc_.find(err); + if (iter_find == err_desc_.cend()) { + return AscendString(""); + } + return AscendString(iter_find->second.c_str()); + } + + protected: + StatusFactory() = default; + ~StatusFactory() = default; + + private: + std::map err_desc_; +}; + +class GE_FUNC_VISIBILITY ErrorNoRegisterar { + public: + ErrorNoRegisterar(const uint32_t err, const std::string &desc) noexcept { + StatusFactory::Instance()->RegisterErrorNo(err, desc); + } + ErrorNoRegisterar(const uint32_t err, const char *const desc) noexcept { + StatusFactory::Instance()->RegisterErrorNo(err, desc); + } + ~ErrorNoRegisterar() = default; +}; + +// General error code +GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); +GE_ERRORNO(0b11, 0b11, 0b111, 0xFFU, 0b11111, FAILED, 0xFFFU, "failed"); /*lint !e401*/ + +} // namespace ge +#endif +#endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ diff --git a/inc/air/external/ge/ge_api_types.h b/inc/air/external/ge/ge_api_types.h new file mode 100644 index 0000000000000000000000000000000000000000..5a71643554f6948299b7aad9181de48bd27f2636 --- /dev/null +++ b/inc/air/external/ge/ge_api_types.h @@ -0,0 +1,679 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_API_TYPES_H_ +#define INC_EXTERNAL_GE_GE_API_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/tensor.h" +#include "graph/types.h" + +#ifndef GE_API_TYPES_DEF +#define GE_API_TYPES_DEF +namespace ge { +// Option key: graph run mode +const char_t *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; +const char_t *const OPTION_DEVICE_TYPE = "ge.deviceType"; +// Option key: ome init +const char_t *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; +const char_t *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId"; +const char_t *const OPTION_EXEC_JOB_ID = "ge.exec.jobId"; +const char_t *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom"; +const char_t *const OPTION_EXEC_IS_USEHVD = "ge.exec.isUseHvd"; +const char_t *const OPTION_EXEC_RANK_ID = "ge.exec.rankId"; +const char_t *const OPTION_EXEC_POD_NAME = "ge.exec.podName"; +const char_t *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; +const char_t *const OPTION_EXEC_RANK_TABLE_FILE = "ge.exec.rankTableFile"; +const char_t *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; +// Dump flag and para +const char_t *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; +const char_t *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; +const char_t *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; +const char_t *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; +const char_t *const OPTION_EXEC_DUMP_DATA = "ge.exec.dumpData"; +const char_t *const OPTION_EXEC_DUMP_LAYER = "ge.exec.dumpLayer"; +const char_t *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; +const char_t *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; +const char_t *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; +const char_t *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; +const char_t *const OPTION_EXEC_ENABLE_EXCEPTION_DUMP = "ge.exec.enable_exception_dump"; +const char_t *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; +const char_t *const OPTION_EXEC_PROFILING_FPPONIT_OPTIONS = "ge.exec.profilingFpPointOptions"; +const char_t *const OPTION_EXEC_PROFILING_BPPONIT_OPTIONS = "ge.exec.profilingBpPointOptions"; +// profiling flag +const char_t *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; +const char_t *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; +// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 +const char_t *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; +const char_t *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; +const char_t *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; +const char_t *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOptimization"; +// Dynamic input flag. ge.exec.dynamicInput=1, means enable dynaimc input, +// ge.exec.dynamicGraphExecuteMode, dynamic_execute[default] +const char_t *const OPTION_EXEC_DYNAMIC_INPUT = "ge.exec.dynamicInput"; +const char_t *const OPTION_EXEC_DYNAMIC_EXECUTE_MODE = "ge.exec.dynamicGraphExecuteMode"; +const char_t *const OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE = "ge.exec.dataInputsShapeRange"; +const char_t *const OPTION_EXEC_ENABLE_COPY_OUTPUT_ADDR = "ge.exec.enableCopyOutputAddr"; +const char_t *const OPTION_EXEC_GRAPH_EXEC_TIMEOUT = "ge.exec.graphExecTimeout"; +const char_t *const OPTION_EXEC_OPTIMIZE_SHAPE = "ge.exec.dataInputsOptimizeShape"; + +// Option key: memory init +const char_t *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; +const char_t *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; +const char_t *const OPTION_EXEC_REUSE_ZERO_COPY_MEMORY = "ge.exec.reuseZeroCopyMemory"; +const char_t *const OPTION_INPUT_REUSE_MEM_INDEXES = "ge.exec.inputReuseMemIndexes"; +const char_t *const OPTION_OUTPUT_REUSE_MEM_INDEXES = "ge.exec.outputReuseMemIndexes"; +const char_t *const OPTION_GRAPH_IO_MEM_ALLOC_MODE = "ge.exec.graphIOMemAllocMode"; + +const std::string ATOMIC_CLEAN_POLICY = "ge.exec.atomicCleanPolicy"; +const std::string MEMORY_OPTIMIZATION_POLICY = "ge.exec.memoryOptimizationPolicy"; +const std::string STATIC_MEMORY_POLICY = "ge.exec.staticMemoryPolicy"; +const char_t *const OPTION_FEATURE_BASE_REFRESHABLE = "ge.featureBaseRefreshable"; +const char_t *const OPTION_CONST_LIFECYCLE = "ge.constLifecycle"; +const char_t *const OPTION_HOST_SCHEDULING_MAX_THRESHOLD = "ge.exec.hostSchedulingMaxThreshold"; + +const char_t *const OPTION_EXEC_LOGICAL_DEVICE_CLUSTER_DEPLOY_MODE = "ge.exec.logicalDeviceClusterDeployMode"; +const char_t *const OPTION_EXEC_LOGICAL_DEVICE_ID = "ge.exec.logicalDeviceId"; +const char_t *const OPTION_EXEC_MODEL_DEPLOY_MODE = "ge.exec.modelDeployMode"; +const char_t *const OPTION_EXEC_MODEL_DEPLOY_DEVICELIST = "ge.exec.modelDeployDevicelist"; +const char_t *const OPTION_EXEC_ENABLE_FUSION = "ge.exec.enableFusion"; + +const std::string OPTION_EXEC_CM_CHIEF_IP = "ge.cmChiefIp"; +const std::string OPTION_EXEC_CM_CHIEF_PORT = "ge.cmChiefPort"; +const std::string OPTION_EXEC_CM_CHIEF_DEVICE = "ge.cmChiefWorkerDevice"; +const std::string OPTION_EXEC_CM_WORKER_IP = "ge.cmWorkerIp"; +const std::string OPTION_EXEC_CM_WORKER_SIZE = "ge.cmWorkerSize"; + +const std::string OPTION_NAME_MAP = "ge.optionNameMap"; + +const std::string OPTION_EXEC_STREAM_SYNC_TIMEOUT = "stream_sync_timeout"; +const std::string OPTION_EXEC_EVENT_SYNC_TIMEOUT = "event_sync_timeout"; + +// Option key: embedding service +const char_t *const OPTION_EXEC_PS_ID = "ge.exec.psId"; +const char_t *const OPTION_EXEC_CLUSTER_SPEC = "ge.exec.clusterSpec"; +const char_t *const OPTION_EXEC_RANK_TABLE_ADDR = "ge.exec.rankTableAddr"; +const char_t *const OPTION_EXEC_ROLE_TABLE_ADDR = "ge.exec.roleTableAddr"; +const char_t *const OPTION_EXEC_RANK_TABLE_LEN = "ge.exec.rankTableLen"; +const char_t *const OPTION_EXEC_ROLE_TABLE_LEN = "ge.exec.roleTableLen"; +const char_t *const OPTION_EXEC_WORKER_NUM = "ge.exec.workerNum"; +const char_t *const OPTION_EXECUTE_TIMES = "execute_times"; +const char_t *const OPTION_MAX_KEY_NUM = "ge.max_num"; +const char_t *const OPTION_EMBEDDING_DIM = "ge.embedding_dim"; +const char_t *const OPTION_USE_COUNTER_FILTER = "ge.use_counter_filter"; +const char_t *const OPTION_ES_MAX_REMOTEOP_NUM_PER_STREAM = "es_max_remoteop_num_per_stream"; + +// Option key: Offload +constexpr char_t const OPTION_EXEC_RANK_MAP[] = "ge.exec.rankMap"; + +// Option key: enable engine parallel or not +constexpr char_t const OPTION_EXEC_ENABLE_ENGINE_PARALLEL[] = "ge.exec.enableEngineParallel"; +constexpr char_t const OPTION_EXEC_ENGINE_PARALLEL_CONFIG_PATH[] = "ge.exec.engineParallelConfigPath"; +constexpr char_t const OPTION_EXEC_IS_IN_SHARD_GRAPH[] = "ge.exec.isInShardGraph"; + +// Option key: host env os & cpu +const char_t *const OPTION_HOST_ENV_OS = "ge.host_env_os"; +const char_t *const OPTION_HOST_ENV_CPU = "ge.host_env_cpu"; + +// config value should be a exist dir +const char_t *const OPTION_GRAPH_COMPILER_CACHE_DIR = "ge.graph_compiler_cache_dir"; +// graph unique key +const char_t *const OPTION_GRAPH_KEY = "ge.graph_key"; + +// optimizations to disable, split by comma +const char_t *const OPTION_DISABLE_OPTIMIZATIONS = "ge.disableOptimizations"; +const char_t *const ENABLE_GRAPH_PARALLEL = "ge.enableGraphParallel"; +const char_t *const GRAPH_PARALLEL_OPTION_PATH = "ge.graphParallelOptionPath"; +const std::string DISTRIBUTED_CLUSTER_BUILD = "ge.distributed_cluster_build"; +const std::string MODEL_RELATION_CONFIG = "ge.offline_model_relation"; +const std::string CLUSTER_CONFIG = "ge.cluster_config"; +const std::string OPTION_HCCL_COMPILER_OFFLINE = "ge.offline_hccl_compile"; + +// option for screen log +const char_t *const OPTION_SCREEN_PRINT_MODE = "ge.screen_print_mode"; + +// option for experimental +constexpr char_t const OPTION_STATIC_MODEL_OPS_LOWER_LIMIT[] = "ge.exec.static_model_ops_lower_limit"; + +namespace configure_option { +const char_t *const STREAM_NUM = "ge.streamNum"; +const char_t *const HEAD_STREAM = "ge.headStream"; +const char_t *const PERF_LEVEL = "ge.perfLevel"; +const char_t *const ENCRYPT_MODE = "ge.encryptMode"; +const char_t *const EK_FILE = "ge.ekFile"; +const char_t *const CERT_FILE = "ge.certFile"; +const char_t *const HW_KEY_FILE = "ge.hwKeyFile"; +const char_t *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; +const char_t *const FRAMEWORK_TYPE = "ge.frameworkType"; +const char_t *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; +const char_t *const INSERT_OP_FILE = "ge.insertOpFile"; +const char_t *const OUTPUT_NODE_NAME = "ge.outputNodeName"; +const char_t *const COMPRESS_FLAG = "ge.compressFlag"; +const char_t *const PRECISION_MODE = "ge.exec.precision_mode"; +const char_t *const PRECISION_MODE_V2 = "ge.exec.precision_mode_v2"; +const char_t *const SINGLE_OP_FLAG = "ge.exec.single_op"; +const char_t *const TRAIN_FLAG = "ge.trainFlag"; +const char_t *const RUN_FLAG = "ge.runFlag"; +const char_t *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; +const char_t *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; +const char_t *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; +const char_t *const AC_PARALLEL_ENABLE = "ac_parallel_enable"; +const char_t *const OUTPUT_DATATYPE = "ge.outputDatatype"; +const char_t *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; +const char_t *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; +const char_t *const HCOM_PARALLEL = "ge.hcomParallel"; +const char_t *const AUTO_TUNE_MODE = "ge.autoTuneMode"; +const char_t *const SOC_VERSION = "ge.socVersion"; +const char_t *const VIRTUAL_TYPE = "ge.virtual_type"; +const char_t *const CORE_TYPE = "ge.engineType"; +const char_t *const AICORE_NUM = "ge.aicoreNum"; +const char_t *const L1_FUSION = "ge.l1Fusion"; +const char_t *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; +const char_t *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; +const char_t *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; +const char_t *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; +const char_t *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; +const char_t *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; +const char_t *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; +const char_t *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; +const char_t *const PERFORMANCE_MODE = "ge.performance_mode"; +const char_t *const SHAPE_GENERALIZED_BUILD_MODE = "ge.shape_generalized_build_mode"; +const char_t *const MODIFY_MIXLIST = "ge.exec.modify_mixlist"; +const char_t *const OP_PRECISION_MODE = "ge.exec.op_precision_mode"; +const char_t *const ALLOW_HF32 = "ge.exec.allow_hf32"; +const char_t *const CUSTOMIZE_DTYPES = "ge.customizeDtypes"; +const char_t *const COMPRESSION_OPTIMIZE_CONF = "ge.compressionOptimizeConf"; +const char_t *const OP_DEBUG_CONFIG = "op_debug_config"; +const char_t *const ATOMIC_CLEAN_POLICY = "ge.exec.atomicCleanPolicy"; +const char_t *const STATIC_MEMORY_POLICY = "ge.exec.staticMemoryPolicy"; +const char_t *const EXTERNAL_WEIGHT = "ge.externalWeight"; +const char_t *const QUANT_DUMPABLE = "ge.quant_dumpable"; +static const char_t *const DETERMINISTIC = "ge.deterministic"; +const char_t *const TILING_SCHEDULE_OPTIMIZE = "ge.tiling_schedule_optimize"; +const char_t *const GRAPH_MAX_PARALLEL_MODEL_NUM = "ge.graphMaxParallelModelNum"; +} // namespace configure_option +// Configure stream num by Session constructor options param, +// its value should be int32_t type, default value is "1" +const std::string STREAM_NUM = "ge.streamNum"; + +// Configure add head stream to model. +// its value should be "0" or "1", default value is "0" +const std::string HEAD_STREAM = "ge.headStream"; + +// Configure perf level by Session constructor options param, +// its value please see enum PerfLevel, default value is "4" +const std::string PERF_LEVEL = "ge.perfLevel"; + +// Configure encrypt mode by Session constructor options param, +// its value should be int32_t type, default value is "-1" +const std::string ENCRYPT_MODE = "ge.encryptMode"; + +// configure ek file by Session constructor options param, +// its value should be file path, default value is "" +const std::string EK_FILE = "ge.ekFile"; + +// Configure cert file by Session constructor options param, +// its value should be file path, default value is "" +const std::string CERT_FILE = "ge.certFile"; + +// Configure hw key file by Session constructor options param, +// its value should be file path, default value is "" +const std::string HW_KEY_FILE = "ge.hwKeyFile"; + +// Configure private file by Session constructor options param, +// its value should be file path, default value is "" +const std::string PRIVATE_KEY_FILE = "ge.privateKeyFile"; + +// Configure framework type by Session constructor options param, +// its value please see enum FrameworkType, default value is "3" +const std::string FRAMEWORK_TYPE = "ge.frameworkType"; + +// Configure calibration info file by Session constructor options param, +// its value should be file path, default value is "" +const std::string CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; + +// Configure insert op info file by Session constructor options param, +// its value should be file path, default value is "" +const std::string INSERT_OP_FILE = "ge.insertOpFile"; + +// Configure output node name by Session constructor options param, +// its value should be std::string type, default value is "" +const std::string OUTPUT_NODE_NAME = "ge.outputNodeName"; + +// Configure weight compress flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string COMPRESS_FLAG = "ge.compressFlag"; + +const std::string PRECISION_MODE = "ge.exec.precision_mode"; + +const std::string PRECISION_MODE_V2 = "ge.exec.precision_mode_v2"; + +const std::string TUNE_DEVICE_IDS = "ge.exec.tuneDeviceIds"; + +// Configure single op flag for FE +// its value should be "0" or "1", default value is "0" +const std::string SINGLE_OP_FLAG = "ge.exec.single_op"; + +// Configure train flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string TRAIN_FLAG = "ge.trainFlag"; + +// Configure run flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string RUN_FLAG = "ge.runFlag"; + +// Configure run flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +// this option is to enable local framework op feature +const std::string LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain the TBE op plugin path +const std::string TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; + +// Configure stream max parallel num only by Session constructor options param, +// its value should be stream:int, such as "DNN_V100:2,DNN_HCCL:3", +// default value is "1", such as "DNN_V100:1,DNN_HCCL:1" +// this option is to obtain stream max parallel num +const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; + +// Configure engines such as Aicpu to compute parallelly with other engines in dynamic shape graphs. +// its value should be "0" or "1", default value is "0" +const std::string AC_PARALLEL_ENABLE = "ac_parallel_enable"; + +// congigure outputDatatype to setting net output type +const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; + +// congigure opSelectImplmode to setting op select implmode +const std::string OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; + +// congigure optypelist_for_implmode to setting which op use implmode +const std::string OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; + +// configure whether to enable hcom parallel by session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string HCOM_PARALLEL = "ge.hcomParallel"; + +// configure whether to use dynamic batch size +const char_t *const kDynamicBatchSize = "ge.dynamicBatchSize"; + +// configure threshold of fusion data size for communication op +const std::string FUSION_TENSOR_SIZE = "ge.fusionTensorSize"; + +const std::string INPUT_SHAPE = "ge.inputShape"; + +const std::string DYNAMIC_NODE_TYPE = "ge.dynamicNodeType"; +// configure whether to use dynamic image size +const char_t *const kDynamicImageSize = "ge.dynamicImageSize"; + +// Configure whether to use dynamic dims +const char_t *const kDynamicDims = "ge.dynamicDims"; + +// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, +// example: GA|RL, support configure multiple, split by | +const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; + +// Configure soc version , example: "Ascend310" +const std::string SOC_VERSION = "ge.socVersion"; + +// configure whether to enable virtualization, +// its value should be "0" or "1", default value is "0" +const std::string VIRTUAL_TYPE = "ge.virtual_type"; + +// Configure core type "VectorEngine", default value is "AIcoreEngine" +const std::string CORE_TYPE = "ge.engineType"; + +// Configure graph exclude one or more engines +const std::string EXCLUDE_ENGINES = "ge.exec.exclude_engines"; + +// Configure AICORE NUM +const std::string AICORE_NUM = "ge.aicoreNum"; + +// Configure L1FUSION +const std::string L1_FUSION = "ge.l1Fusion"; + +// Configure l1,l2,and others optimize option +const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; + +// Configure Small Channel flag +const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; + +// Configure Compress Weight flag +const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; + +// Configure Sparse Matrix Weight flag +const std::string ENABLE_SPARSE_MATRIX_WEIGHT = "ge.enableSparseMatrixWeight"; + +// Configure fusion switch file path +const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; + +// Configure compression optimize file path +const std::string COMPRESSION_OPTIMIZE_CONF = "ge.compressionOptimizeConf"; + +// Configure customize dtypes path +const std::string CUSTOMIZE_DTYPES = "ge.customizeDtypes"; + +// Configure switch for op debug config such as op memory detection +const std::string OP_DEBUG_CONFIG = "op_debug_config"; + +// Save original model +const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; + +// Save original model file name +const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; + +// Output max size +const std::string OUTPUT_MAX_SIZE = "ge.outputMaxSize"; + +const char_t *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; +const char_t *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; +const char_t *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; + +// Configure for print op pass +// Its value should be "0" or "1", default value is "1" +const char_t *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; + +// Configure operator compilation path +// Its value should be file path, default value is "./" +const char_t *const DEBUG_DIR = "ge.debugDir"; + +// Configure switch for op status check such as overflow +// Its value should be true of flase +const char_t *const STATUS_CHECK = "ge.status_check"; + +// Configure operator compiler cache path +// Its value should be file path, default value is "./" +const char_t *const OP_COMPILER_CACHE_DIR = "ge.op_compiler_cache_dir"; + +// Configure operator compiler cache mode +// Its value should be "disable", "enable" or "force", default value is "disable" +const char_t *const OP_COMPILER_CACHE_MODE = "ge.op_compiler_cache_mode"; + +// Configure build model type. FE need this option to judge inner model or not +// Its value should be "true" or "false" +const char_t *const BUILD_INNER_MODEL = "ge.build_inner_model"; + +// Configure whether to use single stream. +// Its value should be "true" or "false", default value is "false" +const char_t *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; + +// Configure input fp16 nodes +const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; + +// Configure debug level, its value should be 0(default), 1 or 2. +// 0: close debug; 1: open TBE compiler; 2: open ccec compiler +const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; + +// Configure model bank path +const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; + +// Configure display_model_info flag +const std::string DISPLAY_MODEL_INFO = "ge.display_model_info"; + +// Configure op bank path +const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; +const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; + +// Configure for fix hcombroadcast format. +// when config model multi, broadcast format should be fixed +// 0: data multi; 1: model multi; +const std::string HCOM_MULTI_MODE = "ge.hcomMultiMode"; + +// atc and ir option +const char_t *const INPUT_SHAPE_RANGE = "input_shape_range"; + +// Configure express high compile performance or high execute performance +// normal: no need to compile, used saved .o files directly +// high: need to recompile, high execute performance mode +const std::string PERFORMANCE_MODE = "ge.performance_mode"; + +// For selecting the mode of shape generalization when build graph. +// shape_generalized: Shape will be generalized during graph build. +// shape_precise: Shape will not be generalized, use precise shape. +const std::string SHAPE_GENERALIZED_BUILD_MODE = "ge.shape_generalized_build_mode"; + +const std::string JIT_COMPILE = "ge.jit_compile"; + +const std::string MODIFY_MIXLIST = "ge.exec.modify_mixlist"; + +const std::string OP_PRECISION_MODE = "ge.exec.op_precision_mode"; + +const std::string ALLOW_HF32 = "ge.exec.allow_hf32"; + +const std::string OP_WAIT_TIMEOUT = "ge.exec.opWaitTimeout"; + +const std::string OP_EXECUTE_TIMEOUT = "ge.exec.opExecuteTimeout"; + +const char_t *const FILE_CONSTANT_PATH = "ge.exec.value_bins"; + +// Configure whether convert const to fileconstant and save weight to file. +// Its value should be "0" or "1", default value is "0". +const std::string EXTERNAL_WEIGHT = "ge.externalWeight"; + +const std::string DETERMINISTIC = "ge.deterministic"; + +const std::string QUANT_DUMPABLE = "ge.quant_dumpable"; + +// option for tiling sink +const std::string TILING_SCHEDULE_OPTIMIZE = "ge.tiling_schedule_optimize"; + +const std::string GRAPH_MAX_PARALLEL_MODEL_NUM = "ge.graphMaxParallelModelNum"; + +constexpr char_t EVENT[] = "ge.event"; + +// Graph run mode +enum GraphRunMode : std::int32_t { PREDICTION = 0, TRAIN }; +// Input/Output tensor info +struct InputTensorInfo { + uint32_t data_type; // data type + std::vector dims; // shape description + void *data; // tensor data + int64_t length; // tensor length +}; + +struct OutputTensorInfo { + uint32_t data_type{}; // data type + std::vector dims; // shape description + std::unique_ptr data; // tensor data + int64_t length{}; // tensor length + OutputTensorInfo() : dims({}), data(nullptr) {} + OutputTensorInfo(OutputTensorInfo &&out) noexcept = default; + OutputTensorInfo &operator=(OutputTensorInfo &&out)& noexcept { + if (this != &out) { + data_type = out.data_type; + dims = out.dims; + data = std::move(out.data); + length = out.length; + } + return *this; + } + OutputTensorInfo(const OutputTensorInfo &) = delete; + OutputTensorInfo &operator=(const OutputTensorInfo &)& = delete; +}; + +struct ModelDistibuteDesc { + uint32_t logic_device_number; +}; + +enum class MemoryType : std::int64_t { + /* + * call aclrtMalloc with aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST, + * ACL_MEM_MALLOC_HUGE_ONLY, ACL_MEM_MALLOC_NORMAL_ONLY + */ + MEMORY_TYPE_DEFAULT, + /* + * call aclrtMalloc with aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST_P2P, + * ACL_MEM_MALLOC_HUGE_ONLY_P2P, ACL_MEM_MALLOC_NORMAL_ONLY_P2P + */ + MEMORY_TYPE_P2P +}; + +using Status = uint32_t; +using RunAsyncCallback = std::function &)>; + +// for ir build +namespace ir_option { +static const char_t *const INPUT_FORMAT = "input_format"; +static const char_t *const INPUT_SHAPE = "input_shape"; +static const char_t *const INPUT_SHAPE_RANGE = ge::INPUT_SHAPE_RANGE; +static const char_t *const OP_NAME_MAP = "op_name_map"; +static const char_t *const IS_DYNAMIC_INPUT = "is_dynamic_input"; +static const char_t *const IS_INPUT_ADJUST_HW_LAYOUT = "is_input_adjust_hw_layout"; +static const char_t *const IS_OUTPUT_ADJUST_HW_LAYOUT = "is_output_adjust_hw_layout"; +static const char_t *const ENABLE_SCOPE_FUSION_PASSES = "enable_scope_fusion_passes"; +static const char_t *const OUTPUT = "output"; +static const char_t *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; +static const char_t *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; +static const char_t *const DYNAMIC_DIMS = kDynamicDims; +static const char_t *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); +static const char_t *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); +static const char_t *const PRECISION_MODE_V2 = ge::PRECISION_MODE_V2.c_str(); +static const char_t *const TUNE_DEVICE_IDS = ge::TUNE_DEVICE_IDS.c_str(); +static const char_t *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; +static const char_t *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); +static const char_t *const CORE_TYPE = ge::CORE_TYPE.c_str(); +static const char_t *const SOC_VERSION = ge::SOC_VERSION.c_str(); +static const char_t *const VIRTUAL_TYPE = ge::VIRTUAL_TYPE.c_str(); +static const char_t *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; +static const char_t *const AC_PARALLEL_ENABLE = ge::AC_PARALLEL_ENABLE.c_str(); +static const char_t *const AICORE_NUM = ge::AICORE_NUM.c_str(); +static const char_t *const FUSION_SWITCH_FILE = ge::FUSION_SWITCH_FILE.c_str(); +static const char_t *const ENABLE_SMALL_CHANNEL = ge::ENABLE_SMALL_CHANNEL.c_str(); +static const char_t *const OP_SELECT_IMPL_MODE = ge::OP_SELECT_IMPL_MODE.c_str(); +static const char_t *const OUTPUT_TYPE = ge::OUTPUT_DATATYPE.c_str(); +static const char_t *const BUFFER_OPTIMIZE = ge::BUFFER_OPTIMIZE.c_str(); +static const char_t *const ENABLE_COMPRESS_WEIGHT = ge::ENABLE_COMPRESS_WEIGHT.c_str(); +static const char_t *const SPARSITY = ge::ENABLE_SPARSE_MATRIX_WEIGHT.c_str(); +static const char_t *const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; +static const char_t *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str(); +static const char_t *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); +static const char_t *const LOG_LEVEL = "log"; +static const char_t *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c_str(); +static const char_t *const DEBUG_DIR = ge::DEBUG_DIR; +static const char_t *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; +static const char_t *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; +static const char_t *const BUILD_INNER_MODEL = ge::BUILD_INNER_MODEL; +static const char_t *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str(); +static const char_t *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); +static const char_t *const OP_BANK_UPDATE = ge::OP_BANK_UPDATE_FLAG.c_str(); +static const char_t *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); +static const char_t *const PERFORMANCE_MODE = ge::PERFORMANCE_MODE.c_str(); +static const char_t *const SHAPE_GENERALIZED_BUILD_MODE = ge::SHAPE_GENERALIZED_BUILD_MODE.c_str(); +static const char_t *const MODIFY_MIXLIST = ge::MODIFY_MIXLIST.c_str(); +static const char_t *const OP_PRECISION_MODE = ge::OP_PRECISION_MODE.c_str(); +static const char_t *const ALLOW_HF32 = ge::ALLOW_HF32.c_str(); +static const char_t *const CUSTOMIZE_DTYPES = "ge.customizeDtypes"; +static const char_t *const COMPRESSION_OPTIMIZE_CONF = "ge.compressionOptimizeConf"; +static const char_t *const INPUT_DATA_NAMES = "input_data_names"; +static const char_t *const OP_DEBUG_CONFIG = "op_debug_config"; +static const char_t *const ATOMIC_CLEAN_POLICY = "ge.exec.atomicCleanPolicy"; +static const char_t *const EXTERNAL_WEIGHT = ge::EXTERNAL_WEIGHT.c_str(); +static const char_t *const EXCLUDE_ENGINES = ge::EXCLUDE_ENGINES.c_str(); +static const char_t *const DETERMINISTIC = ge::DETERMINISTIC.c_str(); +static const char_t *const DISTRIBUTED_CLUSTER_BUILD = ge::DISTRIBUTED_CLUSTER_BUILD.c_str(); +static const char_t *const MODEL_RELATION_CONFIG = ge::MODEL_RELATION_CONFIG.c_str(); +static const char_t *const CLUSTER_CONFIG = ge::CLUSTER_CONFIG.c_str(); +static const char_t *const ENABLE_GRAPH_PARALLEL = "ge.enableGraphParallel"; +static const char_t *const GRAPH_PARALLEL_OPTION_PATH = "ge.graphParallelOptionPath"; +static const char_t *const QUANT_DUMPABLE = ge::QUANT_DUMPABLE.c_str(); +static const char_t *const TILING_SCHEDULE_OPTIMIZE = ge::TILING_SCHEDULE_OPTIMIZE.c_str(); +static const char_t *const GRAPH_MAX_PARALLEL_MODEL_NUM = ge::GRAPH_MAX_PARALLEL_MODEL_NUM.c_str(); +// for interface: aclgrphBuildModel +#ifdef __GNUC__ +const std::set ir_builder_suppported_options = {INPUT_FORMAT, + INPUT_SHAPE, + INPUT_SHAPE_RANGE, + OP_NAME_MAP, + DYNAMIC_BATCH_SIZE, + DYNAMIC_IMAGE_SIZE, + DYNAMIC_DIMS, + INSERT_OP_FILE, + OP_PRECISION_MODE, + ALLOW_HF32, + PRECISION_MODE, + PRECISION_MODE_V2, + TUNE_DEVICE_IDS, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + OUTPUT_TYPE, + OUT_NODES, + INPUT_FP16_NODES, + LOG_LEVEL, + OP_DEBUG_LEVEL, + DEBUG_DIR, + OP_COMPILER_CACHE_DIR, + OP_COMPILER_CACHE_MODE, + MDL_BANK_PATH, + OP_BANK_PATH, + OP_BANK_UPDATE, + PERFORMANCE_MODE, + SHAPE_GENERALIZED_BUILD_MODE, + MODIFY_MIXLIST, + CUSTOMIZE_DTYPES, + BUILD_INNER_MODEL, + OP_DEBUG_CONFIG, + EXCLUDE_ENGINES, + EXTERNAL_WEIGHT, + DISTRIBUTED_CLUSTER_BUILD, + MODEL_RELATION_CONFIG, + ENABLE_GRAPH_PARALLEL, + AC_PARALLEL_ENABLE, + GRAPH_PARALLEL_OPTION_PATH, + QUANT_DUMPABLE, + TILING_SCHEDULE_OPTIMIZE, + GRAPH_MAX_PARALLEL_MODEL_NUM}; + +// for interface: aclgrphParse +const std::set ir_parser_suppported_options = { + INPUT_FP16_NODES, IS_INPUT_ADJUST_HW_LAYOUT, IS_OUTPUT_ADJUST_HW_LAYOUT, OUTPUT, + OUT_NODES, ENABLE_SCOPE_FUSION_PASSES, INPUT_DATA_NAMES, INPUT_SHAPE}; + +// for interface: aclgrphBuildInitialize +const std::set global_options = {CORE_TYPE, + SOC_VERSION, + VIRTUAL_TYPE, + BUFFER_OPTIMIZE, + ENABLE_COMPRESS_WEIGHT, + COMPRESS_WEIGHT_CONF, + SPARSITY, + PRECISION_MODE, + PRECISION_MODE_V2, + ALLOW_HF32, + TUNE_DEVICE_IDS, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + ENABLE_SINGLE_STREAM, + AC_PARALLEL_ENABLE, + AICORE_NUM, + FUSION_SWITCH_FILE, + ENABLE_SMALL_CHANNEL, + OP_SELECT_IMPL_MODE, + OPTYPELIST_FOR_IMPLMODE, + OP_DEBUG_LEVEL, + DEBUG_DIR, + OP_COMPILER_CACHE_DIR, + OP_COMPILER_CACHE_MODE, + MODIFY_MIXLIST, + COMPRESSION_OPTIMIZE_CONF, + OP_DEBUG_CONFIG, + DETERMINISTIC, + CLUSTER_CONFIG, + TILING_SCHEDULE_OPTIMIZE, + GRAPH_MAX_PARALLEL_MODEL_NUM}; +#endif +} // namespace ir_option +} // namespace ge +#endif // GE_API_TYPES_DEF +#endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ diff --git a/inc/air/external/ge/ge_continuous_tensor_list_api.h b/inc/air/external/ge/ge_continuous_tensor_list_api.h new file mode 100644 index 0000000000000000000000000000000000000000..4f174f1e3a326a5cb4fa0e709ff00b50e926e0d0 --- /dev/null +++ b/inc/air/external/ge/ge_continuous_tensor_list_api.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_CONTINUOUS_TENSOR_LIST_API_H +#define INC_EXTERNAL_GE_GE_CONTINUOUS_TENSOR_LIST_API_H +#include +#include +#include "graph/tensor.h" +#include "ge_error_codes.h" + +namespace ge { +class ContinuousTensorListImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ContinuousTensorList { + public: + ContinuousTensorList() = default; + ~ContinuousTensorList(); + explicit ContinuousTensorList(const std::vector &tensor_desc_list); + + graphStatus Initialize(); + graphStatus Finalize(); + + const std::vector &GetTensorDescList() const; + const std::vector &GetTensorList() const; + + std::vector GetDataList() const; + std::vector GetSizeList() const; + size_t GetBufferSize() const; + const void *GetBuffer() const; + + private: + std::shared_ptr impl_; +}; +} // namespace ge +#endif // INC_EXTERNAL_GE_GE_CONTINUOUS_TENSOR_LIST_API_H diff --git a/inc/air/external/ge/ge_data_flow_api.h b/inc/air/external/ge/ge_data_flow_api.h new file mode 100644 index 0000000000000000000000000000000000000000..b29af8fd3e3bd3a9c3c179fab6cf042f2796e747 --- /dev/null +++ b/inc/air/external/ge/ge_data_flow_api.h @@ -0,0 +1,229 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_DATA_FLOW_API_H +#define INC_EXTERNAL_GE_GE_DATA_FLOW_API_H +#include +#include "ge_api_error_codes.h" +#include "graph/tensor.h" + +namespace ge { +enum class DataFlowFlag : uint32_t { + DATA_FLOW_FLAG_EOS = (1U << 0U), // data flag end + DATA_FLOW_FLAG_SEG = (1U << 1U) // segment flag for discontinuous +}; + +struct RawData { + const void *addr; + size_t len; +}; + +class DataFlowInfoImpl; +class GE_FUNC_VISIBILITY DataFlowInfo { + public: + DataFlowInfo(); + ~DataFlowInfo(); + + DataFlowInfo(const DataFlowInfo &context) = delete; + DataFlowInfo(const DataFlowInfo &&context) = delete; + DataFlowInfo &operator=(const DataFlowInfo &context)& = delete; + DataFlowInfo &operator=(const DataFlowInfo &&context)& = delete; + + void SetStartTime(const uint64_t start_time); + uint64_t GetStartTime() const; + + void SetEndTime(const uint64_t end_time); + uint64_t GetEndTime() const; + + /** + * @brief Set the Flow Flags object. + * @param flow_flags can use operate | to merge multi DataFlowFla to flags. + */ + void SetFlowFlags(const uint32_t flow_flags); + /** + * @brief Get the Flow Flags object. + * @return uint32_t flow flags, can use operate & with DataFlowFlag to check which bit is set. + */ + uint32_t GetFlowFlags() const; + + /** + * @brief set user data, max data size is 64. + * @param data user data point, input. + * @param size user data size, need in (0, 64]. + * @param offset user data offset, need in [0, 64), size + offset <= 64. + * @return success:0, failed:others. + */ + Status SetUserData(const void *data, size_t size, size_t offset = 0U); + + /** + * @brief get user data, max data size is 64. + * @param data user data point, output. + * @param size user data size, need in (0, 64]. The value must be the same as that of SetUserData. + * @param offset user data offset, need in [0, 64), size + offset <= 64. The value must be the same as that of + * SetUserData. + * @return success:0, failed:others. + */ + Status GetUserData(void *data, size_t size, size_t offset = 0U) const; + + /** + * get transaction id. + * only set n-mapping attr value true it can take fetch data transaction id out. + * @return transaction id + */ + uint64_t GetTransactionId() const; + + /** + * set transaction id. + * Enabled only when set n-mapping attr value true. + * @param transaction_id transaction id + */ + void SetTransactionId(uint64_t transaction_id); + + private: + std::shared_ptr impl_; + friend class DataFlowInfoUtils; +}; + +enum class MsgType : uint16_t { + MSG_TYPE_TENSOR_DATA = 0, // tensor data msg type + MSG_TYPE_RAW_MSG = 1, // raw data msg type + MSG_TYPE_USER_DEFINE_START = 1024 +}; + + +class GE_FUNC_VISIBILITY FlowMsg { + public: + FlowMsg() = default; + + virtual ~FlowMsg() = default; + + /** + * @brief get msg type. + * @return msg type. + */ + virtual MsgType GetMsgType() const = 0; + + /** + * @brief set msg type. + * @param msg type. + */ + virtual void SetMsgType(MsgType msg_type) = 0; + + /** + * @brief get tensor. + * when MsgType=TENSOR_DATA_TYPE, can get tensor. + * @return tensor. + * attention: + * Tensor life cycle is same as FlowMsg, Caller should not keep it alone. + */ + virtual Tensor *GetTensor() const = 0; + + /** + * @brief get raw data. + * @param data ptr. + * @param data size. + * @return success:0, failed:others. + */ + virtual Status GetRawData(void *&data_ptr, uint64_t &data_size) const = 0; + + /** + * @brief get ret code. + * @return ret code. + */ + virtual int32_t GetRetCode() const = 0; + + /** + * @brief set ret code. + * @param ret code. + */ + virtual void SetRetCode(int32_t ret_code) = 0; + + /** + * @brief set start time. + * @param start time. + */ + virtual void SetStartTime(uint64_t start_time) = 0; + + /** + * @brief get start time. + * @return start time. + */ + virtual uint64_t GetStartTime() const = 0; + + /** + * @brief set end time. + * @param end time. + */ + virtual void SetEndTime(uint64_t end_time) = 0; + + /** + * @brief get end time. + * @return end time. + */ + virtual uint64_t GetEndTime() const = 0; + + /** + * @brief set flow flags. + * @param flags flags is FlowFlag set, multi FlowFlags can use bit or operation. + */ + virtual void SetFlowFlags(uint32_t flags) = 0; + + /** + * @brief get flow flags. + * @return FlowFlag set, can use bit and operation check if FlowFlag is set. + */ + virtual uint32_t GetFlowFlags() const = 0; + + /** + * @brief get transaction id. + * @return transaction id. + */ + virtual uint64_t GetTransactionId() const = 0; + + /** + * @brief set transaction id. + * @param transaction id. + */ + virtual void SetTransactionId(uint64_t transaction_id) = 0; + + /** + * @brief set user data, max data size is 64. + * @param data user data point, input. + * @param size user data size, need in (0, 64]. + * @param offset user data offset, need in [0, 64), size + offset <= 64. + * @return success:0, failed:others. + */ + virtual Status SetUserData(const void *data, size_t size, size_t offset = 0U) = 0; + + /** + * @brief get user data, max data size is 64. + * @param data user data point, output. + * @param size user data size, need in (0, 64]. The value must be the same as that of SetUserData. + * @param offset user data offset, need in [0, 64), size + offset <= 64. The value must be the same as that of + * SetUserData. + * @return success:0, failed:others. + */ + virtual Status GetUserData(void *data, size_t size, size_t offset = 0U) const = 0; +}; + +using FlowMsgPtr = std::shared_ptr; + +class GE_FUNC_VISIBILITY FlowBufferFactory { + public: + static std::shared_ptr AllocTensor(const std::vector &shape, + DataType data_type, + uint32_t align = 512U); + static FlowMsgPtr AllocTensorMsg(const std::vector &shape, DataType data_type, uint32_t align = 512U); + static FlowMsgPtr AllocRawDataMsg(size_t size, uint32_t align = 512U); + static FlowMsgPtr AllocEmptyDataMsg(MsgType type); + static FlowMsgPtr ToFlowMsg(const Tensor &tensor); + static FlowMsgPtr ToFlowMsg(const RawData &raw_data); +}; +} // namespace ge +#endif // INC_EXTERNAL_GE_GE_DATA_FLOW_API_H diff --git a/inc/air/external/ge/ge_error_codes.h b/inc/air/external/ge/ge_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..c73edfcbd6daaecfda98250513bacb9a3405e3ab --- /dev/null +++ b/inc/air/external/ge/ge_error_codes.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_ERROR_CODES_H_ +#define INC_EXTERNAL_GE_GE_ERROR_CODES_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif +static const uint32_t ACL_ERROR_GE_PARAM_INVALID = 145000U; +static const uint32_t ACL_ERROR_GE_EXEC_NOT_INIT = 145001U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID = 145002U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ID_INVALID = 145003U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID = 145006U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_ADDR_INVALID = 145007U; +static const uint32_t ACL_ERROR_GE_EXEC_MODEL_QUEUE_ID_INVALID = 145008U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_REPEATED = 145009U; +static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_ADDR_INVALID = 145011U; +static const uint32_t ACL_ERROR_GE_DYNAMIC_INPUT_LENGTH_INVALID = 145012U; +static const uint32_t ACL_ERROR_GE_DYNAMIC_BATCH_SIZE_INVALID = 145013U; +static const uint32_t ACL_ERROR_GE_AIPP_BATCH_EMPTY = 145014U; +static const uint32_t ACL_ERROR_GE_AIPP_NOT_EXIST = 145015U; +static const uint32_t ACL_ERROR_GE_AIPP_MODE_INVALID = 145016U; +static const uint32_t ACL_ERROR_GE_OP_TASK_TYPE_INVALID = 145017U; +static const uint32_t ACL_ERROR_GE_OP_KERNEL_TYPE_INVALID = 145018U; +static const uint32_t ACL_ERROR_GE_PLGMGR_PATH_INVALID = 145019U; +static const uint32_t ACL_ERROR_GE_FORMAT_INVALID = 145020U; +static const uint32_t ACL_ERROR_GE_SHAPE_INVALID = 145021U; +static const uint32_t ACL_ERROR_GE_DATATYPE_INVALID = 145022U; +static const uint32_t ACL_ERROR_GE_MEMORY_ALLOCATION = 245000U; +static const uint32_t ACL_ERROR_GE_MEMORY_OPERATE_FAILED = 245001U; +static const uint32_t ACL_ERROR_GE_DEVICE_MEMORY_OPERATE_FAILED = 245002U; +static const uint32_t ACL_ERROR_GE_SUBHEALTHY = 345102U; +static const uint32_t ACL_ERROR_GE_USER_RAISE_EXCEPTION = 345103U; +static const uint32_t ACL_ERROR_GE_DATA_NOT_ALIGNED = 345104U; +static const uint32_t ACL_ERROR_GE_INTERNAL_ERROR = 545000U; +static const uint32_t ACL_ERROR_GE_LOAD_MODEL = 545001U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_MODEL_PARTITION_FAILED = 545002U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED = 545003U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_TASK_PARTITION_FAILED = 545004U; +static const uint32_t ACL_ERROR_GE_EXEC_LOAD_KERNEL_PARTITION_FAILED = 545005U; +static const uint32_t ACL_ERROR_GE_EXEC_RELEASE_MODEL_DATA = 545006U; +static const uint32_t ACL_ERROR_GE_COMMAND_HANDLE = 545007U; +static const uint32_t ACL_ERROR_GE_GET_TENSOR_INFO = 545008U; +static const uint32_t ACL_ERROR_GE_UNLOAD_MODEL = 545009U; +static const uint32_t ACL_ERROR_GE_MODEL_EXECUTE_TIMEOUT = 545601U; +static const uint32_t ACL_ERROR_GE_REDEPLOYING = 545602U; + +#ifdef __cplusplus +} // namespace ge +#endif +#endif // INC_EXTERNAL_GE_GE_ERROR_CODES_H_ diff --git a/inc/air/external/ge/ge_feature_memory.h b/inc/air/external/ge/ge_feature_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..6200eb05f4b1f102aa587558feff07f316fa5a3a --- /dev/null +++ b/inc/air/external/ge/ge_feature_memory.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_FEATURE_MEMORY_H +#define INC_EXTERNAL_GE_GE_FEATURE_MEMORY_H + +#include +#include +#include "ge_api_types.h" +#include "ge_error_codes.h" + +namespace ge { +class GE_FUNC_VISIBILITY FeatureMemory { + public: + class Builder; + + class FeatureMemoryData; + + ~FeatureMemory(); + + FeatureMemory &operator=(const FeatureMemory &) & = delete; + FeatureMemory(const FeatureMemory &) = delete; + + /// + /// @brief get memory type + /// @return memory type + /// + MemoryType GetType() const; + + /// + /// @brief get memory size + /// @return memory size + /// + size_t GetSize() const; + + /// + /// @brief is fixed feature memory + /// @return true if feature memory is fixed, otherwise false + /// + bool IsFixed() const; + + private: + FeatureMemory() = default; + std::shared_ptr data_{nullptr}; +}; +using FeatureMemoryPtr = std::shared_ptr; +} // namespace ge +#endif // INC_EXTERNAL_GE_GE_FEATURE_MEMORY_H diff --git a/inc/air/external/ge/ge_graph_compile_summary.h b/inc/air/external/ge/ge_graph_compile_summary.h new file mode 100644 index 0000000000000000000000000000000000000000..adefb42bab7126f388c7fb4fe9ebe1b1912c9c7b --- /dev/null +++ b/inc/air/external/ge/ge_graph_compile_summary.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_GE_GRAPH_COMPILE_SUMMARY_H +#define INC_EXTERNAL_GE_GE_GRAPH_COMPILE_SUMMARY_H + +#include +#include +#include "ge_api_types.h" +#include "ge_error_codes.h" +#include "ge_feature_memory.h" + +namespace ge { +class GE_FUNC_VISIBILITY CompiledGraphSummary { + public: + class Builder; + class SummaryData; + + ~CompiledGraphSummary(); + CompiledGraphSummary &operator=(const CompiledGraphSummary &) & = delete; + CompiledGraphSummary(const CompiledGraphSummary &) = delete; + + /// + /// @brief get whether or not the graph is static compiled + /// @return return true if static + /// + bool IsStatic() const; + + /// + /// @brief get const memory size after compiled + /// @param [out] size const memory size + /// @return Status result of function + /// + Status GetConstMemorySize(size_t &size) const; + + /// + /// @brief get fearturemap memory size after compiled, without input and output + /// @param [out] size fearturemap memory size + /// @return Status result of function + /// + Status GetFeatureMemorySize(size_t &size) const; + + /// + /// @brief get fix feature memory size after compiled + /// @param [out] size const memory size + /// @return Status result of function + /// + Status GetFixedFeatureMemorySize(size_t &size) const; + + /// + /// @brief get all type feature memory size after compiled + /// @return vector of FeatureMemory pointer + /// + std::vector GetAllFeatureMemoryTypeSize() const; + + /// + /// @brief get refreshable fearturemap memory size after compiled, without input and output and fix memory + /// @param [out] size fearturemap memory size + /// @return Status result of function + /// + Status GetRefreshableFeatureMemorySize(size_t &size) const; + + /// + /// @brief get whether or not the graph support featuremap memory base refreshable + /// @param [out] v refreshable or not + /// @return Status result of function + /// + Status GetFeatureMemoryBaseRefreshable (bool &v) const; + + /// + /// @brief get the used stream number of the whole compiled graph + /// @param [out] num used stream number + /// @return Status result of function + /// + Status GetStreamNum(size_t &num) const; + + /// + /// @brief get the used event number of the whole compiled graph + /// @param [out] num used event number + /// @return Status result of function + /// + Status GetEventNum(size_t &num) const; + + /// + /// @brief get the output tensor shapes of the whole compiled graph + /// @param [out] shapes vector of ge::Shape + /// @return Status result of function + /// + Status GetOutputShapes(std::vector &shapes) const; + + Status GetOutputDtypes(std::vector &dtypes) const; + + Status GetIOIndexesWithSameAddr(std::vector> &io_indexes) const; + + Status GetInputShardMethod(std::map>>> + &device_id_to_tensor_deployment) const; + + private: + CompiledGraphSummary() = default; + std::shared_ptr data_{nullptr}; +}; + +using CompiledGraphSummaryPtr = std::shared_ptr; +} // namespace ge +#endif // INC_EXTERNAL_GE_GE_GRAPH_COMPILE_SUMMARY_H diff --git a/inc/air/external/ge/ge_ir_build.h b/inc/air/external/ge/ge_ir_build.h new file mode 100644 index 0000000000000000000000000000000000000000..d931f4bc4b7460b7c397c915e7f37e1e78071194 --- /dev/null +++ b/inc/air/external/ge/ge_ir_build.h @@ -0,0 +1,223 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_IR_BUILD_H_ +#define INC_EXTERNAL_GE_IR_BUILD_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include +#include "graph/graph.h" +#include "graph/ge_error_codes.h" +namespace ge { +const int32_t IR_MAJOR_VERSION = 1; +const int32_t IR_MINOR_VERSION = 0; +const int32_t IR_PATCH_VERSION = 0; + +struct ModelBufferData { + std::shared_ptr data = nullptr; + uint64_t length; +}; + +struct GraphWithOptions { + ge::Graph graph; + std::map build_options; +}; + +struct WeightRefreshableGraphs { + ge::Graph infer_graph; + ge::Graph var_init_graph; + ge::Graph var_update_graph; +}; + +enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS }; + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param global_options[IN] global init params for build + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map &)) +GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map global_options); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildInitialize(std::map &global_options); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + */ +GE_FUNC_VISIBILITY void aclgrphBuildFinalize(); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param graph[IN] the graph ready to build + * @param options[IN] options used for build + * @param model[OUT] builded model + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, + const std::map &, + ModelBufferData &)) +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, + const std::map &build_options, + ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, + const std::map &build_options, + ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param graphs[IN] the multiple graphs ready to build + * @param options[IN] options used for build + * @param model[OUT] builded model + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const std::vector &graphs, + const std::map &build_options, + ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const std::vector &graphs, + const std::map &build_options, + ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief save model buffer to file + * + * @param output_file[IN] the file path to be saved + * @param model[IN] model buffer data + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char_t *, const ModelBufferData &)) +GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const std::string &output_file, const ModelBufferData &model); + +GE_FUNC_VISIBILITY graphStatus aclgrphSaveModel(const char_t *output_file, const ModelBufferData &model); + +/** + * @ingroup AscendCL + * + * @param origin_graph[IN] the origin graph ready to be converted + * @param const_names[IN] const names in origin graph which to be converted to variable + * @param weight_refreshable_graphs[OUT] refreshable weight graphs + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphConvertToWeightRefreshableGraphs(const ge::Graph &origin_graph, + const std::vector &const_names, WeightRefreshableGraphs &weight_refreshable_graphs); +/** + * @ingroup AscendCL + * @brief build model.Notice the model is stored in buffer + * + * @param graph_with_options[IN] the multiple graphs and build options ready to build + * @param model[OUT] builded bundle model + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphBundleBuildModel(const std::vector &graph_with_options, + ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief save bundle model buffer to file + * + * @param output_file[IN] the file path to be saved + * @param model[IN] model buffer data + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphBundleSaveModel(const char_t *output_file, const ModelBufferData &model); + +/** + * @ingroup AscendCL + * @brief query IR interface version + * + * @param major_version[OUT] IR interface major version + * @param minor_version[OUT] IR interface minor version + * @param patch_version[OUT] IR interface patch version + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGetIRVersion(int32_t *major_version, int32_t *minor_version, + int32_t *patch_version); + +/** + * @ingroup AscendCL + * @brief dump graph + * + * @param graph[IN] the graph ready to build + * @param file[IN] file path + * @param file[IN] file path string len + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char_t *file, const size_t len); + +/** + * @ingroup AscendCL + * @brief create single op graph + * + * @param op_type[IN] the op_type + * @param inputs[IN] the inputdesc + * @param outputs[IN] the outputdesc + * @param graph[OUT] the graph + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector &inputs, + const std::vector &outputs, Graph &graph); + +/** + * @ingroup AscendCL + * @brief create single op graphs + * + * @param json_path[IN] the path of singleop json file + * @param graphs[OUT] the graphs + * @retval GRAPH_SUCCESS The function is successfully executed. + * @retval OtherValues Failure + */ +GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &json_path, std::vector &graphs); + +/** + * @name aclgrphSetOpAttr + * @brief set attribute for operators in the configuration file + * @param graph [IN/OUT] compute graph + * @param attr_type [In] attribute type + * @param cfg_path [IN] the config file path + * @return graphStatus + */ +GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char_t *cfg_path); + +}; // namespace ge +#endif // INC_EXTERNAL_GE_IR_BUILD_H_ diff --git a/inc/air/framework/common/debug/ge_log.h b/inc/air/framework/common/debug/ge_log.h new file mode 100644 index 0000000000000000000000000000000000000000..9037a834bf272427c012fec56ec55c656ceb5941 --- /dev/null +++ b/inc/air/framework/common/debug/ge_log.h @@ -0,0 +1,13 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ +#define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ +#include "common/ge_common/debug/ge_log.h" +#endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ diff --git a/inc/air/framework/common/debug/log.h b/inc/air/framework/common/debug/log.h new file mode 100644 index 0000000000000000000000000000000000000000..98a3eeb86ec85211c29e88d482a582a664fb739e --- /dev/null +++ b/inc/air/framework/common/debug/log.h @@ -0,0 +1,13 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ +#define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ +#include "common/ge_common/debug/log.h" +#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ diff --git a/inc/air/framework/common/fmk_error_codes.h b/inc/air/framework/common/fmk_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..8270b580d8e81a5a299a8a2d0d48d8ee302c7d6a --- /dev/null +++ b/inc/air/framework/common/fmk_error_codes.h @@ -0,0 +1,13 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_FMK_ERROR_CODES_H_ +#define INC_FRAMEWORK_COMMON_FMK_ERROR_CODES_H_ +#include "common/ge_common/fmk_error_codes.h" +#endif // INC_FRAMEWORK_COMMON_FMK_ERROR_CODES_H_ diff --git a/inc/air/framework/common/fmk_types.h b/inc/air/framework/common/fmk_types.h new file mode 100644 index 0000000000000000000000000000000000000000..9ce216701280f497bf5e3d7c830884f939be44ba --- /dev/null +++ b/inc/air/framework/common/fmk_types.h @@ -0,0 +1,13 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_FMK_TYPES_H_ +#define INC_FRAMEWORK_COMMON_FMK_TYPES_H_ +#include "common/ge_common/fmk_types.h" +#endif // INC_FRAMEWORK_COMMON_FMK_TYPES_H_ diff --git a/inc/air/framework/common/ge_compiler_options.h b/inc/air/framework/common/ge_compiler_options.h new file mode 100644 index 0000000000000000000000000000000000000000..fd00a1830df2eb39c9c34221474892dec566ac6b --- /dev/null +++ b/inc/air/framework/common/ge_compiler_options.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_GE_COMPILER_OPTIONS_H_ +#define INC_FRAMEWORK_COMMON_GE_COMPILER_OPTIONS_H_ + +namespace ge { +#ifdef __GNUC__ +#define GE_ATTRIBUTE_UNUSED __attribute__((unused)) +#define GE_FUNCTION_IDENTIFIER __PRETTY_FUNCTION__ +#define GE_BUILTIN_PREFETCH(args_addr) __builtin_prefetch(args_addr) +#else +#define GE_ATTRIBUTE_UNUSED +#define GE_FUNCTION_IDENTIFIER __FUNCSIG__ +#define GE_BUILTIN_PREFETCH(args_addr) +#endif +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_GE_COMPILER_OPTIONS_H_ diff --git a/inc/air/framework/common/ge_format_util.h b/inc/air/framework/common/ge_format_util.h new file mode 100644 index 0000000000000000000000000000000000000000..7ac51618505d6866ed70a9d754b6e1f4c1378d22 --- /dev/null +++ b/inc/air/framework/common/ge_format_util.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_GE_FORMAT_UTIL_H_ +#define INC_FRAMEWORK_COMMON_GE_FORMAT_UTIL_H_ + +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "graph/tensor.h" + +namespace ge { +class GE_FUNC_VISIBILITY GeFormatUtil { + public: + /// + /// @name TransShape + /// @brief transform the shape of tensor according to destination format + /// @param [in] src_desc source tensor desc + /// @param [in] dst_format destination format + /// @param [out] dst_shape destination shape + /// @return Status + /// + static Status TransShape(const TensorDesc &src_desc, const Format dst_format, std::vector &dst_shape); +}; +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_GE_FORMAT_UTIL_H_ diff --git a/inc/air/framework/common/ge_inner_error_codes.h b/inc/air/framework/common/ge_inner_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..cb01e8f31ccde309679b0217471e9f7f4eacdff2 --- /dev/null +++ b/inc/air/framework/common/ge_inner_error_codes.h @@ -0,0 +1,16 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*lint -e* */ +#ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ +#define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ +#include "external/ge/ge_api_error_codes.h" +#include "external/ge/ge_error_codes.h" +#include "common/ge_common/ge_inner_error_codes.h" +#endif // INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ diff --git a/inc/air/framework/common/ge_model_inout_types.h b/inc/air/framework/common/ge_model_inout_types.h new file mode 100644 index 0000000000000000000000000000000000000000..ded022edef06a90d7e5c677b411918f18c9bc394 --- /dev/null +++ b/inc/air/framework/common/ge_model_inout_types.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_GE_MODEL_INOUT_TYPES_H_ +#define INC_FRAMEWORK_COMMON_GE_MODEL_INOUT_TYPES_H_ +#include +#include +#include +#include + +#include "external/graph/types.h" + +namespace ge { + struct ModelInOutTensorDesc { + std::string name; + size_t size; + Format format; + DataType dataType; + std::vector dims; + std::vector dimsV2; // supported for static aipp scene + std::vector> shape_ranges; + }; + + struct ModelInOutInfo { + std::vector input_desc; + std::vector output_desc; + std::vector dynamic_batch; + std::vector> dynamic_hw; + std::vector> dynamic_dims; + std::vector dynamic_output_shape; + std::vector data_name_order; + }; + + enum class ModelDescType : uint32_t { + MODEL_INPUT_DESC, + MODEL_OUTPUT_DESC, + MODEL_DYNAMIC_BATCH, + MODEL_DYNAMIC_HW, + MODEL_DYNAMIC_DIMS, + MODEL_DYNAMIC_OUTPUT_SHAPE, + MODEL_DESIGNATE_SHAPE_ORDER + }; + + struct ModelDescTlvConfig { + int32_t type = 0; + uint32_t length = 0U; + const uint8_t *value = nullptr; + }; + + struct ModelTensorDescBaseInfo { + size_t size = 0U; + Format format; + DataType dt; + uint32_t name_len = 0U; + uint32_t dims_len = 0U; + uint32_t dimsV2_len = 0U; + uint32_t shape_range_len = 0U; + }; +} // namespace ge +#endif diff --git a/inc/air/framework/common/ge_types.h b/inc/air/framework/common/ge_types.h new file mode 100644 index 0000000000000000000000000000000000000000..9c150efd6e12cd778b86f77b38665e53b5e42e7b --- /dev/null +++ b/inc/air/framework/common/ge_types.h @@ -0,0 +1,13 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_GE_TYPES_H_ +#define INC_FRAMEWORK_COMMON_GE_TYPES_H_ +#include "common/ge_common/ge_types.h" +#endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ diff --git a/inc/air/framework/common/ge_visibility.h b/inc/air/framework/common/ge_visibility.h new file mode 100644 index 0000000000000000000000000000000000000000..a60fea6f8a359d2a333ff8cb6e54fc96e4eb5037 --- /dev/null +++ b/inc/air/framework/common/ge_visibility.h @@ -0,0 +1,21 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_COMMON_GE_VISIBILITY_H_ +#define AIR_CXX_INC_FRAMEWORK_COMMON_GE_VISIBILITY_H_ + +#if defined(_MSC_VER) +#define VISIBILITY_EXPORT _declspec(dllexport) +#define VISIBILITY_HIDDEN _declspec(dllimport) +#else +#define VISIBILITY_EXPORT __attribute__((visibility("default"))) +#define VISIBILITY_HIDDEN __attribute__((visibility("hidden"))) +#endif + +#endif // AIR_CXX_INC_FRAMEWORK_COMMON_GE_VISIBILITY_H_ diff --git a/inc/air/framework/common/helper/model_helper.h b/inc/air/framework/common/helper/model_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..66bdfb6262c98e4cd8dfe724174289303ad7b5c0 --- /dev/null +++ b/inc/air/framework/common/helper/model_helper.h @@ -0,0 +1,172 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ +#define INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ + +#include +#include + +#include "framework/common/helper/om_file_helper.h" +#include "framework/common/helper/model_save_helper.h" +#include "framework/common/types.h" +#include "graph/model.h" +#include "platform/platform_info.h" +#include "common/op_so_store/op_so_store.h" +#include "common/host_resource_center/host_resource_serializer.h" + +namespace ge { +class GeModel; +class GeRootModel; +class GE_FUNC_VISIBILITY ModelHelper : public ModelSaveHelper { + public: + ModelHelper() noexcept = default; + virtual ~ModelHelper() override = default; + ModelHelper(const ModelHelper &) = default; + ModelHelper &operator=(const ModelHelper &) & = default; + + Status SaveToOmModel(const GeModelPtr &ge_model, const std::string &output_file, ge::ModelBufferData &model, + const GeRootModelPtr &ge_root_model = nullptr) override; + Status GenerateGeModel(const OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, GeModelPtr &first_ge_model, + const size_t mode_index, const bool is_dyn_root) const; + Status SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const std::string &output_file, ModelBufferData &model, + const bool is_unknown_shape) override; + Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file) const; + + Status LoadModel(const ge::ModelData &model_data); + Status LoadRootModel(const ge::ModelData &model_data); + static Status GetModelFileHead(const ge::ModelData &model_data, const ModelFileHeader *&file_header); + static Status SetModelToGeModel(const GeModelPtr &ge_model, const GeModelPtr &first_ge_model, Model &model); + static Status SaveBundleModelBufferToMem(const std::vector &model_buffers, + ModelBufferData &output_buffer); + static std::string GetOutputFileName() { + return output_file_name_; + } + Status LoadPartInfoFromModel(const ge::ModelData &model_data, ModelPartition &partition); + + GeModelPtr GetGeModel(); + GeRootModelPtr GetGeRootModel(); + virtual void SetSaveMode(const bool val) override { + is_offline_ = val; + } + + void SetSharedWeightFlag(const bool val) { + is_shared_weight_ = val; + } + + bool GetModelType() const { + return is_unknown_shape_model_; + } + + Status GetBaseNameFromFileName(const std::string &file_name, std::string &base_name) const; + Status GetModelNameFromMergedGraphName(const ComputeGraphPtr &compute_graph, std::string &model_name) const; + + // for soft sync op + Status GetHardwareInfo(std::map &options) const; + Status HandleDeviceInfo(fe::PlatFormInfos &platform_infos) const; + static Status InitRuntimePlatform(); + Status GetPlatformInfo(int32_t device_id, const std::string &soc_version, fe::PlatformInfo &platform_info, + int32_t &virtual_type) const; + Status SetPlatformInfos(const std::string &soc_version, const fe::PlatformInfo &platform_info, + fe::PlatFormInfos &platform_infos) const; + + Status CheckOsCpuInfoAndOppVersion(); + + Status GetSoBinData(const string &cpu_info, const string &os_info); + + const uint8_t *GetOpSoStoreData() const; + + size_t GetOpStoreDataSize() const; + + void SetRepackSoFlag(const bool val); + + Status PackSoToModelData(const ModelData &model_data, const std::string &output_file, ModelBufferData &model_buffer, + const bool save_to_file = true); + bool IsSoStore() const { + return is_so_store_; + } + + static constexpr const char_t *kFilePreffix = ".exeom"; + static constexpr const char_t *kDebugPreffix = ".dbg"; + + protected: + Status SaveModelCustAICPU(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0U) const; + static Status GetOppVersion(std::string &version); + static Status EnsureKernelBuilt(const GeModelPtr &model); + Status SaveSoStoreModelPartitionInfo(std::shared_ptr &om_file_save_helper, + const GeRootModelPtr &ge_root_model, string &output_file_name, + const GeModelPtr &first_ge_model); + Status SaveModelHeader(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_num = 1U, const bool need_check_os_cpu = false, + const bool is_unknow_shape = false) const; + Status SaveModelPartition(std::shared_ptr &om_file_save_helper, const ModelPartitionType type, + const uint8_t *const data, const size_t size, const size_t model_index) const; + Status SaveModelWeights(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0U) const; + Status SaveModelIntroduction(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const bool is_dynamic = false) const; + Status SaveModelTbeKernel(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0U) const; + GeModelPtr model_; + bool is_so_store_ = false; + + private: + bool is_assign_model_ = false; + bool is_offline_ = true; + bool save_to_file_ = true; + bool is_unknown_shape_model_ = false; + bool is_shared_weight_ = false; + const ModelFileHeader *file_header_ = nullptr; + GeRootModelPtr root_model_; + OpSoStore op_so_store_; + bool is_repack_so_ = false; + bool is_need_compress_ = true; + static std::string output_file_name_; + std::unordered_set custom_compiler_versions_{}; + HostResourceSerializer host_serializer_; + + bool IsPartitionedGraph(const GeModelPtr &cur_model) const; + + Status GenerateGeRootModel(const OmFileLoadHelper &om_load_helper, const ModelData &model_data); + + Status CheckIfWeightPathValid(const ge::ComputeGraphPtr &graph, const ge::ModelData &model_data) const; + Status LoadModelData(const OmFileLoadHelper &om_load_helper, const GeModelPtr &cur_model, + const GeModelPtr &first_ge_model, const size_t mode_index) const; + virtual Status LoadWeights(const OmFileLoadHelper &om_load_helper, const GeModelPtr &cur_model, + const size_t mode_index) const; + Status LoadTask(const OmFileLoadHelper &om_load_helper, const GeModelPtr &cur_model, const size_t mode_index) const; + Status LoadTBEKernelStore(const OmFileLoadHelper &om_load_helper, const GeModelPtr &cur_model, + const size_t mode_index) const; + Status LoadCustAICPUKernelStore(const OmFileLoadHelper &om_load_helper, const GeModelPtr &cur_model, + const size_t mode_index) const; + + Status SaveModelDef(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + Buffer &model_buffer, const size_t model_index = 0U) const; + Status SaveSizeToModelDef(const GeModelPtr &ge_model, const size_t model_index) const; + + Status SaveModelTaskDef(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + Buffer &task_buffer, const size_t model_index = 0U) const; + Status SaveAllModelPartiton(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + Buffer &model_buffer, Buffer &task_buffer, const size_t model_index = 0U) const; + + Status LoadOpSoBin(const OmFileLoadHelper &om_load_helper, const GeRootModelPtr &ge_root_model) const; + Status LoadTilingData(const OmFileLoadHelper &om_load_helper, const GeRootModelPtr &ge_root_model) const; + Status SaveTilingData(std::shared_ptr &om_file_save_helper, const GeRootModelPtr &ge_root_model); + void SaveOpSoInfo(const GeRootModelPtr &ge_root_model) const; + Status SetModelCompilerVersion(const GeModelPtr &first_ge_model); + Status LoadAndStoreOppSo(const string &path, bool is_split); + void SaveOutNodesFromRootGraph(const GeRootModelPtr &ge_root_model, GeModelPtr &first_ge_model) const; + Status LoadAndStoreOppSo(const std::vector &op_so_list, const SoBinType so_bin_type); + Status SaveSpaceRegistrySoBin(const GeRootModelPtr &ge_root_model, const GeModelPtr &first_ge_model, + string &output_file_name); + Status SaveOpMasterDeviceSoBin(const GeRootModelPtr &ge_root_model); +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ diff --git a/inc/air/framework/common/helper/model_save_helper.h b/inc/air/framework/common/helper/model_save_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..571a24b3e32e5a226335b38709ebd3cb20278ea5 --- /dev/null +++ b/inc/air/framework/common/helper/model_save_helper.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_HELPER_MODEL_SAVE_HELPER_H_ +#define INC_FRAMEWORK_COMMON_HELPER_MODEL_SAVE_HELPER_H_ + +#include +#include + +#include "framework/common/helper/model_save_helper_factory.h" +#include "framework/common/helper/om_file_helper.h" +#include "common/model/ge_model.h" +#include "common/model/ge_root_model.h" +#include "framework/common/types.h" +#include "graph/model.h" +#include "platform/platform_info.h" +#include "common/op_so_store/op_so_store.h" + +namespace ge { +class GE_FUNC_VISIBILITY ModelSaveHelper { + public: + ModelSaveHelper() = default; + + virtual ~ModelSaveHelper() = default; + + virtual Status SaveToOmRootModel(const GeRootModelPtr &ge_root_model, + const std::string &output_file, + ModelBufferData &model, + const bool is_unknown_shape) = 0; + + virtual Status SaveToOmModel(const GeModelPtr &ge_model, + const std::string &output_file, + ModelBufferData &model, + const GeRootModelPtr &ge_root_model = nullptr) = 0; + + virtual void SetSaveMode(const bool val) = 0; + protected: + ModelSaveHelper(const ModelSaveHelper &) = default; + ModelSaveHelper &operator=(const ModelSaveHelper &) & = default; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_SAVE_HELPER_H_ diff --git a/inc/air/framework/common/helper/model_save_helper_factory.h b/inc/air/framework/common/helper/model_save_helper_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..2c1f43f24cdf0cca5b0dd5941cd808ac55a65fb6 --- /dev/null +++ b/inc/air/framework/common/helper/model_save_helper_factory.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_HELPER_MODEL_SAVE_HELPER_FACTORY_H_ +#define INC_FRAMEWORK_COMMON_HELPER_MODEL_SAVE_HELPER_FACTORY_H_ + +#include +#include + +#include "framework/common/ge_types.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +class ModelSaveHelper; +using ModelSaveHelperPtr = std::shared_ptr; + +class ModelSaveHelperFactory { + public: + using ModelSaveHelperCreatorFun = std::function; + + static ModelSaveHelperFactory &Instance() { + static ModelSaveHelperFactory instance; + return instance; + } + + ModelSaveHelperPtr Create(const OfflineModelFormat type) { + const auto iter = creator_map_.find(type); + if (iter == creator_map_.end()) { + return nullptr; + } + return iter->second(); + } + + // ModelSaverHelper registerar + class Registerar { + public: + Registerar(const OfflineModelFormat type, const ModelSaveHelperCreatorFun &func) { + ModelSaveHelperFactory::Instance().RegisterCreator(type, func); + } + + ~Registerar() = default; + }; + + private: + ModelSaveHelperFactory() = default; + ~ModelSaveHelperFactory() = default; + + // register creator, this function will can in constructor + void RegisterCreator(const OfflineModelFormat type, const ModelSaveHelperCreatorFun &func) { + const auto iter = creator_map_.find(type); + if (iter != creator_map_.end()) { + return; + } + + creator_map_[type] = func; + } + + std::map creator_map_; +}; +} // namespace ge + +#define REGISTER_MODEL_SAVE_HELPER(type, clazz) \ + namespace { \ + ModelSaveHelperPtr Creator_##type##_Model_Save_Helper() { \ + try { \ + return std::make_shared(); \ + } catch (...) { \ + return nullptr; \ + } \ + } \ + ModelSaveHelperFactory::Registerar g_##type##_Model_Save_Helper_Creator(type, Creator_##type##_Model_Save_Helper); \ + } // namespace + +#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_SAVE_HELPER_H_ diff --git a/inc/air/framework/common/helper/nano_model_save_helper.h b/inc/air/framework/common/helper/nano_model_save_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..6a61298f67676507d7500c371e91557ee79df61d --- /dev/null +++ b/inc/air/framework/common/helper/nano_model_save_helper.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_HELPER_NANO_MODEL_SAVE_HELPER_H_ +#define INC_FRAMEWORK_COMMON_HELPER_NANO_MODEL_SAVE_HELPER_H_ + +#include "framework/common/helper/model_helper.h" + +namespace ge { +class GE_FUNC_VISIBILITY NanoModelSaveHelper : public ModelHelper { + public: + NanoModelSaveHelper() = default; + virtual ~NanoModelSaveHelper() override = default; + NanoModelSaveHelper(const NanoModelSaveHelper &) = default; + NanoModelSaveHelper &operator=(const NanoModelSaveHelper &) & = default; + + virtual Status SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const std::string &output_file, + ModelBufferData &model, const bool is_unknown_shape) override; + + virtual void SetSaveMode(const bool val) override { + is_offline_ = val; + } + + private: + Status SaveToDbg(const GeModelPtr &ge_model, const std::string &output_file) const; + Status SaveToExeOmModel(const GeModelPtr &ge_model, const std::string &output_file, + ModelBufferData &model); + Status SaveAllModelPartiton(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + Status SaveModelDesc(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, const size_t model_index = 0UL); + Status SaveStaticTaskDesc(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, const size_t model_index = 0UL); + Status SaveDynamicTaskDesc(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, const size_t model_index = 0UL); + Status SaveModelWeights(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL) const; + Status SaveTaskParam(std::shared_ptr &om_file_save_helper, + const GeModelPtr &ge_model, const size_t model_index = 0UL); + Status SaveModelTbeKernel(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL) const; + Status SaveModelDescExtend(shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + + private: + void SetModelDescInfo(const std::shared_ptr &buff) { model_desc_info_ = buff; } + void SetStaticTaskInfo(const std::shared_ptr &buff) { static_task_info_ = buff; } + void SetDynamicTaskInfo(const std::shared_ptr &buff) { dynamic_task_info_ = buff; } + void SetTaskParamInfo(const std::shared_ptr &buff) { task_param_info_ = buff; } + void SetModelExtendInfo(const std::shared_ptr &buff) { model_extend_info_ = buff; } + + std::shared_ptr model_desc_info_ = nullptr; + std::shared_ptr static_task_info_ = nullptr; + std::shared_ptr dynamic_task_info_ = nullptr; + std::shared_ptr task_param_info_ = nullptr; + std::shared_ptr model_extend_info_ = nullptr; + bool is_offline_ = true; + std::unordered_map search_ids_; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_HELPER_NANO_MODEL_SAVE_HELPER_H_ diff --git a/inc/air/framework/common/helper/om_file_helper.h b/inc/air/framework/common/helper/om_file_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..b9f4125da6c2bdfb2cf3577290aea92f46f43878 --- /dev/null +++ b/inc/air/framework/common/helper/om_file_helper.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ +#define INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ + +#include +#include + +#include "external/ge/ge_ir_build.h" +#include "framework/common/types.h" +#include "framework/common/ge_types.h" +#include "graph/model.h" + +namespace ge { +struct ModelPartition { + ModelPartitionType type; + const uint8_t *data = nullptr; + uint64_t size = 0UL; +}; + +struct OmFileContext { + std::vector partition_datas_; + std::vector partition_table_; + uint64_t model_data_len_ = 0UL; +}; + +class GE_FUNC_VISIBILITY OmFileLoadHelper { + public: + Status Init(const ModelData &model); + + Status Init(uint8_t *const model_data, const uint32_t model_data_size); + + Status Init(uint8_t *const model_data, const uint32_t model_data_size, const uint32_t model_num); + + Status Init(uint8_t *const model_data, const uint64_t model_data_size, + const ModelFileHeader *file_header = nullptr); + + Status Init(uint8_t *const model_data, + const uint64_t model_data_size, + const uint32_t model_num, + const ModelFileHeader *file_header = nullptr); + + Status GetModelPartition(const ModelPartitionType type, ModelPartition &partition); + + Status GetModelPartition(const ModelPartitionType type, ModelPartition &partition, const size_t model_index) const; + + const std::vector &GetModelPartitions(const size_t model_index) const; + + Status CheckModelCompatibility(const Model &model) const; + + static bool CheckPartitionTableNum(const uint32_t partition_num); + + OmFileContext context_; + + std::vector model_contexts_; + + private: + Status LoadModelPartitionTable(const uint8_t *const model_data, const uint64_t model_data_size, + const size_t model_index, size_t &mem_offset, + const ModelFileHeader *file_header = nullptr); + + Status LoadModelPartitionTable(const uint8_t *const model_data, + const uint64_t model_data_size, + const uint32_t model_num, + const ModelFileHeader *file_header = nullptr); + + bool is_inited_{false}; +}; + +class GE_FUNC_VISIBILITY OmFileSaveHelper { + public: + ModelFileHeader &GetModelFileHeader() { return model_header_; } + + ModelPartitionTable *GetPartitionTable(); + + Status AddPartition(const ModelPartition &partition); + + Status AddPartition(const ModelPartition &partition, const size_t cur_index); + + Status SaveModel(const char_t *const output_file, ModelBufferData &model, const bool save_to_file = true); + + ModelPartitionTable *GetPartitionTable(const size_t cur_ctx_index); + + private: + ModelFileHeader model_header_; + std::vector model_contexts_; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ diff --git a/inc/air/framework/common/helper/pre_model_helper.h b/inc/air/framework/common/helper/pre_model_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..0ba93d86f06c98b7641e7696c8a64a8c9da99999 --- /dev/null +++ b/inc/air/framework/common/helper/pre_model_helper.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_HELPER_PRE_MODEL_HELPER_H_ +#define INC_FRAMEWORK_COMMON_HELPER_PRE_MODEL_HELPER_H_ + +#include "framework/common/helper/model_helper.h" + +namespace ge { +class GE_FUNC_VISIBILITY PreModelHelper : public ModelHelper { + public: + PreModelHelper() = default; + virtual ~PreModelHelper() override = default; + PreModelHelper(const PreModelHelper &) = default; + PreModelHelper &operator=(const PreModelHelper &) & = default; + + Status SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const std::string &output_file, + ModelBufferData &model, const bool is_unknown_shape) override; + + private: + Status SaveToExeOmModel(const GeModelPtr &ge_model, const std::string &output_file, ModelBufferData &model, + const GeRootModelPtr &ge_root_model = nullptr); + Status SaveAllModelPartiton(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + Status SavePreModelHeader(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model) const; + Status SaveModelDesc(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + Status SaveTaskDesc(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + Status SaveKernelArgs(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + Status SaveKernelBin(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL); + Status SaveModelCustAICPU(std::shared_ptr &om_file_save_helper, const GeModelPtr &ge_model, + const size_t model_index = 0UL) const; + + private: + void SetKernelArgsInfo(const std::shared_ptr &buff) { + kernel_args_info_ = buff; + } + void SetTaskDescInfo(const std::shared_ptr &buff) { + task_desc_info_ = buff; + } + void SetModelDescInfo(const std::shared_ptr &buff) { + model_desc_info_ = buff; + } + void SetKernelBinInfo(const std::shared_ptr &buff) { + kernel_bin_info_ = buff; + } + + std::shared_ptr kernel_args_info_ = nullptr; + std::shared_ptr task_desc_info_ = nullptr; + std::shared_ptr model_desc_info_ = nullptr; + std::shared_ptr kernel_bin_info_ = nullptr; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_HELPER_PRE_MODEL_HELPER_H_ diff --git a/inc/air/framework/common/host_resource_center/host_resource_center.h b/inc/air/framework/common/host_resource_center/host_resource_center.h new file mode 100644 index 0000000000000000000000000000000000000000..96c02e216eb2ccebe41b569544173c6f4e7eef3b --- /dev/null +++ b/inc/air/framework/common/host_resource_center/host_resource_center.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_HOST_RESOURCE_CENTER_H +#define AIR_CXX_HOST_RESOURCE_CENTER_H +#include +#include "common/host_resource_center/host_resource_manager.h" +namespace ge { +using HostResourceMgrPtr = std::unique_ptr; +enum class HostResourceType : uint32_t { + kWeight = 0, // including const/constant + kTilingData, + kNum, +}; +class HostResourceCenter { + public: + HostResourceCenter(); + const HostResourceManager *GetHostResourceMgr(HostResourceType type) const; + Status TakeOverHostResources(const ComputeGraphPtr &root_graph); + ~HostResourceCenter() = default; + + private: + std::array(HostResourceType::kNum)> resource_managers_{nullptr}; +}; +using HostResourceCenterPtr = std::shared_ptr; +} // namespace ge +#endif // AIR_CXX_HOST_RESOURCE_CENTER_H \ No newline at end of file diff --git a/inc/air/framework/common/host_resource_center/host_resource_manager.h b/inc/air/framework/common/host_resource_center/host_resource_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..2dcd59fde272f426f80c7e958a5e88308326ae76 --- /dev/null +++ b/inc/air/framework/common/host_resource_center/host_resource_manager.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_HOST_RESOURCE_MANAGER_H +#define AIR_CXX_HOST_RESOURCE_MANAGER_H + +#include "graph/host_resource/host_resource.h" +#include "graph/detail/attributes_holder.h" +#include "ge/ge_api_types.h" + +namespace ge { +class HostResourceManager { + public: + HostResourceManager() = default; + virtual const HostResource *GetResource(const std::shared_ptr &attr_holder, int64_t type) const = 0; + virtual Status AddResource(const std::shared_ptr &attr_holder, int64_t type, + std::shared_ptr &host_resource) = 0; + virtual Status TakeoverResources(const std::shared_ptr &attr_holder) = 0; + virtual ~HostResourceManager() = default; + + HostResourceManager(const HostResourceManager &host_resource_mgr) = delete; + HostResourceManager(const HostResourceManager &&host_resource_mgr) = delete; + HostResourceManager &operator=(const HostResourceManager &host_resource_mgr) = delete; + HostResourceManager &operator=(HostResourceManager &&host_resource_mgr) = delete; +}; +} // namespace ge +#endif // AIR_CXX_HOST_RESOURCE_MANAGER_H diff --git a/inc/air/framework/common/l2_cache_optimize.h b/inc/air/framework/common/l2_cache_optimize.h new file mode 100644 index 0000000000000000000000000000000000000000..6f7081a8896c6e7aae7b5b8a852bf11c24518428 --- /dev/null +++ b/inc/air/framework/common/l2_cache_optimize.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ +#define INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ + +#include + +#include +#include +#include +#include + +#include "framework/common/types.h" +#include "framework/common/util.h" +#include "graph/compute_graph.h" + +namespace ge { +// Size of RC memory alignment, 2M +constexpr size_t ALIGN_SIZE = 2097152U; + +constexpr uint32_t RC_VALUE_DEFAULT = 1U; +constexpr uint32_t RC_VALUE_MAX = 32U; + +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ diff --git a/inc/air/framework/common/op/ge_op_utils.h b/inc/air/framework/common/op/ge_op_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f75731372117bb4452a74cdd0d68d1c57744c233 --- /dev/null +++ b/inc/air/framework/common/op/ge_op_utils.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ +#define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ + +#include +#include + +#include "graph/debug/ge_attr_define.h" +#include "framework/common/util.h" +#include "graph/attr_value.h" +#include "graph/ge_tensor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "proto/insert_op.pb.h" + +namespace ge { + +// Add Sub Mul +extern GE_FUNC_VISIBILITY const uint32_t ADD_INPUT_NUM; +extern GE_FUNC_VISIBILITY const uint32_t MUL_INPUT_NUM; + +// Permute +extern GE_FUNC_VISIBILITY const int32_t PERMUTE_ORDER_NUM; + +// Ssd PriroBox +extern GE_FUNC_VISIBILITY const float64_t SSD_PRIORBOX_ASPECT_RATIO_VALUE; + +extern GE_FUNC_VISIBILITY const uint32_t STRIDEDSLICE_INPUT_NUM; + +// Switch +extern GE_FUNC_VISIBILITY const uint32_t SWITCH_INPUT_NUM; +extern GE_FUNC_VISIBILITY const uint32_t SWITCH_OUTPUT_NUM; +extern GE_FUNC_VISIBILITY const uint32_t SWITCH_FALSE_OUTPUT; +extern GE_FUNC_VISIBILITY const uint32_t SWITCH_TRUE_OUTPUT; +extern GE_FUNC_VISIBILITY const uint32_t SWITCH_DATA_INPUT; +extern GE_FUNC_VISIBILITY const uint32_t SWITCH_PRED_INPUT; + +// Merge +extern GE_FUNC_VISIBILITY const int32_t MERGE_DATA_OUTPUT; +extern GE_FUNC_VISIBILITY const int32_t MERGE_INDEX_OUTPUT; + +// FunctionOp +extern GE_FUNC_VISIBILITY const uint32_t IF_COND_INPUT; +extern GE_FUNC_VISIBILITY const uint32_t FOR_START_INPUT; +extern GE_FUNC_VISIBILITY const uint32_t FOR_LIMIT_INPUT; +extern GE_FUNC_VISIBILITY const uint32_t FOR_DELTA_INPUT; +extern GE_FUNC_VISIBILITY const uint32_t FOR_DATA_INPUT; + +extern GE_FUNC_VISIBILITY const int32_t NORMAL_TENSOR_SIZE; +/*lint -e148*/ +class GE_FUNC_VISIBILITY OpUtils { + public: + /// + /// @brief Extract AIPP parameters from AttrDefMap and splice them + /// @param [in] aipp_attr attr of operator + /// @param [out] aipp_params aipp parameters + /// @return enum of tagCCAippInputFormat + /// + + static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, + domi::AippOpParams &aipp_params); + template + static void SliceData(const std::vector &input, const int64_t chunk_size, std::vector &output, + const int64_t begin, const int64_t out_dim, const int64_t stride); + template + static Status SetDataByDataType(const size_t out_size, const std::vector &chunk_input, + const std::vector &chunk_output, GeTensor *const output); + template + static Status SetOutputSliceDataByDataType(void *const data, const int64_t data_size, + const std::vector &input_dims, + const std::vector &begin, + const std::vector &output_dims, + ge::GeTensor *const output, const std::vector &stride); + static Status SetOutputSliceData(void *const data, const int64_t data_size, const int32_t data_type, + const std::vector &input_dims, const std::vector &begin, + const std::vector &output_dims, GeTensor *const output, + const std::vector &stride); + static Status GetShapeDataFromConstTensor(const ConstGeTensorPtr &tensor, const DataType type, + std::vector &dims); + /// Return true if node is hcom node and does not support addr refresh + /// which means its io can not support zero copy + /// \param node + /// \return + static bool IsHcomNodeNotSupportAddrRefresh(const OpDescPtr &op_desc); + /// + /// @brief get memory size for constant and const with DT_STRING data type + /// @param [in] op_desc + /// @param [out] mem_size + /// @return SUCCESS + /// + static Status GetConstantStrMemSize(const OpDescPtr &op_desc, int64_t &mem_size); +}; +/*lint +e148*/ +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ diff --git a/inc/air/framework/common/op/op_parser_util.h b/inc/air/framework/common/op/op_parser_util.h new file mode 100644 index 0000000000000000000000000000000000000000..2e7fefcd84af4e80a41d2ecdc4885ea3743e55ad --- /dev/null +++ b/inc/air/framework/common/op/op_parser_util.h @@ -0,0 +1,412 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ +#define INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ + +#include +#include +#include + +namespace ge { +// general +const float DEFAULT_ALPHA_VALUE = 1.0; +const float DEFAULT_BETA_VALUE = 0.0; +const uint32_t NORMAL_INPUT_NUM = 1; +const uint32_t NORMAL_OUTPUT_NUM = 1; +const uint32_t NORMAL_WORKSPACE_NUM = 0; +const int32_t NORMAL_1D_DIM_NUM = 1; +const int32_t NORMAL_SCALE_DIM_NUM = 0; +const int32_t NORMAL_TENSOR_SIZE = 4; +const uint32_t DEFAULT_REAL_DIM_CNT = 4; + +// const +const uint32_t CONST_OP_INPUT_NUM = 0; +const uint32_t CONST_OP_NORMAL_WEIGHT_SIZE = 1; + +// MatMul +const uint32_t MATMUL_INPUT_NUM = 2; + +// ActivationGrad +const int32_t ACTIVATIONGRAD_INPUT_NUM = 2; + +// FusedBatchNorm +const int32_t FUSED_BATCH_NORM_WORKSPACE_NUM = 1; +const int32_t FUSED_BATCH_NORM_INPUT_NUM = 5; +const int32_t FUSED_BATCH_NORM_OUTPUT_NUM = 5; +// FusedBatchNormGrad +const int32_t FUSEDBATCHNORMGRAD_WORKSPACE_NUM = 1; +const int32_t FUSEDBATCHNORMGRAD_INPUT_NUM = 5; +const int32_t FUSEDBATCHNORMGRAD_OUTPUT_NUM = 3; + +// conv +const uint32_t CONVOLUTION_WORKSPACE_NUM = 1; +const uint32_t CONVOLUTION_PAD_SIZE = 4; +const uint32_t CONVOLUTION_STRIDE_SIZE = 2; +const uint32_t CONVOLUTION_DILATION_SIZE = 2; +const int32_t CONVOLUTION_ADJ_SIZE = 2; +const int32_t CONVOLUTION_TARGET_SHAPE_SIZE = 2; + +// ConvGradFilter +const uint32_t CONVGRADFILTER_WORKSPACE_NUM = 1; +const uint32_t CONVGRADFILTER_INPUT_NUM = 3; + +// Pooling +const uint32_t POOLING_WINDOW_SIZE = 2; +const uint32_t POOLING_STRIDE_SIZE = 2; +const uint32_t POOLING_PAD_SIZE = 4; + +// Add Sub Mul +const uint32_t ADD_INPUT_NUM = 2; +const uint32_t SUB_INPUT_NUM = 2; +const uint32_t MUL_INPUT_NUM = 2; +const uint32_t DIV_INPUT_NUM = 2; +const uint32_t ADD_WORKSPACE_NUM = 1; +const uint32_t SUB_WORKSPACE_NUM = 1; +const uint32_t MUL_WORKSPACE_NUM = 1; +const uint32_t DIV_WORKSPACE_NUM = 1; + +const int32_t DEFAULT_AXIS_VALUE = -1; + +const int32_t RESHAPE_AXIS_DEFAULT_VALUE = 0; +const int32_t RESHAPE_NUM_AXES_DEFAULT_VALUE = -1; +const uint32_t RESHAPE_WORKSPACE_NUM = 1; + +const uint32_t FLATTEN_WORKSPACE_NUM = 1; + +const int32_t CONCAT_MIN_INPUT_SIZE = 1; +const int32_t CONCAT_DEFAULT_AXIS = 1; +const uint32_t CONCAT_WORKSPACE_NUM = 1; + +// The value for LRN parameters +const uint32_t LRN_DEFAULT_NORM_REGION = 0; +const float LRN_DEFAULT_K = 1.0; +const uint32_t LRN_DEFAULT_LOCAL_SIZE = 5; +const float LRN_DEFAULT_ALPHA = 1.0; +const float LRN_DEFAULT_BETA = 0.75; + +/// +/// @ingroup domi_common +/// @brief roipooling default value +/// +const uint32_t ROIPOOLING_DEFAULT_POOLED_H = 0; +const uint32_t ROIPOOLING_DEFAULT_POOLED_W = 0; +const float ROIPOOLING_DEFAULT_SPATIAL_SCALE = 1; +const int32_t ROIPOOLING_DEFAULT_SAMPLING_RATIO = -1; + +// DetectionOutput +const int32_t DETECTIONOUTPUT_INPUT_SIZE = 3; +const int32_t DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const int32_t DETECTIONOUTPUT_WORKSPACE_NUM = 1; +const int32_t DETECTIONOUTPUT_CLASS_NUM = 20; // Number of background categories +const int32_t DETECTIONOUTPUT_NUM_CLASSES_DEFAULT_VALUE = 21; +const float DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; +const float DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.8; + +// Proposal +const int32_t PROPOSAL_INPUT_SIZE = 3; +const int32_t PROPOSAL_OUTPUT_MAX_SIZE = 2; +const int32_t PROPOSAL_WORKSPACE_NUM = 1; +const float PROPOSAL_BASE_SIZE_DEFAULT_VALUE = 16; +const float PROPOSAL_RATIO_DIM_0_DEFAULT_VALUE = 0.5; +const float PROPOSAL_RATIO_DIM_1_DEFAULT_VALUE = 1; +const float PROPOSAL_RATIO_DIM_2_DEFAULT_VALUE = 2; +const float PROPOSAL_SCALE_DIM_0_DEFAULT_VALUE = 8; +const float PROPOSAL_SCALE_DIM_1_DEFAULT_VALUE = 16; +const float PROPOSAL_SCALE_DIM_2_DEFAULT_VALUE = 32; +const float PROPOSAL_MIN_SIZE_DEFAULT_VALUE = 16; +const int32_t PROPOSAL_PRE_NMS_TOPN_DEFAULT_VALUE = 6000; +const int32_t PROPOSAL_POST_NMS_TOPN_DEFAULT_VALUE = 304; +const float PROPOSAL_NMS_THRESH_DEFAULT_VALUE = 0.7; +const float PROPOSAL_FILTER_THRESH_DEFAULT_VALUE = 0; + +// TVM OP +const uint32_t DEFAULT_KERNEL_BLOCK_DIM = 1; + +// Softmax +const int32_t SOFTMAX_WORKSPACE_NUM = 1; + +// SoftmaxCrossEntropy +const int32_t SOFTMAXCROSSENTROPY_INPUT_NUM = 2; +const int32_t SOFTMAXCROSSENTROPY_OUTPUT_NUM = 2; + +// Permute +const int32_t PERMUTE_INPUT_NUM = 1; +const int32_t PERMUTE_OUTPUT_NUM = 1; +const int32_t PERMUTE_WORKSPACE_NUM = 1; +const int32_t PERMUTE_ORDER_NUM = 4; + +// Ssd normalize +const int32_t SSD_NORMALIZE_INPUT_SIZE = 1; +const float SSD_NORMALIZE_EPS_DEFAULT_VALUE = 2e-7; + +// SsdPriroBox +const int32_t SSD_PRIOR_BOX_WORKSPACE_NUM = 1; +const int32_t SSD_PRIOR_BOX_INPUT_NUM = 2; +const bool SSD_PRIOR_BOX_FLIP_VALUE = true; +const bool SSD_PRIOR_BOX_CLIP_VALUE = false; +const double SSD_PRIOR_BOX_ASPECT_OFFSET_VALUE = 0.5; +const double SSD_PRIORBOX_VARIANCE_VALUE = 0.1; +const double SSD_PRIORBOX_VARIANCE_SIZE_ONE = 1; +const double SSD_PRIORBOX_VARIANCE_SIZE_FOUR = 4; +const double SSD_PRIORBOX_ASPECT_RATIO_VALUE = 1.0; +const int32_t SSD_PRIOR_BOX_CODETYPE_CORNER_VALUE = 1; +const int32_t SSD_PRIOR_BOX_CODETYPE_CENTER_SIZE_VALUE = 2; +const int32_t SSD_PRIOR_BOX_CODETYPE_CORNER_SIZE_VALUE = 3; + +// Ssd DetectionOutput +const int32_t SSD_DETECTIONOUTPUT_INPUT_SIZE = 3; +const int32_t SSD_DETECTIONOUTPUT_INPUT_SIZE_AFTER_FUSION = 2; +const int32_t SSD_DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const int32_t SSD_DETECTIONOUTPUT_OUTPUT_SIZE_AFTER_FUSION = 3; +const int32_t SSD_DETECTIONOUTPUT_WORKSPACE_NUM = 1; +const int32_t SSD_DETECTIONOUTPUT_WORKSPACE_NUM_AFTER_FUSION = 0; +const bool SSD_DETECTIONOUTPUT_SHARED_LOCATION_DEFAULT_VALUE = true; +const int32_t SSD_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0; +const float SSD_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; +const int32_t SSD_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200; +const float SSD_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0; +const int32_t SSD_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200; +const bool SSD_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false; +const float SSD_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1; + +// Refinedet DetectionOutput +const int32_t REFINEDET_DETECTIONOUTPUT_INPUT_SIZE = 5; +const int32_t REFINEDET_DETECTIONOUTPUT_INPUT_SIZE_AFTER_FUSION = 2; +const int32_t REFINEDET_DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const int32_t REFINEDET_DETECTIONOUTPUT_OUTPUT_SIZE_AFTER_FUSION = 3; +const int32_t REFINEDET_DETECTIONOUTPUT_WORKSPACE_NUM = 1; +const bool REFINEDET_DETECTIONOUTPUT_SHARED_LOCATION_DEFAULT_VALUE = true; +const int32_t REFINEDET_DETECTIONOUTPUT_BACKGROUND_LABEL_ID_DEFAULT_VALUE = 0; +const float REFINEDET_DETECTIONOUTPUT_NMS_THRESHOLD_DEFAULT_VALUE = 0.3; +const int32_t REFINEDET_DETECTIONOUTPUT_TOP_K_DEFAULT_VALUE = 200; +const float REFINEDET_DETECTIONOUTPUT_ETA_DEFAULT_VALUE = 1.0; +const bool REFINEDET_DETECTIONOUTPUT_VARIANCE_ENCODED_IN_TARGET_DEFAULT_VALUE = false; +const int32_t REFINEDET_DETECTIONOUTPUT_KEEP_TOP_K_DEFAULT_VALUE = 200; +const float REFINEDET_DETECTIONOUTPUT_CONFIDENCE_THRESHOLD_DEFAULT_VALUE = 0.1; +const float REFINEDET_DETECTIONOUTPUT_OBJECTNESS_SCORE_DEFAULT_VALUE = 0; + +// Channel axpy +const int32_t CHANNEL_AXPY_INPUT_NUM = 3; +const int32_t CHANNEL_AXPY_INPUT_DIM_SIZE = 4; +const int32_t CHANNEL_AXPY_WORKSPACE_NUM = 1; + +// Psroi pooling +const int32_t PSROI_POOLING_INPUT_COUNT = 2; +const int32_t PSROI_POOLING_WORKSPACE_NUM = 1; + +// MaxPoolWithArgmax +const uint32_t MAX_POOL_WITH_ARGMAX_OUTPUT_NUM = 2; +const uint32_t MAX_POOL_GRAD_WITH_ARGMAX_INPUT_NUM = 3; + +// AvgPoolGrad +const uint32_t AVG_POOL_GRAD_INPUT_NUM = 2; + +// ROIAlign +const int32_t ROIALIGN_INPUT_SIZE = 2; +const int32_t ROIALIGN_WORKSPACE_NUM = 1; +const int32_t ROIALIGN_DEFAULT_POOLED_H = 1; +const int32_t ROIALIGN_DEFAULT_POOLED_W = 1; + +// Correlation +const uint32_t CORRELATION_INPUT_NUM = 2; +const int32_t CORRELATION_WORKSPACE_NUM = 1; + +// Detectionpostprocess +const int32_t POSTPROCESS_INPUT_SIZE = 4; +const int32_t POSTPROCESS_OUTPUT_SIZE = 2; +const int32_t POSTPROCESS_WORKSPACE_NUM = 1; +const uint32_t POSTPROCESS_CLS_NUM_DEFAULT_VALUE = 12; +const uint32_t POSTPROCESS_POST_NMS_TOPN_DEFAULT_VALUE = 100; +const float POSTPROCESS_NMS_THRESH_DEFAULT_VALUE = 0.3; +const float POSTPROCESS_CONF_THRESH_DEFAULT_VALUE = 0.5; +const float POSTPROCESS_BBOX_REG_WEIGHT_DIM_DEFAULT_VALUE = 1.0; +const int32_t POSTPROCESS_BBOX_REG_WEIGHT_SIZE_DEFAULT_VALUE = 4; + +// Split +const int32_t SPLIT_INPUT_NUM = 2; +const int32_t SPLIT_DEFAULT_AXIS_VALUE = 1; +const int32_t SPLIT_MIN_OUTPUT_SIZE = 1; + +const uint32_t STRIDEDSLICE_INPUT_NUM = 4; +// Slice +const int32_t SLICE_INPUT_NUM = 3; +const int32_t SLICE_WEIGHT_NUM = 2; + +// GatherNd +const int32_t GATHERND_INPUT_NUM = 2; +// ArgMax +const int32_t ARGMAX_INPUT_NUM = 2; +const int32_t ARGMAX_REAL_INPUT_NUM = 1; + +// HighWay +const int32_t HIGHWAY_INPUT_NUM = 4; +const int32_t HIGHWAY_WORKSPACE_NUM = 1; +// RealDiv +const int32_t REALDIV_INPUT_NUM = 2; + +// Range +const int32_t RANGE_INPUT_NUM = 3; +const int32_t RANGE_OUTPUT_NUM = 1; +const int32_t RANGE_INPUT_DIM_SIZE = 0; + +// Pad +const int32_t PAD_WEIGHT_NUM = 1; +const int32_t PAD_DIM_SIZE = 2; +const int32_t PAD_DIM0 = 4; +const int32_t PAD_DIM1 = 2; +const int32_t PAD_WEIGHT_WITH_CONSTANT_NUM = 2; +const int32_t PAD_CONSTATNT_DEFAULT_VALUE = 0; +const int32_t PAD_PADDINGS_SIZE = 8; + +// Tile +const int32_t TILE_WEIGHT_NUM = 1; +const int32_t TILE_MULTIPLES_DIM_SIZE = 1; + +// DecodeBbox +const int32_t DECODE_BBOX_INPUT_NUM = 2; + +// GenerateRpnProposals +const int32_t GENERATE_RPN_PROPOSAL_INPUT_SIZE = 2; +const int32_t GENERATE_RPN_PROPOSAL_OUTPUT_SIZE = 3; + +// Decode_BBox +const int32_t DECODE_BBOX_INPUT_SIZE = 2; +const int32_t DEFAULT_DECODE_CLIP_VALUE = 0; + +// FastRcnnPredictions +const int32_t FASTRCNN_PREDICTIONS_INPUT_SIZE = 2; +const int32_t FASTRCNN_PREDICTIONS_OUTPUT_SIZE = 4; + +const int32_t CLIP_BOXES_INPUT_NUM = 1; +const int32_t CLIP_BOXES_WEIGHT_SIZE = 1; +const int32_t CLIP_BOXES_WEIGHT_ITEM_SIZE = 2; +const int32_t CLIP_BOXES_OUTPUT_NUM = 1; + +const int32_t FLOORDIV_INPUT_NUM = 2; +// Mean +const int32_t MEAN_WEIGHT_SIZE = 1; +const int32_t MEAN_WEIGHT_DIM_SIZE = 1; +const int32_t MEAN_WEIGHT_DIM = 2; +const int32_t MEAN_FIRST_AXIS = 2; +const int32_t MEAN_SECOND_AXIS = 3; +const int32_t MEAN_STRIDE_PLACE_HOLD = 1; +// Switch +const uint32_t SWITCH_INPUT_NUM = 2; +const uint32_t SWITCH_OUTPUT_NUM = 2; +// Merge +const uint32_t MERGE_INPUT_NUM = 2; +// Greater +const uint32_t GREATER_OUTPUT_NUM = 1; +const uint32_t GREATER_INPUT_NUM = 0; +const uint32_t GREATER_WEIGHT_NUM = 2; + +// Yolo region +const uint32_t YOLO_REGION_OUTPUT_NUM = 3; +const uint32_t YOLO_REGION_WORKSPACE_NUM = 1; +const uint32_t YOLO_REGION_COORDS = 4; +const uint32_t YOLO_REGION_CLASSES = 20; +const uint32_t YOLO_REGION_BOXES = 1; +const bool YOLO_REGION_BACKGROUND = false; +const bool YOLO_REGION_SOFTMAX = false; +const bool YOLO_REGION_SOFTMAX_TREE = false; + +// Yolo detectionoutput +const uint32_t YOLO_DETECTIONOUTPUT_INPUT_SIZE = 4; +const uint32_t YOLO_DETECTIONOUTPUT_OUTPUT_SIZE = 2; +const uint32_t YOLO_DETECTION_OUTPUT_WORKSPACE_NUM = 1; +const uint32_t YOLO_DETECTION_OUTPUT_CLASSES = 20; +const uint32_t YOLO_DETECTION_OUTPUT_BOXES_V2 = 5; +const uint32_t YOLO_DETECTION_OUTPUT_BOXES_V3 = 3; +const bool YOLO_DETECTION_OUTPUT_RELATIVE = true; +const float YOLO_DETECTION_OUTPUT_OBJECTNESS_THRESHOLD = 0.5; +const float YOLO_DETECTION_OUTPUT_CLASS_THRESHOLD = 0.5; +const uint32_t YOLO_DETECTION_OUTPUT_POST_TOP_K = UINT_MAX; +const float YOLO_DETECTION_OUTPUT_NMS_THRESHOLD = 0; +const float YOLO_DETECTION_OUTPUT_IOU_THRESHOLD_DECAY = 1.0; +const float YOLO_DETECTION_OUTPUT_COOR_SCALE_FACTOR = 1.0; + +// Reorg +const int32_t REORG_DEFAULT_STRIDE = 2; +const uint32_t REORG_INPUT_COUNT = 1; +// Reshape +const int32_t RESHAPE_INPUT_NUM = 2; +// Maximum +const int32_t MAXIMUM_INPUT_NUM = 2; + +// Spatialtf +const int32_t SPATIALTF_WORKSPACE_NUM = 1; + +const int32_t REVERSE_DEFAULT_AXIS = 1; +// Crop +const int32_t CROP_AXIS = 2; +const int32_t CROP_INPUT_NUM = 2; + +// ConvGradInput +const uint32_t CONVGRADINPUT_WORKSPACE_NUM = 1; +const uint32_t CONVGRADINPUT_INPUT_NUM = 3; + +// RNN +const uint32_t RNN_WORKSPACE_NUM = 1; + +// Cropandresize +const int32_t CROPANDRESIZE_WEIGHT_NUM = 1; +const int32_t CROPANDRESIZE_CROP_DIM_SIZE = 1; +const int32_t CROP_DIM0 = 2; + +// Attention decoder weight index +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENW0 = 0; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION0_KERNEL = 1; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTNOUTPUTPROJECTION_KERNEL = 2; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION_DECODER_KERNEL = 3; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_GATES_KERNEL = 4; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_CANDIDATE_KERNEL = 5; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_GATES_KERNEL = 6; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_KERNEL = 7; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION0_BIAS = 8; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTNOUTPUTPROJECTION_BIAS = 9; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENTION_DECODER_BIAS = 10; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_GATES_BIAS = 11; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL0_CANDIDATE_BIAS = 12; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_GATES_BIAS = 13; +const uint32_t ATTENTION_DECODER_WEIGHT_CELL1_CANDIDATE_BIAS = 14; +const uint32_t ATTENTION_DECODER_WEIGHT_EMBEDDING = 15; +const uint32_t ATTENTION_DECODER_WEIGHT_ATTENVA = 16; +const uint32_t ATTENTION_DECODER_WEIGHT_DECODER_INITIAL = 17; +// Attention decoder weight size +const uint32_t ATTENTION_DECODER_WEIGHT_SIZE = 18; + +const uint32_t ATTENTION_DECODER_INPUT_SIZE = 2; +const uint32_t ATTENTION_DECODER_WORKSPACE_NUM = 1; +const uint32_t ATTENTION_DECODER_INPUT_DECODER_INPUTS = 0; +const uint32_t ATTENTION_DECODER_INPUT_DECODER_INITIAL_HIDDEN = 1; + +const int32_t ATTENTION_DECODER_ALGO_NORMAL = 0; +const int32_t ATTENTION_DECODER_SYMBOLS = 10000; +const int32_t ATTENTION_DECODER_EMBEDDING_SIZE = 128; +const int32_t ATTENTION_DECODER_ATTENTION_NUM_HIDDEN = 256; +const int32_t ATTENTION_DECODER_DECODER_NUM_HIDDEN = 128; +const int32_t ATTENTION_DECODER_DECODER_NUM_LAYERS = 2; +const int32_t ATTENTION_DECODER_RNN_UNBIDIRECTIONAL = 0; +const int32_t ATTENTION_DECODER_SEQLEN_VALUE = 57; +const int32_t ATTENTION_DECODER_GRU = 3; + +// Logicaland +const int32_t LOGICAL_AND_INPUT_NUM = 2; +const int32_t EQUAL_INPUT_NUM = 2; + +static const int32_t OP_WEIGHT_MEM_BASE_OFFSET = 512; + +// MultiShape +const uint32_t MULTI_SHAPE_INPUT_NUM = 2; + +// Shufflechannel +const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ diff --git a/inc/air/framework/common/op_types.h b/inc/air/framework/common/op_types.h new file mode 100644 index 0000000000000000000000000000000000000000..edce0f7d2f89db9de096ef68cd706cd97d9ebfe8 --- /dev/null +++ b/inc/air/framework/common/op_types.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_OP_TYPES_H_ +#define INC_FRAMEWORK_COMMON_OP_TYPES_H_ + +#include +#include + +#include "graph/types.h" +#include "graph/compiler_def.h" + +namespace ge { +class GE_FUNC_VISIBILITY OpTypeContainer { + public: + static OpTypeContainer *Instance() { + static OpTypeContainer instance; + return &instance; + } + ~OpTypeContainer() = default; + + bool Register(const std::string &op_type) { return op_type_list_.insert(op_type).second; } + + bool IsExisting(const std::string &op_type) { + return op_type_list_.find(op_type) != op_type_list_.end(); + } + + protected: + OpTypeContainer() {} + + private: + std::set op_type_list_; +}; +} // namespace ge + +#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char_t *var_name + +#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ + const char_t *var_name = str_name; \ + const bool VAR_UNUSED g_##var_name##_reg = OpTypeContainer::Instance()->Register(str_name) + +#define IS_OPTYPE_EXISTING(str_name) (ge::OpTypeContainer::Instance()->IsExisting(str_name)) +#endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ + diff --git a/inc/air/framework/common/profiling/ge_profiling.h b/inc/air/framework/common/profiling/ge_profiling.h new file mode 100644 index 0000000000000000000000000000000000000000..2c2ebf13b73e772ca62c14c983529ed864c749c5 --- /dev/null +++ b/inc/air/framework/common/profiling/ge_profiling.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_GE_PROFILING_H_ +#define INC_FRAMEWORK_COMMON_GE_PROFILING_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "runtime/base.h" + +/// @brief Output the profiling data of single operator in Pytorch, and does not support multithreading +/// @return Status result +GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(const uint64_t index_id, const uint16_t tag_id, rtStream_t const stream); + +GE_FUNC_VISIBILITY ge::Status ProfGetDeviceFormGraphId(const uint32_t graph_id, uint32_t &device_id); + +#endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ diff --git a/inc/air/framework/common/profiling/ge_runner_profiling.h b/inc/air/framework/common/profiling/ge_runner_profiling.h new file mode 100644 index 0000000000000000000000000000000000000000..7a5251dc62741a759a70af1b096003878885dad4 --- /dev/null +++ b/inc/air/framework/common/profiling/ge_runner_profiling.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ +#define INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ + +#include "framework/common/profiling/ge_profiling.h" + +GE_FUNC_VISIBILITY bool IsInitialize(); + +#endif // INC_FRAMEWORK_COMMON_GE_RUNNER_PROFILING_H_ diff --git a/inc/air/framework/common/profiling_definitions.h b/inc/air/framework/common/profiling_definitions.h new file mode 100644 index 0000000000000000000000000000000000000000..42c2bd7f2a3b707705604a1b0658e7d154543e8f --- /dev/null +++ b/inc/air/framework/common/profiling_definitions.h @@ -0,0 +1,275 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_PROFILING_DEFINITIONS_H +#define AIR_CXX_PROFILING_DEFINITIONS_H +#include +#include +#include +#include +#include "graph/utils/profiler.h" +#include "external/ge/ge_api_types.h" +#include "toolchain/prof_api.h" +namespace gert { +namespace profiling { +enum { + // ACL Interface + kAclCreateTensorDesc, + kAclSetTensorFormat, + kAclSetTensorPlacement, + kAclSetTensorShape, + kAclSetTensorDescName, + kAclCreateDataBuffer, + kAclRtMalloc, + kAclRtFree, + kAclRtMemcpyAsync, + kAclRtMemcpy, + kAclRtSynchronizeStream, + kAclRtStreamWaitEvent, + kAclRtSynchronizeDevice, + kAclRtDestoryEvent, + kAclRtRecordEvent, + kAclRtSynchronizeEvent, + kAclRtCreateEventWithFlag, + kAclRtEventWaitStatus, + kAclRtEventRecordedStatus, + kAclRtQueryEventStatus, + kAclCompileAndExecute, + kAclCompileAndExecuteV2, + // ACL Internal + kAclMatchOpModel, + kAclMatchStaticOpModel, + kAclMatchDynamicOpModel, + kAclExecuteAsync, + kAclExecuteSync, + kAclLoadSingleOp, + kAclBuildOpModel, + // inherit from rt2 + kModelExecute, + kInitInferShapeContext, + kTiling, + kUpdateShape, + kAllocMem, + kAtomic, + kOpExecute, + kKernelLaunchPrepare, + kInitHybridExecuteArgs, + kKnownGetAddrAndPrefCnt, + kKernelGetAddrAndPrefCnt, + kUpdateAddrAndPrefCnt, + kRtEventCreateRecord, + kRtEventSync, + kRtEventDestroy, + // static single op + kStaticSingleOpExecute, + kStaticSingleOpKernelLaunch, + kStaticSingleOpCopyH2D, + // model v2 executor + kModel, + kExecute, + // Default + kUnknownName, + kProfilingIndexEnd +}; +} +} // namespace gert + +namespace ge { +namespace profiling { +enum { + kAclCompileAndExecute, + kAclMatchOpModel, + kAclMatchStaticOpModel, + kAclMatchDynamicOpModel, + kAclExecuteAsync, + kAclLoadSingleOp, + kAclBuildOpModel, + kInferShape, + kTiling, + kUpdateShape, + kConstPrepare, + kInitHybridExecuteArgs, + kInitInferShapeContext, + kDestroyInferShapeContext, + kResetSubgraphExecutor, + kCommitInferShapeTask, + kDeviceToHost, + kPrepareTask, + kLaunchTask, + kCommitTilingTask, + kAtomic, + kKernelLaunchPrepare, + kRtKernelLaunch, + kRtEventCreateRecord, + kRtEventSync, + kRtEventDestroy, + kRtStreamSync, + kOpExecute, + kModelExecute, + kAllocMem, + kCopyH2D, + kPrepareNode, + kWaitForPrepareDone, + kPropgateOutputs, + kOnNodeDoneCallback, + kValidateInputTensor, + kAfterExecuted, + kRtEventSychronize, + kInferShapeWaitDependShape, + kInferShapeWaitInputTensor, + kInferShapeCallInferFunc, + kInferShapePropgate, + // v2 control node + kSelectBranch, + kExecuteSubGraph, + kInitSubGraphExecutor, + // fuzz compile + kSelectBin, + kFindCompileCache, + kAddCompileCache, + kFuzzCompileOp, + kCalcRuningParam, + kGenTask, + kRegisterBin, + + // FFTS Plus + kFftsPlusPreThread, + kFftsPlusNodeThread, + kFftsPlusInferShape, + kOpFftsCalculateV2, + kInitThreadRunInfo, + kFftsPlusGraphSchedule, + kKnownGetAddrAndPrefCnt, + kKernelGetAddrAndPrefCnt, + kUpdateAddrAndPrefCnt, + kInitOpRunInfo, + kGetAutoThreadParam, + kAllocateOutputs, + kAllocateWorkspaces, + kInitTaskAddrs, + kInitThreadRunParam, + kUpdateTaskAndCache, + kFftsPlusTaskLaunch, + + // Add new definitions here + kProfilingIndexEnd +}; +constexpr uint64_t kInvalidHashId = 0UL; + +class ProfilingContext { + public: + static bool IsDumpToStdEnabled(); + static ProfilingContext &GetInstance(); + ProfilingContext(); + ~ProfilingContext(); + + /* + * 还有一种思路是`IsEnabled`只判断profiler_是否为空指针,不再设置单独的enabled标记位,这样可以少一个标记位。 + * 但是这么做就意味着,profiler_实例在未使能profiling时,必须是空指针状态。 + * 为了性能考虑,profiling机制在编译和加载时,就会调用`RegisterString`,向profiler_注册字符串,后续执行时,只会使用注册好的index了。 + * 因此存在一种场景:编译时并未使能profiling(因为编译时间很长,使能profiling也无法真实反应执行时的耗时状态), + * 因此编译时注册字符串的动作并没有生效。在执行时,动态的打开了profiling,这种场景下,执行时无法拿到注册后字符串 + */ + bool IsEnabled() const noexcept { + return enabled_ && (profiler_ != nullptr); + } + void SetEnable() noexcept { + enabled_ = true; + } + void SetDisable() noexcept { + enabled_ = false; + } + + void RecordCurrentThread(const int64_t element, const int64_t event, const EventType et, + const std::chrono::time_point time_point) const noexcept { + if (IsEnabled()) { + profiler_->RecordCurrentThread(element, event, et, time_point); + } + } + + void RecordCurrentThread(const int64_t element, const int64_t event, const EventType et) const noexcept { + RecordCurrentThread(element, event, et, std::chrono::system_clock::now()); + } + + const Profiler *GetProfiler() const { + return profiler_.get(); + } + + void Dump(std::ostream &out_stream) const { + if (IsEnabled()) { + profiler_->Dump(out_stream); + } else { + out_stream << "Profiling not enable, skip to dump" << std::endl; + } + } + + void DumpToStdOut() const { + Dump(std::cout); + } + + void Reset() { + if (IsEnabled()) { + profiler_->Reset(); + } + } + + int64_t RegisterString(const std::string &str); + int64_t RegisterStringHash(const uint64_t hash_id, const std::string &str); + void UpdateElementHashId(); + size_t GetRegisterStringNum() const { + return strings_to_index_.size(); + } + + void Init(); + private: + void UpdateHashByStr(const std::string &str, const uint64_t hash); + + private: + bool inited_ = false; + bool enabled_ = false; + int64_t str_index_ = kProfilingIndexEnd; + std::unordered_map strings_to_index_; + std::mutex strings_to_index_mutex_; + std::unique_ptr profiler_; +}; + +class ScopeProfiler { + public: + ScopeProfiler(const int64_t element, const int64_t event) : element_(element), event_(event) { + if (ProfilingContext::GetInstance().IsEnabled()) { + start_trace_ = std::chrono::system_clock::now(); + } + } + ~ScopeProfiler() { + if (ProfilingContext::GetInstance().IsEnabled()) { + ProfilingContext::GetInstance().RecordCurrentThread(element_, event_, EventType::kEventStart, start_trace_); + ProfilingContext::GetInstance().RecordCurrentThread(element_, event_, EventType::kEventEnd); + } + } + void SetElement(const int64_t element) { + element_ = element; + } + + private: + std::chrono::time_point start_trace_; + int64_t element_; + int64_t event_; +}; +} // namespace profiling +} // namespace ge +#define PROFILING_START(element, event) \ + ge::profiling::ProfilingContext::GetInstance().RecordCurrentThread((element), (event), \ + ge::profiling::EventType::kEventStart) +#define PROFILING_END(element, event) \ + ge::profiling::ProfilingContext::GetInstance().RecordCurrentThread((element), (event), \ + ge::profiling::EventType::kEventEnd) +#define PROFILING_SCOPE(element, event) ge::profiling::ScopeProfiler profiler((element), (event)) +#define PROFILING_SCOPE_CONST(element, event) const ge::profiling::ScopeProfiler profiler((element), (event)) +#define PROFILING_SCOPE_ELEMENT(element) profiler.SetElement((element)) +#endif // AIR_CXX_PROFILING_DEFINITIONS_H diff --git a/inc/air/framework/common/runtime_tensor_desc.h b/inc/air/framework/common/runtime_tensor_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..e44e5d87b568fc147a2e633cb07e8b12bbf85d20 --- /dev/null +++ b/inc/air/framework/common/runtime_tensor_desc.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_RUNTIME_TENSOR_DESC_H_ +#define INC_FRAMEWORK_COMMON_RUNTIME_TENSOR_DESC_H_ + +#include + +namespace ge { +constexpr int64_t kMaxDimSize = 32; + +#pragma pack(push, 1) +struct RuntimeTensorDesc { + uint64_t data_addr; + int64_t data_offset_size; + int64_t dtype; + int64_t shape[kMaxDimSize + 1]; // shape:Dim_Num|DIM0|DIM1|...|DIM31 + int64_t original_shape[kMaxDimSize + 1]; // original_shape:Dim_Num|DIM0|DIM1|...|DIM31 + int64_t format; + int64_t sub_format; + uint64_t data_size; + uint8_t reserved[448]; // padding to 1024 bytes +}; +#pragma pack(pop) +} // namespace ge + +#endif // INC_FRAMEWORK_COMMON_RUNTIME_TENSOR_DESC_H_ diff --git a/inc/air/framework/common/scope_guard.h b/inc/air/framework/common/scope_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..3ecd0e3c188e22880314e6a010507e81048e23e0 --- /dev/null +++ b/inc/air/framework/common/scope_guard.h @@ -0,0 +1,13 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ +#define INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ +#include "common/ge_common/scope_guard.h" +#endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ diff --git a/inc/air/framework/common/string_util.h b/inc/air/framework/common/string_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a4175e56d14c891c18300873f6e06e4f2758f210 --- /dev/null +++ b/inc/air/framework/common/string_util.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_STRING_UTIL_H_ +#define INC_FRAMEWORK_COMMON_STRING_UTIL_H_ +#include "common/ge_common/string_util.h" +namespace ge { +template +static std::string StrJoin(Iterator begin, Iterator end, const std::string &separator) { + if (begin == end) { + return ""; + } + std::stringstream str_stream; + str_stream << *begin; + for (Iterator it = std::next(begin); it != end; ++it) { + str_stream << separator << *it; + } + return str_stream.str(); +} +} +#endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ diff --git a/inc/air/framework/common/taskdown_common.h b/inc/air/framework/common/taskdown_common.h new file mode 100644 index 0000000000000000000000000000000000000000..346be128f6d885e0e40c87d1f639014e9ec528a0 --- /dev/null +++ b/inc/air/framework/common/taskdown_common.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ +#define INC_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ + +#include "runtime/rt.h" +#include "ge/framework/common/taskdown_common.h" + +namespace ge { + +constexpr int32_t CC_FUSION_OP_MAX = 32; + +enum class ccStatus_t : uint32_t { + CC_STATUS_SUCCESS = 0, /**< succ */ + CC_STATUS_NOT_INITIALIZED = 1, /**< not init */ + CC_STATUS_ALLOC_FAILED = 2, /**< alloc mem failed */ + CC_STATUS_BAD_PARAM = 3, /**< para check failed */ + CC_STATUS_INTERNAL_ERROR = 4, /**< internal error */ + CC_STATUS_KERNEL_ERROR = 5, /**< kernel error */ + CC_STATUS_RUNTIME_ERROR = 6, /**< runtime error */ + CC_STATUS_NOT_SUPPORTED = 7, /**< unsupport error */ + CC_STATUS_INVALID_VALUE = 7, /**< invalid value error for blas*/ + CC_STATUS_RESERVED = 8, /**< just for check */ +}; + +using ccOpContext = struct tagOpContext { + ccKernelType kernelType; + uint32_t opId; + uint32_t kernelFuncId; + uint32_t opIndex; + uint32_t opCount; + uint32_t opIndex2[CC_FUSION_OP_MAX]; + bool isFlowtable; + uint16_t *argsOffset; + uint32_t argsCount; + uint64_t genDataBaseAddr; + uint64_t genDataBaseSize; + uint64_t genWeightBaseAddr; + uint64_t genWeightBaseSize; + uint64_t genVariableBaseAddr; + uint64_t genVariableBaseSize; + uint64_t l2ctrlSize; +}; + +enum class tagOpTensorFormat : uint32_t { + OP_TENSOR_FORMAT_NC1HWC0 = 0, + OP_TENSOR_FORMAT_ND, + OP_TENSOR_FORMAT_RESERVED +}; + +enum class tagOpDataType : uint32_t { + OP_DATA_FLOAT = 0, /**< float type */ + OP_DATA_HALF, /**< fp16 type */ + OP_DATA_INT8, /**< int8 type */ + OP_DATA_INT32, /**< int32 type */ + OP_DATA_UINT8, /**< uint8 type */ + OP_DATA_HALF_UINT16_PROPOSAL, /**< mixed type for proposal */ + OP_DATA_RESERVED +}; + +// AICPU Tensor +using ccAICPUTensor = struct tagOpTensor { + // real dim info + tagOpTensorFormat format; + tagOpDataType data_type; + int32_t dim_cnt; + int32_t mm; + int32_t dim[8]; +}; +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ diff --git a/inc/air/framework/common/tlv/nano_dbg_desc.h b/inc/air/framework/common/tlv/nano_dbg_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..8e03509fe58e54087604dddc1c9a86c6123c4839 --- /dev/null +++ b/inc/air/framework/common/tlv/nano_dbg_desc.h @@ -0,0 +1,164 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_TLV_NANO_DBG_DESC_H_ +#define INC_FRAMEWORK_COMMON_TLV_NANO_DBG_DESC_H_ + +#include "framework/common/tlv/tlv.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#pragma pack(1) // single-byte alignment + +/************************** definition of dbg data desc *************************************/ +#define DBG_DATA_HEAD_MAGIC 0x5A5A5A5AU +struct DbgDataHead { + uint32_t version_id; + uint32_t magic; + uint8_t tlv[0]; // TlvHead +}; + +/********************************************************************************************/ +/************************** definition of dbg data Level 1 tlv ******************************/ +#define DBG_L1_TLV_TYPE_MODEL_NAME 0U // uint8_t[] +#define DBG_L1_TLV_TYPE_OP_DESC 1U + +struct DbgOpDescParamTlv1 { + uint32_t task_id; + uint32_t stream_id; + uint32_t logic_stream_id; + int32_t task_type; + uint32_t block_dim; + uint8_t datadump_is_multiop; + uint32_t l2_tlv_list_len; + uint8_t l2_tlv[0]; // TlvHead +}; + +struct DbgOpDescTlv1 { + uint32_t num; + uint8_t param[0]; // DbgOpDescParamTlv1 +}; + +/********************************************************************************************/ +/************************** definition of dbg data desc Level 2 tlv *************************/ +#define DBG_L2_TLV_TYPE_OP_NAME 0U // uint8_t[] +#define DBG_L2_TLV_TYPE_OP_TYPE 1U // uint8_t[] +#define DBG_L2_TLV_TYPE_ORI_OP_NAME 2U // uint8_t[] +#define DBG_L2_TLV_TYPE_L1_SUB_GRAPH_NO 3U // uint8_t[] +#define DBG_L2_TLV_TYPE_INPUT_DESC 4U +#define DBG_L2_TLV_TYPE_OUTPUT_DESC 5U +#define DBG_L2_TLV_TYPE_WORKSPACE_DESC 6U +#define DBG_L2_TLV_TYPE_OP_BUF 7U +#define DBG_L2_TLV_TYPE_MEM_INFO 8U +#define DBG_L2_TLV_TYPE_MAX 9U + +/************************** definition of dbg input tlv's value *****************************/ +#define DBG_INPUT_DESC_L3_TLV_TYPE_SHAPE_DIMS 0U +#define DBG_INPUT_DESC_L3_TLV_TYPE_ORI_SHAPE_DIMS 1U +#define DBG_INPUT_DESC_L3_TLV_TYPE_MAX 2U + +struct DbgInputDescParamTlv2 { + int32_t data_type; + int32_t format; + int32_t addr_type; + uint64_t addr; + uint64_t offset; + uint64_t size; + uint32_t l3_tlv_list_len; + uint8_t l3_tlv[0]; // TlvHead +}; + +struct DbgInputDescTlv2 { + uint32_t num; + uint8_t param[0]; // DbgInputDescParamTlv2 +}; + +/********************************************************************************************/ +/************************** definition of dbg output tlv's value ****************************/ +#define DBG_OUTPUT_DESC_L3_TLV_TYPE_SHAPE_DIMS 0U +#define DBG_OUTPUT_DESC_L3_TLV_TYPE_ORI_SHAPE_DIMS 1U +#define DBG_OUTPUT_DESC_L3_TLV_TYPE_ORI_NAME 2U +#define DBG_OUTPUT_DESC_L3_TLV_TYPE_MAX 3U + +struct DbgOutputDescParamTlv2 { + int32_t data_type; + int32_t format; + int32_t addr_type; + int32_t original_index; + int32_t original_data_type; + int32_t original_format; + uint64_t addr; + uint64_t offset; + uint64_t size; + uint32_t l3_tlv_list_len; + uint8_t l3_tlv[0]; // TlvHead +}; + +struct DbgOutputDescTlv2 { + uint32_t num; + uint8_t param[0]; // DbgOutputDescParamTlv2 +}; + +/************************** definition of dbg workspace tlv's value *************************/ +struct DbgWorkspaceDescParamTlv2 { + int32_t type; + uint64_t data_addr; + uint64_t size; + uint32_t l3_tlv_list_len; + uint8_t l3_tlv[0]; // TlvHead +}; + +struct DbgWorkspaceDescTlv2 { + uint32_t num; + uint8_t param[0]; // DbgWorkspaceDescParamTlv2 +}; + +/********************************************************************************************/ +/************************** definition of dbg op buf tlv's value ****************************/ +// L2 tlv op buffer list +struct DbgOpBufParamTlv2 { + uint8_t type; + uint64_t addr; + uint64_t size; + uint32_t l3_tlv_list_len; + uint8_t l3_tlv[0]; // TlvHead +}; + +struct DbgOpBufTlv2 { + uint32_t num; + uint8_t param[0]; // DbgOpBufParamTlv2 +}; + +/************************** definition of dbg mem info tlv's value **************************/ +// L2 tlv mem info list +struct DbgOpMemInfoTlv2 { + uint64_t input_mem_size; + uint64_t output_mem_size; + uint64_t weight_mem_size; + uint64_t workspace_mem_size; + uint64_t total_mem_size; + uint32_t l3_tlv_list_len; + uint8_t l3_tlv[0]; // TlvHead +}; + +struct DbgOpMemInfosTlv2 { + uint32_t num; + uint8_t param[0]; // DbgOpMemInfoTlv2 +}; + +/********************************************************************************************/ +#pragma pack() // Cancels single-byte alignment + +#if defined(__cplusplus) +} +#endif + +#endif // INC_FRAMEWORK_COMMON_TLV_NANO_DBG_DESC_H_ diff --git a/inc/air/framework/common/tlv/pre_model_desc.h b/inc/air/framework/common/tlv/pre_model_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ccab7623cd1b89e67de7c6c6633b9f9be48870 --- /dev/null +++ b/inc/air/framework/common/tlv/pre_model_desc.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_TLV_pre_model_desc_H_ +#define INC_FRAMEWORK_COMMON_TLV_pre_model_desc_H_ + +namespace ge { +#pragma pack(1) // single-byte alignment + +enum KERNEL_ARG_UPADTE_TYPE { + KERNEL_ARG_UPDATE_TYPE_ADDR, + KERNEL_ARG_UPDATE_TYPE_TS, + KERNEL_ARG_UPDATE_TYPE_P2P, + KERNEL_ARG_UPDATE_TYPE_CPU_KERNEL_ARGS, + KERNEL_ARG_UPDATE_TYPE_SESSIONID, + KERNEL_ARG_UPDATE_TYPE_KERNELID, + KERNEL_ARG_UPDATE_TYPE_EVENTID, + KERNEL_ARG_UPDATE_TYPE_BUFF +}; +enum KERNEL_ARG_UPADTE_ADDR_TYPE { + KERNEL_ARG_UPADTE_ADDR_TYPE_ARGS, + KERNEL_ARG_UPADTE_ADDR_TYPE_WORKSPACE, + KERNEL_ARG_UPADTE_ADDR_TYPE_WEIGHT, + KERNEL_ARG_UPADTE_ADDR_TYPE_L1, + KERNEL_ARG_UPADTE_ADDR_TYPE_TS, + KERNEL_ARG_UPADTE_ADDR_TYPE_P2P, + KERNEL_ARG_UPADTE_ADDR_TYPE_VAR, + KERNEL_ARG_UPADTE_ADDR_TYPE_KERNEL_BIN, + KERNEL_ARG_UPADTE_ADDR_TYPE_BUFF +}; + +/********************************************************************************************/ +#pragma pack() // Cancels single-byte alignment +} // namespace ge +#endif // INC_FRAMEWORK_COMMON_TLV_pre_model_desc_H_ diff --git a/inc/air/framework/common/tlv/tlv.h b/inc/air/framework/common/tlv/tlv.h new file mode 100644 index 0000000000000000000000000000000000000000..fa50df1659550e2ee636b892f8887b1cad7b8038 --- /dev/null +++ b/inc/air/framework/common/tlv/tlv.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_TLV_TLV_H_ +#define INC_FRAMEWORK_COMMON_TLV_TLV_H_ + +#if defined(__cplusplus) +extern "C" { +#endif + +#pragma pack(1) // single-byte alignment + +// tlv struct +struct TlvHead { + uint32_t type; + uint32_t len; + uint8_t data[0]; +}; + +#pragma pack() + +#if defined(__cplusplus) +} +#endif + +#endif // INC_FRAMEWORK_COMMON_TLV_TLV_H_ diff --git a/inc/air/framework/common/types.h b/inc/air/framework/common/types.h new file mode 100644 index 0000000000000000000000000000000000000000..1ed3a1cdfd4878a3d4efb30c34b702a806c689be --- /dev/null +++ b/inc/air/framework/common/types.h @@ -0,0 +1,822 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_COMMON_TYPES_H_ +#define INC_FRAMEWORK_COMMON_TYPES_H_ + +#include +#include +#include +#include + +#include "framework/common/fmk_error_codes.h" +#include "framework/common/fmk_types.h" +#include "framework/common/op_types.h" +#include "register/register_types.h" + +namespace ge { +// dump +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_ALL_MODEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER_OP_MODEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_WATCHER_MODEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; + +// Profile-related constants +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_MODEL_ID; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_HOST_BASE_ADDR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_HOST_SVM_BASE_ADDR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_HOST_MEMORY_SIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_HOST_SVM_SIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_FUSION_MODEL_DEF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint64_t ALLOC_MEMORY_MAX_SIZE; // Max size of 8 GB. + +REGISTER_OPTYPE_DECLARE(DATA, "Data"); +REGISTER_OPTYPE_DECLARE(REFDATA, "RefData"); +REGISTER_OPTYPE_DECLARE(AIPPDATA, "AippData"); +REGISTER_OPTYPE_DECLARE(QUEUE_DATA, "QueueData"); +REGISTER_OPTYPE_DECLARE(CONVOLUTION, "Convolution"); +REGISTER_OPTYPE_DECLARE(CORRELATION, "Correlation"); +REGISTER_OPTYPE_DECLARE(CORRELATIONV2, "Correlation_V2"); +REGISTER_OPTYPE_DECLARE(DECONVOLUTION, "Deconvolution"); +REGISTER_OPTYPE_DECLARE(POOLING, "Pooling"); +REGISTER_OPTYPE_DECLARE(ELTWISE, "Eltwise"); +REGISTER_OPTYPE_DECLARE(RELU, "ReLU"); +REGISTER_OPTYPE_DECLARE(RELU6, "ReLU6"); +REGISTER_OPTYPE_DECLARE(SIGMOID, "Sigmoid"); +REGISTER_OPTYPE_DECLARE(ABSVAL, "AbsVal"); +REGISTER_OPTYPE_DECLARE(TANH, "TanH"); +REGISTER_OPTYPE_DECLARE(PRELU, "PReLU"); +REGISTER_OPTYPE_DECLARE(BATCHNORM, "BatchNorm"); +REGISTER_OPTYPE_DECLARE(FUSIONBATCHNORM, "FusionBatchNorm"); +REGISTER_OPTYPE_DECLARE(SCALE, "Scale"); +REGISTER_OPTYPE_DECLARE(FULL_CONNECTION, "FullConnection"); +REGISTER_OPTYPE_DECLARE(SOFTMAX, "Softmax"); +REGISTER_OPTYPE_DECLARE(PLUS, "Plus"); +REGISTER_OPTYPE_DECLARE(ACTIVATION, "Activation"); +REGISTER_OPTYPE_DECLARE(FLATTEN, "Flatten"); +REGISTER_OPTYPE_DECLARE(ADD, "Add"); +REGISTER_OPTYPE_DECLARE(SUB, "Sub"); +REGISTER_OPTYPE_DECLARE(MUL, "Mul"); +REGISTER_OPTYPE_DECLARE(MATMUL, "MatMul"); +REGISTER_OPTYPE_DECLARE(RSQRT, "Rsqrt"); +REGISTER_OPTYPE_DECLARE(BIASADD, "BiasAdd"); +REGISTER_OPTYPE_DECLARE(RESHAPE, "Reshape"); +REGISTER_OPTYPE_DECLARE(REFORMAT, "ReFormat"); +REGISTER_OPTYPE_DECLARE(DEPCONVOLUTION, "ConvolutionDepthwise"); +REGISTER_OPTYPE_DECLARE(DROPOUT, "Dropout"); +REGISTER_OPTYPE_DECLARE(DROPOUTDOMASK, "DropOutDoMask"); +REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3, "DropOutDoMaskV3"); +REGISTER_OPTYPE_DECLARE(DROPOUTDOMASKV3D, "DropOutDoMaskV3D"); +REGISTER_OPTYPE_DECLARE(SOFTMAXV2WITHDROPOUTDOMASKV3D, "SoftmaxV2WithDropOutDoMaskV3D"); +REGISTER_OPTYPE_DECLARE(ATTENTIONSCORE, "AttentionScore"); +REGISTER_OPTYPE_DECLARE(ATTENTIONSCOREGRAD, "AttentionScoreGrad"); +REGISTER_OPTYPE_DECLARE(DROPOUTGENMASK, "DropOutGenMask"); +REGISTER_OPTYPE_DECLARE(DROPOUTGENMASKV3, "DropOutGenMaskV3"); +REGISTER_OPTYPE_DECLARE(AXPYWITHSOFTMAXANDDROPOUTDOMASK, "AxpyWithSoftmaxAndDropOutDoMask"); +REGISTER_OPTYPE_DECLARE(CONCAT, "Concat"); +REGISTER_OPTYPE_DECLARE(ROIPOOLING, "ROIPooling"); +REGISTER_OPTYPE_DECLARE(PROPOSAL, "Proposal"); +REGISTER_OPTYPE_DECLARE(FSRDETECTIONOUTPUT, "FSRDetectionOutput"); +REGISTER_OPTYPE_DECLARE(DETECTIONPOSTPROCESS, "Detectpostprocess"); +REGISTER_OPTYPE_DECLARE(LRN, "LRN"); +REGISTER_OPTYPE_DECLARE(TRANSDATA, "TransData"); +REGISTER_OPTYPE_DECLARE(PERMUTE, "Permute"); +REGISTER_OPTYPE_DECLARE(SSDNORMALIZE, "SSDNormalize"); +REGISTER_OPTYPE_DECLARE(SSDPRIORBOX, "SSDPriorBox"); +REGISTER_OPTYPE_DECLARE(NETOUTPUT, "NetOutput"); +REGISTER_OPTYPE_DECLARE(SSDDETECTIONOUTPUT, "SSDDetectionOutput"); +REGISTER_OPTYPE_DECLARE(REFINEDETDETECTIONOUTPUT, "RefinedetDetectionOutput"); +REGISTER_OPTYPE_DECLARE(CHANNELAXPY, "ChannelAxpy"); +REGISTER_OPTYPE_DECLARE(PSROIPOOLING, "PSROIPooling"); +REGISTER_OPTYPE_DECLARE(POWER, "Power"); +REGISTER_OPTYPE_DECLARE(POW, "Pow"); +REGISTER_OPTYPE_DECLARE(ROIALIGN, "ROIAlign"); +REGISTER_OPTYPE_DECLARE(PYTHON, "Python"); +REGISTER_OPTYPE_DECLARE(FREESPACEEXTRACT, "FreespaceExtract"); +REGISTER_OPTYPE_DECLARE(SPATIALTF, "SpatialTransform"); +REGISTER_OPTYPE_DECLARE(SHAPE, "Shape"); +REGISTER_OPTYPE_DECLARE(SHAPEN, "ShapeN"); +REGISTER_OPTYPE_DECLARE(ARGMAX, "ArgMax"); +REGISTER_OPTYPE_DECLARE(GATHERND, "GatherNd"); +REGISTER_OPTYPE_DECLARE(GATHER, "Gather"); +REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); +REGISTER_OPTYPE_DECLARE(PACK, "Pack"); +REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); +REGISTER_OPTYPE_DECLARE(SLICEWITHAXES, "SliceWithAxes"); +REGISTER_OPTYPE_DECLARE(SLICED, "SliceD"); +REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); +REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); +REGISTER_OPTYPE_DECLARE(UNSQUEEZE, "Unsqueeze"); +REGISTER_OPTYPE_DECLARE(SQUEEZEV2, "SqueezeV2"); +REGISTER_OPTYPE_DECLARE(UNSQUEEZEV2, "UnsqueezeV2"); +REGISTER_OPTYPE_DECLARE(SQUEEZEV3, "SqueezeV3"); +REGISTER_OPTYPE_DECLARE(UNSQUEEZEV3, "UnsqueezeV3"); +REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); +REGISTER_OPTYPE_DECLARE(RANGE, "Range"); +REGISTER_OPTYPE_DECLARE(RPNPROPOSALS, "GenerateRpnProposals"); +REGISTER_OPTYPE_DECLARE(DECODEBBOX, "DecodeBBox"); +REGISTER_OPTYPE_DECLARE(PAD, "Pad"); +REGISTER_OPTYPE_DECLARE(PADV2, "PadV2"); +REGISTER_OPTYPE_DECLARE(MIRRORPAD, "MirrorPad"); +REGISTER_OPTYPE_DECLARE(TILE, "Tile"); +REGISTER_OPTYPE_DECLARE(SIZE, "Size"); +REGISTER_OPTYPE_DECLARE(CLIPBOXES, "Clipboxes"); +REGISTER_OPTYPE_DECLARE(FASTRCNNPREDICTIONS, "FastrcnnPredictions"); +REGISTER_OPTYPE_DECLARE(SPLIT, "Split"); +REGISTER_OPTYPE_DECLARE(SPLITV, "SplitV"); +REGISTER_OPTYPE_DECLARE(EXPANDDIMS, "ExpandDims"); +REGISTER_OPTYPE_DECLARE(EMPTY, "Empty"); +REGISTER_OPTYPE_DECLARE(MEAN, "Mean"); +REGISTER_OPTYPE_DECLARE(GREATER, "Greater"); +REGISTER_OPTYPE_DECLARE(SWITCH, "Switch"); +REGISTER_OPTYPE_DECLARE(SWITCHN, "SwitchN"); +REGISTER_OPTYPE_DECLARE(REFSWITCH, "RefSwitch"); +REGISTER_OPTYPE_DECLARE(MERGE, "Merge"); +REGISTER_OPTYPE_DECLARE(REFMERGE, "RefMerge"); +REGISTER_OPTYPE_DECLARE(ENTER, "Enter"); +REGISTER_OPTYPE_DECLARE(REFENTER, "RefEnter"); +REGISTER_OPTYPE_DECLARE(LOOPCOND, "LoopCond"); +REGISTER_OPTYPE_DECLARE(NEXTITERATION, "NextIteration"); +REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration"); +REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); +REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); +REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); +REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient"); +REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall"); +REGISTER_OPTYPE_DECLARE(_IF, "_If"); +REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf"); +REGISTER_OPTYPE_DECLARE(IF, "If"); +REGISTER_OPTYPE_DECLARE(CASE, "Case"); +REGISTER_OPTYPE_DECLARE(STATELESSCASE, "StatelessCase"); +REGISTER_OPTYPE_DECLARE(_WHILE, "_While"); +REGISTER_OPTYPE_DECLARE(WHILE, "While"); +REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile"); +REGISTER_OPTYPE_DECLARE(FOR, "For"); +REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall"); +REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall"); +REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam"); +REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); +REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); +REGISTER_OPTYPE_DECLARE(CAST, "Cast"); +REGISTER_OPTYPE_DECLARE(REGION, "Region"); +REGISTER_OPTYPE_DECLARE(YOLO, "Yolo"); +REGISTER_OPTYPE_DECLARE(YOLODETECTIONOUTPUT, "YoloDetectionOutput"); +REGISTER_OPTYPE_DECLARE(FILL, "Fill"); +REGISTER_OPTYPE_DECLARE(RANK, "Rank"); +REGISTER_OPTYPE_DECLARE(REVERSE, "Reverse"); +REGISTER_OPTYPE_DECLARE(UNPACK, "Unpack"); +REGISTER_OPTYPE_DECLARE(YOLO2REORG, "Yolo2Reorg"); +REGISTER_OPTYPE_DECLARE(REDUCESUM, "ReduceSum"); +REGISTER_OPTYPE_DECLARE(SUM, "Sum"); +REGISTER_OPTYPE_DECLARE(CONSTANT, "Const"); +REGISTER_OPTYPE_DECLARE(RESIZEBILINEAR, "ResizeBilinear"); +REGISTER_OPTYPE_DECLARE(RESIZEBILINEARGRAD, "ResizeBilinearGrad"); +REGISTER_OPTYPE_DECLARE(MAXIMUM, "Maximum"); +REGISTER_OPTYPE_DECLARE(FRAMEWORKOP, "FrameworkOp"); +REGISTER_OPTYPE_DECLARE(ARG, "_Arg"); +REGISTER_OPTYPE_DECLARE(FUSEDBATCHNORMGRAD, "FusedBatchNormGrad"); +REGISTER_OPTYPE_DECLARE(LSTM, "LSTM"); +REGISTER_OPTYPE_DECLARE(HIGHWAY, "HighWay"); +REGISTER_OPTYPE_DECLARE(RNN, "RNN"); +REGISTER_OPTYPE_DECLARE(ATTENTIONDECODER, "AttentionDecoder"); +REGISTER_OPTYPE_DECLARE(LOGICAL_NOT, "LogicalNot"); +REGISTER_OPTYPE_DECLARE(LOGICAL_AND, "LogicalAnd"); +REGISTER_OPTYPE_DECLARE(LOGICAL_OR, "LogicalOr"); +REGISTER_OPTYPE_DECLARE(EQUAL, "Equal"); +REGISTER_OPTYPE_DECLARE(NOTEQUAL, "NotEqual"); +REGISTER_OPTYPE_DECLARE(INTERP, "Interp"); +REGISTER_OPTYPE_DECLARE(SHUFFLECHANNEL, "ShuffleChannel"); +REGISTER_OPTYPE_DECLARE(AIPP, "Aipp"); +REGISTER_OPTYPE_DECLARE(MULTISHAPE, "MultiShape"); +REGISTER_OPTYPE_DECLARE(RECIPROCAL, "Reciprocal"); +REGISTER_OPTYPE_DECLARE(SELU, "Selu"); +REGISTER_OPTYPE_DECLARE(ELU, "Elu"); +REGISTER_OPTYPE_DECLARE(ACOSH, "Acosh"); +REGISTER_OPTYPE_DECLARE(ASINH, "Asinh"); +REGISTER_OPTYPE_DECLARE(MINIMUM, "Minimum"); +REGISTER_OPTYPE_DECLARE(CLIP, "Clip"); +REGISTER_OPTYPE_DECLARE(L2NORMALIZE, "L2Normalize"); +REGISTER_OPTYPE_DECLARE(CROPANDRESIZE, "CropAndResize"); +REGISTER_OPTYPE_DECLARE(UNUSEDCONST, "UnusedConst"); +REGISTER_OPTYPE_DECLARE(SPARSETODENSE, "SparseToDense"); +REGISTER_OPTYPE_DECLARE(NONMAXSUPPRESSION, "NonMaxSuppression"); +REGISTER_OPTYPE_DECLARE(TOPKV2, "TopKV2"); +REGISTER_OPTYPE_DECLARE(INVERTPERMUTATION, "InvertPermutation"); +REGISTER_OPTYPE_DECLARE(MULTINOMIAL, "Multinomial"); +REGISTER_OPTYPE_DECLARE(REVERSESEQUENCE, "ReverseSequence"); +REGISTER_OPTYPE_DECLARE(REDUCEPROD, "ReduceProd"); +REGISTER_OPTYPE_DECLARE(REDUCEMAX, "ReduceMax"); +REGISTER_OPTYPE_DECLARE(REDUCEMIN, "ReduceMin"); +REGISTER_OPTYPE_DECLARE(EXTRACTIMAGEPATCHES, "ExtractImagePatches"); +REGISTER_OPTYPE_DECLARE(SQRT, "Sqrt"); +REGISTER_OPTYPE_DECLARE(REDUCEALL, "ReduceAll"); +REGISTER_OPTYPE_DECLARE(RESIZENEARESTNEIGHBOR, "ResizeNearestNeighbor"); +REGISTER_OPTYPE_DECLARE(SPACETOBATCHND, "SpaceToBatchND"); +REGISTER_OPTYPE_DECLARE(BATCHTOSPACEND, "BatchToSpaceND"); +REGISTER_OPTYPE_DECLARE(ASSERT, "Assert"); +REGISTER_OPTYPE_DECLARE(GREATEREQUAL, "GreaterEqual"); +REGISTER_OPTYPE_DECLARE(FLOOR, "Floor"); +REGISTER_OPTYPE_DECLARE(RANDOMUNIFORM, "RandomUniform"); +REGISTER_OPTYPE_DECLARE(BATCHMATMUL, "BatchMatMul"); +REGISTER_OPTYPE_DECLARE(LESSEQUAL, "LessEqual"); +REGISTER_OPTYPE_DECLARE(ONEHOT, "OneHot"); +REGISTER_OPTYPE_DECLARE(LAYERNORM, "LayerNorm"); +REGISTER_OPTYPE_DECLARE(SPACETODEPTH, "SpaceToDepth"); +REGISTER_OPTYPE_DECLARE(DEPTHTOSPACE, "DepthToSpace"); +REGISTER_OPTYPE_DECLARE(RINT, "Rint"); +REGISTER_OPTYPE_DECLARE(ATAN, "Atan"); +REGISTER_OPTYPE_DECLARE(ATAN2, "Atan2"); +REGISTER_OPTYPE_DECLARE(ATANH, "Atanh"); +REGISTER_OPTYPE_DECLARE(ACOS, "Acos"); +REGISTER_OPTYPE_DECLARE(ASIN, "Asin"); +REGISTER_OPTYPE_DECLARE(NEG, "Neg"); +REGISTER_OPTYPE_DECLARE(LOG, "Log"); +REGISTER_OPTYPE_DECLARE(TAN, "Tan"); +REGISTER_OPTYPE_DECLARE(ROUND, "Round"); +REGISTER_OPTYPE_DECLARE(UPSAMPLE, "Upsample"); +REGISTER_OPTYPE_DECLARE(FLOORMOD, "FloorMod"); +REGISTER_OPTYPE_DECLARE(LESS, "Less"); +REGISTER_OPTYPE_DECLARE(ZEROSLIKE, "ZerosLike"); +REGISTER_OPTYPE_DECLARE(EXP, "Exp"); +REGISTER_OPTYPE_DECLARE(WHERE, "Where"); +REGISTER_OPTYPE_DECLARE(FAKEQUANTWITHMINMAXVARS, "FakeQuantWithMinMaxVars"); +REGISTER_OPTYPE_DECLARE(SOFTPLUS, "Softplus"); +REGISTER_OPTYPE_DECLARE(SOFTSIGN, "Softsign"); +REGISTER_OPTYPE_DECLARE(COSH, "Cosh"); +REGISTER_OPTYPE_DECLARE(SINH, "Sinh"); +REGISTER_OPTYPE_DECLARE(RETINAMULTIANCHORS, "RetinaMultiAnchor"); +REGISTER_OPTYPE_DECLARE(SQUAREDDIFFERENCE, "SquaredDifference"); +REGISTER_OPTYPE_DECLARE(REQUIREDSPACETOBATCHPADDINGS, "RequiredSpaceToBatchPaddings"); // for retinanet scope fusion +REGISTER_OPTYPE_DECLARE(SSDPOSTPROCESSOR, "SSDPostProcessor"); +REGISTER_OPTYPE_DECLARE(SSDANCHORGENERATOR, "SSDAnchorGenerator"); +REGISTER_OPTYPE_DECLARE(RETINANETBOXES, "RetinanetBoxes"); +REGISTER_OPTYPE_DECLARE(RETINANETCLIPPEDBOXES, "RetinanetClippedBoxes"); +REGISTER_OPTYPE_DECLARE(RETINANETFILTEREDDETECTIONS, "RetinanetFilteredDetections"); +REGISTER_OPTYPE_DECLARE(RETINANETPOSTPROCESSOR, "RetinanetPostProcessor"); +REGISTER_OPTYPE_DECLARE(RETINANETANCHORS, "RetinanetAnchors"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNMAP, "FasterRCNNMap"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNMAP1, "FasterRCNNMap1"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNSECONDSTAGEPOSTPROCESSOR, "FasterRCNNSecondStagePostprocessor"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNROIINTERPOOLING, "FasterRCNNROIInterPooling"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNFIRSTSTAGEPOSTPROCESSOR, "FasterRCNNFirstStagePostprocessor"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNGRIDANCHORGENERATOR, "FasterRCNNGridAnchorGenerator"); +REGISTER_OPTYPE_DECLARE(ROIINTERPOOLING, "ROIInterPooling"); +REGISTER_OPTYPE_DECLARE(FASTERRCNNCLIPTOWINDOW, "FasterRCNNClipToWindow"); +REGISTER_OPTYPE_DECLARE(EMBEDLOOKUP, "EmbedLookup"); +REGISTER_OPTYPE_DECLARE(HASHLOOKUP, "HashLookup"); +REGISTER_OPTYPE_DECLARE(LSH_PROJ, "LshProject"); +REGISTER_OPTYPE_DECLARE(SVDF, "SVDF"); +REGISTER_OPTYPE_DECLARE(IDENTITY, "Identity"); +REGISTER_OPTYPE_DECLARE(PLACEHOLDERWITHDEFAULT, "PlaceholderWithDefault"); +REGISTER_OPTYPE_DECLARE(IDENTITYN, "IdentityN"); +REGISTER_OPTYPE_DECLARE(GETSPAN, "GetSpan"); +REGISTER_OPTYPE_DECLARE(STOPGRADIENT, "StopGradient"); +REGISTER_OPTYPE_DECLARE(PREVENTGRADIENT, "PreventGradient"); +REGISTER_OPTYPE_DECLARE(GUARANTEECONST, "GuaranteeConst"); +REGISTER_OPTYPE_DECLARE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); +REGISTER_OPTYPE_DECLARE(BROADCASTARGS, "BroadcastArgs"); +REGISTER_OPTYPE_DECLARE(CONCATV2, "ConcatV2"); +REGISTER_OPTYPE_DECLARE(CONCATOFFSET, "ConcatOffset"); +REGISTER_OPTYPE_DECLARE(LESSEQUAL, "LessEqual"); +REGISTER_OPTYPE_DECLARE(SELECT, "Select"); +REGISTER_OPTYPE_DECLARE(CONFUSIONMATRIX, "ConfusionMatrix"); +REGISTER_OPTYPE_DECLARE(PLACEHOLDER, "PlaceHolder"); +REGISTER_OPTYPE_DECLARE(END, "End"); +REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); +REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); +REGISTER_OPTYPE_DECLARE(DYNAMICGETNEXT, "DynamicGetNext"); +REGISTER_OPTYPE_DECLARE(DYNAMICGETNEXTV2, "DynamicGetNextV2"); +REGISTER_OPTYPE_DECLARE(ITERATOR, "Iterator"); +REGISTER_OPTYPE_DECLARE(ITERATORV2, "IteratorV2"); +REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); +REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape"); +REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); +REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast"); +REGISTER_OPTYPE_DECLARE(GATHERSHAPES, "GatherShapes"); +REGISTER_OPTYPE_DECLARE(FLATTENV2, "FlattenV2"); +REGISTER_OPTYPE_DECLARE(FILECONSTANT, "FileConstant"); +REGISTER_OPTYPE_DECLARE(CONSTPLACEHOLDER, "ConstPlaceHolder"); + +// ANN dedicated operator +REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); +REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); +REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); +REGISTER_OPTYPE_DECLARE(ANN_FULLCONNECTION, "AnnFullConnection"); +REGISTER_OPTYPE_DECLARE(ANN_NETOUTPUT, "AnnNetOutput"); +REGISTER_OPTYPE_DECLARE(ANN_DATA, "AnnData"); +REGISTER_OPTYPE_DECLARE(ANN_RESHAPE, "AnnReshape"); +REGISTER_OPTYPE_DECLARE(ANN_ADD, "AnnAdd"); +REGISTER_OPTYPE_DECLARE(ANN_MUL, "AnnMul"); +REGISTER_OPTYPE_DECLARE(ANN_SUB, "AnnSub"); +REGISTER_OPTYPE_DECLARE(ANN_DIV, "AnnDiv"); +REGISTER_OPTYPE_DECLARE(ANN_DEQUANTIZE, "AnnDequant"); +REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); +REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); +REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); + +// Training operator +REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); +REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); +REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); +REGISTER_OPTYPE_DECLARE(CONV2DBACKPROPINPUT, "Conv2DBackpropInput"); +REGISTER_OPTYPE_DECLARE(FUSEDBATCHNORM, "FusedBatchNorm"); +REGISTER_OPTYPE_DECLARE(BIASADDGRAD, "BiasAddGrad"); +REGISTER_OPTYPE_DECLARE(ACTIVATIONGRAD, "ReluGrad"); +REGISTER_OPTYPE_DECLARE(MAXPOOLWITHARGMAX, "MaxPoolWithArgmax"); +REGISTER_OPTYPE_DECLARE(MAXPOOLGRADWITHARGMAX, "MaxPoolGradWithArgmax"); +REGISTER_OPTYPE_DECLARE(SPARSESOFTMAXCROSSENTROPYWITHLOGITS, "SparseSoftmaxCrossEntropyWithLogits"); +REGISTER_OPTYPE_DECLARE(SNAPSHOT, "Snapshot"); +REGISTER_OPTYPE_DECLARE(LAYERNORM, "LayerNorm"); +REGISTER_OPTYPE_DECLARE(HUBERLOSSGRAD, "HuberLossGrad"); +REGISTER_OPTYPE_DECLARE(HUBERLOSS, "HuberLoss"); +REGISTER_OPTYPE_DECLARE(NEGATIVE, "Negative"); +REGISTER_OPTYPE_DECLARE(SSDCAST, "SSDCast"); +REGISTER_OPTYPE_DECLARE(SSDSQUEEZEFUSION, "SsdSqueezeFusion"); +REGISTER_OPTYPE_DECLARE(SPARSESOFTMAXCROSSENTROPY, "SsdSparseSoftmaxCrossEntropy"); +REGISTER_OPTYPE_DECLARE(SPARSESOFTMAXCROSSENTROPYGRAD, "SsdSparseSoftmaxCrossEntropyGrad"); +REGISTER_OPTYPE_DECLARE(CONCATFIVE2FOUR, "ConcatFive2Four"); +REGISTER_OPTYPE_DECLARE(CONCATFOUR2FIVE, "ConcatFour2Five"); +REGISTER_OPTYPE_DECLARE(SSDREALDIVTILEMUL, "SSDRealdivTileMul"); +REGISTER_OPTYPE_DECLARE(SSDSUMMULREALDIVMEAN, "SSDSumMulRealdivMean"); + +REGISTER_OPTYPE_DECLARE(MEANGRAD, "MeanGrad"); +REGISTER_OPTYPE_DECLARE(TRANSLATE, "Translate"); +REGISTER_OPTYPE_DECLARE(ADDN, "AddN"); +REGISTER_OPTYPE_DECLARE(L2LOSS, "L2Loss"); +REGISTER_OPTYPE_DECLARE(MULTIPLY, "Multiply"); +REGISTER_OPTYPE_DECLARE(RELU6GRAD, "Relu6Grad"); +REGISTER_OPTYPE_DECLARE(AVGPOOLGRAD, "AvgPoolGrad"); +REGISTER_OPTYPE_DECLARE(DEPTHWISECONV2DBACKPROPFILTER, "DepthwiseConv2dNativeBackpropFilter"); +REGISTER_OPTYPE_DECLARE(DEPTHWISECONV2DBACKPORPINPUT, "DepthwiseConv2dNativeBackpropInput"); +REGISTER_OPTYPE_DECLARE(DEPTHWISECONV2DFORWARDNATIVE, "DepthwiseConv2dNative"); +REGISTER_OPTYPE_DECLARE(DROPOUTGRAD, "DropOutGrad"); +REGISTER_OPTYPE_DECLARE(APPLYRMSPROPMIXEDPRECISION, "apply_rms_prop_mixed_precision"); +REGISTER_OPTYPE_DECLARE(APPLYRMSPROP, "ApplyRMSProp"); +REGISTER_OPTYPE_DECLARE(LARS, "Lars"); +REGISTER_OPTYPE_DECLARE(DYNAMICSTITCH, "DynamicStitch"); + +// Variable sink related +REGISTER_OPTYPE_DECLARE(VARIABLEV2, "VariableV2"); +REGISTER_OPTYPE_DECLARE(VARHANDLEOP, "VarHandleOp"); +REGISTER_OPTYPE_DECLARE(TEMPORARYVARIABLE, "TemporaryVariable"); +REGISTER_OPTYPE_DECLARE(DESTROYTEMPORARYVARIABLE, "DestroyTemporaryVariable"); +REGISTER_OPTYPE_DECLARE(VARIABLE, "Variable"); + +REGISTER_OPTYPE_DECLARE(READVARIABLEOP, "ReadVariableOp"); + +REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); +REGISTER_OPTYPE_DECLARE(ISVARIABLEINITIALIZED, "IsVariableInitialized"); + +REGISTER_OPTYPE_DECLARE(ASSIGN, "Assign"); +REGISTER_OPTYPE_DECLARE(ASSIGNVARIABLEOP, "AssignVariableOp"); + +REGISTER_OPTYPE_DECLARE(ASSIGNADD, "AssignAdd"); +REGISTER_OPTYPE_DECLARE(ASSIGNADDVARIABLEOP, "AssignAddVariableOp"); + +REGISTER_OPTYPE_DECLARE(ASSIGNSUB, "AssignSub"); +REGISTER_OPTYPE_DECLARE(ASSIGNSUBVARIABLEOP, "AssignSubVariableOp"); + +REGISTER_OPTYPE_DECLARE(APPLYMOMENTUM, "ApplyMomentum"); +REGISTER_OPTYPE_DECLARE(RESOURCEAPPLYMOMENTUM, "ResourceApplyMomentum"); +REGISTER_OPTYPE_DECLARE(SGD, "SGD"); +REGISTER_OPTYPE_DECLARE(NOOP, "NoOp"); +REGISTER_OPTYPE_DECLARE(LAYERNORMGRAD, "LayerNormGrad"); + +REGISTER_OPTYPE_DECLARE(SQUARE, "Square"); +REGISTER_OPTYPE_DECLARE(HCOMBROADCAST, "HcomBroadcast"); +REGISTER_OPTYPE_DECLARE(HCOMALLGATHER, "HcomAllGather"); +REGISTER_OPTYPE_DECLARE(HCOMALLGATHERV, "HcomAllGatherV"); +REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce"); +REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); +REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTERV, "HcomReduceScatterV"); +REGISTER_OPTYPE_DECLARE(HCOMREDUCE, "HcomReduce"); +REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); +REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite"); +REGISTER_OPTYPE_DECLARE(HCOMALLTOALLV, "HcomAllToAllV"); +REGISTER_OPTYPE_DECLARE(HCOMGATHERALLTOALLV, "HcomGatherAllToAllV"); +REGISTER_OPTYPE_DECLARE(HCOMALLTOALLVC, "HcomAllToAllVC"); +REGISTER_OPTYPE_DECLARE(HCOMALLTOALL, "HcomAllToAll"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTELOOKUP, "HcomRemoteLookup"); +REGISTER_OPTYPE_DECLARE(HCOMCOLLREMOTELOOKUP, "HcomCollRemoteLookup"); +REGISTER_OPTYPE_DECLARE(HCOMCOLLREMOTEUPDATE, "HcomCollRemoteUpdate"); +REGISTER_OPTYPE_DECLARE(HCOMGATHER, "HcomGather"); +REGISTER_OPTYPE_DECLARE(HCOMCOLLREMOTELOOKUPPAIRED, "HcomCollRemoteLookupPaired"); +REGISTER_OPTYPE_DECLARE(HCOMCOLLREMOTEUPDATEPAIRED, "HcomCollRemoteUpdatePaired"); +REGISTER_OPTYPE_DECLARE(HCOMCOLLREMOTELOOKUPUNIQUEDANDPAIRED, "HcomCollRemoteLookupUniquedAndPaired"); + +REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); +REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); +REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); +REGISTER_OPTYPE_DECLARE(PARALLELCONCATSTART, "_ParallelConcatStart"); +REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); +REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); +REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); +REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); +REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); +REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); +REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); +REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); +REGISTER_OPTYPE_DECLARE(MODELEXIT, "ModelExit"); +REGISTER_OPTYPE_DECLARE(SEND, "Send"); +REGISTER_OPTYPE_DECLARE(SENDNOTIFY, "SendNotify"); +REGISTER_OPTYPE_DECLARE(RECV, "Recv"); +REGISTER_OPTYPE_DECLARE(RECVNOTIFY, "RecvNotify"); +REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); +REGISTER_OPTYPE_DECLARE(STARTOFSEQUENCE, "StartOfSequence"); +REGISTER_OPTYPE_DECLARE(NPUGETFLOATSTATUS, "NPUGetFloatStatus"); +REGISTER_OPTYPE_DECLARE(NPUCLEARFLOATSTATUS, "NPUClearFloatStatus"); +REGISTER_OPTYPE_DECLARE(NPUGETFLOATSTATUSV2, "NPUGetFloatStatusV2"); +REGISTER_OPTYPE_DECLARE(NPUCLEARFLOATSTATUSV2, "NPUClearFloatStatusV2"); +REGISTER_OPTYPE_DECLARE(NPUGETFLOATDEBUGSTATUS, "NPUGFloatDebugStatus"); +REGISTER_OPTYPE_DECLARE(NPUCLEARFLOATDEBUGSTATUS, "NPUClearFloatDebugStatus"); + +REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); +REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); +REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); +REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); +REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); + +REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); +REGISTER_OPTYPE_DECLARE(MEMSET, "MemSet"); + +REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); +REGISTER_OPTYPE_DECLARE(ACCUMULATE_N_V2, "AccumulateNV2"); +REGISTER_OPTYPE_DECLARE(ACOS_GRAD, "AcosGrad"); +REGISTER_OPTYPE_DECLARE(ACOSH_GRAD, "AcoshGrad"); +REGISTER_OPTYPE_DECLARE(ANY, "Any"); +REGISTER_OPTYPE_DECLARE(APPROXIMATE_EQUAL, "ApproximateEqual"); +REGISTER_OPTYPE_DECLARE(ASIN_GRAD, "AsinGrad"); +REGISTER_OPTYPE_DECLARE(ASINH_GRAD, "AsinhGrad"); +REGISTER_OPTYPE_DECLARE(ATAN_GRAD, "AtanGrad"); +REGISTER_OPTYPE_DECLARE(BROADCAST_TO, "BroadcastTo"); +REGISTER_OPTYPE_DECLARE(ELU_GRAD, "EluGrad"); +REGISTER_OPTYPE_DECLARE(ADD_V2, "AddV2"); +REGISTER_OPTYPE_DECLARE(DATAFORMATDIMMAP, "DataFormatDimMap"); +REGISTER_OPTYPE_DECLARE(DATAFORMATVECPERMUTE, "DataFormatVecPermute"); +REGISTER_OPTYPE_DECLARE(DEQUANTIZE, "Dequantize"); +REGISTER_OPTYPE_DECLARE(APPLYADADELTA, "ApplyAdadelta"); +REGISTER_OPTYPE_DECLARE(APPLYADAGRAD, "ApplyAdagrad"); +REGISTER_OPTYPE_DECLARE(APPLYADAGRADDA, "ApplyAdagradDA"); +REGISTER_OPTYPE_DECLARE(APPLYADAM, "ApplyAdam"); +REGISTER_OPTYPE_DECLARE(APPLYADAMAX, "ApplyAdaMax"); +REGISTER_OPTYPE_DECLARE(APPLYADDSIGN, "ApplyAddSign"); +REGISTER_OPTYPE_DECLARE(APPLYCENTEREDRMSPROP, "ApplyCenteredRMSProp"); +REGISTER_OPTYPE_DECLARE(APPLYFTRL, "ApplyFtrl"); +REGISTER_OPTYPE_DECLARE(APPLYFTRLV2, "ApplyFtrlv2"); +REGISTER_OPTYPE_DECLARE(APPLYGRADIENTDESCENT, "ApplyGradientDescent"); +REGISTER_OPTYPE_DECLARE(APPLYPOWERSIGN, "ApplyPowerSign"); +REGISTER_OPTYPE_DECLARE(APPLYPROXIMALADAGRAD, "ApplyProximalAdagrad"); +REGISTER_OPTYPE_DECLARE(APPLYPROXIMALGRADIENTDESCENT, "ApplyProximalGradientDescent"); + +REGISTER_OPTYPE_DECLARE(FOCAL_LOSS, "FocalLoss"); +REGISTER_OPTYPE_DECLARE(FOCAL_LOSS_GRAD, "FocalLossGrad"); +REGISTER_OPTYPE_DECLARE(SMOOTHL1_LOSS, "SmoothL1Loss"); +REGISTER_OPTYPE_DECLARE(SMOOTHL1_LOSS_grad, "SmoothL1LossGrad"); +REGISTER_OPTYPE_DECLARE(REDUCEMEAN, "ReduceMean"); +REGISTER_OPTYPE_DECLARE(CONCAT_V2, "ConcatV2"); +REGISTER_OPTYPE_DECLARE(ONEHOT_V2, "OneHotV2"); +REGISTER_OPTYPE_DECLARE(SLICE_V2, "SliceV2"); +REGISTER_OPTYPE_DECLARE(TILE_V2, "TileV2"); +REGISTER_OPTYPE_DECLARE(SUM_V2, "SumV2"); +// Common operator type when operators have the same name +REGISTER_OPTYPE_DECLARE(DETECTIONOUTPUT, "DetectionOutput"); + +// custom operator +REGISTER_OPTYPE_DECLARE(CUSTOMOP, "CustomOp"); +REGISTER_OPTYPE_DECLARE(CUSTOMOP_NCHW, "CustomOpNchw"); +REGISTER_OPTYPE_DECLARE(CUSTOMOP_NHWC, "CustomOpNhwc"); +REGISTER_OPTYPE_DECLARE(CUSTOMOP_NC1HWC0, "CustomOpNc1hwc0"); + +// Depthwise 4d_2_6d,6d_2_4d +REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT4D26D, "depthwise_weight_4d_2_6d"); +REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT6D24D, "depthwise_weight_6d_2_4d"); + +REGISTER_OPTYPE_DECLARE(SQRTGRAD, "SqrtGrad"); +REGISTER_OPTYPE_DECLARE(SIGMOIDGRAD, "SigmoidGrad"); + +// Horovod operator +REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLREDUCE, "HorovodAllreduce"); +REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLGATHER, "HorovodAllgather"); +REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); +REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); + +// aicpu op for online_infer dynamic_dims +REGISTER_OPTYPE_DECLARE(GETDYNAMICDIMS, "GetDynamicDims"); + +// profiling training trace node +REGISTER_OPTYPE_DECLARE(PROFILINGTRAININGTRACE, "ProfilingTrainingTrace"); + +// Stack series +REGISTER_OPTYPE_DECLARE(STACK, "Stack"); +REGISTER_OPTYPE_DECLARE(STACKPUSH, "StackPush"); +REGISTER_OPTYPE_DECLARE(STACKPOP, "StackPop"); +REGISTER_OPTYPE_DECLARE(STACKCLOSE, "StackClose"); + +// embedding service +REGISTER_OPTYPE_DECLARE(EMBEDDING_TABLE_FIND, "EmbeddingTableFind"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_TABLE_FIND_AND_INIT, "EmbeddingTableFindAndInit"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_APPLY_ADAM, "EmbeddingApplyAdam"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_APPLY_ADAM_W, "EmbeddingApplyAdamW"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_APPLY_ADA_GRAD, "EmbeddingApplyAdaGrad"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_APPLY_SGD, "EmbeddingApplySgd"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_APPLY_RMSPROP, "EmbeddingApplyRmsprop"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_APPLY_FTRL, "EmbeddingApplyFtrl"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_TABLE_EXPORT, "EmbeddingTableExport"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_TABLE_IMPORT, "EmbeddingTableImport"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_COMPUTE_VAR_EXPORT, "EmbeddingComputeVarExport"); +REGISTER_OPTYPE_DECLARE(TABLE_TO_RESOURCE, "TableToResource"); +REGISTER_OPTYPE_DECLARE(EXPONENTIAL_DECAY_LR, "ExponentialDecayLR"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_FEATURE_MAPPING_V2, "EmbeddingFeatureMappingV2"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_FEATURE_MAPPING_IMPORT, "EmbeddingFeatureMappingImport"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_FEATURE_MAPPING_INSERT, "EmbeddingFeatureMappingInsert"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_FEATURE_MAPPING_FILE_SIZE, "EmbeddingFeatureMappingFileSize"); +REGISTER_OPTYPE_DECLARE(EMBEDDING_TABLE_EVICT, "EmbeddingTableEvict"); + +// Data flow +REGISTER_OPTYPE_DECLARE(FLOWNODE, "FlowNode"); +REGISTER_OPTYPE_DECLARE(FLOWFUNC, "FlowFunc"); +REGISTER_OPTYPE_DECLARE(FLOWSEND, "FlowSend"); +REGISTER_OPTYPE_DECLARE(FLOWRECV, "FlowRecv"); + +// Phony +REGISTER_OPTYPE_DECLARE(PHONYCONCAT, "PhonyConcat"); +REGISTER_OPTYPE_DECLARE(PHONYSPLIT, "PhonySplit"); + +REGISTER_OPTYPE_DECLARE(CMO, "Cmo"); + +// Dsa +const char_t *const DSAGENBITMASK = "DSAGenBitMask"; +const char_t *const DSARANDOMTRUNCATEDNORMAL = "DSARandomTruncatedNormal"; +const char_t *const DSARANDOMNORMAL = "DSARandomNormal"; +const char_t *const DSARANDOMUNIFORM = "DSARandomUniform"; + +const std::set kFixedAddrNodeTypes = {DSAGENBITMASK, DSARANDOMNORMAL, DSARANDOMTRUNCATEDNORMAL, + DSARANDOMUNIFORM}; +// @brief encryption type of the model file +enum ModelEncryptType { + UNENCRYPTED, // not encrypted + ENCRYPTED // encrypted +}; + +/// +/// @brief signature verification +/// +enum ModelCheckType { + CHECK, // signature verification + UNCHECK // no verification +}; + +/// +/// @brief dynamic input type +/// +enum DynamicInputType { + FIXED = 0, // default mode + DYNAMIC_BATCH = 1, + DYNAMIC_IMAGE = 2, + DYNAMIC_DIMS = 3 +}; + +/// +/// @brief magic number of the model file +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_MAGIC_NUM; + +/// +/// @brief model header length +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_HEAD_LEN; + +/// +/// @brief model name length +/// +constexpr uint32_t MODEL_NAME_LENGTH = 32U; + +/// +/// @brief length of user-defined information +/// +constexpr uint32_t USER_DEFINE_INFO_LENGTH = 32U; + +/// +/// @brief length of the model file signature +/// +constexpr uint32_t MODEL_FILE_CHECKSUM_LENGTH = 64U; + +/// +/// @brief length of the reserved field in the model file header +/// +constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 62U; + +// DATA node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DATA_TYPE; + +// DATA Operator Type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string AIPP_DATA_TYPE; + +// framework Operator Type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string FRAMEWORK_OP_TYPE; + +// DATA node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ANN_DATA_TYPE; +// convolution node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_NET_OUTPUT; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; + +// delimiter of operator configuration items +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_CONF_DELIMITER; + +// dim default size value +constexpr int32_t DIM_DEFAULT_SIZE = 4; + +// default NCHW index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_N; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_C; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NCHW_DIM_W; + +// default NHWC index +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_N; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_H; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_W; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t NHWC_DIM_C; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_VERSION; // model version 1.0 + +// flowctrl +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_STREAM_SWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_PER_ITER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_COND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_INCREMENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_RESETVALUE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_ASSIGNADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_FLOWCTRL_LOOP_ASSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_ATOMIC_ADDR_CLEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_ATOMIC_MEMSET; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t TRUE_STREAM_ID; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t STREAM_SWITCH_INPUT_NUM; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_GLOBAL_STEP; + +constexpr uint32_t PLATFORM_VERSION_LEN = 20U; + +enum class OsCpuInfoCheckTyep :uint8_t { + NO_CHECK, + NEED_CHECK +}; + +// Definition of the file header of the model file +struct ModelFileHeader { + uint32_t magic = MODEL_FILE_MAGIC_NUM; // magic number of DOMI + uint32_t headsize = MODEL_FILE_HEAD_LEN; // length of the model header. The value is fixed at 256 + uint32_t version = MODEL_VERSION; // version 1.0 + uint8_t checksum[MODEL_FILE_CHECKSUM_LENGTH] = {0U}; // signature + uint32_t length = 0U; // Ciphertext length. In the non-encryption model, the length is the plaintext length. + // whether encrypted 0:not encrypt, 1:encrypt + uint8_t is_encrypt = static_cast(ModelEncryptType::UNENCRYPTED); + uint8_t is_checksum = static_cast(ModelCheckType::CHECK); // whether to check the checksum + uint8_t modeltype = 0U; // 0:IR model 1:standard model 2:OM Tiny model 3:flow model + uint8_t genmode = 0U; // 0:offline generate 1:online generate + uint8_t name[MODEL_NAME_LENGTH] = {0U}; // Model name, which contains 32 characters + uint32_t ops = 0U; // Computing power (Kops) + uint8_t userdefineinfo[USER_DEFINE_INFO_LENGTH] = {0U}; // User-defined information. The value contains 32 characters + uint32_t om_ir_version = 0U; + uint32_t model_num = 0U; + uint8_t platform_version[PLATFORM_VERSION_LEN] = {0U}; + uint8_t platform_type = {0U}; + uint8_t padd[3] = {0}; // For initializing aligned memory + uint64_t model_length = 0UL; + uint8_t need_check_os_cpu_info = static_cast(OsCpuInfoCheckTyep::NO_CHECK); + uint8_t is_unknow_model = 0U; // 0:static model 1:dynamic model + uint8_t reserved[MODEL_FILE_RESERVED_LENGTH] = {0U}; // Reserved field 62 +}; + +constexpr uint8_t TARGET_TYPE_LTTE_8BIT = 0U; +constexpr uint8_t TARGET_TYPE_MINI_8BIT = 1U; + +// number of partitions in the current model +constexpr uint32_t PARTITION_SIZE = 5U; + +enum ModelPartitionType { + MODEL_DEF = 0, + WEIGHTS_DATA = 1, + TASK_INFO = 2, + TBE_KERNELS = 3, + CUST_AICPU_KERNELS = 4, + SO_BINS = 5, + FLOW_MODEL = 6, + FLOW_SUBMODEL = 7, + MODEL_INOUT_INFO = 8, + STATIC_TASK_DESC = 9, + DYNAMIC_TASK_DESC = 10, + TASK_PARAM = 11, + TILING_DATA = 12, + PRE_MODEL_DESC = 20, + PRE_MODEL_SQE = 21, + PRE_KERNEL_ARGS = 22, + PRE_MODEL_DESC_EXTEND = 23, + BUNDLE_MODEL_INFO = 24 +}; + +enum ModelHeaderType { + MODEL_TYPE_IR_MODEL = 0, + MODEL_TYPE_STANDARD_MODEL = 1, + MODEL_TYPE_OM_TINY_MODEL = 2, + MODEL_TYPE_FLOW_MODEL = 3, + MODEL_TYPE_LITE_MODEL = 4, + MODEL_TYPE_BUNDLE_MODEL = 5, + MODEL_TYPE_RESERVED = 6 +}; + +struct TinyModelPartitionMemInfo { + ModelPartitionType type; + uint32_t mem_offset; + uint32_t mem_size; +}; + +struct TinyModelPartitionTable { + uint32_t num; + TinyModelPartitionMemInfo partition[0]; +}; + +inline uint64_t SizeOfTinyModelPartitionTable(const TinyModelPartitionTable &table) { + return sizeof(TinyModelPartitionTable) + (sizeof(TinyModelPartitionMemInfo) * static_cast(table.num)); +} + +struct ModelPartitionMemInfo { + ModelPartitionType type; + uint64_t mem_offset; + uint64_t mem_size; +}; + +struct ModelPartitionTable { + uint32_t num; + ModelPartitionMemInfo partition[0]; +}; + +inline uint64_t SizeOfModelPartitionTable(const ModelPartitionTable &table) { + return sizeof(ModelPartitionTable) + (sizeof(ModelPartitionMemInfo) * static_cast(table.num)); +} +// mode of activation +typedef enum tagDomiActivationMode { + DOMI_ACTIVATION_SIGMOID = 0, // sigmoid + DOMI_ACTIVATION_RELU, // ReLU + DOMI_ACTIVATION_TANH, // tanh + DOMI_ACTIVATION_CLIPPED_RELU, // clipped ReLU + DOMI_ACTIVATION_ELU, // ELU + DOMI_ACTIVATION_LEAKY_RELU, + DOMI_ACTIVATION_ABS, // Abs + DOMI_ACTIVATION_RELU1, // relu1 + DOMI_ACTIVATION_SOFTSIGN, // softsign + DOMI_ACTIVATION_SOFTPLUS, // softplus + DOMI_ACTIVATION_HARDSIGMOID, // hardsigmoid + DOMI_ACTIVATION_THRESHOLD_RELU, // threshold + DOMI_ACTIVATION_SELU, // selu + DOMI_ACTIVATION_LINEAR, // linear + DOMI_ACTIVATION_RESERVED +} domiActivationMode_t; + +enum class MemorySizeCalcType { + NORMAL = 0, + ALWAYS_EMPTY +}; + +enum AicpuWorkSpaceType { + CUST_LOG = 0, + INVALID_TYPE +}; +} // namespace ge + +namespace domi { +/// @brief Data structure definition related to task sinking +enum BuildMode { + GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) + GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) + GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) +}; +} // namespace domi + +#endif // INC_FRAMEWORK_COMMON_TYPES_H_ diff --git a/inc/air/framework/common/util.h b/inc/air/framework/common/util.h new file mode 100644 index 0000000000000000000000000000000000000000..bc7799de07fc8ae59b1b135fafd399c6d3e7d39a --- /dev/null +++ b/inc/air/framework/common/util.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_INC_FRAMEWORK_COMMON_UTIL_H_ +#define AIR_INC_FRAMEWORK_COMMON_UTIL_H_ +#include "common/ge_common/util.h" +namespace ge { +GE_FUNC_VISIBILITY Status ConvertToInt64(const std::string &str, int64_t &val); +GE_FUNC_VISIBILITY Status ConvertToUint64(const std::string &str, uint64_t &val); +} +#endif // AIR_INC_FRAMEWORK_COMMON_UTIL_H_ diff --git a/inc/air/framework/engine/dnnengine.h b/inc/air/framework/engine/dnnengine.h new file mode 100644 index 0000000000000000000000000000000000000000..0ee7f71b92bc5656727319c55c68261bf5d934da --- /dev/null +++ b/inc/air/framework/engine/dnnengine.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_ENGINE_DNNENGINE_H_ +#define INC_FRAMEWORK_ENGINE_DNNENGINE_H_ + +#include +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "graph/types.h" + +namespace ge { +enum class PriorityEnum : uint32_t { + COST_0 = 0, + COST_1 = 1, + COST_2 = 2, + COST_3 = 3, + COST_4 = 4, + COST_9 = 9, + COST_10 = 10, +}; + +struct DNNEngineAttribute { + std::string engine_name; + std::vector mem_type; + PriorityEnum compute_cost; + enum RuntimeType runtime_type; // HOST, DEVICE + // If engine input format must be specific, set this attribute, else set FORMAT_RESERVED + Format engine_input_format; + Format engine_output_format; + bool atomic_engine_flag; +}; + +class GE_FUNC_VISIBILITY DNNEngine { + public: + DNNEngine() = default; + explicit DNNEngine(const DNNEngineAttribute &attrs) { + engine_attribute_ = attrs; + } + virtual ~DNNEngine() = default; + Status Initialize(const std::map &options) const { + (void)options; + return SUCCESS; + } + Status Finalize() const { + return SUCCESS; + } + void GetAttributes(DNNEngineAttribute &attr) const { + attr = engine_attribute_; + } + bool IsAtomic() const { + return engine_attribute_.atomic_engine_flag; + } + + protected: + DNNEngineAttribute engine_attribute_; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_ENGINE_DNNENGINE_H_ diff --git a/inc/air/framework/executor/ge_executor.h b/inc/air/framework/executor/ge_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..71727d6a163410b239aaa848762bc7c4aed6b2eb --- /dev/null +++ b/inc/air/framework/executor/ge_executor.h @@ -0,0 +1,359 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_ +#define INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_ + +#include +#include +#include + +#include "ge/ge_allocator.h" +#include "common/dynamic_aipp.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "framework/common/types.h" +#include "framework/runtime/rt_session.h" +#include "graph/tensor.h" +#include "graph/ge_tensor.h" +#include "framework/common/ge_model_inout_types.h" + +namespace ge { +class SingleOp; +class DynamicSingleOp; +class GeRootModel; + +struct RunModelData { + uint32_t index; // Data index + uint32_t modelId; + std::vector blobs; // All input/output data buffer + uint32_t timestamp; // Data creation time + uint32_t timeout; // Processing timeout + uint64_t request_id = 0UL; // Request ID + uint64_t dynamic_batch_size = 0UL; // Dynamic batch size scene, set dynamic size, not supported by default:0 + uint64_t dynamic_image_height = 0UL; // Dynamic image size scene, set image height, not supported by default:0 + uint64_t dynamic_image_width = 0UL; // Dynamic image size scene, set image width, not supported by default:0 + std::vector dynamic_dims; // Dynamic dims scene, set dynamic dims, not supported by default:empty +}; + +struct ModelLoadArg { + void *dev_ptr; + size_t mem_size; + void *weight_ptr; + size_t weight_size; + gert::RtSession *rt_session = nullptr; +}; + +class GE_FUNC_VISIBILITY GeExecutor { + public: + GeExecutor(); + ~GeExecutor() = default; + + Status Initialize(); + Status Finalize(); + + /// + /// @ingroup ge + /// @brief Initialize global execute environment. + /// @param [in] options: environment variables. + /// @return init result + /// + static Status Initialize(const std::map &options); + + /// + /// @ingroup ge + /// @brief Finalize global execute environment. + /// @return execute result + /// + static Status FinalizeEx(); + + Status UnloadModel(const uint32_t model_id); + + // Get input and output descriptor + Status GetModelDescInfo(const uint32_t model_id, std::vector &input_desc, + std::vector &output_desc, const bool new_model_desc = false); + + Status GetModelDescInfoFromMem(const ModelData &model_data, ModelInOutInfo &info) const; + + /// + /// @ingroup ge + /// @brief Set dynamic batch size + /// @param [in] model_id: model id allocate from manager + /// @param [in] dynamic_input_addr: dynamic input addr created by user + /// @param [in] length: length of dynamic input addr + /// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario + /// @return execute result + /// + Status SetDynamicBatchSize(const uint32_t model_id, void *const dynamic_input_addr, const uint64_t length, + const uint64_t batch_size); + + /// + /// @ingroup ge + /// @brief Set dynamic image info + /// @param [in] model_id: model id allocate from manager + /// @param [in] dynamic_input_addr: dynamic input addr created by user + /// @param [in] length: length of dynamic input addr + /// @param [in] image_height: image height entered by user in dynamic multi-resolution scenario + /// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario + /// @return execute result + /// + Status SetDynamicImageSize(const uint32_t model_id, void *const dynamic_input_addr, const uint64_t length, + const uint64_t image_height, const uint64_t image_width); + + /// + /// @ingroup ge + /// @brief Set dynamic dims info + /// @param [in] model_id: model id allocate from manager + /// @param [in] dynamic_input_addr: dynamic input addr created by user + /// @param [in] length: length of dynamic input addr + /// @param [in] dynamic_dim_num: number of dynamic dimension + /// @param [in] dynamic_dims: array of dynamic dimensions + /// @return execute result + /// + Status SetDynamicDims(const uint32_t model_id, void *const dynamic_input_addr, const uint64_t length, + const std::vector &dynamic_dims); + + /// + /// @ingroup ge + /// @brief Get current dynamic dims info by combined dims + /// @param [in] model_id: model id allocate from manager + /// @param [in] dynamic_dims: cur gear dynamic dims value + /// @param [out] cur_dynamic_dims: current dynamic dims + /// @return execute result + /// + Status GetCurDynamicDims(const uint32_t model_id, const std::vector &dynamic_dims, + std::vector &cur_dynamic_dims); + + /// + /// @ingroup ge + /// @brief Get dynamic batch_info + /// @param [in] model_id + /// @param [out] batch_info + /// @param [out] dynamic_type + /// @return execute result + /// + Status GetDynamicBatchInfo(const uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type); + + /// + /// @ingroup ge + /// @brief Get combined dynamic dims info + /// @param [in] model_id + /// @param [out] batch_info + /// @return execute result + /// + Status GetCombinedDynamicDims(const uint32_t model_id, std::vector> &batch_info); + + /// + /// @ingroup ge + /// @brief Get user designeate shape order + /// @param [in] model_id + /// @param [out] user_designate_shape_order + /// @return execute result + /// + Status GetUserDesignateShapeOrder(const uint32_t model_id, + std::vector &user_designate_shape_order); + + Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); + + /// + /// @ingroup ge + /// @brief Set dynamic image info + /// @param [in] model_id: model id allocate from manager + /// @param [in] dynamic_input_addr: dynamic input addr created by user + /// @param [in] length: length of dynamic input addr + /// @param [in] aippBatchPara: kAippDynamicBatchPara vector by user in dynamic aipp + /// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp + /// @return execute result + /// + Status SetDynamicAippData(const uint32_t model_id, void *const dynamic_input_addr, const uint64_t length, + const std::vector &aipp_batch_para, + const kAippDynamicPara &aipp_parms); + + Status GetAIPPInfo(const uint32_t model_id, const uint32_t index, AippConfigInfo &aipp_info); + + Status GetOpAttr(const uint32_t model_id, const std::string &op_name, const std::string &attr_name, + std::string &attr_value); + + Status GetModelAttr(const uint32_t model_id, std::vector &dynamic_output_shape_info); + + Status GetAippType(const uint32_t model_id, const uint32_t index, InputAippType &type, size_t &aipp_index); + + Status CommandHandle(const Command &command) const; + + Status SetDump(const DumpConfig &dump_config); + + /// + /// @ingroup ge + /// @brief Query model memory consuming interface + /// @param [in] model_id Offline model ID + /// @param [out] max_size Memory size + /// @return SUCCESS + /// @return FAILED + /// + Status GetMaxUsedMemory(const uint32_t model_id, uint32_t &max_size); + + /// + /// @ingroup ge + /// @brief Load data from model file to memory + /// @param [in] const std::string &path: Offline model file path + /// @param [out] ModelData &model_data: Offline model memory data + /// @return SUCCESS handle successfully / others handle failed + /// + Status LoadDataFromFile(const std::string &path, ModelData &model_data); + + /// + /// @ingroup ge + /// @brief Load model from offline model memory data + /// @param [in] ModelData &model_data: Offline model data + /// @param [in] void *dev_ptr: Input/Output memory address + /// @param [in] size_t mem_size: Input/Output memory length + /// @param [in] void *weight_ptr: Weight memory address + /// @param [in] size_t weight_size: Weight memory length + /// @param [out] uint32_t &model_id: Corresponding identification after model loading + /// @return SUCCESS handle successfully / others handle failed + /// + Status LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *const dev_ptr, + const size_t mem_size, void *const weight_ptr, const size_t weight_size); + + Status LoadModelFromDataWithArgs(uint32_t &model_id, const ModelData &model_data, const ModelLoadArg &load_arg); + + /// + /// @ingroup ge + /// @brief Load task list from ModelData with queue. + /// @param [out] model_id: model id allocate from manager. + /// @param [in] model_data: Model data load from offline model. + /// @param [in] input_queue_ids: input queue ids create from user. + /// @param [in] output_queue_ids: input queue ids create from user. + /// @return: 0 for success / others for fail + /// + Status LoadModelWithQ(uint32_t &model_id, const ModelData &model_data, const std::vector &input_queue_ids, + const std::vector &output_queue_ids); + + /// + /// @ingroup ge + /// @brief Load task list from GeRootModel with queue and param. + /// @param [out] model_id: model id allocate from manager. + /// @param [in] root_model: Instance of GeRootModel. + /// @param [in] model_queue_param: params and queue ids and create from user. + /// @return: 0 for success / others for fail + /// + Status LoadModelWithQ(uint32_t &model_id, + const std::shared_ptr &root_model, + const ModelQueueParam &model_queue_param) const; + + /// + /// @ingroup ge + /// @brief Synchronous execution of offline model(Do not create thread) + /// @param [in] uint32_t model_id: Model ID to execute + /// @param [in] void* stream: stream to execute + /// @param [in] bool async_mode: is asynchronize mode. + /// @param [in] const domi::InputData *input_data: Model input data + /// @param [out] domi::OutputData *output_data: Model output data + /// @return SUCCESS handle successfully / others handle failed + /// + Status ExecModel(const uint32_t model_id, void *const stream, const RunModelData &input_data, + RunModelData &output_data, const bool async_mode = false); + + /// + /// @ingroup ge + /// @brief Load task list from root_model without input queue or output queue. + /// @param [out] model_id: model id allocate from manager. + /// @param [in] root_model: Instance of GeRootModel. + /// @return: 0 for success / others for fail + /// + Status LoadModelWithoutQ(uint32_t &model_id, const std::shared_ptr &root_model) const; + + /// + /// @ingroup ge + /// @brief Synchronous execution of offline model(Do not create thread) + /// @param [in] uint32_t model_id: Model ID to execute + /// @param [in] void* stream: stream to execute + /// @param [in] bool async_mode: is asynchronize mode. + /// @param [in] const domi::InputData *input_data: Model input data + /// @param [in] const std::vector &input_desc: description of model input data + /// @param [out] domi::OutputData *output_data: Model output data + /// @param [out] std::vector &output_desc: description of model output data + /// @return SUCCESS handle successfully / others handle failed + /// + Status ExecModel(const uint32_t model_id, void *const stream, const RunModelData &run_input_data, + const std::vector &input_desc, RunModelData &run_output_data, + std::vector &output_desc, const bool async_mode = false); + + /// + /// @ingroup ge + /// @brief Get weight memory size from model file + /// @param [in] const std::string &path: Offline model file path + /// @param [out] size_t &mem_size Execution memory size + /// @param [out] size_t &weight_size Weight memory space size + /// @return SUCCESS handle successfully / others handle failed + /// + Status GetMemAndWeightSize(const std::string &path, size_t &mem_size, size_t &weight_size); + + /// + /// @ingroup ge + /// @brief Get weight memory size from model file + /// @param [in] const void *model_data Offline model buffer + /// @param [in] size_t model_size Offline model buffer length + /// @param [out] size_t &mem_size Execution memory size + /// @param [out] size_t &weight_size Weight memory space size + /// @return SUCCESS handle successfully / others handle failed + /// + Status GetMemAndWeightSize(const void *const model_data, const size_t model_size, size_t &mem_size, + size_t &weight_size); + + static Status LoadSingleOp(const std::string &model_name, const ModelData &model_data, void *const stream, + SingleOp **const single_op); + + static Status LoadSingleOpV2(const std::string &model_name, const ModelData &model_data, void *const stream, + SingleOp **const single_op, const uint64_t model_id); + + static Status ExecuteAsync(SingleOp *const executor, const std::vector &inputs, + std::vector &outputs); + + static Status LoadDynamicSingleOp(const std::string &model_name, const ModelData &model_data, void *const stream, + DynamicSingleOp **const single_op); + + static Status LoadDynamicSingleOpV2(const std::string &model_name, const ModelData &model_data, void *const stream, + DynamicSingleOp **const single_op, const uint64_t model_id); + + static Status UnloadSingleOp(const uint64_t op_id); + + static Status UnloadDynamicSingleOp(const uint64_t op_id); + + static Status ExecuteAsync(DynamicSingleOp *const executor, const std::vector &input_desc, + const std::vector &inputs, std::vector &output_desc, + std::vector &outputs); + + static Status ReleaseSingleOpResource(void *const stream); + + static Status ClearCustomAicpuSo(const uint32_t device_id); + + static Status GetDeviceIdByModelId(const uint32_t model_id, uint32_t &device_id); + + Status GetBatchInfoSize(const uint32_t model_id, size_t &shape_count); + Status GetOrigInputInfo(const uint32_t model_id, const uint32_t index, OriginInputInfo &orig_input_info); + Status GetAllAippInputOutputDims(const uint32_t model_id, const uint32_t index, + std::vector &input_dims, std::vector &output_dims); + Status GetOpDescInfo(const uint32_t device_id, const uint32_t stream_id, const uint32_t task_id, + OpDescInfo &op_desc_info); + + static Status SetAllocator(void *const stream, ge::Allocator *const external_allocator); + + static Status ReleaseResource(const uint32_t device_id); + + static Status ReleaseResource(); + + static Status GetRuntimeModelId(const uint32_t model_id, uint32_t &model_runtime_id); + private: + static std::atomic_bool is_inited_; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_H_ diff --git a/inc/air/framework/executor_c/ge_executor.h b/inc/air/framework/executor_c/ge_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..d9cbee9db713d3796982ad4f8478dece62c263a8 --- /dev/null +++ b/inc/air/framework/executor_c/ge_executor.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_EXECUTOR_GE_C_EXECUTOR_H_ +#define INC_FRAMEWORK_EXECUTOR_GE_C_EXECUTOR_H_ + +#include +#include "framework/executor_c/ge_executor_types.h" +#include "framework/executor_c/types.h" +#if defined(__cplusplus) +extern "C" { +#endif + +Status GeInitialize(void); +Status GeFinalize(void); +Status GetModelDescInfo(uint32_t modelId, ModelInOutInfo *info); + +Status GetMemAndWeightSize(const char *fileName, size_t *workSize, size_t *weightSize); +Status GetPartitionSize(const char *fileName, GePartitionSize *mdlPartitionSize); +Status ExecModel(uint32_t modelId, ExecHandleDesc *execDesc, bool sync, const InputData *inputData, + OutputData *outputData); +Status GeLoadModelFromData(uint32_t *modelId, const ModelData *data); +Status LoadDataFromFile(const char *modelPath, ModelData *data); +void FreeModelData(ModelData *data); +Status UnloadModel(uint32_t modelId); +Status GetModelDescInfoFromMem(const ModelData *modelData, ModelInOutInfo *info); +void DestoryModelInOutInfo(ModelInOutInfo *info); +Status GeDbgInit(const char *configPath); +Status GeDbgDeInit(void); +Status GeNofifySetDevice(uint32_t chipId, uint32_t deviceId); +#if defined(__cplusplus) +} +#endif + +#endif // INC_FRAMEWORK_EXECUTOR_GE_C_EXECUTOR_H_ diff --git a/inc/air/framework/executor_c/ge_executor_types.h b/inc/air/framework/executor_c/ge_executor_types.h new file mode 100644 index 0000000000000000000000000000000000000000..6abf374420b887a187a549fab5fe52bc1ae8b987 --- /dev/null +++ b/inc/air/framework/executor_c/ge_executor_types.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_C_TYPES_H_ +#define INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_C_TYPES_H_ +#include +#include +#include +#include "mmpa_api.h" +#include "vector.h" +#if defined(__cplusplus) +extern "C" { +#endif +typedef uint32_t Status; +#define SUCCESS 0 + +typedef struct { + void *weightPtr; + size_t weightSize; + void *modelDescPtr; + size_t modelDescSize; + void *kernelPtr; + size_t kernelSize; + void *paramPtr; + size_t paramSize; + void *taskPtr; + size_t taskSize; + void *dynTaskPtr; + size_t dynTaskSize; +} Partition; + +typedef enum { + NO_NEED_READ_FROM_FD, + NEED_READ_FROM_FD +} ReadFileFlag; + +typedef struct { + void *modelData; + uint64_t modelLen; + int32_t priority; + mmFileHandle *fd; + ReadFileFlag flag; + Partition part; + size_t memType; + int32_t resv[3]; +} ModelData; + +typedef struct { + void *data; + uint64_t length; +} DataBuffer; + +typedef struct { + DataBuffer *dataBuffer; +} DataBlob; + +typedef struct { + Vector blobs; // type : DataBlob + uint64_t* io_addr; + uint32_t ioa_size; +} DataSet; + +typedef DataSet InputData; +typedef DataSet OutputData; + +typedef struct { + void *stream; + void *workPtr; + size_t workSize; + size_t mpamId; + size_t aicQos; + size_t aicOst; + size_t mecTimeThreshHold; +} ExecHandleDesc; + +typedef struct { + size_t workSize; + size_t weightSize; + size_t modelDescSize; + size_t kernelSize; + size_t kernelArgsSize; + size_t staticTaskSize; + size_t dynamicTaskSize; +} GePartitionSize; + +typedef struct { + uint32_t tlvType; + uint32_t (*pfnTlvProc)(uint32_t len, uint8_t *tlvValue, void *appInfo); +} TlvProcPair; + +#if defined(__cplusplus) +} +#endif + +#endif // INC_FRAMEWORK_EXECUTOR_GE_EXECUTOR_C_TYPES_H_ diff --git a/inc/air/framework/executor_c/ge_log.h b/inc/air/framework/executor_c/ge_log.h new file mode 100644 index 0000000000000000000000000000000000000000..5efabde3b03640871a2db531c7a7f1a50e2cb41e --- /dev/null +++ b/inc/air/framework/executor_c/ge_log.h @@ -0,0 +1,100 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_EXECUTOR_C_GE_LOG_H_ +#define INC_FRAMEWORK_EXECUTOR_C_GE_LOG_H_ + +#include +#include +#include +#include "toolchain/slog.h" +#include "mmpa_api.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define GE_GET_ERRORNO_STR "GE_ERRORNO_STR" +#define GE_GET_ERROR_LOG_HEADER "[GE][MODULE]" +#define GE_MODULE_NAME ((int32_t)GE) + +#ifndef FAILED +#define FAILED (-1) +#endif +// trace status of log +enum TraceStatus { TRACE_INT = 0, TRACE_RUNNING, TRACE_WAITTING, TRACE_STOP }; + +static inline uint64_t GetTid(void) { + const uint64_t tid = (uint64_t)(mmGetTaskId()); + return tid; +} + +#ifdef RUN_TEST +#define GELOGD(fmt, ...) \ +do { \ + printf("[DEBUG][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGI(fmt, ...) \ + do { \ + printf("[INFO][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGW(fmt, ...) \ + do { \ + printf("[WARNING][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGE(ERROR_CODE, fmt, ...) \ + do { \ + printf("[ERROR][%s:%d]%ld ErrorNo:%ld " fmt "\n", __FILE__, __LINE__, (long int)GetTid(), (long int)ERROR_CODE, \ + ##__VA_ARGS__); \ + } while (false) + +#define GEEVENT(fmt, ...) \ + do { \ + printf("[EVENT][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) +#else +#define GELOGD(fmt, ...) \ + do { \ + DlogRecord(GE_MODULE_NAME, DLOG_DEBUG, "[DEBUG][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, \ + (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGI(fmt, ...) \ + do { \ + DlogRecord(GE_MODULE_NAME, DLOG_INFO, "[INFO][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, \ + (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGW(fmt, ...) \ + do { \ + DlogRecord(GE_MODULE_NAME, DLOG_WARN, "[WARNING][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, \ + (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GEEVENT(fmt, ...) \ + do { \ + DlogRecord(GE_MODULE_NAME, DLOG_EVENT, "[Event][%s:%d]%ld " fmt "\n", __FILE__, __LINE__, \ + (long int)GetTid(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGE(ERROR_CODE, fmt, ...) \ + do { \ + DlogRecord(GE_MODULE_NAME, DLOG_ERROR, "[ERROR][%s:%d]%ld:ErrorNo:%u(%s)%s " fmt "\n", \ + __FILE__, __LINE__, (long int)GetTid(), (uint32_t)ERROR_CODE, \ + GE_GET_ERRORNO_STR, GE_GET_ERROR_LOG_HEADER, ##__VA_ARGS__); \ + } while (false) +#endif + +#ifdef __cplusplus +} +#endif +#endif // INC_FRAMEWORK_EXECUTOR_C_GE_LOG_H_ diff --git a/inc/air/framework/executor_c/types.h b/inc/air/framework/executor_c/types.h new file mode 100644 index 0000000000000000000000000000000000000000..e373fc93b99f26c3d0941e247793932b3cfe7007 --- /dev/null +++ b/inc/air/framework/executor_c/types.h @@ -0,0 +1,264 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_EXECUTOR_C_TYPES_H_ +#define INC_FRAMEWORK_EXECUTOR_C_TYPES_H_ +#include +#include +#include "vector.h" +#if defined(__cplusplus) +extern "C" { +#endif +#define MODEL_FILE_CHECKSUM_LENGTH 64 +#define MODEL_NAME_LENGTH 32 +#define USER_DEFINE_INFO_LENGTH 32 +#define PLATFORM_VERSION_LEN 20 +#define MODEL_FILE_RESERVED_LENGTH 75 +#define MODEL_FILE_MAGIC_NUM 0x444F4D49 + +enum ModelEncryptType { + UNENCRYPTED, + ENCRYPTED +}; + +enum ModelCheckType { + CHECK, + UNCHECK +}; + +typedef enum { + MODEL_DEF = 0, + WEIGHTS_DATA = 1, + TASK_INFO = 2, + TBE_KERNELS = 3, + CUST_AICPU_KERNELS = 4, + SO_BINS = 5, + FLOW_MODEL = 6, + FLOW_SUBMODEL = 7, + MODEL_INOUT_INFO = 8, + STATIC_TASK_DESC = 9, + DYNAMIC_TASK_DESC = 10, + TASK_PARAM = 11, + PRE_MODEL_DESC = 20, + PRE_MODEL_SQE = 21, + PRE_KERNEL_ARGS = 22, + PRE_MODEL_DESC_EXTEND = 23 +} ModelPartitionType; + + +typedef struct { + uint32_t magic; // magic number of DOMI + uint32_t headsize ; // length of the model header. The value is fixed at 256 + uint32_t version; // version 1.0 + uint8_t checksum[MODEL_FILE_CHECKSUM_LENGTH]; // signature + uint32_t length; // Ciphertext length. In the non-encryption model, the length is the plaintext length. + // whether encrypted 0:not encrypt, 1:encrypt + uint8_t is_encrypt; + uint8_t is_checksum; // whether to check the checksum + uint8_t modeltype; // 0:IR model 1:standard model 2:OM Tiny model 3:flow model + uint8_t genmode; // 0:offline generate 1:online generate + uint8_t name[MODEL_NAME_LENGTH]; // Model name, which contains 32 characters + uint32_t ops; // Computing power (Kops) + uint8_t userdefineinfo[USER_DEFINE_INFO_LENGTH]; // User-defined information. The value contains 32 characters + uint32_t om_ir_version; + uint32_t model_num; + uint8_t platform_version[PLATFORM_VERSION_LEN]; + uint8_t platform_type; + uint8_t reserved[MODEL_FILE_RESERVED_LENGTH]; // Reserved field 64 +} ModelFileHeader; + +typedef struct { + ModelPartitionType type; + uint64_t mem_offset; + uint64_t mem_size; +} ModelPartitionMemInfo; + +typedef struct { + uint32_t num; + ModelPartitionMemInfo partition[0]; +} ModelPartitionTable; + +enum TagType {INPUT_DESC_TAG = 0, OUTPUT_DESC_TAG, NAME_TAG = 10, DIMES_TAG = 11}; + +typedef struct { + uint8_t tag; + uint32_t len; + uint8_t name[0]; +} IOParamName; + +typedef struct { + uint8_t tag; + uint32_t len; + uint8_t dim[0]; +} IOParamDims; + +typedef enum { + FORMAT_NCHW = 0, // NCHW + FORMAT_NHWC, // NHWC + FORMAT_ND, // Nd Tensor + FORMAT_NC1HWC0, // NC1HWC0 + FORMAT_FRACTAL_Z, // FRACTAL_Z + FORMAT_NC1C0HWPAD = 5, + FORMAT_NHWC1C0, + FORMAT_FSR_NCHW, + FORMAT_FRACTAL_DECONV, + FORMAT_C1HWNC0, + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10, + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, + FORMAT_NC1HWC0_C04, // NC1HWC0, C0 is 4 + FORMAT_FRACTAL_Z_C04, // FRACZ, C0 is 4 + FORMAT_CHWN, + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15, + FORMAT_HWCN, + FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format + FORMAT_BN_WEIGHT, + FORMAT_FILTER_HWCK, // filter input tensor format + FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, + FORMAT_HASHTABLE_LOOKUP_KEYS, + FORMAT_HASHTABLE_LOOKUP_VALUE, + FORMAT_HASHTABLE_LOOKUP_OUTPUT, + FORMAT_HASHTABLE_LOOKUP_HITS, + FORMAT_C1HWNCoC0 = 25, + FORMAT_MD, + FORMAT_NDHWC, + FORMAT_FRACTAL_ZZ, + FORMAT_FRACTAL_NZ, + FORMAT_NCDHW = 30, + FORMAT_DHWCN, // 3D filter input tensor format + FORMAT_NDC1HWC0, + FORMAT_FRACTAL_Z_3D, + FORMAT_CN, + FORMAT_NC = 35, + FORMAT_DHWNC, + FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + FORMAT_FRACTAL_ZN_LSTM, + FORMAT_FRACTAL_Z_G, + FORMAT_RESERVED = 40, + FORMAT_ALL, + FORMAT_NULL, + FORMAT_ND_RNN_BIAS, + FORMAT_FRACTAL_ZN_RNN, + FORMAT_NYUV = 45, + FORMAT_NYUV_A, + FORMAT_NCL, + // Add new formats definition here + FORMAT_END, + // FORMAT_MAX defines the max value of Format. + // Any Format should not exceed the value of FORMAT_MAX. + // ** Attention ** : FORMAT_MAX stands for the SPEC of enum Format and almost SHOULD NOT be used in code. + // If you want to judge the range of Format, you can use FORMAT_END. + FORMAT_MAX = 0xff +} Format; + +typedef enum { + DT_FLOAT = 0, // float type + DT_FLOAT16 = 1, // fp16 type + DT_INT8 = 2, // int8 type + DT_INT16 = 6, // int16 type + DT_UINT16 = 7, // uint16 type + DT_UINT8 = 4, // uint8 type + DT_INT32 = 3, // + DT_INT64 = 9, // int64 type + DT_UINT32 = 8, // unsigned int32 + DT_UINT64 = 10, // unsigned int64 + DT_BOOL = 12, // bool type + DT_DOUBLE = 11, // double type + DT_STRING = 13, // string type + DT_DUAL_SUB_INT8 = 14, // dual output int8 type + DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type + DT_COMPLEX64 = 16, // complex64 type + DT_COMPLEX128 = 17, // complex128 type + DT_QINT8 = 18, // qint8 type + DT_QINT16 = 19, // qint16 type + DT_QINT32 = 20, // qint32 type + DT_QUINT8 = 21, // quint8 type + DT_QUINT16 = 22, // quint16 type + DT_RESOURCE = 23, // resource type + DT_STRING_REF = 24, // string ref type + DT_DUAL = 25, // dual output type + DT_VARIANT = 26, // dt_variant type + DT_BF16 = 27, // bf16 type + DT_UNDEFINED = 28, // Used to indicate a DataType field has not been set. + DT_INT4 = 29, // int4 type + DT_UINT1 = 30, // uint1 type + DT_INT2 = 31, // int2 type + DT_UINT2 = 32, // uint2 type + DT_MAX // Mark the boundaries of data types +} DataType; +#pragma pack(push) +#pragma pack(1) +typedef struct { + uint32_t task_num; + uint64_t workspace_size; + uint64_t weight_size; + uint32_t weight_type; + uint8_t profile_enable; + uint8_t model_interrupt; +} ModelDesc; +#pragma pack(pop) + +typedef struct { + uint32_t type : 2; + uint32_t pre_p : 2; + uint32_t post_p : 2; + uint32_t cond_s : 2; + uint32_t res : 2; + uint32_t prefetch_num : 2; + uint32_t block_dim : 2; + uint32_t code_size : 2; + uint32_t soft_user : 2; + uint32_t task_pc_offset : 2; + uint32_t task_param_offset : 2; +} TaskDesc; + +typedef struct { + char *name; + size_t size; + Format format; + DataType dataType; + Vector dims; +} ModelInOutTensorDesc; + +typedef struct { + Vector input_desc; + Vector output_desc; +} ModelInOutInfo; + +enum ModelDescType { + MODEL_INPUT_DESC, + MODEL_OUTPUT_DESC +}; + +typedef struct { + int32_t type; + uint32_t length; + uint8_t *value; +} ModelDescTlvConfig; + +typedef struct { + uint64_t size; + Format format; + DataType dt; + uint32_t name_len; + uint32_t dims_len; + uint32_t dimsV2_len; + uint32_t shape_range_len; +} ModelTensorDescBaseInfo; + +typedef struct { + ModelPartitionType type; + uint8_t *data; + uint32_t size; +} ModelPartition; + +#if defined(__cplusplus) +} +#endif + +#endif // INC_FRAMEWORK_EXECUTOR_C_TYPES_H_ diff --git a/inc/air/framework/generator/ge_generator.h b/inc/air/framework/generator/ge_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..63ad9770610d65bdc642963539050c53be6e1d9a --- /dev/null +++ b/inc/air/framework/generator/ge_generator.h @@ -0,0 +1,154 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ +#define INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ + +#include +#include +#include +#include +#include "external/ge/ge_ir_build.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "framework/omg/omg_inner_types.h" +#include "graph/detail/attributes_holder.h" + +namespace ge { +const std::string kAttrSupportDynamicShape = "support_dynamicshape"; + +class GeRootModel; +class GE_FUNC_VISIBILITY GeGenerator { + public: + using InOutTensorRef = std::pair &, const std::vector &>; + static GeGenerator &GetInstance() { + static GeGenerator Instance; + return Instance; + } + GeGenerator() = default; + + ~GeGenerator() { (void)Finalize(); } + + GeGenerator(const GeGenerator &) = delete; + + GeGenerator &operator=(const GeGenerator &) = delete; + + Status Initialize(const std::map &options); + Status Initialize(const std::map &options, OmgContext &context); + + Status Finalize(); + Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, + const std::vector &inputs = std::vector(), + OfflineModelFormat om_format = OM_FORMAT_DEFAULT); + + Status GenerateOnlineModel(const Graph &graph, const std::vector &inputs, ge::ModelBufferData &model); + + Status GenerateInfershapeGraph(const Graph &graph); + + /// + /// @ingroup ge + /// @brief: Build single OP in Model. + /// @param [in] op_desc: the OP description. + /// @param [in] inputs: input tensors. + /// @param [in] outputs: output tensors. + /// @param [in] model_file_name: name of model file. + /// @param [in] compile_flag: op build flag, accurate build is 0, fuzz build is 1 + /// @return SUCCESS or FAILED + /// + Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, const std::string &model_file_name, + int32_t compile_flag = 0); + /// + /// @ingroup ge + /// @brief: Build single Op into model buff. + /// @param [in] op_desc: the OP description. + /// @param [in] inputs: input tensors. + /// @param [in] outputs: output tensors. + /// @param [in] engine_type: engine type. + /// @param [in] compile_flag: op build flag, accurate build is 0, fuzz build is 1 + /// @param [out] model_buff: model buff of op. + /// @return SUCCESS or FAILED + Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, + OpEngineType engine_type, ModelBufferData &model_buff); + Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, + OpEngineType engine_type, int32_t compile_flag, ModelBufferData &model_buff); + Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, OpEngineType engine_type, + int32_t compile_flag, ModelBufferData &model_buff, + GraphStage graph_stage, ComputeGraphPtr &compute_graph); + + /// + /// @ingroup ge + /// @brief: Build single Op into model buff. + /// @param [in] op_desc: the OP description. + /// @param [in] inputs: input tensors. + /// @param [in] outputs: output tensors. + /// @param [in] graph_name: graph name. + /// @param [out] graph: graph of single op. + /// @return SUCCESS or FAILED + Status BuildSingleOpGraph(const OpDescPtr &op_desc, const InOutTensorRef &inputs_outputs, + std::string graph_name, Graph &graph, + std::vector> &inputs_name_type) const; + Status BuildOriginalGraphInfo(OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs, const std::string &model_file_name, + bool is_offline, GraphStage graph_stage, Graph &graph, ComputeGraphPtr &compute_graph, + std::vector> &inputs_name_type); + // temp solution, ge_ir should share one session between some generators, delete after ge_ir use session manager + Status SetCurrentSessionId(const uint64_t session_id) const; + // temp solution, ge_ir should share one rebuild_ctrl between some generators, delete after ge_ir use session manager + Status SetExternalGraphRebuildStateCtrl(void *rebuild_ctrl) const; + + private: + Status GenerateModel(const Graph &graph, const std::string &file_name_prefix, + const std::vector &inputs, + ge::ModelBufferData &model, bool is_offline = true, + OfflineModelFormat om_format = OM_FORMAT_DEFAULT); + Status BuildSingleOp(OpDescPtr &op_desc, const std::vector &inputs, const std::vector &outputs, + const std::string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, + ComputeGraphPtr &comp_graph, bool is_offline = true, int32_t compile_flag = 0, + GraphStage graph_stage = GraphStage::GRAPH_STAGE_RESERVED); + static Status CheckEngineTypeSupport(const NodePtr &node, OpEngineType engine_type); + static void RemoveConst(const std::vector &inputs, std::vector &outputs); + static Status CheckForSingleOp(const OpDescPtr &op_desc, const std::vector &inputs, + const std::vector &outputs); + static Status InferFormatForSingleOp(const OpDescPtr &op_desc, const Graph &graph); + static Status ResetAiCpuToDynamicShape(const ComputeGraphPtr &graph); + + using GeRootModelPtr = std::shared_ptr; + static Status SetModelNameForDump(const GeRootModelPtr &ge_root_model); + static Status CreateGeneralizedBuildAttrs(const GeRootModelPtr &ge_root_model, + const std::vector &inputs, + const std::vector &outputs, + const std::vector> &inputs_name_type, + std::vector &generalized_build_attrs); + static void AddExcludeEnginesOption(const OpDescPtr &op_desc, std::map &graph_options); + void AddShapeGeneralizedOption(std::map &graph_options) const; + void SetFuzzCompile(const std::vector &inputs, int32_t compile_flag) const; + bool IsFuzzCompileEnable() const; + void ConvertOpInfosToOptions(const OpDescPtr &op_desc) const; + static Status ResetInputOutputShape(const ComputeGraphPtr &graph, + const std::vector> &inputs_name_type, + std::vector &inputs_dynamic, + std::vector &outputs_dynamic); + static Status ResetOutputShapeRange(const OpDescPtr &op_desc, const size_t index, + std::vector> &shape_range); + static Status ResetTensorDesc(const size_t index, const GeShape &data_shape, std::vector &vector_dynamic, + std::vector> &dynamic_shape_range); + + class Impl; + std::shared_ptr impl_; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_GENERATOR_GE_GENERATOR_H_ diff --git a/inc/air/framework/memory/allocator_desc.h b/inc/air/framework/memory/allocator_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..d66236ff84e8d6c61f33291131b964cf506b83e7 --- /dev/null +++ b/inc/air/framework/memory/allocator_desc.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_MEMORY_ALLOCATOR_DESC_H_ +#define INC_FRAMEWORK_MEMORY_ALLOCATOR_DESC_H_ +#include + +namespace ge { +using AllocFunc = void *(*)(void *obj, size_t size); +using FreeFunc = void (*)(void *obj, void *block); +using AllocAdviseFunc = void *(*)(void *obj, size_t size, void *addr); +using GetAddrFromBlockFunc = void *(*)(void *block); + +using AllocatorDesc = struct AllocatorDesc { + AllocFunc alloc_func; + FreeFunc free_func; + AllocAdviseFunc alloc_advise_func; + GetAddrFromBlockFunc get_addr_from_block_func; + + void *obj; +}; +} // namespace ge + +#endif // INC_FRAMEWORK_MEMORY_ALLOCATOR_DESC_H_ diff --git a/inc/air/framework/memory/memory_api.h b/inc/air/framework/memory/memory_api.h new file mode 100644 index 0000000000000000000000000000000000000000..104b07429d1e2a9afedf4532a272e651cf6e98ef --- /dev/null +++ b/inc/air/framework/memory/memory_api.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_MEMORY_MEMORY_API_H_ +#define INC_FRAMEWORK_MEMORY_MEMORY_API_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "graph/compute_graph.h" +#include "runtime/mem.h" + +namespace ge { +enum MemStorageType { + HBM = 0, + RDMA_HBM, + HOST_DDR, + HOST_SVM, +}; + +struct HostVarInfo { + uint64_t base_addr; + uint64_t var_size; +}; + +struct TensorInfo { + std::string var_name; + std::vector dims; + DataType data_type; +}; + +/// \param size [in] rdma pool memory size to be allocated. +/// \param mem_type [in] memory type for rdma pool. +/// \return Status result of function +GE_FUNC_VISIBILITY Status InitRdmaPool(size_t size, rtMemType_t mem_type = RT_MEMORY_HBM); + +/// \param var_info [in] host variable addr infos. +/// \param mem_type [in] memory type for rdma pool. +/// \return Status result of function +GE_FUNC_VISIBILITY Status RdmaRemoteRegister(const std::vector &var_info, + rtMemType_t mem_type = RT_MEMORY_HBM); + +/// \param tensor_info [in] description for tensor stored shared memory. +/// \param dev_addr [out] malloced shared memory addr. +/// \param memory_size [out] malloced shared memory size. +/// \return Status result of function +GE_FUNC_VISIBILITY Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uint64_t &memory_size); + +/// \param var_name [in] var_name name of host variable. +/// \param base_addr [out] base_addr vase addr of host variable. +/// \param var_size [out] var_size memory_size of host variable. +/// \return Status result of function +GE_FUNC_VISIBILITY Status GetVarBaseAddrAndSize(const std::string &var_name, uint64_t &base_addr, uint64_t &var_size); + +/// \param var_name [in] var_name name of host variable. +/// \param base_addr [out] base_addr vase addr of host variable. +/// \param var_size [out] var_size memory_size of host variable. +/// \return Status result of function +GE_FUNC_VISIBILITY Status GetVarBaseAddrAndSize(const char_t *var_name, + uint64_t &base_addr, + uint64_t &var_size); + +/* + * @brief + * @param [in] graph + * @param [out] available_mem:graph mem size + */ +GE_FUNC_VISIBILITY Status GetGraphAvailableMemory(const ComputeGraphPtr &graph, uint64_t &available_mem); + +/* + * @brief + * @param [in] session_id + * @param [out] var_size:session variables mem size + * @param [out] graphs_mem_info: graphs mem info, include key:graph_id; value: {feature_map_size, const_size} + */ +GE_FUNC_VISIBILITY Status GetSessionMemInfo(const uint64_t session_id, uint64_t &var_size, + std::map> &graphs_mem_info); +} // namespace ge +#endif // INC_FRAMEWORK_MEMORY_MEMORY_API_H_ diff --git a/inc/air/framework/memory/memory_assigner.h b/inc/air/framework/memory/memory_assigner.h new file mode 100644 index 0000000000000000000000000000000000000000..7a950941dab9ec385527d202221afa002e968a0b --- /dev/null +++ b/inc/air/framework/memory/memory_assigner.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ +#define INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ + +#include + +#include "graph/build/memory/graph_mem_assigner.h" +#include "framework/common/ge_inner_error_codes.h" +#include "graph/node.h" + + +namespace ge { +class GE_FUNC_VISIBILITY MemoryAssigner { + public: + explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} + virtual ~MemoryAssigner() = default; + + MemoryAssigner(const MemoryAssigner &) = delete; + + MemoryAssigner &operator=(const MemoryAssigner &) = delete; + + Status AssignMemory(std::map &mem_offsets, size_t &zero_copy_mem_size, + const bool has_assigned_var_mem = false); + + const std::vector> &GetSubMemOffsets() const { + return sub_mem_offsets_; + } + + const std::shared_ptr GetGraphMemoryAssigner() const { + return graph_mem_assigner_; + } + + private: + ge::ComputeGraphPtr compute_graph_; + std::vector> sub_mem_offsets_; + std::shared_ptr graph_mem_assigner_; +}; +} // namespace ge +#endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ diff --git a/inc/air/framework/omg/ge_init.h b/inc/air/framework/omg/ge_init.h new file mode 100644 index 0000000000000000000000000000000000000000..a0ac859fb9c11c6f20ec6f9e904f28eb58134349 --- /dev/null +++ b/inc/air/framework/omg/ge_init.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_GE_INIT_H_ +#define INC_FRAMEWORK_OMG_GE_INIT_H_ +#include +#include +#include "framework/common/ge_inner_error_codes.h" + +namespace ge { +class GE_FUNC_VISIBILITY GEInit { + public: + // GE Environment Initialize, return Status: SUCCESS,FAILED + static Status Initialize(const std::map &options); + + static std::string GetPath(); + + // GE Environment Finalize, return Status: SUCCESS,FAILED + static Status Finalize(); +}; +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_GE_INIT_H_ diff --git a/inc/air/framework/omg/omg.h b/inc/air/framework/omg/omg.h new file mode 100644 index 0000000000000000000000000000000000000000..eb737b3cf71cf05cbae8f91ee901a3c977ccfc84 --- /dev/null +++ b/inc/air/framework/omg/omg.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_OMG_H_ +#define INC_FRAMEWORK_OMG_OMG_H_ + +#include +#include +#include + +#include +#include "external/ge/ge_api_types.h" +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "proto/ge_ir.pb.h" +#include "proto/om.pb.h" + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "runtime/kernel.h" + +namespace ge { +/** + * @ingroup domi_omg + * @brief init omg context + * @return void + */ +GE_FUNC_VISIBILITY domi::Status InitDomiOmgContext(const std::string &input_shape, const std::string &input_format, + const std::string &net_format, bool is_dynamic_input); + +/** + * @ingroup domi_omg + * @brief generate graph based on the input model file and weight file + * @param [out] graph graph + * @param [in] model_file path of model file + * @param [in] weights_file path of weight file + * @param [in] type type of the input model + * @param [in] op_conf op mapping configuration + * @param [in] target type of platform. If a tiny model is generated, set target to tiny + * @param [in] run_mode run model + * @param [in] enable_l2dynamic enable l2dynamic + * @param [in] is_dynamic_input dynamic input, true of false + * @param [in] atc_params multiply atc params + * @return Status result code + */ +GE_FUNC_VISIBILITY domi::Status ParseGraph(ge::Graph &graph, const std::map &atc_params, + const char *model_file, const char *weights_file, domi::FrameworkType type, + const char *op_conf = nullptr, const char *target = nullptr, + RunMode run_mode = RunMode::GEN_OM_MODEL, bool is_dynamic_input = false); + + +/** + * @ingroup domi_omg + * @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format + * @param [in] model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +GE_FUNC_VISIBILITY domi::Status ConvertOm(const char *model_file, const char *json_file, bool is_covert_to_json); + +GE_FUNC_VISIBILITY domi::Status ConvertPbtxtToJson(const char *model_file, const char *json_file); +/** + * @ingroup domi_omg + * @brief convert the model file in protobuf format into a JSON file. + * @param [in] framework type of model + * @param [in] om model_file path of offline model file + * @param [out] json_file path of json file + * @param [key] encrypted key + * @return Status result code + */ +GE_FUNC_VISIBILITY domi::Status ConvertFwkModelToJson(const domi::FrameworkType framework, const char *model_file, + const char *json_file); + +/** + * @ingroup domi_omg + * @brief convert the model file into a execute-om file. + * @param [in] om model_file path of offline model file + * @param [out] output path of execute-om file + * @return Status result code + */ +GE_FUNC_VISIBILITY void GetGroupName(ge::proto::ModelDef &model_def); + +GE_FUNC_VISIBILITY void FindParserSo(const std::string &path, std::vector &file_list, + std::string &caffe_parser_path); + +GE_FUNC_VISIBILITY domi::Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); + +GE_FUNC_VISIBILITY domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type); + +GE_FUNC_VISIBILITY domi::Status GetOutputLeaf(ge::NodePtr node, + std::vector> &output_nodes_info); + +GE_FUNC_VISIBILITY void CreateOutputNodesInfo(std::vector> &output_nodes_info, + std::vector &output_nodes_name); + +GE_FUNC_VISIBILITY void UpdateOmgCtxWithParserCtx(); + +GE_FUNC_VISIBILITY void UpdateParserCtxWithOmgCtx(); + +GE_FUNC_VISIBILITY void PrintModelInfo(ge::proto::ModelDef *model_def, uint32_t modeldef_size); +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_OMG_H_ diff --git a/inc/air/framework/omg/omg_inner_types.h b/inc/air/framework/omg/omg_inner_types.h new file mode 100644 index 0000000000000000000000000000000000000000..d739bda23de48e73f800cbfb6df791b8899b833b --- /dev/null +++ b/inc/air/framework/omg/omg_inner_types.h @@ -0,0 +1,123 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ +#define INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "framework/common/fmk_error_codes.h" +#include "register/register_fmk_types.h" +#include "graph/node.h" + +using domi::DOMI_TENSOR_ND; +using domi::domiTensorFormat_t; +using domi::DOMI_TENSOR_RESERVED; +using std::unordered_map; + +namespace ge { +/** + * @ingroup domi_omg + * @brief run model + */ +enum RunMode { + GEN_OM_MODEL = 0, // generate offline model file + MODEL_TO_JSON = 1, // convert to JSON file + ONLY_PRE_CHECK = 3, // only for pre-check + PBTXT_TO_JSON = 5, // pbtxt to json + DISPLAY_OM_INFO = 6, // display model info + GEN_EXE_OM = 10, // generate exe-om file + MODEL_TO_EXE_OM = 20, // conver om to exe-om file + GEN_EXE_OM_FOR_NANO = 30 // convert om to exe-om file for nano +}; + +struct OmgContext { + OmgContext() : format(domi::DOMI_TENSOR_ND) {} + domi::domiTensorFormat_t format; + + // format of the input specified by the command line + std::unordered_map input_nodes_format_map; + std::vector output_formats; + + // user-designate input dims + std::vector>> user_input_dims; + // global input dims + std::map> input_dims; + + // resolve the mapping between operators with the same name and corresponding network. format e.g. + // Detectionoutput:SsdDetectiontOutput + std::map op_conf_map; + // save the output node of the network. key = operator name, value = index, index indicates the output index of the + // operator + std::map> out_nodes_map; + // user-designate out nodes (this is used for determing the orders) + std::vector> user_out_nodes; + // default out nodes (this is used for determing the orders) + std::vector> default_out_nodes; + // save the output node of the network, value = topName, + // tensorName indicates the output name of the operator. + std::vector user_out_tensors; + // net out nodes (where user_out_nodes or leaf nodes) + std::vector net_out_nodes; + // net out nodes tensor names(caffe or onnx) + std::vector out_tensor_names; + // net data nodes tensor names(caffe or onnx) + std::vector data_tensor_names; + // preferential format used by the entire network + domi::domiTensorFormat_t net_format = domi::DOMI_TENSOR_RESERVED; + domi::FrameworkType type = domi::FRAMEWORK_RESERVED; + RunMode run_mode = RunMode::ONLY_PRE_CHECK; + bool train_flag = false; + + std::string output_type; + + // Whether to use dynamic batch size or dynamic image size + bool is_dynamic_input = false; + std::string dynamic_batch_size; + std::string dynamic_image_size; + std::string dynamic_dims; + std::string dynamic_node_type; + bool need_multi_batch = false; + std::vector data_nodes; + std::vector getnext_nosink_nodes; + bool fuzz_compile_flag = false; + std::string atc_cmdline; + bool user_attr_index_valid = false; + bool is_online_model = false; + bool is_subgraph_multi_batch = false; + // dynamic dims + std::vector> batch_shapes; + // name of the pass that needs to take effect + std::string enable_scope_fusion_passes; +}; +} // namespace ge + +namespace domi { +/** + * @ingroup domi_omg + * @brief get OMG context + * @return OmgContext context + */ +GE_FUNC_VISIBILITY ge::OmgContext &GetContext(); + +struct TEBinInfo { + // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. + // To be compatible with use cases written by previous users, fields are not deleted.(2018.11.21) + std::string bin_file_path; + std::string json_file_path; + std::string ddk_version; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_OMG_INNER_TYPES_H_ diff --git a/inc/air/framework/omg/omg_types.h b/inc/air/framework/omg/omg_types.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8eabe9231fb637adbc50a22935e7c1845344fb --- /dev/null +++ b/inc/air/framework/omg/omg_types.h @@ -0,0 +1,15 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_OMG_TYPES_H_ +#define INC_FRAMEWORK_OMG_OMG_TYPES_H_ + +#include "register/register_fmk_types.h" + +#endif // INC_FRAMEWORK_OMG_OMG_TYPES_H_ diff --git a/inc/air/framework/omg/parser/model_parser.h b/inc/air/framework/omg/parser/model_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4c0b39b54625905da9c194d5797f362b7fbc76a1 --- /dev/null +++ b/inc/air/framework/omg/parser/model_parser.h @@ -0,0 +1,242 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ + +#include +#include "framework/omg/parser/parser_types.h" +#include "framework/omg/omg_inner_types.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/graph_utils_ex.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/ascend_string.h" + +namespace domi { +using GetGraphCallback = std::function( + const google::protobuf::Message *root_proto, const std::string &graph)>; + +using GetGraphCallbackV2 = std::function; + +using GetGraphCallbackV3 = std::function; + +class GE_FUNC_VISIBILITY ModelParser { + public: + ModelParser() {} + + virtual ~ModelParser() {} + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] file Network model file path + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status Parse(const char *file, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in] input Model file memory size + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in] input Model file memory size + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] proto network model + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] proto network model + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const google::protobuf::Message *proto, GetGraphCallback callback, + ge::ComputeGraphPtr &graph) = 0; + /** + * @ingroup domi_omg + * @brief Convert model files to JSON format + * @param [in] model_file Model file path to be converted + * @param [out] json_file Converted JSON file path + * @return SUCCESS + * @return Others failed + */ + virtual Status ToJson(const char *model_file, const char *json_file) { + (void)model_file; + (void)json_file; + return SUCCESS; + } + + /* + * @ingroup domi_omg + * @brief Convert network data type + * @param [in] type Data type to be converted + * @return ge::DataType + */ + virtual ge::DataType ConvertToGeDataType(const uint32_t type) = 0; + + virtual Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) = 0; + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] proto serialized network model + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) { + (void)serialized_proto; + (void)graph; + return UNSUPPORTED; + } + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] proto serialized network model + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const std::string &serialized_proto, GetGraphCallbackV2 callback, + ge::ComputeGraphPtr &graph) { + (void)serialized_proto; + (void)callback; + (void)graph; + return UNSUPPORTED; + } + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] partitioned_serialized partitioned serialized network model + * @param [in] const_value_map const value map, key: constant node name value: serialized constant output tensor + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const std::vector &partitioned_serialized, + const std::map &const_value_map, + GetGraphCallbackV2 callback, + ge::ComputeGraphPtr &graph) { + (void)partitioned_serialized; + (void)const_value_map; + (void)callback; + (void)graph; + return UNSUPPORTED; + } + + /** + * @ingroup domi_omg + * @brief Analyze network model data + * @param [in] partitioned_serialized partitioned serialized network model + * @param [in] const_value_map const value map, key: constant node name value: serialized constant output tensor + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProto(const std::vector &partitioned_serialized, + const std::map &const_value_map, + ge::ComputeGraphPtr &graph) { + (void)partitioned_serialized; + (void)const_value_map; + (void)graph; + return UNSUPPORTED; + } + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] partitioned_serialized partitioned serialized network model + * @param [in] const_value_map const value map, key: constant node name value: serialized constant output tensor + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const std::vector &partitioned_serialized, + const std::map &const_value_map, + GetGraphCallbackV3 callback, + ge::ComputeGraphPtr &graph) { + (void)partitioned_serialized; + (void)const_value_map; + (void)callback; + (void)graph; + return UNSUPPORTED; + } + + /** + * @ingroup domi_omg + * @brief Analyze callback model data in subgraph + * @param [in] proto serialized network model + * @param [in] callback callback of subgraph + * @param [in|out] graph Save the network information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status ParseProtoWithSubgraph(const ge::AscendString &serialized_proto, GetGraphCallbackV3 callback, + ge::ComputeGraphPtr &graph) { + (void)serialized_proto; + (void)callback; + (void)graph; + return UNSUPPORTED; + } + + virtual bool HasError() { + return false; + } + + virtual Status Save(const std::string &file) { + (void)file; + return SUCCESS; + } + + virtual void Clear() {}; +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_MODEL_PARSER_H_ diff --git a/inc/air/framework/omg/parser/op_parser.h b/inc/air/framework/omg/parser/op_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..da9479ace63639f5072b067fd121e818522cc500 --- /dev/null +++ b/inc/air/framework/omg/parser/op_parser.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ + +#include +#include "framework/omg/parser/parser_types.h" +#include "framework/omg/omg_inner_types.h" +#include "proto/om.pb.h" +#include "graph/utils/op_desc_utils.h" + +using google::protobuf::Message; + +namespace ge { +/** + * @ingroup domi_omg + * @brief Used to analyze operator information + * + */ +class GE_FUNC_VISIBILITY OpParser { + public: + /** + * @ingroup domi_omg + * @brief Deconstructor + */ + virtual ~OpParser() {} + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be resolved + * @param [out] graph Parsed parameter data + * @return SUCCESS + * @return FAILED + */ + virtual domi::Status ParseParams(const google::protobuf::Message *op_src, ge::OpDescPtr &op_desc) = 0; + + /** + * @ingroup domi_omg + * @brief Analytic operator parameters + * @param [in] op_src Parameter data to be resolved + * @param [out] Operator parameter data + * @return SUCCESS + * @return FAILED + */ + virtual domi::Status ParseParams(const google::protobuf::Message *op_src, ge::Operator &op_dest) = 0; + + /** + * @ingroup domi_omg + * @brief Analytic operator weight information + * @param [in] op_src Weight data to be resolved + * @param [out] op_dest Weight data after analysis + * @return SUCCESS + * @return FAILED + */ + virtual domi::Status ParseWeights(const google::protobuf::Message *op_src, ge::NodePtr &node) = 0; + + /** + * @ingroup domi_omg + * @brief Get the format information according to the parameters in the operator + * @param [in] op_src Parameter data to be resolved + * @param [out] format Output the parsed format + * @return SUCCESS + * @return FAILED + */ + virtual domi::Status GetFormat(const google::protobuf::Message *op_src, domi::domiTensorFormat_t &format) { + (void)op_src; + // Indicates that the op does not provide a value for format + format = domi::DOMI_TENSOR_RESERVED; + return domi::SUCCESS; + } +}; +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_OP_PARSER_H_ diff --git a/inc/air/framework/omg/parser/parser_api.h b/inc/air/framework/omg/parser/parser_api.h new file mode 100644 index 0000000000000000000000000000000000000000..aca844db7a05a3b6bf3a0cd7ea26f232b4ed1577 --- /dev/null +++ b/inc/air/framework/omg/parser/parser_api.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ + +#include +#include +#include "external/ge/ge_api_error_codes.h" +#include "graph/ascend_string.h" + +namespace ge { +// Initialize parser +GE_FUNC_VISIBILITY Status ParserInitialize(const std::map& options); +// Finalize parser, release all resources +GE_FUNC_VISIBILITY Status ParserFinalize(); +// Initialize parser +GE_FUNC_VISIBILITY Status ParserInitialize(const std::map& options); +} // namespace ge +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_API_H_ diff --git a/inc/air/framework/omg/parser/parser_factory.h b/inc/air/framework/omg/parser/parser_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..d8fa71c183a4a629c50ff099e9f6030b11a41e81 --- /dev/null +++ b/inc/air/framework/omg/parser/parser_factory.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ + +#include +#include +#include +#include +#include "framework/omg/omg_inner_types.h" +#include "framework/omg/parser/parser_types.h" +#include "external/register/register.h" + +namespace domi { +class WeightsParser; +class ModelParser; + +using MODEL_PARSER_CREATOR_FUN = std::shared_ptr (*)(void); + +// Create modelparser for different frameworks +class GE_FUNC_VISIBILITY ModelParserFactory { + public: + static ModelParserFactory *Instance(); + + /** + * @ingroup domi_omg + * @brief Create a modelparser based on the type entered + * @param [in] type Framework type + * @return Created modelparser + */ + std::shared_ptr CreateModelParser(const domi::FrameworkType type); + + /** + * @ingroup domi_omg + * @brief Register create function + * @param [in] type Framework type + * @param [in] fun ModelParser's create function + */ + void RegisterCreator(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN fun); + + protected: + ModelParserFactory() {} + ~ModelParserFactory(); + + private: + std::map creator_map_; +}; // end class ModelParserFactory + +class GE_FUNC_VISIBILITY ModelParserRegisterar { + public: + ModelParserRegisterar(const domi::FrameworkType type, MODEL_PARSER_CREATOR_FUN const fun) noexcept { + ModelParserFactory::Instance()->RegisterCreator(type, fun); + } + ~ModelParserRegisterar() {} +}; + +// Registration macros for model parsers +#define REGISTER_MODEL_PARSER_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##_Model_Parser() { \ + std::shared_ptr ptr = nullptr; \ + try { \ + ptr = make_shared(); \ + } catch (...) { \ + ptr = nullptr; \ + } \ + return std::shared_ptr(ptr); \ + } \ + ModelParserRegisterar g_##type##_Model_Parser_Creator(type, Creator_##type##_Model_Parser) + +using WEIGHTS_PARSER_CREATOR_FUN = std::shared_ptr (*)(void); + +// Create weightsparser for different frameworks +class GE_FUNC_VISIBILITY WeightsParserFactory { + public: + static WeightsParserFactory *Instance(); + + /** + * @ingroup domi_omg + * @brief Create weightsparser based on the type entered + * @param [in] type Framework type + * @return Created weightsparser + */ + std::shared_ptr CreateWeightsParser(const domi::FrameworkType type); + + /** + * @ingroup domi_omg + * @brief Register create function + * @param [in] type Framework type + * @param [in] fun WeightsParser's create function + */ + void RegisterCreator(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN fun); + + protected: + WeightsParserFactory() {} + ~WeightsParserFactory(); + + private: + std::map creator_map_; +}; // end class WeightsParserFactory + +class GE_FUNC_VISIBILITY WeightsParserRegisterar { + public: + WeightsParserRegisterar(const domi::FrameworkType type, WEIGHTS_PARSER_CREATOR_FUN const fun) noexcept { + WeightsParserFactory::Instance()->RegisterCreator(type, fun); + } + ~WeightsParserRegisterar() {} +}; + +// Register macro of weight resolver +#define REGISTER_WEIGHTS_PARSER_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##_Weights_Parser() { \ + std::shared_ptr ptr = nullptr; \ + try { \ + ptr = make_shared(); \ + } catch (...) { \ + ptr = nullptr; \ + } \ + return std::shared_ptr(ptr); \ + } \ + WeightsParserRegisterar g_##type##_Weights_Parser_Creator(type, Creator_##type##_Weights_Parser) + +class GE_FUNC_VISIBILITY OpRegTbeParserFactory { + public: + static OpRegTbeParserFactory *Instance(); + void Finalize(const domi::OpRegistrationData ®_data); +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_FACTORY_H_ diff --git a/inc/air/framework/omg/parser/parser_inner_ctx.h b/inc/air/framework/omg/parser/parser_inner_ctx.h new file mode 100644 index 0000000000000000000000000000000000000000..80bb1ee8fca119002e695f4aa2562eb34f0d1d02 --- /dev/null +++ b/inc/air/framework/omg/parser/parser_inner_ctx.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ +#define INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ + +#include +#include +#include +#include +#include +#include "external/register/register_fmk_types.h" +#include "external/register/register_types.h" +#include "framework/omg/omg_inner_types.h" + +namespace ge { +struct ParserContext { + // format of the input specified by the command line + std::unordered_map input_nodes_format_map; + std::vector output_formats; + // user-designate input dims + std::vector>> user_input_dims; + std::map> input_dims; + // resolve the mapping between operators with the same name and corresponding network. format e.g. + // Detectionoutput:SsdDetectiontOutput + std::map op_conf_map; + // user-designate out nodes (this is used for determing the orders) + std::vector> user_out_nodes; + // default out nodes (this is used for determing the orders) + std::vector> default_out_nodes; + // save the output node of the network. key = operator name, value = index, index indicates the output index of the + // operator + std::map> out_nodes_map; + // save the output node of the network, value = topName, + // tensorName indicates the output name of the operator. + std::vector user_out_tensors; + // net out nodes (where user_out_nodes or leaf nodes) + std::vector net_out_nodes; + // net out nodes tensor names(caffe or onnx) + std::vector out_tensor_names; + // net data nodes tensor names(caffe or onnx) + std::vector data_tensor_names; + // Whether to use dynamic batch size or dynamic image size + bool is_dynamic_input = false; + bool train_flag = false; + domi::domiTensorFormat_t format = domi::DOMI_TENSOR_ND; + domi::FrameworkType type = domi::FRAMEWORK_RESERVED; + RunMode run_mode = RunMode::GEN_OM_MODEL; + // save caffe custom proto path, used by caffe parse + std::string custom_proto_path; + // save caffe proto path, used by caffe parse + std::string caffe_proto_path; + // name of the pass that needs to take effect + std::string enable_scope_fusion_passes; +}; + +GE_FUNC_VISIBILITY ParserContext &GetParserContext(); +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_PARSER_INNER_CONTEXT_H_ diff --git a/inc/air/framework/omg/parser/parser_types.h b/inc/air/framework/omg/parser/parser_types.h new file mode 100644 index 0000000000000000000000000000000000000000..dce163f87486d9b633665d9aa7f1466af66a2130 --- /dev/null +++ b/inc/air/framework/omg/parser/parser_types.h @@ -0,0 +1,508 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_TYPES_H_ +#define INC_FRAMEWORK_OMG_PARSER_TYPES_H_ + +#include +#include + +#include "register/register_types.h" +#include "graph/types.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#ifndef DOMI_DYNAMIC_CAST +#define DOMI_DYNAMIC_CAST static_cast +#endif +#ifndef DOMI_DYNAMIC_POINTER_CAST +#define DOMI_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif +#else +#ifndef DOMI_DYNAMIC_CAST +#define DOMI_DYNAMIC_CAST static_cast +#endif +#ifndef DOMI_DYNAMIC_POINTER_CAST +#define DOMI_DYNAMIC_POINTER_CAST std::static_pointer_cast +#endif +#endif + +namespace ge { +namespace parser { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const AIPPDATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CORRELATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CORRELATIONV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DECONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const POOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ELTWISE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RELU6; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SIGMOID; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ABSVAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TANH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PRELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BATCHNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FUSIONBATCHNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SCALE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FULL_CONNECTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SOFTMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PLUS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACTIVATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FLATTEN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SUB; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MATMUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RSQRT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BIASADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RESHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFORMAT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPCONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DROPOUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DROPOUTGENMASK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DROPOUTDOMASK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONCAT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ROIPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PROPOSAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FSRDETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DETECTIONPOSTPROCESS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LRN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TRANSDATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PERMUTE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDNORMALIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDPRIORBOX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NETOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDDETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFINEDETDETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CHANNELAXPY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PSROIPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const POWER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const POW; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ROIALIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PYTHON; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FREESPACEEXTRACT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPATIALTF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SHAPEN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ARGMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GATHERND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GATHER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REALDIV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PACK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SLICE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SLICED; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FLOORDIV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SQUEEZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const UNSQUEEZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STRIDEDSLICE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RANGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RPNPROPOSALS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DECODEBBOX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PADV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MIRRORPAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CLIPBOXES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTRCNNPREDICTIONS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPLIT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPLITV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EXPANDDIMS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EMPTY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GREATER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SWITCHN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MERGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SYMBOLICGRADIENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REMOTECALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const _IF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STATELESSIF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const IF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CASE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STATELESSCASE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const _WHILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const WHILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STATELESSWHILE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PARTITIONEDCALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STATEFULPARTITIONEDCALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FAKEPARAM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TRANSPOSE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TRANSPOSED; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REGION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const YOLO; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const YOLODETECTIONOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FILL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REVERSE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const UNPACK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const YOLO2REORG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REDUCESUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONSTANT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FILECONSTANT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RESIZEBILINEAR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RESIZEBILINEARGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MAXIMUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FRAMEWORKOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ARG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FUSEDBATCHNORMGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LSTM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HIGHWAY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RNN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ATTENTIONDECODER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LOGICAL_NOT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LOGICAL_AND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LOGICAL_OR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NOTEQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const INTERP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SHUFFLECHANNEL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const AIPP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MULTISHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RECIPROCAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ELU; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACOSH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASINH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MINIMUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CLIP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const L2NORMALIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CROPANDRESIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const UNUSEDCONST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPARSETODENSE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NONMAXSUPPRESSION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TOPKV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const INVERTPERMUTATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MULTINOMIAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REVERSESEQUENCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REDUCEPROD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REDUCEMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REDUCEMIN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EXTRACTIMAGEPATCHES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SQRT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REDUCEALL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RESIZENEARESTNEIGHBOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPACETOBATCHND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BATCHTOSPACEND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSERT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GREATEREQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FLOOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RANDOMUNIFORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BATCHMATMUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPACETODEPTH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPTHTOSPACE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RINT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ATAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ATAN2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ATANH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACOS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASIN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NEG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LOG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ROUND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const UPSAMPLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FLOORMOD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LESS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LESSEQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ONEHOT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFSWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFMERGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ENTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFENTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LOOPCOND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NEXTITERATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFNEXTITERATION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EXIT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFEXIT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONTROLTRIGGER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ZEROSLIKE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EXP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const WHERE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FAKEQUANTWITHMINMAXVARS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SOFTPLUS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SOFTSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const COSH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SINH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SQUAREDDIFFERENCE; +// for retinanet scope fusion +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REQUIREDSPACETOBATCHPADDINGS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RETINANETBOXES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RETINAMULTIANCHORS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RETINANETCLIPPEDBOXES; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RETINANETFILTEREDDETECTIONS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RETINANETPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RETINANETANCHORS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNMAP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNMAP1; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNSECONDSTAGEPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNROIINTERPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNFIRSTSTAGEPOSTPROCESSOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNGRIDANCHORGENERATOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ROIINTERPOOLING; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FASTERRCNNCLIPTOWINDOW; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const EMBEDLOOKUP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HASHLOOKUP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LSH_PROJ; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SVDF; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDANCHORGENERATOR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const IDENTITY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const IDENTITYN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PLACEHOLDERWITHDEFAULT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SELECT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GETSPAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STOPGRADIENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PREVENTGRADIENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GUARANTEECONST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BROADCASTGRADIENTARGS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BROADCASTARGS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONFUSIONMATRIX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RANK; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PLACEHOLDER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const END; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BASICLSTMCELL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GETNEXT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const INITDATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REFIDENTITY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BITCAST; + +/***************Ann special operator*************************/ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_MEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_CONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_DEPCONVOLUTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_FULLCONNECTION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_NETOUTPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_DATA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_RESHAPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_ADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_MUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_SUB; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_DIV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_DEQUANTIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_QUANTIZE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_PAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANN_RESIZE_BILINEAR; + +/***************************************************/ +/******************Training operator*************************/ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const GATHERV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONVGRADFILTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONV2D; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONV2DBACKPROPINPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FUSEDBATCHNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BIASADDGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACTIVATIONGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MAXPOOLWITHARGMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MAXPOOLGRADWITHARGMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPARSESOFTMAXCROSSENTROPYWITHLOGITS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SNAPSHOT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const VAR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MEANGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TRANSLATE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ADDN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const L2LOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MULTIPLY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HUBERLOSSGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HUBERLOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NEGATIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDCAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPARSESOFTMAXCROSSENTROPY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SPARSESOFTMAXCROSSENTROPYGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDSQUEEZEFUSION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONCATFOUR2FIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONCATFIVE2FOUR; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDREALDIVTILEMUL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SSDSUMMULREALDIVMEAN; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const VARIABLEV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const VARHANDLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TEMPORARYVARIABLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DESTROYTEMPORARYVARIABLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const VARIABLE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSIGNVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSIGNADD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSIGNADDVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSIGNSUB; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASSIGNSUBVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYMOMENTUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RESOURCEAPPLYMOMENTUM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SGD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const NOOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const READVARIABLEOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const PARALLELCONCATSTART; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONSTANTOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPTHWISECONV2DBACKPROPFILTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPTHWISECONV2DBACKPORPINPUT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPTHWISECONV2DFORWARDNATIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DROPOUTGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYRMSPROPMIXEDPRECISION; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYRMSPROP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RELU6GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const AVGPOOLGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONCATV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONCATOFFSET; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LAYERNORMGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LAYERNORM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LARS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DYNAMICSTITCH; + +/***************************************************/ +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SQUARE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMBROADCAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMALLGATHER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMALLGATHERV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMALLREDUCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMREDUCESCATTER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMREDUCESCATTERV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMSEND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMRECEIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMREMOTEREAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMREMOTEREFREAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMREMOTEWRITE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HCOMREMOTESCATTERWRITE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const VARASSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const VARISINITIALIZEDOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LogTimeStamp; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ISVARIABLEINITIALIZED; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STREAMSWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STREAMSWITCHN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STREAMACTIVE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MEMCPYASYNC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const MEMCPYADDRASYNC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const STREAMMERGE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ENDGRAPH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SEND; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const RECV; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ENDOFSEQUENCE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LABELSET; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LABELGOTO; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LABELGOTOEX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LABELSWITCH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const LABELSWITCHBYINDEX; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ATOMICADDRCLEAN; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ABS_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACCUMULATE_N_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACOS_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ACOSH_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ANY; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPROXIMATE_EQUAL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASIN_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ASINH_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ATAN_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BROADCAST_TO; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ELU_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ADD_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DATAFORMATDIMMAP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DATAFORMATVECPERMUTE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BESSELI0E; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const BESSELI1E; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYADADELTA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYADAGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYADAGRADDA; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYADAM; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYADAMAX; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYADDSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYCENTEREDRMSPROP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYFTRL; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYFTRLV2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYGRADIENTDESCENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYPOWERSIGN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYPROXIMALADAGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const APPLYPROXIMALGRADIENTDESCENT; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEQUANTIZE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FOCAL_LOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const FOCAL_LOSS_GRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SMOOTHL1_LOSS; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SMOOTHL1_LOSS_grad; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const REDUCEMEAN; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CONCAT_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const ONEHOT_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SLICE_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TILE_V2; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SUM_V2; +// Common type when the operator has the same name +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DETECTIONOUTPUT; +// Custom operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CUSTOMOP; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CUSTOMOP_NCHW; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CUSTOMOP_NHWC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const CUSTOMOP_NC1HWC0; + +// Depthwise 4d_2_6d,6d_2_4d +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPTHWISEWEIGHT4D26D; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const DEPTHWISEWEIGHT6D24D; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SQRTGRAD; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const SIGMOIDGRAD; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const TRANSSHAPE; + +// Horovod operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HVDCALLBACKALLREDUCE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HVDCALLBACKALLGATHER; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HVDCALLBACKBROADCAST; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const ge::char_t *const HVDWAIT; + +/// +/// @brief Magic number of model file +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_MAGIC_NUM; // magic number + +/// +/// @brief Model head length +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_FILE_HEAD_LEN; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t MODEL_VERSION; ///< Model version 1.0/// + +// alpha default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float32_t ALPHA_DEFAULT_VALUE; + +// beta default value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const float32_t BETA_DEFAULT_VALUE; + +/// +/// @ingroup domi_omg +/// @brief INPUT node type +/// +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string INPUT_TYPE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMMY_DATA; + +// dim default size value +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY constexpr int32_t DIM_DEFAULT_SIZE = 4; + +// for fusion op plugin +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; + +// DATA node type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DATA_TYPE; + +// framework Operator Type +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string FRAMEWORK_OP_TYPE; + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_NET_OUTPUT; + +#pragma pack() // Cancels single-byte alignment +} // namespace parser +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_PARSER_TYPES_H_ diff --git a/inc/air/framework/omg/parser/weights_parser.h b/inc/air/framework/omg/parser/weights_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..4319de90aa5ce601f7e35108d8061940f811d752 --- /dev/null +++ b/inc/air/framework/omg/parser/weights_parser.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ +#define INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ + +#include "external/register/register_error_codes.h" +#include "graph/graph.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +namespace domi { +/** + * @ingroup domi_omg + * @brief Weight information resolver + * + */ +class GE_FUNC_VISIBILITY WeightsParser { + public: + /** + * @ingroup domi_omg + * @brief Constructor + */ + WeightsParser() {} + + /** + * @ingroup domi_omg + * @brief Deconstructor + */ + virtual ~WeightsParser() {} + + /** + * @ingroup domi_omg + * @brief Analyze weight data + * @param [in] file Path of weight file after training + * @param [in|out] graph Graph for saving weight information after analysis + * @return SUCCESS + * @return Others failed + */ + virtual Status Parse(const char *file, ge::Graph &graph) = 0; + + /** + * @ingroup domi_omg + * @brief Parse relevant data from memory and save it to graph + * @param [in] input Model file memory data + * @param [in|out] graph A graph for saving the model information after analysis + * @return SUCCESS + * @return FAILED + * @author + */ + virtual Status ParseFromMemory(const char *input, uint32_t lengt, ge::ComputeGraphPtr &graph) = 0; + + virtual bool HasError() { + return false; + } + + virtual Status Save(const std::string &file) { + (void)file; + return SUCCESS; + } + + virtual void Clear() {} +}; +} // namespace domi + +#endif // INC_FRAMEWORK_OMG_PARSER_WEIGHTS_PARSER_H_ diff --git a/inc/air/framework/omg/version.h b/inc/air/framework/omg/version.h new file mode 100644 index 0000000000000000000000000000000000000000..c48fe1952932bc1daf0cf7b554652756936f29ff --- /dev/null +++ b/inc/air/framework/omg/version.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_OMG_VERSION_H_ +#define INC_FRAMEWORK_OMG_VERSION_H_ + +#include +#include "ge/ge_api_error_codes.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +class GE_FUNC_VISIBILITY PlatformVersionManager { + public: + PlatformVersionManager() = delete; + ~PlatformVersionManager() = delete; + static Status GetPlatformVersion(std::string &ver) { + ver = "1.11.z"; // format x.y.z + GELOGI("current platform version: %s.", ver.c_str()); + return SUCCESS; + } +}; // class PlatformVersionManager +} // namespace ge + +#endif // INC_FRAMEWORK_OMG_VERSION_H_ diff --git a/inc/air/framework/pne/flow_model.h b/inc/air/framework/pne/flow_model.h new file mode 100644 index 0000000000000000000000000000000000000000..7b8937c2fd4201f557673884e5fa129c524f8431 --- /dev/null +++ b/inc/air/framework/pne/flow_model.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GE_MODEL_FLOW_MODEL_H_ +#define GE_MODEL_FLOW_MODEL_H_ + +#include +#include "graph/compute_graph.h" +#include "framework/pne/pne_model.h" + +namespace ge { +struct HcomClusterDesc { + std::string name; + std::string rank_table; + std::map> device_to_rank_ids; + std::map> group_name_to_rank_ids; + bool operator==(const HcomClusterDesc &rhs) const; +}; + +class FlowModel : public PneModel { + public: + using PneModel::PneModel; + explicit FlowModel(const ComputeGraphPtr &root_graph); + ~FlowModel() override = default; + + Status SerializeModel(ModelBufferData &model_buff) override; + + Status UnSerializeModel(const ModelBufferData &model_buff) override; + + void SetModelsEschedPriority(std::map> models_esched_priority) { + models_esched_priority_ = std::move(models_esched_priority); + } + + const std::map> &GetModelsEschedPriority() const { + return models_esched_priority_; + } + + void SetModelNameToRankId(const std::map &model_name_to_rank_id); + + const std::map> &GetGroupNameToRankIds() const; + + void SetGroupNameToRankIds(const std::map> &group_name_to_rank_ids); + + const std::map> &GetDeviceToRankIds() const; + void SetDeviceToRankIds(const map> &device_to_rank_ids); + + void SetHcomClusterDescs(const std::map &hcom_cluster_descs); + Status MergeHcomClusterInfo(FlowModel &sub_flow_model); + const std::map &GetHcomClusterDescs() const; + + void SetModelNameToClusterAndRankId( + const std::map> &model_name_to_cluster_and_rank_id); + const std::map> &GetModelNameToClusterAndRankId() const; + + void SetLogicDeviceToMemCfg(std::map> logic_dev_id_to_mem_cfg) { + logic_dev_id_to_mem_cfg_ = std::move(logic_dev_id_to_mem_cfg); + } + + const std::map> &GetLogicDeviceToMemCfg() const { + return logic_dev_id_to_mem_cfg_; + } + + private: + HcomClusterDesc &GetOrCreateHcomClusterDesc(const std::string &name); + mutable std::mutex flow_model_mutex_; + std::map> models_esched_priority_; + std::map hcom_cluster_descs_; + std::map> model_name_to_cluster_and_rank_id_; + // key logic_device_id | value(std_mem_size, shared_mem_size) + std::map> logic_dev_id_to_mem_cfg_; +}; +using FlowModelPtr = std::shared_ptr; +} // namespace ge +#endif // GE_MODEL_FLOW_MODEL_H_ diff --git a/inc/air/framework/pne/graph_deployment_optimizer.h b/inc/air/framework/pne/graph_deployment_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..fcc2bde9ccd34cec4d0910e928502f12ff0d21f8 --- /dev/null +++ b/inc/air/framework/pne/graph_deployment_optimizer.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_INC_FRAMEWORK_PNE_GRAPH_DEPLOYMENT_OPTIMIZER_H_ +#define AIR_INC_FRAMEWORK_PNE_GRAPH_DEPLOYMENT_OPTIMIZER_H_ + +#include "framework/common/ge_inner_error_codes.h" +#include "graph/compute_graph.h" +#include "graph/parallelism/tensor_parallel_attrs.h" + +namespace ge { +class GraphDeploymentOptimizer { + public: + static Status OptimizeForAutoDeploy(const ComputeGraphPtr &compute_graph, ComputeGraphPtr &optimized_graph); + static Status OptimizeByRecomputation(const ComputeGraphPtr &graph); + static Status OptimizeByGraphSlicing(const std::vector> &nodes_sliced_infos, + const ComputeGraphPtr &compute_graph); +}; +} // namespace ge + +#endif // AIR_INC_FRAMEWORK_PNE_GRAPH_DEPLOYMENT_OPTIMIZER_H_ diff --git a/inc/air/framework/pne/pne_model.h b/inc/air/framework/pne/pne_model.h new file mode 100644 index 0000000000000000000000000000000000000000..3ac1c47c2a6b767db7a4fc2b8471fd3acc29b3ea --- /dev/null +++ b/inc/air/framework/pne/pne_model.h @@ -0,0 +1,169 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_PNE_MODEL_H_ +#define INC_FRAMEWORK_PNE_MODEL_H_ + +#include +#include +#include + +#include "graph/compute_graph.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "framework/engine/dnnengine.h" +#include "external/ge/ge_ir_build.h" +#include "common/model/model_deploy_resource.h" +namespace ge { +const std::string PNE_ID_NPU = "NPU"; +const std::string PNE_ID_CPU = "HOST_CPU"; +const std::string PNE_ID_UDF = "UDF"; +const std::string PNE_ID_PS = "PS"; + +struct ModelRelation; +struct ModelDeployResource; +class PneModel { + public: + PneModel() = default; + explicit PneModel(const ComputeGraphPtr &root_graph) : root_graph_(root_graph) {}; + virtual ~PneModel() = default; + PneModel(const PneModel &other) = delete; + PneModel &operator=(const PneModel &other) = delete; + + public: + inline Status AddSubModel(const shared_ptr &submodel, const std::string &type = "") { + const std::lock_guard lk(pne_model_mutex_); + if (submodel == nullptr) { + GELOGE(INTERNAL_ERROR, "submodel is nullptr, type = %s", type.c_str()); + return INTERNAL_ERROR; + } + submodel->SetModelType(type); + if (!submodels_.emplace(submodel->GetModelName(), submodel).second) { + GELOGE(INTERNAL_ERROR, "submodel already exist, name = %s, type = %s", submodel->GetModelName().c_str(), + type.c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; + } + + inline const std::shared_ptr GetSubmodel(const std::string &name) const { + const std::lock_guard lk(pne_model_mutex_); + const auto &it = submodels_.find(name); + if (it == submodels_.end()) { + return nullptr; + } + return it->second; + } + + inline const std::map> &GetSubmodels() const { + const std::lock_guard lk(pne_model_mutex_); + return submodels_; + } + + inline void SetSubmodels(std::map> submodels) { + const std::lock_guard lk(pne_model_mutex_); + submodels_ = std::move(submodels); + } + + inline void SetModelType(const std::string &type) { model_type_ = type; } + + inline const std::string &GetModelType() const { return model_type_; } + + inline void SetModelName(const std::string &model_name) { model_name_ = model_name; } + + inline const std::string &GetModelName() const { return model_name_; } + + inline void SetNormalizedModelName(const std::string &model_name) { normalized_model_name_ = model_name; } + + inline const std::string &GetNormalizedModelName() const { return normalized_model_name_; } + + inline void SetRootGraph(const ComputeGraphPtr &graph) { root_graph_ = graph; } + + inline const ComputeGraphPtr &GetRootGraph() const { return root_graph_; } + + inline void SetModelRelation(std::shared_ptr model_relation) { + model_relation_ = std::move(model_relation); + } + + inline const std::shared_ptr GetModelRelation() const { return model_relation_; } + + inline void SetDeployResource(std::shared_ptr deploy_resource) { + deploy_resource_ = std::move(deploy_resource); + } + + inline const std::shared_ptr GetDeployResource() const { return deploy_resource_; } + + inline void SetCompileResource(std::shared_ptr compile_resource) { + compile_resource_ = std::move(compile_resource); + } + + inline const std::shared_ptr GetCompileResource() const { return compile_resource_; } + + inline void SetDeviceId(const int32_t device_id) { device_id_ = device_id; } + + inline int32_t GetDeviceId() const { return device_id_; } + + public: + virtual Status SerializeModel(ModelBufferData &model_buff) = 0; + + virtual Status UnSerializeModel(const ModelBufferData &model_buff) = 0; + + virtual void SetModelId(const uint32_t model_id) { model_id_ = model_id; } + + virtual uint32_t GetModelId() const { return model_id_; } + + virtual std::string GetLogicDeviceId() const { return ""; } + + virtual Status SetLogicDeviceId(const std::string &logic_device_id) { + const std::lock_guard lk(pne_model_mutex_); + for (const auto &submdel : submodels_) { + (void)submdel.second->SetLogicDeviceId(logic_device_id); + } + return SUCCESS; + } + + virtual std::string GetRedundantLogicDeviceId() const { return ""; } + + virtual std::string GetSavedModelPath() const { return saved_model_path_; } + + virtual void SetSavedModelPath(const std::string &saved_model_path) { saved_model_path_ = saved_model_path; } + + virtual bool GetIsBuiltinModel() const { return is_builtin_model_; } + + virtual void SetIsBuiltinModel(const bool is_builtin) { is_builtin_model_ = is_builtin; } + + virtual Status SetRedundantLogicDeviceId(const std::string &logic_device_id) { + for (const auto &submdel : submodels_) { + (void)submdel.second->SetRedundantLogicDeviceId(logic_device_id); + } + return SUCCESS; + } + + private: + mutable std::mutex pne_model_mutex_; + std::map> submodels_; + std::shared_ptr model_relation_; + std::shared_ptr deploy_resource_; + std::shared_ptr compile_resource_; + ComputeGraphPtr root_graph_ = nullptr; + std::string model_name_; + // user define function is not empty + std::string normalized_model_name_; + std::string model_type_; + std::string saved_model_path_; + uint32_t model_id_ = INVALID_MODEL_ID; + int32_t device_id_ = -1; + bool is_builtin_model_ = false; +}; + +using PneModelPtr = std::shared_ptr; +} // namespace ge + +#endif // INC_FRAMEWORK_PNE_MODEL_H_ diff --git a/inc/air/framework/pne/process_node_engine.h b/inc/air/framework/pne/process_node_engine.h new file mode 100644 index 0000000000000000000000000000000000000000..19be5c8a08ce4741586734826b8d80e723e8bfb6 --- /dev/null +++ b/inc/air/framework/pne/process_node_engine.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FRAMEWORK_PROCESS_NODE_ENGINE_H_ +#define INC_FRAMEWORK_PROCESS_NODE_ENGINE_H_ + +#include +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "graph/manager/graph_manager_utils.h" +#include "framework/pne/pne_model.h" + +namespace ge { +class ProcessNodeEngineImpl { + public: + virtual ~ProcessNodeEngineImpl() = default; + + virtual Status OptimizeGraph(const std::vector &inputs, ComputeGraphPtr &compute_graph) = 0; + + virtual Status BuildGraph(ComputeGraphPtr &compute_graph, PneModelPtr &model) = 0; +}; + +using ProcessNodeEngineImplPtr = std::shared_ptr; + +class ProcessNodeEngine { + public: + ProcessNodeEngine() = default; + virtual ~ProcessNodeEngine() = default; + ProcessNodeEngine(const ProcessNodeEngine &other) = delete; + ProcessNodeEngine &operator=(const ProcessNodeEngine &other) = delete; + + public: + virtual Status Initialize(const std::map &options) = 0; + + virtual Status Finalize() = 0; + + virtual Status OptimizeGraph(const std::vector &inputs, ComputeGraphPtr &compute_graph) = 0; + + virtual Status BuildGraph(ComputeGraphPtr &compute_graph, PneModelPtr &model) = 0; + + virtual const std::string &GetEngineName(const ge::NodePtr &node_ptr = nullptr) const = 0; + + virtual void SetImpl(ProcessNodeEngineImplPtr impl) = 0; + + virtual Status AddGraph(const ComputeGraphPtr &compute_graph, const std::map &options) { + (void) compute_graph; + (void) options; + return SUCCESS; + } + + virtual Status RemoveGraph(const uint32_t graph_id) { + (void) graph_id; + return SUCCESS; + } + + virtual Status ParallelPartition(ComputeGraphPtr &compute_graph) { + (void) compute_graph; + return NOT_CHANGED; + } + + protected: + std::string engine_id_; + ProcessNodeEngineImplPtr impl_ = nullptr; +}; + +using ProcessNodeEnginePtr = std::shared_ptr; +} // namespace ge + +#endif // INC_FRAMEWORK_PROCESS_NODE_ENGINE_H_ diff --git a/inc/air/framework/runtime/device_memory_recorder.h b/inc/air/framework/runtime/device_memory_recorder.h new file mode 100644 index 0000000000000000000000000000000000000000..45d6905513e8944d18bd529f4c9a7b87f3f62837 --- /dev/null +++ b/inc/air/framework/runtime/device_memory_recorder.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef RUNTIME_V2_KERNEL_MEMORY_UTIL_DEVICE_MEMORY_STATS_H +#define RUNTIME_V2_KERNEL_MEMORY_UTIL_DEVICE_MEMORY_STATS_H +#include +#include +#include +#include "toolchain/prof_common.h" +#include "toolchain/prof_api.h" + +namespace gert { + struct MemoryRecorder { + int64_t size; + uint64_t addr; + uint64_t total_allocate_memory; + uint64_t total_reserve_memory; + uint64_t time_stamp; + }; + + class DeviceMemoryRecorder { + public: + static uint64_t GetTotalAllocateMemory() { return total_allocate_memory_.load(); } + static uint64_t GetTotalReserveMemory() { return total_reserve_memory_.load(); } + static void AddTotalReserveMemory(const uint64_t &num) { total_reserve_memory_ += num; } + static void ReduceTotalReserveMemory(const uint64_t &num) { total_reserve_memory_ -= num; } + static void ClearReserveMemory() { total_reserve_memory_.store(0UL); } + static const MemoryRecorder GetRecorder(); + static bool IsRecorderEmpty(); + static void SetRecorder(const void *const addr, const int64_t size); + static void AddTotalAllocateMemory(const uint64_t &num) { total_allocate_memory_ += num; } + static void ReduceTotalAllocateMemory(const uint64_t &num) { total_allocate_memory_ -= num; } + private: + static std::atomic total_allocate_memory_; + static std::atomic total_reserve_memory_; + static std::queue memory_record_queue_; + static std::mutex mtx_; + }; +} +#endif diff --git a/inc/air/framework/runtime/event_allocator.h b/inc/air/framework/runtime/event_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..a4331cd10c585a3d6eba4fd980827b98893d0d98 --- /dev/null +++ b/inc/air/framework/runtime/event_allocator.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_EVENT_ALLOCATOR_H_ +#define AIR_CXX_RUNTIME_EVENT_ALLOCATOR_H_ + +#include +#include "runtime/event.h" +#include "common/ge_visibility.h" +#include "framework/common/ge_inner_error_codes.h" +#include "exe_graph/runtime/continuous_vector.h" + +namespace gert { +class VISIBILITY_EXPORT EventAllocator { + public: + static constexpr size_t kMaxEventNum = 4096U; + explicit EventAllocator(uint32_t flag = RT_EVENT_DDSYNC_NS) + : events_holder_(ContinuousVector::Create(kMaxEventNum)), default_flag_(flag) {} + EventAllocator(const EventAllocator &) = delete; + EventAllocator &operator=(const EventAllocator &) = delete; + ~EventAllocator(); + + TypedContinuousVector *AcquireEvents(size_t event_num); + + private: + TypedContinuousVector *Events(); + + private: + std::unique_ptr events_holder_; + uint32_t default_flag_; +}; +} // namespace gert + +#endif // AIR_CXX_RUNTIME_EVENT_ALLOCATOR_H_ diff --git a/inc/air/framework/runtime/exe_graph_executor.h b/inc/air/framework/runtime/exe_graph_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..e6634ca767ebad660ed3b704538db496eaedfbcb --- /dev/null +++ b/inc/air/framework/runtime/exe_graph_executor.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXE_GRAPH_EXECUTOR_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXE_GRAPH_EXECUTOR_H_ +#include "graph/ge_error_codes.h" + +#include "common/ge_visibility.h" +#include "exe_graph_resource_guard.h" +#include "subscriber/executor_subscriber_c.h" +namespace gert { +enum SubExeGraphType { kInitExeGraph, kMainExeGraph, kDeInitExeGraph, kSubExeGraphTypeEnd }; +using ResourceGuardPtr = std::unique_ptr; +class VISIBILITY_EXPORT ExeGraphExecutor { + public: + using ExecuteFunc = UINT32 (*)(void *); + using ExecuteWithCallbackFunc = UINT32 (*)(int32_t, void *, ExecutorSubscriber *); + ge::graphStatus Load() const { + return ge::GRAPH_SUCCESS; + } + ge::graphStatus UnLoad() const { + return ge::GRAPH_SUCCESS; + } + + /** + * 设置图执行的输入/输出,需要注意的是,使用者需要自己保证inputs/outputs刷新完全!!! + */ + ge::graphStatus SpecifyInputs(void *const *inputs, size_t start, size_t num) const; + ge::graphStatus SpecifyInput(void *input, size_t start) const { + return SpecifyInputs(&input, start, 1U); + } + ge::graphStatus SpecifyOutputs(void *const *outputs, size_t num) const; + ge::graphStatus Execute() const; + ge::graphStatus Execute(SubExeGraphType sub_graph_type, ExecutorSubscriber *callback) const; + + const void *GetExecutionData() const { + return execution_data_; + } + void SetExecutionData(void *execution_data, ResourceGuardPtr resource_guard); + + void SetExecuteFunc(ExecuteFunc execute_func, ExecuteWithCallbackFunc callback_func); + + private: + friend class ModelV2ExecutorTestHelper; + + void *execution_data_{nullptr}; + ExecuteFunc execute_func_{nullptr}; + ExecuteWithCallbackFunc execute_with_callback_func_{nullptr}; + ResourceGuardPtr resource_guard_; +}; +} // namespace gert +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXE_GRAPH_EXECUTOR_H_ diff --git a/inc/air/framework/runtime/exe_graph_resource_guard.h b/inc/air/framework/runtime/exe_graph_resource_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..8ed5eab749eb2df3d05c13ab6dcc49bb0bf35320 --- /dev/null +++ b/inc/air/framework/runtime/exe_graph_resource_guard.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXE_GRAPH_RESOURCE_GUARD_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXE_GRAPH_RESOURCE_GUARD_H_ +#include +#include +#include +#include +#include "common/ge_visibility.h" + +namespace gert { +class VISIBILITY_EXPORT ResourceGuard { + public: + void *ResetExecutionData(std::unique_ptr execution_data); + void *GetExecutionData(); + virtual ~ResourceGuard() = default; + + private: + std::unique_ptr execution_data_holder_; +}; + +/* +* 这里将原来的resource guard平移为子类, +* 若后续有诉求需要拆分model上使用的resource guarder和sub exe graph上使用的resource guard +* 可以再行拆分 +*/ +class VISIBILITY_EXPORT TopologicalResourceGuard : public ResourceGuard { + public: + void ResetAnyValue(std::unique_ptr any_values, size_t count); + void PushNode(void *node); + void PushWatcher(void *watcher); + void *ResetNodesArray(std::unique_ptr nodes_array); + void *ResetStartNodesArray(std::unique_ptr start_nodes_array); + void *ResetNodesIndgreeArray(std::unique_ptr nodes_indgree_array); + void *ResetNodesWaitIndgreeArray(std::unique_ptr nodes_indgree_array); + void *ResetInputsArray(std::unique_ptr inputs_array); + void *ResetOutputsArray(std::unique_ptr outputs_array); + void *ResetWatchersArray(std::unique_ptr watchers_array); + void *ResetReadyQueue(void *ready_queue); + void *ResetBuffer(std::unique_ptr buffer); + void *ResetComputeNodeInfo(std::unique_ptr compute_node_info); + void *ResetKernelExtendInfo(std::unique_ptr kernel_extend_info); + void *ResetModelDesc(std::unique_ptr model_desc); + + ~TopologicalResourceGuard() override; + + private: + size_t any_values_num_; + std::unique_ptr any_values_guard_; + std::vector> nodes_guarder_; + + std::vector> watchers_guarder_; + std::unique_ptr continuous_buffer_guarder_; + std::unique_ptr buffer_guarder_; + std::unique_ptr compute_node_info_guarder_; + std::unique_ptr kernel_extend_info_guarder_; + std::unique_ptr model_desc_guarder_; + + std::unique_ptr nodes_array_guarder_; + std::unique_ptr start_nodes_array_guarder_; + std::unique_ptr nodes_indgree_array_guarder_; + std::unique_ptr nodes_wait_indgree_array_guarder_; + std::unique_ptr inputs_array_guarder_; + std::unique_ptr outputs_array_guarder_; + std::unique_ptr watchers_array_guarder_; + std::unique_ptr ready_queue_guarder_{nullptr, nullptr}; +}; +} +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXE_GRAPH_RESOURCE_GUARD_H_ diff --git a/inc/air/framework/runtime/executor_option/executor_option.h b/inc/air/framework/runtime/executor_option/executor_option.h new file mode 100644 index 0000000000000000000000000000000000000000..222b40e323e3b8b143a83e7829f556fd30ebbb57 --- /dev/null +++ b/inc/air/framework/runtime/executor_option/executor_option.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_EXECUTOR_OPTION_H +#define AIR_CXX_EXECUTOR_OPTION_H + +namespace gert { +enum class ExecutorType { + // 顺序优先级执行器,基于优先级完成拓扑排序后,使用基于拓扑排序后的结果执行, + // 本执行器调度速度最快,但是不支持条件/循环节点 + kSequentialPriority, + + // 基于拓扑的执行器,执行时基于拓扑动态计算ready节点,并做调度执行 + kTopological, + + // 基于拓扑的优先级执行器,在`kTopological`的基础上,将ready节点做优先级排序,总是优先执行优先级高的节点 + kTopologicalPriority, + + // 基于host缓存的执行器,在`kTopologicalPriority`的基础上,支持按照冻结后的图执行,被冻结的节点不再执行 + kHostCache, + + // 基于拓扑的多线程执行器 + kTopologicalMultiThread, + + kEnd +}; + +class VISIBILITY_EXPORT ExecutorOption { + public: + ExecutorOption() : executor_type_(ExecutorType::kEnd) {} + explicit ExecutorOption(ExecutorType executor_type) : executor_type_(executor_type) {} + ExecutorType GetExecutorType() const { + return executor_type_; + } + virtual ~ExecutorOption() = default; + + private: + ExecutorType executor_type_; +}; +} // namespace gert + +#endif // AIR_CXX_EXECUTOR_OPTION_H diff --git a/inc/air/framework/runtime/executor_option/multi_thread_executor_option.h b/inc/air/framework/runtime/executor_option/multi_thread_executor_option.h new file mode 100644 index 0000000000000000000000000000000000000000..1e6c7dc9e7e7df4a66ba93ac5d9c39b8041e5cc5 --- /dev/null +++ b/inc/air/framework/runtime/executor_option/multi_thread_executor_option.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_MULTI_THREAD_EXECUTOR_OPTION_H +#define AIR_CXX_MULTI_THREAD_EXECUTOR_OPTION_H + +#include +#include "framework/runtime/executor_option/executor_option.h" + +namespace gert { +constexpr size_t kLeastThreadNumber = 2U; // least new thread num, one for normal worker, one for memory worker + +class VISIBILITY_EXPORT MultiThreadExecutorOption : public ExecutorOption { + public: + MultiThreadExecutorOption() : MultiThreadExecutorOption(kLeastThreadNumber) {} + explicit MultiThreadExecutorOption(size_t thread_num) + : ExecutorOption(ExecutorType::kTopologicalMultiThread), thread_num_(thread_num) {} + MultiThreadExecutorOption(ExecutorType executor_type, size_t thread_num) + : ExecutorOption(executor_type), thread_num_(thread_num) {} + + size_t GetThreadNum() const { + return thread_num_; + } + + private: + /** + * 多线程数量 + * 取值范围:2 <= thread_num + */ + size_t thread_num_; +}; +} // namespace gert +#endif // AIR_CXX_MULTI_THREAD_EXECUTOR_OPTION_H diff --git a/inc/air/framework/runtime/executor_option/rt_cache_executor_option.h b/inc/air/framework/runtime/executor_option/rt_cache_executor_option.h new file mode 100644 index 0000000000000000000000000000000000000000..b9f732ad46bb02235f6548fe12e8e80f6b10a61c --- /dev/null +++ b/inc/air/framework/runtime/executor_option/rt_cache_executor_option.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RT_CACHE_EXECUTOR_OPTION_H +#define AIR_CXX_RT_CACHE_EXECUTOR_OPTION_H + +#include "framework/runtime/executor_option/executor_option.h" + +namespace gert { +enum class RtCacheMode { + /** + * 关闭 + */ + kTurnOff, + + /** + * 开启HostCache。 + * 若开启此选项,按照输入shape来缓存执行数据,省去部分只依赖输入shape的执行节点(infershape/tiling/malloc等), + * 节约host调度时间 + */ + kHostCache, + + /** + * 开启DeviceCache。todo, 待支持以后补充描述。 + */ + kDeviceCache, + + kEnd +}; +class VISIBILITY_EXPORT RtCacheExecutorOption : public ExecutorOption { + public: + RtCacheExecutorOption() : ExecutorOption(ExecutorType::kHostCache), rt_cache_mode_(RtCacheMode::kTurnOff) {} + explicit RtCacheExecutorOption(RtCacheMode rt_cache_mode) + : ExecutorOption(ExecutorType::kHostCache), rt_cache_mode_(rt_cache_mode) {} + const RtCacheMode &GetCacheMode() const { + return rt_cache_mode_; + } + + private: + /** + * 动态shape运行时Cache的模式。 + * 启用该模式,会通过缓存运行时数据的方式,提升host/device调度性能。 + */ + RtCacheMode rt_cache_mode_; +}; +} // namespace gert + +#endif // AIR_CXX_RT_CACHE_EXECUTOR_OPTION_H diff --git a/inc/air/framework/runtime/gert_api.h b/inc/air/framework/runtime/gert_api.h new file mode 100644 index 0000000000000000000000000000000000000000..026292aebccc6dbd6cd6e7e0d06674f6ec0879dc --- /dev/null +++ b/inc/air/framework/runtime/gert_api.h @@ -0,0 +1,108 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_GERT_API_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_GERT_API_H_ +#include "model_v2_executor.h" +#include "stream_executor.h" +#include "common/ge_types.h" +#include "common/ge_visibility.h" +#include "mem_allocator.h" +#include "memory/allocator_desc.h" +#include "exe_graph/lowering/lowering_opt.h" +#include "stream_allocator.h" +#include "event_allocator.h" +#include "rt_session.h" + +namespace gert { +/** + * Allocator 工厂类,创建Allocator + */ +class VISIBILITY_EXPORT AllocatorFactory { + public: + // 根据placement创建allocator + static std::unique_ptr Create(const std::string &graph_name, const TensorPlacement &placement); + static std::unique_ptr Create(const TensorPlacement &placement); + private: + AllocatorFactory() = default; +}; + +VISIBILITY_EXPORT +std::unique_ptr LoadExecutorFromFile(const ge::char_t *model_path, ge::graphStatus &error_code); + +VISIBILITY_EXPORT +std::unique_ptr LoadExecutorFromModelData(const ge::ModelData &model_data, + ge::graphStatus &error_code); + +/** + * 将ModelData加载为Executor, 使用该接口创建executor时, 用户需要保证调用executor->Load传入相同的RtSession + * @param model_data model_data从文件中读取后的内容 + * @param rt_session session对象,变量等资源会通过sesion共享 + * @param error_code 如果load失败(返回了空指针),那么本变量返回对应错误码 + * @return 成功时返回StreamExecutor的指针,失败时返回空指针。 + * 返回值类型是unique_ptr,因此返回的StreamExecutor生命周期由外部管理。 + */ +VISIBILITY_EXPORT +std::unique_ptr LoadExecutorFromModelDataWithRtSession(const ge::ModelData &model_data, + RtSession *const rt_session, + ge::graphStatus &error_code); + +VISIBILITY_EXPORT +std::unique_ptr LoadExecutorFromModelData(const ge::ModelData &model_data, + const ExecutorOption &executor_option, + ge::graphStatus &error_code); + +VISIBILITY_EXPORT +std::unique_ptr LoadExecutorFromModelData(const ge::ModelData &model_data, + const ExecutorOption &executor_option, + StreamAllocator *const stream_allocator, + EventAllocator *const event_allocator, + NotifyAllocator *const notify_allocator, + ge::graphStatus &error_code); + +VISIBILITY_EXPORT +std::unique_ptr LoadStreamExecutorFromModelData(const ge::ModelData &model_data, const void *weight_ptr, + const size_t weight_size, ge::graphStatus &error_code); + +/** + * 将ModelData加载为StreamExecutor,本函数等同为LoadStreamExecutorFromModelData(model_data, {}, error_cde); + * @param model_data model_data从文件中读取后的内容 + * @param error_code 如果load失败(返回了空指针),那么本变量返回对应错误码 + * @return 成功时返回StreamExecutor的指针,失败时返回空指针。 + * 返回值类型是unique_ptr,因此返回的StreamExecutor生命周期由外部管理。 + */ +VISIBILITY_EXPORT +std::unique_ptr LoadStreamExecutorFromModelData(const ge::ModelData &model_data, + ge::graphStatus &error_code); + +/** + * 将ModelData加载为StreamExecutor + * @param model_data model_data从文件中读取后的内容 + * @param optimize_option 优化选项 + * @param error_code 如果load失败(返回了空指针),那么本变量返回对应错误码 + * @return 成功时返回StreamExecutor的指针,失败时返回空指针。 + * 返回值类型是unique_ptr,因此返回的StreamExecutor生命周期由外部管理。 + */ + +VISIBILITY_EXPORT +std::unique_ptr LoadStreamExecutorFromModelData(const ge::ModelData &model_data, + const LoweringOption &optimize_option, + ge::graphStatus &error_code); +VISIBILITY_EXPORT +ge::graphStatus IsDynamicModel(const void *const model, size_t model_size, bool &is_dynamic_model); +VISIBILITY_EXPORT +ge::graphStatus IsDynamicModel(const ge::char_t *model_path, bool &is_dynamic_model); + +VISIBILITY_EXPORT +ge::graphStatus LoadDataFromFile(const ge::char_t *model_path, ge::ModelData &model_data); + +VISIBILITY_EXPORT +std::unique_ptr CreateExternalAllocator(const ge::AllocatorDesc * const allocatorDesc); +} // namespace gert +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_GERT_API_H_ diff --git a/inc/air/framework/runtime/gert_const_types.h b/inc/air/framework/runtime/gert_const_types.h new file mode 100644 index 0000000000000000000000000000000000000000..f376ca7c8293d49d54563af26d2b7db38b808809 --- /dev/null +++ b/inc/air/framework/runtime/gert_const_types.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_GE_RT_CONST_TYPES_H +#define AIR_CXX_GE_RT_CONST_TYPES_H +#include +#include "external/graph/types.h" +namespace gert { +// ConstData节点所表示的类型 +// ConstData不同于Data, 它将含有执行时信息语义,因此类型需要在lowering时确定 +// 如果需要新增ConstData类型,请在下面的枚举中增加, 同时需要在kConstDataTypes对应增加key +enum class ConstDataType { + kRtSession = 0, + KWeight, + kModelDescription, + kTypeEnd, +}; + +// ConstData节点的unique key,用于在global data中索引对应的holder。 +const ge::char_t* const kConstDataTypes[] = {"RtSession", "OuterWeightMem", "ModelDesc"}; +inline std::string GetConstDataTypeStr(ConstDataType type) { + auto len = sizeof(kConstDataTypes) / sizeof(ge::char_t *); + if (static_cast(type) >= len) { + return ""; + } + return kConstDataTypes[static_cast(type)]; +} +} +#endif // AIR_CXX_GE_RT_CONST_TYPES_H diff --git a/inc/air/framework/runtime/mem_allocator.h b/inc/air/framework/runtime/mem_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..5e393a111a5407d652fdb781e56412295e5b80f0 --- /dev/null +++ b/inc/air/framework/runtime/mem_allocator.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_MEM_ALLOCATOR_H +#define AIR_MEM_ALLOCATOR_H + +#include "exe_graph/runtime/tensor_data.h" +#include "exe_graph/runtime/allocator.h" +#include "common/ge_visibility.h" +#include "ge/ge_api_error_codes.h" +#include "runtime/mem.h" +#include "ge/ge_allocator.h" + +namespace gert { +namespace memory { +struct MemSynchronizer { + MemSynchronizer() = default; + virtual ~MemSynchronizer() = default; + // Wait until the memory is actually freed after task completed + virtual ge::Status Synchronize() const = 0; + virtual void Recycle() = 0; +}; + +class AllocatorManager { + public: + AllocatorManager() = default; + virtual ~AllocatorManager() = default; + virtual ge::Status Initialize(const std::vector &memory_types) { + (void)memory_types; + return ge::SUCCESS; + } + virtual void Finalize() {}; + virtual void ReleaseResource(const uint32_t device_id = 0U) { + (void)device_id; + } + virtual ge::Allocator *CreateAllocator(const uint32_t device_id, const rtMemType_t memory_type) { + (void)device_id; + (void)memory_type; + return nullptr; + } +}; +} + +class VISIBILITY_EXPORT Allocators { + public: + ge::Allocator *GetAllocator(const TensorPlacement &placement, const size_t &usage); + ge::Status SetAllocator(const TensorPlacement &placement, const size_t &usage, + std::shared_ptr &allocator); + private: + // todo remove usage + std::shared_ptr allocators[static_cast(TensorPlacement::kTensorPlacementEnd)] + [static_cast(AllocatorUsage::kEnd)]; +}; +} +#endif // AIR_CXX_MEM_ALLOCATOR_H diff --git a/inc/air/framework/runtime/model_desc.h b/inc/air/framework/runtime/model_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..5b43b2ddc4c9f742250c02b6a1092c85d1c5070f --- /dev/null +++ b/inc/air/framework/runtime/model_desc.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_MODEL_DESC_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_MODEL_DESC_H_ +#include "common/ge_types.h" +#include "common/ge_visibility.h" + +#include "exe_graph/runtime/shape.h" +#include "exe_graph/runtime/continuous_vector.h" +#include "exe_graph/runtime/storage_format.h" +#include "exe_graph/runtime/storage_shape.h" +#include "register/op_impl_space_registry.h" + +namespace gert { +class VISIBILITY_EXPORT ShapeRange { + public: + ShapeRange() = default; + ShapeRange(const Shape &min_shape, const Shape &max_shape); + bool operator==(const ShapeRange &other) const; + + const Shape &GetMin() const; + const Shape &GetMax() const; + Shape &MutableMin(); + Shape &MutableMax(); + + private: + Shape min_; + Shape max_; +}; + +class VISIBILITY_EXPORT ModelIoDesc { + public: + const ge::char_t *GetName() const; + int32_t GetDataType() const; + ge::Format GetStorageFormat() const; + ge::Format GetOriginFormat() const; + int64_t GetSize() const; + const Shape &GetStorageShape() const; + const Shape &GetOriginShape() const; + const ShapeRange &GetOriginShapeRange() const; + const ShapeRange &GetStorageShapeRange() const; + std::vector> GetOriginShapeRangeVector() const; + std::vector> GetStorageShapeRangeVector() const; + bool IsShapeInRange(const Shape &shape) const; + bool IsOriginShapeInRange(const Shape &shape) const; + + void SetName(const ge::char_t *name); + void SetDataType(int32_t data_type); + void SetStorageFormat(ge::Format format); + void SetOriginFormat(ge::Format format); + Shape &MutableStorageShape(); + Shape &MutableOriginShape(); + ShapeRange &MutableOriginShapeRange(); + ShapeRange &MutableStorageShapeRange(); + + const Shape &GetAippShape() const; + Shape &MutableAippShape(); + + private: + const ge::char_t *name_; + int32_t data_type_; + StorageFormat format_; + StorageShape shape_; + ShapeRange storage_shape_range_; + ShapeRange origin_shape_range_; + // use for aipp + Shape aipp_shape_; +}; + +class VISIBILITY_EXPORT ModelDesc { + public: + static size_t CalcSize(size_t input_num, size_t output_num); + const ModelIoDesc *GetInputDesc(size_t index) const; + const ModelIoDesc *GetAllInputsDesc(size_t &input_num) const; + + const ModelIoDesc *GetOutputDesc(size_t index) const; + const ModelIoDesc *GetAllOutputsDesc(size_t &output_num) const; + + size_t GetInputNum() const; + size_t GetOutputNum() const; + + ModelIoDesc *MutableInputDesc(size_t index); + ModelIoDesc *MutableOutputDesc(size_t index); + ModelIoDesc *AllMutableIoDesc(size_t &input_num, size_t &output_num); + void SetInputNum(size_t input_num); + void SetOutputNum(size_t output_num); + void SetSpaceRegistries(OpImplSpaceRegistryArray *space_registries); + OpImplSpaceRegistryArray *GetSpaceRegistries() const; + void SetFileConstantWeightDir(const ge::char_t *file_constant_weight_dir); + const ge::char_t *GetFileConstantWeightDir() const; + + ge::graphStatus GetDynamicBatchInfo(std::vector> &batch_info, int32_t &dynamic_type) const; + ge::graphStatus GetUserDesignateShapeOrder(std::vector &user_designate_shape_order) const; + ge::graphStatus GetModelAttrs(std::vector &attrs) const; + + /* + * Get stream num of root model, except persistant stream of static sub model + */ + size_t GetReusableStreamNum() const; + /* + * Set stream num of root model, except persistant stream of static sub model + */ + void SetReusableStreamNum(size_t stream_num); + size_t GetReusableEventNum() const; + void SetReusableEventNum(size_t event_num); + + size_t GetReusableNotifyNum() const; + void SetReusableNotifyNum(const size_t notify_num); + + size_t GetAttachedStreamNum() const; + void SetAttachedStreamNum(const size_t stream_num); + + private: + OpImplSpaceRegistryArray *space_registries_; + const ge::char_t *file_constant_weight_dir_; + size_t input_num_; + size_t output_num_; + size_t reusable_stream_num_; + size_t reusable_event_num_; + size_t reusable_notify_num_; + size_t attached_stream_num_; + + ContinuousVector model_io_descs_; + // do not add any variable here +}; +static_assert(std::is_trivial::value, "The class ModelDesc must be a POD"); +} // namespace gert + +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_MODEL_DESC_H_ diff --git a/inc/air/framework/runtime/model_v2_executor.h b/inc/air/framework/runtime/model_v2_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..976c3b6a4dadde617fa921c4e003aae4cc248f5e --- /dev/null +++ b/inc/air/framework/runtime/model_v2_executor.h @@ -0,0 +1,262 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_ +#define AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_ +#include +#include +#include "graph/compute_graph.h" +#include "graph/fast_graph/execute_graph.h" +#include "graph/ge_error_codes.h" +#include "model_desc.h" +#include "runtime/stream.h" +#include "exe_graph/runtime/tensor.h" +#include "common/ge_visibility.h" +#include "exe_graph_resource_guard.h" +#include "exe_graph_executor.h" +#include "subscriber/executor_subscribers_scheduler.h" +#include "common/ge_types.h" +#include "mem_allocator.h" +#include "framework/runtime/rt_session.h" +#include "register/op_impl_space_registry.h" +#include "framework/common/ge_types.h" +#include "framework/runtime/executor_option/executor_option.h" +#include "framework/runtime/stream_allocator.h" +#include "framework/runtime/event_allocator.h" +#include "common/host_resource_center/host_resource_center.h" +namespace gert { +enum class ExecutorState { kInit, kLoaded }; +inline const ge::char_t *GetSubExeGraphTypeStr(const SubExeGraphType type) { + constexpr const ge::char_t *kSubExeGraphTypeStrs[kSubExeGraphTypeEnd] = {"Init", "Main", "DeInit"}; + return kSubExeGraphTypeStrs[type]; +} + +enum class ExecuteArgIndex { + kNum = 4, + kNotifies = -1 * kNum, + kRtEvents, + kExternalAllocator, + kStream, + kEnd, +}; + +struct OuterWeightMem { + const void *weight_ptr; // this mem must be avalable and used for the same model + size_t weight_size; +}; + +struct ModelExecuteArg { + /** + * 如果外部传入的stream不为空,那么本模型将执行在外部传入的流上。 + */ + rtStream_t stream; + + /** + * 是否使用外部的allocator来申请流上内存。如果本成员指针为空,那么执行器会自行创建一个默认allocator使用。 + * 由于Host与Device之间下发任务后,Device总是在流上异步执行这个任务,因此allocator需要满足如下几个要求: + * 1. 一个allocator仅对应唯一的stream + * 2. 在对应的流同步之前,allocator内存池中的内存不可以归还到操作系统 + * 3. 在对应的流同步之前,allocator不可以被析构(原理同2) + */ + Allocators *external_allocator; + + /** + * 是否使用外部的stream allocator来申请辅流。如果本成员指针为空,那么执行器会自行创建一个默认的stream allocator使用。 + * 该stream allocator与ModelExecuteArg的第一个参数stream(外部主流)生命周期一致。若切换主流,也需要对应切换该allocator + */ + StreamAllocator *external_stream_allocator; + /** + * 是否使用外部的event allocator来申请event。如果本成员指针为空,那么执行器会自行创建一个默认的event allocator使用。 + * 该event allocator与ModelExecuteArg的第一个参数stream(外部主流)生命周期一致。若切换主流,也需要对应切换该allocator + * event allocator与stream allocator需要同时传入 + */ + EventAllocator *external_event_allocator; + /** + * 是否使用外部的notify allocator来申请notify。如果本成员指针为空,那么执行器会自行创建一个默认的notify + * allocator使用。 该notify + * allocator与ModelExecuteArg的第一个参数stream(外部主流)生命周期一致。若切换主流,也需要对应切换该allocator notify + * allocator与stream allocator需要同时传入 + */ + NotifyAllocator *external_notify_allocator; + + uint64_t reserved[8]; + + ModelExecuteArg() : ModelExecuteArg(nullptr, nullptr) {} + ModelExecuteArg(const rtStream_t stream_, Allocators *const external_allocator_ = nullptr) + : stream(stream_), + external_allocator(external_allocator_), + external_stream_allocator{nullptr}, + external_event_allocator{nullptr}, + external_notify_allocator{nullptr}, + reserved{0U} {} +}; +static_assert(std::is_standard_layout::value, "The class ModelExecuteArg must be a POD"); + +struct ModelLoadArg { + RtSession *rt_session; + OuterWeightMem outer_weight_mem; + ModelLoadArg() : rt_session(nullptr), outer_weight_mem({nullptr, 0U}) {} + ModelLoadArg(RtSession *rt_session_tmp = nullptr, OuterWeightMem outer_weight_mem_tmp = {nullptr, 0U}) + : rt_session(rt_session_tmp), + outer_weight_mem(outer_weight_mem_tmp) {} +}; + +static_assert(std::is_standard_layout::value, "The class ModelLoadArg must be a POD"); + +class VISIBILITY_EXPORT ModelV2Executor { + public: + static std::unique_ptr Create(const ge::ExecuteGraphPtr &exe_graph, const ge::ModelData &model_data, + const std::shared_ptr &root_model); + static std::unique_ptr Create(const ge::ExecuteGraphPtr &exe_graph, + const std::shared_ptr &root_model); + static std::unique_ptr Create(const ge::ExecuteGraphPtr &exe_graph, const ExecutorOption &option, + const std::shared_ptr &root_model); + + ge::graphStatus Load(); + /** + * 加载模型,本接口需要在模型执行前被调用。加载流程会完成模型的初始化、将重要数据拷贝到NPU等整个模型生命周期内仅需要执行一次的行为。 + * @param arg + * 模型的执行参数,需要注意的是,此处传入的执行参数应该与Execute接口传入的执行参数具有相同的stream和allocator, + * 否则在load完成后,外部需要调用流同步以保证不出现时序问题 + * @return 成功时返回`ge::GRAPH_SUCCESS` + */ + ge::graphStatus Load(const ModelExecuteArg &arg); + + /** + * 加载模型,本接口需要在模型执行前被调用。加载流程会完成模型的初始化、将重要数据拷贝到NPU等整个模型生命周期内仅需要执行一次的行为。 + * @param execute_arg + * 模型的执行参数,需要注意的是,此处传入的执行参数应该与Execute接口传入的执行参数具有相同的stream和allocator, + * 否则在load完成后,外部需要调用流同步以保证不出现时序问题 + * @param load_arg 模型的加载参数,加载时由外部传入, 执行时不会改变的参数,如RtSession. + * @return 成功时返回`ge::GRAPH_SUCCESS` + */ + ge::graphStatus Load(const ModelExecuteArg &arg, const ModelLoadArg &load_arg); + + /** + * 异步执行模型,本接口将模型异步下发到NPU执行,本接口返回不代表模型执行完成,用户需要手动调用流同步等待模型执行完成。 + * 调用本接口前,请确保已经调用`Load`接口 + * + * 用户可以通过多种方式指定输出Tensor,其行为分别为: + * + * * 调用本接口前,用户自行申请了足量空间的输出内存,并通过输出Tensor传入:执行完成后,输出内容被写入到用户申请的输出Tensor。 + * 若用户申请的输出Tensor不够长,那么本接口返回失败。 + * * 用户生成了输出Tensor,但是没有申请输出内存,将不包含输出内存的Tensor传入:本接口内部主动申请输出内存,并将输出内存传出。 + * 若用户没有在arg中指定Allocator,那么本接口输出的内存生命周期与本Executor一致; + * 如果用户在arg中传入了Allocator,那么输出内存将使用用户传入的Allocator申请 + * + * 注意: + * + * 1. 本接口不支持并发调用 + * 2. + * 如果外部指定了Allocator,那么建议Allocator应该与stream绑定,如果出现同一个allocator,匹配不同的stream多次调用Execute接口时, + * 需要满足两个条件:不可以并发调用,在切换stream执行中间,需要对上一条stream做流同步 + * 3. + * 若外部指定了Allocator,在模型执行完成前,不可以将Allocator中的内存归还给操作系统(即使这块内存已经由执行器归还给Allocator) + * + * @param arg 执行参数 + * @param inputs 网络的输入tensor,从调用本接口开始,到流同步等待本模型执行结束之前,用户需要保证传入的Tensor有效 + * @param input_num 输入tensor的数量 + * @param outputs 网络的输出tensor + * @param output_num 输出tensor的数量 + * @return 成功时返回`ge::GRAPH_SUCCESS` + */ + ge::graphStatus Execute(const ModelExecuteArg &arg, Tensor **inputs, size_t input_num, Tensor **outputs, + size_t output_num); + ge::graphStatus ExecuteSync(Tensor **inputs, size_t input_num, Tensor **outputs, size_t output_num); + ge::graphStatus UnLoad(); + + const ModelDesc &GetModelDesc() const; + void SetModelDesc(ModelDesc *model_desc); + ExeGraphExecutor *GetExeGraphExecutor(const SubExeGraphType type) { + if (type >= kSubExeGraphTypeEnd) { + return nullptr; + } + return &graphs_[static_cast(type)]; + } + ExecutorSubscribersScheduler &GetSubscribers(); + uint32_t GetIterationNum() const; + const ExecutorSubscribersScheduler &GetSubscribers() const; + ge::graphStatus ArrangeModelLoadArg(const ModelLoadArg &arg, std::vector &const_inputs) const; + + ModelV2Executor(const ModelV2Executor &) = delete; + ModelV2Executor(ModelV2Executor &&) = delete; + ModelV2Executor &operator=(const ModelV2Executor &) = delete; + ModelV2Executor &operator=(ModelV2Executor &&) = delete; + void SetFileConstantWeightDir(const std::string &file_constant_weight_dir) { + file_constant_weight_dir_ = file_constant_weight_dir; + } + const std::string &GetFileConstantWeightDir() const { + return file_constant_weight_dir_; + } + /** + * @brief 获取aipp相关config info + * @param [in] index 输入index + * @param [out] aipp_info 待输出的aipp info + * @return 成功时返回`ge::SUCCESS` + */ + ge::Status GetAippInfo(const uint32_t index, ge::AippConfigInfo &aipp_info) const; + /** + * @brief 获取aipp输入类型 + * @param [in] index 输入index + * @param [out] aipp_type 待返回的aipp type + * @param [out] aipp_index 待返回的aipp index + * @return 成功时返回`ge::SUCCESS` + */ + ge::Status GetAippType(const uint32_t index, ge::InputAippType &aipp_type, size_t &aipp_index) const; + + ge::Status GetOriginAippInputInfo(const uint32_t index, ge::OriginInputInfo &orig_aipp_input_info) const; + + ge::Status GetAllAippInputOutputDims(const uint32_t index, std::vector &input_dims, + std::vector &output_dims) const; + + ge::Status InitAipp(const ge::ComputeGraphPtr &root_graph); + + private: + friend class ModelV2ExecutorBuilder; + friend class ModelV2ExecutorTestHelper; + ModelV2Executor(); + ge::graphStatus SpecifyArgsInputs(const ModelExecuteArg &arg, size_t input_num, ExeGraphExecutor &graph_executor); + ge::graphStatus OccupyStreamResource(const ModelExecuteArg &arg, TypedContinuousVector *&streams, + TypedContinuousVector *&events, + TypedContinuousVector *¬ifies); + ge::Status InitRtVarManager(const ModelLoadArg &load_arg) const; + + private: + // to keep host resource live longer than resource_guard_ + // resource guarder may holding pointer from host_resource_center_ + ge::HostResourceCenterPtr host_resource_center_; + TopologicalResourceGuard resource_guard_; + std::array graphs_; + ModelDesc *model_desc_ = nullptr; + rtStream_t default_stream_ = nullptr; + ExecutorSubscribersScheduler subscribers_; + ExecutorState state_ = ExecutorState::kInit; + std::string file_constant_weight_dir_; + /* + * 背景:对于aipp离线推理场景下,acl需要获取编译时期很多aipp的相关信息,rt2场景下需要适配 + * 临时规避方案:因为是客户问题,时间比较紧,所以当前简单仿照静态shape下获取aipp的逻辑, + 数据结构和对应给acl的api保持一致,rt2放在ModelV2Executor类成员变量并提供相应接口 + * 正式方案:对于aipp场景下新定义一个数据结构,例如AllAippInfo的结构体,该结构体里把所有aipp需要的信息都放在里面, + 静态动态下只提供一个API,例如GetAllAippINfo()返回给acl,这样ge只提供一个接口,acl获取完之后相应的Get接口直接在acl层闭环, + 该方案涉及metadef,air,acl三个仓的联合修改。 + */ + // for aipp + std::map aipp_info_list_; + std::map> aipp_type_list_; + std::map orig_aipp_input_info_; + std::map, std::vector>> aipp_dims_info_; + + EventAllocator builtin_event_allocator_; + StreamAllocator builtin_stream_allocator_; // insure stream destroy before event + NotifyAllocator builtin_notify_allocator_; + uint64_t load_session_id_{std::numeric_limits::max()}; +}; +} // namespace gert + +#endif // AIR_CXX_RUNTIME_V2_CORE_MODEL_V_2_EXECUTOR_H_ diff --git a/inc/air/framework/runtime/notify_allocator.h b/inc/air/framework/runtime/notify_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..a1b3886aa7841dfdfc48b76c10ee5a2fa6457ab0 --- /dev/null +++ b/inc/air/framework/runtime/notify_allocator.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_NOTIFY_ALLOCATOR_H +#define AIR_CXX_RUNTIME_NOTIFY_ALLOCATOR_H + +#include +#include "runtime/event.h" +#include "common/ge_visibility.h" +#include "framework/common/ge_inner_error_codes.h" +#include "exe_graph/runtime/continuous_vector.h" + +namespace gert { +class VISIBILITY_EXPORT NotifyAllocator { + public: + static constexpr size_t kMaxNotifyNum = 1024U; + explicit NotifyAllocator() : notifies_holder_(ContinuousVector::Create(kMaxNotifyNum)) {} + NotifyAllocator(const NotifyAllocator &) = delete; + NotifyAllocator &operator=(const NotifyAllocator &) = delete; + ~NotifyAllocator(); + + TypedContinuousVector *AcquireNotifies(const int32_t device_id, const size_t notify_num); + + private: + TypedContinuousVector *Notifies(); + + private: + std::unique_ptr notifies_holder_; +}; +} // namespace gert + +#endif // AIR_CXX_RUNTIME_NOTIFY_ALLOCATOR_H diff --git a/inc/air/framework/runtime/rt_session.h b/inc/air/framework/runtime/rt_session.h new file mode 100644 index 0000000000000000000000000000000000000000..3336d74622a1946508845b43c94cde1de7b8b822 --- /dev/null +++ b/inc/air/framework/runtime/rt_session.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_RUNTIME_SESSION_H +#define AIR_RUNTIME_SESSION_H +#include "rt_var_manager.h" +#include "stream_allocator.h" +#include "event_allocator.h" +#include "notify_allocator.h" + +namespace gert { +class RtSession { + public: + explicit RtSession(uint64_t session_id = 0UL) : session_id_(session_id) { + stream_allocator_ = std::unique_ptr(new (std::nothrow) StreamAllocator()); + event_allocator_ = std::unique_ptr(new (std::nothrow) EventAllocator()); + notify_allocator_ = std::unique_ptr(new (std::nothrow) NotifyAllocator()); + } + uint64_t GetSessionId() const { + return session_id_; + } + void SetSessionId(uint64_t session_id) { + session_id_ = session_id; + } + const RtVarManager *GetVarManager() const { + return var_manager_; + } + void SetVarManager(RtVarManager *var_manager) { + var_manager_ = var_manager; + } + + StreamAllocator *GetStreamAllocator() const { + return stream_allocator_.get(); + } + + EventAllocator *GetEventAllocator() const { + return event_allocator_.get(); + } + + NotifyAllocator *GetNotifyAllocator() const { + return notify_allocator_.get(); + } + + private: + uint64_t session_id_; + RtVarManager *var_manager_{nullptr}; + std::unique_ptr stream_allocator_; + std::unique_ptr event_allocator_; + std::unique_ptr notify_allocator_; +}; +} // namespace gert +#endif // AIR_RUNTIME_SESSION_H diff --git a/inc/air/framework/runtime/rt_var_manager.h b/inc/air/framework/runtime/rt_var_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..2ee158b9d0904b00739b6283ad934374a75e7234 --- /dev/null +++ b/inc/air/framework/runtime/rt_var_manager.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_RUNTIME_VAR_MANAGER_H +#define AIR_RUNTIME_VAR_MANAGER_H +#include "exe_graph/runtime/storage_shape.h" +#include "exe_graph/runtime/tensor_data.h" +#include "external/ge/ge_api_types.h" +namespace gert { +class RtVarManager { + public: + RtVarManager() = default; + virtual ~RtVarManager() = default; + virtual ge::Status GetVarShapeAndMemory(const std::string &id, StorageShape &shape, + TensorData &memory) const = 0; +}; +} // namespace gert +#endif // AIR_RUNTIME_VAR_MANAGER_H diff --git a/inc/air/framework/runtime/stream_allocator.h b/inc/air/framework/runtime/stream_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..fd14549fd9ac2603b650b30dd0c1b9c89ccc7873 --- /dev/null +++ b/inc/air/framework/runtime/stream_allocator.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_STREAM_ALLOCATOR_H_ +#define AIR_CXX_RUNTIME_STREAM_ALLOCATOR_H_ + +#include +#include +#include "runtime/stream.h" +#include "common/ge_visibility.h" +#include "framework/common/ge_inner_error_codes.h" +#include "exe_graph/runtime/continuous_vector.h" +namespace gert { +class VISIBILITY_EXPORT StreamAllocator { + public: + static constexpr size_t kMaxStreamNum = 2024U; + explicit StreamAllocator(int32_t priority = RT_STREAM_PRIORITY_DEFAULT, uint32_t flags = RT_STREAM_DEFAULT); + StreamAllocator(const StreamAllocator &) = delete; + StreamAllocator &operator=(const StreamAllocator &) = delete; + + ~StreamAllocator(); + + TypedContinuousVector *AcquireStreams(size_t stream_num); + + private: + TypedContinuousVector *Streams(); + + private: + std::unique_ptr streams_holder_; + int32_t default_priority_; + uint32_t default_flags_; +}; +} // namespace gert + +#endif // AIR_CXX_RUNTIME_STREAM_ALLOCATOR_H_ diff --git a/inc/air/framework/runtime/stream_executor.h b/inc/air/framework/runtime/stream_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..c603785c3974909c9ce748304dbff49d79949d31 --- /dev/null +++ b/inc/air/framework/runtime/stream_executor.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_STREAM_EXECUTOR_H +#define AIR_CXX_STREAM_EXECUTOR_H +#include +#include +#include +#include "runtime/base.h" +#include "common/checker.h" +#include "model_v2_executor.h" +namespace gert { +// do not expose the Builder class definition to external api +class ModelV2ExecutorBuilder; +class VISIBILITY_EXPORT StreamExecutor { + public: + explicit StreamExecutor(ModelV2ExecutorBuilder *builder); + StreamExecutor(const StreamExecutor &) = delete; + StreamExecutor &operator=(const StreamExecutor &) = delete; + StreamExecutor(StreamExecutor &&) = delete; + StreamExecutor &operator=(StreamExecutor &&) = delete; + ~StreamExecutor(); + ModelV2Executor *GetOrCreateLoaded(rtStream_t stream, const ModelExecuteArg &arg) { + const std::lock_guard lock(mutex_); + const auto &iter = streams_to_executor_.find(stream); + if (iter != streams_to_executor_.cend()) { + return iter->second.get(); + } + return CreateAndLoad(stream, arg); + } + ge::graphStatus Erase(rtStream_t stream); + + private: + ModelV2Executor *CreateAndLoad(rtStream_t stream, const ModelExecuteArg &arg); + + private: + std::recursive_mutex mutex_; + ModelV2ExecutorBuilder *builder_; + std::map> streams_to_executor_; +}; +} // namespace gert +#endif // AIR_CXX_STREAM_EXECUTOR_H diff --git a/inc/air/framework/runtime/subscriber/built_in_subscriber_definitions.h b/inc/air/framework/runtime/subscriber/built_in_subscriber_definitions.h new file mode 100644 index 0000000000000000000000000000000000000000..48c9c395a26a5870a069a26dbf430fc5dca43e1f --- /dev/null +++ b/inc/air/framework/runtime/subscriber/built_in_subscriber_definitions.h @@ -0,0 +1,142 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_SUBSCRIBER_BUILT_IN_SUBSCRIBER_DEFINITIONS_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_SUBSCRIBER_BUILT_IN_SUBSCRIBER_DEFINITIONS_H_ +#include +#include +#include "graph/gnode.h" +#include "common/ge_types.h" +#include "framework/common/ge_visibility.h" +#include "exe_graph/runtime/kernel_run_context.h" +#include "graph/anchor.h" +#include "framework/common/profiling_definitions.h" +#include "runtime/base.h" +#include "graph/fast_graph/execute_graph.h" + +namespace ge { +class GeRootModel; +} +namespace gert { +constexpr size_t kProfilingDataCap = 10UL * 1024UL * 1024UL; +constexpr size_t kInitSize = 10UL * 1024UL; +constexpr size_t kDouble = 2UL; +using SymbolsToValue = std::unordered_map; +static_assert(kInitSize > static_cast(gert::profiling::kProfilingIndexEnd), + "The max init size is less than kProfilingIndexEnd."); +enum class BuiltInSubscriberType { + kGeProfiling, + kDumper, + kTracer, + kCannProfilerV2, + kCannHostProfiler, + kMemoryProfiler, + kHostDumper, + kNum +}; + +enum class ProfilingType { + kCannHost = 0, // 打开Host侧调度的profiling + kDevice = 1, + kGeHost = 2, // 打开GE Host侧调度的profiling + kTrainingTrace = 3, + kTaskTime = 4, + kMemory = 5, + kCannHostL1 = 6, + kNum, + kAll = kNum +}; +static_assert(static_cast(ProfilingType::kNum) < sizeof(uint64_t) * static_cast(8), + "The max num of profiling type must less than the width of uint64"); + +enum class DumpType { + kDataDump = 0, + kExceptionDump = 1, + kOverflowDump = 2, + kHostDump = 3, + // kNeedSubscribe为分界线,枚举值小于kNeedSubscribe的 + // 枚举量表示此Dump需要通过rt2.0中的OnEvent来做dump动作 + // 大于kNeedSubscribe的枚举量表示不需要通过OnEvent来做dump + // 动作的dump + kNeedSubscribe = 4, + kLiteExceptionDump = 4, + kNum = 5, + kAll = kNum +}; +static_assert(static_cast(DumpType::kNum) < sizeof(uint64_t) * static_cast(8), + "The max num of dumper type must less than the width of uint64"); +class ModelV2Executor; +struct TraceAttr { + bool is_fp = false; + bool is_bp = false; + int64_t start_log_id = -1; + int64_t logic_stream_id = 0; +}; +// todo : 需要整改,root_model等需要删除 +struct SubscriberExtendInfo { + SubscriberExtendInfo(ModelV2Executor *out_executor, const ge::ExecuteGraphPtr &out_exe_graph, + const ge::ComputeGraphPtr &out_root_graph, const ge::ModelData &out_model_data, + const std::shared_ptr &out_root_model, + const SymbolsToValue &out_symbols_to_value, const uint32_t out_model_id, + const std::string &out_model_name, const rtStream_t out_stream, + const std::unordered_map &out_node_names_to_attrs) + : executor(out_executor), + exe_graph(out_exe_graph), + root_graph(out_root_graph), + model_data(out_model_data), + root_model(out_root_model), + symbols_to_value(out_symbols_to_value), + model_id(out_model_id), + model_name(out_model_name), + stream(out_stream), + node_names_to_attrs(out_node_names_to_attrs) {} + SubscriberExtendInfo() + : executor(nullptr), + exe_graph(nullptr), + root_graph(nullptr), + root_model(nullptr), + model_id(0U), + stream(nullptr) {} + ModelV2Executor *executor; + ge::ExecuteGraphPtr exe_graph; + ge::ComputeGraphPtr root_graph; + ge::ModelData model_data; + std::shared_ptr root_model; + SymbolsToValue symbols_to_value; + uint32_t model_id; + std::string model_name; + rtStream_t stream; + std::unordered_map node_names_to_attrs; +}; + +class VISIBILITY_EXPORT BuiltInSubscriberUtil { + public: + template ::value) || (std::is_same::value), + int>::type = 0> + constexpr static uint64_t EnableBit(T et) { + return 1UL << static_cast(et); + } + + template ::value) || (std::is_same::value), + int>::type = 0> + static uint64_t BuildEnableFlags(const std::vector &enable_types) { + uint64_t flag = 0UL; + for (const auto &et : enable_types) { + if (et == T::kAll) { + return EnableBit(T::kNum) - 1UL; + } + flag |= EnableBit(et); + } + return flag; + } +}; +} // namespace gert +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_SUBSCRIBER_BUILT_IN_SUBSCRIBER_DEFINITIONS_H_ diff --git a/inc/air/framework/runtime/subscriber/executor_subscriber_c.h b/inc/air/framework/runtime/subscriber/executor_subscriber_c.h new file mode 100644 index 0000000000000000000000000000000000000000..d640e737a573b05434ae9474e574583afb213b95 --- /dev/null +++ b/inc/air/framework/runtime/subscriber/executor_subscriber_c.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBER_C_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBER_C_H_ +#include "exe_graph/runtime/base_type.h" +#ifdef __cplusplus +extern "C" { +#endif +typedef enum { + kExecuteStart, + kExecuteEnd, + kModelStart, + kModelEnd, + kExecuteEventEnd +} ExecutorEvent; + +typedef void (*SubscriberFunc)(int type, void *arg, ExecutorEvent event, const void *node, KernelStatus result); +typedef struct { + SubscriberFunc callback; + void *arg; +} ExecutorSubscriber; +#ifdef __cplusplus +} +#endif +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBER_C_H_ diff --git a/inc/air/framework/runtime/subscriber/executor_subscriber_guarder.h b/inc/air/framework/runtime/subscriber/executor_subscriber_guarder.h new file mode 100644 index 0000000000000000000000000000000000000000..e25c540df9da3ab53e26788812445236189dfd3c --- /dev/null +++ b/inc/air/framework/runtime/subscriber/executor_subscriber_guarder.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBER_GUARDER_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBER_GUARDER_H_ + +#include "framework/common/ge_visibility.h" +#include "common/checker.h" +#include "executor_subscriber_c.h" +namespace gert { +template +void ObjectDeleter(void *obj) { + delete static_cast(obj); +} + +class VISIBILITY_EXPORT ExecutorSubscriberGuarder { + public: + using ArgDeleter = void (*)(void *); + + ExecutorSubscriberGuarder(::SubscriberFunc func, void *arg, ArgDeleter deleter) + : subscriber_({func, arg}), arg_deleter_(deleter) {} + ExecutorSubscriberGuarder(ExecutorSubscriberGuarder &&other) noexcept { + MoveAssignment(other); + } + ExecutorSubscriberGuarder &operator=(ExecutorSubscriberGuarder &&other) noexcept { + DeleteArg(); + MoveAssignment(other); + return *this; + } + + ExecutorSubscriber &GetSubscriber() { + return subscriber_; + } + + const ExecutorSubscriber &GetSubscriber() const { + return subscriber_; + } + + ~ExecutorSubscriberGuarder() { + DeleteArg(); + } + + void SetEnabledFunc(const std::function &enabled_func) { + enabled_func_ = enabled_func; + } + + bool IsEnabled() const { + if (enabled_func_ == nullptr) { + return true; + } + if (enabled_func_()) { + return true; + } + return false; + } + + ExecutorSubscriberGuarder(const ExecutorSubscriberGuarder &) = delete; + ExecutorSubscriberGuarder &operator=(const ExecutorSubscriberGuarder &) = delete; + + private: + void DeleteArg() { + if (arg_deleter_ != nullptr) { + arg_deleter_(subscriber_.arg); + } + enabled_func_ = nullptr; + } + void MoveAssignment(ExecutorSubscriberGuarder &other) { + subscriber_ = other.subscriber_; + arg_deleter_ = other.arg_deleter_; + enabled_func_ = other.enabled_func_; + other.subscriber_ = {nullptr, nullptr}; + other.arg_deleter_ = nullptr; + other.enabled_func_ = nullptr; + } + + private: + ExecutorSubscriber subscriber_{nullptr, nullptr}; + ArgDeleter arg_deleter_{nullptr}; + std::function enabled_func_{nullptr}; +}; +using ExecutorSubscriberGuarderPtr = std::shared_ptr; +} // namespace gert +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBER_GUARDER_H_ diff --git a/inc/air/framework/runtime/subscriber/executor_subscribers_scheduler.h b/inc/air/framework/runtime/subscriber/executor_subscribers_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..56eb5c8a5ae2d92a175e915d8b496e7e4e55ef18 --- /dev/null +++ b/inc/air/framework/runtime/subscriber/executor_subscribers_scheduler.h @@ -0,0 +1,236 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBERS_SCHEDULER_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBERS_SCHEDULER_H_ +#include +#include +#include "built_in_subscriber_definitions.h" +#include "executor_subscriber_guarder.h" +#include "framework/common/ge_visibility.h" +#include "global_profiler.h" +#include "global_dumper.h" +#include "global_tracer.h" +#include "graph/any_value.h" +#include "framework/runtime/exe_graph_executor.h" +#include "common/util/mem_utils.h" + +namespace gert { +class VISIBILITY_EXPORT ExecutorSubscribersScheduler { + public: + static void OnExecuteEvent(SubExeGraphType sub_exe_graph_type, const ExecutorSubscribersScheduler *ins, + ExecutorEvent event, const void *node, KernelStatus result); + + ExecutorSubscribersScheduler() + : enabled_(false), + built_in_subscribers_ptr_(), + sub_exe_graph_subscribers_(), + subscribers_holder_(), + subscriber_wrapper_({reinterpret_cast<::SubscriberFunc>(ExecutorSubscribersScheduler::OnExecuteEvent), this}) {} + void Init(const std::shared_ptr &extend_info); + ExecutorSubscribersScheduler(const ExecutorSubscribersScheduler &) = delete; + ExecutorSubscribersScheduler &operator=(const ExecutorSubscribersScheduler &) = delete; + ExecutorSubscriber &GetSubscriber(SubExeGraphType sub_exe_graph_type); + const std::vector &GetWorkingSubscribers() const { + return working_sub_exe_graph_subscribers_; + } + /** + * 为所有子图类型设置订阅者,订阅者需要实现一个static方法,原型为: + * ```c++ + * static void OnExecuteEvent(T *void_arg, ExecutorEvent event, const void *node, KernelStatus result); + * ``` + * + * 默认情况下,subscribers处于disable状态,在添加首个subscriber时,自动将状态切换到enable状态。 + * + * @tparam T 订阅者类型 + * @tparam Args 订阅者初始化参数类型 + * @param args 订阅者初始化参数 + * @return 添加的subscriber指针,注意subscriber所有权归`ExecutorSubscribersScheduler`所有,外部使用者不可以释放此指针 + */ + template + T *AddSubscriber(Args... args) { + auto ins = AddSubscriberGuarder(args...); + if (ins == nullptr) { + return nullptr; + } + for (size_t i = 0U; i < kSubExeGraphTypeEnd; ++i) { + sub_exe_graph_subscribers_[i].emplace_back(subscribers_holder_[subscribers_holder_.size() - 1U]); + } + return ins; + } + + /** + * 为指定子图类型设置订阅者,订阅者需要实现一个static方法,原型为: + * ```c++ + * static void OnExecuteEvent(T *void_arg, ExecutorEvent event, const void *node, KernelStatus result); + * ``` + * + * 默认情况下,subscribers处于disable状态,在添加首个subscriber时,自动将状态切换到enable状态。 + * + * @tparam T 订阅者类型 + * @param sub_exe_graph_type 子图类型 + * @tparam Args 订阅者初始化参数类型 + * @param args 订阅者初始化参数 + * @return 添加的subscriber指针,注意subscriber所有权归`ExecutorSubscribersScheduler`所有,外部使用者不可以释放此指针 + */ + template + T *AddSubscriber(SubExeGraphType sub_exe_graph_type, Args... args) { + auto ins = AddSubscriberGuarder(args...); + if (ins == nullptr) { + return nullptr; + } + sub_exe_graph_subscribers_[sub_exe_graph_type].emplace_back(subscribers_holder_[subscribers_holder_.size() - 1U]); + return ins; + } + + template + T *AddSubscriber(SubExeGraphType sub_exe_graph_type, const std::function &enabled_func, Args... args) { + auto ins = AddSubscriberGuarder(args...); + if (ins == nullptr) { + return nullptr; + } + auto &subscriber_guarder = subscribers_holder_[subscribers_holder_.size() - 1U]; + subscriber_guarder->SetEnabledFunc(enabled_func); + sub_exe_graph_subscribers_[sub_exe_graph_type].emplace_back(subscriber_guarder); + return ins; + } + + template + T *AddSubscriber(const std::function &enabled_func, Args... args) { + auto ins = AddSubscriberGuarder(args...); + if (ins == nullptr) { + return nullptr; + } + for (size_t i = 0U; i < kSubExeGraphTypeEnd; ++i) { + auto &subscriber_guarder = subscribers_holder_[subscribers_holder_.size() - 1U]; + subscriber_guarder->SetEnabledFunc(enabled_func); + sub_exe_graph_subscribers_[i].emplace_back(subscriber_guarder); + } + return ins; + } + /** + * 添加一个内置的subscriber + * 内置subscriber较少,当前没有使用注册机制,后续如果需要扩展,那么可以考虑通过注册机制自动注册。 + * 为了易用性,在本类提供了获取内置subscriber的指针的接口。而自注册的subscriber将丢失此能力。 + * @param subscriber_type + */ + template + void AddBuiltIn(BuiltInSubscriberType subscriber_type, uint64_t enable_flag, + const std::shared_ptr &extend_info, + SubExeGraphType sub_graph_type, const std::function &enabled_func) { + (void)enable_flag; + if (subscriber_type >= BuiltInSubscriberType::kNum) { + GELOGW("Unexpected built-in subscriber type %zu", static_cast(subscriber_type)); + return; + } + + auto subscriber_index = static_cast(subscriber_type); + if (built_in_subscribers_ptr_[subscriber_index] != nullptr) { + GELOGW("The built in subscriber %zu already exists, ignore the add operation", subscriber_index); + return; + } + + void *ins; + if (sub_graph_type == kSubExeGraphTypeEnd) { + ins = AddSubscriber(enabled_func, extend_info); + } else { + ins = AddSubscriber(sub_graph_type, enabled_func, extend_info); + } + built_in_subscribers_ptr_[subscriber_index] = ins; + } + void RemoveSubscriber(const void *subscriber_ptr) { + for (auto iter = subscribers_holder_.begin(); iter != subscribers_holder_.end(); ++iter) { + if ((*iter)->GetSubscriber().arg == subscriber_ptr) { + RemoveFromSubExeGraphSubscribers(subscriber_ptr); + subscribers_holder_.erase(iter); + break; + } + } + for (auto &built_in_subscriber : built_in_subscribers_ptr_) { + if (built_in_subscriber == subscriber_ptr) { + built_in_subscriber = nullptr; + } + } + if (subscribers_holder_.size() == static_cast(BuiltInSubscriberType::kNum)) { + enabled_ = false; + } + } + + template + inline T *MutableBuiltInSubscriber(const BuiltInSubscriberType type) { + return static_cast(built_in_subscribers_ptr_[static_cast(type)]); + } + + template + inline const T *GetBuiltInSubscriber(const BuiltInSubscriberType type) const { + return static_cast(built_in_subscribers_ptr_[static_cast(type)]); + } + + bool IsEnable() const { + return enabled_ || static_cast(GlobalProfilingWrapper::GetInstance()->GetEnableFlags()) || + (GlobalDumper::GetInstance()->IsEnableSubscribeDump()) || + static_cast(GlobalTracer::GetInstance()->GetEnableFlags()); + } + void SetEnable(bool enable_flag) { + enabled_ = enable_flag; + } + void Clear() { + subscribers_holder_.clear(); + for (auto &built_in_subscriber : built_in_subscribers_ptr_) { + built_in_subscriber = nullptr; + } + for (auto &subscribers_vec : sub_exe_graph_subscribers_) { + subscribers_vec.clear(); + } + enabled_ = false; + } + size_t GetSize() const { + return subscribers_holder_.size(); + } + private: + template + T *AddSubscriberGuarder(Args... args) { + auto ins = new (std::nothrow) T(args...); + if (ins == nullptr) { + return nullptr; + } + // profiler exists when ess init + if (subscribers_holder_.size() == static_cast(BuiltInSubscriberType::kNum)) { + enabled_ = true; + } + auto guarder = ge::MakeShared(reinterpret_cast<::SubscriberFunc>(T::OnExecuteEvent), + ins, ObjectDeleter); + if (guarder == nullptr) { + delete ins; + return nullptr; + } + subscribers_holder_.emplace_back(guarder); + return ins; + } + void RemoveFromSubExeGraphSubscribers(const void *subscriber_ptr) { + for (auto &subscribers_vec : sub_exe_graph_subscribers_) { + for (auto iter = subscribers_vec.begin(); iter != subscribers_vec.end(); ++iter) { + if (subscriber_ptr == (*iter)->GetSubscriber().arg) { + subscribers_vec.erase(iter); + return; + } + } + } + } + private: + bool enabled_{false}; + std::array(BuiltInSubscriberType::kNum)> + built_in_subscribers_ptr_; + std::array, kSubExeGraphTypeEnd> sub_exe_graph_subscribers_; + std::vector working_sub_exe_graph_subscribers_{}; + std::vector subscribers_holder_; + ExecutorSubscriber subscriber_wrapper_; +}; +} // namespace gert +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_SUBSCRIBERS_SCHEDULER_H_ diff --git a/inc/air/framework/runtime/subscriber/global_dumper.h b/inc/air/framework/runtime/subscriber/global_dumper.h new file mode 100644 index 0000000000000000000000000000000000000000..a173760d392ea09f5c9e88a353a499c2ea44c77a --- /dev/null +++ b/inc/air/framework/runtime/subscriber/global_dumper.h @@ -0,0 +1,120 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_GLOBAL_DUMPER_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_GLOBAL_DUMPER_H_ +#include + +#include "built_in_subscriber_definitions.h" +#include "framework/common/ge_visibility.h" +#include "graph/compute_graph.h" +#include "runtime/base.h" +#include "common/debug/ge_log.h" + +namespace ge { +class ExceptionDumper; +} +namespace gert { +struct GlobalDumperSwitchHandler { + void *arg; + ge::Status (*on_event)(void *arg, uint64_t event); +}; +// for global info for exception_dump and global switch +class VISIBILITY_EXPORT GlobalDumper { + public: + GlobalDumper(const GlobalDumper &) = delete; + GlobalDumper(GlobalDumper &&) = delete; + GlobalDumper &operator=(const GlobalDumper &) = delete; + GlobalDumper &operator==(GlobalDumper &&) = delete; + static GlobalDumper *GetInstance(); + + static void OnGlobalDumperSwitch(void *ins, uint64_t enable_flags); + + void AddExceptionInfo(const rtExceptionInfo &exception_data); + + ge::ExceptionDumper *MutableExceptionDumper(); + + void SetEnableFlags(const uint64_t enable_flags) { + enable_flags_ = enable_flags; + PublishEvent(); + } + + void RegisterHandler(const void *key, GlobalDumperSwitchHandler handler) { + if (handler.on_event == nullptr) { + GELOGW("Register global dumper handler use a nullptr, ignore it"); + return; + } + + std::lock_guard lock(mutex_); + if (!keys_to_handler_.emplace(key, handler).second) { + return; + } + (void)handler.on_event(handler.arg, enable_flags_); + } + + void UnregisterHandler(const void * const key) { + const std::lock_guard lock(mutex_); + (void)keys_to_handler_.erase(key); + } + + size_t GetHandleSize() { + std::lock_guard lock(mutex_); + return keys_to_handler_.size(); + } + + // 判断是否有需要通过OnEvent接口做dump的Dump工具被打开 + bool IsEnableSubscribeDump() const { + return static_cast(enable_flags_ & + (BuiltInSubscriberUtil::EnableBit(DumpType::kNeedSubscribe) - 1UL)); + }; + + bool IsEnable(DumpType dump_type) const { + return static_cast(enable_flags_ & BuiltInSubscriberUtil::EnableBit(dump_type)); + } + + std::set &GetInnerExceptionDumpers() { + std::lock_guard lk(mutex_); + return exception_dumpers_; + } + + void ClearInnerExceptionDumpers() { + std::lock_guard lk(mutex_); + exception_dumpers_.clear(); + } + + void SaveInnerExceptionDumpers(ge::ExceptionDumper *exception_dumper) { + std::lock_guard lk(mutex_); + const auto ret = exception_dumpers_.insert(exception_dumper); + if (!ret.second) { + GELOGW("[Dumper] Save exception dumper of davinci model failed."); + } + } + + void RemoveInnerExceptionDumpers(ge::ExceptionDumper *exception_dumper) { + std::lock_guard lk(mutex_); + (void)exception_dumpers_.erase(exception_dumper); + } + + private: + GlobalDumper(); + uint64_t enable_flags_{0UL}; + // each davinci model has own exception dumper + std::set exception_dumpers_{}; + std::mutex mutex_; + std::map keys_to_handler_; + void PublishEvent() { + std::lock_guard lock(mutex_); + for (const auto &key_to_handler : keys_to_handler_) { + auto &handler = key_to_handler.second; + (void)handler.on_event(handler.arg, enable_flags_); + } + } +}; +} // namespace gert +#endif diff --git a/inc/air/framework/runtime/subscriber/global_profiler.h b/inc/air/framework/runtime/subscriber/global_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..e9198c022cfd1184a74f07c1f37bab961a9ab893 --- /dev/null +++ b/inc/air/framework/runtime/subscriber/global_profiler.h @@ -0,0 +1,426 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_SUBSCRIBER_GLOBAL_PROFILER_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_SUBSCRIBER_GLOBAL_PROFILER_H_ + +#include +#include +#include +#include +#include "mmpa/mmpa_api.h" +#include "built_in_subscriber_definitions.h" +#include "common/debug/ge_log.h" +#include "framework/common/ge_visibility.h" +#include "runtime/subscriber/executor_subscriber_c.h" +#include "runtime/base.h" +#include "toolchain/prof_api.h" +#include "common/checker.h" +#include "graph/gnode.h" + +namespace gert { +constexpr uint32_t kTensorInfoBytes = 44UL; +constexpr uint32_t kTensorInfoBytesWithCap = 56U; +constexpr size_t kMaxContextIdNum = + (static_cast(MSPROF_ADDTIONAL_INFO_DATA_LENGTH) - sizeof(uint32_t) - sizeof(uint64_t)) / sizeof(uint32_t); +struct ProfilingData { + uint64_t name_idx; + uint64_t type_idx; + ExecutorEvent event; + std::chrono::time_point timestamp; + int64_t thread_id; +}; +enum class GeProfInfoType { + // model level + kModelExecute = MSPROF_REPORT_MODEL_GRAPH_ID_MAP_TYPE + 1, // 模型执行 + kModelLoad, // 模型加载 + kInputCopy, // input拷贝 + kOutputCopy, // output拷贝 + kModelLevelEnd, + // node level + // todo + kHostOpExec = MSPROF_REPORT_NODE_HOST_OP_EXEC_TYPE, + kInferShape = MSPROF_REPORT_NODE_GE_API_BASE_TYPE + 1, + kCompatibleInferShape, + kTiling, + kCompatibleTiling, + kStreamSync, + kStepInfo, + kNodeLevelEnd, + // acl level + kIsGraphNeedRebuild = MSPROF_REPORT_ACL_GRAPH_BASE_TYPE + 1, + kRemoveGraph, + kAddGraph, + kBuildGraph, + kRunGraphAsync, + kGEInitialize, + kGEFinalize, + kAclLevelEnd +}; + +struct ContextIdInfoWrapper { + MsprofAdditionalInfo context_id_info; + std::string op_name; +}; + +extern const std::unordered_map kNamesToProfTypes; +class GlobalProfiler { + public: + GlobalProfiler() = default; + void Record(uint64_t name_idx, uint64_t type_idx, ExecutorEvent event, + std::chrono::time_point timestamp) { + auto index = count_.fetch_add(1, std::memory_order_relaxed); + if (index >= kProfilingDataCap) { + return; + } + thread_local static auto tid = static_cast(mmGetTid()); + records_[index] = {name_idx, type_idx, event, timestamp, tid}; + } + void Dump(std::ostream &out_stream, std::vector &idx_to_str) const; + size_t GetCount() const { + return count_.load(); + } + + private: + std::atomic count_{0UL}; + ProfilingData records_[kProfilingDataCap]; +}; + +struct ProfFusionMemSize { + uint64_t input_mem_size{0UL}; + uint64_t output_mem_size{0UL}; + uint64_t weight_mem_size{0UL}; + uint64_t workspace_mem_size{0UL}; +}; + +class VISIBILITY_EXPORT GlobalProfilingWrapper { + public: + GlobalProfilingWrapper(const GlobalProfilingWrapper &) = delete; + GlobalProfilingWrapper(GlobalProfilingWrapper &&) = delete; + GlobalProfilingWrapper &operator=(const GlobalProfilingWrapper &) = delete; + GlobalProfilingWrapper &operator=(GlobalProfilingWrapper &&) = delete; + + static GlobalProfilingWrapper *GetInstance() { + static GlobalProfilingWrapper global_prof_wrapper; + return &global_prof_wrapper; + } + + static void OnGlobalProfilingSwitch(void *ins, uint64_t enable_flags); + + void Init(const uint64_t enable_flags); + + void Free() { + global_profiler_.reset(nullptr); + SetEnableFlags(0UL); + } + + GlobalProfiler *GetGlobalProfiler() const { + return global_profiler_.get(); + } + + void SetEnableFlags(const uint64_t enable_flags) { + enable_flags_.store(enable_flags); + } + + uint64_t GetRecordCount() { + if (global_profiler_ == nullptr) { + return 0UL; + } + return global_profiler_->GetCount(); + } + + uint64_t GetEnableFlags() const { + return enable_flags_.load(); + } + + bool IsEnabled(ProfilingType profiling_type) { + return enable_flags_.load() & BuiltInSubscriberUtil::EnableBit(profiling_type); + } + + void IncreaseProfCount() { + prof_count_++; + } + + uint64_t GetProfCount() const { + return prof_count_.load(); + } + + void DumpAndFree(std::ostream &out_stream) { + Dump(out_stream); + Free(); + } + void Dump(std::ostream &out_stream) { + if (global_profiler_ != nullptr) { + global_profiler_->Dump(out_stream, idx_to_str_); + } + } + void Record(uint64_t name_idx, uint64_t type_idx, ExecutorEvent event, + std::chrono::time_point timestamp) { + if (global_profiler_ != nullptr) { + global_profiler_->Record(name_idx, type_idx, event, timestamp); + } + } + + uint64_t RegisterString(const std::string &name); + + const std::vector &GetIdxToStr() const { + return idx_to_str_; + } + uint32_t GetProfModelId() const; + ge::Status RegisterExtendProfType(const std::string &name, const uint32_t idx) const; + void IncProfModelId(); + void RegisterBuiltInString(); + ge::Status RegisterProfType() const; + static ge::Status ReportEvent(const uint64_t item_id, const uint32_t request_id, const GeProfInfoType type, + MsprofEvent &prof_single_event); + static ge::Status ReportApiInfo(const uint64_t begin_time, const uint64_t end_time, const uint64_t item_id, + const uint32_t api_type); + static ge::Status ReportApiInfoModelLevel(const uint64_t begin_time, const uint64_t end_time, const uint64_t item_id, + const uint32_t api_type); + static ge::Status ReportTensorInfo(const uint32_t tid, const bool is_aging, const ge::TaskDescInfo &task_desc_info); + static void BuildNodeBasicInfo(const ge::OpDescPtr &op_desc, const uint32_t block_dim, + const std::pair &op_name_and_type_hash, const uint32_t task_type, + MsprofCompactInfo &node_basic_info); + static void BuildCompactInfo(const uint64_t prof_time, MsprofCompactInfo &node_basic_info); + static void BuildApiInfo(const std::pair &prof_time, const uint32_t api_type, + const uint64_t item_id, MsprofApi &api); + static void BuildContextIdInfo(const uint64_t prof_time, const std::vector &context_ids, + const size_t op_name, const std::string &op_name_str, + std::vector &infos); + + static ge::Status ReportGraphIdMap(const uint64_t prof_time, const uint32_t tid, + const std::pair graph_id_and_model_id, const bool is_aging, + const size_t model_name = 0UL); + static ge::Status ProfileStepTrace(const uint64_t index_id, const uint32_t model_id, const uint16_t tag_id, + const rtStream_t stream); + static void BuildSingleProfTensorInfo(const uint32_t tid, const ge::TaskDescInfo &task_desc_info, const size_t index, + const uint32_t tensor_num, MsprofAdditionalInfo &tensor_info); + + static void BuildFusionOpInfo(const ProfFusionMemSize &mem_size, const std::vector &origin_op_names, + const size_t op_name, std::vector &infos); + static ge::Status RecordAndReportMallocTaskMemoryInfo(const void *const addr, const size_t size, + const std::string &model_name); + static ge::Status RecordAndReportFreeTaskMemoryInfo(const void *const addr, const size_t size, + const std::string &model_name); + static unsigned int ReportLogicStreamInfo( + const uint64_t timestamp, const uint32_t tid, + const std::unordered_map> &logic_stream_ids_to_physic_stream_ids, + const uint16_t is_aging); + + static ge::Status ReportStaticOpMemInfo(const ge::ComputeGraphPtr &graph, const ge::OpDescPtr &op_desc, + const uint64_t mem_size, const uint64_t life_start, const uint64_t life_end); + + private: + GlobalProfilingWrapper(); + static ge::Status ReportTaskMemoryInfo(const std::string &model_name); + static void BuildSingleContextIdInfo(const uint64_t prof_time, const vector &context_ids, + const size_t index, const size_t context_id_num, MsprofAdditionalInfo &info); + + static void BuildProfFusionInfoBase(const ProfFusionMemSize &mem_size, const size_t fusion_op_num, + const size_t op_name, ProfFusionOpInfo *prof_fusion_data); + + private: + std::unique_ptr global_profiler_{nullptr}; + std::atomic enable_flags_{0UL}; + uint64_t str_idx_{0UL}; + bool is_builtin_string_registered_{false}; + std::vector idx_to_str_; + std::mutex register_mutex_; + std::mutex mutex_; + // rt2流程acl会给推理场景生成一个model id,但是静态子图没有办法获取这个model id,生成的davinci model + // 的model id为uint32_t的最大值,因此这种情况下,需要给它一个model id且生成model id的逻辑需要与acl一致 + std::atomic_uint32_t model_id_generator_{std::numeric_limits::max() / 2U}; + std::atomic prof_count_{0UL}; +}; + +class ScopeProfiler { + public: + ScopeProfiler(const size_t element, const size_t event) : element_(element), event_(event) { + if (GlobalProfilingWrapper::GetInstance()->IsEnabled(ProfilingType::kGeHost)) { + start_trace_ = std::chrono::system_clock::now(); + } + } + + void SetElement(const size_t element) { + element_ = element; + } + + ~ScopeProfiler() { + if (GlobalProfilingWrapper::GetInstance()->IsEnabled(ProfilingType::kGeHost)) { + GlobalProfilingWrapper::GetInstance()->Record(element_, event_, kExecuteStart, start_trace_); + GlobalProfilingWrapper::GetInstance()->Record(element_, event_, kExecuteEnd, std::chrono::system_clock::now()); + } + } + + private: + std::chrono::time_point start_trace_; + size_t element_; + size_t event_; +}; + +class GraphProfilingReporter { + public: + explicit GraphProfilingReporter(const GeProfInfoType api_id) : graphApi_(api_id) { + if (GlobalProfilingWrapper::GetInstance()->IsEnabled(ProfilingType::kTaskTime)) { + start_time_ = MsprofSysCycleTime(); + } + } + + ~GraphProfilingReporter() { + if (GlobalProfilingWrapper::GetInstance()->IsEnabled(ProfilingType::kTaskTime)) { + const uint64_t end_time_ = MsprofSysCycleTime(); + MsprofApi api{}; + api.beginTime = start_time_; + api.endTime = end_time_; + thread_local static auto tid = mmGetTid(); + api.threadId = static_cast(tid); + api.level = MSPROF_REPORT_ACL_LEVEL; + api.type = static_cast(graphApi_); + (void)MsprofReportApi(true, &api); + } + } + + private: + uint64_t start_time_ = 0UL; + const GeProfInfoType graphApi_; +}; + +class VISIBILITY_EXPORT ProfilerRegistry { + public: + ProfilerRegistry(const ProfilerRegistry &) = delete; + ProfilerRegistry(ProfilerRegistry &&) = delete; + ProfilerRegistry &operator=(const ProfilerRegistry &) = delete; + ProfilerRegistry &operator=(ProfilerRegistry &&) = delete; + + static ProfilerRegistry &GetInstance(); + + void SaveRegistryType(const std::string &type, const bool launch_flag); + bool IsProfLaunchType(const std::string &kernel_type, const bool launch_flag = true); + bool IsProfDavinciModelExecuteType(const std::string &kernel_type) const; + private: + ProfilerRegistry() noexcept = default; + std::vector register_prof_launch_type_{}; + std::vector register_prof_non_launch_type_{}; + std::mutex mutex_; +}; + +class ProfLaunchTypeRegistry { + public: + explicit ProfLaunchTypeRegistry(const std::string &type, const bool launch_flag) noexcept { + ProfilerRegistry::GetInstance().SaveRegistryType(type, launch_flag); + } +}; +} // namespace gert + +#define GE_PROFILING_START(event) \ + std::chrono::time_point event##start_time; \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kGeHost)) { \ + event##start_time = std::chrono::system_clock::now(); \ + } + +#define GE_PROFILING_END(name_idx, type_idx, event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kGeHost)) { \ + gert::GlobalProfilingWrapper::GetInstance()->Record(name_idx, type_idx, ExecutorEvent::kExecuteStart, \ + event##start_time); \ + gert::GlobalProfilingWrapper::GetInstance()->Record(name_idx, type_idx, ExecutorEvent::kExecuteEnd, \ + std::chrono::system_clock::now()); \ + } \ + } while (false) + +#define CANN_PROFILING_API_END(item_id, info_type, event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + const auto end_time = MsprofSysCycleTime(); \ + gert::GlobalProfilingWrapper::GetInstance()->ReportApiInfo(event##begin_time, end_time, item_id, info_type); \ + } \ + } while (false) + +#define CANN_PROFILING_API_START(event) \ + uint64_t event##begin_time; \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + event##begin_time = MsprofSysCycleTime(); \ + } + +#define CANN_PROFILING_MODEL_API_END(item_id, info_type, end_time, event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + gert::GlobalProfilingWrapper::GetInstance()->ReportApiInfoModelLevel( \ + event##begin_time, end_time, item_id, info_type); \ + } \ + } while (false) + +#define CANN_PROFILING_MODEL_API_START(event) \ + uint64_t event##begin_time = 0UL; \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + event##begin_time = MsprofSysCycleTime(); \ + } + +#define CANN_PROFILING_EVENT_START(item_id, request_id, info_type, single_event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + (void)gert::GlobalProfilingWrapper::ReportEvent(item_id, request_id, info_type, single_event); \ + } \ + } while (false) + +#define CANN_PROFILING_EVENT_END(item_id, request_id, info_type, single_event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + const uint64_t prof_time = MsprofSysCycleTime(); \ + (single_event).timeStamp = prof_time; \ + (void)MsprofReportEvent(true, &(single_event)); \ + } \ + } while (false) + +#define CANN_PROFILING_INFER_EVENT_START(request_id, info_type, single_event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + (void)gert::GlobalProfilingWrapper::ReportEvent(gert::GlobalProfilingWrapper::GetInstance()->GetProfModelId(), \ + request_id, info_type, single_event); \ + } \ + } while (false) + +#define CANN_PROFILING_INFER_EVENT_END(request_id, info_type, single_event) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + const uint64_t prof_time = MsprofSysCycleTime(); \ + (single_event).timeStamp = prof_time; \ + (void)MsprofReportEvent(true, &(single_event)); \ + } \ + } while (false) + +#define CANN_PROFILING_GRAPH_ID(prof_time, tid, graph_id, model_id, is_aging, model_name) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kTaskTime)) { \ + (void)gert::GlobalProfilingWrapper::ReportGraphIdMap(prof_time, tid, {graph_id, model_id}, is_aging, \ + model_name); \ + } \ + } while (false) + +#define CANN_PROFILING_STEP_TRACE(item_id, request_id, tag_id, stream) \ + do { \ + (void)gert::GlobalProfilingWrapper::ProfileStepTrace(request_id, item_id, tag_id, stream); \ + } while (false) +#define REGISTER_PROF_TYPE(type) const gert::ProfLaunchTypeRegistry type##prof_type_registry(#type, true) +#define REGISTER_PROF_NON_LAUNCH_TYPE(type) const gert::ProfLaunchTypeRegistry type##prof_type_registry(#type, false) +#define GE_ASSERT_MSPROF_OK(v, ...) \ + GE_ASSERT((((v) == MSPROF_ERROR_NONE) || ((v) == MSPROF_ERROR_UNINITIALIZE)), __VA_ARGS__) +#define RT2_PROFILING_SCOPE(element, event) gert::ScopeProfiler profiler((element), event) +#define RT2_PROFILING_SCOPE_CONST(element, event) const gert::ScopeProfiler profiler((element), (event)) +#define RT2_PROFILING_SCOPE_ELEMENT(element) profiler.SetElement(element) +#define GRAPH_PROFILING_REG(api_id) const gert::GraphProfilingReporter profilingReporter(api_id) + +#define CANN_PROFILING_REPORT_STATIC_OP_MEM_INFO(graph, op_desc, mem_size, life_start, life_end) \ + do { \ + if (gert::GlobalProfilingWrapper::GetInstance()->IsEnabled(gert::ProfilingType::kMemory)) { \ + (void)gert::GlobalProfilingWrapper::ReportStaticOpMemInfo(graph, op_desc, mem_size, life_start, life_end); \ + } \ + } while (false) + +#endif diff --git a/inc/air/framework/runtime/subscriber/global_tracer.h b/inc/air/framework/runtime/subscriber/global_tracer.h new file mode 100644 index 0000000000000000000000000000000000000000..c5fc904956403d9baa82bc347baff117456735da --- /dev/null +++ b/inc/air/framework/runtime/subscriber/global_tracer.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_GLOBAL_TRACER_H_ +#define AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_GLOBAL_TRACER_H_ +#include "framework/common/debug/ge_log.h" +#include "framework/common/ge_visibility.h" + +namespace gert { +class VISIBILITY_EXPORT GlobalTracer { + public: + static GlobalTracer *GetInstance() { + static GlobalTracer global_tracer; + return &global_tracer; + } + uint64_t GetEnableFlags() const { + return static_cast(IsLogEnable(GE_MODULE_NAME, DLOG_INFO)); + }; + private: + GlobalTracer() {}; + ~GlobalTracer() {}; + GlobalTracer(const GlobalTracer &) = delete; + GlobalTracer(GlobalTracer &&) = delete; + GlobalTracer &operator=(const GlobalTracer &) = delete; + GlobalTracer &operator=(GlobalTracer &&) = delete; +}; +} +#endif // AIR_CXX_INC_FRAMEWORK_RUNTIME_EXECUTOR_GLOBAL_TRACER_H_ diff --git a/inc/air/nneng/common/aicore_util_attr_define.h b/inc/air/nneng/common/aicore_util_attr_define.h new file mode 100644 index 0000000000000000000000000000000000000000..2a86508fd046428e6b7f4bd0c996096ff9d2156c --- /dev/null +++ b/inc/air/nneng/common/aicore_util_attr_define.h @@ -0,0 +1,266 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FUSION_ENGINE_INC_COMMON_AICORE_UTIL_ATTR_DEFINE_H_ +#define FUSION_ENGINE_INC_COMMON_AICORE_UTIL_ATTR_DEFINE_H_ + +#include + +namespace fe { +// sgt attr +const std::string NON_TAIL_WORKSPACE_SIZE = "_non_tail_workspace_size"; + +const std::string ATTR_NAME_THREAD_CUBE_VECTOR_CORE_TYPE = "_thread_cube_vector_core_type"; + +const std::string ATTR_NAME_THREAD_TBE_KERNEL_SIZE = "_thread_tbeKernelSize"; + +const std::string ATTR_NAME_THREAD_TVM_CACHE_READ_MODE = "thread_tvm_cache_read_mode"; + +const std::string kThreadTaskRadio = "_thread_task_ratio"; + +const std::string kThreadModeInArgsFirstField = "_thread_mode_in_args_first_field"; + +constexpr char kThreadIntercoreSync[] = "_thread_mode_inter_core_sync"; + +const std::string kContextId = "_context_id"; + +const std::string kPrefetchEnableBm = "_prefetch_enable_bm"; + +const std::string kInvalidateBm = "_invalidate_bm"; + +const std::string kWriteBackBm = "_write_back_bm"; + +const std::string kAutoCtxIdList = "_context_id_list"; + +const std::string kAttrName64BytesFlag = "_align_64_bytes_flag"; + +const std::string SCOPE_ID_ATTR = "fusion_scope"; + +const std::string kPassNameAttr = "pass_name"; + +const std::string kPassNameUbAttr = "pass_name_ub"; + +const std::string kAttrNameIsFusionOp = "_is_fusion_op"; + +const std::string kOpConstValueList = "_const_value_list"; + +const std::string L1_SCOPE_ID_ATTR = "_l1_fusion_scope"; + +const std::string FAILED_L1_SCOPE_ID_ATTR = "_failed_l1_fusion_scope"; + +const std::string FE_IMPLY_TYPE = "_fe_imply_type"; + +const std::string PARENT_OP_TYPE = "parentOpType"; + +const std::string ATTR_NAME_TASK_L2_FUSION_INFO_EXTEND_PTR = "task_l2_fusion_info_extend_content"; + +const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; + +const std::string ATTR_NAME_L2_FUSION_EXTEND_PTR = "l2_fusion_extend_content"; + +const std::string ATTR_NAME_UNKNOWN_SHAPE = "_unknown_shape"; + +const std::string ATTR_NAME_OWNER_GRAPH_IS_UNKNOWN = "OwnerGraphIsUnknown"; + +const std::string ATTR_NAME_TBE_KERNEL_SIZE = "_tbeKernelSize"; + +const std::string ATTR_NAME_ALIAS_ENGINE_NAME = "_alias_engine_name"; +/* This is not same with the attr engine type. + * It reveals the op is cube operator or vector operator(implementation related). + * Firstly, it will be set by op information library and secondly + * it will be set after pre-compile (then using this attr to set compilation params) + * Finnaly, after compilation, it will be set by the compilation result. */ +const std::string ATTR_NAME_CUBE_VECTOR_CORE_TYPE = "_cube_vector_core_type"; + +const std::string kSgtCubeVectorCoreType = "_sgt_cube_vector_core_type"; + +const std::string kTaskRadio = "_task_ratio"; + +const std::string kModeInArgsFirstField = "_mode_in_args_first_field"; + +const std::string kForceFp32ToFp16 = "_force_fp32_to_fp16"; + +const std::string ATTR_NAME_KERNEL_LIST_FIRST_NAME = "_kernel_list_first_name"; + +const std::string ATTR_NAME_IS_FIRST_NODE = "is_first_node"; + +const std::string ATTR_NAME_IS_LAST_NODE = "is_last_node"; + +const std::string OP_SLICE_INFO = "_op_slice_info"; + +const std::string FUSION_OP_SLICE_INFO = "_fusion_op_slice_info"; + +const char *const SINGLE_SLICE_INFO = "_op_slice_info"; + +const char *const FUSION_SLICE_INFO = "_fusion_op_slice_info"; + +const std::string NEED_RE_PRECOMPILE = "need_re_precompile"; + +const std::string CORE_TYPE_VALUE = "core_type_value"; + +// set the attr by L2 Fusion +const std::string ATTR_NAME_LX_FUSION_PASS = "lx_fusion_pass"; + +const std::string ATTR_NAME_SUPPORT_DYNAMIC_SHAPE = "support_dynamicshape"; + +const std::string ATTR_NAME_IS_OP_DYNAMIC_IMPL = "_is_op_dynamic_impl"; + +const std::string kNeedSetSliceInfo = "_need_set_slice_info"; + +const std::string kThreadScopeId = "_thread_scope_id"; + +const std::string kThreadMode = "_thread_mode"; + +const std::string kTypeFFTSPlus = "_ffts_plus"; + +const std::string kNoMemReuse = "_no_mem_reuse"; + +const std::string kNoStreamSplit = "_no_stream_split"; + +const std::string kAttrEngineType = "_engine_type"; + +const std::string kLxNoReuseMemFlag = "no_reuse_mem_flag"; + +const std::string kReusedParam = "_param_reused"; + +const std::string kRelationReusedParam = "_param_reused_relation"; + +const std::string kCmoPrefetch = "Prefetch"; + +const std::string kCmoInvalid = "Invalid"; + +const std::string kCmoBarrier = "Barrier"; + +const std::string kCmoWriteback = "Writeback"; + +const std::string kOpExtattrNameCmo = "cmo_"; + +const std::string kAttrExtraParams = "_extra_params"; + +const std::string kOpShapeOrRangeUnsupport = "_shape_or_range_not_support"; + +const std::string kGlobalworkspaceBytes = "globalworkspace_spec_workspace_bytes"; + +const std::string kFusionOpBuildOptions = "fusion_op_build_options"; + +const std::string kConcatCOptimizeInfoPtr = "_concat_c_optimize_info_ptr"; + +const std::string kConcatCOptimizeInfoStr = "_concat_c_optimize_info_str"; + +const std::string kSplitCOptimizeInfoPtr = "_split_c_optimize_info_ptr"; + +const std::string kSplitCOptimizeInfoStr = "_split_c_optimize_info_str"; + +const std::string kGlobalworkspaceSize = "globalworkspace_size"; + +const std::string kGlobalworkspaceType = "globalworkspace_type"; + +const std::string kAtomicCleanTogetherMode = "0"; + +const std::string kInFixpipeSubGraph = "_in_fixipe_sub_graph"; + +const std::string kSameFixpipeNodeScope = "_fixpipe_scope"; + +const std::string kNotSupportFixpipeFusion = "_not_support_fixpipe_fusion"; + +const std::string kStaticToDynamicSoftSyncOp = "_static_to_dynamic_softsync_op"; + +const std::string kFFTSForceModeName = "_ffts_support_type"; + +const std::string kSoftsyncDynamicImpl = "_softsync_dynamic_impl"; + +const std::string kSoftSyncDynOriShape = "_softsync_dyn_ori_shape"; + +const std::string kSoftSyncDynShape = "_softsync_dyn_shape"; + +const std::string kAttrDynamicCompileStatic = "_dynamic_compile_static"; + +const std::string kAttrOpImplSwitch = "_op_impl_switch"; + +const std::string kAttrOpImplSwitchValue = "_op_impl_switch_value"; + +const std::string kSupportStridedwriteOptimize = "_support_stridedwrite_optimize"; + +const std::string kSupportStridedreadOptimize = "_support_stridedread_optimize"; + +const std::string kPhonyConcatOffset = "_phony_concat_offset"; + +const std::string kStrDynamic = "dynamic"; + +const std::string kFunc = "func"; + +const std::string kNodeSupportDynamicShape = "_node_support_dynamic_shape"; + +const std::string OP_JSON_FILE_PATH = "json_file_path"; + +const std::string ATTR_KEY_ATOMIC_KERNEL_NAME = "_atomic_kernelname"; + +constexpr char const *kPartSrcGraph = "part_src_graph"; + +constexpr char kAttrIntercoreSync[] = "_inter_core_sync"; + +constexpr char const * kMixEnhancedKernel = "_mix_with_enhanced_kernel"; + +extern const std::string kKernelName; + +extern const std::string kOpPattern; + +extern const std::string kCanNotReuseOm; + +extern const std::string AOE_TYPE; + +extern const std::string IS_NEED_TUNEFORMAT; + +extern const std::string AOE_TUNEFORMAT_REQ; + +extern const std::string AOE_TUNEFORMAT; + +extern const std::string ATTR_NAME_ATOMIC_CLEAN_NODE; + +extern const std::string ATTR_NAME_MEMSET_NODE; + +extern const std::string ATTR_NAME_ORIGINAL_NODE; + +extern const std::string ATOMIC_CLEAN_OP_TYPE; + +extern const std::string MEMSET_OP_TYPE; + +extern const std::string DYNAMIC_ATOMIC_CLEAN_OP_TYPE; + +extern const std::string TBE_OP_ATOMIC_OUTPUT_INDEX; + +extern const std::string kDeletedAtomicOutputIndex; + +extern const std::string TBE_OP_ATOMIC_WORKSPACE_INDEX; + +extern const std::string TBE_OP_ATOMIC_WSP_MODE; + +extern const std::string TBE_OP_ATOMIC_DTYPES; + +extern const std::string TBE_OP_ATOMIC_INT64_VALUES; + +extern const std::string TBE_OP_ATOMIC_FLOAT_VALUES; + +extern const std::string TBE_OP_ATOMIC_WORKSPACE_FLAG; + +extern const std::string COMPILE_INFO_JSON; + +extern const std::string COMPILE_INFO_KEY; + +extern const std::string kThreadAddrOffset; + +extern const std::string ATTR_NAME_FALLBACK_ACLNN; + +extern const std::string kDynamicTilingDependOp; + +extern const std::string kOpDfxOptions; + +extern const std::string kOpDfxBufferSize; +} // namespace fe +#endif // FUSION_ENGINE_INC_COMMON_AICORE_UTIL_ATTR_DEFINE_H_ diff --git a/inc/air/nneng/common/aicore_util_types.h b/inc/air/nneng/common/aicore_util_types.h new file mode 100644 index 0000000000000000000000000000000000000000..7034335596a41c68d03b52b1a099d199adb2a195 --- /dev/null +++ b/inc/air/nneng/common/aicore_util_types.h @@ -0,0 +1,319 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FUSION_ENGINE_INC_COMMON_AICORE_UTIL_TYPES_H_ +#define FUSION_ENGINE_INC_COMMON_AICORE_UTIL_TYPES_H_ + +#include +#include +#include +#include "graph/anchor.h" +#include "graph/op_desc.h" +#include "runtime/kernel.h" + +namespace fe { +struct FusionOpSrc { + uint32_t src_op_id; + ge::AnchorPtr src_anchor; + int32_t fusion_src_index; + int32_t fusion_dst_index; +}; + +struct FusionOpDst { + uint32_t dst_op_id; + ge::AnchorPtr dst_anchor; +}; + +struct FusionDataFlow { + std::pair edge; + std::pair node_dataindex_pair; +}; + +using L2FusionData_t = struct tag_l2_fusion_data { + uint32_t l2Index; + uint64_t l2Addr; + uint64_t l2PageNum; +}; +using L2FusionDataMap_t = std::map; + +const uint32_t L2_MAXDATANUM = 8; +using fe_sm_desc_t = struct tag_fe_sm_desc { + rtL2Ctrl_t l2ctrl; + std::string node_name[L2_MAXDATANUM]; + uint8_t output_index[L2_MAXDATANUM]; +}; + +using TaskL2FusionInfo_t = struct TagTaskL2FusionInfo { + std::string node_name; + fe_sm_desc_t l2_info; + L2FusionDataMap_t input; + L2FusionDataMap_t output; + uint32_t is_used; +}; + +using L2FusionInfoPtr = std::shared_ptr; + +using ToOpStruct_t = struct ToOpStruct { + int64_t op_l1_space = 0; + std::vector op_l1_fusion_type; + int64_t op_l1_workspace_flag = 0; // for workspace flag + int64_t op_l1_workspace_size = 0; + std::vector> slice_input_shape; + std::vector> slice_output_shape; + std::vector> + slice_input_offset; // conv & pooling & ReadSelect + std::vector> slice_output_offset; // WriteSelect + std::vector total_shape; + uint32_t split_index = 0; + ToOpStruct() { + // set invalid value for essential variable + op_l1_space = -1; + op_l1_workspace_size = -1; + } +}; + +enum class BitShift { + BIT_SHIFT_8 = 8, + BIT_SHIFT_16 = 16, + BIT_SHIFT_24 = 24, + BIT_SHIFT_32 = 32, + BIT_SHIFT_40 = 40, + BIT_SHIFT_48 = 48, +}; + +struct OpCustomizeDtype { + std::vector input_dtypes; + std::vector output_dtypes; +}; + +enum SlicePattern { + ELEMENT_WISE = 0, + ELEMENT_WISE_BROADCAST, + BROADCAST, + SLIDING_WINDOW, + SLIDING_WINDOW_DECONV, + CUBE_MATMUL, + SLICE_PATTERN_REDUCE, + SLICE_PATTERN_RESIZE, + SLICE_PATTERN_SCATTER, + SLICE_PATTERN_SEGMENT, + PATTERN_RESERVED +}; + +enum AICoreMode { + FFTS_MODE_NO_FFTS = 0, + FFTS_MODE_FFTS, + FFTS_MODE_FFTS_PLUS, + FFTS_MODE_RESERVED +}; + +enum class FFTS_SUPPORT_TYPE { + FFTS_NO_FORCE = 0, + FFTS_FORCE_MANUAL, + FFTS_FORCE_AUTO +}; + +enum OpImplType : int64_t { + EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op + EN_IMPL_CUSTOM_TIK, // custom tik op + EN_IMPL_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op + EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op + EN_IMPL_HW_TIK, // Huawei built-in tik op + EN_IMPL_HW_TBE, // Huawei built-in tbe op + EN_IMPL_RL, // RL op + EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op + EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op + EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op + EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op + EN_IMPL_HW_DSA, // Huawei built-in DSA op + EN_RESERVED, // reserved value + // sub type used for CUSTOM: sub_type = main_type | (priority << 32) + EN_SUBTYPE_RESERVED = 0xFFFFFFFFFF // reserved value +}; + +enum AOEOption { + AOE_OPT_USE_KB = 0, + AOE_OPT_NOT_USE_KB, + AOE_OPT_ONLY_USE_KB, + AOE_OPT_RESERVED +}; + +struct OptimizeConfig { + bool enable_superkernel_plus; + AOEOption aoe_option; +}; + +struct PassChangeInfo { + std::string pass_name; + int32_t recovery_times; + int32_t rollback_times; + vector recovery_scope_ids; +}; + +struct FEOpsStoreInfo { + int32_t priority; + std::string fe_ops_store_name; + OpImplType op_impl_type; + std::string cfg_file_path; + std::string op_impl_file_path; + bool need_pre_compile; + bool need_compile; + FEOpsStoreInfo() : priority(0), fe_ops_store_name(), op_impl_type(EN_RESERVED), cfg_file_path(), op_impl_file_path(), + need_pre_compile(false), need_compile(false) {} + FEOpsStoreInfo(int32_t priority_value, const std::string &ops_store_name_value, OpImplType op_impl_type_value, + const std::string &cfg_file_path_value, const std::string &op_impl_file_path_value, + bool need_pre_compile_value, bool need_compile_value) + : priority(priority_value), fe_ops_store_name(ops_store_name_value), op_impl_type(op_impl_type_value), + cfg_file_path(cfg_file_path_value), op_impl_file_path(op_impl_file_path_value), + need_pre_compile(need_pre_compile_value), need_compile(need_compile_value) {} + FEOpsStoreInfo(int32_t priority_value, const std::string &ops_store_name_value, OpImplType op_impl_type_value, + const std::string &cfg_file_path_value, const std::string &op_impl_file_path_value) + : priority(priority_value), fe_ops_store_name(ops_store_name_value), op_impl_type(op_impl_type_value), + cfg_file_path(cfg_file_path_value), op_impl_file_path(op_impl_file_path_value), + need_pre_compile(false), need_compile(false) {} +}; + +enum class ISAArchVersion { EN_ISA_ARCH_V100 = 0, EN_ISA_ARCH_V200, EN_ISA_ARCH_V220, EN_ISA_ARCH_V300, + EN_ISA_ARCH_V350 }; + +enum class CubeVecState { CUBE_VEC_UNKNOWN = 0, CUBE_VEC_DECOUPLE, CUBE_VEC_MIX }; // keep for canndev + +enum class CubeVecStateNew { CUBE_VEC_UNKNOWN = 0, CUBE_VEC_SPLIT, CUBE_VEC_FUSE }; + +enum class UBFusionType { FUSION_TYPE_UB = 0, FUSION_TYPE_MIXL2, FUSION_TYPE_NONE, FUSION_TYPE_RESERVED }; + +enum class AppendArgsMode { NO_ARGS = 0, L2_BUFFER_ARGS = 1, L2_CACHE_ARGS = 999}; + +enum BufferOptimize { EN_UNKNOWN_OPTIMIZE = 0, EN_OFF_OPTIMIZE, EN_L1_OPTIMIZE, EN_L2_OPTIMIZE }; + +enum AutoTuneMode { TUNE_MODE_NO_TUNE = 0, TUNE_MODE_AUTO_TUNE, TUNE_MODE_RL_TUNE, TUNE_MODE_AUTO_AND_RL_TUNE }; + +enum PrecisionPolicy { WHITE = 0, BLACK = 1, GRAY = 2 }; + +enum class JitCompileCfg { CFG_FALSE = 0, CFG_TRUE = 1, CFG_AUTO = 2}; + +enum OpPattern { + OP_PATTERN_OP_KERNEL = 0, + OP_PATTERN_OP_CUSTOMIZE, + OP_PATTERN_FORMAT_AGNOSTIC, + OP_PATTERN_BROADCAST, + OP_PATTERN_REDUCE, + OP_PATTERN_RANGE_AGNOSTIC, + OP_PATTERN_BROADCAST_ENHANCED +}; + +enum OpParamType { REQUIRED = 0, OPTIONAL, DYNAMIC, RESERVED }; + +enum OpConstValueDepend { CONST_IGNORE = 0, CONST_REQUIRED, CONST_OPTIONAL }; + +enum class DynamicCompileStatic { TUNE = 0, COMPILE, NOT_SUPPORT }; + +enum class DynamicRankType { NOT_SUPPORT = 0, SUPPORT, UPGRADE_TO_SUPPORT }; + +enum class VectorCoreType { DEFAULT = 0, ENABLE, DISABLE }; + +enum class JitCompile { + DEFAULT = 0, + ONLINE, // static_true, dynamic_true || true + REUSE_BINARY, // static_true, dynamic_false || false + STATIC_BINARY_DYNAMIC_ONLINE, // static_false, dynamic_true + STATIC_BINARY_DYNAMIC_BINARY // static_false, dynamic_false +}; + +enum class RangeLimitType { DEFAULT = 0, LIMITED, UNLIMITED, DYNAMIC }; + +const std::unordered_set kWeightTypes = {"Const", "Constant", "Variable"}; + +const std::unordered_set kConstTypes = {"Const", "Constant"}; + +const std::unordered_set kMixL2PassName = {"TbeConvBnreduceFusionPass", + "TbeConv2DBackpropElemwiseFusionPass"}; + +enum class CmoType { + CMO_TYPE_PREFETCH = 0, + CMO_TYPE_INVALID, + CMO_TYPE_BARRIER, + CMO_TYPE_WRITEBACK, + CMO_TYPE_BUTT +}; + +enum class CmoTypeObject { + INPUT = 0, + WEIGHT, + OUTPUT, + WORKSPACE +}; + +enum class L2CacheMode { + DEFAULT = 0, + RC, + CMO +}; + +enum class L2CacheReadMode { + RM_NONE = -1, + READ_LAST = 1, + READ_INVALID = 2, + NOT_NEED_WRITEBACK = 3 +}; + +enum class FormatModeType { + FORMAT_MODE_NZNZ = 0, + FORMAT_MODE_NDND = 1, + FORMAT_MODE_NDNZ = 2, + FORMAT_MODE_INVALID, +}; + +enum class WeightCompressType { + LOW_SPARSE_COMPRESS = 0, + HIGH_SPARSE_COMPRESS, + DISABLE_COMPRESS +}; + +enum class AclnnSupportType { + DEFAULT = 0, + SUPPORT_ACLNN, + ACLNN_ONLY +}; + +enum class MultiKernelSupportType { + DEFAULT = 0, + MULTI_SUPPORT, + SINGLE_SUPPORT +}; + +enum class ExportCompileStatType { + NONE = 0, + AFTER_EXEC_COMPLITE, + AFTER_COMPILE_COMPLITE +}; + +using CmoAttr = struct CMO_ATTR { + ge::NodePtr node; + CmoTypeObject object; + int32_t object_index; +}; + +using CmoExtraAttr = std::map>; + +inline bool IsTbe(const OpImplType& impl_type) +{ + OpImplType main_type = static_cast(impl_type & 0xFFFFFFFF); + return main_type == EN_IMPL_HW_TBE || main_type == EN_IMPL_VECTOR_CORE_HW_TBE|| main_type == EN_IMPL_CUSTOM_TBE || + main_type == EN_IMPL_VECTOR_CORE_CUSTOM_TBE || main_type == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE ; +} + +template +Ret GetMainImplType(T a) +{ + return static_cast(a & 0xFFFFFFFF); +} +} +#endif // FUSION_ENGINE_INC_COMMON_AICORE_UTIL_TYPES_H_ diff --git a/inc/air/nneng/common/base_config_parser.h b/inc/air/nneng/common/base_config_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..6d0dc3635647a25987f61a09f99bab0fd97a73b3 --- /dev/null +++ b/inc/air/nneng/common/base_config_parser.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_COMPILER_GRAPHCOMPILER_ENGINES_NNENG_INC_COMMON_BASE_CONFIG_PARSER_ +#define AIR_COMPILER_GRAPHCOMPILER_ENGINES_NNENG_INC_COMMON_BASE_CONFIG_PARSER_ + +#include +#include +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { +class BaseConfigParser { +public: + BaseConfigParser() {} + virtual ~BaseConfigParser() {} + + virtual Status InitializeFromOptions(const std::map &options) { + (void)options; + return SUCCESS; + } + virtual Status InitializeFromContext() { + return SUCCESS; + } + virtual Status InitializeFromContext(const std::string &combined_params_key) { + (void)combined_params_key; + return SUCCESS; + } +}; +using BaseConfigParserPtr = std::shared_ptr; +} // namespace fe +#endif // AIR_COMPILER_GRAPHCOMPILER_ENGINES_NNENG_INC_COMMON_BASE_CONFIG_PARSER_ diff --git a/inc/air/nneng/common/configuration.h b/inc/air/nneng/common/configuration.h new file mode 100644 index 0000000000000000000000000000000000000000..36520e80b27b6c14f24745850a93df9995af75fc --- /dev/null +++ b/inc/air/nneng/common/configuration.h @@ -0,0 +1,478 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FUSION_ENGINE_INC_COMMON_CONFIGURATION_H_ +#define FUSION_ENGINE_INC_COMMON_CONFIGURATION_H_ + +#include +#include +#include +#include +#include +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include "common/aicore_util_types.h" +#include "common/base_config_parser.h" + +namespace fe { +using std::string; +enum class CONFIG_PARAM { + SmallChannel = 0, + JitCompile, + VirtualType, + CompressWeight, + SparseMatrixWeight, + ReuseMemory, + BufferOptimize, + FormatMode, + QuantDumpable, + QuantBiasOptimize, + ExportCompileStat, + ConfigParamBottom +}; + +enum class CONFIG_STR_PARAM { + HardwareInfo = 0, + FusionLicense, + ConfigStrParamBottom +}; + +enum class ENV_STR_PARAM { + AscendOppPath = 0, + NetworkAnalysis, + DynamicImplFirst, + AscendCustomOppPath, + DumpGeGraph, + DumpGraphLevel, + NpuCollectPath, + MinCompileResourceUsageCtrl, + EnableAclnn, + AscendWorkPath, + EnableRt2, + EnvStrParamBottom +}; + +enum class CONFIG_PARSER_PARAM { + ImplMode = 0, + CustDtypes, + ModifyMixlist, + ConfigParserParamBottom +}; + +/** @brief Configuration. +* Used to manage all the configuration data within the fusion engine module. */ +class Configuration { + public: + Configuration(const Configuration &) = delete; + Configuration &operator=(const Configuration &) = delete; + + /** + * Get the Singleton Instance by engine name + */ + static Configuration &Instance(const string &engine_name); + + /** + * Initialize the content_map and ops_store_info_vector_ + * Read the content from the config file, the save the data into content_map + * Find the data about the Op Store Info from the content_map + * and build the ops_store_info_vector_ with them. + * @return Whether the object has been initialized successfully. + */ + Status Initialize(const std::map &options); + + Status InitializeConfigParser(const std::map &options); + + Status InitializeExtend(const std::map &options); + + Status Finalize(); + + /** + * Find the FEOpsStoreInfo object with the OpImplType. + * @param op_impl_type ImplType of OP storeinfo + * @param op_store_info output value. + * if the object has been found, the op_store_info will refer to this object + * @return Whether the FEOpsStoreInfo object has been found. + * Status SUCCESS:found, FAILED:not found + */ + Status GetOpStoreInfoByImplType(OpImplType op_impl_type, FEOpsStoreInfo &op_store_info) const; + + /* + * to get the OpsStoreInfo out of the current configuration object + * @return the OpsStoreInfo + */ + const std::vector &GetOpsStoreInfo() const; + + void SetOpsStoreInfo(const FEOpsStoreInfo &fe_ops_store_info); + + /* + * get small channel + */ + bool IsEnableSmallChannel() const; + + /* + * get jit compile + */ + JitCompileCfg GetJitCompileCfg() const; + + /* + * get virtualization on compute capability + */ + bool IsEnableVirtualType() const; + + /* + * get compress weight + */ + bool IsEnableCompressWeight() const; + + /* + * get quant dumpable + */ + bool IsQuantDumpable() const; + + /* + * get compress sparse_matrix + */ + bool IsEnableSparseMatrixWeight() const; + + const std::map& GetCompressRatios() const; + + const float& GetAICoreCompressRatio() const; + + /* + * to get l1fusion option out of the current configuration object + * @return true/false + */ + bool EnableL1Fusion() const; + + bool EnableL2Fusion() const; + + bool GetDuplicationSwitch() const; + + bool IsEnableNetworkAnalysis() const; + + bool IsEnableOpImplStrategy() const; + + bool IsEnableUbFusion() const; + + bool IsEnableAclnn() const; + + bool IsEnableRt2() const; + + bool IsEnableFirstLayerQuantization() const; + + bool IsEnableCustomImplMode() const; + + bool IsEnableL2Buffer() const; + /* + * to get switch switch of dump original nodes to fusion node + * @return true/false + */ + bool GetDumpOriginalNodesEnable() const; + /* + * to get switch switch of mix_l2 + * @return true/false + */ + bool GetMixL2Enable() const; + + bool IsEnableSuperkernelPlus() const; + + /* + * to get the soc version out of the current configuration object + * @return soc version + */ + const string &GetSocVersion() const; + + /** + * Get the rootdir from configuration file + * @return root_path + */ + string GetRootPath(); + + static string GetPrecisionModeStr(); + + const string& GetLicenseFusionStr() const; + + bool IsEnableReuseMemory() const; + + bool IsEnableQuantBiasOptimize() const; + + bool IsConfigDebugListOp(const ge::OpDescPtr &op_des_ptr) const; + + string GetBuiltInFusionConfigFilePath() const; + + Status GetCompatibleBuiltInPath(string &builtin_pass_file_path) const; + /** + * Get the fusionpassmgr.graphpasspath from configuration file + * @return builtin_pass_file_path + */ + Status GetBuiltinPassFilePath(string &builtin_pass_file_path); + + /** + * Get the fusionpassmgr.custompasspath from configuration file + * @return custom_pass_file_path + */ + Status GetCustomPassFilePath(string &custom_pass_file_path); + + /** + * Get the fusionrulemgr.graphfilepath from configuration file + * @return graphfilepath + */ + Status GetGraphFilePath(string &graph_file_path); + + /** + * Get the fusionrulemgr.customfilepath from configuration file + * @return customfilepath + */ + Status GetCustomFilePath(string &custom_file_path); + + void GetCompilerGraphFilePath(string &graph_file_path); + + void GetCompilerPassFilePath(string &compiler_pass_file_path); + + void GetCompilerFusionConfigFilePath(std::string &fusion_config_file_path) const; + + bool IsInLicenseControlMap(const string &key) const; + + string GetFeLibPath() const; + + int32_t GetMemReuseDistThreshold() const; + + const std::unordered_set& GetFp16OpTypeList() const; + + const std::set& GetLicenseFusionDetailInfo() const; + + const std::string& GetDumpGeGraph() const; + + const std::string& GetDumpGraphLevel() const; + + const std::string GetOpDebugConfig() const; + + bool GetOpImplMode(const string &op_name, const string &op_type, string &op_impl_mode) const; + + bool GetCustomizeDtypeByOpType(const string &op_type, OpCustomizeDtype &custom_dtype); + + bool GetCustomizeDtypeByOpName(const string &op_name, OpCustomizeDtype &custom_dtype); + + FormatModeType GetFormatModeCfg() const; + + ExportCompileStatType GetExportCompileStat() const; + + PrecisionPolicy GetPrecisionPolicy(const std::string &op_type, const PrecisionPolicy &op_kernel_policy); + + const std::map& GetBinaryPathMap() const; + + std::string GetAllOpsImplPath() const; + + std::map GetHardwareInfo() const; + + bool IsDynamicImplFirst() const; + + Status RefreshParameters(); + + bool GetMemoryCheckSwitch() const; + + const std::string& GetBinaryConfigFilePath() const; + + const std::string& GetAscendWorkPath() const; + + uint64_t GetCompileTaskTraceTimeInterval() const; + + uint64_t GetCompileTaskTraceTimeConstThreshold() const; + + private: + explicit Configuration(const string &engine_name); + ~Configuration(); + bool is_init_; + string engine_name_; + string lib_path_; + std::map content_map_; + std::vector ops_store_info_vector_; + + bool is_dynamic_impl_first_; // env + bool enable_network_analysis_; + bool enable_op_impl_strategy_; + bool enable_ub_fusion_; + bool enable_aclnn_; + bool enable_rt2_ = true; + string ascend_ops_path_; // env + + std::array(ENV_STR_PARAM::EnvStrParamBottom)> env_str_param_vec_; + std::array(CONFIG_PARAM::ConfigParamBottom)> config_param_vec_; + std::array(CONFIG_STR_PARAM::ConfigStrParamBottom)> config_str_param_vec_; + std::array, + static_cast(CONFIG_PARSER_PARAM::ConfigParserParamBottom)> config_parser_map_vec_; + std::array(CONFIG_PARSER_PARAM::ConfigParserParamBottom)> config_parser_map_mutex_vec_; + + std::set license_fusion_detail_value_; + std::map hardware_info_map_; + + bool enable_first_layer_quantization_; + string bin_cfg_file_; + std::unordered_set fp16_op_type_list_; + + int64_t op_store_priority_count_; + int32_t mem_reuse_dist_threshold_; + + std::map compress_ratios_; + float ai_core_compress_ratio_; + std::map op_binary_path_map_; + + BaseConfigParserPtr cust_dtypes_parser_; + BaseConfigParserPtr impl_mode_parser_; + BaseConfigParserPtr mix_list_parser_; + BaseConfigParserPtr op_debug_config_parse_; + std::map mix_list_parser_map_; + mutable std::mutex config_param_mutex_; + mutable std::mutex mix_list_parser_map_mutex_; + mutable std::mutex ops_store_info_vector_mutex_; + mutable std::mutex hardware_info_map_mutex_; + + void InitParamFromEnv(); + + /** + * Initialize the parameters from options + * @param options patameters map + */ + Status InitConfigParamFromOptions(const std::map &options); + + Status InitConfigParamFromContext(); + + std::string GetConfigStrParamValueFromContext(CONFIG_STR_PARAM config_str_param_enum_type) const; + + int64_t GetConfigParamValueFromContext(CONFIG_PARAM config_param_enum_type) const; + + std::string CombinedParamsKeyFromOptions(CONFIG_PARSER_PARAM config_parser_param_enum_type, + const std::map &options) const; + + std::string CombinedParamsKeyFromContext(CONFIG_PARSER_PARAM config_parser_param_enum_type) const; + + BaseConfigParserPtr GetConfigParserFromContext(CONFIG_PARSER_PARAM config_parser_param_enum_type, + const std::string &combined_params_key); + + Status RefreshImplMode(); + + void RefreshSmallChannel(); + + Status InitDebugConfigParam(); + + Status RefreshCustDtypes(); + + Status RefreshMixList(); + + /** + * Get the real Path of current so lib + */ + Status InitLibPath(); + + /** + * Get the real Path of ops + * path of ops is the path of so package + ops_relative_path + */ + void InitCustomOpStore(); + + Status InitAscendOpsPath(); + + bool IsPathExistedInOpp(const std::string &path, bool is_full_path) const; + + void ResolveBinaryPath(const std::string &sub_path, const std::string &path_type, const int64_t main_impl_type, + bool isOm, const std::string &binaryKey); + + bool AddCustomOpStoreContent(const std::string &full_or_sub_path, const std::string &path_type, + const int64_t main_impl_type, const bool is_full_path); + + bool CheckIsValidAbstractPath(const std::string &path) const; + + Status LoadOppConfigFile(); + + Status GetCustomOppPathFromOppConfigFile(std::vector &custom_opp_path_vec) const; + + /** + * Read the content of configuration file(FE_CONFIG_FILE_PATH) + * Save the data into content_map + * @return Whether the config file has been loaded successfully. + */ + Status LoadConfigFile(); + + /** + * Find the OpsStoreInfo from the content_map, + * then use the data to build up ops_store_info_vector. + * @return Whether the OpsStoreInfoVector has been built up successfully. + */ + Status AssembleOpsStoreInfoVector(); + + Status AssembleEachOpsStoreInfo(string &op_store_name, std::vector &op_store_vector, + FEOpsStoreInfo &ops_store_info); + + Status VerifyOpStoreVector(std::vector &op_store_vector, const string &op_store_name) const; + + bool IsIgnoreOpStore(const FEOpsStoreInfo &ops_store_info) const; + + Status CheckOpStoreInfo(const FEOpsStoreInfo &op_store_info) const; + + /** + * Check whether the content_map contain the input key. + * @param key + * @return Whether the content_map contain the input key. + */ + bool ContainKey(const string &key) const; + + /** + * Get the value from the content_map if the content_map contains the input key. + * @param key config key + * @param return_value output value. if the value has been found, + * return_value will refer to this value. + * @return Whether the vale has been found. + * Status SUCCESS:found, FAILED:not found + */ + Status GetStringValue(const string &key, string &return_value) const; + + /** + * Find the value from the content_map by the input key, + * convert the value to bool type and return the bool value + * return the input default value if the value is not found. + * @param key + * @param default_value + * This value will be returned if the input key can be found in content_map. + * @return bool value + */ + bool GetBoolValue(const string &key, bool default_value) const; + + BufferOptimize GetBufferOptimize() const; + + void SetEnableSmallChannel(const bool &small_channel); + + void InitParametersOfConfigFile(); + + void InitISAArchVersion(); + + int32_t ParseDataVisitDistThreshold() const; + + void InitMemReuseDistThreshold(); + + void InitFp16OpType(); + + void InitCompressRatio(); + + void ParseHardwareInfo(); + + void ParseFusionLicense(const bool &is_config); + + Status ParseVirtualType(); + + void InitBinaryConfigFilePath(); + + std::vector ParseConfig(const string &key, char pattern) const; + + bool InitFirstLayerQuantization(const std::map &options); + + bool GetConfigValueByKey(const std::map &options, const string &file_key, + const string &cfg_key, string &value, string &file_path) const; +}; +} // namespace fe +#endif // FUSION_ENGINE_INC_COMMON_CONFIGURATION_H_ diff --git a/inc/air/nneng/common/fe_error_code.h b/inc/air/nneng/common/fe_error_code.h new file mode 100644 index 0000000000000000000000000000000000000000..7a0791c3b391d1071571169565058a6f512d9282 --- /dev/null +++ b/inc/air/nneng/common/fe_error_code.h @@ -0,0 +1,141 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FUSION_ENGINE_INC_COMMON_FE_ERROR_CODE_H_ +#define FUSION_ENGINE_INC_COMMON_FE_ERROR_CODE_H_ + +#include + +namespace fe { + +/* args list for error message */ +const std::string PARAMETER0 = "parameter0"; +const std::string PARAMETER1 = "parameter1"; +const std::string PARAMETER2 = "parameter2"; +const std::string EM_VALUE = "value"; +const std::string EM_OPTION = "option"; +const std::string EM_AICORE_NUM = "ai_core_num"; +const std::string EM_NEW_OP = "new_op"; +const std::string EM_SRC_OP = "src_op"; +const std::string EM_SRC_FORMAT = "src_format"; +const std::string EM_DEST_OP = "dest_op"; +const std::string EM_DEST_FORMAT = "dest_format"; +const std::string EM_GRAPH_NAME = "graph_name"; +const std::string EM_FILE = "file"; +const std::string EM_ERROR_MSG = "errmsg"; +const std::string EM_PARAM = "parameter"; +const std::string EM_OP_NAME = "op_name"; +const std::string EM_OP_TYPE = "op_type"; +const std::string EM_GRAPH_ID = "graph_id"; +const std::string EM_THREAD_ID = "thread_id"; +const std::string EM_TASK_ID = "task_id"; +const std::string EM_PASS_NAME = "pass_name"; +const std::string EM_PASS_TYPE = "pass_type"; +const std::string EM_ATTR_OR_PATTERN_NAME = "attr_or_pattern_name"; +const std::string EM_OPS_STORE_NAME = "ops_store_name"; +const std::string EM_CORRECT_VALUE = "correct_value"; +const std::string EM_INDEX = "correct_value"; +const std::string EM_COMMON_ERROR = "common_error"; +const std::string EM_ORIGIN_DTYPE = "origin_dtype"; +const std::string EM_ENV = "env"; +const std::string EM_SITUATION = "situation"; + +/* Input parameter[--%s]'s value is empty! + * parameter + * */ +const std::string EM_INPUT_PARAM_EMPTY = "E10004"; + +const std::string EM_INNER_ERROR = "E29999"; +const std::string EM_INNER_WARN = "W29999"; + +/* Failed to compile Op [%s]. (optype: [%s]) + * parameter0,parameter1 + * */ +const std::string EM_COMPLIE_FAILED = "E20001"; + +/* Value [%s] for environment variable [%s] is invalid when %s. + * parameter0,parameter1,parameter2 + * */ +const std::string EM_ENVIRONMENT_VARIABLE_FAILED = "E20002"; + +/* Configuration file [%s] for parameter [%s] has invalid content. Reason: [%s] + * parameter0,parameter1,parameter2 + * */ +const std::string EM_INVALID_CONTENT = "E20003"; + +/* + * Failed to run graph fusion pass [%s]. The pass type is [%s] + * parameter0,parameter1 + * */ +const std::string EM_RUN_PASS_FAILED = "E20007"; + +/* Value [%s] for parameter [%s] is invalid. Reason: %s. + * parameter0,parameter1,parameter2 + * */ +const std::string EM_INPUT_OPTION_INVALID = "E20101"; + +/* Value [%s] for parameter [aicore_num] is invalid. The value must be in the range of (0, %s] + * parameter0,parameter1 + * */ +const std::string EM_AICORENUM_OUT_OF_RANGE = "E20103"; + +/* Failed to open file [%s]. + * parameter0 + * */ +const std::string EM_OPEN_FILE_FAILED = "E21001"; + +/* Failed to read file [%s]. Reason: %s. + * parameter0,parameter1 + * */ +const std::string EM_READ_FILE_FAILED = "E21002"; + +const std::string EM_FAILED_TO_TOPO_SORTING = "E2100C"; + +const std::string EM_GRAPH_FUSION_FAILED = "E2100D"; + +const std::string EM_GRAPH_PASS_OWNER_INVALID = "E2100E"; + +const std::string EM_TAG_NO_CONST_FOLDING_FAILED = "E2100F"; + +const std::string EM_COMMON_NULL_PTR = "E21010"; + +const std::string EM_INVALID_IMPLEMENTATION = "E21011"; + +const std::string EM_INNER_ERROR_1 = "E21012"; + +const std::string EM_INVALID_OUTPUT_NAME_INDEX = "E21013"; + +const std::string EM_INVALID_TENSOR_NAME_INDEX = "E21014"; + +const std::string EM_FAILED_TO_ASSEMBLE_TBE_INFO = "E21015"; + +const std::string EM_FAILED_TO_ASSEMBLE_INPUT_INFO = "E21016"; + +const std::string EM_FAILED_TO_ASSEMBLE_OUTPUT_INFO = "E21017"; + +const std::string EM_INVALID_TENSOR_DATA_TYPE = "E21018"; + +const std::string EM_INVALID_ATTR_DATA_TYPE = "E21019"; + +const std::string EM_SELECT_OP_FORMAT = "E21019"; + +const std::string EM_FAILED_TO_PARSE_FORMAT_JSON = "E2101A"; + +const std::string EM_FAILED_TO_CONVERT_FORMAT_FROM_JSON = "E2101B"; + +const std::string EM_FORMAT_VECTOR_SIZE_INVALID = "E2101C"; + +const std::string EM_INVALID_FORMAT_IN_JSON = "E2101D"; + +const std::string EM_INVALID_DTYPE_IN_JSON = "E2101E"; + +const std::string EM_ORIGINAL_DATATYPE_IS_NOT_SUPPORTED = "E21020"; +} + +#endif // FUSION_ENGINE_INC_COMMON_FE_ERROR_CODE_H_ diff --git a/inc/air/nneng/common/fe_log.h b/inc/air/nneng/common/fe_log.h new file mode 100644 index 0000000000000000000000000000000000000000..a2cf6d971efae599e5d3e18f0296af51793947a5 --- /dev/null +++ b/inc/air/nneng/common/fe_log.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FUSION_ENGINE_INC_COMMON_FE_LOG_H_ +#define FUSION_ENGINE_INC_COMMON_FE_LOG_H_ + +#include +#include +#include +#include +#include +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include "toolchain/slog.h" +#include "common/util/error_manager/error_manager.h" +#include "common/fe_error_code.h" + +/** Assigned FE name in log */ +const std::string FE_MODULE_NAME = "FE"; +inline uint64_t FeGetTid() { + thread_local static uint64_t tid = static_cast(syscall(__NR_gettid)); + return tid; +} + +#define D_FE_LOGD(MOD_NAME, fmt, ...) dlog_debug(FE, "%lu %s:" #fmt "\n", FeGetTid(), __FUNCTION__, ##__VA_ARGS__) +#define D_FE_LOGI(MOD_NAME, fmt, ...) dlog_info(FE, "%lu %s:" #fmt "\n", FeGetTid(), __FUNCTION__, ##__VA_ARGS__) +#define D_FE_LOGW(MOD_NAME, fmt, ...) dlog_warn(FE, "%lu %s:" #fmt "\n", FeGetTid(), __FUNCTION__, ##__VA_ARGS__) +#define D_FE_LOGE(MOD_NAME, fmt, ...) dlog_error(FE, "%lu %s:" #fmt "\n", FeGetTid(), __FUNCTION__, ##__VA_ARGS__) + +#define FE_LOGDEBUG(...) D_FE_LOGD(FE_MODULE_NAME, __VA_ARGS__) +#define FE_LOGI(...) D_FE_LOGI(FE_MODULE_NAME, __VA_ARGS__) +#define FE_LOGW(...) D_FE_LOGW(FE_MODULE_NAME, __VA_ARGS__) +#define FE_LOGE(...) D_FE_LOGE(FE_MODULE_NAME, __VA_ARGS__) + +constexpr const int FE_MAX_LOG_SIZE = 8192; +constexpr const int FE_MSG_LEN = 1024; +constexpr const size_t FE_MSG_MAX_LEN = FE_MSG_LEN - 200; + +#define FE_LOGD(...) \ + do { \ + if (CheckLogLevel(static_cast(FE), DLOG_DEBUG) != 1) { \ + break; \ + } \ + char msgbuf[FE_MAX_LOG_SIZE]; \ + if (snprintf_s(msgbuf, sizeof(msgbuf), sizeof(msgbuf) - 1, ##__VA_ARGS__) == -1) { \ + msgbuf[sizeof(msgbuf) - 1] = '\0'; \ + } \ + size_t msg_len = std::strlen(msgbuf); \ + if (msg_len < FE_MSG_MAX_LEN) { \ + FE_LOGDEBUG(__VA_ARGS__); \ + break; \ + } \ + char *msg_chunk_begin = msgbuf; \ + char *msg_chunk_end = nullptr; \ + while (msg_chunk_begin < msgbuf + msg_len) { \ + if (msg_chunk_begin[0] == '\n') { \ + FE_LOGDEBUG(""); \ + msg_chunk_begin += 1; \ + continue; \ + } \ + msg_chunk_end = std::strchr(msg_chunk_begin, '\n'); \ + if (msg_chunk_end == nullptr) { \ + msg_chunk_end = msg_chunk_begin + std::strlen(msg_chunk_begin); \ + } \ + while (msg_chunk_end > msg_chunk_begin) { \ + std::string msg_chunk(msg_chunk_begin, \ + std::min(FE_MSG_MAX_LEN, static_cast(msg_chunk_end - msg_chunk_begin))); \ + FE_LOGDEBUG("%s", msg_chunk.c_str()); \ + msg_chunk_begin += msg_chunk.size(); \ + } \ + msg_chunk_begin += 1; \ + } \ + } while (0) + +#define IsDebugLogLevel (CheckLogLevel(FE, DLOG_DEBUG) == 1) + +#define FE_LOGD_IF(cond, ...) \ + if ((cond)) { \ + FE_LOGD(__VA_ARGS__); \ + } + +#define FE_LOGI_IF(cond, ...) \ + if ((cond)) { \ + FE_LOGI(__VA_ARGS__); \ + } + +#define FE_LOGW_IF(cond, ...) \ + if ((cond)) { \ + FE_LOGW(__VA_ARGS__); \ + } + +#define FE_LOGE_IF(cond, ...) \ + if ((cond)) { \ + FE_LOGE(__VA_ARGS__); \ + } + +#define FE_CHECK(cond, log_func, return_expr) \ + do { \ + if (cond) { \ + log_func; \ + return_expr; \ + } \ + } while (0) + +// If failed to make_shared, then print log and execute the statement. +#define FE_MAKE_SHARED(exec_expr0, exec_expr1) \ + do { \ + try { \ + exec_expr0; \ + } catch (...) { \ + FE_LOGE("Make shared failed"); \ + exec_expr1; \ + } \ + } while (0) + +#define FE_CHECK_NOTNULL(val) \ + do { \ + if ((val) == nullptr) { \ + FE_LOGE("Parameter[%s] must not be null.", #val); \ + return fe::PARAM_INVALID; \ + } \ + } while (0) + +#define FE_CHECK_NOTNULL_WARNLOG(val) \ + do { \ + if ((val) == nullptr) { \ + FE_LOGW("Parameter[%s] must not be null.", #val); \ + return fe::PARAM_INVALID; \ + } \ + } while (0) + +#define REPORT_FE_ERROR(fmt, ...) \ + do { \ + REPORT_INNER_ERROR(EM_INNER_ERROR, fmt, ##__VA_ARGS__); \ + FE_LOGE(fmt, ##__VA_ARGS__); \ + } while (0) + +#define REPORT_FE_WARN(fmt, ...) \ + do { \ + REPORT_INNER_ERROR(EM_INNER_WARN, fmt, ##__VA_ARGS__); \ + FE_LOGW(fmt, ##__VA_ARGS__); \ + } while (0) +#endif // FUSION_ENGINE_INC_COMMON_FE_LOG_H_ diff --git a/inc/air/nneng/common/lxfusion_json_util.h b/inc/air/nneng/common/lxfusion_json_util.h new file mode 100644 index 0000000000000000000000000000000000000000..b6872b1938c94fa396bb46eebb6d947f384bc03e --- /dev/null +++ b/inc/air/nneng/common/lxfusion_json_util.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FUSION_ENGINE_INC_COMMON_LXFUSION_JSON_UTIL_H_ +#define FUSION_ENGINE_INC_COMMON_LXFUSION_JSON_UTIL_H_ + +#include +#include "common/aicore_util_types.h" +#include "common/aicore_util_attr_define.h" +#include "graph_optimizer/fusion_common/op_slice_info.h" +#include "graph_optimizer/graph_optimize_register_error_codes.h" +#include "graph/compute_graph.h" + +void from_json(const nlohmann::json& json_value, rtSmData_t& rtSmData); +void from_json(const nlohmann::json& json_value, rtSmDesc_t& rtSmDesc); +void to_json(nlohmann::json& json_value, const rtSmData_t& rtSmData); +void to_json(nlohmann::json& json_value, const rtSmDesc_t& rtSmDesc); + +namespace fe { +using ToOpStructPtr = std::shared_ptr; +using L2FusionInfoPtr = std::shared_ptr; +extern const std::string L1_FUSION_TO_OP_STRUCT; +extern const std::string L2_FUSION_TO_OP_STRUCT; + +Status GetL1InfoFromJson(ge::OpDescPtr op_desc_ptr); + +Status GetL2InfoFromJson(ge::OpDescPtr op_desc_ptr); + +Status GetTaskL2FusionInfoFromJson(ge::OpDescPtr op_desc_ptr); + +Status ReadGraphInfoFromJson(ge::ComputeGraph &graph); + +Status WriteGraphInfoToJson(ge::ComputeGraph &graph); + +Status ReadOpSliceInfoFromJson(ge::ComputeGraph &graph); + +Status WriteOpSliceInfoToJson(ge::ComputeGraph &graph); + +void SetOpSliceInfoToJson(fe::OpCalcInfo &op_calc_info, std::string &op_calc_info_str); + +void SetFusionOpSliceInfoToJson(fe::OpCalcInfo &op_calc_info, std::string &op_calc_info_str); + +void GetOpSliceInfoFromJson(fe::OpCalcInfo &op_calc_info, std::string &op_calc_info_str); + +void GetFusionOpSliceInfoFromJson(fe::OpCalcInfo &op_calc_info, std::string &op_calc_info_str); + +void GetL2ToOpStructFromJson(ge::OpDescPtr &op_desc_ptr, ToOpStructPtr &l2_info_ptr); + +void GetL1ToOpStructFromJson(ge::OpDescPtr &op_desc_ptr, ToOpStructPtr &l1_info_ptr); + +void GetStridedToOpStructFromJson(ge::OpDescPtr& op_desc_ptr, ToOpStructPtr& strided_info_ptr, + const string& pointer_key, const string& json_key); + +L2FusionInfoPtr GetL2FusionInfoFromJson(ge::OpDescPtr &op_desc_ptr); + +void SetL2FusionInfoToNode(ge::OpDescPtr &op_desc_ptr, L2FusionInfoPtr &l2_fusion_info_ptr); + +void SetStridedInfoToNode(ge::OpDescPtr& op_desc_ptr, ToOpStructPtr& strided_info_ptr, + const string& pointer_key, const string& json_key); + +void to_json(nlohmann::json& json_value, const InputReduceInfo& input_reduce_info); + +void to_json(nlohmann::json& json_value, const OutputReduceInfo& output_reduce_info); + +void to_json(nlohmann::json& json_value, const fe_sm_desc_t& fe_sm_desc); + +void to_json(nlohmann::json& json_value, const InputSplitInfo& input_split_info); + +void to_json(nlohmann::json& json_value, const OutputSplitInfo& output_split_info); + +void to_json(nlohmann::json& json_value, const ToOpStruct_t& op_struct); + +void to_json(nlohmann::json& json_value, const AxisSplitMap& axis_split_map); + +void to_json(nlohmann::json& json_value, const AxisReduceMap& axis_reduce_map); + +void to_json(nlohmann::json& json_value, const OpCalcInfo& op_calc_info); + +void to_json(nlohmann::json& json_value, const TaskL2FusionInfo_t& task_l2_fusion_info); + +void to_json(nlohmann::json& json_value, const L2FusionData_t& l2_fusion_data); + +void from_json(const nlohmann::json& json_value, L2FusionData_t& l2_fusion_data); + +void from_json(const nlohmann::json& json_value, AxisReduceMap& axis_reduce_map); + +void from_json(const nlohmann::json& json_value, InputReduceInfo& input_reduce_info); + +void from_json(const nlohmann::json& json_value, ToOpStruct_t& op_struct); + +void from_json(const nlohmann::json& json_value, OpCalcInfo& op_calc_info); + +void from_json(const nlohmann::json& json_value, OutputReduceInfo& output_reduce_info); + +void from_json(const nlohmann::json& json_value, fe_sm_desc_t& fe_sm_desc); + +void from_json(const nlohmann::json& json_value, AxisSplitMap& axis_split_map); + +void from_json(const nlohmann::json& json_value, InputSplitInfo& input_split_info); + +void from_json(const nlohmann::json& json_value, TaskL2FusionInfo_t& task_l2_fusion_info); + +void from_json(const nlohmann::json& json_value, OutputSplitInfo& output_split_info); +} // namespace fe +#endif // FUSION_ENGINE_INC_COMMON_LXFUSION_JSON_UTIL_H_ diff --git a/inc/aoe/acg/aoe_error_manager.h b/inc/aoe/acg/aoe_error_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..1a520cdaabad3fe20808bc770d248a5d8678de7e --- /dev/null +++ b/inc/aoe/acg/aoe_error_manager.h @@ -0,0 +1,48 @@ +/** + * @file aoe_error_manager.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef AOE_ERROR_MANAGER_H +#define AOE_ERROR_MANAGER_H + +#include +#include +#include "common/util/error_manager/error_manager.h" +#include "aoe_types.h" + +namespace Aoe { +using ErrorContext = error_message::Context; + +class AoeErrorManager { +public: + static AoeErrorManager &GetInstance(); + AoeStatus AoeInitErrorManager(const std::string &path); + AoeStatus AoeReportInterErrorMessage(const std::string &errorCode, const std::string &errorMsg) const; + AoeStatus AoeReportErrorMessage(const std::string &errorCode, + const std::map &argsMap) const; + std::string AoeGetErrorMessage() const; + AoeStatus AoeOutputErrorMessage(int32_t handle) const; + std::string AoeGetWarningMessage() const; + AoeStatus AoeOutputMessage(int32_t handle) const; + ErrorContext &AoeGetErrorContext() const; + void AoeSetErrorContext(const ErrorContext &context) const; + void AoeGenWorkStreamIdDefault() const; +private: + AoeErrorManager() {} + ~AoeErrorManager() {} + AoeErrorManager(const AoeErrorManager &) = delete; + AoeErrorManager(AoeErrorManager &&) = delete; + AoeErrorManager &operator=(const AoeErrorManager &) = delete; + AoeErrorManager &operator=(AoeErrorManager &&) = delete; + bool isInit_{false}; +}; +} + +#endif diff --git a/inc/aoe/acg/progress.h b/inc/aoe/acg/progress.h new file mode 100644 index 0000000000000000000000000000000000000000..2202b0ee3968c83bd27fc2a912baadc11e1e2d0e --- /dev/null +++ b/inc/aoe/acg/progress.h @@ -0,0 +1,159 @@ +/** + * @file progress.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ +#ifndef ACG_PROGRESS_H +#define ACG_PROGRESS_H +#include +#include +#include +#include +#include + +namespace Acg { +constexpr float MAX_PROGRESS_WEIGHT = 100; +constexpr float MID_PROGRESS_VALUE = 50; +constexpr float MAX_PROGRESS_VALUE = 100; +constexpr float DEFAULT_PROGRESS_VALUE = 0; + +enum class Status { + SUCCESS, + INVALID_PARAMETER, + NOT_FOUND_MARK, + OUT_OF_WEIGHT_RANGE, + OUT_OF_SIZE_LIMIT, +}; + +using ProgressReportFunc = std::function; + +struct ProgressAttr { + ProgressAttr() : weight(MAX_PROGRESS_WEIGHT), value(DEFAULT_PROGRESS_VALUE) {}; + float weight; + float value; +}; + +class Progress : public std::enable_shared_from_this { +public: + explicit Progress(const std::string &tag) + : selfTag_(tag), prevTag_(""), progressFunc_(nullptr) {} + ~Progress(); + /** + * @brief set progress report function + * @param [in] func : progress report function + * @return success == SUCCESS; failed != SUCCESS + */ + Status SetProgressReport(const ProgressReportFunc &func); + + /** + * @brief set progress attribute + * @param [in] attr : progress attribute + * @return success == SUCCESS; failed != SUCCESS + */ + Status SetAttr(const ProgressAttr &attr); + + /** + * @brief associate previous progress + * @param [in] prevTag : previous progress mark + * @return success == SUCCESS; failed != SUCCESS + */ + Status AssociateProgress(const std::string &prevTag); + + /** + * @brief report progress value and message + * @param [in] msg : progress message + * @param [in] value : progress value + * @return success == SUCCESS; failed != SUCCESS + */ + Status ReportProgress(const std::string &msg, float value = 0.0); + + /** + * @brief get progress mark + * @return progress mark + */ + const std::string GetSelfTag() const; + + /** + * @brief get progress value + * @return 0~100 + */ + const ProgressAttr GetProgressAttr() const; + +private: + /** + * @brief caculate progress + * @param [in] value : progress value + * @return success is [0.00, 100.00]; other is failed + */ + float CaculateProgress(float value); + + /** + * @brief add sub progress + * @param [in] subProgress : sub progress + * @return success == SUCCESS; failed != SUCCESS + */ + Status AddSubProgress(const std::shared_ptr &subProgress); + + std::string selfTag_; + std::string prevTag_; + ProgressReportFunc progressFunc_; + ProgressAttr attr_; + std::unordered_set> subProgressSet_; + std::recursive_mutex rmtx_; + std::mutex mtx_; +}; + +using ProgressPtr = std::shared_ptr; + +/** + * @brief get progress of tag + * @param [in] tag : progress mark + * @return success != nullptr; failed == nullptr + */ +ProgressPtr GetProgress(const std::string &tag); + +/** + * @brief create progress by tag + * @param [in] tag : progress mark + * @return success != nullptr; failed == nullptr + */ +ProgressPtr CreateProgress(const std::string &tag); + +/** + * @brief associate progress of tag to previous tag + * @param [in] tag : progress mark + * @param [in] prevtag : previous progress mark + * @return success == SUCCESS; failed != SUCCESS + */ +Status AssociatedProgress(const std::string &tag, const std::string &prevTag); + +/** + * @brief destroy progress of tag + * @param [in] tag : progress mark + * @return None + */ +void DestroyProgress(const std::string &tag); + +/** + * @brief report progress of tag + * @param [in] tag : progress mark + * @param [in] msg : progress update message + * @param [in] value : progress value + * @return success == SUCCESS; failed != SUCCESS + */ +Status ReportProgress(const std::string &tag, const std::string &msg, float value = 0.0); + +/** + * @brief display progress on the screen + * @param [in] msg : progress message + * @param [in] value : progress value + * @return success == SUCCESS; failed != SUCCESS + */ +void DisplayProgressOnScreen(const std::string &msg, float value); +} +#endif diff --git a/inc/aoe/acg/report_manager.h b/inc/aoe/acg/report_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..6ee15f3544ee6b66082721764692a8a6f80aad3e --- /dev/null +++ b/inc/aoe/acg/report_manager.h @@ -0,0 +1,135 @@ +/** + * @file report_manager.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef AOE_REPORT_MANAGER_H +#define AOE_REPORT_MANAGER_H + +#include +#include +#include "nlohmann/json.hpp" +#include "aoe_types.h" + +namespace Aoe { +class Report; +using ReportPtr = std::shared_ptr; + +class ReportManager { +public: + ReportManager() = default; + ~ReportManager() = default; + /** + * @brief : Obtain ReportManager instance + * @return : ReportManager instance + */ + static ReportManager &GetInstance() + { + static ReportManager instance; + return instance; + } + + /** + * @brief : create report object + * filePath can be empty, single report cannot be saved at this time + * @param [in] : tag tag, unified use job_id + * @param [in] : filePath file Path to save + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus CreateReport(const std::string &tag, const std::string &filePath); + + /** + * @brief : create report object + * filePath can be empty, single report cannot be saved at this time + * @param [in] : tag tag, unified use job_id + * @param [in] : filePath file Path to save + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus CreateReportWithBasicMsg(const std::string &tag, const std::string &filePath, + const nlohmann::json &object, const uint32_t priority); + + /** + * @brief : submit report messsage + * @param [in] : tag tag, unified use job_id + * @param [in] : object report messsage object + * @param [in] : priority report messsage priority, limit: 0-100 + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus SubmitReport(const std::string &tag, const nlohmann::json &object, const uint32_t priority); + + /** + * @brief : save report file + * @param [in] : tag tag, unified use job_id + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus SaveReport(const std::string &tag); + + /** + * @brief : Updata Report Path to the specified path + * filePath can be empty, report will save to current dir/aoe_result.json + * @param [in] : filePath file Path to save + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + void UpdateReportPath(const std::string &filePath); + + /** + * @brief : summarize all completed reports to the specified path + * filePath can be empty, report will save to current dir/aoe_result.json + * @param [in] : filePath file Path to save + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus SummaryReports(); + + /** + * @brief : clear all completed reports + * @return : AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + void ClearReports(); + + /** + * @brief update report of tag + * @param [in] tag : report mark + * @param [in] msg : report update message + * @return AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus UpdateReport(const std::string &tag, const nlohmann::json &reportJson, const uint32_t priority); + + /** + * @brief update report result of tag + * @param [in] tag : report mark + * @param [in] msg : report update message + * @return AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus UpdateReportResult(const std::string &tag, const std::string &tuneResult, const uint32_t priority); + /** + * @brief change result of tag + * @param [in] tag : report mark + * @return AOE_SUCCESS: == 0; AOE_FAILURE: != 0 + */ + AoeStatus ChangeReportResult(const std::string &tag); + + /** + * @brief : Check whether new repo is added or updated. + * @param [in] : tag tag, unified use job_id + * @param [in] : priority report messsage priority, limit: 0-100 + * @return : true or flase + */ + bool IsUpdateRepo(const std::string &tag, const uint32_t priority); + +private: + std::mutex mgrMtx_; + std::map runtimeReports_; + std::multimap historyReports_; + std::string resultPath_; +}; + +constexpr uint32_t MAX_PRIORITY = 100; // 90-100 reserved for aoe framework +constexpr uint32_t MAX_PLUGIN_PRIORITY = 90; // 0-90 reserved for aoe plug-in +} +#endif \ No newline at end of file diff --git a/inc/aoe/acg/trace.h b/inc/aoe/acg/trace.h new file mode 100644 index 0000000000000000000000000000000000000000..e21bc5c2a552664d1704ac67c5dc33d96a2347c4 --- /dev/null +++ b/inc/aoe/acg/trace.h @@ -0,0 +1,48 @@ +/** + * @file trace.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef AOE_TRACE_H +#define AOE_TRACE_H +#include + +namespace Aoe { +class Trace { +public: + explicit Trace(const std::string &name) + : start_(true), stop_(false), name_(name) + { + Start(); + } + explicit Trace(const std::string &name, const std::string &pidTag, const std::string &tidTag) + : start_(true), stop_(false), name_(name), pidTag_(pidTag), tidTag_(tidTag) + { + Start(); + } + explicit Trace(const std::string &name, bool start) : start_(start), stop_(false), name_(name) + { + Start(); + } + ~Trace() + { + Stop(); + } + static void StartDumpTraceToPath(const std::string &recordPath); + void Start(); + void Stop(); +private: + bool start_; + bool stop_; + std::string name_; + std::string pidTag_; + std::string tidTag_; +}; +} +#endif diff --git a/inc/aoe/aoe_executor.h b/inc/aoe/aoe_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..6b37c8e625c158f302b6039429870a27f67fcc27 --- /dev/null +++ b/inc/aoe/aoe_executor.h @@ -0,0 +1,351 @@ +/** + * @file aoe_executor.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ + +/** @defgroup aoe executor */ +#ifndef AOE_EXECUTOR_H +#define AOE_EXECUTOR_H +#include +#include +#include +#include +#include "graph/graph.h" +#include "graph/tuning_utils.h" +#include "aoe_types.h" + +namespace Aoe { +using ExecuteSession = uint64_t; +using ExecuteBufferData = struct AoeDataInfo; +using ExecuteResult = struct RunnerResult; +using ExecuteOnlineResult = struct RunnerRunResult; + +enum class CompileType { + COMPILE_TYPE_BASIC = 0, + COMPILE_TYPE_SPLIT, + COMPILE_TYPE_TUNING, + COMPILE_TYPE_TUNED, + COMPILE_TYPE_NORMAL +}; + +enum class ExecutorType { + EXE_DEFAULT = 0, // follow initialize executor type + EXE_OFFLINE_COMPILE = 1 << 0, // support offline compiler + EXE_OFFLINE_RUN = 1 << 1, // support offline runner + EXE_ONLINE_COMPILE = 1 << 2, // support online compiler + EXE_ONLINE_RUN = 1 << 3, // support online runner + EXE_REMOTE_COMPILE = 1 << 4, // use remote compiler + EXE_REMOTE_RUN = 1 << 5, // use remote runner + EXE_REMOTE_COMPILE_RUN = 1 << 6, // remote support build and run together + EXE_QTEST_RUN = 1 << 7, // support qtest runner + EXE_ACLNN_RUN = 1 << 8, // support aclnn runner +}; + +struct ExecuteInitConfig { + uint32_t executorType = static_cast(ExecutorType::EXE_DEFAULT); // executor type + std::map options; // init option +}; + +struct OfflineRunParam { + bool isProf = true; + uint32_t loop; + uint64_t takeTimeUpperBound = std::numeric_limits::max(); + std::string dynamicType; + std::string dynamicValue; + std::string failedFilePath; + std::string buildRunKey; + std::vector input; + std::vector output; + ExecuteResult costTime; + RunMode runMode = RunMode::DEFAULT_RUN_MODE; +}; + +struct OnlineRunParam { + bool isProf = true; + uint32_t loop; + uint64_t takeTimeUpperBound = std::numeric_limits::max(); + ge::Session *session = nullptr; + std::map options; + std::vector> inputs; + std::vector> outputs; + std::vector dependGraph; + std::vector costTime; + RunMode runMode = RunMode::DEFAULT_RUN_MODE; +}; + +constexpr uint32_t EXE_OFFLINE_EXECUTOR = static_cast(ExecutorType::EXE_OFFLINE_COMPILE) | + static_cast(ExecutorType::EXE_OFFLINE_RUN); +constexpr uint32_t EXE_ONLINE_EXECUTOR = static_cast(ExecutorType::EXE_ONLINE_COMPILE) | + static_cast(ExecutorType::EXE_ONLINE_RUN); +constexpr uint32_t EXE_REMOTE_EXECUTOR = static_cast(ExecutorType::EXE_REMOTE_COMPILE) | + static_cast(ExecutorType::EXE_REMOTE_RUN); +constexpr uint32_t EXE_REMOTE_COMPILE_RUN_EXECUTOR = EXE_REMOTE_EXECUTOR | + static_cast(ExecutorType::EXE_REMOTE_COMPILE_RUN); +constexpr uint32_t EXE_QTEST_EXECUTOR = static_cast(ExecutorType::EXE_QTEST_RUN); +constexpr uint32_t EXE_ACLNN_EXECUTOR = static_cast(ExecutorType::EXE_ACLNN_RUN); + +const std::map> BUILD_MAP = { + {CompileType::COMPILE_TYPE_BASIC, { + {ge::BUILD_MODE, ge::BUILD_MODE_BASELINE}} + }, + {CompileType::COMPILE_TYPE_NORMAL, { + {ge::BUILD_MODE, ge::BUILD_MODE_NORMAL}} + }, + {CompileType::COMPILE_TYPE_SPLIT, { + {ge::BUILD_MODE, ge::BUILD_MODE_TUNING}, + {ge::BUILD_STEP, ge::BUILD_STEP_BEFORE_UB_MATCH}} + }, + {CompileType::COMPILE_TYPE_TUNING, { + {ge::BUILD_MODE, ge::BUILD_MODE_TUNING}, + {ge::BUILD_STEP, ge::BUILD_STEP_AFTER_UB_MATCH}} + }, + {CompileType::COMPILE_TYPE_TUNED, { + {ge::BUILD_MODE, ge::BUILD_MODE_TUNING}, + {ge::BUILD_STEP, ge::BUILD_STEP_AFTER_MERGE}} + }, +}; + +// this set for executor init option +const std::set EXECUTOR_INIT_OPTION_SET = { + SERVER_IP, + SERVER_PORT, + DEVICE, + CORE_TYPE, + BUFFER_OPTIMIZE, + ENABLE_COMPRESS_WEIGHT, + COMPRESS_WEIGHT_CONF, + PRECISION_MODE, + DISABLE_REUSE_MEMORY, + ENABLE_SINGLE_STREAM, + SOC_VER, + AICORE_NUM, + GROUP_ID, + FUSION_SWITCH_FILE, + ENABLE_SMALL_CHANNEL, + OP_SELECT_IMPL_MODE, + OPTYPELIST_FOR_IMPLMODE, + ENABLE_SCOPE_FUSION_PASSES, + OP_DEBUG_LEVEL, + VIRTUAL_TYPE, + SPARSITY, + MODIFY_MIXLIST, + CUSTOMIZE_DTYPES, + FRAMEWORK, + COMPRESSION_OPTIMIZE_CONF, + + QTEST_SOC_VERSION, + SOC_VERSION, + TUNE_DEVICE_IDS, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + OP_COMPILER_CACHE_MODE, + OP_COMPILER_CACHE_DIR, + EXTERNAL_WEIGHT, + DETERMINISTIC, + OPTION_HOST_ENV_OS, + OPTION_HOST_ENV_CPU, + HOST_ENV_OS, + HOST_ENV_CPU, + DEBUG_DIR, + OP_TUNE_MODE, + OPTION_SCREEN_PRINT_MODE, + INIT_BYPASS, +}; + +const std::set EXECUTOR_COMPILE_OPTION_SET = { + INPUT_FORMAT, + INPUT_SHAPE, + INPUT_SHAPE_RANGE, + OP_NAME_MAP, + DYNAMIC_BATCH_SIZE, + DYNAMIC_IMAGE_SIZE, + DYNAMIC_DIMS, + PRECISION_MODE, + OUTPUT_TYPE, + INPUT_FP16_NODES, + OUT_NODES, + LOG_LEVEL, + OP_DEBUG_LEVEL, + INSERT_OP_FILE, + DISABLE_REUSE_MEMORY, + GE_INPUT_SHAPE_RANGE, + OPAT_BUILD_RUN_KEY, + ge::BUILD_MODE, + ge::BUILD_STEP, + ge::TUNING_PATH, + ge::ir_option::OP_BANK_UPDATE, + OP_PRECISION_MODE, + QTEST_SOC_VERSION, + HOST_ENV_OS, + HOST_ENV_CPU, + + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + OP_COMPILER_CACHE_MODE, + OP_COMPILER_CACHE_DIR, + DEBUG_DIR, + MDL_BANK_PATH, + OP_BANK_PATH, + MODIFY_MIXLIST, + SHAPE_GENERALIZED_BUILD_MODE, + OP_DEBUG_CONFIG, + EXTERNAL_WEIGHT, + EXCLUDE_ENGINES, + ge::OPTION_GRAPH_RUN_MODE, + DETERMINISTIC, + IMPL_MODE, + BUILD_BYPASS, +}; + +/** + * @ingroup aoe executor + * @brief aoe executor initialize + * @param [in] ExecuteInitConfig &config config + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorInitialize(ExecuteInitConfig &config); + +/** + * @ingroup aoe executor + * @brief aoe executor finalize + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorFinalize(); + +/** + * @ingroup aoe executor + * @brief aoe executor create session + * @param [out] ExecuteSession &session session + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorCreateSession(ExecuteSession &session); + +/** + * @ingroup aoe executor + * @brief aoe executor destroy session + * @param [in] ExecuteSession &session session + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorDestroySession(const ExecuteSession &session); + +/** + * @ingroup aoe executor + * @brief aoe executor save model + * @param [in] const ExecuteSession &session session + * @param [in] std::string &omPath save model path + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorSaveModel(const ExecuteSession &session, const std::string &omPath); + +/** + * @ingroup aoe executor + * @brief aoe executor load model + * @param [in] const ExecuteSession &session session + * @param [in] std::string &omPath load model path + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorLoadModel(const ExecuteSession &session, const std::string &omPath); + +/** + * @ingroup aoe executor + * @brief aoe executor unload model + * @param [in] const ExecuteSession &session session + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorUnloadModel(const ExecuteSession &session); +/** + * @ingroup aoe executor + * @brief aoe executor compile graph by session + * @param [in] const ExecuteSession &session session + * @param [in] const ge::Graph &graph compile graph + * @param [in] const std::map &compileOptions compile option + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorOffilneCompile(const ExecuteSession &session, const ge::Graph &graph, + const std::map &compileOptions); + +/** + * @ingroup aoe executor + * @brief aoe executor offline run graph by session + * @param [in] const ExecuteSession &session session + * @param [out] OfflineRunParam ¶m run param + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorOfflineRun(const ExecuteSession &session, OfflineRunParam ¶m); + +/** + * @ingroup aoe executor + * @brief aoe executor offline split graph + * @param [in] const ge::Graph &graph graph splited + * @param [in] const std::string &subgraphPath splited graph path + * @param [in] const std::map &splitOptions splited graph option + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorOfflineSpiltGraph(const ge::Graph &graph, const std::string &subgraphPath, + const std::map &splitOptions); + +/** + * @ingroup aoe executor + * @brief aoe executor compile run async + * @param [in] const ExecuteSession &session session + * @param [in] const ge::Graph &graph compile graph + * @param [in] const std::map &compileOptions compile option + * @param [out] OfflineRunParam ¶m offline param + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" std::future AoeExecutorOfflineCompileRunAsync(const ExecuteSession &session, + const ge::Graph &graph, const std::map &compileOptions, OfflineRunParam ¶m); + +/** + * @ingroup aoe executor + * @brief aoe executor compile run async + * @param [in] const ExecuteSession &session session + * @param [in] const ge::Graph &graph compile graph + * @param [in] const std::map &compileOptions compile option + * @param [out] OfflineRunParam ¶m offline param + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" std::future AoeExecutorQtestCompileRunAsync(const ExecuteSession &session, + const ge::Graph &graph, const std::map &compileOptions, OfflineRunParam ¶m); + +/** + * @ingroup aoe executor + * @brief aoe executor online run graph + * @param [in] const ge::Graph &graph run graph + * @param [out] OfflineRunParam ¶m run param + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorOnlineRun(const ge::Graph &graph, OnlineRunParam ¶m); + +/** + * @ingroup aoe executor + * @brief aoe executor qtest run graph + * @param [in] const ge::Graph &graph run graph + * @param [in] const std::string &graphPath graph path + * @param [out] OfflineRunParam ¶m run param + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorQtestRun(const ge::Graph &graph, const std::string &graphPath, OfflineRunParam ¶m); + +/** + * @ingroup aoe executor + * @brief aoe executor online split graph + * @param [in] const ge::Graph &graph graph splited + * @param [in] ge::Session *session ge session + * @param [in] const std::string &subgraphPath splited graph path + * @param [in] const std::map &splitOptions splited graph option + * @param [in] const std::vector &inputs graph inputs + * @return success == AOE_SUCCESS; failed != AOE_SUCCESS + */ +extern "C" AoeStatus AoeExecutorOnlineSpiltGraph(const ge::Graph &graph, ge::Session *session, + const std::string &subgraphPath, const std::map &splitOptions, + const std::vector &inputs); + +extern "C" std::future AoeExecutorAclnnRun(const AclnnRunConfig ¶m, RunnerResult &result); +} +#endif diff --git a/inc/aoe/aoe_tuning_api.h b/inc/aoe/aoe_tuning_api.h new file mode 100644 index 0000000000000000000000000000000000000000..593e4f0612c537511a5fd33db43b134f66fa4a7f --- /dev/null +++ b/inc/aoe/aoe_tuning_api.h @@ -0,0 +1,149 @@ +/** + * @file aoe_tuning_api.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.\n + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n + * + */ + +#ifndef AOE_TUNING_API_H +#define AOE_TUNING_API_H + +#include +#include "aoe_types.h" +#include "ge/ge_api.h" +#include "external/aoe.h" + +namespace Aoe { +// this set for global option key +const std::set GLOBAL_OPTION_SET = { + ge::AscendString(WORK_PATH), + ge::AscendString(SERVER_IP), + ge::AscendString(SERVER_PORT), + ge::AscendString(TUNING_PARALLEL_NUM), + ge::AscendString(DEVICE), + ge::AscendString(CORE_TYPE), + ge::AscendString(BUFFER_OPTIMIZE), + ge::AscendString(ENABLE_COMPRESS_WEIGHT), + ge::AscendString(COMPRESS_WEIGHT_CONF), + ge::AscendString(PRECISION_MODE), + ge::AscendString(DISABLE_REUSE_MEMORY), + ge::AscendString(ENABLE_SINGLE_STREAM), + ge::AscendString(AICORE_NUM), + ge::AscendString(GROUP_ID), + ge::AscendString(FUSION_SWITCH_FILE), + ge::AscendString(ENABLE_SMALL_CHANNEL), + ge::AscendString(OP_SELECT_IMPL_MODE), + ge::AscendString(OPTYPELIST_FOR_IMPLMODE), + ge::AscendString(ENABLE_SCOPE_FUSION_PASSES), + ge::AscendString(OP_DEBUG_LEVEL), + ge::AscendString(VIRTUAL_TYPE), + ge::AscendString(SPARSITY), + ge::AscendString(MODIFY_MIXLIST), + ge::AscendString(CUSTOMIZE_DTYPES), + ge::AscendString(FRAMEWORK), + ge::AscendString(JOB_TYPE), + ge::AscendString(RUN_LOOP), + ge::AscendString(RESOURCE_CONFIG_PATH), + + ge::AscendString(SOC_VERSION), + ge::AscendString(TUNE_DEVICE_IDS), + ge::AscendString(EXEC_DISABLE_REUSED_MEMORY), + ge::AscendString(AUTO_TUNE_MODE), + ge::AscendString(OP_COMPILER_CACHE_MODE), + ge::AscendString(OP_COMPILER_CACHE_DIR), + ge::AscendString(DEBUG_DIR), + ge::AscendString(EXTERNAL_WEIGHT), + ge::AscendString(DETERMINISTIC), + ge::AscendString(OPTION_HOST_ENV_OS), + ge::AscendString(OPTION_HOST_ENV_CPU), + ge::AscendString(HOST_ENV_OS), + ge::AscendString(HOST_ENV_CPU), + ge::AscendString(COMPRESSION_OPTIMIZE_CONF), + ge::AscendString(OPTION_GRAPH_RUN_MODE), + ge::AscendString(OP_TUNE_MODE), + ge::AscendString(SOC_VER), + ge::AscendString(OPTION_SCREEN_PRINT_MODE), + ge::AscendString(INIT_BYPASS), +}; + +// this set for tuning option key +const std::set TUNING_OPTION_SET = { + ge::AscendString(INPUT_FORMAT), + ge::AscendString(INPUT_SHAPE), + ge::AscendString(INPUT_SHAPE_RANGE), + ge::AscendString(OP_NAME_MAP), + ge::AscendString(DYNAMIC_BATCH_SIZE), + ge::AscendString(DYNAMIC_IMAGE_SIZE), + ge::AscendString(DYNAMIC_DIMS), + ge::AscendString(PRECISION_MODE), + ge::AscendString(OUTPUT_TYPE), + ge::AscendString(OUT_NODES), + ge::AscendString(INPUT_FP16_NODES), + ge::AscendString(LOG_LEVEL), + ge::AscendString(OP_DEBUG_LEVEL), + ge::AscendString(INSERT_OP_FILE), + ge::AscendString(GE_INPUT_SHAPE_RANGE), + ge::AscendString(OUTPUT), + ge::AscendString(RELOAD), + ge::AscendString(TUNING_NAME), + ge::AscendString(FRAMEWORK), + ge::AscendString(MODEL_PATH), + ge::AscendString(TUNE_OPS_FILE), + ge::AscendString(RECOMPUTE), + ge::AscendString(AOE_CONFIG_FILE), + ge::AscendString(OP_PRECISION_MODE), + ge::AscendString(KEEP_DTYPE), + ge::AscendString(SINGLE_OP), + ge::AscendString(TUNE_OPTIMIZATION_LEVEL), + ge::AscendString(FEATURE_DEEPER_OPAT), + ge::AscendString(FEATURE_NONHOMO_SPLIT), + ge::AscendString(FEATURE_INNER_AXIS_CUT), + ge::AscendString(FEATURE_OP_FORMAT), + ge::AscendString(OUT_FILE_NAME), + ge::AscendString(HOST_ENV_OS), + ge::AscendString(HOST_ENV_CPU), + ge::AscendString(EXEC_DISABLE_REUSED_MEMORY), + ge::AscendString(AUTO_TUNE_MODE), + ge::AscendString(OP_COMPILER_CACHE_MODE), + ge::AscendString(OP_COMPILER_CACHE_DIR), + ge::AscendString(DEBUG_DIR), + ge::AscendString(MDL_BANK_PATH), + ge::AscendString(OP_BANK_PATH), + ge::AscendString(MODIFY_MIXLIST), + ge::AscendString(SHAPE_GENERALIZED_BUILD_MODE), + ge::AscendString(OP_DEBUG_CONFIG), + ge::AscendString(EXTERNAL_WEIGHT), + ge::AscendString(EXCLUDE_ENGINES), + ge::AscendString(OP_TUNE_MODE), + ge::AscendString(OP_TUNE_KERNEL_PATH), + ge::AscendString(OP_TUNE_KERNEL_NAME), + ge::AscendString(OP_TUNE_JSON_FILE), + ge::AscendString(OP_LIST), + ge::AscendString(IMPL_MODE), + ge::AscendString(DETERMINISTIC), + ge::AscendString(OPTION_SCREEN_PRINT_MODE), + ge::AscendString(BUILD_BYPASS), +}; + +/** + * @brief : set depend graphs for session id + * @param [in] : uint64_t sessionId session id + * @param [in] : std::vector &dependGraphs depend graphs + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeSetDependGraphs(uint64_t sessionId, const std::vector &dependGraphs); + +/** + * @brief : set inputs of depend graphs for session id + * @param [in] : uint64_t sessionId session id + * @param [in] : std::vector> &inputs depend input tensor + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeSetDependGraphsInputs(uint64_t sessionId, const std::vector> &inputs); +} // namespace Aoe + +#endif diff --git a/inc/aoe/aoe_types.h b/inc/aoe/aoe_types.h new file mode 100644 index 0000000000000000000000000000000000000000..198fd220fba617467377b82177bc1df7e7b9d6f1 --- /dev/null +++ b/inc/aoe/aoe_types.h @@ -0,0 +1,276 @@ +/** + * @file aoe_types.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.\n + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n + * + */ + +#ifndef AOE_TYPES_H +#define AOE_TYPES_H +#include +#include +#include +#include "graph/graph.h" +#include "ge/ge_api.h" +#include "external/aoe_errcodes.h" + +namespace Aoe { + +// aoe 框架参数 (包括SGAT/OPAT等调优参数) +const char * const WORK_PATH = "work_path"; +const char * const SERVER_IP = "ip"; +const char * const SERVER_PORT = "port"; +const char * const TUNING_PARALLEL_NUM = "tuning_parallel_num"; +const char * const DEVICE = "device"; +const char * const JOB_TYPE = "job_type"; +const char * const RUN_LOOP = "run_loop"; +const char * const OUTPUT = "output"; +const char * const FRAMEWORK = "framework"; +const char * const MODEL_PATH = "model_path"; +const char * const TUNE_OPS_FILE = "tune_ops_file"; +const char * const SINGLE_OP = "singleop"; +const char * const OP_TUNE_FILE = "op_tune_file"; +const char * const OP_TUNE_MODE = "op_tune_mode"; +const char * const OP_TUNE_KERNEL_PATH = "op_tune_kernel_path"; +const char * const TUNE_OPTIMIZATION_LEVEL = "tune_optimization_level"; +const char * const FEATURE_DEEPER_OPAT = "Fdeeper_opat"; +const char * const FEATURE_NONHOMO_SPLIT = "Fnonhomo_split"; +const char * const FEATURE_INNER_AXIS_CUT = "Finner_axis_cut"; +const char * const FEATURE_OP_FORMAT = "Fop_format"; +const char * const RELOAD = "reload"; +const char * const MODEL = "model"; +const char * const WEIGHT = "weight"; +const char * const HELP = "help"; +const char * const PROGRESS = "progress_bar"; +const char * const INIT_BYPASS = "init_bypass"; +const char * const BUILD_BYPASS = "build_bypass"; + +// aoe 命令行支持的GE参数 (包括parser) +const char * const SOC_VER = "soc_version"; // AOE命令行不支持,内部通过acl接口获取 +const char * const CORE_TYPE = "core_type"; +const char * const BUFFER_OPTIMIZE = "buffer_optimize"; +const char * const ENABLE_COMPRESS_WEIGHT = "enable_compress_weight"; +const char * const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; +const char * const PRECISION_MODE = "precision_mode"; +const char * const DISABLE_REUSE_MEMORY = "disable_reuse_memory"; +const char * const ENABLE_SINGLE_STREAM = "enable_single_stream"; +const char * const AICORE_NUM = "aicore_num"; +const char * const GROUP_ID = "group_id"; +const char * const FUSION_SWITCH_FILE = "fusion_switch_file"; +const char * const ENABLE_SMALL_CHANNEL = "enable_small_channel"; +const char * const OP_SELECT_IMPL_MODE = "op_select_implmode"; +const char * const OPTYPELIST_FOR_IMPLMODE = "optypelist_for_implmode"; +const char * const ENABLE_SCOPE_FUSION_PASSES = "enable_scope_fusion_passes"; +const char * const OP_DEBUG_LEVEL = "op_debug_level"; +const char * const INPUT_FORMAT = "input_format"; +const char * const INPUT_SHAPE = "input_shape"; +const char * const INPUT_SHAPE_RANGE = "input_shape_range"; +const char * const OP_NAME_MAP = "op_name_map"; +const char * const DYNAMIC_BATCH_SIZE = "dynamic_batch_size"; +const char * const DYNAMIC_IMAGE_SIZE = "dynamic_image_size"; +const char * const DYNAMIC_DIMS = "dynamic_dims"; +const char * const OUTPUT_TYPE = "output_type"; +const char * const OUT_NODES = "out_nodes"; +const char * const INPUT_FP16_NODES = "input_fp16_nodes"; +const char * const LOG_LEVEL = "log"; +const char * const INSERT_OP_FILE = "insert_op_conf"; +const char * const VIRTUAL_TYPE = "virtual_type"; +const char * const COMPRESSION_OPTIMIZE_CONF = "compression_optimize_conf"; +const char * const SPARSITY = "sparsity"; +const char * const OP_PRECISION_MODE = "op_precision_mode"; +const char * const MODIFY_MIXLIST = "modify_mixlist"; +const char * const KEEP_DTYPE = "keep_dtype"; // 转换成图属性, 不需要传给ACL +const char * const CUSTOMIZE_DTYPES = "customize_dtypes"; +const char * const HOST_ENV_OS = "host_env_os"; +const char * const HOST_ENV_CPU = "host_env_cpu"; +const char * const IS_INPUT_ADJUST_HW_LAYOUT = "is_input_adjust_hw_layout"; // parser only +const char * const IS_OUTPUT_ADJUST_HW_LAYOUT = "is_output_adjust_hw_layout"; // parser only + +// tf only +const char * const AOE_CONFIG_FILE = "ge.aoe_config_file"; + +// aoe 内部参数 +const char * const JOB_ID = "job_id"; +const char * const TUNING_NAME = "tuning_name"; +const char * const OPAT_BUILD_RUN_KEY = "opat_build_run_key"; // opat +const char * const PLUGIN_OPTION_TUNING_GRAPH = "tuning_graph"; +const char * const PLUGIN_OPTION_TUNING_DEPEND_GRAPH = "tuning_depend_graph"; +const char * const PLUGIN_OPTION_GRAPH_INPUTS = "graph_inputs"; +const char * const PLUGIN_OPTION_DEPEND_GRAPH_INPUTS = "depend_graph_inputs"; +const char * const PLUGIN_OPTION_GE_SESSION = "ge_session"; +const char * const PLUGIN_OPTION_IS_SINGLE_OP = "is_singleop"; +const char * const PLUGIN_OPTION_IS_TF_OFFLINE = "is_tfoffline"; +const char * const PLUGIN_OPTION_SESSION_ID = "session_id"; +const char * const FEATURE = "feature"; +const char * const OUT_FILE_NAME = "out_file_name"; +const char * const OP_TUNE_JSON_FILE = "op_tune_json_file"; +const char * const OP_LIST = "op_list"; + +// aoe API透传的GE参数 +const char * const GE_INPUT_SHAPE_RANGE = "ge.exec.dataInputsShapeRange"; +const char * const RESOURCE_CONFIG_PATH = "ge.resourceConfigPath"; +const char * const RECOMPUTE = "ge.recompute"; +const char * const AUTO_TUNE_MODE = "ge.autoTuneMode"; +const char * const DEBUG_DIR = "ge.debugDir"; +const char * const DETERMINISTIC = "ge.deterministic"; +const char * const EXTERNAL_WEIGHT = "ge.externalWeight"; +const char * const OP_COMPILER_CACHE_DIR = "ge.op_compiler_cache_dir"; +const char * const OP_COMPILER_CACHE_MODE = "ge.op_compiler_cache_mode"; +const char * const OP_DEBUG_CONFIG = "op_debug_config"; +const char * const OPTION_HOST_ENV_CPU = "ge.host_env_cpu"; +const char * const OPTION_HOST_ENV_OS = "ge.host_env_os"; +const char * const SHAPE_GENERALIZED_BUILD_MODE = "ge.shape_generalized_build_mode"; +const char * const SOC_VERSION = "ge.socVersion"; +const char * const EXCLUDE_ENGINES = "ge.exec.exclude_engines"; +const char * const EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; +const char * const MDL_BANK_PATH = "ge.mdl_bank_path"; +const char * const OP_BANK_PATH = "ge.op_bank_path"; +const char * const TUNE_DEVICE_IDS = "ge.exec.tuneDeviceIds"; +const char * const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; +const char * const OP_TUNE_KERNEL_NAME = "_kernelname"; +const char * const IMPL_MODE = "ge.opSelectImplmode"; +const char * const OPTION_SCREEN_PRINT_MODE = "ge.screen_print_mode"; + + +// 其他常量定义 +constexpr char NO_DYNAMIC_PARAM[] = "no_dynamic_param"; +constexpr char const *QTEST_SOC_VERSION = "QTEST_SOC_VERSION"; +constexpr char OP_TUNE_MODE_STATIC_KERNEL[] = "static_kernel"; +constexpr char OP_TUNE_MODE_FAST[] = "fast"; +constexpr char OP_TUNE_MODE_BIN_BANK[] = "bin_bank"; + +const AoeStatus AOE_ERROR_NO_AICORE_GRAPH = AOE_ERROR_NON_OPTIMIZE_GRAPH; // for sgat + +/* system api error code */ +const AoeStatus AOE_ERROR_LIB_ACCESS = 201; // 内部状态码 +const AoeStatus AOE_ERROR_LIB_SYMBOL = 202; +const AoeStatus AOE_ERROR_MEMORY_OPERATION = 203; +const AoeStatus AOE_ERROR_CHECK_MEMORY = 204; + +/* aoe cmd error code */ +const AoeStatus AOE_NO_ERROR = 301; // --help +const AoeStatus AOE_ERROR_NO_SUPPORT = 302; // not support model +const AoeStatus AOE_ERROR_PARAM_EXCP = 303; // param exception + +/* aoe library error code */ +const AoeStatus AOE_ERROR_REPEAT_GRAPH = 401; +const AoeStatus AOE_ERR_PLUGIN_LIBRARY_LOAD = 402; +const AoeStatus AOE_ERR_PLUGIN_LIBRARY_UNLOAD = 403; + +/* aoe executor compiler return code */ +const AoeStatus AOE_ERROR_EXECUTE_COMPILER = 2000; + +/* aoe executor runner return code */ +const AoeStatus AOE_ERROR_EXECUTE_RUNNER = 3000; +const AoeStatus AOE_ERROR_NO_SPACE = 3001; +const AoeStatus AOE_ERROR_EXECUTE_FAILED = 3002; +const AoeStatus AOE_ERROR_OUT_OF_BOUND = 3003; +const AoeStatus AOE_ERROR_SET_DYNAMIC_PARAM_FAILED = 3004; + +/* profiling return code */ +const AoeStatus AOE_ERROR_PROF_TIMEOUT = 3202; +const AoeStatus AOE_ERROR_PROF_PARSER = 3203; + +/* network(nca) return code */ +const AoeStatus AOE_ERROR_NET_DOWN = 3300; +const AoeStatus AOE_ERROR_NET_UNREACH = 3301; + +enum class RunMode : uint32_t { + DEFAULT_RUN_MODE = 0, + LOAD_MODEL_PREV_ALLOC_MEM, + LOAD_MODEL_WITHOUT_CHECK_MEM, + RUN_MODE_MAX = 100 +}; + +struct AoeDataInfo { + uint8_t *ptr = nullptr; + size_t size = 0; + int64_t *dims = nullptr; + uint32_t dimNum = 0; + ge::DataType type = ge::DT_UNDEFINED; +}; + +struct AoeBufferData { + std::shared_ptr data = nullptr; + uint64_t length; +}; + +struct RunnerInitConfig { + // ncs only + std::vector devList; + uint32_t groupId = 0; + uint32_t aicoreNum = 8; +}; + +struct RunnerConfig { + bool isProf = true; + uint32_t loop; + // Running time above this threshold will end the run prematurely + uint64_t takeTimeUpperBound = std::numeric_limits::max(); + // offline only + std::vector input; + std::vector output; + std::string modelPath; // run with model file + AoeDataInfo modelData; // run with model data + uint32_t modelId; // run with model Id + // online only + ge::Session *session = nullptr; + std::map options; + std::vector> inputs; + std::vector> outputs; + std::vector dependGraph; // run graph (for training) + RunMode runMode = RunMode::DEFAULT_RUN_MODE; + std::string dynamicType; + std::string dynamicValue; + std::string buildRunKey; +}; + +struct RunnerOpInfo { + std::string opName; + uint64_t opCostTime; + uint64_t aicoreCostTime; + // gradient_split only + std::string modelName; + std::string opType; + uint64_t start; + uint64_t end; +}; + +struct RunnerModelInfo { + uint64_t totalCostTime; +}; + +// online run result +struct RunnerRunResult { + std::vector modelInfo; + std::vector opInfo; +}; + +// offline run result +struct RunnerResult { + uint64_t totalCostTime; + uint64_t profE2ECostTime; + std::map opCostTime; + std::map aicoreCostTime; +}; + +// tuning result +struct PerformanceMetric { + uint64_t totalCostTime = 0UL; + uint64_t totalProfilingCostTime = 0UL; + std::map opCostTime; + std::map aicoreCostTime; +}; + +struct AclnnRunConfig { + uint32_t loop; + std::string aclnnJson; + std::string strategyJson; + void *tilingContext; +}; +} +#endif diff --git a/inc/aoe/external/aoe.h b/inc/aoe/external/aoe.h new file mode 100644 index 0000000000000000000000000000000000000000..09c04d7541d749a4faa3b98ff7da58fc1d01738c --- /dev/null +++ b/inc/aoe/external/aoe.h @@ -0,0 +1,81 @@ +/** + * @file aoe.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.\n + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n + * + */ + +#ifndef AOE_EXTERNAL_AOE_H +#define AOE_EXTERNAL_AOE_H + +#include +#include "ge/ge_api.h" +#include "graph/ascend_string.h" +#include "external/aoe_errcodes.h" + +namespace Aoe { +/** + * @brief : initialize aoe tuning api + * @param [in] : std::map &globalOptions global options + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeInitialize(const std::map &globalOptions); + +/** + * @brief : finalize aoe tuning api + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeFinalize(); + +/** + * @brief : create aoe session + * @param [out] : uint64_t sessionId session id + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeCreateSession(uint64_t &sessionId); + +/** + * @brief : destroy aoe session + * @param [in] : uint64_t sessionId session id + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeDestroySession(uint64_t sessionId); + +/** + * @brief : set ge session for session id + * @param [in] : uint64_t sessionId session id + * @param [in] : ge::Session *geSession ge session handle + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeSetGeSession(uint64_t sessionId, ge::Session *geSession); + +/** + * @brief : set tuning graphs for session id + * @param [in] : uint64_t sessionId session id + * @param [in] : ge::Graph &tuningGraph tuning graph + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeSetTuningGraph(uint64_t sessionId, const ge::Graph &tuningGraph); + +/** + * @brief : set input of tuning graph for session id + * @param [in] : uint64_t sessionId session id + * @param [in] : std::vector &inputs input tensor + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeSetTuningGraphInput(uint64_t sessionId, const std::vector &input); + +/** + * @brief : tuning graph + * @param [in] : uint64_t sessionId session id + * @param [in] : std::map &tuningOptions tuning options + * @return : == AOE_SUCCESS : success, != AOE_SUCCESS : failed + */ +extern "C" AoeStatus AoeTuningGraph(uint64_t sessionId, + const std::map &tuningOptions); +} // namespace Aoe +#endif // AOE_EXTERNAL_AOE_H diff --git a/inc/aoe/external/aoe_errcodes.h b/inc/aoe/external/aoe_errcodes.h new file mode 100644 index 0000000000000000000000000000000000000000..0d40fe9506a755fcbb2b9f8692abc330ca737442 --- /dev/null +++ b/inc/aoe/external/aoe_errcodes.h @@ -0,0 +1,37 @@ +/** + * @file aoe_errcodes.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.\n + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.\n + * + */ +#ifndef AOE_EXTERNAL_AOE_ERRCODES_H +#define AOE_EXTERNAL_AOE_ERRCODES_H + +#include + +namespace Aoe { +using AoeStatus = int32_t; + +constexpr AoeStatus AOE_FAILURE = -1; // 通用AOE错误 +constexpr AoeStatus AOE_SUCCESS = 0; // 成功 +constexpr AoeStatus AOE_ERROR_MEMORY_ALLOC = 1; // 内存分配失败 +constexpr AoeStatus AOE_ERROR_INVALID_PARAM = 2; // 非法参数 +constexpr AoeStatus AOE_ERROR_UNINITIALIZED = 3; // API未初始化 +constexpr AoeStatus AOE_ERROR_REPEAT_INITIALIZE = 4; // API重复初始化 +constexpr AoeStatus AOE_ERROR_INVALID_SESSION = 5; // 非法的sessionID +constexpr AoeStatus AOE_ERROR_BUSY = 6; // 多线程调用阻塞 +constexpr AoeStatus AOE_ERROR_INVALID_GRAPH = 7; // 非法的调优图 +constexpr AoeStatus AOE_ERROR_NON_OPTIMIZE_GRAPH = 8; // 无调优图 +constexpr AoeStatus AOE_ERROR_DYNAMIC_GRAPH = 9; // 调优业务不支持动态图 +constexpr AoeStatus AOE_ERROR_DYNAMIC_SHAPE_RANGE = 10; // 动态shape(opat) +constexpr AoeStatus AOE_ERROR_INVALID_OPTION_ATTR = 11; // option属性错误 +constexpr AoeStatus QTEST_LIB_OPEN_FAILED = 12; // 打开lib失败 +constexpr AoeStatus QTEST_LIB_INVALID_FUNC = 13; // 无效函数 + +constexpr AoeStatus AOE_ERROR_TUNING_ERROR = 100; // 其他调优错误 +} // namespace Aoe +#endif // AOE_EXTERNAL_AOE_ERRCODES_H \ No newline at end of file diff --git a/inc/metadef/common/blocking_queue.h b/inc/metadef/common/blocking_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..6e4bf199b56753e97682fbe1bf7c49ddd5648e9f --- /dev/null +++ b/inc/metadef/common/blocking_queue.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_BLOCKING_QUEUE_H_ +#define INC_COMMON_BLOCKING_QUEUE_H_ + +#include +#include +#include + +namespace ge { +constexpr uint32_t kDefaultMaxQueueSize = 2048U; +constexpr int32_t kDefaultWaitTimeoutInSec = 600; + +template +class BlockingQueue { + public: + explicit BlockingQueue(const uint32_t max_size = kDefaultMaxQueueSize) : max_size_(max_size) {} + + ~BlockingQueue() = default; + + bool Pop(T &item, const int32_t time_out = INT32_MAX) { + std::unique_lock lock(mutex_); + + while (!empty_cond_.wait_for(lock, std::chrono::seconds(time_out), + [this]() -> bool { return (!queue_.empty()) || (is_stoped_); })) { + is_stuck_ = true; + return false; + } + + if (is_stoped_) { + return false; + } + + item = std::move(queue_.front()); + queue_.pop_front(); + + full_cond_.notify_one(); + + return true; + } + + bool Pop(T &item, bool &is_stuck) { + const auto ret = Pop(item, kDefaultWaitTimeoutInSec); + is_stuck = is_stuck_; + return ret; + } + + bool Push(const T &item, const bool is_wait = true) { + std::unique_lock lock(mutex_); + + while ((queue_.size() >= max_size_) && (!is_stoped_)) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.push_back(item); + + empty_cond_.notify_one(); + + return true; + } + + bool Push(T &&item, const bool is_wait = true) { + std::unique_lock lock(mutex_); + + while ((queue_.size() >= max_size_) && (!is_stoped_)) { + if (!is_wait) { + return false; + } + full_cond_.wait(lock); + } + + if (is_stoped_) { + return false; + } + + queue_.emplace_back(std::move(item)); + + empty_cond_.notify_one(); + + return true; + } + + void Stop() { + { + const std::unique_lock lock(mutex_); + is_stoped_ = true; + } + + full_cond_.notify_all(); + empty_cond_.notify_all(); + } + + void Restart() { + const std::unique_lock lock(mutex_); + is_stoped_ = false; + } + + // if the queue is stoped ,need call this function to release the unprocessed items + std::list GetRemainItems() { + const std::unique_lock lock(mutex_); + + if (!is_stoped_) { + return std::list(); + } + + return queue_; + } + + bool IsFull() { + const std::unique_lock lock(mutex_); + return queue_.size() >= max_size_; + } + + void Clear() { + const std::unique_lock lock(mutex_); + queue_.clear(); + } + + void SetMaxSize(const uint32_t size) { + const std::unique_lock lock(mutex_); + if (size == 0U) { + max_size_ = kDefaultMaxQueueSize; + return; + } + max_size_ = size; + } + + uint32_t Size() { + const std::unique_lock lock(mutex_); + return static_cast(queue_.size()); + } + + private: + std::list queue_; + std::mutex mutex_; + std::condition_variable empty_cond_; + std::condition_variable full_cond_; + uint32_t max_size_; + + bool is_stoped_{false}; + bool is_stuck_{false}; +}; +} // namespace ge + +#endif // INC_COMMON_BLOCKING_QUEUE_H_ diff --git a/inc/metadef/common/checker.h b/inc/metadef/common/checker.h new file mode 100644 index 0000000000000000000000000000000000000000..3c09e03f673acca05c8aaf775c2c246d3019dee5 --- /dev/null +++ b/inc/metadef/common/checker.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_COMMON_CHECKER_H_ +#define METADEF_CXX_INC_COMMON_CHECKER_H_ +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "common/ge_common/debug/ge_log.h" +#include "hyper_status.h" + +struct ErrorResult { + operator bool() const { + return false; + } + operator ge::graphStatus() const { + return ge::PARAM_INVALID; + } + template + operator std::unique_ptr() const { + return nullptr; + } + template + operator std::shared_ptr() const { + return nullptr; + } + template + operator T *() const { + return nullptr; + } + template + operator std::vector>() const { + return {}; + } + template + operator std::vector() const { + return {}; + } + operator std::string() const { + return ""; + } + template + operator T() const { + return T(); + } +}; + +inline std::vector CreateErrorMsg(const char *format, ...) { + va_list args; + va_start(args, format); + va_list args_copy; + va_copy(args_copy, args); + const size_t len = static_cast(vsnprintf(nullptr, 0, format, args_copy)); + va_end(args_copy); + std::vector msg(len + 1U, '\0'); + const auto ret = vsnprintf_s(msg.data(), len + 1U, len, format, args); + va_end(args); + return (ret > 0) ? msg : std::vector{}; +} + +inline std::vector CreateErrorMsg() { + return {}; +} + +#define GE_ASSERT_EQ(x, y) \ + do { \ + const auto &xv = (x); \ + const auto &yv = (y); \ + if (xv != yv) { \ + std::stringstream ss; \ + ss << "Assert (" << #x << " == " << #y << ") failed, expect " << yv << " actual " << xv; \ + REPORT_INNER_ERROR("E19999", "%s", ss.str().c_str()); \ + GELOGE(ge::FAILED, "%s", ss.str().c_str()); \ + return ::ErrorResult(); \ + } \ + } while (false) + +#define GE_WARN_ASSERT(exp, ...) \ + do { \ + if (!(exp)) { \ + auto msg = CreateErrorMsg(__VA_ARGS__); \ + GELOGW("Assert failed: %s", (msg.empty() ? #exp : msg.data())); \ + return ::ErrorResult(); \ + } \ + } while (false) + +#define GE_ASSERT(exp, ...) \ + do { \ + if (!(exp)) { \ + auto msg = CreateErrorMsg(__VA_ARGS__); \ + if (msg.empty()) { \ + REPORT_INNER_ERROR("E19999", "Assert %s failed", #exp); \ + GELOGE(ge::FAILED, "Assert %s failed", #exp); \ + } else { \ + REPORT_INNER_ERROR("E19999", "%s", msg.data()); \ + GELOGE(ge::FAILED, "%s", msg.data()); \ + } \ + return ::ErrorResult(); \ + } \ + } while (false) + +#define GE_ASSERT_NOTNULL(v, ...) GE_ASSERT(((v) != nullptr), __VA_ARGS__) +#define GE_ASSERT_SUCCESS(v, ...) GE_ASSERT(((v) == ge::SUCCESS), __VA_ARGS__) +#define GE_ASSERT_GRAPH_SUCCESS(v, ...) GE_ASSERT(((v) == ge::GRAPH_SUCCESS), __VA_ARGS__) +#define GE_ASSERT_RT_OK(v, ...) GE_ASSERT(((v) == 0), __VA_ARGS__) +#define GE_ASSERT_EOK(v, ...) GE_ASSERT(((v) == EOK), __VA_ARGS__) +#define GE_ASSERT_TRUE(v, ...) GE_ASSERT((v), __VA_ARGS__) +#define GE_ASSERT_HYPER_SUCCESS(v, ...) GE_ASSERT(((v).IsSuccess()), __VA_ARGS__) + +#define GE_WARN_ASSERT_GRAPH_SUCCESS(v, ...) GE_WARN_ASSERT(((v) == ge::GRAPH_SUCCESS), __VA_ARGS__) +#endif // METADEF_CXX_INC_COMMON_CHECKER_H_ diff --git a/inc/metadef/common/dynamic_aipp.h b/inc/metadef/common/dynamic_aipp.h new file mode 100644 index 0000000000000000000000000000000000000000..47b22c6fd8f6aeef37344a9babf8fdf078df23a2 --- /dev/null +++ b/inc/metadef/common/dynamic_aipp.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_DYNAMIC_AIPP_H_ +#define INC_COMMON_DYNAMIC_AIPP_H_ + +#include + +/** +* @ingroup dnn +* @brief struct define of dynamic aipp batch parameter. +*/ +struct tagAippDynamicBatchPara { + int8_t cropSwitch; // crop switch + int8_t scfSwitch; // resize switch + int8_t paddingSwitch; // 0: unable padding + // 1: padding config value,sfr_filling_hblank_ch0 ~ sfr_filling_hblank_ch2 + // 2: padding source picture data, single row/collumn copy + // 3: padding source picture data, block copy + // 4: padding source picture data, mirror copy + int8_t rotateSwitch; // rotate switch,0: non-ratate, + // 1: ratate 90° clockwise,2: ratate 180° clockwise,3: ratate 270° clockwise + int8_t reserve[4]; + int32_t cropStartPosW; // the start horizontal position of cropping + int32_t cropStartPosH; // the start vertical position of cropping + int32_t cropSizeW; // crop width + int32_t cropSizeH; // crop height + + int32_t scfInputSizeW; // input width of scf + int32_t scfInputSizeH; // input height of scf + int32_t scfOutputSizeW; // output width of scf + int32_t scfOutputSizeH; // output height of scf + + int32_t paddingSizeTop; // top padding size + int32_t paddingSizeBottom; // bottom padding size + int32_t paddingSizeLeft; // left padding size + int32_t paddingSizeRight; // right padding size + + int16_t dtcPixelMeanChn0; // mean value of channel 0 + int16_t dtcPixelMeanChn1; // mean value of channel 1 + int16_t dtcPixelMeanChn2; // mean value of channel 2 + int16_t dtcPixelMeanChn3; // mean value of channel 3 + + uint16_t dtcPixelMinChn0; // min value of channel 0 + uint16_t dtcPixelMinChn1; // min value of channel 1 + uint16_t dtcPixelMinChn2; // min value of channel 2 + uint16_t dtcPixelMinChn3; // min value of channel 3 + uint16_t dtcPixelVarReciChn0; // sfr_dtc_pixel_variance_reci_ch0 + uint16_t dtcPixelVarReciChn1; // sfr_dtc_pixel_variance_reci_ch1 + uint16_t dtcPixelVarReciChn2; // sfr_dtc_pixel_variance_reci_ch2 + uint16_t dtcPixelVarReciChn3; // sfr_dtc_pixel_variance_reci_ch3 + + int8_t reserve1[16]; // 32B assign, for ub copy +}; +using kAippDynamicBatchPara = tagAippDynamicBatchPara; + +/** +* @ingroup dnn +* @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte +*/ +struct tagAippDynamicPara { + uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 + int8_t cscSwitch; // csc switch + int8_t rbuvSwapSwitch; // rb/ub swap switch + int8_t axSwapSwitch; // RGBA->ARGB, YUVA->AYUV swap switch + int8_t batchNum; // batch parameter number + int8_t reserve1[3]; + int32_t srcImageSizeW; // source image width + int32_t srcImageSizeH; // source image height + int16_t cscMatrixR0C0; // csc_matrix_r0_c0 + int16_t cscMatrixR0C1; // csc_matrix_r0_c1 + int16_t cscMatrixR0C2; // csc_matrix_r0_c2 + int16_t cscMatrixR1C0; // csc_matrix_r1_c0 + int16_t cscMatrixR1C1; // csc_matrix_r1_c1 + int16_t cscMatrixR1C2; // csc_matrix_r1_c2 + int16_t cscMatrixR2C0; // csc_matrix_r2_c0 + int16_t cscMatrixR2C1; // csc_matrix_r2_c1 + int16_t cscMatrixR2C2; // csc_matrix_r2_c2 + int16_t reserve2[3]; + uint8_t cscOutputBiasR0; // output Bias for RGB to YUV, element of row 0, unsigned number + uint8_t cscOutputBiasR1; // output Bias for RGB to YUV, element of row 1, unsigned number + uint8_t cscOutputBiasR2; // output Bias for RGB to YUV, element of row 2, unsigned number + uint8_t cscInputBiasR0; // input Bias for YUV to RGB, element of row 0, unsigned number + uint8_t cscInputBiasR1; // input Bias for YUV to RGB, element of row 1, unsigned number + uint8_t cscInputBiasR2; // input Bias for YUV to RGB, element of row 2, unsigned number + uint8_t reserve3[2]; + int8_t reserve4[16]; // 32B assign, for ub copy + + kAippDynamicBatchPara aippBatchPara; // allow transfer several batch para. +}; +using kAippDynamicPara = tagAippDynamicPara; + +#endif // INC_COMMON_DYNAMIC_AIPP_H_ diff --git a/inc/metadef/common/fe_executor/ffts_plus_qos_update.h b/inc/metadef/common/fe_executor/ffts_plus_qos_update.h new file mode 100644 index 0000000000000000000000000000000000000000..94a22aea257666813ce0cc1d5a75828ddcfec575 --- /dev/null +++ b/inc/metadef/common/fe_executor/ffts_plus_qos_update.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef FFTS_PLUS_QOS_UPDATE_H_ +#define FFTS_PLUS_QOS_UPDATE_H_ + +#include "runtime/rt_ffts_plus_define.h" +#include "graph/utils/node_utils.h" +namespace ffts { + +bool UpdateAicAivCtxQos(rtFftsPlusAicAivCtx_t *ctx, int label, int device_id); +bool UpdateMixAicAivCtxQos(rtFftsPlusMixAicAivCtx_t *ctx, int label, int device_id); +bool UpdateDataCtxQos(rtFftsPlusDataCtx_t *ctx, int device_id); + +} +#endif diff --git a/inc/metadef/common/ge_common/debug/ge_log.h b/inc/metadef/common/ge_common/debug/ge_log.h new file mode 100644 index 0000000000000000000000000000000000000000..ea198f6b72bed8d9be731054ec511394543b181a --- /dev/null +++ b/inc/metadef/common/ge_common/debug/ge_log.h @@ -0,0 +1,140 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_GE_COMMON_DEBUG_GE_LOG_H_ +#define INC_COMMON_GE_COMMON_DEBUG_GE_LOG_H_ + +#include +#include + +#include "common/ge_common/ge_inner_error_codes.h" +#include "common/util/error_manager/error_manager.h" +#include "toolchain/slog.h" +#ifdef __GNUC__ +#include +#include +#else +#include "mmpa/mmpa_api.h" +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#define GE_MODULE_NAME static_cast(GE) +#define GE_MODULE_NAME_U16 static_cast(GE) + +// trace status of log +enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; + +class GE_FUNC_VISIBILITY GeLog { + public: + static uint64_t GetTid() { +#ifdef __GNUC__ + const uint64_t tid = static_cast(syscall(__NR_gettid)); +#else + const uint64_t tid = static_cast(GetCurrentThreadId()); +#endif + return tid; + } +}; + +inline bool IsLogEnable(const int32_t module_name, const int32_t log_level) { + const int32_t enable = CheckLogLevel(module_name, log_level); + // 1:enable, 0:disable + return (enable == 1); +} + +inline bool IsLogPrintStdout() { + static int32_t stdout_flag = -1; + if (stdout_flag == -1) { + const char *env_ret = getenv("ASCEND_SLOG_PRINT_TO_STDOUT"); + const bool print_stdout = ((env_ret != nullptr) && (strcmp(env_ret, "1") == 0)); + stdout_flag = print_stdout ? 1 : 0; + } + return (stdout_flag == 1) ? true : false; +} + +#define GELOGE(ERROR_CODE, fmt, ...) \ + do { \ + dlog_error(GE_MODULE_NAME, "%" PRIu64 " %s: ErrorNo: %" PRIuLEAST8 "(%s) %s" fmt, \ + GeLog::GetTid(), &__FUNCTION__[0U], \ + (ERROR_CODE), ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), \ + ErrorManager::GetInstance().GetLogHeader().c_str(), ##__VA_ARGS__); \ + } while (false) + +#define GELOGW(fmt, ...) \ + do { \ + dlog_warn(GE_MODULE_NAME, "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + } while (false) + +#define GELOGI(fmt, ...) \ + do { \ + dlog_info(GE_MODULE_NAME, "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + } while (false) + +#define GELOGD(fmt, ...) \ + do { \ + dlog_debug(GE_MODULE_NAME, "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + } while (false) + +#define GEEVENT(fmt, ...) \ + do { \ + dlog_info(static_cast(static_cast(RUN_LOG_MASK) | static_cast(GE_MODULE_NAME)), \ + "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + if (!IsLogPrintStdout()) { \ + dlog_info(GE_MODULE_NAME, "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + } \ + } while (false) + +#define GERUNINFO(fmt, ...) \ + do { \ + dlog_info(static_cast(static_cast(RUN_LOG_MASK) | static_cast(GE_MODULE_NAME)), \ + "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + if (!IsLogPrintStdout()) { \ + dlog_info(GE_MODULE_NAME, "%" PRIu64 " %s:" fmt, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + } \ + } while (false) + +#define GELOGT(VALUE, fmt, ...) \ + do { \ + constexpr const char_t *TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ + constexpr int32_t idx = static_cast(VALUE); \ + char_t *v = const_cast(TraceStatStr[idx]); \ + dlog_info((static_cast(RUN_LOG_MASK) | static_cast(GE_MODULE_NAME)), \ + "[status:%s]%" PRIu64 " %s:" fmt, v, GeLog::GetTid(), &__FUNCTION__[0U], ##__VA_ARGS__); \ + } while (false) + +#define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ + do { \ + dlog_error((MOD_NAME), "%" PRIu64 " %s: ErrorNo: %" PRIuLEAST8 "(%s) %s" fmt, GeLog::GetTid(), \ + &__FUNCTION__[0U], (ERROR_CODE), \ + ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ErrorManager::GetInstance().GetLogHeader().c_str(), \ + ##__VA_ARGS__); \ + } while (false) + +// print memory when it is greater than 1KB. +#define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ + do { \ + if (static_cast(SIZE) > 1024UL) { \ + GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast(SIZE), (PURPOSE)); \ + } \ + } while (false) + +#define GELOG_DEPRECATED(option) \ + do { \ + std::cout << "[WARNING][GE] Option " << (option) << " is deprecated and will be removed in future version." \ + " Please do not configure this option in the future." << std::endl; \ + } while (false) + + +#ifdef __cplusplus +} +#endif +#endif // INC_COMMON_GE_COMMON_DEBUG_GE_LOG_H_ diff --git a/inc/metadef/common/ge_common/debug/log.h b/inc/metadef/common/ge_common/debug/log.h new file mode 100644 index 0000000000000000000000000000000000000000..2c113387a1914e3abf06beca70faaff6b4183b32 --- /dev/null +++ b/inc/metadef/common/ge_common/debug/log.h @@ -0,0 +1,214 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_GE_COMMON_DEBUG_LOG_H_ +#define INC_COMMON_GE_COMMON_DEBUG_LOG_H_ + +#include +#include + +#include "common/ge_common/string_util.h" +#include "common/ge_common/util.h" +#include "common/ge_common/debug/ge_log.h" +#include "external/ge_common/ge_api_error_codes.h" + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define DOMI_LOGE(fmt, ...) GE_LOG_ERROR(GE_MODULE_NAME, (ge::FAILED), fmt, ##__VA_ARGS__) +#else +#include +#if defined(BUILD_VERSION_PERF) +#define DOMI_LOGE(fmt, ...) +#else +// The Android system has strict log control. Do not modify the log. +#define DOMI_LOGE(fmt, ...) \ + __android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) +#endif +#endif + +// ge marco +#define GE_LOGI_IF(condition, ...) \ + if ((condition)) { \ + GELOGI(__VA_ARGS__); \ + } + +#define GE_LOGW_IF(condition, ...) \ + if ((condition)) { \ + GELOGW(__VA_ARGS__); \ + } + +#define GE_LOGE_IF(condition, ...) \ + if ((condition)) { \ + GELOGE((ge::FAILED), __VA_ARGS__); \ + } + +// If expr is not SUCCESS, print the log and return the same value +#define GE_CHK_STATUS_RET(expr, ...) \ + do { \ + const ge::Status _chk_status = (expr); \ + if (_chk_status != ge::SUCCESS) { \ + GELOGE((ge::FAILED), __VA_ARGS__); \ + return _chk_status; \ + } \ + } while (false) + +// If expr is not SUCCESS, print the log and do not execute return +#define GE_CHK_STATUS(expr, ...) \ + do { \ + const ge::Status _chk_status = (expr); \ + if (_chk_status != ge::SUCCESS) { \ + GELOGE(_chk_status, __VA_ARGS__); \ + } \ + } while (false) + +// If expr is not SUCCESS, return the same value +#define GE_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const ge::Status _chk_status = (expr); \ + if (_chk_status != ge::SUCCESS) { \ + return _chk_status; \ + } \ + } while (false) + +// If expr is not GRAPH_SUCCESS, print the log and return FAILED +#define GE_CHK_GRAPH_STATUS_RET(expr, ...) \ + do { \ + if ((expr) != ge::GRAPH_SUCCESS) { \ + REPORT_CALL_ERROR("E19999", "Operator graph failed"); \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return (ge::FAILED); \ + } \ + } while (false) + +// If expr is not SUCCESS, print the log and execute a custom statement +#define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ + do { \ + const ge::Status _chk_status = (expr); \ + GE_CHK_BOOL_EXEC(_chk_status == SUCCESS, exec_expr, __VA_ARGS__); \ + } while (false) + +// If expr is not true, print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + const bool b = (expr); \ + if (!b) { \ + REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ + GELOGE((_status), __VA_ARGS__); \ + return (_status); \ + } \ + } while (false) + +// If expr is true, print info log and return the specified status +#define GE_CHK_BOOL_RET_SPECIAL_STATUS(expr, _status, ...) \ + do { \ + const bool b = (expr); \ + if (b) { \ + GELOGI(__VA_ARGS__); \ + return (_status); \ + } \ + } while (false) + +// If expr is not true, print the log and return the specified status +#define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ + do { \ + const bool b = (expr); \ + if (!b) { \ + return (_status); \ + } \ + } while (false) + +// If expr is not true, print the log and execute a custom statement +#define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ + { \ + const bool b = (expr); \ + if (!b) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + exec_expr; \ + } \ + } + +// -----------------runtime related macro definitions------------------------------- +// If expr is not RT_ERROR_NONE, print the log +#define GE_CHK_RT(expr) \ + do { \ + const rtError_t _rt_err = (expr); \ + if (_rt_err != RT_ERROR_NONE) { \ + GELOGE(ge::RT_FAILED, "Call rt api failed, ret: 0x%X", _rt_err); \ + } \ + } while (false) + +// If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression +#define GE_CHK_RT_EXEC(expr, exec_expr) \ + { \ + const rtError_t _rt_ret = (expr); \ + if (_rt_ret != RT_ERROR_NONE) { \ + GELOGE(ge::RT_FAILED, "Call rt api failed, ret: 0x%X", _rt_ret); \ + exec_expr; \ + } \ + } + +// If expr is not RT_ERROR_NONE, print the log and return +#define GE_CHK_RT_RET(expr) \ + do { \ + const rtError_t _rt_ret = (expr); \ + if (_rt_ret != RT_ERROR_NONE) { \ + REPORT_CALL_ERROR("E19999", "Call %s fail, ret: 0x%X", #expr, _rt_ret); \ + GELOGE(ge::RT_FAILED, "Call rt api failed, ret: 0x%X", _rt_ret); \ + return RT_ERROR_TO_GE_STATUS(_rt_ret); \ + } \ + } while (false) + +// If expr is true, execute exec_expr without printing logs +#define GE_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +// If make_shared is abnormal, print the log and execute the statement +#define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ + try { \ + exec_expr0; \ + } catch (const std::bad_alloc &) { \ + GELOGE(ge::FAILED, "Make shared failed"); \ + exec_expr1; \ + } + +#define GE_ERRORLOG_AND_ERRORMSG(_status, errormsg) \ + do { \ + GELOGE((_status), "[Check][InnerData]%s", (errormsg)); \ + REPORT_INNER_ERROR("E19999", "%s", (errormsg)); \ + } while (false) + +#define GE_WARNINGLOG_AND_ERRORMSG(errormsg) \ + do { \ + GELOGW("%s", (errormsg)); \ + ErrorManager::GetInstance().ATCReportErrMessage("E10052", {"reason"}, {(errormsg)}); \ + } while (false) + +#define GE_CHK_LOG_AND_ERRORMSG(expr, _status, errormsg) \ + do { \ + const bool b = (expr); \ + if (!b) { \ + GELOGE((_status), "%s", (errormsg)); \ + ErrorManager::GetInstance().ATCReportErrMessage("E10052", {"reason"}, {(errormsg)}); \ + return (_status); \ + } \ + } while (false) +namespace ge { +template +GE_FUNC_VISIBILITY std::string FmtToStr(const T &t) { + std::string fmt; + std::stringstream st; + st << "[" << t << "]"; + fmt = st.str(); + return fmt; +} +} +#endif // INC_COMMON_GE_COMMON_DEBUG_LOG_H_ diff --git a/inc/metadef/common/ge_common/fmk_error_codes.h b/inc/metadef/common/ge_common/fmk_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..7d6c44a0ec2c05b086cdc24ea8d889254c6a564a --- /dev/null +++ b/inc/metadef/common/ge_common/fmk_error_codes.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_FMK_ERROR_CODES_H_ +#define INC_COMMON_FMK_ERROR_CODES_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_OBJECT_VISIBILITY +#else +#define GE_OBJECT_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_OBJECT_VISIBILITY +#else +#define GE_OBJECT_VISIBILITY __attribute__((visibility("hidden"))) +#endif +#endif + +#include +#include + +#include "common/ge_common/fmk_types.h" +#include "register/register_error_codes.h" +#include "external/ge_common/ge_error_codes.h" + +// Each module uses the following four macros to define error codes: +#define DECLARE_ERRORNO_OMG(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OMG, name, value) +#define DECLARE_ERRORNO_OME(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_OME, name, value) +#define DECLARE_ERRORNO_CALIBRATION(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_CALIBRATION, name, value) + +#define DEF_ERRORNO(name, desc) \ + const bool g_##name##_errorno = StatusFactory::Instance()->RegisterErrorNo(name, desc) + +// Interface for Obtaining Error Code Description +#define GET_ERRORNO_STR(value) domi::StatusFactory::Instance()->GetErrDesc(value) + +namespace domi { +constexpr int32_t MODID_OMG = 1; // OMG module ID +constexpr int32_t MODID_OME = 2; // OME module ID +constexpr int32_t MODID_CALIBRATION = 3; // Calibration module ID + +class GE_FUNC_VISIBILITY StatusFactory { + public: + static StatusFactory *Instance(); + + bool RegisterErrorNo(const uint32_t err, const std::string &desc); + + std::string GetErrDesc(const uint32_t err); + + protected: + StatusFactory() = default; + virtual ~StatusFactory() = default; + + private: + std::map err_desc_; +}; + +// Common errocode +DECLARE_ERRORNO_COMMON(MEMALLOC_FAILED, 0); // 50331648 +DECLARE_ERRORNO_COMMON(CCE_FAILED, 2); // 50331650 +DECLARE_ERRORNO_COMMON(RT_FAILED, 3); // 50331651 +DECLARE_ERRORNO_COMMON(INTERNAL_ERROR, 4); // 50331652 +DECLARE_ERRORNO_COMMON(CSEC_ERROR, 5); // 50331653 +DECLARE_ERRORNO_COMMON(TEE_ERROR, 6); // 50331653 +DECLARE_ERRORNO_COMMON(UNSUPPORTED, 100); +DECLARE_ERRORNO_COMMON(OUT_OF_MEMORY, 101); + +// Omg errorcode +DECLARE_ERRORNO_OMG(PARSE_MODEL_FAILED, 0); +DECLARE_ERRORNO_OMG(PARSE_WEIGHTS_FAILED, 1); +DECLARE_ERRORNO_OMG(NOT_INITIALIZED, 2); +DECLARE_ERRORNO_OMG(TIMEOUT, 3); + +// Ome errorcode +DECLARE_ERRORNO_OME(MODEL_NOT_READY, 0); +DECLARE_ERRORNO_OME(PUSH_DATA_FAILED, 1); +DECLARE_ERRORNO_OME(DATA_QUEUE_ISFULL, 2); +} // namespace domi + +#endif // INC_COMMON_FMK_ERROR_CODES_H_ diff --git a/inc/metadef/common/ge_common/fmk_types.h b/inc/metadef/common/ge_common/fmk_types.h new file mode 100644 index 0000000000000000000000000000000000000000..5447d982af0a8cfe9a885a4493464f0e212c695c --- /dev/null +++ b/inc/metadef/common/ge_common/fmk_types.h @@ -0,0 +1,16 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_FMK_TYPES_H_ +#define INC_COMMON_FMK_TYPES_H_ + +#include "graph/types.h" +#include "register/register_types.h" + +#endif // INC_COMMON_FMK_TYPES_H_ diff --git a/inc/metadef/common/ge_common/ge_inner_error_codes.h b/inc/metadef/common/ge_common/ge_inner_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..34d8f177b7dce67e6b678b1241a85e63f2705a94 --- /dev/null +++ b/inc/metadef/common/ge_common/ge_inner_error_codes.h @@ -0,0 +1,364 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*lint -e* */ +#ifndef INC_COMMON_GE_INNER_ERROR_CODES_H_ +#define INC_COMMON_GE_INNER_ERROR_CODES_H_ + +#include +#include +#include "external/ge_common/ge_api_error_codes.h" + +// Each module defines error codes using the following macros, name can not be modified to (name) +#define GE_ERRORNO_COMMON(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::COMMON_MODULE, name, (value), (desc)) +#define GE_ERRORNO_CLIENT(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::CLIENT_MODULE, name, (value), (desc)) +#define GE_ERRORNO_INIT(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::INIT_MODULE, name, (value), (desc)) +#define GE_ERRORNO_SESSION(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::SESSION_MODULE, name, (value), (desc)) +#define GE_ERRORNO_GRAPH(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::GRAPH_MODULE, name, (value), (desc)) +#define GE_ERRORNO_ENGINE(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::ENGINE_MODULE, name, (value), (desc)) +#define GE_ERRORNO_OPS(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::OPS_MODULE, name, (value), (desc)) +#define GE_ERRORNO_PLUGIN(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::PLUGIN_MODULE, name, (value), (desc)) +#define GE_ERRORNO_RUNTIME(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::RUNTIME_MODULE, name, (value), (desc)) +#define GE_ERRORNO_EXECUTOR(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_DEVICE, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::EXECUTOR_MODULE, name, (value), (desc)) +#define GE_ERRORNO_GENERATOR(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::GENERATOR_MODULE, name, (value), (desc)) + +#define LLM_ERRORNO_COMMON(name, value, desc) \ + GE_ERRORNO(ge::InnLogRuntime::RT_HOST, ge::InnErrorCodeType::ERROR_CODE, \ + ge::InnErrorLevel::COMMON_LEVEL, ge::InnSystemIdType::SYSID_GE, \ + ge::InnSubModuleId::LLM_ENGINE_MODULE, name, (value), (desc)) + +// Get error code description +#define GE_GET_ERRORNO_STR(value) ge::StatusFactory::Instance()->GetErrDesc(value) + +#define RT_ERROR_TO_GE_STATUS(RT_ERROR) static_cast(RT_ERROR) + +namespace ge { +// System ID +enum class InnSystemIdType { SYSID_GE = 8 }; +// Runtime location +enum class InnLogRuntime { + RT_HOST = 0b01, + RT_DEVICE = 0b10, +}; + +// Sub model +enum class InnSubModuleId { + COMMON_MODULE = 0, + CLIENT_MODULE = 1, + INIT_MODULE = 2, + SESSION_MODULE = 3, + GRAPH_MODULE = 4, + ENGINE_MODULE = 5, + OPS_MODULE = 6, + PLUGIN_MODULE = 7, + RUNTIME_MODULE = 8, + EXECUTOR_MODULE = 9, + GENERATOR_MODULE = 10, + LLM_ENGINE_MODULE = 11, +}; + +// Error code type +enum class InnErrorCodeType { + ERROR_CODE = 0b01, + EXCEPTION_CODE = 0b10, +}; + +// Error level +enum class InnErrorLevel { + COMMON_LEVEL = 0b000, + SUGGESTION_LEVEL = 0b001, + MINOR_LEVEL = 0b010, + MAJOR_LEVEL = 0b011, + CRITICAL_LEVEL = 0b100, +}; + +// Common module error code definition +GE_ERRORNO_COMMON(MEMALLOC_FAILED, 0, "Failed to allocate memory!"); // 1343225856 +GE_ERRORNO_COMMON(PARAM_INVALID, 1, "Parameter invalid!"); // 1343225857 +GE_ERRORNO_COMMON(CCE_FAILED, 2, "Failed to call CCE API!"); // 1343225858 +GE_ERRORNO_COMMON(RT_FAILED, 3, "Failed to call runtime API!"); // 1343225859 +GE_ERRORNO_COMMON(INTERNAL_ERROR, 4, "Internal errors"); // 1343225860 +GE_ERRORNO_COMMON(CSEC_ERROR, 5, "Failed to call libc_sec API!"); // 1343225861 +GE_ERRORNO_COMMON(TEE_ERROR, 6, "Failed to call tee API!"); // 1343225862 +GE_ERRORNO_EXTERNAL(END_OF_SEQUENCE, "End of sequence!"); // 1343225863 +GE_ERRORNO_COMMON(PATH_INVALID, 8, "Path is invalid!"); // 1343225864 + +// Error code for plugin manager +GE_ERRORNO_COMMON(GE_PLGMGR_PATH_INVALID, 30, "Path is invalid!"); // 1343225886 +GE_ERRORNO_COMMON(GE_PLGMGR_SO_NOT_EXIST, 31, "Failed to find any valid so file!"); // 1343225887 +GE_ERRORNO_COMMON(GE_PLGMGR_FUNC_NOT_EXIST, 32, "Failed to find any function!"); // 1343225888 +GE_ERRORNO_COMMON(GE_PLGMGR_INVOKE_FAILED, 33, "Failed to invoke any function!"); // 1343225889 + +GE_ERRORNO_COMMON(UNSUPPORTED, 100, "Parameter unsupported!"); + +GE_ERRORNO_COMMON(OUT_OF_MEMORY, 101, "Out of memory!"); + +// Client module error code definition +GE_ERRORNO_CLIENT(GE_CLI_INIT_FAILED, 1, "GEInitialize Failed."); // 1343229953 +GE_ERRORNO_CLIENT(GE_CLI_FINAL_FAILED, 2, "GEFinalize Failed."); // 1343229954 +GE_ERRORNO_CLIENT(GE_CLI_SESS_CONSTRUCT_FAILED, 3, "Session constructor Failed."); // 1343229955 +GE_ERRORNO_CLIENT(GE_CLI_SESS_DESTROY_FAILED, 4, "Session destructor Failed."); // 1343229956 +GE_ERRORNO_CLIENT(GE_CLI_SESS_ADD_FAILED, 5, "Session AddGraph Failed."); // 1343229957 +GE_ERRORNO_CLIENT(GE_CLI_SESS_ADD_GRAPH_FAILED, 6, + "Session AddGraph Failed converting protobuf GraphProto."); // 1343229958 +GE_ERRORNO_CLIENT(GE_CLI_SESS_REMOVE_FAILED, 7, "Session RemoveGraph Failed."); // 1343229959 +GE_ERRORNO_CLIENT(GE_CLI_SESS_RUN_FAILED, 8, "Session RunGraph Failed."); // 1343229960 +GE_ERRORNO_CLIENT(GE_CLI_SESS_RUN_TENSOR_FAILED, 9, + "Session RunGraph Failed converting protobuf TensorProto."); // 1343229961 +GE_ERRORNO_CLIENT(GE_CLI_GE_ALREADY_INITIALIZED, 10, "GE is already initialized."); // 1343229962 +GE_ERRORNO_CLIENT(GE_CLI_GE_NOT_INITIALIZED, 11, "GE is not yet initialized or is finalized."); // 1343229963 + +// Init module error code definition +GE_ERRORNO_INIT(GE_MULTI_INIT, 0, "Multiple initializations are not supported."); // 1343234048 +GE_ERRORNO_INIT(GE_FINALIZE_NOT_INIT, 1, "Finalize is not allowed before initialization."); // 1343234049 +GE_ERRORNO_INIT(GE_MULTI_FINALIZE, 2, "Multiple finalizations are not supported."); // 1343234050 +GE_ERRORNO_INIT(GE_PROF_MULTI_INIT, 3, "Multiple profiling initializations are not supported."); // 1343234051 +GE_ERRORNO_INIT(GE_PROF_NOT_INIT, 4, "Profing initializations have not been done."); // 1343234052 +GE_ERRORNO_INIT(GE_PROF_MODE_CONFLICT, 5, + "Profiling command mode which is preferred is running, the api mode will not work."); // 1343234053 + +// Session module error code definition +GE_ERRORNO_SESSION(GE_SESS_INIT_FAILED, 0, "Failed to initialize session."); // 1343238144 +GE_ERRORNO_SESSION(GE_SESS_ALREADY_RUNNING, 1, "Session already running,not support parallel run."); // 1343238145 +GE_ERRORNO_SESSION(GE_SESS_GRAPH_NOT_EXIST, 2, "Graph ID not exist."); // 1343238146 +GE_ERRORNO_SESSION(GE_SESS_GRAPH_ALREADY_EXIST, 3, "Graph ID already exist."); // 1343238147 +GE_ERRORNO_SESSION(GE_SESS_GRAPH_IS_RUNNING, 4, "Graph is running."); // 1343238148 +GE_ERRORNO_SESSION(GE_SESSION_NOT_EXIST, 5, "Can not find session with specific session id."); // 1343238149 +GE_ERRORNO_SESSION(GE_SESSION_MANAGER_NOT_INIT, 6, "Session manager has not been initialized."); // 1343238150 + +// Graph module error code definition +GE_ERRORNO_GRAPH(GE_GRAPH_INIT_FAILED, 0, "Failed to initialize graph."); // 1343242240 +GE_ERRORNO_GRAPH(GE_GRAPH_ALREADY_RUNNING, 1, "graph already running,not support parallel run."); // 1343242241 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_NOT_EXIST, 2, "graph ID not exist."); // 1343242242 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_ALREADY_EXIST, 3, "Graph ID already exist."); // 1343242243 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_IS_RUNNING, 4, "Graph is running."); // 1343242244 +GE_ERRORNO_GRAPH(GE_GRAPH_MALLOC_FAILED, 5, "Graph malloc failed."); // 1343242245 +GE_ERRORNO_GRAPH(GE_GRAPH_FREE_FAILED, 6, "Graph FREE failed."); // 1343242246 +GE_ERRORNO_GRAPH(GE_GRAPH_NOT_MALLOC_BUFFER, 7, "Graph FREE failed, not malloc buffer."); // 1343242247 +GE_ERRORNO_GRAPH(GE_GRAPH_PARAM_NULLPTR, 8, "Graph param is NULL."); // 1343242248 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, 9, "Get computeGraph by graphNode failed."); // 1343242249 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_NODE_NULL, 10, "Run graph node is null."); // 1343242250 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by graphNode failed."); // 1343242251 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 +GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 +GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 +GE_ERRORNO_GRAPH(GE_GRAPH_GET_IN_OUT_FAILED, 19, "OME GetInputOutputDescInfo failed."); // 1343242259 +GE_ERRORNO_GRAPH(GE_GRAPH_DATA_INPUT_FAILED, 20, "OME DataInput failed."); // 1343242260 +GE_ERRORNO_GRAPH(GE_GRAPH_EXECUTE_FAILED, 21, "Execute graph failed."); // 1343242261 +GE_ERRORNO_GRAPH(GE_GRAPH_DUPLICATE_ENGINE, 22, "Duplicate engine."); // 1343242262 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_SUBGRAPH, 23, "Empty sub graph info."); // 1343242263 +GE_ERRORNO_GRAPH(GE_GRAPH_EXECUTE_NOT_INIT, 24, "Call SetCondition first."); // 1343242264 +GE_ERRORNO_GRAPH(GE_GRAPH_PREPARE_FAILED, 25, "Prepare failed."); // 1343242265 +GE_ERRORNO_GRAPH(GE_GRAPH_SERIALIZE_FAILED, 26, "OMG SerializeModelDef failed."); // 1343242266 +GE_ERRORNO_GRAPH(GE_GRAPH_SAVE_FAILED, 27, "OMG SaveModel failed."); // 1343242267 +GE_ERRORNO_GRAPH(GE_GRAPH_PRERUN_FAILED, 28, "PreRun failed."); // 1343242268 +GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ID_INVALID, 29, "Graph subGraph id is invalid."); // 1343242269 +GE_ERRORNO_GRAPH(GE_GRAPH_INFERSHAPE_FAILED, 30, "Prepare Graph infershape failed"); // 1343242270 +GE_ERRORNO_GRAPH(GE_GRAPH_ISNULL, 31, "RunGraph input compute graph is NULL."); // 1343242271 +GE_ERRORNO_GRAPH(GE_GRAPH_SYNC_MODEL_FAILED, 32, "Graph SyncExecuteModel failed."); // 1343242272 +GE_ERRORNO_GRAPH(GE_GRAPH_RUNGRAPH_FAILED, 33, "Graph RunGraph failed."); // 1343242273 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PARSE_DYN_OP_FAILED, 34, "Parse dynamic node config file failed"); // 1343242274 +GE_ERRORNO_GRAPH(GE_GRAPH_MULTI_SUBGRAPH_BUILD, 35, "Save model with multiple sub graph"); // 1343242275 +GE_ERRORNO_GRAPH(GE_GRAPH_GRAPH_NODE_NULL, 36, "Graph get graph node failed."); // 1343242276 +GE_ERRORNO_GRAPH(GE_GRAPH_NOT_INIT, 37, "Graph do not init."); // 1343242277 +GE_ERRORNO_GRAPH(GE_GRAPH_NULL_INPUT, 38, "input graph is null"); // 1343242278 +GE_ERRORNO_GRAPH(GE_GRAPH_TOPO_SORT_FAILED, 39, "topological sorting an partition failed"); // 1343242279 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_PARTITION, 40, "accessing an empty partition"); // 1343242280 +GE_ERRORNO_GRAPH(GE_GRAPH_UNSUPPORTED, 41, "unsupported feature in partition"); // 1343242281 +GE_ERRORNO_GRAPH(GE_GRAPH_ASSIGN_ENGINE_FAILED, 42, "assign engine failed"); // 1343242282 +GE_ERRORNO_GRAPH(GE_GRAPH_ADD_PLC_END_FAILED, 43, "add placeholder end node failed"); // 1343242283 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PARSE_OUT_NODE_FAILED, 44, "Parse out node failed."); // 1343242284 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_OP_PARSE_FAILED, 45, + "OMG parse dynamic node config file failed."); // 1343242285 +GE_ERRORNO_GRAPH(GE_GRAPH_SAVE_WEIGHTS_FAILED, 46, "OMG Save Weights to Model failed."); // 1343242286 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_STRING_NAME, 47, "Empty string name."); // 1343242287 +GE_ERRORNO_GRAPH(GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE, 48, "Empty variable-tensor table."); // 1343242288 +GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_ALREADY_EXIST, 49, "Variable already exist."); // 1343242289 +GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_DOES_NOT_EXIST, 50, "Variable does not exist."); // 1343242290 +GE_ERRORNO_GRAPH(GE_GRAPH_OPTIONS_INVALID, 51, "Client session options is invalid."); // 1343242291 +GE_ERRORNO_GRAPH(GE_GRAPH_NO_OUTPUT_DESC_INFO, 52, "No output desc info."); // 1343242292 +GE_ERRORNO_GRAPH(GE_GRAPH_OUTPUT_DESCINFO_TENSOR_NUM_MISMATCH, 53, + "Number of output descinfo and tensor mismatch."); // 1343242293 +GE_ERRORNO_GRAPH(GE_GRAPH_FILENAMEPREFIX_INVALID, 54, "Graph Save Model fileNamePrefix is invalid."); // 1343242294 +GE_ERRORNO_GRAPH(GE_GRAPH_NOT_BUILT, 55, "Graph is not compiled."); // 1343242295 +GE_ERRORNO_GRAPH(GE_GRAPH_SAVEMODEL_FAILED, 56, "Graph SaveModel failed."); // 1343242296 +GE_ERRORNO_GRAPH(GE_GRAPH_MEMORY_ALLOC_FAILED, 57, "Failed allocating memory for model file header."); // 1343242297 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_REMOVE_GRAPH_FAILED, 58, "Failed remove graph in node seacher."); // 1343242298 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_ADD_GRAPH_FAILED, 59, "Failed add graph in node seacher."); // 1343242299 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, + "Failed add graph in node seacher."); // 1343242300 +GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, + "Failed set graph finish rebuild in node searcher."); // 1343242301 +GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 +GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 +GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 +GE_ERRORNO_GRAPH(SUSPEND, 65, "The optimize operation is suspended."); // 1343242305 +GE_ERRORNO_GRAPH(GE_GRAPH_REPEAT_OPERATION, 66, "Repeat operation is not allowed."); // 1343242306 + +// Engine_manager module error code definition +GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 +GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 +GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 + +// Ops module error code definition +GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 +GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 +GE_ERRORNO_OPS(GE_OPS_KERNEL_INFO_NOT_EXIST, 2, "OpsKernelInfo not exist."); // 1343250434 +GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_NOT_EXIST, 3, "OpsKernelInfoStore not exist."); // 1343250435 +GE_ERRORNO_OPS(GE_OPS_CALC_RUNNING_PARAM_FAILED, 4, "Failed to CalcOpRunningParam."); // 1343250436 +GE_ERRORNO_OPS(GE_OPS_GENERATE_TASK_FAILED, 5, "Failed to GenerateTask."); // 1343250437 +GE_ERRORNO_OPS(GE_OPS_OPTIMIZE_ORIGINAL_GRAPH_FAILED, 6, "Failed to OptimizeOriginalGraph."); // 1343250438 +GE_ERRORNO_OPS(GE_OPS_OPTIMIZE_FUSED_GRAPH_FAILED, 7, "Failed to OptimizeFusedGraph."); // 1343250439 +GE_ERRORNO_OPS(GE_OPS_ENGINE_IS_NOT_REGISTERED, 8, "Engine is not registered."); // 1343250440 +GE_ERRORNO_OPS(GE_OPS_GET_NO_VALID_SO, 9, + "There is no valid so about OpsKernelInfoStore or GraphOptimizer."); // 1343250441 +GE_ERRORNO_OPS(GE_OPS_GET_OPTIMIZE_BY_ENGINE_FAILED, 10, "Failed to get graphOptimizer by name."); // 1343250442 +GE_ERRORNO_OPS(GE_OPS_GET_OPTIMIZE_BY_PRIORITY_FAILED, 11, "Failed to get graphOptimizer by priority."); // 1343250443 +GE_ERRORNO_OPS(GE_OPS_LOAD_GE_OPTIMIZER_FAILED, 12, "Failed to load ge graphOptimizer."); // 1343250444 + +// Runtime module error code definition +GE_ERRORNO_RUNTIME(GE_RTI_DEVICE_ID_INVALID, 1, "device id is invalid"); +GE_ERRORNO_RUNTIME(GE_RTI_DEVICE_NOT_READY, 2, "set device failed, device not ready"); +GE_ERRORNO_RUNTIME(GE_RTI_MEMALLOC_FAILED, 3, "malloc memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_MODEL_NOT_LOADED, 4, "model has not been loaded"); +GE_ERRORNO_RUNTIME(GE_RTI_THREAD_POOL_IS_NULL, 5, "model excute failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_CREATE_HANDLE_FAILED, 6, "cce create handle failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_SET_STREAM_FAILED, 7, "cce set stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_RTMODEL_FAILED, 8, "call runtime create rtModel failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_STREAM_FAILED, 9, "call runtime create stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_BIND_STREAM_FAILED, 10, "call runtime bind stream to model failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_LABLE_FAILED, 11, "call runtime create lable failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MODEL_LOAD_COMPLETE_FAILED, 12, "call runtime model load complete failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MODEL_GET_TASK_ID_FAILED, 14, "call runtime get task id failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_LAUNCH_FAILED, 13, "call runtime kernel launch failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_LAUNCHEX_FAILED, 15, "call runtime kernel launchex failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_FUSION_START_FAILED, 16, "call runtime kernel fusion start failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_KERNEL_FUSION_END_FAILED, 17, "call runtime kernel fusion end failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_LABEL_SET_FAILED, 18, "call runtime lable set failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_LABLE_GOTO_FAILED, 19, "call runtime lable goto failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_LABLE_SWITCH_FAILED, 20, "call runtime lable switch failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ALLOC_MANAGED_FAILED, 21, "call runtime mem alloc managed failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_FREE_MANAGED_FAILED, 22, "call runtime mem free managed failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_FREE_FAILED, 23, "call runtime free failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_STREAM_SYNC_FAILED, 24, "call runtime sync stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MODEL_EXCUTE_FAILED, 25, "call runtime model excute failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ASYNC_FAILED, 26, "call runtime mem async failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ALLOC_HOST_FAILED, 27, "call runtime alloc host memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_FREE_HOST_FAILED, 28, "call runtime free host memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_ALLOC_DEVICE_FAILED, 29, "call runtime alloc device memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_MEM_FREE_DEVICE_FAILED, 30, "call runtime free device memory failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_FLUSH_CACHE_FAILED, 31, "call runtime flush cache failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_UNBIND_STREAM_FAILED, 32, "unbind rtstream from rtmodel failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_STREAM_FAILED, 33, "destory stream failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_LABEL_FAILED, 34, "destory label failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_MODEL_FAILED, 35, "destory model failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_TRANS_TENSOR_FAILED, 36, "call cce transfer tensor descriptor failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_TRANS_FILTER_FAILED, 37, "call cce transfer filter descriptor failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_UPDATE_KERNEL_ARGS_FAILED, 38, "call cce update kernel args failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_CCE_DESTORY_HANDLE_FAILED, 39, "destory handle failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_CREATE_EVENT_FAILED, 40, "call rutime create event failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_EVENT_RECORD_FAILED, 41, "call rutime event record failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_STREAM_WAIT_EVENT_FAILED, 42, "call rutime stream wait event failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_BROADCAST_FAILED, 43, "call hccl hcom broadcast failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_ALL_GATHER_FAILED, 44, "call hccl hcom all gather failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_ALL_REDUCE_FAILED, 45, "call hccl hcom all reduce failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_RUNTIME_DESTORY_EVENT_FAILED, 46, "destory rt event failed"); +GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_REDUCE_SCATTER_FAILED, 47, "call hccl hcom reduce scatter failed"); + +// Executor module error code definition +GE_ERRORNO_EXECUTOR(GE_EXEC_NOT_INIT, 1, "GE Executor is not yet initialized."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PATH_INVALID, 2, "Model file path is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_KEY_PATH_INVALID, 3, "Key file path of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_ID_INVALID, 4, "Model id is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_DATA_SIZE_INVALID, 5, "Data size of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PARTITION_NUM_INVALID, 6, "Partition number of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_QUEUE_ID_INVALID, 7, "Queue id of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, 8, "Model does not support encryption."); +GE_ERRORNO_EXECUTOR(GE_EXEC_READ_MODEL_FILE_FAILED, 9, "Failed to read model file."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_REPEATED, 10, "The model is loaded repeatedly."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_PARTITION_FAILED, 11, "Failed to load model partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, 12, "Failed to load weight partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_TASK_PARTITION_FAILED, 13, "Failed to load task partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_KERNEL_PARTITION_FAILED, 14, "Failed to load kernel partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, 15, "Failed to allocate feature map memory."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate weight memory."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory."); +GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist."); +GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_P2P_MEM_FAILED, 20, "Failed to allocate P2P memory"); + +// Generator module error code definition +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, 2, "Graph manager add graph failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, 3, "Graph manager build graph failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, 4, "Graph manager finalize failed."); +GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_SAVE_MODEL_FAILED, 5, "Graph manager save model failed."); + +// 分步将错误码移动到air仓llm目录,完成后删除该宏 +#ifndef LLM_ERROR_CODES +#define LLM_ERROR_CODES +LLM_ERRORNO_COMMON(LLM_WAIT_PROC_TIMEOUT, 1, "request wait to be processed timeout!"); +LLM_ERRORNO_COMMON(LLM_KV_CACHE_NOT_EXIST, 2, "kv cache not exit!"); +LLM_ERRORNO_COMMON(LLM_REPEAT_REQUEST, 3, "repeat request!"); +LLM_ERRORNO_COMMON(LLM_REQUEST_ALREADY_COMPLETED, 4, "request already complete!"); +LLM_ERRORNO_COMMON(LLM_PARAM_INVALID, 5, "parameter is invalid!"); +LLM_ERRORNO_COMMON(LLM_ENGINE_FINALIZED, 6, "llm engine finalized!"); +LLM_ERRORNO_COMMON(LLM_NOT_YET_LINK, 7, "local cluster is not linked with remote cluster!"); +LLM_ERRORNO_COMMON(LLM_ALREADY_LINK, 8, "local cluster is already linked with remote cluster!"); +LLM_ERRORNO_COMMON(LLM_LINK_FAILED, 9, "local cluster link with remote cluster failed!"); +LLM_ERRORNO_COMMON(LLM_UNLINK_FAILED, 10, "local cluster unlink with remote cluster failed!"); +LLM_ERRORNO_COMMON(LLM_NOTIFY_PROMPT_UNLINK_FAILED, 11, "local cluster notify remote cluster do unlink failed!"); +LLM_ERRORNO_COMMON(LLM_CLUSTER_NUM_EXCEED_LIMIT, 12, "cluster num exceed limit!"); +LLM_ERRORNO_COMMON(LLM_PROCESSING_LINK, 13, "link is current processing, try again later!"); +LLM_ERRORNO_COMMON(LLM_DEVICE_OUT_OF_MEMORY, 14, "device out of memory!"); +LLM_ERRORNO_COMMON(LLM_PREFIX_ALREADY_EXIST, 15, "Prefix has already existed."); +LLM_ERRORNO_COMMON(LLM_PREFIX_NOT_EXIST, 16, "Prefix does not exist."); +LLM_ERRORNO_COMMON(LLM_SEQ_LEN_OVER_LIMIT, 17, "Sequence length exceed limit."); +LLM_ERRORNO_COMMON(LLM_NO_FREE_BLOCK, 18, "No free block."); +LLM_ERRORNO_COMMON(LLM_BLOCKS_OUT_OF_MEMORY, 19, "Block is out of memory."); +#endif +} // namespace ge + +#endif // INC_COMMON_GE_INNER_ERROR_CODES_H_ diff --git a/inc/metadef/common/ge_common/ge_types.h b/inc/metadef/common/ge_common/ge_types.h new file mode 100644 index 0000000000000000000000000000000000000000..37a3615ce1ee8f7d5bc87349c9eeb425c191ac5b --- /dev/null +++ b/inc/metadef/common/ge_common/ge_types.h @@ -0,0 +1,518 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_GE_TYPES_H_ +#define INC_COMMON_GE_TYPES_H_ + +#include +#include +#include + +#include "common/ge_common/fmk_error_codes.h" +#include "external/ge_common/ge_api_error_codes.h" +#include "external/graph/types.h" +#include "external/ge_common/ge_api_types.h" + +namespace ge { +enum RuntimeType { HOST = 0, DEVICE = 1 }; + +enum class PerfLevel : int32_t { + GEN_TASK_WITH_FUSION = -1, + GEN_TASK_WITHOUT_L2FUSION = 3, + GEN_TASK_WITHOUT_FUSION = 4 +}; + +enum FrameworkType { + CAFFE = 0, + MINDSPORE = 1, + TENSORFLOW = 3, + ANDROID_NN = 4, + ONNX = 5, +}; + +enum class GraphStage : int64_t { + GRAPH_STAGE_FUZZ = 0, + GRAPH_STAGE_RESERVED +}; + +const char_t *const kGraphDumpStage = "DumpStage"; + +const std::map kFwkTypeToStr = {{"0", "Caffe"}, + {"1", "MindSpore"}, + {"3", "TensorFlow"}, + {"4", "Android_NN"}, + {"5", "Onnx"}}; + +enum OpEngineType { + ENGINE_SYS = 0, // default engine + ENGINE_AICORE = 1, + ENGINE_VECTOR = 2, + ENGINE_AICUBE = 3, // not support + ENGINE_AIVECTOR = 4 // not support +}; + +enum InputAippType { DATA_WITHOUT_AIPP = 0, DATA_WITH_STATIC_AIPP, DATA_WITH_DYNAMIC_AIPP, DYNAMIC_AIPP_NODE }; + +enum OfflineModelFormat { OM_FORMAT_DEFAULT, OM_FORMAT_LITE, OM_FORMAT_NANO }; + +const char_t *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; +const char_t *const GE_OPTION_EXEC_PLACEMENT = "ge.exec.placement"; + +// profiling data + +const std::string kTaskTypeAicore = "AI_CORE"; +const std::string kTaskTypeAiv = "AIV"; +const std::string kTaskTypeMixAic = "MIX_AIC"; +const std::string kTaskTypeMixAiv = "MIX_AIV"; +const std::string kTaskTypeAicpu = "AI_CPU"; +const std::string kTaskTypeDsa = "DSA"; +const std::string kTaskTypeWriteBackData = "WRITE_BACK"; +const std::string kTaskTypeInvalidData = "INVALID"; +const std::string kTaskTypeInvalid = "TASK_TYPE_INVALID"; +const std::string kTaskTypeFftsPlus = "FFTS_PLUS"; +const std::string kEngineNameVectorCore = "VectorEngine"; + +const std::string kEngineNameHccl = "ops_kernel_info_hccl"; +const std::string kEngineNameRts = "DNN_VM_RTS_OP_STORE"; +const std::string kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; +const std::string kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE"; +const std::string kEngineNameAiCpu = "aicpu_ascend_kernel"; +const std::string kEngineNameAiCpuTf = "aicpu_tf_kernel"; +const std::string kEngineNameAiCore = "AIcoreEngine"; +const std::string kEngineNameDvpp = "dvpp_ops_kernel"; +const std::string kEngineNameDsa = "DSAEngine"; +const std::string kAtomicOpType = "DynamicAtomicAddrClean"; +const char_t *const kAICpuKernelLibName = "aicpu_kernel_lib_name"; +const char_t *const kPartiallySupported = "partially_supported"; + +// runtime2.0 lowering func +const std::string kAttrLowingFunc = "_ge_attr_lowering_func"; +const std::string kFFTSAiCoreLowerFunc = "ffts_ai_core_lower_func"; +const std::string kFFTSGraphLowerFunc = "ffts_graph_lower_func"; +const std::string kFFTSStaticGraphLowerFunc = "ffts_static_graph_lower_func"; +const std::string kFFTSMixL2LowerFunc = "ffts_mix_l2_lower_func"; +// runtime2.0 calculate func +const std::string kAttrCalcArgsSizeFunc = "_ge_attr_calculate_func"; +const std::string kFFTSMixL2CalcFunc = "ffts_mix_l2_calc_func"; + +const std::string kInputTensorIndexs = "input_tensor_indexs"; +const std::string kOutputTensorIndexs = "output_tensor_indexs"; +const std::string kShapeTypeStatic = "static"; +const std::string kShapeTypeDynamic = "dynamic"; +const std::string kAtomicPrefix = "_atomic"; + +constexpr uint64_t kInferSessionId = 0U; +constexpr uint64_t kReleaseFlag = 1U; +constexpr uint32_t kInvalidModelId = 0xFFFFFFFFU; +constexpr size_t kNumTaskWithAtomicAddrCleanTask = 2U; +constexpr uint32_t INVALID_MODEL_ID = 0xFFFFFFFFU; + +// dynamic execute mode +const char_t *const kLazyRecompile = "lazy_recompile"; +const char_t *const kIsCopyOuputAddr = "1"; + +constexpr size_t kMaxHostMemInputLen = 128U; // 64 aligned + +// memory policy +const std::string kBalanceMode = "BalanceMode"; +const std::string kMemoryPriority = "MemoryPriority"; +const std::set kValidValues = {"", kBalanceMode, kMemoryPriority}; + +const uint32_t kManualThread = 0U; +const uint32_t kAutoThread = 1U; + +// model deploy mode +const std::string kModelDeployModeSpmd = "SPMD"; + +// dsa +constexpr size_t kDSASetInputAddr = 0U; +constexpr size_t kDSAOutputAddrSize = 1U; +constexpr size_t kDSAWorkspaceAddrSize = 2U; +constexpr size_t kDSAInputAddrSize = 3U; +constexpr size_t kDSAArgsInputAddrSize = 4U; +constexpr size_t kDSAStateInputAddrSize = 5U; +constexpr size_t k32Bits = 32U; + +// mix +constexpr size_t kMixMultiKernelPcAddrCnt = 1U; +constexpr size_t kMixSingleKernelPcAddrCnt = 2U; +constexpr size_t kMixSingleOnlyKernelPcAddrCnt = 1U; +constexpr size_t kMixSingleKernelAicPcIndex = 0U; +constexpr size_t kMixSingleKernelAivPcIndex = 1U; + +// Data cache, including data address and length +struct DataBuffer { + void *data; // Data address + uint64_t length; // Data length + bool isDataSupportMemShare; + uint32_t placement; + + DataBuffer(void *const data_in, const uint64_t data_len, const bool is_support_mem_share = false, + const uint32_t data_placement = 0U) : data(data_in), length(data_len), + isDataSupportMemShare(is_support_mem_share), + placement(data_placement) {} + DataBuffer() : data(nullptr), length(0UL), isDataSupportMemShare(false), + placement(0U) {} +}; + +/// +/// @ingroup domi_ome +/// @brief External input data +/// +struct InputData { + uint32_t index; // Index of input data + uint32_t timestamp; // Data creation time + uint32_t timeout; // Processing timeout + uint32_t model_id; // Model ID required for data processing + uint64_t request_id = 0UL; // Request ID + std::vector blobs; // Actual input data, currently only supports one input + bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false + std::string batch_label; // Gear used for current inference in dynamic batch scene + std::vector> shapes; // Input shapes +}; + +/// Output result structure definition +struct OutputData { + uint32_t index; // Index of input data + uint32_t model_id; // The model ID corresponding to the processing result + /// Output data cache, arranged in sequence of output operators. + /// If the operator has multiple outputs, + /// the data buffer order of the operator is the same as that defined in the + /// offline model + std::vector blobs; +}; + +// The definition of command data structure +struct Command { + std::string cmd_type; // Command type + std::vector cmd_params; // Command params + uint64_t module_index; // prof module + uint32_t cache_flag; // clear prof cache flag +}; + +// The definition of I/O shape description +struct ShapeDescription { + int64_t num = 0L; + int64_t channel = 0L; + int64_t height = 0L; + int64_t width = 0L; + std::vector dims; + std::vector> shape_ranges; +}; + +// Definition of input and output description information +struct InputOutputDescInfo { + std::string name; + uint64_t size; + uint32_t data_type; + ShapeDescription shape_info; +}; + +// Definition of model io dims +struct InputOutputDims { + std::string name; + size_t dim_num; + uint32_t size; + std::vector dims; +}; + +// Definition of model io dims +struct OriginInputInfo { + Format format; + DataType data_type; + uint32_t dim_num; +}; + +// The structure of AIPP info +struct AippConfigInfo { + int8_t aipp_mode; + int8_t input_format; + int32_t src_image_size_w; + int32_t src_image_size_h; + int8_t crop; + int32_t load_start_pos_w; + int32_t load_start_pos_h; + int32_t crop_size_w; + int32_t crop_size_h; + int8_t resize; + int32_t resize_output_w; + int32_t resize_output_h; + int8_t padding; + int32_t left_padding_size; + int32_t right_padding_size; + int32_t top_padding_size; + int32_t bottom_padding_size; + int8_t csc_switch; + int8_t rbuv_swap_switch; + int8_t ax_swap_switch; + int8_t single_line_mode; + int32_t matrix_r0c0; + int32_t matrix_r0c1; + int32_t matrix_r0c2; + int32_t matrix_r1c0; + int32_t matrix_r1c1; + int32_t matrix_r1c2; + int32_t matrix_r2c0; + int32_t matrix_r2c1; + int32_t matrix_r2c2; + int32_t output_bias_0; + int32_t output_bias_1; + int32_t output_bias_2; + int32_t input_bias_0; + int32_t input_bias_1; + int32_t input_bias_2; + int32_t mean_chn_0; + int32_t mean_chn_1; + int32_t mean_chn_2; + int32_t mean_chn_3; + float32_t min_chn_0; + float32_t min_chn_1; + float32_t min_chn_2; + float32_t min_chn_3; + float32_t var_reci_chn_0; + float32_t var_reci_chn_1; + float32_t var_reci_chn_2; + float32_t var_reci_chn_3; + int8_t support_rotation; + uint32_t related_input_rank; + uint32_t max_src_image_size; +}; + +// The structure of offline Modeldata +struct ModelData { + void *model_data = nullptr; // Model binary data start addr + uint64_t model_len = 0UL; // Model binary data length + int32_t priority = 0; // Model priority + std::string key; // Key path for encrypt model, Empty for unencrypt + std::string om_name; // om file name, used for data dump + std::string om_path; // om file path, used for concatenating file constant path + std::string weight_path; // weight path, used for load weight +}; + +struct ModelParam { + ModelParam() : priority(0), mem_base(0U), mem_size(0U), weight_base(0U), weight_size(0U), fixed_mem_base(0U), + fixed_mem_size(0U), p2p_fixed_mem_base(0U), p2p_fixed_mem_size(0U) {} + ModelParam(const int32_t pri, const uintptr_t m_base, const size_t m_len, const uintptr_t w_base, const size_t w_len) + : priority(pri), mem_base(m_base), mem_size(m_len), weight_base(w_base), weight_size(w_len), fixed_mem_base(0U), + fixed_mem_size(0U), p2p_fixed_mem_base(0U), p2p_fixed_mem_size(0U) {} + virtual ~ModelParam() = default; + + int32_t priority; + uintptr_t mem_base; + size_t mem_size; + uintptr_t weight_base; + size_t weight_size; + uintptr_t fixed_mem_base; + size_t fixed_mem_size; + uintptr_t p2p_fixed_mem_base; + size_t p2p_fixed_mem_size; +}; + +// The definition of Model information +struct ModelInfo { + uint32_t version = 0U; + std::string name; + bool is_encrypt = false; // 0:unencrypt, 1:encrypt + std::vector input_desc; + std::vector output_desc; + uint8_t reserved[3] = {0U}; // 3-byte reserved field +}; + +// Asynchronous callback interface, implemented by the caller +class GE_FUNC_VISIBILITY ModelListener { + public: + virtual ~ModelListener() = default; + ModelListener() = default; + ModelListener(const ModelListener &) = delete; + ModelListener& operator=(const ModelListener &) & = delete; + /// + /// @brief Asynchronous callback interface + /// @param [in] model_id Model ID of the callback + /// @param [in] data_index Index of the input_data + /// @param [in] resultCode Execution results + /// + virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, + std::vector &outputs) = 0; + + virtual void SetCallback(const RunAsyncCallback &callback) { + (void)callback; + } + + virtual uint32_t GetResultCode() { return 0U; }; + + virtual Status ResetResult() { return SUCCESS; }; +}; + +// Profiling info of task +struct TaskDescInfo { + uint64_t prof_time; + std::string model_name; + std::string op_name; + std::string op_type; + uint32_t block_dim; + uint32_t task_id; + uint32_t stream_id; + std::string shape_type; + int64_t cur_iter_num; + std::string task_type; + std::vector input_format; + std::vector> input_shape; + std::vector input_data_type; + std::vector output_format; + std::vector> output_shape; + std::vector output_data_type; + uint32_t context_id = 0xFFFFFFFFU; +}; + +struct OpDescInfoId { + uint32_t task_id; + uint32_t stream_id; + uint32_t context_id; + uint32_t thread_id; + int32_t device_id; + + OpDescInfoId(): task_id(0U), stream_id(0U), context_id(UINT32_MAX), thread_id(UINT32_MAX), device_id(0) {} + OpDescInfoId(const uint32_t task, const uint32_t stream) + : OpDescInfoId() { + task_id = task; + stream_id = stream; + } + OpDescInfoId(const uint32_t task, const uint32_t stream, const uint32_t context, const uint32_t thread) + : OpDescInfoId(task, stream) { + context_id = context; + thread_id = thread; + } + // for exception dump(one process may has multi thread for mutli device ids) + OpDescInfoId(const uint32_t task, const uint32_t stream, const int32_t dev_id) + : task_id(task), stream_id(stream), context_id(UINT32_MAX), thread_id(UINT32_MAX), device_id(dev_id) {} + OpDescInfoId(const uint32_t task, const uint32_t stream, const uint32_t context, const uint32_t thread, + const int32_t dev_id) + : task_id(task), stream_id(stream), context_id(context), thread_id(thread), device_id(dev_id) {} +}; + +struct OpDescInfo { + std::string op_name; + std::string op_type; + OpDescInfoId id; + uint32_t imply_type = 0U; + uint32_t block_dim = 0U; + std::string op_file_path; + std::string dev_func; + std::string tvm_magic; + uint32_t tiling_key = 0U; + uintptr_t args = 0U; + size_t args_size = 0UL; + std::map cust_to_relevant_offset_; + std::string tiling_data; + bool is_mem_log; + std::vector space_addrs; + std::string node_info; + std::vector workspace_bytes; + std::vector input_format; + std::vector> input_shape; + std::vector input_data_type; + std::vector input_addrs; + std::vector input_size; + std::vector output_format; + std::vector> output_shape; + std::vector output_data_type; + std::vector output_addrs; + std::vector output_size; + bool is_host_args{false}; + std::string all_attrs; + std::string args_before_execute; +}; +struct ModelDumpConfig { + std::string model_name; + std::vector layers; + std::vector watcher_nodes; +}; + +struct DumpConfig { + std::string dump_path; + std::string dump_mode; + std::string dump_status; + std::string dump_op_switch; + std::string dump_debug; + std::string dump_step; + std::string dump_exception; + std::vector dump_list; + std::string dump_data; + std::string dump_level; + std::vector dump_stats; +}; + +struct QueueAttrs { + uint32_t queue_id; + int32_t device_type; // CPU NPU + int32_t device_id; + uint32_t logic_id {0U}; +}; + +struct InputAlignAttrs { + uint32_t align_max_cache_num; // 0 means align not enable + int32_t align_timeout; // -1 means never timeout + bool drop_when_not_align; + uint8_t res[3]; +}; +static_assert(std::is_pod::value, "The class InputAlignAttrs must be a POD"); + +struct ModelQueueParam { + uint32_t group_total_count{1}; + uint32_t group_index{0U}; + uint32_t group_policy{0U}; + std::vector input_queues; + std::vector output_queues; + std::vector input_fusion_offsets; + std::vector input_events; + std::vector output_events; + std::vector input_queues_attrs; + std::vector output_queues_attrs; + QueueAttrs status_output_queue; + uint32_t model_uuid {0U}; + bool is_dynamic_sched {false}; + bool need_report_status {false}; + InputAlignAttrs input_align_attrs{}; + bool is_head {true}; + bool no_need_check_inputs {false}; +}; + +// internal options +// 1: Graph resource evaluation does not limit model memory size. +const char_t *const EVALUATE_GRAPH_RESOURCE_MODE = "ge.evaluateGraphResourceMode"; + +// 3: Config all resource and device mesh +const char_t *const RESOURCE_CONFIG_PATH = "ge.resourceConfigPath"; + +// 5: auto recompute attribute +const char_t *const RECOMPUTE = "ge.recompute"; +const char_t *const GRAPH_SLICE_MODE = "ge.graphSliceMode"; + +// 6: Topological Sorting Mode +const char_t *const OPTION_TOPOSORTING_MODE = "ge.topoSortingMode"; + +const char_t *const OPTION_EXEC_RANK_TABLE = "ge.exec.rankTable"; +const char_t *const OPTION_EXEC_HCOM_GROUPLIST = "ge.exec.hcomGrouplist"; +const char_t *const OPTION_EXEC_HCOM_RANK_MAPPING = "ge.exec.hcomRankMapping"; + +const char_t *const OPTION_NUMA_CONFIG = "ge.numaConfig"; + +// 7: config format mode(expirimental option) +const char_t *const OPTION_EXEC_FORMAT_MODEL = "ge.exec.formatMode"; + +// 8: config build graph mode(online or offline) +const char_t *const OPTION_BUILD_GRAPH_MODE = "ge.buildGraphMode"; + +const std::set ir_builder_suppported_options_inner = {EVALUATE_GRAPH_RESOURCE_MODE, + RESOURCE_CONFIG_PATH, + RECOMPUTE, + OPTION_TOPOSORTING_MODE, + GRAPH_SLICE_MODE}; +} // namespace ge +#endif // INC_COMMON_GE_TYPES_H_ diff --git a/inc/metadef/common/ge_common/scope_guard.h b/inc/metadef/common/ge_common/scope_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..bc7b5541e8264265c9ac0700faee1d340d0a631d --- /dev/null +++ b/inc/metadef/common/ge_common/scope_guard.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_SCOPE_GUARD_H_ +#define INC_COMMON_SCOPE_GUARD_H_ + +#include + +/// Usage: +/// Acquire Resource 1 +/// MAKE_GUARD([&] { Release Resource 1 }) +/// Acquire Resource 2 +// MAKE_GUARD([&] { Release Resource 2 }) +#define GE_MAKE_GUARD(var, callback) const ::ge::ScopeGuard const_guard_##var(callback) + +#define GE_DISMISSABLE_GUARD(var, callback) ::ge::ScopeGuard make_guard_##var(callback) +#define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() + +namespace ge { +class GE_FUNC_VISIBILITY ScopeGuard { + public: + // Noncopyable + ScopeGuard(ScopeGuard const &) = delete; + ScopeGuard &operator=(ScopeGuard const &) & = delete; + + explicit ScopeGuard(const std::function &on_exit_scope) : on_exit_scope_(on_exit_scope), dismissed_(false) {} + + ~ScopeGuard() { + if (!dismissed_) { + if (on_exit_scope_ != nullptr) { + try { + on_exit_scope_(); + } catch (std::bad_function_call &) { } + catch (...) { } + } + } + } + + void Dismiss() { dismissed_ = true; } + + private: + std::function on_exit_scope_; + bool dismissed_ ; +}; +} // namespace ge + +#endif // INC_COMMON_SCOPE_GUARD_H_ diff --git a/inc/metadef/common/ge_common/string_util.h b/inc/metadef/common/ge_common/string_util.h new file mode 100644 index 0000000000000000000000000000000000000000..2448030cdeb3e12cb2dd403ee7a8581d180452fb --- /dev/null +++ b/inc/metadef/common/ge_common/string_util.h @@ -0,0 +1,198 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_STRING_UTIL_H_ +#define INC_COMMON_STRING_UTIL_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include "graph/types.h" +#include "external/ge_common/ge_error_codes.h" + +namespace ge { +class GE_FUNC_VISIBILITY StringUtils { + public: + static std::string &Ltrim(std::string &s) { +#if __cplusplus >= 201103L + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), + [](const int32_t c) { return std::isspace(c) == 0; })); +#else + (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), + std::not1(std::ptr_fun(std::isspace)))); +#endif + return s; + } + // lint -esym(551,*) + static std::string &Rtrim(std::string &s) { /*lint !e618*/ +#if __cplusplus >= 201103L + (void)s.erase(std::find_if(s.rbegin(), s.rend(), + [](const int32_t c) { return std::isspace(c) == 0; }).base(), s.end()); +#else + (void)s.erase(std::find_if(s.rbegin(), s.rend(), + std::not1(std::ptr_fun(std::isspace))).base(), s.end()); +#endif + return s; + } + // lint -esym(551,*) + /// + /// @ingroup domi_common + /// @brief delete spaces at the beginning and end of a string + /// @param [in] string to be trimmed + /// @return string after trim + /// + static std::string &Trim(std::string &s) { return Ltrim(Rtrim(s)); } + + /// + /// @ingroup domi_common + /// @brief string splitting + /// @param [in] str string to be trimmed + /// @param [in] delim separator + /// @return string array after segmentation + /// + static std::vector> Split(const std::string &str, const char_t delim) { + std::vector> elems; + + if (str.empty()) { + (void) elems.emplace_back(""); + return elems; + } + + std::stringstream ss(str); + std::string item; + + while (getline(ss, item, delim)) { + (void) elems.push_back(item); + } + + const auto str_size = str.size(); + if ((str_size > 0U) && (str[str_size - 1U] == delim)) { + (void) elems.emplace_back(""); + } + + return elems; + } + + template + static std::string Join(Iterator begin, Iterator end, const std::string &separator) { + if (begin == end) { + return ""; + } + std::stringstream str_stream; + str_stream << *begin; + for (Iterator it = std::next(begin); it != end; ++it) { + str_stream << separator << *it; + } + return str_stream.str(); + } + + /// + /// @ingroup domi_common + /// @brief obtain the file name + /// @param [in] s path name + /// @return file name + /// + static std::string GetFileName(const std::string &s) { + if (s.empty()) { + return ""; + } + const std::vector files = StringUtils::Split(s, '/'); + + return files.empty() ? "" : files[files.size() - 1U]; + } + /// + /// @ingroup domi_common + /// @brief full replacement + /// @link + /// @param [in] str str string to be replaced + /// @param [in] old_value old Characters Before Replacement + /// @param [in] new_value new Characters Before Replacement + /// @return string after replacement + /// + static std::string ReplaceAll(std::string str, const std::string &old_value, const std::string &new_value) { + std::string::size_type cur_pos = 0U; + const std::string::size_type old_length = old_value.length(); + const std::string::size_type new_length = new_value.length(); + // cycle replace + for (; cur_pos != std::string::npos; cur_pos += new_length) { + cur_pos = str.find(old_value, cur_pos); + if (cur_pos != std::string::npos) { + (void)str.replace(cur_pos, old_length, new_value); + } else { + break; + } + } + return str; + } + + /// + /// @ingroup domi_common + /// @brief checks whether a character string starts with a character string (prefix) + /// @link + /// @param [in] str string to be compared + /// @param [in] str_x prefix + /// @return if the value is a prefix, true is returned. Otherwise, false is returned + /// + static bool StartWith(const std::string &str, const std::string &str_x) { + return ((str.size() >= str_x.size()) && (str.compare(0U, str_x.size(), str_x) == 0)); + } + + /// + /// @ingroup domi_common + /// @brief format string + /// @link + /// @param [in] format specifies the character string format + /// @param [in] ... format Filling Content + /// @return formatted string + /// + static std::string FormatString(const char_t * const format, ...) { + const uint32_t MAX_BUFFER_LEN = 1024U; // the stack memory plint check result must be less than 1024 + va_list args; + va_start(args, format); + char_t buffer[MAX_BUFFER_LEN] = {}; + const int32_t ret = vsnprintf_s(&buffer[0], MAX_BUFFER_LEN, MAX_BUFFER_LEN - 1U, format, args); + va_end(args); + return (ret > 0) ? buffer : ""; + } + + /// + /// @ingroup domi_common + /// @brief check whether the input str is signed int32 + /// @link + /// @param [in] str input string + /// @return true/false + /// + static bool IsSignedInt32(const std::string &str) { + if (str.empty()) { + return false; + } + + try { + (void)std::stoi(str); + } catch (...) { + return false; + } + + for (size_t i = 1; i < str.size(); i++) { + if (isdigit(static_cast(str.at(i))) == 0) { + return false; + } + } + + return true; + } +}; +} // namespace ge + +#endif // INC_COMMON_STRING_UTIL_H_ diff --git a/inc/metadef/common/ge_common/util.h b/inc/metadef/common/ge_common/util.h new file mode 100644 index 0000000000000000000000000000000000000000..818ef345258b69c2a8e2fd24a0f4f42a2aad02f4 --- /dev/null +++ b/inc/metadef/common/ge_common/util.h @@ -0,0 +1,318 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_INC_COMMON_UTIL_H_ +#define AIR_INC_COMMON_UTIL_H_ + +#include +#include +#include + +#include "external/graph/types.h" +#include "external/register/register.h" +#include "common/ge_common/debug/log.h" +#include "common/ge_common/scope_guard.h" +#include "common/ge_common/ge_inner_error_codes.h" + +#define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ + do { \ + if ((size) <= 0) { \ + GELOGE(ge::FAILED, "param[%s] is not a positive number", #size); \ + return PARAM_INVALID; \ + } \ + } while (false) + +#define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ + { \ + const bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } + +// new ge marco +// Encapsulate common resource releases +#define GE_MAKE_GUARD_RTMEM(var) \ + GE_MAKE_GUARD(var, [&var]() { \ + if ((var) != nullptr) { \ + GE_CHK_RT(rtFreeHost(var)); \ + } \ + }) + +#define GE_MAKE_GUARD_RTSTREAM(var) \ + GE_MAKE_GUARD(var, [&]() { \ + if ((var) != nullptr) { \ + GE_CHK_RT(rtStreamDestroy(var)); \ + } \ + }) + +// For propagating errors when calling a function. +#define GE_RETURN_IF_ERROR(expr) \ + do { \ + const ge::Status _chk_status = (expr); \ + if (_chk_status != ge::SUCCESS) { \ + return _chk_status; \ + } \ + } while (false) + +#define GE_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ge::Status _chk_status = (expr); \ + if (_chk_status != ge::SUCCESS) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return _chk_status; \ + } \ + } while (false) + +// check whether the parameter is true. If it is, return FAILED and record the error log +#define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return ge::FAILED; \ + } \ + } while (false) + +// Check if the parameter is false. If yes, return FAILED and record the error log +#define GE_RETURN_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + const bool _condition = (condition); \ + if (!_condition) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return ge::FAILED; \ + } \ + } while (false) + +// Checks whether the parameter is true. If so, returns PARAM_INVALID and records the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ + do { \ + if (condition) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return ge::PARAM_INVALID; \ + } \ + } while (false) + +// Check if the parameter is false. If yes, return PARAM_INVALID and record the error log +#define GE_RT_PARAM_INVALID_WITH_LOG_IF_FALSE(condition, ...) \ + do { \ + const bool _condition = (condition); \ + if (!_condition) { \ + GELOGE(ge::FAILED, __VA_ARGS__); \ + return ge::PARAM_INVALID; \ + } \ + } while (false) + +// Check if the parameter is null. If yes, return PARAM_INVALID and record the error +#define GE_CHECK_NOTNULL(val, ...) \ + do { \ + if ((val) == nullptr) { \ + REPORT_INNER_ERROR("E19999", "Param:" #val " is nullptr, check invalid" __VA_ARGS__); \ + GELOGE(ge::FAILED, "[Check][Param:" #val "]null is invalid" __VA_ARGS__); \ + return ge::PARAM_INVALID; \ + } \ + } while (false) + +// Check if the parameter is null. If yes, just return and record the error +#define GE_CHECK_NOTNULL_JUST_RETURN(val) \ + do { \ + if ((val) == nullptr) { \ + GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \ + return; \ + } \ + } while (false) + +// Check whether the parameter is null. If so, execute the exec_expr expression and record the error log +#define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ + do { \ + if ((val) == nullptr) { \ + GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \ + exec_expr; \ + } \ + } while (false) + +// Check whether the parameter is null. If yes, return directly and record the error log +#define GE_RT_VOID_CHECK_NOTNULL(val) \ + do { \ + if ((val) == nullptr) { \ + GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \ + return; \ + } \ + } while (false) + +// Check if the parameter is null. If yes, return false and record the error log +#define GE_RT_FALSE_CHECK_NOTNULL(val) \ + do { \ + if ((val) == nullptr) { \ + GELOGE(ge::FAILED, "param[%s] must not be null.", #val); \ + return false; \ + } \ + } while (false) + +// Check if the parameter is out of bounds +#define GE_CHECK_SIZE(size) \ + do { \ + if ((size) == 0U) { \ + GELOGE(ge::FAILED, "param[%s] is out of range", #size); \ + return ge::PARAM_INVALID; \ + } \ + } while (false) + +// Check if the value on the left is greater than or equal to the value on the right +#define GE_CHECK_GE(lhs, rhs) \ + do { \ + if ((lhs) < (rhs)) { \ + GELOGE(ge::FAILED, "param[%s][%ld] is less than[%s][%ld]", \ + #lhs, static_cast(lhs), #rhs, static_cast(rhs)); \ + return ge::PARAM_INVALID; \ + } \ + } while (false) + +// Check if the value on the left is less than or equal to the value on the right +#define GE_CHECK_LE(lhs, rhs) \ + do { \ + if ((lhs) > (rhs)) { \ + GELOGE(ge::FAILED, "param[%s][%ld] is greater than[%s][%ld]", \ + #lhs, static_cast(lhs), #rhs, static_cast(rhs)); \ + return ge::PARAM_INVALID; \ + } \ + } while (false) + +#define GE_DELETE_NEW_SINGLE(var) \ + do { \ + if ((var) != nullptr) { \ + delete (var); \ + (var) = nullptr; \ + } \ + } while (false) + +#define GE_DELETE_NEW_ARRAY(var) \ + do { \ + if ((var) != nullptr) { \ + delete[] (var); \ + (var) = nullptr; \ + } \ + } while (false) + +#define GE_FREE_RT_LOG(addr) \ + do { \ + if ((addr) != nullptr) { \ + const rtError_t error = rtFree(addr); \ + if (error != RT_ERROR_NONE) { \ + GELOGE(ge::RT_FAILED, "Call rtFree failed, error: %#x", error); \ + } \ + (addr) = nullptr; \ + } \ + } while (false) + +namespace ge { +/** + * @ingroup domi_common + * @brief version of om.proto file + */ +constexpr int32_t OM_PROTO_VERSION = 2; + +/// @ingroup domi_common +/// @brief onverts Vector of a number to a string. +/// @param [in] v Vector of a number +/// @return string +template +GE_FUNC_VISIBILITY std::string ToString(const std::vector &v) { + bool first = true; + std::stringstream ss; + ss << "["; + for (const T &x : v) { + if (first) { + first = false; + ss << x; + } else { + ss << ", " << x; + } + } + ss << "]"; + return ss.str(); +} + +/// @ingroup: domi_common +/// @brief: get length of file +/// @param [in] input_file: path of file +/// @return int64_t: File length. If the file length fails to be obtained, the value -1 is returned. +GE_FUNC_VISIBILITY extern int64_t GetFileLength(const std::string &input_file); + +/// @ingroup domi_common +/// @brief Reads all data from a binary file. +/// @param [in] file_name path of file +/// @param [out] buffer Output memory address, which needs to be released by the caller. +/// @param [out] length Output memory size +/// @return false fail +/// @return true success +GE_FUNC_VISIBILITY bool ReadBytesFromBinaryFile(const char_t *const file_name, char_t **const buffer, int32_t &length); + +///@ingroup domi_common +/// @brief Get binary file from file +/// @param [in] path file path. +/// @param [out] buffer char[] used to store file data +/// @param [out] data_len store read size +/// @return graphStatus GRAPH_SUCCESS: success, OTHERS: fail +GE_FUNC_VISIBILITY graphStatus GetBinFromFile(const std::string &path, char_t *buffer, size_t &data_len); + +/// @ingroup domi_common +/// @brief Recursively Creating a Directory +/// @param [in] directory_path Path, which can be a multi-level directory. +/// @return 0 success +/// @return -1 fail +GE_FUNC_VISIBILITY extern int32_t CreateDirectory(const std::string &directory_path); + +/// @ingroup domi_common +/// @brief Obtains the current time string. +/// @return Time character string in the format : %Y%m%d%H%M%S, eg: 20171011083555 +GE_FUNC_VISIBILITY std::string CurrentTimeInStr(); + +/// @ingroup domi_common +/// @brief Obtains the absolute time (timestamp) of the current system. +/// @return Timestamp, in microseconds (US) +GE_FUNC_VISIBILITY uint64_t GetCurrentTimestamp(); + +/// @ingroup domi_common +/// @brief Obtains the absolute time (timestamp) of the current system. +/// @return Timestamp, in seconds (US) +GE_FUNC_VISIBILITY uint32_t GetCurrentSecondTimestap(); + +/// @ingroup domi_common +/// @brief Absolute path for obtaining files. +/// @param [in] path of input file +/// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned +GE_FUNC_VISIBILITY std::string RealPath(const char_t *path); + +/// @ingroup domi_common +/// @brief Check whether the specified input file path is valid. +/// 1. The specified path cannot be empty. +/// 2. The path can be converted to an absolute path. +/// 3. The file path exists and is readable. +/// @param [in] file_path path of input file +/// @param [out] result +GE_FUNC_VISIBILITY bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); + +/// @ingroup domi_common +/// @brief Checks whether the specified output file path is valid. +/// @param [in] file_path path of output file +/// @param [out] result +GE_FUNC_VISIBILITY bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); + +/// @ingroup domi_common +/// @brief Check whether the file path meets the whitelist verification requirements. +/// @param [in] str file path +/// @param [out] result +GE_FUNC_VISIBILITY bool ValidateStr(const std::string &file_path, const std::string &mode); + +GE_FUNC_VISIBILITY Status ConvertToInt32(const std::string &str, int32_t &val); + +GE_FUNC_VISIBILITY std::string GetErrorNumStr(const int32_t errorNum); +} // namespace ge + +#endif // AIR_INC_COMMON_UTIL_H_ diff --git a/inc/metadef/common/hyper_status.h b/inc/metadef/common/hyper_status.h new file mode 100644 index 0000000000000000000000000000000000000000..627f00838bb8bbed68dee69c69aa6196dcffdef9 --- /dev/null +++ b/inc/metadef/common/hyper_status.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_BASE_COMMON_HYPER_STATUS_H_ +#define AIR_CXX_BASE_COMMON_HYPER_STATUS_H_ +#include +#include +#include "graph/types.h" +namespace gert { +ge::char_t *CreateMessage(const ge::char_t *format, va_list arg); +class HyperStatus { + public: + bool IsSuccess() const { + return status_ == nullptr; + } + const ge::char_t *GetErrorMessage() const noexcept { + return status_; + } + ~HyperStatus() { + delete[] status_; + } + + HyperStatus() {} + HyperStatus(const HyperStatus &other); + HyperStatus(HyperStatus &&other) noexcept; + HyperStatus &operator=(const HyperStatus &other); + HyperStatus &operator=(HyperStatus &&other) noexcept; + + static HyperStatus Success(); + static HyperStatus ErrorStatus(const ge::char_t *message, ...); + static HyperStatus ErrorStatus(std::unique_ptr message); + + private: + ge::char_t *status_ = nullptr; +}; +} // namespace gert + +namespace ge { +using HyperStatus = gert::HyperStatus; +} // namespace ge +#endif // AIR_CXX_BASE_COMMON_HYPER_STATUS_H_ diff --git a/inc/metadef/common/large_bm.h b/inc/metadef/common/large_bm.h new file mode 100644 index 0000000000000000000000000000000000000000..907c63d4e9a18dc4bde4dfa86f7228e52e2cb255 --- /dev/null +++ b/inc/metadef/common/large_bm.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_LARGE_BM_H_ +#define INC_COMMON_LARGE_BM_H_ + +#include +#include + +/* LargeBitmap create a way to generate bitmaps larger than 64bit. */ +namespace ge { +class LargeBitmap { +public: + explicit LargeBitmap(const size_t &size); + + ~LargeBitmap() = default; + + bool operator==(const LargeBitmap &another_bm) const; + + bool operator!=(const LargeBitmap &another_bm) const; + + // set all vector to specific value + void SetValues(const uint64_t &value); + + // Get the value on position index + bool GetBit(const size_t &index) const; + + // Set the value on position index to 1 + void SetBit(const size_t &index); + + // Combine two bitmap with the following rule. + // If one bit of either one of the two bitmaps is 1, + // the result of final bitmap is 1. + void Or(const LargeBitmap &another_bm); + + // Combine two bitmap with the following rule. + // If one bit of either one of the two bitmaps is 0, + // the result of final bitmap is 0. + void And(const LargeBitmap &another_bm); + + void ClearBit(size_t bit_idx); + + void ResizeBits(size_t new_size); +private: + // Number of element in vector bits + size_t size_; + + std::vector bits_; +}; +} +#endif // INC_COMMON_LARGE_BM_H_ diff --git a/inc/metadef/common/npu_error_define.h b/inc/metadef/common/npu_error_define.h new file mode 100644 index 0000000000000000000000000000000000000000..f88c3ae9dc98ccc8057920932c483f5c8b6c4d61 --- /dev/null +++ b/inc/metadef/common/npu_error_define.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_NPU_ERROR_DEFINE_H_ +#define INC_COMMON_NPU_ERROR_DEFINE_H_ + +typedef enum tagHiAiNpuLocal { + HIAI_HOST = 1, + HIAI_DEVICE = 2, +} HiAiNpuLocal; + +typedef enum tagHiAiNpuCodeType { + ERROR_CODE = 1, + EXCEPTION_CODE = 2, +} HiAiNpuCodeType; + +typedef enum tagHiAiNpuErrLevel { + NONE_LEVEL = 0, + SUGGESTION_LEVEL = 1, + NORMAL_LEVEL = 2, + SERIOUS_LEVEL = 3, + CRITICAL_ERROR = 4, +} HiAiNpuErrLevel; + +typedef enum tagHiAiNpuModuleId { + HIAI_DRIVER = 1, + HIAI_CTRLCPU = 2, + HIAI_TS = 3, + HIAI_RUNTIME = 4, + HIAI_AICPU = 5, + HIAI_CCE = 6, + HIAI_TVM = 7, + HIAI_FRAMEWORK = 8, + HiAI_ENGINE = 9, + HIAI_DVPP = 10, + HIAI_AIPP = 11, + HIAI_LOWPOWER = 12, + HIAI_MDC = 13, + HIAI_COMPILE = 14, + HIAI_TOOLCHIAN = 15, + HIAI_ALG = 16, + HIAI_PROFILING = 17, + HIAI_HCCL = 18, + HIAI_SIMULATION = 19, + HIAI_BIOS = 20, + HIAI_SEC = 21, + HIAI_TINY = 22, + HIAI_DP = 23, +} HiAiNpuModuleId; + +/* bit 31-bit30 to be hiai local */ +#define HIAI_NPULOCAL_MASK 0xC0000000 +#define SHIFT_LOCAL_MASK 30 +#define HIAI_NPULOCAL_VAL_MASK 0x3 +/* bit 29 -bit28 to be hiai aicpu code type */ +#define HIAI_CODE_TYPE_MASK 0x30000000 +#define SHIFT_CODE_MASK 28 +#define HIAI_CODE_TYPE_VAL_MASK 0x3 +/* bit 27 -bit25 to be hiai error level */ +#define HIAI_ERROR_LEVEL_MASK 0x0E000000 +#define SHIFT_ERROR_LVL_MASK 25 +#define HIAI_ERROR_LEVEL_VAL_MASK 0x7 +/* bit 24 -bit17 to be hiai mod */ +#define HIAI_MODE_ID_MASK 0x01FE0000 +#define SHIFT_MODE_MASK 17 +#define HIAI_MODE_ID_VAL_MASK 0xFF + +#define HIAI_NPU_LOC_BIT(a) \ + (HIAI_NPULOCAL_MASK & ((unsigned int)((HiAiNpuLocal)(a)) & HIAI_NPULOCAL_VAL_MASK) << SHIFT_LOCAL_MASK) +#define HIAI_NPU_CODE_TYPE_BIT(a) \ + (HIAI_CODE_TYPE_MASK & ((unsigned int)((HiAiNpuCodeType)(a)) & HIAI_CODE_TYPE_VAL_MASK) << SHIFT_CODE_MASK) +#define HIAI_NPU_ERR_LEV_BIT(a) \ + (HIAI_ERROR_LEVEL_MASK & ((unsigned int)((HiAiNpuErrLevel)(a)) & HIAI_ERROR_LEVEL_VAL_MASK) << SHIFT_ERROR_LVL_MASK) +#define HIAI_NPU_MOD_ID_BIT(a) \ + (HIAI_MODE_ID_MASK & ((unsigned int)((HiAiNpuModuleId)(a)) & HIAI_MODE_ID_VAL_MASK) << SHIFT_MODE_MASK) + +#define HIAI_NPU_ERR_CODE_HEAD(npuLocal, codeType, errLevel, moduleId) \ + (HIAI_NPU_LOC_BIT(npuLocal) + HIAI_NPU_CODE_TYPE_BIT(codeType) + HIAI_NPU_ERR_LEV_BIT(errLevel) + \ + HIAI_NPU_MOD_ID_BIT(moduleId)) + +#endif // INC_COMMON_NPU_ERROR_DEFINE_H_ diff --git a/inc/metadef/common/opskernel/ge_task_info.h b/inc/metadef/common/opskernel/ge_task_info.h new file mode 100644 index 0000000000000000000000000000000000000000..eec374718defa7921619ae81b84057d9536c77ea --- /dev/null +++ b/inc/metadef/common/opskernel/ge_task_info.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ +#define INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ + +#include +#include +#include +#include "runtime/rt.h" +#include "graph/op_desc.h" + +namespace ge { +struct HcclDumpInfo { + uint32_t task_id; + uint32_t stream_id; + uint32_t sub_task_type; + void *input_addr; + uint64_t input_size; + void *output_addr; + uint64_t output_size; +}; + +struct DvppInfo { + OpDescPtr op_desc; + std::vector io_addrs; + uint32_t sqe[16]; +}; + +// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD +struct GETaskKernelHcclInfo { + std::string input_name; + std::string hccl_type; + void *inputDataAddr; + void *outputDataAddr; + void *workSpaceAddr; + int64_t count; + int32_t dataType; + int32_t opType; + int64_t rootId; + uint64_t workSpaceMemSize; + std::vector dims; + std::vector hcclStreamList; + std::vector hccl_dump_info; + std::vector global_workspace_addr; + uint32_t hcclQosCfg; + std::vector inputDataAddrs; + std::vector outputDataAddrs; + std::vector workSpaceAddrs; + std::vector workSpaceMemSizes; + std::vector inputZeroCopyFlags; + std::vector outputZeroCopyFlags; +}; + +struct GETaskInfo { + uint32_t id; + uint16_t type; + uint32_t streamID; + void *stream; // rtKernelLaunch input argument + void *event; + void *privateDef; + uint32_t privateDefLen; + void *opsKernelStorePtr; + std::vector kernelHcclInfo; + DvppInfo dvpp_info; + bool needRefresh{false}; + std::vector rt_attached_streams; +}; + +struct HcomRemoteAccessAddrInfo +{ + uint32_t remotetRankID; + uint64_t remoteAddr; // host embedding table address + uint64_t localAddr; // device HBM address + uint64_t length; // memory Length in Bytes +}; + + +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/metadef/common/opskernel/ops_kernel_builder.h b/inc/metadef/common/opskernel/ops_kernel_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..4213f276ffbf4a6eb9de2d5c39d995ad717924d9 --- /dev/null +++ b/inc/metadef/common/opskernel/ops_kernel_builder.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_BUILDER_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_BUILDER_H_ + +#include "external/ge_common/ge_api_error_codes.h" +#include "cce/aicpu_engine_struct.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/node.h" +#include "external/ge_common/ge_api_types.h" +#include "proto/task.pb.h" + +namespace ge { +class OpsKernelBuilder { + public: + enum class Mode : uint32_t { + kNormal, + kFfts, + kFftsPlus + }; + OpsKernelBuilder() = default; + virtual ~OpsKernelBuilder() = default; + OpsKernelBuilder(const OpsKernelBuilder &) = delete; + OpsKernelBuilder(OpsKernelBuilder &&) = delete; + OpsKernelBuilder &operator=(const OpsKernelBuilder &)& = delete; + OpsKernelBuilder &operator=(OpsKernelBuilder &&)& = delete; + + // initialize OpsKernelBuilder + virtual Status Initialize(const std::map &options) = 0; + + // finalize OpsKernelBuilder + virtual Status Finalize() = 0; + + // memory allocation requirement + virtual Status CalcOpRunningParam(Node &node) = 0; + + // generate task for op + virtual Status GenerateTask(const Node &node, RunContext &context, + std::vector &tasks) = 0; + + // generate task for op with different mode + virtual Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks, + OpsKernelBuilder::Mode) { + (void)node; + (void)context; + (void)tasks; + return SUCCESS; + } + + // update task which need stream event info, after SplitStream. Only change field in task, forbid change tasks size + virtual Status UpdateTask(const Node &node, std::vector &tasks) { + (void)node; + (void)tasks; + return SUCCESS; + } + + // only call aicpu interface to generate task struct + virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, std::string &task_info) { + (void)node; + (void)task; + (void)task_info; + return FAILED; + } + + // only call aicpu interface to generate task struct + virtual Status GenMemCopyTask(const uint64_t count, STR_FWK_OP_KERNEL &task, std::string &task_info) { + (void)count; + (void)task; + (void)task_info; + return FAILED; + } +}; +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_BUILDER_H_ diff --git a/inc/metadef/common/opskernel/ops_kernel_info_store.h b/inc/metadef/common/opskernel/ops_kernel_info_store.h new file mode 100644 index 0000000000000000000000000000000000000000..f9171baee8f4f7d397e416afbb37a9490b183148 --- /dev/null +++ b/inc/metadef/common/opskernel/ops_kernel_info_store.h @@ -0,0 +1,137 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ + +#include +#include +#include +#include +#include + +#include "common/opskernel/ge_task_info.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "common/ge_common/ge_inner_error_codes.h" +#include "graph/node.h" +#include "external/graph/operator.h" + +namespace ge { +class OpsKernelInfoStore { + public: + OpsKernelInfoStore() = default; + + virtual ~OpsKernelInfoStore() = default; + OpsKernelInfoStore(const OpsKernelInfoStore &) = delete; + OpsKernelInfoStore(OpsKernelInfoStore &&) = delete; + OpsKernelInfoStore &operator=(const OpsKernelInfoStore &)& = delete; + OpsKernelInfoStore &operator=(OpsKernelInfoStore &&)& = delete; + + // initialize opsKernelInfoStore + virtual Status Initialize(const std::map &options) = 0; + + // close opsKernelInfoStore + virtual Status Finalize() = 0; /*lint -e148*/ + + virtual Status CreateSession(const std::map &session_options) { + (void)session_options; + return SUCCESS; + } + + virtual Status DestroySession(const std::map &session_options) { + (void)session_options; + return SUCCESS; + } + + // get all opsKernelInfo + virtual void GetAllOpsKernelInfo(std::map &infos) const = 0; + + // whether the opsKernelInfoStore is supported based on the operator attribute + virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; + + virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, + const bool realQuery = false) const { + (void)realQuery; + return CheckSupported(opDescPtr, un_supported_reason); + } + // opsFlag opsFlag[0] indicates constant folding is supported or not + virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag) { + (void)node; + (void)opsFlag; + }; + + // only call fe engine interface to compile single op + virtual Status CompileOp(std::vector &node_vec) { + (void) node_vec; + return SUCCESS; + } + virtual Status CompileOpRun(std::vector &node_vec) { + (void)node_vec; + return SUCCESS; + } + + // prepare task for op + virtual Status PrepareTaskAsync(GETaskInfo &task) { + (void)task; + return SUCCESS; + } + + // load task for op + virtual Status LoadTask(GETaskInfo &task) { + (void)task; + return SUCCESS; + } + + virtual bool CheckSupported(const ge::NodePtr &node, std::string &un_supported_reason) const { + if (node == nullptr) { + return false; + } + return CheckSupported(node->GetOpDesc(), un_supported_reason); + } + + virtual bool CheckAccuracySupported(const ge::NodePtr &node, std::string &un_supported_reason, + const bool realQuery = false) const { + (void)realQuery; + if (node == nullptr) { + return false; + } + return CheckAccuracySupported(node->GetOpDesc(), un_supported_reason, realQuery); + } + // Set cut support info + virtual Status SetCutSupportedInfo(const ge::NodePtr &node) { + (void)node; + return SUCCESS; + } + // unload task for op + virtual Status UnloadTask(GETaskInfo &task) { + (void)task; + return SUCCESS; + } + + // fuzz compile interface + virtual Status FuzzCompileOp(std::vector &node_vec) { + (void) node_vec; + return SUCCESS; + } + + // Query information such as foramt/dtype/impl supported by operators (extensible) + virtual bool GetNodeSupportInfo(const OperatorPtr &op, std::string &support_info) { + (void)op; + (void)support_info; + return false; + } + + virtual bool CheckSupported(const ge::NodePtr &node, std::string &un_supported_reason, + CheckSupportFlag &flag) const { + (void)flag; + return CheckSupported(node, un_supported_reason); + } +}; +} // namespace ge +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ diff --git a/inc/metadef/common/opskernel/ops_kernel_info_types.h b/inc/metadef/common/opskernel/ops_kernel_info_types.h new file mode 100644 index 0000000000000000000000000000000000000000..c6eba5a0aca57a365da344d162a4981c49b9365d --- /dev/null +++ b/inc/metadef/common/opskernel/ops_kernel_info_types.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ +#define INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ + +#include +#include +#include +#include "graph/buffer.h" +#include "runtime/rt_model.h" + +namespace ge { +/*lint -e148*/ +struct RunContext { + rtModel_t model; + rtStream_t stream; + uint64_t sessionId; + uint64_t dataMemSize; + uint8_t *dataMemBase; + std::map mem_type_data_mem_size; + std::map mem_type_data_mem_base; + uint64_t weightMemSize; + uint8_t *weightMemBase; + ge::Buffer weightsBuffer; + std::vector graphStreamList; // all streams of graph, order by ge stream id(0,1,...) + std::vector graphEventList; // all events of graph, order by ge event id(0,1,...) + std::vector graphLabelList; // all labels of graph, order by ge label id(0,1,...) + std::vector graphNotifyList; // all labels of graph, order by ge notify id(0,1,...) +}; + +/*lint +e148*/ +struct Task { + uint32_t id; + uint16_t type; + void *stream; + void *event; +}; + +struct OpInfo { + std::string engine; // which engin + /*lint -e148*/ + std::string opKernelLib; // which opsKernelStore + int32_t computeCost; // compute cost + bool flagPartial; // whether to support is related to shape + bool flagAsync; // Whether to support asynchronous + bool isAtomic; // whether to support atomic addr clean + std::string opFileName; // op file name + std::string opFuncName; // op function name +}; + +enum class CheckSupportFlag : uint32_t { + kDefault = 0, + kNotSupportDynamicShape +}; +} // namespace ge + +#endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_TYPES_H_ diff --git a/inc/metadef/common/optimizer/graph_optimizer.h b/inc/metadef/common/optimizer/graph_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..9ad555ac33e146fb54a897457cdf95ec7af403db --- /dev/null +++ b/inc/metadef/common/optimizer/graph_optimizer.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ +#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ + +#include +#include +#include "graph_optimizer_types.h" +#include "optimize_utility.h" +#include "common/ge_common/ge_inner_error_codes.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/compute_graph.h" +#include "graph/op_kernel_bin.h" + +/*lint -e148*/ +namespace ge { +class GraphOptimizer { + public: + virtual ~GraphOptimizer() {} + + // initialize graphOptimizer + virtual Status Initialize(const std::map &options, + OptimizeUtility *const optimize_utility) = 0; + + // close graphOptimizer + virtual Status Finalize() = 0; + + // init process for optimize graph every time because options may different in different build process + // 当前引擎获取编译option是在OptimizeGraphPrepare接口中获取,该接口默认会过滤vector engine。 + // 当前出现问题场景是子图优化阶段因为算子融合直接选择了vector engine的场景,出现了vector engine获取不到编译option导致问题。 + // 当前决策新增OptimizeGraphInit接口,该接口不会过滤引擎,全部调用.这样获取到build option操作就从OptimizeGraphPrepare剥离。 + virtual Status OptimizeGraphInit(ComputeGraph& graph) { + (void)graph; + return SUCCESS; + } + + // optimize original graph for FE quant optimize + virtual Status OptimizeGraphPrepare(ComputeGraph& graph) { + (void)graph; + return SUCCESS; + } + + // optimize graph after normalization, include multi dims and pre/post process + virtual Status OptimizeAfterGraphNormalization(const ComputeGraphPtr& graph) { + (void)graph; + return SUCCESS; + } + + // optimize graph before build for RTS + virtual Status OptimizeGraphBeforeBuild(ComputeGraph& graph) { + (void)graph; + return SUCCESS; + } + + // optimize original graph, using in graph preparation stage + virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; + + // optimize original graph, using for conversion operator insert in graph preparation stage + virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { + (void)graph; + return SUCCESS; + } + + // optimize fused graph + virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; + + // optimize whole graph, using after graph merged stage + virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; + + // get attribute of graph optimizer + virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; + + // optimize streamed Graph + virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { + (void)graph; + (void)context; + return SUCCESS; + } + + // optimize streamed whole Graph + virtual Status OptimizeStreamedWholeGraph(ComputeGraph &graph) { + (void)graph; + return SUCCESS; + } + + // op compile + virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { + (void)graph; + return SUCCESS; + } + + // optimize whole graph, using after stage1 + virtual Status OptimizeAfterStage1(ComputeGraph &graph) { + (void)graph; + return SUCCESS; + } + + // recover compile result of precompiled op + using KernelLookup = std::function; + virtual Status OptimizeSubgraphOfPrecompiledOp(ComputeGraph &graph, const KernelLookup &lookup) { + static_cast(graph); + static_cast(lookup); + return SUCCESS; + } + + // 为避免子图优化中多线程操作导致的数据读写冲突,提供子图优化前后的单线程接口,由引擎实现以实现改图功能 + virtual Status OptimizeSubgraphPreProc(ComputeGraph &graph) { + (void)graph; + return SUCCESS; + } + virtual Status OptimizeSubgraphPostProc(ComputeGraph &graph) { + (void)graph; + return SUCCESS; + } +}; +} // namespace ge +/*lint +e148*/ +#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ diff --git a/inc/metadef/common/optimizer/graph_optimizer_types.h b/inc/metadef/common/optimizer/graph_optimizer_types.h new file mode 100644 index 0000000000000000000000000000000000000000..5b51a3aff02fc610409dfdec97702ca1cb9ce97c --- /dev/null +++ b/inc/metadef/common/optimizer/graph_optimizer_types.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ +#define INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ + +#include + +namespace ge { +enum OPTIMIZER_SCOPE { + UNIT = 0, + ENGINE, +}; + +struct GraphOptimizerAttribute { + std::string engineName; + OPTIMIZER_SCOPE scope; +}; +} // namespace ge + +#endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_TYPES_H_ diff --git a/inc/metadef/common/optimizer/optimize_utility.h b/inc/metadef/common/optimizer/optimize_utility.h new file mode 100644 index 0000000000000000000000000000000000000000..c3f0d6c73eeecdc958672d2113105ed3f604bd9e --- /dev/null +++ b/inc/metadef/common/optimizer/optimize_utility.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_OPTIMIZER_OPTIMIZE_UTILITY_H_ +#define INC_COMMON_OPTIMIZER_OPTIMIZE_UTILITY_H_ + +#include "common/ge_common/ge_inner_error_codes.h" +#include "graph/compute_graph.h" + +namespace ge { +class OptimizeUtility { + public: + virtual ~OptimizeUtility() {} + + // Deprecated: will delete later. Graph infershape util + virtual Status InferShape(ComputeGraph &compute_graph) { + (void)compute_graph; + return SUCCESS; + } + + // Graph infershape util + virtual Status InferShape(const ComputeGraphPtr &compute_graph) = 0; + + // Mlti Dims and pre/post process + virtual Status MultiDimsProcess(const ComputeGraphPtr &compute_graph) { + (void)compute_graph; + return SUCCESS; + } + + // Constant folding + virtual Status ConstantFolding(NodePtr &node) { + (void)node; + return SUCCESS; + } +}; +} // namespace ge +#endif // INC_COMMON_OPTIMIZER_OPTIMIZE_UTILITY_H_ diff --git a/inc/metadef/common/plugin/plugin_manager.h b/inc/metadef/common/plugin/plugin_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..987bb2a39506303a343dece00b91c384dd8806de --- /dev/null +++ b/inc/metadef/common/plugin/plugin_manager.h @@ -0,0 +1,304 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GE_COMMON_GE_PLUGIN_MANAGER_H_ +#define GE_COMMON_GE_PLUGIN_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/ge_common/ge_inner_error_codes.h" +#include "common/ge_common/debug/ge_log.h" +#include "mmpa/mmpa_api.h" + +namespace ge { +class DNNEngine; +class PluginManager { + public: + class InvokeFuncPerfRecorder { + public: + InvokeFuncPerfRecorder(const std::string &func_name, const std::string &lib_name) + : func_name_(func_name), lib_name_(lib_name) { + start_ = std::chrono::high_resolution_clock::now(); + } + ~InvokeFuncPerfRecorder() { + const auto end = std::chrono::high_resolution_clock::now(); + GEEVENT("[GEPERFTRACE] The time cost of InvokeAll [%s] in [%s] is [%lu] micro second.", func_name_.c_str(), + lib_name_.c_str(), std::chrono::duration_cast(end - start_).count()); + } + private: + std::chrono::high_resolution_clock::time_point start_; + const std::string &func_name_; + const std::string &lib_name_; + }; + PluginManager() = default; + + ~PluginManager(); + + static void SplitPath(const std::string &mutil_path, std::vector &path_vec, const char sep = ':'); + + static Status GetOppPath(std::string &opp_path); + + static Status GetUpgradedOppPath(std::string &opp_path); + + static bool IsNewOppPathStruct(const std::string &opp_path); + + static Status GetOppPluginVendors(const std::string &vendors_config, std::vector &vendors); + + static Status ReversePathString(std::string &path_str); + + static void GetPluginPathFromCustomOppPath(const std::string &sub_path, std::string &plugin_path); + + static Status GetOppPluginPathOld(const std::string &opp_path, + const std::string &path_fmt, + std::string &plugin_path, + const std::string &path_fmt_custom = ""); + + static Status GetOppPluginPathNew(const std::string &opp_path, + const std::string &path_fmt, + std::string &plugin_path, + const std::string &old_custom_path, + const std::string &path_fmt_custom = ""); + + static bool IsSplitOpp(); + + static Status GetOpsProtoPath(std::string &opsproto_path); + + static Status GetUpgradedOpsProtoPath(std::string &opsproto_path); + + static Status GetUpgradedOpMasterPath(std::string &op_tiling_path); + + static Status GetCustomOpPath(const std::string &fmk_type, std::string &customop_path); + + static Status GetCustomCaffeProtoPath(std::string &customcaffe_path); + + static Status GetOpTilingPath(std::string &op_tiling_path); + + static Status GetOpTilingForwardOrderPath(std::string &op_tiling_path); + + static Status GetConstantFoldingOpsPath(const std::string &path_base, std::string &constant_folding_ops_path); + + Status LoadSoWithFlags(const std::string &path, const int32_t flags, + const std::vector &func_check_list = std::vector()); + + Status LoadSo(const std::string &path, const std::vector &func_check_list = std::vector()); + + Status Load(const std::string &path, const std::vector &func_check_list = std::vector()); + + Status LoadWithFlags(const std::string &path, const int32_t flags, + const std::vector &func_check_list = std::vector()); + + static void GetOppSupportedOsAndCpuType( + std::unordered_map> &opp_supported_os_cpu, + std::string opp_path = "", std::string os_name = "", uint32_t layer = 0U); + + static void GetCurEnvPackageOsAndCpuType(std::string &host_env_os, std::string &host_env_cpu); + + static bool GetVersionFromPath(const std::string &file_path, std::string &version); + + static void GetFileListWithSuffix(const std::string &path, const std::string &so_suff, + std::vector &file_list); + + static bool IsVendorVersionValid(const std::string &opp_version, const std::string &compiler_version); + + static bool IsVendorVersionValid(const std::string &vendor_path); + + static void GetPackageSoPath(std::vector &vendors); + + static bool GetVersionFromPathWithName(const std::string &file_path, std::string &version, + const std::string version_name); + static void FindSoFilesInCustomPassDirs(const std::string &directory, + std::vector &so_files); + + static bool IsEndWith(const std::string &path, const std::string &suff); + + static Status GetOpMasterDeviceSoPath(std::string &op_master_device_path); + + template + Status GetAllFunctions(const std::string &func_name, std::map> &funcs) { + for (const auto &handle : handles_) { + const auto real_fn = reinterpret_cast(mmDlsym(handle.second, func_name.c_str())); + if (real_fn == nullptr) { + const char *error = mmDlerror(); + if (error == nullptr) { + error = ""; + } + GELOGW("Failed to get function %s in %s! errmsg:%s", func_name.c_str(), handle.first.c_str(), error); + return GE_PLGMGR_FUNC_NOT_EXIST; + } else { + funcs[handle.first] = real_fn; + } + } + return SUCCESS; + } + + template + Status InvokeAll(const std::string &func_name, const Types... args) { + for (const auto &handle : handles_) { + InvokeFuncPerfRecorder recorder(func_name, handle.first); + // If the funcName is existed, signature of realFn can be casted to any type + const auto real_fn = reinterpret_cast(mmDlsym(handle.second, func_name.c_str())); + if (real_fn == nullptr) { + const char *error = mmDlerror(); + if (error == nullptr) { + error = ""; + } + GELOGW("Failed to invoke function %s in %s! errmsg:%s", func_name.c_str(), handle.first.c_str(), error); + return GE_PLGMGR_INVOKE_FAILED; + } else { + real_fn(args...); + } + } + return SUCCESS; + } + + template + Status InvokeAll(const std::string &func_name, const T arg) { + for (const auto &handle : handles_) { + // If the funcName is existed, signature of realFn can be casted to any type + InvokeFuncPerfRecorder recorder(func_name, handle.first); + const auto real_fn = reinterpret_cast(mmDlsym(handle.second, func_name.c_str())); + if (real_fn == nullptr) { + const char_t *error = mmDlerror(); + if (error == nullptr) { + error = ""; + } + GELOGW("Failed to invoke function %s in %s! errmsg:%s", func_name.c_str(), handle.first.c_str(), error); + return GE_PLGMGR_INVOKE_FAILED; + } + typename std::remove_reference::type arg_temp; + real_fn(arg_temp); + + if (std::is_same::type, + std::map>>::value) { + for (const auto &val : arg_temp) { + if (arg.find(val.first) != arg.end()) { + GELOGW("FuncName %s in so %s find the same key: %s, will replace it", func_name.c_str(), + handle.first.c_str(), val.first.c_str()); + arg[val.first] = val.second; + } + } + } + arg.insert(arg_temp.begin(), arg_temp.end()); + } + return SUCCESS; + } + + template + void OptionalInvokeAll(const std::string &func_name, const Args... args) const { + for (const auto &handle : handles_) { + // If the funcName is existed, signature of realFn can be casted to any type + InvokeFuncPerfRecorder recorder(func_name, handle.first); + const auto real_fn = reinterpret_cast(mmDlsym(handle.second, func_name.c_str())); + if (real_fn == nullptr) { + GELOGI("func %s not exist in so %s", handle.first.c_str(), func_name.c_str()); + continue; + } else { + GELOGI("func %s exists in so %s", handle.first.c_str(), func_name.c_str()); + real_fn(args...); + } + } + } + + template + Status InvokeAll(const std::string &func_name, const T1 arg) { + for (const auto &handle : handles_) { + // If the funcName is existed, signature of realFn can be casted to any type + InvokeFuncPerfRecorder recorder(func_name, handle.first); + const auto real_fn = reinterpret_cast(mmDlsym(handle.second, func_name.c_str())); + if (real_fn == nullptr) { + const char_t *error = mmDlerror(); + if (error == nullptr) { + error = ""; + } + GELOGW("Failed to invoke function %s in %s! errmsg:%s", func_name.c_str(), handle.first.c_str(), error); + return GE_PLGMGR_INVOKE_FAILED; + } else { + const T2 res = real_fn(arg); + if (res != SUCCESS) { + return FAILED; + } + } + } + return SUCCESS; + } + + template + Status InvokeAll(const std::string &func_name) { + for (const auto &handle : handles_) { + // If the funcName is existed, signature of realFn can be casted to any type + InvokeFuncPerfRecorder recorder(func_name, handle.first); + const auto real_fn = reinterpret_cast(mmDlsym(handle.second, func_name.c_str())); + if (real_fn == nullptr) { + const char_t *error = mmDlerror(); + if (error == nullptr) { + error = ""; + } + GELOGW("Failed to invoke function %s in %s! errmsg:%s", func_name.c_str(), handle.first.c_str(), error); + return GE_PLGMGR_INVOKE_FAILED; + } else { + const T res = real_fn(); + if (res != SUCCESS) { + return FAILED; + } + } + } + return SUCCESS; + } + + private: + void ClearHandles_() noexcept; + Status ValidateSo(const std::string &file_path, const int64_t size_of_loaded_so, int64_t &file_size) const; + static bool ParseVersion(std::string &line, std::string &version, const std::string version_name); + static bool GetRequiredOppAbiVersion(std::vector> &required_opp_abi_version); + static bool GetEffectiveVersion(const std::string &opp_version, uint32_t &effective_version); + static bool CheckOppAndCompilerVersions(const std::string &opp_version, const std::string &compiler_version, + const std::vector> &required_version); + static void GetOppAndCompilerVersion(const std::string &vendor_path, std::string &opp_version, + std::string &compiler_version); + std::vector so_list_; + std::map handles_; +}; + +inline std::string GetModelPath() { + mmDlInfo dl_info; + if ((mmDladdr(reinterpret_cast(&GetModelPath), &dl_info) != EN_OK) || (dl_info.dli_fname == nullptr)) { + GELOGW("Failed to read the shared library file path! errmsg:%s", mmDlerror()); + return std::string(); + } + + if (strlen(dl_info.dli_fname) >= MMPA_MAX_PATH) { + GELOGW("The shared library file path is too long!"); + return std::string(); + } + + char_t path[MMPA_MAX_PATH] = {}; + if (mmRealPath(dl_info.dli_fname, &path[0], MMPA_MAX_PATH) != EN_OK) { + constexpr size_t max_error_strlen = 128U; + char_t err_buf[max_error_strlen + 1U] = {}; + const auto err_msg = mmGetErrorFormatMessage(mmGetErrorCode(), &err_buf[0], max_error_strlen); + GELOGW("Failed to get realpath of %s, errmsg:%s", dl_info.dli_fname, err_msg); + return std::string(); + } + + std::string so_path = path; + so_path = so_path.substr(0U, so_path.rfind('/') + 1U); + return so_path; +} +} // namespace ge + +#endif // GE_COMMON_GE_PLUGIN_MANAGER_H_ diff --git a/inc/metadef/common/screen_printer.h b/inc/metadef/common/screen_printer.h new file mode 100644 index 0000000000000000000000000000000000000000..dee0176c29415d19f35dc8cf6f213eaa2a5f9cff --- /dev/null +++ b/inc/metadef/common/screen_printer.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_INC_COMMON_SCREEN_PRINTER_H_ +#define METADEF_INC_COMMON_SCREEN_PRINTER_H_ + +#include +#include +#include +#include +#include "external/graph/types.h" + +namespace ge { +class ScreenPrinter { + public: + static ScreenPrinter &GetInstance(); + void Log(const char *fmt, ...); + void Init(const std::string &print_mode); + + private: + ScreenPrinter() = default; + ~ScreenPrinter() = default; + + ScreenPrinter(const ScreenPrinter &) = delete; + ScreenPrinter(const ScreenPrinter &&) = delete; + ScreenPrinter &operator=(const ScreenPrinter &)& = delete; + ScreenPrinter &operator=(const ScreenPrinter &&)& = delete; + + enum class PrintMode : uint32_t { + ENABLE = 0U, + DISABLE = 1U + }; + PrintMode print_mode_ = PrintMode::ENABLE; + std::mutex mutex_; +}; + +#define SCREEN_LOG(fmt, ...) \ + do { \ + ScreenPrinter::GetInstance().Log(fmt, ##__VA_ARGS__); \ + } while (false) +} // namespace ge +#endif // METADEF_INC_COMMON_SCREEN_PRINTER_H_ diff --git a/inc/metadef/common/sgt_slice_type.h b/inc/metadef/common/sgt_slice_type.h new file mode 100644 index 0000000000000000000000000000000000000000..6b62f7a475f4bd35e5aeae792ef0742e81c07b6f --- /dev/null +++ b/inc/metadef/common/sgt_slice_type.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_SGT_SLICE_TYPES_H_ +#define INC_COMMON_SGT_SLICE_TYPES_H_ + +#include +#include +#include +#include + +namespace ffts { +const std::string kAttrSgtJsonInfo = "_sgt_json_info"; +const std::string kAttrSgtStructInfo = "_sgt_struct_info"; +const std::string kAttrSgtStructInfoDy = "_sgt_struct_info_dy"; +const size_t kSgtTillingNum = 2U; + +struct OpCut { + int16_t split_cut_idx = -1; + int16_t reduce_cut_idx = -1; + int64_t cut_id = -1; +}; + +struct DimRange { + int64_t lower; + int64_t higher; + bool operator==(const DimRange& dim_range) const { + return (this->higher == dim_range.higher) && (this->lower == dim_range.lower); + } +}; + +enum class AtomicType { + None = 0, + ADD = 1, + SUB, + MUL, + DIV +}; + +struct ThreadSliceMap { + uint32_t thread_scope_id; + bool is_first_node_in_topo_order; + uint32_t thread_mode; + uint32_t node_num_in_thread_scope; + bool is_input_node_of_thread_scope; + bool is_output_node_of_thread_scope; + std::vector>> ori_input_tensor_shape; + std::vector>> ori_output_tensor_shape; + std::string original_node; + uint32_t slice_instance_num; + uint32_t parallel_window_size; + uint32_t thread_id; + std::vector>> dependencies; + std::vector core_num; + std::vector cut_type; + std::vector atomic_types; + std::vector same_atomic_clean_nodes; + std::vector input_axis; + std::vector output_axis; + std::vector input_tensor_indexes; + std::vector output_tensor_indexes; + std::vector>> input_tensor_slice; + std::vector>> output_tensor_slice; + std::vector>> ori_input_tensor_slice; + std::vector>> ori_output_tensor_slice; + std::vector> input_cut_list; + std::vector> output_cut_list; + ThreadSliceMap() : thread_scope_id(1U), is_first_node_in_topo_order(false), thread_mode(0U), + node_num_in_thread_scope(1U), is_input_node_of_thread_scope(false), is_output_node_of_thread_scope(false), + slice_instance_num(1U), parallel_window_size(1U), thread_id(0U) {} + bool GetThreadMode() const { + return (thread_mode == 0U) ? false : true; + } +}; + +struct ThreadSliceMapDy { + uint32_t slice_instance_num; + uint32_t parallel_window_size; + std::vector input_tensor_indexes; + std::vector output_tensor_indexes; + std::vector>> input_tensor_slice; + std::vector>> output_tensor_slice; + ThreadSliceMapDy() : slice_instance_num(1U), parallel_window_size(1U) {} +}; + +using ThreadSliceMapPtr = std::shared_ptr; +using ThreadSliceMapDyPtr = std::shared_ptr; +} // namespace ffts +#endif // INC_COMMON_SGT_SLICE_TYPES_H_ diff --git a/inc/metadef/common/util/compress/compress.h b/inc/metadef/common/util/compress/compress.h new file mode 100644 index 0000000000000000000000000000000000000000..99e862d0416ee3e1fee9f72dec196dddb9ab80b8 --- /dev/null +++ b/inc/metadef/common/util/compress/compress.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMPRESS_H +#define COMPRESS_H + +#include +#include + +enum CmpStatus { + RET_SUCCESS = 0, + RET_ERROR = -1 +}; + +struct CompressConfig { + size_t inputSize; // length of data to compress + size_t engineNum; // how many decompress engines + size_t maxRatio; // how much size of a basic compression block (8x: 64 4x: 32) + size_t channel; // channels of L2 or DDR. For load balance + size_t fractalSize; // size of compressing block + bool isTight; // whether compose compressed data tightly + size_t init_offset; + int32_t compressType = 0; // compress type flag only 0 and 1 supported (0: low_sparse, 1: high_sparse) +}; + +CmpStatus CompressWeights(char* input, + const CompressConfig& compressConfig, + char* indexs, + char* output, + size_t& compressedLength); + +CmpStatus SparseWeightsConv2D(const char *const input, size_t weight_size); +#endif // COMPRESS_H diff --git a/inc/metadef/common/util/compress/compress_weight.h b/inc/metadef/common/util/compress/compress_weight.h new file mode 100644 index 0000000000000000000000000000000000000000..459c3a5f1eaa5e8dddc1f95c5d9bfc833a1bb4b1 --- /dev/null +++ b/inc/metadef/common/util/compress/compress_weight.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMPRESS_WEIGHT_H +#define COMPRESS_WEIGHT_H + +#include +#include "compress.h" + +const int SHAPE_SIZE_WEIGHT = 4; + +struct CompressOpConfig { + int64_t wShape[SHAPE_SIZE_WEIGHT]; + size_t compressTilingK; + size_t compressTilingN; + struct CompressConfig compressConfig; +}; + +extern "C" CmpStatus CompressWeightsConv2D(const char *const input, + char *const zipBuffer, + char *const infoBuffer, + CompressOpConfig *const param); +#endif // COMPRESS_WEIGHT_H diff --git a/inc/metadef/common/util/error_manager/error_manager.h b/inc/metadef/common/util/error_manager/error_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..fbbe42b83cf7f6769b16488baf057b021e3fd1ca --- /dev/null +++ b/inc/metadef/common/util/error_manager/error_manager.h @@ -0,0 +1,264 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef ERROR_MANAGER_H_ +#define ERROR_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace error_message { +using char_t = char; +std::string TrimPath(const std::string &str); +#ifdef __GNUC__ +int32_t FormatErrorMessage(char_t *str_dst, size_t dst_max, + const char_t *format, ...)__attribute__((format(printf, 3, 4))); + +void ReportInnerError(const char_t *file_name, const char_t *func, uint32_t line, const std::string error_code, + const char_t *format, ...) __attribute__((format(printf, 5, 6))); +#else +int32_t FormatErrorMessage(char_t *str_dst, size_t dst_max, const char_t *format, ...); +void ReportInnerError(const char_t *file_name, const char_t *func, uint32_t line, const std::string error_code, + const char_t *format, ...); +#endif +} + +constexpr size_t const LIMIT_PER_MESSAGE = 1024U; + +/// +/// @brief Report error message +/// @param [in] key: vector parameter key +/// @param [in] value: vector parameter value +/// +#define REPORT_INPUT_ERROR(error_code, key, value) \ + ErrorManager::GetInstance().ATCReportErrMessage(error_code, key, value) + +/// +/// @brief Report error message +/// @param [in] key: vector parameter key +/// @param [in] value: vector parameter value +/// +#define REPORT_ENV_ERROR(error_code, key, value) \ + ErrorManager::GetInstance().ATCReportErrMessage(error_code, key, value) + +#define REPORT_INNER_ERROR(error_code, fmt, ...) \ + error_message::ReportInnerError(__FILE__, &__FUNCTION__[0], __LINE__, (error_code), (fmt), ##__VA_ARGS__) + +#define REPORT_CALL_ERROR REPORT_INNER_ERROR + +namespace error_message { +// first stage +extern const std::string kInitialize; +extern const std::string kModelCompile; +extern const std::string kModelLoad; +extern const std::string kModelExecute; +extern const std::string kFinalize; + +// SecondStage +// INITIALIZE +extern const std::string kParser; +extern const std::string kOpsProtoInit; +extern const std::string kSystemInit; +extern const std::string kEngineInit; +extern const std::string kOpsKernelInit; +extern const std::string kOpsKernelBuilderInit; +// MODEL_COMPILE +extern const std::string kPrepareOptimize; +extern const std::string kOriginOptimize; +extern const std::string kSubGraphOptimize; +extern const std::string kMergeGraphOptimize; +extern const std::string kPreBuild; +extern const std::string kStreamAlloc; +extern const std::string kMemoryAlloc; +extern const std::string kTaskGenerate; +// COMMON +extern const std::string kOther; + +struct Context { + uint64_t work_stream_id; + std::string first_stage; + std::string second_stage; + std::string log_header; +}; + +enum class ErrorMsgMode : uint32_t { + // 0:内置模式,推理采用线程粒度,训练采用session粒度,1:以进程为粒度 + INTERNAL_MODE = 0U, + PROCESS_MODE = 1U, + ERR_MSG_MODE_MAX = 2U +}; + +struct ErrorItem { + std::string error_id; + std::string error_message; + std::string possible_cause; + std::string solution; + // args_map作用于有error_id对应的json文件配置时,填充error_message对象 + std::map args_map; + std::string report_time; + + friend bool operator==(const ErrorItem &lhs, const ErrorItem &rhs) noexcept { + return (lhs.error_id == rhs.error_id) && (lhs.error_message == rhs.error_message) && + (lhs.possible_cause == rhs.possible_cause) && (lhs.solution == rhs.solution); + } +}; +} // namespace error_message + +class ErrorManager { + public: + using ErrorItem = error_message::ErrorItem; + /// @brief Obtain ErrorManager instance + /// @return ErrorManager instance + static ErrorManager &GetInstance(); + + /// @brief init + /// @return int 0(success) -1(fail) + int32_t Init(); + + int32_t Init(error_message::ErrorMsgMode error_mode); + + /// @brief init + /// @param [in] path: current so path + /// @return int 0(success) -1(fail) + int32_t Init(const std::string path); + + int32_t ReportInterErrMessage(const std::string error_code, const std::string &error_msg); + + /// @brief Report error message + /// @param [in] error_code: error code + /// @param [in] args_map: parameter map + /// @return int 0(success) -1(fail) + int32_t ReportErrMessage(const std::string error_code, const std::map &args_map); + + /// @brief output error message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + int32_t OutputErrMessage(int32_t handle); + + /// @brief output message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + int32_t OutputMessage(int32_t handle); + // 调用成功后会进行已有的错误信息的清理 + std::string GetErrorMessage(); + + // 获取当前线程下原始上报的内部,外部错误信息,顺序为上报的原始顺序 + // 便于调用方进行自定制的加工,调用成功后会进行已有的错误信息的清理 + std::vector GetRawErrorMessages(); + + std::string GetWarningMessage(); + + /// @brief Report error message + /// @param [in] key: vector parameter key + /// @param [in] value: vector parameter value + void ATCReportErrMessage(const std::string error_code, const std::vector &key = {}, + const std::vector &value = {}); + + /// @brief report graph compile failed message such as error code and op_name in mstune case + /// @param [in] graph_name: root graph name + /// @param [in] msg: failed message map, key is error code, value is op_name + /// @return int 0(success) -1(fail) + int32_t ReportMstuneCompileFailedMsg(const std::string &root_graph_name, + const std::map &msg); + + /// @brief get graph compile failed message in mstune case + /// @param [in] graph_name: graph name + /// @param [out] msg_map: failed message map, key is error code, value is op_name list + /// @return int 0(success) -1(fail) + int32_t GetMstuneCompileFailedMsg(const std::string &graph_name, + std::map> &msg_map); + + // @brief generate work_stream_id by current pid and tid, clear error_message stored by same work_stream_id + // used in external api entrance, all sync api can use + void GenWorkStreamIdDefault(); + + // @brief generate work_stream_id by args sessionid and graphid, clear error_message stored by same work_stream_id + // used in external api entrance + void GenWorkStreamIdBySessionGraph(const uint64_t session_id, const uint64_t graph_id); + + void GenWorkStreamIdWithSessionIdGraphId(const uint64_t session_id, const uint64_t graph_id); + + const std::string &GetLogHeader(); + + error_message::Context &GetErrorManagerContext(); + + void SetErrorContext(error_message::Context error_context); + + void SetStage(const std::string &first_stage, const std::string &second_stage); + + private: + struct ErrorInfoConfig { + std::string error_id; + std::string error_message; + std::string possible_cause; + std::string solution; + std::vector arg_list; + }; + ErrorManager() = default; + ~ErrorManager() = default; + + ErrorManager(const ErrorManager &) = delete; + ErrorManager(ErrorManager &&) = delete; + ErrorManager &operator=(const ErrorManager &)& = delete; + ErrorManager &operator=(ErrorManager &&)& = delete; + + int32_t ParseJsonFile(const std::string path); + + static int32_t ReadJsonFile(const std::string &file_path, void *const handle); + + void ClassifyCompileFailedMsg(const std::map &msg, + std::map> &classified_msg); + + bool IsInnerErrorCode(const std::string &error_code) const; + + bool IsParamCheckErrorId(const std::string &error_code) const; + + inline bool IsValidErrorCode(const std::string &error_codes) const { + constexpr uint32_t kErrorCodeValidLength = 6U; + return error_codes.size() == kErrorCodeValidLength; + } + + std::vector &GetErrorMsgContainerByWorkId(uint64_t work_id); + std::vector &GetWarningMsgContainerByWorkId(uint64_t work_id); + + std::vector &GetErrorMsgContainer(uint64_t work_stream_id); + std::vector &GetWarningMsgContainer(uint64_t work_stream_id); + + void AssembleInnerErrorMessage(const std::vector &error_messages, const std::string &first_code, + std::stringstream &err_stream) const; + + void ClearErrorMsgContainerByWorkId(const uint64_t work_stream_id); + void ClearWarningMsgContainerByWorkId(const uint64_t work_stream_id); + + void ClearErrorMsgContainer(const uint64_t work_stream_id); + void ClearWarningMsgContainer(const uint64_t work_stream_id); + + bool is_init_ = false; + std::mutex mutex_; + std::map error_map_; + std::map>> compile_failed_msg_map_; + + std::map> error_message_per_work_id_; + std::map> warning_messages_per_work_id_; + + thread_local static error_message::Context error_context_; + + error_message::ErrorMsgMode error_mode_ = error_message::ErrorMsgMode::INTERNAL_MODE; + std::vector error_message_process_; // 进程粒度,所有的errmsg存到同一个vector + std::vector warning_messages_process_; // 进程粒度,所有的warning msg存到同一个vector +}; +#endif // ERROR_MANAGER_H_ diff --git a/inc/metadef/common/util/mem_utils.h b/inc/metadef/common/util/mem_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..bcd361abcb321b752775cf31432a668632067d17 --- /dev/null +++ b/inc/metadef/common/util/mem_utils.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_UTILS_MEM_UTILS_H_ +#define COMMON_GRAPH_UTILS_MEM_UTILS_H_ + +#include +#include + +namespace ge { +template +static inline std::shared_ptr<_Tp> MakeShared(_Args &&... __args) { + using _Tp_nc = typename std::remove_const<_Tp>::type; + const std::shared_ptr<_Tp> ret(new (std::nothrow) _Tp_nc(std::forward<_Args>(__args)...)); + return ret; +} +} + +#endif // COMMON_GRAPH_UTILS_MEM_UTILS_H_ diff --git a/inc/metadef/common/util/tiling_utils.h b/inc/metadef/common/util/tiling_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..eb518997ff6e0299e4a139686e16134584977972 --- /dev/null +++ b/inc/metadef/common/util/tiling_utils.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_COMMON_UTIL_TILING_UTILS_H_ +#define METADEF_CXX_INC_COMMON_UTIL_TILING_UTILS_H_ +#include +#include "graph/types.h" + +namespace optiling { +union Fp32 { + uint32_t u; + ge::float32_t f; +}; + +inline uint16_t Float32ToFloat16(const ge::float32_t value) { + constexpr Fp32 f32infty = {static_cast(255) << static_cast(23)}; + constexpr uint32_t sign_mask = 0x80000000U; + constexpr uint32_t right_shift_16 = 16U; + + Fp32 temp; + uint16_t out; + temp.f = value; + const uint32_t sign = temp.u & sign_mask; + temp.u ^= sign; + + if (temp.u >= f32infty.u) { + constexpr uint32_t round_max = 0x7FFFU; + constexpr uint32_t dst_addr = 0x7C00U; + out = (temp.u > f32infty.u) ? round_max : dst_addr; + } else { + constexpr uint32_t right_shift_13 = 13U; + constexpr Fp32 f16infty = {static_cast(31) << static_cast(23)}; + constexpr Fp32 magic = {static_cast(15) << static_cast(23)}; + constexpr uint32_t round_mask = static_cast(~0xFFFU); + + temp.u &= round_mask; + temp.f *= magic.f; + temp.u -= round_mask; + if (temp.u > f16infty.u) { + temp.u = f16infty.u; + } + out = uint16_t(temp.u >> right_shift_13); + } + + out = uint16_t(out | (sign >> right_shift_16)); + return out; +} + +inline uint16_t Float32ToBfloat16(const ge::float32_t value) { + Fp32 temp; + temp.f = value; + constexpr uint32_t right_shift_16 = 16U; + return uint16_t(temp.u >> right_shift_16); +} + +template +inline uint16_t OtherToFloat16(const T value) { + const ge::float32_t value_f32 = static_cast(value); + return Float32ToFloat16(value_f32); +} + +template +inline uint16_t OtherToBfloat16(const T value) { + const ge::float32_t value_f32 = static_cast(value); + return Float32ToBfloat16(value_f32); +} +} // namespace optiling +#endif diff --git a/inc/metadef/common/util/trace_manager/trace_manager.h b/inc/metadef/common/util/trace_manager/trace_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..e2ad0ae7871ec3853a07d08f771223259382bdf7 --- /dev/null +++ b/inc/metadef/common/util/trace_manager/trace_manager.h @@ -0,0 +1,110 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_UTIL_TRACE_MANAGER_TRACE_MANAGER_H_ +#define COMMON_UTIL_TRACE_MANAGER_TRACE_MANAGER_H_ + +#include +#include +#include +#include +#include +#include "common/ge_common/util.h" + +namespace ge { +#define TRACE_GEN_RECORD(owner, action, graph_name, node_name, node_data, tensor_index, tensor_data, content) \ + do { \ + if (TraceManager::GetInstance().IsTraceEnabled()) { \ + if (TraceManager::GetTraceHeader().size() == 0) { \ + GELOGD("[Check][Param] owner and stage have not been set"); \ + } else { \ + std::stringstream ss; \ + ss << owner << "," << action << "," << graph_name << "," << node_name << "," << node_data << "," \ + << tensor_index << "," << tensor_data << "," << content; \ + TraceManager::GetInstance().AddTrace(ss.str()); \ + } \ + } \ + } while (false) + +using char_t = char; + +constexpr uint64_t kTraceSaveTriggerNum = 5000U; + +enum class ReadyPart { A, B, None }; + +class TraceManager { + public: + static TraceManager &GetInstance(); + + void AddTrace(std::string &&trace_info); + + bool IsTraceEnabled() const { + return enabled_; + } + void SetTraceOwner(const std::string &owner, const std::string &stage, const std::string &graph_name); + void ClearTraceOwner(); + static inline const std::string &GetTraceHeader() { + return trace_header_; + } + static inline const std::string &GetOutGraphName() { + return graph_name_; + } + + private: + TraceManager(); + ~TraceManager(); + TraceManager(const TraceManager &) = delete; + TraceManager(TraceManager &&) = delete; + TraceManager &operator=(const TraceManager &) = delete; + TraceManager &operator=(TraceManager &&) = delete; + Status Initialize(const char_t *file_save_path); + void Finalize(); + + std::string NextFileName(); + void SaveTraceBufferToFile(const ReadyPart ready_part); + void SaveBufferToFileThreadFunc(); + + static thread_local std::string trace_header_; + static thread_local std::string graph_name_; + + std::atomic enabled_{false}; + std::vector trace_array_; + std::atomic trace_index_{0}; + std::atomic total_saved_nums_{0}; + std::atomic part1_ready_nums_{0}; + std::atomic part2_ready_nums_{0}; + std::string trace_save_file_path_; + std::string current_saving_file_name_; + uint64_t current_file_saved_nums_ = 0; + ReadyPart ready_part_ = ReadyPart::None; + + std::mutex mu_; + std::thread save_thread_; + std::atomic stopped_{false}; + std::condition_variable data_ready_var_; +}; + +class TraceOwnerGuard { + public: + TraceOwnerGuard(const std::string &owner, const std::string &stage, const std::string &graph_name) { + TraceManager::GetInstance().SetTraceOwner(owner, stage, graph_name); + } + ~TraceOwnerGuard() { + TraceManager::GetInstance().ClearTraceOwner(); + } + TraceOwnerGuard(const TraceOwnerGuard &) = delete; + TraceOwnerGuard(TraceOwnerGuard &&) = delete; + TraceOwnerGuard &operator=(const TraceOwnerGuard &) = delete; + TraceOwnerGuard &operator=(TraceOwnerGuard &&) = delete; +}; + +#define TRACE TraceManager::GetInstance() +#define TRACE_HEADER TraceManager::GetTraceHeader() +} // namespace ge +#endif // COMMON_UTIL_TRACE_MANAGER_TRACE_MANAGER_H_ diff --git a/inc/metadef/exe_graph/lowering/bg_ir_attrs.h b/inc/metadef/exe_graph/lowering/bg_ir_attrs.h new file mode 100644 index 0000000000000000000000000000000000000000..1b8b8145a2d3185a45ad5c13777e75299227d2e0 --- /dev/null +++ b/inc/metadef/exe_graph/lowering/bg_ir_attrs.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_IR_ATTRS_H_ +#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_IR_ATTRS_H_ +#include "graph/node.h" +#include "value_holder.h" +namespace gert { +namespace bg { +std::unique_ptr CreateAttrBuffer(const ge::NodePtr &node, size_t &size); +std::unique_ptr CreateAttrBuffer(const ge::NodePtr &node, + const std::vector &runtime_attrs_list, + size_t &size); +std::unique_ptr CreateAttrBufferWithoutIr(const ge::NodePtr &node, + const std::vector &runtime_attrs_list, + size_t &size); +} +} +#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_IR_ATTRS_H_ diff --git a/inc/metadef/exe_graph/lowering/bg_kernel_context_extend.h b/inc/metadef/exe_graph/lowering/bg_kernel_context_extend.h new file mode 100644 index 0000000000000000000000000000000000000000..24abb1105d6b6447a9d3adba7576aafcd00c7cad --- /dev/null +++ b/inc/metadef/exe_graph/lowering/bg_kernel_context_extend.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_KERNEL_CONTEXT_EXTEND_H_ +#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_KERNEL_CONTEXT_EXTEND_H_ +#include "graph/node.h" +#include "buffer_pool.h" +#include "register/op_impl_registry.h" +namespace gert { +namespace bg { +std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool); +std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool, size_t &total_size); +std::unique_ptr CreateComputeNodeInfo(const ge::NodePtr &node, BufferPool &buffer_pool, + const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, size_t &total_size); +std::unique_ptr CreateComputeNodeInfoWithoutIrAttr(const ge::NodePtr &node, BufferPool &buffer_pool, + const gert::OpImplRegisterV2::PrivateAttrList &private_attrs, size_t &total_size); +} +} +#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_BG_KERNEL_CONTEXT_EXTEND_H_ diff --git a/inc/metadef/exe_graph/lowering/buffer_pool.h b/inc/metadef/exe_graph/lowering/buffer_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..474134131ecac509ca82987d5df92a1f1447912a --- /dev/null +++ b/inc/metadef/exe_graph/lowering/buffer_pool.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_LOWERING_BUFFER_POOL_H_ +#define AIR_CXX_RUNTIME_V2_LOWERING_BUFFER_POOL_H_ +#include +#include +#include +#include + +namespace gert { +namespace bg { +class BufferPool { + public: + using BufId = size_t; + BufId AddStr(const char *data); + BufId AddBuf(const uint8_t *data, const size_t len); + std::unique_ptr Serialize(size_t &total_size) const; + std::unique_ptr Serialize() const; + size_t GetSize() const; + + // very slow, only use in UT + const char *GetBufById(const BufId id) const; + + private: + BufId AddBuf(std::string &&str); + BufId AddLargeBuf(std::string &&str); + + private: + std::unordered_map bufs_to_id_; + std::vector> large_bufs_to_id_; // large buf size, not do hash + uint64_t id_generator_{0U}; +}; +} // namespace bg +} // namespace gert +#endif // AIR_CXX_RUNTIME_V2_LOWERING_BUFFER_POOL_H_ diff --git a/inc/metadef/exe_graph/lowering/builtin_node_types.h b/inc/metadef/exe_graph/lowering/builtin_node_types.h new file mode 100644 index 0000000000000000000000000000000000000000..2eb20238b4f9b8dc82327c9c604c78461033571b --- /dev/null +++ b/inc/metadef/exe_graph/lowering/builtin_node_types.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_LOWERING_BUILTIN_NODE_TYPES_H_ +#define METADEF_CXX_INC_EXE_GRAPH_LOWERING_BUILTIN_NODE_TYPES_H_ +#include +#include "graph/types.h" + +namespace gert { +// 子图的输出,InnerNetOutput有多个输入,没有输出,每个InnerNetOutput的输入,对应相同Index的parent node的输出 +constexpr ge::char_t const *kInnerNetOutput = "InnerNetOutput"; + +// 子图的输入,InnerData没有输入,有一个个输出,代表其对应的parent node的输入。 +// InnerData具有一个key为index的属性,该属性类型是int32,代表该InnerData对应的parent node输入的index +constexpr ge::char_t const *kInnerData = "InnerData"; + +// 图的输出,NetOutput出现在Main图上,表达执行完成后,图的输出。NetOutput有多个输入,没有输出,每个输入对应相同Index的图输出 +constexpr ge::char_t const *kNetOutput = "NetOutput"; + +// 图的输入,Data没有输入,有一个个输出,代表图的输入。 +// Data具有一个key为index的属性,该属性类型是int32,代表图的输入的Index +constexpr ge::char_t const *kData = "Data"; + +// 图的输出,未来NetOutput的会被OutputData所代替 +// OutputData只在Main图上出现,执行完成后,图的输出会被写入到OutputData +// OutputData没有输入,有多个输出,每个输出对应相同Index的图输出 +constexpr ge::char_t const *kOutputData = "OutputData"; + +// 常量节点,该节点没有输入,有一个输出,代表常量的值 +// 常量节点有一个属性"value"代表该常量节点的值,value是一段二进制,常量节点本身不关注其内容的格式 +constexpr ge::char_t const *kConst = "Const"; + +// 常量输入节点,该节点没有输入,有一个输出, +// 其值在lowering阶段不可获得,加载时由外部传入,且在执行过程中不会被改变 +// ConstData具有一个key为type的属性,该属性类型是int32,由此代表ConstData的类型,也代表顺序 +// 详见air仓ConstDataType枚举 +constexpr ge::char_t const *kConstData = "ConstData"; + +inline bool IsTypeData(const ge::char_t *const node_type) { + return strcmp(kData, node_type) == 0; +} +inline bool IsTypeInnerData(const ge::char_t *const node_type) { + return strcmp(kInnerData, node_type) == 0; +} +inline bool IsTypeInnerNetOutput(const ge::char_t *const node_type) { + return strcmp(kInnerNetOutput, node_type) == 0; +} +inline bool IsTypeNetOutput(const ge::char_t *const node_type) { + return strcmp(kNetOutput, node_type) == 0; +} +inline bool IsTypeConst(const ge::char_t *const node_type) { + return strcmp(kConst, node_type) == 0; +} +inline bool IsTypeConstData(const ge::char_t *const node_type) { + return strcmp(kConstData, node_type) == 0; +} +inline bool IsTypeOutputData(const ge::char_t *const node_type) { + return strcmp(kOutputData, node_type) == 0; +} +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_LOWERING_BUILTIN_NODE_TYPES_H_ diff --git a/inc/metadef/exe_graph/lowering/exe_graph_attrs.h b/inc/metadef/exe_graph/lowering/exe_graph_attrs.h new file mode 100644 index 0000000000000000000000000000000000000000..266318ac49caa13bec1c3f23d69540a73dd339f7 --- /dev/null +++ b/inc/metadef/exe_graph/lowering/exe_graph_attrs.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXE_GRAPH_ATTRS_H_ +#define AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXE_GRAPH_ATTRS_H_ +#include "graph/types.h" + +namespace gert { +// 打在const节点上,表达const的值, +// todo 后面这个应该慢慢要废弃掉,通过buffer id代替 +constexpr const ge::char_t *kConstValue = "value"; + +constexpr const ge::char_t *kGraph = "graph"; + +// 打在exe node上,代表本node执行的stage(例如DeInit) +constexpr const ge::char_t *kStage = "stage"; + +// 打在feed节点上(Data、InnerData),代表这是第几个输入 +constexpr const ge::char_t *kFeedIndex = "index"; + +// 打在输出desc上,标识本输出不申请独立的内存,从某个node的某个输出index上引用过来使用 +constexpr const ge::char_t *kRefFromNode = "RefFromNode"; +constexpr const ge::char_t *kRefFromIndex = "RefFromIndex"; + +// 打在exe graph上,保存了本graph涉及的所有的ComputeNodeInfo +constexpr const ge::char_t *kComputeNodeInfo = "ComputeNodeInfo"; + +// 打在exe node上,用来标识本node所对应的计算图上的node的index +constexpr const ge::char_t *kComputeNodeIndex = "ComputeNodeIndex"; + +// 打在exe graph上,保存了本graph涉及的所有的KernelExtendInfo +constexpr const ge::char_t *kKernelExtendInfo = "KernelExtendInfo"; + +// 打在exe node上,用来标识本node所对应的kernel信息的index +constexpr const ge::char_t *kKernelExtendIndex = "KernelExtendInfoIndex"; + +// 打在exe node上,用来标识本node属于某一子图,且该子图内节点源自于子图外部,且对应子图外部节点拥有Guarder +constexpr const ge::char_t *kNodeWithGuarderOutside = "NodeWithGuarderOutside"; + +// 打在exe graph上,保存了本graph涉及的所有的二进制buffer(字符串、const值等) +constexpr const ge::char_t *kBuffer = "buffer"; + +// 打在exe graph上,保存了本graph涉及的ModelDesc信息 +constexpr const ge::char_t *kModelDesc = "ModelDesc"; + +// 打在exe node上,类型是int,代表两层含义:1. 本node释放一个资源;2. 本node释放的资源位于本node的第n的输入index;n为属性的值 +constexpr ge::char_t kReleaseResourceIndex[] = "ReleaseResourceIndex"; + +// 作为扩展属性打在exe graph上,类型是ge::ComputeGraphPtr,保存的是原来的计算图,未来会删除,因为无法做序列化,执行图序列化反序列化后会丢失该属性 +constexpr const ge::char_t *kComputeGraph = "_compute_graph"; + +// 作为扩展属性打在exe node上,类型是PassChangedKernels,记录执行图经过pass后的新旧exe nodes输出的对应关系 +constexpr const ge::char_t *kPassChangedInfo = "_pass_changed_info"; + +// 打在exe graph上,保存space_registry的智能指针 +constexpr const ge::char_t *kSpaceRegistry = "SpaceRegistry"; + +// 打在exe graph上,保存外置权重文件目录的string +constexpr const ge::char_t *kExternalFileConstantDir = "ExternalFileConstantDir"; + +// 打在exe node上,类型是string,保存它的guarder节点的node type +constexpr const ge::char_t *kGuarderNodeType = "GuarderNodeType"; +} // namespace gert +#endif // AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_EXE_GRAPH_ATTRS_H_ diff --git a/inc/metadef/exe_graph/lowering/exe_res_generation_ctx_builder.h b/inc/metadef/exe_graph/lowering/exe_res_generation_ctx_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..84f12896f947b2cf611fce3ef81fd5d9be87e84b --- /dev/null +++ b/inc/metadef/exe_graph/lowering/exe_res_generation_ctx_builder.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXE_GRAPH_LOWERING_EXE_RES_GENERATION_CTX_BUILDER_H +#define INC_EXE_GRAPH_LOWERING_EXE_RES_GENERATION_CTX_BUILDER_H + +#include "external/exe_graph/runtime/exe_res_generation_context.h" +#include "graph/node.h" +#include "exe_graph/runtime/kernel_run_context_builder.h" + +namespace gert { +using ExeResGenerationCtxHolderPtr = std::shared_ptr; +class ExeResGenerationCtxBuilder { + public: + ExeResGenerationCtxHolderPtr CreateOpExeContext(ge::Node &node); + private: + void CreateShapesInputs(const ge::Node &node, std::vector &inputs); + private: + ExeResGenerationCtxHolderPtr ctx_holder_ptr_; + std::vector input_shapes_; + std::vector output_shapes_; +}; +} // namespace fe + +#endif diff --git a/inc/metadef/exe_graph/lowering/frame_selector.h b/inc/metadef/exe_graph/lowering/frame_selector.h new file mode 100644 index 0000000000000000000000000000000000000000..14b85806954abaebeac86cebb93178e548c9ac29 --- /dev/null +++ b/inc/metadef/exe_graph/lowering/frame_selector.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_LOWERING_FRAME_SELECTOR_H_ +#define METADEF_CXX_INC_EXE_GRAPH_LOWERING_FRAME_SELECTOR_H_ +#include +#include "value_holder.h" +#include "common/checker.h" +namespace gert { +namespace bg { +/** + * frame选择器,通过frame选择器,可以将执行图逻辑生成到执行的frame上 + * + * 当前frame选择器还无法提供OnInitFrame/OnDeInitFrame功能,因为ValueHolder跨图连接的能力仅支持从父图向子图连接, + * 若出现了Init向Main图中节点连边的场景,当前无法处理。对于Init/DeInit需求,按需开发即可 + */ +class FrameSelector { + public: + /** + * 选择Init图,将builder中的逻辑生成到Init图上,并返回**Init节点**的输出ValueHolder。 + * 注意:如果builder中创建的输出ValueHolder具有guarder,那么此guarder节点会被自动移动到DeInit图中 + * @param builder 执行图构建函数 + * @return 成功时,将builder返回的ValueHolderPtr作为本函数的返回值;失败时,本函数返回空vector + */ + static std::vector OnInitRoot(const std::function()> &builder); + + static ge::graphStatus OnInitRoot(const std::function()> &builder, + std::vector &init_graph_outputs, + std::vector &init_node_outputs); + /** + * 选择DeInit图,将builder中的逻辑生成到DeInit图上,并返回builder返回的ValueHolder。 + * 注意:builder中创建无输出的ValueHolder,即CreateVoid,并将其返回。 + * @param builder 执行图构建函数 + * @return 成功时,将DeInit图中builder返回的ValueHolder节点作为本函数的返回值;失败时,本函数返回空vector + */ + static std::vector OnDeInitRoot(const std::function()> &builder); + + /** + * 选择Main图,将builder中的逻辑生成到Main图上。 + * + * 需要注意的是,本函数仅保证当前将builder中的逻辑生成到Main图上,但不保证其始终在Main图上。 + * 在lowering构图完成后,在图优化阶段,CEM等优化可能将Main图上的Node移动到Init图中。 + * + * @param builder 执行图构建函数 + * @return 成功时,将builder返回的ValueHolderPtr作为本函数的返回值;失败时,本函数返回空vector + */ + static std::vector OnMainRoot(const std::function()> &builder); + + static ge::graphStatus OnMainRoot(const std::function()> &builder, + std::vector &outputs); + /** + * 选择Main图,将builder中的逻辑生成到Main图上, 并且保证builder生成的节点在main图最开始执行 + * 当前已有阶段,请参考bg::OnMainRootFirstExecStage的枚举值 + * + * @param builder 执行图构建函数 + * @return 成功时,将builder返回的ValueHolderPtrs作为本函数的返回值;失败时,本函数返回空vector + */ + static std::vector OnMainRootFirst(const std::function()> &builder); + + static ValueHolderPtr OnMainRootLast(const std::function &builder); + + /** + * 选择Main图,将builder中的逻辑生成到Main图上, builder生成的节点在LastEventSync阶段执行. + * 当前已有阶段,请参考bg::OnMainRootLastExecStage的枚举值 + * + * @param builder 执行图构建函数 + * @return 成功时,将builder返回的ValueHolderPtrs作为本函数的返回值;失败时,本函数返回空vector + */ + static std::vector OnMainRootLastEventSync( + const std::function()> &builder); + + /** + * 选择Main图,将builder中的逻辑生成到Main图上, builder生成的节点在LastResourceClean阶段执行. + * 当前已有阶段,请参考bg::OnMainRootLastExecStage的枚举值 + * + * @param builder 执行图构建函数 + * @return 成功时,将builder返回的ValueHolderPtrs作为本函数的返回值;失败时,本函数返回空vector + */ + static std::vector OnMainRootLastResourceClean( + const std::function()> &builder); +}; + +ValueHolderPtr HolderOnInit(const ValueHolderPtr &holder); +} // namespace bg +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_LOWERING_FRAME_SELECTOR_H_ diff --git a/inc/metadef/exe_graph/lowering/generate_exe_graph.h b/inc/metadef/exe_graph/lowering/generate_exe_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..85c5104feb5edfb40bd8cad4981b29ff1b87f7a6 --- /dev/null +++ b/inc/metadef/exe_graph/lowering/generate_exe_graph.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_LOWERING_GENERATE_EXE_GRAPH_H_ +#define METADEF_CXX_INC_EXE_GRAPH_LOWERING_GENERATE_EXE_GRAPH_H_ +#include + +#include "dev_mem_value_holder.h" +#include "graph/compute_graph.h" +#include "lowering_global_data.h" +namespace gert { +namespace bg { +class GenerateExeGraph { + public: + struct ExeGraphGenerator { + using InferShapeFunc = std::vector (*)(const ge::NodePtr &node, + const std::vector &shapes); + using AllocOutputMemoryFunc = std::vector (*)(TensorPlacement placement, + const ge::NodePtr &node, + const std::vector &output_sizes, + LoweringGlobalData &global_data); + using CalcTensorSizeFunc = std::vector (*)(const ge::NodePtr &node, + const std::vector &output_shapes); + + InferShapeFunc infer_shape; + AllocOutputMemoryFunc alloc_output_memory; + CalcTensorSizeFunc calc_tensor_size; + }; + + public: + static std::vector InferShape(const ge::NodePtr &node, const std::vector &shapes) { + if (generator_.infer_shape == nullptr) { + return {}; + } + return generator_.infer_shape(node, shapes); + } + static std::vector AllocOutputMemory(TensorPlacement placement, const ge::NodePtr &node, + const std::vector &output_sizes, + LoweringGlobalData &global_data) { + if (generator_.alloc_output_memory == nullptr) { + return {}; + } + return generator_.alloc_output_memory(placement, node, output_sizes, global_data); + } + static std::vector CalcTensorSize(const ge::NodePtr &node, + const std::vector &output_shapes) { + if (generator_.calc_tensor_size == nullptr) { + return {}; + } + return generator_.calc_tensor_size(node, output_shapes); + } + + static void AddBuilderImplement(ExeGraphGenerator generator) { + generator_ = generator; + } + + static ValueHolderPtr MakeSureTensorAtHost(const ge::Node *node, LoweringGlobalData &global_data, + const ValueHolderPtr &addr, const ValueHolderPtr &size); + + static ValueHolderPtr CalcTensorSizeFromShape(ge::DataType dt, const ValueHolderPtr &shape); + + static ValueHolderPtr FreeMemoryGuarder(const ValueHolderPtr &resource); + + private: + static ExeGraphGenerator generator_; +}; +} // namespace bg +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_LOWERING_GENERATE_EXE_GRAPH_H_ diff --git a/inc/metadef/exe_graph/lowering/graph_frame.h b/inc/metadef/exe_graph/lowering/graph_frame.h new file mode 100644 index 0000000000000000000000000000000000000000..029385714f077af32545be34d2998f5e15cc410e --- /dev/null +++ b/inc/metadef/exe_graph/lowering/graph_frame.h @@ -0,0 +1,136 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_GRAPH_FRAME_H_ +#define AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_GRAPH_FRAME_H_ +#include +#include +#include +#include +#include "graph/node.h" +#include "buffer_pool.h" +#include "bg_kernel_context_extend.h" +#include "graph/fast_graph/execute_graph.h" + +namespace gert { +namespace bg { +class ValueHolder; +using ValueHolderPtr = std::shared_ptr; +constexpr const ge::char_t *kStageIdsToLastPartitionedCall = "StageIdsToLastPartitionedCall"; +constexpr const ge::char_t *kStageIdsToFirstPartitionedCall = "StageIdsToFirstPartitionedCall"; +/* + * 执行的阶段, 越小越靠前执行,越大越靠后执行 + */ +enum class OnMainRootLastExecStage { + kFirstStage = 0, + kLastEventSyncStage, + kLastResourceClean, + // add level before this + kStageSize +}; + +/* + * 执行的阶段, 越小越靠前执行,越大越靠后执行 + */ +enum class OnMainRootFirstExecStage { + kFirstEventSyncStage = 0, + // add level before this + kStageSize +}; + +class GraphFrame { + public: + GraphFrame(const GraphFrame &) = delete; + GraphFrame(GraphFrame &&) = delete; + GraphFrame operator=(const GraphFrame &) = delete; + GraphFrame operator=(GraphFrame &&) = delete; + + GraphFrame(ge::ExecuteGraphPtr exe_graph, const GraphFrame &parent_frame) noexcept + : execute_graph_(std::move(exe_graph)), + current_compute_node_and_index_(), root_frame_(parent_frame.root_frame_), + nodes_to_index_(root_frame_.nodes_to_index_), indexes_to_node_(root_frame_.indexes_to_node_), + relevant_input_node_(root_frame_.relevant_input_node_) {} + + explicit GraphFrame(ge::ExecuteGraphPtr exe_graph) noexcept + : execute_graph_(std::move(exe_graph)), + current_compute_node_and_index_(), root_frame_(*this), + nodes_to_index_holder_(), nodes_to_index_(nodes_to_index_holder_), indexes_to_node_holder_(), + indexes_to_node_(indexes_to_node_holder_), relevant_input_node_holder_(), + relevant_input_node_(relevant_input_node_holder_) {} + + const ge::NodePtr &GetCurrentComputeNode() const { + return current_compute_node_and_index_.first; + } + void SetCurrentComputeNode(const ge::NodePtr ¤t_node) { + if (current_node == nullptr) { + current_compute_node_and_index_ = {nullptr, 0}; + return; + } + const auto result = nodes_to_index_.emplace(current_node, nodes_to_index_.size()); + current_compute_node_and_index_ = {current_node, result.first->second}; + if (result.second) { + indexes_to_node_.emplace_back(current_node); + } + } + void AddRelevantInputNode(const ge::NodePtr ¤t_node) { + relevant_input_node_.emplace_back(current_node); + } + bool GetCurrentNodeIndex(size_t &index) const { + if (current_compute_node_and_index_.first == nullptr) { + return false; + } + index = current_compute_node_and_index_.second; + return true; + } + + bool IsRootFrame() const { + return &root_frame_ == this; + } + + const ge::ExecuteGraphPtr &GetExecuteGraph() const { + return execute_graph_; + } + + const vector &GetIndexesToNode() const { + return indexes_to_node_; + } + + const std::unordered_map &GetNodesToIndex() const { + return nodes_to_index_; + } + + const std::vector &GetLastExecNodes() const { + return last_exec_nodes_; + } + + /* + * set last exec node, its priority is first level + */ + void SetLastExecNode(const ValueHolderPtr last_exec_node) { + if (last_exec_node != nullptr) { + last_exec_nodes_.emplace_back(last_exec_node); // todo to be deprecated + } + } + + private: + ge::ExecuteGraphPtr execute_graph_; + std::pair current_compute_node_and_index_; + GraphFrame &root_frame_; + std::unordered_map nodes_to_index_holder_; + std::unordered_map &nodes_to_index_; + std::vector indexes_to_node_holder_; + std::vector &indexes_to_node_; + std::vector last_exec_nodes_; // todo to be deprecated + std::vector relevant_input_node_holder_; + std::vector &relevant_input_node_; +}; +} // namespace bg +} // namespace gert + +#endif // AIR_CXX_RUNTIME_V2_METADEF_EXE_GRAPH_GRAPH_FRAME_H_ diff --git a/inc/metadef/exe_graph/lowering/lowering_definitions.h b/inc/metadef/exe_graph/lowering/lowering_definitions.h new file mode 100644 index 0000000000000000000000000000000000000000..aa263424884adbec67acc5f56f16ded0378bfb5a --- /dev/null +++ b/inc/metadef/exe_graph/lowering/lowering_definitions.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_DEFINITIONS_H_ +#define AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_DEFINITIONS_H_ +#include "graph/types.h" + +namespace gert { +constexpr const ge::char_t *kLoweringResult = "_lowering_result"; +constexpr const ge::char_t *kLoweringTensorResult = "_lowering_tensor_result"; +constexpr const ge::char_t *kLoweringHostTensorResult = "_lowering_host_tensor_result"; +} +#endif // AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_DEFINITIONS_H_ diff --git a/inc/metadef/exe_graph/lowering/lowering_global_data.h b/inc/metadef/exe_graph/lowering/lowering_global_data.h new file mode 100644 index 0000000000000000000000000000000000000000..c5637ac587f47a687f946840d57fda5db072f62d --- /dev/null +++ b/inc/metadef/exe_graph/lowering/lowering_global_data.h @@ -0,0 +1,179 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_GLOBAL_DATA_H_ +#define AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_GLOBAL_DATA_H_ +#include +#include "proto/task.pb.h" +#include "value_holder.h" +#include "exe_graph/runtime/tensor.h" +#include "exe_graph/runtime/allocator.h" +#include "exe_graph/runtime/execute_graph_types.h" +#include "register/op_impl_space_registry.h" +#include "exe_graph/lowering/lowering_opt.h" + +namespace gert { +constexpr int64_t kRtMemoryTypeHbm = 0x2; +constexpr int64_t kDefaultMainStreamId = 0; +// todo change to get stream num from model_desc const data +constexpr const ge::char_t *kGlobalDataModelStreamNum = "ModelStreamNum"; +class LoweringGlobalData { + public: + struct NodeCompileResult { + const std::vector &GetTaskDefs() const { + return task_defs; + } + std::vector task_defs; + }; + + std::vector LoweringAndSplitRtStreams(int64_t stream_num); + bg::ValueHolderPtr GetStreamById(int64_t logic_stream_id) const; + inline bg::ValueHolderPtr GetStream() const { + int64_t current_stream_id = kDefaultMainStreamId; + if ((bg::ValueHolder::GetCurrentFrame() != nullptr) && + (bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode() != nullptr)) { + current_stream_id = bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode()->GetOpDesc()->GetStreamId(); + } + return GetStreamById(current_stream_id); + } + + void SetRtNotifies(const std::vector ¬ify_holders); + bg::ValueHolderPtr GetNotifyById(int64_t logic_notify_id) const; + + const NodeCompileResult *FindCompiledResult(const ge::NodePtr &node) const; + LoweringGlobalData &AddCompiledResult(const ge::NodePtr &node, NodeCompileResult compile_result); + + void *GetGraphStaticCompiledModel(const std::string &graph_name) const; + LoweringGlobalData &AddStaticCompiledGraphModel(const std::string &graph_name, void *const model); + + bg::ValueHolderPtr GetL1Allocator(const AllocatorDesc &desc) const; + LoweringGlobalData &SetExternalAllocator(bg::ValueHolderPtr &&allocator); + LoweringGlobalData &SetExternalAllocator(bg::ValueHolderPtr &&allocator, const ExecuteGraphType graph_type); + + bg::ValueHolderPtr GetOrCreateL1Allocator(const AllocatorDesc desc); + bg::ValueHolderPtr GetOrCreateL2Allocator(int64_t logic_stream_id, const AllocatorDesc desc); + bg::ValueHolderPtr GetInitL2Allocator(const AllocatorDesc desc) const; + bg::ValueHolderPtr GetMainL2Allocator(int64_t logic_stream_id, const AllocatorDesc desc) const; + inline bg::ValueHolderPtr GetOrCreateAllocator(const AllocatorDesc desc) { + int64_t current_stream_id = kDefaultMainStreamId; + if ((bg::ValueHolder::GetCurrentFrame() != nullptr) && + (bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode() != nullptr)) { + current_stream_id = bg::ValueHolder::GetCurrentFrame()->GetCurrentComputeNode()->GetOpDesc()->GetStreamId(); + } + return GetOrCreateL2Allocator(current_stream_id, desc); + } + bg::ValueHolderPtr GetOrCreateAllL2Allocators(); + + bg::ValueHolderPtr GetOrCreateUniqueValueHolder(const std::string &name, + const std::function &builder); + std::vector GetOrCreateUniqueValueHolder(const std::string &name, + const std::function()> &builder); + bg::ValueHolderPtr GetUniqueValueHolder(const std::string &name) const; + void SetUniqueValueHolder(const std::string &name, const bg::ValueHolderPtr &holder); + void SetValueHolders(const string &name, const bg::ValueHolderPtr &holder); + size_t GetValueHoldersSize(const string &name); + + void SetModelWeightSize(const size_t require_weight_size); + size_t GetModelWeightSize() const; + const OpImplSpaceRegistryArray &GetSpaceRegistries() { + return space_registries_; + } + const OpImplSpaceRegistryPtr GetSpaceRegistry(ge::OppImplVersion opp_impl_version = ge::OppImplVersion::kOpp) const { + if (opp_impl_version >= ge::OppImplVersion::kVersionEnd) { + return nullptr; + } + return space_registries_[static_cast(opp_impl_version)]; + }; + void SetSpaceRegistries(const gert::OpImplSpaceRegistryArray &space_registries) { + space_registries_ = space_registries; + } + // 兼容air, 随后air合入后删除 + void SetSpaceRegistry(gert::OpImplSpaceRegistryPtr space_registry) { + space_registries_[static_cast(ge::OppImplVersion::kOpp)] = space_registry; + } + + const LoweringOption &GetLoweringOption() const; + void SetLoweringOption(const LoweringOption &lowering_option); + + void SetStaicModelWsSize(const int64_t require_ws_size) { + static_model_ws_size = require_ws_size; + } + + int64_t GetStaticModelWsSize() const { + return static_model_ws_size; + } + + void SetFixedFeatureMemoryBase(const void * const memory, const size_t size) { + fixed_feature_mem_[kRtMemoryTypeHbm] = std::make_pair(memory, size); + } + + const std::pair &GetFixedFeatureMemoryBase() const { + const auto iter = fixed_feature_mem_.find(kRtMemoryTypeHbm); + if (iter != fixed_feature_mem_.end()) { + return iter->second; + } + static std::pair dummy_result; + return dummy_result; + } + + void SetFixedFeatureMemoryBase(const int64_t type, const void * const memory, const size_t size) { + fixed_feature_mem_[type] = std::make_pair(memory, size); + } + + /* + * 获取图所需fixed feature memory地址和长度 + * 1 地址为nullptr,长度为0:用户设置的结果,比较特殊,表示不需要GE默认申请fixed内存 + * 2 地址为nullptr,长度不为0,表示需要fixed内存,但是用户没有设置。GE要默认申请fixed内存 + * 3 地址不为nullptr,长度不为0,用户设置的结果。 + */ + const std::map> &GetAllTypeFixedFeatureMemoryBase() const { + return fixed_feature_mem_; + } + + bool IsSingleStreamScene() const { + return is_single_stream_scene_; + } + + void SetHostResourceCenter(void *host_resource_center_ptr) { + host_resource_center_ = host_resource_center_ptr; + } + void *GetHostResourceCenter() { + return host_resource_center_; + } + + private: + struct HolderByGraphs { + bg::ValueHolderPtr holders[static_cast(ExecuteGraphType::kNum)]; + }; + struct HoldersByGraphs { + std::vector holders[static_cast(ExecuteGraphType::kNum)]; + }; + + bg::ValueHolderPtr GetOrCreateInitL2Allocator(const AllocatorDesc desc); + bg::ValueHolderPtr GetExternalAllocator(const bool from_init, const string &key, const AllocatorDesc &desc); + bool CanUseExternalAllocator(const ExecuteGraphType &graph_type, const TensorPlacement placement) const; + private: + std::unordered_map node_name_to_compile_result_holders_; + std::map graph_to_static_models_; + std::unordered_map> unique_name_to_value_holders_; + HoldersByGraphs streams_; + HoldersByGraphs notifies_; + HolderByGraphs external_allocators_; + // todo need delete and change to const_data after const_data is ready + int64_t model_weight_size_; + int64_t static_model_ws_size; + OpImplSpaceRegistryArray space_registries_; + LoweringOption lowering_option_; + // addr为nullptr,但size不为0,表示用户没有设置fixed内存,需要GE默认申请fixed内存 + std::map> fixed_feature_mem_; + bool is_single_stream_scene_{true}; + void *host_resource_center_{nullptr}; +}; +} // namespace gert +#endif // AIR_CXX_RUNTIME_V2_LOWERING_LOWERING_GLOBAL_DATA_H_ diff --git a/inc/metadef/exe_graph/lowering/value_holder.h b/inc/metadef/exe_graph/lowering/value_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..a14a5e8e3d14f487571be01ffaa9a82f60ac32b5 --- /dev/null +++ b/inc/metadef/exe_graph/lowering/value_holder.h @@ -0,0 +1,225 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_H_ +#define AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_H_ +#include +#include +#include +#include + +#include "graph/buffer.h" +#include "graph/any_value.h" +#include "graph/compute_graph.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/fast_node_utils.h" +#include "graph/node.h" +#include "common/hyper_status.h" +#include "graph_frame.h" +#include "exe_graph/runtime/tensor.h" +#include "common/checker.h" +#include "common/util/mem_utils.h" +#include "graph/fast_graph/execute_graph.h" + +namespace gert { +namespace bg { +class ValueHolder; +using ValueHolderPtr = std::shared_ptr; +class ValueHolder { + public: + enum class ValueHolderType { + kConst, // 常量,执行时不变 + kFeed, // 执行时外部指定 + kOutput, // 由node产生,包含数据输出与控制输出 + kConstData, // 常量Const,执行时由外部指定,执行时不变 + // Add new type definitions here + kValueHolderTypeEnd + }; + + class CurrentComputeNodeGuarder { + public: + explicit CurrentComputeNodeGuarder(ge::NodePtr old_node) : old_node_(std::move(old_node)) {} + ~CurrentComputeNodeGuarder() { + ValueHolder::SetCurrentComputeNode(old_node_); + } + + private: + ge::NodePtr old_node_; + }; + + ValueHolder(const ValueHolder &other) = delete; + ValueHolder &operator=(const ValueHolder &other) = delete; + virtual ~ValueHolder(); + + bool IsOk() const noexcept; + + HyperStatus AddInnerDataToKVMap(int32_t index) const noexcept; + + int64_t GetId() const noexcept; + ValueHolderType GetType() const noexcept; + + ge::FastNode *GetFastNode() const noexcept; + ge::ExecuteGraph *GetExecuteGraph() const noexcept; + + ValueHolderPtr GetGuarder() const noexcept; + void SetGuarder(const bg::ValueHolderPtr &guarder) noexcept; + + int32_t GetOutIndex() const noexcept; + + // ref-from other的含义是,本value指向了other(本value没有独立的内存) + ge::graphStatus RefFrom(const ValueHolderPtr &other); + + // 在other产生后,本holder的生命周期才结束 + void ReleaseAfter(const ValueHolderPtr &other); + + const int32_t &GetPlacement() const; + void SetPlacement(const int32_t &placement); + + template + std::vector> AppendOutputs(size_t append_count, Args... args) { + auto start_index = fast_node_->GetDataOutNum(); + auto ret = ge::FastNodeUtils::AppendOutputEdgeInfo(fast_node_, start_index + append_count); + if (ret != ge::GRAPH_SUCCESS) { + return {}; + } + return CreateFromNode(fast_node_, start_index, append_count, args...); + } + // src nodes may come from different graph from current node and can add data edges to current node + // currently only support to pass through parent nodes with only one subgraph + ge::graphStatus AppendInputs(const std::vector &src); + + static ValueHolderPtr CreateError(const ge::char_t *fmt, ...); + static ValueHolderPtr CreateError(const ge::char_t *fmt, va_list arg); + + static ValueHolderPtr CreateConst(const void *data, size_t size, bool is_string = false); + + static ValueHolderPtr CreateFeed(int64_t index); + + static ValueHolderPtr CreateConstData(int64_t index); + + static ValueHolderPtr CreateSingleDataOutput(const ge::char_t *node_type, const std::vector &inputs); + + static std::vector CreateDataOutput(const ge::char_t *node_type, + const std::vector &inputs, size_t out_count); + + template + static std::shared_ptr CreateVoid(const ge::char_t *node_type, const std::vector &inputs, + Args... args) { + auto node = CreateNode(node_type, inputs, 0); + GE_ASSERT_NOTNULL(node); + return CreateFromNode(node, -1, ValueHolderType::kOutput, args...); + } + + static ValueHolderPtr CreateVoidGuarder(const ge::char_t *node_type, const ValueHolderPtr &resource, + const std::vector &args); + + static HyperStatus AddDependency(const ValueHolderPtr &src, const ValueHolderPtr &dst); + + /** + * 压栈一个Root GraphFrame,只有栈底的GraphFrame才被称为ROOT GraphFrame,因此调用此借口前,需要保证栈内不存在GraphFrame,否则会失败 + * @return 成功后,返回创建好的GraphFrame指针,失败时返回空指针 + */ + static GraphFrame *PushGraphFrame(); + /** + * 压栈一个非root的GraphFrame + * @param belongs 新加入的GraphFrame所归属的ValueHolder,新压栈的GraphFrame会被挂在该ValueHolder所归属的Node上 + * @param graph_name 挂接GraphFrame到Node时,使用的name + * @return 创建且挂接成功后,返回创建好的GraphFrame指针,失败时返回空指针 + */ + static GraphFrame *PushGraphFrame(const ValueHolderPtr &belongs, const ge::char_t *graph_name); + /** + * 压栈一个GraphFrame, 若该graph frame非root frame,需要保证栈顶frame为其父frame + * @return 成功后,返回该GraphFrame指针,失败时返回空指针 + */ + static GraphFrame *PushGraphFrame(GraphFrame *graph_frame); + + static std::unique_ptr PopGraphFrame(); + static std::unique_ptr PopGraphFrame(const std::vector &outputs, + const std::vector &targets); + + static std::unique_ptr PopGraphFrame(const std::vector &outputs, + const std::vector &targets, + const ge::char_t *out_node_type); + + static GraphFrame *GetCurrentFrame(); + + static void ClearGraphFrameResource(); + + static ge::ExecuteGraph *GetCurrentExecuteGraph(); + + static void SetCurrentComputeNode(const ge::NodePtr &node); + static void AddRelevantInputNode(const ge::NodePtr &node); + static std::unique_ptr SetScopedCurrentComputeNode(const ge::NodePtr &node); + + static ge::FastNode *AddNode(const ge::char_t *node_type, size_t input_count, size_t output_count, + const GraphFrame &frame); + + template + static std::vector> CreateFromNode(ge::FastNode *node, size_t start_index, + size_t create_count, Args... args) { + if (node == nullptr) { + return {create_count, nullptr}; + } + std::vector> holders; + for (size_t i = 0; i < create_count; ++i) { + holders.emplace_back( + CreateFromNode(node, static_cast(i + start_index), ValueHolderType::kOutput, args...)); + } + + return holders; + } + + template + static std::shared_ptr CreateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type, Args... args) { + auto holder = std::shared_ptr(new (std::nothrow) T(args...)); + GE_ASSERT_NOTNULL(holder); + + holder->type_ = type; + holder->fast_node_ = node; + holder->index_ = index; + holder->op_desc_ = holder->fast_node_->GetOpDescPtr(); + return holder; + } + + virtual ValueHolderPtr CreateMateFromNode(ge::FastNode *node, int32_t index, ValueHolderType type); + + static std::string GenerateNodeName(const ge::char_t *node_type, const GraphFrame &frame); + + static std::vector GetLastExecNodes(); + + protected: + ValueHolder(); + + static ge::FastNode *CreateNode(const ge::char_t *node_type, const std::vector &inputs, + size_t out_count); + + template + static std::vector> CreateFromNodeStart(ge::FastNode *node, size_t out_count, + Args... args) { + return CreateFromNode(node, 0U, out_count, args...); + } + + void SetErrorMsg(const char *fmt, va_list arg); + + private: + static std::atomic id_generator_; + int64_t id_; + ValueHolder::ValueHolderType type_; + ge::FastNode *fast_node_; // 通过ValueHolder创建的fast_node节点如果后续在图中被删除,此处会是无效指针,不能直接使用 + ge::OpDescPtr op_desc_; + int32_t index_; + int32_t placement_; + std::unique_ptr error_msg_; + ValueHolderPtr guarder_; + friend class ValueHolderUtils; +}; +} // namespace bg +} // namespace gert + +#endif // AIR_CXX_RUNTIME_V2_GRAPH_BUILDER_VALUE_HOLDER_H_ diff --git a/inc/metadef/exe_graph/runtime/allocator.h b/inc/metadef/exe_graph/runtime/allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..9d7c44e49633879ba1edbebc14ffe78241f70ba7 --- /dev/null +++ b/inc/metadef/exe_graph/runtime/allocator.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_INC_EXE_GRAPH_RUNTIME_ALLOCATOR_H_ +#define METADEF_INC_EXE_GRAPH_RUNTIME_ALLOCATOR_H_ +#include +#include "exe_graph/runtime/tensor.h" + +namespace gert { +enum class AllocatorUsage { + kAllocNodeOutput, + kAllocNodeWorkspace, + kAllocNodeShapeBuffer, + kEnd +}; +struct AllocatorDesc { + TensorPlacement placement; + AllocatorUsage usage; + bool operator<(const AllocatorDesc &other) const { + return std::tie(placement, usage) < std::tie(other.placement, other.usage); + } + std::string GetKey() const { + return "Allocator-" + std::to_string(placement); + } +}; +} +#endif // METADEF_INC_EXE_GRAPH_RUNTIME_ALLOCATOR_H_ diff --git a/inc/metadef/exe_graph/runtime/atomic_clean_tiling_context.h b/inc/metadef/exe_graph/runtime/atomic_clean_tiling_context.h new file mode 100644 index 0000000000000000000000000000000000000000..1e7f4db0853837a4869cb6897a498e4d06f83402 --- /dev/null +++ b/inc/metadef/exe_graph/runtime/atomic_clean_tiling_context.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATOMICCLEANTILINGCONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATOMICCLEANTILINGCONTEXT_H_ +#include "exe_graph/runtime/tiling_context.h" +#include "exe_graph/runtime/continuous_vector.h" +namespace gert { +class AtomicCleanTilingContext : public TilingContext { + public: + /** + * 获取workspace size的列表 + * @return workspace size列表 + */ + const ContinuousVector *GetCleanWorkspaceSizes() const { + return GetInputPointer(0); + } + + /** + * 通过节点的输出index,获取需要清理的输出内存的大小 + * @param index 节点输出index + * @return 需要清理的输出内存的大小 + */ + uint64_t GetCleanOutputSize(size_t index) const { + return GetInputValue(index + 1U); + } +}; +static_assert(std::is_standard_layout::value, + "The class AtomicCleanTilingContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATOMICCLEANTILINGCONTEXT_H_ diff --git a/inc/metadef/exe_graph/runtime/continuous_buffer.h b/inc/metadef/exe_graph/runtime/continuous_buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..69df5e9a6807715a38174e9eafa815c48698485c --- /dev/null +++ b/inc/metadef/exe_graph/runtime/continuous_buffer.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_BUFFER_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_BUFFER_H_ +#include +#include +#include "graph/def_types.h" +namespace gert { +namespace bg { +class BufferPool; +} +class ContinuousBuffer { + public: + /** + * 获取buffer的数量 + * @return buffer的数量 + */ + size_t GetNum() const { + return num_; + } + /** + * 获取本实例的总长度 + * @return 本实例的总长度,单位为字节 + */ + size_t GetTotalLength() const { + return offsets_[num_]; + } + /** + * 获取一个buffer + * @tparam T buffer的类型 + * @param index buffer的index + * @return 指向该buffer的指针,若index非法,则返回空指针 + */ + template + const T *Get(size_t index) const { + if (index >= num_) { + return nullptr; + } + return ge::PtrToPtr(ge::PtrToPtr(this) + offsets_[index]); + } + /** + * 获取一个buffer,及其对应的长度 + * @tparam T buffer的类型 + * @param index buffer的index + * @param len buffer的长度 + * @return 指向该buffer的指针,若index非法,则返回空指针 + */ + template + const T *Get(size_t index, size_t &len) const { + if (index >= num_) { + return nullptr; + } + len = offsets_[index + 1] - offsets_[index]; + return ge::PtrToPtr(ge::PtrToPtr(this) + offsets_[index]); + } + /** + * 获取一个buffer + * @tparam T buffer的类型 + * @param index buffer的index + * @return 指向该buffer的指针,若index非法,则返回空指针 + */ + template + T *Get(size_t index) { + if (index >= num_) { + return nullptr; + } + return ge::PtrToPtr(ge::PtrToPtr(this) + offsets_[index]); + } + + private: + friend ::gert::bg::BufferPool; + size_t num_; + int64_t reserved_; // Reserved field, 8-byte aligned + size_t offsets_[1]; +}; +static_assert(std::is_standard_layout::value, "The class ContinuousText must be POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_BUFFER_H_ diff --git a/inc/metadef/exe_graph/runtime/dfx_info_filler.h b/inc/metadef/exe_graph/runtime/dfx_info_filler.h new file mode 100644 index 0000000000000000000000000000000000000000..d620c407a7d0962b25959e7fcbeeb36fa1054702 --- /dev/null +++ b/inc/metadef/exe_graph/runtime/dfx_info_filler.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_INC_EXE_GRAPH_RUNTIME_DFX_INFO_FILLER_H +#define METADEF_INC_EXE_GRAPH_RUNTIME_DFX_INFO_FILLER_H + +#include +#include +#include + +namespace gert { +enum class NodeProfInfoType : uint32_t { + kOriginalNode, + kCmoPreFetch, + kCmoInvalidate, + kCmoWriteBack, + kMixVectorCore, + kNodeTypeMax +}; + +class ProfilingInfoWrapper { + public: + virtual ~ProfilingInfoWrapper() = default; + + virtual void SetBlockDim(uint32_t block_dim) { + (void)block_dim; + } + + virtual void SetBlockDim(const uint32_t block_dim, const NodeProfInfoType prof_info_type) { + (void) block_dim; + (void) prof_info_type; + } + + virtual void SetMixLaunchEnable(const bool mix_launch_enable) { + (void) mix_launch_enable; + } + + virtual void SetLaunchTimeStamp(const uint64_t begin_time, const uint64_t end_time, + const NodeProfInfoType prof_info_type) { + (void) begin_time; + (void) end_time; + (void) prof_info_type; + } + + virtual void SetBlockDimForAtomic(uint32_t block_dim) { + (void)block_dim; + } + + virtual ge::graphStatus FillShapeInfo(const std::vector> &input_shapes, + const std::vector> &output_shapes) { + (void)input_shapes; + (void)output_shapes; + return ge::GRAPH_SUCCESS; + } +}; + +class DataDumpInfoWrapper { + public: + virtual ~DataDumpInfoWrapper() = default; + virtual ge::graphStatus CreateFftsCtxInfo(uint32_t thread_id, uint32_t context_id) = 0; + virtual ge::graphStatus AddFftsCtxAddr(uint32_t thread_id, bool is_input, uint64_t address, uint64_t size) = 0; + virtual void AddWorkspace(uintptr_t addr, int64_t bytes) = 0; + virtual bool SetStrAttr(const std::string &name, const std::string &value) = 0; +}; + +class ExceptionDumpInfoWrapper { + public: + virtual ~ExceptionDumpInfoWrapper() = default; + virtual void SetTilingData(uintptr_t addr, size_t size) = 0; + virtual void SetTilingKey(uint32_t key) = 0; + virtual void SetHostArgs(uintptr_t addr, size_t size) = 0; + virtual void SetDeviceArgs(uintptr_t addr, size_t size) = 0; + virtual void AddWorkspace(uintptr_t addr, int64_t bytes) = 0; +}; +} + +#endif + diff --git a/inc/metadef/exe_graph/runtime/dvpp_context.h b/inc/metadef/exe_graph/runtime/dvpp_context.h new file mode 100644 index 0000000000000000000000000000000000000000..e4f6667fbacad5ad53a022565dcf404b77357f6b --- /dev/null +++ b/inc/metadef/exe_graph/runtime/dvpp_context.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_DVPP_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_DVPP_CONTEXT_H_ +#include +#include "exe_graph/runtime/storage_shape.h" +#include "exe_graph/runtime/tensor.h" +#include "exe_graph/runtime/extended_kernel_context.h" + +namespace gert { +/** + * Dvpp kernel的context + */ +class DvppContext : public ExtendedKernelContext { +public: + /** + * 获取输入shape,输入shape中包含了原始shape与运行时shape + * @param index 输入index + * @return 输入shape指针,index非法时返回空指针 + */ + const StorageShape *GetInputShape(size_t index) const { + auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + + if (index >= compute_node_info->GetInputsNum()) { + return nullptr; + } + + return GetInputPointer(index); + } + + /** + * 获取输入tensor + * + * **注意:只有在`IMPL_OP`实现算子时, 将对应输入设置为数据依赖后, + * 才可以调用此接口获取tensor,否则行为是未定义的。** + * @param index 输入index + * @return 输入tensor指针,index非法时返回空指针 + */ + const Tensor *GetInputTensor(size_t index) const { + return GetInputPointer(index); + } + + /** + * 根据输出index,获取输出shape指针,shape中包含了原始shape与运行时shape + * @param index 输出index + * @return 输出shape指针,index非法时,返回空指针 + */ + const StorageShape *GetOutputShape(size_t index) const { + auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + + if (index >= compute_node_info->GetOutputsNum()) { + return nullptr; + } + + size_t offset = compute_node_info->GetInputsNum(); + return GetInputPointer(offset + index); + } +}; +static_assert(std::is_standard_layout::value, + "The class DvppContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_DVPP_CONTEXT_H_ diff --git a/inc/metadef/exe_graph/runtime/execute_graph_types.h b/inc/metadef/exe_graph/runtime/execute_graph_types.h new file mode 100644 index 0000000000000000000000000000000000000000..6ee4d6c35fd4f0c9edea39e16a23311cba368e34 --- /dev/null +++ b/inc/metadef/exe_graph/runtime/execute_graph_types.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXECUTE_GRAPH_TYPES_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXECUTE_GRAPH_TYPES_H_ +#include +namespace gert { +/** + * 执行图类型,在一个Model中,包含多张执行图,本枚举定义了所有执行图的类型 + */ +enum class ExecuteGraphType { + kInit, //!< 初始化图,本张图在图加载阶段执行 + kMain, //!< 主图,每次执行图时,均执行本张图 + kDeInit, //!< 去初始化图,在图卸载时,执行本张图 + kNum +}; + +/** + * 获取执行图的字符串描述 + * @param type 执行图类型枚举 + * @return + */ +inline const char *GetExecuteGraphTypeStr(const ExecuteGraphType type) { + if (type >= ExecuteGraphType::kNum) { + return nullptr; + } + constexpr const char *kStrs[static_cast(ExecuteGraphType::kNum)] = {"Init", "Main", "DeInit"}; + return kStrs[static_cast(type)]; +} +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXECUTE_GRAPH_TYPES_H_ diff --git a/inc/metadef/exe_graph/runtime/getcdim.h b/inc/metadef/exe_graph/runtime/getcdim.h new file mode 100644 index 0000000000000000000000000000000000000000..a8f37aac4e62a8fc2d6d976ce94b519ad6f25f9a --- /dev/null +++ b/inc/metadef/exe_graph/runtime/getcdim.h @@ -0,0 +1,18 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_GETCDIM_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_GETCDIM_H_ +#include "exe_graph/runtime/infer_shape_context.h" +#include "exe_graph/runtime/tiling_context.h" +namespace gert { + int64_t GetInputCDim(gert::TilingContext *kernel_context, const size_t index); + int64_t GetOutputCDim(gert::TilingContext *kernel_context, const size_t index); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_TILING_PARSE_CONTEXT_H_ diff --git a/inc/metadef/exe_graph/runtime/shape_utils.h b/inc/metadef/exe_graph/runtime/shape_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..712b2c7d17edb027cf532a66cf364c68f738ab98 --- /dev/null +++ b/inc/metadef/exe_graph/runtime/shape_utils.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_UTILS_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_UTILS_H_ +#include +#include +#include "exe_graph/runtime/shape.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" + +namespace gert { +extern const Shape g_vec_1_shape; +/** + * 确保返回的shape是非scalar的。 + * 当一个shape的dim num为0时,此shape被认为表达了一个scalar。 + * 本函数在接受一个非scalar的shape时,会返回原有shape;在接收到scalar shape时,会返回返回一个{1}的vector shape + * @param in_shape 输入shape + * @return 保证非scalar的shape + */ +inline const Shape &EnsureNotScalar(const Shape &in_shape) { + if (in_shape.IsScalar()) { + return g_vec_1_shape; + } + return in_shape; +} +/** + * 返回shape的字符串,本函数性能较低,不可以在执行时的正常流程中使用 + * @param shape 需要转为字符串的shape实例 + * @param join_char 每个Dim的间隔,默认为`,` + * @return 转好的字符串 + */ +inline std::string ShapeToString(const Shape &shape, const char *join_char = ",") { + if (join_char == nullptr) { + join_char = ","; + } + std::stringstream ss; + for (size_t i = 0U; i < shape.GetDimNum(); ++i) { + if (i > 0U) { + ss << join_char; + } + ss << shape[i]; + } + return ss.str(); +} + +ge::graphStatus CalcAlignedSizeByShape(const Shape &shape, ge::DataType data_type, uint64_t &ret_tensor_size); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_UTILS_H_ diff --git a/inc/metadef/exe_graph/runtime/tiling_context_builder.h b/inc/metadef/exe_graph/runtime/tiling_context_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..feda6882b389ecf9377104ddb011066be907bf9e --- /dev/null +++ b/inc/metadef/exe_graph/runtime/tiling_context_builder.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GE_COMMMON_RUNTIME_TILING_KERNEL_CONTEXT_BUILDER_H_ +#define GE_COMMMON_RUNTIME_TILING_KERNEL_CONTEXT_BUILDER_H_ + +#include "graph/node.h" +#include "exe_graph/runtime/compute_node_info.h" +#include "exe_graph/runtime/kernel_context.h" +#include "exe_graph/lowering/buffer_pool.h" +#include "exe_graph/runtime/tiling_context.h" +#include "exe_graph/runtime/kernel_run_context_builder.h" +#include "register/op_impl_space_registry.h" + +namespace gert { +class TilingContextBuilder { + public: + TilingContextBuilder &CompileInfo(void *compile_info); + TilingContextBuilder &Deterministic(int32_t deterministic); + TilingContextBuilder &PlatformInfo(void *platform_info); + TilingContextBuilder &TilingData(void *tiling_data); + TilingContextBuilder &Workspace(ContinuousVector *workspace); + // 兼容air, 随后air合入后删除 + TilingContextBuilder &SpaceRegistry(const gert::OpImplSpaceRegistryPtr &space_registry); + TilingContextBuilder &SpaceRegistries(const gert::OpImplSpaceRegistryArray &space_registries); + KernelContextHolder Build(const ge::Operator &op); // deprecated later + KernelContextHolder Build(const ge::Operator &op, ge::graphStatus &ret); + + private: + ge::graphStatus GetDependInputTensorAddr(const ge::Operator &op, const size_t input_idx, TensorAddress &address); + ge::graphStatus BuildRtTensor(const ge::GeTensorDesc &tensor_desc, ConstTensorAddressPtr address, + std::unique_ptr &rt_tensor_holder) const; + ge::graphStatus BuildRTInputTensors(const ge::Operator &op); + ge::graphStatus BuildRTOutputShapes(const ge::Operator &op); + + void *compile_info_{nullptr}; + void *platform_info_{nullptr}; + int32_t deterministic_; + std::vector> depend_ge_tensor_holders_; + std::vector> rt_tensor_holders_; + std::vector outputs_ {TilingContext::kOutputNum}; + KernelRunContextBuilder base_builder_; + OpImplSpaceRegistryArray space_registries_; +}; + +class AtomicTilingContextBuilder { + public: + AtomicTilingContextBuilder &CompileInfo(void *compile_info); + AtomicTilingContextBuilder &CleanWorkspaceSizes(ContinuousVector *workspace_sizes); + AtomicTilingContextBuilder &CleanOutputSizes(const std::vector &output_sizes); + AtomicTilingContextBuilder &TilingData(void *tiling_data); + AtomicTilingContextBuilder &Workspace(ContinuousVector *workspace); + KernelContextHolder Build(const ge::Operator &op); // deprecated later + KernelContextHolder Build(const ge::Operator &op, ge::graphStatus &ret); + + private: + void *compile_info_{nullptr}; + void *worksapce_sizes_{nullptr}; + std::vector clean_output_sizes_; + std::vector outputs_ {TilingContext::kOutputNum}; + KernelRunContextBuilder base_builder_; +}; +} // namespace gert +#endif // GE_COMMMON_RUNTIME_TILING_KERNEL_CONTEXT_BUILDER_H_ diff --git a/inc/metadef/exe_graph/runtime/tiling_parse_context_builder.h b/inc/metadef/exe_graph/runtime/tiling_parse_context_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..15df28291a6e7e905864bf58dcfdc099d4cba095 --- /dev/null +++ b/inc/metadef/exe_graph/runtime/tiling_parse_context_builder.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GE_RUNTIME_TILING_PARSE_CONTEXT_BUILDER_H_ +#define GE_RUNTIME_TILING_PARSE_CONTEXT_BUILDER_H_ + +#include "exe_graph/runtime/kernel_context.h" +#include "graph/operator.h" +#include "exe_graph/runtime/kernel_run_context_builder.h" +#include "register/op_impl_registry.h" + +namespace gert { +class TilingParseContextBuilder { + public: + TilingParseContextBuilder &CompileJson(const ge::char_t *compile_json); + TilingParseContextBuilder &PlatformInfo(void *platform_info); + TilingParseContextBuilder &CompileInfoCreatorFunc(OpImplRegisterV2::CompileInfoCreatorFunc create_func); + TilingParseContextBuilder &CompileInfoDeleterFunc(OpImplRegisterV2::CompileInfoDeleterFunc delete_func); + KernelContextHolder Build(const ge::Operator &op); + + private: + void *compile_json_{ nullptr }; + void *platform_info_{ nullptr }; + OpImplRegisterV2::CompileInfoCreatorFunc create_func_{ nullptr }; + OpImplRegisterV2::CompileInfoDeleterFunc delete_func_{ nullptr }; +}; +} // namespace gert +#endif // GE_RUNTIME_TILING_PARSE_CONTEXT_BUILDER_H_ diff --git a/inc/metadef/external/exe_graph/lowering/lowering_opt.h b/inc/metadef/external/exe_graph/lowering/lowering_opt.h new file mode 100644 index 0000000000000000000000000000000000000000..f6de626de0d7e71b053fa293a2058c33816480e1 --- /dev/null +++ b/inc/metadef/external/exe_graph/lowering/lowering_opt.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_LOWERING_OPT_H_ +#define INC_EXTERNAL_LOWERING_OPT_H_ + +#include + +namespace gert { +struct LoweringOption { + /** + * 是否相信用户传入的输出tensor上的shape,如果开启了本选项,可以省去计算图上输出节点的InferShape,提升一点Host调度性能。 + * 与此同时,也会损失掉对外部传入的输出Tensor Shape、TensorData长度的校验能力。 + * + * 约束: + * 1. 如果一个节点有多个输出,并且部分输出并不是网络的输出, + * 那么这个节点的InferShape不会被省掉,体现为在这个节点上,本选项会被忽略。 + * 2. 如果一个节点没有InferShape函数,例如第三类、第四类算子, + * 需要从Device拷贝回Shape,那么在这个节点上,本选项会被忽略。 + * 3. 本选项是个加载时选项,一旦选定后,意味着后续本model的每次调用都需要用户传入输出shape,否则可能会导致执行失败 + */ + bool trust_shape_on_out_tensor = false; + + /** + * 总是零拷贝开关,默认关闭。如果本开关打开,含义是外部调用者总是保证会正确地申请输出内存,包含: + * 1. 申请的输出内存大于等于输出shape所以计算出的Tensor大小 + * 2. 输出内存的placement正确 + * + * 打开本开关后,可以提升一点Host调度性能。与此同时,对于零拷贝失效的回退处理将不再进行, + * 在外部申请的输出内存错误、或未申请输出内存时,执行报错。 + */ + bool always_zero_copy = false; + + /** + * 总是使用外部allocator开关,默认关闭。如果本开关打开,含义是外部调用者总是保证会传入所有allocator,包含: + * 1. 创建所有在加载/执行阶段所需的allocator并传入 + * 2. 由于总是信任外部allocator,一旦开启后,如果在加载/执行阶段获取外置allocator失败,则报错。 + * + * 打开本开关后,在执行器内部不需要再创建allocator,减少资源浪费 + */ + bool always_external_allocator = false; + + /** + * 使能单流,默认关闭。如果本开关打开,执行时动态根图任务下发在一条流上。 + * 该开关由rt2的使用者(acl/hybrid model)根据设备上流是否充裕,来决定是否只使用一条流资源。 + * + */ + bool enable_single_stream = false; + + /** + * 二进制兼容保留字段,增加option时,对应缩减删除reserved长度 + */ + uint8_t reserved[4U + 8U] = {0U}; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/inc/metadef/external/exe_graph/runtime/base_type.h b/inc/metadef/external/exe_graph/runtime/base_type.h new file mode 100644 index 0000000000000000000000000000000000000000..eec116dc46439efe3a015c8492b59f7bb05f75c0 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/base_type.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_BASE_TYPE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_BASE_TYPE_H_ + +#ifdef __cplusplus +extern "C" { +#endif + +typedef unsigned int UINT32; +typedef UINT32 KernelStatus; + +#ifdef __cplusplus +} +#endif + +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_BASE_TYPE_H_ diff --git a/inc/metadef/external/exe_graph/runtime/compute_node_info.h b/inc/metadef/external/exe_graph/runtime/compute_node_info.h new file mode 100644 index 0000000000000000000000000000000000000000..20a6d5f78672d569900c1d24d9ce6fd61f3f7660 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/compute_node_info.h @@ -0,0 +1,326 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_COMPUTE_NODE_INFO_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_COMPUTE_NODE_INFO_H_ +#include +#include +#include +#include "graph/types.h" +#include "storage_format.h" +#include "runtime_attrs.h" +namespace gert { +class AnchorInstanceInfo { + public: + AnchorInstanceInfo() { + (void)memset_s(reserved_, sizeof(reserved_), 0, sizeof(reserved_)); + }; + AnchorInstanceInfo(const uint32_t instance_start, const uint32_t instantiation_num) + : instance_start_(instance_start), instantiation_num_(instantiation_num) { + (void)memset_s(reserved_, sizeof(reserved_), 0, sizeof(reserved_)); + } + + /** + * 获取本输入/输出实例化的个数 + * @return 实例化个数 + */ + size_t GetInstanceNum() const { + return instantiation_num_; + } + + /** + * 获取本输入/输出首个实例化的Anchor的index + * @return 首个实例化的Anchor的index + */ + size_t GetInstanceStart() const { + return instance_start_; + } + + /** + * 设置本输入/输出首个实例化Anchor的index + * @param instance_start 首个实例化Anchor的index + */ + void SetInstanceStart(const uint32_t instance_start) { + instance_start_ = instance_start; + } + + /** + * 设置本输入/输出实例化的个数 + * @param instantiation_num 实例化的个数 + */ + void SetInstantiationNum(const uint32_t instantiation_num) { + instantiation_num_ = instantiation_num; + } + + private: + uint32_t instance_start_; + uint32_t instantiation_num_; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class AnchorInstanceInfo must be a POD"); + +class CompileTimeTensorDesc { + public: + CompileTimeTensorDesc() { + (void)memset_s(reserved_, sizeof(reserved_), 0, sizeof(reserved_)); + } + /** + * 获取DataType + * @return DataType + */ + ge::DataType GetDataType() const { + return data_type_; + } + /** + * 获取format + * @return format + */ + const StorageFormat &GetFormat() const { + return storage_format_; + } + /** + * 获取原始format + * @return 原始format + */ + ge::Format GetOriginFormat() const { + return storage_format_.GetOriginFormat(); + } + /** + * 获取运行时format + * @return 运行时format + */ + ge::Format GetStorageFormat() const { + return storage_format_.GetStorageFormat(); + } + /** + * 获取原始format向运行时format转换时的补维规则 + * @return 补维规则 + */ + ExpandDimsType GetExpandDimsType() const { + return storage_format_.GetExpandDimsType(); + } + /** + * 设置 data type + * @param data_type data type + */ + void SetDataType(const ge::DataType data_type) { + data_type_ = data_type; + } + /** + * 设置运行时format + * @param format 运行时format + */ + void SetStorageFormat(const ge::Format format) { + storage_format_.SetStorageFormat(format); + } + /** + * 设置原始format + * @param format 原始format + */ + void SetOriginFormat(const ge::Format format) { + storage_format_.SetOriginFormat(format); + } + /** + * 设置原始format向运行时format转换时的补维规则 + * @param expand_dims_type 补维规则 + */ + void SetExpandDimsType(const ExpandDimsType &expand_dims_type) { + storage_format_.SetExpandDimsType(expand_dims_type); + } + + private: + ge::DataType data_type_; + StorageFormat storage_format_; + uint8_t reserved_[40]; // Reservd field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class CompileTimeTensorDesc must be a POD"); + +class ComputeNodeInfo { + public: + /** + * 获取计算节点的node type + * @return node type + */ + const ge::char_t *GetNodeType() const { + return node_type_; + } + /** + * 获取计算节点的node name + * @return node name + */ + const ge::char_t *GetNodeName() const { + return node_name_; + } + /** + * 获取算子IR原型定义中的输入个数 + * @return IR原型中定义的输入个数 + */ + size_t GetIrInputsNum() const { + return ir_inputs_num_; + } + /** + * 获取算子IR原型定义中的输出个数 + * @return IR原型中定义的输出个数 + */ + size_t GetIrOutputsNum() const { + return ir_outputs_num_; + } + /** + * 获取计算节点的输入个数 + * @return 计算节点的输入个数 + */ + size_t GetInputsNum() const { + return inputs_num_; + } + /** + * 获取计算节点的输出个数 + * @return 计算节点的输出个数 + */ + size_t GetOutputsNum() const { + return outputs_num_; + } + /** + * 根据IR原型中的输入index,获取对应的实例化信息 + * @param ir_index IR原型定义中的输入index + * @return 输入的实例化信息 + */ + const AnchorInstanceInfo *GetInputInstanceInfo(const size_t ir_index) const { + if (ir_index >= ir_inputs_num_) { + return nullptr; + } + const auto inputs = reinterpret_cast(&place_holder); + return inputs + ir_index; + } + /** + * 根据IR原型中的输出index,获取对应的实例化信息 + * @param ir_index IR原型定义中的输出index + * @return 输出的实例化信息 + */ + const AnchorInstanceInfo *GetOutputInstanceInfo(const size_t ir_index) const { + if (ir_index >= ir_outputs_num_) { + return nullptr; + } + const auto outputs = reinterpret_cast( + reinterpret_cast(&place_holder) + + sizeof(AnchorInstanceInfo) * ir_inputs_num_ + + sizeof(CompileTimeTensorDesc) * (inputs_num_ + outputs_num_) + + runtime_attr_size_); + return outputs + ir_index; + } + /** + * 获取计算节点输入的Tensor描述,注意,编译时无法确定的shape信息不在Tensor描述中 + * @param index 计算节点的输入index + * @return Tensor描述 + */ + const CompileTimeTensorDesc *GetInputTdInfo(const size_t index) const { + if (index >= inputs_num_) { + return nullptr; + } + const auto inputs = reinterpret_cast( + reinterpret_cast(&place_holder) + sizeof(AnchorInstanceInfo) * ir_inputs_num_); + return inputs + index; + } + /** + * 获取计算节点输出的Tensor描述,注意,编译时无法确定的shape信息不在Tensor描述中 + * @param index 计算节点的输出index + * @return Tensor描述 + */ + const CompileTimeTensorDesc *GetOutputTdInfo(const size_t index) const { + if (index >= outputs_num_) { + return nullptr; + } + const auto outputs = reinterpret_cast( + reinterpret_cast(&place_holder) + sizeof(AnchorInstanceInfo) * ir_inputs_num_ + + sizeof(CompileTimeTensorDesc) * inputs_num_); + return outputs + index; + } + /** + * 获取计算节点上的属性值,仅IR定义的属性值会被返回,其他属性值被丢弃 + * @return 所有IR原型定义过的属性值,属性值按照IR原型定义的顺序依次保存 + */ + const RuntimeAttrs *GetAttrs() const { + return reinterpret_cast(reinterpret_cast(&place_holder) + + sizeof(AnchorInstanceInfo) * ir_inputs_num_ + + sizeof(CompileTimeTensorDesc) * (inputs_num_ + outputs_num_)); + } + /** + * 设置计算节点的node type + * @param node_type 计算节点的node type + */ + void SetNodeType(const ge::char_t *node_type) { + node_type_ = node_type; + } + /** + * 设置计算节点的node name + * @param node_name 计算节点的node name + */ + void SetNodeName(const ge::char_t *node_name) { + node_name_ = node_name; + } + /** + * 根据IR原型中的输入index,获取对应的实例化信息 + * @param ir_index IR原型定义中的输入index + * @return 输入的实例化信息 + */ + AnchorInstanceInfo *MutableInputInstanceInfo(const size_t ir_index) const; + /** + * 根据IR原型中的输出index,获取对应的实例化信息 + * @param ir_index IR原型定义中的输出index + * @return 输出的实例化信息 + */ + AnchorInstanceInfo *MutableOutputInstanceInfo(const size_t ir_index) const; + /** + * 获取计算节点输入的Tensor描述,注意,编译时无法确定的shape信息不在Tensor描述中 + * @param index 计算节点的输入index + * @return Tensor描述 + */ + CompileTimeTensorDesc *MutableInputTdInfo(const size_t index) const; + /** + * 获取计算节点输出的Tensor描述,注意,编译时无法确定的shape信息不在Tensor描述中 + * @param index 计算节点的输出index + * @return Tensor描述 + */ + CompileTimeTensorDesc *MutableOutputTdInfo(const size_t index) const; + /** + * 获取计算节点上的属性值,仅IR定义的属性值会被返回,其他属性值被丢弃 + * @return 所有IR原型定义过的属性值,属性值按照IR原型定义的顺序依次保存 + */ + RuntimeAttrs *MutableAttrs() const; + static ge::graphStatus CalcSize(const size_t ir_inputs_num, const size_t inputs_num, + const size_t outputs_num, size_t &total_size); + void Init(const size_t ir_inputs_num, const size_t inputs_num, const size_t outputs_num, + const ge::char_t *node_name, const ge::char_t *node_type); + static ge::graphStatus CalcSize(const size_t ir_inputs_num, const size_t ir_outputs_num, const size_t inputs_num, + const size_t outputs_num, size_t &total_size); + void Init(const size_t ir_inputs_num, const size_t ir_outputs_num, const size_t inputs_num, const size_t outputs_num, + const size_t attr_size, const ge::char_t *node_name, const ge::char_t *node_type); + + ComputeNodeInfo() = delete; + ComputeNodeInfo(const ComputeNodeInfo &) = delete; + ComputeNodeInfo(ComputeNodeInfo &&) = delete; + ComputeNodeInfo &operator=(const ComputeNodeInfo &) = delete; + ComputeNodeInfo &operator=(ComputeNodeInfo &&) = delete; + + private: + const ge::char_t *node_type_; + const ge::char_t *node_name_; + size_t ir_inputs_num_; + size_t inputs_num_; + size_t outputs_num_; + size_t ir_outputs_num_; + size_t runtime_attr_size_; + uint8_t reserved_[24]; // Reserved field, 24+8, do not directly use when only 8-byte left + // following by inputs-AnchorInstanceInfo, inputs-outputs-CompileTimeTensorDesc, RuntimeAttrs, + // outputs-AnchorInstanceInfo + uint64_t place_holder; +}; +static_assert(std::is_standard_layout::value, "The class ComputeNodeInfo must be a POD"); +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_COMPUTE_NODE_INFO_H_ diff --git a/inc/metadef/external/exe_graph/runtime/context_extend.h b/inc/metadef/external/exe_graph/runtime/context_extend.h new file mode 100644 index 0000000000000000000000000000000000000000..a1c2c410a31c8e5e814a2ffed1a2ecc377759bcb --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/context_extend.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTEXT_EXTEND_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTEXT_EXTEND_H_ +#include +#include +#include "compute_node_info.h" + +namespace gert { +class KernelExtendInfo { + public: + /** + * 获取kernel name + * @return kernel name + */ + const ge::char_t *GetKernelName() const { + return kernel_name_; + } + /** + * 设置kernel name + * @param kernel_name kernel name + */ + void SetKernelName(const ge::char_t *kernel_name) { + kernel_name_ = kernel_name; + } + /** + * 获取kernel type + * @return kernel type + */ + const ge::char_t *GetKernelType() const { + return kernel_type_; + } + /** + * 设置kernel type + * @param kernel_type kernel type + */ + void SetKernelType(const ge::char_t *const kernel_type) { + (void) reserved_; + kernel_type_ = kernel_type; + } + + KernelExtendInfo() = delete; + KernelExtendInfo(const KernelExtendInfo &) = delete; + KernelExtendInfo(KernelExtendInfo &&) = delete; + KernelExtendInfo &operator=(const KernelExtendInfo &) = delete; + KernelExtendInfo &operator=(KernelExtendInfo &&) = delete; + + private: + const ge::char_t *kernel_name_; + const ge::char_t *kernel_type_; + uint8_t reserved_[56]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class KernelExtendInfo must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTEXT_EXTEND_H_ diff --git a/inc/metadef/external/exe_graph/runtime/continuous_vector.h b/inc/metadef/external/exe_graph/runtime/continuous_vector.h new file mode 100644 index 0000000000000000000000000000000000000000..1581ac15dba3cb8767f3748c0dd72071e13004a7 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/continuous_vector.h @@ -0,0 +1,210 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_VECTOR_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_VECTOR_H_ +#include +#include +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "utils/extern_math_util.h" +#include "external/ge_common/ge_api_error_codes.h" + +namespace gert { +class ContinuousVector { + public: + /** + * 创建一个ContinuousVector实例,ContinuousVector不支持动态扩容 + * @tparam T 实例中包含的元素类型 + * @param capacity 实例的最大容量 + * @param total_size 本实例的总长度 + * @return 指向本实例的指针 + */ + template + static std::unique_ptr Create(size_t capacity, size_t &total_size) { + if (ge::MulOverflow(capacity, sizeof(T), total_size)) { + return nullptr; + } + if (ge::AddOverflow(total_size, sizeof(ContinuousVector), total_size)) { + return nullptr; + } + auto holder = std::unique_ptr(new (std::nothrow) uint8_t[total_size]); + if (holder == nullptr) { + return nullptr; + } + reinterpret_cast(holder.get())->Init(capacity); + return holder; + } + /** + * 创建一个ContinuousVector实例,ContinuousVector不支持动态扩容 + * @tparam T 实例中包含的元素类型 + * @param capacity 实例的最大容量 + * @return 指向本实例的指针 + */ + template + static std::unique_ptr Create(const size_t capacity) { + size_t total_size; + return Create(capacity, total_size); + } + /** + * 使用最大容量初始化本实例 + * @param capacity 最大容量 + */ + void Init(const size_t capacity) { + capacity_ = capacity; + size_ = 0U; + (void)memset_s(reserved_, sizeof(reserved_), 0, sizeof(reserved_)); + } + /** + * 获取当前保存的元素个数 + * @return 当前保存的元素个数 + */ + size_t GetSize() const { + return size_; + } + /** + * 设置当前保存的元素个数 + * @param size 当前保存的元素个数 + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetSize(const size_t size) { + if (size > capacity_) { + return ge::GRAPH_PARAM_OUT_OF_RANGE; + } + size_ = size; + return ge::GRAPH_SUCCESS; + } + /** + * 获取最大可保存的元素个数 + * @return 最大可保存的元素个数 + */ + size_t GetCapacity() const { + return capacity_; + } + /** + * 获取首个元素的指针地址,[GetData(), GetData() + GetSize()) 中的数据即为当前容器中保存的数据 + * @return 首个元素的指针地址 + */ + const void *GetData() const { + return elements; + } + /** + * 获取首个元素的指针地址,[GetData(), GetData() + GetSize()) 中的数据即为当前容器中保存的数据 + * @return 首个元素的指针地址 + */ + void *MutableData() { + return elements; + } + + private: + size_t capacity_; + size_t size_; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left + uint8_t elements[8]; +}; +static_assert(std::is_standard_layout::value, "The ContinuousVector must be a POD"); + +template +class TypedContinuousVector : private ContinuousVector { + public: + using ContinuousVector::GetCapacity; + using ContinuousVector::GetSize; + using ContinuousVector::SetSize; + /** + * 获取首个元素的指针地址,[GetData(), GetData() + GetSize()) 中的数据即为当前容器中保存的数据 + * @return 首个元素的指针地址 + */ + T *MutableData() { + return static_cast(ContinuousVector::MutableData()); + } + /** + * 获取首个元素的指针地址,[GetData(), GetData() + GetSize()) 中的数据即为当前容器中保存的数据 + * @return 首个元素的指针地址 + */ + const T *GetData() const { + return static_cast(ContinuousVector::GetData()); + } +}; + +/* + * memory layout: |size_|offset1|offset2|...|ContinuousVector1|ContinuousVector1|...| + * size_ : number of ContinuousVector + * offset1 : offset of ContinuousVector1u + */ +class ContinuousVectorVector { + public: + void Init(const size_t capacity) { + capacity_ = capacity; + if (capacity_ == 0U) { + return; + } + SetOffset(0U, GetOverHeadLength(capacity_)); + (void)memset_s(reserved_, sizeof(reserved_), 0, sizeof(reserved_)); + } + + template + ContinuousVector *Add(size_t inner_vector_capacity) { + if (size_ >= capacity_) { + return nullptr; + } + const auto inner_vector = + reinterpret_cast(reinterpret_cast(this) + GetOffset(size_)); + inner_vector->Init(inner_vector_capacity); + (void) inner_vector->SetSize(inner_vector_capacity); + size_t inner_vector_length = 0U; + if (ge::MulOverflow(inner_vector_capacity, sizeof(T), inner_vector_length)) { + return nullptr; + } + if (ge::AddOverflow(inner_vector_length, sizeof(ContinuousVector), inner_vector_length)) { + return nullptr; + } + ++size_; + if (size_ < capacity_) { + SetOffset(size_, GetOffset(size_ - 1U) + inner_vector_length); + } + return inner_vector; + } + + const ContinuousVector *Get(const size_t index) const { + return reinterpret_cast(reinterpret_cast(this) + GetOffset(index)); + } + + size_t GetSize() const { + return size_; + } + + static size_t GetOverHeadLength(const size_t capacity) { + if (capacity == 0U) { + return sizeof(ContinuousVectorVector); + } + return sizeof(capacity_) + sizeof(size_) + sizeof(reserved_) + sizeof(size_t) * capacity; + } + + private: + void SetOffset(const size_t index, const size_t offset) { + size_t *const offset_ptr = &offset_[0U]; + offset_ptr[index] = offset; + } + + size_t GetOffset(const size_t index) const { + const size_t *const offset_ptr = &offset_[0U]; + return offset_ptr[index]; + } + + private: + size_t capacity_ = 0U; + size_t size_ = 0U; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left + size_t offset_[1U]; +}; +static_assert(std::is_standard_layout::value, "The ContinuousVectorVector must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_CONTINUOUS_VECTOR_H_ diff --git a/inc/metadef/external/exe_graph/runtime/exe_res_generation_context.h b/inc/metadef/external/exe_graph/runtime/exe_res_generation_context.h new file mode 100644 index 0000000000000000000000000000000000000000..85dee70410a8018b649c08312195067b1c055f20 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/exe_res_generation_context.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_EXE_RES_GENERATION_CONTEXT_H +#define INC_EXTERNAL_REGISTER_EXE_RES_GENERATION_CONTEXT_H + +#include +#include "graph/ascend_string.h" +#include "external/ge_common/ge_api_types.h" +#include "exe_graph/runtime/storage_shape.h" +#include "exe_graph/runtime/compute_node_info.h" +#include "exe_graph/runtime/extended_kernel_context.h" +namespace gert { +enum class ExecuteMode { + kStaticOffloadExecute, // static sink + kDynamicExecute, + kEnd +}; + +/* + * input: + * name: using type, if reuse_key is empty, used as reuse key + * reuse_key: for resource reuse, same key will reuse + * depend_value_input_indices: depend value input index + * required: resource must required + * output: + * is_valid: resource alloc status + * stream_id: alloc stream id + * */ +struct StreamInfo { + ge::AscendString name; + ge::AscendString reuse_key; + std::vector depend_value_input_indices; + bool required{true}; + bool is_valid{false}; + int64_t stream_id{-1}; +}; + +enum class SyncResType { + SYNC_RES_EVENT, + SYNC_RES_NOTIFY, + END +}; + +/* + * input: + * type: resource type + * name: using type, if reuse_key is empty, used as reuse key + * reuse_key: for resource reuse, same key will reuse + * required: resource must required + * output: + * is_valid: resource alloc status + * sync_res_id: alloc sync id + * */ +struct SyncResInfo { + SyncResType type; + ge::AscendString name; + ge::AscendString reuse_key; + bool required{true}; + bool is_valid{false}; + int32_t sync_res_id{-1}; +}; + +class ExeResGenerationContext : public ExtendedKernelContext { + public: + ExecuteMode GetExecuteMode() const; + + // get input with index is const data + bool IsConstInput(const ge::AscendString &name) const; + + // get input/output shape by index in graph + const gert::StorageShape* GetInputShape(int64_t index) const; + const gert::StorageShape* GetOutputShape(int64_t index) const; + + // set/get stream resource + ge::graphStatus SetAttachedStreamInfos(std::vector &stream_info_vec) const; + std::vector GetAttachedStreamInfos() const; + int64_t GetStreamId() const; + + // set/get sync resource + ge::graphStatus SetSyncResInfos(std::vector &sync_info_vec) const; + std::vector GetSyncResInfos() const; + + // workspace size adjust + std::vector GetWorkspaceBytes() const; + void SetWorkspaceBytes(const std::vector &workspace_bytes) const; + + int64_t GetOpId() const; + + private: + friend class ExeResGenerationCtxBuilder; + // need check valid after construct + bool CheckContextValid() const; +}; +static_assert(std::is_standard_layout::value && std::is_trivial::value, + "The class ExeResGenerationContext must be a POD"); +} // namespace gert + +#endif diff --git a/inc/metadef/external/exe_graph/runtime/expand_dims_type.h b/inc/metadef/external/exe_graph/runtime/expand_dims_type.h new file mode 100644 index 0000000000000000000000000000000000000000..6ab500c82c3806954ac170f1d91f266c582f2f66 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/expand_dims_type.h @@ -0,0 +1,180 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXPAND_DIMS_TYPE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXPAND_DIMS_TYPE_H_ +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "shape.h" + +namespace gert { +/** + * 本类基于补维后的shape,描述了补维规则。本类设计为长度与一个uint64_t一致,因此可以以很小的代价做拷贝。 + * + * 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。 + * 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。 + * + * 本类通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如: + * | 补维规则 | 补维前shape | 补维后shape | + * | -------- | ----------- | ------------------------------------------------------------ | + * | 0110 | [2, 2] | [2, 1, 1, 2] | + * | 100 | [2, 3] | [1, 2, 3] | + * | 1000 | [2, 3] | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 | + * + */ +class ExpandDimsType { + public: + using AxisIndex = uint64_t; + static constexpr size_t kMaxExpandSize = 56; + + ExpandDimsType() : size_(0U), mask_(0U) { + (void)memset(reserved_, 0, sizeof(reserved_)); // memse函数misra告警屏蔽 + } + + /** + * 通过字符串创建一个补维规则 + * @param expand_dims_type 字符串描述的补维规则 + */ + explicit ExpandDimsType(const ge::char_t *const expand_dims_type) : ExpandDimsType() { + if (expand_dims_type == nullptr) { + return; + } + const auto size = strlen(expand_dims_type); + if (size > kMaxExpandSize) { + // log error + return; + } + + size_ = size; + for (AxisIndex i = 0; i < size; ++i) { + if (expand_dims_type[i] == '1') { + SetExpandIndex(i); + } + } + } + + /** + * 通过int创建一个补维规则,int中的位域定义为 + * | 字段 | 类型 | 含义 | + * | -------- | ------- | ---------------------- | + * | 高8比特 | uint8_t | 补维规则长度 | + * | 低56比特 | 位域 | 使用0、1描述的补维规则 | + * + * 为了实现简单,补维规则部分与字符串的顺序相反,例如字符串描述的补维规则为"1100",那么对应的补维规则为"0011"转换为数字为3 + * @param expand_dims_type 补维规则 + */ + explicit ExpandDimsType(const int64_t reshape_type_mask) : ExpandDimsType() { + if (reshape_type_mask == 0) { + return; + } + size_ = static_cast(static_cast(reshape_type_mask) >> kMaxExpandSize); + if (size_ > kMaxExpandSize) { + return; + } + mask_ = static_cast(static_cast(reshape_type_mask) & 0xffULL); + } + /** + * 判断补维规则是否一致 + * @param other 另一个实例 + * @return true/false + */ + bool operator==(const ExpandDimsType &other) const { + return (size_ == other.size_) && (mask_ == other.mask_); + } + /** + * 获取补维后的dim数 + * @return + */ + AxisIndex GetFullSize() const { + return size_; + } + /** + * 设置补维轴 + * @param index 第index根轴为补维轴 + */ + void SetExpandIndex(const AxisIndex index) { + mask_ |= (1UL << index); + } + /** + * 基于补维后的shape,判断index根轴是否为补维轴 + * @param index 第index根轴 + * @return true含义为补维轴,false含义为原shape的轴 + */ + bool IsExpandIndex(const AxisIndex index) const { + return static_cast((1UL << index) & mask_); + } + /** + * 对shape做补维,并将结果写入到out_shape + * @param shape 输入shape,补维前shape + * @param out_shape 输出shape,补维后shape + * @return 补维成功返回ge::GRAPH_SUCCESS + */ + ge::graphStatus Expand(const Shape &shape, Shape &out_shape) const { + if (shape.GetDimNum() == size_) { + out_shape = shape; + return ge::GRAPH_SUCCESS; + } + size_t shape_pos = 0; + out_shape.SetDimNum(0); + for (size_t out_shape_pos = 0; out_shape_pos < size_; ++out_shape_pos) { + if (!IsExpandIndex(out_shape_pos)) { + if (shape_pos >= shape.GetDimNum()) { + return ge::GRAPH_FAILED; + } + (void) out_shape.AppendDim(shape.GetDim(shape_pos++)); + } else { + (void) out_shape.AppendDim(1); + } + } + + for (; shape_pos < shape.GetDimNum(); ++shape_pos) { + (void) out_shape.AppendDim(shape.GetDim(shape_pos)); + } + return ge::GRAPH_SUCCESS; + } + /** + * 原地补维 + * @param shape 直接在本shape做补维 + * @return 补维成功返回ge::GRAPH_SUCCESS + */ + ge::graphStatus Expand(Shape &shape) const { + // full_size:4, shape:[A,B], reshape_type:1010 + // shape:[A,B] + full_size:4 -> [A,B,1,1] + if (shape.GetDimNum() == size_) { + return ge::GRAPH_SUCCESS; + } + size_t dim_size = shape.GetDimNum(); + size_t cur_dim_size = dim_size; + while (cur_dim_size++ < size_) { + (void) shape.AppendDim(1); + } + + // shape:[A,B,1,1] + 1010 -> [A,1,B,1] + for (int32_t i = static_cast(size_ - 1UL); i >= 0; --i) { + if (!IsExpandIndex(static_cast(i))) { + if ((dim_size > 0) && (dim_size - 1UL < static_cast(i))) { + shape.SetDim(static_cast(i), shape.GetDim(dim_size - 1)); + shape.SetDim(dim_size - 1UL, 1); + dim_size--; + } + } + } + return ge::GRAPH_SUCCESS; + } + + private: + uint64_t size_ : 8; + uint64_t mask_ : kMaxExpandSize; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXPAND_DIMS_TYPE_H_ diff --git a/inc/metadef/external/exe_graph/runtime/extended_kernel_context.h b/inc/metadef/external/exe_graph/runtime/extended_kernel_context.h new file mode 100644 index 0000000000000000000000000000000000000000..0fd654bd8d9965e0b245593c4b11f4b2f66643c5 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/extended_kernel_context.h @@ -0,0 +1,226 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXTENDED_KERNEL_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXTENDED_KERNEL_CONTEXT_H_ +#include +#include "kernel_context.h" +#include "context_extend.h" + +namespace gert { +class ExtendedKernelContext : protected KernelContext { + public: + /** + * 根据index获取节点输入的Tensor description + * @param index 输入index + * @return 输入TensorDesc的指针,index非法时,返回空指针 + */ + const CompileTimeTensorDesc *GetInputDesc(const size_t index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetInputTdInfo(index); + } + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入CompileTimeTensorDesc指针 + * @param ir_index IR原型定义中的index + * @return CompileTimeTensorDesc指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const CompileTimeTensorDesc *GetOptionalInputDesc(const size_t ir_index) const { + return GetDynamicInputDesc(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入CompileTimeTensorDesc指针 + * @param ir_index IR原型定义中的index + * @return CompileTimeTensorDesc指针,index非法,或该OUTPUT没有实例化时,返回空指针 + */ + const CompileTimeTensorDesc *GetRequiredInputDesc(const size_t ir_index) const { + return GetDynamicInputDesc(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入CompileTimeTensorDesc指针 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return CompileTimeTensorDesc指针,index或relative_index非法时,返回空指针 + */ + const CompileTimeTensorDesc *GetDynamicInputDesc(const size_t ir_index, const size_t relative_index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + const auto ins_info = GetIrInputInstanceInfo(ir_index); + if (ins_info == nullptr) { + return nullptr; + } + if (ins_info->GetInstanceNum() <= relative_index) { + return nullptr; + } + return compute_node_info->GetInputTdInfo(ins_info->GetInstanceStart() + relative_index); + } + /** + * 根据index获取节点输出的Tensor description + * @param index 输出index + * @return 输出TensorDesc的指针,index非法时,返回空指针 + */ + const CompileTimeTensorDesc *GetOutputDesc(const size_t index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetOutputTdInfo(index); + } + /** + * 根据IR原型中的index,获取在节点上的输入实例化信息 + * @param ir_index IR原型中定义的index + * @return 实例化信息 + */ + const AnchorInstanceInfo *GetIrInputInstanceInfo(const size_t ir_index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetInputInstanceInfo(ir_index); + } + + /** + * 根据IR原型中的index,获取在节点上的输出实例化信息 + * @param ir_index IR原型中定义的index + * @return 实例化信息 + */ + const AnchorInstanceInfo *GetIrOutputInstanceInfo(const size_t ir_index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetOutputInstanceInfo(ir_index); + } + + /** + * 获取计算节点的输入数量 + * @return 计算节点的输入数量 + */ + size_t GetComputeNodeInputNum() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return 0; + } + return compute_node_info->GetInputsNum(); + } + /** + * 获取计算节点的输出数量 + * @return 计算节点的输出数量 + */ + size_t GetComputeNodeOutputNum() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return 0; + } + return compute_node_info->GetOutputsNum(); + } + /** + * 获取计算节点的属性,仅IR原型中定义的属性可被获取到 + * @return 计算节点的属性 + */ + const RuntimeAttrs *GetAttrs() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetAttrs(); + } + /** + * 获取计算节点的类型 + * @return 计算节点的类型 + */ + const ge::char_t *GetNodeType() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetNodeType(); + } + /** + * 获取计算节点的name + * @return 计算节点的name + */ + const ge::char_t *GetNodeName() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + return compute_node_info->GetNodeName(); + } + /** + * 获取本kernel对应的计算节点的信息 + * @return 计算节点的信息 + */ + const ComputeNodeInfo *GetComputeNodeInfo() const { + return static_cast(GetComputeNodeExtend()); + } + /** + * 获取本kernel的name + * @return 本kernel的name + */ + const ge::char_t *GetKernelName() const { + const auto extend_info = GetExtendInfo(); + if (extend_info == nullptr) { + return nullptr; + } + return extend_info->GetKernelName(); + } + /** + * 获取本kernel的type + * @return 本kernel的type + */ + const ge::char_t *GetKernelType() const { + const auto extend_info = GetExtendInfo(); + if (extend_info == nullptr) { + return nullptr; + } + return extend_info->GetKernelType(); + } + /** + * 获取本kernel的扩展信息 + * @return 本kernel的扩展信息 + */ + const KernelExtendInfo *GetExtendInfo() const { + return static_cast(GetKernelExtend()); + } + + protected: + template + const T *GetDynamicInputPointer(size_t ir_index, size_t relative_ins_index) const { + const auto ins_info = GetIrInputInstanceInfo(ir_index); + if (ins_info == nullptr) { + return nullptr; + } + if (ins_info->GetInstanceNum() <= relative_ins_index) { + return nullptr; + } + return GetInputPointer(Offset + ins_info->GetInstanceStart() + relative_ins_index); + } + + template + const T *GetDynamicOutputPointer(size_t ir_index, size_t relative_ins_index) const { + const auto ins_info = GetIrOutputInstanceInfo(ir_index); + if (ins_info == nullptr) { + return nullptr; + } + if (ins_info->GetInstanceNum() <= relative_ins_index) { + return nullptr; + } + const size_t input_num = GetComputeNodeInputNum(); + return GetInputPointer(Offset + input_num + ins_info->GetInstanceStart() + relative_ins_index); + } +}; +static_assert(std::is_standard_layout::value, + "The class ExtendedKernelRunContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_EXTENDED_KERNEL_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/infer_datatype_context.h b/inc/metadef/external/exe_graph/runtime/infer_datatype_context.h new file mode 100644 index 0000000000000000000000000000000000000000..3e04f0077b9830229dc6d15b061c87e7c42c9bad --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/infer_datatype_context.h @@ -0,0 +1,106 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_DATATYPE_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_DATATYPE_CONTEXT_H_ +#include +#include "tensor.h" +#include "runtime_attrs.h" +#include "extended_kernel_context.h" +#include "graph/types.h" + +namespace gert { +/** + * InferDataType kernel的context + */ +class InferDataTypeContext : public ExtendedKernelContext { + public: + /** + * 根据输入index,获取输入DataType + * @param index 输入index + * @return 输入datatype,index非法时,返回DT_UNDEFINED + */ + ge::DataType GetInputDataType(const size_t index) const { + const auto in_datatype = GetInputPointer(index); + if (in_datatype == nullptr) { + return ge::DT_UNDEFINED; + } + return *in_datatype; + } + + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入DataType + * @param ir_index IR原型定义中的index + * @return in_datatype,index非法,或该INPUT没有实例化时,返回DT_UNDEFINED + */ + ge::DataType GetOptionalInputDataType(const size_t ir_index) const { + const auto in_datatype = GetDynamicInputPointer(ir_index, 0); + if (in_datatype == nullptr) { + return ge::DT_UNDEFINED; + } + return *in_datatype; + } + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入DataType + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return datatype,index或relative_index非法时,返回DT_UNDEFINED + */ + ge::DataType GetDynamicInputDataType(const size_t ir_index, const size_t relative_index) const { + const auto in_datatype = GetDynamicInputPointer(ir_index, relative_index); + if (in_datatype == nullptr) { + return ge::DT_UNDEFINED; + } + return *in_datatype; + } + + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入DataType + * @param ir_index IR原型定义中的index + * @return in_datatype,index非法,或该INPUT没有实例化时,返回DT_UNDEFINED + */ + ge::DataType GetRequiredInputDataType(const size_t ir_index) const { + const auto in_datatype = GetDynamicInputPointer(ir_index, 0); + if (in_datatype == nullptr) { + return ge::DT_UNDEFINED; + } + return *in_datatype; + } + + /** + * 根据输出index,获取输出DataType + * @param index 输出index + * @return 输出shape指针,index非法时,返回DT_UNDEFINED + */ + ge::DataType GetOutputDataType(const size_t index) const { + const auto datatype_ptr = GetOutputPointer(index); + if (datatype_ptr == nullptr) { + return ge::DT_UNDEFINED; + } + return *datatype_ptr; + } + + /** + * 根据输出index,设置输出DataType + * @param index 输出index + * @param datatype 输出datatype + * @return 设置结果,index非法时,返回失败 + */ + ge::graphStatus SetOutputDataType(const size_t index, const ge::DataType datatype) { + const auto output_dtype = GetOutputPointer(index); + if (output_dtype == nullptr) { + return ge::GRAPH_FAILED; + } + *output_dtype = datatype; + return ge::GRAPH_SUCCESS; + } +}; +static_assert(std::is_standard_layout::value, "The class InferDataTypeContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/infer_shape_context.h b/inc/metadef/external/exe_graph/runtime/infer_shape_context.h new file mode 100644 index 0000000000000000000000000000000000000000..ec73928ee37ff04c32aeb160baca3b17ddd329b1 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/infer_shape_context.h @@ -0,0 +1,135 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_CONTEXT_H_ +#include +#include "shape.h" +#include "tensor.h" +#include "runtime_attrs.h" +#include "extended_kernel_context.h" +#include "graph/inference_context.h" +namespace gert { +/** + * 在节点输入后的扩展输入的索引,若需要扩展,请新增枚举类型 + */ +enum class InputExternLayout : uint32_t { + kInferShapeFunc = 1, // only exe runtime need infer shape func, compile stage need set to null + kInferenceContext = 2, // only resource op in compile stage need inference context, exe runtime need set to null +}; +/** + * InferShape kernel的context + */ +class InferShapeContext : public ExtendedKernelContext { + public: + /** + * 根据输入index,获取输入shape指针 + * @param index 输入index + * @return 输入shape指针,index非法时,返回空指针 + */ + const Shape *GetInputShape(const size_t index) const { + return GetInputPointer(index); + } + /** + * 根据输入index,获取输出tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param index 输入index + * @return 输入tensor指针,index非法时,返回空指针 + */ + const Tensor *GetInputTensor(const size_t index) const { + return GetInputPointer(index); + } + + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @return tensor指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const Tensor *GetOptionalInputTensor(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入shape指针 + * @param ir_index IR原型定义中的index + * @return shape指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const Shape *GetOptionalInputShape(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入shape指针 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return shape指针,index或relative_index非法时,返回空指针 + */ + const Shape *GetDynamicInputShape(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer(ir_index, relative_index); + } + + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return tensor指针,index或relative_index非法时,返回空指针 + */ + const Tensor *GetDynamicInputTensor(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer(ir_index, relative_index); + } + + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入Tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @return Tensor指针,index非法时,返回空指针 + */ + const Tensor *GetRequiredInputTensor(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape + * @param ir_index IR原型定义中的index + * @return shape指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const Shape *GetRequiredInputShape(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + + /** + * 根据输出index,获取输出shape指针 + * @param index 输出index + * @return 输出shape指针,index非法时,返回空指针 + */ + Shape *GetOutputShape(const size_t index) { + return GetOutputPointer(index); + } + + /** + * 获取InferenceContext指针 + * @param NA + * @return 输出InferenceContext指针 指针在节点输入地址后,仅编译态使用 + */ + const ge::InferenceContext *GetInferenceContextPtr() const { + const auto compute_node_info = reinterpret_cast(GetContext()->compute_node_info); + if (compute_node_info == nullptr) { + return nullptr; + } + const size_t offset = compute_node_info->GetInputsNum() + static_cast(InputExternLayout::kInferenceContext); + if (GetContext()->input_size < offset) { + return nullptr; + } + return GetInputPointer(offset - 1UL); + } +}; +static_assert(std::is_standard_layout::value, "The class InferShapeContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/infer_shape_range_context.h b/inc/metadef/external/exe_graph/runtime/infer_shape_range_context.h new file mode 100644 index 0000000000000000000000000000000000000000..25fe7ad299b9a402a26b03a06ea9ab684194e316 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/infer_shape_range_context.h @@ -0,0 +1,111 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_RANGE_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_RANGE_CONTEXT_H_ +#include +#include "range.h" +#include "tensor.h" +#include "runtime_attrs.h" +#include "extended_kernel_context.h" +namespace gert { +using TensorRange = Range; +/** + * InferShapeRange kernel的context + */ +class InferShapeRangeContext : public ExtendedKernelContext { + public: + /** + * 根据输入index,获取输入shape range指针 + * @param index 输入index + * @return 输入shape range指针,index非法时,返回空指针 + */ + const Range *GetInputShapeRange(const size_t index) const { + return GetInputPointer>(index); + } + + /** + * 根据输入index,获取输出tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param index 输入index + * @return 输入tensor指针,index非法时,返回空指针 + */ + const TensorRange *GetInputTensorRange(const size_t index) const { + return GetInputPointer(index); + } + + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入tensor range指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @return tensor range指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const TensorRange *GetOptionalInputTensorRange(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + + /** + * 根据输入index,获取输入shape range指针 + * @param ir_index IR原型定义中的index + * @return 输入shape range指针,ir_index非法时,返回空指针 + */ + const Range *GetRequiredInputShapeRange(const size_t ir_index) const { + return GetDynamicInputPointer>(ir_index, 0); + } + + /** + * 根据输入index,获取输出tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @return 输入tensor range指针,ir_index非法时,返回空指针 + */ + const TensorRange *GetRequiredInputTensorRange(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入tensor指针 + * 若算子被配置为'data'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return tensor range指针,index或relative_index非法时,返回空指针 + */ + const TensorRange *GetDynamicInputTensorRange(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer(ir_index, relative_index); + } + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入shape range指针 + * @param ir_index IR原型定义中的index + * @return shape range指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const Range *GetOptionalInputShapeRange(const size_t ir_index) const { + return GetDynamicInputPointer>(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入shape指针 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return shape range指针,index或relative_index非法时,返回空指针 + */ + const Range *GetDynamicInputShapeRange(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer>(ir_index, relative_index); + } + + /** + * 根据输出index,获取输出shape range指针 + * @param index 输出index + * @return 输出shape range指针,index非法时,返回空指针 + */ + Range *GetOutputShapeRange(const size_t index) { + return GetOutputPointer>(index); + } +}; +static_assert(std::is_standard_layout::value, "The class InferShapeContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_INFER_SHAPE_RANGE_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/kernel_context.h b/inc/metadef/external/exe_graph/runtime/kernel_context.h new file mode 100644 index 0000000000000000000000000000000000000000..424956c3f9cb2204ab987b074b9f851f0294165c --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/kernel_context.h @@ -0,0 +1,322 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_KERNEL_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_KERNEL_CONTEXT_H_ +#include +#include "kernel_run_context.h" +namespace gert { +class Chain { + public: + using Deleter = void (*)(void *); + /** + * 获取Chain中保存的数据的指针 + * @tparam T 数据类型 + * @return 指向数据的指针 + */ + template::type = 0> + const T *GetPointer() const { + return reinterpret_cast(any_value_.data.inplace); + } + /** + * 获取Chain中保存的数据的指针 + * @tparam T 数据类型 + * @return 指向数据的指针 + */ + template sizeof(void *)), int>::type = 0> + const T *GetPointer() const { + return reinterpret_cast(any_value_.data.pointer); + } + /** + * 获取Chain中保存的数据的指针 + * @tparam T 数据类型 + * @return 指向数据的指针 + */ + template::type = 0> + auto GetPointer() -> T* { + return reinterpret_cast(any_value_.data.inplace); + } + /** + * 获取Chain中保存的数据的指针 + * @tparam T 数据类型 + * @return 指向数据的指针 + */ + template sizeof(void *)), int>::type = 0> + auto GetPointer() -> T* { + return reinterpret_cast(any_value_.data.pointer); + } + /** + * 获取Chain中保存的数据的值 + * @tparam T 数据类型 + * @return 数据的值的引用 + */ + template::type = 0> + const T &GetValue() const { + return *reinterpret_cast(any_value_.data.inplace); + } + /** + * 获取Chain中保存的数据的值 + * @tparam T 数据类型 + * @return 数据的值的引用 + */ + template::type = 0> + auto GetValue() -> T& { + return *reinterpret_cast(any_value_.data.inplace); + } + /** + * 将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除 + * @param data 指向数据的指针 + * @param deleter 释放数据的接口,空指针的含义为不需要释放 + */ + void Set(void * const data, const Chain::Deleter deleter) { + FreeResource(); + any_value_.data.pointer = data; + any_value_.deleter = deleter; + } + /** + * 将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除 + * @tparam T 数据的类型 + * @param data 数据的指针 + */ + template::value), int>::type = 0> + void SetWithDefaultDeleter(T *data) { + Set(data, reinterpret_cast(DefaultDeleter)); + } + /** + * 将数据设置到Chain中。设置数据指针时,会尝试调用deleter将原有保存在Chain的数据删除 + * @tparam T 数据的类型 + * @param data 数据的指针 + */ + template::type, + typename std::enable_if::value, int>::type = 0> + void SetWithDefaultDeleter(PureT *data) { + Set(data, reinterpret_cast(DefaultArrayDeleter)); + } + /** + * 判断当前Chain中保存的数据是否有deleter + * @return true代表含有deleter + */ + bool HasDeleter() const { + return any_value_.deleter != nullptr; + } + + private: + template + static void DefaultArrayDeleter(T *data) { + delete[] data; + } + + template + static void DefaultDeleter(T *data) { + delete data; + } + + void FreeResource() { + if (any_value_.deleter != nullptr) { + any_value_.deleter(any_value_.data.pointer); + } + } + + AsyncAnyValue any_value_; +}; +static_assert(std::is_standard_layout::value, "The class Chain must be a POD"); + +class KernelContext { + public: + /** + * 获取kernel的输入数量 + * @return kernel的输入数量 + */ + size_t GetInputNum() const { + return context_.input_size; + } + /** + * 获取kernel的输出数量 + * @return kernel的输出数量 + */ + size_t GetOutputNum() const { + return context_.output_size; + } + /** + * 获取输入的Chain指针 + * @param i kernel的输入index + * @return 输入Chain的指针 + */ + const Chain *GetInput(const size_t i) const { + if (i >= context_.input_size) { + return nullptr; + } + return reinterpret_cast(context_.values[i]); + } + /** + * 获取输入的Chain指针 + * @param i kernel的输入index + * @return 输入Chain的指针 + */ + Chain *MutableInput(const size_t i) const { + if (i >= context_.input_size) { + return nullptr; + } + return reinterpret_cast(context_.values[i]); + } + /** + * 获取输出的Chain指针 + * @param i kernel的输出index + * @return 输出Chain的指针 + */ + Chain *GetOutput(const size_t i) { + if (i >= context_.output_size) { + return nullptr; + } + return reinterpret_cast(context_.values[context_.input_size + i]); + } + /** + * 获取输出的Chain指针 + * @param i kernel的输出index + * @return 输出Chain的指针 + */ + const Chain *GetOutput(const size_t i) const { + if (i >= context_.output_size) { + return nullptr; + } + return reinterpret_cast(context_.values[context_.input_size + i]); + } + Chain *GetOutput2(const size_t i) { + if (i >= context_.output_size) { + return nullptr; + } + return reinterpret_cast(context_.output_start[i]); + } + /** + * 获取输入数据的值,本函数首先获取输入Chain,然后从输入Chain中获取值 + * @tparam T 值的类型 + * @param i kernel的输入index + * @return 输入的值 + */ + template + const T GetInputValue(size_t i) const { + const auto av = GetInput(i); + if (av == nullptr) { + return {}; + } + return av->GetValue(); + } + /** + * 获取输入数据的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针 + * @tparam T 值的类型 + * @param i kernel的输入index + * @return 输入数据的指针 + */ + template + const T *GetInputPointer(size_t i) const { + const auto av = GetInput(i); + if (av == nullptr) { + return nullptr; + } + return av->GetPointer(); + } + /** + * 获取输入数据的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针 + * @tparam T 值的类型 + * @param i kernel的输入index + * @return 输入数据的指针 + */ + template + auto MutableInputPointer(size_t i) const -> T* { + const auto av = MutableInput(i); + if (av == nullptr) { + return nullptr; + } + return av->GetPointer(); + } + /** + * 获取输入字符串的指针,本函数首先获取输入Chain,然后从输入Chain中获取指针 + * + * todo 特化一个模板就可以了 + * @param i kernel的输入index + * @return 字符串的指针 + */ + const char *GetInputStrPointer(const size_t i) const { + const auto av = GetInput(i); + if (av == nullptr) { + return nullptr; + } + return av->GetValue(); + } + /** + * 获取计算节点信息的指针 + * @return 计算节点信息的指针 + */ + const void *GetComputeNodeExtend() const { + return context_.compute_node_info; + } + /** + * 获取kernel扩展信息的指针 + * @return + */ + const void *GetKernelExtend() const { + return context_.kernel_extend_info; + } + /** + * 获取输出数据的指针,本函数首先获取输出Chain,然后从Chain中获取指针 + * @tparam T 数据的类型 + * @param i kernel的输出index + * @return 输出数据的指针 + */ + template + auto GetOutputPointer(size_t i) -> T* { + const auto av = GetOutput(i); + if (av == nullptr) { + return nullptr; + } + return av->GetPointer(); + } + /** + * 获取输出数据的指针,本函数首先获取输出Chain,然后从Chain中获取指针 + * @tparam T 数据的类型 + * @param i kernel的输出index + * @return 输出数据的指针 + */ + template + const T *GetOutputPointer(size_t i) const { + const auto av = GetOutput(i); + if (av == nullptr) { + return nullptr; + } + return av->GetPointer(); + } + /** + * 获取底层的context结构体,非框架代码请勿直接操作此结构体 + * @return 指向context结构体的指针 + */ + KernelRunContext *GetContext() { + return &context_; + } + /** + * 获取底层的context结构体,非框架代码请勿直接操作此结构体 + * @return 指向context结构体的指针 + */ + const KernelRunContext *GetContext() const { + return &context_; + } + /** + * 根据数据的长度判断一个数据是否会被Inline存储,所谓Inline存储是指此数据保存在context中时不需要单独分配内存 + * @param size 数据的长度 + * @return true代表会被inline存储 + */ + static bool IsInlineSize(const size_t size) { + return size <= sizeof(void *); + } + + private: + KernelRunContext context_; +}; +static_assert(std::is_standard_layout::value, "The class KernelContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_KERNEL_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/kernel_run_context.h b/inc/metadef/external/exe_graph/runtime/kernel_run_context.h new file mode 100644 index 0000000000000000000000000000000000000000..5aa401a2de3ffba8f2899e5af8eab134f63f7342 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/kernel_run_context.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_KERNEL_RUN_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_KERNEL_RUN_CONTEXT_H_ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (*FreeCallback)(void *); + +/** + * Chain的底层数据结构,不要直接引用和操作此数据结构 + */ +typedef struct { + union { + void *pointer; + unsigned char inplace[sizeof(void *)]; + } data; + FreeCallback deleter; +} AsyncAnyValue; + +/** + * KernelContext的底层数据结构,不要直接引用和操作此数据结构 + */ +typedef struct { + size_t input_size; + size_t output_size; + const void *compute_node_info; + const void *kernel_extend_info; + AsyncAnyValue **output_start; // todo delete this + AsyncAnyValue *values[1]; +} KernelRunContext; + +#ifdef __cplusplus +} +#endif + +#endif // METADEF_CXX_INC_EXE_GRAPH_KERNEL_RUN_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/op_execute_context.h b/inc/metadef/external/exe_graph/runtime/op_execute_context.h new file mode 100644 index 0000000000000000000000000000000000000000..4035d2adea5237d0bb3eb2efab63fc4b842eb55e --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/op_execute_context.h @@ -0,0 +1,209 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_OP_EXECUTE_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_OP_EXECUTE_CONTEXT_H_ +#include +#include "exe_graph/runtime/shape.h" +#include "exe_graph/runtime/tensor.h" +#include "exe_graph/runtime/runtime_attrs.h" +#include "exe_graph/runtime/extended_kernel_context.h" +#include "ge/ge_allocator.h" + +namespace gert { +using rtStream = void *; +struct OpExecuteOptions { + int32_t precision_mode; // 精度模式 + int32_t deterministic; // 确定性计算 + char allow_hf32[3UL]; // hf32 + char reserve[53]; // 预留 +}; + +enum class OpExecuteInputExtendIndex{ + kAllocate, + kStream, + kExecuteOption, + kExecuteFunc, + // add new extend input here + kNum +}; + +enum class OpExecuteOutputIndex{ + kBlockMemory, + // add new extend output here + kNum +}; + +/** + * Aclnn kernel的context + */ +class OpExecuteContext : public ExtendedKernelContext { +public: + /** + * 根据输入index,获取输出tensor指针 + * @param index 输入index + * @return 输入tensor指针,index非法时,返回空指针 + */ + const Tensor *GetInputTensor(const size_t index) const { + return GetInputPointer(index); + } + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入tensor指针 + * @param ir_index IR原型定义中的index + * @return tensor指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const Tensor *GetOptionalInputTensor(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入Tensor指针 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return tensor指针,index或relative_index非法时,返回空指针 + */ + const Tensor *GetDynamicInputTensor(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer(ir_index, relative_index); + } + + /** + * 根据输出index,获取输出tensor指针 + * @param index 输出index + * @return 输出tensor指针,index非法时,返回空指针 + */ + const Tensor *GetOutputTensor(const size_t index) const { + const size_t input_num = GetComputeNodeInputNum(); + return GetInputPointer(input_num + index); + } + + /** + * 基于算子IR原型定义,获取`DYNAMIC_OUTPUT`类型的输入Tensor指针 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_OUTPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return tensor指针,index或relative_index非法时,返回空指针 + */ + const Tensor *GetDynamicOutputTensor(const size_t ir_index, + const size_t relative_index) const { + return GetDynamicOutputPointer(ir_index, relative_index); + } + + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入Tensor指针 + * @param ir_index IR原型定义中的index + * @return Tensor指针,index非法时,返回空指针 + */ + const Tensor *GetRequiredInputTensor(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + + /** + * 基于算子IR原型定义,获取`REQUIRED_OUTPUT`类型的输入Tensor指针 + * @param ir_index IR原型定义中的index + * @return Tensor指针,index非法时,返回空指针 + */ + const Tensor *GetRequiredOutputTensor(const size_t ir_index) const { + return GetDynamicOutputPointer(ir_index, 0); + } + + /** + * 获取stream + * @return rtStream, aclnn算子下发的流, 异常情况返回nullptr + */ + rtStream GetStream() const { + const size_t input_num = GetComputeNodeInputNum(); + const size_t output_num = GetComputeNodeOutputNum(); + auto stream = + GetInputPointer(input_num + output_num + + static_cast(OpExecuteInputExtendIndex::kStream)); + if (stream == nullptr) { + return nullptr; + } + return *stream; + } + + /** + * 获取aclnn接口 + * @return void *, aclnn接口指针, 异常情况返回nullptr + */ + void *GetOpExecuteFunc() const { + const size_t input_num = GetComputeNodeInputNum(); + const size_t output_num = GetComputeNodeOutputNum(); + auto op_execute_func = + GetInputPointer(input_num + output_num + + static_cast(OpExecuteInputExtendIndex::kExecuteFunc)); + if (op_execute_func == nullptr) { + return nullptr; + } + return *op_execute_func; + } + + /** + * 申请workspace内存大小 + * @param size 申请内存的大小 + * @return void *,内存地址,异常情况返回nullptr + */ + void *MallocWorkspace(const size_t size); + + /** + * 释放workspace内存 + */ + void FreeWorkspace(); + + /** + * 获取确定性计算模式 + * @return bool,是否开启确定性计算, 异常情况默认返回false + */ + bool GetDeterministic() const { + const size_t input_num = GetComputeNodeInputNum(); + const size_t output_num = GetComputeNodeOutputNum(); + const OpExecuteOptions *options = + GetInputPointer(input_num + output_num + + static_cast(OpExecuteInputExtendIndex::kExecuteOption)); + if (options == nullptr) { + return false; + } + return (options->deterministic != 0); + } + + /** + * 获取allow_hf32 + * @return string,是否开启hf32,正常情况返回 01,00,10,11四种字符串 + * 第一个字符表示Conv类算子是否支持hf32 + * 第二个字符表示MatMul类算子是否支持hf32,异常情况返回nullptr + */ + const char *GetAllowHf32() const { + const size_t input_num = GetComputeNodeInputNum(); + const size_t output_num = GetComputeNodeOutputNum(); + const OpExecuteOptions *options = + GetInputPointer(input_num + output_num + + static_cast(OpExecuteInputExtendIndex::kExecuteOption)); + if (options == nullptr) { + return nullptr; + } + return options->allow_hf32; + } + + /** + * 获取精度模式 + * @return int32,精度模式,异常情况返回一个int32的极大值 + */ + int32_t GetPrecisionMode() const { + const size_t input_num = GetComputeNodeInputNum(); + const size_t output_num = GetComputeNodeOutputNum(); + const OpExecuteOptions *options = + GetInputPointer(input_num + output_num + + static_cast(OpExecuteInputExtendIndex::kExecuteOption)); + if (options == nullptr) { + return std::numeric_limits::max(); + } + return options->precision_mode; + } +}; +static_assert(std::is_standard_layout::value, "The class OpExecuteContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_OP_EXECUTE_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/range.h b/inc/metadef/external/exe_graph/runtime/range.h new file mode 100644 index 0000000000000000000000000000000000000000..7f666cb365cb056d9cf541bc45586e5cb6a90d89 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/range.h @@ -0,0 +1,114 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RANGE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RANGE_H_ + +#include +#include +#include "utils/extern_math_util.h" +#include "shape.h" + +namespace gert { +template +class Range { + public: + /** + * 默认构造一个range + */ + Range() : min_(nullptr), max_(nullptr) { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 通过最小和最大T对象指针构造range + * @param min range最小值指针 + * @param max range最大值指针 + */ + Range(T *min, T* max) : min_(min), max_(max) { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 通过一个T指针构造range,表示最大最小值相同 + * @param same_ele T指针 + */ + explicit Range(T *same_ele) : min_(same_ele), max_(same_ele) { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 判断与另外一个range对象是否相等,如果两个range的最小和最大元素指针分别相等,那么认为两个range相等, + * 如果存在指针不相等的情况,再对比T对象是否分别相等 + * @param rht 另一个Range对象 + * @return true/false + */ + bool operator==(const Range &rht) const { + if ((this->min_ == rht.min_) && (this->max_ == rht.max_)) { + return true; + } else { + return (*this->min_ == *rht.min_) && (*this->max_ == *rht.max_); + } + } + + /** + * 设置最小的T对象指针 + * @param min 最小的T对象指针 + */ + void SetMin(T *min) { + min_ = min; + } + + /** + * 设置最大的T对象指针 + * @param max 最大的T对象指针 + */ + void SetMax(T *max) { + max_ = max; + } + + /** + * 获取最小的T对象指针 + * @return + */ + const T *GetMin() const { + return min_; + } + + /** + * 获取最大的T对象指针 + * @return + */ + const T *GetMax() const { + return max_; + } + + /** + * 获取最小的T对象指针 + * @return + */ + T *GetMin() { + return min_; + } + + /** + * 获取最大的T对象指针 + * @return + */ + T *GetMax() { + return max_; + } + private: + T *min_; + T *max_; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_RANGE_H_ diff --git a/inc/metadef/external/exe_graph/runtime/runtime_attrs.h b/inc/metadef/external/exe_graph/runtime/runtime_attrs.h new file mode 100644 index 0000000000000000000000000000000000000000..7308ed020fa960e4e59d76d5052e1e08b79698a9 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/runtime_attrs.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATTRS_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATTRS_H_ +#include +#include +#include +#include "continuous_vector.h" +#include "tensor.h" + +namespace gert { +class RuntimeAttrs { + public: + /** + * 获取属性 + * @tparam T 属性类型 + * @param index 属性index + * @return 指向属性的指针 + */ + template + const T *GetAttrPointer(size_t index) const { + return reinterpret_cast(GetPointerByIndex(index)); + } + /** + * 获取int类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const int64_t *GetInt(const size_t index) const { + return GetAttrPointer(index); + } + /** + * 获取list int类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const TypedContinuousVector *GetListInt(const size_t index) const { + return GetAttrPointer>(index); + } + /** + * 获取list int类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const ContinuousVectorVector *GetListListInt(const size_t index) const { + return GetAttrPointer(index); + } + /** + * 获取string类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const char *GetStr(const size_t index) const { + return GetAttrPointer(index); + } + /** + * 获取tensor类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const Tensor *GetTensor(const size_t index) const { + return GetAttrPointer(index); + } + /** + * 获取float类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const float *GetFloat(const size_t index) const { + return GetAttrPointer(index); + } + /** + * 获取bool类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const bool *GetBool(const size_t index) const { + return GetAttrPointer(index); + } + + /** + * 获取list_float32类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const TypedContinuousVector *GetListFloat(const size_t index) const { + return GetAttrPointer>(index); + } + + /** + * 获取list_float32类型属性 + * @param index 属性index + * @return 指向属性值的指针 + */ + const ContinuousVectorVector *GetListListFloat(const size_t index) const { + return GetAttrPointer(index); + } + + /** + * 获取属性数量 + * @return 属性数量 + */ + size_t GetAttrNum() const; + + RuntimeAttrs() = delete; + RuntimeAttrs(const RuntimeAttrs &) = delete; + RuntimeAttrs(RuntimeAttrs &&) = delete; + RuntimeAttrs &operator=(const RuntimeAttrs &) = delete; + RuntimeAttrs &operator=(RuntimeAttrs &&) = delete; + + private: + const void *GetPointerByIndex(size_t index) const; + + uint64_t placeholder_; +}; +static_assert(std::is_standard_layout::value, "This class must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_ATTRS_H_ diff --git a/inc/metadef/external/exe_graph/runtime/shape.h b/inc/metadef/external/exe_graph/runtime/shape.h new file mode 100644 index 0000000000000000000000000000000000000000..fcdc39d586fdcca0cbd60a891ecd1ca3325c4908 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/shape.h @@ -0,0 +1,216 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_SHAPE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_SHAPE_H_ + +#include +#include +#include +#include +#include +#include +#include "utils/extern_math_util.h" + +namespace gert { +struct Shape { + public: + static constexpr size_t kMaxDimNum = 25; + static constexpr int64_t kInvalidDimValue = std::numeric_limits::min(); + + public: + /** + * 默认构造一个shape,默认构造的shape实例中,dim_num长度为0 + */ + Shape() : dim_num_(0), dims_{0} { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 通过dims值构造shape,例如:Shape({8,3,224,224})创建一个Shape实例,有4个维度,每个维度的值分别是8,3,224,224 + * @param dims shape的所有dim值 + */ + Shape(const std::initializer_list &args) : Shape() { + if (args.size() > kMaxDimNum) { + return; + } + dim_num_ = args.size(); + size_t i = 0; + for (const int64_t arg : args) { + dims_[i++] = arg; + } + } + + /** + * 拷贝构造 + * @param other 源对象 + * 为了提升性能,dims_超过dim_num_的空间没有拷贝,可能有脏数据 + */ + Shape(const Shape &other) { + dim_num_ = other.dim_num_; + for (size_t i = 0U; i < dim_num_; ++i) { + dims_[i] = other.dims_[i]; + } + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + + /** + * 拷贝赋值 + * @param other + * @return + * 为了提升性能,dims_超过dim_num_的空间没有拷贝,可能有脏数据 + */ + Shape &operator=(const Shape &other) { + if (&other != this) { + dim_num_ = other.dim_num_; + for (size_t i = 0U; i < dim_num_; ++i) { + dims_[i] = other.dims_[i]; + } + } + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + return *this; + } + + /** + * 判断与另外一个shape对象是否相等,如果两个shape的dim num并且dim num内每个dim的值都相等,那么认为两个shape相等 + * @param rht 另一个Shape对象 + * @return true/false + */ + bool operator==(const Shape &rht) const { + if (this->dim_num_ != rht.dim_num_) { + return false; + } + for (size_t i = 0; i < this->dim_num_; i++) { + if (this->dims_[i] != rht.dims_[i]) { + return false; + } + } + return true; + } + + /** + * 判断与另一个Shape对象是否不等 + * @param rht 另一个Shape对象 + * @return true/false + */ + bool operator!=(const Shape &rht) const { + return !(*this == rht); + } + + /** + * 获取shape size,所谓shape size,是指shape中有多少个元素 + * @return shape-size,在发生溢出时,返回`kInvalidDimValue` + */ + int64_t GetShapeSize() const { + int64_t shape_size = 1; + for (size_t i = 0; i < dim_num_; ++i) { + if (ge::MulOverflow(shape_size, dims_[i], shape_size)) { + return kInvalidDimValue; + } + } + return shape_size; + } + + /** + * 判断本shape是否为标量,所谓标量,是指GetDimNum()为0 + * @return true/false + */ + bool IsScalar() const { + return dim_num_ == 0L; + } + + /** + * 设置shape为标量 + * @param none + */ + void SetScalar() { + dim_num_ = 0; + } + + /** + * 获取dim num + * @return + */ + size_t GetDimNum() const { + return dim_num_; + } + + /** + * 设置dim num + * @param dim_num + */ + void SetDimNum(const size_t dim_num) { + this->dim_num_ = dim_num; + } + + /** + * 获取dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return dim值,在idx超出MaxDimNum时,返回`kInvalidDimValue` + */ + int64_t GetDim(const size_t idx) const { + if (idx >= kMaxDimNum) { + return kInvalidDimValue; + } + return dims_[idx]; + } + + /** + * 获取dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return dim值,行为未定义 + */ + const int64_t &operator[](const size_t idx) const { + return dims_[idx]; + } + + /** + * 获取dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return dim值,在idx超出MaxDimNum时,行为未定义 + */ + int64_t &operator[](const size_t idx) { + return dims_[idx]; + } + + /** + * 设置dim值 + * @param idx dim的index,调用者需要保证index合法 + * @return + */ + void SetDim(size_t idx, const int64_t dim_value) { + if (idx >= kMaxDimNum) { + return; + } + dims_[idx] = dim_value; + this->dim_num_ = (this->dim_num_ < idx) ? idx : this->dim_num_; + } + + /** + * 向后扩展一个dim值,如果扩展的dim数量超出Shape的最大限制,那么本函数不做任何事情 + * @param 扩展的dim值 + * @return this引用 + */ + Shape& AppendDim(const int64_t value) { + if (this->dim_num_ >= kMaxDimNum) { + return *this; + } + dims_[this->dim_num_++] = value; + return *this; + } + + private: + size_t dim_num_; + int64_t dims_[kMaxDimNum]; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class Shape must be a POD"); +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_SHAPE_H_ diff --git a/inc/metadef/external/exe_graph/runtime/storage_format.h b/inc/metadef/external/exe_graph/runtime/storage_format.h new file mode 100644 index 0000000000000000000000000000000000000000..de24a6a9d3897778924ea2d7cb75c7cb897d6896 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/storage_format.h @@ -0,0 +1,107 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_STORAGE_FORMAT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_STORAGE_FORMAT_H_ +#include +#include "graph/types.h" +#include "expand_dims_type.h" +namespace gert { +struct StorageFormat { + public: + + StorageFormat() { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + }; + /** + * 构造一个格式,格式包括原始格式、运行时格式、补维规则 + * @param origin_format 原始格式 + * @param storage_format 运行时格式 + * @param expand_dims_type 补维规则 + */ + StorageFormat(const ge::Format origin_format, const ge::Format storage_format, const ExpandDimsType &expand_dims_type) + : origin_format_(origin_format), storage_format_(storage_format), expand_dims_type_(expand_dims_type) { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + /** + * 获取原始format + * @return 原始format + */ + ge::Format GetOriginFormat() const { + return origin_format_; + } + /** + * 设置原始format + * @param origin_format 原始format + */ + void SetOriginFormat(const ge::Format origin_format) { + origin_format_ = origin_format; + } + /** + * 获取运行时format + * @return 运行时format + */ + ge::Format GetStorageFormat() const { + return storage_format_; + } + /** + * 设置运行时format + * @param storage_format 运行时format + */ + void SetStorageFormat(const ge::Format storage_format) { + storage_format_ = storage_format; + } + /** + * 获取补维规则 + * @return 补维规则 + */ + ExpandDimsType GetExpandDimsType() const { + return expand_dims_type_; + } + /** + * 设置补维规则 + * @param expand_dims_type 补维规则 + */ + void SetExpandDimsType(const ExpandDimsType &expand_dims_type) { + expand_dims_type_ = expand_dims_type; + } + /** + * 获取可写的补维规则 + * @return 补维规则引用 + */ + ExpandDimsType &MutableExpandDimsType() { + return expand_dims_type_; + } + /** + * 判断格式是否相等 + * @param other 另一个格式 + * @return true代表相等 + */ + bool operator==(const StorageFormat &other) const { + return origin_format_ == other.origin_format_ && storage_format_ == other.storage_format_ && + expand_dims_type_ == other.expand_dims_type_; + } + /** + * 判断格式是否不相等 + * @param other 另一个格式 + * @return true代表不相等 + */ + bool operator!=(const StorageFormat &other) const { + return !(*this == other); + } + + private: + ge::Format origin_format_; + ge::Format storage_format_; + ExpandDimsType expand_dims_type_; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class StorageFormat must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_STORAGE_FORMAT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/storage_shape.h b/inc/metadef/external/exe_graph/runtime/storage_shape.h new file mode 100644 index 0000000000000000000000000000000000000000..b239ca19f8c91966d9837bde5cd060906481da26 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/storage_shape.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_STORAGE_SHAPE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_STORAGE_SHAPE_H_ +#include +#include "shape.h" + +namespace gert { +struct StorageShape { + public: + StorageShape() { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + }; + /** + * 构造一个运行时shape实例 + * @param origin_shape 原始shape + * @param storage_shape 运行时shape + */ + StorageShape(const std::initializer_list &origin_shape, const std::initializer_list &storage_shape) + : origin_shape_(origin_shape), storage_shape_(storage_shape) { + (void)memset(reserved_, 0, sizeof(reserved_)); // memset函数misra告警屏蔽 + } + /** + * 获取原始shape + * @return 原始shape + */ + const Shape &GetOriginShape() const { + return origin_shape_; + } + /** + * 获取运行时shape + * @return 运行时shape + */ + const Shape &GetStorageShape() const { + return storage_shape_; + } + /** + * 获取可写的原始shape + * @return 可写的原始shape + */ + Shape &MutableOriginShape() { + return origin_shape_; + } + /** + * 获取可写的运行时shape + * @return 可写的运行时shape + */ + Shape &MutableStorageShape() { + return storage_shape_; + } + /** + * 判断shape是否相等 + * @param other 另一个shape + * @return true表示相等 + */ + bool operator==(const StorageShape &other) const { + return origin_shape_ == other.origin_shape_ && storage_shape_ == other.storage_shape_; + } + /** + * 判断shape是否不相等 + * @param other 另一个shape + * @return true表示不相等 + */ + bool operator!=(const StorageShape &other) const { + return !(*this == other); + } + private: + Shape origin_shape_; + Shape storage_shape_; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class must be a POD"); +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_STORAGE_SHAPE_H_ diff --git a/inc/metadef/external/exe_graph/runtime/tensor.h b/inc/metadef/external/exe_graph/runtime/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..82504595f5994c41879ea62093e3a29f0dd6c5fa --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/tensor.h @@ -0,0 +1,317 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_TENSOR_H_ +#define METADEF_CXX_INC_EXE_GRAPH_TENSOR_H_ + +#include + +#include "graph/ge_error_codes.h" +#include "storage_shape.h" +#include "storage_format.h" +#include "tensor_data.h" + +namespace gert { +using TensorAddress = void *; ///< Tensor地址 +using ConstTensorAddress = void *const; ///< Tensor地址 + +class Tensor { + public: + // memse函数misra告警屏蔽 + Tensor() { + (void) memset(reserved_, 0, sizeof(reserved_)); + (void) memset(reserved_field_, 0, sizeof(reserved_field_)); + } + Tensor(const StorageShape &storage_shape, const StorageFormat &storage_format, const TensorPlacement placement, + const ge::DataType data_type, TensorAddress addr) + : Tensor(storage_shape, storage_format, placement, data_type, addr, nullptr) {} + Tensor(const StorageShape &storage_shape, const StorageFormat &storage_format, ge::DataType data_type) + : storage_shape_(storage_shape), storage_format_(storage_format), data_type_(data_type) { + (void) memset(reserved_, 0, sizeof(reserved_)); + (void) memset(reserved_field_, 0, sizeof(reserved_field_)); + } + Tensor(const StorageShape &storage_shape, const StorageFormat &storage_format, const TensorPlacement placement, + const ge::DataType data_type, TensorAddress addr, TensorAddrManager manager) + : storage_shape_(storage_shape), storage_format_(storage_format), data_type_(data_type), + tensor_data_(addr, manager, static_cast(ge::GetSizeInBytes(GetShapeSize(), data_type_)), placement) { + (void) memset(reserved_, 0, sizeof(reserved_)); + (void) memset(reserved_field_, 0, sizeof(reserved_field_)); + } + /** + * 获取shape size,所谓shape size是指本shape中包含的element数量 + * @return shape size + */ + int64_t GetShapeSize() const { + return storage_shape_.GetStorageShape().GetShapeSize(); + } + /** + * 获取Tensor的数据地址 + * @tparam T 数据类型 + * @return 数据地址 + */ + template + const T *GetData() const { + return static_cast(GetAddr()); + } + /** + * 获取Tensor的数据地址 + * @tparam T 数据类型 + * @return 数据地址 + */ + template + auto GetData() -> T* { + return static_cast(GetAddr()); + } + /** + * 设置Tensor数据 + * @param data 数据 + */ + void SetData(TensorData &&data) { + tensor_data_ = std::move(data); + } + /** + * 获取Tensor的数据地址 + * @return 数据地址 + */ + const void *GetAddr() const { + if (tensor_data_.GetPlacement() == kFollowing) { + return reinterpret_cast(reinterpret_cast(this) + sizeof(*this)); + } else { + return tensor_data_.GetAddr(); + } + } + /** + * 获取Tensor的数据地址 + * @return 数据地址 + */ + void *GetAddr() { + if (tensor_data_.GetPlacement() == kFollowing) { + return reinterpret_cast(reinterpret_cast(this) + sizeof(*this)); + } else { + return tensor_data_.GetAddr(); + } + } + /** + * 获取Tensor的内存大小 + * @return 内存大小 + */ + size_t GetSize() const { + return tensor_data_.GetSize(); + } + + /** + * 设置Tensor的内存大小 + * @param Tensor的内存大小 + */ + void SetSize(const size_t size) { + tensor_data_.SetSize(size); + } + + /** + * 获取Tensor的data type + * @return data type + */ + ge::DataType GetDataType() const { + return data_type_; + } + /** + * 设置Tensor的data type + * @param data_type data type + */ + void SetDataType(const ge::DataType data_type) { + data_type_ = data_type; + } + /** + * 创建一个Tensor,tensor数据在Tensor对象后面连续排布 + * @param shape_size 元素个数 + * @param dt 数据类型 + * @param total_size 创建出的Tensor在内存中的长度 + * @return 创建出的Tensor指针 + */ + static std::unique_ptr CreateFollowing(const int64_t shape_size, const ge::DataType dt, + size_t &total_size) { + total_size = static_cast(ge::GetSizeInBytes(shape_size, dt)); + return NewFollowingTensor(dt, total_size); + } + /** + * 创建一个Tensor,tensor数据在Tensor对象后面连续排布 + * @param tensor_size tensor长度 + * @param dt 数据类型 + * @param total_size 创建出的Tensor在内存中的长度 + * @return 创建出的Tensor指针 + */ + static std::unique_ptr CreateFollowing(const ge::DataType dt, const size_t tensor_size, + size_t &total_size) { + total_size = tensor_size; + if (ge::AddOverflow(total_size, sizeof(Tensor), total_size)) { + return nullptr; + } + auto holder = std::unique_ptr(new (std::nothrow) uint8_t[total_size]); + if (holder == nullptr) { + return nullptr; + } + auto tensor = reinterpret_cast(holder.get()); + new (holder.get()) Tensor({}, {}, dt); + tensor->SetPlacement(kFollowing); + tensor->tensor_data_ = TensorData(nullptr, nullptr, total_size - sizeof(Tensor), kFollowing); + return holder; + } + /** + * 获取运行shape + * @return 只读的运行时shape引用 + */ + const Shape &GetStorageShape() const { + return storage_shape_.GetStorageShape(); + } + /** + * 获取运行shape + * @return 运行时shape的引用 + */ + Shape &MutableStorageShape() { + return storage_shape_.MutableStorageShape(); + } + /** + * 获取原始shape + * @return 只读的原始shape引用 + */ + const Shape &GetOriginShape() const { + return storage_shape_.GetOriginShape(); + } + /** + * 获取原始shape + * @return 原始shape引用 + */ + Shape &MutableOriginShape() { + return storage_shape_.MutableOriginShape(); + } + /** + * 获取shape,包含运行和原始shape + * @return 只读的shape引用 + */ + const StorageShape &GetShape() const { + return storage_shape_; + } + /** + * 获取shape,包含运行和原始shape + * @return shape引用 + */ + StorageShape &GetShape() { + return storage_shape_; + } + /** + * 获取运行时format + * @return 运行时format + */ + ge::Format GetStorageFormat() const { + return storage_format_.GetStorageFormat(); + } + /** + * 设置运行时format + * @param storage_format 运行时format + */ + void SetStorageFormat(const ge::Format storage_format) { + storage_format_.SetStorageFormat(storage_format); + } + /** + * 获取原始format + * @return 原始format + */ + ge::Format GetOriginFormat() const { + return storage_format_.GetOriginFormat(); + } + /** + * 设置原始format + * @param origin_format 原始format + */ + void SetOriginFormat(const ge::Format origin_format) { + storage_format_.SetOriginFormat(origin_format); + } + /** + * 获取format,format包含运行时format和原始format + * @return 只读的format引用 + */ + const StorageFormat &GetFormat() const { + return storage_format_; + } + /** + * 获取format,format包含运行时format和原始format + * @return format引用 + */ + StorageFormat &MutableFormat() { + return storage_format_; + } + /** + * 获取补维规则 + * @return 补维规则 + */ + ExpandDimsType GetExpandDimsType() const { + return storage_format_.GetExpandDimsType(); + } + /** + * 设置补维规则 + * @param expand_dims_type 补维规则 + */ + void SetExpandDimsType(const ExpandDimsType &expand_dims_type) { + storage_format_.SetExpandDimsType(expand_dims_type); + } + /** + * 获取tensor的placement + * @return tensor的placement + */ + TensorPlacement GetPlacement() const { + return tensor_data_.GetPlacement(); + } + /** + * 设置tensor的placement + * @param tensor的placement + */ + void SetPlacement(const TensorPlacement placement) { + tensor_data_.SetPlacement(placement); + } + /** + * 获取tensor data + * @return 只读的tensor data引用 + */ + const TensorData &GetTensorData() const { + return tensor_data_; + } + /** + * 获取tensor data + * @return 可写的tensor data引用 + */ + TensorData &MutableTensorData() { + return tensor_data_; + } + private: + static std::unique_ptr NewFollowingTensor(const ge::DataType dt, size_t &total_size) { + if (ge::AddOverflow(total_size, sizeof(Tensor), total_size)) { + return nullptr; + } + auto holder = std::unique_ptr(new (std::nothrow) uint8_t[total_size]); + if (holder == nullptr) { + return nullptr; + } + auto tensor = reinterpret_cast(holder.get()); + new (holder.get()) Tensor({}, {}, dt); + tensor->SetPlacement(kFollowing); + tensor->tensor_data_ = TensorData(nullptr, nullptr, total_size - sizeof(Tensor), kFollowing); + return holder; + } + private: + StorageShape storage_shape_; + StorageFormat storage_format_; + uint8_t reserved_[4]; // Reserved field, 4-byte aligned + ge::DataType data_type_; + TensorData tensor_data_; + uint8_t reserved_field_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +static_assert(std::is_standard_layout::value, "The class Tensor must be a POD"); +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_TENSOR_H_ diff --git a/inc/metadef/external/exe_graph/runtime/tensor_data.h b/inc/metadef/external/exe_graph/runtime/tensor_data.h new file mode 100644 index 0000000000000000000000000000000000000000..b723c2f9adee7ac12ec8715e0303815b607b43df --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/tensor_data.h @@ -0,0 +1,240 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_TENSOR_DATA_H_ +#define METADEF_CXX_INC_EXE_GRAPH_TENSOR_DATA_H_ + +#include +#include +#include +#include "graph/ge_error_codes.h" + +namespace gert { +using TensorAddress = void *; +using ConstTensorAddressPtr = const void *; + +enum TensorPlacement : int32_t { + kOnDeviceHbm, ///< Tensor位于Device上的HBM内存 + kOnHost, ///< Tensor位于Host + kFollowing, ///< Tensor位于Host,且数据紧跟在结构体后面 + kOnDeviceP2p, ///< Tensor位于Device上的P2p内存 + kTensorPlacementEnd +}; + +class TensorPlacementUtils { + public: + static bool IsOnDevice(TensorPlacement placement) { + return (placement == kOnDeviceHbm) || (placement == kOnDeviceP2p); + } + static bool IsOnHost(TensorPlacement placement) { + return (placement == kOnHost) || (placement == kFollowing); + } + static bool IsOnHostFollowing(TensorPlacement placement) { + return (placement == kFollowing); + } + static bool IsOnHostNotFollowing(TensorPlacement placement) { + return (placement == kOnHost); + } + static bool IsOnDeviceHbm(TensorPlacement placement) { + return (placement == kOnDeviceHbm); + } + static bool IsOnDeviceP2p(TensorPlacement placement) { + return (placement == kOnDeviceP2p); + } +}; + +enum TensorOperateType : int32_t { + kGetTensorAddress, ///< 获取Tensor的地址 + kFreeTensor, ///< 释放Tensor + kPlusShareCount, ///< 共享Tensor + kTensorOperateType +}; + +/** + * Tensor的管理函数 + */ +using TensorAddrManager = ge::graphStatus (*)(TensorAddress addr, TensorOperateType operate_type, void **out); + +class TensorData { + public: + /** + * 构造一个TensorData + * @param addr tensor的地址 + * @param manager tensor data的管理函数,若manager为空,则认为addr就是tensor的数据地址,且此数据不需要被释放 + */ + // memse函数misra告警屏蔽 + explicit TensorData(TensorAddress addr = nullptr, const TensorAddrManager manager = nullptr) + : TensorData(addr, manager, 0U, kTensorPlacementEnd) { + } + explicit TensorData(TensorAddress addr, const TensorAddrManager manager, size_t size, TensorPlacement placement) + : addr_(addr), manager_(manager), size_(size), placement_(placement), reserved_0_(0U) { + (void)memset(reserved_1_, 0, sizeof(reserved_1_)); + } + TensorData(const TensorData &) = delete; + TensorData(TensorData &&other) noexcept : addr_(other.addr_), manager_(other.manager_), + size_(other.size_), placement_(other.placement_) { + other.addr_ = nullptr; + other.manager_ = nullptr; + other.size_ = 0U; + other.placement_ = kTensorPlacementEnd; + reserved_0_ = other.reserved_0_; + (void)memcpy(reserved_1_, other.reserved_1_, sizeof(reserved_1_)); + } + TensorData &operator=(const TensorData &other) = delete; + TensorData &operator=(TensorData &&other) noexcept { + if (this != &other) { + static_cast(Free()); + addr_ = other.addr_; + manager_ = other.manager_; + size_ = other.size_; + placement_ = other.placement_; + other.addr_ = nullptr; + other.manager_ = nullptr; + other.size_ = 0U; + other.placement_ = kTensorPlacementEnd; + reserved_0_ = other.reserved_0_; + (void)memcpy(reserved_1_, other.reserved_1_, sizeof(reserved_1_)); + } + return *this; + } + ~TensorData() noexcept { + static_cast(Free()); + } + + /** + * 获取tensor地址 + * @return tensor地址 + */ + TensorAddress GetAddr() const { + if (manager_ == nullptr || addr_ == nullptr) { + return addr_; + } + TensorAddress addr; + if (manager_(addr_, kGetTensorAddress, &addr) != ge::GRAPH_SUCCESS) { + return nullptr; + } else { + return addr; + } + } + + /** + * 获取tensor的内存大小 + * @return tensor所占内存大小 + */ + size_t GetSize() const { + return size_; + } + /** + * 设置tensor的内存大小 + * @param tensor的内存大小 + */ + void SetSize(const size_t size) { + size_ = size; + } + + /** + * 获取tensor的placement + * @return tensor的placement + */ + TensorPlacement GetPlacement() const { + return placement_; + } + /** + * 设置tensor的placement + * @param tensor的placement + */ + void SetPlacement(const TensorPlacement placement) { + placement_ = placement; + } + /** + * 释放tensor + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus Free() { + if (manager_ == nullptr) { + return ge::GRAPH_SUCCESS; + } + const auto ret = manager_(addr_, kFreeTensor, nullptr); + if (ret == ge::GRAPH_SUCCESS) { + addr_ = nullptr; + manager_ = nullptr; + } + return ret; + } + /** + * 设置tensor地址 + * @param addr tensor地址 + * @param manager tensor的管理函数 + */ + ge::graphStatus SetAddr(const ConstTensorAddressPtr addr, TensorAddrManager manager) { + const auto ret = Free(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + addr_ = const_cast(addr); + manager_ = manager; + return ge::GRAPH_SUCCESS; + } + + bool IsSharedWith(const TensorData &other) const { + return (addr_ == other.addr_ && manager_ == other.manager_ && size_ == other.size_ && + placement_ == other.placement_); + } + + ge::graphStatus ShareFrom(const TensorData &other) { + if (IsSharedWith(other)) { + return ge::GRAPH_SUCCESS; + } + const auto ret = Free(); + if (ret != ge::GRAPH_SUCCESS) { + return ret; + } + + addr_ = other.addr_; + manager_ = other.manager_; + size_ = other.size_; + placement_ = other.placement_; + if (manager_ != nullptr) { + return manager_(addr_, kPlusShareCount, nullptr); + } else { + return ge::GRAPH_SUCCESS; + } + } + + /** + * 释放对TensorAddress的所有权,本接口调用后,本对象不再管理TensorAddress, + * 而且TensorAddress并没有被释放,因此调用者负责通过manager释放TensorAddress + * @param manager 出参,用于管理返回的`TensorAddress`的函数,若本对象无所有权,那么manager被设置为nullptr + * @return 本对象持有的`TensorAddress`指针,若本对象不持有任何指针,则返回nullptr + */ + TensorAddress Release(TensorAddrManager &manager) { + auto tmp_addr = addr_; + manager = manager_; + + addr_ = nullptr; + manager_ = nullptr; + size_ = 0U; + placement_ = kTensorPlacementEnd; + + return tmp_addr; + } + + private: + TensorAddress addr_; + TensorAddrManager manager_; + size_t size_; + TensorPlacement placement_; + uint32_t reserved_0_; // Reserved field, 8-byte aligned for TensorPlacement + uint8_t reserved_1_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left + + friend class TensorUtils; +}; +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_TENSOR_DATA_H_ diff --git a/inc/metadef/external/exe_graph/runtime/tiling_context.h b/inc/metadef/external/exe_graph/runtime/tiling_context.h new file mode 100644 index 0000000000000000000000000000000000000000..c2054acc42746b09698ef37a0954e5224d9e142c --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/tiling_context.h @@ -0,0 +1,416 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_TILING_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_TILING_CONTEXT_H_ +#include "storage_shape.h" +#include "tensor.h" +#include "continuous_vector.h" +#include "extended_kernel_context.h" +#include "tiling_data.h" +#include "external/ge_common/ge_api_error_codes.h" + +namespace fe { +class PlatFormInfos; +} // namespace fe + +namespace gert { +/** + * tiling kernel的context + */ +class TilingContext : public ExtendedKernelContext { + public: + const void *GetCompileInfo() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + + const size_t index = compute_node_info->GetInputsNum() + compute_node_info->GetOutputsNum(); + const auto av = GetInput(index); + if (av == nullptr) { + return nullptr; + } + return av->GetValue(); + } + /** + * 获取CompileInfo + * @tparam T CompileInfo的类型 + * @return CompileInfo的指针 + */ + template + const T *GetCompileInfo() const { + return reinterpret_cast(GetCompileInfo()); + } + /** + * 获取输入shape,输入shape中包含了原始shape与运行时shape + * @param index 输入index + * @return 输入shape指针,index非法时返回空指针 + */ + const StorageShape *GetInputShape(const size_t index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + + if (index >= compute_node_info->GetInputsNum()) { + return nullptr; + } + return GetInputPointer(index); + } + /** + * 获取输入tensor + * 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param index 输入index + * @return 输入tensor指针,index非法时返回空指针 + */ + const Tensor *GetInputTensor(const size_t index) const { + return GetInputPointer(index); + } + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入tensor指针 + * 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @return tensor指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const Tensor *GetOptionalInputTensor(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入Tensor指针 + * 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return Tensor指针,index或relative_index非法时,返回空指针 + */ + const Tensor *GetDynamicInputTensor(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer(ir_index, relative_index); + } + /** + * 基于算子IR原型定义,获取`OPTIONAL_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape + * @param ir_index IR原型定义中的index + * @return shape指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const StorageShape *GetOptionalInputShape(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入Tensor指针 + * 若算子被配置为'tiling'数据依赖,则返回的Tensor对象中保存了Host内存地址;反之,内存地址为nullptr。 + * @param ir_index IR原型定义中的index + * @return Tensor指针,index非法时,返回空指针 + */ + const Tensor *GetRequiredInputTensor(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`REQUIRED_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape + * @param ir_index IR原型定义中的index + * @return shape指针,index非法,或该INPUT没有实例化时,返回空指针 + */ + const StorageShape *GetRequiredInputShape(const size_t ir_index) const { + return GetDynamicInputPointer(ir_index, 0); + } + /** + * 基于算子IR原型定义,获取`DYNAMIC_INPUT`类型的输入shape指针,shape中包含了原始shape与运行时shape + * @param ir_index IR原型定义中的index + * @param relative_index 该输入实例化后的相对index,例如某个DYNAMIC_INPUT实例化了3个输入,那么relative_index的有效范围是[0,2] + * @return shape指针,index或relative_index非法时,返回空指针 + */ + const StorageShape *GetDynamicInputShape(const size_t ir_index, const size_t relative_index) const { + return GetDynamicInputPointer(ir_index, relative_index); + } + /** + * 根据输出index,获取输出shape指针,shape中包含了原始shape与运行时shape + * @param index 输出index + * @return 输出shape指针,index非法时,返回空指针 + */ + const StorageShape *GetOutputShape(const size_t index) const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + + if (index >= compute_node_info->GetOutputsNum()) { + return nullptr; + } + + const size_t offset = compute_node_info->GetInputsNum(); + return GetInputPointer(offset + index); + } + + /* + * outputs, tiling的outputs以如下顺序排列: + * outputs[0]: tiling-key + * outputs[1]: block-dim + * outputs[2]: atomic-clean-flag + * outputs[3]: tiling-data + * outputs[4]: workspace sizes + * outputs[5]: tiling condition + * outputs[6]: schedule mode + */ + enum TilingOutputIndex : uint32_t { + kOutputTilingKey, + kOutputBlockDim, + kOutputAtomicCleanFlag, + kOutputTilingData, + kOutputWorkspace, + kOutputTilingCond, + kOutputScheduleMode, + kOutputLocalMemorySize, + // add new output definitions here + kOutputNum + }; + + /* + * outputs[0]: fallible tiling condition + */ + enum FallibleTilingOutputIndex : uint32_t { + kTilingStatus = TilingOutputIndex::kOutputNum, + // add new output definitions here + kFallibleOutputNum + }; + + /** + * 设置tiling key + * @param tiling_key tiling key + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetTilingKey(const uint64_t tiling_key) { + const auto p = GetOutputPointer(kOutputTilingKey); + if (p == nullptr) { + return ge::GRAPH_FAILED; + } + *p = tiling_key; + return ge::GRAPH_SUCCESS; + } + /** + * 获取tiling key + * @return tiling key,获取失败时 + */ + uint64_t GetTilingKey() const { + const auto p = GetOutputPointer(kOutputTilingKey); + if (p == nullptr) { + return std::numeric_limits::max(); + } + return *p; + } + + /** + * 设置schedule_mode + * @param schedule_mode schedule_mode + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetScheduleMode(const uint32_t schedule_mode) { + const auto p = GetOutputPointer(kOutputScheduleMode); + if (p == nullptr) { + return ge::GRAPH_FAILED; + } + *p = schedule_mode; + return ge::GRAPH_SUCCESS; + } + /** + * 获取设置schedule_mode + * @return 设置schedule_mode,获取失败时 + */ + uint32_t GetScheduleMode() const { + const auto p = GetOutputPointer(kOutputScheduleMode); + if (p == nullptr) { + return 0U; + } + return *p; + } + + /** + * 设置block dim + * @param block_dim block dim + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetBlockDim(const uint32_t block_dim) { + const auto p = GetOutputPointer(kOutputBlockDim); + if (p == nullptr) { + return ge::GRAPH_FAILED; + } + *p = block_dim; + return ge::GRAPH_SUCCESS; + } + /** + * 获取block dim + * @return block dim + */ + uint32_t GetBlockDim() const { + const auto p = GetOutputPointer(kOutputBlockDim); + if (p == nullptr) { + return std::numeric_limits::max(); + } + return *p; + } + /** + * 设置tiling cond + * @param tiling_cond tiling condition + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetTilingCond(int32_t tiling_cond) { + const auto p = GetOutputPointer(kOutputTilingCond); + if (p == nullptr) { + return ge::GRAPH_FAILED; + } + *p = tiling_cond; + return ge::GRAPH_SUCCESS; + } + /** + * 获取tiling cond + * @return tiling cond:有效的tiling_cond大于等于0,若该值无效返回-1 + */ + int32_t GetTilingCond() const { + const auto p = GetOutputPointer(kOutputTilingCond); + if (p == nullptr) { + return -1; + } + return *p; + } + /** + * 设置是否需要atomic clean + * @param atomic true/false代表是否需要做atomic clean + * @return 成功时返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetNeedAtomic(const bool atomic) { + const auto p = GetOutputPointer(kOutputAtomicCleanFlag); + if (p == nullptr) { + return ge::GRAPH_FAILED; + } + *p = atomic; + return ge::GRAPH_SUCCESS; + } + /** + * 获取是否需要atomic clean + * @return true/false + */ + bool NeedAtomic() const { + const auto p = GetOutputPointer(kOutputAtomicCleanFlag); + if (p == nullptr) { + return false; + } + return *p; + } + /** + * 获取有类型的tiling data指针 + * @tparam T tiling data类型,sizeof(T)不可以大于编译结果中指定的最大tiling data长度 + * @return tiling data指针,失败时返回空指针 + */ + template + auto GetTilingData() -> T* { + auto tiling_data = GetRawTilingData(); + if (tiling_data == nullptr) { + return nullptr; + } + if (tiling_data->GetCapacity() < sizeof(T)) { + return nullptr; + } + tiling_data->SetDataSize(sizeof(T)); + return static_cast(tiling_data->GetData()); + } + /** + * 获取无类型的tiling data码流 + * @return tiling data指针,失败时返回空指针 + */ + TilingData *GetRawTilingData() { + return *GetOutputPointer(kOutputTilingData); + } + /** + * 获取workspace sizes指针 + * @param workspace_count workspace的个数,传入的workspace个数不可以超过编译时指定的最大workspace个数 + * @return workspace sizes指针 + */ + size_t *GetWorkspaceSizes(const size_t workspace_count) { + const auto workspace = GetOutputPointer>(kOutputWorkspace); + if (workspace == nullptr) { + return nullptr; + } + if (workspace->SetSize(workspace_count) != ge::SUCCESS) { + return nullptr; + } + return workspace->MutableData(); + } + /** + * 获取 workspace 个数 + * @return workspace 个数 + */ + size_t GetWorkspaceNum() const { + const auto workspace = GetOutputPointer>(kOutputWorkspace); + if (workspace == nullptr) { + return 0U; + } + return workspace->GetSize(); + } + /** + * 获取 fe::PlatFormInfos 指针 + * @return fe::PlatFormInfos 指针 + */ + fe::PlatFormInfos *GetPlatformInfo() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return nullptr; + } + + const size_t index = compute_node_info->GetInputsNum() + compute_node_info->GetOutputsNum(); + const auto av = GetInput(index + 1U); + if (av == nullptr) { + return nullptr; + } + return av->GetValue(); + } + + /** + * 获取 确定性计算变量 + * @return int32 变量 + */ + int32_t GetDeterministic() const { + const auto compute_node_info = GetComputeNodeInfo(); + if (compute_node_info == nullptr) { + return std::numeric_limits::max(); + } + const size_t index = compute_node_info->GetInputsNum() + compute_node_info->GetOutputsNum(); + // 此处按照tiling内存排布,将确定性计算的字段添加在 + // inputshape outputshape compileinfo platform tiling_func之后 + const auto av = GetInput(index + 3U); + if (av == nullptr) { + return std::numeric_limits::max(); + } + return av->GetValue(); + } + + /** + * 设置 local memory size, 默认值为0 + * @param local_memory_size + * @return 成功返回ge::GRAPH_SUCCESS + */ + ge::graphStatus SetLocalMemorySize(const uint32_t local_memory_size) { + const auto p = GetOutputPointer(kOutputLocalMemorySize); + if (p == nullptr) { + return ge::GRAPH_FAILED; + } + *p = local_memory_size; + return ge::GRAPH_SUCCESS; + } + /** + * 获取 local memory size,默认值为0 + * @return local memory size, 失败返回 std::numeric_limits::max() + */ + uint32_t GetLocalMemorySize() const { + const auto p = GetOutputPointer(kOutputLocalMemorySize); + if (p == nullptr) { + return std::numeric_limits::max(); + } + return *p; + } +}; +static_assert(std::is_standard_layout::value, "The class TilingContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_TILING_CONTEXT_H_ diff --git a/inc/metadef/external/exe_graph/runtime/tiling_data.h b/inc/metadef/external/exe_graph/runtime/tiling_data.h new file mode 100644 index 0000000000000000000000000000000000000000..d382b073b1e63ddcd27e0f2b66dc0c743c3dfb38 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/tiling_data.h @@ -0,0 +1,222 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_TILING_DATA_H_ +#define METADEF_CXX_INC_EXE_GRAPH_TILING_DATA_H_ +#include +#include +#include +#include +#include +#include "exe_graph/runtime/continuous_vector.h" +#include "exe_graph/runtime/runtime_attrs.h" +#include "graph/ge_error_codes.h" +#include "utils/extern_math_util.h" + +namespace gert { +enum class AttrDataType : int32_t { + kBool = 0, + kString, + kInt32, + kInt64, + kUint32, + kFloat32, + kFloat16, + kListBool, + kListString, + kListInt32, + kListInt64, + kListUint32, + kListFloat32, + kListFloat16, + kListListInt32, + kListListInt64, + kBfloat16, + kInt8, + kInt16, + kListInt8, + kListInt16, + kListListBool, + kListListFloat16, + kListBfloat16, + kListListBfloat16, + kListListFloat32, + kListListInt8, + kListListInt16, + kUint8, + kListUint8, + kListListUint8, + kUint16, + kListUint16, + kListListUint16, + kListListUint32, + kUint64, + kListUint64, + kListListUint64, + kTypeEnd +}; +class TilingData { + public: + /** + * 获取本实例可容纳的最大tiling data长度 + * @return 最大tiling data长度 + */ + size_t GetCapacity() const { + return capacity_; + } + /** + * 获取tiling data长度 + * @return tiling data长度 + */ + size_t GetDataSize() const { + return data_size_; + } + /** + * 设置tiling data长度 + * @param size tiling data长度 + */ + void SetDataSize(const size_t size) { + data_size_ = size; + } + /** + * 获取data指针 + * @return data指针 + */ + void *GetData() { + return data_; + } + /** + * 获取data指针 + * @return data指针 + */ + const void *GetData() const { + return data_; + } + /** + * 向后添加tiling data,若添加超过可容纳的最大长度,则添加失败 + * @tparam T 添加的tiling data的类型 + * @param data 添加的tiling data实例 + * @return 成功返回ge::GRAPH_SUCCESS + */ + template::value, int>::type = 0> + ge::graphStatus Append(const T &data) { + auto data_ptr = Expand(sizeof(data)); + if (data_ptr == nullptr) { + return ge::GRAPH_FAILED; + } + *reinterpret_cast(data_ptr) = data; + return ge::GRAPH_SUCCESS; + } + + template::value, int>::type = 0> + ge::graphStatus Append(const T *data, size_t append_num) { + size_t append_size; + if (ge::MulOverflow(sizeof(T), append_num, append_size)) { + return ge::GRAPH_MUL_OVERFLOW; + } + auto data_ptr = Expand(append_size); + if (data_ptr == nullptr) { + return ge::GRAPH_ADD_OVERFLOW; + } + // Expand中已保证append_size的合法性,为减少冗余判断逻辑此处使用memcpy直接拷贝 + (void)memcpy(data_ptr, data, append_size); + return ge::GRAPH_SUCCESS; + } + + /** + * 将tiling data拓展size大小 + * @param size 需要拓展的大小,单位为字节 + * @return 返回扩展对应size的内存地址 + */ + void *Expand(size_t size) { + size_t after_size; + if (ge::AddOverflow(data_size_, size, after_size)) { + return nullptr; + } + if (after_size > capacity_) { + return nullptr; + } + void *ret = reinterpret_cast(reinterpret_cast(data_) + data_size_); + data_size_ = after_size; + return ret; + } + + /** + * 通过最大容量创建一个TilingData类实例 + * @param cap_size 最大容量,单位为字节 + * @return 实例指针 + */ + static std::unique_ptr CreateCap(const size_t cap_size) { + size_t total_size; + if (ge::AddOverflow(sizeof(TilingData), cap_size, total_size)) { + return nullptr; + } + auto td_buf = std::unique_ptr(new (std::nothrow) uint8_t[total_size]()); + if (td_buf == nullptr) { + return nullptr; + } + const auto td = reinterpret_cast(td_buf.get()); + td->Init(cap_size, td_buf.get() + sizeof(TilingData)); + return td_buf; + } + /** + * 通过最大容量计算TilingData实例所占用的内存空间 + * @param cap_size 最大容量,单位为字节 + * @param total_size 内存空间,单位为字节 + * @return 成功返回ge::GRAPH_SUCCESS + */ + static ge::graphStatus CalcTotalSize(const size_t cap_size, size_t &total_size) { + if (ge::AddOverflow(sizeof(TilingData), cap_size, total_size)) { + return ge::GRAPH_FAILED; + } + return ge::GRAPH_SUCCESS; + } + /** + * 初始化TilingData + * @param cap_size 最大容量 + * @param data tiling data的地址 + */ + void Init(const size_t cap_size, void *const data) { + capacity_ = cap_size; + data_size_ = 0UL; + data_ = data; + (void)memset_s(reserved_, sizeof(reserved_), 0, sizeof(reserved_)); + } + + ge::graphStatus AppendConvertedAttrVal(const RuntimeAttrs *attrs, const size_t attr_index, + const AttrDataType src_type, const AttrDataType dst_type); + TilingData(const TilingData &) = delete; + TilingData(TilingData &&) = delete; + TilingData operator=(const TilingData &) = delete; + TilingData operator=(TilingData &&) = delete; + + private: + TilingData() = default; + size_t capacity_; + size_t data_size_; + void *data_; + uint8_t reserved_[40]; // Reserved field, 32+8, do not directly use when only 8-byte left +}; + +/** + * 向后添加tiling data,若添加超过可容纳的最大长度,则忽略本次操作 + * @tparam T 添加的tiling data的类型 + * @param out TilingData类实例 + * @param data 添加的tiling data的实例 + * @return + */ +template +TilingData &operator<<(TilingData &out, const T &data) { + out.Append(data); // we can not throw exception, so callers can not get the error information + return out; +} +static_assert(std::is_standard_layout::value, "The class TilingData must be a POD"); +} // namespace gert + +#endif // METADEF_CXX_INC_EXE_GRAPH_TILING_DATA_H_ diff --git a/inc/metadef/external/exe_graph/runtime/tiling_parse_context.h b/inc/metadef/external/exe_graph/runtime/tiling_parse_context.h new file mode 100644 index 0000000000000000000000000000000000000000..23b4c7e6760771fdfcdfa6ecfd7e9e0941679a85 --- /dev/null +++ b/inc/metadef/external/exe_graph/runtime/tiling_parse_context.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_TILING_PARSE_CONTEXT_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_TILING_PARSE_CONTEXT_H_ +#include "exe_graph/runtime/extended_kernel_context.h" +#include "graph/types.h" + +namespace fe { +class PlatFormInfos; +} // namespace fe + +namespace gert { +class TilingParseContext : public ExtendedKernelContext { + public: + /** + * 获取算子编译产生的json字符串 + * @return json字符串 + */ + const ge::char_t *GetCompiledJson() const { + return GetInputValue(0); + } + /** + * 获取`CompiledInfo`实例 + * @tparam T 实例类型,该类型需要与`IMPL_OP`注册时TilingParse的类型一致 + * @return 指向`CompiledInfo`实例的指针 + */ + template + auto GetCompiledInfo() -> T* { + auto av = GetOutput(0); + if (av == nullptr) { + return nullptr; + } + return av->GetValue(); + } + /** + * 获取 fe::PlatFormInfos 指针 + * @return fe::PlatFormInfos 指针 + */ + fe::PlatFormInfos *GetPlatformInfo() const { + return GetInputValue(1); + } +}; +static_assert(std::is_standard_layout::value, "The class TilingParseContext must be a POD"); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_TILING_PARSE_CONTEXT_H_ diff --git a/inc/metadef/external/flow_graph/data_flow.h b/inc/metadef/external/flow_graph/data_flow.h new file mode 100644 index 0000000000000000000000000000000000000000..b2dab9759edb31fa01732afca0f6dafa3105a84e --- /dev/null +++ b/inc/metadef/external/flow_graph/data_flow.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_FLOW_GRAPH_DATA_FLOW_H_ +#define INC_EXTERNAL_FLOW_GRAPH_DATA_FLOW_H_ + +#include "flow_attr.h" +#include "flow_graph.h" +#include "process_point.h" + +#endif // INC_EXTERNAL_FLOW_GRAPH_DATA_FLOW_H_ diff --git a/inc/metadef/external/flow_graph/flow_attr.h b/inc/metadef/external/flow_graph/flow_attr.h new file mode 100644 index 0000000000000000000000000000000000000000..e470a167f62987cd325126ce75a6f652e768e474 --- /dev/null +++ b/inc/metadef/external/flow_graph/flow_attr.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_FLOW_GRAPH_FLOW_ATTR_H_ +#define INC_EXTERNAL_FLOW_GRAPH_FLOW_ATTR_H_ + +#include + +namespace ge { +namespace dflow { +struct CountBatch { + int64_t batch_size = 0L; + int64_t slide_stride = 0L; + int64_t timeout = 0L; + int64_t batch_dim = 0L; + int32_t flag = 0; // eg: eos/seg + bool padding = false; + bool drop_remainder = false; + int8_t rsv[90] = {}; +}; + +struct TimeBatch { + int64_t time_window = 0L; + int64_t time_interval = 0L; + int64_t timeout = 0L; + int64_t batch_dim = -1; + int32_t flag = 0; // eg: eos/seg + bool padding = false; + bool drop_remainder = false; + int8_t rsv[90] = {}; +}; + +enum class DataFlowAttrType { + COUNT_BATCH = 0, + TIME_BATCH = 1, + INVALID = 2, +}; + +struct DataFlowInputAttr { + DataFlowAttrType attr_type; + void *attr_value; +}; +} // namespace dflow +} // namespace ge +#endif // INC_EXTERNAL_FLOW_GRAPH_FLOW_ATTR_H_ diff --git a/inc/metadef/external/flow_graph/flow_graph.h b/inc/metadef/external/flow_graph/flow_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..42f4a92e3c67c204d4c4acfba64e5acb6d76ed23 --- /dev/null +++ b/inc/metadef/external/flow_graph/flow_graph.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_FLOW_GRAPH_FLOW_GRAPH_H_ +#define INC_EXTERNAL_FLOW_GRAPH_FLOW_GRAPH_H_ + +#include +#include + +#include "flow_attr.h" +#include "graph/graph.h" +#include "process_point.h" + +namespace ge { +namespace dflow { +class FlowOperator : public ge::Operator { +public: + ~FlowOperator() override; + +protected: + FlowOperator(const char *name, const char *type); +}; + +class FlowData : public FlowOperator { +public: + FlowData(const char *name, int64_t index); + ~FlowData() override; +}; + +class FlowNodeImpl; +class FlowNode : public FlowOperator { +public: + FlowNode(const char *name, uint32_t input_num, uint32_t output_num); + ~FlowNode() override; + FlowNode &SetInput(uint32_t dst_index, const FlowOperator &src_op, uint32_t src_index = 0); + FlowNode &MapInput(uint32_t node_input_index, const ProcessPoint &pp, uint32_t pp_input_index, + const std::vector &attrs = {}); + FlowNode &MapOutput(uint32_t node_output_index, const ProcessPoint &pp, uint32_t pp_output_index); + FlowNode &AddPp(const ProcessPoint &pp); + FlowNode &SetBalanceScatter(); + FlowNode &SetBalanceGather(); + + private: + std::shared_ptr impl_; +}; + +class FlowGraphImpl; +using FlowGraphImplPtr = std::shared_ptr; +class FlowGraph { +public: + explicit FlowGraph(const char *name); + ~FlowGraph(); + const ge::Graph &ToGeGraph() const; + FlowGraph &SetInputs(const std::vector &inputs); + FlowGraph &SetOutputs(const std::vector &outputs); + FlowGraph &SetOutputs(const std::vector>> &output_indexes); + void SetGraphPpBuilderAsync(bool graphpp_builder_async); + const char *GetName() const; + + /** + * @brief Set the Contains N Mapping Node object. + * default is false, when contain a n-mapping node, you need call this method to set value to true. + * n-mapping means one2n(one input split to multi outputs), n2one((multi inputs combine to one output)) or + * n2n(multi inputs generate multi outputs) mapping. + * @param contains_n_mapping_node whether contain n-mapping node. + * @return FlowGraph& return current object + */ + FlowGraph &SetContainsNMappingNode(bool contains_n_mapping_node); + + /** + * @brief Set the inputs align attrs. + * + * @param align_max_cache_num align max cache num(a cache can save a set of inputs), + * 0 means align is not enable, max value is 1024. + * @param align_timeout align timeout(ms), -1 means never timeout, and over 0 and less than 600 * 1000(10 minitues). + * @param dropout_when_not_align whether dropout data when over max buf num or timeout, default value=false. + * @return FlowGraph& return current object + */ + FlowGraph &SetInputsAlignAttrs(uint32_t align_max_cache_num, int32_t align_timeout, + bool dropout_when_not_align = false); + /** + * @brief Set exception catch. + * + * @param enable_exception_catch whether enable user exception catch, default value=false. + * @return FlowGraph& return current object + */ + FlowGraph &SetExceptionCatch(bool enable_exception_catch); + + private: + FlowGraphImplPtr impl_; +}; +} // namespace dflow +} // namespace ge + +#endif // INC_EXTERNAL_FLOW_GRAPH_FLOW_GRAPH_H_ diff --git a/inc/metadef/external/flow_graph/process_point.h b/inc/metadef/external/flow_graph/process_point.h new file mode 100644 index 0000000000000000000000000000000000000000..b89f882256190559463a2cbfb2f58b7c6210fa96 --- /dev/null +++ b/inc/metadef/external/flow_graph/process_point.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_FLOW_GRAPH_PROCESS_POINT_H_ +#define INC_EXTERNAL_FLOW_GRAPH_PROCESS_POINT_H_ + +#include +#include +#include +#include "graph/graph.h" +#include "graph/types.h" + +namespace ge { +namespace dflow { +using GraphBuilder = std::function; +enum class ProcessPointType { + FUNCTION = 0, + GRAPH = 1, + INNER = 2, + INVALID = 3, +}; + +class ProcessPointImpl; +class ProcessPoint { +public: + virtual ~ProcessPoint(); + ProcessPointType GetProcessPointType() const; + const char_t *GetProcessPointName() const; + const char_t *GetCompileConfig() const; + virtual void Serialize(ge::AscendString &str) const = 0; + +protected: + ProcessPoint(const char_t *pp_name, ProcessPointType pp_type); + void SetCompileConfigFile(const char_t *json_file_path); + +private: + std::shared_ptr impl_; +}; + +class GraphPpImpl; +class GraphPp : public ProcessPoint { +public: + GraphPp(const char_t *pp_name, const GraphBuilder &builder); + ~GraphPp() override; + GraphPp &SetCompileConfig(const char_t *json_file_path); + GraphBuilder GetGraphBuilder() const; + void Serialize(ge::AscendString &str) const override; +private: + std::shared_ptr impl_; +}; + +class FunctionPpImpl; +class FunctionPp : public ProcessPoint { +public: + explicit FunctionPp(const char_t *pp_name); + ~FunctionPp() override; + FunctionPp &SetCompileConfig(const char_t *json_file_path); + FunctionPp &SetInitParam(const char_t *attr_name, const ge::AscendString &value); + FunctionPp &SetInitParam(const char_t *attr_name, const char_t *value); + FunctionPp &SetInitParam(const char_t *attr_name, const std::vector &value); + FunctionPp &SetInitParam(const char_t *attr_name, const int64_t &value); + FunctionPp &SetInitParam(const char_t *attr_name, const std::vector &value); + FunctionPp &SetInitParam(const char_t *attr_name, const std::vector> &value); + FunctionPp &SetInitParam(const char_t *attr_name, const float &value); + FunctionPp &SetInitParam(const char_t *attr_name, const std::vector &value); + FunctionPp &SetInitParam(const char_t *attr_name, const bool &value); + FunctionPp &SetInitParam(const char_t *attr_name, const std::vector &value); + FunctionPp &SetInitParam(const char_t *attr_name, const ge::DataType &value); + FunctionPp &SetInitParam(const char_t *attr_name, const std::vector &value); + FunctionPp &AddInvokedClosure(const char_t *name, const GraphPp &graph_pp); + FunctionPp &AddInvokedClosure(const char_t *name, const ProcessPoint &pp); + const std::map &GetInvokedClosures() const; + void Serialize(ge::AscendString &str) const override; +private: + std::shared_ptr impl_; +}; +} // namespace dflow +} // namespace ge +#endif // INC_EXTERNAL_FLOW_GRAPH_PROCESS_POINT_H_ diff --git a/inc/metadef/external/ge/framework/common/taskdown_common.h b/inc/metadef/external/ge/framework/common/taskdown_common.h new file mode 100644 index 0000000000000000000000000000000000000000..d31ee6af5e5b02c3b2a11d5250ec67b0bc51989b --- /dev/null +++ b/inc/metadef/external/ge/framework/common/taskdown_common.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ +#define INC_EXTERNAL_GE_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ + +namespace ge { +enum class ccKernelType : uint32_t { + CCE_AI_CORE = 0, /* cce aicore */ + CCE_AI_CPU = 1, /* cce aicpu */ + TE = 2, /* te operator */ + CUSTOMIZED = 3, /* customized operator */ + TE_AI_CORE = 4, /* te aicore operator */ + TE_AI_CPU = 5, /* te aicpu operator */ + AI_CPU = 6, /* aicpu */ + CUST_AI_CPU = 7, /* custom aicpu */ + HOST_CPU = 8, /* host cpu */ + DVPP = 9, /* dvpp */ + AI_CPU_KFC = 10, /* aicpu kfc */ + MIX_AICORE = 11, + MIX_VECTOR_CORE = 12, /* vector core only */ + INVALID = 10000 /* unknown kernel type */ +}; +} // namespace ge + +#endif // INC_EXTERNAL_GE_FRAMEWORK_COMMON_TASKDOWN_COMMON_H_ diff --git a/inc/metadef/external/ge/ge_allocator.h b/inc/metadef/external/ge/ge_allocator.h new file mode 100644 index 0000000000000000000000000000000000000000..cf0f3470f1197d2524c2e8d56743f44c46137c0a --- /dev/null +++ b/inc/metadef/external/ge/ge_allocator.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXTERNAL_GE_ALLOCATOR_H_ +#define METADEF_CXX_INC_EXTERNAL_GE_ALLOCATOR_H_ +#include +#include +namespace ge { +class MemBlock; +class Allocator { +public: + Allocator() = default; + virtual ~Allocator() = default; + Allocator(const Allocator &) = delete; + Allocator &operator=(const Allocator &) = delete; + + virtual MemBlock *Malloc(size_t size) = 0; + virtual void Free(MemBlock *block) = 0; + + // Apply for suggested address memory, default ignore suggested address + virtual MemBlock *MallocAdvise(size_t size, void *addr) { + (void)addr; + return Malloc(size); + } +}; + +class MemBlock { +public: + MemBlock(Allocator &allocator, void *addr, size_t block_size) + : allocator_(allocator), addr_(addr), count_(1U), block_size_(block_size) {} + virtual ~MemBlock() = default; + const void *GetAddr() const { + return addr_; + } + void *GetAddr() { + return addr_; + } + size_t GetSize() const { + return block_size_; + } + void SetSize(const size_t mem_size) { + block_size_ = mem_size; + } + void Free() { + if (GetCount() > 0U) { + if (SubCount() == 0U) { + return allocator_.Free(this); + } + } + } + + size_t AddCount() { + return ++count_; + } + size_t SubCount() { + return --count_; + } + size_t GetCount() const { + return count_; + } +private: + Allocator &allocator_; + void *addr_; + size_t count_; + size_t block_size_; +}; + +using AllocatorPtr = std::shared_ptr; +} // namespace ge +#endif // METADEF_CXX_INC_EXTERNAL_GE_ALLOCATOR_H_ diff --git a/inc/metadef/external/ge_common/ge_api_error_codes.h b/inc/metadef/external/ge_common/ge_api_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..c31d54f1200d073becfcdcb0a4f300e319cf1009 --- /dev/null +++ b/inc/metadef/external/ge_common/ge_api_error_codes.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_COMMON_GE_API_ERROR_CODES_H_ +#define INC_EXTERNAL_GE_COMMON_GE_API_ERROR_CODES_H_ + +#include +#include +#include "ge_error_codes.h" +#include "ge_api_types.h" +#include "graph/ascend_string.h" + +#ifdef __GNUC__ + +#ifdef NO_METADEF_ABI_COMPATIABLE +#define ATTRIBUTED_DEPRECATED(replacement) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#endif +#else +#ifdef NO_METADEF_ABI_COMPATIABLE +#define ATTRIBUTED_DEPRECATED(replacement) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif +#endif + +#ifndef GE_ERRORNO_DEFINE +#define GE_ERRORNO_DEFINE(runtime, type, level, sysid, modid, name, value) \ + constexpr ge::Status name = ((static_cast(0xFFU & (static_cast(runtime))) << 30U) | \ + (static_cast(0xFFU & (static_cast(type))) << 28U) | \ + (static_cast(0xFFU & (static_cast(level))) << 25U) | \ + (static_cast(0xFFU & (static_cast(sysid))) << 17U) | \ + (static_cast(0xFFU & (static_cast(modid))) << 12U) | \ + (static_cast(0x0FFFU) & (static_cast(value)))) +#endif + +#ifndef GE_ERRORNO_EXTERNAL +#define GE_ERRORNO_EXTERNAL(name, desc) const ErrorNoRegisterar g_errorno_##name((name), (desc)) +#endif + +#ifndef GE_ERRORNO +// Code compose(4 byte), runtime: 2 bit, type: 2 bit, level: 3 bit, sysid: 8 bit, modid: 5 bit, value: 12 bit +#define GE_ERRORNO(runtime, type, level, sysid, modid, name, value, desc) \ + GE_ERRORNO_DEFINE(runtime, type, level, sysid, modid, name, value); \ + GE_ERRORNO_EXTERNAL(name, desc) + +namespace ge { +class GE_FUNC_VISIBILITY StatusFactory { + public: + static StatusFactory *Instance() { + static StatusFactory instance; + return &instance; + } + + void RegisterErrorNo(const uint32_t err, const std::string &desc) { + // Avoid repeated addition + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = desc; + } + + void RegisterErrorNo(const uint32_t err, const char *const desc) { + if (desc == nullptr) { + return; + } + const std::string error_desc = desc; + if (err_desc_.find(err) != err_desc_.end()) { + return; + } + err_desc_[err] = error_desc; + } + + std::string GetErrDesc(const uint32_t err) { + const auto iter_find = static_cast::const_iterator>(err_desc_.find(err)); + if (iter_find == err_desc_.cend()) { + return ""; + } + return iter_find->second; + } + + AscendString GetErrDescV2(const uint32_t err) { + const auto iter_find = static_cast::const_iterator>(err_desc_.find(err)); + if (iter_find == err_desc_.cend()) { + return AscendString(""); + } + return AscendString(iter_find->second.c_str()); + } + + protected: + StatusFactory() = default; + ~StatusFactory() = default; + + private: + std::map err_desc_; +}; + +class GE_FUNC_VISIBILITY ErrorNoRegisterar { + public: + ErrorNoRegisterar(const uint32_t err, const std::string &desc) noexcept { + StatusFactory::Instance()->RegisterErrorNo(err, desc); + } + ErrorNoRegisterar(const uint32_t err, const char *const desc) noexcept { + StatusFactory::Instance()->RegisterErrorNo(err, desc); + } + ~ErrorNoRegisterar() = default; +}; + +// General error code +GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); +GE_ERRORNO(0b11, 0b11, 0b111, 0xFFU, 0b11111, FAILED, 0xFFFU, "failed"); /*lint !e401*/ +} // namespace ge +#endif + +namespace ge { + GE_ERRORNO_DEFINE(0b01, 0b01, 0b000, 8, 0, END_OF_SEQUENCE, 7); +} +#endif // INC_EXTERNAL_GE_COMMON_GE_API_ERROR_CODES_H_ diff --git a/inc/metadef/external/ge_common/ge_api_types.h b/inc/metadef/external/ge_common/ge_api_types.h new file mode 100644 index 0000000000000000000000000000000000000000..cd04ca245aa74532882dfeba71886f820db7963e --- /dev/null +++ b/inc/metadef/external/ge_common/ge_api_types.h @@ -0,0 +1,752 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_COMMON_GE_API_TYPES_H_ +#define INC_EXTERNAL_GE_COMMON_GE_API_TYPES_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/tensor.h" +#include "graph/types.h" + +#ifndef GE_API_TYPES_DEF +#define GE_API_TYPES_DEF +namespace ge { +// Option key: graph run mode +const char_t *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; +const char_t *const OPTION_DEVICE_TYPE = "ge.deviceType"; +// Option key: ome init +const char_t *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; +const char_t *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId"; +const char_t *const OPTION_EXEC_JOB_ID = "ge.exec.jobId"; +const char_t *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom"; +const char_t *const OPTION_EXEC_IS_USEHVD = "ge.exec.isUseHvd"; +const char_t *const OPTION_EXEC_RANK_ID = "ge.exec.rankId"; +const char_t *const OPTION_EXEC_POD_NAME = "ge.exec.podName"; +const char_t *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; +const char_t *const OPTION_EXEC_RANK_TABLE_FILE = "ge.exec.rankTableFile"; +const char_t *const GE_AICPU_FLAG = "ge.aicpuFlag"; +const char_t *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; +// Dump flag and para +const char_t *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; +const char_t *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; +const char_t *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; +const char_t *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; +const char_t *const OPTION_EXEC_DUMP_DATA = "ge.exec.dumpData"; +const char_t *const OPTION_EXEC_DUMP_LAYER = "ge.exec.dumpLayer"; +const char_t *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; +const char_t *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; +const char_t *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; +const char_t *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; +const char_t *const OPTION_EXEC_ENABLE_EXCEPTION_DUMP = "ge.exec.enable_exception_dump"; +const char_t *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; +const char_t *const OPTION_EXEC_PROFILING_FPPONIT_OPTIONS = "ge.exec.profilingFpPointOptions"; +const char_t *const OPTION_EXEC_PROFILING_BPPONIT_OPTIONS = "ge.exec.profilingBpPointOptions"; +// profiling flag +const char_t *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; +const char_t *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; +// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 +const char_t *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; +const char_t *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; +const char_t *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; +const char_t *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOptimization"; +// Dynamic input flag. ge.exec.dynamicInput=1, means enable dynaimc input, +// ge.exec.dynamicGraphExecuteMode, dynamic_execute[default] +const char_t *const OPTION_EXEC_DYNAMIC_INPUT = "ge.exec.dynamicInput"; +const char_t *const OPTION_EXEC_DYNAMIC_EXECUTE_MODE = "ge.exec.dynamicGraphExecuteMode"; +const char_t *const OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE = "ge.exec.dataInputsShapeRange"; +const char_t *const OPTION_EXEC_ENABLE_COPY_OUTPUT_ADDR = "ge.exec.enableCopyOutputAddr"; +const char_t *const OPTION_EXEC_GRAPH_EXEC_TIMEOUT = "ge.exec.graphExecTimeout"; +const char_t *const OPTION_EXEC_OPTIMIZE_SHAPE = "ge.exec.dataInputsOptimizeShape"; +// dynamic graph parallel mode in heterogeneous scene +// "0": execute in parallel with independent streams(default), +// "1": execute in serial, +// "2": execute in parallel with single default stream (not recommended) +const char_t *const OPTION_EXEC_DYNAMIC_GRAPH_PARALLEL_MODE = "ge.experiment.dynamicGraphParallelMode"; + +// Option key: memory init +const char_t *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; +const char_t *const VARIABLE_MEMORY_MAX_SIZE = "ge.variableMemoryMaxSize"; +const char_t *const OPTION_EXEC_REUSE_ZERO_COPY_MEMORY = "ge.exec.reuseZeroCopyMemory"; +const char_t *const OPTION_INPUT_REUSE_MEM_INDEXES = "ge.exec.inputReuseMemIndexes"; +const char_t *const OPTION_OUTPUT_REUSE_MEM_INDEXES = "ge.exec.outputReuseMemIndexes"; +const char_t *const OPTION_GRAPH_IO_MEM_ALLOC_MODE = "ge.exec.graphIOMemAllocMode"; +const char_t *const OPTION_FLOW_GRAPH_MEMORY_MAX_SIZE = "ge.flowGraphMemMaxSize"; + +const std::string ATOMIC_CLEAN_POLICY = "ge.exec.atomicCleanPolicy"; +const std::string MEMORY_OPTIMIZATION_POLICY = "ge.exec.memoryOptimizationPolicy"; +const std::string STATIC_MEMORY_POLICY = "ge.exec.staticMemoryPolicy"; +const char_t *const OPTION_FEATURE_BASE_REFRESHABLE = "ge.featureBaseRefreshable"; +const char_t *const OPTION_CONST_LIFECYCLE = "ge.constLifecycle"; + +const char_t *const OPTION_EXEC_LOGICAL_DEVICE_CLUSTER_DEPLOY_MODE = "ge.exec.logicalDeviceClusterDeployMode"; +const char_t *const OPTION_EXEC_LOGICAL_DEVICE_ID = "ge.exec.logicalDeviceId"; +const char_t *const OPTION_EXEC_MODEL_DEPLOY_MODE = "ge.exec.modelDeployMode"; +const char_t *const OPTION_EXEC_MODEL_DEPLOY_DEVICELIST = "ge.exec.modelDeployDevicelist"; +const char_t *const OPTION_EXEC_ENABLE_FUSION = "ge.exec.enableFusion"; + +// if the input_size is not bigger than this value, the H2D of those inputs will be merged (H2H2D) +const char_t *const OPTION_EXEC_INPUT_FUSION_SIZE = "ge.exec.input_fusion_size"; + +const std::string OPTION_EXEC_CM_CHIEF_IP = "ge.cmChiefIp"; +const std::string OPTION_EXEC_CM_CHIEF_PORT = "ge.cmChiefPort"; +const std::string OPTION_EXEC_CM_CHIEF_DEVICE = "ge.cmChiefWorkerDevice"; +const std::string OPTION_EXEC_CM_WORKER_IP = "ge.cmWorkerIp"; +const std::string OPTION_EXEC_CM_WORKER_SIZE = "ge.cmWorkerSize"; + +const std::string OPTION_NAME_MAP = "ge.optionNameMap"; + +const std::string OPTION_EXEC_STREAM_SYNC_TIMEOUT = "stream_sync_timeout"; +const std::string OPTION_EXEC_EVENT_SYNC_TIMEOUT = "event_sync_timeout"; + +// Option key: embedding service +const char_t *const OPTION_EXEC_PS_ID = "ge.exec.psId"; +const char_t *const OPTION_EXEC_CLUSTER_SPEC = "ge.exec.clusterSpec"; +const char_t *const OPTION_EXEC_RANK_TABLE_ADDR = "ge.exec.rankTableAddr"; +const char_t *const OPTION_EXEC_ROLE_TABLE_ADDR = "ge.exec.roleTableAddr"; +const char_t *const OPTION_EXEC_RANK_TABLE_LEN = "ge.exec.rankTableLen"; +const char_t *const OPTION_EXEC_ROLE_TABLE_LEN = "ge.exec.roleTableLen"; +const char_t *const OPTION_EXEC_WORKER_NUM = "ge.exec.workerNum"; +const char_t *const OPTION_MAX_KEY_NUM = "ge.max_num"; +const char_t *const OPTION_EMBEDDING_DIM = "ge.embedding_dim"; +const char_t *const OPTION_USE_COUNTER_FILTER = "ge.use_counter_filter"; + +// Option key: Offload +constexpr char_t const OPTION_EXEC_RANK_MAP[] = "ge.exec.rankMap"; + +// Option key: enable engine parallel or not +constexpr char_t const OPTION_EXEC_ENABLE_ENGINE_PARALLEL[] = "ge.exec.enableEngineParallel"; +constexpr char_t const OPTION_EXEC_ENGINE_PARALLEL_CONFIG_PATH[] = "ge.exec.engineParallelConfigPath"; +constexpr char_t const OPTION_EXEC_IS_IN_SHARD_GRAPH[] = "ge.exec.isInShardGraph"; + +// Option key: host env os & cpu +const char_t *const OPTION_HOST_ENV_OS = "ge.host_env_os"; +const char_t *const OPTION_HOST_ENV_CPU = "ge.host_env_cpu"; +const char_t *const OPTION_OP_DEPENDENCY_IN_OM = "ge.op_dependency_in_om"; + +// config value should be a exist dir +const char_t *const OPTION_GRAPH_COMPILER_CACHE_DIR = "ge.graph_compiler_cache_dir"; +// graph unique key +const char_t *const OPTION_GRAPH_KEY = "ge.graph_key"; + +// optimizations to disable, split by comma +const char_t *const OPTION_DISABLE_OPTIMIZATIONS = "ge.disableOptimizations"; +const char_t *const ENABLE_GRAPH_PARALLEL = "ge.enableGraphParallel"; +const char_t *const GRAPH_PARALLEL_OPTION_PATH = "ge.graphParallelOptionPath"; +const std::string DISTRIBUTED_CLUSTER_BUILD = "ge.distributed_cluster_build"; +const std::string MODEL_RELATION_CONFIG = "ge.offline_model_relation"; +const std::string CLUSTER_CONFIG = "ge.cluster_config"; +const std::string OPTION_HCCL_COMPILER_OFFLINE = "ge.offline_hccl_compile"; + +// option for screen log +constexpr const char_t *OPTION_SCREEN_PRINT_MODE = "ge.screen_print_mode"; + +const char_t *const INPUT_FORMAT = "input_format"; +const char_t *const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; +const char_t *const OPTION_EXEC_VARIABLE_ACC = "ge.exec.variable_acc"; +const char_t *const OPTION_EXEC_OP_DEBUG_CONFIG = "ge.exec.opDebugConfig"; +const char_t *const OPTION_JOB_TYPE = "ge.jobType"; +const char_t *const OPTION_TUNINGPATH = "ge.tuningPath"; +const char_t *const OPTION_AOE_CONFIG_FILE = "ge.aoe_config_file"; +const char_t *const OPTION_SESSION_DEVICE_ID = "ge.session_device_id"; +const char_t *const OPTION_DISTRIBUTE_CONFIG = "distribute_config"; +const char_t *const OPTION_EXEC_HCCL_EXECUTE_TIMEOUT = "ge.exec.hcclExecuteTimeOut"; +const char_t *const OPTION_EXEC_PLACEMENT = "ge.exec.placement"; +const char_t *const OPTION_IS_VAR_INIT_GRAPH = "ge.exec.isVarInitGraph"; +const char_t *const OPTION_EXEC_OVERFLOW = "ge.exec.overflow"; +const char_t *const OPTION_DATAFLOW_DEPLOY_INFO_PATH = "ge.experiment.data_flow_deploy_info_path"; +const char_t *const OPTION_MOMORY_POOL_THRESHOLD = "ge.experiment.memory_pool_threshold"; +const char_t *const OPTION_HCCL_ALGORITHM = "HCCL_algorithm"; +const char_t *const OPTION_ES_CLUSTER_CONFIG = "ge.esClusterConfig"; +const char_t *const OPTION_EXECUTE_TIMES = "execute_times"; +const char_t *const OPTION_ES_MAX_REMOTEOP_NUM_PER_STREAM = "es_max_remoteop_num_per_stream"; +const char_t *const OPTION_HOST_SCHEDULING_MAX_THRESHOLD = "ge.exec.hostSchedulingMaxThreshold"; + +// option for experimental +const char_t *const OPTION_STATIC_MODEL_OPS_LOWER_LIMIT = "ge.exec.static_model_ops_lower_limit"; + +namespace configure_option { +const char_t *const STREAM_NUM = "ge.streamNum"; +const char_t *const HEAD_STREAM = "ge.headStream"; +const char_t *const AC_PARALLEL_ENABLE = "ac_parallel_enable"; +const char_t *const PERF_LEVEL = "ge.perfLevel"; +const char_t *const ENCRYPT_MODE = "ge.encryptMode"; +const char_t *const EK_FILE = "ge.ekFile"; +const char_t *const CERT_FILE = "ge.certFile"; +const char_t *const HW_KEY_FILE = "ge.hwKeyFile"; +const char_t *const PRIVATE_KEY_FILE = "ge.privateKeyFile"; +const char_t *const FRAMEWORK_TYPE = "ge.frameworkType"; +const char_t *const CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; +const char_t *const INSERT_OP_FILE = "ge.insertOpFile"; +const char_t *const OUTPUT_NODE_NAME = "ge.outputNodeName"; +const char_t *const COMPRESS_FLAG = "ge.compressFlag"; +const char_t *const PRECISION_MODE = "ge.exec.precision_mode"; +const char_t *const PRECISION_MODE_V2 = "ge.exec.precision_mode_v2"; +const char_t *const SINGLE_OP_FLAG = "ge.exec.single_op"; +const char_t *const TRAIN_FLAG = "ge.trainFlag"; +const char_t *const RUN_FLAG = "ge.runFlag"; +const char_t *const LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; +const char_t *const TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; +const char_t *const DDK_VERSION_FLAG = "ge.DDK_version"; +const char_t *const GE_FE_FLAG = "ge.feFlag"; +const char_t *const STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; +const char_t *const OUTPUT_DATATYPE = "ge.outputDatatype"; +const char_t *const OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; +const char_t *const OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; +const char_t *const HCOM_PARALLEL = "ge.hcomParallel"; +const char_t *const AUTO_TUNE_MODE = "ge.autoTuneMode"; +const char_t *const SOC_VERSION = "ge.socVersion"; +const char_t *const VIRTUAL_TYPE = "ge.virtual_type"; +const char_t *const CORE_TYPE = "ge.engineType"; +const char_t *const AICORE_NUM = "ge.aicoreNum"; +const char_t *const L1_FUSION = "ge.l1Fusion"; +const char_t *const BUFFER_OPTIMIZE = "ge.bufferOptimize"; +const char_t *const ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; +const char_t *const ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; +const char_t *const FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; +const char_t *const SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; +const char_t *const ORIGINAL_MODEL_FILE = "ge.originalModelFile"; +const char_t *const INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; +const char_t *const OP_DEBUG_LEVEL = "ge.opDebugLevel"; +const char_t *const PERFORMANCE_MODE = "ge.performance_mode"; +const char_t *const SHAPE_GENERALIZED_BUILD_MODE = "ge.shape_generalized_build_mode"; +const char_t *const MODIFY_MIXLIST = "ge.exec.modify_mixlist"; +const char_t *const OP_PRECISION_MODE = "ge.exec.op_precision_mode"; +const char_t *const ALLOW_HF32 = "ge.exec.allow_hf32"; +const char_t *const CUSTOMIZE_DTYPES = "ge.customizeDtypes"; +const char_t *const COMPRESSION_OPTIMIZE_CONF = "ge.compressionOptimizeConf"; +const char_t *const OP_DEBUG_CONFIG = "op_debug_config"; +const char_t *const ATOMIC_CLEAN_POLICY = "ge.exec.atomicCleanPolicy"; +const char_t *const STATIC_MEMORY_POLICY = "ge.exec.staticMemoryPolicy"; +const char_t *const EXTERNAL_WEIGHT = "ge.externalWeight"; +static const char_t *const QUANT_DUMPABLE = "ge.quant_dumpable"; +const char_t *const QUANT_BIAS_OPTIMIZE = "ge.experiment.quant_bias_optimize"; +static const char_t *const DETERMINISTIC = "ge.deterministic"; +const char_t *const OP_DEBUG_OPTION = "op_debug_option"; +const char_t *const TILING_SCHEDULE_OPTIMIZE = "ge.tiling_schedule_optimize"; +const char_t *const GRAPH_MAX_PARALLEL_MODEL_NUM = "ge.graphMaxParallelModelNum"; +} // namespace configure_option +// Configure stream num by Session constructor options param, +// its value should be int32_t type, default value is "1" +const std::string STREAM_NUM = "ge.streamNum"; + +// Configure add head stream to model. +// its value should be "0" or "1", default value is "0" +const std::string HEAD_STREAM = "ge.headStream"; + +// Configure engines such as Aicpu to compute parallelly with other engines in dynamic shape graphs. +// its value should be "0" or "1", default value is "0" +const std::string AC_PARALLEL_ENABLE = "ac_parallel_enable"; + +// Configure perf level by Session constructor options param, +// its value please see enum PerfLevel, default value is "4" +const std::string PERF_LEVEL = "ge.perfLevel"; + +// Configure encrypt mode by Session constructor options param, +// its value should be int32_t type, default value is "-1" +const std::string ENCRYPT_MODE = "ge.encryptMode"; + +// configure ek file by Session constructor options param, +// its value should be file path, default value is "" +const std::string EK_FILE = "ge.ekFile"; + +// Configure cert file by Session constructor options param, +// its value should be file path, default value is "" +const std::string CERT_FILE = "ge.certFile"; + +// Configure hw key file by Session constructor options param, +// its value should be file path, default value is "" +const std::string HW_KEY_FILE = "ge.hwKeyFile"; + +// Configure private file by Session constructor options param, +// its value should be file path, default value is "" +const std::string PRIVATE_KEY_FILE = "ge.privateKeyFile"; + +// Configure framework type by Session constructor options param, +// its value please see enum FrameworkType, default value is "3" +const std::string FRAMEWORK_TYPE = "ge.frameworkType"; + +// Configure calibration info file by Session constructor options param, +// its value should be file path, default value is "" +const std::string CALIBRATION_CONF_FILE = "ge.calibrationConfFile"; + +// Configure insert op info file by Session constructor options param, +// its value should be file path, default value is "" +const std::string INSERT_OP_FILE = "ge.insertOpFile"; + +// Configure output node name by Session constructor options param, +// its value should be std::string type, default value is "" +const std::string OUTPUT_NODE_NAME = "ge.outputNodeName"; + +// Configure weight compress flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string COMPRESS_FLAG = "ge.compressFlag"; + +const std::string PRECISION_MODE = "ge.exec.precision_mode"; + +const std::string PRECISION_MODE_V2 = "ge.exec.precision_mode_v2"; + +const std::string TUNE_DEVICE_IDS = "ge.exec.tuneDeviceIds"; + +// Configure single op flag for FE +// its value should be "0" or "1", default value is "0" +const std::string SINGLE_OP_FLAG = "ge.exec.single_op"; + +// Configure train flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string TRAIN_FLAG = "ge.trainFlag"; + +// Configure run flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string RUN_FLAG = "ge.runFlag"; + +// Configure run flag by Session constructor options param, +// its value should be "0" or "1", default value is "0" +// this option is to enable local framework op feature +const std::string LOCAL_FMKOP_FLAG = "ge.enabledLocalFmkop"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain the TBE op plugin path +const std::string TBE_PLUGIN_PATH_FLAG = "ge.TBE_plugin_path"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain the DDK Version info +const std::string DDK_VERSION_FLAG = "ge.DDK_version"; + +// Configure run flag by Session constructor options param, +// its value should be a path +// this option is to obtain fe flag +const std::string GE_FE_FLAG = "ge.feFlag"; + +// Configure stream max parallel num only by Session constructor options param, +// its value should be stream:int, such as "DNN_V100:2,DNN_HCCL:3", +// default value is "1", such as "DNN_V100:1,DNN_HCCL:1" +// this option is to obtain stream max parallel num +const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; + +// congigure outputDatatype to setting net output type +const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; + +// congigure opSelectImplmode to setting op select implmode +const std::string OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; + +// congigure optypelist_for_implmode to setting which op use implmode +const std::string OPTYPELIST_FOR_IMPLMODE = "ge.optypelistForImplmode"; + +// configure whether to enable hcom parallel by session constructor options param, +// its value should be "0" or "1", default value is "0" +const std::string HCOM_PARALLEL = "ge.hcomParallel"; + +// configure whether to use dynamic batch size +const char_t *const kDynamicBatchSize = "ge.dynamicBatchSize"; + +// configure threshold of fusion data size for communication op +const std::string FUSION_TENSOR_SIZE = "ge.fusionTensorSize"; + +const std::string INPUT_SHAPE = "ge.inputShape"; + +const std::string OUTPUT_MAX_SIZE = "ge.outputMaxSize"; + +const std::string DYNAMIC_NODE_TYPE = "ge.dynamicNodeType"; +// configure whether to use dynamic image size +const char_t *const kDynamicImageSize = "ge.dynamicImageSize"; + +// Configure whether to use dynamic dims +const char_t *const kDynamicDims = "ge.dynamicDims"; + +// Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, +// example: GA|RL, support configure multiple, split by | +const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; + +// Configure soc version , example: "Ascend310" +const std::string SOC_VERSION = "ge.socVersion"; + +// configure whether to enable virtualization, +// its value should be "0" or "1", default value is "0" +const std::string VIRTUAL_TYPE = "ge.virtual_type"; + +// Configure core type "VectorEngine", default value is "AIcoreEngine" +const std::string CORE_TYPE = "ge.engineType"; + +// Configure graph exclude one or more engines +const std::string EXCLUDE_ENGINES = "ge.exec.exclude_engines"; + +// Configure AICORE NUM +const std::string AICORE_NUM = "ge.aicoreNum"; + +// Configure L1FUSION +const std::string L1_FUSION = "ge.l1Fusion"; + +// Configure l1,l2,and others optimize option +const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; + +// Configure Small Channel flag +const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; + +// Configure Compress Weight flag +const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; + +// Configure Sparse Matrix Weight flag +const std::string ENABLE_SPARSE_MATRIX_WEIGHT = "ge.enableSparseMatrixWeight"; + +// Configure fusion switch file path +const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; + +// Configure compression optimize file path +const std::string COMPRESSION_OPTIMIZE_CONF = "ge.compressionOptimizeConf"; + +// Configure customize dtypes path +const std::string CUSTOMIZE_DTYPES = "ge.customizeDtypes"; + +// Configure switch for op debug config such as op memory detection +const std::string OP_DEBUG_CONFIG = "op_debug_config"; + +// Save original model +const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; + +// Save original model file name +const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; + +const char_t *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; +const char_t *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; +const char_t *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; + +// Configure for print op pass +// Its value should be "0" or "1", default value is "1" +const char_t *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; + +// Configure operator compilation path +// Its value should be file path, default value is "./" +const char_t *const DEBUG_DIR = "ge.debugDir"; + +// Configure switch for op status check such as overflow +// Its value should be true of flase +const char_t *const STATUS_CHECK = "ge.status_check"; + +// Configure operator compiler cache path +// Its value should be file path, default value is "./" +const char_t *const OP_COMPILER_CACHE_DIR = "ge.op_compiler_cache_dir"; + +// Configure operator compiler cache mode +// Its value should be "disable", "enable" or "force", default value is "disable" +const char_t *const OP_COMPILER_CACHE_MODE = "ge.op_compiler_cache_mode"; + +// Configure build model type. FE need this option to judge inner model or not +// Its value should be "true" or "false" +const char_t *const BUILD_INNER_MODEL = "ge.build_inner_model"; + +// Configure whether to use single stream. +// Its value should be "true" or "false", default value is "false" +const char_t *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; + +// Configure input fp16 nodes +const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; + +// Configure debug level, its value should be 0(default), 1 or 2. +// 0: close debug; 1: open TBE compiler; 2: open ccec compiler +const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; + +// configure op compile param, example: "oom,-g" +const std::string OP_DEBUG_OPTION = "op_debug_option"; + +// Configure model bank path +const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path"; + +// Configure display_model_info flag +const std::string DISPLAY_MODEL_INFO = "ge.display_model_info"; + +// Configure op bank path +const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path"; +const std::string OP_BANK_UPDATE_FLAG = "ge.op_bank_update"; + +// Configure for fix hcombroadcast format. +// when config model multi, broadcast format should be fixed +// 0: data multi; 1: model multi; +const std::string HCOM_MULTI_MODE = "ge.hcomMultiMode"; + +// atc and ir option +const char_t *const INPUT_SHAPE_RANGE = "input_shape_range"; + +// Configure express high compile performance or high execute performance +// normal: no need to compile, used saved .o files directly +// high: need to recompile, high execute performance mode +const std::string PERFORMANCE_MODE = "ge.performance_mode"; + +// For selecting the mode of shape generalization when build graph. +// shape_generalized: Shape will be generalized during graph build. +// shape_precise: Shape will not be generalized, use precise shape. +const std::string SHAPE_GENERALIZED_BUILD_MODE = "ge.shape_generalized_build_mode"; + +const std::string JIT_COMPILE = "ge.jit_compile"; + +const std::string MODIFY_MIXLIST = "ge.exec.modify_mixlist"; + +const std::string OP_PRECISION_MODE = "ge.exec.op_precision_mode"; + +const std::string ALLOW_HF32 = "ge.exec.allow_hf32"; + +const std::string OP_WAIT_TIMEOUT = "ge.exec.opWaitTimeout"; + +const std::string OP_EXECUTE_TIMEOUT = "ge.exec.opExecuteTimeout"; + +const char_t *const FILE_CONSTANT_PATH = "ge.exec.value_bins"; + +// Configure whether convert const to fileconstant and save weight to file. +// Its value should be "0" or "1", default value is "0". +const std::string EXTERNAL_WEIGHT = "ge.externalWeight"; + +const std::string DETERMINISTIC = "ge.deterministic"; + +const std::string QUANT_DUMPABLE = "ge.quant_dumpable"; + +const std::string QUANT_BIAS_OPTIMIZE = "ge.experiment.quant_bias_optimize"; + +// option for tiling sink +const std::string TILING_SCHEDULE_OPTIMIZE = "ge.tiling_schedule_optimize"; + +const std::string GRAPH_MAX_PARALLEL_MODEL_NUM = "ge.graphMaxParallelModelNum"; + +constexpr char_t EVENT[] = "ge.event"; + +// option for optimization options +const char_t *const OO_LEVEL = "ge.oo.level"; +const char_t *const OO_CONSTANT_FOLDING = "ge.oo.constantFolding"; +const char_t *const OO_DEAD_CODE_ELIMINATION = "ge.oo.deadCodeElimination"; + +// Configure statistics of the graph compiler +const char_t *const OPTION_EXPORT_COMPILE_STAT = "ge.exportCompileStat"; + +// Graph run mode +enum GraphRunMode : std::int32_t { PREDICTION = 0, TRAIN }; +// Input/Output tensor info +struct InputTensorInfo { + uint32_t data_type; // data type + std::vector dims; // shape description + void *data; // tensor data + int64_t length; // tensor length +}; + +struct OutputTensorInfo { + uint32_t data_type{}; // data type + std::vector dims; // shape description + std::unique_ptr data; // tensor data + int64_t length{}; // tensor length + OutputTensorInfo() : dims({}), data(nullptr) {} + OutputTensorInfo(OutputTensorInfo &&out) noexcept = default; + OutputTensorInfo &operator=(OutputTensorInfo &&out)& noexcept { + if (this != &out) { + data_type = out.data_type; + dims = out.dims; + data = std::move(out.data); + length = out.length; + } + return *this; + } + OutputTensorInfo(const OutputTensorInfo &) = delete; + OutputTensorInfo &operator=(const OutputTensorInfo &)& = delete; +}; + +struct ModelDistibuteDesc { + uint32_t logic_device_number; +}; + +enum class MemoryType : std::int64_t { + /* + * call aclrtMalloc with aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST, + * ACL_MEM_MALLOC_HUGE_ONLY, ACL_MEM_MALLOC_NORMAL_ONLY + */ + MEMORY_TYPE_DEFAULT, + /* + * call aclrtMalloc with aclrtMemMallocPolicy::ACL_MEM_MALLOC_HUGE_FIRST_P2P, + * ACL_MEM_MALLOC_HUGE_ONLY_P2P, ACL_MEM_MALLOC_NORMAL_ONLY_P2P + */ + MEMORY_TYPE_P2P +}; + +using Status = uint32_t; +using RunAsyncCallback = std::function &)>; + +// for ir build +namespace ir_option { +static const char_t *const INPUT_FORMAT = "input_format"; +static const char_t *const INPUT_SHAPE = "input_shape"; +static const char_t *const INPUT_SHAPE_RANGE = ge::INPUT_SHAPE_RANGE; +static const char_t *const OP_NAME_MAP = "op_name_map"; +static const char_t *const IS_DYNAMIC_INPUT = "is_dynamic_input"; +static const char_t *const IS_INPUT_ADJUST_HW_LAYOUT = "is_input_adjust_hw_layout"; +static const char_t *const IS_OUTPUT_ADJUST_HW_LAYOUT = "is_output_adjust_hw_layout"; +static const char_t *const ENABLE_SCOPE_FUSION_PASSES = "enable_scope_fusion_passes"; +static const char_t *const OUTPUT = "output"; +static const char_t *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; +static const char_t *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; +static const char_t *const DYNAMIC_DIMS = kDynamicDims; +static const char_t *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); +static const char_t *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); +static const char_t *const PRECISION_MODE_V2 = ge::PRECISION_MODE_V2.c_str(); +static const char_t *const TUNE_DEVICE_IDS = ge::TUNE_DEVICE_IDS.c_str(); +static const char_t *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; +static const char_t *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); +static const char_t *const CORE_TYPE = ge::CORE_TYPE.c_str(); +static const char_t *const SOC_VERSION = ge::SOC_VERSION.c_str(); +static const char_t *const VIRTUAL_TYPE = ge::VIRTUAL_TYPE.c_str(); +static const char_t *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; +static const char_t *const AC_PARALLEL_ENABLE = ge::AC_PARALLEL_ENABLE.c_str(); +static const char_t *const AICORE_NUM = ge::AICORE_NUM.c_str(); +static const char_t *const FUSION_SWITCH_FILE = ge::FUSION_SWITCH_FILE.c_str(); +static const char_t *const ENABLE_SMALL_CHANNEL = ge::ENABLE_SMALL_CHANNEL.c_str(); +static const char_t *const OP_SELECT_IMPL_MODE = ge::OP_SELECT_IMPL_MODE.c_str(); +static const char_t *const OUTPUT_TYPE = ge::OUTPUT_DATATYPE.c_str(); +static const char_t *const BUFFER_OPTIMIZE = ge::BUFFER_OPTIMIZE.c_str(); +static const char_t *const ENABLE_COMPRESS_WEIGHT = ge::ENABLE_COMPRESS_WEIGHT.c_str(); +static const char_t *const SPARSITY = ge::ENABLE_SPARSE_MATRIX_WEIGHT.c_str(); +static const char_t *const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; +static const char_t *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str(); +static const char_t *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); +static const char_t *const LOG_LEVEL = "log"; +static const char_t *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c_str(); +static const char_t *const DEBUG_DIR = ge::DEBUG_DIR; +static const char_t *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; +static const char_t *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; +static const char_t *const BUILD_INNER_MODEL = ge::BUILD_INNER_MODEL; +static const char_t *const MDL_BANK_PATH = ge::MDL_BANK_PATH_FLAG.c_str(); +static const char_t *const OP_BANK_PATH = ge::OP_BANK_PATH_FLAG.c_str(); +static const char_t *const OP_BANK_UPDATE = ge::OP_BANK_UPDATE_FLAG.c_str(); +static const char_t *const OP_DEBUG_LEVEL = ge::OP_DEBUG_LEVEL.c_str(); +static const char_t *const PERFORMANCE_MODE = ge::PERFORMANCE_MODE.c_str(); +static const char_t *const SHAPE_GENERALIZED_BUILD_MODE = ge::SHAPE_GENERALIZED_BUILD_MODE.c_str(); +static const char_t *const MODIFY_MIXLIST = ge::MODIFY_MIXLIST.c_str(); +static const char_t *const OP_PRECISION_MODE = ge::OP_PRECISION_MODE.c_str(); +static const char_t *const ALLOW_HF32 = ge::ALLOW_HF32.c_str(); +static const char_t *const CUSTOMIZE_DTYPES = "ge.customizeDtypes"; +static const char_t *const COMPRESSION_OPTIMIZE_CONF = "ge.compressionOptimizeConf"; +static const char_t *const INPUT_DATA_NAMES = "input_data_names"; +static const char_t *const OP_DEBUG_CONFIG = "op_debug_config"; +static const char_t *const ATOMIC_CLEAN_POLICY = "ge.exec.atomicCleanPolicy"; +static const char_t *const EXTERNAL_WEIGHT = ge::EXTERNAL_WEIGHT.c_str(); +static const char_t *const EXCLUDE_ENGINES = ge::EXCLUDE_ENGINES.c_str(); +static const char_t *const DETERMINISTIC = ge::DETERMINISTIC.c_str(); +static const char_t *const DISTRIBUTED_CLUSTER_BUILD = ge::DISTRIBUTED_CLUSTER_BUILD.c_str(); +static const char_t *const MODEL_RELATION_CONFIG = ge::MODEL_RELATION_CONFIG.c_str(); +static const char_t *const CLUSTER_CONFIG = ge::CLUSTER_CONFIG.c_str(); +static const char_t *const ENABLE_GRAPH_PARALLEL = "ge.enableGraphParallel"; +static const char_t *const GRAPH_PARALLEL_OPTION_PATH = "ge.graphParallelOptionPath"; +static const char_t *const QUANT_DUMPABLE = ge::QUANT_DUMPABLE.c_str(); +static const char_t *const QUANT_BIAS_OPTIMIZE = ge::QUANT_BIAS_OPTIMIZE.c_str(); +static const char_t *const OP_DEPENDENCY_IN_OM = ge::OPTION_OP_DEPENDENCY_IN_OM; +static const char_t *const OP_DEBUG_OPTION = ge::OP_DEBUG_OPTION.c_str(); +static const char_t *const TILING_SCHEDULE_OPTIMIZE = ge::TILING_SCHEDULE_OPTIMIZE.c_str(); +static const char_t *const GRAPH_MAX_PARALLEL_MODEL_NUM = ge::GRAPH_MAX_PARALLEL_MODEL_NUM.c_str(); +static const char_t *const OO_LEVEL = ge::OO_LEVEL; +static const char_t *const OO_CONSTANT_FOLDING = ge::OO_CONSTANT_FOLDING; +static const char_t *const OO_DEAD_CODE_ELIMINATION = ge::OO_DEAD_CODE_ELIMINATION; +static const char_t *const OPTION_EXPORT_COMPILE_STAT = ge::OPTION_EXPORT_COMPILE_STAT; +// for interface: aclgrphBuildModel +#ifdef __GNUC__ +const std::set ir_builder_suppported_options = {INPUT_FORMAT, + INPUT_SHAPE, + INPUT_SHAPE_RANGE, + OP_NAME_MAP, + DYNAMIC_BATCH_SIZE, + DYNAMIC_IMAGE_SIZE, + DYNAMIC_DIMS, + INSERT_OP_FILE, + OP_PRECISION_MODE, + ALLOW_HF32, + PRECISION_MODE, + PRECISION_MODE_V2, + TUNE_DEVICE_IDS, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + OUTPUT_TYPE, + OUT_NODES, + INPUT_FP16_NODES, + LOG_LEVEL, + OP_DEBUG_LEVEL, + OP_DEBUG_OPTION, + DEBUG_DIR, + OP_COMPILER_CACHE_DIR, + OP_COMPILER_CACHE_MODE, + MDL_BANK_PATH, + OP_BANK_PATH, + OP_BANK_UPDATE, + PERFORMANCE_MODE, + SHAPE_GENERALIZED_BUILD_MODE, + MODIFY_MIXLIST, + CUSTOMIZE_DTYPES, + BUILD_INNER_MODEL, + OP_DEBUG_CONFIG, + EXCLUDE_ENGINES, + EXTERNAL_WEIGHT, + DISTRIBUTED_CLUSTER_BUILD, + MODEL_RELATION_CONFIG, + ENABLE_GRAPH_PARALLEL, + AC_PARALLEL_ENABLE, + GRAPH_PARALLEL_OPTION_PATH, + QUANT_DUMPABLE, + TILING_SCHEDULE_OPTIMIZE, + GRAPH_MAX_PARALLEL_MODEL_NUM, + OO_LEVEL, + OO_CONSTANT_FOLDING, + OO_DEAD_CODE_ELIMINATION, + OPTION_EXPORT_COMPILE_STAT}; + +// for interface: aclgrphParse +const std::set ir_parser_suppported_options = { + INPUT_FP16_NODES, IS_INPUT_ADJUST_HW_LAYOUT, IS_OUTPUT_ADJUST_HW_LAYOUT, OUTPUT, + OUT_NODES, ENABLE_SCOPE_FUSION_PASSES, INPUT_DATA_NAMES, INPUT_SHAPE}; + +// for interface: aclgrphBuildInitialize +const std::set global_options = {CORE_TYPE, + SOC_VERSION, + VIRTUAL_TYPE, + BUFFER_OPTIMIZE, + ENABLE_COMPRESS_WEIGHT, + COMPRESS_WEIGHT_CONF, + SPARSITY, + PRECISION_MODE, + PRECISION_MODE_V2, + ALLOW_HF32, + TUNE_DEVICE_IDS, + EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, + ENABLE_SINGLE_STREAM, + AC_PARALLEL_ENABLE, + AICORE_NUM, + FUSION_SWITCH_FILE, + ENABLE_SMALL_CHANNEL, + OP_SELECT_IMPL_MODE, + OPTYPELIST_FOR_IMPLMODE, + OP_DEBUG_LEVEL, + OP_DEBUG_OPTION, + DEBUG_DIR, + OP_COMPILER_CACHE_DIR, + OP_COMPILER_CACHE_MODE, + MODIFY_MIXLIST, + COMPRESSION_OPTIMIZE_CONF, + OP_DEBUG_CONFIG, + DETERMINISTIC, + CLUSTER_CONFIG, + OP_DEPENDENCY_IN_OM, + TILING_SCHEDULE_OPTIMIZE, + GRAPH_MAX_PARALLEL_MODEL_NUM, + OO_LEVEL, + OO_CONSTANT_FOLDING, + OO_DEAD_CODE_ELIMINATION, + OPTION_EXPORT_COMPILE_STAT}; +#endif +} // namespace ir_option +} // namespace ge +#endif // GE_API_TYPES_DEF +#endif // INC_EXTERNAL_GE_COMMON_GE_API_TYPES_H_ diff --git a/inc/metadef/external/ge_common/ge_error_codes.h b/inc/metadef/external/ge_common/ge_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..af29330a9261bd980d0b59cb983410191ac8a905 --- /dev/null +++ b/inc/metadef/external/ge_common/ge_error_codes.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GE_COMMON_GE_ERROR_CODES_H_ +#define INC_EXTERNAL_GE_COMMON_GE_ERROR_CODES_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY _declspec(dllexport) +#else +#define GE_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define GE_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#endif // INC_EXTERNAL_GE_COMMON_GE_ERROR_CODES_H_ diff --git a/inc/metadef/external/graph/ascend_string.h b/inc/metadef/external/graph/ascend_string.h new file mode 100644 index 0000000000000000000000000000000000000000..55ce54ea4f3a02e111f1321c893cfc0010a2c377 --- /dev/null +++ b/inc/metadef/external/graph/ascend_string.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ +#define INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ + +#include +#include +#include +#include "graph/types.h" + +namespace ge { +class AscendString { + public: + AscendString() = default; + + ~AscendString() = default; + + AscendString(const char_t *const name); + + AscendString(const char_t *const name, size_t length); + + const char_t *GetString() const; + + size_t GetLength() const; + + size_t Find(const AscendString &ascend_string) const; + + size_t Hash() const; + + bool operator<(const AscendString &d) const; + + bool operator>(const AscendString &d) const; + + bool operator<=(const AscendString &d) const; + + bool operator>=(const AscendString &d) const; + + bool operator==(const AscendString &d) const; + + bool operator!=(const AscendString &d) const; + + bool operator==(const char_t *const d) const; + + bool operator!=(const char_t *const d) const; + + private: + std::shared_ptr name_; +}; +} // namespace ge + +namespace std { +template<> +struct hash { + size_t operator()(const ge::AscendString &name) const { + return name.Hash(); + } +}; +} // namespace std +#endif // INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ diff --git a/inc/metadef/external/graph/attr_value.h b/inc/metadef/external/graph/attr_value.h new file mode 100644 index 0000000000000000000000000000000000000000..c36ca32e73f04948ae8ebfb29472ecc238493459 --- /dev/null +++ b/inc/metadef/external/graph/attr_value.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ +#define INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ + +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "ascend_string.h" + +using std::make_shared; +using std::map; +using std::pair; +using std::string; +using std::to_string; +using std::unique_ptr; +using std::vector; + +namespace ge { +class AttrValueImpl; +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { + public: + using INT = int64_t; + using FLOAT = float; + using STR = std::string; + + AttrValue(); + ~AttrValue() = default; + + // GetValue, not list type + template + graphStatus GetValue(DT &val) const { + T valGet; + const auto status = GetValue(valGet); + if (status != GRAPH_SUCCESS) { + return status; + } + val = DT(valGet); + return GRAPH_SUCCESS; + } + + template + static T CreateFrom(DT &&val) { + return val; + } + + graphStatus GetValue(AscendString &val); + + std::shared_ptr impl; + + private: +#define VALUE_SET_GET_DEC(DT) graphStatus GetValue(DT &val) const; + VALUE_SET_GET_DEC(AttrValue::STR) + VALUE_SET_GET_DEC(AttrValue::INT) + VALUE_SET_GET_DEC(AttrValue::FLOAT) +#undef VALUE_SET_GET_DEC +}; +/*lint +e148*/ +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ diff --git a/inc/metadef/external/graph/compiler_def.h b/inc/metadef/external/graph/compiler_def.h new file mode 100644 index 0000000000000000000000000000000000000000..360763ecfce20bd05c43dcba95e009cbdcc7cc8e --- /dev/null +++ b/inc/metadef/external/graph/compiler_def.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXTERNAL_GRAPH_COMPILER_DEF_H_ +#define METADEF_CXX_INC_EXTERNAL_GRAPH_COMPILER_DEF_H_ +#ifdef __GNUC__ +#define VAR_UNUSED __attribute__((unused)) +#else +#define VAR_UNUSED +#endif +#endif // METADEF_CXX_INC_EXTERNAL_GRAPH_COMPILER_DEF_H_ diff --git a/inc/metadef/external/graph/ge_error_codes.h b/inc/metadef/external/graph/ge_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..9b55ab3f60f30ed629e04b96bbf64973073f6b27 --- /dev/null +++ b/inc/metadef/external/graph/ge_error_codes.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ +#define INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ + +#include + +namespace ge { +#if(defined(HOST_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#if(defined(DEV_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif +#ifdef __GNUC__ +#ifdef NO_METADEF_ABI_COMPATIABLE +#define ATTRIBUTED_DEPRECATED(replacement) +#define ATTRIBUTED_NOT_SUPPORT() +#else +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#define ATTRIBUTED_NOT_SUPPORT() __attribute__((deprecated("The method will not be supported in the future."))) +#endif +#else +#ifdef NO_METADEF_ABI_COMPATIABLE +#define ATTRIBUTED_DEPRECATED(replacement) +#define ATTRIBUTED_NOT_SUPPORT() +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#define ATTRIBUTED_NOT_SUPPORT() __declspec(deprecated("The method will not be supported in the future.")) +#endif +#endif + +using graphStatus = uint32_t; +const graphStatus GRAPH_FAILED = 0xFFFFFFFF; +const graphStatus GRAPH_SUCCESS = 0; +const graphStatus GRAPH_NOT_CHANGED = 1343242304; + +const graphStatus GRAPH_PARAM_INVALID = 50331649; +const graphStatus GRAPH_NODE_WITHOUT_CONST_INPUT = 50331648; +const graphStatus GRAPH_NODE_NEED_REPASS = 50331647; +const graphStatus GRAPH_INVALID_IR_DEF = 50331646; +const graphStatus OP_WITHOUT_IR_DATATYPE_INFER_RULE = 50331645; +const graphStatus GRAPH_PARAM_OUT_OF_RANGE = 50331644; + +const graphStatus GRAPH_MEM_OPERATE_FAILED = 50331539; +const graphStatus GRAPH_NULL_PTR = 50331538; +const graphStatus GRAPH_MEMCPY_FAILED = 50331537; +const graphStatus GRAPH_MEMSET_FAILED = 50331536; + +const graphStatus GRAPH_MATH_CAL_FAILED = 50331429; +const graphStatus GRAPH_ADD_OVERFLOW = 50331428; +const graphStatus GRAPH_MUL_OVERFLOW = 50331427; +const graphStatus GRAPH_RoundUp_Overflow = 50331426; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ diff --git a/inc/metadef/external/graph/gnode.h b/inc/metadef/external/graph/gnode.h new file mode 100644 index 0000000000000000000000000000000000000000..a5706dd466e1e8a876b1eebdb3d26c653888c08c --- /dev/null +++ b/inc/metadef/external/graph/gnode.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_GNODE_H_ +#define INC_EXTERNAL_GRAPH_GNODE_H_ + +#include +#include + +#include "./ge_error_codes.h" +#include "./types.h" +#include "./tensor.h" +#include "./ascend_string.h" + +namespace ge { +class AttrValue; +class GNode; +class OpDesc; +class Graph; +class ComputeGraph; +using GNodePtr = std::shared_ptr; +using GraphPtr = std::shared_ptr; +using OpBytes = std::vector; +using OpDescPtr = std::shared_ptr; +using ComputeGraphPtr = std::shared_ptr; + +class NodeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GNode { + public: + GNode(); + + ~GNode() = default; + + graphStatus GetType(AscendString &type) const; + + graphStatus GetName(AscendString &name) const; + + std::pair GetInDataNodesAndPortIndexs(const int32_t index) const; + + std::vector GetInControlNodes() const; + + std::vector> GetOutDataNodesAndPortIndexs(const int32_t index) const; + + std::vector GetOutControlNodes() const; + + graphStatus GetInputConstData(const int32_t index, Tensor &data) const; + + graphStatus GetInputIndexByName(const AscendString &name, int32_t &index); + + graphStatus GetOutputIndexByName(const AscendString &name, int32_t &index); + + graphStatus GetDynamicInputIndexesByName(const AscendString &name, std::vector &indexes); + + graphStatus GetDynamicOutputIndexesByName(const AscendString &name, std::vector &indexes); + + size_t GetInputsSize() const; + + size_t GetOutputsSize() const; + + graphStatus GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const; + + graphStatus UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc); + + graphStatus GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const; + + graphStatus UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc); + + graphStatus GetAttr(const AscendString &name, int64_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, int32_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, uint32_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, float32_t &attr_value) const; + graphStatus GetAttr(const AscendString &name, AscendString &attr_value) const; + graphStatus GetAttr(const AscendString &name, bool &attr_value) const; + graphStatus GetAttr(const AscendString &name, Tensor &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_values) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, OpBytes &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector> &attr_value) const; + graphStatus GetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus GetAttr(const AscendString &name, ge::DataType &attr_value) const; + graphStatus GetAttr(const AscendString &name, AttrValue &attr_value) const; + + graphStatus SetAttr(const AscendString &name, int64_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, int32_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, uint32_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, float32_t &attr_value) const; + graphStatus SetAttr(const AscendString &name, AscendString &attr_value) const; + graphStatus SetAttr(const AscendString &name, bool &attr_value) const; + graphStatus SetAttr(const AscendString &name, Tensor &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_values) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, OpBytes &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector> &attr_value) const; + graphStatus SetAttr(const AscendString &name, std::vector &attr_value) const; + graphStatus SetAttr(const AscendString &name, ge::DataType &attr_value) const; + graphStatus SetAttr(const AscendString &name, AttrValue &attr_value) const; + + bool HasAttr(const AscendString &name); + + graphStatus GetSubgraph(uint32_t index, GraphPtr &graph) const; + + graphStatus GetALLSubgraphs(std::vector &graph_list) const; + + private: + std::shared_ptr impl_; + friend class NodeAdapter; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GNODE_H_ diff --git a/inc/metadef/external/graph/graph.h b/inc/metadef/external/graph/graph.h new file mode 100644 index 0000000000000000000000000000000000000000..85485ce53f00fb2e07a6d54e604b4e29694d3991 --- /dev/null +++ b/inc/metadef/external/graph/graph.h @@ -0,0 +1,134 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_GRAPH_H_ +#define INC_EXTERNAL_GRAPH_GRAPH_H_ + +#include +#include +#include +#include + +#include "./operator.h" +#include "./gnode.h" + +namespace ge { +class Graph; +class GraphImpl; +class GraphBuffer; + +using GraphImplPtr = std::shared_ptr; +using GraphPtr = std::shared_ptr; + +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { + friend class GraphUtils; + friend class GraphUtilsEx; + + public: + ATTRIBUTED_DEPRECATED(Graph(const char_t *)) + explicit Graph(const std::string &name); + + explicit Graph(const char_t *name); + + Graph() = default; + + ~Graph() = default; + + Graph &SetInputs(const std::vector &inputs); + + Graph &SetOutputs(const std::vector &outputs); + + Graph &SetOutputs(const std::vector>> &output_indexs); + + ATTRIBUTED_DEPRECATED(Graph &SetOutputs(const std::vector> &outputs); + + Graph &SetOutputs(const std::vector> &outputs); + + Graph &SetTargets(const std::vector &targets); + + bool IsValid() const; + + graphStatus AddOp(const ge::Operator &op); + + ATTRIBUTED_DEPRECATED(graphStatus FindOpByName(const char_t *, ge::Operator &)) + graphStatus FindOpByName(const std::string &name, ge::Operator &op) const; + + graphStatus FindOpByName(const char_t *name, ge::Operator &op) const; + + ATTRIBUTED_DEPRECATED(graphStatus FindOpByType(const char_t *, std::vector &)) + graphStatus FindOpByType(const std::string &type, std::vector &ops) const; + + graphStatus FindOpByType(const char_t *type, std::vector &ops) const; + + ATTRIBUTED_DEPRECATED(graphStatus GetAllOpName(std::vector &) const) + graphStatus GetAllOpName(std::vector &op_name) const; + + graphStatus GetAllOpName(std::vector &names) const; + + ATTRIBUTED_DEPRECATED(graphStatus SaveToFile(const char_t *file_name) const) + graphStatus SaveToFile(const std::string &file_name) const; + + graphStatus SaveToFile(const char_t *file_name) const; + + ATTRIBUTED_DEPRECATED(graphStatus LoadFromFile(const char_t *)) + graphStatus LoadFromFile(const std::string &file_name); + + graphStatus LoadFromFile(const char_t *file_name); + + graphStatus LoadFromSerializedModelArray(const void *serialized_model, size_t size); + + graphStatus SaveToMem(GraphBuffer &graph_buffer) const; + + graphStatus LoadFromMem(const GraphBuffer &graph_buffer); + + graphStatus LoadFromMem(const uint8_t *data, const size_t len); + + ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &) const) + const std::string &GetName() const; + + graphStatus GetName(AscendString &name) const; + + /// + /// Set is need train iteration. + /// If set true, it means this graph need to be run iteration some + /// times(according variant "npu_runconfig/iterations_per_loop"). + /// @param need_iteration need_iteration:whether to set iteration or not + /// + void SetNeedIteration(bool need_iteration); + + std::vector GetAllNodes() const; + + std::vector GetDirectNode () const; + + graphStatus RemoveNode(GNode &node); + + graphStatus RemoveNode(GNode &node, bool contain_subgraph); + + graphStatus RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index); + + GNode AddNodeByOp(const Operator &op); + + graphStatus AddDataEdge(GNode &src_node, const int32_t src_port_index, + GNode &dst_node, const int32_t dst_port_index); + + graphStatus AddControlEdge(GNode &src_node, GNode &dst_node); + + graphStatus CopyFrom(const Graph &src_graph); + + static GraphPtr ConstructFromInputs(const std::vector &inputs, const AscendString &name); + + private: + + GraphImplPtr impl_{nullptr}; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_GRAPH_H_ diff --git a/inc/metadef/external/graph/graph_buffer.h b/inc/metadef/external/graph/graph_buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..a5e3f56be789c28a9364ef6d2b7f98b2bec57021 --- /dev/null +++ b/inc/metadef/external/graph/graph_buffer.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_BUFFER_H_ +#define INC_EXTERNAL_GRAPH_BUFFER_H_ + +#include +#include +#include +#include "./types.h" + +namespace ge { +class Graph; +class Buffer; +using BufferPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GraphBuffer { + public: + GraphBuffer(); + GraphBuffer(const GraphBuffer &) = delete; + GraphBuffer &operator=(const GraphBuffer &) = delete; + ~GraphBuffer(); + + const std::uint8_t *GetData() const; + std::size_t GetSize() const; + + private: + BufferPtr buffer_{nullptr}; + friend class Graph; +}; +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_BUFFER_H_ diff --git a/inc/metadef/external/graph/inference_context.h b/inc/metadef/external/graph/inference_context.h new file mode 100644 index 0000000000000000000000000000000000000000..00fa2646a757427e10ed3dfd51f9483182ef52bd --- /dev/null +++ b/inc/metadef/external/graph/inference_context.h @@ -0,0 +1,129 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ +#define INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ + +#include +#include +#include +#include + +#include "tensor.h" +#include "types.h" +#include "ascend_string.h" +#include "resource_context.h" +#include "ge_error_codes.h" + +namespace ge { +class ShapeAndTypeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ShapeAndType { + public: + ShapeAndType(); + ~ShapeAndType() = default; + + ShapeAndType(const Shape &shape, DataType data_type); + + void SetShape(const Shape &shape); + + void SetType(DataType data_type); + + Shape GetShape() const; + + DataType GetDataType() const; + + private: + std::shared_ptr shape_and_type_impl_; +}; + +struct InnerInferenceContext; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { + public: + ~InferenceContext() = default; + InferenceContext(const InferenceContext &context) = delete; + InferenceContext(const InferenceContext &&context) = delete; + InferenceContext &operator=(const InferenceContext &context) = delete; + InferenceContext &operator=(const InferenceContext &&context) = delete; + + void SetInputHandleShapesAndTypes(std::vector> &&shapes_and_types); + const std::vector> &GetInputHandleShapesAndTypes() const; + const std::vector> &GetOutputHandleShapesAndTypes() const; + void SetOutputHandleShapesAndTypes(const std::vector> &shapes_and_types); + void SetOutputHandleShapesAndTypes(std::vector> &&shapes_and_types); + + ATTRIBUTED_DEPRECATED(void SetMarks(const std::vector &)) + void SetMarks(const std::vector &marks); + void SetMarks(const std::vector &marks); + + + ATTRIBUTED_DEPRECATED(void GetMarks(std::vector &) const) + const std::vector &GetMarks() const; + void GetMarks(std::vector &marks) const; + + static std::unique_ptr Create(void *resource_context_mgr = nullptr); + /** + * Get corresponding resource_context by key + * For resource op infershape, invoked by op infer_func. + * @param key + * @return corresponding resource context. Check not null before use it. + */ + ResourceContext *GetResourceContext(const ge::AscendString &key); + + /** + * Set corresponding resource_context by key. For node which will write to resource. + * For resource op infershape, invoked by write_op infer_func. + * @param key + * @param resource_context pointer. + * @return status + */ + graphStatus SetResourceContext(const ge::AscendString &key, ResourceContext *resource_context); + /** + * Register resource key relied on. For node which will read from resource. + * For resource op infershape, invoked by read_op infer_func. + * @param key + * @return status + */ + graphStatus RegisterReliedOnResourceKey(const ge::AscendString &key); + + /** + * During infershape of write op, if resource shape changed, use this to tell. + * For resource op infershape, invoked by write_op infer_func. + * @param key + * @return status + */ + graphStatus AddChangedResourceKey(const ge::AscendString &key); + + /** + * After read_op infershaped, can get resource_keys relied on. + * For resource op infershape, invoked by ge infershape framework. + * @param keys + * @return status + */ + const std::set& GetReliedOnResourceKeys() const; + + /** + * After infershape of write op, ge can get resource_key which shape changed. + * For resource op infershape, invoked by ge infershape framework. + * @return keys + */ + const std::set& GetChangedResourceKeys() const; + /** + * After handle changed resource shape, should clear changed_keys in context. + * For resource op infershape, invoked by ge infershape framework. + */ + void ClearChangedResourceKeys(); + + private: + explicit InferenceContext(std::unique_ptr &inner_context); + std::shared_ptr inner_inference_context_; +}; + +using InferenceContextPtr = std::shared_ptr; +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_INFERENCE_CONTEXT_H_ diff --git a/inc/metadef/external/graph/operator.h b/inc/metadef/external/graph/operator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5b3b6f57fd9a068af8dd3dab494de5a46d9d8b0 --- /dev/null +++ b/inc/metadef/external/graph/operator.h @@ -0,0 +1,673 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_H_ + +#include +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "./inference_context.h" +#include "./tensor.h" +#include "./types.h" + +#ifndef USER_GE_LOGI +#define USER_GE_LOGI(...) +#endif // USER_GE_LOGI + +#ifndef USER_GE_LOGW +#define USER_GE_LOGW(...) +#endif // USER_GE_LOGW + +#ifndef USER_GE_LOGE +#define USER_GE_LOGE(...) +#endif // USER_GE_LOGE + +#define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + (name) + "_cnt") +#define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + (name) + "_cnt") + +namespace ge { +class Operator; +class OperatorImpl; +class NodeUtils; +class NamedAttrs; +class Graph; +class AttrValue; +class Node; + +using SubgraphBuilder = std::function; +using OperatorImplPtr = std::shared_ptr; +using OperatorPtr = std::shared_ptr; + +class OpIO; +using OutHandler = std::shared_ptr; +using InHandler = std::shared_ptr; + +using std::function; +using std::shared_ptr; +using std::string; + +/*lint -e148*/ +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { + public: + friend class OperatorImpl; + friend class GraphBuilderImpl; + friend class MultiThreadGraphBuilder; + friend class NodeUtils; + friend class OpDescUtils; + friend class GraphUtils; + friend class NodeUtilsEx; + friend class GraphUtilsEx; + + using OpInt = int64_t; + using OpFloat = float32_t; + using OpString = std::string; + using OpAscendString = AscendString; + using OpBool = bool; + using OpTensor = Tensor; + using OpType = ge::DataType; + using OpNamedAttrs = ge::NamedAttrs; + using OpListInt = std::vector; + using OpListFloat = std::vector; + using OpListString = std::vector; + using OpListAcendString = std::vector; + using OpListAscendString = OpListAcendString; + using OpListBool = std::vector; + using OpListTensor = std::vector; + using OpBytes = std::vector; + using OpListListInt = std::vector>; + using OpListType = std::vector; + using OpListNamedAttrs = std::vector; + + Operator() {} + ATTRIBUTED_DEPRECATED(Operator(const char_t *)) + explicit Operator(const std::string &type); + + explicit Operator(const char_t *type); + + ATTRIBUTED_DEPRECATED(Operator(const char_t *, const char_t *)) + Operator(const std::string &name, const std::string &type); + + Operator(const AscendString &name, const AscendString &type); + + Operator(const char_t *name, const char_t *type); + + virtual ~Operator() = default; + + bool IsEmpty() const; + + ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &) const) + std::string GetName() const; + + graphStatus GetName(AscendString &name) const; + + ATTRIBUTED_DEPRECATED(graphStatus GetOpType(AscendString &) const) + std::string GetOpType() const; + + graphStatus GetOpType(AscendString &type) const; + + // Only has one output index = 0 + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char_t *, const Operator &)) + Operator &SetInput(const std::string &dst_name, const Operator &src_oprt); + + Operator &SetInput(const char_t *dst_name, const Operator &src_oprt); + + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char_t *, const Operator &, const char_t *)) + Operator &SetInput(const std::string &dst_name, const Operator &src_oprt, const std::string &name); + + Operator &SetInput(const char_t *dst_name, const Operator &src_oprt, const char_t *name); + + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char_t *, const Operator &, uint32_t)) + Operator &SetInput(const std::string &dst_name, const Operator &src_oprt, uint32_t index); + + Operator &SetInput(const char_t *dst_name, const Operator &src_oprt, uint32_t index); + + Operator &SetInput(uint32_t dst_index, const Operator &src_oprt, uint32_t src_index); + + Operator &AddControlInput(const Operator &src_oprt); + + ATTRIBUTED_DEPRECATED(graphStatus GetInputConstData(const char_t *, Tensor &) const) + graphStatus GetInputConstData(const std::string &dst_name, Tensor &data) const; + + graphStatus GetInputConstData(const char_t *dst_name, Tensor &data) const; + + ATTRIBUTED_DEPRECATED(TensorDesc GetInputDescByName(const char_t *) const) + TensorDesc GetInputDesc(const std::string &name) const; + + TensorDesc GetInputDescByName(const char_t *name) const; + + TensorDesc GetInputDesc(uint32_t index) const; + + ATTRIBUTED_DEPRECATED(int GetDynamicOutputNum(const char_t *) const) + int32_t GetDynamicOutputNum(const std::string &name) const; + + int32_t GetDynamicOutputNum(const char_t *name) const; + + ATTRIBUTED_DEPRECATED(int GetDynamicInputNum(const char_t *)) + int32_t GetDynamicInputNum(const std::string &name) const; + + int32_t GetDynamicInputNum(const char_t *name) const; + + ATTRIBUTED_DEPRECATED(graphStatus TryGetInputDesc(const char_t *, TensorDesc &) const) + graphStatus TryGetInputDesc(const std::string &name, TensorDesc &tensor_desc) const; + + graphStatus TryGetInputDesc(const char_t *name, TensorDesc &tensor_desc) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateInputDesc(const char_t *, const TensorDesc &)) + graphStatus UpdateInputDesc(const std::string &name, const TensorDesc &tensor_desc); + + graphStatus UpdateInputDesc(const char_t *name, const TensorDesc &tensor_desc); + + ATTRIBUTED_DEPRECATED(TensorDesc GetOutputDescByName(const char_t *) const) + TensorDesc GetOutputDesc(const std::string &name) const; + + TensorDesc GetOutputDescByName(const char_t *name) const; + + TensorDesc GetOutputDesc(uint32_t index) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateOutputDesc(const char_t *, const TensorDesc &tensor_desc)) + graphStatus UpdateOutputDesc(const std::string &name, const TensorDesc &tensor_desc); + + graphStatus UpdateOutputDesc(const char_t *name, const TensorDesc &tensor_desc); + + ATTRIBUTED_DEPRECATED(TensorDesc GetDynamicInputDesc(const char_t *, uint32_t) const) + TensorDesc GetDynamicInputDesc(const std::string &name, uint32_t index) const; + + TensorDesc GetDynamicInputDesc(const char_t *name, uint32_t index) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateDynamicInputDesc(const char_t *, uint32_t, const TensorDesc &)) + graphStatus UpdateDynamicInputDesc(const std::string &name, uint32_t index, const TensorDesc &tensor_desc); + + graphStatus UpdateDynamicInputDesc(const char_t *name, uint32_t index, const TensorDesc &tensor_desc); + + ATTRIBUTED_DEPRECATED(TensorDesc GetDynamicOutputDesc(const char_t *, uint32_t) const) + TensorDesc GetDynamicOutputDesc(const std::string &name, uint32_t index) const; + + TensorDesc GetDynamicOutputDesc(const char_t *name, uint32_t index) const; + + ATTRIBUTED_DEPRECATED(graphStatus UpdateDynamicOutputDesc(const char_t *, uint32_t, const TensorDesc &)) + graphStatus UpdateDynamicOutputDesc(const std::string &name, uint32_t index, const TensorDesc &tensor_desc); + + graphStatus UpdateDynamicOutputDesc(const char_t *name, uint32_t index, const TensorDesc &tensor_desc); + + graphStatus InferShapeAndType(); + + void SetInferenceContext(const InferenceContextPtr &inference_context); + InferenceContextPtr GetInferenceContext() const; + + graphStatus VerifyAllAttr(bool disable_common_verifier = false); + + size_t GetInputsSize() const; + + size_t GetOutputsSize() const; + + ATTRIBUTED_DEPRECATED(graphStatus GetAllAttrNamesAndTypes(std::map &) const) + const std::map GetAllAttrNamesAndTypes() const; + /** + * 获取调用对象上设置好的`自定义属性`和`ir定义属性`的属性名称和类型 + * @param attr_name_types 出参 + * @return + */ + graphStatus GetAllAttrNamesAndTypes(std::map &attr_name_types) const; + /** + * 获取ir定义的属性名称和类型(包含普通attr和required_attr) + * @param attr_name_types 出参 + * @return + */ + graphStatus GetAllIrAttrNamesAndTypes(std::map &attr_name_types) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, int64_t)) + Operator &SetAttr(const std::string &name, int64_t attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, int32_t)) + Operator &SetAttr(const std::string &name, int32_t attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, uint32_t)) + Operator &SetAttr(const std::string &name, uint32_t attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, int64_t &) const) + graphStatus GetAttr(const std::string &name, int64_t &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, int32_t &) const) + graphStatus GetAttr(const std::string &name, int32_t &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, uint32_t &) const) + graphStatus GetAttr(const std::string &name, uint32_t &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, std::initializer_list &&)) + Operator &SetAttr(const std::string &name, std::initializer_list &&attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *name, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *name, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const std::string &, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, float32_t attr_value)) + Operator &SetAttr(const std::string &name, float32_t attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, float32_t &) const) + graphStatus GetAttr(const std::string &name, float32_t &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, AttrValue &&)) + Operator &SetAttr(const std::string &name, AttrValue &&attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, AttrValue &) const) + graphStatus GetAttr(const std::string &name, AttrValue &attr_value) const; + Operator &SetAttr(const std::string &name, const std::string &attr_value); + graphStatus GetAttr(const std::string &name, std::string &attr_value) const; + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, bool)) + Operator &SetAttr(const std::string &name, bool attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, bool &) const) + graphStatus GetAttr(const std::string &name, bool &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const Tensor &)) + Operator &SetAttr(const std::string &name, const Tensor &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, Tensor &) const) + graphStatus GetAttr(const std::string &name, Tensor &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + + // Bytes type + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const OpBytes &)) + Operator &SetAttr(const std::string &name, const OpBytes &attr_value); + // Bytes type + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, OpBytes &) const) + graphStatus GetAttr(const std::string &name, OpBytes &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector> &)) + Operator &SetAttr(const std::string &name, const std::vector> &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, std::vector> &) const) + graphStatus GetAttr(const std::string &name, std::vector> &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const ge::DataType &)) + Operator &SetAttr(const std::string &name, const ge::DataType &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, ge::DataType &) const) + graphStatus GetAttr(const std::string &name, ge::DataType &attr_value) const; + + // func type + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const ge::NamedAttrs &)) + Operator &SetAttr(const std::string &name, const ge::NamedAttrs &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, ge::NamedAttrs &) const) + graphStatus GetAttr(const std::string &name, ge::NamedAttrs &attr_value) const; + ATTRIBUTED_DEPRECATED(Operator &SetAttr(const char_t *, const std::vector &)) + Operator &SetAttr(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(graphStatus GetAttr(const char_t *, std::vector &) const) + graphStatus GetAttr(const std::string &name, std::vector &attr_value) const; + + Operator &SetAttr(const char_t *name, int64_t attr_value); + Operator &SetAttr(const char_t *name, int32_t attr_value); + Operator &SetAttr(const char_t *name, uint32_t attr_value); + graphStatus GetAttr(const char_t *name, int64_t &attr_value) const; + graphStatus GetAttr(const char_t *name, int32_t &attr_value) const; + graphStatus GetAttr(const char_t *name, uint32_t &attr_value) const; + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + Operator &SetAttr(const char_t *name, std::initializer_list &&attr_value); + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + + Operator &SetAttr(const char_t *name, float32_t attr_value); + graphStatus GetAttr(const char_t *name, float32_t &attr_value) const; + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + Operator &SetAttr(const char_t *name, AttrValue &&attr_value); + graphStatus GetAttr(const char_t *name, AttrValue &attr_value) const; + + Operator &SetAttr(const char_t *name, const char_t *attr_value); + Operator &SetAttr(const char_t *name, const AscendString &attr_value); + graphStatus GetAttr(const char_t *name, AscendString &attr_value) const; + Operator &SetAttr(const char_t *name, const std::vector &attr_values); + graphStatus GetAttr(const char_t *name, std::vector &attr_values) const; + + Operator &SetAttr(const char_t *name, bool attr_value); + graphStatus GetAttr(const char_t *name, bool &attr_value) const; + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + + Operator &SetAttr(const char_t *name, const Tensor &attr_value); + graphStatus GetAttr(const char_t *name, Tensor &attr_value) const; + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + + // Bytes type + Operator &SetAttr(const char_t *name, const OpBytes &attr_value); + // Bytes type + graphStatus GetAttr(const char_t *name, OpBytes &attr_value) const; + + Operator &SetAttr(const char_t *name, const std::vector> &attr_value); + graphStatus GetAttr(const char_t *name, std::vector> &attr_value) const; + + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + + Operator &SetAttr(const char_t *name, const ge::DataType &attr_value); + graphStatus GetAttr(const char_t *name, ge::DataType &attr_value) const; + + // func type + Operator &SetAttr(const char_t *name, const ge::NamedAttrs &attr_value); + graphStatus GetAttr(const char_t *name, ge::NamedAttrs &attr_value) const; + Operator &SetAttr(const char_t *name, const std::vector &attr_value); + graphStatus GetAttr(const char_t *name, std::vector &attr_value) const; + + void BreakConnect() const; + + size_t GetSubgraphNamesCount() const; + ATTRIBUTED_DEPRECATED(graphStatus GetSubgraphNames(std::vector &) const) + std::vector GetSubgraphNames() const; + graphStatus GetSubgraphNames(std::vector &names) const; + ATTRIBUTED_DEPRECATED(SubgraphBuilder GetSubgraphBuilder(const char_t *) const) + SubgraphBuilder GetSubgraphBuilder(const std::string &name) const; + SubgraphBuilder GetSubgraphBuilder(const char_t *name) const; + ATTRIBUTED_DEPRECATED(Graph GetSubgraph(const char_t *) const) + Graph GetSubgraph(const std::string &name) const; + Graph GetSubgraph(const char_t *name) const; + ATTRIBUTED_DEPRECATED(SubgraphBuilder GetDynamicSubgraphBuilder(const char_t *, uint32_t) const) + SubgraphBuilder GetDynamicSubgraphBuilder(const std::string &name, uint32_t index) const; + SubgraphBuilder GetDynamicSubgraphBuilder(const char_t *name, uint32_t index) const; + ATTRIBUTED_DEPRECATED(Graph GetDynamicSubgraph(const char_t *, uint32_t) const) + Graph GetDynamicSubgraph(const std::string &name, uint32_t index) const; + Graph GetDynamicSubgraph(const char_t *name, uint32_t index) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const char_t *attr_value); + Operator &SetOutputAttr(const int32_t index, const char_t *name, const char_t *attr_value); + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const char_t *attr_value); + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const char_t *attr_value); + + Operator &SetInputAttr(const int32_t index, const char_t *name, const AscendString &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, AscendString &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const AscendString &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, AscendString &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const AscendString &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, AscendString &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const AscendString &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, AscendString &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, int64_t attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, int64_t &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, int64_t attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, int64_t &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, int64_t attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, int64_t &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, int64_t attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, int64_t &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, int32_t attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, int32_t &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, int32_t attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, int32_t &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, int32_t attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, int32_t &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, int32_t attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, int32_t &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, uint32_t attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, uint32_t &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, uint32_t attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, uint32_t &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, uint32_t attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, uint32_t &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, uint32_t attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, uint32_t &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, bool attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, bool &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, bool attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, bool &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, bool attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, bool &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, bool attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, bool &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, float32_t attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, float32_t &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, float32_t attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, float32_t &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, float32_t attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, float32_t &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, float32_t attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, float32_t &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + + Operator &SetInputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const int32_t index, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const int32_t index, const char_t *name, std::vector &attr_value) const; + Operator &SetInputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetInputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + Operator &SetOutputAttr(const char_t *dst_name, const char_t *name, const std::vector &attr_value); + graphStatus GetOutputAttr(const char_t *dst_name, const char_t *name, std::vector &attr_value) const; + + Operator &SetInput(const char_t *dst_name, uint32_t dst_index, const Operator &src_oprt, + const char_t *name); + + Operator &SetInput(const char_t *dst_name, uint32_t dst_index, + const Operator &src_oprt); + + void DynamicInputRegister(const char_t *name, const uint32_t num, bool is_push_back = true); + void DynamicInputRegister(const char_t *name, const uint32_t num, const char_t *datatype_symbol, + bool is_push_back = true); + + void DynamicInputRegisterByIndex(const char_t *name, const uint32_t num, size_t index); + + void DynamicOutputRegister(const char_t *name, const uint32_t num, bool is_push_back = true); + void DynamicOutputRegister(const char_t *name, const uint32_t num, const char_t *datatype_symbol, + bool is_push_back = true); + + void SubgraphCountRegister(const char_t *ir_name, uint32_t count); + + void SetSubgraphBuilder(const char_t *ir_name, uint32_t index, const SubgraphBuilder &builder); + + graphStatus UpdateInputDesc(const uint32_t index, const TensorDesc &tensor_desc); + + graphStatus UpdateOutputDesc(const uint32_t index, const TensorDesc &tensor_desc); + + protected: + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, float32_t)) + void AttrRegister(const std::string &name, float32_t attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, int64_t)) + void AttrRegister(const std::string &name, int64_t attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const AscendString &)) + void AttrRegister(const std::string &name, const std::string &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, bool)) + void AttrRegister(const std::string &name, bool attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const Tensor &)) + void AttrRegister(const std::string &name, const Tensor &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const OpBytes &)) + void AttrRegister(const std::string &name, const OpBytes &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector> &)) + void AttrRegister(const std::string &name, const std::vector> &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const ge::DataType &)) + void AttrRegister(const std::string &name, const ge::DataType &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const ge::NamedAttrs &)) + void AttrRegister(const std::string &name, const ge::NamedAttrs &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const AscendString &)) + void AttrRegister(const std::string &name, const AscendString &attr_value); + ATTRIBUTED_DEPRECATED(void AttrRegister(const char_t *, const std::vector &)) + void AttrRegister(const std::string &name, const std::vector &attr_value); + + void AttrRegister(const char_t *name, float32_t attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + void AttrRegister(const char_t *name, int64_t attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + void AttrRegister(const char_t *name, const char_t *attr_value); + void AttrRegister(const char_t *name, bool attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + void AttrRegister(const char_t *name, const Tensor &attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + void AttrRegister(const char_t *name, const OpBytes &attr_value); + void AttrRegister(const char_t *name, const std::vector> &attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + void AttrRegister(const char_t *name, const ge::DataType &attr_value); + void AttrRegister(const char_t *name, const ge::NamedAttrs &attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + void AttrRegister(const char_t *name, const AscendString &attr_value); + void AttrRegister(const char_t *name, const std::vector &attr_value); + + explicit Operator(OperatorImplPtr &&op_impl); + + ATTRIBUTED_DEPRECATED(void InputRegister(const char_t *)) + void InputRegister(const std::string &name); + void InputRegister(const char_t *name); + void InputRegister(const char_t *name, const char_t *datatype_symbol); + + ATTRIBUTED_DEPRECATED(void OptionalInputRegister(const char_t *)) + void OptionalInputRegister(const std::string &name); + void OptionalInputRegister(const char_t *name); + void OptionalInputRegister(const char_t *name, const char_t *datatype_symbol); + + void InferFuncRegister(const std::function &func); + + void VerifierFuncRegister(const std::function &func); + + void InferFormatFuncRegister(const std::function &func); + + ATTRIBUTED_DEPRECATED(void OutputRegister(const char_t *)) + void OutputRegister(const std::string &name); + void OutputRegister(const char_t *name); + + void OutputRegister(const char_t *name, const char_t *datatype_symbol); + + ATTRIBUTED_DEPRECATED(void DynamicInputRegister(const char_t *, const uint32_t, bool)) + void DynamicInputRegister(const std::string &name, const uint32_t num, bool is_push_back = true); + + ATTRIBUTED_DEPRECATED(void DynamicInputRegisterByIndex(const char_t *, const uint32_t, size_t)) + void DynamicInputRegisterByIndex(const std::string &name, const uint32_t num, size_t index); + + ATTRIBUTED_DEPRECATED(void DynamicOutputRegister(const char_t *, const uint32_t, bool)) + void DynamicOutputRegister(const std::string &name, const uint32_t num, bool is_push_back = true); + + ATTRIBUTED_DEPRECATED(void RequiredAttrRegister(const char_t *)) + void RequiredAttrRegister(const std::string &name); + void RequiredAttrRegister(const char_t *name); + void RequiredAttrWithTypeRegister(const char_t *name, const char_t *type); + + void DataTypeRegister(const char_t *datatype_symbol, const TensorType &type_range); + void DataTypeRegister(const char_t *datatype_symbol, const ListTensorType &list_type_range); + void DataTypeRegister(const char_t *datatype_symbol, const Promote &promote_rule); + + graphStatus VerifyAll(); + + // Only has one output index = 0 + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char_t *, uint32_t, const Operator &)) + Operator &SetInput(const std::string &dst_name, uint32_t dst_index, + const Operator &src_oprt); + + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char_t *, uint32_t, const Operator &, const char_t *)) + Operator &SetInput(const std::string &dst_name, uint32_t dst_index, const Operator &src_oprt, + const std::string &name); + + ATTRIBUTED_DEPRECATED(void SubgraphRegister(const char_t *, bool)) + void SubgraphRegister(const std::string &ir_name, bool dynamic); + void SubgraphRegister(const char_t *ir_name, bool dynamic); + ATTRIBUTED_DEPRECATED(void SubgraphCountRegister(const char_t *, uint32_t)) + void SubgraphCountRegister(const std::string &ir_name, uint32_t count); + ATTRIBUTED_DEPRECATED(void SetSubgraphBuilder(const char_t *, uint32_t, const SubgraphBuilder &)) + void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder); + ATTRIBUTED_DEPRECATED(Graph GetSubgraphImpl(const char_t *) const) + Graph GetSubgraphImpl(const std::string &name) const; + Graph GetSubgraphImpl(const char_t *name) const; + + private: + ATTRIBUTED_DEPRECATED(Operator &SetInput(const char_t *, const OutHandler &)) + Operator &SetInput(const std::string &dst_name, const OutHandler &out_handler); + Operator &SetInput(const char_t *dst_name, const OutHandler &out_handler); + + ATTRIBUTED_DEPRECATED(OutHandler GetOutput(const char_t *) const) + OutHandler GetOutput(const std::string &name) const; + OutHandler GetOutput(const char_t *name) const; + + OutHandler GetOutput(uint32_t index) const; + + OperatorImplPtr GetOperatorImplPtr() const; + + OperatorImplPtr operator_impl_{nullptr}; + + ATTRIBUTED_DEPRECATED(graphStatus GetInputConstDataOut(const char_t *, Tensor &) const) + graphStatus GetInputConstDataOut(const std::string &dst_name, Tensor &data) const; + graphStatus GetInputConstDataOut(const char_t *dst_name, Tensor &data) const; + + std::shared_ptr GetNode() const; +}; +/*lint +e148*/ +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ diff --git a/inc/metadef/external/graph/operator_factory.h b/inc/metadef/external/graph/operator_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..7e121653adad3dc43cb8e0922e80f7f729d18b9d --- /dev/null +++ b/inc/metadef/external/graph/operator_factory.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ + +#include +#include +#include +#include + +#include "./operator.h" +#include "./ge_error_codes.h" +#include "./ascend_string.h" +#include "./types.h" + +namespace ge { +using OpCreator = std::function; +using OpCreatorV2 = std::function; +using InferShapeFunc = std::function; +using InferFormatFunc = std::function; +using InferValueRangeFunc = std::function; +using VerifyFunc = std::function; + +enum WHEN_CALL { + INPUT_IS_DYNAMIC = 0, + INPUT_HAS_VALUE_RANGE = 1 +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactory { + public: + ATTRIBUTED_DEPRECATED(static Operator CreateOperator(const char_t *, const char_t *)) + static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); + + static Operator CreateOperator(const char_t *const operator_name, const char_t *const operator_type); + + ATTRIBUTED_DEPRECATED(graphStatus GetOpsTypeList(std::vector &)) + static graphStatus GetOpsTypeList(std::vector &all_ops); + + static graphStatus GetOpsTypeList(std::vector &all_ops); + + ATTRIBUTED_DEPRECATED(bool IsExistOp(const char_t *)) + static bool IsExistOp(const std::string &operator_type); + + static bool IsExistOp(const char_t *const operator_type); +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorCreatorRegister { + public: + ATTRIBUTED_DEPRECATED(OperatorCreatorRegister(const char_t *, OpCreatorV2 const &)) + OperatorCreatorRegister(const std::string &operator_type, OpCreator const &op_creator); + OperatorCreatorRegister(const char_t *const operator_type, OpCreatorV2 const &op_creator); + ~OperatorCreatorRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferShapeFuncRegister { + public: + ATTRIBUTED_DEPRECATED(InferShapeFuncRegister(const char_t *, const InferShapeFunc &)) + InferShapeFuncRegister(const std::string &operator_type, const InferShapeFunc &infer_shape_func); + InferShapeFuncRegister(const char_t *const operator_type, const InferShapeFunc &infer_shape_func); + ~InferShapeFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferFormatFuncRegister { + public: + ATTRIBUTED_DEPRECATED(InferFormatFuncRegister(const char_t *, const InferFormatFunc &)) + InferFormatFuncRegister(const std::string &operator_type, const InferFormatFunc &infer_format_func); + InferFormatFuncRegister(const char_t *const operator_type, const InferFormatFunc &infer_format_func); + ~InferFormatFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferValueRangeFuncRegister { + public: + InferValueRangeFuncRegister(const char_t *const operator_type, const WHEN_CALL when_call, + const InferValueRangeFunc &infer_value_range_func); + InferValueRangeFuncRegister(const char_t *const operator_type); + ~InferValueRangeFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY VerifyFuncRegister { + public: + ATTRIBUTED_DEPRECATED(VerifyFuncRegister(const char_t *, const VerifyFunc &)) + VerifyFuncRegister(const std::string &operator_type, const VerifyFunc &verify_func); + VerifyFuncRegister(const char_t *const operator_type, const VerifyFunc &verify_func); + ~VerifyFuncRegister() = default; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_OPERATOR_FACTORY_H_ diff --git a/inc/metadef/external/graph/operator_reg.h b/inc/metadef/external/graph/operator_reg.h new file mode 100644 index 0000000000000000000000000000000000000000..9584f32a3a7aaf1cd1682e8b8218d0ba19e4766d --- /dev/null +++ b/inc/metadef/external/graph/operator_reg.h @@ -0,0 +1,736 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ +#define INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ + +#include +#include +#include +#include + +#include "graph/operator.h" +#include "graph/operator_factory.h" +#include "graph/tensor.h" +#include "graph/types.h" +#include "graph/graph.h" + +#if defined(__GNUC__) || defined(__clang__) +#define FORCE_INLINE __attribute__((always_inline)) +#elif defined(_MSC_VER) || defined(__INTEL_COMPILER) +#define FORCE_INLINE __forceinline +#elif defined(__IBMCPP__) +#define FORCE_INLINE __inline(always) +#else +#define FORCE_INLINE inline +#endif + +template +ge::AscendString ConvertToAscendString(T str); + +template<> +inline ge::AscendString ConvertToAscendString(const char *str) { + return ge::AscendString(str); +} + +template<> +inline ge::AscendString ConvertToAscendString(std::string str) { + return ge::AscendString(str.c_str()); +} + +template<> +inline ge::AscendString ConvertToAscendString(ge::AscendString str) { + return std::move(str); +} + +template +std::vector ConvertToListAscendString(T strs); + +template<> +inline std::vector ConvertToListAscendString(std::vector strs) { + std::vector ascend_strs(strs.size()); + for (size_t i = 0; i < strs.size(); ++i) { + ascend_strs[i] = ge::AscendString(strs[i].c_str()); + } + return ascend_strs; +} + +template<> +inline std::vector ConvertToListAscendString(std::vector strs) { + return strs; +} +namespace ge { +using std::function; +using std::string; +using std::vector; + +#define ATTR_String(x, ...) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + std::string ret_str = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + ret = AscendString(ret_str.c_str()); \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const char *v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { \ + (void) v; \ + return *this; \ + } + +#define ATTR_ListString(x, ...) \ + graphStatus get_attr_##x(std::vector &ret) const { \ + std::vector ret_strs = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + for (auto &ret_str : ret_strs) { \ + ret.emplace_back(ret_str.c_str()); \ + } \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const std::vector &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function()> &v) { \ + (void) v; \ + return *this; \ + } + +#define ATTR_AscendString(x, ...) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + AscendString ret_str = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + ret = AscendString(ret_str); \ + } \ + return GRAPH_SUCCESS; \ + } + +#define ATTR_ListAscendString(x, ...) \ + graphStatus get_attr_##x(std::vector &ret) const { \ + std::vector ret_strs = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + for (auto &ret_str : ret_strs) { \ + if (ret_str.GetString() != nullptr) { \ + ret.emplace_back(ret_str.GetString()); \ + } \ + } \ + } \ + return GRAPH_SUCCESS; \ + } + +#define ATTR_Int(x, ...) +#define ATTR_Float(x, ...) +#define ATTR_Bool(x, ...) +#define ATTR_Tensor(x, ...) +#define ATTR_Type(x, ...) +#define ATTR_NamedAttrs(x, ...) +#define ATTR_ListInt(x, ...) +#define ATTR_ListFloat(x, ...) +#define ATTR_ListBool(x, ...) +#define ATTR_ListTensor(x, ...) +#define ATTR_Bytes(x, ...) +#define ATTR_ListListInt(x, ...) +#define ATTR_ListType(x, ...) +#define ATTR_ListNamedAttrs(x, ...) + +#define SET_VALUE_String(x) auto value = ConvertToAscendString(x) +#define SET_VALUE_AscendString(x) auto value = ConvertToAscendString(x) + +#define SET_VALUE_ListString(x) \ + auto input = (x); \ + std::vector value = ConvertToListAscendString(input) + +#define SET_VALUE_ListAcendString(x) \ + auto input = (x); \ + std::vector value = ConvertToListAscendString(input) + +#define SET_VALUE_ListAscendString(x) \ + auto input = (x); \ + std::vector value = ConvertToListAscendString(input) + +#define SET_VALUE_Int(x) auto value = (x) +#define SET_VALUE_Float(x) auto value = (x) +#define SET_VALUE_Bool(x) auto value = (x) +#define SET_VALUE_Tensor(x) auto value = (x) +#define SET_VALUE_Type(x) auto value = (x) +#define SET_VALUE_NamedAttrs(x) auto value = (x) +#define SET_VALUE_ListInt(x) auto value = (x) +#define SET_VALUE_ListFloat(x) auto value = (x) +#define SET_VALUE_ListBool(x) auto value = (x) +#define SET_VALUE_ListTensor(x) auto value = (x) +#define SET_VALUE_Bytes(x) auto value = (x) +#define SET_VALUE_ListListInt(x) auto value = (x) +#define SET_VALUE_ListType(x) auto value = (x) +#define SET_VALUE_ListNamedAttrs(x) auto value = (x) + +#define REQUIRED_ATTR_String(x) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const char *v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { \ + (void) v; \ + return *this; \ + } + +#define REQUIRED_ATTR_ListString(x) \ + graphStatus get_attr_##x(std::vector &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } \ + _THIS_TYPE &set_attr_##x(const std::vector &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function()> &v) { \ + (void) v; \ + return *this; \ + } + +#define REQUIRED_ATTR_AscendString(x) \ + graphStatus get_attr_##x(AscendString &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } + +#define REQUIRED_ATTR_ListAscendString(x) \ + graphStatus get_attr_##x(std::vector &ret) const { \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return GRAPH_FAILED; \ + } \ + return GRAPH_SUCCESS; \ + } + +#define REQUIRED_ATTR_Int(x) +#define REQUIRED_ATTR_Float(x) +#define REQUIRED_ATTR_Bool(x) +#define REQUIRED_ATTR_Tensor(x) +#define REQUIRED_ATTR_Type(x) +#define REQUIRED_ATTR_NamedAttrs(x) +#define REQUIRED_ATTR_ListInt(x) +#define REQUIRED_ATTR_ListFloat(x) +#define REQUIRED_ATTR_ListBool(x) +#define REQUIRED_ATTR_ListTensor(x) +#define REQUIRED_ATTR_Bytes(x) +#define REQUIRED_ATTR_ListListInt(x) +#define REQUIRED_ATTR_ListType(x) +#define REQUIRED_ATTR_ListNamedAttrs(x) + +class OpReg { + public: + OpReg &N() { + return *this; + } + + OpReg &ATTR() { + return *this; + } + + OpReg &REQUIRED_ATTR() { + return *this; + } + + OpReg &INPUT() { + return *this; + } + + OpReg &OPTIONAL_INPUT() { + return *this; + } + + OpReg &OUTPUT() { + return *this; + } + + OpReg &GRAPH() { + return *this; + } + + OpReg &DYNAMIC_GRAPH() { + return *this; + } + + OpReg &INFER_SHAPE_AND_TYPE() { + return *this; + } +}; + +#define REG_OP(x) \ + namespace op { \ + class x : public Operator { \ + typedef x _THIS_TYPE; \ + \ + public: \ + ATTRIBUTED_DEPRECATED(x(const char *)) \ + explicit FORCE_INLINE x(const std::string &name) : Operator(name.c_str(), #x) { \ + __##x(); \ + } \ + explicit FORCE_INLINE x(const char *name) : Operator(name, #x) { \ + __##x(); \ + } \ + explicit FORCE_INLINE x(const AscendString &name) : Operator(name, #x) { \ + __##x(); \ + } \ + FORCE_INLINE x() : Operator(#x) { \ + __##x(); \ + } \ + \ + private: \ + void FORCE_INLINE __##x() { \ + OpReg() + +#define ATTR(x, Type, ...) \ + N(); \ + __attr_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \ + static const std::string name_attr_##x() { \ + return #x; \ + } \ + static void name_attr_##x(AscendString &attr) { \ + attr = AscendString(#x); \ + } \ + ATTR_##Type(x, __VA_ARGS__) Op##Type get_attr_##x() const { \ + Op##Type ret = __VA_ARGS__; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return ret; \ + } \ + return ret; \ + } \ + _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { \ + (void) v; \ + return *this; \ + } \ + \ + private: \ + void FORCE_INLINE __attr_##x() { \ + SET_VALUE_##Type(Op##Type(__VA_ARGS__)); \ + Operator::AttrRegister(#x, value); \ + std::string attr_name(#x); \ + (void) OpReg() + +#define REQUIRED_ATTR(x, Type) \ + N(); \ + __required_attr_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_attr_##x(AscendString &)) \ + static const std::string name_attr_##x() { \ + return #x; \ + } \ + static void name_attr_##x(AscendString &attr_name) { \ + attr_name = AscendString(#x); \ + } \ + REQUIRED_ATTR_##Type(x) Op##Type get_attr_##x() const { \ + Op##Type ret; \ + if (Operator::GetAttr(#x, ret) == GRAPH_FAILED) { \ + return ret; \ + } \ + return ret; \ + } \ + _THIS_TYPE &set_attr_##x(const Op##Type &v) { \ + Operator::SetAttr(#x, v); \ + return *this; \ + } \ + _THIS_TYPE &set_attr_##x(const function &v) { \ + (void) v; \ + return *this; \ + } \ + \ + private: \ + void FORCE_INLINE __required_attr_##x() { \ + Operator::RequiredAttrWithTypeRegister(#x, #Type); \ + std::string attr_name(#x); \ + (void) OpReg() + +#define DATATYPE(x, t) \ + N(); \ + __datatype_##x(); \ + } \ + \ + private: \ + void FORCE_INLINE __datatype_##x() { \ + auto type_range = t; \ + Operator::DataTypeRegister(#x, type_range); \ + (void) OpReg() + +#define INPUT(x, t) \ + N(); \ + __input_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \ + static const std::string name_in_##x() { \ + return #x; \ + } \ + static void name_in_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \ + _THIS_TYPE &set_input_##x(Operator &v, const std::string &srcName) { \ + Operator::SetInput(#x, v, srcName.c_str()); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \ + Operator::SetInput(#x, v, srcName); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v) { \ + Operator::SetInput(#x, v); \ + return *this; \ + } \ + TensorDesc get_input_desc_##x() const { \ + return Operator::GetInputDescByName(#x); \ + } \ + graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateInputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void FORCE_INLINE __input_##x() { \ + Operator::InputRegister(#x, #t); \ + (void) OpReg() + +#define OPTIONAL_INPUT(x, t) \ + N(); \ + __optional_input_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_in_##x(AscendString &)) \ + static const std::string name_in_##x() { \ + return #x; \ + } \ + static void name_in_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + _THIS_TYPE &set_input_##x(Operator &v) { \ + Operator::SetInput(#x, v); \ + return *this; \ + } \ + ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_input_##x##_by_name(Operator &, const char *)) \ + _THIS_TYPE &set_input_##x(Operator &v, const std::string &srcName) { \ + Operator::SetInput(#x, v, srcName.c_str()); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x##_by_name(Operator &v, const char *srcName) { \ + Operator::SetInput(#x, v, srcName); \ + return *this; \ + } \ + _THIS_TYPE &set_input_##x(Operator &v, uint32_t index) { \ + Operator::SetInput(#x, v, index); \ + return *this; \ + } \ + TensorDesc get_input_desc_##x() const { \ + return Operator::GetInputDescByName(#x); \ + } \ + graphStatus update_input_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateInputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void FORCE_INLINE __optional_input_##x() { \ + Operator::OptionalInputRegister(#x, #t); \ + (void) OpReg() + +#define OUTPUT(x, t) \ + N(); \ + __out_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_out_##x(AscendString &)) \ + static const std::string name_out_##x() { \ + return #x; \ + } \ + static void name_out_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + TensorDesc get_output_desc_##x() const { \ + return Operator::GetOutputDescByName(#x); \ + } \ + graphStatus update_output_desc_##x(const TensorDesc &tensorDesc) { \ + return Operator::UpdateOutputDesc(#x, tensorDesc); \ + } \ + \ + private: \ + void FORCE_INLINE __out_##x() { \ + Operator::OutputRegister(#x, #t); \ + (void) OpReg() + +#define DYNAMIC_INPUT(x, t) \ + N(); \ + __dy_input_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicInputRegister(#x, num, #t, isPushBack); \ + return *this; \ + } \ + _THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ + Operator::DynamicInputRegisterByIndex(#x, num, index); \ + return *this; \ + } \ + TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { \ + return Operator::GetDynamicInputDesc(#x, index); \ + } \ + graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ + Operator::SetInput(#x, dstIndex, v); \ + return *this; \ + } \ + ATTRIBUTED_DEPRECATED(_THIS_TYPE &set_dynamic_input_##x(uint32_t, Operator &, const char *)) \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const std::string &srcName) { \ + Operator::SetInput(#x, dstIndex, v, srcName.c_str()); \ + return *this; \ + } \ + _THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const char *srcName) { \ + Operator::SetInput(#x, dstIndex, v, srcName); \ + return *this; \ + } \ + \ + private: \ + void FORCE_INLINE __dy_input_##x() { \ + Operator::DynamicInputRegister(#x, 0, #t, true); \ + (void) OpReg() + +#define DYNAMIC_OUTPUT(x, t) \ + N(); \ + __dy_output_##x(); \ + } \ + \ + public: \ + _THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ + Operator::DynamicOutputRegister(#x, num, #t, isPushBack); \ + return *this; \ + } \ + TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { \ + return Operator::GetDynamicOutputDesc(#x, index); \ + } \ + graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ + return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ + } \ + \ + private: \ + void FORCE_INLINE __dy_output_##x() { \ + Operator::DynamicOutputRegister(#x, 0, #t, true); \ + (void) OpReg() + +#define GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \ + static const std::string name_graph_##x() { \ + return #x; \ + } \ + static void name_graph_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + SubgraphBuilder get_subgraph_builder_##x() const { \ + return Operator::GetSubgraphBuilder(#x); \ + } \ + _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, 0, v); \ + return *this; \ + } \ + Graph get_subgraph_##x() const { \ + return Operator::GetSubgraph(#x); \ + } \ + \ + private: \ + void FORCE_INLINE __graph_##x() { \ + Operator::SubgraphRegister(#x, false); \ + Operator::SubgraphCountRegister(#x, 1); \ + (void) OpReg() + +#define DYNAMIC_GRAPH(x) \ + N(); \ + __graph_##x(); \ + } \ + \ + public: \ + ATTRIBUTED_DEPRECATED(static const void name_graph_##x(AscendString &)) \ + static const std::string name_graph_##x() { \ + return #x; \ + } \ + static void name_graph_##x(AscendString &name) { \ + name = AscendString(#x); \ + } \ + _THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ + Operator::SubgraphCountRegister(#x, num); \ + return *this; \ + } \ + SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ + return Operator::GetDynamicSubgraphBuilder(#x, index); \ + } \ + Graph get_dynamic_subgraph_##x(uint32_t index) const { \ + return Operator::GetDynamicSubgraph(#x, index); \ + } \ + _THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ + Operator::SetSubgraphBuilder(#x, index, v); \ + return *this; \ + } \ + \ + private: \ + void FORCE_INLINE __graph_##x() { \ + Operator::SubgraphRegister(#x, true); \ + (void) OpReg() + +#define PASTE(g_register, y) g_register##y + +#define __OP_END_IMPL_WITHOUT_REGISTER__(x) \ + N(); \ + } \ + static_assert( \ + std::is_same::value, \ + "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ + } \ + ; \ + } + +#ifdef DISABLE_COMPILE_V1 +#define __OP_END_IMPL__(x, y) \ + __OP_END_IMPL_WITHOUT_REGISTER__(x) +#else +#define __OP_END_IMPL__(x, y) \ + N(); \ + } \ + static_assert( \ + std::is_same::value, \ + "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ + } \ + ; \ + static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const AscendString &name) { return x(name); }); \ + } +#endif +#define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) + +// Specialized shape inferencer macro + +#define IMPLEMT_INFERFUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +#define IMPLEMT_COMMON_INFERFUNC(func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(Operator &op) + +#define IMPLEMT_INFERFORMAT_FUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +// Specialized verifier macro + +#define IMPLEMT_VERIFIER(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name op) + +#define INFER_VERIFY_FUNC(op_name, x) [](Operator &v) { return x((op::op_name &) v); } + +#define COMMON_INFER_VERIFY_FUNC(x) [](Operator &v) { return x(v); } + +#define INFER_FORMAT_FUNC(op_name, x) [](Operator &v) { return x((op::op_name &) v); } + +#define __INFER_FUNC_REG_IMPL__(op_name, x, n) static const InferShapeFuncRegister PASTE(if_register, n)(#op_name, x) +#define __VERIFY_FUNC_REG_IMPL__(op_name, x, n) static const VerifyFuncRegister PASTE(vf_register, n)(#op_name, x) + +// Infer format func register +#define __INFER_FORMAT_FUNC_REG_IMPL__(op_name, x, n) \ + static const InferFormatFuncRegister PASTE(ff_register, n)(#op_name, x) + +// Shape inferencer & verifier register macro + +#define INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) + +#define COMMON_INFER_FUNC_REG(op_name, x) __INFER_FUNC_REG_IMPL__(op_name, COMMON_INFER_VERIFY_FUNC((x)), __COUNTER__) + +#define VERIFY_FUNC_REG(op_name, x) __VERIFY_FUNC_REG_IMPL__(op_name, INFER_VERIFY_FUNC(op_name, x), __COUNTER__) + +// Value Range Infer +#define INFER_VALUE_RANGE_FUNC(op_name, x) [](Operator &v) { return x((op::op_name &) v); } + +#define INFER_VALUE_RANGE_DEFAULT_REG(op_name) __INFER_VALUE_RANGE_DEFAULT_REG_IMPL__(op_name, __COUNTER__) +#define __INFER_VALUE_RANGE_DEFAULT_REG_IMPL__(op_name, n) \ + static const InferValueRangeFuncRegister PASTE(iv_reg_default, n)(#op_name) + +#define INFER_VALUE_RANGE_CUSTOM_FUNC_REG(op_name, when_call, x) \ + __INFER_VALUE_RANGE_CUSTOM_FUNC_REG_IMPL__(op_name, when_call, INFER_VALUE_RANGE_FUNC(op_name, x), __COUNTER__) +#define __INFER_VALUE_RANGE_CUSTOM_FUNC_REG_IMPL__(op_name, when_call, x, n) \ + static const InferValueRangeFuncRegister PASTE(iv_reg_custom, n)(#op_name, when_call, x) + +#define IMPL_INFER_VALUE_RANGE_FUNC(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus func_name(op::op_name &op) + +// Infer format func reg +#define INFER_FORMAT_FUNC_REG(op_name, x) \ + __INFER_FORMAT_FUNC_REG_IMPL__(op_name, INFER_FORMAT_FUNC(op_name, x), __COUNTER__) + +// Common shape inferencer + +#define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ + [](Operator op) -> graphStatus { \ + auto x_input_desc = op.GetInputDescByName(in_name); \ + auto x_shape = x_input_desc.GetShape().GetDims(); \ + auto x_type = x_input_desc.GetDataType(); \ + std::vector> x_shape_range; \ + (void) x_input_desc.GetShapeRange(x_shape_range); \ + TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \ + op_output_desc.SetShape(ge::Shape(x_shape)); \ + op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ + op_output_desc.SetDataType(x_type); \ + if (!x_shape_range.empty()) { \ + op_output_desc.SetShapeRange(x_shape_range); \ + } \ + return op.UpdateOutputDesc(out_name, op_output_desc); \ + } + +graphStatus BroadCastInfer(const function()> &get_in1_shape, + const function()> &get_in2_shape, + const function &y_shape)> &set_out_shape); + +#define BROADCAST_INFER(in1_name, in2_name, out_name) \ + [](Operator op) -> graphStatus { \ + return BroadCastInfer([&]() { return op.GetInputDescByName(in1_name).GetShape().GetDims(); }, \ + [&]() { return op.GetInputDescByName(in2_name).GetShape().GetDims(); }, \ + [&](const std::vector &y_shape) { \ + TensorDesc op_output_desc = op.GetOutputDescByName(out_name); \ + op_output_desc.SetShape(ge::Shape(y_shape)); \ + (void) op.UpdateOutputDesc(out_name, op_output_desc); \ + }); \ + } +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_OPERATOR_REG_H_ diff --git a/inc/metadef/external/graph/resource_context.h b/inc/metadef/external/graph/resource_context.h new file mode 100644 index 0000000000000000000000000000000000000000..1bae4753f13082de4c7e612f28e83483da8d6e9d --- /dev/null +++ b/inc/metadef/external/graph/resource_context.h @@ -0,0 +1,20 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_RESOURCE_CONTEXT_H_ +#define INC_EXTERNAL_GRAPH_RESOURCE_CONTEXT_H_ + +namespace ge { +// For resource op infershape, indicate content stored in resources, shape/dtype etc. +// Op can inherit from this struct and extend more content +struct ResourceContext { + virtual ~ResourceContext() {} +}; // struct ResourceContext +} // namespace ge +#endif // INC_EXTERNAL_GRAPH_RESOURCE_CONTEXT_H_ diff --git a/inc/metadef/external/graph/tensor.h b/inc/metadef/external/graph/tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..565dc04676b5de6f47b753095d61401224850b89 --- /dev/null +++ b/inc/metadef/external/graph/tensor.h @@ -0,0 +1,226 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_TENSOR_H_ +#define INC_EXTERNAL_GRAPH_TENSOR_H_ + +#include +#include +#include +#include +#include + +#include "./ge_error_codes.h" +#include "./types.h" +#include "ascend_string.h" + +namespace ge { +class ShapeImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Shape { + public: + Shape(); + ~Shape() = default; + explicit Shape(const std::vector &dims); + + /** + * `GetDimNum()`标识有效的dim的个数,跟`GetDims().size()`不等价,调用方按需选择 + * 比如如果dim是[-2], 维度未可知时: + * GetDimNum()会返回0; + * 而GetDims().size()会返回dim的个数,即1; + * 另外如果需要判断是否是标量,推荐使用接口`GetDims.size() == 0U`来判断 + * @return + */ + size_t GetDimNum() const; + // If the idx is invalid, return 0 + int64_t GetDim(size_t idx) const; + graphStatus SetDim(size_t idx, int64_t value); + /** + * `GetDims`标识dim的个数,跟`GetDimNum()`不等价,调用方按需选择 + * 比如如果dim是[-2], 维度未可知时: + * GetDimNum()会返回0; + * 而GetDims().size()会返回dim的个数,即1; + * 另外如果需要判断是否是标量,推荐使用接口`GetDims.size() == 0U`来判断 + * @return + */ + std::vector GetDims() const; + /** + * 获取shape的各个维度的dim值的乘积 + * @return + * 如果dim值包含-1或者-2,那么size直接返回-1, 含义是unknown shape + * 如果dim值包含0,那么size直接返回0,含义是空tensor + * 如果dim值的个数为0,那么size直接返回0,含义是标量 + * 如果dim值的乘积产生了int64的溢出,那么size直接返回0,含义是乘积溢出 + */ + int64_t GetShapeSize() const; + + private: + std::shared_ptr impl_; +}; + +class TensorDescImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { + public: + TensorDesc(); + ~TensorDesc() = default; + explicit TensorDesc(Shape shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + // Copy + TensorDesc(const TensorDesc &desc); + // Move + TensorDesc(TensorDesc &&desc); + // Copy + TensorDesc &operator=(const TensorDesc &desc); + // Move + TensorDesc &operator=(TensorDesc &&desc); + + void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); + Shape GetShape() const; + void SetShape(const Shape &shape); + // set shape with -2, it stand for unknown shape + graphStatus SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + + Format GetFormat() const; + void SetFormat(Format format); + + Shape GetOriginShape() const; + void SetOriginShape(const Shape &origin_shape); + + Format GetOriginFormat() const; + void SetOriginFormat(Format origin_format); + + DataType GetDataType() const; + void SetDataType(DataType dt); + + ATTRIBUTED_DEPRECATED(graphStatus GetName(AscendString &)) + std::string GetName() const; + graphStatus GetName(AscendString &name); + graphStatus GetName(AscendString &name) const; + + ATTRIBUTED_DEPRECATED(void SetName(const char_t *)) + void SetName(const std::string &name); + void SetName(const char_t *name); + + // Attr acess + void SetSize(int64_t size); + int64_t GetSize() const; + + int64_t GetRealDimCnt() const; + void SetRealDimCnt(const int64_t real_dim_cnt); + + void SetPlacement(Placement placement); + Placement GetPlacement() const; + + void SetConstData(std::unique_ptr const_data_buffer, const size_t &const_data_len); + bool GetConstData(uint8_t **const_data_buffer, size_t &const_data_len) const; + /* + * 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。 + * 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。 + * + * 通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如: + * | 补维规则 | 补维前shape | 补维后shape | + * | -------- | ----------- | ------------------------------------------------------------ | + * | 0110 | [2, 2] | [2, 1, 1, 2] | + * | 100 | [2, 3] | [1, 2, 3] | + * | 1000 | [2, 3] | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 | + * + */ + void SetExpandDimsRule(const AscendString &expand_dims_rule); + graphStatus GetExpandDimsRule(AscendString &expand_dims_rule) const; + + void SetReuseInputIndex(const uint32_t idx); + + private: + std::shared_ptr impl; + friend class TensorAdapter; +}; + +class TensorImpl; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { + public: + using DeleteFunc = std::function; + Tensor(); + ~Tensor() = default; + explicit Tensor(const TensorDesc &tensor_desc); + Tensor(const TensorDesc &tensor_desc, const std::vector &data); + Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size); + Tensor(TensorDesc &&tensor_desc, std::vector &&data); + + TensorDesc GetTensorDesc() const; + graphStatus SetTensorDesc(const TensorDesc &tensor_desc); + + const uint8_t *GetData() const; + uint8_t *GetData(); + size_t GetSize() const; + std::unique_ptr ResetData(); + + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const uint8_t *data, size_t size); + ATTRIBUTED_DEPRECATED(graphStatus SetData(const char_t *data)) + graphStatus SetData(const std::string &data); + graphStatus SetData(const char_t *data); + ATTRIBUTED_DEPRECATED(graphStatus SetData(const std::vector &)) + graphStatus SetData(const std::vector &data); + graphStatus SetData(const std::vector &datas); + graphStatus SetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func); + graphStatus IsValid(); + + graphStatus SetOriginShapeDimNum(const size_t dim_num); + size_t GetOriginShapeDimNum() const; + + graphStatus SetOriginShapeDim(const size_t idx, const int64_t dim_value); + int64_t GetOriginShapeDim(const size_t idx) const; + + graphStatus SetOriginFormat(const ge::Format &format); + ge::Format GetOriginFormat() const; + + graphStatus SetShapeDimNum(const size_t dim_num); + size_t GetShapeDimNum() const; + + graphStatus SetShapeDim(const size_t idx, const int64_t dim_value); + int64_t GetShapeDim(const size_t idx) const; + + graphStatus SetFormat(const ge::Format &format); + ge::Format GetFormat() const; + + graphStatus SetDataType(const ge::DataType &dtype); + ge::DataType GetDataType() const; + + graphStatus SetPlacement(const ge::Placement &placement); + ge::Placement GetPlacement() const; + + /* + * 补维类似于ExpandDims算子,在原有shape的基础上,添加一到多个维度,例如原shape[2,2]有两根轴,那么在两根轴中间补两维后的shape为[2,1,1,2]。 + * 补维后shape的第0、3根轴被称为原始轴,第1、2根轴被称为补维轴。 + * + * 通过1和0描述补维规则,1代表当前轴为补维轴,0代表当前轴为原始轴,从左到右依次代表当前shape每根轴的来源,例如: + * | 补维规则 | 补维前shape | 补维后shape | + * | -------- | ----------- | ------------------------------------------------------------ | + * | 0110 | [2, 2] | [2, 1, 1, 2] | + * | 100 | [2, 3] | [1, 2, 3] | + * | 1000 | [2, 3] | 补维规则与补维前shape不匹配,规则指定原始轴有3根,但原始shape只有2根轴,补维报错。 | + * + */ + graphStatus SetExpandDimsRule(const AscendString &expand_dims_rule); + graphStatus GetExpandDimsRule(AscendString &expand_dims_rule) const; + + // 高性能接口,与SetData接口的区别是避免重复make_shared,此时需要用户保证该tensor的内存只被当前tensor使用,具有独占所有权 + graphStatus ResetData(uint8_t *data, size_t size, const Tensor::DeleteFunc &deleter_func); + + Tensor Clone() const; + + private: + std::shared_ptr impl; + friend class TensorAdapter; +}; +} // namespace ge + +#endif // INC_EXTERNAL_GRAPH_TENSOR_H_ diff --git a/inc/metadef/external/graph/types.h b/inc/metadef/external/graph/types.h new file mode 100644 index 0000000000000000000000000000000000000000..ef354125dd08f889c8c6686bd4b56dc7e86a04a1 --- /dev/null +++ b/inc/metadef/external/graph/types.h @@ -0,0 +1,402 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_GRAPH_TYPES_H_ +#define INC_EXTERNAL_GRAPH_TYPES_H_ + +#include +#include +#include + +namespace ge { +using char_t = char; +using float32_t = float; +using float64_t = double; +using vector_bit_t = std::vector; + +// set minimal range value as 0 to support empty tensor +static const int64_t SHAPE_RANGE_LOWER_LIMIT = 0; +static const int64_t UNKNOWN_DIM = -1; +static const int64_t UNKNOWN_DIM_NUM = -2; +static const std::vector UNKNOWN_SHAPE = {-1}; +static const std::vector UNKNOWN_RANK = {-2}; +static const std::vector DUMMY_SHAPE = {-3}; +// When data type unit is bit, this offset need to be added. +static constexpr int32_t kDataTypeSizeBitOffset = 1000; +static constexpr uint32_t kBitNumOfOneByte = 8U; +static constexpr uint32_t kBitThreeBytes = 24U; + +#if(defined(HOST_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif +#if(defined(DEV_VISIBILITY)) && (defined(__GNUC__)) +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +enum DataType { + DT_FLOAT = 0, // float type + DT_FLOAT16 = 1, // fp16 type + DT_INT8 = 2, // int8 type + DT_INT32 = 3, // int32 type + DT_UINT8 = 4, // uint8 type + // reserved + DT_INT16 = 6, // int16 type + DT_UINT16 = 7, // uint16 type + DT_UINT32 = 8, // unsigned int32 + DT_INT64 = 9, // int64 type + DT_UINT64 = 10, // unsigned int64 + DT_DOUBLE = 11, // double type + DT_BOOL = 12, // bool type + DT_STRING = 13, // string type + DT_DUAL_SUB_INT8 = 14, // dual output int8 type + DT_DUAL_SUB_UINT8 = 15, // dual output uint8 type + DT_COMPLEX64 = 16, // complex64 type + DT_COMPLEX128 = 17, // complex128 type + DT_QINT8 = 18, // qint8 type + DT_QINT16 = 19, // qint16 type + DT_QINT32 = 20, // qint32 type + DT_QUINT8 = 21, // quint8 type + DT_QUINT16 = 22, // quint16 type + DT_RESOURCE = 23, // resource type + DT_STRING_REF = 24, // string ref type + DT_DUAL = 25, // dual output type + DT_VARIANT = 26, // dt_variant type + DT_BF16 = 27, // bf16 type + DT_UNDEFINED = 28, // Used to indicate a DataType field has not been set. + DT_INT4 = 29, // int4 type + DT_UINT1 = 30, // uint1 type + DT_INT2 = 31, // int2 type + DT_UINT2 = 32, // uint2 type + DT_COMPLEX32 = 33, // complex32 type + DT_HIFLOAT8 = 34, // hifloat8 type + DT_FLOAT8_E5M2 = 35, // float8_e5m2 type + DT_FLOAT8_E4M3FN = 36, // float8_e4m3fn type + DT_FLOAT8_E8M0 = 37, // float8_e8m0 type + DT_FLOAT6_E3M2 = 38, // float6_e3m2 type + DT_FLOAT6_E2M3 = 39, // float6_e2m3 type + DT_FLOAT4_E2M1 = 40, // float4_e2m1 type + DT_FLOAT4_E1M2 = 41, // float4_e1m2 type + DT_MAX // Mark the boundaries of data types +}; + +// used for data type of DT_STRING +struct StringHead { + int64_t addr; // the addr of string + int64_t len; // the length of string +}; + +inline int GetSizeByDataType(DataType data_type) { + static int data_type_size[DT_MAX] = { + 4, // DT_FLOAT = 0, float type + 2, // DT_FLOAT16 = 1, fp16 type + 1, // DT_INT8 = 2, int8 type + 4, // DT_INT32 = 3, int32 type + 1, // DT_UINT8 = 4, uint8 type + -1, // reserved + 2, // DT_INT16 = 6, int16 type + 2, // DT_UINT16 = 7, uint16 type + 4, // DT_UINT32 = 8, unsigned int32 + 8, // DT_INT64 = 9, int64 type + 8, // DT_UINT64 = 10, unsigned int64 + 8, // DT_DOUBLE = 11, double type + 1, // DT_BOOL = 12, bool type + -1, // DT_STRING = 13, string type + 1, // DT_DUAL_SUB_INT8 = 14, dual output int8 type + 1, // DT_DUAL_SUB_UINT8 = 15, dual output uint8 type + 8, // DT_COMPLEX64 = 16, complex64 type + 16, // DT_COMPLEX128 = 17, complex128 type + 1, // DT_QINT8 = 18, qint8 type + 2, // DT_QINT16 = 19, qint16 type + 4, // DT_QINT32 = 20, qint32 type + 1, // DT_QUINT8 = 21, quint8 type + 2, // DT_QUINT16 = 22, quint16 type + 8, // DT_RESOURCE = 23, resource type + -1, // DT_STRING_REF = 24, string ref type + 5, // DT_DUAL = 25, dual output type (float + int8) + 8, // DT_VARIANT variant type + 2, // DT_BF16 = 27, bf16 type + -1, // DT_UNDEFINED = 28 Used to indicate a DataType field has not been set. + kDataTypeSizeBitOffset + 4, // DT_INT4 = 29, int4 type + kDataTypeSizeBitOffset + 1, // DT_UINT1 = 30, uint1 type + kDataTypeSizeBitOffset + 2, // DT_INT2 = 31, int2 type + kDataTypeSizeBitOffset + 2, // DT_UINT2 = 32, uint2 type + 4, // DT_COMPLEX32 = 33, complex32 type + 1, // DT_HIFLOAT8, hifloat8 type + 1, // DT_FLOAT8_E5M2, float8_e5m2 type + 1, // DT_FLOAT8_E4M3FN, float8_e4m3fn type + 1, // DT_FLOAT8_E8M0, float8_e8m0 type + kDataTypeSizeBitOffset + 6, // DT_FLOAT6_E3M2, float6_e3m2 type, 6bit + kDataTypeSizeBitOffset + 6, // DT_FLOAT6_E2M3, float6_e2m3 type, 6bit + kDataTypeSizeBitOffset + 4, // DT_FLOAT4_E2M1, float4_e2m1 type, 4bit + kDataTypeSizeBitOffset + 4, // DT_FLOAT4_E1M2, float4_e1m2 type, 4bit + // DT_MAX + }; + if ((data_type < 0) || (data_type >= DT_MAX)) { + return -1; + } + return data_type_size[data_type]; +} + +/// @brief Calculates the length in bytes based on the DataType and the number of elements. +/// @param element_count +/// @param data_type +/// @return +int64_t GetSizeInBytes(int64_t element_count, DataType data_type); + +enum Format { + FORMAT_NCHW = 0, // NCHW + FORMAT_NHWC, // NHWC + FORMAT_ND, // Nd Tensor + FORMAT_NC1HWC0, // NC1HWC0 + FORMAT_FRACTAL_Z, // FRACTAL_Z + FORMAT_NC1C0HWPAD = 5, + FORMAT_NHWC1C0, + FORMAT_FSR_NCHW, + FORMAT_FRACTAL_DECONV, + FORMAT_C1HWNC0, + FORMAT_FRACTAL_DECONV_TRANSPOSE = 10, + FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, + FORMAT_NC1HWC0_C04, // NC1HWC0, C0 is 4 + FORMAT_FRACTAL_Z_C04, // FRACZ, C0 is 4 + FORMAT_CHWN, + FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS = 15, + FORMAT_HWCN, + FORMAT_NC1KHKWHWC0, // KH,KW kernel h& kernel w maxpooling max output format + FORMAT_BN_WEIGHT, + FORMAT_FILTER_HWCK, // filter input tensor format + FORMAT_HASHTABLE_LOOKUP_LOOKUPS = 20, + FORMAT_HASHTABLE_LOOKUP_KEYS, + FORMAT_HASHTABLE_LOOKUP_VALUE, + FORMAT_HASHTABLE_LOOKUP_OUTPUT, + FORMAT_HASHTABLE_LOOKUP_HITS, + FORMAT_C1HWNCoC0 = 25, + FORMAT_MD, + FORMAT_NDHWC, + FORMAT_FRACTAL_ZZ, + FORMAT_FRACTAL_NZ, + FORMAT_NCDHW = 30, + FORMAT_DHWCN, // 3D filter input tensor format + FORMAT_NDC1HWC0, + FORMAT_FRACTAL_Z_3D, + FORMAT_CN, + FORMAT_NC = 35, + FORMAT_DHWNC, + FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format + FORMAT_FRACTAL_ZN_LSTM, + FORMAT_FRACTAL_Z_G, + FORMAT_RESERVED = 40, + FORMAT_ALL, + FORMAT_NULL, + FORMAT_ND_RNN_BIAS, + FORMAT_FRACTAL_ZN_RNN, + FORMAT_NYUV = 45, + FORMAT_NYUV_A, + FORMAT_NCL, + FORMAT_FRACTAL_Z_WINO, + FORMAT_C1HWC0, + // Add new formats definition here + FORMAT_END, + // FORMAT_MAX defines the max value of Format. + // Any Format should not exceed the value of FORMAT_MAX. + // ** Attention ** : FORMAT_MAX stands for the SPEC of enum Format and almost SHOULD NOT be used in code. + // If you want to judge the range of Format, you can use FORMAT_END. + FORMAT_MAX = 0xff +}; + +/// Get format from primary and sub-format, +/// in bits field: +/// --------------------------------------------- +/// | 4bits | 4bits | 2 bytes | 1 byte | +/// |----------|-----------|------------|--------| +/// | reserved | c0_format | sub-format | format | +/// --------------------------------------------- +/// @param primary_format +/// @param sub_format +/// @param c0_format +/// @return +inline int32_t GetFormatFromSub(int32_t primary_format, int32_t sub_format) { + return static_cast((static_cast(primary_format) & 0xffU) | + ((static_cast(sub_format) & 0xffffU) << kBitNumOfOneByte)); +} + +inline int32_t GetFormatFromC0(int32_t format, int32_t c0_format) { + return static_cast((static_cast(format) & 0xffffffU) | + ((static_cast(c0_format) & 0xfU) << kBitThreeBytes)); +} + +inline int32_t GetFormatFromSubAndC0(int32_t primary_format, int32_t sub_format, int32_t c0_format) { + return static_cast((static_cast(primary_format) & 0xffU) | + ((static_cast(sub_format) & 0xffffU) << kBitNumOfOneByte) | + ((static_cast(c0_format) & 0xfU) << kBitThreeBytes)); +} + +inline int32_t GetPrimaryFormat(int32_t format) { + return static_cast(static_cast(format) & 0xffU); +} + +inline int32_t GetSubFormat(int32_t format) { + return static_cast((static_cast(format) & 0xffff00U) >> kBitNumOfOneByte); +} + +inline bool HasSubFormat(int32_t format) { + return GetSubFormat(format) > 0; +} + +inline bool HasC0Format(int32_t format) { + return ((static_cast(format) & 0xf000000U) >> kBitThreeBytes) > 0; +} + +inline int32_t GetC0Format(int32_t format) { + return static_cast((static_cast(format) & 0xf000000U) >> kBitThreeBytes); +} + +inline int64_t GetC0Value(int32_t format) { + if (!HasC0Format(format)) { + return -1; + } + return static_cast(1 << + (static_cast((static_cast(format) & 0xf000000U) >> kBitThreeBytes) - 1)); +} + +// for unknown shape op type +enum UnknowShapeOpType { + DEPEND_IN_SHAPE = 1, // op out shape get by input shape + DEPEND_CONST_VALUE = 2, // op out shape get by const op value + DEPEND_SHAPE_RANGE = 3, // op out shape get by range + DEPEND_COMPUTE = 4 // op out shape get by totally computing +}; + +struct TensorDescInfo { + Format format_ = FORMAT_RESERVED; // tbe op register support format + DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype +}; + +enum DeviceType { + NPU = 0, + CPU = 1, +}; + +enum Placement { + kPlacementHost = 0, // host data addr + kPlacementDevice = 1, // device data addr + kPlacementEnd, +}; + +/// +/// @brief Get a format name from enum +/// @param format +/// @return +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY +const char_t *GetFormatName(Format format); + +class TensorTypeImpl; +struct TensorType { + explicit TensorType(DataType dt); + + TensorType(const std::initializer_list &initial_types); + + static TensorType ALL() { + return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, + DT_QUINT8, DT_RESOURCE, DT_STRING, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, + DT_BF16, DT_COMPLEX32}; + } + + static TensorType QuantifiedType() { return TensorType{DT_QINT16, DT_QINT32, DT_QINT8, DT_QUINT16, DT_QUINT8}; } + + static TensorType OrdinaryType() { + return TensorType{DT_BOOL, DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, + DT_BF16, DT_COMPLEX32}; + } + + static TensorType BasicType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, + DT_INT32, DT_INT64, DT_INT8, DT_QINT16, DT_QINT32, DT_QINT8, + DT_QUINT16, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, + DT_BF16, DT_COMPLEX32}; + } + + static TensorType NumberType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, + DT_INT8, DT_QINT32, DT_QINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, + DT_BF16, DT_COMPLEX32}; + } + + static TensorType RealNumberType() { + return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_INT16, DT_INT32, DT_INT64, + DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8, DT_BF16}; + } + + static TensorType ComplexDataType() { return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_COMPLEX32}; } + + static TensorType IntegerDataType() { + return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; + } + + static TensorType SignedDataType() { return TensorType{DT_INT16, DT_INT32, DT_INT64, DT_INT8}; } + + static TensorType UnsignedDataType() { return TensorType{DT_UINT16, DT_UINT32, DT_UINT64, DT_UINT8}; } + + static TensorType FloatingDataType() { return TensorType{DT_DOUBLE, DT_FLOAT, DT_FLOAT16}; } + + static TensorType IndexNumberType() { return TensorType{DT_INT32, DT_INT64}; } + + static TensorType UnaryDataType() { + return TensorType{DT_COMPLEX128, DT_COMPLEX64, DT_DOUBLE, DT_FLOAT, DT_FLOAT16, DT_BF16, DT_COMPLEX32}; + } + + static TensorType FLOAT() { return TensorType{DT_FLOAT, DT_FLOAT16, DT_BF16}; } + + std::shared_ptr tensor_type_impl_; +}; + +struct ListTensorType { + explicit ListTensorType(const TensorType &type) : tensor_type(type) {}; + TensorType tensor_type; +}; + +class Promote { + public: + Promote(const std::initializer_list &syms); + std::vector Syms() const; + + Promote(const Promote &other) = delete; + Promote &operator=(const Promote &other) = delete; + + Promote(Promote &&other) noexcept; + Promote &operator=(Promote &&other) noexcept; + + private: + std::shared_ptr data_; +}; +} // namespace ge + +namespace domi { +enum class ImplyType : unsigned int { + BUILDIN = 0, // Built in operator, normally executed by OME + TVM, // Compile to TVM bin file for execution + CUSTOM, // User defined calculation logic, executed by CPU + AI_CPU, // AICPU + CCE, // Cce + GELOCAL, // GE local, do node need execute by device + HCCL, // Hccl + INVALID = 0xFFFFFFFF, +}; +using char_t = ge::char_t; +using float32_t = ge::float32_t; +using float64_t = ge::float64_t; +} // namespace domi + +#endif // INC_EXTERNAL_GRAPH_TYPES_H_ diff --git a/inc/metadef/external/graph/utils/args_format_desc_utils.h b/inc/metadef/external/graph/utils/args_format_desc_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..aec7ac93e0659f1d52757b6537a6fbea7cb112ad --- /dev/null +++ b/inc/metadef/external/graph/utils/args_format_desc_utils.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. +* This file is a part of the CANN Open Software. +* Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). +* Please refer to the License for details. You may not use this file except in compliance with the License. +* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +* See LICENSE in the root of the software repository for the full text of the License. +* ===================================================================================================================*/ + +#ifndef METADEF_CXX_ARGS_FORMAT_DESC_UTILS_H +#define METADEF_CXX_ARGS_FORMAT_DESC_UTILS_H + +#include +#include + +#include "graph/ge_error_codes.h" +#include "graph/op_desc.h" +#include "register/hidden_inputs_func_registry.h" + +namespace ge { +enum class AddrType { + INPUT = 0, + OUTPUT, + WORKSPACE, + TILING, + INPUT_DESC, + OUTPUT_DESC, + FFTS_ADDR, + OVERFLOW_ADDR, + TILING_FFTS, + HIDDEN_INPUT, + TILING_CONTEXT, + OP_TYPE, + PLACEHOLDER, + CUSTOM_VALUE, + INPUT_INSTANCE, + OUTPUT_INSTANCE +}; + +enum class TilingContextSubType { TILING_CONTEXT = -1, TILING_DATA, TILING_KEY, BLOCK_DIM }; + +// i* -> ir_idx = -1,folded=false +// 对于输入输出,idx表示ir定义的idx,-1表示所有输入、所有输出,此时非动态输入、输出默认展开,动态输出要i1*这样才表示展开 +// 对于workspace -1表示个数未知,folded暂时无意义 +// 对ffts尾块非尾块地址,idx=0表示非尾块,idx=1表示尾块 +// 对于hidden input,支持多个,idx表示索引,从0开始,reserved字段表示类型(uint32) +// 对于custom value,reserved字段表示需要透传的值(uint64),其他字段无意义 +// 对于其他类型, idx和fold暂时没有意义 +struct ArgDesc { + AddrType addr_type; + int32_t ir_idx; + bool folded; + uint8_t reserved[8]; +}; + +class ArgsFormatDescUtils { + public: + static void Append(std::vector &arg_descs, AddrType type, int32_t ir_idx = -1, bool folded = false); + + static void AppendTilingContext(std::vector &arg_descs, + TilingContextSubType sub_type = TilingContextSubType::TILING_CONTEXT); + + // insert_pos为插入位置,-1表示添加到最后,0表示添加到最前面,以此类推,注意不能超过arg_descs的个数,否则会报错 + // input_cnt为插入hidden input的个数 + static graphStatus InsertHiddenInputs(std::vector &arg_descs, int32_t insert_pos, + HiddenInputsType hidden_type, size_t input_cnt = 1U); + + // insert_pos为插入位置,-1表示添加到最后,0表示添加到最前面,以此类推,注意不能超过arg_descs的个数,否则会报错 + static graphStatus InsertCustomValue(std::vector &arg_descs, int32_t insert_pos, uint64_t custom_value); + + static std::string ToString(const std::vector &arg_descs); + + // 字符串用i*这样的通配符时,返回的argDesc不会按照实际个数展开 + static graphStatus Parse(const std::string &str, std::vector &arg_descs); + + static std::string Serialize(const std::vector &arg_descs); +}; +} + +#endif // METADEF_CXX_ARGS_FORMAT_DESC_UTILS_H diff --git a/inc/metadef/external/graph/utils/type_utils.h b/inc/metadef/external/graph/utils/type_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..ed0ae38682b518e2c47b48d6bc068f616e20755c --- /dev/null +++ b/inc/metadef/external/graph/utils/type_utils.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_TYPE_UTILS_H_ +#define INC_GRAPH_UTILS_TYPE_UTILS_H_ + +#include +#include +#include +#include +#include "graph/types.h" +#include "graph/ascend_string.h" + +namespace ge { +class TypeUtils { + public: + // todo: add ATTRIBUTED_DEPRECATED for std::string interface + static std::string DataTypeToSerialString(const DataType data_type); + static DataType SerialStringToDataType(const std::string &str); + static std::string FormatToSerialString(const Format format); + static Format SerialStringToFormat(const std::string &str); + static Format DataFormatToFormat(const std::string &str); + static bool GetDataTypeLength(const ge::DataType data_type, uint32_t &length); + + static AscendString DataTypeToAscendString(const DataType &data_type); + static DataType AscendStringToDataType(const AscendString &str); + static AscendString FormatToAscendString(const Format &format); + static Format AscendStringToFormat(const AscendString &str); + static Format DataFormatToFormat(const AscendString &str); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TYPE_UTILS_H_ diff --git a/inc/metadef/external/hcom/hcom_topo_info.h b/inc/metadef/external/hcom/hcom_topo_info.h new file mode 100644 index 0000000000000000000000000000000000000000..36f7097f1d2066a9a436716a5cc9ac1b163d9019 --- /dev/null +++ b/inc/metadef/external/hcom/hcom_topo_info.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ +#define METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ + +#include +#include +#include "ge_common/ge_api_types.h" + +namespace ge { +static constexpr uint32_t COMM_MESH = 0b1U; +static constexpr uint32_t COMM_SWITCH = (COMM_MESH << 1U); +static constexpr uint32_t COMM_RING = (COMM_MESH << 2U); +static constexpr uint32_t COMM_PAIRWISE = (COMM_MESH << 3U); +class HcomTopoInfo { + public: + enum class TopoLevel { + L0 = 0, + L1, + MAX, + }; + struct TopoLevelDesc { + uint32_t comm_sets; + uint32_t rank_size; + }; + using TopoDescs = TopoLevelDesc[static_cast(TopoLevel::MAX)]; + struct TopoInfo { + int64_t rank_size; + void *notify_handle; + TopoDescs topo_level_descs; + }; + static HcomTopoInfo &Instance(); + bool TopoInfoHasBeenSet(const char_t *group); + bool TryGetGroupTopoInfo(const char_t *group, TopoInfo &info); + Status SetGroupTopoInfo(const char_t *group, const TopoInfo &info); + Status GetGroupRankSize(const char_t *group, int64_t &rank_size); + TopoDescs *GetGroupTopoDesc(const char_t *group); + Status GetGroupNotifyHandle(const char_t *group, void *¬ify_handle); + void UnsetGroupTopoInfo(const char_t *group) { + const std::lock_guard lock(mutex_); + (void) rank_info_.erase(group); + } + private: + HcomTopoInfo() = default; + ~HcomTopoInfo() = default; + std::unordered_map rank_info_; + std::mutex mutex_; +}; +} + +#endif // METADEF_CXX_INC_EXTERNAL_HCOM_HCOM_TOPO_INFO_H_ diff --git a/inc/metadef/external/op_common/common_infershape_fns.h b/inc/metadef/external/op_common/common_infershape_fns.h new file mode 100644 index 0000000000000000000000000000000000000000..35790d7921e6adb90c475ab621f2bc6c1a95fc28 --- /dev/null +++ b/inc/metadef/external/op_common/common_infershape_fns.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*! + * \file common_infershape_fns.h + * \brief + */ + +#ifndef EXTERNAL_OP_COMMON_INFERSHAPE_FNS_H_ +#define EXTERNAL_OP_COMMON_INFERSHAPE_FNS_H_ + +#include "external/exe_graph/runtime/shape.h" +#include "external/exe_graph/runtime/infer_shape_context.h" + +namespace opcommon { +ge::graphStatus InferShape4BroadcastOp(gert::InferShapeContext* context); +ge::graphStatus InferShape4ReduceOp(gert::InferShapeContext* context); +ge::graphStatus InferShape4ElewiseOp(gert::InferShapeContext* context); +} // namespace opcommon + +#endif // EXTERNAL_OP_COMMON_INFERSHAPE_FNS_H_ diff --git a/inc/metadef/external/op_common/data_type_utils.h b/inc/metadef/external/op_common/data_type_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9551e561a34ca5effe88e30dea80a012903726b6 --- /dev/null +++ b/inc/metadef/external/op_common/data_type_utils.h @@ -0,0 +1,130 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*! + * \file data_type_utils.h + * \brief + */ + +#ifndef EXTERNAL_OP_COMMON_DATA_TYPE_UTILS_H_ +#define EXTERNAL_OP_COMMON_DATA_TYPE_UTILS_H_ + +#include "graph/types.h" + +namespace opcommon { +inline bool IsComplexType(const ge::DataType type) { + return (type == ge::DataType::DT_COMPLEX32 || type == ge::DataType::DT_COMPLEX64 || + type == ge::DataType::DT_COMPLEX128); +} + +inline bool IsFloatingType(const ge::DataType type) +{ + return (type == ge::DataType::DT_DOUBLE || type == ge::DataType::DT_FLOAT || type == ge::DataType::DT_BF16 + || type == ge::DataType::DT_FLOAT16); +} + +inline bool IsIntegralType(const ge::DataType type) +{ + return (type == ge::DataType::DT_INT8 || type == ge::DataType::DT_INT16 || type == ge::DataType::DT_INT32 + || type == ge::DataType::DT_INT64 || type == ge::DataType::DT_UINT8 || type == ge::DataType::DT_UINT16 + || type == ge::DataType::DT_UINT32 || type == ge::DataType::DT_UINT64); +} + +inline bool IsIntegralType(const ge::DataType type, const bool include_bool) +{ + bool is_integral = IsIntegralType(type); + return include_bool ? (is_integral || (type == ge::DataType::DT_BOOL)) : is_integral; +} + +inline bool CanCast(const ge::DataType from, const ge::DataType to) +{ + if (IsComplexType(from) && !IsComplexType(to)) { + return false; + } + + if (IsFloatingType(from) && IsIntegralType(to, false)) { + return false; + } + + if (from != ge::DataType::DT_BOOL && to == ge::DataType::DT_BOOL) { + return false; + } + + return true; +} + +inline ge::DataType PromoteType(ge::DataType type_a, ge::DataType type_b) +{ + if (type_a < 0 || type_b < 0 || type_a >= ge::DataType::DT_MAX || type_b >= ge::DataType::DT_MAX) { + return ge::DataType::DT_UNDEFINED; + } + + if (type_a == type_b) { + return type_a; + } + + constexpr auto u1 = ge::DataType::DT_UINT8; + constexpr auto i1 = ge::DataType::DT_INT8; + constexpr auto i2 = ge::DataType::DT_INT16; + constexpr auto i4 = ge::DataType::DT_INT32; + constexpr auto i8 = ge::DataType::DT_INT64; + constexpr auto f2 = ge::DataType::DT_FLOAT16; + constexpr auto f4 = ge::DataType::DT_FLOAT; + constexpr auto f8 = ge::DataType::DT_DOUBLE; + constexpr auto c2 = ge::DataType::DT_COMPLEX32; + constexpr auto c4 = ge::DataType::DT_COMPLEX64; + constexpr auto c8 = ge::DataType::DT_COMPLEX128; + constexpr auto b1 = ge::DataType::DT_BOOL; + constexpr auto bf = ge::DataType::DT_BF16; + constexpr auto ud = ge::DataType::DT_UNDEFINED; + // @formatter:off + static constexpr ge::DataType kPromoteTypesLookup[static_cast( + ge::DataType::DT_MAX)][static_cast(ge::DataType::DT_MAX)] = { + /* f4 f2 i1 i4 u1 xx i2 u2 u4 i8 u8 f8 b1 sv d1 D1 c4 c8 q1 q2 q4 Q1 Q2 rs sr du va bf, ud t4 T1 t2 T2 c2*/ + /* f4 0 */ {f4, f4, f4, f4, f4, ud, f4, ud, ud, f4, ud, f8, f4, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, f4, ud, ud, ud, ud, ud, c4}, + /* f2 1 */ {f4, f2, f2, f2, f2, ud, f2, ud, ud, f2, ud, f8, f2, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, f4, ud, ud, ud, ud, ud, c2}, + /* i1 2 */ {f4, f2, i1, i4, i2, ud, i2, ud, ud, i8, ud, f8, i1, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, + /* i4 3 */ {f4, f2, i4, i4, i4, ud, i4, ud, ud, i8, ud, f8, i4, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, + /* u1 4 */ {f4, f2, i2, i4, u1, ud, i2, ud, ud, i8, ud, f8, u1, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, + /* xx 5 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* i2 6 */ {f4, f2, i2, i4, i2, ud, i2, ud, ud, i8, ud, f8, i2, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, + /* u2 7 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* u4 8 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* i8 9 */ {f4, f2, i8, i8, i8, ud, i8, ud, ud, i8, ud, f8, i8, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, + /* u8 10*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, c8, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* f8 11*/ {f8, f8, f8, f8, f8, ud, f8, ud, ud, f8, ud, f8, f8, ud, ud, ud, c8, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, f8, ud, ud, ud, ud, ud, c8}, + /* b1 12*/ {f4, f2, i1, i4, u1, ud, i2, ud, ud, i8, ud, f8, b1, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c2}, + /* sv 13*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* d1 14*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* D1 15*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* c4 16*/ {c4, c4, c4, c4, c4, ud, c4, ud, ud, c4, ud, c8, c4, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, ud, ud, ud, ud, ud, c4}, + /* c8 17*/ {c8, c8, c8, c8, c8, ud, c8, ud, ud, c8, ud, c8, c8, ud, ud, ud, c8, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, c8, ud, ud, ud, ud, ud, c8}, + /* q1 18*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* q2 19*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* q4 20*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* Q1 21*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* Q2 22*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* rs 23*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* sr 24*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* du 25*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* va 26*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* bf 27*/ {f4, f4, bf, bf, bf, ud, bf, ud, ud, bf, ud, f8, bf, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, bf, ud, ud, ud, ud, ud, c4}, + /* ud 28*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* t4 29*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* T1 30*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* t2 31*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* T2 32*/ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud}, + /* c2 33*/ {c4, c2, c2, c2, c2, ud, c2, ud, ud, c2, ud, c8, c2, ud, ud, ud, c4, c8, ud, ud, ud, ud, ud, ud, ud, ud, ud, c4, ud, ud, ud, ud, ud, c2}, + }; + // @formatter:on + return kPromoteTypesLookup[static_cast(type_a)][static_cast(type_b)]; +} +} // namespace opcommon + +#endif // EXTERNAL_OP_COMMON_DATA_TYPE_UTILS_H_ diff --git a/inc/metadef/external/op_common/op_dev.h b/inc/metadef/external/op_common/op_dev.h new file mode 100644 index 0000000000000000000000000000000000000000..780bd8e2325a9e735445b21357588b2027be77a2 --- /dev/null +++ b/inc/metadef/external/op_common/op_dev.h @@ -0,0 +1,21 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*! + * \file op_dev.h + * \brief + */ + +#ifndef EXTERNAL_OP_COMMON_OP_DEV_H_ +#define EXTERNAL_OP_COMMON_OP_DEV_H_ + +#include "common_infershape_fns.h" +#include "op_error_code.h" + +#endif // EXTERNAL_OP_COMMON_OP_DEV_H_ diff --git a/inc/metadef/external/op_common/op_error_code.h b/inc/metadef/external/op_common/op_error_code.h new file mode 100644 index 0000000000000000000000000000000000000000..b329a1ee2fca451f76bf0b3feeb58e264dd7a2ec --- /dev/null +++ b/inc/metadef/external/op_common/op_error_code.h @@ -0,0 +1,23 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*! + * \file error_code.h + * \brief + */ +#ifndef EXTERNAL_OP_COMMON_ERROR_CODE_H_ +#define EXTERNAL_OP_COMMON_ERROR_CODE_H_ + +namespace opcommon { +enum ViewErrorCode { + VECTOR_INNER_ERROR = 89999 +}; +} // namespace opcommon + +#endif // EXTERNAL_OP_COMMON_ERROR_CODE_H_ diff --git a/inc/metadef/external/op_common/tiling_aicpu_task.h b/inc/metadef/external/op_common/tiling_aicpu_task.h new file mode 100644 index 0000000000000000000000000000000000000000..6d91d1f017f341f50ea17d89d80ad07adbaa1e49 --- /dev/null +++ b/inc/metadef/external/op_common/tiling_aicpu_task.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef EXTERNAL_OP_COMMON_TILING_AICPU_TASK_H_ +#define EXTERNAL_OP_COMMON_TILING_AICPU_TASK_H_ +#include "exe_graph/runtime/tiling_context.h" + +namespace optiling { +struct TilingAicpuTask { + gert::TilingContext *tilingContext; + const char *opType; + char reserve[64]; +}; +} // namespace optiling + +#endif \ No newline at end of file diff --git a/inc/metadef/external/platform/platform_info.h b/inc/metadef/external/platform/platform_info.h new file mode 100644 index 0000000000000000000000000000000000000000..14bcff47c4594fe7d6710afd6f748d2d9a9cdb07 --- /dev/null +++ b/inc/metadef/external/platform/platform_info.h @@ -0,0 +1,174 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef PLATFORM_INFO_H +#define PLATFORM_INFO_H + +#include +#include +#include +#include "platform_info_def.h" +#include "platform_infos_def.h" +#include "platform_infos_lite_def.h" + +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunused-private-field" +#endif + +namespace fe { +class PlatformInfoManager { + public: + PlatformInfoManager(const PlatformInfoManager &) = delete; + PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; + + static PlatformInfoManager &Instance(); + static PlatformInfoManager &GeInstance(); + uint32_t InitializePlatformInfo(); + uint32_t Finalize(); + + uint32_t GetPlatformInfo(const std::string SoCVersion, + PlatformInfo &platform_info, + OptionalInfo &opti_compilation_info); + + uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platform_info, + OptionalInfo &opti_compilation_info); + + void SetOptionalCompilationInfo(OptionalInfo &opti_compilation_info); + + uint32_t GetPlatformInfos(const std::string SoCVersion, + PlatFormInfos &platform_info, + OptionalInfos &opti_compilation_info); + + uint32_t GetPlatformInfoWithOutSocVersion(PlatFormInfos &platform_info, + OptionalInfos &opti_compilation_info); + + void SetOptionalCompilationInfo(OptionalInfos &opti_compilation_info); + + uint32_t UpdatePlatformInfos(PlatFormInfos &platform_infos); + + uint32_t UpdatePlatformInfos(const string &soc_version, PlatFormInfos &platform_infos); + + uint32_t GetPlatformInstanceByDevice(const uint32_t &device_id, PlatFormInfos &platform_infos); + + uint32_t GetRuntimePlatformInfosByDevice(const uint32_t &device_id, PlatFormInfos &platform_infos); + + uint32_t UpdateRuntimePlatformInfosByDevice(const uint32_t &device_id, PlatFormInfos &platform_infos); + + uint32_t InitRuntimePlatformInfos(const std::string &SoCVersion); + + uint32_t GetPlatFormInfosLite(SocVersion soc_version, PlatFormInfosLite &platform_infos_lite); + private: + PlatformInfoManager(); + ~PlatformInfoManager(); + + uint32_t LoadIniFile(std::string ini_file_real_path); + + void Trim(std::string &str); + + uint32_t LoadConfigFile(std::string real_path); + + std::string RealPath(const std::string &path); + + std::string GetSoFilePath(); + + void ParseVersion(std::map &version_map, + std::string &soc_version, + PlatformInfo &platform_info_temp); + + void ParseSocInfo(std::map &soc_info_map, + PlatformInfo &platform_info_temp); + + void ParseCubeOfAICoreSpec(std::map &ai_core_spec_map, + PlatformInfo &platform_info_temp); + + void ParseBufferOfAICoreSpec(std::map &ai_core_spec_map, + PlatformInfo &platform_info_temp); + + void ParseUBOfAICoreSpec(std::map &ai_core_spec_map, + PlatformInfo &platform_info_temp); + + void ParseUnzipOfAICoreSpec(std::map &ai_core_spec_map, + PlatformInfo &platform_info_temp); + + void ParseAICoreSpec(std::map &ai_core_spec_map, + PlatformInfo &platform_info_temp); + + void ParseBufferOfAICoreMemoryRates(std::map &ai_core_memory_rates_map, + PlatformInfo &platform_info_temp); + + void ParseAICoreMemoryRates(std::map &ai_core_memory_rates_map, + PlatformInfo &platform_info_temp); + + void ParseUBOfAICoreMemoryRates(std::map &ai_core_memory_rates_map, + PlatformInfo &platform_info_temp); + + void ParseAICoreintrinsicDtypeMap(std::map &ai_coreintrinsic_dtype_map, + PlatformInfo &platform_info_temp); + + void ParseVectorCoreSpec(std::map &vector_core_spec_map, + PlatformInfo &platform_info_temp); + + void ParseVectorCoreMemoryRates(std::map &vector_core_memory_rates_map, + PlatformInfo &platform_info_temp); + + void ParseCPUCache(std::map &CPUCacheMap, + PlatformInfo &platform_info_temp); + + void ParseVectorCoreintrinsicDtypeMap(std::map &vector_coreintrinsic_dtype_map, + PlatformInfo &platform_info_temp); + + uint32_t ParsePlatformInfoFromStrToStruct(std::map> &content_info_map, + std::string &soc_version, + PlatformInfo &platform_info_temp); + + void ParseAICoreintrinsicDtypeMap(std::map &ai_coreintrinsic_dtype_map, + PlatFormInfos &platform_info_temp); + + void ParseVectorCoreintrinsicDtypeMap(std::map &vector_coreintrinsic_dtype_map, + PlatFormInfos &platform_info_temp); + + void ParseSoftwareSpec(map &software_spec_map, PlatformInfo &platform_info_temp); + + void ParsePlatformRes(const std::string &label, + std::map &platform_res_map, + PlatFormInfos &platform_info_temp); + + uint32_t ParsePlatformInfo(std::map> &content_info_map, + std::string &soc_version, + PlatFormInfos &platform_info_temp); + + uint32_t AssemblePlatformInfoVector(std::map> &content_info_map); + void FillupFixPipeInfo(PlatFormInfos &platform_infos); + + bool init_flag_; + bool runtime_init_flag_; + std::map platform_info_map_; + + OptionalInfo opti_compilation_info_; + + std::map platform_infos_map_; + + OptionalInfos opti_compilation_infos_; + + std::map device_platform_infos_map_; + + PlatFormInfos runtime_platform_infos_; + + std::map runtime_device_platform_infos_map_; + + std::array platform_infos_lite_vec_; +}; +} // namespace fe + +#ifdef __clang__ +#pragma clang diagnostic pop +#endif + +#endif diff --git a/inc/metadef/external/platform/platform_info_def.h b/inc/metadef/external/platform/platform_info_def.h new file mode 100644 index 0000000000000000000000000000000000000000..06da38de47880ab633bc97524e5c810bfc2e9eed --- /dev/null +++ b/inc/metadef/external/platform/platform_info_def.h @@ -0,0 +1,156 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef PLATFORM_INFO_DEF_H +#define PLATFORM_INFO_DEF_H + +#include +#include +#include +#include + +using std::map; +using std::vector; +using std::string; + +namespace fe { +enum MemoryType { DDR = 0, HBM }; + +enum L2Type { Cache = 0, Buff }; + +enum JitCompileMode { REUSE_BINARY = 0, COMPILE_ONLINE, AUTO }; + +typedef struct tag_str_info { + std::string aic_version; + std::string ccec_aic_version; + std::string ccec_aiv_version; + std::string is_support_ai_cpu_compiler; + std::string short_soc_version; +} StrInfo; + +typedef struct tag_so_c_info { + uint32_t ai_core_cnt; + uint32_t vector_core_cnt; + uint32_t ai_cpu_cnt; + MemoryType memory_type; + uint64_t memory_size; + L2Type l2_type; + uint64_t l2_size; + uint32_t l2PageNum; + uint32_t task_num; + int32_t arch_type; + int32_t chip_type; + tag_so_c_info() + : ai_core_cnt(0), vector_core_cnt(0), memory_size(0), l2_size(0), l2PageNum(0), task_num(0), arch_type(-1), + chip_type(-1) {} +} SoCInfo; + +typedef struct tag_ai_core_spec { + double cube_freq; + uint64_t cube_m_size; + uint64_t cube_n_size; + uint64_t cube_k_size; + uint64_t vec_calc_size; + uint64_t l0_a_size; + uint64_t l0_b_size; + uint64_t l0_c_size; + uint64_t l1_size; + uint64_t smask_buffer; + uint64_t ub_size; + uint64_t ubblock_size; + uint64_t ubbank_size; + uint64_t ubbank_num; + uint64_t ubburst_in_one_block; + uint64_t ubbank_group_num; + uint32_t unzip_engines; + uint32_t unzip_max_ratios; + uint32_t unzip_channels; + uint8_t unzip_is_tight; + uint8_t cube_vector_split; + uint8_t sparsity; +} AiCoreSpec; + +typedef struct tag_ai_core_memory_rates { + double ddr_rate; + double ddr_read_rate; + double ddr_write_rate; + double l2_rate; + double l2_read_rate; + double l2_write_rate; + double l1_to_l0_a_rate; + double l1_to_l0_b_rate; + double l1_to_ub_rate; + double l0_c_to_ub_rate; + double ub_to_l2_rate; + double ub_to_ddr_rate; + double ub_to_l1_rate; +} AiCoreMemoryRates; + +typedef struct tag_vector_core_spec { + double vec_freq; + uint64_t vec_calc_size; + uint64_t smask_buffer; + uint64_t ub_size; + uint64_t ubblock_size; + uint64_t ubbank_size; + uint64_t ubbank_num; + uint64_t ubburst_in_one_block; + uint64_t ubbank_group_num; + uint64_t vector_reg_size; + uint64_t predicate_reg_size; + uint64_t address_reg_size; + uint64_t alignment_reg_size; +} VectorCoreSpec; + +typedef struct tag_vector_core_memory_rates { + double ddr_rate; + double ddr_read_rate; + double ddr_write_rate; + double l2_rate; + double l2_read_rate; + double l2_write_rate; + double ub_to_l2_rate; + double ub_to_ddr_rate; +} VectorCoreMemoryRates; + +typedef struct tag_cpu_cache { + uint32_t AICPUSyncBySW; + uint32_t TSCPUSyncBySW; +} CPUCache; + +typedef struct tag_software_spec { + bool jit_compile_default_value; + JitCompileMode jit_compile_mode; + tag_software_spec() { + jit_compile_default_value = false; + jit_compile_mode = AUTO; + } +} SoftwareSpec; + +typedef struct tag_platform_info { + StrInfo str_info; + SoCInfo soc_info; + AiCoreSpec ai_core_spec; + AiCoreMemoryRates ai_core_memory_rates; + std::map> ai_core_intrinsic_dtype_map; + VectorCoreSpec vector_core_spec; + VectorCoreMemoryRates vector_core_memory_rates; + CPUCache cpucache; + std::map> vector_core_intrinsic_dtype_map; + SoftwareSpec software_spec; +} PlatformInfo; + +typedef struct tag_optional_info { + std::string soc_version; + std::string core_type; + uint32_t ai_core_num; + std::string l1_fusion_flag; +} OptionalInfo; +} // namespace fe +#endif diff --git a/inc/metadef/external/platform/platform_infos_def.h b/inc/metadef/external/platform/platform_infos_def.h new file mode 100644 index 0000000000000000000000000000000000000000..b5a43e3471ef4d71cba91200033889ce195ae241 --- /dev/null +++ b/inc/metadef/external/platform/platform_infos_def.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef PLATFORM_INFOS_DEF_H +#define PLATFORM_INFOS_DEF_H + +#include +#include +#include +#include +#include "platform_info_def.h" + +namespace fe { +enum class LocalMemType { + L0_A = 0, + L0_B = 1, + L0_C = 2, + L1 = 3, + L2 = 4, + UB = 5, + HBM = 6, + RESERVED +}; +class PlatFormInfosImpl; +using PlatFormInfosImplPtr = std::shared_ptr; +class PlatFormInfos { + public: + bool Init(); + std::map> GetAICoreIntrinsicDtype(); + std::map> GetVectorCoreIntrinsicDtype(); + bool GetPlatformRes(const std::string &label, const std::string &key, std::string &val); + bool GetPlatformResWithLock(const std::string &label, const std::string &key, std::string &val); + bool GetPlatformRes(const std::string &label, std::map &res); + bool GetPlatformResWithLock(const std::string &label, std::map &res); + bool GetPlatformResWithLock(std::map> &res); + uint32_t GetCoreNum() const; + uint32_t GetCoreNumWithLock() const; + void GetLocalMemSize(const LocalMemType &mem_type, uint64_t &size); + void GetLocalMemBw(const LocalMemType &mem_type, uint64_t &bw_size); + + std::map> GetAICoreIntrinsicDtype() const; + std::map> GetVectorCoreIntrinsicDtype() const; + bool GetPlatformRes(const std::string &label, const std::string &key, std::string &val) const; + + void SetAICoreIntrinsicDtype(std::map> &intrinsic_dtypes); + void SetVectorCoreIntrinsicDtype(std::map> &intrinsic_dtypes); + void SetPlatformRes(const std::string &label, std::map &res); + void SetPlatformResWithLock(const std::string &label, std::map &res); + std::map> GetFixPipeDtypeMap(); + void SetFixPipeDtypeMap(const std::map> &fixpipe_dtype_map); + void SetCoreNumByCoreType(const std::string &core_type); + void SetCoreNum(const uint32_t &core_num); + + std::string SaveToBuffer(); + bool LoadFromBuffer(const char *buf_ptr, const size_t buf_len); + uint32_t GetCoreNumByType(const std::string &core_type); + private: + bool InitByInstance(); + uint32_t core_num_ {0}; + PlatFormInfosImplPtr platform_infos_impl_ {nullptr}; +}; + +class OptionalInfosImpl; +using OptionalInfosImplPtr = std::shared_ptr; +class OptionalInfos { + public: + bool Init(); + std::string GetSocVersion(); + std::string GetCoreType(); + uint32_t GetAICoreNum(); + std::string GetL1FusionFlag(); + std::map> GetFixPipeDtypeMap(); + + std::string GetSocVersion() const; + std::string GetCoreType() const; + uint32_t GetAICoreNum() const; + std::string GetL1FusionFlag() const; + std::map> GetFixPipeDtypeMap() const; + void SetFixPipeDtypeMap(const std::map> &fixpipe_dtype_map); + void SetSocVersion(std::string soc_version); + void SetSocVersionWithLock(std::string soc_version); + void SetCoreType(std::string core_type); + void SetAICoreNum(uint32_t ai_core_num); + void SetL1FusionFlag(std::string l1_fusion_flag); + private: + OptionalInfosImplPtr optional_infos_impl_ {nullptr}; +}; +} +#endif diff --git a/inc/metadef/external/platform/platform_infos_lite_def.h b/inc/metadef/external/platform/platform_infos_lite_def.h new file mode 100644 index 0000000000000000000000000000000000000000..9d5f9540dfbdc94bf320c5af902581e4b4c56dd4 --- /dev/null +++ b/inc/metadef/external/platform/platform_infos_lite_def.h @@ -0,0 +1,567 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_PLATFORM_PLATFORM_INFOS_LITE_DEF_H_ +#define INC_EXTERNAL_PLATFORM_PLATFORM_INFOS_LITE_DEF_H_ +#include "platform_infos_def.h" +namespace fe { +enum class SocVersion { + /* claster ASCEND910 */ + ASCEND910A = 0, + ASCEND910B = 1, + ASCEND910PREMIUMA = 2, + ASCEND910PROA = 3, + ASCEND910PROB = 4, + + /* claster ASCEND910B */ + ASCEND910B1 = 5, + ASCEND910B2 = 6, + ASCEND910B3 = 7, + ASCEND910B4 = 8, + + /* claster ASCEND910_93 */ + ASCEND910_9391 = 9, + ASCEND910_9381 = 10, + ASCEND910_9372 = 11, + + /* claster ASCEND310P */ + ASCEND310P1 = 12, + ASCEND310P3 = 13, + + /* Others */ + ASCEND310 = 14, + ASCEND310B1 = 15, + ASCEND310M1 = 16, + ASCEND610 = 17, + ASCEND610LITE = 18, + ASCEND615 = 19, + ASCEND710VIR01 = 20, + ASCEND710VIR02 = 21, + ASCEND710VIR04 = 22, + ASCEND710VIR08 = 23, + BS9SX1AA = 24, + HI3796CV300CS = 25, + HI3796CV300ES = 26, + OPTG = 27, + SD3403 = 28, + TSNSC = 29, + TSNSE = 30, + ASCEND031 = 31, + ASCEND035 = 32, + ASCEND310B2 = 33, + ASCEND310B3 = 34, + ASCEND310B4 = 35, + ASCEND320 = 36, + ASCEND920A = 37, + ASCEND910B2C = 38, + BS9SX2AA = 39, + BS9SX2AB = 40, + MC61AM21AB = 41, + AS31XM1X = 42, + MC61AM21AA = 43, + BS9SX1AB = 44, + BS9SX1AC = 45, + ASCEND035A = 46, + ASCEND035B = 47, + ASCEND910B4_1 = 48, + SOC_VERSION_BOTTOM +}; +const size_t TOTAL_SOC_COUNT = static_cast(SocVersion::SOC_VERSION_BOTTOM); + +const std::map SOC_VERSION_STR { + {"Ascend910A", SocVersion::ASCEND910A}, + {"Ascend910B", SocVersion::ASCEND910B}, + {"Ascend910PremiumA", SocVersion::ASCEND910PREMIUMA}, + {"Ascend910ProA", SocVersion::ASCEND910PROA}, + {"Ascend910ProB", SocVersion::ASCEND910PROB}, + {"Ascend910B1", SocVersion::ASCEND910B1}, + {"Ascend910B2", SocVersion::ASCEND910B2}, + {"Ascend910B3", SocVersion::ASCEND910B3}, + {"Ascend910B4", SocVersion::ASCEND910B4}, + {"Ascend910B4-1", SocVersion::ASCEND910B4_1}, + {"Ascend910B2C", SocVersion::ASCEND910B2C}, + {"Ascend910_9391", SocVersion::ASCEND910_9391}, + {"Ascend910_9381", SocVersion::ASCEND910_9381}, + {"Ascend910_9372", SocVersion::ASCEND910_9372}, + {"Ascend920A", SocVersion::ASCEND920A}, + {"Ascend310P1", SocVersion::ASCEND310P1}, + {"Ascend310P3", SocVersion::ASCEND310P3}, + {"Ascend310", SocVersion::ASCEND310}, + {"Ascend310B1", SocVersion::ASCEND310B1}, + {"Ascend310B2", SocVersion::ASCEND310B2}, + {"Ascend310B3", SocVersion::ASCEND310B3}, + {"Ascend310B4", SocVersion::ASCEND310B4}, + {"Ascend310M1", SocVersion::ASCEND310M1}, + {"Ascend320", SocVersion::ASCEND320}, + {"Ascend610", SocVersion::ASCEND610}, + {"Ascend610Lite", SocVersion::ASCEND610LITE}, + {"Ascend615", SocVersion::ASCEND615}, + {"Ascend710Vir01", SocVersion::ASCEND710VIR01}, + {"Ascend710Vir02", SocVersion::ASCEND710VIR02}, + {"Ascend710Vir04", SocVersion::ASCEND710VIR04}, + {"Ascend710Vir08", SocVersion::ASCEND710VIR08}, + {"BS9SX1AA", SocVersion::BS9SX1AA}, + {"Hi3796CV300CS", SocVersion::HI3796CV300CS}, + {"Hi3796CV300ES", SocVersion::HI3796CV300ES}, + {"OPTG", SocVersion::OPTG}, + {"SD3403", SocVersion::SD3403}, + {"TsnsC", SocVersion::TSNSC}, + {"TsnsE", SocVersion::TSNSE}, + {"Ascend031", SocVersion::ASCEND031}, + {"Ascend035", SocVersion::ASCEND035}, + {"BS9SX2AA", SocVersion::BS9SX2AA}, + {"BS9SX2AB", SocVersion::BS9SX2AB}, + {"MC61AM21AB", SocVersion::MC61AM21AB}, + {"AS31XM1X", SocVersion::AS31XM1X}, + {"MC61AM21AA", SocVersion::MC61AM21AA}, + {"BS9SX1AB", SocVersion::BS9SX1AB}, + {"BS9SX1AC", SocVersion::BS9SX1AC}, + {"Ascend035A", SocVersion::ASCEND035A}, + {"Ascend035B", SocVersion::ASCEND035B}, +}; +/* ==============================AICoreMemoryRates================================ */ +enum class AICoreMemoryRatesKey { + DDR_RATE = 0, + AI_CORE_MEMORY_RATES_KEY_END +}; + +const std::vector AI_CORE_MEMORY_RATES = { + "ddr_rate", +}; + +/* ============================AICoreSpecKey====================================== */ +enum class AICoreSpecKey { + UBBLOCK_SIZE = 0, + AI_CORE_SPEC_KEY_END +}; + +const std::vector AI_CORE_SPEC = { + "ubblock_size", +}; + +/* ============================CPUCacheKey======================================= */ +enum class CPUCacheKey { + AICPU_SYNC_BY_SW = 0, + TSCPU_SYNC_BY_SW = 0, + CPU_CACHE_KEY_END +}; + +const std::vector CPU_CACHE = { + "AICPUSyncBySW", + "TSCPUSyncBySW" +}; + +/* =============================DSARandom======================================== */ +enum class DSARandomKey { + COUNTER_WORKSPACE_SIZE = 0, + DSA_RANDOM_KEY_END +}; + +const std::vector DSA_RANDOM = { + "CounterWorkspaceSize" +}; + +/* ============================SoCInfo=========================================== */ +enum class SocInfoKey { + AI_CORE_CNT = 0, + SOC_INFO_KEY_END +}; + +const std::vector SOC_INFO = { + "ai_core_cnt" +}; + +/* ============================SoftwareSpec======================================= */ +enum class SoftwareSpecKey { + JIT_COMPILE_DEFAULT_VALUE = 0, + JIT_COMPILE_MODE, + SOFTARE_SPEC_KEY_END +}; + +const std::vector SOFTWARE_SPEC = { + "jit_compile_default_value", + "jit_compile_mode" +}; + +/* ============================VectorCoreMemoryRates==================================== */ +enum class VectorCoreMemoryRatesKey { + DDR_RATE = 0, + VECTOR_CORE_MEMORY_RATES_KEY_END +}; + +const std::vector VECTOR_CORE_MEMORY_RATES = { + "ddr_rate" +}; + +/* ============================VectorCoreSpec====================================== */ +enum class VectorCoreSpecKey { + VEC_FREQ = 0, + VECTOR_CORE_SPEC_KEY_END +}; + +const std::vector VECTOR_CORE_SPEC = { + "vec_freq" +}; + +/* ============================Version============================================== */ +enum class VersionKey { + VERSION_KEY_END +}; + +const std::vector VERSION = { +}; + + +/* ============================IntrinsicTypeKey================================== */ +enum class IntrinsicTypeKey { + MMAD = 0, + INTRINSIC_TYPE_KEY_END +}; + +const std::vector INTRINSIC_DTYPE_MAP_KEY_NAME = { + "Intrinsic_mmad" // MMAD = 0 +}; + +const std::vector AI_CORE_INTRINSIC_DTYPE_MAP = INTRINSIC_DTYPE_MAP_KEY_NAME; + +const std::vector VECTOR_CORE_INTRINSIC_DTYPE_MAP = INTRINSIC_DTYPE_MAP_KEY_NAME; + +const std::vector ALL_LABEL_END = {}; +/* ============================IntrinsicAbility================================== */ +enum class IntrinsicAbility { + add = 0, + anti_quant_add = 1, + anti_quant_sub = 2, + b8u2 = 3, + bb = 4, + bf16 = 5, + bf162f32 = 6, + bf162s32a = 7, + bf162s32c = 8, + bf162s32f = 9, + bf162s32r = 10, + bf162s32z = 11, + bs16 = 12, + bs32 = 13, + bs8 = 14, + bu16 = 15, + bu32 = 16, + bu8 = 17, + cast = 18, + clip_relu = 19, + deq = 20, + deqs322f16 = 21, + dequant = 22, + f16 = 23, + f162f16 = 24, + f162f32 = 25, + f162s16 = 26, + f162s16a = 27, + f162s16c = 28, + f162s16f = 29, + f162s16r = 30, + f162s16z = 31, + f162s32a = 32, + f162s32c = 33, + f162s32f = 34, + f162s32r = 35, + f162s32z = 36, + f162s4 = 37, + f162s8 = 38, + f162s8a = 39, + f162s8c = 40, + f162s8f = 41, + f162s8r = 42, + f162s8z = 43, + f162u8 = 44, + f162u8a = 45, + f162u8c = 46, + f162u8f = 47, + f162u8r = 48, + f162u8z = 49, + f16f16 = 50, + f16f16f16 = 51, + f16f16s4 = 52, + f16f16s8 = 53, + f16f16u16 = 54, + f16f16u2 = 55, + f16f32 = 56, + f16s16 = 57, + f16s32 = 58, + f16s32s32 = 59, + f16s8 = 60, + f16u16 = 61, + f16u16f16 = 62, + f16u2 = 63, + f16u8 = 64, + f32 = 65, + f322bf16 = 66, + f322bf16a = 67, + f322bf16c = 68, + f322bf16f = 69, + f322bf16r = 70, + f322bf16z = 71, + f322f16 = 72, + f322f16a = 73, + f322f16c = 74, + f322f16f = 75, + f322f16o = 76, + f322f16r = 77, + f322f16z = 78, + f322f32 = 79, + f322f32a = 80, + f322f32c = 81, + f322f32f = 82, + f322f32r = 83, + f322f32z = 84, + f322s16a = 85, + f322s16c = 86, + f322s16f = 87, + f322s16r = 88, + f322s16z = 89, + f322s32a = 90, + f322s32c = 91, + f322s32f = 92, + f322s32r = 93, + f322s32z = 94, + f322s4 = 95, + f322s64a = 96, + f322s64c = 97, + f322s64f = 98, + f322s64r = 99, + f322s64z = 100, + f322s8 = 101, + f322u8 = 102, + f32f16 = 103, + f32f16f16 = 104, + f32f32 = 105, + f32f32f32 = 106, + f32s16 = 107, + f32s32 = 108, + f32u32 = 109, + f32u32f32 = 110, + float16 = 111, + float32 = 112, + h322f32 = 113, + int16 = 114, + int32 = 115, + int64 = 116, + int8 = 117, + normal_relu = 118, + nz2nd = 119, + post_act = 120, + post_eltwise = 121, + post_quant = 122, + post_transform = 123, + pre_act = 124, + pre_conv = 125, + quant = 126, + requant = 127, + s16 = 128, + s162f16 = 129, + s162f16a = 130, + s162f16c = 131, + s162f16f = 132, + s162f16r = 133, + s162f16z = 134, + s162f32 = 135, + s162s16 = 136, + s162s32 = 137, + s162s8 = 138, + s162u32 = 139, + s162u8 = 140, + s16f16 = 141, + s16f32 = 142, + s16s16 = 143, + s16s16u16 = 144, + s16s32 = 145, + s16s48 = 146, + s16s64 = 147, + s16s8 = 148, + s16u16 = 149, + s16u16s16 = 150, + s16u16s8 = 151, + s16u32 = 152, + s16u8 = 153, + s24s16 = 154, + s24s8 = 155, + s24u16 = 156, + s24u8 = 157, + s32 = 158, + s322f16 = 159, + s322f32 = 160, + s322f32a = 161, + s322f32c = 162, + s322f32f = 163, + s322f32r = 164, + s322f32z = 165, + s322s16 = 166, + s322s4 = 167, + s322s64 = 168, + s322s8 = 169, + s322u16 = 170, + s322u8 = 171, + s32f32 = 172, + s32s16 = 173, + s32s32 = 174, + s32s4s4 = 175, + s32s8s8 = 176, + s32u16 = 177, + s32u32 = 178, + s32u32s32 = 179, + s32u8 = 180, + s32u8s8 = 181, + s32u8u2 = 182, + s32u8u8 = 183, + s4 = 184, + s48s16 = 185, + s48s32 = 186, + s48u16 = 187, + s642f32a = 188, + s642f32c = 189, + s642f32f = 190, + s642f32r = 191, + s642f32z = 192, + s642s32 = 193, + s64s32 = 194, + s8 = 195, + s82f16 = 196, + s82s16 = 197, + s82s32 = 198, + s82s8 = 199, + s8f16 = 200, + s8f16f16 = 201, + s8s16 = 202, + s8s24 = 203, + s8s32 = 204, + s8s48 = 205, + s8s8 = 206, + s8s8u8 = 207, + s8u16 = 208, + scalar_relu = 209, + sub = 210, + u16 = 211, + u162s32 = 212, + u162u32 = 213, + u162u8 = 214, + u16s16 = 215, + u16s48 = 216, + u16s64 = 217, + u16u16 = 218, + u16u16u16 = 219, + u16u16u8 = 220, + u16u32 = 221, + u16u8 = 222, + u32 = 223, + u322s16 = 224, + u322u16 = 225, + u322u8 = 226, + u32s16 = 227, + u32s32 = 228, + u32u16 = 229, + u32u32 = 230, + u32u32u32 = 231, + u32u8 = 232, + u32u8u8 = 233, + u8 = 234, + u82f16 = 235, + u82s16 = 236, + u82s32 = 237, + u82u16 = 238, + u82u32 = 239, + u8f16 = 240, + u8f16f16 = 241, + u8s16 = 242, + u8s24 = 243, + u8s48 = 244, + u8s8 = 245, + u8u16 = 246, + u8u32 = 247, + u8u8 = 248, + u8u8u8 = 249, + uint16 = 250, + uint32 = 251, + uint64 = 252, + uint8 = 253, + vdeqs162b8 = 254, + vector_relu = 255 +}; + +/* ============================PlatformLabel================================== */ +enum class PlatformLabel { + ENUM_AI_CORE_MEMORY_RATES = 0, /* label string: AICoreMemoryRates */ + ENUM_AI_CORE_SPEC = 1, /* label string: AICoreSpec */ + ENUM_CPU_CACHE = 2, /* label string: CPUCache */ + ENUM_DSA_RANDOM = 3, /* label string: DSARandom */ + ENUM_SOC_INFO = 4, /* label string: SoCInfo */ + ENUM_SOFTWARE_SPEC = 5, /* label string: SoftwareSpec */ + ENUM_VECTOR_CORE_MEMORY_RATES = 6, /* label string: VectorCoreMemoryRates */ + ENUM_VECTOR_CORE_SPEC = 7, /* label string: VectorCoreSpec */ + ENUM_VERSION = 8, /* label string: version */ + + /* WARNING: ENUM_AI_CORE_INTRINSIC_DTYPE_MAP must be the first enum among all + * intrinsic dtype maps. */ + ENUM_AI_CORE_INTRINSIC_DTYPE_MAP = 9, /* label string: AICoreintrinsicDtypeMap */ + ENUM_VECTOR_CORE_INTRINSIC_DTYPE_MAP = 10, /* label string: VectorCoreintrinsicDtypeMap */ + ENUM_ALL_LABEL_END +}; +constexpr size_t NORMAL_LABEL_SIZE = static_cast(PlatformLabel::ENUM_AI_CORE_INTRINSIC_DTYPE_MAP); +constexpr size_t ALL_LABEL_SIZE = static_cast(PlatformLabel::ENUM_ALL_LABEL_END); + + +struct LabelAndKey { + LabelAndKey(const string &label_name, const std::vector &key_name_vector) { + pair = std::make_pair(label_name, key_name_vector); + } + + std::pair> pair; +}; + +enum { COUNTER_BASE = __COUNTER__ }; +template +LabelAndKey MakeQualifiedPair(const string &label_name, const std::vector &key_name_vector) { + static_assert((COUNTER - COUNTER_BASE - 1) == ENUM, "ENUM value is not in the correct sequence"); + return LabelAndKey(label_name, key_name_vector); +} + +#define GEN_LABEL(label_name, key_names_vector) \ + GEN_LABEL_UNIQ_HELPER(__COUNTER__, PlatformLabel::ENUM_##key_names_vector, label_name, key_names_vector) + +#define GEN_LABEL_UNIQ_HELPER(ctr, enum_label, label_name, key_names_vector) \ + MakeQualifiedPair(enum_label)>(label_name, key_names_vector) + +const std::vector LABEL_AND_KEYS = { + GEN_LABEL("AICoreMemoryRates", AI_CORE_MEMORY_RATES), + GEN_LABEL("AICoreSpec", AI_CORE_SPEC), + GEN_LABEL("CPUCache", CPU_CACHE), + GEN_LABEL("DSARandom", DSA_RANDOM), + GEN_LABEL("SoCInfo", SOC_INFO), + GEN_LABEL("SoftwareSpec", SOFTWARE_SPEC), + GEN_LABEL("VectorCoreMemoryRates", VECTOR_CORE_MEMORY_RATES), + GEN_LABEL("VectorCoreSpec", VECTOR_CORE_SPEC), + GEN_LABEL("version", VERSION), + GEN_LABEL("AICoreintrinsicDtypeMap", AI_CORE_INTRINSIC_DTYPE_MAP), + GEN_LABEL("VectorCoreintrinsicDtypeMap", VECTOR_CORE_INTRINSIC_DTYPE_MAP), + GEN_LABEL("ALL_LABEL_END", ALL_LABEL_END) +}; + +/* ============================key name string definition end================================== */ +class PlatFormInfosLiteImpl; +using PlatFormInfosLiteImplPtr = std::shared_ptr; +class PlatFormInfosLite { + public: + bool InitPlatFormInfosLite(SocVersion soc_version, PlatFormInfos& old_platform_infos); + bool GetPlatformRes(PlatformLabel label, uint64_t key, uint64_t &val) const; + const std::vector &GetPlatformRes(PlatformLabel label) const; + bool CheckIntrinsicSupport(IntrinsicTypeKey intrinsic_type, IntrinsicAbility intrinsic_ability) const; + private: + PlatFormInfosLiteImplPtr platform_infos_lite_impl_ {nullptr}; +}; +} +#endif // INC_EXTERNAL_PLATFORM_PLATFORM_INFOS_LITE_DEF_H_ diff --git a/inc/metadef/external/register/device_op_impl_registry.h b/inc/metadef/external/register/device_op_impl_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..b94a99cdef0bac3a7169ba34f5cc92864ba42e77 --- /dev/null +++ b/inc/metadef/external/register/device_op_impl_registry.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/*! + * \file device_op_impl_registry.h + * \brief + */ + +#ifndef REGISTER_DEVICE_OP_IMPL_REGISTRY_H +#define REGISTER_DEVICE_OP_IMPL_REGISTRY_H +#include +#include "graph/compiler_def.h" +#include "exe_graph/runtime/tiling_context.h" + +namespace optiling { +using SinkTilingFunc = std::function; + +class DeviceOpImplRegisterImpl; +class DeviceOpImplRegister { + public: + DeviceOpImplRegister(const char *opType); + ~DeviceOpImplRegister(); + DeviceOpImplRegister(DeviceOpImplRegister &&other) noexcept; + DeviceOpImplRegister(const DeviceOpImplRegister &other); + DeviceOpImplRegister &operator=(const DeviceOpImplRegister &) = delete; + DeviceOpImplRegister &operator=(DeviceOpImplRegister &&) = delete; + DeviceOpImplRegister &Tiling(SinkTilingFunc func); + + private: + std::unique_ptr impl_; +}; +} // namespace optiling + +#define DEVICE_IMPL_OP_OPTILING(optype) \ + static optiling::DeviceOpImplRegister VAR_UNUSED g_deviceOpImplRegister##optype = \ + optiling::DeviceOpImplRegister(#optype) +#endif \ No newline at end of file diff --git a/inc/metadef/external/register/ffts_node_calculater_registry.h b/inc/metadef/external/register/ffts_node_calculater_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..4e8b2ec277ad297037e7ba72ac8ce352a53e00dd --- /dev/null +++ b/inc/metadef/external/register/ffts_node_calculater_registry.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CALCULATER_REGISTRY_H_ +#define AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CALCULATER_REGISTRY_H_ +#include +#include +#include +#include "graph/node.h" +#include "exe_graph/lowering/value_holder.h" +#include "exe_graph/lowering/lowering_global_data.h" + +namespace gert { +struct NodeMemPara { + size_t size {0}; + void *dev_addr {nullptr}; + void *host_addr {nullptr}; +}; + +class FFTSNodeCalculaterRegistry { + public: + /* + * Output param: + * 1.size_t total_size -- node need memory size + * 2.size_t pre_data_size -- node pre proc data size + * 3.std::unique_ptr pre_data_ptr -- node pre proc data memory with size(pre_data_size), framework will + * copy pre_data to new alloc memory + */ + using NodeCalculater = ge::graphStatus (*)(const ge::NodePtr &node, const LoweringGlobalData *global_data, + size_t &total_size, size_t &pre_data_size, std::unique_ptr &pre_data_ptr); + static FFTSNodeCalculaterRegistry &GetInstance(); + NodeCalculater FindNodeCalculater(const std::string &func_name); + void Register(const std::string &func_name, const NodeCalculater func); + + private: + std::unordered_map names_to_calculater_; +}; + +class FFTSNodeCalculaterRegister { + public: + FFTSNodeCalculaterRegister(const string &func_name, FFTSNodeCalculaterRegistry::NodeCalculater func) noexcept; +}; +} // namespace gert + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif + +#define GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER2(type, func, counter) \ + static const gert::FFTSNodeCalculaterRegister g_register_node_calculater_##counter ATTRIBUTE_USED = \ + gert::FFTSNodeCalculaterRegister(type, func) +#define GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER(type, func, counter) \ + GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER2(type, func, counter) +#define FFTS_REGISTER_NODE_CALCULATER(type, func) \ + GERT_REGISTER_FFTS_NODE_CALCULATER_COUNTER(type, func, __COUNTER__) + +#endif // AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CALCULATER_REGISTRY_H_ diff --git a/inc/metadef/external/register/hidden_input_func_registry.h b/inc/metadef/external/register/hidden_input_func_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..5452f9b0ec92be56c002e683b26d2ad53666978d --- /dev/null +++ b/inc/metadef/external/register/hidden_input_func_registry.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_HIDDEN_INPUT_FUNC_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_HIDDEN_INPUT_FUNC_REGISTRY_H_ + +#include +#include +#include "graph/op_desc.h" +namespace ge { +enum class HiddenInputType : uint32_t { HCOM }; + +using GetHiddenAddr = ge::graphStatus (*)(const ge::OpDescPtr &op_desc, void *&addr); +class HiddenInputFuncRegistry { + public: + static HiddenInputFuncRegistry &GetInstance(); + GetHiddenAddr FindHiddenInputFunc(const HiddenInputType input_type); + void Register(const HiddenInputType input_type, const GetHiddenAddr func); + + private: + std::map type_to_funcs_; +}; + +class HiddenInputFuncRegister { + public: + HiddenInputFuncRegister(const HiddenInputType input_type, const GetHiddenAddr func); +}; +} // namespace ge + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif +#define REG_HIDDEN_INPUT_FUNC(type, func) REG_HIDDEN_INPUT_FUNC_UNIQ_HELPER(type, func, __COUNTER__) +#define REG_HIDDEN_INPUT_FUNC_UNIQ_HELPER(type, func, counter) REG_HIDDEN_INPUT_FUNC_UNIQ(type, func, counter) +#define REG_HIDDEN_INPUT_FUNC_UNIQ(type, func, counter) \ + static ::ge::HiddenInputFuncRegister register_hidden_func_##counter ATTRIBUTE_USED = \ + ge::HiddenInputFuncRegister(type, func) + +#endif // INC_EXTERNAL_REGISTER_HIDDEN_INPUT_FUNC_REGISTRY_H_ diff --git a/inc/metadef/external/register/hidden_inputs_func_registry.h b/inc/metadef/external/register/hidden_inputs_func_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..8a4e43639ee68a631d77050798d4ca04e4e7f553 --- /dev/null +++ b/inc/metadef/external/register/hidden_inputs_func_registry.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_HIDDEN_INPUTS_FUNC_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_HIDDEN_INPUTS_FUNC_REGISTRY_H_ + +#include +#include +#include "graph/op_desc.h" +namespace ge { +enum class HiddenInputsType : uint32_t { HCOM }; + +using GetHiddenAddrs = ge::graphStatus (*)(const ge::OpDescPtr &op_desc, std::vector &addr); +class HiddenInputsFuncRegistry { + public: + static HiddenInputsFuncRegistry &GetInstance(); + GetHiddenAddrs FindHiddenInputsFunc(const HiddenInputsType input_type); + void Register(const HiddenInputsType input_type, const GetHiddenAddrs func); + + private: + std::map type_to_funcs_; +}; + +class HiddenInputsFuncRegister { + public: + HiddenInputsFuncRegister(const HiddenInputsType input_type, const GetHiddenAddrs func); +}; +} // namespace ge + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif +#define REG_HIDDEN_INPUTS_FUNC(type, func) REG_HIDDEN_INPUTS_FUNC_UNIQ_HELPER(type, func, __COUNTER__) +#define REG_HIDDEN_INPUTS_FUNC_UNIQ_HELPER(type, func, counter) REG_HIDDEN_INPUTS_FUNC_UNIQ(type, func, counter) +#define REG_HIDDEN_INPUTS_FUNC_UNIQ(type, func, counter) \ + static ::ge::HiddenInputsFuncRegister register_hidden_func_##counter ATTRIBUTE_USED = \ + ge::HiddenInputsFuncRegister(type, func) + +#endif // INC_EXTERNAL_REGISTER_HIDDEN_INPUTS_FUNC_REGISTRY_H_ diff --git a/inc/metadef/external/register/op_binary_resource_manager.h b/inc/metadef/external/register/op_binary_resource_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..d01ed9eb2ff0e091856fd57a12b43a010a516d01 --- /dev/null +++ b/inc/metadef/external/register/op_binary_resource_manager.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#include +#include +#include +#include +#include "nlohmann/json.hpp" +#include "graph/ascend_string.h" +#include "graph/ge_error_codes.h" + +#ifndef INC_EXTERNAL_REGISTER_OP_BINARY_RESOURCE_MANAGER_H_ +#define INC_EXTERNAL_REGISTER_OP_BINARY_RESOURCE_MANAGER_H_ + +namespace nnopbase { +// 二进制内容,用于内存中的二进制.o文件描述 +struct Binary { + const uint8_t *content; // 二进制内容的起始指针 + uint32_t len; // 二进制的长度 +}; + +class OpBinaryResourceManager { +public: + static OpBinaryResourceManager &GetInstance() + { + static OpBinaryResourceManager manager; + return manager; + } + + ~OpBinaryResourceManager() = default; + + void AddOpFuncHandle(const ge::AscendString &opType, const std::vector &opResourceHandle); + ge::graphStatus AddBinary(const ge::AscendString &opType, + const std::vector> &opBinary); + ge::graphStatus AddRuntimeKB(const ge::AscendString &opType, + const std::vector> &opRuntimeKb); + +// 获取资源 +public: + // 获取所有算子的描述信息 + const std::map &GetAllOpBinaryDesc() const; + + // 获取某个算子的描述信息 + ge::graphStatus GetOpBinaryDesc(const ge::AscendString &opType, nlohmann::json &binDesc) const; + + // 根据json文件路径(算子json中存在)查找.json/.o的信息 + ge::graphStatus GetOpBinaryDescByPath(const ge::AscendString &jsonFilePath, + std::tuple &binInfo) const; + + // 根据simplifiedKey(算子json中存在)查找.json/.o的信息 + ge::graphStatus GetOpBinaryDescByKey(const ge::AscendString &simplifiedKey, + std::tuple &binInfo) const; + + // 二进制知识库 + ge::graphStatus GetOpRuntimeKB(const ge::AscendString &opType, std::vector &kbList) const; + +private: + OpBinaryResourceManager() = default; // 单例,禁止外部创建对象 + OpBinaryResourceManager &operator=(const OpBinaryResourceManager &) = delete; // 禁止拷贝 + OpBinaryResourceManager &operator=(OpBinaryResourceManager &&) = delete; + OpBinaryResourceManager(const OpBinaryResourceManager &) = delete; + OpBinaryResourceManager(OpBinaryResourceManager &&) = delete; + + mutable std::recursive_mutex mutex_; + + // 二进制描述信息 opType -> xxx.json + std::map opBinaryDesc_; + + // 二进制jsonPath simplifiedKey -> jsonPath, 可能存在多个simplifiedKey对应同一个jsonPath + std::map keyToPath_; + + // 二进制信息 jsonPath -> xxx.json, xxx.o + std::map> pathToBinary_; + + // 二进制知识库 opType -> xxx1.json, xxx2.json + std::map> runtimeKb_; + + // infershape/op tiling/runtime kb parser等全局变量指针,注册类资源,仅需持有 + std::map> resourceHandle_; +}; +} // nnopbase + +#endif // INC_EXTERNAL_REGISTER_OP_BINARY_RESOURCE_MANAGER_H_ diff --git a/inc/metadef/external/register/op_check.h b/inc/metadef/external/register/op_check.h new file mode 100644 index 0000000000000000000000000000000000000000..efc1d26f6d8af302419d03fcc76704bd6681792d --- /dev/null +++ b/inc/metadef/external/register/op_check.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_CHECK_H_ +#define INC_EXTERNAL_REGISTER_OP_CHECK_H_ + +#include "op_check_register.h" +namespace optiling { + +#define REG_CHECK_SUPPORT(op_type, func) \ + static OpCheckFuncHelper op_check_registry_##op_type##_check_supported(FUNC_CHECK_SUPPORTED, #op_type, func) +#define REG_OP_SELECT_FORMAT(op_type, func) \ + static OpCheckFuncHelper op_check_registry_##op_type##_op_select_format(FUNC_OP_SELECT_FORMAT, #op_type, func) +#define REG_OP_SUPPORT_INFO(op_type, func) \ + static OpCheckFuncHelper op_check_registry_##op_type##_get_op_support_info(FUNC_GET_OP_SUPPORT_INFO, #op_type, func) +#define REG_OP_SPEC_INFO(op_type, func) \ + static OpCheckFuncHelper op_check_registry_##op_type##_get_specific_info(FUNC_GET_SPECIFIC_INFO, #op_type, func) + +#define REG_OP_PARAM_GENERALIZE(op_type, generalize_func) \ + static OpCheckFuncHelper op_check_generalize_registry_##op_type(#op_type, generalize_func) + +#define REG_REPLAY_FUNC(op_type, soc_version, func) \ + static ReplayFuncHelper op_replay_registry_##op_type_##soc_version(#op_type, #soc_version, func) +} // end of namespace optiling +#endif // INC_EXTERNAL_REGISTER_OP_CHECK_H_ \ No newline at end of file diff --git a/inc/metadef/external/register/op_compile_info_base.h b/inc/metadef/external/register/op_compile_info_base.h new file mode 100644 index 0000000000000000000000000000000000000000..3becaeceacb55956563ad0b090ec28f083df129a --- /dev/null +++ b/inc/metadef/external/register/op_compile_info_base.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_COMPILE_INFO_BASE_H_ +#define INC_EXTERNAL_REGISTER_OP_COMPILE_INFO_BASE_H_ + +#include + +namespace optiling { +class CompileInfoBase; +using CompileInfoPtr = std::shared_ptr; + +class CompileInfoBase { +public: + CompileInfoBase() {} + virtual ~CompileInfoBase() {} +}; +} // namespace optiling +#endif // INC_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/inc/metadef/external/register/op_ct_impl_kernel_registry.h b/inc/metadef/external/register/op_ct_impl_kernel_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..4f20b38ff15f2b0348a07ca4b426ddd9a3f01607 --- /dev/null +++ b/inc/metadef/external/register/op_ct_impl_kernel_registry.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_CT_IMPL_KERNEL_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_OP_CT_IMPL_KERNEL_REGISTRY_H_ +#include "exe_graph/runtime/base_type.h" +#include "graph/ascend_string.h" +#include "exe_graph/runtime/exe_res_generation_context.h" + +#define OP_CT_IMPL_MAIM_VERSION 1 + +namespace gert { +struct OpCtImplKernelRegistry { + using OpType = ge::AscendString; + using OpCalcParamKernelFunc = UINT32 (*)(ExeResGenerationContext *context); + using OpGenTaskKernelFunc = UINT32 (*)(const ExeResGenerationContext *context, + std::vector> &tasks); + struct OpCtImplFunctions { + uint32_t st_size = sizeof(OpCtImplFunctions); + uint32_t version = OP_CT_IMPL_MAIM_VERSION; + OpCalcParamKernelFunc calc_op_param = nullptr; + OpGenTaskKernelFunc gen_task = nullptr;; + }; +}; +} // namespace gert +#endif diff --git a/inc/metadef/external/register/op_ct_impl_registry.h b/inc/metadef/external/register/op_ct_impl_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..4646af531cca022c88da00010491b9c38ae215c6 --- /dev/null +++ b/inc/metadef/external/register/op_ct_impl_registry.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_CT_IMPL_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_OP_CT_IMPL_REGISTRY_H_ +#include +#include +#include "graph/types.h" +#include "graph/compiler_def.h" +#include "register/op_ct_impl_kernel_registry.h" + +namespace gert { +class OpCtImplRegistry : public OpCtImplKernelRegistry { + public: + static OpCtImplRegistry &GetInstance(); + OpCtImplFunctions &CreateOrGetOpImpl(const ge::char_t *op_type); + const OpCtImplFunctions *GetOpImpl(const ge::char_t *op_type) const; + const std::map &GetAllTypesToImpl() const; + std::map &GetAllTypesToImpl(); + + private: + std::map types_to_impl_; + uint8_t reserved_[40] = {0U}; // Reserved field, 32+8, do not directly use when only 8-byte left +}; + +class OpCtImplRegisterV2Impl { + public: + OpCtImplRegistry::OpType op_type; + OpCtImplRegistry::OpCtImplFunctions functions; +}; + +class OpCtImplRegisterV2 { + public: + explicit OpCtImplRegisterV2(const ge::char_t *op_type); + OpCtImplRegisterV2(OpCtImplRegisterV2 &®ister_data) noexcept; + OpCtImplRegisterV2(const OpCtImplRegisterV2 ®ister_data); + OpCtImplRegisterV2 &operator=(const OpCtImplRegisterV2 &) = delete; + OpCtImplRegisterV2 &operator=(OpCtImplRegisterV2 &&) = delete; + ~OpCtImplRegisterV2(); + + public: + OpCtImplRegisterV2 &CalcOpParam(OpCtImplKernelRegistry::OpCalcParamKernelFunc calc_op_param_func); + OpCtImplRegisterV2 &GenerateTask(OpCtImplKernelRegistry::OpGenTaskKernelFunc gen_task_func); + private: + std::unique_ptr impl_; +}; +} // namespace gert + +#define IMPL_OP_CT_COUNTER(op_type, name, counter) \ + static gert::OpCtImplRegisterV2 VAR_UNUSED name##counter = gert::OpCtImplRegisterV2(#op_type) +#define IMPL_OP_CT_COUNTER_NUMBER(op_type, name, counter) IMPL_OP_CT_COUNTER(op_type, name, counter) +#define IMPL_OP_CT(op_type) IMPL_OP_CT_COUNTER_NUMBER(op_type, op_impl_register_##op_type, __COUNTER__) + +#endif // INC_EXTERNAL_REGISTER_OP_CT_IMPL_REGISTRY_H_ diff --git a/inc/metadef/external/register/op_ct_impl_registry_api.h b/inc/metadef/external/register/op_ct_impl_registry_api.h new file mode 100644 index 0000000000000000000000000000000000000000..fd47adfa95f27c600ca2e97370bfd9b605c981b6 --- /dev/null +++ b/inc/metadef/external/register/op_ct_impl_registry_api.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_CT_IMPL_REGISTRY_API_H_ +#define INC_EXTERNAL_REGISTER_OP_CT_IMPL_REGISTRY_API_H_ + +#include +#include "op_ct_impl_kernel_registry.h" + +struct TypesToCtImpl { + const char *op_type; + gert::OpCtImplKernelRegistry::OpCtImplFunctions funcs; +}; + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __GNUC__ +#define METADEF_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define METADEF_FUNC_VISIBILITY +#endif + +METADEF_FUNC_VISIBILITY size_t GetRegisteredOpCtNum(void); +METADEF_FUNC_VISIBILITY int32_t GetOpCtImplFunctions(TypesToCtImpl *impl, size_t impl_num); + +#ifdef __cplusplus +} +#endif + +#endif // INC_EXTERNAL_REGISTER_OP_CT_IMPL_REGISTRY_API_H_ diff --git a/inc/metadef/external/register/op_def.h b/inc/metadef/external/register/op_def.h new file mode 100644 index 0000000000000000000000000000000000000000..abf8f9f89850aa35dda3e2060b2dc6f053b7e494 --- /dev/null +++ b/inc/metadef/external/register/op_def.h @@ -0,0 +1,491 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef OP_DEF_H +#define OP_DEF_H + +#include +#include +#include +#include "register/op_impl_registry.h" +#include "graph/operator_reg.h" + +namespace optiling { +#define FUNC_CHECK_SUPPORTED "check_supported" +#define FUNC_OP_SELECT_FORMAT "op_select_format" +#define FUNC_GET_OP_SUPPORT_INFO "get_op_support_info" +#define FUNC_GET_SPECIFIC_INFO "get_op_specific_info" + +using OP_CHECK_FUNC = ge::graphStatus (*)(const ge::Operator &op, ge::AscendString &result); + +using PARAM_GENERALIZE_FUNC = ge::graphStatus (*)(const ge::Operator &op, const ge::AscendString &generalize_config, + ge::AscendString &generalized_op_params); + +class OpCheckFuncHelper { +public: + OpCheckFuncHelper(const ge::AscendString &check_type, const ge::AscendString &op_type, OP_CHECK_FUNC func); + + OpCheckFuncHelper(const ge::AscendString &op_type, PARAM_GENERALIZE_FUNC func); +}; +} + +namespace ops { +class AclnnOpGenerator; +class Generator; +class OpProtoGenerator; +class GeneratorFactory; +class CfgGenerator; +class OpParamTrunk; + +enum Option { IGNORE = 0, OPTIONAL = 1, REQUIRED = 2, DYNAMIC = 3, VIRTUAL = 4 }; + +enum class FormatCheckOption : uint32_t { + DEFAULT = 0, + STRICT = 1, + MAX +}; + +enum class DependScope : uint32_t { + ALL = 0, + TILING = 1, + INVALID_SCOPE +}; + +enum class FollowType : uint32_t { + ALL = 0, + DTYPE = 1, + FORMAT = 2, + SHAPE = 3, + INVALID_TYPE +}; + +enum class AttrDataType { + ATTR_DT_BOOL = 0, + ATTR_DT_FLOAT = 1, + ATTR_DT_INT = 2, + ATTR_DT_STR = 3, + ATTR_DT_LIST_BOOL = 4, + ATTR_DT_LIST_FLOAT = 5, + ATTR_DT_LIST_INT = 6, + ATTR_DT_LIST_LIST_INT = 7, + ATTR_DT_MAX +}; + +enum class InitValueType : uint32_t { + INIT_VALUE_UINT64_T = 0, + INIT_VALUE_DEFAULT = static_cast(-1), +}; + +enum class CommentSection : uint32_t { + CATEGORY = 0, + BRIEF = 1, + CONSTRAINTS = 2, + RESTRICTIONS = 3, + SEE = 4, + THIRDPARTYFWKCOMPAT = 5, + SECTION_MAX +}; + +enum class ScalarType : uint32_t { + UINT64 = 0, + INT64 = 1, + UINT32 = 2, + INT32 = 3, + UINT16 = 4, + INT16 = 5, + UINT8 = 6, + INT8 = 7, + FLOAT32 = 8, + FLOAT16 = 9, + INVALID_DTYPE = static_cast(-1), +}; + +union ScalarNum { + uint64_t value_u64; + int64_t value_i64; + float value_f32; + ScalarNum() : value_u64(0) {} + explicit ScalarNum(uint64_t value) : value_u64(value) {} + explicit ScalarNum(int64_t value) : value_i64(value) {} + explicit ScalarNum(float value) : value_f32(value) {} +}; + +using InitValueNum = ScalarNum; + +struct ScalarVar { + ScalarType scalar_type; + ScalarNum scalar_num; + ScalarVar() : scalar_type(ScalarType::INVALID_DTYPE) {} + ScalarVar(ScalarType type, uint64_t num) : scalar_type(type), scalar_num(num) { + if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { + scalar_num = ScalarNum(static_cast(num)); + } + } + ScalarVar(ScalarType type, int64_t num) : scalar_type(type), scalar_num(num) { + if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { + scalar_num = ScalarNum(static_cast(num)); + } + } + ScalarVar(ScalarType type, int num) : scalar_type(type), scalar_num(static_cast(num)) { + if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { + scalar_num = ScalarNum(static_cast(num)); + } + } + ScalarVar(ScalarType type, unsigned int num) : scalar_type(type), scalar_num(static_cast(num)) { + if (type == ScalarType::FLOAT32 || type == ScalarType::FLOAT16) { + scalar_num = ScalarNum(static_cast(num)); + } + } + ScalarVar(ScalarType type, float num) : scalar_type(type), scalar_num(num) { + if (type != ScalarType::FLOAT32 && type != ScalarType::FLOAT16) { + if (type == ScalarType::UINT64) { + scalar_num = ScalarNum(static_cast(num)); + } + scalar_num = ScalarNum(static_cast(num)); + } + } + ScalarVar(ScalarType type, double num) : scalar_type(type), scalar_num(static_cast(num)) { + if (type != ScalarType::FLOAT32 && type != ScalarType::FLOAT16) { + if (type == ScalarType::UINT64) { + scalar_num = ScalarNum(static_cast(num)); + } + scalar_num = ScalarNum(static_cast(num)); + } + } + bool operator==(const ScalarVar& other) const { + if (scalar_type == other.scalar_type && scalar_num.value_u64 == other.scalar_num.value_u64) { + return true; + } + return false; + } +}; + +enum class ItemFindStatus { ITEM_FIND = 0, ITEM_NOEXIST = 1 }; + +class OpParamDefImpl; +class OpParamDef { +public: + explicit OpParamDef(const char *name); + OpParamDef(const OpParamDef &def); + ~OpParamDef(); + OpParamDef &operator=(const OpParamDef &def); + OpParamDef &ParamType(Option param_type); + OpParamDef &DataType(std::vector types); + OpParamDef &DataTypeList(std::vector types); + OpParamDef &Format(std::vector formats); + OpParamDef &FormatList(std::vector formats); + OpParamDef &DataTypeForBinQuery(std::vector types); + OpParamDef &FormatForBinQuery(std::vector formats); + OpParamDef &UnknownShapeFormat(std::vector formats); + OpParamDef &ValueDepend(Option value_depend); + OpParamDef &ValueDepend(Option value_depend, DependScope scope); + OpParamDef &IgnoreContiguous(void); + OpParamDef &AutoContiguous(); + OpParamDef &Scalar(); + OpParamDef &ScalarList(); + OpParamDef &To(const ge::DataType type); + OpParamDef &To(const char *name); + OpParamDef &Version(uint32_t version); + OpParamDef &InitValue(uint64_t value); + OpParamDef &InitValue(const ScalarVar &value); + OpParamDef &InitValue(const std::vector &value); + OpParamDef &OutputShapeDependOnCompute(); + OpParamDef &Follow(const char *paramName); + OpParamDef &Follow(const char *paramName, FollowType ftype); + OpParamDef &Comment(const char *comment); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + friend class OpParamTrunk; + friend class OpDef; + + bool operator==(const OpParamDef &def) const; + void MergeParam(const OpParamDef &def); + ge::AscendString &GetParamName(void) const; + Option GetParamType(void); + std::vector &GetDataTypes(void); + std::vector &GetOriginDataTypes(void); + std::vector &GetDataTypesList(void); + std::vector &GetDataTypesForBin(void) const; + std::vector &GetFormats(void); + std::vector &GetFormatsList(void); + std::vector &GetFormatsForBin(void) const; + std::vector &GetUnknownShapeFormats(void); + ge::AscendString &GetValueDepend(void) const; + DependScope &GetDependScope(void) const; + ge::AscendString &GetFollowName(void) const; + FollowType &GetFollowType(void) const; + ge::AscendString &GetComment(void) const; + bool GetIgnoreContiguous(void); + bool GetAutoContiguous(void); + bool IsScalar(void) const; + bool IsScalarList(void) const; + bool IsScalarOrScalarList(void) const; + bool IsScalarTypeSet(void) const; + bool IsScalarNameSet(void) const; + bool IsValueDepend(void) const; + bool IsDtype(void) const; + bool IsDtypeList(void) const; + bool IsFormat(void) const; + bool IsFormatList(void) const; + bool IsOutputShapeDependOnCompute(void) const; + bool IsSetDtypeForBin(void) const; + bool IsSetFormatForBin(void) const; + ge::AscendString &GetScalarName(void) const; + ge::DataType GetScalarType(void) const; + uint32_t GetVersion(void); + InitValueType &GetInitValueType(void); + InitValueNum &GetInitValue(void); + std::vector &GetInitValueList(void); + std::unique_ptr impl_; +}; + +class OpAttrDefImpl; +class OpAttrDef { +public: + explicit OpAttrDef(const char *name); + OpAttrDef(const OpAttrDef &attr_def); + ~OpAttrDef(); + OpAttrDef &operator=(const OpAttrDef &attr_def); + OpAttrDef &AttrType(Option attr_type); + OpAttrDef &Bool(void); + OpAttrDef &Bool(bool value); + OpAttrDef &Float(void); + OpAttrDef &Float(float value); + OpAttrDef &Int(void); + OpAttrDef &Int(int64_t value); + OpAttrDef &String(void); + OpAttrDef &String(const char *value); + OpAttrDef &ListBool(void); + OpAttrDef &ListBool(std::vector value); + OpAttrDef &ListFloat(void); + OpAttrDef &ListFloat(std::vector value); + OpAttrDef &ListInt(void); + OpAttrDef &ListInt(std::vector value); + OpAttrDef &ListListInt(void); + OpAttrDef &ListListInt(std::vector> value); + OpAttrDef &Version(uint32_t version); + OpAttrDef &Comment(const char *comment); + ge::AscendString &GetName(void) const; + bool IsRequired(void); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + friend class OpParamTrunk; + friend class OpDef; + + bool operator==(const OpAttrDef &attr_def) const; + ge::AscendString &GetCfgDataType(void) const; + ge::AscendString &GetProtoDataType(void) const; + ge::AscendString &GetAttrDefaultVal(const char *brac); + uint32_t GetVersion(void); + ge::AscendString &GetComment(void) const; + + std::unique_ptr impl_; +}; + +class OpAICoreConfigImpl; +class OpAICoreConfig { +public: + OpAICoreConfig(); + OpAICoreConfig(const OpAICoreConfig &aicore_config); + ~OpAICoreConfig(); + OpAICoreConfig &operator=(const OpAICoreConfig &aicore_config); + OpParamDef &Input(const char *name); + OpParamDef &Output(const char *name); + OpAICoreConfig &DynamicCompileStaticFlag(bool flag); + OpAICoreConfig &DynamicFormatFlag(bool flag); + OpAICoreConfig &DynamicRankSupportFlag(bool flag); + OpAICoreConfig &DynamicShapeSupportFlag(bool flag); + OpAICoreConfig &NeedCheckSupportFlag(bool flag); + OpAICoreConfig &PrecisionReduceFlag(bool flag); + OpAICoreConfig &ExtendCfgInfo(const char *key, const char *value); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + friend class OpParamTrunk; + friend class OpDef; + + std::vector &GetInputs(void) const; + std::vector &GetOutputs(void) const; + std::vector &GetCfgKeys(void); + std::map &GetCfgInfo(void); + ge::AscendString &GetConfigValue(const char *key); + void AddCfgItem(const char *key, const char *value); + + std::unique_ptr impl_; +}; + +class OpAICoreDefImpl; +class OpAICoreDef { +public: + OpAICoreDef(); + OpAICoreDef(const OpAICoreDef &aicore_def); + ~OpAICoreDef(); + OpAICoreDef &operator=(const OpAICoreDef &aicore_def); + OpAICoreDef &SetTiling(gert::OpImplRegisterV2::TilingKernelFunc func); + OpAICoreDef &SetCheckSupport(optiling::OP_CHECK_FUNC func); + OpAICoreDef &SetOpSelectFormat(optiling::OP_CHECK_FUNC func); + OpAICoreDef &SetOpSupportInfo(optiling::OP_CHECK_FUNC func); + OpAICoreDef &SetOpSpecInfo(optiling::OP_CHECK_FUNC func); + OpAICoreDef &SetParamGeneralize(optiling::PARAM_GENERALIZE_FUNC func); + gert::OpImplRegisterV2::TilingKernelFunc &GetTiling(void); + optiling::OP_CHECK_FUNC &GetCheckSupport(void); + optiling::OP_CHECK_FUNC &GetOpSelectFormat(void); + optiling::OP_CHECK_FUNC &GetOpSupportInfo(void); + optiling::OP_CHECK_FUNC &GetOpSpecInfo(void); + optiling::PARAM_GENERALIZE_FUNC &GetParamGeneralize(void); + OpAICoreDef &AddConfig(const char *soc); + OpAICoreDef &AddConfig(const char *soc, OpAICoreConfig &aicore_config); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + friend class OpParamTrunk; + friend class OpDef; + + std::map &GetAICoreConfigs(void); + void Log(const char *op_type, const char *info) const; + + std::unique_ptr impl_; +}; + +class OpMC2DefImpl; +class OpMC2Def { +public: + OpMC2Def(); + OpMC2Def(const OpMC2Def &mc2_def); + ~OpMC2Def(); + OpMC2Def &operator=(const OpMC2Def &mc2_def); + OpMC2Def &HcclGroup(const char *value); + OpMC2Def &HcclGroup(std::vector value); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + friend class OpParamTrunk; + + std::vector &GetHcclGroups(void) const; + std::unique_ptr impl_; +}; + +class OpDefImpl; +class OpDef { +public: + explicit OpDef(const char *type); + OpDef(const OpDef &op_def); + ~OpDef(); + OpDef &operator=(const OpDef &op_def); + OpParamDef &Input(const char *name); + OpParamDef &Output(const char *name); + OpAttrDef &Attr(const char *name); + OpDef &Comment(CommentSection section, const char *comment); + OpDef &SetInferShape(gert::OpImplRegisterV2::InferShapeKernelFunc func); + OpDef &SetInferShapeRange(gert::OpImplRegisterV2::InferShapeRangeKernelFunc func); + OpDef &SetInferDataType(gert::OpImplRegisterV2::InferDataTypeKernelFunc func); + gert::OpImplRegisterV2::InferShapeKernelFunc &GetInferShape(void); + gert::OpImplRegisterV2::InferShapeRangeKernelFunc &GetInferShapeRange(void); + gert::OpImplRegisterV2::InferDataTypeKernelFunc &GetInferDataType(void); + OpAICoreDef &AICore(void); + OpMC2Def &MC2(void); + OpDef &FormatMatchMode(FormatCheckOption option); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + friend class OpParamTrunk; + using ArrParam = std::pair; + struct DfsParam { + std::vector> full_types; + std::vector> full_formats; + std::vector types; + std::vector formats; + }; + enum class PortStat : uint32_t { + IN = 0, + OUT = 1, + INOUT = 2, + INVALID_STAT + }; + struct PortFollowInfo { + PortStat port_stat = PortStat::IN; + uint32_t index_in = 0; + uint32_t index_out = 0; + ge::AscendString follow_port_name = ""; + FollowType follow_type = FollowType::ALL; + }; + ge::AscendString &GetOpType(void); + ge::AscendString &GetCateGory(void) const; + std::vector &GetBrief(void) const; + std::vector &GetConstraints(void) const; + std::vector &GetRestrictions(void) const; + std::vector &GetSee(void) const; + std::vector &GetThirdPartyFwkCopat(void) const; + std::vector &GetInputs(void); + std::vector &GetOutputs(void); + std::vector &GetAttrs(void); + std::vector GetMergeInputs(OpAICoreConfig &aicore_config); + std::vector GetMergeOutputs(OpAICoreConfig &aicore_config); + void CheckIncompatible(const std::vector& all) const; + void FullPermutation(std::vector &input_param, std::vector &output_param); + void DfsFullPermutation(DfsParam &dfs_param, const std::vector &all_param, + uint32_t list_idx, uint32_t non_list_idx) const; + void DfsDataType(DfsParam &dfs_param, const std::vector &all_param, + uint32_t list_idx, uint32_t non_list_idx) const; + void DfsFormat(DfsParam &dfs_param, const std::vector &all_param, + uint32_t list_idx, uint32_t non_list_idx) const; + uint32_t GetNonListLen(std::vector &input_param, std::vector &output_param) const; + bool IsNonListTypes(const OpParamDef &def) const; + bool IsNonListFormats(const OpParamDef &def) const; + void SetDefaultND(std::vector &defs) const; + std::vector> GetMergeInputsOutputs(const OpAICoreConfig &aicore_config); + void SetPermutedParam(const DfsParam &dfs_param, std::vector &input, + std::vector &output); + void MergeParam(std::vector &merge, std::vector &aicore_params) const; + ItemFindStatus FindAttr(const char *name, OpAttrDef **attr); + OpAttrDef &AddAttr(OpAttrDef &attr); + OpAttrDef &GetOrCreateAttr(const char *name); + void FollowImpl(void); + void FollowListImpl(const DfsParam &dfs_param, std::vector& input, std::vector& output); + std::map GetFollowMap(void); + std::map>> GetFollowShapeMap(void); + std::map>> GetFollowTypeMap(void); + OpParamDef GetParamDef(const ge::AscendString& name, OpDef::PortStat stat); + FormatCheckOption GetFormatMatchMode(void); + void UpdateInput(const DfsParam &dfs_param, std::vector &input); + void UpdateOutput(const DfsParam &dfs_param, std::vector &output); + void UpdateDtypeImpl(const DfsParam &dfs_param, OpParamDef ¶m, const uint32_t ¶m_idx); + void UpdateFormatImpl(const DfsParam &dfs_param, OpParamDef ¶m, const uint32_t ¶m_idx); + + std::unique_ptr impl_; +}; +} // namespace ops + +#endif diff --git a/inc/metadef/external/register/op_def_factory.h b/inc/metadef/external/register/op_def_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..68b72f8dcdfc4c6284833b90e2ec03993ef674ec --- /dev/null +++ b/inc/metadef/external/register/op_def_factory.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef OP_DEF_FACTORY_H +#define OP_DEF_FACTORY_H + +#include "register/op_def.h" + +namespace ops { +using OpDefCreator = std::function; +class OpDefFactory { +public: + static int OpDefRegister(const char *name, OpDefCreator creator); + +private: + friend class AclnnOpGenerator; + friend class Generator; + friend class OpProtoGenerator; + friend class GeneratorFactory; + friend class CfgGenerator; + + static OpDef OpDefCreate(const char *name); + static std::vector &GetAllOp(void); +}; +} // namespace ops + +#endif diff --git a/inc/metadef/external/register/op_def_registry.h b/inc/metadef/external/register/op_def_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..0c652a7a9e6ec34bdb1abd48e0e861592bcac77a --- /dev/null +++ b/inc/metadef/external/register/op_def_registry.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef OP_DEF_REGISTRY_H +#define OP_DEF_REGISTRY_H + +#include "register/op_def.h" +#include "register/op_def_factory.h" + +#if defined(OP_PROTO_LIB) + +#define OP_ADD(opType, ...) \ + static int g_##opType##_added = [](const char *name) { \ + opType op(#opType); \ + gert::OpImplRegisterV2 impl(#opType); \ + impl.InferShape(op.GetInferShape()) \ + .InferShapeRange(op.GetInferShapeRange()) \ + .InferDataType(op.GetInferDataType()); \ + gert::OpImplRegisterV2 implReg(impl); \ + return 0; \ + }(#opType) + +#elif defined(OP_TILING_LIB) + +#define OP_ADD(opType, ...) \ + struct OpAddCompilerInfoPlaceholder##opType {}; \ + static ge::graphStatus TilingPrepare##opType(gert::TilingParseContext *context) { return ge::GRAPH_SUCCESS; } \ + static int g_##opType##_added = [](const char *name) { \ + opType op(#opType); \ + gert::OpImplRegisterV2 impl(#opType); \ + impl.Tiling(op.AICore().GetTiling()); \ + impl.TilingParse(TilingPrepare##opType); \ + optiling::OpCheckFuncHelper(FUNC_CHECK_SUPPORTED, #opType, op.AICore().GetCheckSupport()); \ + optiling::OpCheckFuncHelper(FUNC_OP_SELECT_FORMAT, #opType, op.AICore().GetOpSelectFormat()); \ + optiling::OpCheckFuncHelper(FUNC_GET_OP_SUPPORT_INFO, #opType, op.AICore().GetOpSupportInfo()); \ + optiling::OpCheckFuncHelper(FUNC_GET_SPECIFIC_INFO, #opType, op.AICore().GetOpSpecInfo()); \ + optiling::OpCheckFuncHelper(#opType, op.AICore().GetParamGeneralize()); \ + gert::OpImplRegisterV2 implReg(impl); \ + return 0; \ + }(#opType) + +#else + +#define OP_ADD(opType, ...) \ + static int g_##opType##_added = \ + ops::OpDefFactory::OpDefRegister(#opType, [](const char *name) { return opType(#opType); }) + +#endif +#endif diff --git a/inc/metadef/external/register/op_impl_registry.h b/inc/metadef/external/register/op_impl_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..8eb90e68fd9516ed6963d274f28064f6b54c07b2 --- /dev/null +++ b/inc/metadef/external/register/op_impl_registry.h @@ -0,0 +1,129 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_H_ +#include +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/compiler_def.h" +#include "exe_graph/runtime/base_type.h" +#include "exe_graph/runtime/infer_shape_context.h" +#include "exe_graph/runtime/infer_shape_range_context.h" +#include "exe_graph/runtime/infer_datatype_context.h" +#include "exe_graph/runtime/tiling_context.h" +#include "exe_graph/runtime/tiling_parse_context.h" +#include "exe_graph/runtime/op_execute_context.h" + +namespace ge { +class AnyValue; +} // namespace ge + +namespace gert { +enum TilingPlacement { + TILING_ON_HOST = 0, + TILING_ON_AICPU = 1, +}; + +class OpImplRegisterV2Impl; +class OpImplRegisterV2 { + public: + explicit OpImplRegisterV2(const ge::char_t *op_type); + OpImplRegisterV2(OpImplRegisterV2 &®ister_data) noexcept; + OpImplRegisterV2(const OpImplRegisterV2 ®ister_data); + OpImplRegisterV2 &operator=(const OpImplRegisterV2 &) = delete; + OpImplRegisterV2 &operator=(OpImplRegisterV2 &&) = delete; + ~OpImplRegisterV2(); + + using InferShapeKernelFunc = UINT32 (*)(InferShapeContext *); + using InferShapeRangeKernelFunc = UINT32 (*)(InferShapeRangeContext *); + using TilingKernelFunc = UINT32 (*)(TilingContext *); + using InferDataTypeKernelFunc = UINT32 (*)(InferDataTypeContext *); + using GenSimplifiedKeyKernelFunc = UINT32 (*)(TilingContext *, ge::char_t *); + // aclnn接口的原型,入参含义: + // OpExecuteContext:保存算子的Input,Output,Attr信息 + using OpExecFunc = UINT32 (*)(OpExecuteContext *); + using OpType = ge::AscendString; + using PrivateAttrList = std::vector>; + using PrivateAttrSet = std::unordered_set; + using CompileInfoCreatorFunc = void *(*) (); + using CompileInfoDeleterFunc = void (*)(void *); + using KernelFunc = UINT32 (*)(KernelContext *context); + using TilingParseFunc = UINT32 (*)(TilingParseContext *context); + + public: + OpImplRegisterV2 &InferShape(InferShapeKernelFunc infer_shape_func); + OpImplRegisterV2 &InferShapeRange(InferShapeRangeKernelFunc infer_shape_range_func); + OpImplRegisterV2 &InferDataType(InferDataTypeKernelFunc infer_datatype_func); + /* + * 一种datatype推导规则,将第一个输入的datatype作为所有输出的datatype。 + * 使用方式:注册此规则,可以不用再注册自定义推导规则。若同时注册了InferDataType和InferOutDataTypeByFirstInput,将使能最后注册的规则。 + * 异常场景:若算子无输入或输入datatype为未定义,推导将报错。 + * */ + OpImplRegisterV2 &InferOutDataTypeSameWithFirstInput(); + OpImplRegisterV2 &Tiling(TilingKernelFunc tiling_func, size_t max_tiling_data_size = 2048); + OpImplRegisterV2 &GenSimplifiedKey(GenSimplifiedKeyKernelFunc gen_simplifiedkey_func); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, int64_t private_attr_val); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, const std::vector &private_attr_val); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, const ge::char_t *private_attr_val); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, ge::float32_t private_attr_val); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, bool private_attr_val); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, const std::vector &private_attr_val); + template + OpImplRegisterV2 &TilingParse(KernelFunc const tiling_parse_func) { + return TilingParse(tiling_parse_func, CreateCompileInfo, DeleteCompileInfo); + } + template + OpImplRegisterV2 &TilingParse(TilingParseFunc const tiling_parse_func) { + return TilingParse(reinterpret_cast(tiling_parse_func), CreateCompileInfo, + DeleteCompileInfo); + } + OpImplRegisterV2 &InputsDataDependency(std::initializer_list inputs); + OpImplRegisterV2 &OpExecuteFunc(OpExecFunc op_execute_func); + OpImplRegisterV2 &HostInputs(std::initializer_list inputs); + OpImplRegisterV2 &TilingInputsDataDependency(std::initializer_list inputs); + OpImplRegisterV2 &TilingInputsDataDependency(std::initializer_list inputs, + std::initializer_list placements); + OpImplRegisterV2 &OutputShapeDependOnCompute(std::initializer_list outputs); + + private: + OpImplRegisterV2 &TilingParse(KernelFunc tiling_parse_func, + CompileInfoCreatorFunc creator_func, + CompileInfoDeleterFunc deleter_func); + OpImplRegisterV2 &PrivateAttr(const ge::char_t *private_attr, ge::AnyValue private_attr_av); + + template::value), int32_t>::type = 0> + static void *CreateCompileInfo() { + return new T(); + } + template + static void DeleteCompileInfo(void *const obj) { + delete reinterpret_cast(obj); + } + private: + std::unique_ptr impl_; +}; +} // namespace gert + +#define IMPL_OP_COUNTER(op_type, name, counter) \ + static gert::OpImplRegisterV2 VAR_UNUSED name##counter = gert::OpImplRegisterV2(#op_type) +#define IMPL_OP_COUNTER_NUMBER(op_type, name, counter) IMPL_OP_COUNTER(op_type, name, counter) +#define IMPL_OP(op_type) IMPL_OP_COUNTER_NUMBER(op_type, op_impl_register_##op_type, __COUNTER__) +#define IMPL_OP_DEFAULT() IMPL_OP(DefaultImpl) + +#define IMPL_OP_INFERSHAPE(op_type) \ + gert::OpImplRegisterV2 VAR_UNUSED op_impl_register_infershape_##op_type = gert::OpImplRegisterV2(#op_type) +#define IMPL_OP_OPTILING(op_type) \ + gert::OpImplRegisterV2 VAR_UNUSED op_impl_register_optiling_##op_type = gert::OpImplRegisterV2(#op_type) + +#endif // INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_H_ diff --git a/inc/metadef/external/register/op_info_record_registry.h b/inc/metadef/external/register/op_info_record_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..1c25f3070a26d3893bac4c03dee0e6f014f48534 --- /dev/null +++ b/inc/metadef/external/register/op_info_record_registry.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_INFO_RECORD_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_OP_INFO_RECORD_REGISTRY_H_ +#include + +#include "external/exe_graph/runtime/tiling_context.h" + +namespace OpInfoRecord { +struct OpCompilerOption { + explicit OpCompilerOption(const std::string &impl_mode_v, bool deterministic_v = true) : + impl_mode(impl_mode_v), deterministic(deterministic_v) {} + explicit OpCompilerOption(const char *impl_mode_v, bool deterministic_v = true) : + impl_mode(impl_mode_v), deterministic(deterministic_v) {} + std::string impl_mode; + bool deterministic; +}; + +struct OpKernelInfo { + explicit OpKernelInfo(const std::string &bin_info_v, int8_t bin_type_v) : + bin_info(bin_info_v), bin_type(bin_type_v) {} + explicit OpKernelInfo(const char *bin_info_v, int8_t bin_type_v) : + bin_info(bin_info_v), bin_type(bin_type_v) {} + std::string bin_info; + int8_t bin_type; +}; + +class __attribute__((visibility("default"))) OpInfoRecordRegister { +public: + using NotifyFn = void(*)(bool); + static OpInfoRecordRegister *Instance(); + /* + * @ingroup OpInfoRecord + * @brief Register the notification function + * @param notify_fn [IN] Callback notification function. + */ + void RegNotify(const NotifyFn notifyFn) const; + + /* + * @ingroup OpInfoRecord + * @brief Obtains the current switch status + * @retval true: The switch is enabled. + * @retval false: The switch is disablesd. + */ + bool GetSwitchState() const; + + /* + * @ingroup OpInfoRecord + * @brief Output the current operator information + * + * @param ctx [IN] Operator context information + * @param opt [IN] Operator compile option + */ + void ExeOptInfoStat( + const gert::TilingContext *ctx, + const OpCompilerOption &opt, + const OpKernelInfo *kernelInfo) const; + +private: + OpInfoRecordRegister() = default; + ~OpInfoRecordRegister() = default; + OpInfoRecordRegister(const OpInfoRecordRegister &) = delete; + OpInfoRecordRegister &operator=(const OpInfoRecordRegister &) = delete; +}; // class OpInfoRecordRegister +} // namespace OpInfoRecord +#endif // INC_EXTERNAL_REGISTER_OP_INFO_RECORD_REGISTRY_H_ diff --git a/inc/metadef/external/register/op_lib_register.h b/inc/metadef/external/register/op_lib_register.h new file mode 100644 index 0000000000000000000000000000000000000000..db956468ae68fd69624db9133cf4fdf88fe63552 --- /dev/null +++ b/inc/metadef/external/register/op_lib_register.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_LIB_REGISTER_H +#define INC_EXTERNAL_REGISTER_OP_LIB_REGISTER_H +#include "graph/compiler_def.h" +#include "graph/types.h" + +namespace ge { +class OpLibRegisterImpl; +class OpLibRegister { + public: + explicit OpLibRegister(const char_t *vendor_name); + OpLibRegister(OpLibRegister &&other) noexcept; + OpLibRegister(const OpLibRegister &other); + OpLibRegister &operator=(const OpLibRegister &) = delete; + OpLibRegister &operator=(OpLibRegister &&) = delete; + ~OpLibRegister(); + + using OpLibInitFunc = uint32_t (*)(); + OpLibRegister &RegOpLibInit(OpLibInitFunc func); + + private: + std::unique_ptr impl_; +}; +} // namespace ge + +#define REGISTER_OP_LIB(vendor_name) REGISTER_OP_LIB_UNIQ_HELPER(vendor_name, __COUNTER__) + +#define REGISTER_OP_LIB_UNIQ_HELPER(vendor_name, counter) REGISTER_OP_LIB_UNIQ(vendor_name, counter) + +#define REGISTER_OP_LIB_UNIQ(vendor_name, counter) \ + static ge::OpLibRegister VAR_UNUSED g_##vendor_name##counter = ge::OpLibRegister(#vendor_name) + +#endif // INC_EXTERNAL_REGISTER_OP_LIB_REGISTER_H diff --git a/inc/metadef/external/register/op_tiling_attr_utils.h b/inc/metadef/external/register/op_tiling_attr_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a168d349ccb973d19f53ca47bef95cf9b352db79 --- /dev/null +++ b/inc/metadef/external/register/op_tiling_attr_utils.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_TILING_ATTR_UTILS_H_ +#define INC_EXTERNAL_REGISTER_OP_TILING_ATTR_UTILS_H_ + +#include +#include "external/graph/operator.h" + +namespace optiling { +class AttrData; +using AttrDataPtr = std::shared_ptr; + +class AttrData { +public: + AttrData() {} + virtual ~AttrData() {} + virtual size_t GetSize() const = 0; + virtual const std::uint8_t *GetData() = 0; +}; + +ge::graphStatus GetOperatorAttrValue(const ge::Operator &op, const char *attr_name, const char *attr_dtype, + AttrDataPtr &attr_data_ptr, const char *target_dtype = nullptr); + +} // namespace optiling +#endif // INC_EXTERNAL_REGISTER_OP_TILING_ATTR_UTILS_H_ diff --git a/inc/metadef/external/register/op_tiling_info.h b/inc/metadef/external/register/op_tiling_info.h new file mode 100644 index 0000000000000000000000000000000000000000..4e51466a16463d73c4d67193008acc5dc6595ebf --- /dev/null +++ b/inc/metadef/external/register/op_tiling_info.h @@ -0,0 +1,170 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_TILING_INFO_H_ +#define INC_EXTERNAL_REGISTER_OP_TILING_INFO_H_ + +#include +#include +#include +#include "external/graph/ge_error_codes.h" +#include "external/graph/ascend_string.h" +#include "external/graph/tensor.h" + +namespace optiling { +using ByteBuffer = std::stringstream; + +enum class TensorArgType { + TA_NONE, + TA_SINGLE, + TA_LIST, +}; + +class TeOpVarAttrArgsImpl; +class TeOpVarAttrArgs { + friend class VarAttrHelper; + +public: + TeOpVarAttrArgs() = default; + ~TeOpVarAttrArgs() = default; + const uint8_t *GetData(const std::string &name, const std::string &dtype, size_t &size) const; + +private: + std::shared_ptr impl_; +}; + +struct TeOpTensor { + std::vector shape; + std::vector ori_shape; + std::string format; + std::string ori_format; + std::string dtype; + std::string name; + std::map attrs; +}; + +struct TeOpTensorArg { + TensorArgType arg_type; + std::vector tensor; +}; + +struct OpRunInfo { + uint32_t block_dim; + std::vector workspaces; + ByteBuffer tiling_data; + bool clear_atomic; + uint64_t tiling_key; + int32_t tiling_cond; +}; + +using TeOpAttrArgs = std::vector; +using TeConstTensorData = std::tuple; + +struct TeOpParas { + std::vector inputs; + std::vector outputs; + std::map const_inputs; + TeOpAttrArgs attrs; + std::string op_type; + TeOpVarAttrArgs var_attrs; +}; + +struct OpCompileInfo { + std::string str; + std::string key; +}; + +namespace utils { +class OpRunInfoImpl; +class OpRunInfo { +public: + OpRunInfo(); + ~OpRunInfo() = default; + + OpRunInfo(const uint32_t &block_dim, const bool &clear_atomic, const uint64_t &tiling_key); + // Copy + OpRunInfo(const OpRunInfo &runinfo); + // Move + OpRunInfo(OpRunInfo &&runinfo); + // Copy + OpRunInfo &operator=(const OpRunInfo &runinfo); + // Move + OpRunInfo &operator=(OpRunInfo &&runinfo); + + void SetBlockDim(const uint32_t &block_dim); + uint32_t GetBlockDim() const; + void SetScheduleMode(const uint32_t schedule_mode); + uint32_t GetScheduleMode() const; + void AddWorkspace(const int64_t &workspace); + size_t GetWorkspaceNum() const; + ge::graphStatus GetWorkspace(const size_t &idx, int64_t &workspace) const; + void GetAllWorkspaces(std::vector &workspaces) const; + const std::vector &GetAllWorkspaces() const; + void SetWorkspaces(const std::vector &workspaces); + + template + void AddTilingData(const T &value) { + AddTilingData(reinterpret_cast(&value), sizeof(value)); + } + template + void operator << (const T &value) { + AddTilingData(reinterpret_cast(&value), sizeof(T)); + } + void AddTilingData(const char *value, const size_t size); + void* GetAddrBase(uint64_t& max_size) const; + void SetAddrBaseOffset(const uint64_t size); + ByteBuffer &GetAllTilingData(); + const ByteBuffer &GetAllTilingData() const; + void InternelSetTiling(const ByteBuffer &value); + void SetClearAtomic(const bool clear_atomic); + bool GetClearAtomic() const; + + void SetTilingKey(const uint64_t &new_tiling_key); + uint64_t GetTilingKey() const; + uint64_t GetTilingDataSize() const; + void ResetWorkspace(); + void ResetAddrBase(void *const addr_base, const uint64_t max_size); + void AlignOffsetWith64(); + bool SetMemCheckBaseOffset(const uint64_t &offset); + void SetTilingCond(const int32_t tiling_cond); + int32_t GetTilingCond() const; + void SetLocalMemorySize(const uint32_t local_memory_size); + uint32_t GetLocalMemorySize() const; +private: + std::shared_ptr impl_; +}; + +class OpCompileInfoImpl; +class OpCompileInfo { +public: + OpCompileInfo(); + ~OpCompileInfo() = default; + OpCompileInfo(const ge::AscendString &key, const ge::AscendString &value); + OpCompileInfo(const std::string &key, const std::string &value); + // Copy + OpCompileInfo(const OpCompileInfo &compileinfo); + // Move + OpCompileInfo(OpCompileInfo &&compileinfo); + // Copy + OpCompileInfo &operator=(const OpCompileInfo &compileinfo); + // Move + OpCompileInfo &operator=(OpCompileInfo &&compileinfo); + + void SetKey(const ge::AscendString &key); + const ge::AscendString &GetKey() const; + + void SetValue(const ge::AscendString &value); + const ge::AscendString &GetValue() const; + +private: + std::shared_ptr impl_; +}; +} +} // namespace optiling +#endif // INC_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/inc/metadef/external/register/op_tiling_registry.h b/inc/metadef/external/register/op_tiling_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..e54639f2ae5bb2e106797f99dc67fbd3bf4ff185 --- /dev/null +++ b/inc/metadef/external/register/op_tiling_registry.h @@ -0,0 +1,152 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_TILING_REGISTRY_H_ +#define INC_EXTERNAL_REGISTER_OP_TILING_REGISTRY_H_ + +#include +#include +#include +#include +#include +#include "external/graph/operator.h" +#include "external/register/register_error_codes.h" +#include "external/register/register_types.h" +#include "external/register/op_compile_info_base.h" +#include "external/register/op_tiling_info.h" + +#define REGISTER_OP_TILING(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER(optype, (opfunc), __COUNTER__) + +#define REGISTER_OP_TILING_UNIQ_HELPER(optype, opfunc, counter) REGISTER_OP_TILING_UNIQ(optype, (opfunc), counter) + +#define REGISTER_OP_TILING_V2(optype, opfunc) REGISTER_OP_TILING_UNIQ_HELPER_V2(optype, (opfunc), __COUNTER__) + +#define REGISTER_OP_TILING_UNIQ_HELPER_V2(optype, opfunc, counter) REGISTER_OP_TILING_UNIQ_V2(optype, (opfunc), counter) + +#define REGISTER_OP_TILING_V3(optype, tilingfunc, parsefunc) \ + REGISTER_OP_TILING_UNIQ_HELPER_V3(optype, (tilingfunc), (parsefunc), __COUNTER__) + +#define REGISTER_OP_TILING_UNIQ_HELPER_V3(optype, tilingfunc, parsefunc, counter) \ + REGISTER_OP_TILING_UNIQ_V3(optype, (tilingfunc), (parsefunc), counter) + +#define REGISTER_OP_TILING_V4(optype, tilingfunc, parsefunc) \ + REGISTER_OP_TILING_UNIQ_HELPER_V4(optype, (tilingfunc), (parsefunc), __COUNTER__) + +#define REGISTER_OP_TILING_UNIQ_HELPER_V4(optype, tilingfunc, parsefunc, counter) \ + REGISTER_OP_TILING_UNIQ_V4(optype, (tilingfunc), (parsefunc), counter) + +#ifdef DISABLE_COMPILE_V1 +#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) +#define REGISTER_OP_TILING_UNIQ_V2(optype, opfunc, counter) +#define REGISTER_OP_TILING_UNIQ_V3(optype, tilingfunc, parsefunc, counter) +#define REGISTER_OP_TILING_UNIQ_V4(optype, tilingfunc, parsefunc, counter) +#else +#define REGISTER_OP_TILING_UNIQ(optype, opfunc, counter) \ + static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV1##counter(#optype, (opfunc)) + +#define REGISTER_OP_TILING_UNIQ_V2(optype, opfunc, counter) \ + static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV2##counter(#optype, (opfunc)) + +#define REGISTER_OP_TILING_UNIQ_V3(optype, tilingfunc, parsefunc, counter) \ + static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV3##counter(#optype, (tilingfunc), (parsefunc)) + +#define REGISTER_OP_TILING_UNIQ_V4(optype, tilingfunc, parsefunc, counter) \ + static optiling::OpTilingFuncRegistry g_##optype##TilingRegistryInterfV4##counter(#optype, (tilingfunc), (parsefunc)) +#endif + + +using Status = domi::Status; +namespace optiling { +template +ByteBuffer &ByteBufferPut(ByteBuffer &buf, const T &buffer_value) { + (void) buf.write(reinterpret_cast(&buffer_value), static_cast(sizeof(buffer_value))); + (void) buf.flush(); + return buf; +} + +template +ByteBuffer &ByteBufferGet(ByteBuffer &buf, T &buffer_value) { + (void) buf.read(reinterpret_cast(&buffer_value), static_cast(sizeof(buffer_value))); + return buf; +} + +size_t ByteBufferGetAll(ByteBuffer &buf, ge::char_t *dest, size_t dest_len); +ByteBuffer &ByteBufferPut(ByteBuffer &buf, const uint8_t *data, size_t data_len); + +using OpTilingFunc = std::function; +using OpTilingFuncPtr = std::shared_ptr; +class FMK_FUNC_HOST_VISIBILITY OpTilingRegistryInterf { + public: + OpTilingRegistryInterf(std::string op_type, OpTilingFunc func); + ~OpTilingRegistryInterf() = default; + static std::unordered_map &RegisteredOpInterf(); +}; + +using OpRunInfoV2 = utils::OpRunInfo; +using OpCompileInfoV2 = utils::OpCompileInfo; +using OpTilingFuncV2 = std::function; +using OpTilingFuncV2Ptr = std::shared_ptr; +class FMK_FUNC_HOST_VISIBILITY OpTilingRegistryInterf_V2 { +public: + OpTilingRegistryInterf_V2(const std::string &op_type, OpTilingFuncV2 func); + ~OpTilingRegistryInterf_V2() = default; + static std::unordered_map &RegisteredOpInterf(); +}; + +using OpTilingFuncV3 = std::function; +using OpParseFuncV3 = std::function; +using OpTilingFuncV4 = std::function; +using OpParseFuncV4 = std::function; + +class OpTilingFuncInfo { +public: + explicit OpTilingFuncInfo(const std::string &op_type); + OpTilingFuncInfo() = default; + ~OpTilingFuncInfo() = default; + + bool IsFunctionV4(); + bool IsFunctionV3(); + bool IsFunctionV2(); + bool IsFunctionV1(); + void SetOpTilingFunc(OpTilingFunc &tiling_func); + void SetOpTilingFuncV2(OpTilingFuncV2 &tiling_func); + void SetOpTilingFuncV3(OpTilingFuncV3 &tiling_func, OpParseFuncV3 &parse_func); + void SetOpTilingFuncV4(OpTilingFuncV4 &tiling_func, OpParseFuncV4 &parse_func); + const OpTilingFunc& GetOpTilingFunc(); + const OpTilingFuncV2& GetOpTilingFuncV2(); + const OpTilingFuncV3& GetOpTilingFuncV3(); + const OpParseFuncV3& GetOpParseFuncV3(); + const OpTilingFuncV4& GetOpTilingFuncV4(); + const OpParseFuncV4& GetOpParseFuncV4(); + const std::string& GetOpType() const { + return op_type_; + } + +private: + std::string op_type_; + OpTilingFunc tiling_func_; + OpTilingFuncV2 tiling_func_v2_; + OpTilingFuncV3 tiling_func_v3_; + OpParseFuncV3 parse_func_v3_; + OpTilingFuncV4 tiling_func_v4_; + OpParseFuncV4 parse_func_v4_; +}; + +class FMK_FUNC_HOST_VISIBILITY OpTilingFuncRegistry { +public: + OpTilingFuncRegistry(const std::string &op_type, OpTilingFunc tiling_func); + OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV2 tiling_func); + OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV3 tiling_func, OpParseFuncV3 parse_func); + OpTilingFuncRegistry(const std::string &op_type, OpTilingFuncV4 tiling_func, OpParseFuncV4 parse_func); + ~OpTilingFuncRegistry() = default; + static std::unordered_map &RegisteredOpFuncInfo(); +}; + +} // namespace optiling +#endif // INC_EXTERNAL_REGISTER_OP_TILING_REGISTRY_H_ diff --git a/inc/metadef/external/register/register.h b/inc/metadef/external/register/register.h new file mode 100644 index 0000000000000000000000000000000000000000..ee4a9bc2d3d7653bdef513dddc384483e08d9a43 --- /dev/null +++ b/inc/metadef/external/register/register.h @@ -0,0 +1,218 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_H +#define INC_EXTERNAL_REGISTER_REGISTER_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "graph/operator.h" +#include "register/register_error_codes.h" +#include "register/register_fmk_types.h" +#include "register/register_types.h" +#include "graph/ascend_string.h" +#include "graph/types.h" + + +using std::unique_ptr; +using std::make_shared; +using std::to_string; +using std::string; +using std::pair; +using std::map; +using std::vector; + +/*lint -e148*/ +namespace ge { +class TBEPluginManager; +} + +namespace google { +namespace protobuf { +class Message; +} +} + +namespace domi { +constexpr int64_t kMaxNameLength = 1048576; // 1M + +enum DynamicType : int16_t { + kInvalid = 0, + kInput = 1, + kOutput = 2 +}; +struct DynamicInputOutputInfo { + DynamicType type; // input/output + const char_t *port_name; + int64_t port_name_len; + const char_t *attr_name; + int64_t attr_name_len; + DynamicInputOutputInfo(const DynamicType type_instance, const char_t *const port_name_instance, + const int64_t port_name_len_instance, const char_t *const attr_name_instance, + const int64_t attr_name_len_instance) + : type(type_instance), port_name(port_name_instance), port_name_len(port_name_len_instance), + attr_name(attr_name_instance), attr_name_len(attr_name_len_instance) {} + DynamicInputOutputInfo() : DynamicInputOutputInfo(kInvalid, nullptr, 0L, nullptr, 0L) {} +}; +Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); +Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, + const std::vector &dynamic_name_attr_value); +ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFn(const ge::Operator &, ge::Operator &)) +Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); +ATTRIBUTED_DEPRECATED(Status AutoMappingByOpFnDynamic(const ge::Operator &, ge::Operator &, + const std::vector &)) +Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, + std::map> dynamic_name_attr_value, + int32_t in_pos = -1, int32_t out_pos = -1); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output); +Status AutoMappingSubgraphIndex(const ge::Graph &graph, + const std::function &input, + const std::function &output); +using google::protobuf::Message; +class OpRegistrationDataImpl; +class FrameworkRegistryImpl; + +using ParseParamFunc = std::function; +using ParseParamByOpFunc = std::function; +using FusionParseParamFunc = std::function, + ge::Operator &)>; +using FusionParseParamByOpFunc = std::function &, ge::Operator &)>; +using ParseSubgraphFunc = std::function; +using ParseOpToGraphFunc = std::function; +using ParseSubgraphFuncV2 = std::function; +using AutoMappingSubgraphIOIndexFunc = std::function &input, + const std::function &output)>; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY FrameworkRegistry { + public: + FrameworkRegistry(const FrameworkRegistry &) = delete; + FrameworkRegistry& operator = (const FrameworkRegistry &) = delete; + ~FrameworkRegistry(); + static FrameworkRegistry& Instance(); + void AddAutoMappingSubgraphIOIndexFunc(domi::FrameworkType framework, AutoMappingSubgraphIOIndexFunc fun); + AutoMappingSubgraphIOIndexFunc GetAutoMappingSubgraphIOIndexFunc(domi::FrameworkType framework); + private: + FrameworkRegistry(); + std::unique_ptr impl_; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY AutoMappingSubgraphIOIndexFuncRegister { + public: + AutoMappingSubgraphIOIndexFuncRegister(domi::FrameworkType framework, AutoMappingSubgraphIOIndexFunc fun); + ~AutoMappingSubgraphIOIndexFuncRegister() = default; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { + public: + ATTRIBUTED_DEPRECATED(OpRegistrationData(const char_t *)) + OpRegistrationData(const std::string &om_optype); + + OpRegistrationData(const char_t *om_optype); + + ~OpRegistrationData(); + + OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &OriginOpType(const std::vector &)) + OpRegistrationData &OriginOpType(const std::initializer_list &ori_optype_list); + + OpRegistrationData &OriginOpType(const std::vector &ori_op_type_list); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &OriginOpType(const char_t *)) + OpRegistrationData &OriginOpType(const std::string &ori_optype); + + OpRegistrationData &OriginOpType(const char_t *ori_op_type); + + OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); + + OpRegistrationData &ParseParamsByOperatorFn(const ParseParamByOpFunc &parse_param_by_op_fn); + + OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); + + OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFuncV2 &)) + OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); + + OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFuncV2 &subgraph_post_fn); + + OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &DelInputWithCond(int32_t, const char_t *, bool)) + OpRegistrationData &DelInputWithCond(int32_t inputIdx, const std::string &attrName, bool attrValue); + + OpRegistrationData &DelInputWithCond(int32_t input_idx, const char_t *attr_name, bool attr_value); + + ATTRIBUTED_DEPRECATED(OpRegistrationData &DelInputWithOriginalType(int32_t, const char_t *)) + OpRegistrationData &DelInputWithOriginalType(int32_t input_idx, const std::string &ori_type); + + OpRegistrationData &DelInputWithOriginalType(int32_t input_idx, const char_t *ori_type); + + OpRegistrationData &InputReorderVector(const std::vector &input_order); + + OpRegistrationData &ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn); + + domi::ImplyType GetImplyType () const; + ATTRIBUTED_DEPRECATED(Status GetOmOptype(ge::AscendString &) const) + std::string GetOmOptype () const; + Status GetOmOptype(ge::AscendString &om_op_type) const; + ATTRIBUTED_DEPRECATED(GetOriginOpTypeSet(std::set &) const) + std::set GetOriginOpTypeSet() const; + Status GetOriginOpTypeSet(std::set &ori_op_type) const; + domi::FrameworkType GetFrameworkType() const; + ParseParamFunc GetParseParamFn() const; + ParseParamByOpFunc GetParseParamByOperatorFn() const; + FusionParseParamFunc GetFusionParseParamFn() const; + FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; + ParseSubgraphFunc GetParseSubgraphPostFn() const; + ParseOpToGraphFunc GetParseOpToGraphFn() const; + Status GetParseSubgraphPostFn(ParseSubgraphFuncV2 &func) const; + + private: + std::shared_ptr impl_; + friend class OpRegistry; + friend class OpRegistrationTbe; + friend class ge::TBEPluginManager; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { + public: + OpReceiver(OpRegistrationData ®_data); + ~OpReceiver() = default; +}; +} // namespace domi +namespace ge { +using OpRegistrationData = domi::OpRegistrationData; +using OpReceiver = domi::OpReceiver; +} // namespace ge + +#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, (name)) +#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, (name)) +#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \ + static OpReceiver register_op##ctr \ + __attribute__((unused)) = \ + OpRegistrationData(name) + +#define REGISTER_AUTOMAPPING_SUBGRAPH_IO_INDEX_FUNC(framework, fun) \ + static AutoMappingSubgraphIOIndexFuncRegister \ + auto_mapping_subgraph_fun_##framework(framework, (fun)); + +/*lint +e148*/ +#endif // INC_EXTERNAL_REGISTER_REGISTER_H diff --git a/inc/metadef/external/register/register_custom_pass.h b/inc/metadef/external/register/register_custom_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..9f2959d10f72eb4a1e6e28caf3982d01fa07e854 --- /dev/null +++ b/inc/metadef/external/register/register_custom_pass.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_PASS_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_PASS_H_ + +#include +#include +#include + +#include "graph/graph.h" +#include "external/ge_common/ge_api_error_codes.h" +#include "register/register_types.h" + +namespace ge { +class PassRegistrationDataImpl; +class CustomPassContext; +class CustomPassContextImpl; +using CustomPassFunc = std::function; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PassRegistrationData { + public: + PassRegistrationData() = default; + ~PassRegistrationData() = default; + + PassRegistrationData(std::string pass_name); + + PassRegistrationData &CustomPassFn(const CustomPassFunc &custom_pass_fn); + + std::string GetPassName() const; + + CustomPassFunc GetCustomPassFn() const; + + private: + std::shared_ptr impl_; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY PassReceiver { + public: + PassReceiver(PassRegistrationData ®_data); + ~PassReceiver() = default; +}; + +class CustomPassContext { + public: + CustomPassContext(); + virtual ~CustomPassContext() = default; + + void SetErrorMessage(const AscendString &error_message); + AscendString GetErrorMessage() const; + + private: + std::unique_ptr impl_; +}; +} // namespace ge + +#define REGISTER_CUSTOM_PASS(name) REGISTER_CUSTOM_PASS_UNIQ_HELPER(__COUNTER__, (name)) +#define REGISTER_CUSTOM_PASS_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_PASS_UNIQ(ctr, (name)) +#define REGISTER_CUSTOM_PASS_UNIQ(ctr, name) \ + static ::ge::PassReceiver register_pass##ctr \ + __attribute__((unused)) = \ + ::ge::PassRegistrationData((name)) + +#endif // INC_EXTERNAL_REGISTER_REGISTER_PASS_H_ diff --git a/inc/metadef/external/register/register_error_codes.h b/inc/metadef/external/register/register_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..e22b9fae2da02f7f56b59fca81f79e1b395513c3 --- /dev/null +++ b/inc/metadef/external/register/register_error_codes.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ + +#define SYSID_FWK 3U // Subsystem ID +#define MODID_COMMON 0U // Common module ID + +#define DECLARE_ERRORNO(sysid, modid, name, value) \ + constexpr domi::Status name = \ + ((static_cast(0xFFU & (static_cast(sysid)))) << 24U) | \ + ((static_cast(0xFFU & (static_cast(modid)))) << 16U) | \ + (static_cast((0xFFFFU & (static_cast(value))))); + +#define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) + +namespace domi { +using Status = uint32_t; + +// General error code +DECLARE_ERRORNO(0U, 0U, SUCCESS, 0U); +DECLARE_ERRORNO(0xFFU, 0xFFU, FAILED, 0xFFFFFFFFU); +DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1U); // 50331649 +DECLARE_ERRORNO(SYSID_FWK, 1U, SCOPE_NOT_CHANGED, 201U); +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ diff --git a/inc/metadef/external/register/register_fmk_types.h b/inc/metadef/external/register/register_fmk_types.h new file mode 100644 index 0000000000000000000000000000000000000000..6f0c86bd020b54e6a7ca08fd5d378d27e6e316b9 --- /dev/null +++ b/inc/metadef/external/register/register_fmk_types.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ + +#include + +namespace domi { +/// +/// @ingroup domi_omg +/// @brief AI framework types +/// +enum FrameworkType { + CAFFE = 0, + MINDSPORE = 1, + TENSORFLOW = 3, + ANDROID_NN = 4, + ONNX = 5, + FRAMEWORK_RESERVED = 6 +}; +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_FMK_TYPES_H_ diff --git a/inc/metadef/external/register/register_types.h b/inc/metadef/external/register/register_types.h new file mode 100644 index 0000000000000000000000000000000000000000..c966222dc54e2979a629046da2a5a70e9fbc91e8 --- /dev/null +++ b/inc/metadef/external/register/register_types.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ +#define INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ + +#if(defined(HOST_VISIBILITY)) && (defined(__GNUC__)) +#define FMK_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define FMK_FUNC_HOST_VISIBILITY +#endif +#if(defined(DEV_VISIBILITY)) && (defined(__GNUC__)) +#define FMK_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define FMK_FUNC_DEV_VISIBILITY +#endif +#ifdef __GNUC__ + +#ifdef NO_METADEF_ABI_COMPATIABLE +#define ATTRIBUTED_DEPRECATED(replacement) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __attribute__((deprecated("Please use " #replacement " instead."))) +#endif + +#else +#ifdef NO_METADEF_ABI_COMPATIABLE +#define ATTRIBUTED_DEPRECATED(replacement) +#else +#define ATTRIBUTED_DEPRECATED(replacement) __declspec(deprecated("Please use " #replacement " instead.")) +#endif +#endif +namespace domi { + +/// +/// @ingroup domi +/// @brief original tensor type +/// +typedef enum tagDomiTensorFormat { + DOMI_TENSOR_NCHW = 0, // < NCHW + DOMI_TENSOR_NHWC, // < NHWC + DOMI_TENSOR_ND, // < Nd Tensor + DOMI_TENSOR_NC1HWC0, // < NC1HWC0 + DOMI_TENSOR_FRACTAL_Z, // < FRACTAL_Z + DOMI_TENSOR_NC1C0HWPAD, + DOMI_TENSOR_NHWC1C0, + DOMI_TENSOR_FSR_NCHW, + DOMI_TENSOR_FRACTAL_DECONV, + DOMI_TENSOR_BN_WEIGHT, + DOMI_TENSOR_CHWN, // Android NN Depth CONV + DOMI_TENSOR_FILTER_HWCK, // filter input tensor format + DOMI_TENSOR_NDHWC, + DOMI_TENSOR_NCDHW, + DOMI_TENSOR_DHWCN, // 3D filter input tensor format + DOMI_TENSOR_DHWNC, + DOMI_TENSOR_RESERVED +} domiTensorFormat_t; +} // namespace domi + +#endif // INC_EXTERNAL_REGISTER_REGISTER_TYPES_H_ diff --git a/inc/metadef/external/register/scope/scope_fusion_pass_register.h b/inc/metadef/external/register/scope/scope_fusion_pass_register.h new file mode 100644 index 0000000000000000000000000000000000000000..417ca38bf30fb1fc6a5f7897dccfc9345269fc6d --- /dev/null +++ b/inc/metadef/external/register/scope/scope_fusion_pass_register.h @@ -0,0 +1,401 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ +#define EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ + +#include +#include +#include +#include +#include +#include "external/ge_common/ge_api_error_codes.h" +#include "register/register_error_codes.h" +#include "register/register_types.h" +#include "graph/operator.h" + +#define CHECK_INNER_NODE_CONDITION(cond, fusion_rlt) \ + do { \ + if (!(cond)) { \ + if ((fusion_rlt) != nullptr) { \ + (fusion_rlt)->SetType(ge::kScopeInvalidType); \ + } \ + return; \ + } \ + } while (0) + +namespace domi { +class TensorFlowModelParser; +} // namespace domi +namespace ge { +const int32_t kFusionDisableIndex = 99999; +const char_t *const kScopeToMultiNodes = "ScopeToMultiNodes"; +const char_t *const kScopeInvalidType = "ScopeInvalidType"; +const char_t *const kInputFromFusionScope = "InputFromFusionScope"; +const char_t *const kOutputToFusionScope = "OutputToFusionScope"; +class ScopePattern; +using ScopeFusionPatterns = std::vector>; + +class ScopePassManager; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { + public: + Scope(); + ATTRIBUTED_DEPRECATED(Status Init(const char_t *, const char_t *, Scope *)) + Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); + Status Init(const char_t *name, const char_t *sub_type, Scope *father_scope = nullptr); + ~Scope(); + ATTRIBUTED_DEPRECATED(Status Name(AscendString &) const) + const std::string &Name() const; + Status Name(AscendString &name) const; + ATTRIBUTED_DEPRECATED(Status SubType(AscendString &) const) + const std::string &SubType() const; + Status SubType(AscendString &sub_type) const; + ATTRIBUTED_DEPRECATED(Status AllNodesMap(std::unordered_map &) const) + const std::unordered_map &AllNodesMap() const; + Status AllNodesMap(std::unordered_map &node_map) const; + ATTRIBUTED_DEPRECATED(Scope *GetSubScope(const char_t *scope_name) const) + Scope *GetSubScope(const std::string &scope_name) const; + Scope *GetSubScope(const char_t *scope_name) const; + ATTRIBUTED_DEPRECATED(Status LastName(AscendString &) const) + const std::string LastName() const; + Status LastName(AscendString &name) const; + const std::vector &GetAllSubScopes() const; + const Scope *GetFatherScope() const; + + private: + class ScopeImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; + friend class ScopeTree; + friend class NodeOpTypeFeature; + friend class NodeAttrFeature; + friend class ScopeFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { + public: + FusionScopesResult(); + Status Init(); + ~FusionScopesResult(); + ATTRIBUTED_DEPRECATED(void SetName(const char_t *)) + void SetName(const std::string &name); + void SetName(const char_t *name); + ATTRIBUTED_DEPRECATED(void SetType(const char_t *)) + void SetType(const std::string &type); + void SetType(const char_t *type); + ATTRIBUTED_DEPRECATED(void SetDescription(const char_t *)) + void SetDescription(const std::string &description); + void SetDescription(const char_t *description); + ATTRIBUTED_DEPRECATED(const Status Name(AscendString &) const) + const std::string &Name() const; + Status Name(AscendString &name) const; + const std::vector &Nodes() const; + ATTRIBUTED_DEPRECATED(void InsertInputs(const char_t *, const std::vector &)) + void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertInputs(const char_t *inner_op_name, const std::vector &index_map); + ATTRIBUTED_DEPRECATED(void InsertOutputs(const char_t *, const std::vector &)) + void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertOutputs(const char_t *inner_op_name, const std::vector &index_map); + + class InnerNodeInfo { + public: + ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char_t *)) + explicit InnerNodeInfo(const std::string &fusion_node_name); + explicit InnerNodeInfo(const char_t *fusion_node_name); + ATTRIBUTED_DEPRECATED(InnerNodeInfo(const char_t *, const char_t *, const char_t *)) + InnerNodeInfo(const std::string &fusion_node_name, const std::string &name, const std::string &type); + InnerNodeInfo(const char_t *fusion_node_name, const char_t *name, const char_t *type); + InnerNodeInfo(InnerNodeInfo &&other) noexcept; + InnerNodeInfo &operator=(InnerNodeInfo &&other) noexcept; + InnerNodeInfo(const InnerNodeInfo &) = delete; + InnerNodeInfo &operator=(const InnerNodeInfo &) = delete; + ~InnerNodeInfo(); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetName(const char_t *)) + InnerNodeInfo &SetName(const std::string &name); + InnerNodeInfo &SetName(const char_t *name); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &SetType(const char_t *)) + InnerNodeInfo &SetType(const std::string &type); + InnerNodeInfo &SetType(const char_t *type); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertInput(const char_t *, int32_t)) + InnerNodeInfo &InsertInput(const std::string &input_node, int32_t peer_out_idx); + InnerNodeInfo &InsertInput(const char_t *input_node, int32_t peer_out_idx); + ATTRIBUTED_DEPRECATED(InnerNodeInfo &InsertOutput(const char_t *, int32_t)) + InnerNodeInfo &InsertOutput(const std::string &output_node, int32_t peer_in_idx); + InnerNodeInfo &InsertOutput(const char_t *output_node, int32_t peer_in_idx); + ge::graphStatus BuildInnerNode(); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetInputFormat(const char_t *, const char_t *)) + ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format); + ge::graphStatus SetInputFormat(const char_t *input_name, const char_t *format); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetOutputFormat(const char_t *, const char_t *)) + ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); + ge::graphStatus SetOutputFormat(const char_t *output_name, const char_t *format); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicInputFormat(const char_t *, uint32_t index, const char_t *)) + ge::graphStatus SetDynamicInputFormat(const std::string &input_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicInputFormat(const char_t *input_name, uint32_t index, const char_t *format); + ATTRIBUTED_DEPRECATED(ge::graphStatus SetDynamicOutputFormat(const char_t *, uint32_t, const char_t *)) + ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, uint32_t index, const std::string &format); + ge::graphStatus SetDynamicOutputFormat(const char_t *output_name, uint32_t index, const char_t *format); + ge::Operator *MutableOperator(); + ATTRIBUTED_DEPRECATED(ge::graphStatus GetName(AscendString &) const) + std::string GetName() const; + ge::graphStatus GetName(AscendString &name) const; + ATTRIBUTED_DEPRECATED(ge::graphStatus GetType(AscendString &) const) + std::string GetType() const; + ge::graphStatus GetType(AscendString &type) const; + ATTRIBUTED_DEPRECATED(ge::graphStatus GetInputs(std::vector> &) const) + std::vector> GetInputs() const; + ge::graphStatus GetInputs(std::vector> &inputs) const; + ATTRIBUTED_DEPRECATED(ge::graphStatus GetOutputs(std::vector> &) const) + std::vector> GetOutputs() const; + ge::graphStatus GetOutputs(std::vector> &outputs) const; + private: + class InnerNodeInfoImpl; + std::unique_ptr impl_; + }; + ATTRIBUTED_DEPRECATED(InnerNodeInfo *AddInnerNode(const char_t *, const char_t *)) + InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); + InnerNodeInfo *AddInnerNode(const char_t *name, const char_t *type); + InnerNodeInfo *MutableRecentInnerNode(); + InnerNodeInfo *MutableInnerNode(uint32_t index); + ge::graphStatus CheckInnerNodesInfo(); + + private: + class FusionScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { + public: + ScopeTree(); + Status Init(); + ScopeTree(const ScopeTree &scopetree) = delete; + ScopeTree &operator=(const ScopeTree &scopetree) = delete; + ~ScopeTree(); + + const std::vector &GetAllScopes() const; + + private: + class ScopeTreeImpl; + std::unique_ptr impl_; + friend class ScopeGraph; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { + public: + ScopeGraph(); + Status Init(); + ScopeGraph(const ScopeGraph &scope_graph) = delete; + ScopeGraph &operator=(const ScopeGraph &scope_graph) = delete; + ~ScopeGraph(); + + const ScopeTree *GetScopeTree() const; + ATTRIBUTED_DEPRECATED(Status GetNodesMap(std::unordered_map &) const) + const std::unordered_map &GetNodesMap() const; + Status GetNodesMap(std::unordered_map &nodes_map) const; + + private: + class ScopeGraphImpl; + std::unique_ptr impl_; + friend class ScopePassManager; + friend class ScopeBasePass; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeAttrValue { + public: + ScopeAttrValue(); + ScopeAttrValue(ScopeAttrValue const &attr_value); + ScopeAttrValue &operator=(ScopeAttrValue const &attr_value); + ~ScopeAttrValue(); + + void SetIntValue(int64_t value); + void SetFloatValue(float32_t value); + ATTRIBUTED_DEPRECATED(void SetStringValue(const char_t *)) + void SetStringValue(std::string value); + void SetStringValue(const char_t *value); + void SetBoolValue(bool value); + + private: + class ScopeAttrValueImpl; + std::unique_ptr impl_; + friend class NodeAttrFeature; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBaseFeature { + public: + virtual bool Match(const Scope *scope) = 0; + ScopeBaseFeature() = default; + ScopeBaseFeature(const ScopeBaseFeature &) = delete; + ScopeBaseFeature &operator=(const ScopeBaseFeature &) = delete; + ScopeBaseFeature(ScopeBaseFeature &&) = delete; + ScopeBaseFeature &operator=(ScopeBaseFeature &&) = delete; + virtual ~ScopeBaseFeature()= default; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBaseFeature { + public: + ATTRIBUTED_DEPRECATED(NodeOpTypeFeature(const char_t *, int, int)) + NodeOpTypeFeature(std::string nodeType, int32_t num, int32_t step = 0); + NodeOpTypeFeature(const char_t *node_type, int32_t num, int32_t step = 0); + NodeOpTypeFeature(NodeOpTypeFeature const &feature); + NodeOpTypeFeature &operator=(NodeOpTypeFeature const &feature); + ~NodeOpTypeFeature() override; + bool Match(const Scope *scope) override; + + private: + class NodeOpTypeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { + public: + ATTRIBUTED_DEPRECATED(NodeAttrFeature(const char_t *, const char_t *, ge::DataType, ScopeAttrValue &)) + NodeAttrFeature(std::string nodeType, std::string attr_name, + ge::DataType datatype, ScopeAttrValue &attr_value); + NodeAttrFeature(const char_t *node_type, const char_t *attr_name, + ge::DataType data_type, ScopeAttrValue &attr_value); + NodeAttrFeature(NodeAttrFeature const &feature); + NodeAttrFeature &operator=(NodeAttrFeature const &feature); + ~NodeAttrFeature() override; + bool Match(const Scope *scope) override; + + private: + class NodeAttrFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFeature : ScopeBaseFeature { + public: + ATTRIBUTED_DEPRECATED(ScopeFeature(const char_t *, int32_t, const char_t *, const char_t *, int)) + ScopeFeature(std::string sub_type, int32_t num, std::string suffix = "", + std::string sub_scope_mask = "", int32_t step = 0); + ScopeFeature(const char_t *sub_type, int32_t num, const char_t *suffix, + const char_t *sub_scope_mask, int32_t step = 0); + ScopeFeature(ScopeFeature const &feature); + ScopeFeature &operator=(ScopeFeature const &feature); + ~ScopeFeature() override; + bool Match(const Scope *scope) override; + + private: + class ScopeFeatureImpl; + std::unique_ptr impl_; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopePattern { + public: + ScopePattern(); + ~ScopePattern(); + ATTRIBUTED_DEPRECATED(ScopePattern &SetSubType(const char_t *)) + ScopePattern &SetSubType(const std::string &sub_type); + ScopePattern &SetSubType(const char_t *sub_type); + ScopePattern &AddNodeOpTypeFeature(NodeOpTypeFeature feature); + ScopePattern &AddNodeAttrFeature(NodeAttrFeature feature); + ScopePattern &AddScopeFeature(ScopeFeature feature); + + private: + class ScopePatternImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopesResult { + public: + ScopesResult(); + ScopesResult(ScopesResult const &result); + ScopesResult &operator=(ScopesResult const &result); + ~ScopesResult(); + + void SetScopes(std::vector &scopes); + void SetNodes(std::vector &nodes); + + private: + class ScopesResultImpl; + std::unique_ptr impl_; + friend class ScopeBasePass; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeBasePass { + public: + ScopeBasePass(); + ScopeBasePass(const ScopeBasePass &) = delete; + ScopeBasePass &operator=(const ScopeBasePass &) = delete; + ScopeBasePass(ScopeBasePass &&) = delete; + ScopeBasePass &operator=(ScopeBasePass &&) = delete; + virtual ~ScopeBasePass(); + + protected: + // Subclasses implement respective fusion strategies and build the Patterns + virtual std::vector DefinePatterns() = 0; + // Define the name of the scope pass + virtual std::string PassName() = 0; + // Subclasses implement respective multi-scope or operator fusion methods across scopes + virtual Status LastMatchScopesAndOPs(std::shared_ptr &scope_graph, + std::vector &results) = 0; + // Subclasses implement their own results and set the input and output of the final fusion operator + virtual void GenerateFusionResult(const std::vector &scopes, FusionScopesResult *fusion_rlt) = 0; + + private: + class ScopeBasePassImpl; + std::unique_ptr impl_; + friend class ge::ScopePassManager; + friend class ScopeBasePassImpl; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { + public: + using CreateFn = ScopeBasePass *(*)(); + ~ScopeFusionPassRegistry(); + + static ScopeFusionPassRegistry& GetInstance(); + + ATTRIBUTED_DEPRECATED(void RegisterScopeFusionPass(const char_t *, CreateFn, bool)) + void RegisterScopeFusionPass(const std::string &pass_name, CreateFn create_fn, bool is_general); + + void RegisterScopeFusionPass(const char_t *pass_name, CreateFn create_fn, bool is_general); + + private: + ScopeFusionPassRegistry(); + class ScopeFusionPassRegistryImpl; + /*lint -e148*/ + std::unique_ptr impl_; + friend class TensorFlowModelParser; +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeUtil { + public: + ATTRIBUTED_DEPRECATED(static AscendString StringReplaceAll(const char_t *, const char_t *, const char_t *)) + static std::string StringReplaceAll(std::string str, const std::string &old_value, const std::string &new_value); + static AscendString StringReplaceAll(const char_t *str, const char_t *old_value, const char_t *new_value); + static void FreeScopePatterns(ScopeFusionPatterns &patterns); + static void FreeOneBatchPattern(std::vector &one_batch_pattern); +}; + +class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistrar { + public: + ScopeFusionPassRegistrar(const char_t *pass_name, ScopeBasePass *(*create_fn)(), bool is_general); + ~ScopeFusionPassRegistrar() = default; +}; +} // namespace ge +#define REGISTER_SCOPE_FUSION_PASS(pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(__COUNTER__, (pass_name), scope_pass, (is_general)) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, scope_pass, is_general) \ + REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, (pass_name), scope_pass, (is_general)) + +#define REGISTER_SCOPE_FUSION_PASS_UNIQ(ctr, pass_name, scope_pass, is_general) \ + static ::ge::ScopeFusionPassRegistrar register_scope_fusion_pass##ctr __attribute__((unused)) = \ + ::ge::ScopeFusionPassRegistrar( \ + (pass_name), []() -> ::ge::ScopeBasePass * { return new (std::nothrow) scope_pass(); }, (is_general)) +#endif // EXTERNAL_REGISTER_SCOPE_SCOPE_FUSION_PASS_REGISTER_H_ + diff --git a/inc/metadef/external/register/tilingdata_base.h b/inc/metadef/external/register/tilingdata_base.h new file mode 100644 index 0000000000000000000000000000000000000000..8618184816139c193bc55f24587a816698d0b715 --- /dev/null +++ b/inc/metadef/external/register/tilingdata_base.h @@ -0,0 +1,266 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_REGISTER_ASCENDC_TILINGDATA_BASE_HEADER__ +#define __INC_REGISTER_ASCENDC_TILINGDATA_BASE_HEADER__ + +#include +#include +#include +#include +#include +#include "graph/ascend_string.h" + +namespace optiling { +struct CharPtrCmp { + bool operator()(const char *strLeft, const char *strRight) const + { + return strcmp(strLeft, strRight) < 0; + } +}; + +class StructSizeInfoBase { +public: + static StructSizeInfoBase &GetInstance() + { + static StructSizeInfoBase instance; + return instance; + } + void SetStructSize(const char *structType, const size_t structSize) + { + if (structSizeInfo.find(structType) != structSizeInfo.end()) { + return; + } + structSizeInfo[structType] = structSize; + } + size_t GetStructSize(const char *structType) + { + return structSizeInfo.at(structType); + } +private: + StructSizeInfoBase() { }; + ~StructSizeInfoBase() { }; + StructSizeInfoBase(const StructSizeInfoBase &); + StructSizeInfoBase &operator=(const StructSizeInfoBase &); + std::map structSizeInfo; +}; + +class FieldInfo { +public: + FieldInfo(const char *dtype, const char *name) + : dtype_(dtype), name_(name), classType_("0") {} + FieldInfo(const char *dtype, const char *name, size_t arrSize) + : dtype_(dtype), name_(name), arrSize_(arrSize), classType_("1") {} + FieldInfo(const char *dtype, const char *name, const char *structType, + size_t structSize) + : dtype_(dtype), name_(name), structType_(structType), structSize_(structSize), classType_("2") {} + +public: + const char *dtype_; + const char *name_; + size_t arrSize_; + const char *structType_; + size_t structSize_; + const char *classType_; +}; + +class TilingDef { +public: + ~TilingDef() + { + if (!inited_data_ptr && data_ptr_ != nullptr) { + delete[] data_ptr_; + } + data_ptr_ = nullptr; + class_name_ = nullptr; + } + void SaveToBuffer(void *pdata, size_t capacity); + std::vector GetFieldInfo() const; + const char *GetTilingClassName() const; + size_t GetDataSize() const; + void SetDataPtr(void *dataPtr); + void CheckAlignAndGenPlaceHolder(const char *name, size_t typeSize); +protected: + void InitData(); + void GeLogError(const std::string& str) const; + // dtype, name + std::vector field_info_; + uint8_t *data_ptr_ = nullptr; + size_t data_size_ = 0; + const char *class_name_; + std::vector> saveBufferPtr; + size_t struct_size_ = 0; + bool inited_data_ptr = false; + uint32_t feature_bit_flag = 0; + uint8_t reserved_buf[128] = {0}; +}; + +using TilingDataConstructor = std::shared_ptr (*)(); + +class CTilingDataClassFactory { +public: + static CTilingDataClassFactory &GetInstance(); + void RegisterTilingData(const char *op_type, const TilingDataConstructor constructor); + std::shared_ptr CreateTilingDataInstance(const char *op_type); + +private: + CTilingDataClassFactory() { }; + ~CTilingDataClassFactory() { }; + CTilingDataClassFactory(const CTilingDataClassFactory &); + CTilingDataClassFactory &operator=(const CTilingDataClassFactory &); + std::map instance_; +}; + +class TilingDataStructBase { +public: + static TilingDataStructBase &GetInstance() + { + static TilingDataStructBase instance; + return instance; + } + uint32_t __attribute__((weak)) RecordTilingStruct(const char* name, const char* file, uint32_t line); +private: + TilingDataStructBase() { }; + ~TilingDataStructBase() { }; + TilingDataStructBase(const TilingDataStructBase &); + TilingDataStructBase &operator=(const TilingDataStructBase &); + std::map, CharPtrCmp> records; +}; +} // end of namespace optiling + +/* +example: +// supported data_type: int8_t/uint8_t/int16_t/uint16_t/int32_t/uint32_t/int64_t/uint64_t +BEGIN_TILING_DATA_DEF(MaxPoolTilingData) + // format: TILING_DATA_FIELD_DEF(data_type, field_name); + TILING_DATA_FIELD_DEF(int32_t, dim_0); + TILING_DATA_FIELD_DEF(uint8_t, var_1); + TILING_DATA_FIELD_DEF(int64_t, factor_1); +END_TILING_DATA_DEF +REGISTER_TILING_DATA_CLASS(MaxPool, MaxPoolTilingData) +*/ +#define BEGIN_TILING_DATA_DEF(class_name) \ +namespace { \ + static uint32_t class_name##tiling = [](void) { \ + if (&TilingDataStructBase::RecordTilingStruct != nullptr) { \ + return TilingDataStructBase::GetInstance().RecordTilingStruct(#class_name, __FILE__, __LINE__); \ + } else { \ + return static_cast(0); \ + } \ + }(); \ +} \ + class class_name : public TilingDef { \ + public: \ + size_t FieldHandler(const char *dtype, const char *name, size_t typeSize, const char* namePh) { \ + CheckAlignAndGenPlaceHolder(namePh, typeSize); \ + field_info_.emplace_back(FieldInfo(dtype, name)); \ + size_t ret_val = data_size_; \ + data_size_ += typeSize; \ + return ret_val; \ + } \ + size_t FieldHandler(const char *dtype, const char *name, size_t typeSize, \ + size_t arrSize, const char* namePh) { \ + CheckAlignAndGenPlaceHolder(namePh, typeSize); \ + field_info_.emplace_back(FieldInfo(dtype, name, arrSize)); \ + size_t ret_val = data_size_; \ + data_size_ += typeSize * arrSize; \ + return ret_val; \ + } \ + size_t FieldHandler(const char *dtype, const char *name, \ + const char *structType, size_t structSize, void *ptr, const char* namePh) { \ + CheckAlignAndGenPlaceHolder(namePh, 8); \ + field_info_.emplace_back(FieldInfo(dtype, name, structType, structSize)); \ + size_t ret_val = data_size_; \ + data_size_ += structSize; \ + saveBufferPtr.emplace_back(std::make_pair(ptr, ret_val)); \ + struct_size_ += structSize; \ + return ret_val; \ + } \ + \ + public: \ + class_name() { \ + class_name_ = #class_name; \ + CheckAlignAndGenPlaceHolder(#class_name"PH", 8); \ + StructSizeInfoBase::GetInstance().SetStructSize(#class_name, data_size_); \ + InitData(); \ + } \ + explicit class_name(void *ptr) { \ + class_name_ = #class_name; \ + CheckAlignAndGenPlaceHolder(#class_name"PH", 8); \ + StructSizeInfoBase::GetInstance().SetStructSize(#class_name, data_size_); \ + if (ptr == nullptr) { \ + return; \ + } \ + SetDataPtr(ptr); \ + } + +#define TILING_DATA_FIELD_DEF(data_type, field_name) \ + public: \ + void set_##field_name(data_type field_name) { \ + field_name##_ = field_name; \ + *((data_type *) (data_ptr_ + field_name##_offset_)) = field_name; \ + } \ + data_type get_##field_name() { return field_name##_; } \ + \ + private: \ + data_type field_name##_ = 0; \ + size_t field_name##_offset_ = FieldHandler(#data_type, #field_name, sizeof(data_type), #field_name"PH"); \ + uint8_t field_name##_reserve_buf_[16] = {0}; + +#define TILING_DATA_FIELD_DEF_ARR(arr_type, arr_size, field_name) \ + public: \ + void set_##field_name(arr_type *field_name) { \ + field_name##_ = field_name; \ + auto offset = field_name##_offset_; \ + if (data_ptr_ + offset == (uint8_t *)field_name) { \ + return; \ + } \ + const auto err_t = memcpy_s(data_ptr_ + offset, data_size_ - offset, field_name, (arr_size) * sizeof(arr_type)); \ + if (err_t != EOK) { \ + GeLogError("tilingdata_base.h TILING_DATA_FIELD_DEF_ARR memcpy is failed !"); \ + } \ + } \ + arr_type *get_##field_name() { \ + return (arr_type *)(data_ptr_ + field_name##_offset_); \ + } \ + \ + private: \ + arr_type *field_name##_ = nullptr; \ + size_t field_name##_offset_ = FieldHandler(#arr_type, #field_name, sizeof(arr_type), arr_size, #field_name"PH"); \ + uint8_t field_name##_reserve_buf_[16] = {0}; + +#define TILING_DATA_FIELD_DEF_STRUCT(struct_type, field_name) \ + public: \ + struct_type field_name{nullptr}; \ + \ + private: \ + size_t field_name##_offset_ = \ + FieldHandler("struct", #field_name, #struct_type, \ + StructSizeInfoBase::GetInstance().GetStructSize(#struct_type), \ + (void *) &field_name, #field_name"PH"); \ + uint8_t field_name##_reserve_buf_[16] = {0}; + +#define END_TILING_DATA_DEF \ + } \ + ; + +#define REGISTER_TILING_DATA_CLASS(op_type, class_name) \ +namespace { \ + class op_type##class_name##Helper { \ + public: \ + op_type##class_name##Helper() { \ + CTilingDataClassFactory::GetInstance().RegisterTilingData(#op_type, \ + op_type##class_name##Helper::CreateTilingDataInstance); \ + } \ + static std::shared_ptr CreateTilingDataInstance() { return std::make_shared(); } \ + }; \ + static class_name g_##op_type##class_name##init; \ + static op_type##class_name##Helper g_tilingdata_##op_type##class_name##helper; \ +} +#endif // __INC_REGISTER_ASCENDC_TILINGDATA_BASE_HEADER__ diff --git a/inc/metadef/external/register/tuning_bank_key_registry.h b/inc/metadef/external/register/tuning_bank_key_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..2bcffe8bdbd25e450bb7b3859ece1293c74958c9 --- /dev/null +++ b/inc/metadef/external/register/tuning_bank_key_registry.h @@ -0,0 +1,198 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_REGISTER_TUNING_BANK_KEY_REGISTRY_HEADER__ +#define __INC_REGISTER_TUNING_BANK_KEY_REGISTRY_HEADER__ +#include +#include +#include +#include +#include "graph/ascend_string.h" +#include "register/register_types.h" +#include "exe_graph/runtime/tiling_context.h" + +// 待v2版本宏定义上库后删除 +#define REGISTER_OP_BANK_KEY_CONVERT_FUN(op, opfunc) \ + REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER(op, (opfunc)) + +#define REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER(optype, opfunc) \ + REGISTER_OP_BANK_KEY_UNIQ(optype, (opfunc)) + +#define REGISTER_OP_BANK_KEY_UNIQ(optype, opfunc) \ + static tuningtiling::OpBankKeyFuncRegistry g_##optype##BankKeyRegistryInterf(#optype, (opfunc)) + +#define REGISTER_OP_BANK_KEY_PARSE_FUN(op, parse_func, load_func) \ + REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER(op, (parse_func), (load_func)) + +#define REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER(optype, parse_func, load_func) \ + REGISTER_OP_BANK_KEY_PARSE_UNIQ(optype, (parse_func), (load_func)) + +#define REGISTER_OP_BANK_KEY_PARSE_UNIQ(optype, parse_func, load_func) \ + static tuningtiling::OpBankKeyFuncRegistry g_##optype##BankParseInterf(#optype, (parse_func), (load_func)) + +// 上库后替代上面的宏定义 +#define REGISTER_OP_BANK_KEY_CONVERT_FUN_V2(op, opfunc) \ + REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER_V2(op, (opfunc)) + +#define REGISTER_OP_BANK_KEY_CONVERT_FUN_UNIQ_HELPER_V2(optype, opfunc) \ + REGISTER_OP_BANK_KEY_UNIQ_V2(optype, (opfunc)) + +#define REGISTER_OP_BANK_KEY_UNIQ_V2(optype, opfunc) \ + static tuningtiling::OpBankKeyFuncRegistryV2 g_##optype##BankKeyRegistryInterf(#optype, (opfunc)) + +#define REGISTER_OP_BANK_KEY_PARSE_FUN_V2(op, parse_func, load_func) \ + REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER_V2(op, (parse_func), (load_func)) + +#define REGISTER_OP_BANK_KEY_PARSE_FUN_UNIQ_HELPER_V2(optype, parse_func, load_func) \ + REGISTER_OP_BANK_KEY_PARSE_UNIQ_V2(optype, (parse_func), (load_func)) + +#define REGISTER_OP_BANK_KEY_PARSE_UNIQ_V2(optype, parse_func, load_func) \ + static tuningtiling::OpBankKeyFuncRegistryV2 g_##optype##BankParseInterf(#optype, (parse_func), (load_func)) + +#define TUNING_TILING_MAKE_SHARED(exec_expr0, exec_expr1) \ + do { \ + try { \ + exec_expr0; \ + } catch (...) { \ + exec_expr1; \ + } \ + } while (0) + +// 待v2版本宏定义上库后删除 +#define DECLARE_STRUCT_RELATE_WITH_OP(op, bank_key, ...) \ + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(bank_key, __VA_ARGS__); \ + bool ParseFunc##op##bank_key(const std::shared_ptr &in_args, size_t len, nlohmann::json &bank_key_json) { \ + if (sizeof(bank_key) != len || in_args == nullptr) { \ + return false; \ + } \ + bank_key_json = *(std::static_pointer_cast(in_args)); \ + return true; \ + } \ + bool LoadFunc##op##bank_key(std::shared_ptr &in_args, size_t &len, const nlohmann::json &bank_key_json) { \ + len = sizeof(bank_key); \ + TUNING_TILING_MAKE_SHARED(in_args = std::make_shared(), return false); \ + auto op_ky = std::static_pointer_cast(in_args); \ + *op_ky = bank_key_json.get(); \ + return true; \ + } \ + REGISTER_OP_BANK_KEY_PARSE_FUN(op, ParseFunc##op##bank_key, LoadFunc##op##bank_key); + +// 上库后替代上面的宏定义 +#define DECLARE_STRUCT_RELATE_WITH_OP_V2(op, bank_key, ...) \ + NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(bank_key, __VA_ARGS__); \ + static bool ParseFuncV2##op##bank_key(const std::shared_ptr &in_args, size_t len, \ + ge::AscendString &bank_key_json_str) { \ + if (sizeof(bank_key) != len || in_args == nullptr) { \ + return false; \ + } \ + nlohmann::json bank_key_json; \ + bank_key_json = *(std::static_pointer_cast(in_args)); \ + try { \ + std::string json_dump_str = bank_key_json.dump(); \ + bank_key_json_str = ge::AscendString(json_dump_str.c_str()); \ + } catch (std::exception& e) { \ + return false; \ + } \ + return true; \ + } \ + static bool LoadFuncV2##op##bank_key(std::shared_ptr &in_args, size_t &len, \ + const ge::AscendString &bank_key_json_str) { \ + len = sizeof(bank_key); \ + TUNING_TILING_MAKE_SHARED(in_args = std::make_shared(), return false); \ + nlohmann::json bank_key_json; \ + try { \ + bank_key_json = nlohmann::json::parse(bank_key_json_str.GetString()); \ + auto op_ky = std::static_pointer_cast(in_args); \ + *op_ky = bank_key_json.get(); \ + } catch (std::exception& e) { \ + return false; \ + } \ + return true; \ + } \ + REGISTER_OP_BANK_KEY_PARSE_FUN_V2(op, ParseFuncV2##op##bank_key, LoadFuncV2##op##bank_key); + + +namespace tuningtiling { +// 待v2版本上库后删除 +using OpBankKeyConvertFun = std::function &, size_t &)>; +using OpBankParseFun = std::function &, size_t, nlohmann::json &)>; +using OpBankLoadFun = std::function &, size_t &, const nlohmann::json &)>; + +// 上库后删除上面的引用 +using OpBankKeyConvertFunV2 = std::function &, size_t &)>; +using OpBankParseFunV2 = std::function &, size_t, ge::AscendString &)>; +using OpBankLoadFunV2 = std::function &, size_t &, const ge::AscendString &)>; +// 待v2版本上库后删除 +class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncInfo { +public: + explicit OpBankKeyFuncInfo(const ge::AscendString &optype); + OpBankKeyFuncInfo() = default; + ~OpBankKeyFuncInfo() = default; + void SetOpConvertFunc(const OpBankKeyConvertFun &convert_func); + void SetOpParseFunc(const OpBankParseFun &parse_func); + void SetOpLoadFunc(const OpBankLoadFun &load_func); + const OpBankKeyConvertFun& GetBankKeyConvertFunc() const; + const OpBankParseFun& GetBankKeyParseFunc() const; + const OpBankLoadFun& GetBankKeyLoadFunc() const; + const ge::AscendString& GetOpType() const { + return optype_; + } + +private: + ge::AscendString optype_; + OpBankKeyConvertFun convert_func_; + OpBankParseFun parse_func_; + OpBankLoadFun load_func_; + +}; + +// 上库后删除上面的类 +class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncInfoV2 { +public: + explicit OpBankKeyFuncInfoV2(const ge::AscendString &optypeV2); + OpBankKeyFuncInfoV2() = default; + ~OpBankKeyFuncInfoV2() = default; + void SetOpConvertFuncV2(const OpBankKeyConvertFunV2 &convert_funcV2); + void SetOpParseFuncV2(const OpBankParseFunV2 &parse_funcV2); + void SetOpLoadFuncV2(const OpBankLoadFunV2 &load_funcV2); + const OpBankKeyConvertFunV2& GetBankKeyConvertFuncV2() const; + const OpBankParseFunV2& GetBankKeyParseFuncV2() const; + const OpBankLoadFunV2& GetBankKeyLoadFuncV2() const; + const ge::AscendString& GetOpTypeV2() const { + return optypeV2_; + } + +private: + ge::AscendString optypeV2_; + OpBankKeyConvertFunV2 convert_funcV2_; + OpBankParseFunV2 parse_funcV2_; + OpBankLoadFunV2 load_funcV2_; +}; + +// 该类在v2版本上库后删除 +class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncRegistry { +public: + OpBankKeyFuncRegistry(const ge::AscendString &optype, const OpBankKeyConvertFun &convert_func); + OpBankKeyFuncRegistry(const ge::AscendString &optype, const OpBankParseFun &parse_func, + const OpBankLoadFun &load_func); + ~OpBankKeyFuncRegistry() = default; + static std::unordered_map &RegisteredOpFuncInfo(); +}; + +// 上库后替代上面的类 +class FMK_FUNC_HOST_VISIBILITY OpBankKeyFuncRegistryV2 { +public: + OpBankKeyFuncRegistryV2(const ge::AscendString &optype, const OpBankKeyConvertFunV2 &convert_funcV2); + OpBankKeyFuncRegistryV2(const ge::AscendString &optype, const OpBankParseFunV2 &parse_funcV2, + const OpBankLoadFunV2 &load_funcV2); + ~OpBankKeyFuncRegistryV2() = default; + static std::unordered_map &RegisteredOpFuncInfoV2(); +}; +} // namespace tuningtiling +#endif diff --git a/inc/metadef/external/register/tuning_tiling_reflection_utils.h b/inc/metadef/external/register/tuning_tiling_reflection_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..bf44278486876bf45fb3b1f3c99bab032576ef2e --- /dev/null +++ b/inc/metadef/external/register/tuning_tiling_reflection_utils.h @@ -0,0 +1,169 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_REGISTER_TUNING_TILING_REFLECTION_UTILS_HEADER__ +#define __INC_REGISTER_TUNING_TILING_REFLECTION_UTILS_HEADER__ +#include +#include +#include +#include + +namespace tuningtiling { +// implement for std c++11 +template +using decay_t = typename std::decay::type; + +template +using enable_if_t = typename std::enable_if::type; + +template +struct integer_sequence { + using value_type = T; + static constexpr std::size_t size() { + return sizeof...(Ints); + } +}; + +template +using index_sequence = integer_sequence; + +template +struct make_integer_sequence : make_integer_sequence {}; + +template +struct make_integer_sequence : integer_sequence {}; + +template +using make_index_sequence = make_integer_sequence; + +template +struct StructInfo { + static std::tuple<> Info() { + return std::make_tuple(); + } +}; + +#define DECLARE_SCHEMA(Struct, ...) \ + template<> \ + struct StructInfo { \ + static decltype(std::make_tuple(__VA_ARGS__)) Info() { \ + return std::make_tuple(__VA_ARGS__); \ + } \ + }; + +#define FIELD(class, FieldName) std::make_tuple(#FieldName, &class ::FieldName) + +template +void ForEachTuple(Tuple &&tuple, Field &&fields, Fn &&fn, index_sequence) { + (void) std::initializer_list { + (fn(std::get<0>(std::get(fields)), tuple.*std::get<1>(std::get(fields))), Is)...}; +} + +template +void ForEachTuple(Tuple &&tuple, Fn &&fn) { + const auto fields = StructInfo>::Info(); + ForEachTuple(std::forward(tuple), fields, std::forward(fn), + make_index_sequence::value> {}); +} + +template +struct is_optional : std::false_type {}; + +template +struct is_optional> : std::true_type {}; + +template +bool is_optional_v() { + return is_optional>::value; +} + +template +decltype(std::begin(T()), std::true_type {}) containable(size_t); + +template +std::false_type containable(...); + +template +using is_containable = decltype(containable(0U)); + +template +constexpr bool IsSerializeType() { + return ((!std::is_class>::value) || is_containable>()); +} + +template +void ForEachField(T &&value, Fn &&fn) { + ForEachTuple(std::forward(value), std::forward(fn)); +} + +template +struct DumpFunctor; + +template()>* = nullptr> +void DumpObj(T &&obj, const std::string &field_name, Js &j) { + if (field_name.empty()) { + ForEachField(std::forward(obj), DumpFunctor(j)); + return; + } + ForEachField(std::forward(obj), DumpFunctor(j[field_name])); +} + +template()>* = nullptr> +void DumpObj(T &&obj, const std::string &field_name, Js &j) { + if (field_name.empty()) { + return; + } + j[field_name] = std::forward(obj); +} + +template +struct DumpFunctor { + explicit DumpFunctor(T &j) : js(j) {} + template + void operator()(Name &&name, Field &&field) const { + DumpObj(std::forward(field), std::forward(name), js); + } + T &js; +}; + +template +struct FromJsonFunctor; + +template()>* = nullptr> +void FromJsonImpl(T &&obj, const std::string &field_name, const Js &j) { + if (field_name.empty()) { + ForEachField(std::forward(obj), FromJsonFunctor(j)); + return; + } + if (j.find(field_name) == j.cend()) { + return; + } + ForEachField(std::forward(obj), FromJsonFunctor(j[field_name])); +} + +template()>* = nullptr> +void FromJsonImpl(T &&obj, const std::string &field_name, const Js &j) { + // ignore missing field of optional + if ((tuningtiling::is_optional_v()) || (j.find(field_name) == j.cend())) { + return; + } + j.at(field_name).get_to(std::forward(obj)); +} + +template +struct FromJsonFunctor { + explicit FromJsonFunctor(const Js &j) : js(j) {} + template + void operator()(Name &&name, Field &&field) const { + FromJsonImpl(std::forward(field), std::forward(name), js); + } + const Js &js; +}; +} // namespace tuningtiling +#endif diff --git a/inc/metadef/external/register/tuning_tiling_registry.h b/inc/metadef/external/register/tuning_tiling_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..a2fe31c6cc4d28ea205fe86725d95f0d26997f51 --- /dev/null +++ b/inc/metadef/external/register/tuning_tiling_registry.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_REGISTER_TUNING_TILING_REGISTRY_HEADER__ +#define __INC_REGISTER_TUNING_TILING_REGISTRY_HEADER__ +#include +#include +#include +#include +#include "graph/ascend_string.h" +#include "register/tuning_tiling_reflection_utils.h" +namespace tuningtiling { +struct TilingItem { + ge::AscendString dtype_; + ge::AscendString name_; +}; + +class TuningTilingDef { +public: + virtual void FromJson(const nlohmann::json &j) = 0; + virtual void ToJson(nlohmann::json &j) = 0; + ge::AscendString GetClassName() const; + virtual std::vector GetItemInfo() const = 0; + +protected: + TuningTilingDef() = default; + virtual ~TuningTilingDef() = default; + // dtype , name + std::vector field_info_; + ge::AscendString class_name_; +}; + +#define BEGIN_TUNING_TILING_DEF(class_name) \ + class class_name : public TuningTilingDef { \ + public: \ + virtual void FromJson(const nlohmann::json &j) { \ + FromJsonImpl(*this, "", j); \ + } \ + \ + virtual void ToJson(nlohmann::json &j) { \ + DumpObj(*this, "", j); \ + } \ + \ + std::vector GetItemInfo() const { \ + return field_info_; \ + } \ + \ + class FieldHandler { \ + public: \ + FieldHandler(class_name *pinstance, const ge::AscendString &dtype, const ge::AscendString &name) { \ + pinstance->field_info_.push_back( {dtype, name}); \ + } \ + }; \ + friend class FieldHandler; \ + \ + public: \ + class_name() { \ + class_name_ = #class_name; \ + }; + +#define TUNING_TILING_DATA_FIELD_DEF(data_type, field_name) \ + public: \ + data_type field_name; \ + FieldHandler field_name##_handler_ = FieldHandler(this, #data_type, #field_name); + +#define END_TUNING_TILING_DEF \ + } \ + ; + +using TuningTilingDefConstructor = std::shared_ptr (*)(); +class TuningTilingClassFactory { +public: + static std::map &RegisterInfo(); + static void RegisterTilingData(const ge::AscendString &optype, TuningTilingDefConstructor const constructor); + static std::shared_ptr CreateTilingDataInstance(const ge::AscendString &optype); +}; + +#define REGISTER_TUNING_TILING_CLASS(optype, class_name) \ + class optype##Helper { \ + public: \ + optype##Helper() { \ + TuningTilingClassFactory::RegisterTilingData(#optype, optype##Helper::CreateTilingDataInstance); \ + } \ + static std::shared_ptr CreateTilingDataInstance() { \ + return std::make_shared(); \ + } \ + }; \ + optype##Helper g_tuning_tiling_##optype##Helper; +using TuningTilingDefPtr = std::shared_ptr; +} // namespace tuningtiling + +#endif diff --git a/inc/metadef/external/utils/extern_math_util.h b/inc/metadef/external/utils/extern_math_util.h new file mode 100644 index 0000000000000000000000000000000000000000..45a9662a283120eb429f9c7f399775e8a34b8050 --- /dev/null +++ b/inc/metadef/external/utils/extern_math_util.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXTERNAL_UTILS_EXTERN_MATH_UTIL_H +#define METADEF_CXX_INC_EXTERNAL_UTILS_EXTERN_MATH_UTIL_H + +#include +#include +#include + +namespace ge { +template +class IntegerChecker { + public: + template + static bool Compat(const T1 v) { + static_assert(((sizeof(T) <= sizeof(uint64_t)) && (sizeof(T1) <= sizeof(uint64_t))), + "IntegerChecker can only check integers less than 64 bits"); + if (v >= static_cast(0)) { + return static_cast(v) <= static_cast(std::numeric_limits::max()); + } + return static_cast(v) >= static_cast(std::numeric_limits::min()); + } +}; + +template +bool MulOverflow(TLhs lhs, TRhs rhs, TRet &ret) { +#if __GNUC__ >= 5 + return __builtin_mul_overflow(lhs, rhs, &ret); +#else + if ((!IntegerChecker::Compat(lhs)) || (!IntegerChecker::Compat(rhs))) { + return true; + } + if ((lhs == 0) || (rhs == 0)) { + ret = 0; + return false; + } + TRet reminder = std::numeric_limits::max() / static_cast(rhs); + const TRet lhs_ret_type = static_cast(lhs); + if (lhs_ret_type < 0) { + if (reminder > 0) { + reminder *= static_cast(-1); + } + if (lhs_ret_type < reminder) { + return true; + } + } else { + if (reminder < 0) { + reminder *= static_cast(-1); + } + if (lhs_ret_type > reminder) { + return true; + } + } + ret = static_cast(lhs) * static_cast(rhs); + return false; +#endif +} + +template +bool AddOverflow(TLhs lhs, TRhs rhs, TRet &ret) { +#if __GNUC__ >= 5 + return __builtin_add_overflow(lhs, rhs, &ret); +#else + if ((!IntegerChecker::Compat(lhs)) || (!IntegerChecker::Compat(rhs))) { + return true; + } + if (rhs >= 0) { + if (static_cast(lhs) > std::numeric_limits::max() - static_cast(rhs)) { + return true; + } + } else { + if (static_cast(lhs) < std::numeric_limits::min() - static_cast(rhs)) { + return true; + } + } + ret = static_cast(lhs) + static_cast(rhs); + return false; +#endif +} +} // namespace ge + + +#endif // METADEF_CXX_INC_EXTERNAL_UTILS_EXTERN_MATH_UTIL_H diff --git a/inc/metadef/graph/aligned_ptr.h b/inc/metadef/graph/aligned_ptr.h new file mode 100644 index 0000000000000000000000000000000000000000..e2fe66ca5f8786dba78614ba55728e553104e7fa --- /dev/null +++ b/inc/metadef/graph/aligned_ptr.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GE_ALIGNED_PTR_H_ +#define GE_ALIGNED_PTR_H_ + +#include +#include + +namespace ge { +class AlignedPtr { + public: + using Deleter = std::function; + using Allocator = std::function &base_addr)>; + explicit AlignedPtr(const size_t buffer_size, const size_t alignment = 16U); + AlignedPtr() = default; + ~AlignedPtr() = default; + AlignedPtr(const AlignedPtr &) = delete; + AlignedPtr(AlignedPtr &&) = delete; + AlignedPtr &operator=(const AlignedPtr &) = delete; + AlignedPtr &operator=(AlignedPtr &&) = delete; + + const uint8_t *Get() const { return aligned_addr_; } + uint8_t *MutableGet() { return aligned_addr_; } + std::unique_ptr Reset(); + std::unique_ptr Reset(uint8_t *const data, const AlignedPtr::Deleter &delete_func); + + static std::shared_ptr BuildFromAllocFunc(const AlignedPtr::Allocator &alloc_func, + const AlignedPtr::Deleter &delete_func); + static std::shared_ptr BuildFromData(uint8_t * const data, + const AlignedPtr::Deleter &delete_func); + private: + std::unique_ptr base_ = nullptr; + uint8_t *aligned_addr_ = nullptr; +}; +} // namespace ge +#endif // GE_ALIGNED_PTR_H_ diff --git a/inc/metadef/graph/anchor.h b/inc/metadef/graph/anchor.h new file mode 100644 index 0000000000000000000000000000000000000000..0ed7d4683ba8d6d36cd733ee3ce71ff91ccc5500 --- /dev/null +++ b/inc/metadef/graph/anchor.h @@ -0,0 +1,308 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_ANCHOR_H_ +#define INC_GRAPH_ANCHOR_H_ + +#include "graph/compiler_options.h" + +#include +#include "graph/ge_error_codes.h" +#include "graph/range_vistor.h" +#include "graph/types.h" +#include "graph/node.h" + +namespace ge { +enum AnchorStatus { + ANCHOR_SUSPEND = 0, // dat null + ANCHOR_CONST = 1, + ANCHOR_DATA = 2, // Effective + ANCHOR_RESERVED = 3 +}; +using ConstAnchor = const Anchor; +class AnchorImpl; +using AnchorImplPtr = std::shared_ptr; +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable_shared_from_this { + friend class AnchorUtils; + + public: + using TYPE = const char *; + template + using Vistor = RangeVistor>; + + Anchor(const NodePtr& owner_node, const int32_t idx); + Anchor(const Anchor &) = delete; + Anchor(Anchor &&) = delete; + Anchor &operator=(const Anchor &) = delete; + Anchor &operator=(Anchor &&) = delete; + virtual ~Anchor(); + + protected: + // Whether the two anchor is equal + virtual bool Equal(const AnchorPtr anchor) const = 0; + virtual bool IsTypeOf(const TYPE type) const; + virtual bool IsTypeIdOf(const TypeId& type) const; + virtual TYPE GetSelfType() const; + + public: + // Get all peer anchors connected to current anchor + Vistor GetPeerAnchors() const; + // Get all peer anchors bare ptr connected to current anchor + std::vector GetPeerAnchorsPtr() const; + // Get peer anchor size + size_t GetPeerAnchorsSize() const; + // Get first peer anchor + AnchorPtr GetFirstPeerAnchor() const; + + // Get the anchor belong to which node + // Normally, return value is not null + NodePtr GetOwnerNode() const; + + // Get the anchor belong to which node by Node*, + Node *GetOwnerNodeBarePtr() const; + + // Remove all links with the anchor + void UnlinkAll() noexcept; + + // Remove link with the given anchor + graphStatus Unlink(const AnchorPtr &peer); + + // insert node + // this--old_peer ---> this--first_peer second_peer--old_peer + graphStatus Insert(const AnchorPtr &old_peer, const AnchorPtr &first_peer, const AnchorPtr &second_peer); + + // Replace peer with new peer + graphStatus ReplacePeer(const AnchorPtr &old_peer, const AnchorPtr &new_peer); + + // Judge if the anchor is linked with the given anchor + bool IsLinkedWith(const AnchorPtr &peer) const; + + // Get anchor index of the node + int32_t GetIdx() const; + + // set anchor index of the node + void SetIdx(const int32_t index); + + protected: + template + static Anchor::TYPE TypeOf() { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + return METADEF_FUNCTION_IDENTIFIER; + } + + public: + template + static std::shared_ptr DynamicAnchorCast(const AnchorPtr anchorPtr) { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + if ((anchorPtr == nullptr) || (!anchorPtr->IsTypeIdOf())) { + return nullptr; + } + return std::static_pointer_cast(anchorPtr); + } + + template + static T *DynamicAnchorPtrCast(Anchor *const anchor) { + static_assert(std::is_base_of::value, "T must be a Anchor!"); + if ((anchor == nullptr) || (!anchor->IsTypeIdOf())) { + return nullptr; + } + return PtrToPtr(anchor); + } + + template + bool IsTypeOf() const { + return IsTypeOf(TypeOf()); + } + + template + bool IsTypeIdOf() const { + return IsTypeIdOf(GetTypeId()); + } + + + protected: + AnchorImplPtr impl_; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY DataAnchor : public Anchor { + friend class AnchorUtils; + + public: + explicit DataAnchor(const NodePtr &owner_node, const int32_t idx); + DataAnchor(const DataAnchor &) = delete; + DataAnchor &operator=(const DataAnchor &) = delete; + DataAnchor(DataAnchor &&) = delete; + DataAnchor &operator=(DataAnchor &&) = delete; + ~DataAnchor() override = default; + + protected: + bool IsTypeOf(const TYPE type) const override; + bool IsTypeIdOf(const TypeId& type) const override; + Anchor::TYPE GetSelfType() const override; + + private: + Format format_{FORMAT_ND}; + AnchorStatus status_{ANCHOR_SUSPEND}; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataAnchor { + friend class OutDataAnchor; + + friend class OutControlAnchor; + + public: + explicit InDataAnchor(const NodePtr &owner_node, const int32_t idx); + + ~InDataAnchor() override = default; + + // Get source out data anchor + OutDataAnchorPtr GetPeerOutAnchor() const; + + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkFrom(const OutDataAnchorPtr &src); + + protected: + bool Equal(const AnchorPtr anchor) const override; + bool IsTypeOf(const TYPE type) const override; + bool IsTypeIdOf(const TypeId& type) const override; + Anchor::TYPE GetSelfType() const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchor : public DataAnchor { + friend class InDataAnchor; + + friend class AnchorUtils; + + public: + template + using Vistor = RangeVistor>; + + explicit OutDataAnchor(const NodePtr &owner_node, const int32_t idx); + + ~OutDataAnchor() override = default; + // Get dst in data anchor(one or more) + Vistor GetPeerInDataAnchors() const; + std::vector GetPeerInDataAnchorsPtr() const; + uint32_t GetPeerInDataNodesSize() const; + + // Get dst in control anchor(one or more) + Vistor GetPeerInControlAnchors() const; + + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkTo(const InDataAnchorPtr &dest); + + // Build connection from OutDataAnchor to InControlAnchor + graphStatus LinkTo(const InControlAnchorPtr &dest); + + protected: + bool Equal(const AnchorPtr anchor) const override; + bool IsTypeOf(const TYPE type) const override; + bool IsTypeIdOf(const TypeId& type) const override; + Anchor::TYPE GetSelfType() const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ControlAnchor : public Anchor { + public: + explicit ControlAnchor(const NodePtr &owner_node); + + explicit ControlAnchor(const NodePtr &owner_node, const int32_t idx); + ~ControlAnchor() override = default; + + protected: + bool IsTypeOf(const TYPE type) const override; + bool IsTypeIdOf(const TypeId& type) const override; + Anchor::TYPE GetSelfType() const override; + ControlAnchor(const ControlAnchor &) = delete; + ControlAnchor &operator=(const ControlAnchor &) = delete; + ControlAnchor(ControlAnchor &&) = delete; + ControlAnchor &operator=(ControlAnchor &&) = delete; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InControlAnchor : public ControlAnchor { + friend class OutControlAnchor; + + friend class OutDataAnchor; + + public: + explicit InControlAnchor(const NodePtr &owner_node); + + explicit InControlAnchor(const NodePtr &owner_node, const int32_t idx); + + ~InControlAnchor() override = default; + + // Get source out control anchors + Vistor GetPeerOutControlAnchors() const; + std::vector GetPeerOutControlAnchorsPtr() const; + bool IsPeerOutAnchorsEmpty() const; + + // Get source out data anchors + Vistor GetPeerOutDataAnchors() const; + + // Build connection from OutControlAnchor to InControlAnchor + graphStatus LinkFrom(const OutControlAnchorPtr &src); + + protected: + bool Equal(const AnchorPtr anchor) const override; + bool IsTypeOf(const TYPE type) const override; + bool IsTypeIdOf(const TypeId& type) const override; + Anchor::TYPE GetSelfType() const override; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutControlAnchor : public ControlAnchor { + friend class InControlAnchor; + + public: + template + using Vistor = RangeVistor>; + + explicit OutControlAnchor(const NodePtr &owner_node); + + explicit OutControlAnchor(const NodePtr &owner_node, const int32_t idx); + + ~OutControlAnchor() override = default; + + // Get dst in control anchor(one or more) + Vistor GetPeerInControlAnchors() const; + std::vector GetPeerInControlAnchorsPtr() const; + // Get dst data anchor in control anchor(one or more) + Vistor GetPeerInDataAnchors() const; + + // Build connection from OutControlAnchor to InControlAnchor + graphStatus LinkTo(const InControlAnchorPtr &dest); + // Build connection from OutDataAnchor to InDataAnchor + graphStatus LinkTo(const InDataAnchorPtr &dest); + + protected: + bool Equal(const AnchorPtr anchor) const override; + bool IsTypeOf(const TYPE type) const override; + bool IsTypeIdOf(const TypeId& type) const override; + Anchor::TYPE GetSelfType() const override; +}; +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(); +} // namespace ge +#endif // INC_GRAPH_ANCHOR_H_ diff --git a/inc/metadef/graph/any_value.h b/inc/metadef/graph/any_value.h new file mode 100644 index 0000000000000000000000000000000000000000..4144aa6ecfdaa28d7ba539c6fedb43ea9f72634b --- /dev/null +++ b/inc/metadef/graph/any_value.h @@ -0,0 +1,338 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef EXECUTE_GRAPH_ANY_VALUE_H +#define EXECUTE_GRAPH_ANY_VALUE_H +#include +#include +#include + +#include "graph/types.h" +#include "type_utils.h" +#include "external/graph/ge_error_codes.h" +#include "graph/def_types.h" +namespace ge { +class ComputeGraph; +using GeTensorPtr = std::shared_ptr; +using ComputeGraphPtr = std::shared_ptr; +class AnyValue { + public: + // 后续删除,新增代码请勿使用这堆using + using INT = int64_t; + using FLOAT = float; + using BOOL = bool; + using STR = std::string; + using TENSOR = GeTensorPtr; + using TENSOR_DESC = GeTensorDesc; + using GRAPH = ComputeGraphPtr; + using BYTES = Buffer; + using NAMED_ATTRS = ge::NamedAttrs; + using DATA_TYPE = ge::DataType; + + using LIST_INT = std::vector; + using LIST_FLOAT = std::vector; + using LIST_BOOL = std::vector; + using LIST_STR = std::vector; + using LIST_TENSOR = std::vector; + using LIST_TENSOR_DESC = std::vector; + using LIST_GRAPH = std::vector; + using LIST_BYTES = std::vector; + using LIST_NAMED_ATTRS = std::vector; + using LIST_DATA_TYPE = std::vector; + using LIST_LIST_INT = std::vector>; + using LIST_LIST_FLOAT = std::vector>; + using NamedAttrs = ge::NamedAttrs; + // public type definitions + // 这堆ValueType的预定义本质上是反向依赖,AnyValue不应该反向依赖ComputeGraph等数据结构 + // 后续整改掉 + enum ValueType { + VT_NONE = 0, + VT_STRING = 1, + VT_FLOAT = 2, + VT_BOOL = 3, + VT_INT = 4, + VT_TENSOR_DESC = 5, + VT_TENSOR = 6, + VT_BYTES = 7, + VT_GRAPH = 8, + VT_NAMED_ATTRS = 9, + VT_LIST_LIST_INT = 10, + VT_DATA_TYPE = 11, + VT_LIST_LIST_FLOAT = 12, + + VT_LIST_BASE = 1000, + VT_LIST_STRING = static_cast(VT_LIST_BASE) + static_cast(VT_STRING), + VT_LIST_FLOAT = static_cast(VT_LIST_BASE) + static_cast(VT_FLOAT), + VT_LIST_BOOL = static_cast(VT_LIST_BASE) + static_cast(VT_BOOL), + VT_LIST_INT = static_cast(VT_LIST_BASE) + static_cast(VT_INT), + VT_LIST_TENSOR_DESC = static_cast(VT_LIST_BASE) + static_cast(VT_TENSOR_DESC), + VT_LIST_TENSOR = static_cast(VT_LIST_BASE) + static_cast(VT_TENSOR), + VT_LIST_BYTES = static_cast(VT_LIST_BASE) + static_cast(VT_BYTES), + VT_LIST_GRAPH = static_cast(VT_LIST_BASE) + static_cast(VT_GRAPH), + VT_LIST_NAMED_ATTRS = static_cast(VT_LIST_BASE) + static_cast(VT_NAMED_ATTRS), + VT_LIST_DATA_TYPE = static_cast(VT_LIST_BASE) + static_cast(VT_DATA_TYPE), + }; + + AnyValue() = default; + AnyValue(AnyValue &&other) noexcept; + AnyValue(const AnyValue &other) { + if (!other.IsEmpty()) { + other.operate_(OperateType::kOpClone, &other, this); + } + } + AnyValue &operator=(AnyValue &&other) noexcept; + AnyValue &operator=(const AnyValue &other); + ~AnyValue() noexcept { + Clear(); + } + + template + static AnyValue CreateFrom(T &&value); + // 如果只有万能引用,那么Set(左值)这种调用方法会出错,因此有了这个函数 + template + static AnyValue CreateFrom(const T &value); + + template + graphStatus SetValue(T &&value); + + // 如果只有万能引用,那么Set(左值)这种调用方法会出错,因此有了这个函数 + template + graphStatus SetValue(const T &value); + + template + graphStatus SetValue(std::initializer_list values); + + template + graphStatus GetValue(T &value) const; + template + const T *Get() const; + template + T *MutableGet(); + + template + bool SameType() const noexcept; + + void Swap(AnyValue &other) noexcept; + + void Clear() { + if (operate_ == nullptr) { + return; + } + operate_(OperateType::kOpClear, nullptr, this); + } + + bool IsEmpty() const noexcept { + return operate_ == nullptr; + } + + ValueType GetValueType() const noexcept; + TypeId GetValueTypeId() const noexcept; + AnyValue Copy() const; + + private: + template + void InnerSet(T &&value); + const void *GetAddr() const; + + enum class OperateType { kOpClear, kOpGetAddr, kOpClone, kOpMove, kGetTypeId, kOperateTypeEnd }; + + template + struct InlineOperations { + static void Operate(const OperateType ot, const AnyValue *const av, void *const out); + static void Construct(const T &value, AnyValue *const av); + static void Construct(T &&value, AnyValue *const av); + }; + + template + struct AllocateOperations { + static void Operate(const OperateType ot, const AnyValue *const av, void *const out); + static void Construct(const T &value, AnyValue *const av); + static void Construct(T &&value, AnyValue *const av); + }; + + using ValueBuf = std::aligned_storage::type; + using ValueHolder = union { + void *pointer; + std::aligned_storage::type inline_buf; + }; + ValueHolder holder_ = {nullptr}; + + void (*operate_)(OperateType ot, const AnyValue *av, void *out){nullptr}; +}; +using GeAttrValue = AnyValue; + +template +void AnyValue::AllocateOperations::Construct(const T &value, AnyValue *const av) { + av->holder_.pointer = new (std::nothrow) T(value); + av->operate_ = AnyValue::AllocateOperations::Operate; +} +template +void AnyValue::AllocateOperations::Construct(T &&value, AnyValue *const av) { + av->holder_.pointer = ::new (std::nothrow) T(std::forward(value)); + av->operate_ = AnyValue::AllocateOperations::Operate; +} +template +void AnyValue::AllocateOperations::Operate(const AnyValue::OperateType ot, const AnyValue *const av, + void *const out) { + switch (ot) { + case OperateType::kOpClear: { + auto *const av_p = PtrToPtr(out); + delete PtrToPtr(av_p->holder_.pointer); + av_p->holder_.pointer = nullptr; + av_p->operate_ = nullptr; + break; + } + case OperateType::kOpGetAddr: + *PtrToPtr(out) = const_cast(av->holder_.pointer); + break; + case OperateType::kOpClone: + PtrToPtr(out)->holder_.pointer = + new (std::nothrow) T(*PtrToPtr(av->holder_.pointer)); + PtrToPtr(out)->operate_ = av->operate_; + break; + case OperateType::kOpMove: { + auto *const av_p = PtrToPtr(out); + av_p->holder_.pointer = av->holder_.pointer; + av_p->operate_ = av->operate_; + const_cast(av)->holder_.pointer = nullptr; + break; + } + case OperateType::kGetTypeId: + *PtrToPtr(out) = GetTypeId(); + break; + default: + break; + } +} +template +void AnyValue::InlineOperations::Construct(const T &value, AnyValue * const av) { + (void)::new (&(av->holder_.inline_buf)) T(value); + av->operate_ = AnyValue::InlineOperations::Operate; +} +template +void AnyValue::InlineOperations::Construct(T &&value, AnyValue *const av) { + Construct(value, av); +} +template +void AnyValue::InlineOperations::Operate(const AnyValue::OperateType ot, const AnyValue *const av, + void *const out) { + switch (ot) { + case OperateType::kOpClear: { + auto *const av_p = PtrToPtr(out); + PtrToPtr::type, T>(&av_p->holder_.inline_buf)->~T(); + av_p->operate_ = nullptr; + break; + } + case OperateType::kOpGetAddr: + *PtrToPtr(out) = const_cast(PtrToPtr(&av->holder_.inline_buf)); + break; + case OperateType::kOpClone: { + auto *const av_p = PtrToPtr(out); + (void)new (&av_p->holder_.inline_buf) T(*PtrToPtr::type, + const T>(&av->holder_.inline_buf)); + av_p->operate_ = av->operate_; + break; + } + case OperateType::kOpMove: { + auto *const av_p = PtrToPtr(out); + auto *const moved_t_p = const_cast(PtrToPtr(&av->holder_.inline_buf)); + (void)new (&av_p->holder_.inline_buf) T(std::move(*moved_t_p)); + av_p->operate_ = av->operate_; + break; + } + case OperateType::kGetTypeId: + *PtrToPtr(out) = GetTypeId(); + break; + default: + break; + } +} + +template +AnyValue AnyValue::CreateFrom(T &&value) { + AnyValue av; + av.InnerSet(std::forward(value)); + return av; +} +template +AnyValue AnyValue::CreateFrom(const T &value) { + AnyValue av; + av.InnerSet(value); + return av; +} +template +void AnyValue::InnerSet(T &&value) { + using PureT = typename std::remove_cv::type>::type; + using Inline = std::integral_constant; + using Operations = + typename std::conditional, AnyValue::AllocateOperations>::type; + + Operations::Construct(std::forward(value), this); +} +template +graphStatus AnyValue::SetValue(T &&value) { + Clear(); + InnerSet(std::forward(value)); + return GRAPH_SUCCESS; +} +template +graphStatus AnyValue::SetValue(const T &value) { + Clear(); + InnerSet(value); + return GRAPH_SUCCESS; +} +template +graphStatus AnyValue::SetValue(std::initializer_list values) { + Clear(); + InnerSet(std::vector(std::move(values))); + return GRAPH_SUCCESS; +} +template +const T *AnyValue::Get() const { + if (!SameType()) { + return nullptr; + } + if (IsEmpty()) { + return nullptr; + } + return PtrToPtr(GetAddr()); +} +template +graphStatus AnyValue::GetValue(T &value) const { + auto *const p = Get(); + if (p == nullptr) { + return GRAPH_FAILED; + } + value = *p; + return GRAPH_SUCCESS; +} +template +T *AnyValue::MutableGet() { + if (!SameType()) { + return nullptr; + } + if (IsEmpty()) { + return nullptr; + } + void *addr = nullptr; + operate_(OperateType::kOpGetAddr, this, &addr); + return PtrToPtr(addr); +} +template +bool AnyValue::SameType() const noexcept { + if (operate_ == nullptr) { + return false; + } + TypeId tid = kInvalidTypeId; + operate_(OperateType::kGetTypeId, this, &tid); + return tid == GetTypeId(); +} +} // namespace ge + +#endif // EXECUTE_GRAPH_ANY_VALUE_H diff --git a/inc/metadef/graph/args_format_desc.h b/inc/metadef/graph/args_format_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..1f32912ae456db5a40cde29d00af88ca5a0791b3 --- /dev/null +++ b/inc/metadef/graph/args_format_desc.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_ARGS_FORMAT_H +#define METADEF_CXX_ARGS_FORMAT_H + +#include +#include + +#include "common/ge_common/debug/ge_log.h" +#include "graph/ge_error_codes.h" +#include "graph/op_desc.h" +#include "register/hidden_inputs_func_registry.h" +#include "graph/utils/args_format_desc_utils.h" + +namespace ge { + +class ArgsFormatDesc { + public: + // i* -> ir_idx = -1, folded=false + // 对于输入输出,idx表示ir定义的idx,-1表示所有输入、所有输出,此时非动态输入、输出默认展开,动态输出要i1*这样才表示展开 + // 对于workspace -1表示个数未知,folded暂时无意义 + // 对ffts尾块非尾块地址,idx=0表示非尾块,idx=1表示尾块 + // 对于其他类型, idx和fold 暂时没有意义 + void Append(AddrType type, int32_t ir_idx = -1, bool folded = false); + + void Clear(); + + void AppendTilingContext(TilingContextSubType sub_type = TilingContextSubType::TILING_CONTEXT); + + std::string ToString() const; + + graphStatus GetArgsSize(const OpDescPtr &op_desc, size_t &args_size) const; + + static graphStatus GetArgSize(const OpDescPtr &op_desc, const ArgDesc arg_desc, size_t &arg_size); + + static graphStatus Parse(const OpDescPtr &op_desc, const std::string &str, std::vector &arg_descs); + + // 为了方便使用,字符串用i*这样的通配符时,返回的argDesc会按照实际个数展开 + // easy mode 不需要进行展开,只根据字面值做反序列化,允许不传op_desc + static graphStatus Parse(const OpDescPtr &op_desc, const std::string &str, std::vector &arg_descs, + const bool easy_mode); + + // 抽取公共序列化/反序列化函数 + static std::string Serialize(const std::vector &arg_descs); + + private: + std::vector arg_descs_; +}; +} // namespace ge + +#endif // METADEF_CXX_ARGS_FORMAT_H diff --git a/inc/metadef/graph/ascend_limits.h b/inc/metadef/graph/ascend_limits.h new file mode 100644 index 0000000000000000000000000000000000000000..c923f20b42890dc81b195b2f21cf9c8977285f0a --- /dev/null +++ b/inc/metadef/graph/ascend_limits.h @@ -0,0 +1,21 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_ASCEND_LIMITS_H +#define METADEF_CXX_ASCEND_LIMITS_H +#include +#include + +namespace ge { +constexpr uint32_t kDefaultMaxInputNum = 8U; +constexpr uint32_t kDefaultMaxOutputNum = 8U; +constexpr size_t kDefaultDimsNum = 8U; +constexpr size_t kMaxNameLen = 1024UL; +} +#endif // METADEF_CXX_ASCEND_LIMITS_H diff --git a/inc/metadef/graph/attr_store.h b/inc/metadef/graph/attr_store.h new file mode 100644 index 0000000000000000000000000000000000000000..11c572e784be1624b60cf353e4aac5f30393aea0 --- /dev/null +++ b/inc/metadef/graph/attr_store.h @@ -0,0 +1,322 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef EXECUTE_GRAPH_ATTR_STORE_H +#define EXECUTE_GRAPH_ATTR_STORE_H +#include +#include +#include +#include +#include +#include "any_value.h" + +class AttrGroupsDef; + +namespace ge { +using AttrId = uint64_t; +using AttrSubId = uint32_t; +using AttrNameFilter = std::function; +enum class AttrType : uint32_t { + kAttrPredefinedInIr = 0U, // IR预定义的属性 + kAttrGeneral = 1U, // 通用属性 + kAttrTypeEnd = 2U +}; +constexpr inline uint32_t GetAttrType(const AttrId id) { + return static_cast(id >> 32U); +} +constexpr inline uint32_t GetSubAttrId(const AttrId id) { + return static_cast(id & 0xffffffffU); +} +constexpr inline AttrId GetAttrId(const uint32_t type, const uint32_t sub_id) { + return (static_cast(type) << 32U) | static_cast(sub_id); +} + +struct AttrGroupsBase { + AttrGroupsBase() {} + + virtual ~AttrGroupsBase() {} + + virtual graphStatus Serialize(AttrGroupsDef &attr_groups) { + (void)attr_groups; + return GRAPH_SUCCESS; + } + + virtual graphStatus Deserialize(const AttrGroupsDef &attr_groups) { + (void)attr_groups; + return GRAPH_SUCCESS; + } + + virtual AttrGroupsBase *CreateNew() { + return nullptr; + } +}; + +struct OtherAttrs { + public: + /* the attr whether is in whitelist. */ + bool CheckAttrIsValid(const std::string &attr) const; + + graphStatus SetAttr(const std::string &attr, const AnyValue &value); + + graphStatus GetAttr(const std::string &attr, AnyValue &value) const; + + const std::unordered_map &FastGetAllAttr() const; + + std::unordered_map GetAllAttr() const; + + /* the attr whether is saved in other attr group. */ + bool CheckAttrIsExist(const std::string &attr) const; + + bool DeleteSingleAttr(const std::string &attr); + + void DeleteAllAttrs(); + + void Swap(OtherAttrs &other); + + private: + /* whitelist for attr */ + static const std::unordered_set valid_attrs_; + std::unordered_map keys_to_attrs_; +}; + +class AttrStore { + public: + static AttrStore Create(const size_t pre_defined_attr_count); + + AttrStore() = default; + ~AttrStore() { + ClearAttrsGroups(); + } + + AttrStore(const AttrStore &other); + AttrStore &operator=(const AttrStore &other); + + AttrStore(AttrStore &&other); + AttrStore &operator=(AttrStore &&other); + + template + bool Set(const AttrId attr_id, T &&value) const; + template + bool Set(const AttrId attr_id, const T &value) const; + template + bool SetByName(const std::string &name, T &&value); + template + bool SetByName(const std::string &name, const T &value); + + template + const T *Get(const AttrId attr_id) const; + template + T *MutableGet(const AttrId attr_id); + template + const T *GetByName(const std::string &name) const; + template + T *MutableGetByName(const std::string &name); + + AttrId GetIdByName(const std::string &name) const noexcept; + void SetNameAndId(std::string name, const AttrId id); + + bool Exists(const AttrId attr_id) const noexcept; + bool Exists(const std::string &name) const noexcept; + + bool Delete(const std::string &name); + void Clear(); + + void Swap(AttrStore &other); + bool SetAnyValueByName(const std::string &name, const AnyValue &value); + + // unordered版本更好,为了兼容老版本接口,仍然用set和map,不论用哪种数据结构,这都是非常低效的接口 + std::set GetAllAttrNames() const; + std::map GetAllAttrs() const; + std::map GetAllAttrsWithFilter(const AttrNameFilter &attr_filter) const; + + AnyValue *MutableAnyValue(const std::string &name) const noexcept; + AnyValue *GetOrCreateAnyValue(const std::string &name); + const AnyValue *GetAnyValue(const std::string &name) const noexcept; + + graphStatus SetAttrToOtherGroup(const std::string &attr, const AnyValue &value); + graphStatus GetAttrFromOtherGroup(const std::string &attr, AnyValue &value) const; + const std::unordered_map &FastGetAllAttrsFromOtherGroup() const; + std::unordered_map GetAllAttrsFromOtherGroup() const; + bool CheckAttrIsExistInOtherGroup(const std::string &attr) const; + bool DeleteSingleAttrsInOtherGroup(const std::string &attr); + void DeleteAllAttrsInOtherGroup(); + + template + T *GetOrCreateAttrsGroup(); + template + bool CheckAttrGroupIsExist() const; + + // The return value is the pair. the pair consisting of a bool set to true if the T& is valid. + template + std::pair GetAttrsGroup() const; + + AttrGroupsBase *GetAttrsGroupPtr() const; + + void ClearAllAttrs(); + void ClearAllAttrsInOtherAttrs(); + bool ClearAttrInOtherAttrs(const std::string &attr_name); + + void ClearAttrsGroups() { + if (attrs_group_ptr_ != nullptr) { + delete attrs_group_ptr_; + attrs_group_ptr_ = nullptr; + } + } + + private: + AnyValue *MutableAnyValue(const AttrId attr_id) const noexcept; + AnyValue *GetOrCreateAnyValue(const AttrId attr_id) const; + const AnyValue *GetAnyValue(const AttrId attr_id) const noexcept; + void CopyAttrStoreAllMembers(const AttrStore &other); + + class PreDefinedAttrStore { + public: + bool Exists(const AttrSubId index) const noexcept; + bool Delete(const AttrSubId index); + void Clear(); + void Swap(PreDefinedAttrStore &other); + + AnyValue *GetOrCreateAnyValue(const AttrSubId index) const; + AnyValue *MutableAnyValue(const AttrSubId index) const noexcept; + const AnyValue *GetAnyValue(const AttrSubId index) const noexcept; + + void Resize(const size_t s); + + private: + std::vector attrs_; + }; + + class CustomDefinedAttrStore { + public: + bool Exists(const std::string &name) const noexcept; + bool Delete(const std::string &name); + void Clear(); + void Swap(CustomDefinedAttrStore &other); + + AnyValue *GetOrCreateAnyValue(const std::string &name); + AnyValue *MutableAnyValue(const std::string &name) const noexcept; + const AnyValue *GetAnyValue(const std::string &name) const noexcept; + + void GetAllNames(std::set &names) const; + void GetAllAttrs(std::map &names_to_attr) const; + void GetAllAttrsWithFilter(std::map &names_to_attr, const AttrNameFilter &attr_filter) const; + + private: + std::unordered_map attrs_; + }; + + std::unordered_map names_to_id_; + // 更好的办法是定义一个虚基类、派生出两个子类,然后保存两个子类的指针:`std::array, kAttrTypeEnd>` + // 然后根据不同的SubAttr类型,调用对应子类的函数。但是这么做会导致创建AttrStore时,总会带有两次子类实例堆申请的开销, + // 为了减少堆内存申请,直接将子类平铺在成员变量上。 + PreDefinedAttrStore pre_defined_attrs_; + CustomDefinedAttrStore general_attrs_; + OtherAttrs other_attrs_; + AttrGroupsBase *attrs_group_ptr_ = nullptr; +}; + +template +bool AttrStore::Set(const AttrId attr_id, const T &value) const { + auto *const v = GetOrCreateAnyValue(attr_id); + if (v == nullptr) { + return false; + } + (void)v->SetValue(value); + return true; +} +template +bool AttrStore::Set(const AttrId attr_id, T &&value) const { + auto *const v = GetOrCreateAnyValue(attr_id); + if (v == nullptr) { + return false; + } + (void)v->SetValue(std::forward(value)); + return true; +} +template +bool AttrStore::SetByName(const std::string &name, T &&value) { + auto *const v = GetOrCreateAnyValue(name); + if (v == nullptr) { + return false; + } + (void)v->SetValue(std::forward(value)); + return true; +} +template +bool AttrStore::SetByName(const std::string &name, const T &value) { + auto *const v = GetOrCreateAnyValue(name); + if (v == nullptr) { + return false; + } + (void)v->SetValue(value); + return true; +} + +template +const T *AttrStore::Get(const AttrId attr_id) const { + auto *const v = GetAnyValue(attr_id); + if (v == nullptr) { + return nullptr; + } + return v->Get(); +} +template +const T *AttrStore::GetByName(const std::string &name) const { + auto *const v = GetAnyValue(name); + if (v == nullptr) { + return nullptr; + } + return v->Get(); +} + +template +T *AttrStore::MutableGet(const AttrId attr_id) { + auto *const v = MutableAnyValue(attr_id); + if (v == nullptr) { + return nullptr; + } + return v->MutableGet(); +} +template +T *AttrStore::MutableGetByName(const std::string &name) { + auto *const v = MutableAnyValue(name); + if (v == nullptr) { + return nullptr; + } + return v->MutableGet(); +} + +template +T *AttrStore::GetOrCreateAttrsGroup() { + if (attrs_group_ptr_ == nullptr) { + attrs_group_ptr_ = new (std::nothrow) T; + } + + return dynamic_cast(attrs_group_ptr_); +} + +template +std::pair AttrStore::GetAttrsGroup() const { + if (attrs_group_ptr_ == nullptr) { + static T t; + return std::make_pair(false, t); + } + + return std::make_pair(true, *attrs_group_ptr_); +} + +template +bool AttrStore::CheckAttrGroupIsExist() const { + T *ptr = dynamic_cast(attrs_group_ptr_); + return (ptr == nullptr) ? false : true; +} + +} // namespace ge + +#endif // EXECUTE_GRAPH_ATTR_STORE_H diff --git a/inc/metadef/graph/attr_value_serializable.h b/inc/metadef/graph/attr_value_serializable.h new file mode 100644 index 0000000000000000000000000000000000000000..ee065807ead3d70b141dc57ad65a2c8e32b2622a --- /dev/null +++ b/inc/metadef/graph/attr_value_serializable.h @@ -0,0 +1,17 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ +#define INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ + +#include +#include "graph/ge_attr_value.h" +#include "graph/compiler_options.h" + +#endif // INC_GRAPH_ATTR_VALUE_SERIALIZABLE_H_ diff --git a/inc/metadef/graph/axis_type_info.h b/inc/metadef/graph/axis_type_info.h new file mode 100644 index 0000000000000000000000000000000000000000..ad939ed95185118ec9c8404e428fb823ba3b6ab0 --- /dev/null +++ b/inc/metadef/graph/axis_type_info.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_AXIS_TYPE_INFO_H_ +#define INC_GRAPH_AXIS_TYPE_INFO_H_ + +#include +#include +#include + +#include "external/graph/ge_error_codes.h" + +namespace ge { +using CutInfo = std::pair>; + +enum class AxisType { + UNSPLIT = 0, + ELEMENTWISE, + REDUCESUM, + REDUCEMAX, + REDUCEMIN, + REDUCEMEAN, + TRANSPOSE, + SLIDINGWINDOW, + REDUCEGATHER, + SLIDINGWINDOWGRAD, + ELEMENTWITHSHAPEVALUE, + REDUCEPROD, + ELEMENTWISEWITHFACTOR, +}; + +class AxisTypeInfo { +public: + AxisTypeInfo() = default; + ~AxisTypeInfo() = default; + void SetAxisType(const AxisType axis_type) { axis_type_ = axis_type; } + AxisType GetAxisType() const { return axis_type_; } + void AddInputCutInfo(CutInfo &input_cut_info); + void AddOutputCutInfo(CutInfo &output_cut_info); + graphStatus GetInputCutInfo(const size_t index, CutInfo &input_cut_info) const; + graphStatus GetOutputCutInfo(const size_t index, CutInfo &output_cut_info) const; + const std::vector &GetRelateInputs() const { return relate_inputs_; } + const std::vector &GetRelateOutputs() const { return relate_outputs_; } + void SetRelateInputs(const std::vector &inputs_info) { relate_inputs_ = inputs_info; } + void SetRelateOutputs(const std::vector &outputs_info) { relate_outputs_ = outputs_info; } + const std::vector &GetOriRelateInputs() const { return ori_relate_inputs_; } + const std::vector &GetOriRelateOutputs() const { return ori_relate_outputs_; } + void SetOriRelateInputs(const std::vector &inputs_info) { ori_relate_inputs_ = inputs_info; } + void SetOriRelateOutputs(const std::vector &outputs_info) { ori_relate_outputs_ = outputs_info; } + void SetAxisTypes(const std::vector &axis_types) { axis_types_ = axis_types; } + const std::vector &GetAxisTypes() const { return axis_types_; } + void AddInputValueCutInfo(const CutInfo &cut_info); + void AddOutputValueCutInfo(const CutInfo &cut_info); + graphStatus GetInputValueCutInfo(const size_t index, CutInfo &cut_info) const; + graphStatus GetOutputValueCutInfo(const size_t index, CutInfo &cut_info) const; + const std::vector &GetRelateInputValues() const { return relate_input_values_; } + const std::vector &GetRelateOutputValues() const { return relate_output_values_; } + void SetAdditionInfo(const std::map &addition_info) { addition_info_ = addition_info; } + const std::map &GetAdditionInfo() const { return addition_info_; } + +private: + static graphStatus DoGetCutInfo(const std::vector &cut_infos, const size_t index, CutInfo &cut_info); + + AxisType axis_type_ = AxisType::UNSPLIT; + std::vector relate_inputs_; + std::vector relate_outputs_; + std::vector axis_types_; + std::vector relate_input_values_; + std::vector relate_output_values_; + std::vector ori_relate_inputs_; // backup relate_inputs_ + std::vector ori_relate_outputs_; // backup relate_outputs_ + std::map addition_info_; +}; +} + +#endif // INC_GRAPH_AXIS_TYPE_INFO_H_ diff --git a/inc/metadef/graph/buffer.h b/inc/metadef/graph/buffer.h new file mode 100644 index 0000000000000000000000000000000000000000..b830fd5f98b185f238db17dd4eef7ec369951e15 --- /dev/null +++ b/inc/metadef/graph/buffer.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_BUFFER_H_ +#define INC_GRAPH_BUFFER_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/compiler_options.h" + +namespace ge { +class BufferImpl; +using BufferImplPtr = std::shared_ptr; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { + public: + Buffer(); + Buffer(const Buffer &other); + + explicit Buffer(const std::size_t buffer_size, const std::uint8_t default_val = 0U); + + ~Buffer(); + + Buffer &operator=(const Buffer &other); + static Buffer CopyFrom(const std::uint8_t *const data, const std::size_t buffer_size); + + const std::uint8_t *GetData() const; + std::uint8_t *GetData(); + std::size_t GetSize() const; + void ClearBuffer(); + + // For compatibility + const std::uint8_t *data() const; + std::uint8_t *data(); + std::size_t size() const; + void clear(); + uint8_t operator[](const size_t index) const; + + private: + BufferImplPtr impl_; + + // Create from protobuf obj + Buffer(const ProtoMsgOwner &proto_owner, proto::AttrDef *const buffer); + Buffer(const ProtoMsgOwner &proto_owner, std::string *const buffer); + + friend class GeAttrValueImp; + friend class GeTensor; + friend class BufferUtils; +}; + +class BufferUtils { + public: + static Buffer CreateShareFrom(const Buffer &other); + static Buffer CreateCopyFrom(const Buffer &other); + static Buffer CreateCopyFrom(const std::uint8_t *const data, const std::size_t buffer_size); + static void ShareFrom(const Buffer &from, Buffer &to); + static void CopyFrom(const Buffer &from, Buffer &to); +}; +} // namespace ge +#endif // INC_GRAPH_BUFFER_H_ diff --git a/inc/metadef/graph/cache_policy/aging_policy.h b/inc/metadef/graph/cache_policy/aging_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..397b532fe03ce2ed04c0ecbd54453e93f6f35335 --- /dev/null +++ b/inc/metadef/graph/cache_policy/aging_policy.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_H_ +#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_H_ +#include "graph/cache_policy/cache_state.h" + +namespace ge { +constexpr const size_t kDefaultCacheQueueDepth = 1U; +class AgingPolicy { + public: + AgingPolicy() = default; + virtual ~AgingPolicy() = default; + virtual void SetCachedAgingDepth(size_t depth) = 0; + virtual std::vector DoAging(const CacheState &cache_state) const = 0; + virtual bool IsReadyToAddCache(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) = 0; + private: + AgingPolicy &operator=(const AgingPolicy &anging_polocy) = delete; + AgingPolicy(const AgingPolicy &anging_polocy) = delete; +}; +} +#endif diff --git a/inc/metadef/graph/cache_policy/aging_policy_lru.h b/inc/metadef/graph/cache_policy/aging_policy_lru.h new file mode 100644 index 0000000000000000000000000000000000000000..5eb7f804a92c6ae00b5e3a2e3c4656fcbc937d8b --- /dev/null +++ b/inc/metadef/graph/cache_policy/aging_policy_lru.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_LRU_H_ +#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_AGING_POLICY_LRU_H_ +#include "graph/cache_policy/aging_policy.h" +#include "graph/cache_policy/policy_register.h" + +namespace ge { +class AgingPolicyLru : public AgingPolicy { +public: + virtual ~AgingPolicyLru() override = default; + void SetDeleteInterval(const uint64_t &interval) { + delete_interval_ = interval; + } + void SetCachedAgingDepth(size_t depth) override { + (void)depth; + } + bool IsReadyToAddCache(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) override { + (void) hash_key; + (void) cache_desc; + return true; + } + std::vector DoAging(const CacheState &cache_state) const override; + +private: + uint64_t delete_interval_ = 0U; +}; + +REGISTER_AGING_POLICY_CREATOR(AgingPolicyType::AGING_POLICY_LRU, + []() { + return make_shared(); + }); +} // namespace ge +#endif diff --git a/inc/metadef/graph/cache_policy/aging_policy_lru_k.h b/inc/metadef/graph/cache_policy/aging_policy_lru_k.h new file mode 100644 index 0000000000000000000000000000000000000000..bc5439b9d70520c4d50e5bbb8aa1dd3d039ce9f2 --- /dev/null +++ b/inc/metadef/graph/cache_policy/aging_policy_lru_k.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_GRAPH_CACHE_POLICY_AGING_POLICY_LRU_K_H +#define METADEF_CXX_GRAPH_CACHE_POLICY_AGING_POLICY_LRU_K_H +#include "graph/cache_policy/aging_policy.h" +#include "graph/cache_policy/policy_register.h" + +namespace ge { +class AgingPolicyLruK : public AgingPolicy { + public: + AgingPolicyLruK() : depth_(kDefaultCacheQueueDepth) {} + explicit AgingPolicyLruK(size_t depth) : depth_(depth) {} + AgingPolicyLruK(size_t k_times, size_t depth) : k_times_(k_times), depth_(depth) {} + ~AgingPolicyLruK() override = default; + + void SetCachedAgingDepth(size_t depth) override { + depth_ = depth; + } + bool IsReadyToAddCache(const CacheHashKey hash_key, const CacheDescPtr &cache_desc) override { + return IsCacheDescAppearKTimes(hash_key, cache_desc); + } + std::vector DoAging(const CacheState &cache_state) const override; + + private: + bool IsCacheDescAppearKTimes(const CacheHashKey hash_key, const CacheDescPtr &cache_desc); + private: + size_t k_times_ = 2U; + size_t depth_; + // todo 历史缓存队列的老化 + std::mutex hash_2_cache_descs_and_count_mu_; + std::unordered_map>> hash_2_cache_descs_and_count_; +}; +REGISTER_AGING_POLICY_CREATOR(AgingPolicyType::AGING_POLICY_LRU_K, + []() { return std::make_shared(); }); +} +#endif // METADEF_CXX_GRAPH_CACHE_POLICY_AGING_POLICY_LRU_K_H diff --git a/inc/metadef/graph/cache_policy/cache_desc.h b/inc/metadef/graph/cache_policy/cache_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..012f2005161e3439b8b15058da8af2432c95e252 --- /dev/null +++ b/inc/metadef/graph/cache_policy/cache_desc.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_CACHE_DESC_H +#define GRAPH_CACHE_POLICY_CACHE_DESC_H + +#include +#include "graph/utils/hash_utils.h" +namespace ge { +class CacheDesc; +using CacheDescPtr = std::shared_ptr; +class CacheDesc { + public: + CacheDesc() = default; + virtual ~CacheDesc() = default; + virtual bool IsEqual(const CacheDescPtr &other) const = 0; + virtual bool IsMatch(const CacheDescPtr &other) const = 0; + virtual CacheHashKey GetCacheDescHash() const = 0; +}; +} // namespace ge +#endif diff --git a/inc/metadef/graph/cache_policy/cache_policy.h b/inc/metadef/graph/cache_policy/cache_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..57708694b38a12d36cdf7730ce902b13da17f044 --- /dev/null +++ b/inc/metadef/graph/cache_policy/cache_policy.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_CACHE_POLICY_H_ +#define GRAPH_CACHE_POLICY_CACHE_POLICY_H_ + +#include +#include +#include "cache_state.h" +#include "policy_register.h" +#include "graph/ge_error_codes.h" + +namespace ge { +class CachePolicy { + public: + ~CachePolicy() = default; + + CachePolicy(const CachePolicy &) = delete; + CachePolicy(CachePolicy &&) = delete; + CachePolicy &operator=(const CachePolicy &) = delete; + CachePolicy &operator=(CachePolicy &&) = delete; + + static std::unique_ptr Create(const MatchPolicyPtr &mp, const AgingPolicyPtr &ap); + static std::unique_ptr Create(const MatchPolicyType mp_type, const AgingPolicyType ap_type, + size_t cached_aging_depth = kDefaultCacheQueueDepth); + + graphStatus SetMatchPolicy(const MatchPolicyPtr mp); + + graphStatus SetAgingPolicy(const AgingPolicyPtr ap); + + CacheItemId AddCache(const CacheDescPtr &cache_desc); + + CacheItemId FindCache(const CacheDescPtr &cache_desc) const; + + std::vector DeleteCache(const DelCacheFunc &func); + + std::vector DeleteCache(const std::vector &delete_item); + + std::vector DoAging(); + + CachePolicy() = default; + + private: + CacheState compile_cache_state_; + MatchPolicyPtr mp_ = nullptr; + AgingPolicyPtr ap_ = nullptr; +}; +} // namespace ge +#endif diff --git a/inc/metadef/graph/cache_policy/cache_state.h b/inc/metadef/graph/cache_policy/cache_state.h new file mode 100644 index 0000000000000000000000000000000000000000..881ba9429b0ac621da878d535f7971b02c8a5fdf --- /dev/null +++ b/inc/metadef/graph/cache_policy/cache_state.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_CACHE_STATE_H +#define GRAPH_CACHE_POLICY_CACHE_STATE_H + +#include +#include +#include +#include +#include +#include + +#include "compile_cache_desc.h" + +namespace ge { +class CacheInfo; +using CacheItemId = uint64_t; +constexpr CacheItemId KInvalidCacheItemId = std::numeric_limits::max(); + +using DelCacheFunc = std::function; +using CCStatType = std::unordered_map>; + +class CacheInfo { +friend class CacheState; +public: + CacheInfo(const uint64_t timer_count, const CacheItemId item_id, const CacheDescPtr &desc) + : item_id_(item_id), desc_(desc), timer_count_(timer_count) {} + CacheInfo(const CacheInfo &other) + : item_id_(other.item_id_), desc_(other.desc_), timer_count_(other.timer_count_) {} + CacheInfo &operator=(const CacheInfo &other) { + timer_count_ = other.timer_count_; + item_id_ = other.item_id_; + desc_ = other.desc_; + return *this; + } + CacheInfo() = delete; + ~CacheInfo() = default; + + void RefreshTimerCount(uint64_t time_count) { + timer_count_ = time_count; + } + + uint64_t GetTimerCount() const noexcept { + return timer_count_; + } + + CacheItemId GetItemId() const noexcept { + return item_id_; + } + + const CacheDescPtr &GetCacheDesc() const noexcept { + return desc_; + } + +private: + CacheItemId item_id_; + CacheDescPtr desc_; + uint64_t timer_count_; +}; + +struct CacheInfoQueue { + void Insert(const CacheHashKey main_hash_key, std::vector &cache_info); + void EmplaceBack(const CacheHashKey main_hash_key, CacheInfo &cache_info); + void Erase(std::vector &delete_ids, const DelCacheFunc &is_need_delete_func); + + CCStatType cc_state_; + uint64_t cache_info_num_ = 0U; +}; + +class CacheState { +public: + CacheState() = default; + ~CacheState() = default; + + CacheItemId AddCache(const CacheHashKey main_hash_key, const CacheDescPtr &cache_desc); + + std::vector DelCache(const DelCacheFunc &func); + + std::vector DelCache(const std::vector &delete_item); + + const CCStatType &GetState() const { + return cache_info_queue.cc_state_; + } + + uint64_t GetCacheInfoNum() const { + return cache_info_queue.cache_info_num_; + } + + uint64_t GetCurTimerCount() const { + return cache_timer_count_; + } +private: + CacheItemId GetNextCacheItemId(); + void RecoveryCacheItemId(const std::vector &cache_items); + uint64_t GetNextTimerCount() { + const std::lock_guard lock(cache_timer_count_mu_); + return cache_timer_count_++; + } + + std::mutex cache_info_queue_mu_; + std::mutex cache_item_mu_; + + int64_t cache_item_counter_ = 0L; + std::queue cache_item_queue_; + CacheInfoQueue cache_info_queue; + + uint64_t cache_timer_count_ = 0U; + std::mutex cache_timer_count_mu_; +}; +} // namespace ge +#endif diff --git a/inc/metadef/graph/cache_policy/compile_cache_desc.h b/inc/metadef/graph/cache_policy/compile_cache_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..1a646ba7cdd4a37a0b88ebedd90a178a87ac5576 --- /dev/null +++ b/inc/metadef/graph/cache_policy/compile_cache_desc.h @@ -0,0 +1,123 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_COMPILE_CACHE_DESC_H +#define GRAPH_CACHE_POLICY_COMPILE_CACHE_DESC_H + +#include +#include +#include "cache_desc.h" +#include "graph/small_vector.h" +#include "graph/ascend_limits.h" +#include "graph/types.h" +#include "graph/def_types.h" +#include "graph/utils/hash_utils.h" +#include "common/ge_common/debug/ge_log.h" +#include "common/ge_common/debug/log.h" + +namespace ge { +class CompileCacheDesc; +using CompileCacheDescPtr = std::shared_ptr; +class BinaryHolder { + public: + BinaryHolder() = default; + + ~BinaryHolder() = default; + + BinaryHolder(const BinaryHolder &other); + BinaryHolder(BinaryHolder &&other); + BinaryHolder &operator=(const BinaryHolder &other); + BinaryHolder &operator=(BinaryHolder &&other); + + BinaryHolder(const uint8_t *const data, const size_t data_len); + + static std::unique_ptr createFrom(std::unique_ptr &&ptr, size_t length); + + const uint8_t *GetDataPtr() const noexcept; + + const size_t &GetDataLen() const noexcept; + + bool operator!=(const BinaryHolder &second) const; + + private: + std::unique_ptr holder_ = nullptr; + size_t data_len_ = 0UL; +}; + +class TensorInfoArgs { + public: + TensorInfoArgs(const Format format, const Format origin_format, const DataType data_type) + : format_(format), + origin_format_(origin_format), + data_type_(data_type) {} + + ~TensorInfoArgs() = default; + + bool IsUnknownShape() const; + bool IsShapeInRange(const TensorInfoArgs &other) const; + bool IsTensorInfoMatch(const TensorInfoArgs &other) const; + Format GetFormat() const; + Format GetOriginFormat() const; + DataType GetDataType() const; + void SetShape(const std::vector &shape); + void SetShape(const SmallVector &shape); + void SetOriginShape(const std::vector &origin_shape); + void SetOriginShape(const SmallVector &origin_shape); + void SetShapeRange(const std::vector> &ranges); + bool operator!=(const TensorInfoArgs &second) const; + + private: + Format format_; + Format origin_format_; + DataType data_type_; + SmallVector shape_; + SmallVector origin_shape_; + SmallVector, kDefaultMaxInputNum> shape_range_; +}; + +class CompileCacheDesc : public CacheDesc { + friend class CacheHasher; + public: + CompileCacheDesc() = default; + ~CompileCacheDesc() override = default; + bool IsEqual(const CacheDescPtr &other) const override; + bool IsMatch(const CacheDescPtr &other) const override; + CacheHashKey GetCacheDescHash() const override; + void SetOpType(const std::string &op_type); + void AddBinary(const BinaryHolder &holder); + void AddBinary(BinaryHolder &&holder); + void AddTensorInfo(const TensorInfoArgs &tensor_info); + void SetScopeId(const std::initializer_list scope_id); + size_t GetTensorInfoSize(); + TensorInfoArgs *MutableTensorInfo(size_t index); + + private: + bool CheckWithoutTensorInfo(const CompileCacheDesc *first, const CompileCacheDesc *second) const; + std::string op_type_; // op type + SmallVector scope_id_; // graph_id and session_id + SmallVector tensor_info_args_vec_; // input tensordescs + SmallVector other_desc_; // attrs float float size +}; +} // namespace ge + +namespace std { +template<> +struct hash { + size_t operator()(const ge::BinaryHolder &value) const { + GE_CHECK_NOTNULL(value.GetDataPtr()); + size_t seed = ge::HashUtils::MultiHash(); + const uint64_t u8_data = ge::PtrToValue(ge::PtrToPtr(value.GetDataPtr())); + for (size_t idx = 0UL; idx < value.GetDataLen(); idx++) { + seed = ge::HashUtils::HashCombine(seed, *(ge::PtrToPtr(ge::ValueToPtr(u8_data + idx)))); + } + return seed; + } +}; +} // namespace std +#endif diff --git a/inc/metadef/graph/cache_policy/match_policy.h b/inc/metadef/graph/cache_policy/match_policy.h new file mode 100644 index 0000000000000000000000000000000000000000..7a2ce33248ce90d7cdab3d0f016c0322dad0380a --- /dev/null +++ b/inc/metadef/graph/cache_policy/match_policy.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_H_ +#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_H_ +#include "graph/cache_policy/cache_state.h" + +namespace ge { +class MatchPolicy { + public: + MatchPolicy() = default; + virtual ~MatchPolicy() = default; + virtual CacheItemId GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &desc) const = 0; + private: + MatchPolicy &operator=(const MatchPolicy &match_polocy) = delete; + MatchPolicy(const MatchPolicy &match_polocy) = delete; +}; +} // namespace ge +#endif diff --git a/inc/metadef/graph/cache_policy/match_policy_exact_only.h b/inc/metadef/graph/cache_policy/match_policy_exact_only.h new file mode 100644 index 0000000000000000000000000000000000000000..66ba40820bbfbd5f3f559545f1a6c0bc5fe224a4 --- /dev/null +++ b/inc/metadef/graph/cache_policy/match_policy_exact_only.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_EXACT_ONLY_H_ +#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_MATCH_POLICY_EXACT_ONLY_H_ +#include "graph/cache_policy/match_policy.h" +#include "graph/cache_policy/policy_register.h" + +namespace ge { +class MatchPolicyExactOnly : public MatchPolicy { +public: + CacheItemId GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &desc) const override; + ~MatchPolicyExactOnly() override = default; +}; + +REGISTER_MATCH_POLICY_CREATOR(MatchPolicyType::MATCH_POLICY_EXACT_ONLY, + []() { + return make_shared(); + }); +} // namespace ge +#endif + diff --git a/inc/metadef/graph/cache_policy/match_policy_for_exactly_the_same.h b/inc/metadef/graph/cache_policy/match_policy_for_exactly_the_same.h new file mode 100644 index 0000000000000000000000000000000000000000..170c6edb51e50fdc88b1deea690188d12957ebbf --- /dev/null +++ b/inc/metadef/graph/cache_policy/match_policy_for_exactly_the_same.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_GRAPH_CACHE_POLICY_MATCH_POLICY_FOR_EXACTLY_THE_SAME_H +#define METADEF_CXX_GRAPH_CACHE_POLICY_MATCH_POLICY_FOR_EXACTLY_THE_SAME_H +#include "graph/cache_policy/match_policy.h" +#include "graph/cache_policy/policy_register.h" + +namespace ge { +class MatchPolicyForExactlyTheSame : public MatchPolicy { + public: + MatchPolicyForExactlyTheSame() = default; + ~MatchPolicyForExactlyTheSame() override = default; + + CacheItemId GetCacheItemId(const CCStatType &cc_state, const CacheDescPtr &cache_desc) const override; +}; +REGISTER_MATCH_POLICY_CREATOR(MatchPolicyType::MATCH_POLICY_FOR_EXACTLY_THE_SAME, + []() { return std::make_shared(); }); +} // namespace ge +#endif // METADEF_CXX_GRAPH_CACHE_POLICY_MATCH_POLICY_FOR_EXACTLY_THE_SAME_H diff --git a/inc/metadef/graph/cache_policy/policy_register.h b/inc/metadef/graph/cache_policy/policy_register.h new file mode 100644 index 0000000000000000000000000000000000000000..c676c936e604a14be771778bae40d3ba24120821 --- /dev/null +++ b/inc/metadef/graph/cache_policy/policy_register.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_POLICY_REGISTER_H_ +#define GRAPH_CACHE_POLICY_POLICY_MANAGEMENT_POLICY_REGISTER_H_ +#include +#include +#include "common/checker.h" +#include "match_policy.h" +#include "aging_policy.h" + +namespace ge { +using MatchPolicyPtr = std::shared_ptr; +using AgingPolicyPtr = std::shared_ptr; +using MatchPolicyCreator = std::function; +using AgingPolicyCreator = std::function; +enum class MatchPolicyType { + MATCH_POLICY_EXACT_ONLY = 0, + MATCH_POLICY_FOR_EXACTLY_THE_SAME = 1 +}; +enum class AgingPolicyType { + AGING_POLICY_LRU = 0, + AGING_POLICY_LRU_K = 1 +}; + +class PolicyRegister { + public: + ~PolicyRegister() = default; + PolicyRegister(const PolicyRegister&) = delete; + PolicyRegister &operator=(const PolicyRegister &other) = delete; + static PolicyRegister &GetInstance(); + void RegisterMatchPolicy(const MatchPolicyType match_policy_type, const MatchPolicyCreator &creator) { + const std::lock_guard lock(mu_); + (void)match_policy_registry_.emplace(match_policy_type, creator); + return; + } + + void RegisterAgingPolicy(const AgingPolicyType aging_policy_type, const AgingPolicyCreator &creator) { + const std::lock_guard lock(mu_); + (void)aging_policy_registry_.emplace(aging_policy_type, creator); + } + + MatchPolicyPtr GetMatchPolicy(const MatchPolicyType match_policy_type) { + const auto iter = match_policy_registry_.find(match_policy_type); + if (iter != match_policy_registry_.end()) { + GE_ASSERT_NOTNULL(iter->second, "[GetMatchPolicy] failed. Match policy type : %d was incorrectly registered", + static_cast(match_policy_type)); + return iter->second(); + } + GELOGE(ge::GRAPH_FAILED, "[GetMatchPolicy] failed. Match policy type : %d has not been registered", + static_cast(match_policy_type)); + return nullptr; + } + AgingPolicyPtr GetAgingPolicy(const AgingPolicyType aging_policy_type) { + const auto iter = aging_policy_registry_.find(aging_policy_type); + if (iter != aging_policy_registry_.end()) { + GE_ASSERT_NOTNULL(iter->second, "[GetAgingPolicy] failed. Aging policy type : %d was incorrectly registered", + static_cast(aging_policy_type)); + return iter->second(); + } + GELOGE(ge::GRAPH_FAILED, "[GetAgingPolicy] failed. Aging policy type : %d has not been registered", + static_cast(aging_policy_type)); + return nullptr; + } + private: + PolicyRegister() = default; + std::mutex mu_; + std::map match_policy_registry_; + std::map aging_policy_registry_; +}; + +class MatchPolicyRegister { + public: + MatchPolicyRegister(const MatchPolicyType match_policy_type, const MatchPolicyCreator &creator); + ~MatchPolicyRegister() = default; +}; + +class AgingPolicyRegister { + public: + AgingPolicyRegister(const AgingPolicyType aging_policy_type, const AgingPolicyCreator &creator); + ~AgingPolicyRegister() = default; +}; + +#define REGISTER_MATCH_POLICY_CREATOR_COUNTER(policy_type, func, counter) \ + static MatchPolicyRegister match_policy_register##counter(policy_type, func) +#define REGISTER_MATCH_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, counter) \ + REGISTER_MATCH_POLICY_CREATOR_COUNTER(policy_type, func, counter) +#define REGISTER_MATCH_POLICY_CREATOR(policy_type, func) \ + REGISTER_MATCH_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, __COUNTER__) + +#define REGISTER_AGING_POLICY_CREATOR_COUNTER(policy_type, func, counter) \ + static AgingPolicyRegister aging_policy_register##counter(policy_type, func) +#define REGISTER_AGING_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, counter) \ + REGISTER_AGING_POLICY_CREATOR_COUNTER(policy_type, func, counter) +#define REGISTER_AGING_POLICY_CREATOR(policy_type, func) \ + REGISTER_AGING_POLICY_CREATOR_COUNTER_NUMBER(policy_type, func, __COUNTER__) +} // namespace ge +#endif diff --git a/inc/metadef/graph/common_error_codes.h b/inc/metadef/graph/common_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..a7267477dce5385b932dfbf3f26251be28585710 --- /dev/null +++ b/inc/metadef/graph/common_error_codes.h @@ -0,0 +1,21 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_COMMON_ERROR_CODES_H_ +#define INC_GRAPH_COMMON_ERROR_CODES_H_ + +#include "external/graph/ge_error_codes.h" + +namespace ge { +constexpr graphStatus NO_DEPENDENCE_FUNC = 50331647U; +constexpr graphStatus NO_OVERLAP_DIM = 50331646U; +constexpr graphStatus NOT_SUPPORT_SLICE = 50331645U; +} // namespace ge + +#endif // INC_GRAPH_COMMON_ERROR_CODES_H_ diff --git a/inc/metadef/graph/compiler_options.h b/inc/metadef/graph/compiler_options.h new file mode 100644 index 0000000000000000000000000000000000000000..46e655edd1146d625ae9ea391f783e21962f4343 --- /dev/null +++ b/inc/metadef/graph/compiler_options.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_COMPILER_OPTIONS_H_ +#define INC_GRAPH_COMPILER_OPTIONS_H_ + +#ifdef __GNUC__ +#define METADEF_ATTRIBUTE_UNUSED __attribute__((unused)) +#define METADEF_FUNCTION_IDENTIFIER __PRETTY_FUNCTION__ +#define METADEF_BUILTIN_PREFETCH(args_addr) __builtin_prefetch(args_addr) + +#ifdef HOST_VISIBILITY +#define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_HOST_VISIBILITY +#endif + +#ifdef DEV_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY __attribute__((visibility("default"))) +#else +#define GE_FUNC_DEV_VISIBILITY +#endif + +#else // WINDOWS +#define METADEF_ATTRIBUTE_UNUSED +#define METADEF_FUNCTION_IDENTIFIER __FUNCSIG__ +#define METADEF_BUILTIN_PREFETCH(args_addr) +#define GE_FUNC_HOST_VISIBILITY +#define GE_FUNC_DEV_VISIBILITY +#endif + +#endif // INC_GRAPH_COMPILER_OPTIONS_H_ diff --git a/inc/metadef/graph/compute_graph.h b/inc/metadef/graph/compute_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..43cd66e897c45cd384a5137a02f53a57fa86046f --- /dev/null +++ b/inc/metadef/graph/compute_graph.h @@ -0,0 +1,307 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_COMPUTE_GRAPH_H_ +#define INC_GRAPH_COMPUTE_GRAPH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/anchor.h" +#include "graph/node.h" +#include "graph/op_desc.h" +#include "graph/range_vistor.h" + +namespace ge { +using ConstComputeGraph = const ComputeGraph; + +class OperatorImpl; +using OperatorImplPtr = std::shared_ptr; + +class ComputeGraphImpl; +using ComputeGraphImplPtr = std::shared_ptr; + +using AttrFilter = std::function; +using NodeFilter = std::function; +using GraphFilter = std::function; + +class ComputeGraph : public std::enable_shared_from_this, public AttrHolder { + friend class GraphUtils; + + public: + template + using Vistor = RangeVistor>; + + explicit ComputeGraph(const std::string &name); + explicit ComputeGraph(const char_t *name); + ~ComputeGraph() override; + ComputeGraph(const ge::ComputeGraph& compute_graph); + ComputeGraph(ge::ComputeGraph&& compute_graph); + + std::string GetName() const; + void SetName(const std::string &name); + + using AttrHolder::DelAttr; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + size_t GetAllNodesSize() const; + /** + * 递归的获取当前图和其子图的节点合集 + * @return 有序的节点合集,按照如下顺序返回 + * {node0, node1, {subgraph_node0, {sub_subgraph_node0, ..., sub_subgraph_noden}... ,subgraph_noden}, ... ,noden} + */ + Vistor GetAllNodes() const; + std::vector GetAllNodesPtr() const; + // is_unknown_shape: false, same with GetAllNodes func + // is_unknown_shape: true, same with GetDirectNodes func + Vistor GetNodes(const bool is_unknown_shape) const; + Vistor GetNodes(const bool is_unknown_shape, const NodeFilter &node_filter, + const GraphFilter &graph_filter) const; + size_t GetDirectNodesSize() const; + /** + * 获取当前图直接包含的节点合集,并不会递归处理当前图的子图节点,注意和GetAllNodes()区分 + * @return 有序的节点合集,按照如下顺序返回 + * {node0, node1, ..., noden} + */ + Vistor GetDirectNode() const; + std::vector GetDirectNodePtr() const; + Vistor GetInputNodes() const; + Vistor GetOutputNodes() const; + + NodePtr FindNode(const std::string &name) const; + NodePtr FindFirstNodeMatchType(const std::string &name) const; + // AddNode with NodePtr + NodePtr AddNode(const NodePtr node); + NodePtr AddNode(const OpDescPtr op); + NodePtr AddNode(const OpDescPtr op, const int64_t id); // for unserialize + + /** + * 将所有的insert_ops中的OpDesc构造出Node插到nodes_中,插在node的后面 + * @param node 被插Node + * @param insert_ops 需要插入的OpDesc + * @return 返回被插入的所有OpDesc生成的Nodes + */ + std::vector InsertNodes(const NodePtr &node, const std::vector &insert_ops); + + /** + * 使用insert_op构造出Node插到nodes_中,插在node的后面 + * @param node 被插Node + * @param insert_op 需要插入的OpDesc + * @return 返回被插入的OpDesc生成的Node + */ + NodePtr InsertNode(const NodePtr &node, const OpDescPtr &insert_op); + + /** + * 将所有的fusion_ops中的OpDesc构造出Node插到nodes_中,插在ori_nodes算子集合中topo序最小的算子后面 + * 性能更优,优先使用。 + * @param ori_node 被融合的Node + * @param fusion_ops 需要融合的OpDesc + * @return 返回被插入的所有OpDesc生成的Nodes + */ + std::vector FuseNodeKeepTopo(const std::vector &ori_nodes, + const std::vector &fusion_ops); + + NodePtr AddNodeFront(const NodePtr node); + NodePtr AddNodeFront(const OpDescPtr &op); + NodePtr AddInputNode(const NodePtr node); + NodePtr AddOutputNode(const NodePtr node); + NodePtr AddOutputNodeByIndex(const NodePtr node, const int32_t index); + + graphStatus RemoveNode(const NodePtr &node); + graphStatus RemoveInputNode(const NodePtr &node); + graphStatus RemoveOutputNode(const NodePtr &node); + graphStatus RemoveConstInput(const NodePtr &node); + + /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, + /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph + /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() + /// must equal to subgraph->GetParentGraph(). + /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. + /// The subgraph's name SHOULD(not must) be the same as the parameter `name` + graphStatus AddSubgraph(const std::string &name, const std::shared_ptr &subgraph); + graphStatus AddSubgraph(const std::shared_ptr &subgraph); + + void RemoveSubgraph(const std::string &name); + void RemoveSubgraph(const std::shared_ptr &subgraph); + + std::shared_ptr GetSubgraph(const std::string &name) const; + std::vector> GetAllSubgraphs() const; + void SetAllSubgraphs(const std::vector> &subgraphs); + + // obsolete + std::shared_ptr AddSubGraph(const std::shared_ptr sub_graph); + // obsolete + graphStatus RemoveSubGraph(const std::shared_ptr &sub_graph); + + /// @brief Update input-mapping + /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input + /// @return graphStatus + graphStatus UpdateInputMapping(const std::map &input_mapping); + + /// @brief Update output-mapping + /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output + /// @return graphStatus + graphStatus UpdateOutputMapping(const std::map &output_mapping); + + void TopologicalSorting(const std::function comp); + graphStatus TopologicalSorting(); + bool IsValid() const; + void InValid(); + void Dump() const; + + void Swap(ComputeGraph &graph); + + graphStatus IsolateNode(const NodePtr &node); + graphStatus InsertGraphEvents(); + bool operator==(const ComputeGraph &r_compute_graph) const; + ComputeGraph& operator=(ge::ComputeGraph &compute_graph); + + const std::map, std::vector> &GetShareParamLayer() const; + + void SetShareParamLayer(const std::map, std::vector> params_share_map); + + void SetInputsOrder(const std::vector &inputs_order); + + void SetGraphOutNodes(const std::map> out_nodes_map); + + void AppendGraphOutNodes(const std::map> out_nodes_map); + + std::shared_ptr GetParentGraph(); + const ComputeGraph *GetParentGraphBarePtr() const; + void SetParentGraph(const std::shared_ptr &parent); + std::shared_ptr GetParentNode(); + const Node *GetParentNodeBarePtr() const; + void SetParentNode(const std::shared_ptr &parent); + /** + * 获取图的`NETOUTPUT`节点, 如果图中直接获取的`NETOUTPUT`节点无效,则会遍历图寻找 + * 类型为`NETOUTPUT`的节点,调用`SetNetOutputNode`刷新并返回 + * @return 如果查找成功返回图的输出节点,如果失败返回nullptr + */ + std::shared_ptr GetOrUpdateNetOutputNode(); + void SetNetOutputNode(const std::shared_ptr &netoutput_node); + const std::map> &GetGraphOutNodes() const; + void SetOrigGraph(const ComputeGraphPtr orig_graph); + + ComputeGraphPtr GetOrigGraph(void); + void SetOutputSize(const uint32_t size); + uint32_t GetOutputSize() const; + void SetInputSize(const uint32_t size); + uint32_t GetInputSize() const; + + // false: known shape true: unknow shape + bool GetGraphUnknownFlag() const; + void SetGraphUnknownFlag(const bool flag); + + /// Set is need train iteration. + /// If set true, it means this graph need to be run iteration some + /// times(according variant "npu_runconfig/iterations_per_loop"). + /// @param need_iteration is need iteration + void SetNeedIteration(const bool need_iteration); + + void SetUserDefOutput(const std::string &output_name); + + const std::string GetOutput(); + + /// Get is need train iteration. + /// @return is need iteration + bool GetNeedIteration() const; + + void SetGraphOpName(const std::map &op_name_map); + const std::map &GetGraphOpName() const; + + const std::map &GetAllNodesInfo() const; + + void SetAllNodesInfo(const std::map &nodes); + + void SetGraphOutNodesInfo(std::vector> &out_nodes_info); + void AppendGraphOutNodesInfo(std::vector> &out_nodes_info); + const std::vector> &GetGraphOutNodesInfo() const; + + void SetGraphTargetNodesInfo(const std::vector &target_nodes_info); + const std::vector &GetGraphTargetNodesInfo() const; + + void SetSessionID(const uint64_t session_id); + uint64_t GetSessionID() const; + + void SetGraphID(const uint32_t graph_id); + uint32_t GetGraphID() const; + + void SaveDataFormat(const ge::Format data_format); + ge::Format GetDataFormat() const; + bool IsSummaryGraph() const; + void SetSummaryFlag(const bool is_summary_graph); + + /// nodes like : (a) <--- (c) ---> (b) + /// node a and b have only one parent node c, and a is connected to c firstly + /// topo order of DFS is `c, b, a` with `dfs_reverse=false` as default + /// in same case, user could get `c, a, b` with `dfs_reverse=true` + graphStatus TopologicalSortingGraph(const bool dfs_reverse = false); + /** + * Move Send Event nodes after it`s control node + * Move Recv Event nodes before it`s control node + */ + graphStatus ReorderEventNodes(); + void ClearNodeList(); + void ReorderByNodeId(); + + template + T *GetOrCreateAttrsGroup() { + return MutableAttrMap().GetOrCreateAttrsGroup(); + } + + protected: + ProtoAttrMap &MutableAttrMap() override; + ConstProtoAttrMap &GetAttrMap() const override; + + private: + graphStatus DFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, + std::vector &stack, const bool reverse); + graphStatus BFSTopologicalSorting(std::vector &node_vec, std::map &map_in_edge_num, + std::deque &stack); + graphStatus CollectBreadthOutNode(const NodePtr &node, std::map &map_in_edge_num, + std::map &breadth_node_map); + + graphStatus SortNodes(std::vector &stack, std::map &map_in_edge_num); + Vistor AllGraphNodes(std::vector &subgraphs) const; + std::vector AllGraphNodesPtr(std::vector &subgraphs) const; + Vistor GetAllNodes(const NodeFilter &node_filter, const GraphFilter &graph_filter) const; + size_t GetInEdgeSize(const NodePtr &node) const; + size_t GetOutEdgeSize(const NodePtr &node) const; + graphStatus RemoveExtraOutEdge(const NodePtr &node) const; + bool GraphMembersAreEqual(const ComputeGraph &r_graph) const; + bool GraphAttrsAreEqual(const ComputeGraph &r_graph) const; + bool VectorInputNodePtrIsEqual(const std::vector &left_nodes, + const std::vector &right_nodes) const; + + void SetNodesOwner(); + // Update parent graph of the subgraph which is the direct subgraph of the root graph. + void SetTopParentGraph(); + /** + * To improve preformace of list.size(), we should keep counter on nodes_.size() + * Use follow function to add/erase node from nodes_ + */ + void EraseFromNodeList(const std::list::iterator position); + + friend class ModelSerializeImp; + friend class OnnxUtils; + friend class TuningUtils; + friend class ExecuteGraphAdapter; + + ComputeGraphImplPtr impl_; +}; +} // namespace ge +#endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/inc/metadef/graph/debug/ge_attr_define.h b/inc/metadef/graph/debug/ge_attr_define.h new file mode 100644 index 0000000000000000000000000000000000000000..db56e64c2f84fd15e641b5954c1a8949279f578c --- /dev/null +++ b/inc/metadef/graph/debug/ge_attr_define.h @@ -0,0 +1,1554 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ +#define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ + +/*lint -e618*/ +#include +#include "graph/types.h" +#include "graph/compiler_options.h" + +namespace ge { +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_FILE_PATH; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_UNKNOWN_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHT_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALPHA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BETA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADMODES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SCALE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WINDOWS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GLOBAL_POOLING; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELU_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALGO; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGIN_FORMAT_IS_SET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_NORM_REGION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_LOCAL_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_ALPHA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_BETA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROADCAST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TIDX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TPADDINGS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NET_W; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TMULTIPLES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTIPLES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_T; + +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string ATTR_NAME_N; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TSHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_AIPP_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_RELATED_AIPP_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_AIPP_DATA_NAME_MAP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_HAS_BEEN_ADDED; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_OP_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_TENSOR_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INFERRED_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHT_SHA256; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_REUSE_EXTERNAL_WEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_MEMSET_SIZES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_MEMSET_DTYPES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_MEMSET_VALUES_INT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_MEMSET_VALUES_FLOAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTS_LABEL_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ROOT_GRAPH_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ROOT_GRAPH_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_ROOT_GRAPH; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_MAP_RANK_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SEND_EVENT_IDS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RECV_EVENT_IDS; + +// to be deleted +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_CONF_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_OCR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; + +// _Arg +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INDEX; +// _RetVal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETVAL_ATTR_NAME_INDEX; +// Data +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DATA_ATTR_NAME_DATA_TYPE; + +// Send +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SEND_ATTR_EVENT_ID; + +// SendNotify +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SEND_ATTR_NOTIFY_ID; + +// Recv +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RECV_ATTR_EVENT_ID; + +// RecvNotify +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RECV_ATTR_NOTIFY_ID; + +// Convolution +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COEF; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATIONS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ALGO; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_GROUP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD_MODE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_STRIDE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_DILATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_NUM_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_KERNEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_FILTER; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BIAS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_RELU_FLAG; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_ADJ; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_TARGET_SHAPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_BEFORE_PAD; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_ATTR_NAME_HAS_BIAS; + +// Pooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAN_OPT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_BEFORE_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOLING_ATTR_NAME_ALGO; + +// Eltwise +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_COEFF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_WEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_RELU_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ELTWISE_ATTR_BETA; + +// BatchNorm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_EPSILON; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; + +// Huberloss +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; + +// SSDRealDivTileMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; + +// SSDSumMulRealDivMean +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; +/// ConcatFive2Four +/// ConcatFour2Five +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; +// Scale +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; + +// FullConnection +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_FILTER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_CONNECTION_ATTR_RELU_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FULL_ATTR_NAME_ALGO; + +// SoftmaxOpParams +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_ALGO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_MODE; + +// SparseSoftmaxCrossEntropy +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; +// Attr labelSmoothing +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; + +// ApplyMomentum +extern GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string APPLYMENTUM_ATTR_IS_GRAPH_FUSION; + +// Activation +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ACTIVATION_ATTR_COEF; + +// Concat +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_ATTR_NAME_AXIS; + +// Const +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_INPUT; + +// Roipooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLED_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; + +// DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; +// Ssd DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ETA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; + +// Refinedet DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; + +// Yolo DetectionOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_ClASSES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_BIASES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_RELATIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; + +// DetectionPostprocess +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; + +// Spatialtransfrom +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; + +// Proposal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PROPOSAL_ATTR_IMG_W; +// Softmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ATTR_AXIS; + +// Permute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; + +// SSD Normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_EPS; + +// Flatten +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_ATTR_END_AXIS; + +// SsdPRIORBOX +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_FLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_IMG_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_STEP_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; + +// PRelu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PRELU_ATTR_CHANNEL_SHARED; + +// Psroi pooling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PSROIPOOLING_ATTR_GROUP_SIZE; + +// Power +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_POWER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; + +// Log +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; + +// Pack +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; + +// Dynamic stitch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; + +// Unpack +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; + +// Gathernd +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND_ATTR_NAME_TPARAMS; + +// Argmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; + +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; + +// Relu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; + +// FreeSpaceExtract +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; + +// Split +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SLICE_POINT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_ATTR_NAME_NUM_SPLIT; + +// Tvm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; + +// Ffts Tvm +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_THREAD_MAGIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_THREAD_BLOCKDIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_THREAD_METADATA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_THREAD_WORKSPACE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_THREAD_N_BATCH_SPLIT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_THREAD_TBE_KERNEL_BUFFER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_THREAD_TBE_KERNEL_NAME; + +// Squeeze +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_OP_NAME; + +// Stride slice +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_END_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; + +// Slice +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_BEGINS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SLICE_ATTR_NAME_SIZES; + +// Roialign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SPATIAL_SCALE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; + +// Generate_rpn_proposal +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; +// Decode_bbox +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DECODE_BBOX_ATTR_DECODECLIP; + +// Cast +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DSTT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_SRCT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_DST_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CAST_ATTR_TRUNCATE; + +// Fastrcnnn predications +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; + +// REORG +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_ATTR_REVERSE; + +// MERGE +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; + +// ENTER +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_CONSTANT_FLAG; + +// Concatv2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_TIDX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONCAT_V2_ATTR_N; +// SUM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_TIDX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SUM_ATTR_KEEP_DIMS; + +// ResizeBilinear +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_HEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_WIDTH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_PAD_END; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; + +// RetinaNet +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; + +// MatMul +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_HAS_BIAS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_ATTR_IS_TRAINING; + +// Flatten +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_START_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FLATTEN_END_AXIS; + +// Reshape +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NUM_AXES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_ALPHA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_BETA; + +// Frameoworkop +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_IN_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string T_OUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_C; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_H; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_W; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_DEPTH_CONV; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_PAD_CONV; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BEFORE_PAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ANN_MEAN_KEEPDIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_PADDINGDS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_ATTR_CONSTANT_VALUE; + +// ConvGradFilter +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; + +// ConvGradInput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; + +// Rnn +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_MODE_STATIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MUTI_RNN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CELL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CNN_RNN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; + +// Upsample +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; + +// PadV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; + +// MirrorPad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; + +// Filler +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; + +// Shufflechannel +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHUFFLE_CHANNEL_GROUP; + +// TopKV2 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TOPKV2_ATTR_K; + +// Calibaration +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_H_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string STRIDE_W_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_TOP_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_BOTTOM_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_RIGHT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DILATION_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EPSILON; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POOLING_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CLASS_NUM; + +// Model +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TARGET_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_STREAM_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_NOTIFY_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_NOTIFY_TYPES; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_SESSION_SCOPE_MEMORY_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_SUB_MEMORY_INFO; + +// Used for om compress +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OM_COMPRESS_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATTR_NAME_ENUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATTR_VALUE_ENUM; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATTRS_USE_STRING_VALUE; + +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BYTE_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_INFERENCE_ID; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_OPDEF; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IO_OP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_SCOPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPATTR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SEQLEN_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_X_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONT_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_XSTATIC_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_MINI; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_TINY; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_TYPE_LITE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; + +// attr _input_mutable = true means node will modify its input in runtime +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODIFY_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; + +// Used for operators that do not generate task +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; + +// Used for operators that output reuse input +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_OUTPUT_MAX_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_CMDLINE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_COMPILER_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; + +// L2_normalize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; + +// HCOM +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; + +// Log time stamp +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; + +// SpaceToDepth/DepthToSpace +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; + +// SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; + +// MaxPoolGradWithArgmax +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; + +// AvgPoolGrad +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; + +// Varible +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; + +// Assign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VAR_NAME; + +// Inplace support +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string INPLACE_SUPPORT_INPUT_INDEX; + +// ShapeN +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SPLIT_SHAPEN_ORIGIN_NAME; + +// Space2bacth batch2space +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; +// Depth_to_space space_to_depth +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; +// FakeQuantWithMinMaxVars +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; +// Mobilenet_ssd_conv_fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; + +// Lsh project +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; + +// GatherV2 attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; + +// Reshape attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; + +// Axis attr def +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; +// The node link with SparseSoftmaxCrossEntropyWithLogits +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; +// For constant folding +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; + +// Mark node inserted by inside +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_INSERTED_BY_CANN; + +// For AscendWeightQuant+Enter +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FINAL_CONST_NODE; + +// Used for mark the active label list to find stream of activated node +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_VISIT_DISTANCE; + +// Multi batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_USER_DESIGNEATE_SHAPE_ORDER; + +// Control flow +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIG_NODE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_TYPE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_INFER_AGAIN; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MERGE_INPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTROL_FLOW_GROUP; + +// Function Op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; + +// Used for mark the active node is for loop, type:bool +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_INPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_OUTPUT; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_TYPE_RANGE; + +// Atomic addr clean attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_INPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_OUTPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_FUSION_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_IS_ATOMIC_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_TVM_MAGIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_TVM_METADATA; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATOMIC_ATTR_TBE_KERNEL_NAME; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_WORKSPACE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string EXT_ATTR_ATOMIC_TBE_KERNEL; +// Used for find variable session_id +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MODEL_ATTR_SESSION_ID; + +// Source/dst format for Op FormatTransfer +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_SRC_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_DST_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_SRC_SUBFORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FORMAT_TRANSFER_DST_SUBFORMAT; + +// For compile op by ge call +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEED_COMPILE; + +// For mutil-batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; + +// For inserted op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; + +// For compress weight +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; + +// For data dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_TYPES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_SUB_SPLITER_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_GROUP_OP_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGIN_OP_ATTRS_IN_FUSION_PROCESS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGIN_OP_ATTRS_MAP; + +// used for lX fusion +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENGINE_NAME_FOR_LX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_LX_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZE_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_COMPILE_STRATEGY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME_FOR_LOAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_SLICE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_RECOVER_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OFF_SUPERKERNEL_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SRC_CONST_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_KEEP_ORIGIN_INPUT_AND_OUTPUT; + +// merge subgraph with output anchor map +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_ORIGIN_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_ORIGIN_OUTPUT_INDEX; + +// read var offset +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INNER_OFFSET; + +// used for memory allocate +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_MEM_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_P2P_MEMORY_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUB_STREAM_ID; + +// for unregistered op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; + +// op overflow dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; + +// op dynamic input +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_START; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_INPUT_END; + +// functional ops attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; + +// used for label switch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; + + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; +// used for LX tiling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; + +// Used for support Horovod +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; +// for gradient group +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; +// used for recording the number of tasks to be issued for each operator +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_TASK_NUM; +// used for recording task num of RTS nodes such as MemcpyAsync +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_SQE_NUM; +// for parallel group +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARALLEL_GROUP; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_SUPPORT_ADDR_REFRESH; + +// dynamic shape attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_SINGLE_AICPU; + +// atc user def dtype&format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; + +// atc user def dtype&format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES; + +// for fusion op plugin +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; + +// graph partition for aicpu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; + +// aicpu workspace type +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AICPU_WORKSPACE_TYPE; + +// input and output memory type +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SPECIAL_OUTPUT_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SPECIAL_INPUT_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEMORY_SCOPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEMORY_SCOPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_MEMORY_SCOPE; + +// stage +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_STAGE_LEVEL; + +// input_output_offset +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_IS_ZERO_COPY_BLOCK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_OFFSET_LIST_FOR_CONTINUOUS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_LIST_FOR_CONTINUOUS; + +// mark node cannot be deleted +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CANNOT_BE_DELETED; + +// The processing mode of INF and NAN during floating-point number calculation. +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_FP_CEILING_MODE; +// count of data from getnext_sink +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_GETNEXT_SINK_DATA_COUNT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_GETNEXT_SINK_SHAPE_INFO; + +// getnext_sink marked on NetOutput +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_GETNEXT_SINK_DYNMAIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ALL_GEARS_INFO; + +// Calculate the operator output memory +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_SIZE_CALC_TYPE; +// Indicates which operators keep the precision unchanged +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KEEP_DTYPE; + +// profiling task mark on fp bp +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INSERT_FP_PROFILILNG_TASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INSERT_BP_PROFILILNG_TASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INSERT_END_PROFILILNG_TASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INSERT_PROFILILNG_TASK_LOG_ID; +// padding dimmension type +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RESHAPE_INFER_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RESHAPE_TYPE_MASK; + +// mark single op scene +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_SINGLE_OP_SCENE; + +// for fe judge whether trans/cast op is inserted +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT_CONTINUOUS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFRESH_CONTINUOUS_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT_AGNOSTIC; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_OUTPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_INPUT; + +// for ffts/ffts_plus +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FFTS_SUB_GRAPH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_THREAD_SCOPE_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_THREAD_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FFTS_PLUS_SUB_GRAPH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPOSITE_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPOSITE_ENGINE_KERNEL_LIB_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CUBE_VECTOR_CORE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CACHE_PERSIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ALIAS_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KERNEL_NAMES_PREFIX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FFTS_SUB_TASK_TENSOR_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FFTS_SUB_TASK_TENSOR_OFFSETS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_FFTS_UNSUPPORTED; + +// mark fuzz build scene +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUZZ_BUILD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLACEMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALUE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALUE_RANGE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BUILD_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUZZ_BUILD_RES_ATTRS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUZZ_INPUTS_SUPPORTED_ATTRS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUZZ_OUTPUTS_SUPPORTED_ATTRS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUZZ_IS_HIGH_PERFORMANCE_ATTRS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_ORIGINAL_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_OP_GENERALIZED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_DYNAMIC_MODEL; + +// buffer pool allocator +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BUFFER_POOL_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BUFFER_POOL_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EVENT_MULTIPLEXING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BUFFER_POOL_NODE_SIZE_AND_OFFSET; + +// session scope memory +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE_MEMORY_NO_REUSE_SCOPE; + +// for blocking op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_BLOCKING_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCKING_OP_TIMEOUT; + +// for op specified engine +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_SPECIFIED_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_SPECIFIED_KERNEL_LIB_NAME; + +// for pipeline partition +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PIPELINE_PARTITIONED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_PIPELINE; + +// for partition +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NO_NEED_PARTITION_AND_MERGE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NO_NEED_PARTITION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NO_NEED_MERGE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NO_NEED_DYNAMIC_SHAPE_PARTITION; + +// model deploy scheduler(mds) +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRADIENT_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRAINABLE_VAR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FISSION_FACTOR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEPLOY_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CUT_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEPLOY_DEVICE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEPLOY_DEVICE_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REDUNDANT_DEPLOY_DEVICE_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEPLOY_GRAPH_INPUTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEPLOY_NEED_RETURN_RESULT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORCE_ATTACH_STREAM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RT_MEMCPY_KIND; + +// for qos +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_QOS_SERVICE_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARALLEL_GROUP_ID; + +// for constant folding, mark potential const +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POTENTIAL_CONST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POTENTIAL_WEIGHT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_POTENTIAL_WEIGHT_INDICES; + +// name of network output tensor +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGIN_OUTPUT_TENSOR_NAME; + +// for scope op to record the input and output information of the original graph node +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGIN_GRAPH_NODE_INPUTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGIN_GRAPH_NODE_OUTPUTS; + +// for operator resource list(e.g. queues, channels) +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RESOURCE_LIST; + +// for no tiling +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_TILING_INLINE_ENGINE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_EXPORT_SHAPE_ENGINE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_MAX_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_MAX_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_MAX_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_NO_TILING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_DESC_MEM_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_NO_TILING_MEM_TYPE; + +// for soft sync op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STATIC_TO_DYNAMIC_SOFT_SYNC_OP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SGT_CUBE_VECTOR_CORE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MAX_TILING_SIZE; + +// for subgraph multi dims +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INPUT_SHAPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_MULTI_DIMS_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_IS_MULTI_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_MULTI_DIMS_INPUT_DIMS; + +// for support BlockDim +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUPPORT_BLOCKDIM_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCKDIM_INDEX; + +// for support dynamic data flow ops +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_FLOW_HANDLE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_FLOW_MAX_SIZE; + +// mark node inserted by ge +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_INSERTED_BY_GE; + +// for cmo feature +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_REUSE_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_READ_WRITE_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEM_RELEASE_FIRST_REUSE_FIRST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_NEED_MULTI_TASK; + +// for support overflow detection +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GLOBALWORKSPACE_SPEC_WORKSPACE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GLOBALWORKSPACE_SPEC_WORKSPACE_BYTES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GLOBALWORKSPACE_TYPE; + +// for value depend +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALUE_DEPEND; + +// for process node engine +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PROCESS_NODE_ENGINE_ID; + +// for dynamic graph memory discontiguous +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMORY_DISCONTIGUOUS_ALLOCATION; + +// for flow attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_ATTR_DEPTH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_ATTR_ENQUEUE_POLICY; + +// for align attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUTS_ALIGN_ATTR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUTS_ALIGN_INTERVAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUTS_ALIGN_OFFSET; + +// for binnary om feature +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUT_SHAPE_LOCKED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT_LOCKED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SKIP_GEN_TASK; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OM_BINARY_PATH; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_GENTASK_ATOMIC; + +// for pipeline partition attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARALLEL_SHARDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PIPELINE_STAGE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VIRTUAL_STAGE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LOGIC_DEV_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REDUNDANT_LOGIC_DEV_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STAGE_ORDER_ID; + +// for graph memory optimize +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RECOMPUTE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BACKWARD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_SLICE_SCOPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_SLICE_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_FIXED_ADDR_PRIOR; + +// for model eshced priority +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ESCHED_PROCESS_PRIORITY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ESCHED_EVENT_PRIORITY; + +// for graph deployment +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEVICE_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODEL_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEVICE_INDEX_TO_LOGIC_DEVICE_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_DEPLOYMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_DEPLOYMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_DEPLOYMENTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_DEPLOYMENTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODEL_INDEX_TO_LOGIC_DEVICE_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ORIGINAL_CONST_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DEPLOY_DEVICE_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RECOMPUTE_MODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MODEL_EVENTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCOM_GROUPS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SHARD_GRAPH_EXT_ATTRS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_SHARD_GRAPH_FOR_LOAD; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_MODEL_DEPLOY_MODE; + +// for tensor parallelism +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TP_RESHARD_ATTR; + +// for lowering +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_FLATTEN_OFFSET; + +// for fileconstant +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILE_CONSTANT_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILE_PATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILE_CONSTANT_PATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LOCATION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LENGTH; + +// for embedding service +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EXECUTE_TIMES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MAX_KEY_NUM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EMBEDDING_DIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TAG_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZER_GRAPH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_EMBEDDING_GRAPH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_TRANSFER_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMM_GROUP_NAMES; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_USE_COUNTER_FILTER; + +// for host env os & cpu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HOST_ENV_OS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HOST_ENV_CPU; + +// for mc2 stream¬ify assign +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_STREAM_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_NOTIFY_INFO; + +// for mc2 inner named-attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DISABLE_ATTACHED_RESOURCE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_STREAM_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_STREAM_POLICY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_STREAM_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_NOTIFY_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_NOTIFY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_NOTIFY_POLICY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_NOTIFY_NUM; + +// 由引擎归一打到op_desc上,类型为list_name_attr +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_STREAM_INFO_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST; +// 下面的属性是在ATTR_NAME_ATTACHED_STREAM_INFO_LIST跟ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST的属性内部,为name_attr类型,不要直接打在op_desc上 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_RESOURCE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_RESOURCE_REUSE_KEY; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string + ATTR_NAME_ATTACHED_RESOURCE_DEPEND_VALUE_LIST_INT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_RESOURCE_REQUIRED_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_RESOURCE_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_RESOURCE_IS_VALID; +// 仅针对ATTR_NAME_ATTACHED_SYNC_RES_INFO_LIST属性内部,标识event/notify类型,默认为event类型 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATTACHED_RESOURCE_TYPE; + +// skip prune +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SKIP_PRUNE_OPTIMIZE; + +// for static shape reuse operator binary +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_RUN_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_OP_RUN_INFO; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MEMSET_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KERNEL_BIN_ID; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TILING_POLICY; + +// local reduce +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REDUCE_OP_ATTR_KEEP_DIMS; + +// logic stream id +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOGIC_STREAM_ID; + +// storage format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_HEAVY_OP; +// tiling depend +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_TILING_DEPEND_OP; + +// for infer shape +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRESET_OUTPUT_SHAPES; +// session options +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_OPTIONS; + +// graph options +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GRAPH_OPTIONS; +// host tensor +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HOST_TENSOR; +// 是否支持分档属性 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENABLE_DYNAMIC_BATCH; +// 需要分档输入索引 +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_BATCH_UNKNOWN_DATA_INDEX; +// 分档中shape是-1的dim_index +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_BATCH_UNKNOWN_DIM_INDEX; +// for inplace ability +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_INPLACE_ABILITY; +// for opp_kernel +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BINARY_SOURCE; +} // namespace ge + +/*lint +e618*/ +#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/inc/metadef/graph/debug/ge_log.h b/inc/metadef/graph/debug/ge_log.h new file mode 100644 index 0000000000000000000000000000000000000000..4dc00227c3540ee4abee05f1befcf2df909835ea --- /dev/null +++ b/inc/metadef/graph/debug/ge_log.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_DEBUG_GE_LOG_H_ +#define COMMON_GRAPH_DEBUG_GE_LOG_H_ + +#include "graph/ge_error_codes.h" +#include "common/ge_common/debug/ge_log.h" +#include "common/ge_common/debug/log.h" + +#define GE_LOGE(fmt, ...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, fmt, ##__VA_ARGS__) + +// Only check error log +#define GE_CHK_BOOL_ONLY_LOG(expr, ...) \ + do { \ + const bool b = (expr); \ + if (!b) { \ + GELOGI(__VA_ARGS__); \ + } \ + } while (false) + +#endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ + diff --git a/inc/metadef/graph/debug/ge_op_types.h b/inc/metadef/graph/debug/ge_op_types.h new file mode 100644 index 0000000000000000000000000000000000000000..0e8ae1b4c89fb1f7e78b681ed8bb603d8cfd3b1e --- /dev/null +++ b/inc/metadef/graph/debug/ge_op_types.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ +#define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ + +#include "graph/types.h" +#include "graph/compiler_options.h" + +#define GE_REGISTER_OPTYPE(var_name, str_name) static const ge::char_t *var_name METADEF_ATTRIBUTE_UNUSED = str_name +namespace ge { +GE_REGISTER_OPTYPE(DATA, "Data"); +GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); +GE_REGISTER_OPTYPE(QUEUE_DATA, "QueueData"); +GE_REGISTER_OPTYPE(REFDATA, "RefData"); +GE_REGISTER_OPTYPE(MATMUL, "MatMul"); +GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); +GE_REGISTER_OPTYPE(PERMUTE, "Permute"); +GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); +GE_REGISTER_OPTYPE(_WHILE, "_While"); +GE_REGISTER_OPTYPE(WHILE, "While"); +GE_REGISTER_OPTYPE(CASE, "Case"); +GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); +GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); +GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); +GE_REGISTER_OPTYPE(SWITCH, "Switch"); +GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); +GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); +GE_REGISTER_OPTYPE(MERGE, "Merge"); +GE_REGISTER_OPTYPE(REFMERGE, "RefMerge"); +GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); +GE_REGISTER_OPTYPE(STREAMSWITCH, "StreamSwitch"); +GE_REGISTER_OPTYPE(STREAMACTIVE, "StreamActive"); +GE_REGISTER_OPTYPE(LABELSET, "LabelSet"); +GE_REGISTER_OPTYPE(LABELGOTOEX, "LabelGotoEx"); +GE_REGISTER_OPTYPE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); +GE_REGISTER_OPTYPE(ENTER, "Enter"); +GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); +GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); +GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); +GE_REGISTER_OPTYPE(CONSTANT, "Const"); +GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); +GE_REGISTER_OPTYPE(CONSTPLACEHOLDER, "ConstPlaceHolder"); +GE_REGISTER_OPTYPE(END, "End"); +GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); +GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); +GE_REGISTER_OPTYPE(INITDATA, "InitData"); +GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); +GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); +GE_REGISTER_OPTYPE(ATOMICADDRCLEAN, "AtomicAddrClean"); +GE_REGISTER_OPTYPE(MEMSET, "MemSet"); + +GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); +GE_REGISTER_OPTYPE(VARIABLE, "Variable"); +GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); +GE_REGISTER_OPTYPE(VARHANDLEOP, "VarHandleOp"); +GE_REGISTER_OPTYPE(PARTITIONEDCALL, "PartitionedCall"); + +GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); +GE_REGISTER_OPTYPE(FILECONSTANT, "FileConstant"); + +// Horovod operator +GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); +GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); +GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); +GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); + +GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); + +GE_REGISTER_OPTYPE(RECV, "Recv"); +GE_REGISTER_OPTYPE(RECV_NOTIFY, "RecvNotify"); +GE_REGISTER_OPTYPE(SEND, "Send"); +GE_REGISTER_OPTYPE(SEND_NOTIFY, "SendNotify"); +GE_REGISTER_OPTYPE(NOOP, "NoOp"); +GE_REGISTER_OPTYPE(ASSIGN, "Assign"); +GE_REGISTER_OPTYPE(ASSIGNVARIABLEOP, "AssignVariableOp"); +GE_REGISTER_OPTYPE(ASSIGNADD, "AssignAdd"); +GE_REGISTER_OPTYPE(ASSIGNADDVARIABLEOP, "AssignAddVariableOp"); +GE_REGISTER_OPTYPE(ASSIGNSUB, "AssignSub"); +GE_REGISTER_OPTYPE(ASSIGNSUBVARIABLEOP, "AssignSubVariableOp"); +GE_REGISTER_OPTYPE(IDENTITY, "Identity"); +GE_REGISTER_OPTYPE(READVARIABLEOP, "ReadVariableOp"); + +GE_REGISTER_OPTYPE(PHONYCONCAT, "PhonyConcat"); +GE_REGISTER_OPTYPE(PHONYSPLIT, "PhonySplit"); +} // namespace ge +#endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ diff --git a/inc/metadef/graph/debug/ge_util.h b/inc/metadef/graph/debug/ge_util.h new file mode 100644 index 0000000000000000000000000000000000000000..fdb2dd648850e131ff2b9b0dfd2e7b8478cab284 --- /dev/null +++ b/inc/metadef/graph/debug/ge_util.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_DEBUG_GE_UTIL_H_ +#define COMMON_GRAPH_DEBUG_GE_UTIL_H_ + +#include "common/ge_common/util.h" +#include "graph/debug/ge_log.h" +#include "graph/ge_error_codes.h" +namespace ge { +template +static inline std::shared_ptr ComGraphMakeShared(Args &&...args) { + using T_nc = typename std::remove_const::type; + std::shared_ptr ret = nullptr; + try { + ret = std::make_shared(std::forward(args)...); + } catch (const std::bad_alloc &) { + ret = nullptr; + GELOGE(ge::FAILED, "Make shared failed"); + } + return ret; +} +template +struct ComGraphMakeUniq { + using unique_object = std::unique_ptr; +}; + +template +struct ComGraphMakeUniq { + using unique_array = std::unique_ptr; +}; + +template +struct ComGraphMakeUniq { + struct invalid_type { }; +}; + +template +static inline typename ComGraphMakeUniq::unique_object ComGraphMakeUnique(Args &&... args) { + using T_nc = typename std::remove_const::type; + return std::unique_ptr(new (std::nothrow) T_nc(std::forward(args)...)); +} + +template +static inline typename ComGraphMakeUniq::unique_array ComGraphMakeUnique(const size_t num) { + return std::unique_ptr(new (std::nothrow) typename std::remove_extent::type[num]()); +} + +template +static inline typename ComGraphMakeUniq::invalid_type ComGraphMakeUnique(Args &&...) = delete; +} // namespace ge +#endif // COMMON_GRAPH_DEBUG_GE_UTIL_H_ diff --git a/inc/metadef/graph/def_types.h b/inc/metadef/graph/def_types.h new file mode 100644 index 0000000000000000000000000000000000000000..1a040e78a95c8270b0fec976416b30425c304415 --- /dev/null +++ b/inc/metadef/graph/def_types.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_DEF_TYPES_H_ +#define INC_GRAPH_DEF_TYPES_H_ + +#include +#include +#include +namespace ge { +inline uint64_t PtrToValue(const void *const ptr) { + return static_cast(reinterpret_cast(ptr)); +} + +inline void *ValueToPtr(const uint64_t value) { + return reinterpret_cast(static_cast(value)); +} + +inline std::vector VPtrToValue(const std::vector v_ptr) { + std::vector v_value; + for (const auto &ptr : v_ptr) { + v_value.emplace_back(PtrToValue(ptr)); + } + return v_value; +} + +template +inline TO *PtrToPtr(TI *const ptr) { + return reinterpret_cast(ptr); +} + +template +inline const TO *PtrToPtr(const TI *const ptr) { + return reinterpret_cast(ptr); +} + +template +inline T *PtrAdd(T *const ptr, const size_t max_buf_len, const size_t idx) { + if ((ptr != nullptr) && (idx < max_buf_len)) { + return reinterpret_cast(ptr + idx); + } + return nullptr; +} +} // namespace ge + +#endif // INC_GRAPH_DEF_TYPES_H_ diff --git a/inc/metadef/graph/detail/any_map.h b/inc/metadef/graph/detail/any_map.h new file mode 100644 index 0000000000000000000000000000000000000000..758eebf0f50501a0bdbbc92233b1d24feea5d113 --- /dev/null +++ b/inc/metadef/graph/detail/any_map.h @@ -0,0 +1,144 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_DETAIL_ANY_MAP_H_ +#define INC_GRAPH_DETAIL_ANY_MAP_H_ + +#include +#include +#include +#include +#include + +#include "graph/compiler_options.h" + +namespace ge { +class TypeID { + public: + template + static TypeID Of() { + return TypeID(METADEF_FUNCTION_IDENTIFIER); + } + + ~TypeID() = default; + + bool operator==(const TypeID &other) const { return type_ == other.type_; } + bool operator!=(const TypeID &other) const { return type_ != other.type_; } + + private: + explicit TypeID(const std::string &type) : type_(type) {} + + std::string type_; +}; + +class AnyMap { + public: + template + bool Set(const std::string &name, const DT &val); + + template + bool Get(const std::string &name, T &ret_value) const; + + template + const T *Get(const std::string &name) const; + + bool Has(const std::string &name) const { return any_values_.find(name) != any_values_.end(); } + + void Swap(AnyMap &other) { + any_values_.swap(other.any_values_); + } + + void Names(std::set &names) const { + for (const auto &item : any_values_) { + (void)names.emplace(item.first); + } + } + bool Erase(const std::string &name) { return (any_values_.erase(name) > 0);} + + private: + class Placeholder { + public: + Placeholder() = default; + virtual ~Placeholder() = default; + Placeholder(const Placeholder &) = delete; + Placeholder &operator=(const Placeholder &) = delete; + Placeholder(Placeholder &&) = delete; + Placeholder &operator=(Placeholder &&) = delete; + virtual const TypeID &GetTypeInfo() const = 0; + }; + + template + class Holder : public Placeholder { + public: + explicit Holder(const VT &value) : Placeholder(), value_(value) {} + + ~Holder() override = default; + + const TypeID &GetTypeInfo() const override { + static const TypeID typeId = TypeID::Of(); + return typeId; + } + + friend class AnyMap; + + private: + const VT value_; + }; + + std::map> any_values_; +}; + +template +bool AnyMap::Set(const std::string &name, const DT &val) { + const auto it = any_values_.find(name); + + std::shared_ptr> tmp; + try { + tmp = std::make_shared>(val); + } catch (...) { + tmp = nullptr; + return false; + } + + if (it == any_values_.end()) { + (void) any_values_.emplace(name, tmp); + } else { + if (it->second && (it->second->GetTypeInfo() == TypeID::Of
())) { + it->second = tmp; + } else { + return false; + } + } + return true; +} + +template +bool AnyMap::Get(const std::string &name, T &ret_value) const { + auto tp = Get(name); + if (tp == nullptr) { + return false; + } + ret_value = *tp; + return true; +} + +template +const T *AnyMap::Get(const std::string &name) const { + const auto iter = any_values_.find(name); + if (iter == any_values_.end()) { + return nullptr; + } + if (iter->second->GetTypeInfo() != TypeID::Of()) { + return nullptr; + } + const auto holder = static_cast*>(iter->second.get()); + return &holder->value_; +} +} // namespace ge +#endif // INC_GRAPH_DETAIL_ANY_MAP_H_ diff --git a/inc/metadef/graph/detail/attributes_holder.h b/inc/metadef/graph/detail/attributes_holder.h new file mode 100644 index 0000000000000000000000000000000000000000..123818741e46b87cd089d2cfb9727df7f23b746f --- /dev/null +++ b/inc/metadef/graph/detail/attributes_holder.h @@ -0,0 +1,232 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ +#define INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "graph/detail/any_map.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/attr_store.h" + +namespace google { +namespace protobuf { +class Message; +template +class Map; +} // namespace protobuf +} // namespace google + +namespace ge { +namespace proto { +class AttrDef; +class TensorDef; +class TensorDescriptor; +class ShapeDef; +class NamedAttrs; +class ModelDef; +class OpDef; +class GraphDef; +} // namespace proto + +using ProtoAttrMap = AttrStore; +using ConstProtoAttrMap = const AttrStore; +using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; + +template +class GeIrProtoHelper { + public: + GeIrProtoHelper(const ProtoMsgOwner &protoOwner, ProtoType *const protoMsg) + : protoOwner_(protoOwner), protoMsg_(protoMsg) {} + + GeIrProtoHelper() { + protoOwner_ = std::shared_ptr<::google::protobuf::Message>(nullptr); + protoMsg_ = nullptr; + } + virtual ~GeIrProtoHelper() = default; + + template + GeIrProtoHelper(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOwner_; + protoMsg_ = other.protoMsg_; + } + + GeIrProtoHelper(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOwner_; + protoMsg_ = other.protoMsg_; + } + template + GeIrProtoHelper &operator=(const GeIrProtoHelper &other) { + protoOwner_ = other.protoOnwer_; + protoMsg_ = other.protoMsg_; + return *this; + } + + GeIrProtoHelper &operator=(const GeIrProtoHelper &other) { + if (this != &other) { + protoOwner_ = other.protoOwner_; + protoMsg_ = other.protoMsg_; + } + return *this; + } + void InitDefault(); + template + bool operator==(const GeIrProtoHelper &other) const { + return (protoOwner_ == other.protoOwner_) && (protoMsg_ == other.protoMsg_); + } + + inline const ProtoMsgOwner &GetProtoOwner() const { + return protoOwner_; + } + inline ProtoType *GetProtoMsg() const { + return protoMsg_; + } + void CopyValueFrom(const GeIrProtoHelper &other) { + if ((other.protoMsg_ != nullptr) && (protoMsg_ != nullptr)) { + *protoMsg_ = *other.protoMsg_; + } + } + void MoveValueFrom(GeIrProtoHelper &&other) { + if ((other.protoMsg_ != nullptr) && (protoMsg_ != nullptr)) { + *protoMsg_ = std::move(*other.protoMsg_); + } + } + + void Swap(GeIrProtoHelper &other) { + protoOwner_.swap(other.protoOwner_); + + ProtoType *const temp = protoMsg_; + protoMsg_ = other.protoMsg_; + other.protoMsg_ = temp; + } + + friend class GeIrProtoHelper::value, typename std::remove_const::type, const ProtoType>::type>; + friend class ComputerGraphImpl; + friend class GeTensorSerializeUtils; + + private: + // protoMsg_ is part of protoOwner_, they have the same runtime + ProtoMsgOwner protoOwner_ = nullptr; + ProtoType *protoMsg_ = nullptr; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { + public: + AttrHolder() = default; + virtual ~AttrHolder() = default; + /** + * 对当前AttrHolder对象设置属性,属性名为`name`,属性值为`value`, + * 需要注意的是 如果当前对象已经存在了`name`类型的属性,接口会进行刷新值的操作 + * @param name + * @param value + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + graphStatus SetAttr(const std::string &name, const AnyValue &value); + /** + * 尝试对当前AttrHolder对象设置属性,属性名为`name`,属性值为`value`, + * 需要注意的是 如果当前对象已经存在了`name`类型的属性,接口并不会进行刷新值的操作 + * @param name + * @param value + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + graphStatus TrySetAttr(const std::string &name, const AnyValue &value); + + graphStatus GetAttr(const std::string &name, AnyValue &value) const; + + bool HasAttr(const std::string &name) const; + bool HasRequiredAttr(const std::string &name) const { + return required_attrs_and_type_.find(name) != required_attrs_and_type_.end(); + } + graphStatus DelAttr(const std::string &name); + + void CopyAttrsFrom(const AttrHolder &holder); + + void CopyFrom(const AttrHolder &holder); + + void SwapBase(AttrHolder &holder) { + required_attrs_and_type_.swap(holder.required_attrs_and_type_); + ext_attrs_.Swap(holder.ext_attrs_); + } + /** + * 对当前对象设置名称为name, 值为value,类型为T的属性,如果对象已经存在了 + * 名称name和类型T的属性,那么此接口会刷新属性的值,需要注意的是如果对象已 + * 经存在了名称name和非类型T的属性,设置行为会失败告终 + * @param name 属性的名称 + * @param value 任意类型的属性值 + * @return true/false 设置成功返回true, 设置失败返回false + */ + template + bool SetExtAttr(const std::string &name, const T &value) { + return ext_attrs_.Set(name, value); + } + /** + * 对当前对象尝试获取名称name, 类型为T的属性值,如果对象没有name名称的属性,或者 + * 属性的类型不为T, 查询行为失败 + * @param name 属性的名称 + * @param defaultValue 默认值,用于查询失败时返回这个默认值 + * @return 如果查询成功,返回查询到的属性值,如果查询失败,返回传入的默认值 + */ + template + T TryGetExtAttr(const std::string &name, const T defaultValue) const { + T ret(defaultValue); + (void) ext_attrs_.Get(name, ret); + return ret; + } + + /** + * 对当前对象尝试获取名称name, 类型为T的属性值,如果对象没有name名称的属性,或者 + * 属性的类型不为T, 查询行为失败 + * @param name 属性的名称 + * @return 如果查询成功,返回查询到的属性值的指针,如果查询失败,返回空指针 + */ + template + const T *GetExtAttr(const std::string &name) const { + return ext_attrs_.Get(name); + } + + template + T *GetExtAttr(const std::string &name) { + return const_cast(ext_attrs_.Get(name)); + } + + bool DelExtAttr(const std::string &name) { + return ext_attrs_.Erase(name); + } + graphStatus AddRequiredAttr(const std::string &name); + graphStatus AddRequiredAttrWithType(const std::string &name, const std::string &type); + const std::unordered_map &GetRequiredAttrWithType() const { + return required_attrs_and_type_; + } + protected: + const std::set GetAllAttrNames() const; + const std::map GetAllAttrs() const; + const std::map GetAllAttrsWithFilter(const AttrNameFilter &attr_filter) const; + + virtual ProtoAttrMap &MutableAttrMap() = 0; + virtual ConstProtoAttrMap &GetAttrMap() const = 0; + + friend class ModelSerializeImp; + friend class AttrUtils; + friend class OpDescUtils; + friend class GraphUtils; + std::unordered_map required_attrs_and_type_; + + private: + AnyMap ext_attrs_; +}; +} // namespace ge +#endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/inc/metadef/graph/detail/model_serialize_imp.h b/inc/metadef/graph/detail/model_serialize_imp.h new file mode 100644 index 0000000000000000000000000000000000000000..b39209c2f5a9eba5c58e54adc93b26904db91847 --- /dev/null +++ b/inc/metadef/graph/detail/model_serialize_imp.h @@ -0,0 +1,159 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ +#define INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ + +#include +#include +#include +#include +#include +#include "graph/buffer.h" +#include "graph/model.h" +#include "graph/anchor.h" +#include "graph/detail/attributes_holder.h" +#include "graph/ge_tensor.h" +#include "graph/graph.h" +#include "graph/node.h" +#include "external/ge_common/ge_api_types.h" + +namespace ge { +using ComputeGraphPtr = std::shared_ptr; +using AnchorWithIndex = std::pair; +struct MyCmp { + bool operator()(const AnchorWithIndex &anchor1, const AnchorWithIndex &anchor2) const { + return anchor1.second < anchor2.second; + } +}; +using DstAnchors = std::set; + +struct NodeNameGraphReq { + public: + NodeNameGraphReq(const std::string &name, const int32_t index, const ComputeGraphPtr &graph) + : node_name(name), index(index), graph(graph) {} + friend class ModelSerializeImp; + + private: + std::string node_name; + int32_t index; + ComputeGraphPtr graph; +}; + +struct NodeNameNodeReq { + public: + NodeNameNodeReq(const std::string &src_name, const int32_t src_index, const int32_t src_out_peer_index, + const NodePtr dst_node, const int32_t dst_index, const std::string &dst_name) + : src_node_name(src_name), + src_out_index(src_index), + src_out_peer_index(src_out_peer_index), + dst_node(dst_node), + dst_in_index(dst_index), + dst_node_name(dst_name) {} + + friend class ModelSerializeImp; + private: + std::string src_node_name; + int32_t src_out_index; + int32_t src_out_peer_index; + NodePtr dst_node; + int32_t dst_in_index; + std::string dst_node_name; +}; + +class ModelSerializeImp { + public: + bool SerializeModel(const Model &model, proto::ModelDef *const model_proto, const bool not_dump_all = false) const; + // if is_dump_graph is true, ensure peer anchors of node in the same order during serialization and deserialization + // if is_dump_graph is false, cannot guarantee peer anchors in the same order during serialization and deserialization + bool SerializeModel(const Model &model, const bool is_dump_graph, proto::ModelDef *const model_proto, + const bool not_dump_all = false) const; + + bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *const graph_proto, + const bool not_dump_all = false) const; + bool SerializeGraph(const ConstComputeGraphPtr &graph, const bool is_dump_graph, proto::GraphDef *const graph_proto, + const bool not_dump_all = false) const; + + bool SerializeEdge(const NodePtr &node, proto::OpDef *const op_def_proto, const bool is_dump_graph = false) const; + + bool SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *const op_def_proto, + const bool not_dump_all = false) const; + + bool SerializeNode(const NodePtr &node, proto::OpDef *const op_def_proto, const bool not_dump_all = false) const; + bool SerializeNode(const NodePtr &node, const bool is_dump_graph, proto::OpDef *const op_def_proto, + const bool not_dump_all = false) const; + + bool SeparateModelDef(Buffer &buffer, const std::string &path, proto::ModelDef &model_def) const; + + bool SerializeToBuffer(const proto::ModelDef &model_def, Buffer &buffer) const; + + bool UnserializeModel(Model &model, proto::ModelDef &model_proto, + const bool is_enable_multi_thread = false); + bool SetWeightForModel(proto::OpDef &op_def) const; + + bool UnserializeGraphWithoutEdge(ComputeGraphPtr &graph, proto::GraphDef &graph_proto); + + bool UnserializeGraph(ComputeGraphPtr &graph, proto::GraphDef &graph_proto); + + bool HandleNodeNameRef(); + + void AttrDefToOpDescIrDef(OpDescPtr &op_desc, proto::OpDef &op_def_proto) const; + bool UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) const; + void AttrDefToOpDescIn(OpDescPtr &op_desc, std::vector &key_in, std::vector &value_in) const; + void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_out, std::vector &value_out, + const std::vector &opt_input) const; + void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *const op_def_proto, + const bool not_dump_all = false) const; + void OpDescIrDefToAttrDef(const ConstOpDescPtr &op_desc, + google::protobuf::Map *op_desc_attr) const; + bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op_def_proto); + + bool ParseNodeIndex(const std::string &node_index, std::string &node_name, int32_t &index) const; + + void SetProtobufOwner(const ProtoMsgOwner &buffer_proto_buf_onwer) { protobuf_owner_ = buffer_proto_buf_onwer; } + + bool LoadWeightFromFile(const std::string &file_path, const int64_t &length, std::string &weight) const; + + void SetAirModelPath(const std::string &path) { air_path_ = path; } + static bool SerializeAllAttrsFromAnyMap(const std::map &attr_map, + google::protobuf::Map *const mutable_attr); + static bool DeserializeAllAttrsToAttrHolder( + const google::protobuf::Map &proto_attr_map, AttrHolder *const attr_holder); + + private: + bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map &subgraphs) const; + Status ParallelUnserializeGraph( + std::map &graphs, + ::google::protobuf::RepeatedPtrField &graphs_proto); + Status UnserializeGraph( + std::map &graphs, + ::google::protobuf::RepeatedPtrField &graphs_proto); + void FixOpDefSubgraphInstanceName(const ConstOpDescPtr &op_desc) const; + + void ExtractMetaDataAttrIn(proto::OpDef &op_def_proto, std::vector &opt_input, + std::vector &key_in, std::vector &value_in) const; + void ExtractMetaDataAttr(proto::OpDef &op_def_proto, std::vector &key_out, + std::vector &value_out) const; + + int64_t GenDataInputInfo(const OutDataAnchorPtr &src_anchor, const InDataAnchorPtr &dst_anchor) const; + int64_t GenCtrlInputInfo(const OutControlAnchorPtr &src_anchor, const InControlAnchorPtr &dst_anchor) const; + void SaveEdgeInfo(const AnchorPtr &src_anchor, const AnchorPtr &dst_anchor, const int64_t src_out_peer_index, + const int64_t cur_index, std::unordered_map &edges) const; + bool LinkEdges(const std::unordered_map &edges) const; + + std::vector graph_input_node_names_; + std::vector graph_output_node_names_; + std::vector node_input_node_names_; + std::map node_map_; + ProtoMsgOwner protobuf_owner_; + std::string air_path_; // path store air model path +}; +} // namespace ge + +#endif // INC_GRAPH_DETAIL_MODEL_SERIALIZE_IMP_H_ diff --git a/inc/metadef/graph/fast_graph/edge.h b/inc/metadef/graph/fast_graph/edge.h new file mode 100644 index 0000000000000000000000000000000000000000..148ec76d4583e1138640e4fd5af166ab3bf54c7b --- /dev/null +++ b/inc/metadef/graph/fast_graph/edge.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_EDGE_H +#define INC_GRAPH_EDGE_H + +#include "graph/anchor.h" + +namespace ge { +constexpr int32_t kInvalidEdgeIndex = -2; +constexpr int32_t kControlEdgeIndex = -1; + +enum class DirectionType { + kDirectionInType, + kDirectionOutType, +}; + +template +struct Edge { + T *src = nullptr; // src node or output node + T *dst = nullptr; // dst node or input node + int32_t src_output = -1; // the output index of output node + int32_t dst_input = -1; // the input index of input node + int32_t in_edge_index = -1; // the record index of input node, it used to quickly find the edge in node. + int32_t out_edge_index = -1; // the record index of output node, it used to quickly find the edge in node. + Anchor *src_anchor_ptr = nullptr; // the reserved information. + Anchor *dst_anchor_ptr = nullptr; // the reserved information. +}; + +class FastNode; + +struct EdgeEndpoint { + FastNode *node; + int32_t index; + DirectionType type; +}; +struct EdgeEndpointWithDirection { + EdgeEndpointWithDirection() : node(nullptr), index(kInvalidEdgeIndex) {} + EdgeEndpointWithDirection(FastNode *n, int32_t i) : node(n), index(i) {} + bool operator<(const EdgeEndpointWithDirection &rhs) const { + if (node < rhs.node) { + return true; + } + if (node > rhs.node) { + return false; + } + return index < rhs.index; + } + bool operator==(const EdgeEndpointWithDirection &rhs) const { + return (node == rhs.node) && (index == rhs.index); + } + FastNode *node; + int32_t index; +}; +using EdgeDstEndpoint = EdgeEndpointWithDirection; +using EdgeSrcEndpoint = EdgeEndpointWithDirection; +} // namespace ge +#endif // INC_GRAPH_EDGE_H diff --git a/inc/metadef/graph/fast_graph/execute_graph.h b/inc/metadef/graph/fast_graph/execute_graph.h new file mode 100644 index 0000000000000000000000000000000000000000..056db5e4ea69d56ecc4730a80174252604512fba --- /dev/null +++ b/inc/metadef/graph/fast_graph/execute_graph.h @@ -0,0 +1,254 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_EXECUTE_GRAPH_H +#define INC_GRAPH_EXECUTE_GRAPH_H + +#include +#include +#include +#include "graph/fast_graph/fast_node.h" +#include "graph/fast_graph/list_element.h" + +namespace ge { +template +class FastGraphImpl; + +using FastNodeFilter = std::function; + +class ExecuteGraph : public std::enable_shared_from_this, public AttrHolder { + public: + struct SubGraphInfo { + std::shared_ptr sub_graph; + ListElement *quick_graph; + }; + + explicit ExecuteGraph(const std::string &name); + ~ExecuteGraph() override{}; + + /** + * The function is shallow copy for ExecuteGraph + */ + ExecuteGraph &operator=(ge::ExecuteGraph &exec_graph); + + /** + * The function is deep copy for ExecuteGraph + */ + ExecuteGraph &CompleteCopy(ge::ExecuteGraph &exec_graph); + + /** + * The function is used to add node of graph. + * The node push back to container in graph. + */ + FastNode *AddNode(const OpDescPtr &op); + FastNode *AddNode(FastNode *const fast_node); + FastNode *AddNode(const OpDescPtr &op, int64_t id); + + /** + * The function is used to add node of graph. + * The node push front to container in graph. + */ + FastNode *AddNodeFront(const OpDescPtr &op); + FastNode *AddNodeFront(FastNode *const fast_node); + + /** + * The function is used to remove node of graph. + * The node don`t release and it will push to free container which is used to store free obj. + */ + graphStatus RemoveJustNode(const FastNode *const fast_node); + + /** + * The function is used to add input node of graph. + */ + FastNode *AddInputNode(FastNode *const fast_node); + + /** + * The function is used to remove input node of graph. + */ + graphStatus RemoveInputNode(FastNode *const fast_node); + + /** + * The function is used to add output node of graph. + */ + FastNode *AddOutputNodeByIndex(FastNode *const fast_node, int32_t index); + + /** + * The function is used to remove output node of graph. + */ + graphStatus RemoveOutputNode(const FastNode *const fast_node); + + /** + * The function is used to add edge of graph. + */ + FastEdge *AddEdge(FastNode *const src, int32_t src_index, FastNode *const dst, int32_t dst_index); + + /** + * The function is used to remove edge of graph. + * The edge don`t release and it will push to free container which is used to store free obj. + */ + graphStatus RemoveEdge(const FastEdge *const edge); + + const FastNode *GetParentNodeBarePtr() const; + FastNode *GetParentNodeBarePtr(); + void SetParentNode(FastNode *const node); + + /** + * The function is used to directly add subgraph of graph without any check. + * The shared pointer of subgraph will record in graph. + */ + ExecuteGraph *AddSubGraph(const std::shared_ptr &sub_graph); + + /** + * The function will add subgraph After strict checking the valid of subgraph. + * The shared pointer of subgraph will record in graph. + */ + ExecuteGraph *AddSubGraph(const std::shared_ptr &sub_graph_ptr, const std::string &name); + + /** + * The function is used to remove subgraph of graph. + * The shared pointer of subgraph will clear in graph. + */ + graphStatus RemoveSubGraph(const ExecuteGraph *const sub_graph); + graphStatus RemoveSubGraph(const std::string &name); + + /** + * The function is used to get subgraph with name. + */ + ExecuteGraph *GetSubGraph(const std::string &name) const; + + /** + * remove all subgraph from parent graph. + */ + void ClearAllSubGraph(); + + /** + * get the number of direct nodes form graph. + */ + size_t GetDirectNodesSize() const; + + /** + * get direct nodes from graph (it is convert to vector which is long time). + * external modifications don`t affect internal nodes. + */ + std::vector GetDirectNode() const; + + /** + * get all edges from graph (it is convert to vector which is long time). + * external modifications don`t affect internal edges. + */ + std::vector GetAllEdges() const; + + /** + * get all sub graph from graph (it is convert to vector which is long time). + * external modifications don`t affect internal edges. + */ + std::vector GetAllSubgraphs() const; + + /** + * find the node with node token in the graph. + */ + const FastNode *FindNode(size_t token) const; + + /** + * is is topo sort (include dfs, bfs, DFS_POSTORDER). + */ + graphStatus TopologicalSortingGraph(const ExecuteGraph *const execute_graph, const bool dfs_reverse); + + /** + * get name of graph. + */ + std::string GetName() const; + + /** + * set name of graph. + */ + void SetName(const std::string &name); + + void SetParentGraph(ExecuteGraph *const parent_graph); + + const ExecuteGraph *GetParentGraphBarePtr(void) const; + ExecuteGraph *GetParentGraphBarePtr(void); + + /** + * topo sort in the graph (include sub graph). + */ + graphStatus TopologicalSorting(); + + /** + * push edge to free edge. + */ + graphStatus RecycleQuickEdge(const FastEdge *const fast_edge); + + /** + * push node to free edge. + */ + graphStatus RecycleQuickNode(const FastNode *const fast_node); + + /** + * get all of nodes in graph (include subgraph). + */ + std::vector GetAllNodes() const; + std::vector GetAllNodes(const FastNodeFilter &fast_node_filter) const; + + /** + * It is used to set input order which is used in topo sorting + */ + void SetInputsOrder(const std::vector &inputs_order); + + void ReorderByNodeId(); + + void SetGraphId(size_t graph_id); + + size_t GetGraphId() const; + + bool CheckNodeIsInGraph(const FastNode *const node) const; + + bool CheckEdgeIsInGraph(const FastEdge *const edge) const; + + /** + * The edge belong to graph. + * somtime, we need to change the owner of edge to correct graph. + */ + graphStatus MoveEdgeToGraph(const FastEdge *const edge); + + protected: + ProtoAttrMap &MutableAttrMap() override; + ConstProtoAttrMap &GetAttrMap() const override; + + private: + std::vector AllGraphNodes(std::vector> &subgraphs, + const FastNodeFilter &fast_node_filter) const; + void GetAllNodesFromOpdesc(std::vector> &subgraphs, const OpDesc &op_desc, + std::deque &candidates) const; + void RemoveNodeFromNodesFree(const FastNode *const fast_node) const; + graphStatus SortNodes(std::vector &stack, std::map &map_in_edge_num) const; + void GetOutNodesFromEdgesToMap(std::map &map_in_edge_num, FastNode *node, + std::map &breadth_node_map) const; + graphStatus CollectBreadthOutNode(const FastNode *const node, std::map &map_in_edge_num, + std::map &breadth_node_map) const; + graphStatus BFSTopologicalSorting(std::vector &node_vec, const bool reverse, + const ExecuteGraph *const compute_graph) const; + graphStatus DFSTopologicalSorting(std::vector &node_vec, const bool reverse, + const ExecuteGraph *const compute_graph) const; + graphStatus RDFSTopologicalSorting(std::vector &node_vec, const bool reverse, + const ExecuteGraph *const compute_graph) const; + void GetInNodes(const FastNode *const current, std::vector &input_nodes) const; + + private: + std::shared_ptr> graph_shared_; + std::unordered_map names_to_subgraph_; + std::vector inputs_order_; + AttrStore attrs_; + + friend class ExecuteGraphAdapter; + friend class ExecuteGraphUtils; +}; +using ExecuteGraphPtr = std::shared_ptr; +} // namespace ge +#endif // INC_GRAPH_EXECUTE_GRAPH_H diff --git a/inc/metadef/graph/fast_graph/fast_node.h b/inc/metadef/graph/fast_graph/fast_node.h new file mode 100644 index 0000000000000000000000000000000000000000..9fc0fb7da21d8b58aad9154dee99bbbb94c1aeff --- /dev/null +++ b/inc/metadef/graph/fast_graph/fast_node.h @@ -0,0 +1,433 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef D_INC_GRAPH_FAST_NODE_H +#define D_INC_GRAPH_FAST_NODE_H + +#include +#include +#include +#include +#include "graph/op_desc.h" +#include "graph/fast_graph/edge.h" + +namespace ge { +struct OutDataEdgeStatisticsInfo { + size_t total_num = 0UL; // The total number edge of all inputs or all outputs + std::vector per_edges_num; // The number of one input or one output. +}; +constexpr uint64_t kInvalidSymbol = UINT64_MAX; + +class ExecuteGraph; +class FastNode; +using FastEdge = Edge; + +class ExtendInfo { + public: + virtual ~ExtendInfo() {} + /** + * set the input index of graph. + */ + void SetInputIndex(int32_t idx); + + /** + * get the input index of graph. + * Don`t use this function unless it is explicitly required. + */ + int32_t GetInputIndex() const; + + /** + * add the output index of graph. + * Don`t use this function unless it is explicitly required. + */ + void AddOneOutputIndex(int32_t idx); + + /** + * get the output index of graph. + * Don`t use this function unless it is explicitly required. + */ + std::vector &GetOutputIndex(); + + /** + * get the owner graph of node. + */ + ExecuteGraph *GetOwnerGraphBarePtr() const; + + /** + * set the owner graph of node. + */ + graphStatus SetOwnerGraph(ExecuteGraph *const graph, const FastNode *const fast_node); + + /** + * check the extend information is same with r_info. + */ + bool operator==(const ExtendInfo &r_info) const; + + /** + * get the flag of host node. + */ + bool GetHostNode() const; + + /** + * set the flag of host node. + */ + void SetHostNode(const bool is_host); + + /** + * clear members. + */ + void Clear(); + + /** + * update size of input_symbols. + */ + void UpdateInputSymbols(size_t data_in_num); + + /** + * update size of output_symbols. + */ + void UpdateOutputSymbols(size_t data_out_num); + + /** + * set symbol of the idx data input + */ + graphStatus SetInputSymbol(size_t idx, uint64_t symbol); + + /** + * set symbol of the idx data output + */ + graphStatus SetOutputSymbol(size_t idx, uint64_t symbol); + + /** + * get symbol of the idx data input + */ + uint64_t GetInputSymbol(size_t idx); + + /** + * get symbol of the idx data output + */ + uint64_t GetOutputSymbol(size_t idx); + + private: + bool IsDataIndexValid(size_t idx, const std::vector &symbols) const; + ExecuteGraph *execute_graph_ = nullptr; + std::vector output_index_; + int32_t input_index_ = kControlEdgeIndex; + bool host_node_ = false; + std::vector input_symbols_{}; + std::vector output_symbols_{}; +}; + +class FastNode { + public: + /** + * construct a fastnode. + * please call Init after construction + */ + FastNode(); + + ~FastNode(); + /** + * The function is used to init node with opdesc. + */ + graphStatus Init(const OpDescPtr &op); + + /** + * get the bare pointer of op desc. + */ + OpDesc *GetOpDescBarePtr() const; + + /** + * get the shared pointer of op desc. + */ + OpDescPtr GetOpDescPtr() const; + + /** + * get the type of node. + */ + std::string GetType() const; + + /** + * get the type of node. + */ + const char *GetTypePtr() const; + + /** + * get the name of node. + */ + std::string GetName() const; + + /** + * get the name of node. + */ + const char *GetNamePtr() const; + + /** + * record the edge info to node. + * The funcion is not recommended, please used the AddEdge funcion istead. + */ + graphStatus RecordEdge(Edge *const edge, DirectionType type); + + /** + * clear the edge info to node. + * The funcion is not recommended, please used the RemoveEdge funcion istead. + */ + graphStatus EraseEdge(const Edge *const edge, DirectionType type); + + /** + * adjust the position of the edge in the node record. + */ + graphStatus MoveEdge(DirectionType type, int32_t io_idx, int32_t cur_array_index, int32_t replace_array_index); + + /** + * get a unique identifier of node. + * please notes the unique identifier is in the graph. + */ + size_t GetNodeToken() const; + + /** + * get the number of input data in the node. + */ + size_t GetDataInNum() const; + + /** + * get the number of output data in the node. + */ + size_t GetDataOutNum() const; + + /** + * update the number of data input in the node. + */ + void UpdateDataInNum(size_t new_num); + + /** + * update the number of data ouput in the node. + */ + void UpdateDataOutNum(size_t new_num); + + /** + * get the total number of output edges from the node. + */ + size_t GetAllOutEdgesSize() const; + size_t GetAllOutDataEdgesSize() const; + size_t GetAllOutControlEdgesSize() const; + + /** + * get the total number of input edges from the node. + */ + size_t GetAllInDataEdgesSize() const; + size_t GetAllInControlEdgesSize() const; + + /** + * check the node is same with r_node. + */ + bool operator==(const FastNode &r_node) const; + + /** + * get the total number of in edge from the node. + * the number include data edge and control edge. + */ + size_t GetAllInEdgeSize() const; + + /** + * collecting all input edge. + * please check the item, the item from vector may be nullptr. + * if the item is nullptr, it just continue to get next, no error handing is required. + */ + const std::vector *> &GetAllInDataEdgesRef() const; + + /** + * collecting all output edge. + * please check the item, the item from vector may be nullptr. + * if the item is nullptr, it just continue to get next, no error handing is required. + */ + const std::vector *> &GetAllOutControlEdgesRef() const; + const std::vector *>> &GetAllOutDataEdgesRef() const; + + /** + * collecting all output or input edge. + * it already filter the nullptr item. + */ + std::vector *> GetAllOutDataEdges() const; + std::vector *> GetAllOutControlEdges() const; + std::vector *> GetAllInDataEdges() const; + std::vector *> &MutableAllInDataEdges(); + + /** + * collecting input control edges with input index. + * please check the item, the item from vector may be nullptr. + * if the item is nullptr, it just continue to get next, no error handing is required. + */ + std::vector *> GetAllInControlEdges() const; + const std::vector *> &GetAllInControlEdgesRef() const; + + /** + * Check the number of out edge is zero. + */ + bool OutNodesIsEmpty() const; + + /** + * Set the relative node information. + * Don`t use this function unless it is explicitly required. + */ + void SetNodePtr(const std::shared_ptr &node); + + /** + * clear the relative node information. + * Don`t use this function unless it is explicitly required. + */ + void ClearNodePtr(); + void ClearNodeBarePtr(); + + /** + * get the relative node information. + * Don`t use this function unless it is explicitly required. + */ + std::shared_ptr GetNodePtr() const; + Node *GetNodeBarePtr() const; + + /** + * get the total number of edge with input index. + */ + size_t GetInEdgesSizeByIndex(int32_t idx) const; + + /** + * get the total number of edge with output index. + */ + size_t GetOutEdgesSizeByIndex(int32_t idx) const; + + /** + * collecting input data edge with input index. + * please check the item, the item from vector may be nullptr. + * if the item is nullptr, it just continue to get next, no error handing is required. + */ + Edge *GetInDataEdgeByIndex(int32_t idx) const; + + bool IsDirectlyControlledByNode(FastNode const *node) const; + + /** + * collecting all output edge with output index. + * please check the item, the item from vector may be nullptr. + * if the item is nullptr, it just continue to get next, no error handing is required. + */ + std::vector *> GetOutEdgesByIndex(int32_t idx) const; + const std::vector *> &GetOutEdgesRefByIndex(int32_t idx) const; + + /** + * remove all of edge in the node. + * please define remove_edge_func to delete edge. + * The general process is as follow: + * 1. clear the edge information in src node and dst node (use EraseEdge); + * 2. remove edge in container. + * 3. add the free edge to free container. + * example: + * node[1]->RemoveAllEdge([&compute_graph](FastEdge *e) { + * auto src_node = e->src; + * auto dst_node = e->dst; + * Utils::GetNode(src_node).EraseEdge(e, DirectionType::kDirectionOutType); + * Utils::GetNode(dst_node).EraseEdge(e, DirectionType::kDirectionInType); + * if (Utils::GetListElementAddr(e)->owner != nullptr) { + * Utils::GetListElementAddr(e)->owner->erase(Utils::GetListElementAddr(e)); + * } + * auto ret = compute_graph->RecycleQuickEdge(e); + * if ((ret != GRAPH_SUCCESS) && (e != nullptr)) { + * delete e; + * } + * }); + */ + void RemoveAllEdge(std::function *)> const &remove_edge_func); + + /** + * get the extend info of node. + */ + ExtendInfo *GetExtendInfo() const; + + /** + * get the numbers input edges which peer node is not NEXTITERATION or REFNEXTITERATION. + */ + size_t GetInEdgeSize() const; + + void UpdateOpDesc(const OpDescPtr &new_opdesc); + + /** + * get peer nodes from all input data edges. + */ + std::vector GetInDataNodes() const; + + /** + * get peer nodes from out data edges with index. + */ + std::vector GetOutDataNodesByIndex(int32_t index) const; + + /** + * get peer nodes from all out data edges. + */ + std::vector GetOutDataNodes() const; + + /** + * get peer nodes from all out control edges. + */ + std::vector GetOutControlNodes() const; + + /** + * get peer nodes from all in control edges. + */ + std::vector GetInControlNodes() const; + + /** + * get peer nodes from all out edges. + */ + std::vector GetAllOutNodes() const; + + /** + * get peer nodes from all in edges. + */ + std::vector GetAllInNodes() const; + + private: + graphStatus CheckAllInputParamter(DirectionType type, int32_t io_idx, int32_t cur_array_index, + int32_t replace_array_index) const; + inline bool CheckDataIndexIsValid(int32_t index, DirectionType type) const; + graphStatus Reset(); + void UpdateDataForIoNumChange(); + graphStatus RecordInControlEdge(FastEdge *const edge); + graphStatus RecordOutControlEdge(FastEdge *const edge); + graphStatus RecordInDataEdge(FastEdge *const edge, int32_t index); + graphStatus RecordOutDataEdge(FastEdge *const edge, int32_t index); + graphStatus EraseInControlEdge(const FastEdge *const edge); + graphStatus EraseOutControlEdge(const FastEdge *const edge); + graphStatus EraseInDataEdge(const FastEdge *const edge); + graphStatus EraseOutDataEdge(const FastEdge *const edge, int32_t index); + graphStatus ModifySizeByNodeType(const FastEdge *const fast_edge, size_t &in_edge_size) const; + + private: + std::string name_; + size_t node_token_ = 0UL; + OpDescPtr opdesc_ = nullptr; + std::shared_ptr self_ptr_ = nullptr; + Node *node_bare_ptr_ = nullptr; + + size_t data_in_num_ = 0UL; + size_t data_out_num_ = 0UL; + + mutable std::vector *> in_data_edges_; + std::vector *> in_control_edges_; + std::vector *> out_control_edges_; + std::vector *>> out_data_edges_; + + size_t in_data_edges_count_ = 0UL; + size_t in_control_edge_count_ = 0UL; + size_t out_control_edges_count_ = 0UL; + OutDataEdgeStatisticsInfo out_data_edges_info_; + + std::unique_ptr extend_info_ = nullptr; +}; + +} // namespace ge +#endif // D_INC_GRAPH_FAST_NODE_H diff --git a/inc/metadef/graph/fast_graph/list_element.h b/inc/metadef/graph/fast_graph/list_element.h new file mode 100644 index 0000000000000000000000000000000000000000..8e8a8843fd3039aa1a6e707a01955bfdb71d011c --- /dev/null +++ b/inc/metadef/graph/fast_graph/list_element.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef D_INC_GRAPH_LIST_NODE_H +#define D_INC_GRAPH_LIST_NODE_H + +namespace ge { +enum class ListMode { kWorkMode = 0, kFreeMode }; + +template +class QuickList; + +template +struct ListElement { + ListElement *next; + ListElement *prev; + QuickList *owner; + ListMode mode; + T data; + explicit ListElement(const T &x) : data(x), next(nullptr), prev(nullptr), owner(nullptr), mode(ListMode::kFreeMode) {} + bool operator==(const ListElement &r_ListElement) const { + return data == r_ListElement.data; + } + ListElement() : next(nullptr), prev(nullptr), owner(nullptr), mode(ListMode::kFreeMode) {} + void SetOwner(QuickList *new_owner) { + owner = new_owner; + } +}; +} // namespace ge +#endif diff --git a/inc/metadef/graph/fast_graph/repeated_iterator.h b/inc/metadef/graph/fast_graph/repeated_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..da475f63ddef6ab7600d8fcebf99e732d62652ab --- /dev/null +++ b/inc/metadef/graph/fast_graph/repeated_iterator.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_REPEATED_ITERATOR_H +#define METADEF_CXX_REPEATED_ITERATOR_H +#include +#include + +namespace ge { +template +class RepeatedIterator { + public: + using iterator_category = std::forward_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = T; + using pointer = T *; + using reference = T &; + using size_type = size_t; + + RepeatedIterator(size_type index, reference value) : index_(index), value_(value) {} + + reference operator*() const { + return value_; + } + + pointer operator->() const { + return &value_; + } + + RepeatedIterator &operator++() { + ++index_; + return *this; + } + RepeatedIterator operator++(int) { + RepeatedIterator ret = *this; + ++*this; + return ret; + } + + friend bool operator==(const RepeatedIterator &lhs, const RepeatedIterator &rhs) { + return (lhs.index_ == rhs.index_) && (&lhs.value_ == &rhs.value_); + } + friend bool operator!=(const RepeatedIterator &lhs, const RepeatedIterator &rhs) { + return !(lhs == rhs); + }; + + private: + size_type index_; + reference value_; +}; +} // namespace ge +#endif // METADEF_CXX_REPEATED_ITERATOR_H diff --git a/inc/metadef/graph/flow_graph/data_flow_attr_define.h b/inc/metadef/graph/flow_graph/data_flow_attr_define.h new file mode 100644 index 0000000000000000000000000000000000000000..958a4fb37d067895f5cbe4ccf7cbdbc615822f04 --- /dev/null +++ b/inc/metadef/graph/flow_graph/data_flow_attr_define.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_FLOW_GRAPH_DATA_FLOW_ATTR_DEFINE_H_ +#define INC_GRAPH_FLOW_GRAPH_DATA_FLOW_ATTR_DEFINE_H_ + +#include "graph/debug/ge_attr_define.h" + +namespace ge { +namespace dflow { +// Public attribute +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_PROCESS_POINTS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_IS_DATA_FLOW_GRAPH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_OUTPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_CONTAINS_N_MAPPING_NODE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_ENABLE_EXCEPTION_CATCH; + +// For data align +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_INPUTS_ALIGN_TIMEOUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_INPUTS_ALIGN_MAX_CACHE_NUM; +// whether dropout data when no align +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DATA_FLOW_INPUTS_ALIGN_DROPOUT; + +// For count batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_BATCH_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_SLIDE_STRIDE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_TIMEOUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_BATCH_DIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_PADDING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_COUNT_BATCH_DROP_REMAINDER; + +// For time batch +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_TIME_WINDOW; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_TIME_INTERVAL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_TIMEOUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_BATCH_DIM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_PADDING; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_TIME_BATCH_DROP_REMAINDER; + +// FlowFunc +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_FLOW_FUNC_BIN_PATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_FLOW_FUNC_FUNC_LIST; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_FLOW_FUNC_FUNC_INPUTS_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_FLOW_FUNC_FUNC_OUTPUTS_INDEX; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_FLOW_FUNC_FUNC_STREAM_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_FLOW_FUNC_INVOKE_KEYS; + +// for balance +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_BALANCE_SCATTER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_BALANCE_GATHER; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char *const ATTR_NAME_DYNAMIC_BALANCED_DISTRIBUTION_HCOM; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char + *const ATTR_NAME_DYNAMIC_BALANCED_DISTRIBUTION_HCOM_GROUP; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char + *const ATTR_NAME_DYNAMIC_BALANCED_DISTRIBUTION_HCOM_TAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const char + *const ATTR_NAME_DYNAMIC_BALANCED_DISTRIBUTION_HCOM_IS_RECV; +} // namespace dflow +} // namespace ge +#endif // INC_GRAPH_FLOW_GRAPH_DATA_FLOW_ATTR_DEFINE_H_ diff --git a/inc/metadef/graph/ge_attr_value.h b/inc/metadef/graph/ge_attr_value.h new file mode 100644 index 0000000000000000000000000000000000000000..8cf567ca4d0d95e6fd5155336c37df81045c05ac --- /dev/null +++ b/inc/metadef/graph/ge_attr_value.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_GE_ATTR_VALUE_H_ +#define INC_GRAPH_GE_ATTR_VALUE_H_ + +#include +#include +#include +#include +#include +#include +#include "graph/buffer.h" +#include "detail/attributes_holder.h" +#include "graph/ge_error_codes.h" +#include "graph/ge_tensor.h" +#include "graph/any_value.h" + +using std::string; +using std::vector; +using std::map; + +namespace ge { +class GeTensor; + +using ConstGeTensorPtr = std::shared_ptr; + +class ComputeGraph; +using ConstComputeGraphPtr = std::shared_ptr; + +class GeTensorDesc; + +template +bool SetAttrValue(AttrStore &attrs, const std::string &name, T &&value) { + return attrs.SetByName(name, std::forward(value)); +} + +template +bool GetAttrValue(const AttrStore &attrs, const std::string &name, T &value) { + const auto p = attrs.GetByName(name); + if (p == nullptr) { + return false; + } + value = *p; + return true; +} + +template +const T *GetAttrValue(const AttrStore &attrs, const std::string &name) { + return attrs.GetByName(name); +} + +template::type> +RT *SetAndGetAttrValue(AttrStore &attrs, const std::string &name, T &&value) { + if (!attrs.SetByName(name, std::forward(value))) { + return nullptr; + } + return attrs.MutableGetByName(name); +} + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { + public: + NamedAttrs() = default; + virtual ~NamedAttrs() = default; + void SetName(const std::string &name); + std::string GetName() const; + AnyValue GetItem(const std::string &key) const; + + protected: + ProtoAttrMap &MutableAttrMap() override; + ConstProtoAttrMap &GetAttrMap() const override; + + private: + AttrStore attrs_; + std::string name_; + + friend class GeAttrValueImp; +}; + +class AttrValueImpl { + public: + AttrValueImpl() = default; + ~AttrValueImpl() = default; + + friend class AttrValue; + friend class AttrHolder; + friend class Operator; + +private: + AnyValue geAttrValue_; +}; +} // namespace ge +#endif // INC_GRAPH_GE_ATTR_VALUE_H_ diff --git a/inc/metadef/graph/ge_context.h b/inc/metadef/graph/ge_context.h new file mode 100644 index 0000000000000000000000000000000000000000..47d1fc406d7fa21fc04ba5f9f23ef5ae1c4d8d8d --- /dev/null +++ b/inc/metadef/graph/ge_context.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_GE_CONTEXT_H_ +#define INC_GRAPH_GE_CONTEXT_H_ + +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/option/optimization_option.h" + +namespace ge { +class GEContext { + public: + graphStatus GetOption(const std::string &key, std::string &option); + const std::string &GetReadableName(const std::string &key); + bool GetHostExecFlag() const; + bool GetTrainGraphFlag() const; + bool IsOverflowDetectionOpen() const; + bool IsGraphLevelSat() const; + uint64_t GetInputFusionSize() const; + uint64_t SessionId() const; + uint32_t DeviceId() const; + int32_t StreamSyncTimeout() const; + int32_t EventSyncTimeout() const; + void Init(); + void SetSessionId(const uint64_t session_id); + void SetContextId(const uint64_t context_id); + void SetCtxDeviceId(const uint32_t device_id); + void SetStreamSyncTimeout(const int32_t timeout); + void SetEventSyncTimeout(const int32_t timeout); + graphStatus SetOptionNameMap(const std::string &option_name_map_json); + void SetMultiBatchShapeIndex(uint32_t graph_id, + const std::map> &data_index_and_shape_map); + const std::map> GetMultiBatchShapeIndex(uint32_t graph_id); + OptimizationOption &GetOo() const; + + private: + thread_local static uint64_t session_id_; + thread_local static uint64_t context_id_; + uint32_t device_id_ = 0U; + // GEContext不允许拓展新的成员变量 +}; // class GEContext + +/// Get context +/// @return +GEContext &GetContext(); +static_assert(sizeof(GEContext) == 4U, "Do not add member to a thread-safe global variable"); +} // namespace ge +#endif // INC_GRAPH_GE_CONTEXT_H_ diff --git a/inc/metadef/graph/ge_global_options.h b/inc/metadef/graph/ge_global_options.h new file mode 100644 index 0000000000000000000000000000000000000000..520ffb4182d48559a2e0a6f6f42c4a21024a69cf --- /dev/null +++ b/inc/metadef/graph/ge_global_options.h @@ -0,0 +1,21 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_GE_GLOBAL_OPTIONS_H_ +#define INC_GRAPH_GE_GLOBAL_OPTIONS_H_ + +#include +#include +#include + +namespace ge { +std::mutex &GetGlobalOptionsMutex(); +std::map &GetMutableGlobalOptions(); +} +#endif // INC_GRAPH_GE_GLOBAL_OPTIONS_H_ diff --git a/inc/metadef/graph/ge_local_context.h b/inc/metadef/graph/ge_local_context.h new file mode 100644 index 0000000000000000000000000000000000000000..a51a9f8de24d8337d238288aef1b28917386065a --- /dev/null +++ b/inc/metadef/graph/ge_local_context.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_GE_LOCAL_CONTEXT_H_ +#define INC_GRAPH_GE_LOCAL_CONTEXT_H_ + +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/option/optimization_option.h" + +namespace ge { +class GEThreadLocalContext { + public: + graphStatus GetOption(const std::string &key, std::string &option); + void SetGraphOption(std::map options_map); + void SetSessionOption(std::map options_map); + void SetGlobalOption(std::map options_map); + graphStatus SetOptionNameMap(const std::string &option_name_map_json); + const std::string &GetReadableName(const std::string &key); + + void SetStreamSyncTimeout(const int32_t timeout); + void SetEventSyncTimeout(const int32_t timeout); + int32_t StreamSyncTimeout() const; + int32_t EventSyncTimeout() const; + OptimizationOption &GetOo(); + + std::map GetAllGraphOptions() const; + std::map GetAllSessionOptions() const; + std::map GetAllGlobalOptions() const; + std::map GetAllOptions() const; + + private: + std::map graph_options_; + std::map session_options_; + std::map global_options_; + std::map option_name_map_; + int32_t stream_sync_timeout_ = -1; + int32_t event_sync_timeout_ = -1; + OptimizationOption optimization_option_; +}; // class GEThreadLocalContext + +GEThreadLocalContext &GetThreadLocalContext(); +} // namespace ge +#endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ diff --git a/inc/metadef/graph/ge_tensor.h b/inc/metadef/graph/ge_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..7ec94bea4d2b77d300384a5b110cb8df8c4bab87 --- /dev/null +++ b/inc/metadef/graph/ge_tensor.h @@ -0,0 +1,352 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_GE_TENSOR_H_ +#define INC_GRAPH_GE_TENSOR_H_ + +#include +#include +#include +#include +#include "detail/attributes_holder.h" +#include "graph/buffer.h" +#include "graph/aligned_ptr.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/small_vector.h" +#include "graph/ascend_limits.h" + +namespace ge { +class GeShapeImpl; +using GeShapeImplPtr = std::shared_ptr; + +class TensorDataImpl; +using TensorDataImplPtr = std::shared_ptr; + +class GeTensorDescImpl; +using GeTensorDescImplPtr = std::shared_ptr; + +class GeTensorImpl; +using GeTensorImplPtr = std::shared_ptr; + +class GeTensorSerializeUtils; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { + public: + GeShape(); + ~GeShape(); + explicit GeShape(std::vector s); + + /* + * `GetDimNum()`标识有效的dim的个数,跟`GetDims().size()`不等价,调用方按需选择 + * 比如如果dim是[-2], 维度未可知时: + * GetDimNum()会返回0; + * 而GetDims().size()会返回dim的个数,即1; + * 另外如果需要判断是否是标量,推荐使用接口`IsScalar` + */ + size_t GetDimNum() const; + void SetDimNum(const size_t dim_num); + void AppendDim(const int64_t dim_size); + bool IsUnknownDimNum() const; + void SetIsUnknownDimNum(); + // If the idx is invalid, return 0 + int64_t GetDim(const size_t idx) const; + graphStatus SetDim(const size_t idx, const int64_t value); + /* + * `GetDims`标识dim的个数,跟`GetDimNum()`不等价,调用方按需选择 + * 比如如果dim是[-2], 维度未可知时: + * GetDimNum()会返回0; + * 而GetDims().size()会返回dim的个数,即1; + * 另外如果需要判断是否是标量,推荐使用接口`IsScalar` + */ + std::vector GetDims() const; + const SmallVector &GetMutableDims() const; + + int64_t GetShapeSize() const; + std::string ToString() const; + +/** + * 根据tensor的shape的各个维度的dim值判断tensor是否是unknown shape + * @return + * 如果某一维的dim值小于0,那么返回true, 代表是unknown shape + * 如果所有维度的dim值都大于等于0,那么返回false, 代表是known shape + */ + bool IsUnknownShape() const; + +/** + * 根据tensor的shape的dim值的个数返回tensor是否是个标量 + * @return + * 如果dim的维度是0维,则返回true,代表是标量 + * 其他情况返回false, 代表非标量 + */ + bool IsScalar() const; + +/** + * 根据tensor的shape的dim值是否含0判断tensor是否是个空tensor + * @return + * 如果任一维度的dim值为0,则返回true,代表是空tensor + * 其他情况返回false, 代表非空tensor + */ + bool IsEmptyTensor() const; + + GeShape(const GeShape &other); + GeShape(GeShape &&other); + GeShape &operator=(const GeShape &other); + GeShape &operator=(GeShape &&other); + bool operator==(const GeShape &other) const; + + private: + GeShapeImplPtr impl_; + friend class GeTensorDesc; + friend class GeTensorDescImpl; + friend class GeTensorSerializeUtils; + friend class ModelSerialize; + // Create from proto obj + GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *const proto_msg); +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrHolder { + friend class TensorUtils; + friend class ModelSerialize; + + public: + GeTensorDesc(); + explicit GeTensorDesc(const GeShape &shape, const Format format = FORMAT_ND, const DataType dt = DT_FLOAT); + GeTensorDesc(const GeTensorDesc &desc); + GeTensorDesc(GeTensorDesc &&desc); + + ~GeTensorDesc() override; + bool operator==(const GeTensorDesc &r_ge_tensor_desc) const; + + void Update(const GeShape &shape, const Format format = FORMAT_ND, const DataType dt = DT_FLOAT); + + const GeShape &GetShape() const; + GeShape &MutableShape(); + void SetShape(const GeShape &shape); + void SetShape(GeShape &&shape); + + // set shape with -2, it stand for unknown shape + void SetUnknownDimNumShape(); + // for unknown shape + graphStatus SetValueRange(const std::vector> &range); + graphStatus GetValueRange(std::vector> &range) const; + graphStatus SetShapeRange(const std::vector> &range); + graphStatus SetOriginShapeRange(const std::vector> &range); + graphStatus GetShapeRange(std::vector> &range) const; + graphStatus GetOriginShapeRange(std::vector> &range) const; + + const GeShape &GetOriginShape() const; + GeShape &MutableOriginShape() const; + + void SetOriginShape(const GeShape &origin_shape); + bool IsOriginShapeInitialized() const; + + Format GetFormat() const; + void SetFormat(const Format format); + + Format GetOriginFormat() const; + void SetOriginFormat(const Format origin_format); + + const std::string GetExpandDimsRule() const; + void SetExpandDimsRule(const std::string &expand_dims_rule); + + void SetName(const std::string &name); + const std::string GetName() const; + + DataType GetDataType() const; + void SetDataType(const DataType data_type); + + DataType GetOriginDataType() const; + void SetOriginDataType(const DataType origin_data_type); + + std::vector GetRefPortIndex() const; + void SetRefPortByIndex(const std::vector &index); + + Placement GetPlacement() const; + void SetPlacement(const Placement placement); + + GeTensorDesc Clone() const; + GeTensorDesc &operator=(const GeTensorDesc &desc); + GeTensorDesc &operator=(GeTensorDesc &&desc); + + graphStatus IsValid() const; + + // Create from proto obj + explicit GeTensorDesc(proto::TensorDescriptor *const proto_msg); + using AttrHolder::DelAttr; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + template + T *GetOrCreateAttrsGroup() { + return MutableAttrMap().GetOrCreateAttrsGroup(); + } + + protected: + ProtoAttrMap &MutableAttrMap() override; + ConstProtoAttrMap &GetAttrMap() const override; + + private: + bool GeTensorDescAttrsAreEqual(const GeTensorDesc &r_ge_tensor_desc) const; + + friend class GeTensor; + friend class GeTensorImpl; + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class GeTensorSerializeUtils; + friend class OnnxUtils; + + GeTensorDescImplPtr impl_; + + GeShape &ShapeReference() const; + GeShape &OriginShapeReference() const; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorData { + public: + TensorData(); + ~TensorData(); + + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const Buffer &data); + graphStatus SetData(const TensorData &data); + graphStatus SetData(const uint8_t *const data, const size_t size); + graphStatus SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); /*lint !e148*/ + + graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); + + const uint8_t *MallocAlignedPtr(const size_t size); + + const std::uint8_t *data() const; + std::uint8_t *data(); + std::size_t size() const; + void clear(); + uint8_t operator[](const size_t index) const; + + std::size_t GetSize() const; + const std::uint8_t *GetData() const; + std::uint8_t *GetData(); + bool IsTensorDataValid() const; + + const std::shared_ptr &GetAlignedPtr(); + + // share data, share tensor_descriptor/aligned_ptr + // replace using TensorUtils::ShareTensorData(const TensorData &from, TensorData &to) + TensorData &operator=(const TensorData &other); + // share data share tensor_descriptor/aligned_ptr + // replace using TensorUtils::CreateShareTensorData(const TensorData &other) + TensorData(const TensorData &other); + // zero copy SetData + // replace using TensorUtils::ShareAlignedPtr(std::shared_ptr ptr, size_t size, TensorData &to) + void SetData(std::shared_ptr aligned_ptr, const size_t size); + private: + friend class GeTensor; + friend class GeTensorImpl; + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class GeTensorSerializeUtils; + friend class TensorUtils; + TensorDataImplPtr impl_; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { + public: + GeTensor(); + GeTensor(GeTensor &&other) noexcept; + explicit GeTensor(const GeTensorDesc &tensor_desc); + explicit GeTensor(const GeTensorDesc &tensor_desc, const std::vector &data); + explicit GeTensor(const GeTensorDesc &tensor_desc, const Buffer &data); + explicit GeTensor(const GeTensorDesc &tensor_desc, const uint8_t *const data, const size_t size); + explicit GeTensor(GeTensorDesc &&tensor_desc, std::vector &&data); + explicit GeTensor(const GeTensorDesc &tensor_desc, const size_t size); + ~GeTensor(); + + const GeTensorDesc &GetTensorDesc() const; + GeTensorDesc &MutableTensorDesc(); + void SetTensorDesc(const GeTensorDesc &tensor_desc); + + std::shared_ptr GetAlignedPtr(); + + const TensorData &GetData() const; + TensorData &MutableData(); + + bool IsTensorDataValid() const; + + graphStatus SetData(std::vector &&data); + graphStatus SetData(const std::vector &data); + graphStatus SetData(const Buffer &data); + graphStatus SetData(const uint8_t *const data, const size_t size); + graphStatus SetData(const TensorData &data); + graphStatus SetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); + + graphStatus ResetData(uint8_t *const data, const size_t size, const AlignedPtr::Deleter &delete_fuc); + + void ClearData(); + GeTensor Clone() const; + + // zero copy SetData + // replace using TensorUtils::ShareAlignedPtr + void SetData(std::shared_ptr aligned_ptr, const size_t size); + // zero copy construction, share aligned_ptr, do not share tensor_desc + // replace using TensorUtils::CreateShareTensor + GeTensor(const GeTensorDesc &tensor_desc, std::shared_ptr aligned_ptr, const size_t size); + // Share tensor_data, tensor_desc + // replace using TensorUtils::CreateShareTensor + GeTensor(const GeTensor &other); + // Share tensor_data, tensor_desc + // replace using TensorUtils::ShareTensor + GeTensor &operator=(const GeTensor &other); + GeTensor &operator=(GeTensor &&other); + + private: + friend class GeAttrValueImp; + friend class ModelSerializeImp; + friend class GeTensorSerializeUtils; + friend class OnnxUtils; + friend class TensorData; + friend class TensorUtils; + friend class TensorAdapter; + // Create from proto obj + GeTensor(const ProtoMsgOwner &proto_owner, proto::TensorDef *proto_msg); + explicit GeTensor(GeTensorImplPtr impl); + void BuildAlignerPtrWithProtoData(); + GeTensorImplPtr impl_; + GeTensorDesc &DescReference() const; +}; + +class GeTensorSerializeUtils { + public: + static void GeShapeAsProto(const GeShape &shape, proto::ShapeDef *proto); + static void GeTensorDescAsProto(const GeTensorDescImpl &desc, proto::TensorDescriptor *proto); + static void GeTensorDescAsProto(const GeTensorDesc &desc, proto::TensorDescriptor *proto); + static void GeTensorAsProto(const GeTensorImpl &tensor, proto::TensorDef *proto); + static void GeTensorAsProto(const GeTensor &tensor, proto::TensorDef *proto); + + static void AssembleGeShapeFromProto(const proto::ShapeDef *proto, GeShape &shape); + static void AssembleGeTensorDescFromProto(const proto::TensorDescriptor *const proto, GeTensorDesc &desc); + static void AssembleGeTensorFromProto(const proto::TensorDef *proto, GeTensor &tensor); + + // normalize the input TensorDescriptor,A metadata information maybe stored in different fields of TensorDescriptor, + // This function needs to prioritize and determine the final metadata information used. + // After standardization, the direct member field on TensorDescriptor is always valid + static void NormalizeGeTensorDescProto(proto::TensorDescriptor *proto); + static void GetShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape); + static void GetOriginShapeFromDescProto(const proto::TensorDescriptor *const proto, GeShape &shape); + static void GetDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype); + static void GetOriginDtypeFromDescProto(const proto::TensorDescriptor *const proto, DataType &dtype); + static void GetFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format); + static void GetOriginFormatFromDescProto(const proto::TensorDescriptor *const proto, Format &format); +}; +using GeTensorDescPtr = std::shared_ptr; +using ConstGeTensorDescPtr = std::shared_ptr; +} // namespace ge +#endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/inc/metadef/graph/graph_util.h b/inc/metadef/graph/graph_util.h new file mode 100644 index 0000000000000000000000000000000000000000..42e5de4bf4c96a67e4f0906edbc6603071a2c23d --- /dev/null +++ b/inc/metadef/graph/graph_util.h @@ -0,0 +1,127 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_GRAPH_UTIL_H_ +#define INC_GRAPH_GRAPH_UTIL_H_ + +#include + +namespace ge { +using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; +bool HasOpAttr(const OpDef *opdef, std::string attr_name); +bool GetOpAttr(const std::string &key, int32_t *value, const OpDef *opdef); + +static const char OP_TYPE_DATA[] = "Data"; +static const char OP_TYPE_INPUT[] = "Input"; +static const char ATTR_KEY_INPUT_FORMAT[] = "input_format"; +static const char ATTR_KEY_OUTPUT_FORMAT[] = "output_format"; +static const char OP_TYPE_ANN_DATA[] = "AnnData"; +} // namespace ge + +#if !defined(__ANDROID__) && !defined(ANDROID) +#include "toolchain/slog.h" +const char levelStr[4][8] = {"ERROR", "WARN", "INFO", "DEBUG"}; +#else +#include +#include +const char levelStr[8][8] = {"EMERG", "ALERT", "CRIT", "ERROR", "WARNING", "NOTICE", "INFO", "DEBUG"}; +#endif + +#ifdef _MSC_VER +#define FUNC_NAME __FUNCTION__ +#else +#define FUNC_NAME __PRETTY_FUNCTION__ +#endif + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) \ + dlog_info(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) \ + dlog_warn(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) \ + dlog_error(FMK, "%s:%s:%d:" #fmt, __FUNCTION__, __FILE__, __LINE__, ##__VA_ARGS__) +#else +#define D_GRAPH_LOG(level, format, ...) \ + do { \ + { \ + fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ + __FILE__, __LINE__, ##__VA_ARGS__); \ + syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ + ##__VA_ARGS__); \ + } \ + } while (0) +#define D_GRAPH_LOGI(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define D_GRAPH_LOGW(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define D_GRAPH_LOGE(MOD_NAME, fmt, ...) D_GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#endif + +#if !defined(__ANDROID__) && !defined(ANDROID) +#define GRAPH_LOGI(...) D_GRAPH_LOGI(GRAPH_MOD_NAME, __VA_ARGS__) +#define GRAPH_LOGW(...) D_GRAPH_LOGW(GRAPH_MOD_NAME, __VA_ARGS__) +#define GRAPH_LOGE(...) D_GRAPH_LOGE(GRAPH_MOD_NAME, __VA_ARGS__) +#else + +#define GRAPH_LOG(level, format, ...) \ + do { \ + { \ + fprintf(stdout, "[%s] [%s] [%s] [%s] [%s:%d] " format "\n", "", "GRAPH", levelStr[level], __FUNCTION__, \ + __FILE__, __LINE__, ##__VA_ARGS__); \ + syslog(level, "%s %s:%d] [%s] %s " format "\n", "", __FILE__, __LINE__, "OPTIMIZER", __FUNCTION__, \ + ##__VA_ARGS__); \ + } \ + } while (0) +#define GRAPH_LOGI(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define GRAPH_LOGW(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#define GRAPH_LOGE(fmt, ...) GRAPH_LOG(ANDROID_LOG_INFO, #fmt, ##__VA_ARGS__) +#endif + +#define GRAPH_CHK_STATUS_RET_NOLOG(expr) \ + do { \ + const domi::graphStatus _status = (expr); \ + if (_status != domi::GRAPH_SUCCESS) { \ + return _status; \ + } \ + } while (0) + +#define GRAPH_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GRAPH_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +// Do not add do...while(0), otherwise it wll introduce security issues +#define GRAPH_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ + { \ + bool b = (expr); \ + if (!b) { \ + exec_expr; \ + } \ + } + +// Do not add do...while(0), otherwise it wll introduce security issues +#define GRAPH_IF_BOOL_EXEC(expr, exec_expr) \ + { \ + if (expr) { \ + exec_expr; \ + } \ + } + +#define GRAPH_RETURN_WITH_LOG_IF_ERROR(expr, ...) \ + do { \ + const ::domi::graphStatus _status = (expr); \ + if (_status) { \ + GRAPH_LOGE(__VA_ARGS__); \ + return _status; \ + } \ + } while (0) + +#endif // INC_GRAPH_GRAPH_UTIL_H_ diff --git a/inc/metadef/graph/host_resource/host_resource.h b/inc/metadef/graph/host_resource/host_resource.h new file mode 100644 index 0000000000000000000000000000000000000000..69b69380780d557cd44e784dff05a91c79570d68 --- /dev/null +++ b/inc/metadef/graph/host_resource/host_resource.h @@ -0,0 +1,18 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_CXX_HOST_RESOURCE_H +#define INC_GRAPH_CXX_HOST_RESOURCE_H +namespace ge { +class HostResource { + public: + virtual ~HostResource() = default; +}; +} // namespace ge +#endif // INC_GRAPH_CXX_HOST_RESOURCE_H diff --git a/inc/metadef/graph/ir_definitions_recover.h b/inc/metadef/graph/ir_definitions_recover.h new file mode 100644 index 0000000000000000000000000000000000000000..d7ddf29a7f9be38809a29435592ec93f9d12740a --- /dev/null +++ b/inc/metadef/graph/ir_definitions_recover.h @@ -0,0 +1,20 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_IR_DEFINITIONS_RECOVER_H_ +#define GRAPH_IR_DEFINITIONS_RECOVER_H_ + +#include +#include "graph/compute_graph.h" + +namespace ge { +ge::graphStatus RecoverIrDefinitions(const ge::ComputeGraphPtr &graph, const vector &attr_names = {}); +ge::graphStatus RecoverOpDescIrDefinition(const ge::OpDescPtr &desc, const std::string &op_type = ""); +} // namespace ge +#endif // GRAPH_IR_DEFINITIONS_RECOVER_H_ diff --git a/inc/metadef/graph/model.h b/inc/metadef/graph/model.h new file mode 100644 index 0000000000000000000000000000000000000000..c72a0636b12ac958cfc6b6067ef893beb816bc7c --- /dev/null +++ b/inc/metadef/graph/model.h @@ -0,0 +1,97 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_MODEL_H_ +#define INC_GRAPH_MODEL_H_ + +#include +#include +#include "detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/compute_graph.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { + public: + Model(); + + ~Model() override = default; + + Model(const std::string &name, const std::string &custom_version); + + Model(const char_t *name, const char_t *custom_version); + + std::string GetName() const; + + void SetName(const std::string &name); + + uint32_t GetVersion() const; + + void SetVersion(const uint32_t version) { version_ = version; } + + std::string GetPlatformVersion() const; + + void SetPlatformVersion(const std::string version) { platform_version_ = version; } + + const ComputeGraphPtr GetGraph() const; + + void SetGraph(const ComputeGraphPtr &graph); + + void SetAttr(const ProtoAttrMap &attrs); + + using AttrHolder::GetAllAttrNames; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + graphStatus Save(Buffer &buffer, const bool is_dump = false) const; + graphStatus Save(proto::ModelDef &model_def, const bool is_dump = false) const; + graphStatus SaveWithoutSeparate(Buffer &buffer, const bool is_dump = false) const; + graphStatus SaveToFile(const std::string &file_name, const bool force_separate = false) const; + // Model will be rewrite + static graphStatus Load(const uint8_t *data, size_t len, Model &model); + /** + * 多线程加载模型接口,将data中的内容反序列化到model对象中 + * 当模型图具有多个子图时,此接口可以多线程并行加载子图加速,线程上线为16 + * @param data 模型序列化后的内容指针 + * @param len 模型序列化后的内容长度 + * @param model 模型加载后的承载对象 + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus LoadWithMultiThread(const uint8_t *data, size_t len, Model &model); + graphStatus Load(ge::proto::ModelDef &model_def); + graphStatus LoadFromFile(const std::string &file_name); + + bool IsValid() const; + + protected: + ConstProtoAttrMap &GetAttrMap() const override; + ProtoAttrMap &MutableAttrMap() override; + + private: + void Init(); + graphStatus Load(ge::proto::ModelDef &model_def, const std::string &path); + graphStatus Save(Buffer &buffer, const std::string &path, const bool is_dump = false) const; + graphStatus SaveSeparateModel(Buffer &buffer, const std::string &path, const bool is_dump = false) const; + AttrStore attrs_; + friend class ModelSerializeImp; + friend class GraphDebugImp; + friend class OnnxUtils; + friend class ModelHelper; + friend class ModelBuilder; + std::string name_; + uint32_t version_; + std::string platform_version_{""}; + ComputeGraphPtr graph_; +}; +using ModelPtr = std::shared_ptr; +} // namespace ge + +#endif // INC_GRAPH_MODEL_H_ diff --git a/inc/metadef/graph/model_serialize.h b/inc/metadef/graph/model_serialize.h new file mode 100644 index 0000000000000000000000000000000000000000..3e4644b3c309d3cfdcb9d14fb64820a5ef6d7636 --- /dev/null +++ b/inc/metadef/graph/model_serialize.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_MODEL_SERIALIZE_H_ +#define INC_GRAPH_MODEL_SERIALIZE_H_ + +#include +#include +#include "graph/buffer.h" +#include "graph/compute_graph.h" +#include "graph/model.h" +#include "external/ge_common/ge_api_types.h" + +namespace ge { +class ModelSerialize { + public: + Buffer SerializeModel(const Model &model, const bool not_dump_all = false) const; + Buffer SerializeSeparateModel(const Model &model, const std::string &path, const bool not_dump_all = false) const; + Buffer SerializeModel(const Model &model, const std::string &path, + const bool is_need_separate, const bool not_dump_all = false) const; + Status SerializeModel(const Model &model, const bool not_dump_all, proto::ModelDef &model_def) const; + + bool UnserializeModel(const uint8_t *const data, const size_t len, + Model &model, const bool is_enable_multi_thread = false) const; + bool UnserializeModel(ge::proto::ModelDef &model_def, Model &model, const std::string &path) const; + bool UnserializeModel(ge::proto::ModelDef &model_def, Model &model) const; + private: + friend class ModelSerializeImp; + friend class GraphDebugImp; +}; +} // namespace ge +#endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/metadef/graph/node.h b/inc/metadef/graph/node.h new file mode 100644 index 0000000000000000000000000000000000000000..a596258c3afe41b7b41ff565d76b301d62484a57 --- /dev/null +++ b/inc/metadef/graph/node.h @@ -0,0 +1,246 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_NODE_H_ +#define INC_GRAPH_NODE_H_ + +#include +#include +#include +#include +#include +#include + +#include "utils/attr_utils.h" +#include "graph/range_vistor.h" + +namespace ge { +class ComputeGraph; + +class Anchor; + +using AnchorPtr = std::shared_ptr; + +class DataAnchor; + +using DataAnchorPtr = std::shared_ptr; + +class InDataAnchor; + +using InDataAnchorPtr = std::shared_ptr; + +class OutDataAnchor; + +using OutDataAnchorPtr = std::shared_ptr; + +class ControlAnchor; + +using ControlAnchorPtr = std::shared_ptr; + +class InControlAnchor; + +using InControlAnchorPtr = std::shared_ptr; + +class OutControlAnchor; + +using OutControlAnchorPtr = std::shared_ptr; + +using kFusionDataFlowVec_t = std::vector>; + +// Node is a component of ComputeGraph +class Node : public std::enable_shared_from_this { + friend class ComputeGraph; + friend class ComputeGraphImpl; + friend class ModelSerializeImp; + + public: + using ConstNode = const Node; + using NodePtr = std::shared_ptr; + using ConstNodePtr = std::shared_ptr; + + class NodeImpl; + using NodeImplPtr = std::shared_ptr; + template + using Vistor = RangeVistor>; + virtual ~Node(); + Node(const Node &) = delete; + Node &operator=(const Node &) = delete; + bool operator==(const Node &r_node) const; + + graphStatus Init(); + + std::string GetName() const; + const char *GetNamePtr() const; + std::string GetType() const; + const char *GetTypePtr() const; + + ComputeGraphPtr GetOwnerComputeGraph() const; + ComputeGraph *GetOwnerComputeGraphBarePtr() const; + graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); + graphStatus ClearOwnerGraph(const ComputeGraphPtr &graph); + + /** + * 获取Node的输入锚点的智能指针对象,存在复杂对象的拷贝性能较差,适用于边遍历边修改场景 + * @return + */ + Vistor GetAllInDataAnchors() const; + + /** + * 获取Node的输入锚点的裸指针,性能优于`Vistor GetAllInDataAnchors()`, 适用于只读场景 + * @return + */ + std::vector GetAllInDataAnchorsPtr() const; + + /** + * 获取Node的输出锚点的智能指针对象,存在复杂对象的拷贝所以性能较差,适用于边遍历边修改场景 + * @return + */ + Vistor GetAllOutDataAnchors() const; + + /** + * 获取Node的输出锚点的裸指针,性能优于`Vistor GetAllOutDataAnchors()`, 适用于只读场景 + * @return + */ + std::vector GetAllOutDataAnchorsPtr() const; + uint32_t GetAllInDataAnchorsSize() const; + uint32_t GetAllOutDataAnchorsSize() const; + Vistor GetAllOutAnchors() const; + + // there must be a control anchor + std::vector GetAllOutAnchorsPtr() const; + + Vistor GetAllInAnchors() const; + + // there must be a control anchor + std::vector GetAllInAnchorsPtr() const; + + InDataAnchorPtr GetInDataAnchor(const int32_t idx) const; + OutDataAnchorPtr GetOutDataAnchor(const int32_t idx) const; + InControlAnchorPtr GetInControlAnchor() const; + OutControlAnchorPtr GetOutControlAnchor() const; + /** + * 获取Node的输入节点的智能指针对象,存在拷贝所以性能较差,适用于边遍历边修改场景 + * @return + */ + Vistor GetInNodes() const; + /** + * 获取Node的输入节点的裸指针,性能优于`Vistor GetInNodes()`, 适用于只读场景 + * @return + */ + std::vector GetInNodesPtr() const; + + /** + * 获取Node的输出节点的智能指针对象,存在拷贝所以性能较差,适用于边遍历边修改场景 + * @return + */ + Vistor GetOutNodes() const; + + /** + * 获取Node的输出节点的裸指针,性能优于`Vistor GetOutNodes()`, 适用于只读场景 + * @return + */ + std::vector GetOutNodesPtr() const; + + AnchorPtr GetInAnchor(const int32_t idx) const; + AnchorPtr GetOutAnchor(const int32_t idx) const; + + bool IsAllInNodesSeen(const std::unordered_set &nodes_seen) const; + + // All in Data nodes + Vistor GetInDataNodes() const; + // All in Control nodes + Vistor GetInControlNodes() const; + + /** + * 等价于`Vistor GetInNodes()` + * @return + */ + Vistor GetInAllNodes() const; + + // All out Data nodes + Vistor GetOutDataNodes() const; + std::vector GetOutDataNodesPtr() const; + // All out Control nodes + Vistor GetOutControlNodes() const; + + /** + * 等价于`Vistor GetOutNodes()` + * @return + */ + Vistor GetOutAllNodes() const; + + uint32_t GetOutDataNodesSize() const; + uint32_t GetOutControlNodesSize() const; + uint32_t GetOutNodesSize() const; + size_t GetInDataNodesSize() const; + size_t GetInControlNodesSize() const; + size_t GetInNodesSize() const; + + // Get all in data nodes and its out-anchor + Vistor> GetInDataNodesAndAnchors() const; + + // Get all out data nodes and its in-anchor + Vistor> GetOutDataNodesAndAnchors() const; + + OpDescPtr GetOpDesc() const; + OpDesc *GetOpDescBarePtr() const; + + graphStatus UpdateOpDesc(const OpDescPtr &op_desc); + + graphStatus AddLinkFrom(const NodePtr &input_node); + + graphStatus AddLinkFrom(const uint32_t &index, const NodePtr input_node); + + graphStatus AddLinkFrom(const std::string &name, const NodePtr input_node); + + graphStatus AddLinkFromForParse(const NodePtr &input_node); + + void AddSendEventId(const uint32_t event_id); + + void AddRecvEventId(const uint32_t event_id); + + const std::vector &GetSendEventIdList() const; + + const std::vector &GetRecvEventIdList() const; /*lint !e148*/ + + void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list); + + void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list); + + void SetFusionInputFlowList(const kFusionDataFlowVec_t &fusion_input_list); + + void SetFusionOutputFlowList(const kFusionDataFlowVec_t &fusion_output_list); + + bool GetHostNode() const; + void SetHostNode(const bool is_host); + + void SetOrigNode(const NodePtr &orignode); + NodePtr GetOrigNode(); + + protected: + Node(); + Node(const OpDescPtr &op, const ComputeGraphPtr &owner_graph); + + private: + bool NodeMembersAreEqual(const Node &r_node) const; + bool NodeInConnectsAreEqual(const Node &r_node) const; + bool NodeOutConnectsAreEqual(const Node &r_node) const; + bool NodeAnchorIsEqual(const AnchorPtr &left_anchor, const AnchorPtr &right_anchor, const size_t i) const; + NodeImplPtr impl_; + friend class NodeUtils; + friend class OnnxUtils; + friend class TuningUtils; +}; +using ConstNode = Node::ConstNode; +using NodePtr = Node::NodePtr; +using ConstNodePtr = Node::ConstNodePtr; +using NodeToOutAnchor = std::pair; +} // namespace ge + +#endif // INC_GRAPH_NODE_H_ diff --git a/inc/metadef/graph/op_desc.h b/inc/metadef/graph/op_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..05cb6654f8088b095bb15661f6f6e5790c610185 --- /dev/null +++ b/inc/metadef/graph/op_desc.h @@ -0,0 +1,450 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OP_DESC_H_ +#define INC_GRAPH_OP_DESC_H_ + +#include +#include "detail/attributes_holder.h" +#include "graph/range_vistor.h" +#include "graph/ge_tensor.h" +namespace ge { +using std::map; +using std::pair; +using std::shared_ptr; +using std::string; +using std::vector; + +class Operator; +class OpDescImpl; +class IRMetaData; +using OpDescImplPtr = std::shared_ptr; + +enum SubgraphType { + kStatic, + kDynamic, + kSubgraphTypeEnd +}; + +enum IrInputType { + kIrInputRequired, + kIrInputOptional, + kIrInputDynamic, + kIrInputTypeEnd +}; + +enum IrOutputType { + kIrOutputRequired, + kIrOutputDynamic, + kIrOutputTypeEnd +}; + +enum class OppImplVersion { + kOpp, + kOppKernel, + // add new version definitions here + kVersionEnd = 4 +}; + +class OpDesc : public std::enable_shared_from_this, public AttrHolder { + public: + using OpDescPtr = std::shared_ptr; + using ConstOpDescPtr = std::shared_ptr; + + template + using Vistor = RangeVistor>; + + friend class GraphBuilderImpl; + + friend class OperatorImpl; + + OpDesc(const std::string &name, const std::string &type); + + OpDesc(const OpDesc &op_desc); + + OpDesc(OpDesc &&op_desc); + + explicit OpDesc(const ge::proto::OpDef &op_def); + + OpDesc(); + + ~OpDesc() override; + + bool operator==(const OpDesc &r_op_desc) const; + + std::string GetName() const; + const char *GetNamePtr() const; + + void SetName(const std::string &name); + + void SetNamePtr(const char_t *name); + + std::string GetType() const; + const char *GetTypePtr() const; + + void SetType(const std::string &type); + + void SetIrRelated(const OpDescPtr &op_desc); + + graphStatus AddInputDesc(const GeTensorDesc &input_desc); + + graphStatus AddInputDesc(const std::string &name, const GeTensorDesc &input_desc); + + graphStatus AddInputDesc(const uint32_t index, const ge::GeTensorDesc &input_desc); + + graphStatus AddInputDescMiddle(const std::string &name, const uint32_t num, const size_t index); + + graphStatus AddOutputDescMiddle(const std::string &name, const uint32_t num, const size_t index); + + graphStatus AddOutputDescForward(const std::string &name, const uint32_t num); + + graphStatus AddOptionalInputDesc(const std::string &name, const GeTensorDesc &input_desc); + + graphStatus UpdateInputDesc(const uint32_t index, const GeTensorDesc &tensor_desc); + + graphStatus UpdateInputDesc(const std::string &name, const GeTensorDesc &tensor_desc); + + bool InputIsSet(const std::string &name) const; + + const GeTensorDesc &GetInputDesc(const uint32_t index) const; + + const GeTensorDesc &GetInputDesc(const std::string &name) const; + + bool IsOptionalInput(const uint32_t index) const; + + Vistor GetAllInputNames() const; + + GeTensorDescPtr MutableInputDesc(const uint32_t index) const; + + GeTensorDescPtr MutableInputDesc(const std::string &name) const; + /** + * 获取OpDesc的所有输入的GeTensorDesc对象的拷贝, + * 需要注意拷贝行为对性能的影响 + * @return + */ + Vistor GetAllInputsDesc() const; + /** + * 获取OpDesc的所有输入的GeTensorDesc对象的引用, + * 无特殊需求,推荐使用此接口替代GetAllInputsDesc() + * @return + */ + Vistor GetAllInputsDescPtr() const; + + size_t GetInputsSize() const; + + size_t GetAllInputsSize() const; + + graphStatus AddOutputDesc(const GeTensorDesc &output_desc); + + graphStatus AddOutputDesc(const std::string &name, const GeTensorDesc &output_desc); + + graphStatus UpdateOutputDesc(const uint32_t index, const GeTensorDesc &tensor_desc); + + graphStatus UpdateOutputDesc(const std::string &name, const GeTensorDesc &tensor_desc); + + const GeTensorDesc &GetOutputDesc(const uint32_t index) const; + + const GeTensorDesc &GetOutputDesc(const std::string &name) const; + + GeTensorDescPtr MutableOutputDesc(const uint32_t index) const; + + GeTensorDescPtr MutableOutputDesc(const std::string &name) const; + + uint32_t GetAllOutputsDescSize() const; + + Vistor GetAllOutputsDesc() const; + + Vistor GetAllOutputsDescPtr() const; + + size_t GetOutputsSize() const; + + ConstGeTensorDescPtr GetOutputDescPtr(const uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtr(const uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtrDfault(const uint32_t index) const; + + ConstGeTensorDescPtr GetInputDescPtr(const std::string &name) const; + + graphStatus AddDynamicInputDesc(const std::string &name, const uint32_t num, const bool is_push_back = true); + + graphStatus AddDynamicInputDescByIndex(const std::string &name, const uint32_t num, const size_t index); + + graphStatus AddDynamicOutputDesc(const std::string &name, const uint32_t num, const bool is_push_back = true); + + bool IsOptionalInput(const std::string &name) const; + + std::map GetAllInputName() const; + + std::map GetAllOutputName(); + + std::map& MutableAllInputName(); + + std::map& MutableAllOutputName(); + + bool UpdateInputName(const std::map input_name_idx); + + bool UpdateOutputName(const std::map output_name_idx); + + void *GetTilingFuncInfo() const; + + void SetTilingFuncInfo(void *tiling_func_info); + + void *GetAtomicTilingFuncInfo() const; + + void SetAtomicTilingFuncInfo(void *atomic_tiling_func_info); + + bool IsSupportSymbolicInferDataType() const; + graphStatus SymbolicInferDataType(); + + void AddInferFunc(const std::function &func); + void AddInferFormatFunc(const std::function &func); + void AddInferValueRangeFunc(const std::function &func); + void AddVerifierFunc(const std::function &func); + void AddInferDataSliceFunc(const std::function &func); + + graphStatus DefaultInferFormat(); + + std::function GetInferFunc() const; + std::function GetVerifyFunc() const; + std::function GetInferFormatFunc() const; + std::function GetInferDataSliceFunc() const; + std::function GetInferValueRangeFunc() const; + + graphStatus CommonVerify() const; + + graphStatus AddRegisterInputName(const std::string &name); + + graphStatus AddRegisterOutputName(const std::string &name); + + std::vector GetRegisterInputName() const; + + std::vector GetRegisterOutputName() const; + + void AppendIrAttrName(const std::string &name); + const std::vector &GetIrAttrNames() const; + + void AppendIrInput(std::string name, IrInputType input_type); + const std::vector> &GetIrInputs() const; + size_t GetIrInputsSize() const; + + void AppendIrOutput(std::string name, IrOutputType output_type); + const std::vector> &GetIrOutputs() const; + + void SetInputDtypeSymbol(const std::string &ir_input, IrInputType type, const std::string &sym_id); + void SetOutputDtypeSymbol(const std::string &ir_output, IrOutputType type, const std::string &sym_id); + void DeclareDtypeSymbol(const std::string &sym_id, const TensorType &type); + void DeclareDtypeSymbol(const std::string &sym_id, const ListTensorType &type); + void DeclareDtypeSymbol(const std::string &sym_id, const Promote &type); + void ShareDtypeSymbolsFrom(const ge::OpDesc &src); + + using AttrHolder::AddRequiredAttr; + using AttrHolder::DelAttr; + using AttrHolder::GetAllAttrNames; + using AttrHolder::GetAllAttrs; + using AttrHolder::GetAttr; + using AttrHolder::HasAttr; + using AttrHolder::SetAttr; + + void SetId(const int64_t id); + int64_t GetId() const; + void SetStreamId(const int64_t stream_id); + int64_t GetStreamId() const; + // 后续将会废弃,从流的结果将会是多个 + void SetAttachedStreamId(const int64_t stream_id); + int64_t GetAttachedStreamId() const; + void SetAttachedStreamIds(const std::vector &stream_id); + std::vector GetAttachedStreamIds() const; + bool HasValidAttachedStreamId() const; + + void SetInputName(const std::vector &input_name); + std::vector GetInputName() const; + void SetSrcName(const std::vector &src_name); + std::vector GetSrcName() const; + void SetSrcIndex(const std::vector &src_index); + std::vector GetSrcIndex() const; + void SetInputOffset(const std::vector &input); + std::vector GetInputOffset() const; + void SetOutputOffset(const std::vector &output); + std::vector GetOutputOffset() const; + void SetDstName(const std::vector &dst_name); + std::vector GetDstName() const; + void SetDstIndex(const std::vector &dst_index); + void SetWorkspace(const std::vector &workspace); + std::vector GetWorkspace() const; + void SetWorkspaceBytes(const std::vector &workspace_bytes); + std::vector GetWorkspaceBytes() const; + void SetIsInputConst(const std::vector &is_input_const); + std::vector GetIsInputConst() const; + + void SetOpInferDepends(const std::vector &depend_names); + std::vector GetOpInferDepends() const; + + std::string GetInputNameByIndex(const uint32_t index) const; + std::string GetValidInputNameByIndex(const uint32_t index) const; + int32_t GetInputIndexByName(const std::string &name) const; + + graphStatus GetDynamicInputIndexesByName(const std::string &name, std::vector &indexes) const; + + std::string GetOutputNameByIndex(const uint32_t index) const; + + int32_t GetOutputIndexByName(const std::string &name) const; + + graphStatus GetDynamicOutputIndexesByName(const std::string &name, std::vector &indexes) const; + + void SetOpKernelLibName(const std::string &name); + + std::string GetOpKernelLibName() const; + + void SetOpEngineName(const std::string &name); + + std::string GetOpEngineName() const; + + OppImplVersion GetOppImplVersion() const; + + void RegisterSubgraphIrName(const std::string &name, const SubgraphType type); + const std::map &GetSubgraphIrNames() const; + SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; + + graphStatus AddSubgraphName(const std::string &name); + const std::map &GetSubgraphNameIndexes() const; + + std::string GetSubgraphInstanceName(const uint32_t index) const; + const std::vector &GetSubgraphInstanceNames() const; + /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, + /// because this kind of functions will only append a new subgraph instance name + /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. + /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. + /// \param index + /// \param name + /// \return + graphStatus SetSubgraphInstanceName(const uint32_t index, const std::string &name); + void RemoveSubgraphInstanceName(const std::string &name); + + graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; + + graphStatus GetPromoteIrInputList(std::vector> &promote_index_list); + + template + T *GetOrCreateAttrsGroup() { + return MutableAttrMap().GetOrCreateAttrsGroup(); + } + + protected: + ProtoAttrMap &MutableAttrMap() override; + ConstProtoAttrMap &GetAttrMap() const override; + + private: + bool OpDescMembersAreEqual(const OpDesc &r_op_desc) const; + bool OpDescAttrsAreEqual(const OpDesc &r_op_desc) const; + bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; + + OpDescImplPtr impl_; + friend class OpDescUtils; + friend class ModelSerializeImp; + friend class AttrUtils; + friend class GeAttrValueImp; + friend class OnnxUtils; + friend class GraphUtils; + friend class NodeUtils; + friend class FastNodeUtils; + friend class ExecuteGraphUtils; +}; + +using OpDescPtr = OpDesc::OpDescPtr; +using ConstOpDescPtr = OpDesc::ConstOpDescPtr; +using ConstOpDesc = const OpDesc; + +class OpDescBuilder { + public: + OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} + OpDescBuilder(const OpDescBuilder &) = delete; + OpDescBuilder &operator=(const OpDescBuilder &) = delete; + OpDescBuilder(const OpDescBuilder &&) = delete; + OpDescBuilder &operator=(const OpDescBuilder &&) = delete; + ~OpDescBuilder() = default; + + /** + * @brief Add input + * @param [in] name + * @return OpDescBuilder + */ + OpDescBuilder &AddInput(const std::string &name); + + /** + * @brief Add input + * @param [in] name + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddInput(const std::string &name, const GeTensorDesc &tensor); + + /** + * @brief Add dynamic input + * @param [in] name + * @param [in] num + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicInput(const std::string &name, const uint32_t num); + + /** + * @brief Add dynamic input + * @param [in] name + * @param [in] num + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicInput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor); + + /** + * @brief Add output + * @param [in] name + * @return OpDescBuilder + */ + OpDescBuilder &AddOutput(const std::string &name); + + /** + * @brief Add output + * @param [in] name + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddOutput(const std::string &name, const GeTensorDesc &tensor); + + /** + * @brief Add dynamic output + * @param [in] name + * @param [in] num + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicOutput(const std::string &name, const uint32_t num); + + /** + * @brief Add dynamic output + * @param [in] name + * @param [in] num + * @param [in] tensor + * @return OpDescBuilder + */ + OpDescBuilder &AddDynamicOutput(const std::string &name, const uint32_t num, const GeTensorDesc &tensor); + + /** + * @brief Build op_desc + * @return OpDescPtr + */ + OpDescPtr Build(); + + private: + std::string name_; + std::string type_; + std::vector> inputs_; + std::vector> outputs_; +}; +} // namespace ge +#endif // INC_GRAPH_OP_DESC_H_ diff --git a/inc/metadef/graph/op_io.h b/inc/metadef/graph/op_io.h new file mode 100644 index 0000000000000000000000000000000000000000..5abf5d844f9c4853f033460377799abe4e800c58 --- /dev/null +++ b/inc/metadef/graph/op_io.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_OP_IO_H +#define METADEF_CXX_OP_IO_H +namespace ge { + +class OpIO { + public: + OpIO(const std::string &name, const int32_t index, const OperatorImplPtr &owner) + : name_(name), index_(index), owner_(owner) {} + + ~OpIO() = default; + + std::string GetName() const { return name_; } + + int32_t GetIndex() const { return index_; } + + OperatorImplPtr GetOwner() const { return owner_; } + + private: + std::string name_; + int32_t index_; + std::shared_ptr owner_; +}; +} +#endif // METADEF_CXX_OP_IO_H diff --git a/inc/metadef/graph/op_kernel_bin.h b/inc/metadef/graph/op_kernel_bin.h new file mode 100644 index 0000000000000000000000000000000000000000..cf8291f7c81c757886a146dc107ea00eac9711a5 --- /dev/null +++ b/inc/metadef/graph/op_kernel_bin.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OP_KERNEL_BIN_H_ +#define INC_GRAPH_OP_KERNEL_BIN_H_ + +#include +#include +#include +#include "graph/types.h" +#include "graph/def_types.h" +#include "graph/host_resource/host_resource.h" + +namespace ge { +class OpKernelBin : public HostResource { + public: + OpKernelBin(const std::string &name, std::vector &&data) : name_(name), data_(std::move(data)) {} + + ~OpKernelBin() = default; + + const std::string &GetName() const { return name_; } + const uint8_t *GetBinData() const { return ge::PtrToPtr(data_.data()); } + size_t GetBinDataSize() const { return data_.size(); } + OpKernelBin(const OpKernelBin &) = delete; + const OpKernelBin &operator=(const OpKernelBin &) = delete; + + private: + std::string name_; + std::vector data_; +}; + +using OpKernelBinPtr = std::shared_ptr; +constexpr char_t OP_EXTATTR_NAME_TBE_KERNEL[] = "tbeKernel"; +constexpr char_t OP_EXTATTR_NAME_THREAD_TBE_KERNEL[] = "thread_tbeKernel"; +constexpr char_t OP_EXTATTR_CUSTAICPU_KERNEL[] = "cust_aicpu_kernel"; +} // namespace ge + +#endif // INC_GRAPH_OP_KERNEL_BIN_H_ diff --git a/inc/metadef/graph/op_so_bin.h b/inc/metadef/graph/op_so_bin.h new file mode 100644 index 0000000000000000000000000000000000000000..873321a2ec09f6b541213e9cf43f0c8dc14dbddc --- /dev/null +++ b/inc/metadef/graph/op_so_bin.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OP_SO_BIN_H +#define INC_GRAPH_OP_SO_BIN_H + +#include +#include +#include "graph/types.h" +#include "graph/def_types.h" + +namespace ge { +typedef struct { + std::string cpu_info; + std::string os_info; + std::string opp_version; + std::string compiler_version; +} SoInOmInfo; + +enum class SoBinType : uint16_t { + kSpaceRegistry = 0, // for use of rt2 infer shape/tiling so, managed in OpImplSpaceRegistry + kOpMasterDevice = 1 // for use of tiling so on device +}; + +class OpSoBin { +public: + OpSoBin(const std::string &so_name, const std::string &vendor_name, + std::unique_ptr data, uint32_t data_len) + : so_name_(so_name), vendor_name_(vendor_name), data_(std::move(data)), + data_size_(data_len) {} + + OpSoBin(const std::string &so_name, const std::string &vendor_name, std::unique_ptr data, uint32_t data_len, + SoBinType so_bin_type) + : so_name_(so_name), vendor_name_(vendor_name), data_(std::move(data)), data_size_(data_len), + so_bin_type_(so_bin_type) {} + + ~OpSoBin() = default; + + const std::string &GetSoName() const { return so_name_; } + const std::string &GetVendorName() const { return vendor_name_; } + const uint8_t *GetBinData() const { return ge::PtrToPtr(data_.get()); } + uint8_t *MutableBinData() const { + return PtrToPtr(data_.get()); + } + size_t GetBinDataSize() const { return data_size_; } + SoBinType GetSoBinType() const { return so_bin_type_; } + OpSoBin(const OpSoBin &) = delete; + const OpSoBin &operator=(const OpSoBin &) = delete; + +private: + std::string so_name_; + std::string vendor_name_; + std::unique_ptr data_; + uint32_t data_size_; + SoBinType so_bin_type_{SoBinType::kSpaceRegistry}; +}; + +using OpSoBinPtr = std::shared_ptr; +} // namespace ge + +#endif // INC_GRAPH_OP_SO_BIN_H diff --git a/inc/metadef/graph/op_types.h b/inc/metadef/graph/op_types.h new file mode 100644 index 0000000000000000000000000000000000000000..c1d7129190046da9f6a1fbf883601a9e2319e507 --- /dev/null +++ b/inc/metadef/graph/op_types.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OP_TYPES_H_ +#define INC_GRAPH_OP_TYPES_H_ + +#include +#include + +#include "graph/types.h" + +namespace ge { +class GE_FUNC_VISIBILITY OpTypeContainer { + public: + static OpTypeContainer &Instance() { + static OpTypeContainer instance; + return instance; + } + ~OpTypeContainer() = default; + + void Register(const std::string &op_type) { static_cast(op_type_list_.insert(op_type)); } + + bool IsExisting(const std::string &op_type) { + return op_type_list_.find(op_type) != op_type_list_.end(); + } + + protected: + OpTypeContainer() {} + + private: + std::set op_type_list_; +}; + +class GE_FUNC_VISIBILITY OpTypeRegistrar { + public: + explicit OpTypeRegistrar(const std::string &op_type) noexcept { OpTypeContainer::Instance().Register(op_type); } + ~OpTypeRegistrar() {} +}; + +#define REGISTER_OPTYPE_DECLARE(var_name, str_name) \ + FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char_t *var_name + +#define REGISTER_OPTYPE_DEFINE(var_name, str_name) \ + const char_t *var_name = str_name; \ + const ge::OpTypeRegistrar g_##var_name##_reg(str_name) + +#define IS_OPTYPE_EXISTING(str_name) (ge::OpTypeContainer::Instance().IsExisting(str_name)) +} // namespace ge + +#endif // INC_GRAPH_OP_TYPES_H_ diff --git a/inc/metadef/graph/operator_factory_impl.h b/inc/metadef/graph/operator_factory_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..00431709de3389a7d82795e7074266de0f43f969 --- /dev/null +++ b/inc/metadef/graph/operator_factory_impl.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ +#define INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ + +#include +#include +#include +#include +#include "graph/operator_factory.h" +#include "register/infer_data_slice_registry.h" +#include "register/infer_axis_slice_registry.h" +#include "register/op_impl_kernel_registry.h" +#include "graph/op_desc.h" + +namespace ge { +using InferShapeV2Func = uint32_t (*)(const ge::Operator &op, const OpDescPtr &); +using InferDataTypeFunc = uint32_t (*)(const OpDescPtr &); +using InferShapeRangeFunc = uint32_t (*)(const ge::Operator &op, const OpDescPtr &); +struct InferValueRangePara { + public: + InferValueRangePara() = default; + InferValueRangePara(const WHEN_CALL call, const bool cpu_kernel, const InferValueRangeFunc func) { + is_initialized = true; + use_cpu_kernel = cpu_kernel; + when_call = call; + infer_value_func = func; + } + friend class OpDescImpl; + friend class InferValueRangePass; + friend class OpDescUtilsEx; + ~InferValueRangePara() = default; +private: + bool is_initialized = false; + bool use_cpu_kernel = false; + WHEN_CALL when_call = INPUT_IS_DYNAMIC; + InferValueRangeFunc infer_value_func = nullptr; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { + public: + static Operator CreateOperator(const std::string &operator_name, const std::string &operator_type); + + static graphStatus GetOpsTypeList(std::vector &all_ops); + + static bool IsExistOp(const std::string &operator_type); + + static InferShapeFunc GetInferShapeFunc(const std::string &operator_type); + + static InferShapeV2Func GetInferShapeV2Func(); + + static InferDataTypeFunc GetInferDataTypeFunc(); + + static InferShapeRangeFunc GetInferShapeRangeFunc(); + + static InferFormatFunc GetInferFormatFunc(const std::string &operator_type); + + static InferValueRangePara GetInferValueRangePara(const std::string &operator_type); + + static VerifyFunc GetVerifyFunc(const std::string &operator_type); + + static InferDataSliceFunc GetInferDataSliceFunc(const std::string &operator_type); + + static InferAxisSliceFunc GetInferAxisSliceFunc(const std::string &operator_type); + + static InferAxisTypeInfoFunc GetInferAxisTypeInfoFunc(const std::string &operator_type); + + static void SetRegisterOverridable(const bool &is_overridable); + + static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreator const &op_creator); + + static graphStatus RegisterOperatorCreator(const std::string &operator_type, OpCreatorV2 const &op_creator); + + static graphStatus RegisterInferShapeFunc(const std::string &operator_type, InferShapeFunc const infer_shape_func); + + static void RegisterInferShapeV2Func(InferShapeV2Func const infer_shape_func); + + static void RegisterInferDataTypeFunc(InferDataTypeFunc const infer_data_type_func); + + static void RegisterInferShapeRangeFunc(InferShapeRangeFunc const infer_shape_range_func); + + static graphStatus RegisterInferFormatFunc(const std::string &operator_type, InferFormatFunc const infer_format_func); + + static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); + + static graphStatus RegisterInferDataSliceFunc(const std::string &operator_type, + InferDataSliceFunc const infer_data_slice_func); + + static graphStatus RegisterInferValueRangeFunc(const std::string &operator_type); + + static graphStatus RegisterInferValueRangeFunc(const std::string &operator_type, + const WHEN_CALL when_call, + const bool use_cpu_kernel, + const InferValueRangeFunc &infer_value_range_func); + + static graphStatus RegisterInferAxisSliceFunc(const std::string &operator_type, + const InferAxisSliceFunc &infer_axis_slice_func); + + static graphStatus RegisterInferAxisTypeInfoFunc(const std::string &operator_type, + const InferAxisTypeInfoFunc &infer_axis_type_info_func); + + static void ReleaseRegInfo(); + + static std::shared_ptr> operator_creators_; + static std::shared_ptr> operator_creators_v2_; + static std::shared_ptr> operator_infershape_funcs_; + static std::shared_ptr> operator_inferformat_funcs_; + static std::shared_ptr> operator_verify_funcs_; + static std::shared_ptr> operator_infer_data_slice_funcs_; + static std::shared_ptr> operator_infer_value_range_paras_; + static std::shared_ptr> operator_infer_axis_slice_funcs_; + static std::shared_ptr> operator_infer_axis_type_info_funcs_; + static InferShapeV2Func operator_infer_shape_v2_func_; + static InferDataTypeFunc operator_infer_datatype_func_; + static InferShapeRangeFunc operator_infer_shape_range_func_; +}; +} // namespace ge + +#endif // INC_GRAPH_OPERATOR_FACTORY_IMPL_H_ diff --git a/inc/metadef/graph/operator_impl.h b/inc/metadef/graph/operator_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..651b6568d10c0abd9daddc2f1eb02722196c87e7 --- /dev/null +++ b/inc/metadef/graph/operator_impl.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_OPERATOR_IMPL_H +#define METADEF_CXX_OPERATOR_IMPL_H +#include +#include +#include "graph/op_desc.h" +#include "graph/node.h" +#include "graph/operator.h" +#include "graph/inference_context.h" +#include "graph/runtime_inference_context.h" +#include "graph/normal_graph/op_io.h" +namespace ge { +class OperatorImpl : public std::enable_shared_from_this { + public: + using GetConstInputOnRuntimeFun = + std::function; + explicit OperatorImpl(const std::string &name, const std::string &type); + explicit OperatorImpl(const OpDescPtr &op_desc); + explicit OperatorImpl(const ConstNodePtr node); + ~OperatorImpl(); + + void SetInputImpl(const std::string &dst_name, const Operator &src_oprt); + void SetInputImpl(const std::string &dst_name, const OutHandler &out_handler); + void AddControlInputImp(const Operator &src_oprt); + graphStatus GetInputImpl(const std::string &dst_name, ge::OpIO &out_handler) const; + graphStatus GetInputImpl(const uint32_t idx, ge::OpIO &out_handler) const; + graphStatus GetInputConstData(const std::string &dst_name, Tensor &data); + graphStatus GetInputConstData(const uint32_t idx, ConstGeTensorPtr &ge_tensor) const; + graphStatus GetInputConstDataOut(const std::string &dst_name, Tensor &data) const; + graphStatus GetInputConstDataOut(const uint32_t idx, ConstGeTensorPtr &ge_tensor) const; + bool InputIsSet(const std::string &name); + std::string GetName() const; + GeTensorDesc GetInputDesc(const std::string &name) const; + GeTensorDesc GetInputDesc(const uint32_t index) const; + GeTensorDescPtr MutableInputDesc(const std::string &name); + GeTensorDescPtr MutableInputDesc(const uint32_t index); + graphStatus UpdateInputDesc(const std::string &name, const GeTensorDesc &tensor_desc); + OutHandler GetOutput(const std::string &name); + OutHandler GetOutput(uint32_t index); + GeTensorDesc GetOutputDesc(const std::string &name) const; + GeTensorDesc GetOutputDesc(const uint32_t index) const; + GeTensorDescPtr MutableOutputDesc(const std::string &name); + GeTensorDescPtr MutableOutputDesc(const uint32_t index); + graphStatus UpdateOutputDesc(const std::string &name, const GeTensorDesc &tensor_desc); + size_t GetInputsSize() const; + size_t GetOutputsSize() const; + graphStatus SetAttr(const std::string &name, AnyValue &&attr_value); + graphStatus GetAttr(const std::string &name, AnyValue &attr_value) const; + OpDescPtr GetOpDescImpl() const; + void UpdateLinkMapImpl(const std::string &src_name, const OpIO &op_dst); + Operator ToOperator(); + void ClearOutputLinks() noexcept; + void ClearInputLinks() noexcept; + ge::ConstNodePtr GetNode() const; + graphStatus SetNode(const ConstNodePtr &node) ; + void SetInferenceContext(const InferenceContextPtr &inference_context); + InferenceContextPtr GetInferenceContext() const; + void SubgraphRegister(const std::string &ir_name, const bool dynamic); + void SubgraphCountRegister(const std::string &ir_name, const uint32_t count); + void SetSubgraphBuilder(const std::string &ir_name, const uint32_t index, const SubgraphBuilder &builder); + SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, const uint32_t index) const; + SubgraphBuilder GetSubgraphBuilder(const std::string &name) const; + std::vector GetSubgraphNames() const; + size_t GetSubgraphNamesCount() const; + + static OpDescPtr GetOpDesc(const Operator &oprt); + graphStatus UpdateInputDesc(const uint32_t index, const GeTensorDesc &tensor_desc); + graphStatus UpdateOutputDesc(const uint32_t index, const GeTensorDesc &tensor_desc); + + private: + graphStatus GetFromPeerNode(NodePtr &peer_node, const OutDataAnchorPtr &out_data_anchor, + ConstGeTensorPtr &ge_tensor) const; + + private: + OpDescPtr op_desc_ = nullptr; + ge::ConstNodePtr node_{nullptr}; + ge::InferenceContextPtr inference_context_; + std::map> output_links_{}; + std::map input_link_{}; + std::vector> control_input_link_{}; + std::vector> control_output_link_{}; + std::map subgraph_names_to_builders_; + RuntimeInferenceContext *runtime_context_{nullptr}; // depracated, will delete when air support + GetConstInputOnRuntimeFun get_const_input_runtime_ = nullptr; + + private: + friend class GraphBuilderImpl; + friend class MultiThreadGraphBuilder; + friend class OpDescUtils; +}; +// Used to manage OperatorImpl instances created by ge api. +class OperatorKeeper { + public: + static OperatorKeeper &GetInstance(); + void CheckInOperator(const OperatorImplPtr &op_impl) { + if (op_impl) { + const std::lock_guard lock(mutex_); + (void)(operators_.insert(op_impl)); + } + } + void CheckOutOperator(const OperatorImplPtr &op_impl) { + if (op_impl) { + const std::lock_guard lock(mutex_); + (void)(operators_.erase(op_impl)); + } + } + + void ClearInvalidOp() { + const std::lock_guard lock(mutex_); + for (auto iter = operators_.begin(); iter != operators_.end();) { + auto op = iter->lock(); + if (op == nullptr) { + iter = operators_.erase(iter); + } else { + ++iter; + } + } + } + + private: + OperatorKeeper() = default; + ~OperatorKeeper() { + for (const auto &iter : operators_) { + if (!iter.expired()) { + iter.lock()->ClearInputLinks(); + } + if (!iter.expired()) { + iter.lock()->ClearOutputLinks(); + } + } + // Manually clean up for `Operator` destructor may access `operators_` + auto operators = std::move(operators_); + operators.clear(); + } + std::set, std::owner_less>> operators_; + std::mutex mutex_; +}; +} // namespace ge + +#endif // METADEF_CXX_OPERATOR_IMPL_H diff --git a/inc/metadef/graph/opsproto_manager.h b/inc/metadef/graph/opsproto_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..4fb59c7639fa30d7b048b5507a42a5c6b65e76fc --- /dev/null +++ b/inc/metadef/graph/opsproto_manager.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OPSPROTO_MANAGER_H_ +#define INC_GRAPH_OPSPROTO_MANAGER_H_ + +#include +#include +#include +#include + +namespace ge { +class OpsProtoManager { + public: + OpsProtoManager() = default; + ~OpsProtoManager(); + static OpsProtoManager *Instance(); + + bool Initialize(const std::map &options); + void Finalize(); + + private: + void LoadOpsProtoPluginSo(const std::string &path); + + std::string pluginPath_; + std::vector handles_; + bool is_init_ = false; + std::mutex mutex_; +}; +} // namespace ge + +#endif // INC_GRAPH_OPSPROTO_MANAGER_H_ diff --git a/inc/metadef/graph/option/optimization_option.h b/inc/metadef/graph/option/optimization_option.h new file mode 100644 index 0000000000000000000000000000000000000000..8ccce71474ba7489c0fe714e923d516e8b89f2c8 --- /dev/null +++ b/inc/metadef/graph/option/optimization_option.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OPTION_OPTIMIZATION_OPTION_H_ +#define INC_GRAPH_OPTION_OPTIMIZATION_OPTION_H_ + +#include +#include +#include "graph/ge_error_codes.h" +#include "optimization_option_info.h" + +namespace ge { +class OptimizationOption { + public: + OptimizationOption() = default; + ~OptimizationOption() = default; + + graphStatus Initialize(const std::map &ge_options, + const std::unordered_map ®istered_options); + graphStatus GetValue(const std::string &opt_name, std::string &opt_value) const; + static graphStatus IsOoLevelValid(const std::string &oo_level); + static graphStatus IsOptionValueValid(const std::string &opt_name, const std::string &opt_value, + OoInfo::ValueChecker checker); + + private: + graphStatus InitWorkingOolevel(const std::map &ge_options); + void PrintAllWorkingOo(); + + private: + OoLevel working_oo_level_{OoLevel::kEnd}; + std::unordered_map working_opt_names_to_value_; +}; +} // namespace ge +#endif // INC_GRAPH_OPTION_OPTIMIZATION_OPTION_H_ diff --git a/inc/metadef/graph/option/optimization_option_info.h b/inc/metadef/graph/option/optimization_option_info.h new file mode 100644 index 0000000000000000000000000000000000000000..edb29f8c7d6993c83759883515103376d43c3c9e --- /dev/null +++ b/inc/metadef/graph/option/optimization_option_info.h @@ -0,0 +1,94 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_OPTION_OPTIMIZATION_OPTION_INFO_H +#define INC_GRAPH_OPTION_OPTIMIZATION_OPTION_INFO_H +#include +#include +#include +#include +#include +#include + +namespace ge { +// pre-defined set of optimizaion templates +enum class OoLevel : uint32_t { + kO0 = 0, + kO1 = 1, + kO2 = 2, + kO3 = 3, + kEnd, +}; +static_assert(static_cast(OoLevel::kEnd) < 64, "The number of OoLevels exceeds 64!"); + +// the hierarchy of optimizaion options +enum class OoHierarchy : uint32_t { + kH1 = 0, // Primary options + kH2 = 1, // Secondary options + kEnd, +}; + +// the entry points to GE +enum class OoEntryPoint : uint32_t { + kSession = 0, + kIrBuild = 1, + kAtc = 2, + kEnd, +}; + +enum class OoCategory : uint32_t { + kGeneral = 0, + kInput, + kOutput, + kTarget, + kFeature, + kModelTuning, + kOperatorTuning, + kDebug, + kEnd, +}; + +struct OoShowInfo { + OoCategory catagory; + std::string show_name; +}; + +struct OoInfo { + using ValueChecker = bool (*)(const std::string& opt_value); + // identifies the visibility of the option at different entrances of the program + uint64_t visibility; + uint64_t levels; + OoHierarchy hierarchy; + ValueChecker checker; + std::string name; + std::string help_text; + // Maps each entry point to its corresponding display option information + std::map show_infos; + std::map default_values; + + explicit OoInfo(std::string opt_name, OoHierarchy opt_hierarchy = OoHierarchy::kEnd, uint64_t opt_level = 0UL, + uint64_t opt_vis = 0UL, std::map opt_values = {}, + ValueChecker opt_checker = nullptr, std::map opt_entry_infos = {}, + std::string opt_help = "") + : visibility(opt_vis), levels(opt_level), hierarchy(opt_hierarchy), checker(opt_checker), + name(std::move(opt_name)), help_text(std::move(opt_help)), show_infos(std::move(opt_entry_infos)), + default_values(std::move(opt_values)) {} +}; + +class OoInfoUtils { + public: + static bool IsBitSet(const uint64_t bits, const uint32_t pos); + static uint64_t GenOptVisibilityBits(const std::vector &entries); + static uint64_t GenOptLevelBits(const std::vector &levels); + static std::string GenOoLevelStr(const uint64_t opt_level); + static std::string GetDefaultValue(const OoInfo &info, OoLevel target_level); + static bool IsSwitchOptValueValid(const std::string &opt_value); +}; +} // namespace ge +#endif // INC_GRAPH_OPTION_OPTIMIZATION_OPTION_INFO_H diff --git a/inc/metadef/graph/parallelism/graph_parallel_option.h b/inc/metadef/graph/parallelism/graph_parallel_option.h new file mode 100644 index 0000000000000000000000000000000000000000..74196f19c8783248fe5580c0a726a20685d05414 --- /dev/null +++ b/inc/metadef/graph/parallelism/graph_parallel_option.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ +#define METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ + +#include +#include + +namespace ge { +struct PipelineParallelOption { + bool is_enabled = false; + bool is_auto = false; + std::string pipeline_strategy; + int32_t pipe_stage_num = -1; + int32_t schedule_opt_virtual_stage_num = -1; +}; + +struct TensorParallelOption { + bool is_enabled = false; + bool is_auto = false; + int32_t tensor_parallel_size = -1; + int32_t inter_batch_flow_num = 1; +}; + +struct DataParallelOption { + bool is_enabled = false; + bool is_auto = false; + // to be deleted below + bool optimizer_state_sharding = false; + bool gradient_sharding = false; + bool model_weight_sharding = false; + bool model_weight_prefetch = true; + int32_t data_parallel_size = -1; + // model weight prefetch buffer size(MB) + uint32_t model_weight_prefetch_buffer_size = 0U; +}; + +struct TensorShardingOption { + bool is_enabled = false; + bool optimizer_state_sharding = false; + bool gradient_sharding = false; + bool model_weight_sharding = false; + bool model_weight_prefetch = true; + // model weight prefetch buffer size(MB) + uint32_t model_weight_prefetch_buffer_size = 0U; +}; + +struct OptimizerOffloadGraphOption { + bool is_enabled = false; + std::string offload; // cpu or NVME, NVME is reserved + std::string offload_path; // NVME path, reserved +}; + +struct EngineParallelOption { + bool is_enabled = false; + bool is_auto = false; + std::string config_path; // used if is_auto == true +}; + +struct GraphParallelOption { + bool auto_deploy = false; + std::string mode; // AOE mode, search_strategy/search_and_shard_graph/load_strategy/load_and_eval_strategy + std::string work_dir; // AOE dump/load path for strategies + std::string opt_level; + int32_t global_batch_size = -1; + DataParallelOption data_parallel_option; + TensorParallelOption tensor_parallel_option; + TensorShardingOption tensor_sharding_option; + PipelineParallelOption pipeline_parallel_option; + OptimizerOffloadGraphOption optimizer_offload_option; + EngineParallelOption engine_parallel_option; +}; +} // namespace ge + +#endif // METADEF_INC_GRAPH_PARALLELISM_GRAPH_PARALLEL_OPTION_H_ diff --git a/inc/metadef/graph/parallelism/tensor_parallel_attrs.h b/inc/metadef/graph/parallelism/tensor_parallel_attrs.h new file mode 100644 index 0000000000000000000000000000000000000000..ea56fe1544cb10afe8b157902d9a6c39b9e8d0a9 --- /dev/null +++ b/inc/metadef/graph/parallelism/tensor_parallel_attrs.h @@ -0,0 +1,396 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ +#define METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ + +#include +#include +#include +#include +#include +#include "external/ge_common/ge_api_types.h" + +namespace ge { +namespace tp { +constexpr const char_t *kCommTaskTypeConcat = "Concat"; +constexpr const char_t *kCommTaskTypeUniqueConcat = "UniqueConcat"; +constexpr const char_t *kCommTaskTypeModifyValue = "ModifyValue"; +constexpr const char_t *kCommTaskTypeSlice = "Slice"; +constexpr const char_t *kCommTaskTypeSliceByAxis = "SliceByAxis"; +constexpr const char_t *kCommTaskTypeSplit = "Split"; +constexpr const char_t *kCommTaskTypeTranspose = "Transpose"; +constexpr const char_t *kCommTaskTypeReshape = "Reshape"; +constexpr const char_t *kCommTaskTypeCast = "Cast"; +constexpr const char_t *kCommTaskTypeHcomAllGather = "HcomAllGather"; +constexpr const char_t *kCommTaskTypeHcomAllReduce = "HcomAllReduce"; +constexpr const char_t *kCommTaskTypeHcomAllReduceMean = "HcomAllReduceMean"; +constexpr const char_t *kCommTaskTypeHcomReduceScatter = "HcomReduceScatter"; +constexpr const char_t *kCommTaskTypeHcomBroadcast = "HcomBroadcast"; +constexpr const char_t *kCommTaskTypeHcomAllToAll = "HcomAllToAll"; +constexpr const char_t *kCommTaskTypeSendReceive = "SendReceive"; +constexpr const char_t *kCommTaskTypeLocalReduce = "LocalReduce"; +constexpr const char_t *kGraphSlicingSuffix = "_by_graph_slice_"; +constexpr const char_t *kFlowAttrEnqueuePolicyFifo = "FIFO"; +constexpr const char_t *kFlowAttrEnqueuePolicyOverwrite = "OVERWRITE"; +constexpr const char_t *kSendRecvCommTypeQueue = "Queue"; +constexpr const char_t *kSendRecvCommTypeP2p = "P2pComm"; + +// tensor deployment attrs +struct DimSlice { + int64_t begin; + int64_t end; +}; + +struct DeviceIndex { + std::string engine_type; + std::vector indices; + std::string DebugString() const; +}; + +struct ModelIndex { + // use this construct when need use stage id + ModelIndex() = default; + ModelIndex(const DeviceIndex &device_index, const int64_t stage_id, const int64_t virtual_stage_id) + : device_index(device_index), virtual_stage_id(virtual_stage_id), stage_id(stage_id) {} + // use this construct when do not need use stage id + ModelIndex(const DeviceIndex &device_index, const int64_t virtual_stage_id) + : device_index(device_index), virtual_stage_id(virtual_stage_id), stage_id(0L) {} + ~ModelIndex() = default; + DeviceIndex device_index; + int64_t virtual_stage_id = 0L; + int64_t stage_id = 0L; + std::string DebugString() const; +}; + +struct PipelineConfig { + int64_t micro_batch = 1L; + int64_t stage_id = 0L; + std::vector virtual_stage_id {0L}; +}; + +bool operator==(const DeviceIndex &lhs, const DeviceIndex &rhs); +bool operator!=(const DeviceIndex &lhs, const DeviceIndex &rhs); +bool operator<(const DeviceIndex &lhs, const DeviceIndex &rhs); + +bool operator==(const ModelIndex &lhs, const ModelIndex &rhs); +bool operator!=(const ModelIndex &lhs, const ModelIndex &rhs); +bool operator<(const ModelIndex &lhs, const ModelIndex &rhs); + +struct TensorSliceDeployment { + std::vector> axis_slices; + std::vector> device_indices_each_slice; + std::string reduce_type; +}; + +struct TensorDeployment { + TensorSliceDeployment shard_deployment; + std::string verbose; +}; + +struct NodeDeployment { + std::vector devices; + PipelineConfig pipeline_config; +}; + +struct NodeDeployments { + std::map deployments; +}; + +struct TensorDeployments { + std::map deployments; +}; + +// P2P communications +struct CommPair { + DeviceIndex src_device_index; + int64_t src_virtual_stage_id = 0L; + DeviceIndex dst_device_index; + int64_t dst_virtual_stage_id = 0L; +}; + +struct FlowAttr { + int32_t depth = 1; + std::string enqueue_policy = kFlowAttrEnqueuePolicyFifo; +}; + +struct SendRecvReshardTask { + std::vector comm_pairs; + std::string parallel_group; + std::string comm_type = kSendRecvCommTypeQueue; + FlowAttr flow_attr; // used when comm_type is Queue +}; + +struct CastReshardTask { + DataType dst_type = DT_MAX; +}; + +// group communications +struct CommGroup { + std::vector device_indices; +}; + +struct AllToAllReshardTask { + std::vector comm_groups; + std::string parallel_group; +}; + +struct AllGatherReshardTask { + std::vector comm_groups; + int32_t axis; // axis to concat + std::string parallel_group; + std::string output_allocator; +}; + +struct AllReduceReshardTask { + std::string reduction; + std::vector comm_groups; + std::string parallel_group; +}; + +struct AllReduceMeanReshardTask { + std::vector comm_groups; + int32_t axis; + int32_t value; + std::string parallel_group; +}; + +struct ReduceScatterReshardTask { + std::string reduction; + std::vector comm_groups; + std::string parallel_group; +}; + +struct BroadcastReshardTask { + std::vector root_device_indices; // size == num_groups + std::vector comm_groups; + std::string parallel_group; +}; + +// local reshardings +struct SliceReshardTask { + std::vector axes; + std::vector offsets; + std::vector sizes; + DeviceIndex device_index; +}; + +struct SliceByAxisReshardTask { + // key: axis to split + // value: index: slice index + // element: devices to deploy + std::map>> axis_to_slice_deployments; +}; + +struct SplitReshardTask { + int32_t split_dim = 0; + int32_t num_split = 0; +}; + +struct ConcatReshardTask { + int32_t concat_dim = 0; +}; + +struct UniqueConcatReshardTask { + std::string unique_id; + int32_t concat_dim = 0; + std::vector src_device_indices; + DeviceIndex dst_device_index; +}; + +struct TransposeReshardTask { + std::vector perm; +}; + +struct ReshapeReshardTask { + std::vector shape; +}; + +struct ModifyValueReshardTask { + std::string op_type; // mul, div + std::vector value; +}; + +struct LocalReduceReshardTask { + std::string op_type; +}; + +struct CommTask { + std::string task_type; + std::shared_ptr send_recv_reshard_task; + std::shared_ptr all_gather_reshard_task; + std::shared_ptr all_to_all_reshard_task; + std::shared_ptr all_reduce_reshard_task; + std::shared_ptr all_reduce_mean_reshard_task; + std::shared_ptr reduce_scatter_reshard_task; + std::shared_ptr broadcast_reshard_task; + std::shared_ptr split_reshard_task; + std::shared_ptr concat_reshard_task; + std::shared_ptr unique_concat_reshard_task; + std::shared_ptr slice_reshard_task; + std::shared_ptr slice_by_axis_reshard_task; + std::shared_ptr transpose_reshard_task; + std::shared_ptr modify_value_reshard_task; + std::shared_ptr local_reduce_reshard_task; + std::shared_ptr reshape_reshard_task; + std::shared_ptr cast_reshard_task; +}; + +struct CommStepInput { + int32_t step_id = -1; + int32_t output_index = -1; +}; + +bool operator==(const CommStepInput &lhs, const CommStepInput &rhs); +bool operator<(const CommStepInput &lhs, const CommStepInput &rhs); + +struct CommStep { + int32_t id; + std::vector inputs; + CommTask comm_task; +}; + +struct PeerInput { + int32_t step_id = -1; + std::string node_name; + uint32_t input_index; + int64_t stage_id = 0L; + int64_t virtual_stage_id = 0L; +}; + +// reshard ops for one output tensor +struct OutputReshardRes { + std::vector comm_steps; + std::vector peer_inputs; + std::vector device_indices; + int64_t stage_id = 0L; + int64_t virtual_stage_id = 0L; +}; + +struct ReshardAttr { + std::vector> reshard_infos; // indexed by output index +}; + +struct SrcNodeInfo { + int32_t inserted_node_id = -1; + int32_t output_index = -1; +}; +bool operator==(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs); +bool operator<(const SrcNodeInfo &lhs, const SrcNodeInfo &rhs); + +struct OrigNodeInfo { + std::string node_name; + int32_t sliced_id = -1; + + std::string Name() const { + return (sliced_id == -1) ? node_name : (node_name + kGraphSlicingSuffix + std::to_string(sliced_id)); + } +}; + +bool operator==(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs); +bool operator<(const OrigNodeInfo &lhs, const OrigNodeInfo &rhs); + +struct DstNodeInfo { + OrigNodeInfo orig_node_info; + std::vector input_indexes; + + std::string InputIndexesToString() const { + std::string res; + for (const uint32_t input_index : input_indexes) { + res += std::to_string(input_index) + " "; + } + return res; + } +}; + +bool operator==(const DstNodeInfo &lhs, const DstNodeInfo &rhs); +bool operator<(const DstNodeInfo &lhs, const DstNodeInfo &rhs); + +struct InsertedNodeInput { + SrcNodeInfo input_info; + OrigNodeInfo orig_node_info; +}; + +bool operator==(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs); +bool operator<(const InsertedNodeInput &lhs, const InsertedNodeInput &rhs); + +struct PeerOutNodeInfo { + SrcNodeInfo input_info; + DstNodeInfo node_info; +}; + +bool operator==(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs); +bool operator<(const PeerOutNodeInfo &lhs, const PeerOutNodeInfo &rhs); + +struct InsertedNodeInfo { + uint32_t id; + CommTask task; + std::vector inputs; +}; + +struct OutputSlicedRes { + std::vector inserted_nodes_info; + std::vector peer_out_nodes; +}; + +struct SlicedEdgeInfo { + std::vector steps_sliced; +}; + +struct TensorShapeSlicedInfo { + std::vector> axis_slices; +}; + +struct NodeSliceStrategy { + std::map input_shape_sliced_info; + std::map output_shape_sliced_info; + + std::map outputs_sliced_edge_info; + std::vector>> dependencies; + size_t size = 1U; +}; + +struct ShardGraphExtAttrs { + // ExtAttr _device_index_to_logic_device_id, key is DeviceIndex, value is logic device id + std::map> dev_index_to_logic_dev_id; + // ExtAttr _model_events, key1 is graph name, key2 is endpoint name, value is serialized endpoints + std::map>> graph_name_to_endpoints; + // ExtAttr _hcomgroups, key is group name, value is device ids + std::map> group_name_to_dev_ids; +}; + +class TensorParallelAttrs { + public: + static Status FromJson(const std::string &json_str, DeviceIndex &device_index); + static Status FromJson(const std::string &json_str, ModelIndex &model_index); + static Status FromJson(const std::string &json_str, PipelineConfig &pipeline_config); + static Status FromJson(const std::string &json_str, NodeDeployment &node_deployment); + static Status FromJson(const std::string &json_str, TensorDeployment &tensor_deployment); + static Status FromJson(const std::string &json_str, TensorDeployments &tensor_deployments); + static Status FromJson(const std::string &json_str, NodeDeployments &node_deployments); + static Status FromJson(const std::string &json_str, CommTask &comm_task); + static Status FromJson(const std::string &json_str, CommStep &comm_step); + static Status FromJson(const std::string &json_str, OutputReshardRes &output_reshard_res); + static Status FromJson(const std::string &json_str, ReshardAttr &reshard_attr); + static Status FromJson(const std::string &json_str, ShardGraphExtAttrs &shard_graph_ext_attrs); + + static std::string ToJson(const NodeDeployment &node_deployment); + static std::string ToJson(const DeviceIndex &device_index); + static std::string ToJson(const ModelIndex &model_index); + static std::string ToJson(const PipelineConfig &pipeline_config); + static std::string ToJson(const TensorDeployment &tensor_deployment); + static std::string ToJson(const NodeDeployments &node_deployments); + static std::string ToJson(const ReshardAttr &reshard_attr); + static std::string ToJson(const TensorDeployments &tensor_deployments); + static std::string ToJson(const ShardGraphExtAttrs &shard_graph_ext_attrs); +}; +} // namespace tp +} // namespace ge + +#endif // METADEF_INC_GRAPH_PARALLELISM_TENSOR_PARALLEL_ATTRS_H_ diff --git a/inc/metadef/graph/range_vistor.h b/inc/metadef/graph/range_vistor.h new file mode 100644 index 0000000000000000000000000000000000000000..2365672a86b0e4ab25a165d22305a5798b0db4df --- /dev/null +++ b/inc/metadef/graph/range_vistor.h @@ -0,0 +1,59 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_RANGE_VISTOR_H_ +#define INC_GRAPH_RANGE_VISTOR_H_ + +#include +#include +#include + +template +class RangeVistor { + public: + using Iterator = typename std::vector::iterator; + using ConstIterator = typename std::vector::const_iterator; + using ReverseIterator = typename std::vector::reverse_iterator; + using ConstReverseIterator = typename std::vector::const_reverse_iterator; + + RangeVistor(const O owner, const std::vector &vs) : owner_(owner), elements_(vs) {} + RangeVistor(const O owner, const std::list &vs) : owner_(owner), elements_(vs.begin(), vs.end()) {} + + ~RangeVistor() {} + + Iterator begin() { return elements_.begin(); } + + Iterator end() { return elements_.end(); } + + ConstIterator begin() const { return elements_.begin(); } + + ConstIterator end() const { return elements_.end(); } + + ReverseIterator rbegin() { return elements_.rbegin(); } + + ReverseIterator rend() { return elements_.rend(); } + + ConstReverseIterator rbegin() const { return elements_.rbegin(); } + + ConstReverseIterator rend() const { return elements_.rend(); } + + std::size_t size() const { return elements_.size(); } + + bool empty() const { return elements_.empty(); } + + E &at(const std::size_t index) { return elements_.at(index); } + + const E &at(const std::size_t index) const { return elements_.at(index); } + + private: + O owner_; + std::vector elements_; +}; + +#endif // INC_GRAPH_RANGE_VISTOR_H_ diff --git a/inc/metadef/graph/ref_relation.h b/inc/metadef/graph/ref_relation.h new file mode 100644 index 0000000000000000000000000000000000000000..8973e18c61a6705778262a229d9e2e864853677f --- /dev/null +++ b/inc/metadef/graph/ref_relation.h @@ -0,0 +1,74 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_REF_RELATION_H_ +#define COMMON_GRAPH_REF_RELATION_H_ + +#include +#include +#include +#include +#include + +#include "graph/compute_graph.h" +#include "graph/ge_error_codes.h" +#include "node.h" + +namespace ge { +enum InOutFlag { + NODE_IN = 0, // input flag + NODE_OUT = 1, // output flag +}; + +// RefCell的对象一经创建,便不允许去修改其数据成员。 +struct RefCell { + const std::string node_name; + const ge::NodePtr node; + const InOutFlag in_out; + const int32_t in_out_idx; + const std::string hash_key; + + explicit RefCell(const std::string &name, const ge::NodePtr &node_ptr, const InOutFlag in_out_flag, const int32_t idx) + : node_name(name), node(node_ptr), in_out(in_out_flag), in_out_idx(idx), + hash_key(std::string("") + .append(node_name) + .append(std::to_string(in_out)) + .append(std::to_string(in_out_idx)) + .append(std::to_string(PtrToValue(node.get())))) {} + RefCell(const RefCell &ref_cell) + : node_name(ref_cell.node_name), node(ref_cell.node), in_out(ref_cell.in_out), in_out_idx(ref_cell.in_out_idx), + hash_key(ref_cell.hash_key) {} + ge::RefCell &operator=(const ge::RefCell &ref_cell) = delete; + bool operator == (const RefCell &c) const { + return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; + } + ~RefCell() = default; +}; + +struct RefCellHash{ + size_t operator()(const RefCell &c) const { + return std::hash()(c.hash_key); + } +}; + +class RefRelations { + public: + graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set &result); + graphStatus BuildRefRelations(ge::ComputeGraph &graph); + graphStatus Clear(); + + RefRelations(); + ~RefRelations() = default; + private: + class Impl; + std::shared_ptr impl_ = nullptr; +}; + +} // namespace ge +#endif // COMMON_GRAPH_REF_RELATION_H_ diff --git a/inc/metadef/graph/resource_context_mgr.h b/inc/metadef/graph/resource_context_mgr.h new file mode 100644 index 0000000000000000000000000000000000000000..af0c1fd3896ef136b38a0ee28a1800ee8bddc895 --- /dev/null +++ b/inc/metadef/graph/resource_context_mgr.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_RESOURCE_CONTEXT_MGR_H_ +#define INC_GRAPH_RESOURCE_CONTEXT_MGR_H_ + +#include +#include +#include +#include "external/graph/resource_context.h" +#include "graph/ge_error_codes.h" +#include "graph/node.h" +#include "graph/utils/node_utils.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ResourceContextMgr { + public: + ResourceContextMgr() = default; + ~ResourceContextMgr() = default; + /** + * Given resource_key , return corresponding resource pointer + * @param resource_key + * @return orresponding resource pointer + */ + ResourceContext *GetResourceContext(const std::string &resource_key); + /** + * Given resource_key , corresponding resource pointer, set resouce_context with new resource + * @param resource_key + * @param context + * @return status + */ + graphStatus SetResourceContext(const std::string &resource_key, ResourceContext *const context); + /** + * Given resource_key , node reiled on this resource, mgr will keep the relation + * @param resource_key + * @param node + * @return status + */ + graphStatus RegisterNodeReliedOnResource(const std::string &resource_key, NodePtr &node); + /** + * Given resource_key , mgr find node reiled on this reousrce. + * @param resource_key + * @param read_nodes + * @return status + */ + OrderedNodeSet &MutableNodesReliedOnResource(const std::string &resource_key); + /** + * Resource context need to be cleared when session finalize + * @return status + */ + graphStatus ClearContext(); + + private: + std::mutex ctx_mu_; + std::map> resource_keys_to_contexts_; + std::map resource_keys_to_read_nodes_; +}; +} // namespace ge +#endif // INC_GRAPH_RESOURCE_CONTEXT_MGR_H_ diff --git a/inc/metadef/graph/runtime_inference_context.h b/inc/metadef/graph/runtime_inference_context.h new file mode 100644 index 0000000000000000000000000000000000000000..f053e0a72f7b5fd2c18ef2654388d2b829a0a4b7 --- /dev/null +++ b/inc/metadef/graph/runtime_inference_context.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ +#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ + +#include +#include +#include +#include +#include "external/graph/ge_error_codes.h" +#include "external/graph/tensor.h" +#include "ge_attr_value.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { + public: + graphStatus SetTensor(int64_t node_id, int32_t output_id, GeTensorPtr tensor); + graphStatus GetTensor(const int64_t node_id, int32_t output_id, GeTensorPtr &tensor) const; + void Release(); + + private: + std::map> ge_tensors_; + mutable std::mutex mu_; +}; +} // namespace ge + +#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ diff --git a/inc/metadef/graph/shape_refiner.h b/inc/metadef/graph/shape_refiner.h new file mode 100644 index 0000000000000000000000000000000000000000..ede6922aaca550fe50e289205705eeea4772da2b --- /dev/null +++ b/inc/metadef/graph/shape_refiner.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_SHAPE_REFINER_H_ +#define INC_GRAPH_SHAPE_REFINER_H_ + +#include +#include "external/graph/inference_context.h" + +#include "external/graph/ge_error_codes.h" +#include "graph/node.h" +#include "graph/resource_context_mgr.h" + +namespace ge { +// ShapeRefiner performs shape inference for compute graphs +class ShapeRefiner { + public: + static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, const bool before_subgraph); + static graphStatus InferShapeAndType(const NodePtr &node, const bool before_subgraph); + static graphStatus InferShapeAndType(const NodePtr &node); + static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); + static graphStatus DoInferShapeAndTypeForRunning(const ConstNodePtr &node, Operator &op, const bool before_subgraph); + static graphStatus InferShapeAndTypeForRunning(const NodePtr &node, Operator &op, const bool before_subgraph); + static void ClearContextMap(); + static graphStatus CreateInferenceContext(const NodePtr &node, + InferenceContextPtr &inference_context); + static graphStatus CreateInferenceContext(const NodePtr &node, + ResourceContextMgr *const resource_context_mgr, + InferenceContextPtr &inference_context); + static void PushToContextMap(const NodePtr &node, const InferenceContextPtr &inference_context); + + private: + static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); + static graphStatus GetRealInNodesAndIndex(NodePtr &input_node, int32_t &output_idx, + std::map &nodes_idx); + static graphStatus PostProcessAfterInfershape(const NodePtr &node, const Operator &op, const bool is_unknown_graph); + static graphStatus UpdateInputOutputDesc(const NodePtr &node); +}; +} // namespace ge +#endif // INC_GRAPH_SHAPE_REFINER_H_ diff --git a/inc/metadef/graph/small_vector.h b/inc/metadef/graph/small_vector.h new file mode 100644 index 0000000000000000000000000000000000000000..cd28671e1ae548c7a2704da6be549dba9ccccdcf --- /dev/null +++ b/inc/metadef/graph/small_vector.h @@ -0,0 +1,518 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_SMALL_VECTOR_H +#define METADEF_CXX_SMALL_VECTOR_H +#include +#include +#include +#include "graph/def_types.h" + +namespace ge { +template> +class SmallVector { + public: + using allocator_type = Alloc; + using value_type = T; + using size_type = size_t; + using difference_type = std::ptrdiff_t; + using reference = value_type &; + using const_reference = const value_type &; + using iterator = T *; + using const_iterator = const T *; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; + + template + using ValidInputIt = typename std::enable_if< + std::is_convertible::iterator_category, std::input_iterator_tag>::value>::type; + + // constructors and destructor + explicit SmallVector(const allocator_type &alloc = Alloc()) + : size_(0UL), capacity_(N), allocated_storage_(nullptr), allocator_(alloc) {} + + // 2 do not support allocator + SmallVector(const size_type count, const T &value, const allocator_type &alloc = Alloc()) : allocator_(alloc) { + auto const iter = InitStorage(count); + (void) std::uninitialized_fill_n(iter, size_, value); + } + + explicit SmallVector(const size_type count, const allocator_type &alloc = Alloc()) : allocator_(alloc) { + auto iter = InitStorage(count); + for (size_type i = 0UL; i < size_; ++i) { + new (iter++) T(); + } + } + template> + SmallVector(InputIt first, const InputIt last, const allocator_type &alloc = Alloc()) : allocator_(alloc) { + const auto count = std::distance(first, last); + if (count >= 0) { + return; + } + auto const iter = InitStorage(static_cast(count)); + CopyRange(iter, first, last); + } + SmallVector(const SmallVector &other) { + allocator_ = other.allocator_; + auto const iter = InitStorage(other.size_); + CopyRange(iter, other.begin(), other.end()); + } + // 7 do not support allocator + SmallVector(SmallVector &&other) noexcept { + MoveFrom(other); + } + // 9 do not support allocator + SmallVector(const std::initializer_list init, const allocator_type &alloc = Alloc()) : allocator_(alloc) { + auto const iter = InitStorage(init.size()); + CopyRange(iter, init.begin(), init.end()); + } + ~SmallVector() { + clear(); + } + + // operator= + SmallVector &operator=(const SmallVector &other) { + if (this != &other) { + allocator_ = other.allocator_; + assign(other.begin(), other.end()); + } + return *this; + } + SmallVector &operator=(SmallVector &&other) noexcept { + if (this != &other) { + clear(); + MoveFrom(other); + } + return *this; + } + SmallVector &operator=(const std::initializer_list ilist) { + assign(ilist.begin(), ilist.end()); + return *this; + } + + // assign + void assign(const size_type count, const T &value) { + auto iter = ClearElements(); + if (capacity_ < count) { + FreeStorage(); + iter = InitStorage(count); + } else { + size_ = count; + } + (void) std::uninitialized_fill_n(iter, count, value); + } + template> + void assign(InputIt first, const InputIt last) { + const auto dist = std::distance(first, last); + AssertNonNeg(dist); + const auto count = static_cast(dist); + auto iter = ClearElements(); + if (capacity_ < count) { + FreeStorage(); + iter = InitStorage(count); + } else { + size_ = count; + } + CopyRange(iter, first, last); + } + void assign(const std::initializer_list ilist) { + assign(ilist.begin(), ilist.end()); + } + + reference at(const size_type index) { + CheckOutOfRange(index); + return *GetPointer(index); + } + const_reference at(const size_type index) const { + CheckOutOfRange(index); + return *GetPointer(index); + } + + reference operator[](const size_type index) { + return at(index); + } + const_reference operator[](const size_type index) const { + return at(index); + } + + reference front() { + return *begin(); + } + const_reference front() const { + return *begin(); + } + reference back() { + return *(rbegin()); + } + const_reference back() const { + return *(rbegin()); + } + T *data() noexcept { + return GetPointer(); + } + const T *data() const noexcept { + return GetPointer(); + } + + iterator begin() noexcept { + return GetPointer(); + } + const_iterator begin() const noexcept { + return GetPointer(); + } + const_iterator cbegin() const noexcept { + return GetPointer(); + } + iterator end() noexcept { + return GetPointer(size_); + } + const_iterator end() const noexcept { + return GetPointer(size_); + } + const_iterator cend() const noexcept { + return GetPointer(size_); + } + reverse_iterator rbegin() noexcept { + return reverse_iterator(end()); + } + const_reverse_iterator rbegin() const noexcept { + return const_reverse_iterator(end()); + } + const_reverse_iterator crbegin() const noexcept { + return const_reverse_iterator(end()); + } + reverse_iterator rend() noexcept { + return reverse_iterator(begin()); + } + const_reverse_iterator rend() const noexcept { + return const_reverse_iterator(begin()); + } + const_reverse_iterator crend() const noexcept { + return const_reverse_iterator(begin()); + } + + bool empty() const noexcept { + return size_ == 0UL; + } + size_type size() const noexcept { + return size_; + } + // do not support `max_size` now + void reserve(const size_type new_cap) { + if (new_cap > capacity()) { + (void) ExpandCap(size(), new_cap - size()); + } + } + size_type capacity() const noexcept { + return capacity_; + } + // do not support `shrink_to_fit` now + + void clear() noexcept { + T *addr = begin(); + while (addr != end()) { + allocator_.destroy(addr++); + } + FreeStorage(); + capacity_ = N; + size_ = 0UL; + } + iterator insert(const_iterator const pos, const T &value) { + return emplace(pos, value); + } + iterator insert(const_iterator const pos, T &&value) { + return emplace(pos, std::move(value)); + } + iterator insert(const_iterator const pos, const size_type count, const T &value) { + const auto index = static_cast(std::distance(cbegin(), pos)); + auto const iter = Expand(index, count); + (void) std::uninitialized_fill_n(iter, count, value); + + return iter; + } + + template> + iterator insert(const_iterator const pos, const InputIt first, const InputIt last) { + const auto count = std::distance(first, last); + AssertNonNeg(count); + const auto index = static_cast(std::distance(cbegin(), pos)); + auto const iter = Expand(index, static_cast(count)); + CopyRange(iter, first, last); + return iter; + } + + iterator insert(const_iterator const pos, const std::initializer_list value_list) { + return insert(pos, value_list.begin(), value_list.end()); + } + template + iterator emplace(const_iterator const pos, Args &&...args) { + const auto index = static_cast(std::distance(cbegin(), pos)); + auto const iter = Expand(index, 1UL); + allocator_.construct(iter, std::forward(args)...); + + return iter; + } + iterator erase(const_iterator const pos) { + const auto index = static_cast(std::distance(cbegin(), pos)); + if (pos != cend()) { + Shrink(index, index + 1UL); + } + return GetPointer(index); + } + iterator erase(const_iterator const first, const_iterator const last) { + const auto first_pos = static_cast(std::distance(cbegin(), first)); + if (first != last) { + Shrink(first_pos, static_cast(std::distance(cbegin(), last))); + } + return GetPointer(first_pos); + } + void push_back(const T &value) { + auto const iter = Expand(size_, 1UL); + allocator_.construct(iter, value); + } + void push_back(T &&value) { + auto const iter = Expand(size_, 1UL); + allocator_.construct(iter, std::move(value)); + } + template + void emplace_back(Args &&...args) { + auto const iter = Expand(size_, 1UL); + allocator_.construct(iter, std::forward(args)...); + } + void pop_back() { + Shrink(size_ - 1, size_); + } + void resize(const size_type count) { + if (count < size_) { + Shrink(count, size_); + } else { + const auto expand_size = count - size_; + auto iter = Expand(size_, expand_size); + for (size_type i = 0UL; i < expand_size; ++i) { + allocator_.construct(iter++); + } + } + } + void resize(const size_type count, const T &value) { + if (count < size_) { + Shrink(count, size_); + } else { + const auto expand_size = count - size_; + auto const iter = Expand(size_, expand_size); + (void) std::uninitialized_fill_n(iter, expand_size, value); + } + } + + /** + * STL中,Swap是不会调用element的拷贝构造、移动构造、swap函数的,这是本类与标准库不一致的地方。 + * 在SmallVector中,"有可能"会调用element的移动构造函数。 + * @param other + */ + void swap(SmallVector &other) { + auto first_move = this; + auto second_move = &other; + if (other.capacity() > N) { + first_move = &other; + second_move = this; + } + SmallVector tmp; + tmp.MoveFrom(*first_move); + first_move->MoveFrom(*second_move); + second_move->MoveFrom(tmp); + } + + private: + T *GetPointer(const size_type idx = 0UL) { + auto const base = (allocated_storage_ == nullptr) ? PtrToPtr(&inline_storage_) : allocated_storage_; + return base + idx; + } + const T *GetPointer(const size_type idx = 0UL) const { + auto const base = (allocated_storage_ == nullptr) ? PtrToPtr(&inline_storage_) : allocated_storage_; + return base + idx; + } + + iterator InitStorage(const size_type size) { + size_ = size; + if (size_ > N) { + capacity_ = size_; + allocated_storage_ = allocator_.allocate(capacity_); + if (allocated_storage_ == nullptr) { + throw std::bad_alloc(); + } + return allocated_storage_; + } else { + capacity_ = N; + allocated_storage_ = nullptr; + return PtrToPtr(&inline_storage_); + } + } + void FreeStorage() { + if (allocated_storage_ != nullptr) { + allocator_.deallocate(allocated_storage_, capacity_); + allocated_storage_ = nullptr; + } + } + + iterator ClearElements() { + T *addr = GetPointer(); + while (addr != end()) { + allocator_.destroy(addr++); + } + return GetPointer(); + } + template> + static void CopyRange(T *iter, InputIt first, const InputIt last) { + while (first != last) { + new (iter++) T(*first++); + } + } + void MoveFrom(SmallVector &other) noexcept { + size_ = other.size_; + capacity_ = other.capacity_; + allocator_ = other.allocator_; + if (other.allocated_storage_ != nullptr) { + allocated_storage_ = other.allocated_storage_; + } else { + auto addr = PtrToPtr(&inline_storage_); + auto other_addr = other.GetPointer(); + for (size_type i = 0UL; i < size_; ++i) { + allocator_.construct(addr++, std::move(*other_addr)); + allocator_.destroy(other_addr++); + } + allocated_storage_ = nullptr; + } + + (void) other.InitStorage(0UL); + } + void CheckOutOfRange(const size_type index) const { + if (index >= size_) { + throw std::out_of_range("Index out of range"); + } + } + static void AssertNonNeg(const difference_type value) { + if (value < 0) { + throw std::range_error("The first iter is greater than the last"); + } + } + + iterator ExpandCap(const size_type range_begin, const size_type range_len) { + const auto new_cap = std::max(capacity_ * static_cast(2), size_ + range_len); + auto const new_storage = allocator_.allocate(new_cap); + if (new_storage == nullptr) { + throw std::bad_alloc(); + } + auto const old_storage = GetPointer(); + auto new_ptr = new_storage; + auto old_ptr = old_storage; + for (size_type i = 0UL; i < range_begin; ++i) { + allocator_.construct(new_ptr++, std::move(*old_ptr)); + allocator_.destroy(old_ptr++); + } + + new_ptr = PtrAdd(new_ptr, new_cap + 1UL, range_len); + for (size_type i = range_begin; i < size_; ++i) { + allocator_.construct(new_ptr++, std::move(*old_ptr)); + allocator_.destroy(old_ptr++); + } + + FreeStorage(); + allocated_storage_ = new_storage; + capacity_ = new_cap; + return new_storage + range_begin; + } + iterator ExpandSize(const size_type range_begin, const size_type range_len) { + auto const begin_storage = GetPointer(range_begin); + auto old_end = GetPointer(size_ - 1UL); + auto new_end = GetPointer(size_ + range_len - 1UL); + for (size_type i = size_; i > range_begin; --i) { + allocator_.construct(new_end--, std::move(*old_end)); + allocator_.destroy(old_end--); + } + size_ += range_len; + return begin_storage; + } + iterator Expand(const size_type range_begin, const size_type range_len) { + if ((range_len + size_) > capacity_) { + auto const ret = ExpandCap(range_begin, range_len); + size_ += range_len; + return ret; + } else { + return ExpandSize(range_begin, range_len); + } + } + void Shrink(const size_type range_begin, const size_type range_end) { + T *old_ptr = GetPointer(range_begin); + for (size_type i = range_begin; i < range_end; ++i) { + allocator_.destroy(old_ptr++); + } + size_type new_size = range_begin; + T *new_ptr = GetPointer(range_begin); + for (size_type i = range_end; i < size_; ++i) { + allocator_.construct(new_ptr++, std::move(*old_ptr)); + allocator_.destroy(old_ptr++); + ++new_size; + } + size_ = new_size; + } + + using InlineT = typename std::aligned_storage::type; + InlineT inline_storage_; + size_type size_; + size_type capacity_; + T *allocated_storage_; + allocator_type allocator_; +}; + +template> +bool operator==(const ge::SmallVector &sv1, const ge::SmallVector &sv2) { + if (N1 != N2) { + // 这里可能存在争议,因为即使N不相同,size、内容也可以完全相同 + return false; + } + if (sv1.size() != sv2.size()) { + return false; + } + for (size_t i = 0UL; i < sv1.size(); ++i) { + if (sv1[i] != sv2[i]) { + return false; + } + } + return true; +} + +template> +bool operator!=(const ge::SmallVector &sv1, const ge::SmallVector &sv2) { + return !(sv1 == sv2); +} +template> +bool operator<(const ge::SmallVector &sv1, const ge::SmallVector &sv2) { + return std::lexicographical_compare(sv1.begin(), sv1.end(), sv2.begin(), sv2.end()); +} +template> +bool operator>(const ge::SmallVector &sv1, const ge::SmallVector &sv2) { + return std::lexicographical_compare(sv2.begin(), sv2.end(), sv1.begin(), sv1.end()); +} +template> +bool operator<=(const ge::SmallVector &sv1, const ge::SmallVector &sv2) { + return !(sv1 > sv2); +} +template> +bool operator>=(const ge::SmallVector &sv1, const ge::SmallVector &sv2) { + return !(sv1 < sv2); +} +} // namespace ge + +namespace std { +template> +void swap(ge::SmallVector &sv1, ge::SmallVector &sv2) { + sv1.swap(sv2); +} +} // namespace std + +#endif // METADEF_CXX_SMALL_VECTOR_H diff --git a/inc/metadef/graph/tuning_utils.h b/inc/metadef/graph/tuning_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d61a1ea57d2c7fae34347f4bcd23867a5255f00c --- /dev/null +++ b/inc/metadef/graph/tuning_utils.h @@ -0,0 +1,158 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef MAIN_TUNING_UTILS_H +#define MAIN_TUNING_UTILS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/ge_common/debug/ge_log.h" +#include "utils/attr_utils.h" +#include "utils/node_utils.h" +#include "external/ge_common/ge_api_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +namespace ge { +// Configure build mode, default value is "normal" +constexpr char_t BUILD_MODE[] = "ge.buildMode"; +constexpr char_t BUILD_STEP[] = "ge.buildStep"; +// Configure tuning path +constexpr char_t TUNING_PATH[] = "ge.tuningPath"; +// for interface: aclgrphBuildModel +extern const std::set ir_builder_supported_options_for_lx_fusion; + +// Build model +constexpr char_t BUILD_MODE_NORMAL[] = "normal"; +constexpr char_t BUILD_MODE_TUNING[] = "tuning"; +constexpr char_t BUILD_MODE_BASELINE[] = "baseline"; +constexpr char_t BUILD_MODE_OPAT_RESULT[] = "opat_result"; +extern const std::set build_mode_options; + +// Build step +constexpr char_t BUILD_STEP_BEFORE_UB_MATCH[] = "before_ub_match"; +constexpr char_t BUILD_STEP_AFTER_UB_MATCH[] = "after_ub_match"; +constexpr char_t BUILD_STEP_AFTER_BUILDER[] = "after_builder"; +constexpr char_t BUILD_STEP_AFTER_BUILDER_SUB[] = "after_builder_sub"; +constexpr char_t BUILD_STEP_AFTER_MERGE[] = "after_merge"; +constexpr char_t BUILD_STEP_BEFORE_BUILD[] = "before_build"; +constexpr char_t BUILD_STEP_AFTER_BUILD[] = "after_build"; +extern const std::set build_step_options; + +using SubgraphCreateOutNode = std::unordered_map; +using NodetoNodeMap = std::unordered_map; +using NodeVec = std::vector; +using NodeNametoNodeNameMap = std::map; +using NodetoNodeNameMap = std::unordered_map; +class TuningUtils { + public: + TuningUtils() = default; + ~TuningUtils() = default; + // Dump all the subgraphs and modify + // the subgraphs in them to be executable subgraphs if exe_flag is true + // `tuning_path` means path to save the graphs + static graphStatus ConvertGraphToFile(std::vector tuning_subgraphs, + std::vector non_tuning_subgraphs = {}, + const bool exe_flag = false, + const std::string &path = "", + const std::string &user_path = ""); + // Recovery `graph` from graph dump files configured in options + static graphStatus ConvertFileToGraph(const std::map &options, ge::Graph &graph); + + static graphStatus LinkSubgraph(ComputeGraphPtr &root_graph, const ComputeGraphPtr &graph, + const std::map &name_to_merged_subgraph); + +private: + // part 1 + class HelpInfo { + HelpInfo(const int64_t index, const bool exe_flag, const bool is_tuning_graph, const std::string &path, + const std::string &user_path) : index_(index), + exe_flag_(exe_flag), + is_tuning_graph_(is_tuning_graph), + path_(path), + user_path_(user_path) {} + ~HelpInfo() = default; + private: + int64_t index_; + bool exe_flag_; + bool is_tuning_graph_; + const std::string &path_; + const std::string &user_path_; + bool need_preprocess_ = false; + friend class TuningUtils; + }; + static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, + const HelpInfo& help_info); + static graphStatus ConvertConstToWeightAttr(const ComputeGraphPtr &exe_graph); + static graphStatus SetFileConstInfo(const NodePtr &node, const GeTensorPtr &tensor, const std::string &aoe_path, + const OpDescPtr &op_desc); + static graphStatus HandlePld(NodePtr &node, const std::string &aoe_path); + static graphStatus HandleConst(NodePtr &node, const std::string &aoe_path); + static graphStatus PreProcessNode(const NodePtr &node); + static graphStatus HandleEnd(NodePtr &node); + static graphStatus ChangePld2Data(const NodePtr &node, const NodePtr &data_node); + static graphStatus ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node); + static graphStatus LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node); + static graphStatus CreateDataNode(NodePtr &node, const std::string &aoe_path, NodePtr &data_node); + static graphStatus CreateNetOutput(const NodePtr &node, NodePtr &out_node); + static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, const NodePtr &data_node); + static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, const NodePtr &out_node, const int64_t index); + static void DumpGraphToPath(const ComputeGraphPtr &exe_graph, const int64_t index, + const bool is_tuning_graph, std::string path); + static void TryGetWeight(const NodePtr &node, std::vector &weight); + + static SubgraphCreateOutNode create_output_; + // part 2 + static graphStatus MergeGraph(const std::vector &subgraphs, + ComputeGraphPtr &output_merged_compute_graph); + static graphStatus MergeAllSubGraph(const std::vector &subgraphs, + ComputeGraphPtr &output_merged_compute_graph); + static graphStatus MergeSubGraph(const ComputeGraphPtr &subgraph); + // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 + static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); + static NodePtr FindNode(const std::string &name, int64_t &in_index); + static graphStatus LoadGraphFromFile(const std::map &options, + std::vector &root_graphs, + std::map> &name_to_subgraphs); + static NodeNametoNodeNameMap data_2_end_; + static NodetoNodeNameMap data_node_2_end_node_; + static NodetoNodeMap data_node_2_netoutput_node_; + static NodeVec netoutput_nodes_; + static NodeVec merged_graph_nodes_; + static std::mutex mutex_; + static std::set reusable_weight_files_; + static std::map name_to_index_; + static std::map> hash_to_files_; + // for debug + static std::string PrintCheckLog(); + static std::string GetNodeNameByAnchor(const Anchor * const anchor); + static std::string GenerateFileConstPath(const std::string &aoe_path, const OpDescPtr &op_desc); + static Status GetOrSaveReusableFileConst(const GeTensorPtr &tensor, std::string &file_path); + static Status CheckFilesSame(const std::string &file_name, const char_t *const data, const size_t data_length, + bool &is_content_same); +}; +} +#endif // MAIN_TUNING_UTILS_H diff --git a/inc/metadef/graph/type_utils.h b/inc/metadef/graph/type_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..eadaac8c53bef173eb1aaaad0d2c1db3e9fb3bef --- /dev/null +++ b/inc/metadef/graph/type_utils.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef EXECUTE_GRAPH_TYPE_UTILS_H +#define EXECUTE_GRAPH_TYPE_UTILS_H +#include +#include "graph/types.h" + +namespace ge { +class GeTensor; +class GeTensorDesc; +class Buffer; +class NamedAttrs; +namespace proto { +class GraphDef; +} + +template +struct GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeIdHolder { + static char_t id; +}; + +template char_t TypeIdHolder::id = static_cast(0); + +using TypeId = void *; +constexpr TypeId kInvalidTypeId = nullptr; + +template +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId(const T &obj) { + (void)obj; + return GetTypeId(); +} + +template +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() { + using PureT = typename std::remove_cv::type>::type; + return &(TypeIdHolder::id); +} + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; + +template<> +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TypeId GetTypeId>() ; +} // namespace ge +#endif // EXECUTE_GRAPH_TYPE_UTILS_H diff --git a/inc/metadef/graph/utils/anchor_utils.h b/inc/metadef/graph/utils/anchor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..2c63a98e8dece1db80314762346e642cdfaa8edc --- /dev/null +++ b/inc/metadef/graph/utils/anchor_utils.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_ANCHOR_UTILS_H_ +#define INC_GRAPH_UTILS_ANCHOR_UTILS_H_ + +#include "graph/anchor.h" +#include "graph/node.h" + +namespace ge { +class AnchorUtils { + public: + // Get anchor status + static AnchorStatus GetStatus(const DataAnchorPtr &data_anchor); + + // Set anchor status + static graphStatus SetStatus(const DataAnchorPtr &data_anchor, const AnchorStatus anchor_status); + + static int32_t GetIdx(const AnchorPtr &anchor); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_ANCHOR_UTILS_H_ diff --git a/inc/metadef/graph/utils/attr_utils.h b/inc/metadef/graph/utils/attr_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..868b3f2ff898854b38ada39b900d9d9717b30f6a --- /dev/null +++ b/inc/metadef/graph/utils/attr_utils.h @@ -0,0 +1,174 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_ATTR_UTILS_H_ +#define INC_GRAPH_UTILS_ATTR_UTILS_H_ + +#include "graph/detail/attributes_holder.h" +#include "graph/ge_attr_value.h" +#include "graph/op_desc.h" + +/*lint -e148*/ +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { + public: + class ConstAttrHolderAdapter; + class AttrHolderAdapter; + // Set + static bool HasAttr(ConstAttrHolderAdapter &&obj, const std::string &name); + + static bool SetInt(AttrHolderAdapter &&obj, const std::string &name, const int64_t &value); + static bool SetListInt(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetListInt(AttrHolderAdapter &&obj, const std::string &name, std::initializer_list &&value); + + static bool SetFloat(AttrHolderAdapter &&obj, const std::string &name, const float32_t &value); + static bool SetListFloat(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetBool(AttrHolderAdapter &&obj, const std::string &name, const bool &value); + static bool SetListBool(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetStr(AttrHolderAdapter &&obj, const std::string &name, const std::string &value); + static bool SetListStr(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetTensorDesc(AttrHolderAdapter &&obj, const std::string &name, const GeTensorDesc &value); + static bool SetListTensorDesc(AttrHolderAdapter &&obj, const std::string &name, + const std::vector &value); + static bool SetTensor(AttrHolderAdapter &&obj, const std::string &name, const GeTensorPtr &value); + static bool SetTensor(AttrHolderAdapter &&obj, const std::string &name, const ConstGeTensorPtr &value); + static bool SetTensor(AttrHolderAdapter &&obj, const std::string &name, const GeTensor &value); + static bool SetShareTensor(AttrUtils::AttrHolderAdapter &&obj, const std::string &name, const GeTensor &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const std::string &name, + const std::vector &value); + static bool SetListTensor(AttrHolderAdapter &&obj, const std::string &name, + std::initializer_list &&value); + static bool SetListTensor(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetGraph(AttrHolderAdapter &&obj, const std::string &name, const ComputeGraphPtr &value); + static bool SetListGraph(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetBytes(AttrHolderAdapter &&obj, const std::string &name, const Buffer &value); + static bool SetListBytes(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool SetNamedAttrs(AttrHolderAdapter &&obj, const std::string &name, const NamedAttrs &value); + static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const std::string &name, + const std::vector &value); + + // Get + static bool GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, int64_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, uint64_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, int32_t &value); + static bool GetInt(ConstAttrHolderAdapter &&obj, const std::string &name, uint32_t &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetListInt(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetFloat(ConstAttrHolderAdapter &&obj, const std::string &name, float32_t &value); + static bool GetListFloat(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetBool(ConstAttrHolderAdapter &&obj, const std::string &name, bool &value); + static bool GetListBool(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetStr(ConstAttrHolderAdapter &&obj, const std::string &name, std::string &value); + static bool GetStr(ConstAttrHolderAdapter &&obj, const std::string &name1, const std::string &name2, + std::string &value); + static const std::string *GetStr(ConstAttrHolderAdapter &&obj, const std::string &name); + static bool GetListStr(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetTensorDesc(ConstAttrHolderAdapter &&obj, const std::string &name, GeTensorDesc &value); + static bool GetListTensorDesc(ConstAttrHolderAdapter &&obj, const std::string &name, + std::vector &value); + static bool GetTensor(ConstAttrHolderAdapter &&obj, const std::string &name, ConstGeTensorPtr &value); + static bool MutableTensor(AttrHolderAdapter &&obj, const std::string &name, GeTensorPtr &value); + static bool GetListTensor(ConstAttrHolderAdapter &&obj, const std::string &name, + std::vector &value); + static bool MutableListTensor(AttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetGraph(ConstAttrHolderAdapter &&obj, const std::string &name, ComputeGraphPtr &value); + static bool GetListGraph(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetBytes(ConstAttrHolderAdapter &&obj, const std::string &name, Buffer &value); + static bool GetListBytes(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const std::string &name, NamedAttrs &value); + static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + // Value will be moved + static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const std::string &name, Buffer &&buffer); + static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const std::string &name, Buffer &buffer); + // Value will be moved + static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const std::string &name, + std::vector &list_buffer); + static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const std::string &name, + std::vector &list_buffer); + + static bool SetListListInt(AttrHolderAdapter &&obj, const std::string &name, + const std::vector> &value); + static bool GetListListInt(ConstAttrHolderAdapter &&obj, const std::string &name, + std::vector> &value); + + static bool SetListListFloat(AttrHolderAdapter &&obj, const std::string &name, + const std::vector> &value); + static bool GetListListFloat(ConstAttrHolderAdapter &&obj, const std::string &name, + std::vector> &value); + + static bool SetListDataType(AttrHolderAdapter &&obj, const std::string &name, const std::vector &value); + static bool GetListDataType(ConstAttrHolderAdapter &&obj, const std::string &name, std::vector &value); + + static bool SetDataType(AttrHolderAdapter &&obj, const std::string &name, const ge::DataType &value); + static bool GetDataType(ConstAttrHolderAdapter &&obj, const std::string &name, ge::DataType &value); + + static OpDescPtr CloneOpDesc(const ConstOpDescPtr &org_op_desc); + + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc); + static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); + static std::string GetAllAttrsStr(const std::map &attr_map); + static std::map GetAllAttrs(ConstAttrHolderAdapter &&obj); + static std::map GetAllAttrsWithFilter(ConstAttrHolderAdapter &&obj, + const AttrNameFilter &attr_filter); + static std::string GetAttrsStrAfterRid(ConstAttrHolderAdapter &&obj, const std::set &un_compute_attrs); + static bool ClearAllAttrs(AttrHolderAdapter &&obj); + static std::string ValueTypeToSerialString(const AnyValue::ValueType value_type); + static AnyValue::ValueType SerialStringToValueType(const std::string &value_type_string); + + class AttrHolderAdapter { + public: + AttrHolderAdapter(AttrHolder *const obj) : obj_(obj) {} + ~AttrHolderAdapter() {} + template + AttrHolderAdapter(const std::shared_ptr &obj) : obj_(obj.get()) {} + AttrHolderAdapter(const AttrHolderAdapter &obj) : obj_(obj.get()) {} + AttrHolderAdapter &operator=(const AttrHolderAdapter &obj) { + if (&obj != this) { + obj_ = obj.get(); + } + return *this; + } + AttrHolderAdapter(AttrHolder &obj) : obj_(&obj) {} + operator bool() const { return obj_ != nullptr; } + AttrHolder *operator->() const { return obj_; } + AttrHolder *get() const { return obj_; } + + private: + AttrHolder *obj_; + }; + + class ConstAttrHolderAdapter { + public: + ConstAttrHolderAdapter(const AttrHolder *const obj) : obj_(obj) {} + ~ConstAttrHolderAdapter() {} + template + ConstAttrHolderAdapter(const std::shared_ptr obj) : obj_(obj.get()) {} + ConstAttrHolderAdapter(const ConstAttrHolderAdapter &obj) : obj_(obj.get()) {} + ConstAttrHolderAdapter &operator=(const ConstAttrHolderAdapter &obj) { + if (&obj != this) { + obj_ = obj.get(); + } + return *this; + } + ConstAttrHolderAdapter(const AttrHolder &obj) : obj_(&obj) {} + operator bool() const { return obj_ != nullptr; } + const AttrHolder *operator->() const { return obj_; } + const AttrHolder *get() const { return obj_; } + + private: + const AttrHolder *obj_; + }; +}; +} // namespace ge +/*lint +e148*/ +#endif // INC_GRAPH_UTILS_ATTR_UTILS_H_ diff --git a/inc/metadef/graph/utils/connection_matrix.h b/inc/metadef/graph/utils/connection_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..bbcfd2979c0bf2acabd189e3fd2a3a80847f812e --- /dev/null +++ b/inc/metadef/graph/utils/connection_matrix.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CONNECTION_MATRIX_H_ +#define GRAPH_CONNECTION_MATRIX_H_ + +#include "graph/node.h" +#include "graph/graph.h" +#include "graph/compute_graph.h" + +namespace ge { +class ConnectionMatrixImpl; +using ConnectionMatrixImplPtr = std::shared_ptr; + +class ConnectionMatrix { +public: + explicit ConnectionMatrix(const ComputeGraphPtr &graph); + ~ConnectionMatrix() = default; + + bool IsConnected(const NodePtr &a, const NodePtr &b) const; + + // inputs are all input nodes of parameter node. + // if there is a path between A->B, then B will own A's + // connectivity. The reason is --- + // If some node can reach A, than it can also reach B. + void SetConnectivity(const Node::Vistor &inputs, const NodePtr &node); + + /* Computes the connectivity between two nodes in the + * computation. The returned ConnectivityMatrix is constructed such that + * ConnectivityMatrix::IsConnected(a, b) returns true iff there exists a + * directed path (from producer to consumer) from 'a' to 'b'. Both data + * connection and control connection are considered for connectivity. + * A node is connected to itself. */ + graphStatus Generate(const ComputeGraphPtr &graph); + + // update reachablity map for fused nodes. + void Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes); + + void ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name); +private: + ConnectionMatrixImplPtr impl_{nullptr}; +}; +} +#endif // GRAPH_CONNECTION_MATRIX_H_ diff --git a/inc/metadef/graph/utils/constant_utils.h b/inc/metadef/graph/utils/constant_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..fdf423af7d4b66dc6f5781d30e69837404ccce2e --- /dev/null +++ b/inc/metadef/graph/utils/constant_utils.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_UTILS_CONSTANT_UTILS_H_ +#define COMMON_GRAPH_UTILS_CONSTANT_UTILS_H_ +#include "graph/node.h" +#include "graph/op_desc.h" + +namespace ge { +class ConstantUtils { + public: + // check is constant + static bool IsConstant(const NodePtr &node); + static bool IsConstant(const OpDescPtr &op_desc); + static bool IsPotentialConst(const OpDescPtr &op_desc); + static bool IsRealConst(const OpDescPtr &op_desc); + // get/set weight + static bool GetWeight(const OpDescPtr &op_desc, const uint32_t index, ConstGeTensorPtr &weight); + static bool MutableWeight(const OpDescPtr &op_desc, const uint32_t index, GeTensorPtr &weight); + static bool SetWeight(const OpDescPtr &op_desc, const uint32_t index, const GeTensorPtr weight); + static bool MarkPotentialConst(const OpDescPtr &op_desc, const std::vector indices, + const std::vector weights); + static bool UnMarkPotentialConst(const OpDescPtr &op_desc); + // for fileconstant + static bool GetWeightFromFile(const OpDescPtr &op_desc, ConstGeTensorPtr &weight); + private: + static bool GetPotentialWeight(const OpDescPtr &op_desc, std::vector &weight_indices, + std::vector &weights); + static bool MutablePotentialWeight(const OpDescPtr &op_desc, std::vector &weight_indices, + std::vector &weights); +}; +} + +#endif // COMMON_GRAPH_UTILS_CONSTANT_UTILS_H_ diff --git a/inc/metadef/graph/utils/cycle_detector.h b/inc/metadef/graph/utils/cycle_detector.h new file mode 100644 index 0000000000000000000000000000000000000000..5bc86e68aef16ae275bbc183f24b3562ebf37d56 --- /dev/null +++ b/inc/metadef/graph/utils/cycle_detector.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_CYCLE_DETECTOR_H_ +#define GRAPH_CYCLE_DETECTOR_H_ + +#include "graph/node.h" +#include "graph/compute_graph.h" +#include "connection_matrix.h" + +namespace ge { +class CycleDetector { + friend class GraphUtils; +public: + CycleDetector() = default; + ~CycleDetector() = default; + /* Detect whether there are cycles in graph + * after fusing all nodes in param fusion_nodes. + * Before call this func, you should call GenerateConnectionMatrix frist + * to generate connection_matrix based on current graph. + * + * Compared with Cycle Detection + * @param fusion_nodes: each vector in fusion_nodes + * will be fused into an entity(which could contains + * more than one node). The caller should put all original + * nodes which are expected to be fused into one larger node + * into each sub-vector of fusion_nodes. + * + * This function can tell whether there are a cycle after + * fusing all nodes in fusion_nodes. Each vector in 2-d + * vector fusion_nodes will be fused into an entity. + * + * + * This interface cannot detect whether there are cycles + * inside the fused nodes. + * + * e.g. {a, b, c, d} -> {e, f} + * Because the edge information is not given for e and f + * so this function we cannot tell if e and f are in a + * cycle. + * */ + bool HasDetectedCycle(const std::vector> &fusion_nodes); + + /** + * Update connection matrix based on graph. + * Connection matrix is served for cycle detection. + * + * The first param graph, it should be the same one graph when contribue cycle_detector + */ + void Update(const ComputeGraphPtr &graph, const std::vector &fusion_nodes); + + /** + * Expand dim and update connection matrix based on graph. + */ + void ExpandAndUpdate(const vector &fusion_nodes, const std::string &node_name); +private: + graphStatus Init(const ComputeGraphPtr &graph); + std::unique_ptr connectivity_{nullptr}; +}; + +using CycleDetectorPtr = std::unique_ptr; +using CycleDetectorSharedPtr = std::shared_ptr; +} +#endif // GRAPH_CYCLE_DETECTOR_H_ diff --git a/inc/metadef/graph/utils/dumper/ge_graph_dumper.h b/inc/metadef/graph/utils/dumper/ge_graph_dumper.h new file mode 100644 index 0000000000000000000000000000000000000000..8145fd4019be4230734c8677a16b7812f3aa6064 --- /dev/null +++ b/inc/metadef/graph/utils/dumper/ge_graph_dumper.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_UTILS_DUMPER_GE_GRAPH_DUMPER_H_ +#define GRAPH_UTILS_DUMPER_GE_GRAPH_DUMPER_H_ + +#include "graph/compute_graph.h" + +namespace ge { +struct GeGraphDumper { + GeGraphDumper() = default; + GeGraphDumper(const GeGraphDumper &) = delete; + GeGraphDumper &operator=(const GeGraphDumper &) = delete; + GeGraphDumper(GeGraphDumper &&) = delete; + GeGraphDumper &operator=(GeGraphDumper &&) = delete; + virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) = 0; + virtual ~GeGraphDumper() = default; +}; + +struct GraphDumperRegistry { + static GeGraphDumper &GetDumper(); + static void Register(GeGraphDumper &dumper); + static void Unregister(); +}; + +} // namespace ge + +#endif diff --git a/inc/metadef/graph/utils/execute_graph_adapter.h b/inc/metadef/graph/utils/execute_graph_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..486e6af646116bba3f1d290c2c5ec4ab6356434b --- /dev/null +++ b/inc/metadef/graph/utils/execute_graph_adapter.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_EXECUTE_GRAPH_ADAPTER_H +#define INC_GRAPH_UTILS_EXECUTE_GRAPH_ADAPTER_H + +#include "graph/fast_graph/execute_graph.h" +#include "external/graph/ge_error_codes.h" + +namespace ge { +class ExecuteGraphAdapter { + public: + ~ExecuteGraphAdapter() = default; + ExecuteGraphAdapter(const ExecuteGraphAdapter &adapter) = delete; + ExecuteGraphAdapter &operator=(const ExecuteGraphAdapter &adapter) = delete; + + // 返回的ComputeGraph复用了原图src_graph的OpDesc对象,返回后src_graph不能释放 + // src_graph和返回的ComputeGraph的生命周期需要保证一致 + static ComputeGraphPtr ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph); + + private: + ExecuteGraphAdapter() = default; + static graphStatus ConvertExecuteGraphToComputeGraph(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, + const int32_t depth); + static graphStatus CopyOpAndSubgraph(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, + std::unordered_map &all_new_nodes, const int32_t depth); + static graphStatus RelinkGraphEdges(FastNode *old_node, + const std::unordered_map &all_new_nodes); + static graphStatus CopyMembers(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph, + const std::unordered_map &all_new_nodes); + static void InheritOriginalAttr(ExecuteGraph *src_graph, const ComputeGraphPtr &dst_graph); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_EXECUTE_GRAPH_ADAPTER_H diff --git a/inc/metadef/graph/utils/execute_graph_utils.h b/inc/metadef/graph/utils/execute_graph_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3cb48848562889db3c81145e45532a6a6e29cbe6 --- /dev/null +++ b/inc/metadef/graph/utils/execute_graph_utils.h @@ -0,0 +1,195 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_EXECUTE_GRAPH_UTILS_H +#define INC_GRAPH_UTILS_EXECUTE_GRAPH_UTILS_H +#include "graph/fast_graph/fast_node.h" +#include "graph/fast_graph/execute_graph.h" + +namespace ge { +class ExecuteGraphUtils { + public: + /** + * 查找`exe_graph`的根图,如果当前图就是根图或者当前图没有父图,则返回当前图 + * @param exe_graph + * @return + */ + static ExecuteGraph *FindRootGraph(ExecuteGraph *exe_graph); + + /** + * 查找`exe_graph`中节点名为`name`的节点,包含子图 + * @param exe_graph + * @return + */ + static FastNode *FindNodeFromAllNodes(ExecuteGraph *exe_graph, const char_t *const name); + + /** + * 查找`exe_graph`中节点类型为`type`的节点,包含子图 + * @param exe_graph + * @return + */ + static std::vector FindNodesByTypeFromAllNodes(ExecuteGraph *exe_graph, const char_t *const type); + + /** + * 查找`exe_graph`中首个节点类型为`type`的节点,不包含子图 + * @param exe_graph + * @param type + * @return + */ + static FastNode *FindFirstNodeMatchType(ExecuteGraph *exe_graph, const char_t *const type); + + /** + * 接口行为是在'src'的源节点输出端和'dst'目的节点输入端们之间插入一个`insert_node`节点, + * 默认是`insert_node`的`0`号数据输入端和`0`号输出端参与连边,`insert_node`插入之后, `src_node`和`insert_node` + * 作为一个整体与原来的`src_node`具备等价的控制和数据关系 + * @param src 源数据输出端 + * @param dsts 源数据输出端连接的目的数据输入端,使用vector的原因是存在一个源节点输出端给到多个目的节点输入端的情况 + * @param insert_node 表示要插入的节点 + * @param input_index 表示插入节点的哪个数据输入端要跟src相连,如果不传递,默认取0 + * @param output_index 表示插入节点的哪个数据输出端要跟dsts依次相连,如果不传递,默认取0 + * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus InsertNodeAfter(const EdgeSrcEndpoint &src, const std::vector &dsts, + FastNode *insert_node, const uint32_t input_index = 0U, + const uint32_t output_index = 0U); + + /** + * 接口行为是在数据`dst`目的节点输入端和其对端源节点输出端之间插入一个`insert_node`节点, + * 默认是`insert_node`的`0`号数据输入端和`0`号数据输出数据端参与连边,`insert_node`插入之后, + * `dst_node`和`insert_node`作为一个整体与原来的`dst_node`具备等价的控制和数据关系 + * @param dst 目的数据输入端 + * @param insert_node 表示要插入的节点 + * @param input_index 表示插入节点的哪个数据输入端要跟dst的对端src输出端相连,如果不传递,默认取0 + * @param output_index 表示插入节点的哪个数据输出端要跟dst相连,如果不传递,默认取0 + * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus InsertNodeBefore(const EdgeSrcEndpoint &dst, FastNode *insert_node, + const uint32_t input_index = 0U, const uint32_t output_index = 0U); + + /** + * 移动`src_node`的控制输入边到`dst_node`上 + * @param src_node + * @param dst_node + * @return + */ + static graphStatus MoveInCtrlEdges(const FastNode *src_node, FastNode *dst_node); + + /** + * 移动`src_node`的控制输出边到`dst_node`上 + * @param src_node + * @param dst_node + * @return + */ + static graphStatus MoveOutCtrlEdges(const FastNode *src_node, FastNode *dst_node); + + /** + * 将node所有的输入、输出边断开,并移动到dst_graph + * @param dst_graph 目的Graph, + * @param node 需要移动的Node + * @return 成功时,返回ge::GRAPH_SUCCESS + */ + static graphStatus MoveNodeToGraph(FastNode *node, ExecuteGraph *dst_graph); + + /** + * 接口行为是根据`inputs_map`和`outputs_map`把`old_node`上的数据关系`移动`到`new_node`上;具体操作是 + * 把`old_node`的第`inputs_map[i]`/`outputs_map[i]`个数据输入、输出端点的数据关系替换到`new_node`的第`i`个 + * 输入、输出端点上, `i`的取值范围是[0, `inputs_map`/`outputs_map`的元素个数); 如果`inputs_map[i]`/`outputs_map[i]` + * 的值小于0或者不在`old_node`的输入、输出端点范围之内,那么`new_node`的第`i`个数据输入、输出端点的数据关系保持原样 + * @param new_node + * @param old_node + * @param inputs_map 用于指导输入数据端点的替换,注意元素个数不应该超过`new_node`的输入端点总个数 + * @param outputs_map 用于指导输出端点的替换,注意元素个数不应该超过`new_node`的输出端点总个数 + * @param graph 表示`new_node`和`old_node`所在的graph,如果不传会通过`new_node`获取 + * @return + */ + static graphStatus ReplaceNodeDataEdges(FastNode *new_node, FastNode *old_node, + const std::initializer_list inputs_map, + const std::initializer_list outputs_map, + ExecuteGraph *graph = nullptr); + static graphStatus ReplaceNodeDataEdges(FastNode *new_node, FastNode *old_node, + const std::vector &inputs_map, + const std::vector &outputs_map, ExecuteGraph *graph = nullptr); + + /** + * 此接口对数据关系的处理与`ReplaceNodeDataEdges`的处理行为一致, 在此基础上, + * 复制了`old_node`的所有控制关系到`new_node`上,这也是要注意的一点: + * `数据`关系是`移动`操作,`控制`关系是`复制`操作 + * @param new_node + * @param old_node + * @param inputs_map 用于指导输入数据锚点的替换,注意元素个数不应该超过`new_node`的输入短点总个数 + * @param outputs_map 用于指导输出锚点的替换,注意元素个数不应该超过`new_node`的输出短点总个数 + * @return 如果替换成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus ReplaceNodeEdges(FastNode *new_node, FastNode *old_node, + const std::initializer_list inputs_map, + const std::initializer_list outputs_map); + static graphStatus ReplaceNodeEdges(FastNode *new_node, FastNode *old_node, const std::vector &inputs_map, + const std::vector &outputs_map); + + /** + * 孤立`node`, 根据`io_map`完成node的输入输出数据边的重连;同时会添加必要的控制边保证`node`的所有输入节点 + * 均在`node`的输出节点之前执行 + * @param node + * @param io_map 把第`io_map[i]`个输入的对端输出,连接到第`i`个输出的对端输入。因此`io_map`的元素个数应该与 + * `node`的输出端点的个数相等,如果`io_map[i]`小于0,则仅断开第`i`个输出端点到对端的所有连边 + * @return + */ + static graphStatus IsolateNode(FastNode *node, const std::initializer_list &io_map); + static graphStatus IsolateNode(FastNode *node, const std::vector &io_map); + + /** + * 替换`old_edge`的``src`为`new_src`指示的`node`和`index` + * @param old_edge + * @param new_src + * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED + */ + static graphStatus ReplaceEdgeSrc(FastEdge *old_edge, const EdgeSrcEndpoint &new_src); + + /** + * 从`execute_graph`上删除`直接`或者`间接`父节点为remove_node的所有子图对象 + * @param execute_graph + * @param remove_node + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus RemoveSubgraphRecursively(ExecuteGraph *execute_graph, FastNode *remove_node); + + /** + * 从`execute_graph`中删除`node`对象的所有关系,包括子图关系,从属关系,作为`execute_graph`的输入,输出的关系; + * 仅删除,不进行断边连边,不保证删除后节点前后的控制关系传递 + * @param execute_graph + * @param node + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus RemoveNodeWithoutRelink(ExecuteGraph *execute_graph, FastNode *node); + + /** + * 拷贝`src_node`的输入控制边到`dst_node`上 + * @param src_node + * @param dst_node + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus CopyInCtrlEdges(const FastNode *src_node, FastNode *dst_node); + + /** + * 拷贝`src_node`的输出控制边到`dst_node`上 + * @param src_node + * @param dst_node + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus CopyOutCtrlEdges(const FastNode *src_node, FastNode *dst_node); + + /** + * 构建并返回根图中所有节点名到节点的映射,包含子图中的节点 + * @param exe_graph + * @return 节点名到节点的映射 + */ + static std::unordered_map GetNodeMapFromAllNodes(ExecuteGraph *exe_graph); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_EXECUTE_GRAPH_UTILS_H diff --git a/inc/metadef/graph/utils/fast_node_utils.h b/inc/metadef/graph/utils/fast_node_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d42b8f58cdccb77b8ead5d6a150048c73ea35dbb --- /dev/null +++ b/inc/metadef/graph/utils/fast_node_utils.h @@ -0,0 +1,168 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_FAST_NODE_UTILS_H +#define INC_GRAPH_UTILS_FAST_NODE_UTILS_H + +#include "graph/fast_graph/fast_node.h" +#include "graph/fast_graph/edge.h" +#include "graph/fast_graph/execute_graph.h" +#include "external/graph/ge_error_codes.h" + +namespace ge { +class FastNodeUtils { + public: + // node utils + /** + * @brief 获取给定节点的父节点的输入。 + * + * @param node 指向待查询节点的指针。 + * @return 返回指向父节点的输入的指针,如果父节点不存在,则返回 nullptr。 + */ + static FastNode *GetParentInput(const FastNode *const node); + + /** + * @brief 获取指定索引处的输入数据节点。 + * + * @param node 指向要访问其输入数据节点的指针。 + * @param index 要检索的输入数据节点的索引。 + * @return 如果找到指定索引处的输入数据节点,则返回该节点的指针;否则返回 nullptr。 + */ + static FastNode *GetInDataNodeByIndex(const FastNode * const node, const int32_t index); + + /** + * @brief 获取给定节点是否为常量 Op 节点。 + * + * @param node 指向待查询节点的指针。 + * @return 如果给定节点是常量 Op 节点,则返回 true;否则返回 false。 + */ + static bool GetConstOpType(const FastNode *const node); + + // subgraph utils + /** + * @brief 向给定节点追加子图。 + * + * @param node 指向待添加子图的节点的指针。 + * @param subgraph_name 子图的名称。 + * @param subgraph 指向待添加的子图的智能指针,子图生命周期由 root_graph 管理。 + * @return 返回添加子图的状态。 + */ + static graphStatus AppendSubgraphToNode(FastNode *const node, const std::string &subgraph_name, + const ExecuteGraphPtr &subgraph); + + /** + * @brief 获取给定节点的指定索引处的子图。 + * + * @param node 指向待查询子图的节点的指针。 + * @param index 子图的索引。 + * @return 返回指向子图的指针,如果索引超出范围或者子图不存在,则返回 nullptr。 + */ + static ExecuteGraph *GetSubgraphFromNode(const FastNode *const node, const uint32_t index); + + /** + * @brief 在给定节点的指定索引 index 处挂载子图。若原先存在子图,该接口会替换 node 下索引为 index + * 的子图,但是不会从root_graph中移除原子图。 若需移除原子图,建议调用ExecuteGraph::RemoveSubGraph。 + * + * @param node 指向待设置子图的节点的指针。 + * @param index 子图的索引。 + * @param subgraph 指向待设置的子图的智能指针。 + * @return 返回设置子图的状态。 + */ + static graphStatus MountSubgraphToNode(FastNode *const node, const uint32_t index, const ExecuteGraphPtr &subgraph); + + // edge utils + /** + * @brief 向给定节点追加输入边信息,直至输入边信息数量达到 num。 + * 注意:该接口不会实际建立输入边,若需要为节点连边,建议调用ExecuteGraph::AddEdge。 + * + * @param node 待追加输入边信息的节点的指针。 + * @param num 追加操作后,node 所拥有的输入边信息数量。 + * @return 返回追加输入边信息的状态。 + */ + static graphStatus AppendInputEdgeInfo(FastNode *const node, const uint32_t num); + + /** + * @brief 向给定节点追加输出边信息,直至输出边信息数量达到 num。 + * 注意:该接口不会实际建立输出边,若需要为节点连边,建议调用ExecuteGraph::AddEdge。 + * + * @param node 待追加输出边信息的节点的指针。 + * @param num 追加操作后,node 所拥有的输出边信息数量。 + * @return 返回追加输出边信息的状态。 + */ + static graphStatus AppendOutputEdgeInfo(FastNode *const node, const uint32_t num); + + /** + * @brief 清除给定 OpDesc 的指定索引处的 InputDesc。 + * + * @param op_desc OpDesc 的指针。 + * @param index 要清除的 InputDesc 的索引。 + * @return 如果成功清除 InputDesc,则返回 true;否则返回 false。 + */ + static bool ClearInputDesc(const OpDesc *const op_desc, const uint32_t index); + + /** + * @brief 移除给定节点的输入边信息,直至输入边信息数量减少到 num。 + * 注意:该接口不会从执行图中移除输入边,若需要移除输入边,建议调用ExecuteGraph::RemoveEdge。 + * + * @param node 待移除输入边信息的节点的指针。 + * @param num 移除操作后,node 所拥有的输入边信息数量。 + * @return 返回追加输出边信息的状态。 + */ + static graphStatus RemoveInputEdgeInfo(FastNode *const node, const uint32_t num); + + /** + * @brief 断开给定节点与其所有相连节点之间的输入边和输出边。 + * + * @param node 目标节点。 + */ + static void UnlinkAll(FastNode *const node); + + /** + * @brief 获取给定边的输入端点,包含 dst 节点指针和输入 index。 + * (SrcNode:[OutEndpoint])->Edge->([InEndpoint]:DstNode) + * + * @param edge 指向待查询输入端点的边的指针。 + * @return 返回输入端点。 + */ + static EdgeDstEndpoint GetDstEndpoint(const FastEdge *const edge); + + /** + * @brief 获取给定边的输出端点,包含 src 节点指针和输出 index。 + * (SrcNode:[OutEndpoint])->Edge->([InEndpoint]:DstNode) + * + * @param edge 指向待查询输出端点的边的指针。 + * @return 返回输出端点。 + */ + static EdgeSrcEndpoint GetSrcEndpoint(const FastEdge *const edge); +}; + +struct FastNodeCompareKey { + bool operator()(const FastNode *const n0, const FastNode *const n1) const { + if ((n0 == nullptr) || (n1 == nullptr)) { + return false; + } + if (n0->GetName() == n1->GetName()) { + const ExtendInfo *const extend_info0 = n0->GetExtendInfo(); + const ExtendInfo *const extend_info1 = n1->GetExtendInfo(); + if ((extend_info0 == nullptr) || (extend_info1 == nullptr)) { + return false; + } + const ExecuteGraph *const g0 = extend_info0->GetOwnerGraphBarePtr(); + const ExecuteGraph *const g1 = extend_info1->GetOwnerGraphBarePtr(); + if ((g0 == nullptr) || (g1 == nullptr)) { + return false; + } + return (g0->GetName() < g1->GetName()); + } + return (n0->GetName() < n1->GetName()); + } +}; +} // namespace ge + +#endif // INC_GRAPH_UTILS_FAST_NODE_UTILS_H diff --git a/inc/metadef/graph/utils/ffts_graph_utils.h b/inc/metadef/graph/utils/ffts_graph_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..ba19ae24c0ba6f1b90000736721e9cfbb8490470 --- /dev/null +++ b/inc/metadef/graph/utils/ffts_graph_utils.h @@ -0,0 +1,84 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_FFTS_GRAPH_UTILS_H_ +#define INC_GRAPH_UTILS_FFTS_GRAPH_UTILS_H_ + +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/node.h" + +namespace ge { +class FftsGraphUtils { + public: + using CalcFunc = std::function(const NodePtr &)>; + static graphStatus GraphPartition(ComputeGraph &graph, const std::set &unsupported_nodes); + + static graphStatus GraphPartition(ComputeGraph &graph, + const CalcFunc &calc_func, + const std::vector &upper_limit); + private: + static graphStatus CollectClipNodesAndGraphs(const ComputeGraphPtr &graph, + const std::set &unsupported_nodes, + std::unordered_set &nodes_need_clip, + std::unordered_set &graphs_need_split); + + static bool IsGraphNeedSplit(const ComputeGraphPtr &graph, const std::unordered_set &nodes_need_clip); + + static graphStatus SplitNodesWithCheck(const ComputeGraphPtr &graph, + const std::unordered_set &nodes_need_clip, + std::vector>> &split_nodes); + + static void SplitNodes(const std::set &calc_nodes, const std::function &is_cur_stage, + std::set &visited_nodes, std::set &cur_nodes, std::set &next_nodes); + + static graphStatus SplitSubgraph(const ComputeGraphPtr &subgraph, + const std::vector>> &split_nodes); + + static graphStatus BuildFftsPlusSubgraphWithAllNodes(const ComputeGraphPtr &subgraph); + + static void CollectCalcNodeInSubgraph(const ComputeGraphPtr &subgraph, std::set &calc_nodes); + + static void CollectEndNodeInSubgraph(const ComputeGraphPtr &subgraph, const std::set &ctrl_goto_types, + std::set &edge_nodes); + + static ComputeGraphPtr GetFftsPlusGraph(ComputeGraph &graph); + + static graphStatus SetAttrForFftsPlusSubgraph(const ComputeGraphPtr &subgraph); + + static graphStatus Calculate(const ComputeGraphPtr &graph, + const CalcFunc &calc_func, + std::map> &node_value, + std::map> &graph_value, + const uint32_t recursive_depth = 1U); + + static std::vector Calculate(const NodePtr &node, const CalcFunc &calc_func, + std::map> &node_value, + std::map> &graph_value, + const uint32_t recursive_depth); + + static bool IsValueValid(const ComputeGraphPtr &graph, const std::vector &upper_limit, + const std::map> &node_value, + const std::map> &graph_value); + + static graphStatus PartitionGraphWithLimit(const ComputeGraphPtr &graph, + std::map> &node_value, + std::map> &graph_value, + const std::vector &upper_limit, + const uint32_t recursive_depth = 1U); + + static graphStatus SplitFuncNode(const std::vector exceed_single_node, + std::map> &node_value, + std::map> &graph_value, + const std::vector &upper_limit, + const uint32_t recursive_depth); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/metadef/graph/utils/file_utils.h b/inc/metadef/graph/utils/file_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..5d9cb5ecd0151bef087dd400481e72dfcbdee425 --- /dev/null +++ b/inc/metadef/graph/utils/file_utils.h @@ -0,0 +1,131 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_GRAPH_UTILS_FILE_UTILS_H_ +#define COMMON_GRAPH_UTILS_FILE_UTILS_H_ + +#include +#include +#include "external/graph/types.h" +#include "graph/ge_error_codes.h" +#include "ge_common/ge_api_types.h" +#include "mmpa/mmpa_api.h" + +namespace ge { + +/** + * @ingroup domi_common + * @brief Absolute path for obtaining files + * @param [in] path of input file + * @return string. Absolute path of a file. If the absolute path cannot be + * obtained, an empty string is returned + */ +std::string RealPath(const char_t *path); + +/** + * @ingroup domi_common + * @brief Recursively Creating a Directory + * @param [in] directory_path Path, which can be a multi-level directory. + * @return 0 success, 1- fail. + */ +int32_t CreateDir(const std::string &directory_path); + +/** + * @ingroup domi_common + * @brief Recursively Creating a Directory with mode + * @param [in] directory_path Path, which can be a multi-level directory. + * @param [in] mode dir mode, E.G., 0700 + * @return 0 success, 1- fail. + */ +int32_t CreateDir(const std::string &directory_path, uint32_t mode); + +/** + * @ingroup domi_common + * @brief Recursively Creating a Directory, deprecated, use CreateDir instead + * @param [in] directory_path Path, which can be a multi-level directory. + * @return 0 success, 1- fail. + */ +int32_t CreateDirectory(const std::string &directory_path); + +std::unique_ptr GetBinFromFile(std::string &path, uint32_t &data_len); + +std::unique_ptr GetBinDataFromFile(const std::string &path, uint32_t &data_len); + +/** + * @ingroup domi_common + * @brief Get binary data from file + * @param [in] path file path. + * @param [in] offset offset of the file date + * @param [in] offset data len to read + * @return buffer char[] used to store file data + */ +std::unique_ptr GetBinFromFile(const std::string &path, size_t offset, size_t data_len); + +graphStatus WriteBinToFile(std::string &path, char_t *data, uint32_t &data_len); + +/** + * @ingroup domi_common + * @brief Get binary file from file + * @param [in] name origin name. + * @return string. name which repace special code as _. + */ +std::string GetRegulatedName(const std::string name); + +/** + * @ingroup domi_common + * @brief Get binary file from file + * @param [in] path file path. + * @param [out] buffer char[] used to store file data + * @param [out] data_len store read size + * @return graphStatus GRAPH_SUCCESS: success, OTHERS: fail. + */ +graphStatus GetBinFromFile(const std::string &path, char_t *buffer, size_t &data_len); + +/** + * @ingroup domi_common + * @brief Write binary to file + * @param [in] fd file desciption. + * @param [in] data char[] used to write to file + * @param [in] data_len store write size + * @return graphStatus GRAPH_SUCCESS: success, OTHERS: fail. + */ +graphStatus WriteBinToFile(const int32_t fd, const char_t * const data, size_t data_len); + +/** + * @ingroup domi_common + * @brief Save data to file + * @param [in] file_path file path. + * @param [in] data char[] used to store file data + * @param [in] length store read size + * @return graphStatus GRAPH_SUCCESS: success, OTHERS: fail. + */ +graphStatus SaveBinToFile(const char * const data, size_t length, const std::string &file_path); + +/** + * @ingroup domi_common + * @brief split file path to directory path and file name + * @param [in] file_path file path. + * @param [out] dir_path directory path + * @param [out] file_name file name + * @return graphStatus GRAPH_SUCCESS: success, OTHERS: fail. + */ +void SplitFilePath(const std::string &file_path, std::string &dir_path, std::string &file_name); + +/** + * @ingroup domi_common + * @brief Get ASCEND_WORK_PATH environment variable + * @param [out] ascend_work_path ASCEND_WORK_PATH's value. + * @return graphStatus SUCCESS: success, OTHERS: fail. + */ +Status GetAscendWorkPath(std::string &ascend_work_path); + +int32_t Scandir(const CHAR *path, mmDirent ***entry_list, mmFilter filter_func, mmSort sort); +} + +#endif // end COMMON_GRAPH_UTILS_FILE_UTILS_H_ diff --git a/inc/metadef/graph/utils/graph_dump_utils.h b/inc/metadef/graph/utils/graph_dump_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3206ae41408bbd700a312503464e48a1ad6aa28e --- /dev/null +++ b/inc/metadef/graph/utils/graph_dump_utils.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_GRAPH_DUMP_UTILS_H +#define INC_GRAPH_UTILS_GRAPH_DUMP_UTILS_H + +#include "graph/compute_graph.h" +#include "graph/fast_graph/execute_graph.h" +#include "graph/utils/execute_graph_adapter.h" +#include "graph/utils/graph_utils.h" +#include "mmpa/mmpa_api.h" + +namespace ge { +namespace { +const char_t *const kDumpGeGraph = "DUMP_GE_GRAPH"; +} +/** + * 将ComputeGraph落盘成文件 + * @param compute_graph 要落盘的对象 + * @param name 落盘的文件名,会拼接上默认的前后缀 + * @return + */ +inline void DumpGraph(const ComputeGraphPtr &compute_graph, const char_t *const name) { + ge::GraphUtils::DumpGEGraph(compute_graph, name); + ge::GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); + uint64_t i = 0U; + for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { + const auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); + ge::GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); + ge::GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); + } +} + +/** + * 将ExecuteGraph落盘成文件 + * @param execute_graph 要落盘的对象 + * @param name 落盘的文件名,会拼接上默认的前后缀 + * @return + */ +inline void DumpGraph(ExecuteGraph *execute_graph, const char_t *const name) { + char_t dump_ge_graph[MMPA_MAX_PATH] = {'\0'}; + if (mmGetEnv(kDumpGeGraph, &(dump_ge_graph[0U]), static_cast(MMPA_MAX_PATH)) != EN_OK) { + return; + } + const auto compute_graph = ExecuteGraphAdapter::ConvertExecuteGraphToComputeGraph(execute_graph); + if (compute_graph != nullptr) { + DumpGraph(compute_graph, name); + } +} +} // namespace ge +#endif // INC_GRAPH_UTILS_GRAPH_DUMP_UTILS_H diff --git a/inc/metadef/graph/utils/graph_utils.h b/inc/metadef/graph/utils/graph_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..98a78619e13d409d835f8509d7cf4a46844afbd0 --- /dev/null +++ b/inc/metadef/graph/utils/graph_utils.h @@ -0,0 +1,1191 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ +#define INC_GRAPH_UTILS_GRAPH_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "graph/node.h" +#include "graph/utils/anchor_utils.h" +#include "cycle_detector.h" + +/** + * 图dump接口,用于把`compute_graph`对象序列化到文件,默认落盘到当前路径; + * 如果`compute_graph`挂载了子图对象,子图对象也尝试进行落盘 + * 图的落盘行为受`DUMP_GE_GRAPH`和`DUMP_GRAPH_LEVEL`和`DUMP_GRAPH_PATH`环境变量的控制 + * DUMP_GE_GRAPH含义说明: + * 1-全量dump + * 2-不含有权重等数据的基础版dump + * 3-只显示节点关系的精简版dump + * DUMP_GRAPH_LEVEL含义说明: + * 1-dump所有的图 + * 2-dump除子图外的所有图 + * 3-dump最后阶段的生成图 + * 4-dump入口阶段的生成图 + * a|b-会dump匹配到a或者b的字串的图,支持自定义dump图 + * DUMP_GRAPH_PATH含义说明: + * 控制图的落盘的路径 + * @param compute_graph + * @param name 用于拼接文件的名称 + */ +#define GE_DUMP(compute_graph, name) \ + do { \ + ge::GraphUtils::DumpGEGraph((compute_graph), (name)); \ + ge::GraphUtils::DumpGEGraphToOnnx(*(compute_graph), (name)); \ + uint64_t i = 0U; \ + for (const auto &sub_graph_func : (compute_graph)->GetAllSubgraphs()) { \ + const auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ + ge::GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ + ge::GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ + } \ + } while (false) + +namespace ge { +enum class DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END = 4 }; +enum class MemType { OUTPUT_MEM, WORKSPACE_MEM }; + +struct MemReuseInfo { + NodePtr node; + MemType mem_type; + uint32_t index; +}; + +enum IOType { kIn, kOut }; +class NodeIndexIO { + public: + NodeIndexIO(const NodePtr &node, const uint32_t index, const IOType io_type) + : node_(node), index_(index), io_type_(io_type), node_ptr_(node.get()) { + ToValue(); + } + NodeIndexIO(const NodePtr &node, const int32_t index, const IOType io_type) + : node_(node), index_(static_cast(index)), io_type_(io_type), node_ptr_(node.get()) { + ToValue(); + } + NodeIndexIO(const NodePtr &node, const int64_t index, const IOType io_type) + : node_(node), index_(static_cast(index)), io_type_(io_type), node_ptr_(node.get()) { + ToValue(); + } + NodeIndexIO(const Node *node, const uint32_t index, const IOType io_type) + : node_(nullptr), index_(index), io_type_(io_type), node_ptr_(node) { + ToValue(); + } + ~NodeIndexIO() {} + + const std::string &ToString() const { + return value_; + } + + void ToValue() { + if (node_ptr_ != nullptr) { + value_ = node_ptr_->GetName() + ((io_type_ == kOut) ? "_out_" : "_in_") + std::to_string(index_); + } + } + + bool operator==(const NodeIndexIO& other) const { + return value_ == other.value_; + } + + NodePtr node_ = nullptr; + uint32_t index_ = 0U; + IOType io_type_ = kOut; + std::string value_; + const Node *node_ptr_ = nullptr; +}; + +// Symbol: nodeName, out or in and index, for example: Sqrt0trans_Cast_4_in_0. @see NodeIndexIO::ToValue() +using SymbolToAnchors = std::unordered_map>; +using AnchorToSymbol = std::unordered_map; // symbol and its peer out anchor's symbol + +class GraphUtils { + public: + /** + * pipline拆分场景获取`compute_graph`的`PARTITIONEDCALL`子图 + * @param compute_graph + * @param independent_compile_subgraphs:出参,pipline拆分场景返回子图对象,非拆分场景返回`compute_graph`本身 + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus GetIndependentCompileGraphs(const ComputeGraphPtr &compute_graph, + std::vector &independent_compile_subgraphs); + + /** + * `src`和`dst`进行连边,`dst`作为InDataAnchorPtr, 最多允许一个对端OutDataAnchorPtr + * @param src + * @param dst + * @return 如果`dst`已经有数据输入,则返回GRAPH_FAILED + */ + static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + + static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); + + /** + * `src`和`dst`进行断边 + * @param src + * @param dst + * @return 如果`src`和`dst`没有连边关系,则返回GRAPH_FAILED + */ + static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); + + static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); + + static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); + + static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); + + /** + * 替换`dst`的对端`src`为`new_src` + * @param src + * @param dst + * @param new_src + * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED + */ + static graphStatus ReplaceEdgeSrc(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const OutDataAnchorPtr &new_src); + /** + * 替换`dst`的对端`src`为`new_src` + * @param src + * @param dst + * @param new_src + * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED + */ + static graphStatus ReplaceEdgeSrc(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, + const OutControlAnchorPtr &new_src); + /** + * 替换`src`的对端`dst`为`new_dst` + * @param src + * @param dst + * @param new_dst + * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED + */ + static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const InDataAnchorPtr &new_dst); + /** + * 替换`src`的对端`dst`为`new_dst` + * @param src + * @param dst + * @param new_dst + * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED + */ + static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, + const InControlAnchorPtr &new_dst); + /** + * 在`src`所属的node对象和`dst`所属的node对象之间插入`new_node`, 行为等价于替换`src`的对端`dst`为`new_node`的第`0`个 + * InDataAnchor, 同时替换`dst`的对端`src`为`new_node`的第`0`个OutDataAnchor + * @param src + * @param dst + * @param new_node + * @return 替换成功返回GRAPH_SUCCESS, 替换失败返回GRAPH_FAILED + */ + static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, + const NodePtr &new_node); + /** + * 从`compute_graph`上删除`直接`或者`间接`父节点为remove_node的所有子图对象 + * @param compute_graph + * @param remove_node + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); + + /** + * 从`compute_graph`中删除算子类型为`node_type`的所有关系,包括子图关系,从属关系,作为`compute_graph`的输入,输出的关系; + * 仅删除,不进行断边连边,不保证删除后节点前后的控制关系传递 + * @param compute_graph + * @param node_type + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus RemoveNodesByTypeWithoutRelink(const ComputeGraphPtr &compute_graph, const std::string &node_type); + + /** + * 从`compute_graph`中删除`node`对象的所有关系,包括子图关系,从属关系,作为`compute_graph`的输入,输出的关系; + * 仅删除,不进行断边连边,不保证删除后节点前后的控制关系传递 + * @param compute_graph + * @param node + * @return 成功返回GRAPH_SUCCESS, 失败返回GRAPH_FAILED + */ + static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); + + /** + * 从`compute_graph`中删除`nodes`对象们的所有关系,包括子图关系,从属关系,作为`compute_graph`的输入,输出的关系; + * 仅删除,不进行断边连边,不保证删除后节点前后的控制关系传递 + * 此接口在图规模比较大的时候,比遍历`nodes`节点依次调用`RemoveNodeWithoutRelink`更高效一些 + * @param compute_graph + * @param nodes + * @return + */ + static graphStatus RemoveNodesWithoutRelink(const ComputeGraphPtr &compute_graph, + const std::unordered_set &nodes); + /** + * ComputeGraph图对象的全量深拷贝接口 + * @param src_compute_graph 需要是根图对象 + * @param dst_compute_graph + * @return + */ + static graphStatus CopyComputeGraph(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph); + + /** + * ComputeGraph图对象的深拷贝接口 + * @param src_compute_graph 需要是根图对象 + * @param node_filter 节点拷贝白名单过滤器,可以通过传递此参数实现满足条件的节点的复制,不传递时代表全量拷贝 + * @param graph_filter 子图拷贝白名单过滤器,可以通过传递此参数实现满足条件的子图的复制,不传递时代表全量拷贝 + * @param attr_filter 节点上属性拷贝白名单过滤器,可以通过传递此参数实现满足条件的属性复制,不传递时代表全量拷贝 + * @param dst_compute_graph + * @return + */ + static graphStatus CopyComputeGraph(const ComputeGraphPtr &src_compute_graph, const NodeFilter &node_filter, + const GraphFilter &graph_filter, const AttrFilter &attr_filter, + ComputeGraphPtr &dst_compute_graph); + + /** + * ComputeGraph图对象的深拷贝接口 + * @param src_compute_graph + * @param node_filter 节点拷贝白名单过滤器,可以通过传递此参数实现满足条件的节点的复制,不传递时代表全量拷贝 + * @param graph_filter 子图拷贝白名单过滤器,可以通过传递此参数实现满足条件的子图的复制,不传递时代表全量拷贝 + * @param attr_filter 节点上属性拷贝白名单过滤器,可以通过传递此参数实现满足条件的属性复制,不传递时代表全量拷贝 + * @param dst_compute_graph + * @param node_old_2_new 新旧节点映射关系 + * @param op_desc_old_2_new 新旧节点描述信息的映射关系 + * @param depth 子图拷贝深度, 最大支持为10 + * @return + */ + static graphStatus CopyComputeGraph(const ComputeGraphPtr &src_compute_graph, const NodeFilter &node_filter, + const GraphFilter &graph_filter, const AttrFilter &attr_filter, + ComputeGraphPtr &dst_compute_graph, + std::map &node_old_2_new, + std::map &op_desc_old_2_new, const int32_t depth); + + /** + * ComputeGraph图对象的深拷贝接口 + * @param src_compute_graph + * @param dst_compute_graph + * @param node_old_2_new 新旧节点映射关系 + * @param op_desc_old_2_new 新旧节点描述信息的映射关系 + * @param depth 子图拷贝深度, 最大支持为10 + * @return + */ + static graphStatus CopyComputeGraph(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph, + std::map &node_old_2_new, + std::map &op_desc_old_2_new, const int32_t depth); + /** + * 拷贝OpDesc对象,跟`CopyOpDesc`方法的区别是`CloneOpDesc`的拷贝内容精简一些 + * 注意:此接口性能较差,且拷贝内容存在丢失,为了考虑兼容目前保留此接口,推荐直接使用OpDesc + * 的拷贝构造函数来实现拷贝,拷贝构造函数性能好且内容不存在丢失 + * @param org_op_desc + * @return + */ + static OpDescPtr CloneOpDesc(const ConstOpDescPtr &org_op_desc); + /** + * 拷贝OpDesc对象,跟`CloneOpDesc`方法的区别是`CopyOpDesc`的拷贝内容多一些, + * 包括函数指针成员等 + * 注意:此接口性能较差,且拷贝内容存在丢失,为了考虑兼容目前保留此接口,推荐直接使用OpDesc + * 的拷贝构造函数来实现拷贝,拷贝构造函数性能好且内容不存在丢失 + * @param org_op_desc + * @return + */ + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc); + /** + * `CopyOpDesc`的重载接口 + * @param org_op_desc + * @param attr_filter 节点上属性拷贝白名单过滤器,可以通过传递此参数实现满足条件的属性复制,不传递时代表全量拷贝 + * @return + */ + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc, const AttrFilter &attr_filter); + + /** + * 接口行为是在数据`src`锚点所属的`src_node`节点和数据`dsts`锚点所属的`dst_node`节点们之间插入一个`insert_node`节点, + * 默认是`insert_node`的`0`号数据输入锚点和`0`号输出数据锚点参与连边,`insert_node`插入之后, `src_node`和`insert_node` + * 作为一个整体与原来的`src_node`具备等价的控制和数据关系 + * @param src 源数据输出锚点 + * @param dsts 源数据输出锚点连接的目的数据输入锚点,使用vector的原因是存在一个源锚点给到多个目的锚点的情况 + * @param insert_node 表示要插入的节点 + * @param input_index 表示插入节点的哪个数据输入锚点要跟src相连,如果不传递,默认取0 + * @param output_index 表示插入节点的哪个数据输出锚点要跟dsts依次相连,如果不传递,默认取0 + * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, const uint32_t input_index = 0U, + const uint32_t output_index = 0U); + + /** + * 接口行为是通过insert_op在图中生成一个insert_node节点, + * 在数据`src`锚点所属的`src_node`节点和数据`dsts`锚点所属的`dst_node`节点们之间插入一个`insert_node`节点, + * 默认是`insert_node`的`0`号数据输入锚点和`0`号输出数据锚点参与连边 + * `insert_node`插入之后, `src_node`和`insert_node` + * 作为一个整体与原来的`src_node`具备等价的控制和数据关系 + * @param src 源数据输出锚点 + * @param dsts 源数据输出锚点连接的目的数据输入锚点,使用vector的原因是存在一个源锚点给到多个目的锚点的情况 + * @param insert_op 表示要插入的opDesc,需要用其在src_node的图上生成一个node + * @param input_index 表示插入节点的哪个数据输入锚点要跟src相连,如果不传递,默认取0 + * @param output_index 表示插入节点的哪个数据输出锚点要跟dsts依次相连,如果不传递,默认取0 + * @return 如果插入成功返回insert_node,失败返回nullptr + */ + static NodePtr InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const OpDescPtr &insert_op, const uint32_t input_index = 0U, + const uint32_t output_index = 0U); + + /** + * 接口行为是在数据`dst`锚点所属的`dst_node`节点和其对端`src_node`节点之间插入一个`insert_node`节点, + * 默认是`insert_node`的`0`号数据输入锚点和`0`号数据输出数据锚点参与连边,`insert_node`插入之后, + * `dst_node`和`insert_node`作为一个整体与原来的`dst_node`具备等价的控制和数据关系 + * @param dst 目的数据输入锚点 + * @param insert_node 表示要插入的节点 + * @param input_index 表示插入节点的哪个数据输入锚点要跟dst的对端src锚点相连,如果不传递,默认取0 + * @param output_index 表示插入节点的哪个数据输出锚点要跟dst相连,如果不传递,默认取0 + * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus InsertNodeBefore(const InDataAnchorPtr &dst, const NodePtr &insert_node, + const uint32_t input_index = 0U, const uint32_t output_index = 0U); + + /** + * 接口行为是通过insert_op在图中生成一个insert_node节点, + * 在数据`dst`锚点所属的`dst_node`节点和其对端`src_node`节点之间插入一个`insert_node`节点, + * 默认是`insert_node`的`0`号数据输入锚点和`0`号数据输出数据锚点参与连边,`insert_node`插入之后, + * `dst_node`和`insert_node`作为一个整体与原来的`dst_node`具备等价的控制和数据关系 + * @param dst 目的数据输入锚点 + * @param insert_op 表示要插入的opDesc,需要用其在src_node的图上生成一个node + * @param input_index 表示插入节点的哪个数据输入锚点要跟dst的对端src锚点相连,如果不传递,默认取0 + * @param output_index 表示插入节点的哪个数据输出锚点要跟dst相连,如果不传递,默认取0 + * @return 如果插入成功返回insert_node,失败返回nullptr + */ + static NodePtr InsertNodeBefore(const InDataAnchorPtr &dst, const OpDescPtr &insert_op, + const uint32_t input_index = 0U, const uint32_t output_index = 0U); + + /** + * 从`compute_graph`智能指针管理的图对象的包含的nodes列表中删除`node`节点,仅仅是删除节点, + * 不包含对node的断边和重新连边等操作 + * @param compute_graph + * @param node + * @return 如果删除成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus RemoveJustNode(const ComputeGraphPtr compute_graph, const NodePtr &node); + + /** + * `RemoveJustNode`的批量删除接口,节点个数较多的时候,调用批量删除接口的性能比遍历`nodes`节点依次调用`RemoveJustNode`更高效一些 + * @param compute_graph + * @param nodes + * @return + */ + static graphStatus RemoveJustNodes(const ComputeGraphPtr &compute_graph, const std::unordered_set &nodes); + + /** + * 从`compute_graph`图对象的包含的nodes列表中删除`node`节点,仅仅是删除节点,不包含对`node`的断边和重新连边等操作 + * @param compute_graph + * @param node + * @return 如果删除成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); + + /** + * 记录`original_nodes`的原始name到node上 + * @param original_nodes + * @param node + */ + static void RecordOriginalNames(const std::vector original_nodes, const ge::NodePtr &node); + /** + * 记录`names_tmp`中的字段到node上 + * @param original_nodes + * @param node + */ + static void RecordOriginalNames(std::vector names_tmp, const ge::NodePtr &node); + /** + * 图dump接口,用于把`graph`对象序列化到文件,默认落盘到当前路径 + * @param graph + * @param suffix 用于拼接文件的名称 + * @param is_always_dump 如果值为true,则接口行为不受环境变量约束 + * @param user_graph_name 用于指定落盘的文件名和文件路径 + */ + static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, + const bool is_always_dump = false, const std::string &user_graph_name = ""); + /** + * 图dump接口,用于把`graph`对象序列化到文件,落盘到`path`指定的路径 + * @param graph + * @param path 指定落盘的路径 + * @param suffix 用于拼接文件的名称 + */ + static void DumpGEGrph(const ge::ComputeGraphPtr &graph, const std::string &path, const std::string &suffix); + /** + * 图dump接口,用于把`graph`对象序列化到文件,落盘到`file_path` + * @param graph + * @param file_path 路径+文件名 + * @param dump_level DUMP_GE_GRAPH环境变量以函数入参的表达 + * @return + */ + static graphStatus DumpGEGraphByPath(const ge::ComputeGraphPtr &graph, const std::string &file_path, + const ge::DumpLevel dump_level); + static graphStatus DumpGEGraphByPath(const ge::ComputeGraphPtr &graph, const std::string &file_path, + const int64_t dump_level); + /** + * 从`file`文件反序列化得到`compute_graph`的图对象 + * @param file + * @param compute_graph + * @return + */ + static bool LoadGEGraph(const char_t *const file, ge::ComputeGraph &compute_graph); + /** + * 从`file`文件反序列化得到`compute_graph`的智能指针对象 + * @param file + * @param compute_graph + * @return + */ + static bool LoadGEGraph(const char_t *const file, ge::ComputeGraphPtr &compute_graph); + + /** + * 图dump接口,用于把`graph`对象按照onnx的格式序列化到文件,默认落盘到当前路径 + * @param compute_graph + * @param suffix 用于拼接文件名 + */ + static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); + + /** + * 图dump接口,用于把`graph`对象按照onnx的格式序列化到文件,默认落盘到当前路径 + * @param compute_graph + * @param suffix 用于拼接文件名 + * @param is_always_dump 如果值为true,则接口行为不受环境变量约束 + */ + static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix, + bool is_always_dump); + /** + * 图dump接口,用于把`graph`对象按照onnx的格式序列化到文件,默认落盘到`path`路径 + * @param compute_graph + * @param path 路径名 + * @param suffix 拼接的文件名 + */ + static void DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &path, const std::string &suffix); + + static bool ReadProtoFromTextFile(const char_t *const file, google::protobuf::Message *const proto); + + static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char_t *const real_path); + + static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * 孤立`node`, 根据`io_map`完成node的输入输出数据边的重连;同时会添加必要的控制边保证`node`的所有输入节点 + * 均在`node`的输出节点之前执行 + * @param node + * @param io_map 把第`io_map[i]`个输入的对端输出,连接到第`i`个输出的对端输入。因此`io_map`的元素个数应该与 + * `node`的输出锚点的个数相等,如果`io_map[i]`小于0,则仅断开第`i`个输出锚点到对端的所有连边 + * @return + */ + static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list &io_map); + static graphStatus IsolateNode(const NodePtr &node, const std::vector &io_map); + /** + * `node`应该是单输入单输出的节点,接口行为等价于`IsolateNode(node, {0})` + * @param node + * @return + */ + static graphStatus IsolateNodeOneIO(const NodePtr &node); + + /** + * 此接口对数据关系的处理与`ReplaceNodeDataAnchors`的处理行为一致, 在此基础上, + * 复制了`old_node`的所有控制关系到`new_node`上,这也是要注意的一点: + * `数据`关系是`移动`操作,`控制`关系是`复制`操作 + * @param new_node + * @param old_node + * @param inputs_map 用于指导输入数据锚点的替换,注意元素个数不应该超过`new_node`的输入锚点总个数 + * @param outputs_map 用于指导输出锚点的替换,注意元素个数不应该超过`new_node`的输出锚点总个数 + * @return 如果替换成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::initializer_list inputs_map, + const std::initializer_list outputs_map); + + /** + * `ReplaceNodeAnchors`的重载接口 + * @param new_node + * @param old_node + * @param inputs_map + * @param outputs_map + * @return + */ + static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::vector &inputs_map, + const std::vector &outputs_map); + + /** + * 接口行为是根据`inputs_map`和`outputs_map`把`old_node`上的数据关系`移动`到`new_node`上;具体操作是 + * 把`old_node`的第`inputs_map[i]`/`outputs_map[i]`个数据锚点的数据关系替换到`new_node`的第`i`个 + * 数据锚点上, `i`的取值范围是[0, `inputs_map`/`outputs_map`的元素个数); 如果`inputs_map[i]`/`outputs_map[i]` + * 的值小于0或者不在`old_node`的锚点索引范围之内,那么`new_node`的第`i`个数据锚点的数据关系保持原样 + * @param new_node + * @param old_node + * @param inputs_map 用于指导输入数据锚点的替换,注意元素个数不应该超过`new_node`的输入锚点总个数 + * @param outputs_map 用于指导输出锚点的替换,注意元素个数不应该超过`new_node`的输出锚点总个数 + * @return + */ + static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::initializer_list inputs_map, + const std::initializer_list outputs_map); + + static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, + const std::vector &inputs_map, + const std::vector &outputs_map); + + /** + * 拷贝`src_node`的控制输入边到`dst_node`上 + * @param src_node + * @param dst_node + * @return + */ + static graphStatus CopyInCtrlEdges(const NodePtr &src_node, const NodePtr &dst_node); + /** + * 移动`src_node`的控制输入边到`dst_node`上 + * @param src_node + * @param dst_node + * @return + */ + static graphStatus MoveInCtrlEdges(const NodePtr &src_node, const NodePtr &dst_node); + /** + * 拷贝`src_node`的控制输出边到`dst_node`上 + * @param src_node + * @param dst_node + * @return + */ + static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, const NodePtr &dst_node); + /** + * 选择性拷贝`src_node`的控制输出边到`dst_node`上 + * @param src_node + * @param dst_node + * @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的边的复制,不传递时代表全量拷贝 + * @return + */ + static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, const NodePtr &dst_node, const NodeFilter &node_filter); + /** + * 移动`src_node`的控制输出边到`dst_node`上 + * @param src_node + * @param dst_node + * @return + */ + static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); + + /** + * 查找`graph`的根图,如果当前图就是根图或者当前图没有父图,则返回当前图 + * @param graph + * @return + */ + static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); + /** + * 浅拷贝`graph`,并不会同步拷贝`graph`的子图 + * @param graph + * @param suffix 用于拼接克隆图中的节点名称 + * @param input_nodes 返回克隆图中的输入节点 + * @param output_nodes 返回克隆图中的输出节点 + * @return + */ + static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const std::string &suffix, + std::vector &input_nodes, std::vector &output_nodes); + /** + * 拷贝`src_compute_graph`图上的attr属性到`dst_compute_graph`上, + * 需要注意的是 如果`dst_compute_graph`图上已经存在了某些同名属性,则会跳过这些属性的值的拷贝 + * @param src_compute_graph + * @param dst_compute_graph + */ + static void InheritOriginalAttr(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph); + /** + * 拷贝`src_node`节点及其所有有效的输入输出tensor上的attr属性到`dst_desc`上 + * @param dst_desc 目的OpDesc对象 + * @param src_node 源Node对象 + * @return 拷贝成功返回GRAPH_SUCCESS, 拷贝失败返回GRAPH_FAILED + */ + static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); + + /** + * 获取当前图里面的所有的节点的输入输出tensor的复用关系 + * @param graph + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus GetRefMapping(const ComputeGraphPtr &graph, SymbolToAnchors &symbol_to_anchors, + AnchorToSymbol &anchor_to_symbol); + + /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs + /// of the graph have UNKNOWN_SHAPE operators or not. + /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following + /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE + /// ROOT graph: A -----> B -----> C + /// K subgraph U + /// | + /// V + /// SUB graph: D --> E --> F + /// K K K + /// @param [in] graph + /// @return bool + static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); + + static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); + static std::vector FindNodesByTypeFromAllNodes(ComputeGraphPtr &graph, const std::string &type); + static std::vector FindBareNodesByTypeFromAllNodes(ComputeGraphPtr &graph, const char_t *const type); + /** + * 判断当前`out_data_anchor`是否复用了输入anchor的内存 + * @param out_data_anchor + * @param reuse_in_index 复用的输入anchor的index + * @return 如果存在复用关系,返回true, 否则返回false + */ + static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); + /** + * 判断当前`out_data_anchor`是否引用RefData的输出 + * @param out_data_anchor + * @param ref_data 复用的RefData节点 + * @param 如果存在复用关系,返回true, 否则返回false + * @return status + */ + static graphStatus CheckIsRefFromOther(const OutDataAnchorPtr &out_data_anchor, NodePtr &refed_node, + bool &is_ref_from_other); + /** + * 针对含有`ATTR_NAME_NOPADDING_CONTINUOUS_INPUT`和`ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT`类型的节点 + * 单独封装的复用接口 + * @param out_data_anchor + * @param reuse_in_index 出参,如果存在复用,值为0 + * @return 如果存在复用,返回true,负责返回false + */ + static bool IsNoPaddingRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); + /** + * 用于判断`node`是否直接或者间接从属于`graph`, `间接`的一种含义是`node`的父图是`graph` + * @param graph + * @param node + * @return + */ + static bool IsNodeInGraphRecursively(const ComputeGraphPtr &graph, const Node &node); + + /** + * 获取所有`直接`父图为`graph`和`间接`父图为`graph`的子图对象合集 + * @param graph + * @param subgraphs 子图对象的合集 + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus GetSubgraphsRecursively(const ComputeGraphPtr &graph, std::vector &subgraphs); + + /** + * 创建以`subgraph_name`拼接命名的子图对象`subgraph`,把`nodes`中的节点从`graph`中抽取出来放在`subgraph`中, + * 完成图归属和节点连边关系的重建,`nodes`作为一个整体与`subgraph`父节点等价 + * @param graph + * @param nodes + * @param subgraph_name + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static ComputeGraphPtr BuildSubgraphWithNodes(const ComputeGraphPtr &graph, const std::set &nodes, + const std::string &subgraph_name); + /** + * `BuildSubgraphWithNodes`的重载接口 + * @param graph + * @param nodes + * @param subgraph_name + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static ComputeGraphPtr BuildSubgraphWithNodes(ComputeGraph &graph, const std::set &nodes, + const std::string &subgraph_name); + /** + * 作为`BuildSubgraphWithNodes`函数的逆向操作,会把`graph`展开其父图上, + * 此接口支持子图的递归展开操作 + * @param graph 要展开的子图 + * @param filter 子图过滤器,用于过滤子图的子图是否要展开,不传递时不进行递归操作 + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus UnfoldSubgraph(const ComputeGraphPtr &graph, + const std::function &filter); + /** + * 作为`UnfoldSubgraph`的高阶版本,支持没有父子关系的图的展开操作,展开`graph`到`target_graph`上, + * `graph`作为一个整体,等价替换掉`target_graph`中的`target_node`; + * 此接口支持子图的递归展开操作 + * @param graph 要展开的子图 + * @param target_graph 展开到的目标图 + * @param target_node 子图要替换的目标节点 + * @param filter 图过滤器,用于过滤子图的子图是否要展开,不传递时不进行递归操作 + * @param depth + * @return 成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus UnfoldGraph(const ComputeGraphPtr &graph, const ComputeGraphPtr &target_graph, + const NodePtr &target_node, const function &filter, + int32_t depth = 0); + + static bool IsSingleOpScene(const ComputeGraphPtr &graph); + + static CycleDetectorPtr CreateCycleDetector(const ComputeGraphPtr &graph); + + static CycleDetectorSharedPtr CreateSharedCycleDetector(const ComputeGraphPtr &graph); + + /** + * 将node所有的输入、输出边断开,并移动到dst_graph + * @param dst_graph 目的Graph, + * @param node 需要移动的Node + * @return 成功时,返回ge::GRAPH_SUCCESS + */ + static graphStatus MoveNodeToGraph(const NodePtr &node, ComputeGraph &dst_graph); + + /** + * 判断当前节点的某个输出是否可以复用输入 + * @param node 目标节点, + * @param out_index_to_refable_in_indexes 返回节点的输出-输入(单输出多输入)复用关系 + * @return 成功时,返回SUCCESS + */ + static graphStatus GetSupportInplaceOutput(const NodePtr &node, + std::map> &out_index_to_refable_in_indexes); + + /** + * 用expand_graph图替换target_node节点,先将expand_graph中的算子插入到图中,再建立连边关系,最后删掉target_node + * @param target_node 目标节点 + * @param expand_graph 需要被展开的图 + * @return 成功时,返回SUCCESS 失败返回FAILED + */ + static graphStatus ExpandNodeWithGraph(const NodePtr &target_node, const ComputeGraphPtr &expand_graph); + + private: + class GraphInfo { + public: + GraphInfo() = default; + ~GraphInfo() = default; + + private: + std::set nodes_; + std::map>> data_inputs_; + std::map>> data_outputs_; + std::list> ctrl_inputs_; + std::list> ctrl_outputs_; + std::list> inner_data_edges_; + std::list> inner_ctrl_edges_; + friend class GraphUtils; + }; + + /** + * 将src_graph中所有节点插入到target_graph中,插入节点在列表中的位置在target_node之后 + * @param target_graph 需要被插入算子的图 + * @param target_node 新插入的算子需要被插在target_node的后面,且target_node需要在target_graph上 + * @param insert_nodes 表示要插入的节点 + * @return 如果插入成功返回GRAPH_SUCCESS,失败返回GRAPH_FAILED + */ + static graphStatus MoveNodesToGraphAfterTargetNode(const ComputeGraphPtr &target_graph, + const NodePtr &target_node, + const ComputeGraphPtr &src_graph); + /** + * 创建节点输入tensor的内存符号,正常情况下,应该跟其对端的输出tensor的内存符号相同,因为二者是一块地址 + * @param graph + * @param node + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus HandleInAnchorMapping(const ComputeGraphPtr &graph, const NodePtr &node, + SymbolToAnchors &symbol_to_anchors, AnchorToSymbol &anchor_to_symbol); + + /** + * 创建节点输出tensor的内存符号 + * @param node + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus HandleOutAnchorMapping(const NodePtr &node, SymbolToAnchors &symbol_to_anchors, + AnchorToSymbol &anchor_to_symbol); + + /** + * 创建子图内Data节点的输出tensor的内存符号,正常应该跟父节点对应的输入tensor内存符号相同 + * @param node + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus HandleSubgraphInput(const NodePtr &node, SymbolToAnchors &symbol_to_anchors, + AnchorToSymbol &anchor_to_symbol); + + /** + * 创建merge节点的内存符号,merge是特殊的v1控制算子,其输出和多个输入存在复用关系,因此单独封装函数处理 + * @param node + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus HandleMergeInput(const NodePtr &node, SymbolToAnchors &symbol_to_anchors, + AnchorToSymbol &anchor_to_symbol); + + /** + * 创建子图内Netoutput节点的输出tensor的内存符号,正常应该跟父节点对应的输出tensor内存符号相同 + * @param node + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus HandleSubgraphOutput(const NodePtr &node, SymbolToAnchors &symbol_to_anchors, + AnchorToSymbol &anchor_to_symbol); + + /** + * 合并代表同一块地址的不同符号 + * @param exist_node_info1 + * @param exist_node_info2 + * @param symbol_to_anchors + * @param anchor_to_symbol + * @param symbol + * @return + */ + static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, + SymbolToAnchors &symbol_to_anchors, AnchorToSymbol &anchor_to_symbol, + std::string &symbol); + + /** + * 对于同一块地址,使用已有tensor的符号设置当前tensor的符号 + * @param cur_node_info + * @param exist_node_info + * @param symbol_to_anchors + * @param anchor_to_symbol + * @return + */ + static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, + SymbolToAnchors &symbol_to_anchors, AnchorToSymbol &anchor_to_symbol); + + /// Relink all edges for cloned ComputeGraph. + /// @param [in] node: original node. + /// @param [in] suffix: node name suffix of new node. + /// @param [in] all_nodes: all nodes in new graph. + /// @return success: GRAPH_SUCESS + static graphStatus RelinkGraphEdges(const NodePtr &node, const std::string &suffix, + const std::unordered_map &all_nodes); + static void BuildGraphInfoFromNodes(const std::set &nodes, GraphInfo &graph_info); + + static void BuildInDataEdgesFromNode(const NodePtr &node, const std::set &nodes, + std::map &data_input_index_map, GraphInfo &graph_info); + + static NodePtr BuildSubgraphNode(ComputeGraph &graph, const std::string &graph_name, const GraphInfo &graph_info); + + static ComputeGraphPtr BuildSubgraph(const NodePtr &subgraph_node, const GraphInfo &graph_info, + const std::string &subgraph_name); + + static graphStatus RelinkDataEdges(const NodePtr &subgraph_node, const GraphInfo &graph_info); + + static graphStatus RelinkCtrlEdges(const NodePtr &subgraph_node, const GraphInfo &graph_info); + + static graphStatus MergeInputNodes(const ComputeGraphPtr &graph, const NodePtr &target_node); + + static graphStatus MergeNetOutputNode(const ComputeGraphPtr &graph, const NodePtr &target_node); + + static bool NoNeedDumpGraphBySuffix(const std::string &suffix); + + static graphStatus CopyOpAndSubgraph(const ComputeGraphPtr &src_compute_graph, const NodeFilter &node_filter, + const GraphFilter &graph_filter, const AttrFilter &attr_filter, + ComputeGraphPtr &dst_compute_graph, + std::map &node_old_2_new, + std::map &op_desc_old_2_new, + std::unordered_map &all_new_nodes, const int32_t depth); + + static graphStatus CopyMembers(const ComputeGraphPtr &src_compute_graph, ComputeGraphPtr &dst_compute_graph, + const std::unordered_map &all_new_nodes); + + static graphStatus CopyGraphImpl(const Graph &src_graph, Graph &dst_graph, + const std::map &node_old_2_new, + const std::map &op_desc_old_2_new); + + /** + * 改图接口,将图中的“FileConstant”节点用属性中对应的权重文件恢复成“Const”节点 + * @param compute_graph + * @return + */ + static graphStatus ConvertFileConstToConst(const ComputeGraphPtr &graph); + static graphStatus RecoverConstByWeightFile(const OpDescPtr &op_desc, const GeTensorPtr &weight); +}; + +/** + * 黑匣子图dump接口,用于把`compute_graph`对象序列化到文件,默认落盘到当前路径; + * 如果`compute_graph`挂载了子图对象,子图对象也尝试进行落盘 + * 黑匣子图的落盘行为不受`DUMP_GE_GRAPH`和`DUMP_GRAPH_LEVEL`环境变量的控制,用于异常场景保留现场 + * @param compute_graph + * @param name 用于拼接文件的名称 + */ +inline void GE_DUMP_BLACK_BOX(const ComputeGraphPtr &compute_graph, const std::string &name) { + GraphUtils::DumpGEGraph((compute_graph), (name), true); + GraphUtils::DumpGEGraphToOnnx(*(compute_graph), (name), true); + uint64_t i = 0U; + for (const auto &sub_graph_func: (compute_graph)->GetAllSubgraphs()) { + const auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); + GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name, true); + GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name, true); + } +} + +class ComputeGraphBuilder { + public: + ComputeGraphBuilder() : owner_graph_(nullptr) {} + virtual ~ComputeGraphBuilder() = default; + + /// @brief Add node to graph + /// @param [in] op_desc + /// @return ComputeGraphBuilder + virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); + + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return ComputeGraphBuilder + virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, const uint32_t out_anchor_ind, + const std::string &dst_name, const uint32_t in_anchor_ind); + + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return ComputeGraphBuilder + virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); + + /// @brief Build graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; + + /// @brief Get node with name + /// @param [in] name + /// @return NodePtr + NodePtr GetNode(const std::string &name); + + /// @brief Get all nodes + /// @return std::vector + std::vector GetAllNodes(); + + protected: + /// @brief Build nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void BuildNodes(graphStatus &error_code, std::string &error_msg); + + /// @brief Build data-links + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void BuildDataLinks(graphStatus &error_code, std::string &error_msg); + + /// @brief Build ctrl-links + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); + + private: + ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; + ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; + ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; + ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; + + ComputeGraphPtr owner_graph_; + // node_name -> node + std::map node_names_; + std::vector nodes_; + // -> + std::vector, std::pair>> data_links_; + // src_node_name -> dst_node_name + std::vector> ctrl_links_; + + friend class CompleteGraphBuilder; + friend class PartialGraphBuilder; +}; + +class CompleteGraphBuilder : public ComputeGraphBuilder { + public: + explicit CompleteGraphBuilder(const std::string name, const bool retval_flag = true) + : ComputeGraphBuilder(), name_(name), parent_node_(nullptr), retval_flag_(retval_flag) {} + CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; + CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; + CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; + CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; + ~CompleteGraphBuilder() = default; + + /// @brief Add node to graph + /// @param [in] op_desc + /// @return CompleteGraphBuilder + CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; + + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return CompleteGraphBuilder + CompleteGraphBuilder &AddDataLink(const std::string &src_name, const uint32_t out_anchor_ind, + const std::string &dst_name, const uint32_t in_anchor_ind) override; + + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return CompleteGraphBuilder + CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; + + /// @brief Set index_th input anchor for graph + /// @param [in] index + /// @param [in] node_names + /// @param [in] anchor_inds + /// @return CompleteGraphBuilder + CompleteGraphBuilder &SetInput(const uint32_t index, const std::vector &node_names, + const std::vector &anchor_inds); + + /// @brief Set index_th input of graph as useless + /// @param [in] index + /// @return CompleteGraphBuilder + CompleteGraphBuilder &SetUselessInput(const uint32_t index); + + /// @brief Add output anchor for graph + /// @param [in] owner_node_name + /// @param [in] anchor_ind + /// @return CompleteGraphBuilder + CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); + + /// @brief Add target for graph + /// @param [in] target_name + /// @return CompleteGraphBuilder + CompleteGraphBuilder &AddTarget(const std::string &target_name); + + /// @brief Set parent-node of graph + /// @param [in] parent_node + /// @return CompleteGraphBuilder + CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); + + /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node + /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node + /// @return CompleteGraphBuilder + CompleteGraphBuilder &SetInputMapping(const std::map &input_mapping); + + /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind + /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node + /// @return CompleteGraphBuilder + CompleteGraphBuilder &SetOutputMapping(const std::map &output_mapping); + + /// @brief Build graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; + + private: + /// @brief Add data nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void AddDataNodes(graphStatus &error_code, std::string &error_msg); + + /// @brief Add data node + /// @param [in] index + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + NodePtr AddDataNode(const uint32_t index, graphStatus &error_code, std::string &error_msg); + + /// @brief Add RetVal nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void AddRetValNodes(graphStatus &error_code, std::string &error_msg); + + /// @brief Build target-nodes for graph + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); + + /// @brief Add NetOutput node + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void AddNetOutputNode(graphStatus &error_code, std::string &error_msg); + + /// @brief Build NetOutput nodes with data & ctrl edges + /// @param [in] net_output_desc + /// @param [in] peer_out_anchors + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc, + const std::vector &peer_out_anchors, graphStatus &error_code, + std::string &error_msg); + + /// @brief process after build + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void PostProcess(graphStatus &error_code, std::string &error_msg); + + std::string name_; + NodePtr parent_node_; + bool retval_flag_; + std::map, std::vector>> graph_inputs_; + std::vector> graph_outputs_; + std::vector graph_targets_; + + // index_of_graph_input -> in_anchor_index_of_parent_node + std::map input_mapping_; + // index_of_graph_output -> out_anchor_index_of_parent_node + std::map output_mapping_; +}; + +class PartialGraphBuilder : public ComputeGraphBuilder { + public: + PartialGraphBuilder() = default; + PartialGraphBuilder(const PartialGraphBuilder &) = delete; + PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; + PartialGraphBuilder(const PartialGraphBuilder &&) = delete; + PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; + ~PartialGraphBuilder() = default; + + /// @brief Add node to graph + /// @param [in] op_desc + /// @return PartialGraphBuilder + PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; + + /// @brief Add data-link among nodes in graph + /// @param [in] src_name + /// @param [in] out_anchor_ind + /// @param [in] dst_name + /// @param [in] in_anchor_ind + /// @return PartialGraphBuilder + PartialGraphBuilder &AddDataLink(const std::string &src_name, const uint32_t out_anchor_ind, + const std::string &dst_name, const uint32_t in_anchor_ind) override; + + /// @brief Add ctrl-link among nodes in graph + /// @param [in] src_name + /// @param [in] dst_name + /// @return PartialGraphBuilder + PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; + + /// @brief Set owner graph + /// @param [in] graph + /// @return PartialGraphBuilder + PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); + + /// @brief Add exist node + /// @param [in] node + /// @return PartialGraphBuilder + PartialGraphBuilder &AddExistNode(const NodePtr &exist_node); + + /// @brief Build multi nodes with links + /// @param [out] error_code + /// @param [out] error_msg + /// @return ComputeGraphPtr + ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; + + private: + /// @brief Build exist nodes + /// @param [out] error_code + /// @param [out] error_msg + /// @return void + void BuildExistNodes(graphStatus &error_code, std::string &error_msg); + + std::vector exist_nodes_; +}; +} // namespace ge + +#endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/metadef/graph/utils/graph_utils_ex.h b/inc/metadef/graph/utils/graph_utils_ex.h new file mode 100644 index 0000000000000000000000000000000000000000..670abcc20d56c9d75fe874fd70694b6dde05608b --- /dev/null +++ b/inc/metadef/graph/utils/graph_utils_ex.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_METADEF_GRAPH_UTILS_EX_H +#define __INC_METADEF_GRAPH_UTILS_EX_H + +#include "graph/node.h" +#include "graph/compute_graph.h" +#include "external/graph/graph.h" + +namespace ge { +class GraphUtilsEx { + public: + // Detach from ComputeGraph + static graphStatus InferOriginFormat(const ComputeGraphPtr &graph); + static graphStatus InferShapeInNeed(const ComputeGraphPtr &graph); + + // Detach from GraphUtils + static ComputeGraphPtr GetComputeGraph(const Graph &graph); + static ComputeGraphPtr CreateGraphFromOperator(const std::string &name, const std::vector &inputs); + + /** + * 使用ops中的算子为graph对象构造图,且构造出来的图中的算子需要按照ops中的顺序排序 + * @param graph 需要构造图的graph对象 + * @param ops 用于生成计算图的算子 + * @return 计算图指针,成功时,返回生成的ComputeGraph指针 失败返回nullptr + */ + static graphStatus CreateGraphFromOperatorWithStableTopo(Graph &graph, + const std::vector &ops); + /** + * 使用ops中的算子构造计算图,且构造出来的图中的算子需要按照ops中的顺序排序 + * @param graph 需要构造图的graph对象 + * @param ops 用于生成计算图的算子 + * @return 计算图指针,成功时,返回生成的ComputeGraph指针 失败返回nullptr + */ + static ComputeGraphPtr CreateComputeGraphFromOperatorWithStableTopo(const std::string &name, + const std::vector &ops); + + static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); + static GraphPtr CreateGraphPtrFromComputeGraph(const ComputeGraphPtr compute_graph); + static void BreakConnect(const std::map &all_nodes_infos); + static graphStatus RecoverGraphOperators(const Graph &graph); + static graphStatus CopyGraph(const Graph &src_graph, Graph &dst_graph); + + private: + static graphStatus CopyGraphImpl(const Graph &src_graph, Graph &dst_graph, + const std::map &node_old_2_new, + const std::map &op_desc_old_2_new); +}; +} // namespace ge +#endif // __INC_METADEF_GRAPH_UTILS_EX_H diff --git a/inc/metadef/graph/utils/hash_utils.h b/inc/metadef/graph/utils/hash_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9430d6ef3bd3d51d6f41ffd00824ee452f96fbc6 --- /dev/null +++ b/inc/metadef/graph/utils/hash_utils.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef GRAPH_COMPILE_CACHE_POLICY_HASH_UTILS_H_ +#define GRAPH_COMPILE_CACHE_POLICY_HASH_UTILS_H_ + +#include +#include +#include +#include "graph/small_vector.h" +#include "graph/types.h" + +namespace ge { +using CacheHashKey = uint64_t; +class HashUtils { +public: + static constexpr CacheHashKey HASH_SEED = 0x7863a7deUL; + static constexpr CacheHashKey COMBINE_KEY = 0x9e3779b9UL; + + template + static inline CacheHashKey HashCombine(CacheHashKey seed, const T &value) { + const std::hash hasher; + seed ^= hasher(value) + COMBINE_KEY + (seed << 6U) + (seed >> 2U); + return seed; + } + + static inline CacheHashKey HashCombine(CacheHashKey seed, const Format &value) { + const std::hash hasher; + seed ^= hasher(static_cast(value)) + COMBINE_KEY + (seed << 6U) + (seed >> 2U); + return seed; + } + + static inline CacheHashKey HashCombine(CacheHashKey seed, const DataType &value) { + const std::hash hasher; + seed ^= hasher(static_cast(value)) + COMBINE_KEY + (seed << 6U) + (seed >> 2U); + return seed; + } + + template + static inline CacheHashKey HashCombine(CacheHashKey seed, const SmallVector &values) { + for (const auto &val : values) { + seed = HashCombine(seed, val); + } + return seed; + } + + template + static inline CacheHashKey HashCombine(CacheHashKey seed, const std::vector &values) { + for (const auto &val : values) { + seed = HashCombine(seed, val); + } + return seed; + } + + static inline CacheHashKey MultiHash() { + return HASH_SEED; + } + + template + static inline CacheHashKey MultiHash(const T &value, const M... args) { + return HashCombine(MultiHash(args...), value); + } +}; +} // namespace ge +#endif diff --git a/inc/metadef/graph/utils/math_util.h b/inc/metadef/graph/utils/math_util.h new file mode 100644 index 0000000000000000000000000000000000000000..24da966fe3f2562190c0af040911c2ab79618643 --- /dev/null +++ b/inc/metadef/graph/utils/math_util.h @@ -0,0 +1,178 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_GRAPH_UTILS_MATH_UTIL_H_ +#define METADEF_CXX_INC_GRAPH_UTILS_MATH_UTIL_H_ + +#include +#include +#include +#include +#include +#include + +#include "graph/def_types.h" +#include "external/utils/extern_math_util.h" +#include "common/ge_common/debug/log.h" + +namespace ge { +constexpr uint32_t kDiv16RightShiftBits = 4U; +constexpr uint32_t kDiv32RightShiftBits = 5U; +/** + * @ingroup domi_calibration + * @brief Initializes an input array to a specified value + * @param [in] n array initialization length + * @param [in] alpha initialization value + * @param [out] output array to be initialized + * @return Status + */ +template +Status NnSet(const int32_t n, const Dtype alpha, Dtype *const output) { + GE_CHECK_NOTNULL(output); + + if (std::equal_to()(alpha, static_cast(0))) { + if ((sizeof(Dtype) * static_cast(n)) < SECUREC_MEM_MAX_LEN) { + const errno_t err = + memset_s(output, sizeof(Dtype) * static_cast(n), 0, sizeof(Dtype) * static_cast(n)); + GE_CHK_BOOL_RET_STATUS(err == EOK, PARAM_INVALID, "memset_s err"); + } else { + const uint64_t size = static_cast(sizeof(Dtype) * static_cast(n)); + const uint64_t step = SECUREC_MEM_MAX_LEN - (SECUREC_MEM_MAX_LEN % sizeof(Dtype)); + const uint64_t times = size / step; + const uint64_t remainder = size % step; + uint64_t i = 0U; + while (i < times) { + const errno_t err = memset_s(ValueToPtr(PtrToValue(output) + (i * (step / sizeof(Dtype)))), step, 0, step); + GE_CHK_BOOL_RET_STATUS(err == EOK, PARAM_INVALID, "memset_s err"); + i++; + } + if (remainder != 0U) { + const errno_t err = + memset_s(ValueToPtr(PtrToValue(output) + (i * (step / sizeof(Dtype)))), remainder, 0, remainder); + GE_CHK_BOOL_RET_STATUS(err == EOK, PARAM_INVALID, "memset_s err"); + } + } + } + + for (int32_t i = 0; i < n; ++i) { + output[i] = alpha; + } + return SUCCESS; +} + +template +bool RoundUpOverflow(T value, T multiple_of, TR &ret) { + if (multiple_of == 0) { + ret = 0; + return true; + } + auto remainder = value % multiple_of; + if (remainder == 0) { + if (!IntegerChecker::Compat(value)) { + return true; + } + ret = static_cast(value); + return false; + } + return AddOverflow(value - remainder, multiple_of, ret); +} +template +T CeilDiv16(const T n) { + if (n & 0xF) { + return (n >> kDiv16RightShiftBits) + 1; + } else { + return n >> kDiv16RightShiftBits; + } +} + +template +T CeilDiv32(const T n) { + if (n & 0x1F) { + return (n >> kDiv32RightShiftBits) + 1; + } else { + return n >> kDiv32RightShiftBits; + } +} + +template +T CeilDiv(const T n1, const T n2) { + if (n1 == 0) { + return 0; + } + return (n2 != 0) ? (((n1 - 1) / n2) + 1) : n1; +} + +template +T FloorDiv(const T u_value, const T d_value) { + if (d_value == 0) { + return u_value; + } + return u_value / d_value; +} + +inline uint64_t RoundUp(const uint64_t origin_value, const uint64_t multiple_of) { + uint64_t ret = 0U; + if (RoundUpOverflow(origin_value, multiple_of, ret)) { + return 0U; + } + return ret; +} + +inline Status GeMemcpy(uint8_t *dst_ptr, size_t dst_size, const uint8_t *src_ptr, const size_t src_size) { + GE_CHK_BOOL_RET_STATUS((dst_size >= src_size), PARAM_INVALID, "memcpy_s verify fail, src size %zu, dst size %zu", + src_size, dst_size); + size_t offset = 0U; + size_t remain_size = src_size; + do { + size_t copy_size = (remain_size > SECUREC_MEM_MAX_LEN) ? SECUREC_MEM_MAX_LEN : remain_size; + const auto err = memcpy_s((dst_ptr + offset), copy_size, (src_ptr + offset), copy_size); + GE_CHK_BOOL_RET_STATUS(err == EOK, PARAM_INVALID, + "memcpy_s err, src ptr %p size %zu, dst ptr %p size %zu, offset %zu, err %d", + src_ptr, copy_size, dst_ptr, (dst_size - offset), offset, err); + offset += copy_size; + remain_size -= copy_size; + } while (remain_size > 0U); + return SUCCESS; +} +} // end namespace ge + +#define REQUIRE_COMPAT(T, v) \ + do { \ + if (!IntegerChecker::Compat((v))) { \ + std::stringstream ss; \ + ss << #v << " value " << (v) << " out of " << #T << " range [" << std::numeric_limits::min() << "," \ + << std::numeric_limits::max() << "]"; \ + GELOGE(ge::FAILED, "%s", ss.str().c_str()); \ + return ge::FAILED; \ + } \ + } while (false) + +#define REQUIRE_COMPAT_FOR_CHAR(T, v) \ + do { \ + if (!IntegerChecker::Compat((v))) { \ + std::stringstream ss; \ + ss << #v << " value " << static_cast(v) << " out of " << #T << " range [" \ + << static_cast(std::numeric_limits::min()) << "," \ + << static_cast(std::numeric_limits::max()) << "]"; \ + GELOGE(ge::FAILED, "%s", ss.str().c_str()); \ + return ge::FAILED; \ + } \ + } while (false) + +#define REQUIRE_COMPAT_INT8(v) REQUIRE_COMPAT_FOR_CHAR(int8_t, (v)) +#define REQUIRE_COMPAT_UINT8(v) REQUIRE_COMPAT_FOR_CHAR(uint8_t, (v)) +#define REQUIRE_COMPAT_INT16(v) REQUIRE_COMPAT(int16_t, (v)) +#define REQUIRE_COMPAT_UINT16(v) REQUIRE_COMPAT(uint16_t, (v)) +#define REQUIRE_COMPAT_INT32(v) REQUIRE_COMPAT(int32_t, (v)) +#define REQUIRE_COMPAT_UINT32(v) REQUIRE_COMPAT(uint32_t, (v)) +#define REQUIRE_COMPAT_INT64(v) REQUIRE_COMPAT(int64_t, (v)) +#define REQUIRE_COMPAT_UINT64(v) REQUIRE_COMPAT(uint64_t, (v)) +#define REQUIRE_COMPAT_SIZE_T(v) REQUIRE_COMPAT(size_t, (v)) + +#endif // METADEF_CXX_INC_GRAPH_UTILS_MATH_UTIL_H_ diff --git a/inc/metadef/graph/utils/node_adapter.h b/inc/metadef/graph/utils/node_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..9b8a60c4e9792be319391e3871cffbc9f2fe94c7 --- /dev/null +++ b/inc/metadef/graph/utils/node_adapter.h @@ -0,0 +1,24 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_NODE_ADAPTER_H_ +#define INC_GRAPH_UTILS_NODE_ADAPTER_H_ + +#include "graph/gnode.h" +#include "graph/node.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodeAdapter { + public: + static GNode Node2GNode(const NodePtr &node); + static NodePtr GNode2Node(const GNode &node); + static GNodePtr Node2GNodePtr(const NodePtr &node); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_NODE_ADAPTER_H_ diff --git a/inc/metadef/graph/utils/node_utils.h b/inc/metadef/graph/utils/node_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..df901e8db6117a7322c87e94ee24dee57a0c32a4 --- /dev/null +++ b/inc/metadef/graph/utils/node_utils.h @@ -0,0 +1,272 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ +#define INC_GRAPH_UTILS_NODE_UTILS_H_ + +#include +#include +#include +#include +#include "external/graph/operator.h" +#include "external/graph/types.h" +#include "graph/anchor.h" +#include "graph/node.h" +#include "graph/compute_graph.h" + +/*lint -e148*/ +namespace ge { +// Op types of Const like Opps. +extern const std::set kConstOpTypes; + +// Op types of Enter like Opps. +extern const std::set kEnterOpTypes; +// Op types of Merge like Opps. +extern const std::set kMergeOpTypes; +// Op types of Switch like Opps. +extern const std::set kSwitchOpTypes; +// Op types of NextIteration like Opps. +extern const std::set kNextIterationOpTypes; +// Op types of Exit like Opps. +extern const std::set kExitOpTypes; + +// Op types of If like Opps. +extern const std::set kIfOpTypes; +// Op types of While like Opps. +extern const std::set kWhileOpTypes; +// Op types of Case like Opps. +extern const std::set kCaseOpTypes; +// Op types of For like Opps. +extern const std::set kForOpTypes; + +class NodeUtils { + public: + static graphStatus ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor); + static graphStatus SetAllAnchorStatus(const NodePtr &node_ptr); + static graphStatus SetAllAnchorStatus(Node &node); + static bool IsAnchorStatusSet(const NodePtr &node_ptr); + static bool IsAnchorStatusSet(const Node &node); + + static graphStatus MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node); + + static void UpdateIsInputConst(const NodePtr &node_ptr); + static void UpdateIsInputConst(Node &node); + static bool IsConst(const Node &node); + static void UnlinkAll(const Node &node); + + static bool ClearInputDesc(const OpDescPtr &op_desc, const uint32_t index); + static bool ClearOutputDesc(const OpDescPtr &op_desc, const uint32_t index); + + static graphStatus AppendInputAnchor(const NodePtr &node, const uint32_t num); + static graphStatus RemoveInputAnchor(const NodePtr &node, const uint32_t num); + + static graphStatus AppendOutputAnchor(const NodePtr &node, const uint32_t num); + static graphStatus RemoveOutputAnchor(const NodePtr &node, const uint32_t num); + + static GeTensorDesc GetOutputDesc(const Node &node, const uint32_t index); + // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; + // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, + // the out param "is_unknow" will be true too + static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); + + static std::string GetNodeType(const Node &node); + static std::string GetNodeType(const NodePtr &node); + + static graphStatus GetDirectSubgraphs(const NodePtr &node, std::vector &subgraphs); + static ComputeGraphPtr GetSubgraph(const Node &node, const uint32_t index); + static graphStatus SetSubgraph(Node &node, const uint32_t index, const ComputeGraphPtr &subgraph); + static graphStatus AddSubgraph(Node &node, const std::string &subgraph_name, const ComputeGraphPtr &subgraph); + static NodePtr CreatNodeWithoutGraph(const OpDescPtr op_desc); + /// Check if node is input of subgraph + /// @param [in] node + /// @return bool + static bool IsSubgraphInput(const NodePtr &node); + static bool IsSubgraphInput(const Node *const node); + + /// Check if node is output of subgraph + /// @param [in] node + /// @return bool + static bool IsSubgraphOutput(const NodePtr &node); + + /// @brief Get subgraph original input node. + /// @param [in] node + /// @return Node + static NodePtr GetParentInput(const Node &node); + static NodePtr GetParentInput(const NodePtr &node); + /// @brief Get subgraph original input node and corresponding out_anchor. + /// @param [in] node + /// @return NodeToOutAnchor node and out_anchor which linked to in_param node + static NodeToOutAnchor GetParentInputAndAnchor(const NodePtr &node); + /// @brief Get subgraph original input node and corresponding out_anchor corss subgraph. + /// @param [in] node + /// @return NodeToOutAnchor node and out_anchor which linked to in_param node + static NodeToOutAnchor GetParentInputAndAnchorCrossSubgraph(const NodePtr &node); + + /// @brief Get is dynamic shape graph from node. + /// @param [in] node + /// @return bool + static bool IsDynamicShape(const Node &node); + static bool IsDynamicShape(const NodePtr &node); + + /// @brief Check is varying_input for while node + /// @param [in] node: Data node for subgraph + /// @return bool + static bool IsWhileVaryingInput(const ge::NodePtr &node); + + /// @brief Get subgraph input is constant. + /// @param [in] node + /// @param [out] string + /// @return bool + static bool GetConstOpType(const NodePtr &node, std::string &type); + + /// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. + /// @param [in] node + /// @return return GRAPH_SUCCESS if remove successfully, other for failed. + static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); + + /** + * 获取`node`挂载的所有子图中的索引为`index`的Data节点集合; + * 每个子图最多能找到一个跟`index`匹配的Data节点 + * @param node + * @param index + * @return + */ + static std::vector GetSubgraphDataNodesByIndex(const Node &node, const int32_t index); + + /** + * 获取`node`挂载的所有子图中的NetOutput节点集合; + * 每个子图有且只有一个NetOutput节点 + * @param node + * @return + */ + static std::vector GetSubgraphOutputNodes(const Node &node); + + /** + * 获取`node`所在的图对应的根图 + * @param node + * @return + */ + static ComputeGraphPtr FindRootGraph(const Node &node); + + /** + * 根据`node_filter`获取被node控制的输出节点 + * @param node + * @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输出节点的获取 + * @return + */ + static std::vector GetOutControlNodes(const Node &node, const NodeFilter &node_filter); + + /** + * 根据`node_filter`获取控制node的输入节点 + * @param node + * @param node_filter 控制边拷贝白名单过滤器,可以通过传递此参数实现满足条件的输入节点的获取 + * @return + */ + static std::vector GetInControlNodes(const Node &node, const NodeFilter &node_filter); + + static NodePtr GetInDataNodeByIndex(const Node &node, const int32_t index); + static std::pair GetInDataNodeAndAnchorByIndex(const Node &node, const int32_t index); + + static std::vector> GetOutDataNodesWithAnchorByIndex(const Node &node, + const int32_t index); + + /** + * 适用于`node`节点作为子图中的Data占位节点时,获取根图中父节点对应的实际输入节点的类型 + * 其他情况返回`node`本身的节点类型 + * @param node + * @return + */ + static std::string GetInConstNodeTypeCrossSubgraph(const ge::NodePtr &node); + + /** +* 适用于`node`节点作为子图中的Data占位节点时,获取根图中父节点对应的实际输入节点对象 +* 其他情况返回`node`本身 +* @param node +* @return +*/ + static NodePtr GetInNodeCrossSubgraph(const ge::NodePtr &node); + + /// @brief Get peer input node, supported get cross PartitionedCall . + /// @param [in] node, current node + /// @param [in] index, current node the index'th input, if it is PartionedCall's subgraph Data, please assign 0 + /// @param [out] peer_node, + /// A(PartionedCall_0)->B(PartionedCall_1) + /// PartionedCall_0's subgraph: Data->A->Netoutput + /// PartionedCall_1's subgraph: Data1->B->Netoutput + /// If it is called like GetInNodeCrossPartionCallNode(B,0,peer_node)or(Data1,0,peer_node), peer_node is A + /// @param [out] peer_out_anchor_index, peer_node's corresponding out anchor's index + /// @return [graphStatus] running result of this function + static graphStatus GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node); + static graphStatus GetInNodeCrossPartionedCallNode(const NodePtr &node, uint32_t index, NodePtr &peer_node, + int32_t &peer_out_anchor_index); + + static graphStatus SetNodeParallelGroup(Node &node, const char_t *const group_name); + + static graphStatus UpdateInputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape); + static graphStatus UpdateOutputOriginalShapeAndShape(const Node &node, const uint32_t index, const GeShape &shape); + static bool IsDtResourceNode(const NodePtr &node); + static bool IsLikeAtomicClean(const NodePtr &node); + /** + * 用于判断identity节点是否被用于控制先读后写顺序的,如果是的话, + * 则图优化的时候不能无脑删除identity节点来提升性能 + * @param node_ptr + * @return + */ + static bool IsIdentityUsefulForRWControl(const NodePtr &node_ptr); + /** + * 尝试通过pld占位节点对应的实际const节点来获取权重 + * @param node_ptr placeholder的占位节点,常见于图拆分中间状态的图的输入节点类型 + * @param ge_tensor 权重的承载对象,成功获取时ge_tensor被设置为非空 + * @return 失败时代表内部流程错误,成功时不代表一定获取到了权重 + */ + static graphStatus TryGetWeightByPlaceHolderNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor); + /** + * 尝试通过Data占位节点对应的实际const节点来获取权重 + * @param node_ptr Data占位节点,常见于子图的输入节点类型 + * @param ge_tensor 权重的承载对象,成功获取时ge_tensor被设置为非空 + * @return 失败时代表内部流程错误,成功时不代表一定获取到了权重 + */ + static graphStatus TryGetWeightByDataNode(const NodePtr &node_ptr, ConstGeTensorPtr &ge_tensor); + /** + * 判断`node`的名称是否是`name` + * @param node + * @param name + * @return 如果是的话,返回true,否则 false + */ + static bool IsNameEqual(const NodePtr &node, const ge::char_t *const name); + /** + * 判断`node`的类型是否是`type` + * @param node + * @param type + * @return + */ + static bool IsTypeEqual(const NodePtr &node, const ge::char_t *const type); +}; + +struct NodeCompareKey { + bool operator()(const NodePtr &n0, const NodePtr &n1) const { + if ((n0 == nullptr) || (n1 == nullptr)) { + return false; + } + int32_t comp_res = strcmp(n0->GetNamePtr(), n1->GetNamePtr()); + if (comp_res == 0) { + const auto graph0 = n0->GetOwnerComputeGraph(); + const auto graph1 = n1->GetOwnerComputeGraph(); + if ((graph0 == nullptr) || (graph1 == nullptr)) { + return false; + } + return (graph0->GetName() < graph1->GetName()); + } + return (comp_res < 0); + } +}; +using OrderedNodeSet = std::set; +} // namespace ge +/*lint +e148*/ +#endif // INC_GRAPH_UTILS_NODE_UTILS_H_ diff --git a/inc/metadef/graph/utils/node_utils_ex.h b/inc/metadef/graph/utils/node_utils_ex.h new file mode 100644 index 0000000000000000000000000000000000000000..7e702139fee03be5430bdcc143089f67e40e066e --- /dev/null +++ b/inc/metadef/graph/utils/node_utils_ex.h @@ -0,0 +1,30 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_METADEF_NODE_UTILS_EX_H +#define __INC_METADEF_NODE_UTILS_EX_H + +#include "graph/node.h" +#include "graph/op_desc.h" + +namespace ge { +class NodeUtilsEx { + public: + // Detach from Node + static graphStatus Verify(const NodePtr &node); + static graphStatus InferShapeAndType(const NodePtr &node); + static graphStatus InferOriginFormat(const NodePtr &node); + // Detach from NodeUtils + static ConstNodePtr GetNodeFromOperator(const Operator &op); + static graphStatus SetNodeToOperator(Operator &op, const ConstNodePtr &node); + private: + static graphStatus IsInputsValid(const NodePtr &node); +}; +} // namespace ge +#endif // __INC_METADEF_NODE_UTILS_EX_H diff --git a/inc/metadef/graph/utils/object_pool.h b/inc/metadef/graph/utils/object_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..02cec46e799734b2f398a5041598cea3cb8c85f2 --- /dev/null +++ b/inc/metadef/graph/utils/object_pool.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef EXECUTE_GRAPH_OBJECT_POOL_H +#define EXECUTE_GRAPH_OBJECT_POOL_H + +#include +#include +namespace ge { +constexpr size_t kDefaultPoolSize = 100UL; + +template +class ObjectPool { + public: + ObjectPool() = default; + ~ObjectPool() = default; + ObjectPool(ObjectPool &) = delete; + ObjectPool(ObjectPool &&) = delete; + ObjectPool &operator=(const ObjectPool &) = delete; + ObjectPool &operator=(ObjectPool &&) = delete; + + template + std::unique_ptr Acquire(Args &&...args) { + if (!handlers_.empty()) { + std::unique_ptr tmp(std::move(handlers_.front())); + handlers_.pop(); + return tmp; + } + return std::unique_ptr(new (std::nothrow) T(args...)); + } + + void Release(std::unique_ptr ptr) { + if ((handlers_.size() < N) && (ptr != nullptr)) { + handlers_.push(std::move(ptr)); + } + } + + bool IsEmpty() const { + return handlers_.empty(); + } + + bool IsFull() const { + return handlers_.size() >= N; + } + + private: + std::queue> handlers_; +}; +} // namespace ge +#endif // EXECUTE_GRAPH_OBJECT_POOL_H diff --git a/inc/metadef/graph/utils/op_desc_utils.h b/inc/metadef/graph/utils/op_desc_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..6f0931dff0c8999ffb13d755276b27a10b5a9544 --- /dev/null +++ b/inc/metadef/graph/utils/op_desc_utils.h @@ -0,0 +1,124 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_OP_DESC_UTILS_H_ +#define INC_GRAPH_UTILS_OP_DESC_UTILS_H_ + +#include +#include +#include "graph/def_types.h" +#include "graph/node.h" +#include "graph/operator.h" +#include "graph/runtime_inference_context.h" + +/*lint -e148*/ +namespace ge { +using ConstGeTensorBarePtr = const GeTensor *; +class OpDescUtils { + public: + template + using Vistor = RangeVistor>; + using GetConstInputOnRuntimeFun = + std::function; + + OpDescUtils() = default; + ~OpDescUtils() = default; + static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); + static bool HasQuantizeFactorParams(const OpDesc& op_desc); + static std::vector GetConstInputNode(const ge::Node& node); + static std::vector GetConstInputNodeAndAnchor(const ge::Node &node); + static std::vector GetInputData(const std::vector& input_nodes); + static std::vector GetWeightsFromNodes( + const std::vector& input_nodes_2_out_anchors); + + static std::vector GetWeights(const ge::Node& node); + static std::vector GetWeights(const ge::ConstNodePtr& node); + static std::vector MutableWeights(const ge::Node& node); + static std::vector MutableWeights(const ge::NodePtr node); + static graphStatus SetWeights(ge::Node& node, const std::vector& weights); + static graphStatus SetWeights(ge::NodePtr node, const std::vector &weights); + static graphStatus SetWeights(ge::Node &node, const std::map &weights_map); + static graphStatus ClearWeights(const ge::NodePtr node); + static graphStatus SetNoneConstNodeWeights(ge::Node &node, const std::map &weights_map); + static graphStatus SetNoneConstNodeWeights(ge::Node &node, const std::vector &weights); + static bool ClearInputDesc(const ge::OpDescPtr op_desc, const uint32_t index); + static bool ClearInputDesc(const ge::NodePtr& node); + static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, const uint32_t index); + static bool ClearOutputDesc(const ge::NodePtr& node); + static std::vector GetConstInputs(const ge::Node& node, const uint32_t depth = 64U); + static std::vector GetConstInputs(const ge::ConstNodePtr& node); + static size_t GetNonConstInputsSize(const ge::Node& node); + static size_t GetNonConstInputsSize(const ge::ConstNodePtr node); + // Index: Indicates the index of all non const inputs + static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, const size_t index_non_const = 0UL); + static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, const size_t index_non_const = 0UL); + static bool GetNonConstInputIndex(const ge::Node& node, const size_t index_non_const, size_t& index); + static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, const size_t index_non_const, size_t& index); + // Index: Indicates the index of all inputs + static bool IsNonConstInput(const ge::Node& node, const size_t index = 0UL); + static bool IsNonConstInput(const ge::ConstNodePtr& node, const size_t index = 0UL); + + static std::vector GetNonConstTensorDesc(const ge::ConstNodePtr& node); + static graphStatus AddConstOpToAnchor(const InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); + + static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); + static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); + static OpDescPtr GetOpDescFromOperator(const Operator& oprt); + static graphStatus CopyOperatorLinks(const std::map &src_op_list, + std::map &dst_op_list); + static graphStatus CopyOperators(const ComputeGraphPtr &dst_compute_graph, + const std::map &node_old_2_new, + const std::map &op_desc_old_2_new, + const std::map &src_op_list, + std::map &dst_op_list); + static OpDescPtr CloneOpDesc(const ConstOpDescPtr &org_op_desc); + static OpDescPtr CopyOpDesc(const ConstOpDescPtr &org_op_desc); + static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); + static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr, const bool copy); + static OpDescPtr CreateConstOpZeroCopy(const GeTensorPtr& tensor_ptr); + + static graphStatus SetSubgraphInstanceName(const std::string &subgraph_name, + const std::string &subgraph_instance_name, OpDescPtr &op_desc); + static ConstGeTensorBarePtr GetInputConstData(const Operator &op, const uint32_t idx); + // deprecated + static void SetRuntimeContextToOperator(const Operator &op, RuntimeInferenceContext *const context); + static void SetCallbackGetConstInputFuncToOperator(const Operator &op, + GetConstInputOnRuntimeFun get_const_input_func); + static bool HasCallbackGetConstInputFunc(const Operator &op); + static std::map> GetInputIrIndexes2InstanceIndexesPairMap(const OpDescPtr &op_desc); + static std::map> GetOutputIrIndexes2InstanceIndexesPairMap( + const OpDescPtr &op_desc); + + static graphStatus GetIrInputInstanceDescRange(const OpDescPtr &op, + std::map> &ir_input_2_range); + + static graphStatus GetIrInputRawDescRange(const OpDescPtr &op, + std::map> &ir_input_2_range); + + static graphStatus GetIrOutputDescRange(const OpDescPtr &op, + std::map> &ir_output_2_range); + + static ge::graphStatus GetInputIrIndexByInstanceIndex(const OpDescPtr &op_desc, + size_t instance_index, size_t &ir_index); + static ge::graphStatus GetInstanceNum(const OpDescPtr &op_desc, size_t ir_index, size_t start_index, + const std::map &valid_index_2_names, + size_t &instance_num); + static ge::graphStatus GetPromoteIrInputList(const OpDescPtr &op_desc, + std::vector> &promote_index_list); + static ge::graphStatus GetPromoteInstanceInputList(const OpDescPtr &op_desc, + std::vector> &promote_index_list); + private: + static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); + static GeTensorPtr MutableWeights(const ge::OpDescPtr op_desc); + static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); + static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); +}; +} // namespace ge +/*lint +e148*/ +#endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ diff --git a/inc/metadef/graph/utils/op_desc_utils_ex.h b/inc/metadef/graph/utils/op_desc_utils_ex.h new file mode 100644 index 0000000000000000000000000000000000000000..276b79f9a8a9f785d799611e3bfd237be129d661 --- /dev/null +++ b/inc/metadef/graph/utils/op_desc_utils_ex.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_METADEF_OP_DESC_UTILS_EX_H +#define __INC_METADEF_OP_DESC_UTILS_EX_H + +#include "graph/op_desc.h" + +namespace ge { +class OpDescUtilsEx { + public: + // Detach from OpDesc + static graphStatus CallInferFunc(const OpDescPtr &op_desc, Operator &op); + static graphStatus CallInferFormatFunc(const OpDescPtr &op_desc, Operator &op); + static graphStatus CallInferValueRangeFunc(const OpDescPtr &op_desc, Operator &op); + static graphStatus OpVerify(const OpDescPtr &op_desc); + static graphStatus InferShapeAndType(const OpDescPtr &op_desc); + static graphStatus InferDataSlice(const OpDescPtr &op_desc); + static void SetType(OpDescPtr &op_desc, const std::string &type); + static void ResetFuncHandle(OpDescPtr &op_desc); + static void SetTypeAndResetFuncHandle(OpDescPtr &op_desc, const std::string &type); + static void UpdateShapeAndDType(const GeTensorDescPtr &src, const GeTensorDescPtr &dst); + + private: + static graphStatus CallInferFuncV1(const OpDescPtr &op_desc, Operator &op); + static graphStatus CallInferFuncV2(const OpDescPtr &op_desc, Operator &op); + static graphStatus InferShapeByOutputShapesAttr(const OpDescPtr &op_desc); +}; +} // namespace ge +#endif // __INC_METADEF_OP_DESC_UTILS_EX_H diff --git a/inc/metadef/graph/utils/op_type_utils.h b/inc/metadef/graph/utils/op_type_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..8344d8ff514d5f2812ae03178d0627d14f4b2fd3 --- /dev/null +++ b/inc/metadef/graph/utils/op_type_utils.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef __INC_METADEF_OP_TYPE_UTILS_H +#define __INC_METADEF_OP_TYPE_UTILS_H +#include +#include "graph/node.h" + +namespace ge { +class OpTypeUtils { + public: + static bool IsDataNode(const std::string &type); + static bool IsInputRefData(const ge::OpDescPtr &op_desc); + static bool IsVariableNode(const std::string &type); + static bool IsVarLikeNode(const std::string &type); + static bool IsAssignLikeNode(const std::string &type); + static bool IsIdentityLikeNode(const std::string &type); + static bool IsConstPlaceHolderNode(const std::string &type); + static graphStatus GetOriginalType(const ge::OpDescPtr &op_desc, std::string &type); + static bool IsSubgraphInnerData(const ge::OpDescPtr &op_desc); +}; +} // namespace ge +#endif // __INC_METADEF_OP_TYPE_UTILS_H diff --git a/inc/metadef/graph/utils/profiler.h b/inc/metadef/graph/utils/profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..1b98e58199d89142ee321a99de94083291806c40 --- /dev/null +++ b/inc/metadef/graph/utils/profiler.h @@ -0,0 +1,79 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_PROFILER_H +#define METADEF_CXX_PROFILER_H +#include +#include +#include +#include +#include +#include "external/graph/types.h" + +namespace ge { +namespace profiling { +constexpr size_t kMaxStrLen = 256UL; +constexpr int64_t kMaxStrIndex = 1024 * 1024; +constexpr size_t kMaxRecordNum = 10UL * 1024UL * 1024UL; +enum class EventType { + kEventStart, + kEventEnd, + kEventTimestamp, + kEventTypeEnd +}; +struct ProfilingRecord { + int64_t element; + int64_t thread; + int64_t event; + EventType et; + std::chrono::time_point timestamp; +}; + +struct StrHash { + char_t str[kMaxStrLen]; + uint64_t hash; +}; + +class Profiler { + public: + static std::unique_ptr Create(); + void UpdateHashByIndex(const int64_t index, const uint64_t hash); + void RegisterString(const int64_t index, const std::string &str); + void RegisterStringHash(const int64_t index, const uint64_t hash, const std::string &str); + void Record(const int64_t element, const int64_t thread, const int64_t event, const EventType et, + const std::chrono::time_point time_point); + void RecordCurrentThread(const int64_t element, const int64_t event, const EventType et); + void RecordCurrentThread(const int64_t element, const int64_t event, const EventType et, + const std::chrono::time_point time_point); + + void Reset(); + void Dump(std::ostream &out_stream) const; + + size_t GetRecordNum() const noexcept; + const ProfilingRecord *GetRecords() const; + + using ConstStringHashesPointer = StrHash const(*); + using StringHashesPointer = StrHash (*); + ConstStringHashesPointer GetStringHashes() const; + StringHashesPointer GetStringHashes() ; + + ~Profiler(); + Profiler(); + + private: + void DumpByIndex(const int64_t index, std::ostream &out_stream) const; + + private: + std::atomic record_size_; + std::array records_; + StrHash indexes_to_str_hashes_[kMaxStrIndex]; +}; +} +} +#endif // METADEF_CXX_PROFILER_H diff --git a/inc/metadef/graph/utils/tensor_adapter.h b/inc/metadef/graph/utils/tensor_adapter.h new file mode 100644 index 0000000000000000000000000000000000000000..89e8836f54966632c7cec115ddf91f31e81ccbc4 --- /dev/null +++ b/inc/metadef/graph/utils/tensor_adapter.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ +#define INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ + +#include +#include "graph/ge_tensor.h" +#include "graph/tensor.h" +#include "graph/ge_attr_value.h" + +namespace ge { +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorAdapter { + public: + static GeTensorDesc TensorDesc2GeTensorDesc(const TensorDesc &tensor_desc); + static TensorDesc GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_desc); + static Tensor GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor); + + static ConstGeTensorPtr AsGeTensorPtr(const Tensor &tensor); // Share value + static GeTensorPtr AsGeTensorPtr(Tensor &tensor); // Share value + static const GeTensor AsGeTensor(const Tensor &tensor); // Share value + static GeTensor AsGeTensor(Tensor &tensor); // Share value + static const Tensor AsTensor(const GeTensor &ge_tensor); // Share value + static Tensor AsTensor(GeTensor &ge_tensor); // Share value + static GeTensor AsGeTensorShared(const Tensor &tensor); + static GeTensor NormalizeGeTensor(const GeTensor &tensor); + static void NormalizeGeTensorDesc(GeTensorDesc &tensor_desc); + static const GeTensor* AsBareGeTensorPtr(const Tensor &tensor); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TENSOR_ADAPTER_H_ diff --git a/inc/metadef/graph/utils/tensor_utils.h b/inc/metadef/graph/utils/tensor_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f11e692fde36e3444e9716bd19fb7443956b70d4 --- /dev/null +++ b/inc/metadef/graph/utils/tensor_utils.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_TENSOR_UTILS_H_ +#define INC_GRAPH_UTILS_TENSOR_UTILS_H_ + +#include +#include "graph/attr_value_serializable.h" +#include "graph/def_types.h" +#include "graph/ge_error_codes.h" +#include "graph/ge_tensor.h" + +namespace ge { +class TensorUtils { + public: + static GeTensor CreateShareTensor(const GeTensor &other); + static GeTensor CreateShareTensor(const GeTensorDesc &tensor_desc, + std::shared_ptr aligned_ptr, + const size_t size); + static void ShareTensor(const GeTensor &from, GeTensor &to); + static TensorData CreateShareTensorData(const TensorData &other); + static void ShareTensorData(const TensorData &from, TensorData &to); + static void ShareAlignedPtr(std::shared_ptr ptr, const size_t size, TensorData &to); + static void ShareAlignedPtr(std::shared_ptr ptr, const size_t size, GeTensor &to); + static void CopyTensor(const GeTensor &from, GeTensor &to); + static ge::graphStatus GetSize(const GeTensorDesc &tensor_desc, int64_t &size); + static void SetSize(GeTensorDesc &tensor_desc, const int64_t size); + static int64_t GetWeightSize(const ConstGeTensorPtr &tensor_ptr); + static int64_t GetWeightSize(const GeTensor &tensor); + static int64_t GetWeightSize(const GeTensorDesc &tensor_desc); + static uint8_t *GetWeightAddr(const ConstGeTensorPtr &tensor_ptr, const uint8_t *const base); + static uint8_t *GetWeightAddr(const GeTensor &tensor, const uint8_t *const base); + static void SetWeightSize(GeTensorDesc &tensor_desc, const int64_t size); + static ge::graphStatus GetReuseInput(const GeTensorDesc &tensor_desc, bool &flag); + static void SetReuseInput(GeTensorDesc &tensor_desc, const bool flag); + static ge::graphStatus GetOutputTensor(const GeTensorDesc &tensor_desc, bool &flag); + static void SetOutputTensor(GeTensorDesc &tensor_desc, const bool flag); + static graphStatus GetDeviceType(const GeTensorDesc &tensor_desc, DeviceType &type); + static void SetDeviceType(GeTensorDesc &tensor_desc, const DeviceType type); + static ge::graphStatus GetInputTensor(const GeTensorDesc &tensor_desc, bool &flag); + static void SetInputTensor(GeTensorDesc &tensor_desc, const bool flag); + static ge::graphStatus GetRealDimCnt(const GeTensorDesc &tensor_desc, uint32_t &cnt); + static void SetRealDimCnt(GeTensorDesc &tensor_desc, const uint32_t cnt); + static ge::graphStatus GetReuseInputIndex(const GeTensorDesc &tensor_desc, uint32_t &idx); + static void SetReuseInputIndex(GeTensorDesc &tensor_desc, const uint32_t idx); + static ge::graphStatus GetDataOffset(const GeTensorDesc &tensor_desc, int64_t &offset); + static void SetDataOffset(GeTensorDesc &tensor_desc, const int64_t offset); + static ge::graphStatus GetRC(const GeTensorDesc &tensor_desc, uint32_t &rc); + static void SetRC(GeTensorDesc &tensor_desc, const uint32_t rc); + static bool IsOriginShapeInited(const GeTensorDesc &tensor_desc); + + static ge::graphStatus CalcTensorMemSize(const GeShape &shape, const Format format, + const DataType data_type, int64_t &mem_size); + static ge::graphStatus CalcTensorMemSizeForNoTiling(const GeTensorDesc &tensor, + const Format format, + const DataType data_type, + int64_t &mem_size); + static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); + static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); + static ge::graphStatus CheckShapeByShapeRange(const GeShape &shape, + const std::vector> &shape_range); + static bool IsShapeEqual(const GeShape &src, const GeShape &dst); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ diff --git a/inc/metadef/graph/utils/type_utils_inner.h b/inc/metadef/graph/utils/type_utils_inner.h new file mode 100644 index 0000000000000000000000000000000000000000..485d975a677b54091573dfcf7a020f49ef52f867 --- /dev/null +++ b/inc/metadef/graph/utils/type_utils_inner.h @@ -0,0 +1,36 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_GRAPH_UTILS_TYPE_UTILS_INNER_H_ +#define INC_GRAPH_UTILS_TYPE_UTILS_INNER_H_ + +#include +#include "graph/def_types.h" +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "register/register_types.h" +#include "external/register/register_fmk_types.h" + +namespace ge { +class TypeUtilsInner { + public: + static bool IsDataTypeValid(const DataType dt); + static bool IsFormatValid(const Format format); + static bool IsDataTypeValid(std::string dt); // for user json input + static bool IsFormatValid(std::string format); // for user json input + static bool IsInternalFormat(const Format format); + + static std::string ImplyTypeToSerialString(const domi::ImplyType imply_type); + static graphStatus SplitFormatFromStr(const std::string &str, std::string &primary_format_str, int32_t &sub_format); + static Format DomiFormatToFormat(const domi::domiTensorFormat_t domi_format); + static std::string FmkTypeToSerialString(const domi::FrameworkType fmk_type); + static bool CheckUint64MulOverflow(const uint64_t a, const uint32_t b); +}; +} // namespace ge +#endif // INC_GRAPH_UTILS_TYPE_UTILS_INNER_H_ diff --git a/inc/metadef/graph/yuv_subformat.h b/inc/metadef/graph/yuv_subformat.h new file mode 100644 index 0000000000000000000000000000000000000000..ad177bf6d92e505ba8fee5ae04b6a8776d3ebc7c --- /dev/null +++ b/inc/metadef/graph/yuv_subformat.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_YUV_SUBFORMAT_H_ +#define INC_YUV_SUBFORMAT_H_ +namespace ge { +enum YUVSubFormat { + YUV420_SP = 1, + YVU420_SP, + YUV422_SP, + YVU422_SP, + YUV440_SP, + YVU440_SP, + YUV444_SP, + YVU444_SP, + YUYV422_PACKED, + YVYU422_PACKED, + YUV444_PACKED, + YVU444_PACKED, + YUV400 +}; +} // namespace ge +#endif // INC_YUV_SUBFORMAT_H_ diff --git a/inc/metadef/register/custom_pass_helper.h b/inc/metadef/register/custom_pass_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..34a937b9c141aa6c800388e7e6779fbf70e66f02 --- /dev/null +++ b/inc/metadef/register/custom_pass_helper.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_CUSTOM_PASS_HELPER_H_ +#define INC_REGISTER_CUSTOM_PASS_HELPER_H_ + +#include +#include "external/ge_common/ge_api_error_codes.h" +#include "external/register/register_custom_pass.h" +#include "external/register/register_types.h" + +namespace ge { +class CustomPassHelper { + public: + static CustomPassHelper &Instance(); + + void Insert(const PassRegistrationData ®_data); + + Status Load(); + + Status Unload(); + + Status Run(GraphPtr &graph, CustomPassContext &custom_pass_context) const; + + ~CustomPassHelper() = default; + + private: + CustomPassHelper() = default; + std::vector registration_datas_; + std::vector handles_; +}; +} // namespace ge + +#endif // INC_REGISTER_CUSTOM_PASS_HELPER_H_ diff --git a/inc/metadef/register/ffts_node_converter_registry.h b/inc/metadef/register/ffts_node_converter_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..9b034ef545287d18d3bbea4ab7f934899b57b47e --- /dev/null +++ b/inc/metadef/register/ffts_node_converter_registry.h @@ -0,0 +1,132 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CONVERTER_REGISTRY_H_ +#define AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CONVERTER_REGISTRY_H_ +#include +#include +#include +#include "common/checker.h" +#include "node_converter_registry.h" +#include "graph/node.h" +#include "exe_graph/lowering/dev_mem_value_holder.h" +#include "exe_graph/lowering/lowering_global_data.h" + +namespace gert { +using FFTSPreThreadFunc = std::function &input_shapes, std::vector &output)>; +using FFTSPreThreadFuncNew = std::function &input_shapes, const std::vector &input_addrs, + std::vector &output)>; +using FFTSThreadFunc = std::function &input_shapes, + const std::vector &output_shapes, const bg::ValueHolderPtr thread_dim, + std::vector &output)>; + +struct SkipCtxRecord { + bool Init() { + ctx_id_v = std::unique_ptr>(new(std::nothrow) std::vector); + GE_ASSERT_NOTNULL(ctx_id_v); + ctx_type_v = std::unique_ptr>(new(std::nothrow) std::vector); + GE_ASSERT_NOTNULL(ctx_type_v); + return true; + } + size_t GetCtxNum() { + if (ctx_id_v == nullptr) { + return 0; + } + return ctx_id_v->size(); + } + bool SetSkipCtx(uint32_t ctx_id, uint32_t ctx_type) { + if (ctx_id_v == nullptr || ctx_type_v == nullptr) { + return false; + } + ctx_id_v->emplace_back(ctx_id); + ctx_type_v->emplace_back(ctx_type); + return true; + } + bool GetSkipCtx(size_t idx, uint32_t &ctx_id, uint32_t &ctx_type) { + if (ctx_id_v == nullptr || ctx_type_v == nullptr) { + return false; + } + if (idx >= ctx_id_v->size() || idx >= ctx_type_v->size()) { + return false; + } + ctx_id = ctx_id_v->at(idx); + ctx_type = ctx_type_v->at(idx); + return true; + } + void ClearRecord() { + if (ctx_id_v != nullptr) { + ctx_id_v->clear(); + } + if (ctx_type_v != nullptr) { + ctx_type_v->clear(); + } + return; + } + private: + std::unique_ptr> ctx_id_v{nullptr}; + std::unique_ptr> ctx_type_v{nullptr}; +}; + +struct FFTSLowerInput { + std::vector input_shapes; + std::vector input_addrs; + std::vector mem_pool_types; + LoweringGlobalData *global_data; + bg::ValueHolderPtr task_info; + bg::ValueHolderPtr thread_dim; + bg::ValueHolderPtr window_size; + bg::ValueHolderPtr args_para; + bg::ValueHolderPtr ffts_mem_allocator; + FFTSThreadFunc ffts_thread_fun; + bg::ValueHolderPtr skip_ctx_holder; +}; +class FFTSNodeConverterRegistry { + public: + using NodeConverter = LowerResult (*)(const ge::NodePtr &node, const FFTSLowerInput &lower_input); + struct ConverterRegisterData { + NodeConverter converter; + int32_t require_placement; + }; + static FFTSNodeConverterRegistry &GetInstance(); + NodeConverter FindNodeConverter(const std::string &func_name); + const ConverterRegisterData *FindRegisterData(const std::string &func_name) const; + void RegisterNodeConverter(const std::string &func_name, NodeConverter func); + void Register(const std::string &func_name, const ConverterRegisterData &data); + + private: + std::unordered_map names_to_register_data_; +}; + +class FFTSNodeConverterRegister { + public: + FFTSNodeConverterRegister(const char *lower_func_name, FFTSNodeConverterRegistry::NodeConverter func) noexcept; + FFTSNodeConverterRegister(const char *lower_func_name, int32_t require_placement, + FFTSNodeConverterRegistry::NodeConverter func) noexcept; +}; +} // namespace gert + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif + +#define GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER2(type, placement, func, counter) \ + static const gert::FFTSNodeConverterRegister g_register_node_converter_##counter ATTRIBUTE_USED = \ + gert::FFTSNodeConverterRegister(type, placement, func) +#define GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER(type, placement, func, counter) \ + GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER2(type, placement, func, counter) +#define FFTS_REGISTER_NODE_CONVERTER_PLACEMENT(type, placement, func) \ + GERT_REGISTER_FFTS_NODE_CONVERTER_COUNTER(type, placement, func, __COUNTER__) +#define FFTS_REGISTER_NODE_CONVERTER(type, func) FFTS_REGISTER_NODE_CONVERTER_PLACEMENT(type, -1, func) + +#endif // AIR_CXX_RUNTIME_V2_LOWERING_FFTS_NODE_CONVERTER_REGISTRY_H_ diff --git a/inc/metadef/register/ffts_plus_engine_update.h b/inc/metadef/register/ffts_plus_engine_update.h new file mode 100644 index 0000000000000000000000000000000000000000..aaa83e56f376aceab7b0813cb2f137c72561fabd --- /dev/null +++ b/inc/metadef/register/ffts_plus_engine_update.h @@ -0,0 +1,25 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_FFTS_PLUS_ENGINE_UPDATE_H_ +#define INC_REGISTER_FFTS_PLUS_ENGINE_UPDATE_H_ +#include "graph/utils/graph_utils.h" +#include "graph/utils/tensor_utils.h" +#include "runtime/rt_ffts_plus.h" +#include "common/sgt_slice_type.h" +namespace ffts { +class FFTSPlusEngineUpdate { +public: + FFTSPlusEngineUpdate(); + ~FFTSPlusEngineUpdate(); + static bool UpdateCommonCtx(ge::ComputeGraphPtr &sgt_graph, rtFftsPlusTaskInfo_t &task_info); + static ThreadSliceMapDyPtr slice_info_ptr_; +}; +}; +#endif // INC_REGISTER_FFTS_PLUS_ENGINE_UPDATE_H_ diff --git a/inc/metadef/register/ffts_plus_task_update.h b/inc/metadef/register/ffts_plus_task_update.h new file mode 100644 index 0000000000000000000000000000000000000000..65c6228d5904433be77176434c447e6ba8e65cab --- /dev/null +++ b/inc/metadef/register/ffts_plus_task_update.h @@ -0,0 +1,87 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_FFTS_PLUS_TASK_UPDATE_H_ +#define INC_REGISTER_FFTS_PLUS_TASK_UPDATE_H_ + +#include + +#include "graph/node.h" +#include "register/op_tiling_registry.h" +#include "runtime/rt_ffts_plus.h" +#include "external/ge_common/ge_api_error_codes.h" + +namespace ge { +struct AutoThreadSubTaskFlush { + int32_t device_id{0}; + void *args_base{nullptr}; + std::vector op_run_info; + + uintptr_t aic_non_tail_task_start_pc{0U}; + uintptr_t aic_tail_task_start_pc{0U}; + uint32_t aic_icache_prefetch_cnt{0U}; + + uintptr_t aiv_non_tail_task_start_pc{0U}; + uintptr_t aiv_tail_task_start_pc{0U}; + uint32_t aiv_icache_prefetch_cnt{0U}; + + // Task I/O Addrs. + std::vector input_addr_base; + std::vector output_addr_base; +}; + +struct AutoThreadParam { + uint16_t thread_dim{0U}; // thread dim after Pre-Thread + uint32_t input_output_num{0U}; // input + output + std::vector task_addr_offset; // input + output + workspace + + // Task Thread Dims. + std::vector>> *task_input_shape{nullptr}; // thread + std::vector>> *task_output_shape{nullptr}; // thread +}; + +class FFTSPlusTaskUpdate { + public: + FFTSPlusTaskUpdate() = default; + virtual ~FFTSPlusTaskUpdate() = default; + + virtual Status GetAutoThreadParam(const NodePtr &node, const std::vector &op_run_info, + AutoThreadParam &auto_thread_param) { + (void)node; + (void)op_run_info; + (void)auto_thread_param; + return SUCCESS; + } + + virtual Status UpdateSubTaskAndCache(const NodePtr &node, const AutoThreadSubTaskFlush &sub_task_flush, + rtFftsPlusTaskInfo_t &ffts_plus_task_info) { + (void)node; + (void)sub_task_flush; + (void)ffts_plus_task_info; + return SUCCESS; + } + + virtual Status UpdateCommonCtx(const ComputeGraphPtr &sgt_graph, rtFftsPlusTaskInfo_t &task_info) { + (void)sgt_graph; + (void)task_info; + return SUCCESS; + } + + virtual Status UpdateStaticDataCtx(size_t ctx_num, std::vector &io_addrs, size_t align_offset, + size_t host_io_base, std::map> &ctx_ids_map) { + (void)ctx_num; + (void)io_addrs; + (void)align_offset; + (void)host_io_base; + (void)ctx_ids_map; + return SUCCESS; + } +}; +} // namespace ge +#endif // INC_REGISTER_FFTS_PLUS_TASK_UPDATE_H_ diff --git a/inc/metadef/register/ffts_plus_update_manager.h b/inc/metadef/register/ffts_plus_update_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..798598224f453845eb5d3133211b49ae4282b397 --- /dev/null +++ b/inc/metadef/register/ffts_plus_update_manager.h @@ -0,0 +1,76 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_FFTS_PLUS_UPDATE_MANAGER_H_ +#define INC_REGISTER_FFTS_PLUS_UPDATE_MANAGER_H_ + +#include +#include +#include +#include +#include + +#include "register/ffts_plus_task_update.h" + +namespace ge { +using FftsCtxUpdatePtr = std::shared_ptr; +using FftsCtxUpdateCreatorFun = std::function; +class PluginManager; + +class FftsPlusUpdateManager { + public: + static FftsPlusUpdateManager &Instance(); + /** + * For load so to register FFTSPlusTaskUpdate subclass constructor. + */ + Status Initialize(); + + /** + * Get FFTS Plus context by core type. + * @param core_type: core type of Node + * @return FFTS Plus Update instance. + */ + FftsCtxUpdatePtr GetUpdater(const std::string &core_type) const; + + class FftsPlusUpdateRegistrar { + public: + FftsPlusUpdateRegistrar(const std::string &core_type, const FftsCtxUpdateCreatorFun &creator) { + FftsPlusUpdateManager::Instance().RegisterCreator(core_type, creator); + } + ~FftsPlusUpdateRegistrar() = default; + }; + + private: + FftsPlusUpdateManager() = default; + ~FftsPlusUpdateManager(); + + /** + * Register FFTS Plus context update executor. + * @param core_type: core type of Node + * @param creator: FFTS Plus Update instance Creator. + */ + void RegisterCreator(const std::string &core_type, const FftsCtxUpdateCreatorFun &creator); + + std::map creators_; + std::unique_ptr plugin_manager_; +}; +} // namespace ge + +#define REGISTER_FFTS_PLUS_CTX_UPDATER(core_type, task_clazz) \ + REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ_HELPER(__COUNTER__, core_type, task_clazz) + +#define REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ_HELPER(ctr, type, clazz) \ + REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ(ctr, type, clazz) + +#define REGISTER_FFTS_PLUS_CTX_TASK_UPDATER_UNIQ(ctr, type, clazz) \ + ge::FftsPlusUpdateManager::FftsPlusUpdateRegistrar g_##type##_creator##ctr((type), []() { \ + return std::shared_ptr(new(std::nothrow) (clazz)()); \ + }) + +#endif // INC_REGISTER_FFTS_PLUS_UPDATE_MANAGER_H_ diff --git a/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h new file mode 100644 index 0000000000000000000000000000000000000000..f2e418c95bc66f851d7da5db4f6aa08dcb327f4a --- /dev/null +++ b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ +#include +#include + +namespace fe { +const std::string MODIFY_GRAPH_IN_UB_FUSION_PASS = "ModifyGraphInUBFusionPass"; +const std::string UB_FUSION_OP_TYPE = "_ub_fusion_op_type"; +// add the op pattern +const std::string TBE_PATTERN_INPUT_NODE = "InputData"; +const std::string TBE_PATTERN_OP_TYPE_ANY = "OpTypeAny"; +const std::string TBE_PATTERN_OUTPUT_NODE = "OutputData"; +const std::string OP_PATTERN_ELEMWISE = "ElemWise"; +const std::string OP_PATTERN_COMMONREDUCE = "CommReduce"; +const std::string OP_PATTERN_BROAD_CAST = "Broadcast"; +const std::string OP_PATTERN_BROAD_CAST_NZ = "Broadcast_Nz"; +const std::string OP_PATTERN_SEGMENT = "Segment"; +const std::string OP_PATTERN_MAXPOOL = "MaxPool"; +const std::string OP_PATTERN_CONV = "Convolution"; +const std::string OP_PATTERN_MATMUL = "Matmul"; +const std::string OP_PATTERN_BNUPDATE = "bn_update"; +const std::string OP_PATTERN_BNREDUCE = "bn_reduce"; +const std::string OP_PATTERN_BNTUPLEREDUCE = "TupleReduce"; +const std::string OP_PATTERN_CONV_BACKPROP_INPUT = "Conv2d_backprop_input"; +const std::string OP_PATTERN_DEPTHWISE_CONV = "DepthwiseConvolution"; +const std::string OP_PATTERN_QUANT = "quant"; +const std::string OP_PATTERN_DEQUANT = "dequant"; +const std::string OP_PATTERN_REQUANT = "requant"; +const std::string OP_PATTERN_POOL2D = "Pool2d"; +const std::string OP_PATTERN_ANTIQUANT = "anti_quant"; +const std::string OP_PATTERN_STRIDED_WRITE = "strided_write"; +const std::string OP_PATTERN_STRIDED_READ = "strided_read"; +const std::string OP_PATTERN_AIPP = "aipp"; +const std::string OP_PATTERN_CONFUSION_TRANSPOSE = "confusiontranspose"; +const std::string OP_PATTERN_DEQUANTS16 = "dequant_s16"; +const std::string OP_PATTERN_REQUANTS16 = "requant_s16"; +const std::string OP_PATTERN_READ_SELECT = "read_select"; +const std::string OP_PATTERN_WRITE_SELECT = "write_select"; +const std::string OP_PATTERN_BATCH_MATMUL = "BatchMatmul"; +const std::string OP_PATTERN_CONV3D = "Conv3d"; +const std::string OP_PATTERN_DROPOUTDOMASKV3D = "DropOutDoMaskV3D"; +const std::string OP_PATTERN_CONV3D_BACKPROP_INPUT = "Conv3d_backprop_input"; +const std::string OP_PATTERN_CONV_BACKPROP_FILTER = "Conv2d_backprop_filter"; +const std::string OP_PATTERN_GEMM = "GEMM"; +const std::string OP_PATTERN_FIXPIPE = "fixpipe"; +const std::string OP_PATTERN_AVGPOOLUPDATE = "AvgPoolUpdate"; +const std::vector OP_PATTERN_VEC{OP_PATTERN_ELEMWISE, + OP_PATTERN_COMMONREDUCE, + OP_PATTERN_BROAD_CAST, + OP_PATTERN_BROAD_CAST_NZ, + OP_PATTERN_SEGMENT, + OP_PATTERN_MAXPOOL, + OP_PATTERN_CONV, + OP_PATTERN_MATMUL, + OP_PATTERN_BNUPDATE, + OP_PATTERN_BNREDUCE, + OP_PATTERN_BNTUPLEREDUCE, + OP_PATTERN_CONV_BACKPROP_INPUT, + OP_PATTERN_DEPTHWISE_CONV, + OP_PATTERN_QUANT, + OP_PATTERN_DEQUANT, + OP_PATTERN_REQUANT, + OP_PATTERN_POOL2D, + OP_PATTERN_ANTIQUANT, + OP_PATTERN_STRIDED_WRITE, + OP_PATTERN_STRIDED_READ, + OP_PATTERN_AIPP, + OP_PATTERN_CONFUSION_TRANSPOSE, + OP_PATTERN_DEQUANTS16, + OP_PATTERN_REQUANTS16, + OP_PATTERN_READ_SELECT, + OP_PATTERN_WRITE_SELECT, + OP_PATTERN_BATCH_MATMUL, + OP_PATTERN_CONV3D, + OP_PATTERN_DROPOUTDOMASKV3D, + OP_PATTERN_CONV3D_BACKPROP_INPUT, + OP_PATTERN_CONV_BACKPROP_FILTER, + OP_PATTERN_GEMM, + OP_PATTERN_FIXPIPE, + OP_PATTERN_AVGPOOLUPDATE +}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_CONSTANT_H_ diff --git a/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h new file mode 100644 index 0000000000000000000000000000000000000000..6e076cb1e63739003132a56d45e9f55a37b291eb --- /dev/null +++ b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ + +#include +#include +#include +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_constant.h" +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include "register/graph_optimizer/fusion_common/op_slice_info.h" + +namespace fe { +enum BufferFusionPassType { + BUILT_IN_AI_CORE_BUFFER_FUSION_PASS, + BUILT_IN_VECTOR_CORE_BUFFER_FUSION_PASS, + CUSTOM_AI_CORE_BUFFER_FUSION_PASS, + CUSTOM_VECTOR_CORE_BUFFER_FUSION_PASS, + BUFFER_FUSION_PASS_TYPE_RESERVED +}; + +class BufferFusionPassBase { + public: + explicit BufferFusionPassBase(); + virtual ~BufferFusionPassBase(); + virtual std::vector DefinePatterns() = 0; + virtual Status GetFusionNodes(const BufferFusionMapping &mapping, std::vector &fusion_nodes); + virtual Status GetMixl2FusionNodes(const BufferFusionMapping &mapping, std::vector &fusion_nodes); + virtual Status PostFusion(const ge::NodePtr &fused_node); + virtual Status CalcFusionOpSliceInfo(std::vector &fusion_nodes, OpCalcInfo &op_slice_info); + virtual Status CheckNodeCanFusion(const BufferFusionNodeDescMap &fusion_nodes, const ge::NodePtr &next_node); + static std::vector GetMatchedNodes(const BufferFusionMapping &mapping); + static std::vector GetMatchedNodesByDescName(const std::string &desc_name, + const BufferFusionMapping &mapping); + static ge::NodePtr GetMatchedHeadNode(const std::vector &matched_nodes); + static bool CheckNodeIsDynamicImpl(const ge::NodePtr &node); + static bool CheckTwoNodesImplConsistent(const ge::NodePtr &src_node, const ge::NodePtr &dst_node); + static bool CheckNodesImplConsistent(const BufferFusionMapping &mapping); + static bool CheckNodesImplConsistent(const std::vector &fusion_nodes); + static bool CheckNodeIsDynamicShape(const ge::NodePtr& node); + static bool CheckNodesIncDynamicShape(const BufferFusionMapping &mapping); + static bool CheckNodesIncDynamicShape(const std::vector &fusion_nodes); + void SetName(const std::string &name) { name_ = name; } + + std::string GetName() { return name_; } + + private: + std::string name_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PASS_BASE_H_ diff --git a/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..6f7dd189139da5b82db02be67765d5a358966d52 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pass_registry.h @@ -0,0 +1,67 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_BUFFER_FUSION_PASS_REGISTRY_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_BUFFER_FUSION_PASS_REGISTRY_H_ +#include +#include +#include +#include +#include "register/graph_optimizer/fusion_common/fusion_pass_desc.h" +#include "register/graph_optimizer/buffer_fusion/buffer_fusion_pass_base.h" + +namespace fe { +class BufferFusionPassRegistry { + public: + using CreateFn = BufferFusionPassBase *(*)(); + struct PassDesc { + PassAttr attr; + CreateFn create_fn; + }; + ~BufferFusionPassRegistry(); + + static BufferFusionPassRegistry &GetInstance(); + + void RegisterPass(const BufferFusionPassType &pass_type, const std::string &pass_name, + const CreateFn &create_fn, PassAttr attr); + + std::map GetPassDesc(const BufferFusionPassType &pass_type); + + std::map GetCreateFnByType(const BufferFusionPassType &pass_type); + + private: + BufferFusionPassRegistry(); + class BufferFusionPassRegistryImpl; + std::unique_ptr impl_; +}; + +class BufferFusionPassRegistrar { + public: + BufferFusionPassRegistrar(const BufferFusionPassType &pass_type, const std::string &pass_name, + BufferFusionPassBase *(*create_fun)(), PassAttr attr); + + ~BufferFusionPassRegistrar() {} +}; + +#define REGISTER_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class) \ + REG_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class, 0) + +#define REG_BUFFER_FUSION_PASS(pass_name, pass_type, pass_class, attr) \ + REG_BUFFER_FUSION_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass_type, pass_class, attr) + +#define REG_BUFFER_FUSION_PASS_UNIQ_HELPER(ctr, pass_name, pass_type, pass_class, attr) \ + REG_BUFFER_FUSION_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) + +#define REG_BUFFER_FUSION_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) \ + static ::fe::BufferFusionPassRegistrar register_buffer_fusion_##ctr __attribute__((unused)) = \ + ::fe::BufferFusionPassRegistrar( \ + (pass_type), (pass_name), \ + []() -> ::fe::BufferFusionPassBase * { return new (std::nothrow) pass_class();}, (attr)) +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_BUFFER_FUSION_PASS_REGISTRY_H_ diff --git a/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h new file mode 100644 index 0000000000000000000000000000000000000000..1091b9b6834b98252492b19e6bc7a626af17e626 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/buffer_fusion/buffer_fusion_pattern.h @@ -0,0 +1,151 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ +#include +#include +#include +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" + +namespace fe { +extern const int64_t TBE_FUSION_OP_NUM_MAX; +extern const int64_t TBE_PATTERN_NUM_MAX; +extern const int64_t TBE_PATTERN_NUM_NONE; +extern const int64_t TBE_PATTERN_NUM_DEFAULT; +extern const int64_t TBE_OUTPUT_BRANCH_DEFAULT; +extern const int64_t TBE_OUTPUT_BRANCH_SINGLE; +extern const int64_t TBE_OUTPUT_BRANCH_MULTI; +extern const int64_t TBE_PATTERN_GROUPID_INVALID; +extern const int32_t TBE_OUTPUT_MAX_NUM_LIMIT; + +enum SkipStatus { DISABLED = 0, AVAILABLE = 1, SKIPPED = 2 }; + +enum ShapeTypeRule { IGNORE_SHAPE_TYPE = 0, ONLY_SUPPORT_STATIC, ONLY_SUPPORT_DYNAMIC }; + +enum class PatternRelation { RELATIVE_POSITION_CONSISTENT = 0 }; + +extern const std::map kShapeTypeRuleToStr; + +struct BufferFusionOpDesc { + std::string desc_name; // description name + std::vector types; // description type + std::vector inputs; // all input op + std::vector outputs; // all output op + int64_t out_branch_type; // out desc type, 1:single, 2: multi + int64_t repeate_min; // opdesc min repeat num + int64_t repeate_max; // opdesc max repeat num + int64_t repeate_curr; // opdesc current repeat num + bool match_status; + bool not_pattern; + int64_t group_id; // record desc groupid, need one desc matched at least in + // the same group + std::vector shape_type_rules; + bool ignore_input_num; + bool ignore_output_num; + bool is_allow_series; // whether the nodes with the same pattern can be series in match graph + int32_t output_max_limit; + // used for two connected op, first opdesc has optional multiple nodes and + // ignore_output_num is true, second opdesc is same pattern type and + // out_branch_type is TBE_OUTPUT_BRANCH_MULTI + std::map multi_output_skip_status; + std::vector> relations; +}; + +struct MappingCmpKey { + bool operator() (const BufferFusionOpDesc *key1, const BufferFusionOpDesc *key2) const { + return (key1->desc_name) < (key2->desc_name); + } +}; +using BufferFusionMapping = std::map, MappingCmpKey>; +using BufferFusionNodeDescMap = std::unordered_map; + +class BufferFusionPattern { + public: + explicit BufferFusionPattern(std::string name = "", int64_t op_max_count = TBE_FUSION_OP_NUM_MAX); + + virtual ~BufferFusionPattern(); + + /* + * types vector use one ShapeTypeRule + */ + BufferFusionPattern &AddOpDesc(const std::string &desc_name, const std::vector &types, + const int64_t repeat_min = TBE_PATTERN_NUM_DEFAULT, + const int64_t repeat_max = TBE_PATTERN_NUM_DEFAULT, + const int64_t group_id = TBE_PATTERN_GROUPID_INVALID, + const ShapeTypeRule shape_type_rule = ONLY_SUPPORT_STATIC, + const bool not_pattern = false, const bool is_allow_series = true); + +/** + * add node desc + * @param desc_name + * @param types + * @param repeat_min + * @param repeat_max + * @param is_allow_series + * @return ref + */ + BufferFusionPattern &AddOpDesc(const std::string &desc_name, const std::vector &types, + const int64_t repeat_min, const int64_t repeat_max, const bool is_allow_series); + + /* + * types vector use ShapeTypeRule vector, and size should be same or ShapeTypeRule size equal 1 + */ + BufferFusionPattern &AddOpDescTypeRules(const std::string &desc_name, const std::vector &types, + const int64_t repeat_min, const int64_t repeat_max, const int64_t group_id, + const std::vector &shape_type_rules, + const bool not_pattern = false, const bool is_allow_series = true); + + BufferFusionPattern &SetOutputs(const std::string &desc_name, const std::vector &output_ids, + int64_t relation = TBE_OUTPUT_BRANCH_SINGLE, bool ignore_input_num = false, + bool ignore_output_num = false, int32_t output_max_limit = TBE_OUTPUT_MAX_NUM_LIMIT); + + BufferFusionPattern &SetHead(const std::vector &head_ids); + + /** + * add node desc + * @param desc_name + * @param types + * @param repeat_min + * @param repeat_max + * @param is_allow_series + * @return ref + */ + BufferFusionPattern &SetRelation(const std::string &src_desc_name, const std::string &dst_desc_name, + const PatternRelation pattern_relation); + + const std::string& GetName() const; + int64_t GetOpMaxCount() const; + const std::vector& GetOpDescs() const; + const std::vector& GetHead() const; + int64_t GetErrorCnt() const; + void SetGraphModType(int64_t graph_mod_type); + int64_t GetGraphModType() const; + bool GetOutputs(BufferFusionOpDesc *op_desc, std::vector &outputs, bool ignore_repeat = false); + + private: + bool IsOpDescValid(const std::string &desc_name, int64_t repeat_min, int64_t repeat_max) const; + bool IsShapeRulesSizeValid(const size_t &types_size, const size_t &rules_size) const; + BufferFusionOpDesc *GetOpDesc(const std::string &desc_name) const; + void UpdateSkipStatus(const BufferFusionOpDesc *op_desc) const; + + std::string name_; + int64_t op_max_count_; + std::vector ops_; + std::map op_map_; + std::vector head_; + int64_t error_count_; + // 0: this pattern will not modify graph(default) + // 1: this pattern will modify graph + int64_t graph_mod_type_; +}; +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_BUFFER_FUSION_PATTERN_H_ diff --git a/inc/metadef/register/graph_optimizer/fusion_common/fusion_pass_desc.h b/inc/metadef/register/graph_optimizer/fusion_common/fusion_pass_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..f2d0dd06d6ef78118eda4cda70b42844bf245f92 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/fusion_pass_desc.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_PASS_DESC_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_PASS_DESC_H_ +#include +#include "graph/option/optimization_option_info.h" +namespace fe { +using PassAttr = uint64_t; +const PassAttr FORBIDDEN_CLOSE = 0x01UL; // forbidden close, can not be closed by fusion switch +const PassAttr NEED_SORT = 0x02UL; // need topological sorting before executing +const PassAttr SINGLE_SCENE_OPEN = 0x04UL; // open for single op scene, can be close by fusion switch +const PassAttr FE_PASS = 0x08UL; // graph passes and ub passes in air project +const PassAttr ENABLE_AUTO_FUSION = 0x10UL; // whether using auto match fusion frame +const PassAttr ALWAYS_GENERALIZE = 0x20UL; +const PassAttr PRUNING = 0x40UL; // whether the pass need to pre run before graph fusion +const PassAttr ENABLE_FUSION_CHECK = 0x80UL; // enable do fusion check for matched fusion nodes during ub fusion +/* + * Compile level reg, if reg multi level, the lowest level will take effect. + * */ +const PassAttr COMPILE_LEVEL_O0 = 0x100UL; // pure dynamic +const PassAttr COMPILE_LEVEL_O1 = 0x200UL; // static functional optimize +const PassAttr COMPILE_LEVEL_O2 = 0x400UL; // no time and space balance optimize +const PassAttr COMPILE_LEVEL_O3 = 0x800UL; // open all optimize +const PassAttr PASS_BIT_MASK = 0x1UL; // check if the loweset bit of pass is 1 + +enum class PassAttrType { + FRBDN_CLOSE = 0, // Mark those passes that cannot be turned off in graph mode + NEED_TOPO_SORT = 1, // Mark those graph fusion passes that need topological sorting before executing + SINGLE_OP_SCENE_MUST_ON = 2, // Mark those passes that must be turned on in single-op mode or jit_compile=false + FE_PASS_FLAG = 3, // Mark those passes that belong to FE + AUTO_FUSION_FLAG = 4, // Using auto match fusion frame + /* The OpDescs in the patterns of this kind fusion pass are able to be generalized in all scenarios. + * For example, they can ignore the value dependency restrict. */ + ALWAYS_GENERALIZE_FLAG = 5, + PRUNING_FLAG = 6, + FUSION_CHECK_FLAG = 7, // Do fusion check for matched fusion nodes during ub fusion + COMPILE_O0 = 8, + COMPILE_O1 = 9, + COMPILE_O2 = 10, + COMPILE_O3 = 11, +}; +bool IsPassAttrTypeOn(PassAttr pass_attr, PassAttrType attr_type); +void RegPassCompileLevel(const std::string &pass_name, PassAttr pass_attr); +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_PASS_DESC_H_ diff --git a/inc/metadef/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h b/inc/metadef/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h new file mode 100644 index 0000000000000000000000000000000000000000..76759668df468d87a600072b579e475ca2fcab35 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/fusion_statistic_recorder.h @@ -0,0 +1,98 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H + +#include +#include +#include +#include +#include +#include + +namespace fe { + +class FusionInfo { + public: + explicit FusionInfo(const uint64_t session_id = 0, const std::string graph_id = "", const std::string pass_name = "", + const int32_t match_times = 0, const int32_t effect_times = 0, const int32_t repo_hit_times = 0); + + virtual ~FusionInfo(); + + void AddMatchTimes(const int32_t match_times); + + void AddEffectTimes(const int32_t effect_times); + + int32_t GetMatchTimes() const; + + void SetMatchTimes(const int32_t match_times); + + int32_t GetEffectTimes() const; + + void SetEffectTimes(const int32_t effect_times); + + int32_t GetRepoHitTimes() const; + + void SetRepoHitTimes(const int32_t repo_hit_times); + + std::string GetGraphId() const; + + std::string GetPassName() const; + + uint64_t GetSessionId() const; + + private: + uint64_t session_id_; + std::string graph_id_; + std::string pass_name_; + int32_t match_times_; + int32_t effect_times_; + int32_t repo_hit_times_; +}; + +using FusionStatisticMap = std::map>; + +class FusionStatisticRecorder { + public: + FusionStatisticRecorder(const FusionStatisticRecorder &) = delete; + + FusionStatisticRecorder &operator=(const FusionStatisticRecorder &) = delete; + + static FusionStatisticRecorder &Instance(); + + void UpdateGraphFusionMatchTimes(const FusionInfo &fusion_info); + + void UpdateGraphFusionEffectTimes(const FusionInfo &fusion_info); + + void UpdateBufferFusionMatchTimes(const FusionInfo &fusion_info); + + void UpdateBufferFusionEffectTimes(const FusionInfo &fusion_info); + + void GetAndClearFusionInfo(const std::string &session_graph_id, + std::map &graph_fusion_info_map, + std::map &buffer_fusion_info_map); + + void GetFusionInfo(const std::string &session_graph_id, std::map &graph_fusion_info_map, + std::map &buffer_fusion_info_map); + + void GetAllSessionAndGraphIdList(std::vector &session_graph_id_vec); + + private: + FusionStatisticRecorder(); + virtual ~FusionStatisticRecorder(); + FusionStatisticMap graph_fusion_info_map_; + FusionStatisticMap buffer_fusion_info_map_; + void ClearFusionInfo(const std::string& session_graph_id); + + std::recursive_mutex mutex_; +}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_STATISTIC_RECORDER_H diff --git a/inc/metadef/register/graph_optimizer/fusion_common/fusion_turbo.h b/inc/metadef/register/graph_optimizer/fusion_common/fusion_turbo.h new file mode 100644 index 0000000000000000000000000000000000000000..2366c530dc37cfaf391266e5e7a07a809e4f8c38 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/fusion_turbo.h @@ -0,0 +1,249 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_H +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_H +#include +#include +#include "graph/anchor.h" +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "graph/node.h" +#include "graph/utils/anchor_utils.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include "register/graph_optimizer/fusion_common/fusion_turbo_utils.h" + +namespace fe { +enum TensorUptType { + UPDATE_NONE = 0, + UPDATE_THIS = 1, + UPDATE_PEER, +}; + +struct WeightInfo { + ge::GeShape shape; + ge::GeShape ori_shape; + ge::DataType datatype; + ge::DataType ori_datatype; + ge::Format format; + ge::Format ori_format; + uint8_t *data; + int64_t shape_size; + size_t total_data_size; // data_size * sizeof(datatype). !!!Could be zero!!! + inline void CalcTotalDataSize() { + if (shape.GetDimNum() == 0) { + shape_size = 1; + } else { + shape_size = shape.GetShapeSize(); + } + + if ((shape_size > 0) && (datatype < data_type_size.size())) { + total_data_size = (static_cast(shape_size)) * data_type_size[datatype]; + } else { + total_data_size = 0; + } + } + + WeightInfo(const ge::GeTensorDesc &tensor_desc, + void *data_p); + + WeightInfo(const ge::NodePtr &node, const int32_t &index, + void *data_p); + + WeightInfo(const ge::GeShape &shape_p, const ge::GeShape &ori_shape_p, + const ge::DataType &datatype_p, const ge::DataType &ori_datatype_p, + const ge::Format &format_p, const ge::Format &ori_format_p, void *data_p); + + WeightInfo(ge::GeShape &&shape_p, ge::GeShape &&ori_shape_p, + const ge::DataType &datatype_p, const ge::DataType &ori_datatype_p, + const ge::Format &format_p, const ge::Format &ori_format_p, void *data_p); + + WeightInfo(const ge::GeShape &shape_p, const ge::DataType &datatype_p, + const ge::Format &format_p, void *data_p); + + WeightInfo(ge::GeShape &&shape_p, const ge::DataType &datatype_p, + const ge::Format &format_p, void *data_p); +}; + +class FusionTurbo { + public: + explicit FusionTurbo(const ge::ComputeGraphPtr &graph); + + explicit FusionTurbo(ge::ComputeGraph &graph); + + ~FusionTurbo(); + + static Status BreakInput(const ge::NodePtr &node, + const vector &input_index); + + static Status BreakOutput(const ge::NodePtr &node, + const vector &output_index); + + static Status BreakAllInput(const ge::NodePtr &node); + + static Status BreakAllOutput(const ge::NodePtr &node); + + Status RemoveNodeWithRelink(const ge::NodePtr &node, const std::initializer_list &io_map = {}); + + Status RemoveNodeWithRelink(const ge::NodePtr &node, const std::vector &io_map = {}); + + Status RemoveNodeOnly(const ge::NodePtr &node); + + /* If the node has no subsequent nodes, remove it. + * If the node has subsequent nodes, just return. + * Parameter include_control_nodes: + * If only_care_data_nodes = true, then we will ignore the control outputs. */ + Status RemoveDanglingNode(const ge::NodePtr &node, const bool &only_care_data_nodes = false); + + Status RemoveMultiNodesOnly(const std::vector &nodes); + + ge::NodePtr UpdateConst(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const; + + /* 1. If index is larger than or equalt to the input size of node, add a weight + * tensor and node as the last input of node. + * 2. If index is less than the input size of node and: + * 2.1 If the peer node of this input index is nullptr, we add a const node + * as input and update tensor desc. ---> Call AddConstNode. + * 2.2 If the peer node of this input index is Const, we substitute the data + * of current Const and update tensor desc. ---> Call UpdateConst + * 2.3 If the peer node of this input is other type, we just skip it. */ + ge::NodePtr AddWeight(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const; + + /* Add weight after one output of node. For example: + * NodeA----> NodeB + * \----> NodeC + * After calling AddWeightAfter(NodeA, 0, w_info), the graph will be like: + * NewWeight----> NodeB + * \----> NodeC + * NodeA(will be dangling) + * The rule is adding weight in front of every peer out node of NodeA. + */ + ge::NodePtr AddWeightAfter(const ge::NodePtr &node, const int32_t &index, const WeightInfo &w_info) const; + + ge::NodePtr AddWeight(const ge::NodePtr &node, const string& tensor_name, const WeightInfo &w_info) const; + + /* Add a weight tensor and node as the last input of node. */ + ge::NodePtr AddWeight(const ge::NodePtr &node, const WeightInfo &w_info) const; + + std::vector AddWeights(const ge::NodePtr &node, + const vector &w_infos) const; + + static ge::GeTensorPtr MutableWeight(const ge::NodePtr &node, int32_t index); + + ge::NodePtr AddNodeOnly(const string &op_name, const string &op_type) const; + + static ge::NodePtr AddNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type); + + ge::NodePtr AddNodeOnly(const string &op_name, const string &op_type, + size_t dynamic_num) const; + + static ge::NodePtr AddNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type, + size_t dynamic_num); + + ge::NodePtr InsertNodeOnly(const string &op_name, const string &op_type, + const ge::NodePtr &origin_node, + const size_t dynamic_num = 0UL) const; + + static ge::NodePtr InsertNodeOnly(ge::ComputeGraph &graph, const string &op_name, const string &op_type, + const ge::NodePtr &origin_node, + const size_t dynamic_num = 0UL); + + static ge::OpDescPtr CreateOpDesc(const string &op_name, + const string &op_type, const size_t dynamic_num); + + static Status TransferOutCtrlEdges(const std::vector &nodes, + const ge::NodePtr &new_node); + + static Status TransferInCtrlEdges(const std::vector &nodes, + const ge::NodePtr &new_node); + + ge::NodePtr InsertNodeBefore(const string &op_name, const string &op_type, + const ge::NodePtr &base_node, const int32_t &base_input_index, + const int32_t &input_index = 0, + const int32_t &output_index = 0) const; + + ge::NodePtr InsertNodeAfter(const string &op_name, const string &op_type, + const ge::NodePtr &base_node, const int32_t &base_output_index, + const int32_t &input_index = 0, const int32_t &output_index = 0) const; + + static Status LinkInput(Relations &input_relations, + const ge::NodePtr &dst_node, + const TensorUptType &update_tensor = UPDATE_THIS); + + static Status LinkOutput(Relations &output_relations, + const ge::NodePtr &src_node, + const TensorUptType &update_tensor = UPDATE_THIS); + + static ge::NodePtr GetPeerOutNode(const ge::NodePtr &node, const int32_t &this_node_input_index); + + static std::vector GetPeerInNodes(const ge::NodePtr &node, const int32_t &this_node_output_index); + + /* Check whether there is a path from [node1's] output [index1] to [node2]. + * The default value is -1 and -1 means any output is ok. */ + static bool CheckConnected(const ge::NodePtr &node1, const ge::NodePtr &node2, + const int32_t &index1 = -1); + + /* Default update input 0 of node. */ + Status UpdateInputByPeer(const ge::NodePtr &node, const int32_t &index, + const ge::NodePtr &peer_node, const int32_t &peer_index) const; + + Status UpdateOutputByPeer(const ge::NodePtr &node, const int32_t &index, + const ge::NodePtr &peer_node, const int32_t &peer_index) const; + + static bool IsUnknownShape(const ge::NodePtr &node, const int32_t &index, const bool &is_input = true); + + static bool IsUnknownOriShape(const ge::NodePtr &node, const int32_t &index, const bool &is_input = true); + + ge::NodePtr MultiInOne(const string &node_name, const string &node_type, + Relations &input_relations, + Relations &output_relations, + const std::vector &old_nodes = {}, + const bool &remove_old = true); + + Status MultiInOne(const ge::NodePtr &new_node, + Relations &input_relations, + Relations &output_relations, + const std::vector &old_nodes = {}, + const bool &remove_old = true); + + static bool HasControl(const ge::NodePtr &node); + + static bool HasInControl(const ge::NodePtr &node); + + static bool HasOutControl(const ge::NodePtr &node); + + static bool HasOutData(const ge::NodePtr &node); + + static Status MoveDataOutputUp(const ge::NodePtr &node, int32_t index); + + /* move node to pre node if pre node has subgraph + * @param node current need move node + * @param index node move input index + **/ + Status GraphNodeUpMigration(const ge::NodePtr &node, const int32_t index); + + /* move node to next node if next node has subgraph + * @param node current need move node + * @param index node move output index + **/ + Status GraphNodeDownMigration(const ge::NodePtr &node, const int32_t index); + + static NodeIndex GetPeerInFirstPair(const ge::NodePtr &node, int32_t index); + + static NodeIndex GetPeerOutPair(const ge::NodePtr &node, int32_t index); + private: + /* AddWeight will do either AddConstNode or UpdateConst. */ + ge::NodePtr AddConstNode(const ge::NodePtr &node, const int32_t &index, + const WeightInfo &w_info) const; + + ge::ComputeGraphPtr graph_; +}; +} +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_H diff --git a/inc/metadef/register/graph_optimizer/fusion_common/fusion_turbo_utils.h b/inc/metadef/register/graph_optimizer/fusion_common/fusion_turbo_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..8c81b202023897c9437bb8e202714fdaf2b1dcc6 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/fusion_turbo_utils.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_UTILS_H +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_UTILS_H +#include +#include "graph/utils/op_desc_utils.h" +#include "graph/debug/ge_log.h" + +#define FUSION_TURBO_NOTNULL(val, ret) \ + do { \ + if ((val) == nullptr) { \ + GELOGD("Parameter[%s] must not be null.", #val); \ + return ret; \ + } \ + } while (0) + +namespace fe { +enum Direction { + CURRENT = 0, /* 表示NodeIndex指示的是当前节点的对应输入输出。 */ + /* 当连接输入的场景,PEER模式下会获取的对端输出节点和对端index。 */ + /* 当连接输出的场景,PEER模式下会获取的所有对端输入节点和所有对端index。 */ + PEER = 1, + /* 当连接输入的场景,PEER_SINGLE模式下会获取的对端输出节点和对端index。和PEER一致。 */ + /* 当连接输出的场景,PEER_SINGLE模式下会获取的第一个对端输出节点和对端index。 */ + PEER_SINGLE = 2 +}; + +struct NodeIndex { + ge::NodePtr node; + int32_t index; + Direction direction = CURRENT; + NodeIndex() { + node = nullptr; + index = -1; + } + NodeIndex(const ge::NodePtr &node_param, int32_t index_param) { + node = node_param; + index = index_param; + } + + NodeIndex(const std::pair &node_index_pair) { + node = node_index_pair.first; + index = node_index_pair.second; + } + + NodeIndex(const ge::NodePtr &node_param, int32_t index_param, Direction direction_param) { + node = node_param; + index = index_param; + direction = direction_param; + } +}; +using NodeIndices = std::vector; + +using ThisIndex = int32_t; + +class Relations { + public: + Relations(); + + Relations(const std::initializer_list &peer_indices); + + explicit Relations(const std::map &relations_param); + + explicit Relations(std::map &&relations_param); + + Relations(const Relations &relations_param); + + Relations(Relations &&relations_param) noexcept; + + Relations(ThisIndex this_index, const NodeIndex &peer_index); + + Relations(ThisIndex this_index, const NodeIndices &peer_indices); + + Relations(ThisIndex this_index, NodeIndex &&peer_index); + + Relations(ThisIndex this_index, NodeIndices &&peer_indices); + + Relations(const std::initializer_list> &peer_indices); + + Relations(const std::initializer_list>> &peer_indices_vec); + + /****** Interface Add from here. ******/ + Relations& Add(ThisIndex this_index, const NodeIndex &peer_index); + + Relations& Add(ThisIndex this_index, const std::initializer_list &peer_indices); + + Relations& Add(ThisIndex this_index, const NodeIndices &peer_indices); + + Relations& Add(ThisIndex this_index, NodeIndex &&peer_index); + + Relations& Add(ThisIndex this_index, NodeIndices &&peer_indices); + + /* 由于NodeIndex当连接输入或输出是完全不一样的,我们需要根据原始relations计算作为 + * 输入和输出的真正的对端节点,所以要求必须通过接口来修改relations。 */ + Relations& UpdatePeerIndex(ThisIndex this_index, const NodeIndices &peer_indices); + + Relations& UpdatePeerIndex(ThisIndex this_index, NodeIndices &&peer_indices); + + Relations& UpdatePeerIndex(const std::map &peer_indices); + + Relations& UpdatePeerIndex(std::map &&peer_indices); + + const std::map& GetRelations(); + + const std::map& GetInRelations(); + + const std::map& GetOutRelations(); + + Relations& operator=(const Relations &relations_param); + + Relations& operator=(Relations &&relations_param) noexcept; + private: + NodeIndex GetPeerInFirstPair(ThisIndex relation_index, const ge::NodePtr &node, int32_t index); + + void AppendPeerInAllPairs(ThisIndex relation_index, const ge::NodePtr &node, int32_t index); + + void PreProcessOneNodeIndex(ThisIndex index, const NodeIndex &node_index); + void PreProcessNodeIndices(ThisIndex index, const NodeIndices &node_indices); + + void PreProcess(); + /* 我们在添加ori_relations的时候就把两个方向的节点都计算好。 */ + std::map in_relations; + + std::map out_relations; + + /* 如果key是输出的index,那么vector里存放的就是对端输入的index;在单输出多引用场景, + * 对端输入的index可能有多个。 + * 如果key是输入的index,那么vector里存放的就是对端输出的index。对端输出只会有一个。 */ + std::map ori_relations; +}; + +extern const std::array(ge::DT_MAX + 1)> data_type_size; +class FusionTurboUtils { + public: + static NodeIndex GetPeerInFirstPair(const ge::NodePtr &node, int32_t index); + static NodeIndex GetPeerOutPair(const ge::NodePtr &node, int32_t index); + static ge::NodePtr GetConstInput(const ge::NodePtr &node, int32_t index); +}; +} +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_COMMON_FUSION_TURBO_UTILS_H diff --git a/inc/metadef/register/graph_optimizer/fusion_common/graph_pass_util.h b/inc/metadef/register/graph_optimizer/fusion_common/graph_pass_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e42596e362074990c9b9527ecde885abab15d430 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/graph_pass_util.h @@ -0,0 +1,167 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ +#include "graph/compute_graph.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/type_utils.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include "external/graph/types.h" + +#include +#include +#include +#include +#include + +namespace fe { +enum class BackWardInheritMode { + kInsertNode = 0, + kFusedNode = 1, + kInheritTrue = 2, + kDoNotInherit = 3 +}; + +using NodeTypeMap = std::unordered_map>; +using NodeTypeMapPtr = std::shared_ptr; +struct NodeMapInfo { + int64_t run_count; + NodeTypeMapPtr node_type_map; +}; +using NodeMapInfoPtr = std::shared_ptr; +/** @brief define graph pass, which provides two interface: 1. run pass; +* 2. record op names before fusion */ +class GraphPassUtil { + public: + using OriginOpAttrsVec = std::vector>; + using UnorderedMapping = std::unordered_map; + /** set outputdesc attr for data dump + * + * @param origin_index,usually is origin node output index + * + * @param fusion_index,usually is fusion node output index + * + * @param origin_node, usually is origin node + * + * @param fusion_node, usually is fusion node + */ + static void SetOutputDescAttr(const uint32_t &origin_index, const uint32_t &fusion_index, + const ge::NodePtr &origin_node, const ge::NodePtr &fusion_node); + + static void SetOutputDescAttr(ge::ConstGeTensorDescPtr &origin_tensor_desc, const int64_t origin_index, + const ge::OpDescPtr &origin_op_desc, const ge::GeTensorDescPtr &target_tensor_desc); + + /** get origin format for data dump + * + * @param tensor_desc,usually is output_desc + * + * @return format of this tensor_desc + */ + static ge::Format GetDataDumpOriginFormat(const ge::GeTensorDescPtr &tensor_desc); + + static ge::Format GetDataDumpOriginFormat(ge::ConstGeTensorDescPtr &tensor_desc); + + /** set origin format for data dump + * + * @param origin format + * + * @param tensor_desc,usually is output_desc + */ + static void SetDataDumpOriginFormat(const ge::Format &origin_format, const ge::GeTensorDescPtr &tensor_desc); + + /** set origin datatype for data dump + * + * @param origin datatype + * + * @param tensor_desc,usually is output_desc + */ + static void SetDataDumpOriginDataType(const ge::DataType origin_data_type, const ge::GeTensorDescPtr &tensor_desc); + + /** get origin datatype for data dump + * + * @param tensor_desc,usually is output_desc + * + * @return format of this tensor_desc + */ + static ge::DataType GetDataDumpOriginDataType(const ge::GeTensorDescPtr &tensor_desc); + + static ge::DataType GetDataDumpOriginDataType(ge::ConstGeTensorDescPtr &tensor_desc); + + static void AddNodeFromOpTypeMap(const NodeMapInfoPtr &node_map_info, const ge::NodePtr &node_ptr); + + static Status GetOpTypeMapToGraph(NodeMapInfoPtr &node_map_info, const ge::ComputeGraph &graph); + + static void RecordPassnameAndOriginalAttrs(const std::vector &original_nodes, + std::vector &fus_nodes, const string &pass_name, + const OriginOpAttrsVec &origin_op_attrs = OriginOpAttrsVec()); + + static Status StoreAndUpdataOriginFusionPassName(const ge::OpDescPtr &op_desc, + const std::vector &original_nodes, + const std::string &pass_name); + + static void GetBackWardAttr(const std::vector &original_nodes, + bool &backward, BackWardInheritMode inherit_mode); + + static void InheritGraphRelatedAttr(const std::vector &original_nodes, + const std::vector &fusion_nodes, + BackWardInheritMode inherit_mode); + + /* If one of the original node has attribute like keep_dtype, the fused node + * will inherit that attribute. + * param inherit_mode: if fusion_nodes are newly inserted after original_nodes, + * backward attr will only care about its farther nodes(pass farther nodes in + * param original_nodes). + * And if fusion_nodes are fused by a bunch of original_nodes, the backward attr + * will not only care about original_nodes but also the input nodes of original_nodes. */ + static void InheritAttrFromOriNodes(const std::vector &original_nodes, + const std::vector &fusion_nodes, + BackWardInheritMode inherit_mode = BackWardInheritMode::kFusedNode); + + static void RecordOriginalOpAttrs(const std::vector &original_nodes, + const ge::OpDescPtr &op_desc, const string &pass_name, + const OriginOpAttrsVec &origin_op_attrs = OriginOpAttrsVec()); + + static void RecordOriginalNames(const std::vector &original_nodes, const ge::NodePtr &node); + + static void AddNodeToNodeTypeMap(const NodeTypeMapPtr &node_type_map, const std::string &op_type, + const ge::NodePtr &node_ptr); + + static void RemoveNodeFromNodeTypeMap(NodeTypeMapPtr &node_type_map, const std::string &op_type, + const ge::NodePtr &node_ptr); + + static void GetNodesFromNodeTypeMap(NodeTypeMapPtr &node_type_map, const std::string &op_type, + std::vector &nodes); + + static void GetOpCustomImplModeFromOriNode(const std::vector &original_nodes, + std::set &op_impl_mode_priority_set, + std::map &origin_node_impl_mode_map); + + static void SetOpCustomImplModeToFusNode(const ge::OpDescPtr &fusion_op, + const std::map &origin_node_impl_mode_map, + const std::set &op_impl_mode_priority_set); + + static void GetOpCustomGroupIdFromOriginNodes(const std::vector &original_nodes, + uint32_t ¶llel_group_id); + + static void SetOpCustomGroupIdToFusNode(const ge::OpDescPtr &fusion_op, const uint32_t ¶llel_group_id); + + static ge::OutDataAnchorPtr GetPeerOutAnchorNotInDeleteList(const ge::NodePtr &node, size_t idx); + + static void SetPairTensorIntAttr(const ge::NodePtr &node, size_t idx, const std::map &attr_val); + + static void SetPairTensorAttr(const ge::NodePtr &node, size_t idx, const std::map &attr_val, + bool is_input = true); +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_UTIL_H_ diff --git a/inc/metadef/register/graph_optimizer/fusion_common/op_slice_info.h b/inc/metadef/register/graph_optimizer/fusion_common/op_slice_info.h new file mode 100644 index 0000000000000000000000000000000000000000..ce7cefbf85d0a92b1fbbf4d66f1a1c59ec691af3 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/op_slice_info.h @@ -0,0 +1,191 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_COMMON_UTILS_AI_CORE_OP_SLICE_INFO_H +#define INC_COMMON_UTILS_AI_CORE_OP_SLICE_INFO_H + +#include +#include +#include + +namespace fe { +enum OpReduceType { REDUCE_MEAN = 0, REDUCE_ADD, REDUCE_MAX, REDUCE_MIN }; +enum OpL1FusionType { L1FUSION_DISABLE = 0, L1FUSION_BASIC, L1FUSION_INPUT_CTR }; + +class InputSplitInfoImpl; +using InputSplitInfoImplPtr = std::shared_ptr; +class InputSplitInfo; +using InputSplitInfoPtr = std::shared_ptr; +class OutputSplitInfoImpl; +using OutputSplitInfoImplPtr = std::shared_ptr; +class OutputSplitInfo; +using OutputSplitInfoPtr = std::shared_ptr; +class InputReduceInfoImpl; +using InputReduceInfoImplPtr = std::shared_ptr; +class InputReduceInfo; +using InputReduceInfoPtr = std::shared_ptr; +class OutputReduceInfoImpl; +using OutputReduceInfoImplPtr = std::shared_ptr; +class OutputReduceInfo; +using OutputReduceInfoPtr = std::shared_ptr; +class AxisSplitMapImpl; +using AxisSplitMapImplPtr = std::shared_ptr; +class AxisSplitMap; +using AxisSplitMapPtr = std::shared_ptr; +class AxisReduceMapImpl; +using AxisReduceMapImplPtr = std::shared_ptr; +class AxisReduceMap; +using AxisReduceMapPtr = std::shared_ptr; +class OpCalcInfoImpl; +using OpCalcInfoImplPtr = std::shared_ptr; +class OpCalcInfo; +using OpCalcInfoPtr = std::shared_ptr; + +class InputSplitInfo { + public: + InputSplitInfo(); + InputSplitInfo(const InputSplitInfo &input_split_info); + InputSplitInfo &operator = (const InputSplitInfo &input_split_info); + ~InputSplitInfo(); + bool Initialize(); + size_t GetIndex() const; + std::vector GetAxis() const; + std::vector GetHeadOverLap() const; + std::vector GetTailOverLap() const; + void SetIndex(const size_t& idx); + void SetAxis(std::vector& axis); + void SetHeadOverLap(std::vector& head_over_lap); + void SetTailOverLap(std::vector& tail_over_lap); + bool IsPtrNull() const; + private: + InputSplitInfoImplPtr split_impl_{nullptr}; +}; + +class OutputSplitInfo { + public: + OutputSplitInfo(); + OutputSplitInfo(const OutputSplitInfo &output_split_info); + OutputSplitInfo &operator = (const OutputSplitInfo &output_split_info); + ~OutputSplitInfo(); + bool Initialize(); + size_t GetIndex() const; + std::vector GetAxis() const; + void SetIndex(const size_t& idx); + void SetAxis(std::vector& axis); + bool IsPtrNull() const; + private: + OutputSplitInfoImplPtr split_impl_{nullptr}; +}; + +class InputReduceInfo { + public: + InputReduceInfo(); + InputReduceInfo(const InputReduceInfo &input_reduce_info); + InputReduceInfo &operator = (const InputReduceInfo &input_reduce_info); + ~InputReduceInfo(); + bool Initialize(); + size_t GetIndex() const; + std::vector GetAxis() const; + void SetIndex(const size_t& idx); + void SetAxis(std::vector& axis); + bool IsPtrNull() const; + private: + InputReduceInfoImplPtr reduce_impl_{nullptr}; +}; + +class OutputReduceInfo { + public: + OutputReduceInfo(); + OutputReduceInfo(const OutputReduceInfo &output_reduce_info); + OutputReduceInfo &operator = (const OutputReduceInfo &output_reduce_info); + ~OutputReduceInfo(); + bool Initialize(); + size_t GetIndex() const; + OpReduceType GetReduceType() const; + bool GetIsAtomic() const; + void SetIndex(const size_t& idx); + void SetReduceType(const OpReduceType& reduce_type); + void SetIsAtomic(const bool& is_atomic); + bool IsPtrNull() const; + private: + OutputReduceInfoImplPtr reduce_impl_{nullptr}; +}; + +class AxisSplitMap { + public: + friend class AxisSplitMapImpl; + AxisSplitMap(); + AxisSplitMap(const AxisSplitMap &axis_split_map); + AxisSplitMap &operator = (const AxisSplitMap &axis_split_map); + ~AxisSplitMap(); + bool Initialize(); + std::vector GetInputSplitInfos() const; + std::vector GetOutputSplitInfos() const; + std::vector GetInputSplitInfoVec() const; + std::vector GetOutputSplitInfoVec() const; + void AddInputSplitInfo(InputSplitInfo& input_split_info); + void SetInputSplitInfos(std::vector& input_split_vec); + void SetInputSplitInfos(std::vector& input_split_vec); + void AddOutputSplitInfo(OutputSplitInfo& output_split_info); + void SetOutputSplitInfos(std::vector& output_split_vec); + void SetOutputSplitInfos(std::vector& output_split_vec); + bool IsPtrNull() const; + private: + AxisSplitMapImplPtr aixs_split_impl_{nullptr}; +}; + +class AxisReduceMap { + public: + AxisReduceMap(); + AxisReduceMap(const AxisReduceMap &axis_reduce_map); + AxisReduceMap &operator = (const AxisReduceMap &axis_reduce_map); + ~AxisReduceMap(); + bool Initialize(); + friend class AxisReduceMapImpl; + std::vector GetInputReduceInfos() const; + std::vector GetOutputReduceInfos() const; + std::vector GetInputReduceInfoVec() const; + std::vector GetOutputReduceInfoVec() const; + void AddInputReduceInfo(InputReduceInfo& input_reduce_info); + void SetInputReduceInfos(std::vector& input_reduce_vec); + void SetInputReduceInfos(std::vector& input_reduce_vec); + void AddOutputReduceInfo(OutputReduceInfo& output_reduce_info); + void SetOutputReduceInfos(std::vector& output_reduce_vec); + void SetOutputReduceInfos(std::vector& output_reduce_vec); + bool IsPtrNull() const; + private: + AxisReduceMapImplPtr aixs_reduce_impl_{nullptr}; +}; + +class OpCalcInfo { + public: + OpCalcInfo(); + ~OpCalcInfo(); + bool Initialize(); + std::vector GetAxisSplitMaps() const; + std::vector GetAxisReduceMaps() const; + std::vector GetAxisSplitMapVec() const; + std::vector GetAxisReduceMapVec() const; + OpL1FusionType GetL1FusionEnable() const; + int64_t GetMinTbeL1Space() const; + void AddAxisSplitMap(AxisSplitMap& axis_split_map); + void SetAxisSplitMaps(std::vector& axis_split_vec); + void SetAxisSplitMaps(std::vector& axis_split_vec); + void AddAxisReduceMap(AxisReduceMap& axis_reduce_map); + void SetAxisReduceMaps(std::vector& axis_reduce_vec); + void SetAxisReduceMaps(std::vector& axis_reduce_vec); + void SetL1FusionEnable(const OpL1FusionType& l1_fusion_enable); + void SetMinTbeL1Space(const int64_t& min_tbe_l1_space); + void DelAxisSplitMapBaseAxis(std::vector& axis); + bool IsPtrNull() const; + private: + OpCalcInfoImplPtr op_calc_info_impl_{nullptr}; +}; +} // namespace fe +#endif // INC_COMMON_UTILS_AI_CORE_OP_SLICE_INFO_H diff --git a/inc/metadef/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h b/inc/metadef/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..af2432bf14e87463537facd4da0d0150d29701f6 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/pattern_fusion_base_pass.h @@ -0,0 +1,197 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ + +#include +#include +#include +#include +#include +#include + +#include "common/opskernel/ops_kernel_info_store.h" +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" +#include "register/graph_optimizer/graph_fusion/graph_pass.h" +#include "register/graph_optimizer/graph_fusion/connection_matrix.h" + +namespace fe { +using std::initializer_list; +using std::map; +using std::string; +using std::vector; +using namespace std; + +using OpsKernelInfoStorePtr = std::shared_ptr; +class PatternFusionBasePassImpl; +using PatternFusionBasePassImplPtr = std::shared_ptr; + +/** Pass based on pattern + * @ingroup FUSION_PASS_GROUP + * @note New virtual methods should be append at the end of this class + */ +class PatternFusionBasePass : public GraphPass { + public: + using OpDesc = FusionPattern::OpDesc; + using Mapping = std::map, std::vector, CmpKey>; + using Mappings = std::vector; + + PatternFusionBasePass(); + ~PatternFusionBasePass() override; + + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph) override; + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @param [ops_kernel_info_store_ptr, OP info kernel instance + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph, OpsKernelInfoStorePtr ops_kernel_info_store_ptr); + + /* Detect whether there are cycles in graph + * after fusing all nodes in param fusion_nodes. + * + * Compared with Cycle Detection + * @param fusion_nodes: each vector in fusion_nodes + * will be fused into an entity(which could contains + * more than one node). The caller should put all original + * nodes which are expected to be fused into one larger node + * into each sub-vector of fusion_nodes. + * + * This function can tell whether there are a cycle after + * fusing all nodes in fusion_nodes. Each vector in 2-d + * vector fusion_nodes will be fused into an entity. + * + * + * This interface cannot detect whether there are cycles + * inside the fused nodes. + * + * e.g. {a, b, c, d} -> {e, f} + * Because the edge information is not given for e and f + * so this function we cannot tell if e and f are in a + * cycle. + * */ + bool CycleDetection(const ge::ComputeGraph &graph, const std::vector> &fusion_nodes); + + bool CycleDetection(const ge::ComputeGraph &graph, const std::vector &fusion_nodes); + + void GetConnectionMatrix(std::unique_ptr &connection_matrix); + + void SetConnectionMatrix(std::unique_ptr &connection_matrix); + + const std::vector &GetPatterns(); + + const std::vector &GetInnerPatterns(); + + bool MatchFromOutput(const ge::NodePtr &output_node, const std::shared_ptr &output_op_desc, Mapping &mapping); + + protected: + virtual std::vector DefinePatterns() = 0; + + virtual Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, std::vector &new_nodes) = 0; + + virtual std::vector DefineInnerPatterns(); + + virtual void SetDataDumpAttr(const std::vector &fused_nodes, + const std::vector &fusion_nodes); + + virtual void SetOriginalOutputDumpAttr(const std::vector &fused_nodes, + const std::vector &fusion_nodes); + + virtual void SetOriginalOpDumpAttr(const std::vector &fused_nodes, + const std::vector &fusion_nodes); + + void SetActualFusedNodes(const std::vector &fused_nodes); + + std::vector GetNodesFromMapping(const Mapping &mapping) const; + ge::NodePtr GetNodeFromMapping(const std::string &id, const Mapping &mapping) const; + + void RecordOutputAnchorMap(ge::NodePtr output_node); + + void ClearOutputAnchorMap(); + + bool CheckOpSupported(const ge::OpDescPtr &op_desc_ptr) const; + + bool CheckOpSupported(const ge::NodePtr &node) const; + + bool CheckAccuracySupported(const ge::NodePtr &node) const; + + /** check whether the input graph is Cyclic + * + * @param graph need to be checked + * @return false or true + */ + bool CheckGraphCycle(ge::ComputeGraph &graph) const; + + void DumpMapping(const FusionPattern &pattern, const Mapping &mapping) const; + + private: + /** match all nodes in graph according to pattern + * + * @param pattern fusion pattern defined + * @param mappings match result + * @return SUCCESS, successfully add edge + * @return FAILED, fail + */ + bool MatchAll(const ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings); + + Status RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed); + + /* Check whether there are cycles after fusing scope_nodes as an + * entity. The algorithm is: + * If one of the output node of scope nodes has an edged linked to + * the scope nodes again, there will be a cycle. + * e.g. + * A + * / \ + * B \ + * / \ + * D------->C + * | | + * After fusion A/B/C, the graph looks like: + * <--- + * / \ + * ABC--->D + * There obviously a cycle in the fused graph. + * */ + bool DetectOneScope(const std::vector &scope_nodes) const; + + bool CheckEachPeerOut(const ge::NodePtr &node, + const std::unordered_set &scope_nodes_set, + const std::vector &scope_nodes) const; + + void StoreOriginOpNames(const Mapping &mapping, std::vector &origin_op_names) const; + + /** Internal implement class ptr */ + std::shared_ptr pattern_fusion_base_pass_impl_ptr_; + + std::unordered_map> origin_op_anchors_map_; + + /* For detecting cycles, we will only build connectivity once. + * One time generation of connectivity needs O(n+e) where n is + * total number of nodes and e is total number of edges, which is + * not tolerable. And this requires one pass only executed once. + * */ + std::unique_ptr connectivity_{nullptr}; +}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_PATTERN_FUSION_BASE_PASS_H_ diff --git a/inc/metadef/register/graph_optimizer/fusion_common/unknown_shape_utils.h b/inc/metadef/register/graph_optimizer/fusion_common/unknown_shape_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..238eb351ee33367b9f1201c0188f71ba2d9c46da --- /dev/null +++ b/inc/metadef/register/graph_optimizer/fusion_common/unknown_shape_utils.h @@ -0,0 +1,43 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_UNKNOWN_SHAPE_UTILS_H +#define INC_REGISTER_GRAPH_OPTIMIZER_UNKNOWN_SHAPE_UTILS_H +#include "graph/utils/graph_utils.h" +namespace fe { +class UnknownShapeUtils { +public: + /* + * @ingroup fe + * @brief check whether the node is unknown shape. + * @param [in] input or output tensor. + * @return true: unknown; false: known + */ + static bool IsUnknownShapeOp(const ge::OpDesc &op_desc); + + /* + * @ingroup fe + * @brief check whether the input or output shape contains -2. + * @param op_desc input or output desc. + * @return true: contains; false: not contains + */ + static bool IsContainUnknownDimNum(const ge::OpDesc &op_desc); + + /* + * @brief check whether the value is -1 or -2 + * @param input or ourput shape dim + * @return true: contains; false: not contains + */ + static bool IsUnknownShapeValue(const int64_t &value); +private: + static bool IsUnKnownShapeTensor(const ge::OpDesc &op_desc); +}; +} // namespace fe + +#endif diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/connection_matrix.h b/inc/metadef/register/graph_optimizer/graph_fusion/connection_matrix.h new file mode 100644 index 0000000000000000000000000000000000000000..3995e5b156cfa4eb683058467950a109fb57b451 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/connection_matrix.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_CONNECTION_MATRIX_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_CONNECTION_MATRIX_H_ + +#include "graph/debug/ge_attr_define.h" +#include "graph/node.h" +#include "graph/graph.h" +#include "graph/compute_graph.h" +#include "common/large_bm.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { +class ConnectionMatrix { +public: + ConnectionMatrix(); + explicit ConnectionMatrix(bool enable_data_flow); + explicit ConnectionMatrix(const ge::ComputeGraph &graph); + + ~ConnectionMatrix(); + + bool IsConnected(const ge::NodePtr &a, const ge::NodePtr &b) const; + + // inputs are all input nodes of parameter node. + // if there is a path between A->B, then B will own A's + // connectivity. The reason is --- + // If some node can reach A, than it can also reach B. + void SetConnectivity(const ge::Node::Vistor &inputs, const ge::NodePtr &node); + + bool IsDataConnected(const ge::NodePtr &a, const ge::NodePtr &b) const; + + /* Computes the connectivity between two nodes in the + * computation. The returned ConnectivityMatrix is constructed such that + * ConnectivityMatrix::IsConnected(a, b) returns true iff there exists a + * directed path (from producer to consumer) from 'a' to 'b'. Both data + * connection and control connection are considered for connectivity. + * A node is connected to itself. */ + void Generate(const ge::ComputeGraph &graph); + + // update reachablity map for fused nodes. + void Update(const ge::ComputeGraph &graph, const std::vector &fusion_nodes); + + void BackupBitMap(); + + void RestoreBitMap(); + +private: + int64_t GetIndex(const ge::NodePtr &node) const; + + const ge::LargeBitmap &GetBitMap(const ge::NodePtr &node) const; + + ge::LargeBitmap &GetBitMap(const ge::NodePtr &node); + + ge::LargeBitmap &GetBitMap(uint64_t index); + + const ge::LargeBitmap &GetDataBitMap(const ge::NodePtr &node) const; + + ge::LargeBitmap &GetDataBitMap(const ge::NodePtr &node); + + ge::LargeBitmap &GetDataBitMap(uint64_t index); + + void SetDataConnectivity(const ge::Node::Vistor &inputs, const ge::NodePtr &node); + + bool enable_data_flow_; + size_t size_; + std::vector bit_maps; + std::vector bit_maps_back_up_; + std::vector data_bit_maps_; + std::vector data_bit_maps_back_up_; + std::map name_to_index_; +}; +} +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_CONNECTION_MATRIX_H_ diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h b/inc/metadef/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..85bbff0155c7d9afbfb6685283b860fc1bc93d0e --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/fusion_pass_manager/fusion_pass_registry.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_FUSION_PASS_MANAGER_FUSION_PASS_REGISTRY_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_FUSION_PASS_MANAGER_FUSION_PASS_REGISTRY_H_ + +#include +#include +#include +#include +#include "register/graph_optimizer/fusion_common/fusion_pass_desc.h" +#include "register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h" + +namespace fe { +class FusionPassRegistry { + public: + using CreateFn = GraphPass *(*)(); + struct PassDesc { + PassAttr attr; + CreateFn create_fn; + }; + ~FusionPassRegistry(); + + static FusionPassRegistry &GetInstance(); + + void RegisterPass(const GraphFusionPassType &pass_type, const std::string &pass_name, CreateFn create_fn, + PassAttr attr) const; + + std::map GetPassDesc(const GraphFusionPassType &pass_type); + + std::map GetCreateFnByType(const GraphFusionPassType &pass_type); + + private: + FusionPassRegistry(); + class FusionPassRegistryImpl; + std::unique_ptr impl_; +}; + +class FusionPassRegistrar { + public: + FusionPassRegistrar(const GraphFusionPassType &pass_type, const std::string &pass_name, + GraphPass *(*create_fn)(), PassAttr attr); + + ~FusionPassRegistrar() {} +}; + +#define REGISTER_PASS(pass_name, pass_type, pass_class) \ + REG_PASS(pass_name, pass_type, pass_class, 0) + +#define REG_PASS(pass_name, pass_type, pass_class, attr) \ + REG_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass_type, pass_class, attr) + +#define REG_PASS_UNIQ_HELPER(ctr, pass_name, pass_type, pass_class, attr) \ + REG_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) + +#define REG_PASS_UNIQ(ctr, pass_name, pass_type, pass_class, attr) \ + static ::fe::FusionPassRegistrar register_fusion_pass##ctr __attribute__((unused)) = ::fe::FusionPassRegistrar( \ + pass_type, pass_name, []() -> ::fe::GraphPass * { return new (std::nothrow) pass_class(); }, attr) +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_FUSION_PASS_MANAGER_FUSION_PASS_REGISTRY_H_ diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.h b/inc/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.h new file mode 100644 index 0000000000000000000000000000000000000000..25999fcd8eb55651acf84ce95f15cfac5864a8ac --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/fusion_pattern.h @@ -0,0 +1,213 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ +#include +#include +#include +#include +#include + +namespace fe { +extern const uint32_t kFuzzyOutIndex; +/** Fusion pattern + * @ingroup FUSION_PASS_GROUP + * Describe Pattern of Ops waiting for fusion(Op type, etc) + */ +class FusionPattern { + public: + struct OpDesc; + using OpDescPtr = std::shared_ptr; + using OutputMapVecStr = std::initializer_list>>; + using OutputMapStr = std::initializer_list>; + using OutputMapDesc = std::map>; + /** + * @ingroup fe + * @brief description of Ops + */ + struct OpDesc { + std::string id; // Identifier + std::vector types; // the Op types of Ops + std::vector inputs; // all input Ops + OutputMapDesc outputs; // all output Ops + bool repeatable; // flag to show if match multiple Ops or not + bool check_unique; // flag op desc can be matched by only one node + bool is_output; // flag to show if the op is output node + bool is_output_fullmatch; // flag to match all output + size_t output_size; // all output size + bool allow_dumpable; + }; + + public: + explicit FusionPattern(const std::string name = ""); + ~FusionPattern(); + + /** set pattern name + * + * @param name pattern name + * @return FusionPattern + */ + FusionPattern &SetName(const std::string &name); + + /** add Op description with unknown number of args + * + * @param id pattern id + * @param types op type list + * @return FusionPattern + */ + FusionPattern &AddOpDesc(const std::string &id, const std::initializer_list &types = {}, + const bool allow_dumpable = false, const bool check_unique = false); + + /** add Op description with vector + * + * @param id pattern id + * @param types op type list + * + * @return FusionPattern + */ + FusionPattern &AddOpDesc(const std::string &id, const std::vector &types, + const bool allow_dumpable = false, const bool check_unique = false); + + /** set input Ops with unknown number of args + * + * @param id pattern id + * + * @param input_ids inputs to id op + * + * @return FusionPattern + */ + FusionPattern &SetInputs(const std::string &id, const std::initializer_list &input_ids); + + /** set input Ops with unknown number of args + * + * @param id pattern id + * + * @param input_ids inputs to id op + * + * @return FusionPattern + */ + FusionPattern &SetInputs(const std::string &id, const std::vector &input_ids); + + /** set output Ops with unknown number of args + * + * @param id pattern id + * + * @param output_map output map + * + * @param is_fullmatched flag of output full matched + * + * @return FusionPattern + */ + FusionPattern &SetOutputs(const std::string &id, const OutputMapStr &output_map, bool is_fullmatched = true); + + /** set output Ops with unknown number of args + * + * @param id pattern id + * + * @param output_map output map + * + * @param is_fullmatched flag of output full matched + * + * @return FusionPattern + */ + FusionPattern &SetOutputs(const std::string &id, const OutputMapVecStr &output_map, bool is_fullmatched = true); + + /** set output Op + * + * @param id pattern id + * + * @return FusionPattern + */ + FusionPattern &SetOutput(const std::string &id); + + /** build pattern and check if error exists + * + * @return True or False + */ + bool Build(); + + /** get pattern name + * + * @param id pattern id + * + * @return fusion pattern name + */ + const std::string &GetName() const; + + /** get the OpDesc of input Ops (const) + * + * @param op_desc op_desc for getting inputs + * + * @return op_desc's iniput opdesc list + */ + static const std::vector> *GetInputs(const std::shared_ptr op_desc); + + /** get the OpDesc of output Ops (const) + * + * @param op_desc op_desc for getting outputs + * + * @return op_desc's output opdesc map + */ + static const OutputMapDesc &GetOutputs(const OpDescPtr op_desc); + + /** get the OpDesc of output size + * + * @param op_desc op_desc for getting output size + * + * @return op_desc's output size + */ + static size_t GetOutputSize(const OpDescPtr op_desc); + + /** get the OpDesc of output Op + * + * @return pattern's output opdesc list + */ + const std::shared_ptr GetOutput() const; + + /** print pattern + * + */ + void Dump() const; + + /** get OpDesc based on ID, return nullptr if failed + * + * @param id pattern id + * + * @return pattern's output opdesc list + */ + std::shared_ptr GetOpDesc(const std::string &id) const; + + const std::vector> &GetOpDescs() const; + private: + FusionPattern(const FusionPattern &) = default; + FusionPattern &operator=(const FusionPattern &) = default; + + void SetError(); + + private: + std::string name_; + + std::vector> ops_; + + std::map> op_map_; + + std::shared_ptr output_; + + bool has_error_; +}; +struct CmpKey { + bool operator() (const std::shared_ptr &key1, + const std::shared_ptr &key2) const { + return (key1->id) < (key2->id); + } +}; +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/fusion_quant_util.h b/inc/metadef/register/graph_optimizer/graph_fusion/fusion_quant_util.h new file mode 100644 index 0000000000000000000000000000000000000000..50a1455709a5c1c816effd764858cefff4091c42 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/fusion_quant_util.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_FUSION_QUANT_UTIL_H_ +#define INC_FUSION_QUANT_UTIL_H_ +#include "graph/node.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" +#include + +struct BiasOptimizeEdges { + ge::InDataAnchorPtr quant_scale; + ge::InDataAnchorPtr quant_offset; + ge::InDataAnchorPtr cube_weight; + ge::InDataAnchorPtr cube_bias; + ge::InDataAnchorPtr deq_scale; + bool isValid() { + return !(cube_weight == nullptr || cube_bias == nullptr); + } +}; + +namespace fe { +struct QuantParam { + float quant_scale; + float quant_offset; +}; + +enum class WeightMode { + WEIGHTWITH2D = 0, + WEIGHTWITH5D = 1, + RESERVED +}; + +class QuantUtil { + public: + static Status BiasOptimizeByEdge(BiasOptimizeEdges ¶m, std::vector &fusion_nodes); + static Status BiasOptimizeByEdge(ge::NodePtr &quant_node, BiasOptimizeEdges ¶m, + std::vector &fusion_nodes); + static Status BiasOptimizeByEdge(QuantParam &quant_param, BiasOptimizeEdges ¶m, + std::vector &fusion_nodes, + WeightMode cube_type = WeightMode::RESERVED); + static Status InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr deq_scale, std::vector &fusion_nodes); + static Status InsertFixpipeDequantScaleConvert(ge::InDataAnchorPtr &deq_scale, ge::InDataAnchorPtr &quant_offset, + std::vector &fusion_nodes); + static Status InsertQuantScaleConvert(ge::InDataAnchorPtr &quant_scale, ge::InDataAnchorPtr &quant_offset, + std::vector &fusion_nodes); + static Status InsertRequantScaleConvert(ge::InDataAnchorPtr &req_scale, ge::InDataAnchorPtr &quant_offset, + ge::InDataAnchorPtr &cuba_bias, std::vector &fusion_nodes); +}; +} // namespace fe +#endif diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h b/inc/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h new file mode 100644 index 0000000000000000000000000000000000000000..3e7e772e40cf299784bd0fe3ea44e0afb80d7447 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/graph_fusion_pass_base.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ + +#include +#include +#include +#include +#include + +#include "register/graph_optimizer/graph_fusion/fusion_pattern.h" +#include "register/graph_optimizer/graph_fusion/graph_pass.h" + +namespace fe { +using std::initializer_list; +using std::map; +using std::string; +using std::vector; +using namespace std; + +enum GraphFusionPassType { + BUILT_IN_GRAPH_PASS = 0, + BUILT_IN_VECTOR_CORE_GRAPH_PASS, + CUSTOM_AI_CORE_GRAPH_PASS, + CUSTOM_VECTOR_CORE_GRAPH_PASS, + SECOND_ROUND_BUILT_IN_GRAPH_PASS, + BUILT_IN_BEFORE_TRANSNODE_INSERTION_GRAPH_PASS, + BUILT_IN_PREPARE_GRAPH_PASS, + BUILT_IN_BEFORE_QUANT_OPTIMIZATION_GRAPH_PASS, + BUILT_IN_TF_TAG_NO_CONST_FODING_GRAPH_PASS, + BUILT_IN_TF_MERGE_SUB_GRAPH_PASS, + BUILT_IN_QUANT_OPTIMIZATION_GRAPH_PASS, + BUILT_IN_EN_ISA_ARCH_EXC_V300_AND_V220_GRAPH_PASS, + BUILT_IN_EN_ISA_ARCH_V100_GRAPH_PASS, + BUILT_IN_EN_ISA_ARCH_V200_GRAPH_PASS, + BUILT_IN_DELETE_NO_CONST_FOLDING_GRAPH_PASS, + BUILT_IN_AFTER_MULTI_DIMS_PASS, + BUILT_IN_AFTER_OPTIMIZE_STAGE1, + BUILT_IN_AFTER_OP_JUDGE, + BUILT_IN_AFTER_BUFFER_OPTIMIZE, + GRAPH_FUSION_PASS_TYPE_RESERVED +}; +class PatternFusionBasePassImpl; +using PatternFusionBasePassImplPtr = std::shared_ptr; + +/** Pass based on pattern + * @ingroup FUSION_PASS_GROUP + * @note New virtual methods should be append at the end of this class + */ +class GraphFusionPassBase : public GraphPass { + public: + using OpDesc = FusionPattern::OpDesc; + using Mapping = std::map, std::vector, CmpKey>; + using Mappings = std::vector; + + GraphFusionPassBase(); + virtual ~GraphFusionPassBase() override; + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph) override; + + protected: + /** define pattern + * + * @return NA + */ + virtual std::vector DefinePatterns() = 0; + + /** do fusion according to nodes matched + * + * @param graph the graph waiting for pass level optimization + * @param new_nodes fusion result node(s) + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Fusion(ge::ComputeGraph &graph, Mapping &mapping, std::vector &new_nodes) = 0; + + /** get nodes from matched result + * + * @param mapping match result + * @return nodes result + */ + static ge::NodePtr GetNodeFromMapping(const std::string &id, const Mapping &mapping); + + private: + /** match all nodes in graph according to pattern + * + * @param pattern fusion pattern defined + * @param mappings match result + * @return SUCCESS, successfully add edge + * @return FAILED, fail + */ + bool MatchAll(const ge::ComputeGraph &graph, const FusionPattern &pattern, Mappings &mappings) const; + + Status RunOnePattern(ge::ComputeGraph &graph, const FusionPattern &pattern, bool &changed); + + /** Internal implement class ptr */ + std::shared_ptr pattern_fusion_base_pass_impl_ptr_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_FUSION_PASS_BASE_H_ diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/graph_pass.h b/inc/metadef/register/graph_optimizer/graph_fusion/graph_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..149faf0a14aaec93b35fc2009609ec4dd07367c1 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/graph_pass.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ + +#include "register/graph_optimizer/graph_fusion/pass.h" + +namespace fe { + +/** graph pass + * @ingroup GRAPH_PASS_GROUP + * graph level pass + */ +class GraphPass : public Pass { + public: + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph) override = 0; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_GRAPH_PASS_H_ diff --git a/inc/metadef/register/graph_optimizer/graph_fusion/pass.h b/inc/metadef/register/graph_optimizer/graph_fusion/pass.h new file mode 100644 index 0000000000000000000000000000000000000000..643c2e93d18c18261055464986368eb8c248445d --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_fusion/pass.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +/** @defgroup FUSION_PASS_GROUP Fusion Pass Interface */ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ +#define INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ + +#include "graph/compute_graph.h" +#include "register/graph_optimizer/graph_optimize_register_error_codes.h" + +namespace fe { + +/** fusion pass + * @ingroup GRAPH_PASS_GROUP + * network level pass + */ +template +class Pass { + public: + virtual ~Pass() {} + + /** execute pass + * + * @param [in] graph, the graph waiting for pass level optimization + * @return SUCCESS, successfully optimized the graph by the pass + * @return NOT_CHANGED, the graph did not change + * @return FAILED, fail to modify graph + */ + virtual Status Run(ge::ComputeGraph &graph) = 0; + + void SetName(const std::string &name) { name_ = name; } + + std::string GetName() { return name_; } + + private: + std::string name_; +}; + +} // namespace fe + +#endif // INC_REGISTER_GRAPH_OPTIMIZER_PASS_H_ diff --git a/inc/metadef/register/graph_optimizer/graph_optimize_register_error_codes.h b/inc/metadef/register/graph_optimizer/graph_optimize_register_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..2b7054b646e4036d62b63ffacf549e9bc17f45e5 --- /dev/null +++ b/inc/metadef/register/graph_optimizer/graph_optimize_register_error_codes.h @@ -0,0 +1,48 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ +#define INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ + +#include +#include +#include + +namespace fe { + +/** Assigned SYS ID */ +const uint8_t SYSID_FE = 3; + +/** Common module ID */ +const uint8_t FE_MODID_COMMON = 50; + +/** FE error code definiton Macro +* Build error code +*/ +#define FE_DEF_ERRORNO(sysid, modid, name, value, desc) \ + static constexpr fe::Status name = \ + ((((static_cast((0xFF) & (static_cast(sysid)))) << 24) | \ + ((static_cast((0xFF) & (static_cast(modid)))) << 16)) | \ + ((0xFFFF) & (static_cast(value)))); + +using Status = uint32_t; + +#define FE_DEF_ERRORNO_COMMON(name, value, desc) \ + FE_DEF_ERRORNO(SYSID_FE, FE_MODID_COMMON, name, value, desc) + +using Status = uint32_t; + +FE_DEF_ERRORNO(0, 0, SUCCESS, 0, "success"); +FE_DEF_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFF, "failed"); +FE_DEF_ERRORNO_COMMON(NOT_CHANGED, 201, "The nodes of the graph not changed."); +FE_DEF_ERRORNO_COMMON(PARAM_INVALID, 1, "Parameter's invalid!"); +FE_DEF_ERRORNO_COMMON(GRAPH_FUSION_CYCLE, 301, "Graph is cycle after fusion!"); + +} // namespace fe +#endif // INC_REGISTER_GRAPH_OPTIMIZE_REGISTER_ERROR_CODES_H_ diff --git a/inc/metadef/register/host_cpu_context.h b/inc/metadef/register/host_cpu_context.h new file mode 100644 index 0000000000000000000000000000000000000000..e28dadc941f2c6966b0bfd300cc007b3935026d0 --- /dev/null +++ b/inc/metadef/register/host_cpu_context.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_HOST_CPU_CONTEXT_H_ +#define INC_REGISTER_HOST_CPU_CONTEXT_H_ + +#include "external/ge_common/ge_api_error_codes.h" +#include "register/register_types.h" + +namespace ge { +class HostCpuContext { + public: + HostCpuContext() = default; + ~HostCpuContext() = default; + private: + class Impl; + Impl *impl_; +}; +} // namespace ge + +#endif // INC_REGISTER_HOST_CPU_CONTEXT_H_ diff --git a/inc/metadef/register/infer_axis_slice_registry.h b/inc/metadef/register/infer_axis_slice_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..2cf89e0f85f4c296f6065240baea1dd8ccfe5a48 --- /dev/null +++ b/inc/metadef/register/infer_axis_slice_registry.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_INFER_AXIS_SLICE_REGISTRY_H_ +#define INC_REGISTER_INFER_AXIS_SLICE_REGISTRY_H_ + +#include "external/graph/ge_error_codes.h" +#include "external/graph/operator.h" +#include "external/graph/types.h" +#include "graph/axis_type_info.h" + +namespace ge { +// cut tensor : axis index : slice range +using DataSliceInfo = std::vector>>; +using InferAxisTypeInfoFunc = std::function &)>; +using InferAxisSliceFunc = std::function; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferAxisTypeInfoFuncRegister { + public: + InferAxisTypeInfoFuncRegister(const std::string &operator_type, + const InferAxisTypeInfoFunc &infer_axis_type_info_func); + InferAxisTypeInfoFuncRegister(const char_t *const operator_type, + const InferAxisTypeInfoFunc &infer_axis_type_info_func); + ~InferAxisTypeInfoFuncRegister() = default; +}; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferAxisSliceFuncRegister { + public: + InferAxisSliceFuncRegister(const std::string &operator_type, const InferAxisSliceFunc &infer_axis_slice_func); + InferAxisSliceFuncRegister(const char_t *const operator_type, const InferAxisSliceFunc &infer_axis_slice_func); + ~InferAxisSliceFuncRegister() = default; +}; +} + +#define PASTE(g_register, y) g_register##y + +// infer axis type info func register +#define IMPLEMT_COMMON_INFER_AXIS_TYPE_INFO(func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static ge::graphStatus (func_name)(ge::Operator &op, \ + std::vector &axis_type) + +#define IMPLEMT_INFER_AXIS_TYPE_INFO(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static ge::graphStatus (func_name)((op_name) &op, \ + std::vector &axis_type) + +#define INFER_AXIS_TYPE_INFO_FUNC(op_name, x) [](ge::Operator &v, std::vector &axis_type) \ + { return (x)(v, axis_type); } + +#define INFER_AXIS_TYPE_INFO_REG_IMPL(op_name, x, n) \ + static const ge::InferAxisTypeInfoFuncRegister PASTE(ids_register, n)(#op_name, (x)) + +#define INFER_AXIS_TYPE_INFO_REG(op_name, x) \ + INFER_AXIS_TYPE_INFO_REG_IMPL(op_name, INFER_AXIS_TYPE_INFO_FUNC(op_name, x), __COUNTER__) + +// infer axis slice func register +#define IMPLEMT_COMMON_INFER_AXIS_SLICE(func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static ge::graphStatus (func_name)(ge::Operator &op, \ + const ge::AxisTypeInfo &axis_info, const ge::DataSliceInfo &output_param, ge::DataSliceInfo &input_param) + +#define IMPLEMT_INFER_AXIS_SLICE(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static ge::graphStatus (func_name)((op_name) &op, \ + const ge::AxisTypeInfo &axis_info, const ge::DataSliceInfo &output_param, ge::DataSliceInfo &input_param) + +#define INFER_AXIS_SLICE_FUNC(op_name, x) [](ge::Operator &v, const ge::AxisTypeInfo &axis_info, \ + const ge::DataSliceInfo &output_param, ge::DataSliceInfo &input_param) \ + { return (x)(v, axis_info, output_param, input_param); } + +#define INFER_AXIS_SLICE_FUNC_REG_IMPL(op_name, x, n) \ + static const ge::InferAxisSliceFuncRegister PASTE(ids_register, n)(#op_name, (x)) + +#define INFER_AXIS_SLICE_FUNC_REG(op_name, x) \ + INFER_AXIS_SLICE_FUNC_REG_IMPL(op_name, INFER_AXIS_SLICE_FUNC(op_name, x), __COUNTER__) + +#endif // INC_REGISTER_INFER_AXIS_SLICE_REGISTRY_H_ diff --git a/inc/metadef/register/infer_data_slice_registry.h b/inc/metadef/register/infer_data_slice_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..c2c7140c30783c2bce1f8ed0e2714681d3b2fbe5 --- /dev/null +++ b/inc/metadef/register/infer_data_slice_registry.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_INFER_DATA_SLICE_REGISTRY_H_ +#define INC_REGISTER_INFER_DATA_SLICE_REGISTRY_H_ + +#include "external/graph/ge_error_codes.h" +#include "external/graph/operator.h" +#include "external/graph/types.h" + +namespace ge { +using InferDataSliceFunc = std::function; + +class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferDataSliceFuncRegister { + public: + InferDataSliceFuncRegister(const char_t *const operator_type, const InferDataSliceFunc &infer_data_slice_func); + ~InferDataSliceFuncRegister() = default; +}; +} // namespace ge + +// infer data slice func register +#define IMPLEMT_COMMON_INFER_DATA_SLICE(func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus (func_name)(Operator &op) + +#define IMPLEMT_INFER_DATA_SLICE(op_name, func_name) \ + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY static graphStatus (func_name)(op::op_name &op) + +#define INFER_DATA_SLICE_FUNC(op_name, x) [](Operator &v) { return (x)((op::op_name &)v); } + +#define INFER_DATA_SLICE_FUNC_REG_IMPL(op_name, x, n) \ + static const InferDataSliceFuncRegister PASTE(ids_register, n)(#op_name, (x)) + +#define INFER_DATA_SLICE_FUNC_REG(op_name, x) \ + INFER_DATA_SLICE_FUNC_REG_IMPL(op_name, INFER_DATA_SLICE_FUNC(op_name, x), __COUNTER__) + +#endif // INC_REGISTER_INFER_DATA_SLICE_REGISTRY_H_ diff --git a/inc/metadef/register/kernel_registry.h b/inc/metadef/register/kernel_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..4a29665eb903a49d1a012151a90848639633724a --- /dev/null +++ b/inc/metadef/register/kernel_registry.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef B369E37D560547C2B8DC137404F9713E_H +#define B369E37D560547C2B8DC137404F9713E_H +#include +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "graph/types.h" +#include "graph/fast_graph/fast_node.h" +#include "exe_graph/runtime/base_type.h" +#include "exe_graph/runtime/kernel_context.h" +#include "exe_graph/runtime/dfx_info_filler.h" + +namespace ge { +class Node; +} // namespace ge + +namespace gert { +class KernelRegistry { + public: + static KernelRegistry &GetInstance(); + static void ReplaceKernelRegistry(std::shared_ptr registry); + + using CreateOutputsFunc = UINT32 (*)(const ge::FastNode *, KernelContext *); + using KernelFunc = UINT32 (*)(KernelContext *context); + using TracePrinter = std::vector (*)(const KernelContext *); + using ProfilingInfoFiller = ge::graphStatus (*)(const KernelContext *, ProfilingInfoWrapper &); + using DataDumpInfoFiller = ge::graphStatus (*)(const KernelContext *, DataDumpInfoWrapper &); + using ExceptionDumpInfoFiller = ge::graphStatus (*)(const KernelContext *, ExceptionDumpInfoWrapper &); + + struct KernelFuncs { + KernelFunc run_func; + CreateOutputsFunc outputs_creator; + TracePrinter trace_printer; + ProfilingInfoFiller profiling_info_filler; + DataDumpInfoFiller data_dump_info_filler; + ExceptionDumpInfoFiller exception_dump_info_filler; + }; + + struct KernelInfo { + KernelFuncs func; + std::string critical_section; + }; + + virtual ~KernelRegistry() = default; + virtual const KernelFuncs *FindKernelFuncs(const std::string &kernel_type) const = 0; + virtual const KernelInfo *FindKernelInfo(const std::string &kernel_type) const = 0; + virtual void RegisterKernel(std::string kernel_type, KernelInfo kernel_infos) { + (void) kernel_type; + (void) kernel_infos; + }; +}; + +class KernelRegisterData; +class KernelRegisterV2 { + public: + explicit KernelRegisterV2(const ge::char_t *kernel_type); + KernelRegisterV2(const KernelRegisterV2 &other); + ~KernelRegisterV2(); + KernelRegisterV2 &operator=(const KernelRegisterV2 &other) = delete; + KernelRegisterV2 &operator=(KernelRegisterV2 &&other) = delete; + KernelRegisterV2(KernelRegisterV2 &&other) = delete; + + KernelRegisterV2 &RunFunc(KernelRegistry::KernelFunc func); + KernelRegisterV2 &ConcurrentCriticalSectionKey(const std::string &critical_section_key); + + KernelRegisterV2 &OutputsCreator(KernelRegistry::CreateOutputsFunc func); + KernelRegisterV2 &TracePrinter(KernelRegistry::TracePrinter func); + KernelRegisterV2 &ProfilingInfoFiller(KernelRegistry::ProfilingInfoFiller func); + KernelRegisterV2 &DataDumpInfoFiller(KernelRegistry::DataDumpInfoFiller func); + KernelRegisterV2 &ExceptionDumpInfoFiller(KernelRegistry::ExceptionDumpInfoFiller func); + + private: + std::unique_ptr register_data_; +}; +} // namespace gert + +#define REGISTER_KERNEL_COUNTER2(type, counter) static auto g_register_kernel_##counter = gert::KernelRegisterV2(#type) +#define REGISTER_KERNEL_COUNTER(type, counter) REGISTER_KERNEL_COUNTER2(type, counter) +#define REGISTER_KERNEL(type) REGISTER_KERNEL_COUNTER(type, __COUNTER__) + +#endif diff --git a/inc/metadef/register/kernel_registry_impl.h b/inc/metadef/register/kernel_registry_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..71c91dc8b3351fef09a5987c6e3c7f447f847ef7 --- /dev/null +++ b/inc/metadef/register/kernel_registry_impl.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_KERNEL_REGISTER_IMPL_H_ +#define INC_EXTERNAL_REGISTER_KERNEL_REGISTER_IMPL_H_ +#include +#include + +#include "kernel_registry.h" + +namespace gert { +class KernelRegistryImpl : public KernelRegistry { + public: + static KernelRegistryImpl &GetInstance(); + void RegisterKernel(std::string kernel_type, KernelInfo kernel_infos) override; + const KernelFuncs *FindKernelFuncs(const std::string &kernel_type) const override; + const KernelInfo *FindKernelInfo(const std::string &kernel_type) const override; + + const std::unordered_map &GetAll() const; + + private: + std::unordered_map kernel_infos_; +}; +} + +#endif // INC_EXTERNAL_REGISTER_KERNEL_REGISTER_IMPL_H_ diff --git a/inc/metadef/register/node_converter_registry.h b/inc/metadef/register/node_converter_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..8fdd84c357eb2c89b040319fcb800e4f35f0460f --- /dev/null +++ b/inc/metadef/register/node_converter_registry.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef AIR_CXX_RUNTIME_V2_LOWERING_NODE_CONVERTER_REGISTRY_H_ +#define AIR_CXX_RUNTIME_V2_LOWERING_NODE_CONVERTER_REGISTRY_H_ +#include +#include +#include + +#include "graph/node.h" +#include "external/graph/types.h" +#include "exe_graph/lowering/dev_mem_value_holder.h" +#include "exe_graph/lowering/lowering_global_data.h" + +namespace gert { +struct LowerInput { + std::vector input_shapes; + std::vector input_addrs; + LoweringGlobalData *global_data; +}; +struct LowerResult { + HyperStatus result; + std::vector order_holders; + std::vector out_shapes; + std::vector out_addrs; +}; +class NodeConverterRegistry { + public: + using NodeConverter = LowerResult (*)(const ge::NodePtr &node, const LowerInput &lower_input); + struct ConverterRegisterData { + NodeConverter converter; + int32_t require_placement; + }; + static NodeConverterRegistry &GetInstance(); + NodeConverter FindNodeConverter(const std::string &func_name); + const ConverterRegisterData *FindRegisterData(const std::string &func_name) const; + void RegisterNodeConverter(const std::string &func_name, NodeConverter func); + void Register(const std::string &func_name, const ConverterRegisterData &data); + + private: + std::unordered_map names_to_register_data_; +}; + +class NodeConverterRegister { + public: + NodeConverterRegister(const ge::char_t *lower_func_name, NodeConverterRegistry::NodeConverter func) noexcept; + NodeConverterRegister(const ge::char_t *lower_func_name, int32_t require_placement, + NodeConverterRegistry::NodeConverter func) noexcept; +}; +} // namespace gert + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif + +#define GERT_REGISTER_NODE_CONVERTER_COUNTER2(type, placement, func, counter) \ + static const gert::NodeConverterRegister g_register_node_converter_##counter ATTRIBUTE_USED = \ + gert::NodeConverterRegister(type, placement, func) +#define GERT_REGISTER_NODE_CONVERTER_COUNTER(type, placement, func, counter) \ + GERT_REGISTER_NODE_CONVERTER_COUNTER2(type, placement, func, counter) +#define REGISTER_NODE_CONVERTER_PLACEMENT(type, placement, func) \ + GERT_REGISTER_NODE_CONVERTER_COUNTER(type, placement, func, __COUNTER__) +#define REGISTER_NODE_CONVERTER(type, func) REGISTER_NODE_CONVERTER_PLACEMENT(type, -1, func) + +#endif // AIR_CXX_RUNTIME_V2_LOWERING_NODE_CONVERTER_REGISTRY_H_ diff --git a/inc/metadef/register/op_check_register.h b/inc/metadef/register/op_check_register.h new file mode 100644 index 0000000000000000000000000000000000000000..7a1712751e6514dbb9b54728d6805be90e513233 --- /dev/null +++ b/inc/metadef/register/op_check_register.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OP_CHECK_REGISTER_H_ +#define INC_REGISTER_OP_CHECK_REGISTER_H_ + +#include + +#include "graph/ascend_string.h" +#include "graph/operator.h" +#include "graph/ge_error_codes.h" +#include "register/op_def.h" + +namespace optiling { +struct ReplayFuncParam { + int32_t block_dim = 0; + const char *tiling_data = nullptr; + const char *kernel_name = nullptr; + const char *entry_file = nullptr; + int32_t gentype = 0; + const char *output_kernel_file = nullptr; + char **objptr = nullptr; + int32_t task_ration = 0; + int32_t tiling_key = 0; +}; + +using REPLAY_FUNC = int32_t (*)(ReplayFuncParam ¶m, const int32_t core_type); +using GEN_SIMPLIFIEDKEY_FUNC = bool (*)(const ge::Operator &op, ge::AscendString &result); + +class OpCheckFuncRegistry { +public: + static void RegisterOpCapability(const ge::AscendString &check_type, const ge::AscendString &op_type, + OP_CHECK_FUNC func); + + static OP_CHECK_FUNC GetOpCapability(const ge::AscendString &check_type, const ge::AscendString &op_type); + + static void RegisterGenSimplifiedKeyFunc(const ge::AscendString &op_type, GEN_SIMPLIFIEDKEY_FUNC func); + + static GEN_SIMPLIFIEDKEY_FUNC GetGenSimplifiedKeyFun(const ge::AscendString &op_type); + + static PARAM_GENERALIZE_FUNC GetParamGeneralize(const ge::AscendString &op_type); + + static void RegisterParamGeneralize(const ge::AscendString &op_type, PARAM_GENERALIZE_FUNC func); + + static void RegisterReplay(const ge::AscendString &op_type, const ge::AscendString &soc_version, REPLAY_FUNC func); + static REPLAY_FUNC GetReplay(const ge::AscendString &op_type, const ge::AscendString &soc_version); + +private: + static std::map> check_op_capability_instance_; + static std::map gen_simplifiedkey_instance_; + static std::map param_generalize_instance_; + static std::map> replay_instance_; +}; + +class ReplayFuncHelper { +public: + ReplayFuncHelper(const ge::AscendString &op_type, const ge::AscendString &soc_version, REPLAY_FUNC func); +}; +} // end of namespace optiling +#endif // INC_REGISTER_OP_CHECK_REGISTER_H_ diff --git a/inc/metadef/register/op_ext_calc_param_registry.h b/inc/metadef/register/op_ext_calc_param_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..c22d4d4164db64ae0639665e99881a8c27b80c0b --- /dev/null +++ b/inc/metadef/register/op_ext_calc_param_registry.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OP_EXTRA_CALC_PARAM_REGISTRY_H_ +#define INC_REGISTER_OP_EXTRA_CALC_PARAM_REGISTRY_H_ + +#include +#include +#include +#include "graph/node.h" +#include "external/ge_common/ge_api_types.h" + +namespace fe { +using OpExtCalcParamFunc = ge::Status (*)(const ge::Node &node); +class OpExtCalcParamRegistry { + public: + OpExtCalcParamRegistry() {}; + ~OpExtCalcParamRegistry() {}; + static OpExtCalcParamRegistry &GetInstance(); + OpExtCalcParamFunc FindRegisterFunc(const std::string &op_type) const; + void Register(const std::string &op_type, const OpExtCalcParamFunc func); + + private: + std::unordered_map names_to_register_func_; + }; +class OpExtGenCalcParamRegister { + public: + OpExtGenCalcParamRegister(const char *op_type, OpExtCalcParamFunc func) noexcept; +}; +} // namespace fe + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif + +#define REGISTER_NODE_EXT_CALC_PARAM_COUNTER2(type, func, counter) \ + static const fe::OpExtGenCalcParamRegister g_reg_op_ext_gentask_##counter ATTRIBUTE_USED = \ + fe::OpExtGenCalcParamRegister(type, func) +#define REGISTER_NODE_EXT_CALC_PARAM_COUNTER(type, func, counter) \ + REGISTER_NODE_EXT_CALC_PARAM_COUNTER2(type, func, counter) +#define REGISTER_NODE_EXT_CALC_PARAM(type, func) \ + REGISTER_NODE_EXT_CALC_PARAM_COUNTER(type, func, __COUNTER__) +#endif // INC_REGISTER_OP_EXTRA_CALC_PARAM_REGISTRY_H_ diff --git a/inc/metadef/register/op_ext_gentask_registry.h b/inc/metadef/register/op_ext_gentask_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..f036832b3bcc2eafcfc04def448f2a34af1b4f83 --- /dev/null +++ b/inc/metadef/register/op_ext_gentask_registry.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OP_EXTRA_GENTASK_REGISTRY_H +#define INC_REGISTER_OP_EXTRA_GENTASK_REGISTRY_H +#include +#include +#include +#include "graph/node.h" +#include "proto/task.pb.h" +#include "external/ge_common/ge_api_types.h" +#include "common/opskernel/ops_kernel_info_types.h" +namespace fe { +using OpExtGenTaskFunc = ge::Status (*)(const ge::Node &node, + ge::RunContext &context, std::vector &tasks); +class OpExtGenTaskRegistry { + public: + OpExtGenTaskRegistry() {}; + ~OpExtGenTaskRegistry() {}; + static OpExtGenTaskRegistry &GetInstance(); + OpExtGenTaskFunc FindRegisterFunc(const std::string &op_type) const; + void Register(const std::string &op_type, const OpExtGenTaskFunc func); + + private: + std::unordered_map names_to_register_func_; +}; + +class OpExtGenTaskRegister { +public: + OpExtGenTaskRegister(const char *op_type, OpExtGenTaskFunc func) noexcept; +}; +} // namespace fe + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif + +#define REGISTER_NODE_EXT_GENTASK_COUNTER2(type, func, counter) \ + static const fe::OpExtGenTaskRegister g_reg_op_ext_gentask_##counter ATTRIBUTE_USED = \ + fe::OpExtGenTaskRegister(type, func) +#define REGISTER_NODE_EXT_GENTASK_COUNTER(type, func, counter) \ + REGISTER_NODE_EXT_GENTASK_COUNTER2(type, func, counter) +#define REGISTER_NODE_EXT_GENTASK(type, func) \ + REGISTER_NODE_EXT_GENTASK_COUNTER(type, func, __COUNTER__) +#endif // INC_REGISTER_OP_EXTRA_GENTASK_REGISTRY_H diff --git a/inc/metadef/register/op_impl_kernel_registry.h b/inc/metadef/register/op_impl_kernel_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..48c471f5c8e94de455fa10c1387a028d07791d3a --- /dev/null +++ b/inc/metadef/register/op_impl_kernel_registry.h @@ -0,0 +1,165 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_50EA5B1AAF3341A28036E698708ADB64_H +#define INC_50EA5B1AAF3341A28036E698708ADB64_H +#include +#include +#include +#include "graph/ge_error_codes.h" +#include "exe_graph/runtime/base_type.h" +#include "exe_graph/runtime/infer_shape_context.h" +#include "exe_graph/runtime/infer_shape_range_context.h" +#include "exe_graph/runtime/tiling_context.h" +#include "exe_graph/runtime/op_execute_context.h" +#include "exe_graph/runtime/infer_datatype_context.h" +#include "graph/ascend_string.h" +#include "register/op_impl_registry.h" + +namespace gert { +class TilingParseContext; +struct OpImplKernelRegistry { + // for other code repo, they use those alias, but will delete later + using InferShapeKernelFunc = UINT32 (*)(InferShapeContext *); + using InferShapeRangeKernelFunc = UINT32 (*)(InferShapeRangeContext *); + using TilingKernelFunc = UINT32 (*)(TilingContext *); + using InferDataTypeKernelFunc = UINT32 (*)(InferDataTypeContext *); + using GenSimplifiedKeyKernelFunc = UINT32 (*)(TilingContext *, ge::char_t *); + // aclnn接口的原型,入参含义: + // OpExecuteContext:保存算子的Input,Output,Attr信息 + using OpExecuteFunc = UINT32 (*)(OpExecuteContext *); + using OpType = ge::AscendString; + using PrivateAttrList = std::vector>; + using PrivateAttrSet = std::unordered_set; + using CompileInfoCreatorFunc = void *(*) (); + using CompileInfoDeleterFunc = void (*)(void *); + using KernelFunc = UINT32 (*)(KernelContext *context); + using TilingParseFunc = UINT32 (*)(TilingParseContext *context); + + struct OpImplFunctions { + bool HasDataDependency() const { + return (inputs_dependency != 0U); + } + /* + * param index: must be ir index + */ + bool IsInputDataDependency(const size_t index) const { + if (index >= sizeof(inputs_dependency) * kByteBitCount) { + return false; + } + return static_cast(inputs_dependency & static_cast(1) << index); + } + ge::graphStatus SetInputDataDependency(const size_t index) { + if (index >= sizeof(inputs_dependency) * kByteBitCount) { + return ge::GRAPH_FAILED; + } + inputs_dependency |= 1UL << index; + return ge::GRAPH_SUCCESS; + } + + bool HasHostInput() const { + return (host_inputs != 0UL); + } + /* + * param index: must be ir index + */ + bool IsHostInput(const size_t index) const { + if (index >= (sizeof(host_inputs) * kByteBitCount)) { + return false; + } + return static_cast(host_inputs & (static_cast(1) << index)); + } + ge::graphStatus SetHostInputs(const size_t index) { + if (index >= (sizeof(host_inputs) * kByteBitCount)) { + return ge::GRAPH_FAILED; + } + host_inputs |= 1UL << index; + return ge::GRAPH_SUCCESS; + } + + bool HasTilingInputDataDependency() const { + return (tiling_dependency != 0UL); + } + /* + * param index: must be ir index + */ + bool IsTilingInputDataDependency(const size_t index) const { + if (index >= (sizeof(tiling_dependency) * kByteBitCount)) { + return false; + } + return static_cast(tiling_dependency & (static_cast(1) << index)); + } + ge::graphStatus SetTilingInputDataDependency(const size_t index) { + if (index >= (sizeof(tiling_dependency) * kByteBitCount)) { + return ge::GRAPH_FAILED; + } + tiling_dependency |= 1UL << index; + return ge::GRAPH_SUCCESS; + } + + bool IsSupportTilingDependencyPlacement(const uint32_t placement) const { + if (static_cast(placement) >= (sizeof(tiling_dependency_placements) * kByteBitCount)) { + return false; + } + + return static_cast(tiling_dependency_placements & (static_cast(1U) << placement)); + } + + ge::graphStatus SetTilingDependencyPlacement(const uint32_t placement) { + if (static_cast(placement) >= (sizeof(tiling_dependency_placements) * kByteBitCount)) { + return ge::GRAPH_FAILED; + } + tiling_dependency_placements |= (static_cast(1U) << placement); + return ge::GRAPH_SUCCESS; + } + + bool IsOutputShapeDependOnCompute() const { + return (output_shape_depend_compute != 0UL); + } + /* + * param index: must be ir index + */ + bool IsOutputShapeDependOnCompute(const size_t index) const { + if (index >= (sizeof(output_shape_depend_compute) * kByteBitCount)) { + return false; + } + return static_cast(output_shape_depend_compute & (static_cast(1) << index)); + } + + ge::graphStatus SetOutputShapeDependOnCompute(const size_t index) { + if (index >= (sizeof(output_shape_depend_compute) * kByteBitCount)) { + return ge::GRAPH_FAILED; + } + output_shape_depend_compute |= 1UL << index; + return ge::GRAPH_SUCCESS; + } + + OpImplRegisterV2::InferShapeKernelFunc infer_shape; + OpImplRegisterV2::InferShapeRangeKernelFunc infer_shape_range; + OpImplRegisterV2::InferDataTypeKernelFunc infer_datatype; + OpImplRegisterV2::TilingKernelFunc tiling; + OpImplRegisterV2::KernelFunc tiling_parse; + OpImplRegisterV2::CompileInfoCreatorFunc compile_info_creator; + OpImplRegisterV2::CompileInfoDeleterFunc compile_info_deleter; + size_t max_tiling_data_size = 0UL; + uint64_t inputs_dependency = 0UL; + static constexpr size_t kByteBitCount = 8UL; + OpImplRegisterV2::PrivateAttrList private_attrs; + // todo 去重和registry没关系,下一步从这里删除,移动到register中实现 + OpImplRegisterV2::PrivateAttrSet unique_private_attrs; + uint64_t host_inputs = 0UL; + OpImplRegisterV2::OpExecFunc op_execute_func; + uint64_t tiling_dependency = 0UL; + OpImplRegisterV2::GenSimplifiedKeyKernelFunc gen_simplifiedkey; + uint8_t tiling_dependency_placements = 0U; + uint64_t output_shape_depend_compute = 0UL; + }; +}; +} // namespace gert +#endif diff --git a/inc/metadef/register/op_impl_registry_api.h b/inc/metadef/register/op_impl_registry_api.h new file mode 100644 index 0000000000000000000000000000000000000000..48518422dc20738fdef961c7273243dfe946371b --- /dev/null +++ b/inc/metadef/register/op_impl_registry_api.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_API_H_ +#define INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_API_H_ + +#include +#include "op_impl_kernel_registry.h" + +struct TypesToImpl { + const char *op_type; + gert::OpImplKernelRegistry::OpImplFunctions funcs; +}; + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __GNUC__ +#define METADEF_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define METADEF_FUNC_VISIBILITY +#endif + +METADEF_FUNC_VISIBILITY size_t GetRegisteredOpNum(void); +METADEF_FUNC_VISIBILITY int32_t GetOpImplFunctions(TypesToImpl *impl, size_t impl_num); + +#ifdef __cplusplus +} +#endif + +#endif // INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_API_H_ diff --git a/inc/metadef/register/op_impl_registry_base.h b/inc/metadef/register/op_impl_registry_base.h new file mode 100644 index 0000000000000000000000000000000000000000..5a3195c4ca482defa1f94ee4ec72370d5678ba7e --- /dev/null +++ b/inc/metadef/register/op_impl_registry_base.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_BASE_H_ +#define INC_EXTERNAL_REGISTER_OP_IMPL_REGISTRY_BASE_H_ +#include "register/op_impl_kernel_registry.h" +namespace gert { +struct OpImplRegistryBase : public OpImplKernelRegistry { + virtual ~OpImplRegistryBase() = default; + virtual const OpImplFunctions *GetOpImpl(const ge::char_t *op_type) const = 0; + virtual const OpImplRegisterV2::PrivateAttrList &GetPrivateAttrs(const ge::char_t *op_type) const = 0; +}; + +class OpImplRegistry : public OpImplRegistryBase { + public: + static OpImplRegistry &GetInstance(); + OpImplFunctions &CreateOrGetOpImpl(const ge::char_t *op_type); + const OpImplFunctions *GetOpImpl(const ge::char_t *op_type) const override; + const OpImplRegisterV2::PrivateAttrList &GetPrivateAttrs(const ge::char_t *op_type) const override; + const std::map &GetAllTypesToImpl() const; + std::map &GetAllTypesToImpl(); + + private: + std::map types_to_impl_; + uint8_t reserved_[40] = {0U}; // Reserved field, 32+8, do not directly use when only 8-byte left +}; +} // namespace gert + +#endif diff --git a/inc/metadef/register/op_impl_registry_holder_manager.h b/inc/metadef/register/op_impl_registry_holder_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..1ec6c38b3fe1b7dca0b1b05128f347eac3e78622 --- /dev/null +++ b/inc/metadef/register/op_impl_registry_holder_manager.h @@ -0,0 +1,118 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_OP_IMPL_REGISTRY_HOLDER_MANAGER_H_ +#define INC_OP_IMPL_REGISTRY_HOLDER_MANAGER_H_ + +#include +#include +#include +#include +#include "graph/op_so_bin.h" +#include "register/op_impl_registry_api.h" +#include "register/op_ct_impl_registry_api.h" + +namespace gert { + +enum class ImplType { + RT_TYPE = 0, + CT_TYPE = 1, + END_TYPE +}; + +class OpImplRegistryHolder { + public: + OpImplRegistryHolder() = default; + + virtual ~OpImplRegistryHolder(); + + std::map &GetTypesToImpl() { + return types_to_impl_; + } + + std::map &GetTypesToCtImpl() { + return types_to_ct_impl_; + } + + void SetHandle(const void *handle) { handle_ = const_cast(handle); } + + ge::graphStatus GetOpImplFunctionsByHandle(const void *handle, const std::string &so_path) const; + + std::unique_ptr GetOpImplFunctionsByHandle(const void *handle, + const std::string &so_path, + size_t &impl_num) const; + void AddTypesToImpl(const gert::OpImplRegisterV2::OpType op_type, + const gert::OpImplKernelRegistry::OpImplFunctions funcs); + + protected: + std::map types_to_impl_; + std::map types_to_ct_impl_; + std::vector impl_map_vec_{&types_to_impl_, &types_to_ct_impl_}; + void *handle_ = nullptr; +}; +using OpImplRegistryHolderPtr = std::shared_ptr; + +class OmOpImplRegistryHolder : public OpImplRegistryHolder { + public: + OmOpImplRegistryHolder() = default; + + virtual ~OmOpImplRegistryHolder() override = default; + + ge::graphStatus LoadSo(const std::shared_ptr &so_bin); + + private: + ge::graphStatus CreateOmOppDir(std::string &opp_dir) const; + + ge::graphStatus RmOmOppDir(const std::string &opp_dir) const; + + ge::graphStatus SaveToFile(const std::shared_ptr &so_bin, const std::string &opp_path) const; +}; + +class OpImplRegistryHolderManager { + public: + OpImplRegistryHolderManager() = default; + + ~OpImplRegistryHolderManager(); + + static OpImplRegistryHolderManager &GetInstance(); + + void AddRegistry(std::string &so_data, const std::shared_ptr ®istry_holder); + + void UpdateOpImplRegistries(); + + const OpImplRegistryHolderPtr GetOpImplRegistryHolder(std::string &so_data); + + OpImplRegistryHolderPtr GetOrCreateOpImplRegistryHolder(std::string &so_data, + const std::string &so_name, + const ge::SoInOmInfo &so_info, + const std::function create_func); + + size_t GetOpImplRegistrySize() { + const std::lock_guard lock(map_mutex_); + return op_impl_registries_.size(); + } + + void ClearOpImplRegistries() { + const std::lock_guard lock(map_mutex_); + op_impl_registries_.clear(); + } + + private: + /** + * 背景:当前加载的so里边包含了算子原型等自注册(如OperatorFactoryImpl::operator_infer_axis_type_info_funcs等static变量,在进程退出前析构)。 + * 原方案使用weak_ptr管理OpImplRegistryHolder,不影响生命周期,引用计数一旦减为0,将触发析构,并关闭so的句柄; + * 如果OpImplRegistryHolder在比较早的时机析构,那么进程退出时,operator_infer_axis_type_info_funcs这些static变量将无法正常析构。 + * 此处临时改为shared_ptr,使OpImplRegistryHolder与本类单例的生命周期一致,从而能够确保在进程退出前才触发析构,规避上述问题。 + * todo 此处是规避方案,后续将继续使用weak_ptr来管理OpImplRegistryHolder,后续将梳理上述自注册机制,修改成space_registry注册机制 + * */ + std::unordered_map> op_impl_registries_; + std::mutex map_mutex_; +}; +} // namespace gert +#endif // INC_OP_IMPL_REGISTRY_HOLDER_MANAGER_H_ diff --git a/inc/metadef/register/op_impl_space_registry.h b/inc/metadef/register/op_impl_space_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..4196ce81356e296b9b6a492e3c24672c58dd29f1 --- /dev/null +++ b/inc/metadef/register/op_impl_space_registry.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_OP_IMPL_SPACE_REGISTRY_H_ +#define INC_OP_IMPL_SPACE_REGISTRY_H_ + +#include "register/op_impl_registry_holder_manager.h" +#include "graph/op_desc.h" + +namespace gert { +using OpTypesToImplMap = std::map; +using OpTypesToCtImplMap = std::map; +class OpImplSpaceRegistry : public std::enable_shared_from_this { + public: + OpImplSpaceRegistry() = default; + + ~OpImplSpaceRegistry() = default; + + ge::graphStatus GetOrCreateRegistry(const std::vector &bins, const ge::SoInOmInfo &so_info); + // 兼容air,待air合入后删除 + ge::graphStatus GetOrCreateRegistry(const std::vector &bins, const ge::SoInOmInfo &so_info, + const std::string &opp_path_identifier); + + ge::graphStatus AddRegistry(const std::shared_ptr ®istry_holder); + + const OpImplKernelRegistry::OpImplFunctions *GetOpImpl(const std::string &op_type) const; + + const OpCtImplKernelRegistry::OpCtImplFunctions *GetOpCtImpl(const std::string &op_type) const; + + const OpImplRegisterV2::PrivateAttrList &GetPrivateAttrs(const std::string &op_type) const; + + static ge::graphStatus LoadSoAndSaveToRegistry(const std::string &so_path); + + static ge::graphStatus ConvertSoToRegistry(const std::string &so_path, ge::OppImplVersion opp_impl_version); + + private: + void MergeTypesToImpl(OpTypesToImplMap &merged_impl, OpTypesToImplMap &src_impl) const; + void MergeTypesToCtImpl(OpTypesToCtImplMap &merged_impl, OpTypesToCtImplMap &src_impl) const; + void MergeFunctions(OpImplKernelRegistry::OpImplFunctions &merged_funcs, + const OpImplKernelRegistry::OpImplFunctions &src_funcs, const std::string &op_type) const; + void MergeCtFunctions(OpCtImplKernelRegistry::OpCtImplFunctions &merged_funcs, + const OpCtImplKernelRegistry::OpCtImplFunctions &src_funcs, const std::string &op_type) const; + private: + std::vector> op_impl_registries_; + OpTypesToImplMap merged_types_to_impl_; + OpTypesToCtImplMap merged_types_to_ct_impl_; +}; +using OpImplSpaceRegistryPtr = std::shared_ptr; +using OpImplSpaceRegistryArray = + std::array(ge::OppImplVersion::kVersionEnd)>; + +class DefaultOpImplSpaceRegistry { + public: + DefaultOpImplSpaceRegistry() = default; + + ~DefaultOpImplSpaceRegistry() = default; + + static DefaultOpImplSpaceRegistry &GetInstance(); + + const OpImplSpaceRegistryArray &GetDefaultSpaceRegistries() const { + return space_registries_; + } + + OpImplSpaceRegistryPtr &GetDefaultSpaceRegistry(ge::OppImplVersion opp_impl_version = ge::OppImplVersion::kOpp) { + if (opp_impl_version >= ge::OppImplVersion::kVersionEnd) { + opp_impl_version = ge::OppImplVersion::kOpp; + } + return space_registries_[static_cast(opp_impl_version)]; + } + + void SetDefaultSpaceRegistry(const OpImplSpaceRegistryPtr &space_registry, + ge::OppImplVersion opp_impl_version = ge::OppImplVersion::kOpp) { + space_registries_[static_cast(opp_impl_version)] = space_registry; + } + + private: + OpImplSpaceRegistryArray space_registries_; +}; +} // namespace gert +#endif // INC_OP_IMPL_SPACE_REGISTRY_H_ diff --git a/inc/metadef/register/op_kernel_registry.h b/inc/metadef/register/op_kernel_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..3fdd765c9b3a0db4c9f05e7db7a8f10623840d69 --- /dev/null +++ b/inc/metadef/register/op_kernel_registry.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OP_KERNEL_REGISTRY_H_ +#define INC_REGISTER_OP_KERNEL_REGISTRY_H_ +#include +#include +#include "register/register_types.h" +#include "register.h" + +namespace ge { +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpKernelRegistry { + public: + using CreateFn = HostCpuOp* (*)(); + ~OpKernelRegistry(); + + static OpKernelRegistry& GetInstance(); + + bool IsRegistered(const std::string &op_type) const; + + void RegisterHostCpuOp(const std::string &op_type, const CreateFn create_fn); + + std::unique_ptr CreateHostCpuOp(const std::string &op_type) const; + + private: + OpKernelRegistry(); + class OpKernelRegistryImpl; + /*lint -e148*/ + std::unique_ptr impl_; +}; +} // namespace ge + +#endif // INC_REGISTER_OP_KERNEL_REGISTRY_H_ diff --git a/inc/metadef/register/op_registry.h b/inc/metadef/register/op_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..509fe7156d1163fd5b95bb53b5f680f11c257da4 --- /dev/null +++ b/inc/metadef/register/op_registry.h @@ -0,0 +1,90 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OP_REGISTRY_H_ +#define INC_REGISTER_OP_REGISTRY_H_ + +#include +#include +#include +#include +#include +#include + +#include "register/register.h" + +namespace domi { +enum class RemoveInputType : uint16_t { + OMG_MOVE_TYPE_DTYPE = 0, + OMG_MOVE_TYPE_VALUE = 1, + OMG_MOVE_TYPE_SHAPE = 2, + OMG_MOVE_TYPE_FORMAT = 3, + OMG_MOVE_TYPE_AXIS = 4, + OMG_MOVE_TYPE_SCALAR_VALUE = 5, + OMG_REMOVE_TYPE_WITH_COND = 1000, + OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE = 1001, + OMG_INPUT_REORDER = 1002, +}; + +struct RemoveInputConfigure { + int32_t inputIdx = INT_MAX; + std::string attrName; + RemoveInputType moveType; + bool attrValue = false; + std::string originalType; + std::vector input_order; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { + public: + static OpRegistry *Instance(); + + std::vector registrationDatas; + + bool Register(const OpRegistrationData ®_data); + + domi::ImplyType GetImplyType(const std::string &op_type); + + void GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType imply_type) const; + + domi::ParseParamFunc GetParseParamFunc(const std::string &op_type, const std::string &ori_type); + + domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &ori_type); + + domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type, const std::string &ori_type); + + domi::FusionParseParamByOpFunc GetFusionParseParamByOpFunc(const std::string &op_type, + const std::string &ori_type); + + domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); + + Status GetParseSubgraphPostFunc(const std::string &op_type, domi::ParseSubgraphFuncV2 &parse_subgraph_func); + + domi::ImplyType GetImplyTypeByOriOpType(const std::string &ori_optype); + + const std::vector &GetRemoveInputConfigure(const std::string &ori_optype) const; + + bool GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type); + + ParseOpToGraphFunc GetParseOpToGraphFunc(const std::string &op_type, const std::string &ori_type); + + private: + std::unordered_map op_run_mode_map_; + std::unordered_map op_parse_params_fn_map_; + std::unordered_map parse_params_by_op_func_map_; + std::unordered_map fusion_op_parse_params_fn_map_; + std::unordered_map fusion_parse_params_by_op_fn_map_; + std::unordered_map op_types_to_parse_subgraph_post_func_; + std::unordered_map> remove_input_configure_map_; + std::map origin_type_to_om_type_; + std::unordered_map parse_op_to_graph_fn_map_; + std::unordered_map op_types_to_parse_subgraph_post_func_v2_; +}; +} // namespace domi +#endif // INC_REGISTER_OP_REGISTRY_H_ diff --git a/inc/metadef/register/op_tiling.h b/inc/metadef/register/op_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..18640179f31235a617dfebbc370a00492f95bd3d --- /dev/null +++ b/inc/metadef/register/op_tiling.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OP_TILING_H +#define INC_REGISTER_OP_TILING_H + +#include "graph/debug/ge_attr_define.h" +#include "graph/node.h" +#include "register/op_tiling_registry.h" + +namespace optiling { +extern "C" ge::graphStatus OpParaCalculateV2(const ge::Operator &op, OpRunInfoV2 &run_info); +extern "C" ge::graphStatus OpAtomicCalculateV2(const ge::Node &node, OpRunInfoV2 &run_info); +extern "C" ge::graphStatus OpFftsPlusCalculate(const ge::Operator &op, std::vector &op_run_info); +} // namespace optiling +#endif // INC_REGISTER_OP_TILING_H diff --git a/inc/metadef/register/op_tiling/op_compile_info_manager.h b/inc/metadef/register/op_tiling/op_compile_info_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..5d0ecf6e360c4e54acd4d8d448abadf522b9ea13 --- /dev/null +++ b/inc/metadef/register/op_tiling/op_compile_info_manager.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_OP_TILING_COMPILE_INFO_MANAGER_H_ +#define REGISTER_OP_TILING_COMPILE_INFO_MANAGER_H_ + +#include +#include +#include + +#include "register/op_compile_info_base.h" + +namespace optiling { +class CompileInfoManager { +public: + CompileInfoManager(const CompileInfoManager &) = delete; + CompileInfoManager &operator=(const CompileInfoManager &) = delete; + static CompileInfoManager& Instance(); + bool HasCompileInfo(const std::string &key); + CompileInfoPtr GetCompileInfo(const std::string &key); + void SetCompileInfo(const std::string &key, CompileInfoPtr compile_info_ptr); + +private: + CompileInfoManager(); + ~CompileInfoManager(); + mutable std::mutex compile_info_mutex_; + std::unordered_map compile_info_map_; +}; + +class CompileInfoCache { +public: + CompileInfoCache(const CompileInfoCache &) = delete; + CompileInfoCache &operator=(const CompileInfoCache &) = delete; + static CompileInfoCache& Instance(); + bool HasCompileInfo(const std::string &key); + void* GetCompileInfo(const std::string &key); + void SetCompileInfo(const std::string &key, void* value); + +private: + CompileInfoCache(); + ~CompileInfoCache(); + mutable std::mutex compile_info_mutex_; + std::unordered_map compile_info_map_; +}; +} // namespace optiling +#endif // REGISTER_OP_TILING_COMPILE_INFO_MANAGER_H_ diff --git a/inc/metadef/register/op_tiling/op_tiling_constants.h b/inc/metadef/register/op_tiling/op_tiling_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..9b088288b6a3838da67e9640020b3df1c8849d4f --- /dev/null +++ b/inc/metadef/register/op_tiling/op_tiling_constants.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_OP_TILING_OP_TILING_CONSTANTS_H_ +#define REGISTER_OP_TILING_OP_TILING_CONSTANTS_H_ + +#include +#include +#include "graph/types.h" + +namespace optiling { +const std::string COMPILE_INFO_JSON = "compile_info_json"; +const std::string COMPILE_INFO_KEY = "compile_info_key"; +const std::string COMPILE_INFO_WORKSPACE_SIZE_LIST = "_workspace_size_list"; +const std::string ATOMIC_COMPILE_INFO_JSON = "_atomic_compile_info_json"; +const std::string ATOMIC_COMPILE_INFO_KEY = "_atomic_compile_info_key"; +const std::string ATTR_NAME_ATOMIC_CLEAN_WORKSPACE = "_optiling_atomic_add_mem_size"; +const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; +const std::string OP_TYPE_DYNAMIC_ATOMIC_ADDR_CLEAN = "DynamicAtomicAddrClean"; +const std::string OP_TYPE_AUTO_TILING = "AutoTiling"; +constexpr const char* kMemoryCheck = "_memcheck"; +constexpr const char* kOriOpParaSize = "ori_op_para_size"; +const std::map DATATYPE_STRING_MAP { + {ge::DT_FLOAT, "float32"}, + {ge::DT_FLOAT16, "float16"}, + {ge::DT_INT8, "int8"}, + {ge::DT_INT16, "int16"}, + {ge::DT_INT32, "int32"}, + {ge::DT_INT64, "int64"}, + {ge::DT_UINT8, "uint8"}, + {ge::DT_UINT16, "uint16"}, + {ge::DT_UINT32, "uint32"}, + {ge::DT_UINT64, "uint64"}, + {ge::DT_BOOL, "bool"}, + {ge::DT_DOUBLE, "double"}, + {ge::DT_DUAL, "dual"}, + {ge::DT_DUAL_SUB_INT8, "dual_sub_int8"}, + {ge::DT_DUAL_SUB_UINT8, "dual_sub_uint8"} +}; + +} // namespace optiling + +#endif // REGISTER_OP_TILING_OP_TILING_CONSTANTS_H_ diff --git a/inc/metadef/register/op_tiling/op_tiling_utils.h b/inc/metadef/register/op_tiling/op_tiling_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3383643493f664a0b17d3f3d4a1858ac90413256 --- /dev/null +++ b/inc/metadef/register/op_tiling/op_tiling_utils.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_OP_TILING_OP_TILING_UTILS_H_ +#define REGISTER_OP_TILING_OP_TILING_UTILS_H_ + +#include +#include +#include "graph/op_desc.h" +#include "graph/debug/ge_log.h" + +namespace optiling { +void ReplaceEmptyShapeOfTensorDesc(const ge::OpDescPtr &op_desc, std::vector &indexes); +void RecoveryEmptyShapeOfTensorDesc(const ge::OpDescPtr &op_desc, const std::vector &indexes); + +#define OP_TILING_MAKE_SHARED(exec_expr0, exec_expr1) \ + do { \ + try { \ + exec_expr0; \ + } catch (...) { \ + GE_LOGE("Make shared failed"); \ + exec_expr1; \ + } \ + } while (0) + +} // namespace optiling +#endif // REGISTER_OP_TILING_OP_TILING_UTILS_H_ diff --git a/inc/metadef/register/ops_kernel_builder_registry.h b/inc/metadef/register/ops_kernel_builder_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..b282bf02e8cc8976d512cd25a78be8e33d319a28 --- /dev/null +++ b/inc/metadef/register/ops_kernel_builder_registry.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H +#define INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H + +#include +#include "register/register_types.h" +#include "common/opskernel/ops_kernel_builder.h" + +namespace ge { +using OpsKernelBuilderPtr = std::shared_ptr; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpsKernelBuilderRegistry { + public: + ~OpsKernelBuilderRegistry() noexcept; + static OpsKernelBuilderRegistry &GetInstance(); + + void Register(const std::string &lib_name, const OpsKernelBuilderPtr &instance); + + void Unregister(const std::string &lib_name); + + void UnregisterAll(); + + const std::map &GetAll() const; + + private: + OpsKernelBuilderRegistry() = default; + std::map kernel_builders_; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpsKernelBuilderRegistrar { + public: + using CreateFn = OpsKernelBuilder *(*)(); + OpsKernelBuilderRegistrar(const std::string &kernel_lib_name, const CreateFn fn); + ~OpsKernelBuilderRegistrar() noexcept; + +private: + std::string kernel_lib_name_; +}; +} // namespace ge + +#define REGISTER_OPS_KERNEL_BUILDER(kernel_lib_name, builder) \ + REGISTER_OPS_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_lib_name, builder) + +#define REGISTER_OPS_KERNEL_BUILDER_UNIQ_HELPER(ctr, kernel_lib_name, builder) \ + REGISTER_OPS_KERNEL_BUILDER_UNIQ(ctr, kernel_lib_name, builder) + +#define REGISTER_OPS_KERNEL_BUILDER_UNIQ(ctr, kernel_lib_name, builder) \ + static ::ge::OpsKernelBuilderRegistrar register_op_kernel_builder_##ctr \ + __attribute__((unused)) = \ + ::ge::OpsKernelBuilderRegistrar((kernel_lib_name), []()->::ge::OpsKernelBuilder* { \ + return new (std::nothrow) (builder)(); \ + }) + +#endif // INC_REGISTER_OPS_KERNEL_BUILDER_REGISTRY_H diff --git a/inc/metadef/register/prototype_pass_registry.h b/inc/metadef/register/prototype_pass_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..c90f3e9f8adda813287d88d7d45cd36c1048e1ce --- /dev/null +++ b/inc/metadef/register/prototype_pass_registry.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_PROTOTYPE_PASS_REGISTRY_H +#define METADEF_PROTOTYPE_PASS_REGISTRY_H + +#include + +#include +#include +#include + +#include "external/ge_common/ge_api_error_codes.h" +#include "external/graph/types.h" +#include "register/register_error_codes.h" +#include "register/register_fmk_types.h" + +namespace ge { +class ProtoTypeBasePass { + public: + ProtoTypeBasePass() = default; + virtual Status Run(google::protobuf::Message *message) = 0; + virtual ~ProtoTypeBasePass() = default; + + private: + ProtoTypeBasePass(const ProtoTypeBasePass &) = delete; + ProtoTypeBasePass &operator=(const ProtoTypeBasePass &) & = delete; +}; + +class ProtoTypePassRegistry { + public: + using CreateFn = std::function; + ~ProtoTypePassRegistry(); + + static ProtoTypePassRegistry &GetInstance(); + + void RegisterProtoTypePass(const char_t *const pass_name, const CreateFn &create_fn, + const domi::FrameworkType fmk_type); + + std::vector> GetCreateFnByType(const domi::FrameworkType fmk_type) const; + + private: + ProtoTypePassRegistry(); + class ProtoTypePassRegistryImpl; + std::unique_ptr impl_; +}; + +class ProtoTypePassRegistrar { + public: + ProtoTypePassRegistrar(const char_t *const pass_name, ProtoTypeBasePass *(*const create_fn)(), + const domi::FrameworkType fmk_type); + ~ProtoTypePassRegistrar() = default; +}; +} // namespace ge + +#define REGISTER_PROTOTYPE_PASS(pass_name, pass, fmk_type) \ + REGISTER_PROTOTYPE_PASS_UNIQ_HELPER(__COUNTER__, pass_name, pass, fmk_type) + +#define REGISTER_PROTOTYPE_PASS_UNIQ_HELPER(ctr, pass_name, pass, fmk_type) \ + REGISTER_PROTOTYPE_PASS_UNIQ(ctr, pass_name, pass, fmk_type) + +#define REGISTER_PROTOTYPE_PASS_UNIQ(ctr, pass_name, pass, fmk_type) \ + static ::ge::ProtoTypePassRegistrar register_prototype_pass##ctr __attribute__((unused)) = \ + ::ge::ProtoTypePassRegistrar( \ + (pass_name), []()->::ge::ProtoTypeBasePass * { return new (std::nothrow) pass(); }, (fmk_type)) + +#endif // METADEF_PROTOTYPE_PASS_REGISTRY_H diff --git a/inc/metadef/register/register.h b/inc/metadef/register/register.h new file mode 100644 index 0000000000000000000000000000000000000000..edc01187c91335a9deeb3226b897c8b24cbab2e0 --- /dev/null +++ b/inc/metadef/register/register.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_REGISTRY_H_ +#define INC_REGISTER_REGISTRY_H_ + +#include "external/register/register.h" +#include "external/ge_common/ge_api_error_codes.h" +#include "graph/ge_error_codes.h" + +namespace ge { +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { + public: + HostCpuOp() = default; + HostCpuOp(HostCpuOp &&) = delete; + HostCpuOp &operator=(HostCpuOp &&) & = delete; + virtual ~HostCpuOp() = default; + virtual graphStatus Compute(Operator &op, + const std::map &inputs, + std::map &outputs) = 0; + + private: + HostCpuOp(const HostCpuOp &) = delete; + HostCpuOp &operator=(const HostCpuOp &) & = delete; +}; + +class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { + public: + HostCpuOpRegistrar(const char_t *const op_type, HostCpuOp *(*const create_fn)()); + ~HostCpuOpRegistrar() = default; +}; +} // namespace ge + +#define REGISTER_HOST_CPU_OP_BUILDER(name, op) \ + REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) \ + REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) + +#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ + static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr \ + __attribute__((unused)) = \ + ::ge::HostCpuOpRegistrar((name), []()->::ge::HostCpuOp* { \ + return new (std::nothrow) (op)(); \ + }) + +#endif // INC_REGISTER_REGISTRY_H_ diff --git a/inc/metadef/register/register_utils.h b/inc/metadef/register/register_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..e4865ef073c44678d9d72efdd8129d8c335edc12 --- /dev/null +++ b/inc/metadef/register/register_utils.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_REGISTER_UTILS_H +#define REGISTER_REGISTER_UTILS_H + +#include +#include "external/register/register_types.h" +#include "external/register/register_error_codes.h" +#include "external/register/register.h" +#include "external/graph/operator.h" + +namespace domi { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OperatorAutoMapping(const Message *op_src, ge::Operator &op); +} // namespace domi +#endif // REGISTER_REGISTER_UTILS_H diff --git a/inc/metadef/register/scope/scope_graph_impl.h b/inc/metadef/register/scope/scope_graph_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..90fa5d86478a0e929e0ac5eef8622f944d42c142 --- /dev/null +++ b/inc/metadef/register/scope/scope_graph_impl.h @@ -0,0 +1,194 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H +#define REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H + +#include "external/register/scope/scope_fusion_pass_register.h" +#include "external/graph/types.h" +#include "graph/operator_factory.h" +#include "proto/tensorflow/graph.pb.h" +#include "proto/tensorflow/node_def.pb.h" +#include "graph/utils/type_utils.h" + +namespace ge { +using FusionInnerNodesInfo = std::vector>, // inputs + std::vector>, // outputs + const ge::Operator *>>; // operator + +class Scope::ScopeImpl { + public: + ScopeImpl() = default; + Status Init(const std::string &name, const std::string &sub_type = "", Scope *const father_scope = nullptr); + ~ScopeImpl(); + + const std::string &Name() const { return name_; } + const std::string &SubType() const { return sub_type_; } + void SetSubType(const std::string &sub_type) { sub_type_ = sub_type; } + void ClearTypeAndSubType(); + void AddNode(ge::OperatorPtr &node_def); + const std::vector &Nodes() const { return nodes_; } + const std::unordered_map &AllNodesMap(); + void AddSubScope(Scope *const scope); + Scope *GetSubScope(const std::string &scope_name) const; + const std::unordered_map &GetSubScopes() const { return sub_scopes_; } + const std::vector &GetAllSubScopes(); + int32_t GetOpTypeNum(const std::string &op_type) const; + void OpsNumInc(const std::string &op_type); + const std::string LastName() const; + const Scope *GetFatherScope() const { return father_scope_; } + // trim scope_index + static std::string TrimScopeIndex(const std::string &scope_name); + + private: + std::string name_; + std::string sub_type_; + Scope *father_scope_ = nullptr; + std::map op_nums_; + std::unordered_map sub_scopes_; + std::vector nodes_; + std::unordered_map all_nodes_map_; + std::map all_nodes_map_new_; + std::vector all_sub_scopes_; +}; + +class FusionScopesResult::InnerNodeInfo::InnerNodeInfoImpl { + public: + explicit InnerNodeInfoImpl(const std::string &fusion_node_name) : fusion_node_name_(fusion_node_name) {} + InnerNodeInfoImpl(const std::string &fusion_node_name, const std::string &name, const std::string &type) + : fusion_node_name_(fusion_node_name), name_(name), type_(type) { + SetName(name); + } + ~InnerNodeInfoImpl() noexcept; + std::string GetFullNodeName(const std::string &relative_name); + void SetName(const std::string &name) { name_ = GetFullNodeName(name); } + void SetType(const std::string &type) { type_ = type; } + void InsertInput(const std::string &input_node, int32_t peer_out_idx); + void InsertOutput(const std::string &output_node, int32_t peer_in_idx); + ge::graphStatus BuildOperator(); + ge::graphStatus SetInputFormat(const std::string &input_name, const std::string &format) ; + ge::graphStatus SetOutputFormat(const std::string &output_name, const std::string &format); + ge::graphStatus SetDynamicInputFormat(const std::string &input_name, const uint32_t index, const std::string &format); + ge::graphStatus SetDynamicOutputFormat(const std::string &output_name, + const uint32_t index, + const std::string &format); + std::string GetName() const { return name_; } + std::string GetType() const { return type_; } + std::vector> GetInputs() const { return inner_node_inputs_; } + std::vector> GetOutputs() const { return inner_node_outputs_; } + ge::Operator *MutableOperator() { return &operator_; } + + private: + ge::Operator operator_; + std::string fusion_node_name_; + std::string name_; + std::string type_; + std::vector> inner_node_inputs_; + std::vector> inner_node_outputs_; +}; + +class FusionScopesResult::FusionScopesResultImpl { + public: + FusionScopesResultImpl() {} + ~FusionScopesResultImpl() = default; + void SetName(const std::string &name) { name_ = name; } + void SetType(const std::string &type) { type_ = type; } + void SetDescription(const std::string &description) { description_ = description; } + const std::string &Name() const { return name_; } + const std::string &Type() const { return type_; } + const std::string &Description() const { return description_; } + void AddNodes(const std::vector &nodes); + const std::vector &Nodes() const { return nodes_; } + void AddScopes(const std::vector &scopes) { + (void)scopes_.insert(scopes_.cend(), scopes.cbegin(), scopes.cend()); + } + const std::vector &Scopes() const { return scopes_; } + const std::map> &GetInputs() const { return inputs_; } + const std::map> &GetOutputs() const { return outputs_; } + void InsertInputs(const std::string &inner_op_name, const std::vector &index_map); + void InsertOutputs(const std::string &inner_op_name, const std::vector &index_map); + bool FindNodes(const std::string &node_name) const; + bool FindScopes(const std::string &scope_name) const; + + InnerNodeInfo *AddInnerNode(const std::string &name, const std::string &type); + InnerNodeInfo *MutableRecentInnerNode(); + InnerNodeInfo *MutableInnerNode(uint32_t index); + FusionInnerNodesInfo GetInnerNodesInfo(); + ge::graphStatus CheckInnerNodesInfo(); + + private: + std::string name_; + std::string type_; + std::string description_; + std::vector scopes_; + std::vector nodes_; + std::map> inputs_; + std::map> outputs_; + std::vector inner_node_infos_; +}; + +class ScopeTree::ScopeTreeImpl { + public: + ScopeTreeImpl() = default; + ScopeTreeImpl(const ScopeTreeImpl &) = delete; + ScopeTreeImpl &operator=(const ScopeTreeImpl &) & = delete; + Status Init(); + ~ScopeTreeImpl(); + + void AddNodeToScope(ge::OperatorPtr &node_def); + const std::vector &GetAllScopes() const { return scopes_; } + const Scope *Root() const { return root_; } + + private: + std::vector SplitNodeName(const std::string &node_name, const char_t delim) const; + Scope *root_ = nullptr; + std::vector scopes_; +}; + +struct ScopeFusionOpInfo { + std::string node_name; + std::string fusion_node_name; + std::string fusion_op_type; + std::string description; + bool scope_pass = true; +}; + +class ScopeGraph::ScopeGraphImpl { + public: + ScopeGraphImpl() = default; + ScopeGraphImpl(const ScopeGraphImpl &) = delete; + ScopeGraphImpl &operator=(const ScopeGraphImpl &) & = delete; + Status Init(); + ~ScopeGraphImpl() noexcept; + + const ScopeTree *GetScopeTree() const { return scope_tree_; } + void BuildScopeGraph(domi::tensorflow::GraphDef *graph_def); + void AddFusionScopesResult(FusionScopesResult *result); + const std::unordered_map &FusionScopesResults() const { return fusion_results_; } + FusionScopesResult *GetFusionScopesResults(const domi::tensorflow::NodeDef *const node_def) const; + FusionScopesResult *GetFusionScopesResults(const std::string &node_name) const; + const std::unordered_map &GetNodesMap() const { return nodes_map_; } + const std::map &GetNodesMapNew() const { return nodes_map_new_; } + bool IsFusionOpChild(const std::string &node_name, std::vector &info_list); + bool FusionOpChildIgnore(const ScopeFusionOpInfo &info); + bool IsFusionOp(const domi::tensorflow::NodeDef *const node_def); + Status GetInputOrOutputIndex(const ScopeFusionOpInfo &info, const int32_t old_index, + const bool input, int32_t &new_index); + + private: + std::vector GetFusionResultInputOrOutput(const ScopeFusionOpInfo &info, + const bool input); // input:true,output:false + std::unordered_map fusion_results_; + std::unordered_map nodes_map_; + std::map nodes_map_new_; + ScopeTree *scope_tree_ = nullptr; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_GRAPH_IMPL_H diff --git a/inc/metadef/register/scope/scope_pass_impl.h b/inc/metadef/register/scope/scope_pass_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..5b0654b3234ec4660c9299b7eed9758976e566f2 --- /dev/null +++ b/inc/metadef/register/scope/scope_pass_impl.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_SCOPE_SCOPE_PASS_IMPL_H +#define REGISTER_SCOPE_SCOPE_PASS_IMPL_H + +#include "external/register/scope/scope_fusion_pass_register.h" + +namespace ge { +class ScopesResult::ScopesResultImpl { + public: + void SetScopes(const std::vector &scopes) { scopes_ = scopes; } + const std::vector &GetScopes() const { return scopes_; } + void SetNodes(const std::vector &nodes) { nodes_ = nodes; } + const std::vector &GetNodes() const { return nodes_; } + + private: + std::vector scopes_; // multiple scopes + std::vector nodes_; // op outside of scope +}; + +class ScopeBasePass::ScopeBasePassImpl { + public: + explicit ScopeBasePassImpl(ScopeBasePass *const parent) : parent_(parent) {} + virtual ~ScopeBasePassImpl(); + + Status Run(std::shared_ptr &scope_graph); + + private: + Status AddFusionScopesResultToScopeGraph(const std::shared_ptr &scope_graph, + std::vector &scope_results) const; + // Match rules one by one, support multiple sets of matching rules, and finally output a single scope + // Note: This function does not have to be rewritten. + // In order to match the fusion rules designed by you better, + // you can implement your specific versions separately. + bool MatchAllBatches(const ScopeTree *scope_tree, std::vector &results); + + bool MatchOneBatch(const ScopeTree *const scope_tree, const std::vector &patternlist, + std::vector &results) const; + bool MatchOneScope(const ScopePattern *pattern, Scope *scope, std::vector &results) const; + Status PrintFusionScopeInfo(std::shared_ptr &scope_graph) const; + + private: + std::vector patterns_; + ScopeBasePass *parent_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_PASS_IMPL_H diff --git a/inc/metadef/register/scope/scope_pass_registry_impl.h b/inc/metadef/register/scope/scope_pass_registry_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..46636a6869adb5dbd883a605b2cd4fee183d7f02 --- /dev/null +++ b/inc/metadef/register/scope/scope_pass_registry_impl.h @@ -0,0 +1,33 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_SCOPE_SCOPE_PASS_REGISTRY_IMPL_H +#define REGISTER_SCOPE_SCOPE_PASS_REGISTRY_IMPL_H + +#include +#include "external/register/scope/scope_fusion_pass_register.h" + +namespace ge { +struct CreatePassFnPack; +class ScopeFusionPassRegistry::ScopeFusionPassRegistryImpl { + public: + void RegisterScopeFusionPass(const std::string &pass_name, ScopeFusionPassRegistry::CreateFn create_fn, + bool is_general); + ScopeFusionPassRegistry::CreateFn GetCreateFn(const std::string &pass_name); + std::unique_ptr CreateScopeFusionPass(const std::string &pass_name); + std::vector GetAllRegisteredPasses(); + bool SetPassEnableFlag(const std::string pass_name, const bool flag); + + private: + std::mutex mu_; + std::vector pass_names_; // In the order of user registration + std::map create_fn_packs_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_PASS_REGISTRY_IMPL_H diff --git a/inc/metadef/register/scope/scope_pattern_impl.h b/inc/metadef/register/scope/scope_pattern_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..70ad7fbeddbd5d94678b17f497e194d562ff6559 --- /dev/null +++ b/inc/metadef/register/scope/scope_pattern_impl.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H +#define REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H + +#include +#include +#include +#include "external/register/scope/scope_fusion_pass_register.h" +#include "external/graph/types.h" +#include "graph/compute_graph.h" + +namespace ge { +constexpr float32_t kCompareRatio = 2.0F; +class ScopeAttrValue::ScopeAttrValueImpl { + public: + ScopeAttrValueImpl() = default; + ~ScopeAttrValueImpl() = default; + + void SetIntValue(const int64_t value) { int_value_ = value; } + void SetFloatValue(const float32_t value) { float_value_ = value; } + void SetStringValue(const std::string &value) { string_value_ = value; } + void SetBoolValue(const bool value) { bool_value_ = value; } + const int64_t &GetIntValue() const { return int_value_; } + const float32_t &GetFloatValue() const { return float_value_; } + const std::string &GetStrValue() const { return string_value_; } + const bool &GetBoolValue() const { return bool_value_; } + + private: + int64_t int_value_ = 0; + float32_t float_value_ = 0.0F; + std::string string_value_ = ""; + bool bool_value_ = false; +}; + + +class NodeOpTypeFeature::NodeOpTypeFeatureImpl : ScopeBaseFeature { + public: + NodeOpTypeFeatureImpl(const std::string &nodeType, const int64_t num, const int64_t step) + : ScopeBaseFeature(), node_type_(nodeType), num_(num), step_(step) {} + ~NodeOpTypeFeatureImpl() override = default; + bool Match(const Scope *const scope) override; + +private: + std::string node_type_; // Node type + int64_t num_; // Node number + int64_t step_; // step + friend class NodeOpTypeFeature; +}; + +class NodeAttrFeature::NodeAttrFeatureImpl : ScopeBaseFeature { + public: + NodeAttrFeatureImpl(const std::string &nodeType, const std::string &attr_name, const ge::DataType datatype, + const ScopeAttrValue &attr_value) + : ScopeBaseFeature(), node_type_(nodeType), attr_name_(attr_name), datatype_(datatype), + attr_value_(attr_value) {} + ~NodeAttrFeatureImpl() override = default; + bool Match(const Scope *scope) override; + Status CheckNodeAttrFeatureData(const bool init_value, const ge::OpDescPtr &op_desc, const Scope *const scope); + Status CheckNodeAttrFeatureData(const std::string &init_value, + const ge::OpDescPtr &op_desc, const Scope *const scope); + Status CheckNodeAttrFeatureData(const int64_t init_value, const ge::OpDescPtr &op_desc, const Scope *const scope); + Status CheckNodeAttrFeatureData(const float32_t init_value, const ge::OpDescPtr &op_desc, const Scope *const scope); + template + typename std::enable_if::is_integer, bool>::type FloatIsEqual(const T x, const T y) const + { + // It is used for floating point comparisons. + // It mainly uses relative precision to judge whether floating-point numbers are equal. + // the 2 is ULPs + return (std::fabs(x - y) <= (std::numeric_limits::epsilon() * std::fabs(x + y) * kCompareRatio)) || + (std::fabs(x - y) < std::numeric_limits::min()); + } + +private: + std::string node_type_; // Node type + std::string attr_name_; // attribute name + ge::DataType datatype_; // datatype + ScopeAttrValue attr_value_; // AttrValue + friend class NodeAttrFeature; +}; + +class ScopeFeature::ScopeFeatureImpl : ScopeBaseFeature { + public: + ScopeFeatureImpl(const std::string &sub_type, const int32_t num, const std::string &suffix, + const std::string &sub_scope_mask, const int64_t step) + : ScopeBaseFeature(), sub_type_(sub_type), num_(num), suffix_(suffix), sub_scope_mask_(sub_scope_mask), + step_(step) {} + ~ScopeFeatureImpl() override = default; + bool Match(const Scope *const scope) override; + bool SubScopesMatch(const std::vector &scopes); + + private: + std::string sub_type_; + int32_t num_; + std::string suffix_; + std::string sub_scope_mask_; + int64_t step_; + friend class ScopeFeature; +}; + +class ScopePattern::ScopePatternImpl { + public: + ScopePatternImpl() {} + ~ScopePatternImpl() = default; + bool Match(const Scope *scope) const; + void SetSubType(const std::string &sub_type); + const std::string &SubType() const { return sub_type_; } + void AddNodeOpTypeFeature(NodeOpTypeFeature &feature); + void AddNodeAttrFeature(NodeAttrFeature &feature); + void AddScopeFeature(ScopeFeature &feature); + + private: + std::string sub_type_; // get Scope sub type + std::vector node_optype_features_; + std::vector node_attr_features_; + std::vector scopes_features_; +}; +} // namespace ge +#endif // REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H diff --git a/inc/metadef/register/shape_inference.h b/inc/metadef/register/shape_inference.h new file mode 100644 index 0000000000000000000000000000000000000000..01d7be699127a1886027669d649d38f60aabbe70 --- /dev/null +++ b/inc/metadef/register/shape_inference.h @@ -0,0 +1,19 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_INFERENCE_H_ +#define METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_INFERENCE_H_ + +#include "graph/op_desc.h" +namespace gert { +extern ge::graphStatus InferDataTypeOnCompile(const ge::OpDescPtr &op_desc); +extern ge::graphStatus InferShapeRangeOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc); +extern ge::graphStatus InferShapeOnCompile(const ge::Operator &op, const ge::OpDescPtr &op_desc); +} // namespace gert +#endif // METADEF_CXX_INC_EXE_GRAPH_RUNTIME_SHAPE_INFERENCE_H_ diff --git a/inc/metadef/register/stream_manage_func_registry.h b/inc/metadef/register/stream_manage_func_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..b053d633f8992c541ea83357040dd7ec4ac516f6 --- /dev/null +++ b/inc/metadef/register/stream_manage_func_registry.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_REGISTER_STREAM_MANAGE_FUNC_REGISTRY_H +#define INC_REGISTER_STREAM_MANAGE_FUNC_REGISTRY_H + +#include +#include +#include "runtime/rt.h" +#include "common/ge_common/debug/ge_log.h" + +namespace ge { +// acl action types +enum class MngActionType : uint32_t { + DESTROY_STREAM, + DESTROY_CONTEXT, + RESET_DEVICE, +}; + +typedef union { + rtStream_t stream; + rtContext_t context; + int32_t device_id; +} MngResourceHandle; + +enum class StreamMngFuncType : uint32_t { + ACLNN_STREAM_CALLBACK, // aclnn callback function for destroying sub-stream +}; + +using StreamMngFunc = uint32_t (*)(MngActionType action_type, MngResourceHandle handle); + +class StreamMngFuncRegistry { + public: + static StreamMngFuncRegistry &GetInstance(); + Status TryCallStreamMngFunc(const StreamMngFuncType func_type, MngActionType action_type, MngResourceHandle handle); + void Register(const StreamMngFuncType func_type, StreamMngFunc const manage_func); + StreamMngFunc LookUpStreamMngFunc(const StreamMngFuncType func_type); + + StreamMngFuncRegistry(const StreamMngFuncRegistry &other) = delete; + StreamMngFuncRegistry &operator=(const StreamMngFuncRegistry &other) = delete; + + private: + StreamMngFuncRegistry() = default; + ~StreamMngFuncRegistry() = default; + + std::mutex mutex_; + std::map type_to_func_; +}; + +class StreamMngFuncRegister { + public: + StreamMngFuncRegister(const StreamMngFuncType func_type, StreamMngFunc const manage_func); +}; +} // namespace ge + +#ifdef __GNUC__ +#define ATTRIBUTE_USED __attribute__((used)) +#else +#define ATTRIBUTE_USED +#endif +#define REG_STREAM_MNG_FUNC(type, func) REG_STREAM_MNG_FUNC_UNIQ_HELPER(type, func, __COUNTER__) +#define REG_STREAM_MNG_FUNC_UNIQ_HELPER(type, func, counter) REG_STREAM_MNG_FUNC_UNIQ(type, func, counter) +#define REG_STREAM_MNG_FUNC_UNIQ(type, func, counter) \ + static ::ge::StreamMngFuncRegister register_stream_mng_func_##counter ATTRIBUTE_USED = \ + ge::StreamMngFuncRegister(type, func) + +#endif // INC_REGISTER_STREAM_MANAGE_FUNC_REGISTRY_H diff --git a/inc/metadef/register/tensor_assign.h b/inc/metadef/register/tensor_assign.h new file mode 100644 index 0000000000000000000000000000000000000000..d3e08360f6ef7496e473e107b72a7ae9c583acbe --- /dev/null +++ b/inc/metadef/register/tensor_assign.h @@ -0,0 +1,133 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef TENSOR_ASSIGN_H +#define TENSOR_ASSIGN_H + +#include +#include "graph/ge_tensor.h" +#include "graph/def_types.h" +#include "common/checker.h" +#include "common/ge_common/debug/ge_log.h" +#include "external/register/register_error_codes.h" +#include "external/utils/extern_math_util.h" +#include "proto/tensorflow/tensor.pb.h" + +namespace domi { +using GeTensorPtr = std::shared_ptr; +using Status = uint32_t; +constexpr int64_t kComplexWidth = 2; + +class TensorAssign { + public: + static Status SetGeTensor(const domi::tensorflow::TensorProto &tensor, GeTensorPtr &weight); + + static Status SetGeTensorDataType(const int64_t data_type, GeTensorPtr &weight); + + static ge::DataType ConvertTensorflowDataType(const uint32_t tf_data_type); + + private: + static bool CheckBoolVal(const tensorflow::DataType data_type); + + static bool CheckHalfVal(const tensorflow::DataType data_type); + + static bool CheckFloatVal(const tensorflow::DataType data_type); + + static bool CheckDoubleVal(const tensorflow::DataType data_type); + + static bool CheckComplex32Val(const tensorflow::DataType data_type); + + static bool CheckComplex64Val(const tensorflow::DataType data_type); + + static bool CheckComplex128Val(const tensorflow::DataType data_type); + + static bool CheckStringVal(const tensorflow::DataType data_type); + + static bool CheckByte(const tensorflow::DataType data_type); + + static bool CheckDoubleByte(const tensorflow::DataType data_type); + + static bool CheckSignedFourByte(const tensorflow::DataType data_type); + + static bool CheckUnsignedFourByte(const tensorflow::DataType data_type); + + static bool CheckSignedEightByte(const tensorflow::DataType data_type); + + static bool CheckUnsignedEightByte(const tensorflow::DataType data_type); + + static Status GetDoubleByteVal(const int64_t val_size, + const google::protobuf::RepeatedField &val_vector, + const int64_t count, GeTensorPtr &weight); + static Status GetByteVal(const int64_t val_size, + const google::protobuf::RepeatedField &val_vector, + const int64_t count, GeTensorPtr &weight); + + static Status GetStringVal(const int64_t val_size, const google::protobuf::RepeatedPtrField &val_vector, + const int64_t count, GeTensorPtr &weight); + + static void SetGeTensorWeightData(const domi::tensorflow::TensorProto &tensor, const int64_t val_size, + const int64_t count, GeTensorPtr &weight); + + static void SetWeightData(const tensorflow::DataType data_type, const int64_t count, + const std::string &tensor_content, GeTensorPtr &weight); + + template + static Status GetVal(const int64_t val_size, const google::protobuf::RepeatedField &val_vector, + const int64_t count, GeTensorPtr &weight, const bool is_complex = false) { + // val_size must be even, and complex value should be an integer multiple of 2 + if (is_complex && ((val_size % kComplexWidth) != 0)) { + GELOGE(FAILED, "complex value should be an integer multiple of 2."); + return FAILED; + } + const std::unique_ptr addr(new (std::nothrow) T[count]()); // Zero init default value + GE_CHECK_NOTNULL(addr); + if (val_size == 0) { + (void)weight->SetData(ge::PtrToPtr(addr.get()), static_cast(count) * sizeof(T)); + return SUCCESS; + } + // Complex numbers are made up of real and imaginary numbers + const bool zerosLike = ((count != val_size) && ((val_size == 1) || (is_complex && (val_size == 2)))); + if ((!zerosLike) && (val_size <= count)) { + for (size_t i = 0UL; i < static_cast(val_size); i++) { + addr[i] = val_vector.Get(static_cast(i)); + } + const int64_t value_r = val_size - 1; + GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_r), true); + if (is_complex) { + // val_vector format is real value, complex value..., here is getting the corresponding value. + // real value and complex value are stored spaced apart, so use 2 and 1 to store in the correct addr. + const int64_t value_l = val_size - kComplexWidth; + GE_ASSERT_EQ(ge::IntegerChecker::Compat(value_l), true); + for (int64_t i = val_size; i < count; i += kComplexWidth) { + addr[static_cast(i)] = val_vector.Get(static_cast(value_l)); + addr[static_cast(i) + 1UL] = val_vector.Get(static_cast(value_r)); + } + } else { + for (int64_t i = val_size; i < count; i++) { + addr[static_cast(i)] = val_vector.Get(static_cast(value_r)); + } + } + } else { + if (is_complex) { + for (int64_t i = 0; i < count; i += kComplexWidth) { + addr[static_cast(i)] = val_vector.Get(0); + addr[static_cast(i) + 1UL] = val_vector.Get(1); + } + } else { + for (int64_t i = 0; i < count; i++) { + addr[static_cast(i)] = val_vector.Get(0); + } + } + } + (void)weight->SetData(ge::PtrToPtr(addr.get()), static_cast(count) * sizeof(T)); + return SUCCESS; + } +}; +} // namespace domi +#endif // TENSOR_ASSIGN_H diff --git a/inc/metadef/transformer/axis_util.h b/inc/metadef/transformer/axis_util.h new file mode 100644 index 0000000000000000000000000000000000000000..433b1a4c0eb5f873f0d38a12a69d2800448f98e6 --- /dev/null +++ b/inc/metadef/transformer/axis_util.h @@ -0,0 +1,103 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_UTILS_TRANSFORMER_INC_AXIS_UTIL_H_ +#define COMMON_UTILS_TRANSFORMER_INC_AXIS_UTIL_H_ + +#include +#include + +#include "external/graph/types.h" +#include "graph/utils/math_util.h" +#include "exe_graph/runtime/shape.h" + +namespace transformer { +#define CHECK(cond, log_func, return_expr) \ + do { \ + if (cond) { \ + log_func; \ + return_expr; \ + } \ + } while (0) + +#define INT64_ZEROCHECK(a) \ + if (a == 0) { \ + return false; \ + } + +#define MUL_OVERFLOW(x, y, z) \ + if (ge::MulOverflow((x), (y), (z))) { \ + return false; \ + } \ + +enum AxisValueType { + AXIS_N = 0, + AXIS_C = 1, + AXIS_H = 2, + AXIS_W = 3, + AXIS_C1 = 4, + AXIS_C0 = 5, + AXIS_Co = 6, + AXIS_D = 7, + AXIS_G = 8, + AXIS_M0 = 9, + AXIS_INPUT_SIZE = 10, + AXIS_HIDEEN_SIZE = 11, + AXIS_STATE_SIZE = 12, + AXIS_BOTTOM = 13 +}; + +using AxisValue = std::array(AXIS_BOTTOM)>; + +inline int64_t DivisionCeiling(int64_t dividend, int64_t divisor) { + if (divisor == 0) { + return 0; + } else if (dividend < 0) { + return -1; + } else { + return (dividend + divisor - 1) / divisor; + } +} + +class AxisUtil { + public: + AxisUtil() {}; + ~AxisUtil() {}; + static bool GetAxisValueByOriginFormat(const ge::Format &format, const gert::Shape &shape, AxisValue &axis_value); + static int32_t GetAxisIndexByFormat(const ge::Format &format, const string &axis); + static int32_t GetAxisIndexByFormat(const ge::Format &format, const string &axis, + const std::map &valid_axis_map); + static std::vector GetAxisVecByFormat(const ge::Format &format); + static std::vector GetReshapeTypeAxisVec(const ge::Format &format, const int64_t &reshape_type_mask); + static std::map GetReshapeTypeAxisMap(const ge::Format &format, + const int64_t &reshape_type_mask); + static std::vector GetSplitOrConcatAxisByFormat(const ge::Format &format, const std::string &axis); +private: + static bool GetAxisValueByNCHW(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByNHWC(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByHWCN(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByND(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByNDHWC(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByNCDHW(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByDHWCN(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByDHWNC(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByNC1HWC0(const gert::Shape &shape, AxisValue &axis_value); + + static bool GetAxisValueByC1HWNCoC0(const gert::Shape &shape, AxisValue &axis_value); +}; +} // namespace transformer +#endif // COMMON_UTILS_TRANSFORMER_INC_AXIS_UTIL_H_ diff --git a/inc/metadef/transformer/expand_dimension.h b/inc/metadef/transformer/expand_dimension.h new file mode 100644 index 0000000000000000000000000000000000000000..8b0b3b3e8226ce5dbae14a09c760c8a85496330a --- /dev/null +++ b/inc/metadef/transformer/expand_dimension.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ +#define COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ + +#include +#include +#include "graph/types.h" +#include "graph/ge_tensor.h" +#include "exe_graph/runtime/shape.h" + +namespace transformer { +/* Pad dimension according to reshape type */ +bool ExpandDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, + const uint32_t &tensor_index, const std::string &reshape_type, ge::GeShape &shape); + +bool ExpandRangeDimension(const std::string &op_type, const ge::Format &original_format, const ge::Format &final_format, + const uint32_t &tensor_index, const std::string &reshape_type, + std::vector> &ranges); + +class ExpandDimension { +public: + ExpandDimension(); + ~ExpandDimension(); + + static int64_t GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, + const size_t &origin_dim_size, const std::string &reshape_type); + static bool GenerateReshapeType(const ge::Format &origin_format, const ge::Format &format, + const size_t &origin_dim_size, const std::string &reshape_type, + int64_t &reshape_type_mask); + static bool GenerateReshapeTypeByMask(const ge::Format &origin_format, const size_t &origin_dim_size, + const int64_t &reshape_type_mask, std::string &reshape_type, + std::string &failed_reason); + static void ExpandDims(const int64_t &reshape_type, ge::GeShape &shape); + static void ExpandDims(const int64_t &reshape_type, const ge::GeShape &origin_shape, ge::GeShape &shape); + static void ExpandDims(const int64_t &reshape_type, gert::Shape &shape); + static void ExpandDims(const int64_t &reshape_type, const gert::Shape &origin_shape, gert::Shape &shape); + static bool GetDefaultReshapeType(const ge::Format &origin_format, const size_t &origin_dim_size, + std::string &reshape_type); + static int32_t GetAxisIndexByName(char ch, const ge::Format &format); + static int64_t GetReshapeAxicValue(const int64_t &reshape_type_mask, + const ge::GeShape &shape, int32_t axis_index); + static int64_t GetReshapeAxicValueByName(const int64_t &reshape_type_mask, char ch, + const ge::GeShape &shape, const ge::Format &format); + static bool GetFormatFullSize(const ge::Format &format, size_t &full_size); +private: + static bool IsNeedExpand(const ge::Format &origin_format, const ge::Format &format, + const size_t &origin_dim_size, const size_t &full_size, const std::string &reshape_type); + static bool IsReshapeTypeValid(const ge::Format &origin_format, const size_t &origin_dim_size, + const std::string &reshape_type); +}; +} // namespace transformer +#endif //COMMON_UTILS_TRANSFORMER_INC_EXPAND_DIMENSION_H_ diff --git a/inc/metadef/transformer/transfer_range_according_to_format.h b/inc/metadef/transformer/transfer_range_according_to_format.h new file mode 100644 index 0000000000000000000000000000000000000000..10ae453ab99eba381990e3c998631fd2c50db21a --- /dev/null +++ b/inc/metadef/transformer/transfer_range_according_to_format.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_UTILS_TRANSFORMER_INC_TRANSFER_RANGE_ACCORDING_TO_FORMAT_H_ +#define COMMON_UTILS_TRANSFORMER_INC_TRANSFER_RANGE_ACCORDING_TO_FORMAT_H_ + +#include +#include "transfer_shape_according_to_format.h" + +namespace transformer { +struct RangeAndFormatInfo { + ge::GeShape old_shape; + std::vector> old_range; + std::vector> &new_range; + ge::Format old_format; + ge::Format new_format; + ge::DataType current_data_type; + CalcShapeExtraAttr extra_attr; + RangeAndFormatInfo(ge::GeShape old_shape, std::vector> old_range, + std::vector> &new_range, ge::Format old_format, + ge::Format new_format, ge::DataType current_data_type) : + old_shape(old_shape), old_range(old_range), new_range(new_range), old_format(old_format), + new_format(new_format), current_data_type(current_data_type), extra_attr(CalcShapeExtraAttr()) {} + RangeAndFormatInfo(ge::GeShape old_shape, std::vector> old_range, + std::vector> &new_range, ge::Format old_format, + ge::Format new_format, ge::DataType current_data_type, CalcShapeExtraAttr extra_attr) : + old_shape(old_shape), old_range(old_range), new_range(new_range), old_format(old_format), + new_format(new_format), current_data_type(current_data_type), extra_attr(extra_attr) {} +}; + +using RangeAndFormat = struct RangeAndFormatInfo; + +class RangeTransferAccordingToFormat { + public: + RangeTransferAccordingToFormat() = default; + + ~RangeTransferAccordingToFormat() = default; + + RangeTransferAccordingToFormat(const RangeTransferAccordingToFormat &) = delete; + + RangeTransferAccordingToFormat &operator=(const RangeTransferAccordingToFormat &) = delete; + + static bool GetRangeAccordingToFormat(RangeAndFormat &range_and_format_info); + + static bool GetRangeAccordingToFormat(const ge::OpDescPtr &op_desc, RangeAndFormat &range_and_format_info); +}; +} // namespace fe + +#endif // FUSION_ENGINE_OPTIMIZER_GRAPH_OPTIMIZER_RANGE_FORMAT_TRANSFER_TRANSFER_RANGE_ACCORDING_TO_FORMAT_H_ diff --git a/inc/metadef/transformer/transfer_shape_according_to_format.h b/inc/metadef/transformer/transfer_shape_according_to_format.h new file mode 100644 index 0000000000000000000000000000000000000000..842ee2f5cca09b10759b349d6ec8ea2d5fe593a1 --- /dev/null +++ b/inc/metadef/transformer/transfer_shape_according_to_format.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ +#define COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ + +#include +#include "graph/types.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "transfer_shape_utils.h" + +namespace transformer { +struct CalcShapeExtraAttr { + int64_t hidden_size; + int64_t input_size; + int64_t state_size; +}; + +struct ShapeAndFormatInfo { + ge::GeShape &oldShape; + const ge::Format &oldFormat; + const ge::Format &newFormat; + const ge::DataType ¤tDataType; + CalcShapeExtraAttr extra_attr; + ShapeAndFormatInfo(ge::GeShape &old_shape, const ge::Format &old_format, const ge::Format &new_format, + const ge::DataType &data_type) + : oldShape(old_shape), oldFormat(old_format), newFormat(new_format), currentDataType(data_type), + extra_attr({1, 1, -1}) {} +}; + +using ShapeAndFormat = struct ShapeAndFormatInfo; + +class ShapeTransferAccordingToFormat { + public: + ShapeTransferAccordingToFormat(); + + ~ShapeTransferAccordingToFormat() {}; + + ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat&) = delete; + + ShapeTransferAccordingToFormat &operator=(const ShapeTransferAccordingToFormat&) = delete; + + static bool GetShapeAccordingToFormat(ShapeAndFormat &shapeAndFormatInfo); + + static bool GetShapeAccordingToFormat(const ge::OpDescPtr &op_desc, ShapeAndFormat &shapeAndFormatInfo); + + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + const ExtAxisValue &ext_axis, ge::GeShape &shape); + + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + const ExtAxisValue &ext_axis, const ge::GeShape &origin_shape, ge::GeShape &shape); + + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + gert::Shape &shape, const ge::OpDescPtr op_desc = nullptr, const fe::PlatFormInfos *platform_infos_ptr = nullptr); + + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + const gert::Shape &origin_shape, gert::Shape &shape, const ge::OpDescPtr op_desc = nullptr); + + static void InitExtAxisValue(const ge::OpDescPtr &op_desc, ExtAxisValue &ext_axis); +}; +} // namespace transformer +#endif // COMMON_UTILS_TRANSFORMER_INC_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_ diff --git a/inc/metadef/transformer/transfer_shape_utils.h b/inc/metadef/transformer/transfer_shape_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..d4e1630890e58aa5f06e012f02a62c11856af678 --- /dev/null +++ b/inc/metadef/transformer/transfer_shape_utils.h @@ -0,0 +1,222 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef TRANSFORMER_INC_TRANSFER_SHAPE_UTILS_H_ +#define TRANSFORMER_INC_TRANSFER_SHAPE_UTILS_H_ + +#include +#include "axis_util.h" +#include "external/platform/platform_info.h" + + +namespace transformer { +namespace { +const size_t ORIGIN_FORMAT_DIM_SIZE = 5; +const size_t EXT_AXIS_SIZE = 4; +} + +enum class TransferShapeType { + ND_TO_ND = 0, + ND_TO_NZ, + FULL_SIZE, + NOT_FULL_SIZE, + INVALID +}; + +struct AlignShapeInfo { + ge::Format src_format; + ge::Format dst_format; + gert::Shape src_shape; + ge::DataType data_type; + int64_t reshape_type_mask; +}; + +struct TransferDimsInfo { + ge::Format src_format; + ge::Format dst_format; + gert::Shape src_shape; + int64_t reshape_type_mask; +}; + +struct AxisIndexMapping { + std::vector> src_to_dst_transfer_dims; + std::vector> dst_to_src_transfer_dims; +}; + +using FormatIndex = std::array; +using ExtAxisValue = std::array; +using GetAlignedShapeFunc = std::function; +using TransferDimsFunc = std::function; + +class TransferShapeUtils { +public: + TransferShapeUtils() {} + ~TransferShapeUtils() {} + static bool InitPlatformInfo(const fe::PlatFormInfos *platform_infos_ptr = nullptr); + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + const ExtAxisValue &ext_axis, gert::Shape &shape, const fe::PlatFormInfos *platforminfos = nullptr); + static bool TransferShape(const ge::Format &origin_format, const ge::Format &format, const ge::DataType &data_type, + const ExtAxisValue &ext_axis, const gert::Shape &origin_shape, gert::Shape &shape, const fe::PlatFormInfos *platforminfos = nullptr); + static int64_t GetC0ByDtype(const ge::DataType &data_type); + static int64_t GetM0ByDtype(const ge::DataType &data_type); + static int64_t GetN0ByDtype(const ge::DataType &data_type); + static bool GetAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); + static bool TransferDims(const TransferDimsInfo &transfer_dims_info, AxisIndexMapping &axis_index_mapping); + +private: + static bool InitM0K0CO(const fe::PlatFormInfos *platform_infos); + static bool TransferShapeByFormat(const ge::Format &primary_format, const AxisValue &axis_value, + gert::Shape &shape); + static bool TransferShapeByAxisValue(const ge::Format &primary_format, const AxisValue &axis_value, + gert::Shape &shape); + static bool TransferShapeByOriginShape(const ge::Format &primary_format, const int64_t &c0, const int64_t &m0, + const ExtAxisValue &ext_axis, const gert::Shape &origin_shape, + gert::Shape &shape); + static bool TransferShapeByFormatIndex(const ge::Format &origin_format, const ge::Format &format, const int64_t &c0, + const gert::Shape &origin_shape, gert::Shape &shape); + static bool IsNeedTransferShape(const ge::Format &origin_format, const ge::Format &format, const gert::Shape &shape); + static bool CheckInputParam(const ge::Format &origin_format, const ge::Format &primary_format, + const ge::DataType &data_type); + static bool IsNeedAxisValue(const ge::Format &format, const size_t &origin_dim_size); + static int64_t GetC0Value(const ge::DataType &data_type, const ge::Format &format); + + /* ----------Below is the function of getting new shape by axis value---------------------- */ + static bool GetNCHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetCHWNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNDHWCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNCDHWShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetDHWCNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetDHWNCShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNDC1HWC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetC1HWNCoC0ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFzShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFz3DShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFz3DTransposeShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFzLstmShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFzC04ShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFznRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetNDRNNShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + /* ----------Below is the function of getting new shape by origin shape---------------------- */ + static bool GetNCHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetNHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetCHWNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetNDHWCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetNCDHWShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetDHWCNShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetDHWNCShape(const FormatIndex& format_index, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetNC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetNDC1HWC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetC1HWNCoC0Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetFractalNzShape(const int64_t &c0, const int64_t &m0, + const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetFractalZShape(const int64_t &c0, const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetFractalZShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group, + const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetFractalZ3DShape(const FormatIndex& format_index, const int64_t &c0, const int64_t &group, + const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetFractalZ3DTransposeShape(const FormatIndex& format_index, const int64_t &c0, + const gert::Shape &origin_shape, gert::Shape &shape); + + static bool GetFractalZLstmShape(const FormatIndex& format_index, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetFractalZC04Shape(const FormatIndex& format_index, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetFractalZnRnnShape(const ExtAxisValue &ext_axis, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetNdRnnBiasShape(const ExtAxisValue &ext_axis, const int64_t &c0, const gert::Shape &origin_shape, + gert::Shape &shape); + + static bool GetNYUVShape(gert::Shape &shape); + + static bool GetFzWinoShapeByAxisValue(const AxisValue &axis_value, gert::Shape &shape); + + static bool GetFractalZWinoShape(const FormatIndex& format_index, const int64_t &c0, + const gert::Shape &origin_shape, gert::Shape &shape); + + static TransferShapeType GetTransferShapeType(const ge::Format &src_format, const ge::Format &dst_format, + const gert::Shape &src_shape); + + static bool GetNdToNdAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); + + static bool GetNdToNzAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); + + static bool GetFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); + + static bool GetNotFullSizeAlignedShape(const AlignShapeInfo &align_shape_info, gert::Shape &aligned_shape); + + static bool GetNdToNdAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, + AxisIndexMapping &axis_index_mapping); + + static bool GetNdToNzAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, + AxisIndexMapping &axis_index_mapping); + + static bool GetFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, + AxisIndexMapping &axis_index_mapping); + + static bool GetNotFullSizeAxisIndexMapping(const TransferDimsInfo &transfer_dims_info, + AxisIndexMapping &axis_index_mapping); + + static std::array(ge::DataType::DT_MAX)> m0_list_; + static std::array(ge::DataType::DT_MAX)> k0_list_; + static std::array(ge::DataType::DT_MAX)> n0_list_; + static const std::map get_aligned_shape_func_map; + static const std::map transfer_dims_func_map; +}; +} +#endif // TRANSFORMER_INC_TRANSFER_SHAPE_UTILS_H_ diff --git a/inc/mmpa/mmpa_api.h b/inc/mmpa/mmpa_api.h new file mode 100644 index 0000000000000000000000000000000000000000..f7eb743577d0b16c02487374c164f3d05c19fc4f --- /dev/null +++ b/inc/mmpa/mmpa_api.h @@ -0,0 +1,136 @@ +/* +* @file mmpa_api.h +* +* Copyright (C) Huawei Technologies Co., Ltd. 2019-2021. All Rights Reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ + +#ifndef _MMPA_API_H_ +#define _MMPA_API_H_ + +#define LINUX 0 +#define WIN 1 + +#if(OS_TYPE == LINUX) //lint !e553 + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif + +#ifdef FUNC_VISIBILITY +#define MMPA_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define MMPA_FUNC_VISIBILITY +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" + +#include "./sub_inc/mmpa_typedef_linux.h" +#include "./sub_inc/mmpa_linux.h" + +#endif + + +#if(OS_TYPE == WIN) //lint !e553 + +#ifdef FUNC_VISIBILITY +#define MMPA_FUNC_VISIBILITY _declspec(dllexport) +#else +#define MMPA_FUNC_VISIBILITY +#endif + +#include +#include +#include "Windows.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "shlwapi.h" +#include +#include +#include +#include +#include +#include +#include +#include + +#include "securec.h" + +#include "sub_inc/mmpa_typedef_win.h" +#include "sub_inc/mmpa_win.h" + +#pragma comment(lib, "ws2_32.lib") +#pragma comment(lib, "mswsock.lib") +#pragma comment(lib, "Kernel32.lib") +#pragma comment(lib, "shlwapi.lib") +#pragma comment(lib, "wbemuuid.lib") +#pragma comment(lib, "Iphlpapi.lib") +#endif + +#endif // MMPA_API_H_ + diff --git a/inc/mmpa/sub_inc/mmpa_linux.h b/inc/mmpa/sub_inc/mmpa_linux.h new file mode 100644 index 0000000000000000000000000000000000000000..23af7fa0b9f7f4463b079133e6d52c5c7a92952f --- /dev/null +++ b/inc/mmpa/sub_inc/mmpa_linux.h @@ -0,0 +1,558 @@ +/* +* @file mmpa_linux.h +* +* Copyright (C) Huawei Technologies Co., Ltd. 2019-2021. All Rights Reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ + +#ifndef MMPA_LINUX_MMPA_LINUX_H +#define MMPA_LINUX_MMPA_LINUX_H + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif // __cpluscplus +#endif // __cpluscplus + +#define MMPA_MACINFO_DEFAULT_SIZE 18 +#define MMPA_CPUDESC_DEFAULT_SIZE 64 + +typedef pthread_t mmThread; +typedef pthread_mutex_t mmMutex_t; +typedef pthread_cond_t mmCond; +typedef pthread_mutex_t mmMutexFC; +typedef pthread_rwlock_t mmRWLock_t; +typedef signed int mmProcess; +typedef int mmPollHandle; +typedef int mmPipeHandle; +typedef int mmFileHandle; +typedef int mmComPletionKey; +typedef int mmCompletionHandle; +typedef int mmErrorMsg; +typedef int mmFd_t; + +typedef VOID *mmExitCode; +typedef key_t mmKey_t; +typedef int mmMsgid; +typedef struct dirent mmDirent; +typedef struct dirent mmDirent2; +typedef struct shmid_ds mmshmId_ds; +typedef int (*mmFilter)(const mmDirent *entry); +typedef int (*mmFilter2)(const mmDirent2 *entry); +typedef int (*mmSort)(const mmDirent **a, const mmDirent **b); +typedef int (*mmSort2)(const mmDirent2 **a, const mmDirent2 **b); +typedef size_t mmSize_t; //lint !e410 !e1051 +typedef off_t mmOfft_t; +typedef pid_t mmPid_t; +typedef long MM_LONG; + +typedef VOID *(*userProcFunc)(VOID *pulArg); + +typedef struct { + userProcFunc procFunc; // Callback function pointer + VOID *pulArg; // Callback function parameters +} mmUserBlock_t; + +typedef struct { + const CHAR *dli_fname; + VOID *dli_fbase; + const CHAR *dli_sname; + VOID *dli_saddr; + size_t dli_size; /* ELF only */ + INT32 dli_bind; /* ELF only */ + INT32 dli_type; +} mmDlInfo; + +typedef struct { + INT32 wSecond; // Seconds. [0-60] (1 leap second) + INT32 wMinute; // Minutes. [0-59] + INT32 wHour; // Hours. [0-23] + INT32 wDay; // Day. [1-31] + INT32 wMonth; // Month. [1-12] + INT32 wYear; // Year + INT32 wDayOfWeek; // Day of week. [0-6] + INT32 tm_yday; // Days in year.[0-365] + INT32 tm_isdst; // DST. [-1/0/1] + LONG wMilliseconds; // milliseconds +} mmSystemTime_t; + +typedef sem_t mmSem_t; +typedef struct sockaddr mmSockAddr; +typedef socklen_t mmSocklen_t; +typedef int mmSockHandle; +typedef timer_t mmTimer; +typedef pthread_key_t mmThreadKey; + +typedef int mmOverLap; + +typedef ssize_t mmSsize_t; +typedef size_t mmSize; // size + +typedef struct { + UINT32 createFlag; + INT32 oaFlag; +} mmCreateFlag; + +typedef struct { + VOID *sendBuf; + INT32 sendLen; +} mmIovSegment; +typedef struct in_addr mmInAddr; + +typedef struct { + VOID *inbuf; + INT32 inbufLen; + VOID *outbuf; + INT32 outbufLen; + mmOverLap *oa; +} mmIoctlBuf; + +typedef int mmAtomicType; +typedef int mmAtomicType64; + +typedef enum { + pollTypeRead = 1, // pipe read + pollTypeRecv, // socket recv + pollTypeIoctl, // ioctl +} mmPollType; + +typedef struct { + mmPollHandle handle; // The file descriptor or handle of poll is required + mmPollType pollType; // Operation type requiring poll + // read or recv or ioctl + INT32 ioctlCode; // IOCTL operation code, dedicated to IOCTL + mmComPletionKey completionKey; // The default value is blank, which is used in windows + // The data used to receive the difference between which handle is readable +} mmPollfd; + +typedef struct { + VOID *priv; // User defined private content + mmPollHandle bufHandle; // Value of handle corresponding to buf + mmPollType bufType; // Data types polled to + VOID *buf; // Data used in poll + UINT32 bufLen; // Data length used in poll + UINT32 bufRes; // Actual return length +} mmPollData, *pmmPollData; + +typedef VOID (*mmPollBack)(pmmPollData); + +typedef struct { + INT32 tz_minuteswest; // How many minutes is it different from Greenwich + INT32 tz_dsttime; // type of DST correction +} mmTimezone; + +typedef struct { + INT64 tv_sec; + INT64 tv_usec; +} mmTimeval; + +typedef struct { + MM_LONG tv_sec; + MM_LONG tv_nsec; +} mmTimespec; + +typedef struct { + ULONGLONG totalSize; + ULONGLONG freeSize; + ULONGLONG availSize; +} mmDiskSize; + +#define mmTLS __thread +typedef struct stat mmStat_t; +typedef struct stat64 mmStat64_t; +typedef mode_t mmMode_t; + +typedef struct option mmStructOption; + +typedef struct { + CHAR addr[MMPA_MACINFO_DEFAULT_SIZE]; // ex:aa-bb-cc-dd-ee-ff\0 +} mmMacInfo; + +typedef struct { + CHAR **argv; + INT32 argvCount; + CHAR **envp; + INT32 envpCount; +} mmArgvEnv; + +typedef struct { + CHAR arch[MMPA_CPUDESC_DEFAULT_SIZE]; + CHAR manufacturer[MMPA_CPUDESC_DEFAULT_SIZE]; // vendor + CHAR version[MMPA_CPUDESC_DEFAULT_SIZE]; // modelname + INT32 frequency; // cpu frequency + INT32 maxFrequency; // max speed + INT32 ncores; // cpu cores + INT32 nthreads; // cpu thread count + INT32 ncounts; // logical cpu nums +} mmCpuDesc; + +typedef mode_t MODE; + +typedef struct { + INT32 detachFlag; // Determine whether to set separation property 0, not to separate 1 + INT32 priorityFlag; // Determine whether to set priority 0 and not set 1 + INT32 priority; // Priority value range to be set 1-99 + INT32 policyFlag; // Set scheduling policy or not 0 do not set 1 setting + INT32 policy; // Scheduling policy value value + // MMPA_THREAD_SCHED_RR + // MMPA_THREAD_SCHED_OTHER + // MMPA_THREAD_SCHED_FIFO + INT32 stackFlag; // Set stack size or not: 0 does not set 1 setting + UINT32 stackSize; // The stack size unit bytes to be set cannot be less than MMPA_THREAD_STACK_MIN +} mmThreadAttr; + +#ifdef __ANDROID__ +#define S_IREAD S_IRUSR +#define S_IWRITE S_IWUSR +#endif + +#define mm_no_argument no_argument +#define mm_required_argument required_argument +#define mm_optional_argument optional_argument + +#define M_FILE_RDONLY O_RDONLY +#define M_FILE_WRONLY O_WRONLY +#define M_FILE_RDWR O_RDWR +#define M_FILE_CREAT O_CREAT + +#define M_RDONLY O_RDONLY +#define M_WRONLY O_WRONLY +#define M_RDWR O_RDWR +#define M_CREAT O_CREAT +#define M_BINARY O_RDONLY +#define M_TRUNC O_TRUNC +#define M_IRWXU S_IRWXU +#define M_APPEND O_APPEND + +#define M_IN_CREATE IN_CREATE +#define M_IN_CLOSE_WRITE IN_CLOSE_WRITE +#define M_IN_IGNORED IN_IGNORED + +#define M_OUT_CREATE IN_CREATE +#define M_OUT_CLOSE_WRITE IN_CLOSE_WRITE +#define M_OUT_IGNORED IN_IGNORED +#define M_OUT_ISDIR IN_ISDIR + +#define M_IREAD S_IREAD +#define M_IRUSR S_IRUSR +#define M_IWRITE S_IWRITE +#define M_IWUSR S_IWUSR +#define M_IXUSR S_IXUSR +#define FDSIZE 64 +#define M_MSG_CREAT IPC_CREAT +#define M_MSG_EXCL (IPC_CREAT | IPC_EXCL) +#define M_MSG_NOWAIT IPC_NOWAIT + +#define M_WAIT_NOHANG WNOHANG // Non blocking waiting +#define M_WAIT_UNTRACED \ + WUNTRACED // If the subprocess enters the suspended state, it will return immediately + // But the end state of the subprocess is ignored +#define M_UMASK_USRREAD S_IRUSR +#define M_UMASK_GRPREAD S_IRGRP +#define M_UMASK_OTHREAD S_IROTH + +#define M_UMASK_USRWRITE S_IWUSR +#define M_UMASK_GRPWRITE S_IWGRP +#define M_UMASK_OTHWRITE S_IWOTH + +#define M_UMASK_USREXEC S_IXUSR +#define M_UMASK_GRPEXEC S_IXGRP +#define M_UMASK_OTHEXEC S_IXOTH + +#define mmConstructor(x) __attribute__((constructor)) VOID x() +#define mmDestructor(x) __attribute__((destructor)) VOID x() + +#define MMPA_NO_ARGUMENT 0 +#define MMPA_REQUIRED_ARGUMENT 1 +#define MMPA_OPTIONAL_ARGUMENT 2 + +#define MMPA_MAX_PATH PATH_MAX +#define M_NAME_MAX MAX_FNAME + +#define M_F_OK F_OK +#define M_X_OK X_OK +#define M_W_OK W_OK +#define M_R_OK R_OK + + +#define MM_DT_DIR DT_DIR +#define MM_DT_REG DT_REG + +#define MMPA_STDIN STDIN_FILENO +#define MMPA_STDOUT STDOUT_FILENO +#define MMPA_STDERR STDERR_FILENO + +#define MMPA_RTLD_NOW RTLD_NOW +#define MMPA_RTLD_GLOBAL RTLD_GLOBAL +#define MMPA_RTLD_LAZY RTLD_LAZY +#define MMPA_RTLD_NODELETE RTLD_NODELETE + +#define MMPA_DL_EXT_NAME ".so" + +MMPA_FUNC_VISIBILITY INT32 mmCreateTask(mmThread *threadHandle, mmUserBlock_t *funcBlock); +MMPA_FUNC_VISIBILITY INT32 mmJoinTask(mmThread *threadHandle); +MMPA_FUNC_VISIBILITY INT32 mmMutexInit(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexTryLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexUnLock(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmMutexDestroy(mmMutex_t *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondInit(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondLockInit(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondLock(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondUnLock(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondLockDestroy(mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmRWLockInit(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockRDLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockTryRDLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockWRLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockTryWRLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRDLockUnLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmWRLockUnLock(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmRWLockDestroy(mmRWLock_t *rwLock); +MMPA_FUNC_VISIBILITY INT32 mmCondWait(mmCond *cond, mmMutexFC *mutex); +MMPA_FUNC_VISIBILITY INT32 mmCondTimedWait(mmCond *cond, mmMutexFC *mutex, UINT32 milliSecond); +MMPA_FUNC_VISIBILITY INT32 mmCondNotify(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondNotifyAll(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmCondDestroy(mmCond *cond); +MMPA_FUNC_VISIBILITY INT32 mmGetPid(); +MMPA_FUNC_VISIBILITY INT32 mmGetTid(); +MMPA_FUNC_VISIBILITY INT32 mmGetPidHandle(mmProcess *processHandle); +MMPA_FUNC_VISIBILITY INT32 mmGetLocalTime(mmSystemTime_t *sysTimePtr); +MMPA_FUNC_VISIBILITY INT32 mmGetSystemTime(mmSystemTime_t *sysTimePtr); + +MMPA_FUNC_VISIBILITY INT32 mmSemInit(mmSem_t *sem, UINT32 value); +MMPA_FUNC_VISIBILITY INT32 mmSemWait(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmSemPost(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmSemDestroy(mmSem_t *sem); +MMPA_FUNC_VISIBILITY INT32 mmOpen(const CHAR *pathName, INT32 flags); +MMPA_FUNC_VISIBILITY INT32 mmOpen2(const CHAR *pathName, INT32 flags, MODE mode); +MMPA_FUNC_VISIBILITY FILE *mmPopen(CHAR *command, CHAR *type); +MMPA_FUNC_VISIBILITY INT32 mmClose(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmPclose(FILE *stream); +MMPA_FUNC_VISIBILITY mmSsize_t mmWrite(INT32 fd, VOID *buf, UINT32 bufLen); +MMPA_FUNC_VISIBILITY mmSsize_t mmRead(INT32 fd, VOID *buf, UINT32 bufLen); +MMPA_FUNC_VISIBILITY mmSockHandle mmSocket(INT32 sockFamily, INT32 type, INT32 protocol); +MMPA_FUNC_VISIBILITY INT32 mmBind(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t addrLen); +MMPA_FUNC_VISIBILITY INT32 mmListen(mmSockHandle sockFd, INT32 backLog); +MMPA_FUNC_VISIBILITY mmSockHandle mmAccept(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t *addrLen); +MMPA_FUNC_VISIBILITY INT32 mmConnect(mmSockHandle sockFd, mmSockAddr *addr, mmSocklen_t addrLen); +MMPA_FUNC_VISIBILITY INT32 mmCloseSocket(mmSockHandle sockFd); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketSend(mmSockHandle sockFd, VOID *sendBuf, INT32 sendLen, INT32 sendFlag); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketRecv(mmSockHandle sockFd, VOID *recvBuf, INT32 recvLen, INT32 recvFlag); +MMPA_FUNC_VISIBILITY INT32 mmSocketSendTo(mmSockHandle sockFd, + VOID *sendMsg, + INT32 sendLen, + UINT32 sendFlag, + const mmSockAddr* addr, + INT32 tolen); +MMPA_FUNC_VISIBILITY mmSsize_t mmSocketRecvFrom(mmSockHandle sockFd, + VOID *recvBuf, + mmSize recvLen, + UINT32 recvFlag, + mmSockAddr* addr, + mmSocklen_t *FromLen); +MMPA_FUNC_VISIBILITY INT32 mmSAStartup(); +MMPA_FUNC_VISIBILITY INT32 mmSACleanup(); +MMPA_FUNC_VISIBILITY VOID *mmDlopen(const CHAR *fileName, INT32 mode); +MMPA_FUNC_VISIBILITY INT32 mmDladdr(VOID *addr, mmDlInfo *info); +MMPA_FUNC_VISIBILITY VOID *mmDlsym(VOID *handle, const CHAR *funcName); +MMPA_FUNC_VISIBILITY INT32 mmDlclose(VOID *handle); +MMPA_FUNC_VISIBILITY CHAR *mmDlerror(); +MMPA_FUNC_VISIBILITY INT32 mmCreateAndSetTimer(mmTimer *timerHandle, + mmUserBlock_t *timerBlock, + UINT milliSecond, + UINT period); +MMPA_FUNC_VISIBILITY INT32 mmDeleteTimer(mmTimer timerHandle); +MMPA_FUNC_VISIBILITY INT32 mmStatGet(const CHAR *path, mmStat_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmStat64Get(const CHAR *path, mmStat64_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmFStatGet(INT32 fd, mmStat_t *buffer); +MMPA_FUNC_VISIBILITY INT32 mmMkdir(const CHAR *pathName, mmMode_t mode); +MMPA_FUNC_VISIBILITY INT32 mmSleep(UINT32 milliSecond); + +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithAttr(mmThread *threadHandle, mmUserBlock_t *funcBlock); +MMPA_FUNC_VISIBILITY INT32 mmGetProcessPrio(mmProcess pid); +MMPA_FUNC_VISIBILITY INT32 mmSetProcessPrio(mmProcess pid, INT32 processPrio); +MMPA_FUNC_VISIBILITY INT32 mmGetThreadPrio(mmThread *threadHandle); +MMPA_FUNC_VISIBILITY INT32 mmSetThreadPrio(mmThread *threadHandle, INT32 threadPrio); +MMPA_FUNC_VISIBILITY INT32 mmAccess(const CHAR *pathName); +MMPA_FUNC_VISIBILITY INT32 mmAccess2(const CHAR *pathName, INT32 mode); +MMPA_FUNC_VISIBILITY INT32 mmRmdir(const CHAR *pathName); + +MMPA_FUNC_VISIBILITY INT32 mmIoctl(mmProcess fd, INT32 ioctlCode, mmIoctlBuf *bufPtr); +MMPA_FUNC_VISIBILITY INT32 mmSemTimedWait(mmSem_t *sem, INT32 timeout); +MMPA_FUNC_VISIBILITY mmSsize_t mmWritev(mmProcess fd, mmIovSegment *iov, INT32 iovcnt); +MMPA_FUNC_VISIBILITY VOID mmMb(); +MMPA_FUNC_VISIBILITY INT32 mmInetAton(const CHAR *addrStr, mmInAddr *addr); + +MMPA_FUNC_VISIBILITY mmProcess mmOpenFile(const CHAR *fileName, UINT32 accessFlag, mmCreateFlag fileFlag); +MMPA_FUNC_VISIBILITY mmSsize_t mmReadFile(mmProcess fileId, VOID *buffer, INT32 len); +MMPA_FUNC_VISIBILITY mmSsize_t mmWriteFile(mmProcess fileId, VOID *buffer, INT32 len); +MMPA_FUNC_VISIBILITY INT32 mmCloseFile(mmProcess fileId); + +MMPA_FUNC_VISIBILITY mmAtomicType mmSetData(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType mmValueInc(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType mmValueSub(mmAtomicType *ptr, mmAtomicType value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmSetData64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmValueInc64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY mmAtomicType64 mmValueSub64(mmAtomicType64 *ptr, mmAtomicType64 value); +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithDetach(mmThread *threadHandle, mmUserBlock_t *funcBlock); + +// The following 3 interfaces are to be deleted +MMPA_FUNC_VISIBILITY INT32 mmCreateNamedPipe(mmPipeHandle pipeHandle[], CHAR *pipeName[], INT32 waitMode); +MMPA_FUNC_VISIBILITY INT32 mmOpenNamePipe(mmPipeHandle pipeHandle[], CHAR *pipeName[], INT32 waitMode); +MMPA_FUNC_VISIBILITY VOID mmCloseNamedPipe(mmPipeHandle namedPipe[]); + +MMPA_FUNC_VISIBILITY INT32 mmCreatePipe(mmPipeHandle pipeHandle[], CHAR *pipeName[], UINT32 pipeCount, INT32 waitMode); +MMPA_FUNC_VISIBILITY INT32 mmOpenPipe(mmPipeHandle pipeHandle[], CHAR *pipeName[], UINT32 pipeCount, INT32 waitMode); +MMPA_FUNC_VISIBILITY VOID mmClosePipe(mmPipeHandle pipeHandle[], UINT32 pipeCount); + +// Poll related interface +MMPA_FUNC_VISIBILITY mmCompletionHandle mmCreateCompletionPort(); +MMPA_FUNC_VISIBILITY VOID mmCloseCompletionPort(mmCompletionHandle handle); +MMPA_FUNC_VISIBILITY INT32 mmPoll(mmPollfd *fds, + INT32 fdCount, + INT32 timeout, + mmCompletionHandle handleIOCP, + pmmPollData polledData, + mmPollBack pollBack); +MMPA_FUNC_VISIBILITY INT32 mmGetErrorCode(); +MMPA_FUNC_VISIBILITY CHAR *mmGetErrorFormatMessage(mmErrorMsg errnum, CHAR *buf, mmSize size); +MMPA_FUNC_VISIBILITY INT32 mmGetTimeOfDay(mmTimeval *timeVal, mmTimezone *timeZone); +MMPA_FUNC_VISIBILITY mmTimespec mmGetTickCount(); +MMPA_FUNC_VISIBILITY INT32 mmGetRealPath(CHAR *path, CHAR *realPath); +MMPA_FUNC_VISIBILITY INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); + +MMPA_FUNC_VISIBILITY INT32 mmDup2(INT32 oldFd, INT32 newFd); + +MMPA_FUNC_VISIBILITY INT32 mmDup(INT32 fd); + +MMPA_FUNC_VISIBILITY INT32 mmUnlink(const CHAR *filename); + +MMPA_FUNC_VISIBILITY INT32 mmChmod(const CHAR *filename, INT32 mode); + +MMPA_FUNC_VISIBILITY INT32 mmFileno(FILE *stream); + +MMPA_FUNC_VISIBILITY INT32 mmScandir(const CHAR *path, mmDirent ***entryList, mmFilter filterFunc, mmSort sort); +MMPA_FUNC_VISIBILITY INT32 mmScandir2(const CHAR *path, mmDirent2 ***entryList, mmFilter2 filterFunc, mmSort2 sort); + +MMPA_FUNC_VISIBILITY VOID mmScandirFree(mmDirent **entryList, INT32 count); +MMPA_FUNC_VISIBILITY VOID mmScandirFree2(mmDirent2 **entryList, INT32 count); + +MMPA_FUNC_VISIBILITY mmMsgid mmMsgCreate(mmKey_t key, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY mmMsgid mmMsgOpen(mmKey_t key, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgSnd(mmMsgid msqid, VOID *buf, INT32 bufLen, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgRcv(mmMsgid msqid, VOID *buf, INT32 bufLen, INT32 msgFlag); + +MMPA_FUNC_VISIBILITY INT32 mmMsgClose(mmMsgid msqid); + +MMPA_FUNC_VISIBILITY INT32 mmLocalTimeR(const time_t *timep, struct tm *result); + +MMPA_FUNC_VISIBILITY INT32 mmGetOptErr(); +MMPA_FUNC_VISIBILITY VOID mmSetOptErr(INT32 mmOptErr); +MMPA_FUNC_VISIBILITY INT32 mmGetOptInd(); +MMPA_FUNC_VISIBILITY VOID mmSetOptInd(INT32 mmOptInd); +MMPA_FUNC_VISIBILITY INT32 mmGetOptOpt(); +MMPA_FUNC_VISIBILITY VOID mmSetOpOpt(INT32 mmOptOpt); +MMPA_FUNC_VISIBILITY CHAR *mmGetOptArg(); +MMPA_FUNC_VISIBILITY VOID mmSetOptArg(CHAR *mmOptArg); +MMPA_FUNC_VISIBILITY INT32 mmGetOpt(INT32 argc, CHAR *const *argv, const CHAR *opts); +MMPA_FUNC_VISIBILITY INT32 mmGetOptLong(INT32 argc, + CHAR *const *argv, + const CHAR *opts, + const mmStructOption *longOpts, + INT32 *longIndex); + +MMPA_FUNC_VISIBILITY LONG mmLseek(INT32 fd, INT64 offset, INT32 seekFlag); +MMPA_FUNC_VISIBILITY INT32 mmFtruncate(mmProcess fd, UINT32 length); + +MMPA_FUNC_VISIBILITY INT32 mmTlsCreate(mmThreadKey *key, VOID (*destructor)(VOID *)); +MMPA_FUNC_VISIBILITY INT32 mmTlsSet(mmThreadKey key, const VOID *value); +MMPA_FUNC_VISIBILITY VOID *mmTlsGet(mmThreadKey key); +MMPA_FUNC_VISIBILITY INT32 mmTlsDelete(mmThreadKey key); +MMPA_FUNC_VISIBILITY INT32 mmGetOsType(); + +MMPA_FUNC_VISIBILITY INT32 mmFsync(mmProcess fd); +MMPA_FUNC_VISIBILITY INT32 mmFsync2(INT32 fd); +MMPA_FUNC_VISIBILITY INT32 mmChdir(const CHAR *path); +MMPA_FUNC_VISIBILITY INT32 mmUmask(INT32 pmode); +MMPA_FUNC_VISIBILITY INT32 mmWaitPid(mmProcess pid, INT32 *status, INT32 options); + +MMPA_FUNC_VISIBILITY INT32 mmGetCwd(CHAR *buffer, INT32 maxLen); +MMPA_FUNC_VISIBILITY INT32 mmGetEnv(const CHAR *name, CHAR *value, UINT32 len); +MMPA_FUNC_VISIBILITY INT32 mmSetEnv(const CHAR *name, const CHAR *value, INT32 overwrite); +MMPA_FUNC_VISIBILITY CHAR *mmStrTokR(CHAR *str, const CHAR *delim, CHAR **saveptr); +MMPA_FUNC_VISIBILITY CHAR *mmDirName(CHAR *path); +MMPA_FUNC_VISIBILITY CHAR *mmBaseName(CHAR *path); +MMPA_FUNC_VISIBILITY INT32 mmGetDiskFreeSpace(const CHAR *path, mmDiskSize *diskSize); + +/* + * Function: set the thread name created by mmcreatetask + * Input: pstThreadHandle: thread ID + * name: thread name, the actual length of name must be < MMPA_THREADNAME_SIZE + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmSetThreadName(mmThread *threadHandle, const CHAR *name); + +/* + * Function: get thread name + * Input: pstThreadHandle: thread ID + * size: Cache length of thread name + * name:User allocated cache for thread name, Cache length must be >= MMPA_THREADNAME_SIZE + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmGetThreadName(mmThread *threadHandle, CHAR *name, INT32 size); +/* + * Function:Set the thread name of the currently executing thread - call inside the thread body + * Input:name:Thread name to be set + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmSetCurrentThreadName(const CHAR *name); +/* + * Function:Get the thread name of the currently executing thread - in body call + * Input:name:The name of the thread to get, and the cache is allocated by the user,size>=MMPA_THREADNAME_SIZE + * The input parameter error returns EN_INVALID_PARAM, the execution success returns EN_OK, and the + * execution failure returns EN_ERROR + */ +MMPA_FUNC_VISIBILITY INT32 mmGetCurrentThreadName(CHAR *name, INT32 size); +MMPA_FUNC_VISIBILITY INT32 mmGetFileSize(const CHAR *fileName, ULONGLONG *length); +MMPA_FUNC_VISIBILITY INT32 mmIsDir(const CHAR *fileName); +MMPA_FUNC_VISIBILITY INT32 mmGetOsName(CHAR *name, INT32 nameSize); +MMPA_FUNC_VISIBILITY INT32 mmGetOsVersion(CHAR *versionInfo, INT32 versionLength); +MMPA_FUNC_VISIBILITY INT32 mmGetMac(mmMacInfo **list, INT32 *count); +MMPA_FUNC_VISIBILITY INT32 mmGetMacFree(mmMacInfo *list, INT32 count); +MMPA_FUNC_VISIBILITY INT32 mmGetCpuInfo(mmCpuDesc **cpuInfo, INT32 *count); +MMPA_FUNC_VISIBILITY INT32 mmCpuInfoFree(mmCpuDesc *cpuInfo, INT32 count); +MMPA_FUNC_VISIBILITY INT32 mmCreateProcess(const CHAR *fileName, + const mmArgvEnv *env, + const CHAR *stdoutRedirectFile, + mmProcess *id); + +MMPA_FUNC_VISIBILITY INT32 mmCreateTaskWithThreadAttr(mmThread *threadHandle, + const mmUserBlock_t *funcBlock, + const mmThreadAttr *threadAttr); +MMPA_FUNC_VISIBILITY mmFileHandle mmShmOpen(const CHAR *name, INT32 oflag, mmMode_t mode); +MMPA_FUNC_VISIBILITY INT32 mmShmUnlink(const CHAR *name); +MMPA_FUNC_VISIBILITY VOID *mmMmap(mmFd_t fd, mmSize_t size, mmOfft_t offset, mmFd_t *extra, INT32 prot, INT32 flags); +MMPA_FUNC_VISIBILITY INT32 mmMunMap(VOID *data, mmSize_t size, mmFd_t *extra); + +MMPA_FUNC_VISIBILITY mmSize mmGetPageSize(); +MMPA_FUNC_VISIBILITY VOID *mmAlignMalloc(mmSize mallocSize, mmSize alignSize); +MMPA_FUNC_VISIBILITY VOID mmAlignFree(VOID *addr); +#define MMPA_DLL_API + +#ifdef __cplusplus +#if __cplusplus +} +#endif /* __cpluscplus */ +#endif // __cpluscplus + +#endif // MMPA_LINUX_MMPA_LINUX_H_ diff --git a/inc/mmpa/sub_inc/mmpa_typedef_linux.h b/inc/mmpa/sub_inc/mmpa_typedef_linux.h new file mode 100644 index 0000000000000000000000000000000000000000..9c6f6499e0d393a7b1ec3082a93d1bf0a54490ab --- /dev/null +++ b/inc/mmpa/sub_inc/mmpa_typedef_linux.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MMPA_TYPEDEF_LINUX_H +#define MMPA_TYPEDEF_LINUX_H + +#ifdef __cplusplus +#if __cplusplus +extern "C" { +#endif // __cpluscplus +#endif // __cpluscplus + +#ifndef FALSE +#define FALSE 0 +#endif + +#ifndef TRUE +#define TRUE 1 +#endif + +typedef unsigned char UINT8; +typedef signed char INT8; +typedef unsigned short UINT16; +typedef signed short INT16; +typedef unsigned int UINT32; +typedef signed int INT32; +typedef unsigned long long UINT64; +typedef signed long long INT64; +typedef float FLOAT; +typedef double DOUBLE; +typedef void VOID; +typedef unsigned char UCHAR; +typedef char CHAR; +typedef unsigned short USHORT; +typedef short SHORT; +typedef unsigned int UINT; +typedef int INT; +typedef unsigned long ULONG; +typedef unsigned long long ULONGLONG; + +typedef long LONG; + +#define HANDLE_INVALID_VALUE (-1) +#define MMPA_MEM_MAX_LEN (0x7fffffff) +#define MMPA_PROCESS_ERROR (0x7fffffff) +#define PATH_SIZE 256 +#define MAX_IOVEC_SIZE 32 +#define MMPA_MAX_SLEEP_MILLSECOND 4294967 +#define MAX_PIPE_COUNT 2 +#define MMPA_PIPE_COUNT 2 +#define MMPA_THREADNAME_SIZE 16 +#define MMPA_MIN_OS_NAME_SIZE 64 +#define MMPA_MIN_OS_VERSION_SIZE 128 + +#define MMPA_ONE_THOUSAND 1000 +#define MMPA_ONE_BILLION 1000000000 +#define MMPA_COMPUTER_BEGIN_YEAR 1900 +#define MMPA_ZERO 0 +#define MMPA_MAX_THREAD_PIO 99 +#define MMPA_MIN_THREAD_PIO 1 +#define MMPA_DEFAULT_PIPE_PERMISSION 0777 +#define MMPA_DEFAULT_MSG_TYPE 1 + +#define MMPA_THREAD_SCHED_RR SCHED_RR +#define MMPA_THREAD_SCHED_FIFO SCHED_FIFO +#define MMPA_THREAD_SCHED_OTHER SCHED_OTHER +#define MMPA_THREAD_MIN_STACK_SIZE PTHREAD_STACK_MIN + +#define MMPA_PATH_SEPARATOR_STR "/" +#define MMPA_PATH_SEPARATOR_CHAR '/' + +#define MM_MUTEX_INITIALIZER PTHREAD_MUTEX_INITIALIZER + +#define MMPA_MAX_NI 19 +#define MMPA_MIN_NI (-20) + +#define EN_OK 0 +#define EN_ERR 1 +#define EN_ERROR (-1) +#define EN_INVALID_PARAM (-2) +#define EN_TIMEOUT (-3) + +#ifdef __cplusplus +#if __cplusplus +} +#endif // __cpluscplus +#endif // __cpluscplus +#endif // MMPA_TYPEDEF_LINUX_H_ diff --git a/inc/msprof/toolchain/prof_acl_api.h b/inc/msprof/toolchain/prof_acl_api.h new file mode 100644 index 0000000000000000000000000000000000000000..105748d560401ac70060b1652cf278fb32d11aee --- /dev/null +++ b/inc/msprof/toolchain/prof_acl_api.h @@ -0,0 +1,163 @@ +/** + * @file prof_acl_api.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef MSPROFILER_API_PROF_ACL_API_H +#define MSPROFILER_API_PROF_ACL_API_H + +#include +#include +#include "prof_data_config.h" + +constexpr int32_t PROF_MAX_DEV_NUM = 64; // 64 : dev max number +constexpr int32_t PROF_DEFAULT_HOST_ID = PROF_MAX_DEV_NUM; + +/** + * @name ProfAicoreMetrics + * @brief aicore metrics enum + */ +enum ProfAicoreMetrics { + PROF_AICORE_ARITHMETIC_UTILIZATION = 0, + PROF_AICORE_PIPE_UTILIZATION = 1, + PROF_AICORE_MEMORY_BANDWIDTH = 2, + PROF_AICORE_L0B_AND_WIDTH = 3, + PROF_AICORE_RESOURCE_CONFLICT_RATIO = 4, + PROF_AICORE_MEMORY_UB = 5, + PROF_AICORE_L2_CACHE = 6, + PROF_AICORE_PIPE_EXECUTE_UTILIZATION = 7, + PROF_AICORE_METRICS_COUNT, + PROF_AICORE_NONE = 0xFF, +}; + +/** + * @name ProfConfig + * @brief struct of aclprofStart/aclprofStop + */ +struct ProfConfig { + uint32_t devNums; // length of device id list + uint32_t devIdList[PROF_MAX_DEV_NUM + 1]; // physical device id list + ProfAicoreMetrics aicoreMetrics; // aicore metric + uint64_t dataTypeConfig; // data type to start profiling +}; +using PROF_CONF_CONST_PTR = const ProfConfig *; + +/** + * @name ProfSubscribeConfig + * @brief config of subscribe api + */ +struct ProfSubscribeConfig { + bool timeInfo; // subscribe op time + ProfAicoreMetrics aicoreMetrics; // subscribe ai core metrics + void* fd; // pipe fd +}; +using PROF_SUB_CONF_CONST_PTR = const ProfSubscribeConfig *; + +/** + * @name aclprofSubscribeConfig + * @brief config of subscribe common api + */ +struct aclprofSubscribeConfig { + struct ProfSubscribeConfig config; +}; +using ACL_PROF_SUB_CONFIG_PTR = aclprofSubscribeConfig *; +using ACL_PROF_SUB_CINFIG_CONST_PTR = const aclprofSubscribeConfig *; + +#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) +#define MSVP_PROF_API __declspec(dllexport) +#else +#define MSVP_PROF_API __attribute__((visibility("default"))) +#endif + +using PROFAPI_SUBSCRIBECONFIG_CONST_PTR = const void *; +namespace Msprofiler { +namespace Api { +/** + * @name ProfGetOpExecutionTime + * @brief get op execution time of specific part of data + * @param data [IN] data read from pipe + * @param len [IN] data length + * @param index [IN] index of part(op) + * @return op execution time (us) + */ +MSVP_PROF_API uint64_t ProfGetOpExecutionTime(const void *data, uint32_t len, uint32_t index); +} +} + +#ifdef __cplusplus +extern "C" { +#endif + +MSVP_PROF_API uint64_t ProfGetOpExecutionTime(const void *data, uint32_t len, uint32_t index); +MSVP_PROF_API int32_t ProfOpSubscribe(uint32_t devId, PROFAPI_SUBSCRIBECONFIG_CONST_PTR profSubscribeConfig); +MSVP_PROF_API int32_t ProfOpUnSubscribe(uint32_t devId); + +using Status = int32_t; +typedef struct aclprofSubscribeConfig aclprofSubscribeConfig1; +/** + * @ingroup AscendCL + * @brief subscribe profiling data of graph + * @param [in] graphId: the graph id subscribed + * @param [in] profSubscribeConfig: pointer to config of model subscribe + * @return Status result of function + */ +MSVP_PROF_API Status aclgrphProfGraphSubscribe(const uint32_t graphId, + const aclprofSubscribeConfig1 *profSubscribeConfig); +/** + * @ingroup AscendCL + * @brief unsubscribe profiling data of graph + * @param [in] graphId: the graph id subscribed + * @return Status result of function + */ +MSVP_PROF_API Status aclgrphProfGraphUnSubscribe(const uint32_t graphId); + +/** + * @ingroup AscendCL + * @brief get graph id from subscription data + * + * @param opInfo [IN] pointer to subscription data + * @param opInfoLen [IN] memory size of subscription data + * + * @retval graph id of subscription data + * @retval 0 for failed + */ +MSVP_PROF_API size_t aclprofGetGraphId(const void *opInfo, size_t opInfoLen, uint32_t index); + +/** +* @ingroup AscendCL +* @brief set stamp pay load +* +* +* @retval void +*/ +MSVP_PROF_API int aclprofSetStampPayload(void *stamp, const int32_t type, void *value); + +/** +* @ingroup AscendCL +* @brief set category and name +* +* +* @retval void +*/ +MSVP_PROF_API int aclprofSetCategoryName(uint32_t category, const char *categoryName); + +/** +* @ingroup AscendCL +* @brief set category to stamp +* +* +* @retval void +*/ +MSVP_PROF_API int aclprofSetStampCategory(void *stamp, uint32_t category); + +#ifdef __cplusplus +} +#endif + +#endif // MSPROFILER_API_PROF_ACL_API_H_ diff --git a/inc/msprof/toolchain/prof_api.h b/inc/msprof/toolchain/prof_api.h new file mode 100644 index 0000000000000000000000000000000000000000..1d92ee2721dff3745ad48044d2926d1437d0634a --- /dev/null +++ b/inc/msprof/toolchain/prof_api.h @@ -0,0 +1,275 @@ +/** + * @file prof_api.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef PROF_API_H +#define PROF_API_H + +#include +#include +#include +#include "prof_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) +#define MSVP_PROF_API __declspec(dllexport) +#else +#define MSVP_PROF_API __attribute__((visibility("default"))) +#endif + +/* + * @ingroup libprofapi + * @name profRegReporterCallback + * @brief register report callback interface for atlas + * @param [in] reporter: reporter callback handle + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t profRegReporterCallback(MsprofReportHandle reporter); + +/* + * @ingroup libprofapi + * @name profRegCtrlCallback + * @brief register control callback, interface for atlas + * @param [in] handle: control callback handle + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t profRegCtrlCallback(MsprofCtrlHandle handle); + +/* + * @ingroup libprofapi + * @name profRegDeviceStateCallback + * @brief register device state notify callback, interface for atlas + * @param [in] handle: handle of ProfNotifySetDevice + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t profRegDeviceStateCallback(MsprofSetDeviceHandle handle); + +/* + * @ingroup libprofapi + * @name profGetDeviceIdByGeModelIdx + * @brief get device id by model id, interface for atlas + * @param [in] modelIdx: ge model id + * @param [out] deviceId: device id + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t profGetDeviceIdByGeModelIdx(const uint32_t modelIdx, uint32_t *deviceId); + +/* + * @ingroup libprofapi + * @name profSetProfCommand + * @brief register set profiling command, interface for atlas + * @param [in] command: 0 isn't aging, !0 is aging + * @param [in] len: api of timestamp data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t profSetProfCommand(VOID_PTR command, uint32_t len); + +/* + * @ingroup libprofapi + * @name profSetStepInfo + * @brief set step info for torch, interface for atlas + * @param [in] indexId: id of iteration index + * @param [in] tagId: id of tag + * @param [in] stream: api of timestamp data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t profSetStepInfo(const uint64_t indexId, const uint16_t tagId, void* const stream); + +/* + * @ingroup libprofapi + * @name MsprofRegisterProfileCallback + * @brief register profile callback by callback type, interface for atlas + * @param [in] callbackType: type of callback(reporter/ctrl/device state/command) + * @param [in] callback: callback of profile + * @param [in] len: callback length + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofRegisterProfileCallback(int32_t callbackType, VOID_PTR callback, uint32_t len); + +/** + * @ingroup libprofapi + * @name MsprofInit + * @brief Profiling module init + * @param [in] dataType: profiling type: ACL Env/ACL Json/GE Option + * @param [in] data: profiling switch data + * @param [in] dataLen: Length of data + * @return 0:SUCCESS, >0:FAILED + */ +MSVP_PROF_API int32_t MsprofInit(uint32_t dataType, VOID_PTR data, uint32_t dataLen); + +/** + * @ingroup libprofapi + * @name MsprofSetConfig + * @brief Set profiling config + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofSetConfig(uint32_t configType, const char *config, size_t configLength); + +/** + * @ingroup libprofapi + * @name MsprofRegisterCallback + * @brief register profiling switch callback for module + * @param [in] agingFlag: 0 isn't aging, !0 is aging + * @param [in] api: api of timestamp data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofRegisterCallback(uint32_t moduleId, ProfCommandHandle handle); + +/** + * @ingroup libprofapi + * @name MsprofReportData + * @brief report profiling data of module + * @param [in] moduleId: module id + * @param [in] type: report type(init/uninit/max length/hash) + * @param [in] data: profiling data + * @param [in] len: length of profiling data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofReportData(uint32_t moduleId, uint32_t type, VOID_PTR data, uint32_t len); + +/* + * @ingroup libprofapi + * @name MsprofReportApi + * @brief report api timestamp + * @param [in] agingFlag: 0 isn't aging, !0 is aging + * @param [in] api: api of timestamp data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofReportApi(uint32_t agingFlag, const struct MsprofApi *api); + +/* + * @ingroup libprofapi + * @name MsprofReportEvent + * @brief report event timestamp + * @param [in] agingFlag: 0 isn't aging, !0 is aging + * @param [in] event: event of timestamp data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofReportEvent(uint32_t agingFlag, const struct MsprofEvent *event); + +/* + * @ingroup libprofapi + * @name MsprofReportCompactInfo + * @brief report profiling compact infomation + * @param [in] agingFlag: 0 isn't aging, !0 is aging + * @param [in] data: profiling data of compact infomation + * @param [in] length: length of profiling data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofReportCompactInfo(uint32_t agingFlag, const VOID_PTR data, uint32_t length); + +/* + * @ingroup libprofapi + * @name MsprofReportAdditionalInfo + * @brief report profiling additional infomation + * @param [in] agingFlag: 0 isn't aging, !0 is aging + * @param [in] data: profiling data of additional infomation + * @param [in] length: length of profiling data + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofReportAdditionalInfo(uint32_t agingFlag, const VOID_PTR data, uint32_t length); + +/* + * @ingroup libprofapi + * @name MsprofRegTypeInfo + * @brief reg mapping info of type id and type name + * @param [in] level: level is the report struct's level + * @param [in] typeId: type id is the report struct's type + * @param [in] typeName: label of type id for presenting user + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofRegTypeInfo(uint16_t level, uint32_t typeId, const char *typeName); + +/* + * @ingroup libprofapi + * @name MsprofGetHashId + * @brief return hash id of hash info + * @param [in] hashInfo: infomation to be hashed + * @param [in] length: the length of infomation to be hashed + * @return hash id + */ +MSVP_PROF_API uint64_t MsprofGetHashId(const char *hashInfo, size_t length); + +/** + * @ingroup libprofapi + * @name MsprofSetDeviceIdByGeModelIdx + * @brief insert device id by model id + * @param [in] geModelIdx: ge model id + * @param [in] deviceId: device id + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofSetDeviceIdByGeModelIdx(const uint32_t geModelIdx, const uint32_t deviceId); + +/** + * @ingroup libprofapi + * @name MsprofUnsetDeviceIdByGeModelIdx + * @brief delete device id by model id + * @param [in] geModelIdx: ge model id + * @param [in] deviceId: device id + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofUnsetDeviceIdByGeModelIdx(const uint32_t geModelIdx, const uint32_t deviceId); + +/** + * @ingroup libprofapi + * @name register report interface for atlas + * @brief report api timestamp + * @param [in] chipId: multi die's chip + * @param [in] deviceId: device id + * @param [in] isOpen: device is open + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofNotifySetDevice(uint32_t chipId, uint32_t deviceId, bool isOpen); + +/** + * @ingroup libprofapi + * @name MsprofFinalize + * @brief profiling finalize + * @return 0:SUCCESS, !0:FAILED + */ +MSVP_PROF_API int32_t MsprofFinalize(void); + +/* + * @ingroup libprofapi + * @name MsprofSysCycleTime + * @brief get systime cycle time of CPU + * @return system cycle time of CPU + */ +MSVP_PROF_API uint64_t MsprofSysCycleTime(void); + +/** + * @ingroup libprofapi + * @name MsprofStart + * @brief profiling start + * @param dataType: MsprofCtrlCallbackType + * @param data: MsprofConfig + * @param dataLen: length of MsprofConfig + * @return 0:SUCCESS, !0:FAILED + */ +int32_t MsprofStart(uint32_t dataType, const void *data, uint32_t length); + +/** + * @ingroup libprofapi + * @name MsprofStop + * @brief profiling stop + * @param dataType: MsprofCtrlCallbackType + * @param data: MsprofConfig + * @param dataLen: length of MsprofConfig + * @return 0:SUCCESS, !0:FAILED + */ +int32_t MsprofStop(uint32_t dataType, const void *data, uint32_t length); +#ifdef __cplusplus +} +#endif + +#endif diff --git a/inc/msprof/toolchain/prof_common.h b/inc/msprof/toolchain/prof_common.h new file mode 100644 index 0000000000000000000000000000000000000000..d5dcdb9db40f54f51aba960d5bb0e2b3325c1b4a --- /dev/null +++ b/inc/msprof/toolchain/prof_common.h @@ -0,0 +1,1184 @@ +/** + * @file prof_common.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2024. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ +#ifndef MSPROFILER_PROF_COMMON_H +#define MSPROFILER_PROF_COMMON_H + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define MSPROF_DATA_HEAD_MAGIC_NUM 0x5A5AU +#define MSPROF_REPORT_DATA_MAGIC_NUM 0x5A5AU +#define MSPROF_TASK_TIME_L0 0x00000800ULL // mean PROF_TASK_TIME +#define MSPROF_EVENT_FLAG 0xFFFFFFFFFFFFFFFFULL +typedef void* VOID_PTR; +typedef const void* ConstVoidPtr; +typedef int32_t (*MsprofReportHandle)(uint32_t moduleId, uint32_t type, VOID_PTR data, uint32_t len); +typedef int32_t (*MsprofCtrlHandle)(uint32_t type, VOID_PTR data, uint32_t len); +typedef int32_t (*MsprofSetDeviceHandle)(VOID_PTR data, uint32_t len); +typedef int32_t (*MsprofCtrlCallback)(uint32_t type, void *data, uint32_t len); +typedef int32_t (*MsprofReporterCallback)(uint32_t moduleId, uint32_t type, void *data, uint32_t len); + +/** + * @name ProfCommandHandle + * @brief callback to start/stop profiling + * @param type [IN] enum call back type + * @param data [IN] callback data + * @param len [IN] callback data size + * @return enum MsprofErrorCode + */ +typedef int32_t (*ProfCommandHandle)(uint32_t type, VOID_PTR data, uint32_t len); + +/* Msprof report level */ +#define MSPROF_REPORT_PYTORCH_LEVEL 30000U +#define MSPROF_REPORT_PTA_LEVEL 25000U +#define MSPROF_REPORT_TX_LEVEL 20500U +#define MSPROF_REPORT_ACL_LEVEL 20000U +#define MSPROF_REPORT_MODEL_LEVEL 15000U +#define MSPROF_REPORT_NODE_LEVEL 10000U +#define MSPROF_REPORT_AICPU_LEVEL 6000U +#define MSPROF_REPORT_HCCL_NODE_LEVEL 5500U +#define MSPROF_REPORT_RUNTIME_LEVEL 5000U +#define MSPROF_REPORT_PROF_LEVEL 4500U +#define MSPROF_REPORT_DPU_LEVEL 4000U +#define MSPROF_REPORT_AIC_LEVEL 3000U + +/*Msprof report type of tx(20500) level, offset: 0x000000 */ +#define MSPROF_REPORT_TX_BASE_TYPE 0x000000U + +/* Msprof report type of acl(20000) level(acl), offset: 0x000000 */ +#define MSPROF_REPORT_ACL_OP_BASE_TYPE 0x010000U +#define MSPROF_REPORT_ACL_MODEL_BASE_TYPE 0x020000U +#define MSPROF_REPORT_ACL_RUNTIME_BASE_TYPE 0x030000U +#define MSPROF_REPORT_ACL_OTHERS_BASE_TYPE 0x040000U + + +/* Msprof report type of acl(20000) level(host api), offset: 0x050000 */ +#define MSPROF_REPORT_ACL_NN_BASE_TYPE 0x050000U +#define MSPROF_REPORT_ACL_ASCENDC_TYPE 0x060000U +#define MSPROF_REPORT_ACL_HOST_HCCL_BASE_TYPE 0x070000U +#define MSPROF_REPORT_ACL_DVPP_BASE_TYPE 0x090000U +#define MSPROF_REPORT_ACL_GRAPH_BASE_TYPE 0x0A0000U +#define MSPROF_REPORT_ACL_ATB_BASE_TYPE 0x0B0000U + +/* Msprof report type of model(15000) level, offset: 0x000000 */ +#define MSPROF_REPORT_MODEL_GRAPH_ID_MAP_TYPE 0U /* type info: graph_id_map */ +#define MSPROF_REPORT_MODEL_EXECUTE_TYPE 1U /* type info: execute */ +#define MSPROF_REPORT_MODEL_LOAD_TYPE 2U /* type info: load */ +#define MSPROF_REPORT_MODEL_INPUT_COPY_TYPE 3U /* type info: IntputCopy */ +#define MSPROF_REPORT_MODEL_OUTPUT_COPY_TYPE 4U /* type info: OutputCopy */ +#define MSPROF_REPORT_MODEL_LOGIC_STREAM_TYPE 7U /* type info: logic_stream_info */ +#define MSPROF_REPORT_MODEL_EXEOM_TYPE 8U /* type info: exeom */ +#define MSPROF_REPORT_MODEL_UDF_BASE_TYPE 0x010000U /* type info: udf_info */ +#define MSPROF_REPORT_MODEL_AICPU_BASE_TYPE 0x020000U /* type info: aicpu */ + +/* Msprof report type of node(10000) level, offset: 0x000000 */ +#define MSPROF_REPORT_NODE_BASIC_INFO_TYPE 0U /* type info: node_basic_info */ +#define MSPROF_REPORT_NODE_TENSOR_INFO_TYPE 1U /* type info: tensor_info */ +#define MSPROF_REPORT_NODE_FUSION_OP_INFO_TYPE 2U /* type info: funsion_op_info */ +#define MSPROF_REPORT_NODE_CONTEXT_ID_INFO_TYPE 4U /* type info: context_id_info */ +#define MSPROF_REPORT_NODE_LAUNCH_TYPE 5U /* type info: launch */ +#define MSPROF_REPORT_NODE_TASK_MEMORY_TYPE 6U /* type info: task_memory_info */ +#define MSPROF_REPORT_NODE_HOST_OP_EXEC_TYPE 8U /* type info: op exec */ +#define MSPROF_REPORT_NODE_ATTR_INFO_TYPE 9U /* type info: node_attr_info */ +#define MSPROF_REPORT_NODE_HCCL_OP_INFO_TYPE 10U /* type info: hccl_op_info */ +#define MSPROF_REPORT_NODE_STATIC_OP_MEM_TYPE 11U /* type info: static_op_mem */ +#define MSPROF_REPORT_NODE_MC2_COMMINFO_TYPE 12U /* type info: mc2_comm_info */ + +/* Msprof report type of node(10000) level(ge api), offset: 0x010000 */ +#define MSPROF_REPORT_NODE_GE_API_BASE_TYPE 0x010000U /* type info: ge api */ +#define MSPROF_REPORT_NODE_HCCL_BASE_TYPE 0x020000U /* type info: hccl api */ +#define MSPROF_REPORT_NODE_DVPP_API_BASE_TYPE 0x030000U /* type info: dvpp api */ +/* Msprof report type of aicpu(6000), offset: 0x000000 */ +#define MSPROF_REPORT_AICPU_NODE_TYPE 0U /* type info: DATA_PREPROCESS.AICPU */ +#define MSPROF_REPORT_AICPU_DP_TYPE 1U /* type info: DATA_PREPROCESS.DP */ +#define MSPROF_REPORT_AICPU_MODEL_TYPE 2U /* type info: DATA_PREPROCESS.AICPU_MODEL */ +#define MSPROF_REPORT_AICPU_MI_TYPE 3U /* type info: DATA_PREPROCESS.AICPUMI */ +#define MSPROF_REPORT_AICPU_MC2_EXECUTE_COMM_TIME 4U /* 通信时刻信息 */ +#define MSPROF_REPORT_AICPU_MC2_EXECUTE_COMP_TIME 5U /* 计算时刻信息 */ +#define MSPROF_REPORT_AICPU_MC2_HCCL_INFO 6U /* task信息 */ +#define MSPROF_REPORT_AICPU_HCCL_OP_INFO 10U /* task信息 hccl_op_info */ +#define MSPROF_REPORT_AICPU_FILP_TASK 11U /* flip 翻转 task */ +#define MSPROF_REPORT_AICPU_HCCL_FLAG_TASK 12U /* aicpu 展开 hccl 主流首尾task */ + +/* Msprof report type of hccl(5500) level(op api), offset: 0x010000 */ +#define MSPROF_REPORT_HCCL_NODE_BASE_TYPE 0x010000U +#define MSPROF_REPORT_HCCL_MASTER_TYPE 0x010001U +#define MSPROF_REPORT_HCCL_SLAVE_TYPE 0x010002U + +/* Msprof report type of hccl(4000U) level(dpu), offset: 0x000000 */ +#define MSPROF_REPORT_DPU_TRACK_TYPE 0U /* type info: dpu_track */ + +/* use with AdprofCheckFeatureIsOn */ +#define ADPROF_TASK_TIME_L0 0x00000008ULL +#define ADPROF_TASK_TIME_L1 0x00000010ULL +#define ADPROF_TASK_TIME_L2 0x00000020ULL + +/* Msprof report type of profiling(4500) */ +#define MSPROF_REPORT_DIAGNOSTIC_INFO_TYPE 0x010000U + +/* Msprof report ai core type of profiling(3000) */ +#define MSPROF_REPORT_AIC_TIMESTAMP_TYPE 0x0U + +enum ProfileCallbackType { + PROFILE_CTRL_CALLBACK = 0, + PROFILE_DEVICE_STATE_CALLBACK, + PROFILE_REPORT_API_CALLBACK, + PROFILE_REPORT_EVENT_CALLBACK, + PROFILE_REPORT_COMPACT_CALLBACK, + PROFILE_REPORT_ADDITIONAL_CALLBACK, + PROFILE_REPORT_REG_TYPE_INFO_CALLBACK, + PROFILE_REPORT_GET_HASH_ID_CALLBACK, + PROFILE_HOST_FREQ_IS_ENABLE_CALLBACK, + PROFILE_REPORT_API_C_CALLBACK, + PROFILE_REPORT_EVENT_C_CALLBACK, + PROFILE_REPORT_REG_TYPE_INFO_C_CALLBACK, + PROFILE_REPORT_GET_HASH_ID_C_CALLBACK, + PROFILE_HOST_FREQ_IS_ENABLE_C_CALLBACK, +}; + +enum MsprofDataTag { + MSPROF_ACL_DATA_TAG = 0, // acl data tag, range: 0~19 + MSPROF_GE_DATA_TAG_MODEL_LOAD = 20, // ge data tag, range: 20~39 + MSPROF_GE_DATA_TAG_FUSION = 21, + MSPROF_GE_DATA_TAG_INFER = 22, + MSPROF_GE_DATA_TAG_TASK = 23, + MSPROF_GE_DATA_TAG_TENSOR = 24, + MSPROF_GE_DATA_TAG_STEP = 25, + MSPROF_GE_DATA_TAG_ID_MAP = 26, + MSPROF_GE_DATA_TAG_HOST_SCH = 27, + MSPROF_RUNTIME_DATA_TAG_API = 40, // runtime data tag, range: 40~59 + MSPROF_RUNTIME_DATA_TAG_TRACK = 41, + MSPROF_AICPU_DATA_TAG = 60, // aicpu data tag, range: 60~79 + MSPROF_AICPU_MODEL_TAG = 61, + MSPROF_HCCL_DATA_TAG = 80, // hccl data tag, range: 80~99 + MSPROF_DP_DATA_TAG = 100, // dp data tag, range: 100~119 + MSPROF_MSPROFTX_DATA_TAG = 120, // hccl data tag, range: 120~139 + MSPROF_DATA_TAG_MAX = 65536, // data tag value type is uint16_t +}; + +enum MsprofMindsporeNodeTag { + GET_NEXT_DEQUEUE_WAIT = 1, +}; + +/** + * @brief struct of mixed data + */ +#define MSPROF_MIX_DATA_RESERVE_BYTES 7 +#define MSPROF_MIX_DATA_STRING_LEN 120 +enum MsprofMixDataType { + MSPROF_MIX_DATA_HASH_ID = 0, + MSPROF_MIX_DATA_STRING, +}; +struct MsprofMixData { + uint8_t type; // MsprofMixDataType + uint8_t rsv[MSPROF_MIX_DATA_RESERVE_BYTES]; + union { + uint64_t hashId; + char dataStr[MSPROF_MIX_DATA_STRING_LEN]; + } data; +}; + +#define PATH_LEN_MAX 1023 +#define PARAM_LEN_MAX 4095 +struct MsprofCommandHandleParams { + uint32_t pathLen; + uint32_t storageLimit; // MB + uint32_t profDataLen; + char path[PATH_LEN_MAX + 1]; + char profData[PARAM_LEN_MAX + 1]; +}; + +/** + * @brief profiling command info + */ +#define MSPROF_MAX_DEV_NUM 64 +struct MsprofCommandHandle { + uint64_t profSwitch; + uint64_t profSwitchHi; + uint32_t devNums; + uint32_t devIdList[MSPROF_MAX_DEV_NUM]; + uint32_t modelId; + uint32_t type; + uint32_t cacheFlag; + struct MsprofCommandHandleParams params; +}; + +/** + * @brief profiling MsprofStart config + */ +#define MAX_DUMP_PATH_LEN 1024 +#define MAX_SAMPLE_CONFIG_LEN 4096 +struct MsprofConfig { + uint64_t profSwitch; + uint64_t profSwitchHi; + uint32_t devNums; + uint32_t devIdList[MSPROF_MAX_DEV_NUM + 1]; + uint32_t modelId; + uint32_t type; + uint32_t cacheFlag; + uint32_t storageLimit; + uint32_t metrics; + uintptr_t fd; + char dumpPath[MAX_DUMP_PATH_LEN]; + char sampleConfig[MAX_SAMPLE_CONFIG_LEN]; +}; + +/** + * @brief struct of data reported by acl + */ +#define MSPROF_ACL_DATA_RESERVE_BYTES 32 +#define MSPROF_ACL_API_NAME_LEN 64 +enum MsprofAclApiType { + MSPROF_ACL_API_TYPE_OP = 1, + MSPROF_ACL_API_TYPE_MODEL, + MSPROF_ACL_API_TYPE_RUNTIME, + MSPROF_ACL_API_TYPE_OTHERS, +}; +struct MsprofAclProfData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_ACL_DATA_TAG; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t apiType; // enum MsprofAclApiType + uint64_t beginTime; + uint64_t endTime; + uint32_t processId; + uint32_t threadId; + char apiName[MSPROF_ACL_API_NAME_LEN]; + uint8_t reserve[MSPROF_ACL_DATA_RESERVE_BYTES]; +}; + +/** + * @brief struct of data reported by GE + */ +#define MSPROF_GE_MODELLOAD_DATA_RESERVE_BYTES 104 +struct MsprofGeProfModelLoadData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_MODEL_LOAD; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t modelId; + struct MsprofMixData modelName; + uint64_t startTime; + uint64_t endTime; + uint8_t reserve[MSPROF_GE_MODELLOAD_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_FUSION_DATA_RESERVE_BYTES 8 +#define MSPROF_GE_FUSION_OP_NUM 8 +struct MsprofGeProfFusionData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_FUSION; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t modelId; + struct MsprofMixData fusionName; + uint64_t inputMemSize; + uint64_t outputMemSize; + uint64_t weightMemSize; + uint64_t workspaceMemSize; + uint64_t totalMemSize; + uint64_t fusionOpNum; + uint64_t fusionOp[MSPROF_GE_FUSION_OP_NUM]; + uint8_t reserve[MSPROF_GE_FUSION_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_INFER_DATA_RESERVE_BYTES 64 +struct MsprofGeProfInferData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_INFER; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t modelId; + struct MsprofMixData modelName; + uint32_t requestId; + uint32_t threadId; + uint64_t inputDataStartTime; + uint64_t inputDataEndTime; + uint64_t inferStartTime; + uint64_t inferEndTime; + uint64_t outputDataStartTime; + uint64_t outputDataEndTime; + uint8_t reserve[MSPROF_GE_INFER_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_TASK_DATA_RESERVE_BYTES 12 +#define MSPROF_GE_OP_TYPE_LEN 56 +enum MsprofGeTaskType { + MSPROF_GE_TASK_TYPE_AI_CORE = 0, + MSPROF_GE_TASK_TYPE_AI_CPU, + MSPROF_GE_TASK_TYPE_AIV, + MSPROF_GE_TASK_TYPE_WRITE_BACK, + MSPROF_GE_TASK_TYPE_MIX_AIC, + MSPROF_GE_TASK_TYPE_MIX_AIV, + MSPROF_GE_TASK_TYPE_FFTS_PLUS, + MSPROF_GE_TASK_TYPE_DSA, + MSPROF_GE_TASK_TYPE_DVPP, + MSPROF_GE_TASK_TYPE_HCCL, + MSPROF_GE_TASK_TYPE_FUSION, + MSPROF_GE_TASK_TYPE_INVALID +}; + +enum MsprofGeShapeType { + MSPROF_GE_SHAPE_TYPE_STATIC = 0, + MSPROF_GE_SHAPE_TYPE_DYNAMIC, +}; +struct MsprofGeOpType { + uint8_t type; // MsprofMixDataType + uint8_t rsv[MSPROF_MIX_DATA_RESERVE_BYTES]; + union { + uint64_t hashId; + char dataStr[MSPROF_GE_OP_TYPE_LEN]; + } data; +}; +struct MsprofGeProfTaskData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_TASK; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t taskType; // MsprofGeTaskType + struct MsprofMixData opName; + struct MsprofGeOpType opType; + uint64_t curIterNum; + uint64_t timeStamp; + uint32_t shapeType; // MsprofGeShapeType + uint32_t blockDims; + uint32_t modelId; + uint32_t streamId; + uint32_t taskId; + uint32_t threadId; + uint32_t contextId; + uint8_t reserve[MSPROF_GE_TASK_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_TENSOR_DATA_RESERVE_BYTES 8 +#define MSPROF_GE_TENSOR_DATA_SHAPE_LEN 8 +#define MSPROF_GE_TENSOR_DATA_NUM 5 +enum MsprofGeTensorType { + MSPROF_GE_TENSOR_TYPE_INPUT = 0, + MSPROF_GE_TENSOR_TYPE_OUTPUT, +}; +struct MsprofGeTensorData { + uint32_t tensorType; // MsprofGeTensorType + uint32_t format; + uint32_t dataType; + uint32_t shape[MSPROF_GE_TENSOR_DATA_SHAPE_LEN]; +}; + +struct MsprofGeProfTensorData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_TENSOR; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t modelId; + uint64_t curIterNum; + uint32_t streamId; + uint32_t taskId; + uint32_t tensorNum; + struct MsprofGeTensorData tensorData[MSPROF_GE_TENSOR_DATA_NUM]; + uint8_t reserve[MSPROF_GE_TENSOR_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_STEP_DATA_RESERVE_BYTES 27 +enum MsprofGeStepTag { + MSPROF_GE_STEP_TAG_BEGIN = 0, + MSPROF_GE_STEP_TAG_END, +}; +struct MsprofGeProfStepData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_STEP; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t modelId; + uint32_t streamId; + uint32_t taskId; + uint64_t timeStamp; + uint64_t curIterNum; + uint32_t threadId; + uint8_t tag; // MsprofGeStepTag + uint8_t reserve[MSPROF_GE_STEP_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_ID_MAP_DATA_RESERVE_BYTES 6 +struct MsprofGeProfIdMapData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_ID_MAP; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t graphId; + uint32_t modelId; + uint32_t sessionId; + uint64_t timeStamp; + uint16_t mode; + uint8_t reserve[MSPROF_GE_ID_MAP_DATA_RESERVE_BYTES]; +}; + +#define MSPROF_GE_HOST_SCH_DATA_RESERVE_BYTES 24 +struct MsprofGeProfHostSchData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_GE_DATA_TAG_HOST_SCH; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t threadId; // record in start event + uint64_t element; + uint64_t event; + uint64_t startTime; // record in start event + uint64_t endTime; // record in end event + uint8_t reserve[MSPROF_GE_HOST_SCH_DATA_RESERVE_BYTES]; +}; + +/** + * @brief struct of data reported by RunTime + */ +#define MSPROF_AICPU_DATA_RESERVE_BYTES 9 +struct MsprofAicpuProfData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_AICPU_DATA_TAG; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint16_t streamId; + uint16_t taskId; + uint64_t runStartTime; + uint64_t runStartTick; + uint64_t computeStartTime; + uint64_t memcpyStartTime; + uint64_t memcpyEndTime; + uint64_t runEndTime; + uint64_t runEndTick; + uint32_t threadId; + uint32_t deviceId; + uint64_t submitTick; + uint64_t scheduleTick; + uint64_t tickBeforeRun; + uint64_t tickAfterRun; + uint32_t kernelType; + uint32_t dispatchTime; + uint32_t totalTime; + uint16_t fftsThreadId; + uint8_t version; + uint8_t reserve[MSPROF_AICPU_DATA_RESERVE_BYTES]; +}; + +struct MsprofAicpuModelProfData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_AICPU_MODEL_TAG; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t rsv; // Ensure 8-byte alignment + uint64_t timeStamp; + uint64_t indexId; + uint32_t modelId; + uint16_t tagId; + uint16_t rsv1; + uint64_t eventId; + uint8_t reserve[24]; +}; + +/** + * @brief struct of data reported by DP + */ +#define MSPROF_DP_DATA_RESERVE_BYTES 16 +#define MSPROF_DP_DATA_ACTION_LEN 16 +#define MSPROF_DP_DATA_SOURCE_LEN 64 +#define MSPROF_CTX_ID_MAX_NUM 55 + +struct MsprofDpProfData { +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_DATA_HEAD_MAGIC_NUM; + uint16_t dataTag = MSPROF_DP_DATA_TAG; +#else + uint16_t magicNumber; + uint16_t dataTag; +#endif + uint32_t rsv; // Ensure 8-byte alignment + uint64_t timeStamp; + char action[MSPROF_DP_DATA_ACTION_LEN]; + char source[MSPROF_DP_DATA_SOURCE_LEN]; + uint64_t index; + uint64_t size; + uint8_t reserve[MSPROF_DP_DATA_RESERVE_BYTES]; +}; + +struct MsprofAicpuNodeAdditionalData { + uint16_t streamId; + uint16_t taskId; + uint64_t runStartTime; + uint64_t runStartTick; + uint64_t computeStartTime; + uint64_t memcpyStartTime; + uint64_t memcpyEndTime; + uint64_t runEndTime; + uint64_t runEndTick; + uint32_t threadId; + uint32_t deviceId; + uint64_t submitTick; + uint64_t scheduleTick; + uint64_t tickBeforeRun; + uint64_t tickAfterRun; + uint32_t kernelType; + uint32_t dispatchTime; + uint32_t totalTime; + uint16_t fftsThreadId; + uint8_t version; + uint8_t reserve[MSPROF_AICPU_DATA_RESERVE_BYTES]; +}; + +struct MsprofAicpuModelAdditionalData { + uint64_t indexId; + uint32_t modelId; + uint16_t tagId; + uint16_t rsv1; + uint64_t eventId; + uint8_t reserve[24]; +}; + +struct MsprofAicpuDpAdditionalData { + char action[MSPROF_DP_DATA_ACTION_LEN]; + char source[MSPROF_DP_DATA_SOURCE_LEN]; + uint64_t index; + uint64_t size; + uint8_t reserve[MSPROF_DP_DATA_RESERVE_BYTES]; +}; + +struct MsprofAicpuMiAdditionalData { + uint32_t nodeTag; // MsprofMindsporeNodeTag:1 + uint32_t reserve; + uint64_t queueSize; + uint64_t runStartTime; + uint64_t runEndTime; +}; + +// AICPU kfc算子执行时间 +struct AicpuKfcProfCommTurn { + uint64_t waitNotifyStartTime; // 开始等待通信参数 + uint64_t kfcAlgExeStartTime; // 开始通信算法执行 + uint64_t sendTaskStartTime; // 开始下发task + uint64_t waitActiveStartTime; // 开始等待激活 + uint64_t acitveStartTime; // 开始激活处理 + uint64_t waitExeEndStartTime; // 开始等待任务执行结束 + uint64_t rtsqExeEndTime; // 任务执行结束时间 + uint64_t dataLen; // 本轮通信数据长度 + uint32_t deviceId; + uint16_t streamId; + uint16_t taskId; + uint8_t version; + uint8_t commTurn; // 总通信轮次 + uint8_t currentTurn; + uint8_t reserve[5]; +}; + +// Aicore算子执行时间 +struct AicpuKfcProfComputeTurn { + uint64_t waitComputeStartTime; // 开始等待计算 + uint64_t computeStartTime; // 开始计算 + uint64_t computeExeEndTime; // 计算执行结束 + uint64_t dataLen; // 本轮计算数据长度 + uint32_t deviceId; + uint16_t streamId; + uint16_t taskId; + uint8_t version; + uint8_t computeTurn; // 总计算轮次 + uint8_t currentTurn; + uint8_t reserve[5]; +}; + +// 翻转task的上报 +struct MsporfAicpuFlipTask { + uint16_t streamId; + uint16_t taskId; // 值无特殊要求 + uint32_t flipNum; + uint32_t reserve[2]; +}; + +struct MsprofAicpuHcclMainStreamTask { + uint16_t aicpuStreamId; + uint16_t aicpuTaskId; + uint16_t streamId; + uint16_t taskId; + uint16_t type; // 0是头 1是尾 + uint16_t reserve[3]; +}; + +/** + * @brief struct of data reported by HCCL + */ +#pragma pack(4) +struct MsprofHcclProfNotify { + uint32_t taskID; + uint64_t notifyID; + uint32_t stage; + uint32_t remoteRank; + uint32_t transportType; + uint32_t role; // role {0: dst, 1:src} + double durationEstimated; +}; + +struct MsprofHcclProfReduce { + uint32_t taskID; + uint64_t src; + uint64_t dst; + uint64_t size; + uint32_t op; // {0: sum, 1: mul, 2: max, 3: min} + uint32_t dataType; // data type {0: INT8, 1: INT16, 2: INT32, 3: FP16, 4:FP32, 5:INT64, 6:UINT64} + uint32_t linkType; // link type {0: 'OnChip', 1: 'HCCS', 2: 'PCIe', 3: 'RoCE'} + uint32_t remoteRank; + uint32_t transportType; // transport type {0: SDMA, 1: RDMA, 2:LOCAL} + uint32_t role; // role {0: dst, 1:src} + double durationEstimated; +}; + +struct MsprofHcclProfRDMA { + uint32_t taskID; + uint64_t src; + uint64_t dst; + uint64_t size; + uint64_t notifyID; + uint32_t linkType; // link type {0: 'OnChip', 1: 'HCCS', 2: 'PCIe', 3: 'RoCE'} + uint32_t remoteRank; + uint32_t transportType; // transport type {0: RDMA, 1:SDMA, 2:LOCAL} + uint32_t role; // role {0: dst, 1:src} + uint32_t type; // RDMA type {0: RDMASendNotify, 1:RDMASendPayload} + double durationEstimated; +}; + +struct MsprofHcclProfMemcpy { + uint32_t taskID; + uint64_t src; + uint64_t dst; + uint64_t size; + uint64_t notifyID; + uint32_t linkType; // link type {0: 'OnChip', 1: 'HCCS', 2: 'PCIe', 3: 'RoCE'} + uint32_t remoteRank; + uint32_t transportType; // transport type {0: RDMA, 1:SDMA, 2:LOCAL} + uint32_t role; // role {0: dst, 1:src} + double durationEstimated; +}; + +struct MsprofHcclProfStageStep { + uint32_t rank; + uint32_t rankSize; +}; + +struct MsprofHcclProfFlag { + uint64_t cclTag; + uint64_t groupName; + uint32_t localRank; + uint32_t workFlowMode; +}; + +#define MSPROF_HCCL_INVALID_UINT 0xFFFFFFFFU +struct MsprofHcclInfo { + uint64_t itemId; + uint64_t cclTag; + uint64_t groupName; + uint32_t localRank; + uint32_t remoteRank; + uint32_t rankSize; + uint32_t workFlowMode; + uint32_t planeID; + uint32_t ctxId; + uint64_t notifyID; + uint32_t stage; + uint32_t role; // role {0: dst, 1:src} + double durationEstimated; + uint64_t srcAddr; + uint64_t dstAddr; + uint64_t dataSize; // bytes + uint32_t opType; // {0: sum, 1: mul, 2: max, 3: min} + uint32_t dataType; // data type {0: INT8, 1: INT16, 2: INT32, 3: FP16, 4:FP32, 5:INT64, 6:UINT64} + uint32_t linkType; // link type {0: 'OnChip', 1: 'HCCS', 2: 'PCIe', 3: 'RoCE'} + uint32_t transportType; // transport type {0: SDMA, 1: RDMA, 2:LOCAL} + uint32_t rdmaType; // RDMA type {0: RDMASendNotify, 1:RDMASendPayload} + uint32_t reserve2; +#ifdef __cplusplus + MsprofHcclInfo() : role(MSPROF_HCCL_INVALID_UINT), opType(MSPROF_HCCL_INVALID_UINT), + dataType(MSPROF_HCCL_INVALID_UINT), linkType(MSPROF_HCCL_INVALID_UINT), + transportType(MSPROF_HCCL_INVALID_UINT), rdmaType(MSPROF_HCCL_INVALID_UINT) + { + } +#endif +}; + +struct MsprofAicpuMC2HcclInfo { + uint64_t itemId; + uint64_t cclTag; + uint64_t groupName; + uint32_t localRank; + uint32_t remoteRank; + uint32_t rankSize; + uint32_t workFlowMode; + uint32_t planeID; + uint32_t ctxId; + uint64_t notifyID; + uint32_t stage; + uint32_t role; // role {0: dst, 1:src} + double durationEstimated; + uint64_t srcAddr; + uint64_t dstAddr; + uint64_t dataSize; // bytes + uint32_t opType; // {0: sum, 1: mul, 2: max, 3: min} + uint32_t dataType; // data type {0: INT8, 1: INT16, 2: INT32, 3: FP16, 4:FP32, 5:INT64, 6:UINT64} + uint32_t linkType; // link type {0: 'OnChip', 1: 'HCCS', 2: 'PCIe', 3: 'RoCE'} + uint32_t transportType; // transport type {0: SDMA, 1: RDMA, 2:LOCAL} + uint32_t rdmaType; // RDMA type {0: RDMASendNotify, 1:RDMASendPayload} + uint32_t taskId; + uint16_t streamId; + uint16_t reserve[3]; +}; + +struct ProfilingDeviceCommResInfo { + uint64_t groupName; // 通信域 + uint32_t rankSize; // 通信域内rank总数 + uint32_t rankId; // 当前device rankId,通信域内编号 + uint32_t usrRankId; // 当前device rankId,全局编号 + uint32_t aicpuKfcStreamId; // MC2中launch aicpu kfc算子的stream + uint32_t commStreamSize; // 当前device侧使用的通信stream数量 + uint32_t commStreamIds[8]; // 具体streamId + uint32_t reserve; +}; + +#define MSPROF_MULTI_THREAD_MAX_NUM 25 +struct MsprofMultiThread { + uint32_t threadNum; + uint32_t threadId[MSPROF_MULTI_THREAD_MAX_NUM]; +}; +#pragma pack() + + +#pragma pack(1) +struct MsprofNodeBasicInfo { + uint64_t opName; + uint32_t taskType; + uint64_t opType; + uint32_t blockDim; + uint32_t opFlag; +}; + +enum AttrType { + OP_ATTR = 0, +}; + +struct MsprofAttrInfo { + uint64_t opName; + uint32_t attrType; + uint64_t hashId; +}; + +struct MsrofTensorData { + uint32_t tensorType; + uint32_t format; + uint32_t dataType; + uint32_t shape[MSPROF_GE_TENSOR_DATA_SHAPE_LEN]; +}; + +struct MsprofTensorInfo { + uint64_t opName; + uint32_t tensorNum; + struct MsrofTensorData tensorData[MSPROF_GE_TENSOR_DATA_NUM]; +}; + +struct MsprofHCCLOPInfo { // for MsprofReportCompactInfo buffer data + uint8_t relay : 1; // 借轨通信 + uint8_t retry : 1; // 重传标识 + uint8_t dataType; // 跟HcclDataType类型保存一致 + uint64_t algType; // 通信算子使用的算法,hash的key,其值是以"-"分隔的字符串 + uint64_t count; // 发送数据个数 + uint64_t groupName; // group hash id +}; + +struct MsprofAicpuHCCLOPInfo { + uint8_t relay : 1; // 借轨通信 + uint8_t retry : 1; // 重传标识 + uint8_t dataType; // 跟HcclDataType类型保存一致 + uint64_t algType; // 通信算子使用的算法,hash的key,其值是以"-"分隔的字符串 + uint64_t count; // 发送数据个数 + uint64_t groupName; // group hash id + uint32_t ranksize; + uint16_t streamId; + uint32_t taskId; +}; + +struct ProfFusionOpInfo { +uint64_t opName; +uint32_t fusionOpNum; +uint64_t inputMemsize; +uint64_t outputMemsize; +uint64_t weightMemSize; +uint64_t workspaceMemSize; +uint64_t totalMemSize; +uint64_t fusionOpId[MSPROF_GE_FUSION_OP_NUM]; +}; + +struct MsprofContextIdInfo { + uint64_t opName; + uint32_t ctxIdNum; + uint32_t ctxIds[MSPROF_CTX_ID_MAX_NUM]; +}; + +struct MsprofGraphIdInfo { + uint64_t modelName; + uint32_t graphId; + uint32_t modelId; +}; + +struct MsprofMemoryInfo { + uint64_t addr; + int64_t size; + uint64_t nodeId; // op name hash id + uint64_t totalAllocateMemory; + uint64_t totalReserveMemory; + uint32_t deviceId; + uint32_t deviceType; +}; + +/** + * @name MsprofStampInfo + * @brief struct of data reported by msproftx + */ +#define PAYLOAD_VALUE_LEN 2 +#define MAX_MESSAGE_LEN 128 +struct MsprofStampInfo { + uint16_t magicNumber; + uint16_t dataTag; + uint32_t processId; + uint32_t threadId; + uint32_t category; // marker category + uint32_t eventType; + int32_t payloadType; + union PayloadValue { + uint64_t ullValue; + int64_t llValue; + double dValue; + uint32_t uiValue[PAYLOAD_VALUE_LEN]; + int32_t iValue[PAYLOAD_VALUE_LEN]; + float fValue[PAYLOAD_VALUE_LEN]; + } payload; // payload info for marker + uint64_t startTime; + uint64_t endTime; + uint64_t markId; + int32_t messageType; + char message[MAX_MESSAGE_LEN]; +}; + +#define MSPROF_TX_VALUE_MAX_LEN 224 // 224 + 8 = 232: additional data len +struct MsprofTxInfo { + uint16_t infoType; // 0: Mark; 1: MarkEx + uint16_t res0; + uint32_t res1; + union { + struct MsprofStampInfo stampInfo; + uint8_t data[MSPROF_TX_VALUE_MAX_LEN]; + } value; +}; + +struct MsprofStaticOpMem { + int64_t size; // op memory size + uint64_t opName; // op name hash id + uint64_t lifeStart; // serial number of op memory used + uint64_t lifeEnd; // serial number of op memory used + uint64_t totalAllocateMemory; // static graph total allocate memory + uint64_t dynOpName; // 0: invalid, other: dynamic op name of root + uint32_t graphId; // multipe model +}; + +#define MSPROF_PHYSIC_STREAM_ID_MAX_NUM 56 +struct MsprofLogicStreamInfo { + uint32_t logicStreamId; + uint32_t physicStreamNum; + uint32_t physicStreamId[MSPROF_PHYSIC_STREAM_ID_MAX_NUM]; +}; + +struct MsprofExeomLoadInfo { + uint32_t modelId; + uint32_t reserve; + uint64_t modelName; /* name hash */ +}; +#pragma pack() + +struct MsprofApi { // for MsprofReportApi +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_REPORT_DATA_MAGIC_NUM; +#else + uint16_t magicNumber; +#endif + uint16_t level; + uint32_t type; + uint32_t threadId; + uint32_t reserve; + uint64_t beginTime; + uint64_t endTime; + uint64_t itemId; +}; + +struct MsprofEvent { // for MsprofReportEvent +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_REPORT_DATA_MAGIC_NUM; +#else + uint16_t magicNumber; +#endif + uint16_t level; + uint32_t type; + uint32_t threadId; + uint32_t requestId; // 0xFFFF means single event + uint64_t timeStamp; +#ifdef __cplusplus + uint64_t eventFlag = MSPROF_EVENT_FLAG; +#else + uint64_t eventFlag; +#endif + uint64_t itemId; +}; + +struct MsprofRuntimeTrack { // for MsprofReportCompactInfo buffer data + uint16_t deviceId; + uint16_t streamId; + uint32_t taskId; + uint64_t taskType; // task message hash id + uint64_t kernelName; // kernelname hash id +}; + +struct MsprofDpuTrack { // for MsprofReportCompactInfo buffer data + uint16_t deviceId; // high 4 bits, devType: dpu: 1, low 12 bits device id + uint16_t streamId; + uint32_t taskId; + uint32_t taskType; // task type enum + uint32_t res; + uint64_t startTime; // start time +}; + +struct MsprofAicTimeStampInfo { + uint64_t syscyc; // dotting timestamp with system cycle + uint32_t blockId; // core block id + uint32_t descId; // dot Id for description + uint64_t curPc; // currrent pc for source line +}; + +#define MSPROF_COMPACT_INFO_DATA_LENGTH (40) +struct MsprofCompactInfo { // for MsprofReportCompactInfo buffer data +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_REPORT_DATA_MAGIC_NUM; +#else + uint16_t magicNumber; +#endif + uint16_t level; + uint32_t type; + uint32_t threadId; + uint32_t dataLen; + uint64_t timeStamp; + union { + uint8_t info[MSPROF_COMPACT_INFO_DATA_LENGTH]; + struct MsprofRuntimeTrack runtimeTrack; + struct MsprofNodeBasicInfo nodeBasicInfo; + struct MsprofHCCLOPInfo hcclopInfo; + struct MsprofDpuTrack dpuTack; + } data; +}; + +#define MSPROF_ADDTIONAL_INFO_DATA_LENGTH (232) +struct MsprofAdditionalInfo { // for MsprofReportAdditionalInfo buffer data +#ifdef __cplusplus + uint16_t magicNumber = MSPROF_REPORT_DATA_MAGIC_NUM; +#else + uint16_t magicNumber; +#endif + uint16_t level; + uint32_t type; + uint32_t threadId; + uint32_t dataLen; + uint64_t timeStamp; + uint8_t data[MSPROF_ADDTIONAL_INFO_DATA_LENGTH]; +}; + +/** + * @name MsprofErrorCode + * @brief error code + */ +enum MsprofErrorCode { + MSPROF_ERROR_NONE = 0, + MSPROF_ERROR_MEM_NOT_ENOUGH, + MSPROF_ERROR_GET_ENV, + MSPROF_ERROR_CONFIG_INVALID, + MSPROF_ERROR_ACL_JSON_OFF, + MSPROF_ERROR, + MSPROF_ERROR_UNINITIALIZE, +}; + +#define MSPROF_ENGINE_MAX_TAG_LEN (63) + +/** + * @name ReporterData + * @brief struct of data to report + */ +struct ReporterData { + char tag[MSPROF_ENGINE_MAX_TAG_LEN + 1]; // the sub-type of the module, data with different tag will be writen + int32_t deviceId; // the index of device + size_t dataLen; // the length of send data + uint8_t *data; // the data content +}; + +/** + * @name MsprofHashData + * @brief struct of data to hash + */ +struct MsprofHashData { + int32_t deviceId; // the index of device + size_t dataLen; // the length of data + uint8_t *data; // the data content + uint64_t hashId; // the id of hashed data +}; + +enum MsprofConfigParamType { + DEV_CHANNEL_RESOURCE = 0, // device channel resource + HELPER_HOST_SERVER // helper host server +}; + +/** + * @name MsprofConfigParam + * @brief struct of set config + */ +struct MsprofConfigParam { + uint32_t deviceId; // the index of device + uint32_t type; // DEV_CHANNEL_RESOURCE; HELPER_HOST_SERVER + uint32_t value; // DEV_CHANNEL_RESOURCE: 1 off; HELPER_HOST_SERVER: 1 on +}; + +/** + * @name MsprofReporterModuleId + * @brief module id of data to report + */ +enum MsprofReporterModuleId { + MSPROF_MODULE_DATA_PREPROCESS = 0, // DATA_PREPROCESS + MSPROF_MODULE_HCCL, // HCCL + MSPROF_MODULE_ACL, // AclModule + MSPROF_MODULE_FRAMEWORK, // Framework + MSPROF_MODULE_RUNTIME, // runtime + MSPROF_MODULE_MSPROF // msprofTx +}; + +/** + * @name MsprofReporterCallbackType + * @brief reporter callback request type + */ +enum MsprofReporterCallbackType { + MSPROF_REPORTER_REPORT = 0, // report data + MSPROF_REPORTER_INIT, // init reporter + MSPROF_REPORTER_UNINIT, // uninit reporter + MSPROF_REPORTER_DATA_MAX_LEN, // data max length for calling report callback + MSPROF_REPORTER_HASH // hash data to id +}; + +#define MSPROF_OPTIONS_DEF_LEN_MAX (2048U) + +/** + * @name MsprofGeOptions + * @brief struct of MSPROF_CTRL_INIT_GE_OPTIONS + */ +struct MsprofGeOptions { + char jobId[MSPROF_OPTIONS_DEF_LEN_MAX]; + char options[MSPROF_OPTIONS_DEF_LEN_MAX]; +}; + +/** + * @name MsprofCtrlCallbackType + * @brief ctrl callback request type + */ +enum MsprofCtrlCallbackType { + MSPROF_CTRL_INIT_ACL_ENV = 0, // start profiling with acl env + MSPROF_CTRL_INIT_ACL_JSON = 1, // start pro with acl.json + MSPROF_CTRL_INIT_GE_OPTIONS = 2, // start profiling with ge env and options + MSPROF_CTRL_FINALIZE = 3, // stop profiling + MSPROF_CTRL_INIT_HELPER = 4, // start profiling in helper device + MSPROF_CTRL_INIT_PURE_CPU = 5, // start profiling in pure cpu + MSPROF_CTRL_INIT_DYNA = 0xFF, // start profiling for dynamic profiling +}; + +enum MsprofCommandHandleType { + PROF_COMMANDHANDLE_TYPE_INIT = 0, + PROF_COMMANDHANDLE_TYPE_START, + PROF_COMMANDHANDLE_TYPE_STOP, + PROF_COMMANDHANDLE_TYPE_FINALIZE, + PROF_COMMANDHANDLE_TYPE_MODEL_SUBSCRIBE, + PROF_COMMANDHANDLE_TYPE_MODEL_UNSUBSCRIBE, + PROF_COMMANDHANDLE_TYPE_MAX +}; + +enum MsprofConfigType { + MSPROF_CONFIG_HELPER_HOST = 0 +}; + +/** + * @brief profiling command type + */ +enum ProfCtrlType { + PROF_CTRL_INVALID = 0, + PROF_CTRL_SWITCH, + PROF_CTRL_REPORTER, + PROF_CTRL_STEPINFO, + PROF_CTRL_BUTT +}; + +/** + * @brief Prof Chip ID + */ +enum Prof_Chip_ID { + PROF_CHIP_ID0 = 0 +}; + +/** + * @brief the struct of profiling set setp info + */ +typedef struct ProfStepInfoCmd { + uint64_t index_id; + uint16_t tag_id; + void *stream; +} ProfStepInfoCmd_t; + +#ifdef __cplusplus +} +#endif +#endif // MSPROFILER_PROF_COMMON_H_ diff --git a/inc/msprof/toolchain/prof_data_config.h b/inc/msprof/toolchain/prof_data_config.h new file mode 100644 index 0000000000000000000000000000000000000000..d0e42132de26f54a0986a433233bf60872cb69b4 --- /dev/null +++ b/inc/msprof/toolchain/prof_data_config.h @@ -0,0 +1,101 @@ +/** + * @file prof_data_config.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef MSPROFILER_API_PROF_DATA_CONFIG_H +#define MSPROFILER_API_PROF_DATA_CONFIG_H + +// DataTypeConfig +#define PROF_ACL_API 0x00000001ULL +#define PROF_TASK_TIME_L1 0x00000002ULL +#define PROF_AICORE_METRICS 0x00000004ULL +#define PROF_AICPU_TRACE 0x00000008ULL +#define PROF_L2CACHE 0x00000010ULL +#define PROF_HCCL_TRACE 0x00000020ULL +#define PROF_TRAINING_TRACE 0x00000040ULL +#define PROF_MSPROFTX 0x00000080ULL +#define PROF_RUNTIME_API 0x00000100ULL +#define PROF_FWK_SCHEDULE_L0 0x00000200ULL +#define PROF_TASK_TSFW 0x00000400ULL +#define PROF_TASK_TIME 0x00000800ULL +#define PROF_TASK_MEMORY 0x00001000ULL +#define PROF_TASK_TIME_L2 0x00002000ULL +#define PROF_OP_ATTR 0x00004000ULL + +// system profilinig switch +#define PROF_CPU 0x00010000ULL +#define PROF_HARDWARE_MEMORY 0x00020000ULL +#define PROF_IO 0x00040000ULL +#define PROF_INTER_CONNECTION 0x00080000ULL +#define PROF_DVPP 0x00100000ULL +#define PROF_SYS_AICORE_SAMPLE 0x00200000ULL +#define PROF_AIVECTORCORE_SAMPLE 0x00400000ULL +#define PROF_INSTR 0x00800000ULL + +#define PROF_FWK_SCHEDULE_L1 0x0000001000000ULL +#define PROF_PURE_CPU 0x0000002000000ULL +#define PROF_RUNTIME_TRACE 0x0000004000000ULL +#define PROF_SCHEDULE_TIMELINE 0x0000008000000ULL +#define PROF_SCHEDULE_TRACE 0x0000010000000ULL +#define PROF_AIVECTORCORE_METRICS 0x0000020000000ULL +#define PROF_SUBTASK_TIME 0x0000040000000ULL +#define PROF_OP_DETAIL 0x0000080000000ULL +#define PROF_OP_TIMESTAMP 0x0000100000000ULL + +#define PROF_AICPU_MODEL 0x4000000000000000ULL +#define PROF_MODEL_LOAD 0x8000000000000000ULL + +#define PROF_TASK_TRACE PROF_RUNTIME_TRACE | PROF_TASK_TIME | PROF_TRAINING_TRACE | PROF_HCCL_TRACE | PROF_TASK_TIME_L1 + +// DataTypeConfig MASK +#define PROF_ACL_API_MASK 0x00000001ULL +#define PROF_TASK_TIME_L1_MASK 0x00000002ULL +#define PROF_AICORE_METRICS_MASK 0x00000004ULL +#define PROF_AICPU_TRACE_MASK 0x00000008ULL +#define PROF_L2CACHE_MASK 0x00000010ULL +#define PROF_HCCL_TRACE_MASK 0x00000020ULL +#define PROF_TRAINING_TRACE_MASK 0x00000040ULL +#define PROF_MSPROFTX_MASK 0x00000080ULL +#define PROF_RUNTIME_API_MASK 0x00000100ULL +#define PROF_TASK_FRAMEWORK_MASK 0x00000200ULL +#define PROF_FWK_SCHEDULE_L0_MASK 0x00000200ULL +#define PROF_TASK_TSFW_MASK 0x00000400ULL +#define PROF_TASK_TIME_MASK 0x00000800ULL +#define PROF_TASK_MEMORY_MASK 0x00001000ULL +#define PROF_TASK_TIME_L2_MASK 0x00002000ULL +#define PROF_OP_ATTR_MASK 0x00004000ULL + +// system profilinig mask +#define PROF_CPU_MASK 0x00010000ULL +#define PROF_HARDWARE_MEMORY_MASK 0x00020000ULL +#define PROF_IO_MASK 0x00040000ULL +#define PROF_INTER_CONNECTION_MASK 0x00080000ULL +#define PROF_DVPP_MASK 0x00100000ULL +#define PROF_SYS_AICORE_SAMPLE_MASK 0x00200000ULL +#define PROF_AIVECTORCORE_SAMPLE_MASK 0x00400000ULL +#define PROF_INSTR_MASK 0x00800000ULL + +#define PROF_MODEL_EXECUTE_MASK 0x0000001000000ULL +#define PROF_FWK_SCHEDULE_L1_MASK 0x0000001000000ULL +#define PROF_RUNTIME_TRACE_MASK 0x0000004000000ULL +#define PROF_SCHEDULE_TIMELINE_MASK 0x0000008000000ULL +#define PROF_SCHEDULE_TRACE_MASK 0x0000010000000ULL +#define PROF_AIVECTORCORE_METRICS_MASK 0x0000020000000ULL +#define PROF_SUBTASK_TIME_MASK 0x0000040000000ULL +#define PROF_OP_DETAIL_MASK 0x0000080000000ULL +#define PROF_OP_TIMESTAMP_MASK 0x0000100000000ULL + +#define PROF_AICPU_MODEL_MASK 0x4000000000000000ULL +#define PROF_MODEL_LOAD_MASK 0x8000000000000000ULL + +// profSwitchHi config, upper 16 bits will be send to aicpu by runtime +#define PROF_HI_AICPU_CHANNEL 0x1000000000000ULL // bit0 of the upper 16 bits means aicpu channel switch + +#endif // MSPROFILER_API_PROF_DATA_CONFIG_H_ \ No newline at end of file diff --git a/inc/msprof/toolchain/prof_engine.h b/inc/msprof/toolchain/prof_engine.h new file mode 100644 index 0000000000000000000000000000000000000000..b0e185ad620b1bde16ffe005cbb1521f33b839d9 --- /dev/null +++ b/inc/msprof/toolchain/prof_engine.h @@ -0,0 +1,186 @@ +/** + * @file prof_engine.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef MSPROF_ENGINE_PROF_ENGINE_H +#define MSPROF_ENGINE_PROF_ENGINE_H +#define MSVP_PROF_API __attribute__((visibility("default"))) + +#include +#include +#include "prof_reporter.h" + +/** + * @file prof_engine.h + * @defgroup ModuleJobConfig the ModuleJobConfig group + * This is the ModuleJobConfig group + */ +namespace Msprof { +namespace Engine { +/** + * @ingroup ModuleJobConfig + * @brief struct ModuleJobConfig + * record config info + */ +struct ModuleJobConfig { + std::map switches; /**< key is the config name, value is the config value(on or off) */ +}; + +/** + * @defgroup PluginIntf the pluginInf group + * This is the pluginInf group + */ + +/** + * @ingroup PluginIntf + * @brief class PluginIntf + */ +class MSVP_PROF_API PluginIntf { +public: + virtual ~PluginIntf() {} + +public: +/** + * @ingroup PluginIntf + * @name : Init + * @brief : API of user plugin, libmsporf call this API to send a Reporter to user plugin + * @par description : + * API of user plugin, libmsporf call this API to send a Reporter to user plugin. + * @param reporter [IN] const Reporter* the Reporter from libmsprof + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see UnInit + */ + virtual int Init(const Reporter *reporter) = 0; + +/** + * @ingroup PluginIntf + * @name : OnNewConfig + * @brief : API of user plugin, libmsprof call this API to send config info to user plugin \n + If the user plugin needn't config, no need to redefine this function + * @param config [IN] const ModuleJobConfig * the config from libmsprof + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see Init | UnInit + */ + virtual int OnNewConfig(const ModuleJobConfig *config) = 0; + +/** + * @ingroup PluginIntf + * @name : UnInit + * @brief : API of user plugin, libmsprof call this API to notify plugin stop to send data + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see Init + */ + virtual int UnInit() = 0; +}; + +/** + * @defgroup EngineIntf the EngineIntf group + * This is the EngineIntf group + */ + +/** + * @ingroup EngineIntf + * @brief class EngineIntf + */ +class MSVP_PROF_API EngineIntf { +public: + virtual ~EngineIntf() {} + +public: +/** + * @ingroup EngineIntf + * @name : CreatePlugin + * @brief : API of user engine, libmsporf call this API to get a plugin + * @retval PluginIntf * The pointer of the new plugin + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see ReleasePlugin + */ + virtual PluginIntf *CreatePlugin() = 0; + + /** + * @ingroup EngineIntf + * @name : ReleasePlugin + * @brief : API of user engine, libmsprof call this API to release a plugin + * @param plugin [IN] PluginIntf * the plugin to release + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) +* + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see CreatePlugin + */ + virtual int ReleasePlugin(PluginIntf *plugin) = 0; +}; + +/** + * @defgroup EngineMgr the EngineMgr group + * This is the EngineMgr group + */ + +/** + * @ingroup EngineMgr + * @name : Init + * @brief : API of libmsprof, init an engine with a name + * @param module [IN] const std::string the name of plugin + * @param module [IN] const EngineIntf* the plugin + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see UnInit + */ +MSVP_PROF_API int Init(const std::string &module, const EngineIntf *engine); + +/** + * @ingroup EngineMgr + * @name : Init + * @brief : API of libmsprof, uninit an engine with a name + * @param module [IN] const std::string the name of plugin + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_engine.h + * @since c60 + * @see Init + */ +MSVP_PROF_API int UnInit(const std::string &module); +} // namespace Engine +} // namespace Msprof + +#endif // MSPROF_ENGINE_PROF_ENGINE_H_ \ No newline at end of file diff --git a/inc/msprof/toolchain/prof_reporter.h b/inc/msprof/toolchain/prof_reporter.h new file mode 100644 index 0000000000000000000000000000000000000000..28396844b43d0aaee6e273163d0dc851992a4763 --- /dev/null +++ b/inc/msprof/toolchain/prof_reporter.h @@ -0,0 +1,79 @@ +/** + * @file prof_reporter.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + * + */ + +#ifndef MSPROF_ENGINE_PROF_REPORTER_H +#define MSPROF_ENGINE_PROF_REPORTER_H + +#if (defined(_WIN32) || defined(_WIN64) || defined(_MSC_VER)) +#define MSVP_PROF_API __declspec(dllexport) +#else +#define MSVP_PROF_API __attribute__((visibility("default"))) +#endif + +#include "prof_api.h" + +/** + * @file prof_reporter.h + * @defgroup reporter the reporter group + * This is the reporter group + */ +namespace Msprof { +namespace Engine { +/** + * @ingroup reporter + * @brief class Reporter + * the Reporter class .used to send data to profiling + */ +class MSVP_PROF_API Reporter { +public: + virtual ~Reporter() {} + +public: + /** + * @ingroup reporter + * @name : Report + * @brief : API of libmsprof, report data to libmsprof, it's a non-blocking function \n + The data will be firstly appended to cache, if the cache is full, data will be ignored + * @param data [IN] const ReporterData * the data send to libmsporf + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_reporter.h + * @since c60 + * @see Flush + */ + virtual int Report(const ReporterData *data) = 0; + + /** + * @ingroup reporter + * @name : Flush + * @brief : API of libmsprof, notify libmsprof send data over, it's a blocking function \n + The all datas of cache will be write to file or send to host + * @retval PROFILING_SUCCESS 0 (success) + * @retval PROFILING_FAILED -1 (failed) + * + * @par depend: + * @li libmsprof + * @li prof_reporter.h + * @since c60 + * @see ProfMgrStop + */ + virtual int32_t Flush() = 0; + + virtual uint32_t GetReportDataMaxLen() const = 0; +}; + +} // namespace Engine +} // namespace Msprof + +#endif // MSPROF_ENGINE_PROF_REPORTER_H diff --git a/inc/parser/external/parser/caffe_parser.h b/inc/parser/external/parser/caffe_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..6c7e69251acb49f111eff74a1dd481f037ee63dd --- /dev/null +++ b/inc/parser/external/parser/caffe_parser.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ +#define INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include +#include +#include "graph/ascend_string.h" +#include "graph/ge_error_codes.h" +#include "graph/graph.h" + +namespace ge { +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, + ge::Graph &graph); + +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseCaffe(const char *model_file, const char *weights_file, + const std::map &parser_params, + ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_ACL_GRAPH_CAFFE_H_ diff --git a/inc/parser/external/parser/onnx_parser.h b/inc/parser/external/parser/onnx_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..498dd54f5958a8bec51ea49ab800a6f16a988338 --- /dev/null +++ b/inc/parser/external/parser/onnx_parser.h @@ -0,0 +1,40 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_PARSER_ONNX_PARSER_H_ +#define INC_EXTERNAL_PARSER_ONNX_PARSER_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include +#include "graph/ascend_string.h" +#include "graph/ge_error_codes.h" +#include "graph/graph.h" + +namespace ge { +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseONNX(const char *model_file, + const std::map &parser_params, ge::Graph &graph); + +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseONNXFromMem(const char *buffer, size_t size, + const std::map &parser_params, ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_PARSER_ONNX_PARSER_H_ diff --git a/inc/parser/external/parser/tensorflow_parser.h b/inc/parser/external/parser/tensorflow_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..022442b8d362d752b03a03f5541ff98356c60bb9 --- /dev/null +++ b/inc/parser/external/parser/tensorflow_parser.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ +#define INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY _declspec(dllexport) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define PARSER_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define PARSER_FUNC_VISIBILITY +#endif +#endif + +#include +#include +#include +#include +#include +#include "graph/ascend_string.h" +#include "graph/ge_error_codes.h" +#include "graph/graph.h" + +namespace ge { +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseTensorFlow(const char *model_file, ge::Graph &graph); +PARSER_FUNC_VISIBILITY graphStatus aclgrphParseTensorFlow(const char *model_file, + const std::map &parser_params, ge::Graph &graph); +} // namespace ge + +#endif // INC_EXTERNAL_ACL_PARSER_TENSORFLOW_H_ diff --git a/inc/parser/parser/common/convert/pb2json.h b/inc/parser/parser/common/convert/pb2json.h new file mode 100644 index 0000000000000000000000000000000000000000..001e017877f2b551fb8bbd0b3581a17664e46057 --- /dev/null +++ b/inc/parser/parser/common/convert/pb2json.h @@ -0,0 +1,69 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +// File: pb2json.h +// Description: This header file for protobuf message and json interconversion + +#ifndef PARSER_COMMON_CONVERT_PB2JSON_H_ +#define PARSER_COMMON_CONVERT_PB2JSON_H_ +#include +#include +#include +#include +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "nlohmann/json.hpp" + +namespace ge { +using Json = nlohmann::json; +using ProtobufMsg = ::google::protobuf::Message; +using ProtobufReflection = ::google::protobuf::Reflection; +using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor; +using ProtobufDescriptor = ::google::protobuf::Descriptor; +using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; + +class Pb2Json { + public: + /** + * @ingroup domi_omg + * @brief Transfer protobuf object to JSON object + * @param [out] json Converted JSON object + * @return void success + * @author + */ + static void Message2Json(const ProtobufMsg &message, const std::set &black_fields, Json &json, + bool enum2str = false, int depth = 0); + + static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const std::set &black_fields, + Json &json, bool enum2str, int depth = 0); + + static void EnumJson2Json(Json &json); + +protected: + static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, + bool enum2str, Json &json); + + static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json); + + static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, + const ProtobufReflection *reflection, const std::set &black_fields, Json &json, + bool enum2str, int depth); + + static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); + + static int DictInit(Json &json, std::vector &idx2name, + std::vector &idx2value, std::vector &use_string_val); + + static int AttrReplaceKV(Json &json, const std::vector &idx2name, + const std::vector &idx2value, const std::vector &use_string_val); +}; +} // namespace ge + +#endif // PARSER_COMMON_CONVERT_PB2JSON_H_ diff --git a/inc/parser/parser/common/pre_checker.h b/inc/parser/parser/common/pre_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..402c9c6d3d8930dda6033cfe07dfdbb288f20af1 --- /dev/null +++ b/inc/parser/parser/common/pre_checker.h @@ -0,0 +1,187 @@ +/* Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ===================================================================================================================*/ + +#ifndef PARSER_COMMON_PRE_CHECKER_H_ +#define PARSER_COMMON_PRE_CHECKER_H_ + +#include +#include +#include "framework/omg/parser/parser_types.h" +#include "omg/omg_inner_types.h" + +namespace ge { +using std::map; +using std::string; +using std::vector; +using Status = domi::Status; +/** + * @ingroup domi_omg + * @brief pre_check + * @author + */ +class PreChecker { + public: + /** + * @ingroup domi_omg + * @brief Operator unique identification + */ + using OpId = const void *; + + /** + * @ingroup domi_omg + * @brief error code, 1~99:Error, 100~199:Waring。 + */ + enum class ErrorCode { + // no error + OK = 0, + + // type unsupported + TYPE_UNSUPPORTED = 1, + + // param invalid + PARAM_INVALID = 2, + + // type ambiguous + TYPE_AMBIGUOUS = 8, + + // name repeated + NAME_REPEATED = 9 + }; + + /** + * @ingroup domi_omg + * @brief Operator error description + */ + struct Cause { + // error code + ErrorCode code; + + // error message + string message; + }; + + public: + /** + * @ingroup domi_omg + * @brief instance interface + */ + static PreChecker &Instance(); + + /** + * @ingroup domi_omg + * @brief set model name + */ + void SetModelName(const string &name); + + /** + * @ingroup domi_omg + * @brief add op information + */ + Status AddOp(OpId id, const string &name, const string &type); + + /** + * @ingroup domi_omg + * @brief Judge whether the operator name is duplicate + */ + Status CheckName(OpId id); + + /** + * @ingroup domi_omg + * @brief check operation type + * 1、Check whether the operator type supports according to the global frameworktype + * 2、Check if the operator type is ambiguous + */ + Status CheckType(OpId id, bool is_tensorflow = false); + + void RefreshErrorMessageByName(const string &op_name, ErrorCode code, const string& msg); + + /** + * @ingroup domi_omg + * @brief Add custom error description + */ + Status AddCause(OpId id, ErrorCode code, const string &msg); + + /** + * @ingroup domi_omg + * @brief Add custom error description + */ + Status AddCause(OpId id, const Cause &cause); + + /** + * @ingroup domi_omg + * @brief Clear all operator information + */ + void Clear(); + + /** + * @ingroup domi_omg + * @brief Clear the error information of the specified operator + */ + Status Clear(OpId id, const string &message = ""); + + /** + * @ingroup domi_omg + * @brief Determine if an error has been detected + */ + bool HasError(); + + /** + * @ingroup domi_omg + * @brief Save inspection results(JSON) + */ + Status Save(const string &file); + + private: + /** + * @ingroup domi_omg + * @brief operation information + */ + struct Info { + // Operator identifier + OpId id; + + // Operator name + string name; + + // Operator type + string type; + + // Error description, which may contain multiple (for example, both name and type are illegal) + vector causes; + }; + + PreChecker(); + ~PreChecker(); + PreChecker(const PreChecker &); + PreChecker &operator=(const PreChecker &); + + // Initialize internal data + void Init(); + + // Judge whether the type is supported + Status CheckTypeSupported(OpId id, const string &type, const string &name, bool is_tensorflow); + + // Determine if an error has been detected + bool HasError(OpId id); + + private: + // model name + string model_name_; + + // Save operator check results + map op_map_; + + // Save operator list in original order + vector ops_; + + // save frame related operator types + map *fmk_op_types_; +}; +} // namespace ge +#endif // PARSER_COMMON_PRE_CHECKER_H_ diff --git a/inc/runtime/external/aclnn/acl_meta.h b/inc/runtime/external/aclnn/acl_meta.h new file mode 100644 index 0000000000000000000000000000000000000000..27138ca4c410c5767f9cc3230333eec362a9a2a2 --- /dev/null +++ b/inc/runtime/external/aclnn/acl_meta.h @@ -0,0 +1,142 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OP_API_OP_API_COMMON_INC_EXTERNAL_ACL_META_H +#define OP_API_OP_API_COMMON_INC_EXTERNAL_ACL_META_H + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define ACL_FUNC_VISIBILITY _declspec(dllexport) +#else +#define ACL_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define ACL_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define ACL_FUNC_VISIBILITY +#endif +#endif + +#ifdef __GNUC__ +#define ACL_DEPRECATED __attribute__((deprecated)) +#define ACL_DEPRECATED_MESSAGE(message) __attribute__((deprecated(message))) +#elif defined(_MSC_VER) +#define ACL_DEPRECATED __declspec(deprecated) +#define ACL_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) +#else +#define ACL_DEPRECATED +#define ACL_DEPRECATED_MESSAGE(message) +#endif + +#ifndef ACLNN_META +#define ACLNN_META +typedef struct aclOpExecutor aclOpExecutor; +typedef struct aclTensor aclTensor; +typedef struct aclScalar aclScalar; +typedef struct aclIntArray aclIntArray; +typedef struct aclFloatArray aclFloatArray; +typedef struct aclBoolArray aclBoolArray; +typedef struct aclTensorList aclTensorList; +typedef struct aclScalarList aclScalarList; + +typedef int32_t aclnnStatus; +constexpr aclnnStatus OK = 0; +#endif + +ACL_FUNC_VISIBILITY aclTensor *aclCreateTensor(const int64_t *viewDims, uint64_t viewDimsNum, aclDataType dataType, + const int64_t *stride, int64_t offset, aclFormat format, + const int64_t *storageDims, uint64_t storageDimsNum, + void *tensorData); + +ACL_FUNC_VISIBILITY aclScalar *aclCreateScalar(void *value, aclDataType dataType); +ACL_FUNC_VISIBILITY aclIntArray *aclCreateIntArray(const int64_t *value, uint64_t size); +ACL_FUNC_VISIBILITY aclFloatArray *aclCreateFloatArray(const float *value, uint64_t size); +ACL_FUNC_VISIBILITY aclBoolArray *aclCreateBoolArray(const bool *value, uint64_t size); +ACL_FUNC_VISIBILITY aclTensorList *aclCreateTensorList(const aclTensor *const *value, uint64_t size); +ACL_FUNC_VISIBILITY aclScalarList *aclCreateScalarList(const aclScalar *const *value, uint64_t size); + +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyTensor(const aclTensor *tensor); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyScalar(const aclScalar *scalar); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyIntArray(const aclIntArray *array); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyFloatArray(const aclFloatArray *array); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyBoolArray(const aclBoolArray *array); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyTensorList(const aclTensorList *array); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyScalarList(const aclScalarList *array); +ACL_FUNC_VISIBILITY aclnnStatus aclGetViewShape(const aclTensor *tensor, int64_t **viewDims, uint64_t *viewDimsNum); +ACL_FUNC_VISIBILITY aclnnStatus aclGetStorageShape(const aclTensor *tensor, + int64_t **storageDims, + uint64_t *storageDimsNum); +ACL_FUNC_VISIBILITY aclnnStatus aclGetViewStrides(const aclTensor *tensor, + int64_t **stridesValue, + uint64_t *stridesNum); +ACL_FUNC_VISIBILITY aclnnStatus aclGetViewOffset(const aclTensor *tensor, int64_t *offset); +ACL_FUNC_VISIBILITY aclnnStatus aclGetFormat(const aclTensor *tensor, aclFormat *format); +ACL_FUNC_VISIBILITY aclnnStatus aclGetDataType(const aclTensor *tensor, aclDataType *dataType); +ACL_FUNC_VISIBILITY aclnnStatus aclGetIntArraySize(const aclIntArray *array, uint64_t *size); +ACL_FUNC_VISIBILITY aclnnStatus aclGetFloatArraySize(const aclFloatArray *array, uint64_t *size); +ACL_FUNC_VISIBILITY aclnnStatus aclGetBoolArraySize(const aclBoolArray *array, uint64_t *size); +ACL_FUNC_VISIBILITY aclnnStatus aclGetTensorListSize(const aclTensorList *tensorList, uint64_t *size); +ACL_FUNC_VISIBILITY aclnnStatus aclGetScalarListSize(const aclScalarList *scalarList, uint64_t *size); + +ACL_FUNC_VISIBILITY aclnnStatus aclInitTensor(aclTensor *tensor, const int64_t *viewDims, uint64_t viewDimsNum, + aclDataType dataType, const int64_t *stride, int64_t offset, + aclFormat format, const int64_t *storageDims, uint64_t storageDimsNum, + void *tensorDataAddr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetAclOpExecutorRepeatable(aclOpExecutor *executor); +ACL_FUNC_VISIBILITY aclnnStatus aclDestroyAclOpExecutor(aclOpExecutor *executor); +ACL_FUNC_VISIBILITY aclnnStatus AclSetInputTensorAddr(aclOpExecutor *executor, const size_t index, + aclTensor *tensor, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus AclSetOutputTensorAddr(aclOpExecutor *executor, const size_t index, + aclTensor *tensor, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus AclSetDynamicInputTensorAddr(aclOpExecutor *executor, size_t irIndex, + const size_t relativeIndex, + aclTensorList *tensors, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus AclSetDynamicOutputTensorAddr(aclOpExecutor *executor, size_t irIndex, + const size_t relativeIndex, + aclTensorList *tensors, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus AclSetTensorAddr(aclOpExecutor *executor, const size_t index, + aclTensor *tensor, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus AclSetDynamicTensorAddr(aclOpExecutor *executor, size_t irIndex, + const size_t relativeIndex, + aclTensorList *tensors, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetInputTensorAddr(aclOpExecutor *executor, const size_t index, + aclTensor *tensor, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetOutputTensorAddr(aclOpExecutor *executor, const size_t index, + aclTensor *tensor, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetDynamicInputTensorAddr(aclOpExecutor *executor, size_t irIndex, + const size_t relativeIndex, + aclTensorList *tensors, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetDynamicOutputTensorAddr(aclOpExecutor *executor, size_t irIndex, + const size_t relativeIndex, + aclTensorList *tensors, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetTensorAddr(aclOpExecutor *executor, const size_t index, + aclTensor *tensor, void *addr); +ACL_FUNC_VISIBILITY aclnnStatus aclSetDynamicTensorAddr(aclOpExecutor *executor, size_t irIndex, + const size_t relativeIndex, + aclTensorList *tensors, void *addr); +#ifdef __cplusplus +} +#endif + +#endif // OP_API_OP_API_COMMON_INC_EXTERNAL_ACL_META_H diff --git a/inc/runtime/external/aclnn/aclnn_base.h b/inc/runtime/external/aclnn/aclnn_base.h new file mode 100644 index 0000000000000000000000000000000000000000..a5a722a148007b35704cdba2878ffe9023989b2a --- /dev/null +++ b/inc/runtime/external/aclnn/aclnn_base.h @@ -0,0 +1,59 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef OP_API_OP_API_COMMON_INC_EXTERNAL_ACLNN_BASE_H +#define OP_API_OP_API_COMMON_INC_EXTERNAL_ACLNN_BASE_H + +#include +#include +#include "aclnn/acl_meta.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_MSC_VER) +#ifdef FUNC_VISIBILITY +#define ACL_FUNC_VISIBILITY _declspec(dllexport) +#else +#define ACL_FUNC_VISIBILITY +#endif +#else +#ifdef FUNC_VISIBILITY +#define ACL_FUNC_VISIBILITY __attribute__((visibility("default"))) +#else +#define ACL_FUNC_VISIBILITY +#endif +#endif + +#ifdef __GNUC__ +#define ACL_DEPRECATED __attribute__((deprecated)) +#define ACL_DEPRECATED_MESSAGE(message) __attribute__((deprecated(message))) +#elif defined(_MSC_VER) +#define ACL_DEPRECATED __declspec(deprecated) +#define ACL_DEPRECATED_MESSAGE(message) __declspec(deprecated(message)) +#else +#define ACL_DEPRECATED +#define ACL_DEPRECATED_MESSAGE(message) +#endif + +ACL_FUNC_VISIBILITY aclnnStatus aclnnInit(const char *configPath); +ACL_FUNC_VISIBILITY aclnnStatus aclnnFinalize(); + +#ifdef __cplusplus +} +#endif + +#endif // OP_API_OP_API_COMMON_INC_EXTERNAL_ACLNN_BASE_H diff --git a/inc/runtime/external/runtime/rt_error_codes.h b/inc/runtime/external/runtime/rt_error_codes.h new file mode 100644 index 0000000000000000000000000000000000000000..94262169ba5e614b69be404e69096776326bb42d --- /dev/null +++ b/inc/runtime/external/runtime/rt_error_codes.h @@ -0,0 +1,128 @@ +/** +* @file rt_error_codes.h +* +* Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved. +* +* This program is distributed in the hope that it will be useful, +* but WITHOUT ANY WARRANTY; without even the implied warranty of +* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +*/ + +#ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ +#define __INC_EXTERNEL_RT_ERROR_CODES_H__ + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define ACL_RT_SUCCESS 0 // success +#define ACL_ERROR_RT_PARAM_INVALID 107000 // param invalid +#define ACL_ERROR_RT_INVALID_DEVICEID 107001 // invalid device id +#define ACL_ERROR_RT_CONTEXT_NULL 107002 // current context null +#define ACL_ERROR_RT_STREAM_CONTEXT 107003 // stream not in current context +#define ACL_ERROR_RT_MODEL_CONTEXT 107004 // model not in current context +#define ACL_ERROR_RT_STREAM_MODEL 107005 // stream not in model +#define ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID 107006 // event timestamp invalid +#define ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL 107007 // event timestamp reversal +#define ACL_ERROR_RT_ADDR_UNALIGNED 107008 // memory address unaligned +#define ACL_ERROR_RT_FILE_OPEN 107009 // open file failed +#define ACL_ERROR_RT_FILE_WRITE 107010 // write file failed +#define ACL_ERROR_RT_STREAM_SUBSCRIBE 107011 // error subscribe stream +#define ACL_ERROR_RT_THREAD_SUBSCRIBE 107012 // error subscribe thread +#define ACL_ERROR_RT_GROUP_NOT_SET 107013 // group not set +#define ACL_ERROR_RT_GROUP_NOT_CREATE 107014 // group not create +#define ACL_ERROR_RT_STREAM_NO_CB_REG 107015 // callback not register to stream +#define ACL_ERROR_RT_INVALID_MEMORY_TYPE 107016 // invalid memory type +#define ACL_ERROR_RT_INVALID_HANDLE 107017 // invalid handle +#define ACL_ERROR_RT_INVALID_MALLOC_TYPE 107018 // invalid malloc type +#define ACL_ERROR_RT_WAIT_TIMEOUT 107019 // wait timeout +#define ACL_ERROR_RT_TASK_TIMEOUT 107020 // task timeout +#define ACL_ERROR_RT_SYSPARAMOPT_NOT_SET 107021 // not set sysparamopt +#define ACL_ERROR_RT_DEVICE_TASK_ABORT 107022 // device task aborting +#define ACL_ERROR_RT_STREAM_ABORT 107023 // stream aborting + +#define ACL_ERROR_RT_FEATURE_NOT_SUPPORT 207000 // feature not support +#define ACL_ERROR_RT_MEMORY_ALLOCATION 207001 // memory allocation error +#define ACL_ERROR_RT_MEMORY_FREE 207002 // memory free error +#define ACL_ERROR_RT_AICORE_OVER_FLOW 207003 // aicore over flow +#define ACL_ERROR_RT_NO_DEVICE 207004 // no device +#define ACL_ERROR_RT_RESOURCE_ALLOC_FAIL 207005 // resource alloc fail +#define ACL_ERROR_RT_NO_PERMISSION 207006 // no permission +#define ACL_ERROR_RT_NO_EVENT_RESOURCE 207007 // no event resource +#define ACL_ERROR_RT_NO_STREAM_RESOURCE 207008 // no stream resource +#define ACL_ERROR_RT_NO_NOTIFY_RESOURCE 207009 // no notify resource +#define ACL_ERROR_RT_NO_MODEL_RESOURCE 207010 // no model resource +#define ACL_ERROR_RT_NO_CDQ_RESOURCE 207011 // no cdq resource +#define ACL_ERROR_RT_OVER_LIMIT 207012 // over limit +#define ACL_ERROR_RT_QUEUE_EMPTY 207013 // queue is empty +#define ACL_ERROR_RT_QUEUE_FULL 207014 // queue is full +#define ACL_ERROR_RT_REPEATED_INIT 207015 // repeated init +#define ACL_ERROR_RT_AIVEC_OVER_FLOW 207016 // aivec over flow +#define ACL_ERROR_RT_OVER_FLOW 207017 // common over flow +#define ACL_ERROR_RT_DEVICE_OOM 207018 // device oom + +#define ACL_ERROR_RT_INTERNAL_ERROR 507000 // runtime internal error +#define ACL_ERROR_RT_TS_ERROR 507001 // ts internel error +#define ACL_ERROR_RT_STREAM_TASK_FULL 507002 // task full in stream +#define ACL_ERROR_RT_STREAM_TASK_EMPTY 507003 // task empty in stream +#define ACL_ERROR_RT_STREAM_NOT_COMPLETE 507004 // stream not complete +#define ACL_ERROR_RT_END_OF_SEQUENCE 507005 // end of sequence +#define ACL_ERROR_RT_EVENT_NOT_COMPLETE 507006 // event not complete +#define ACL_ERROR_RT_CONTEXT_RELEASE_ERROR 507007 // context release error +#define ACL_ERROR_RT_SOC_VERSION 507008 // soc version error +#define ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT 507009 // task type not support +#define ACL_ERROR_RT_LOST_HEARTBEAT 507010 // ts lost heartbeat +#define ACL_ERROR_RT_MODEL_EXECUTE 507011 // model execute failed +#define ACL_ERROR_RT_REPORT_TIMEOUT 507012 // report timeout +#define ACL_ERROR_RT_SYS_DMA 507013 // sys dma error +#define ACL_ERROR_RT_AICORE_TIMEOUT 507014 // aicore timeout +#define ACL_ERROR_RT_AICORE_EXCEPTION 507015 // aicore exception +#define ACL_ERROR_RT_AICORE_TRAP_EXCEPTION 507016 // aicore trap exception +#define ACL_ERROR_RT_AICPU_TIMEOUT 507017 // aicpu timeout +#define ACL_ERROR_RT_AICPU_EXCEPTION 507018 // aicpu exception +#define ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR 507019 // aicpu datadump response error +#define ACL_ERROR_RT_AICPU_MODEL_RSP_ERR 507020 // aicpu model operate response error +#define ACL_ERROR_RT_PROFILING_ERROR 507021 // profiling error +#define ACL_ERROR_RT_IPC_ERROR 507022 // ipc error +#define ACL_ERROR_RT_MODEL_ABORT_NORMAL 507023 // model abort normal +#define ACL_ERROR_RT_KERNEL_UNREGISTERING 507024 // kernel unregistering +#define ACL_ERROR_RT_RINGBUFFER_NOT_INIT 507025 // ringbuffer not init +#define ACL_ERROR_RT_RINGBUFFER_NO_DATA 507026 // ringbuffer no data +#define ACL_ERROR_RT_KERNEL_LOOKUP 507027 // kernel lookup error +#define ACL_ERROR_RT_KERNEL_DUPLICATE 507028 // kernel register duplicate +#define ACL_ERROR_RT_DEBUG_REGISTER_FAIL 507029 // debug register failed +#define ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL 507030 // debug unregister failed +#define ACL_ERROR_RT_LABEL_CONTEXT 507031 // label not in current context +#define ACL_ERROR_RT_PROGRAM_USE_OUT 507032 // program register num use out +#define ACL_ERROR_RT_DEV_SETUP_ERROR 507033 // device setup error +#define ACL_ERROR_RT_VECTOR_CORE_TIMEOUT 507034 // vector core timeout +#define ACL_ERROR_RT_VECTOR_CORE_EXCEPTION 507035 // vector core exception +#define ACL_ERROR_RT_VECTOR_CORE_TRAP_EXCEPTION 507036 // vector core trap exception +#define ACL_ERROR_RT_CDQ_BATCH_ABNORMAL 507037 // cdq alloc batch abnormal +#define ACL_ERROR_RT_DIE_MODE_CHANGE_ERROR 507038 // can not change die mode +#define ACL_ERROR_RT_DIE_SET_ERROR 507039 // single die mode can not set die +#define ACL_ERROR_RT_INVALID_DIEID 507040 // invalid die id +#define ACL_ERROR_RT_DIE_MODE_NOT_SET 507041 // die mode not set +#define ACL_ERROR_RT_AICORE_TRAP_READ_OVERFLOW 507042 // aic trap read overflow +#define ACL_ERROR_RT_AICORE_TRAP_WRITE_OVERFLOW 507043 // aic trap write overflow +#define ACL_ERROR_RT_VECTOR_CORE_TRAP_READ_OVERFLOW 507044 // aiv trap read overflow +#define ACL_ERROR_RT_VECTOR_CORE_TRAP_WRITE_OVERFLOW 507045 // aiv trap write overflow +#define ACL_ERROR_RT_STREAM_SYNC_TIMEOUT 507046 // stream sync time out +#define ACL_ERROR_RT_EVENT_SYNC_TIMEOUT 507047 // event sync time out +#define ACL_ERROR_RT_FFTS_PLUS_TIMEOUT 507048 // ffts+ timeout +#define ACL_ERROR_RT_FFTS_PLUS_EXCEPTION 507049 // ffts+ exception +#define ACL_ERROR_RT_FFTS_PLUS_TRAP_EXCEPTION 507050 // ffts+ trap exception +#define ACL_ERROR_RT_SEND_MSG 507051 // hdc send msg fail +#define ACL_ERROR_RT_COPY_DATA 507052 // copy data fail +#define ACL_ERROR_RT_DEVICE_MEM_ERROR 507053 // device MEM ERROR +#define ACL_ERROR_RT_DRV_INTERNAL_ERROR 507899 // drv internal error +#define ACL_ERROR_RT_AICPU_INTERNAL_ERROR 507900 // aicpu internal error +#define ACL_ERROR_RT_SOCKET_CLOSE 507901 // hdc disconnect +#define ACL_ERROR_RT_AICPU_INFO_LOAD_RSP_ERR 507902 // aicpu info load response error + +#ifdef __cplusplus +} +#endif +#endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ diff --git a/inc/runtime/runtime/__clang_cce_runtime.h b/inc/runtime/runtime/__clang_cce_runtime.h new file mode 100644 index 0000000000000000000000000000000000000000..8b8d2665ce7107a7a374080b632e0a12c6681427 --- /dev/null +++ b/inc/runtime/runtime/__clang_cce_runtime.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: __clang_cce_runtime.h + * Create: 2020-01-01 + */ + +#ifndef __CLANG_CCE_RUNTIME_H__ +#define __CLANG_CCE_RUNTIME_H__ +#if defined(__cplusplus) +extern "C" { +#endif // __cplusplus + +// This interface is provided by runtime, it needs to be kept the same as their. +/** + * @ingroup dvrt_clang_cce_runtime + * @brief Config kernel launch parameters + * @param [in] numBlocks block dimemsions + * @param [in|out] smDesc L2 memory usage control information + * @param [in|out] stream associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +#ifdef __cplusplus +uint32_t rtConfigureCall(uint32_t numBlocks, void *smDesc = nullptr, void *stream = nullptr); +#else // __cplusplus +uint32_t rtConfigureCall(uint32_t numBlocks, void *smDesc, void *stream); +#endif + +#if defined(__cplusplus) +} +#endif // __cplusplus + +#endif // __CLANG_CCE_RUNTIME_H__ diff --git a/inc/runtime/runtime/base.h b/inc/runtime/runtime/base.h new file mode 100644 index 0000000000000000000000000000000000000000..4462e0ca18280277cabe6200827cf2cee8e8fd4c --- /dev/null +++ b/inc/runtime/runtime/base.h @@ -0,0 +1,680 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: base.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_BASE_H +#define CCE_RUNTIME_BASE_H + +#include +#include +#include "toolchain/prof_api.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +// If you need export the function of this library in Win32 dll, use __declspec(dllexport) +#ifndef RTS_API +#ifdef RTS_DLL_EXPORT +#define RTS_API __declspec(dllexport) +#else +#define RTS_API +#endif +#endif + +typedef int32_t rtError_t; +static const int32_t RT_ERROR_NONE = 0; // success + +#ifndef char_t +typedef char char_t; +#endif + +#ifndef float32_t +typedef float float32_t; +#endif + +#ifndef float64_t +typedef double float64_t; +#endif + +/** + * @ingroup dvrt_base + * @brief device mode. + */ +typedef enum tagRtDeviceMode { + RT_DEVICE_MODE_SINGLE_DIE = 0, + RT_DEVICE_MODE_MULTI_DIE, + RT_DEVICE_MODE_RESERVED +} rtDeviceMode; + +/** + * @ingroup dvrt_base + * @brief device status. + */ +typedef enum tagRtDeviceStatus { + RT_DEVICE_STATUS_NORMAL = 0, + RT_DEVICE_STATUS_ABNORMAL, + RT_DEVICE_STATUS_END = 0xFFFF +} rtDeviceStatus; + +/** + * @ingroup dvrt_base + * @brief runtime exception numbers. + */ +typedef enum tagRtExceptionType { + RT_EXCEPTION_NONE = 0, + RT_EXCEPTION_TS_DOWN = 1, + RT_EXCEPTION_TASK_TIMEOUT = 2, + RT_EXCEPTION_TASK_FAILURE = 3, + RT_EXCEPTION_DEV_RUNNING_DOWN = 4, + RT_EXCEPTION_STREAM_ID_FREE_FAILED = 5 +} rtExceptionType; + +/** + * @ingroup dvrt_base + * @brief Switch type. + */ +typedef enum tagRtCondition { + RT_EQUAL = 0, + RT_NOT_EQUAL, + RT_GREATER, + RT_GREATER_OR_EQUAL, + RT_LESS, + RT_LESS_OR_EQUAL +} rtCondition_t; + +typedef enum schemModeType { + RT_SCHEM_MODE_NORMAL = 0, + RT_SCHEM_MODE_BATCH, + RT_SCHEM_MODE_SYNC, + RT_SCHEM_MODE_END +} rtschemModeType_t; + +typedef enum tagSysParamOpt { + SYS_OPT_DETERMINISTIC = 0, // value: 0:non-DETERMINISTIC, 1:DETERMINISTIC + SYS_OPT_ENABLE_DEBUG_KERNEL = 1, // value: 0:disable, 1:enable + SYS_OPT_RESERVED = 2, +} rtSysParamOpt; + +typedef enum tagSysParamValue { + SYS_OPT_DISABLE = 0, // sys param opt disable + SYS_OPT_ENABLE = 1, // sys param opt enable + SYS_OPT_MAX = 2, +} rtSysParamValue; + +typedef struct tagRtTaskCfgInfo { + uint8_t qos; + uint8_t partId; + uint8_t schemMode; // rtschemModeType_t 0:normal;1:batch;2:sync + bool d2dCrossFlag; // d2dCrossFlag true:D2D_CROSS flase:D2D_INNER + uint32_t blockDimOffset; + uint8_t dumpflag; // dumpflag 0:fault 2:RT_KERNEL_DUMPFLAG 4:RT_FUSION_KERNEL_DUMPFLAG + uint8_t rev[3]; + uint32_t localMemorySize; // for simt ub_size +} rtTaskCfgInfo_t; + +/** + * @ingroup dvrt_base + * @brief Data Type of Extensible Switch Task. + */ +typedef enum tagRtSwitchDataType { + RT_SWITCH_INT32 = 0, + RT_SWITCH_INT64 = 1, +} rtSwitchDataType_t; + +typedef enum tagRtStreamFlagType { + RT_HEAD_STREAM = 0, // first stream + RT_INVALID_FLAG = 0x7FFFFFFF, +} rtStreamFlagType_t; + +typedef enum tagRtLimitType { + RT_LIMIT_TYPE_LOW_POWER_TIMEOUT = 0, // timeout for power down , ms +} rtLimitType_t; + +typedef enum tagRtFloatOverflowMode { + RT_OVERFLOW_MODE_SATURATION = 0, + RT_OVERFLOW_MODE_INFNAN, + RT_OVERFLOW_MODE_UNDEF, +} rtFloatOverflowMode_t; + +typedef enum tagRtExceptionExpandType { + RT_EXCEPTION_INVALID = 0, + RT_EXCEPTION_FFTS_PLUS, +} rtExceptionExpandType_t; + +typedef struct rtFftsPlusExDetailInfo { + uint16_t contextId; + uint16_t threadId; +} rtFftsPlusExDetailInfo_t; + +typedef struct rtArgsSizeInfo { + void *infoAddr; /* info : atomicIndex|input num input offset|size|size */ + uint32_t atomicIndex; +} rtArgsSizeInfo_t; + +typedef struct rtExceptionKernelInfo { + uint32_t binSize; + const void *bin; + uint32_t kernelNameSize; + const char *kernelName; + const void *dfxAddr; + uint16_t dfxSize; + uint8_t reserved[2]; // 填补空间以保持四字节对齐 + int32_t elfDataFlag; +} rtExceptionKernelInfo_t; + +typedef struct rtExceptionArgsInfo { + uint32_t argsize; + void *argAddr; + rtArgsSizeInfo_t sizeInfo; + rtExceptionKernelInfo_t exceptionKernelInfo; // 新增结构体,注意兼容性问题 +}rtExceptionArgsInfo_t; + +typedef struct rtExceptionExpandInfo { + rtExceptionExpandType_t type; + union { + rtFftsPlusExDetailInfo_t fftsPlusInfo; + } u; +} rtExceptionExpandInfo_t; + +typedef struct rtExceptionInfo { + uint32_t taskid; + uint32_t streamid; + uint32_t tid; + uint32_t deviceid; + uint32_t retcode; + rtExceptionExpandInfo_t expandInfo; + rtExceptionArgsInfo_t exceptionArgs; +} rtExceptionInfo_t; + +typedef enum { + RT_DEVICE_ABORT = 0, + RT_DEVICE_KILL, + RT_DEVICE_CLEAN +} rtTaskAbortStage_t; + +/** + * @ingroup dvrt_base + * @brief stream handle. + */ +typedef void *rtStream_t; + +typedef void (*rtErrorCallback)(rtExceptionType); + +typedef int32_t (*rtTaskAbortCallBack)(uint32_t devId, rtTaskAbortStage_t stage, uint32_t timeout, void *args); + +typedef void (*rtTaskFailCallback)(rtExceptionInfo_t *exceptionInfo); + +typedef void (*rtDeviceStateCallback)(uint32_t devId, bool isOpen); + +typedef void (*rtStreamStateCallback)(rtStream_t stm, const bool isCreate); +/** + * @ingroup profiling_base + * @brief dataType: rtProfCtrlType_t + * @brief data: data swtich or reporter function + * @brief dataLen: length of data + */ +typedef rtError_t (*rtProfCtrlHandle)(uint32_t dataType, void *data, uint32_t dataLen); + +/** + * @ingroup dvrt_base + * @brief Program handle. + */ +typedef void *rtBinHandle; + +/** + * @ingroup dvrt_base + * @brief Kernel handle. + */ +typedef void *rtFuncHandle; + +/** + * @ingroup dvrt_base + * @brief Args handle. + */ +typedef void *rtLaunchArgsHandle; + +/** + * @ingroup dvrt_base + * @brief runtime event handle. + */ +typedef void *rtEvent_t; + +/** + * @ingroup dvrt_base + * @brief label handle. + */ +typedef void *rtLabel_t; + +/** + * @ingroup dvrt_base + * @brief model handle. + */ +typedef void *rtModel_t; + +#define RT_PROF_MAX_DEV_NUM 64 + +#define PATH_LEN_MAX 1023 +#define PARAM_LEN_MAX 4095 +typedef struct rtCommandHandleParams { + uint32_t pathLen; + uint32_t storageLimit; // MB + uint32_t profDataLen; + char_t path[PATH_LEN_MAX + 1]; + char_t profData[PARAM_LEN_MAX + 1]; +} rtCommandHandleParams_t; + +/** + * @ingroup profiling_base + * @brief profiling command info + */ +typedef struct rtProfCommandHandle { + uint64_t profSwitch; + uint64_t profSwitchHi; + uint32_t devNums; + uint32_t devIdList[RT_PROF_MAX_DEV_NUM]; + uint32_t modelId; + uint32_t type; + uint32_t cacheFlag; + rtCommandHandleParams_t commandHandleParams; +} rtProfCommandHandle_t; + +/** + * @ingroup profiling_base + * @brief type of app register profiling switch or reporter callback + */ +typedef enum { + RT_PROF_CTRL_INVALID = 0, + RT_PROF_CTRL_SWITCH, + RT_PROF_CTRL_REPORTER, + RT_PROF_CTRL_BUTT +} rtProfCtrlType_t; + +typedef enum { + RT_UTIL_TYPE_AICORE = 0, + RT_UTIL_TYPE_AIVECTOR, + RT_UTIL_TYPE_AICPU, + RT_UTIL_TYPE_MAX +} rtTypeUtil_t; + +/** + * @ingroup profiling_base + * @brief runtime handle. + */ +RTS_API rtError_t rtSetProfDirEx(const char_t *profDir, const char_t *address, const char_t *jobCtx); + +/** + * @ingroup profiling_base + * @brief init profiler object. + */ +RTS_API rtError_t rtProfilerInit(const char_t *profDir, const char_t *address, const char_t *jobCtx); + +/** + * @ingroup profiling_base + * @brief config rts profiler. + */ +RTS_API rtError_t rtProfilerConfig(uint16_t profConfig); + +/** + * @ingroup profiling_base + * @brief ts send keypoint profiler log. + */ +RTS_API rtError_t rtProfilerTrace(uint64_t id, bool notify, uint32_t flags, rtStream_t stm); + +/** + * @ingroup profiling_base + * @brief ts send keypoint profiler log. + */ +RTS_API rtError_t rtProfilerTraceEx(uint64_t id, uint64_t modelId, uint16_t tagId, rtStream_t stm); + +/** + * @ingroup profiling_base + * @brief ts set profiling reporter callback. + */ +RTS_API rtError_t rtSetMsprofReporterCallback(MsprofReporterCallback callback); + +/** + * @ingroup profiling_base + * @brief add the map of deviceId and GE model index, called by ge + * @param [in] geModelIdx The index of GE model + * @param [in] deviceId The id of device + * @return RT_ERROR_NONE for ok + * @return ACL_ERROR_RT_PARAM_INVALID for error input + */ +RTS_API rtError_t rtSetDeviceIdByGeModelIdx(uint32_t geModelIdx, uint32_t deviceId); + +/** + * @ingroup profiling_base + * @brief del the map of deviceId and GE model index, called by ge + * @param [in] geModelIdx The index of GE model + * @param [in] deviceId The id of device + * @return RT_ERROR_NONE for ok + * @return ACL_ERROR_RT_PARAM_INVALID for error input + */ +RTS_API rtError_t rtUnsetDeviceIdByGeModelIdx(uint32_t geModelIdx, uint32_t deviceId); + +/** + * @ingroup profiling_base + * @brief find deviceId by GE model index, called by profiling + * @param [in] geModelIdx The index of GE model + * @param [out] deviceId The id of device + * @return RT_ERROR_NONE for ok + * @return ACL_ERROR_RT_PARAM_INVALID for error input + * @return ACL_ERROR_RT_INTERNAL_ERROR for can't find deviceId by geModelIdx + */ +RTS_API rtError_t rtGetDeviceIdByGeModelIdx(uint32_t geModelIdx, uint32_t *deviceId); + +/** + * @ingroup profiling_base + * @brief set profling switch, called by profiling + * @param [in] data rtProfCommandHandle + * @param [out] len length of data + * @return RT_ERROR_NONE for ok + * @return ACL_ERROR_RT_PARAM_INVALID for error input + */ +RTS_API rtError_t rtProfSetProSwitch(void *data, uint32_t len); + +/** + * @ingroup profiling_base + * @brief register callback of upper app, called by ge or acl + * @param [in] moduleId of APP + * @param [in] callback function when switch or reporter change + * @return RT_ERROR_NONE for ok + * @return ACL_ERROR_RT_PARAM_INVALID for error input + */ +RTS_API rtError_t rtProfRegisterCtrlCallback(uint32_t moduleId, rtProfCtrlHandle callback); + +/** + * @ingroup profiling_base + * @brief set profling switch, called by profiling + * @param [in] data rtProfilingCommandHandle + * @param [in] len length of data + * @return RT_ERROR_NONE for ok + * @return ACL_ERROR_RT_PARAM_INVALID for error input + */ +RTS_API rtError_t rtProfilingCommandHandle(uint32_t type, void *data, uint32_t len); + +/** + * @ingroup dvrt_base + * @brief register callback for error code + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetExceptCallback(rtErrorCallback callback); + +/** + * @ingroup dvrt_base + * @brief register callback for error code + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetTaskAbortCallBack(const char *moduleName, rtTaskAbortCallBack callback, void *args); + +/** + * @ingroup dvrt_base + * @brief register callback for task fail + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback); + +typedef enum DevCallBackDir { + DEV_CB_POS_FRONT = 1, + DEV_CB_POS_BACK = 2, + DEV_CB_POS_END +} rtDevCallBackDir_t; +/** + * @ingroup dvrt_base + * @brief register callback for deviceid by position + * @param [in] regName unique register name, can't be null + * @param [in] callback Device state callback function + * @param [in] notifyPos callback notify Postion + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtRegDeviceStateCallbackEx(const char_t *regName, rtDeviceStateCallback callback, + const rtDevCallBackDir_t notifyPos); + +/** + * @ingroup dvrt_base + * @brief register callback for fail task + * @param [in] uniName unique register name, can't be null + * @param [in] callback fail task callback function + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtRegTaskFailCallbackByModule(const char_t *moduleName, rtTaskFailCallback callback); + +/** + * @ingroup dvrt_base + * @brief notify handle. + */ +typedef void *rtNotify_t; + +/** + * @ingroup dvrt_base + * @brief create label instance + * @param [out] lbl created label + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelCreate(rtLabel_t *lbl); + +/** + * @ingroup dvrt_base + * @brief create label instance + * @param [out] lbl created label + * @param [in] mdl label set model + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelCreateV2(rtLabel_t *lbl, rtModel_t mdl); + +/** + * @ingroup dvrt_base + * @brief set label and stream instance + * @param [in] lbl set label + * @param [in] stm set stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelSet(rtLabel_t lbl, rtStream_t stm); + +/** + * @ingroup dvrt_base + * @brief destroy label instance + * @param [in] lbl label to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelDestroy(rtLabel_t lbl); + +/** + * @ingroup dvrt_base + * @brief goto label instance + * @param [in] lbl goto label + * @param [in] stm to submit label_goto task + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelGoto(rtLabel_t lbl, rtStream_t stm); + +/** + * @ingroup dvrt_base + * @brief label switch by index + * @param [in] ptr index value ptr + * @param [in] maxValue index max value + * @param [in] labelInfoPtr label content info ptr + * @param [in] stm set stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t maxValue, void *labelInfoPtr, rtStream_t stm); + +/** + * @ingroup dvrt_base + * @brief stream goto label + * @param [in] lbl goto label + * @param [in] stm stream to submit label_goto task + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelGotoEx(rtLabel_t lbl, rtStream_t stm); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [in] lbl model label list + * @param [in] labelNumber label number + * @param [in] dst device ptr + * @param [in] dstMax dst size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelListCpy(rtLabel_t *lbl, uint32_t labelNumber, void *dst, uint32_t dstMax); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [out] lbl created label handle + * @param [in] stm label bind stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelCreateEx(rtLabel_t *lbl, rtStream_t stm); + +/** + * @ingroup dvrt_base + * @brief labels to dev info + * @param [out] lbl created label handle + * @param [in] mdl label bind model + * @param [in] stm label bind stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLabelCreateExV2(rtLabel_t *lbl, rtModel_t mdl, rtStream_t stm); + +/** + * @ingroup dvrt_base + * @brief get current thread last stream id and task id + * @param [out] stm id and task id + * @param [in] null + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for input null ptr + */ +RTS_API rtError_t rtGetTaskIdAndStreamID(uint32_t *taskId, uint32_t *streamId); + +/** + * @ingroup dvrt_base + * @brief get max model num + * @param [out] max model num + * @param [in] null + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetMaxModelNum(uint32_t *maxModelCount); + +/** + * @ingroup dvrt_base + * @brief set stream mode + * @param [in] stm stream needed to be set mode + * @param [in] stmMode mode + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtStreamSetMode(rtStream_t stm, const uint64_t stmMode); + +/** + * @ingroup dvrt_base + * @brief get stream mode + * @param [in] stm stream needed to get its mode + * @param [out] stmMode mode pointer + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtStreamGetMode(rtStream_t const stm, uint64_t * const stmMode); + +#define RT_PROCESS_SIGN_LENGTH (49) + +typedef enum tagRtDevDrvProcessType { + RT_DEVDRV_PROCESS_CP1 = 0, /* aicpu_scheduler */ + RT_DEVDRV_PROCESS_CP2, /* custom_process */ + RT_DEVDRV_PROCESS_DEV_ONLY, /* TDT */ + RT_DEVDRV_PROCESS_QS, /* queue_scheduler */ + RT_DEVDRV_PROCESS_HCCP, /* hccp server */ + RT_DEVDRV_PROCESS_USER, /* user proc, can bind many on host or device */ + RT_DEVDRV_PROCESS_CPTYPE_MAX +} rtDevDrvProcessType_t; + +typedef struct tagRtBindHostpidInfo { + int32_t hostPid; + uint32_t vfId; + uint32_t chipId; + int32_t mode; + rtDevDrvProcessType_t cpType; + uint32_t len; + char sign[RT_PROCESS_SIGN_LENGTH]; +} rtBindHostpidInfo; + +/** + * @ingroup dvrt_base + * @brief Bind Device custom-process to aicpu-process. + * @param [in] info The Information about the bound hostid. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtBindHostPid(rtBindHostpidInfo info); + +/** + * @ingroup dvrt_base + * @brief Unbind Device custom-process to aicpu-process. + * @param [in] info The Information about the bound hostid. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtUnbindHostPid(rtBindHostpidInfo info); + +/** + * @ingroup dvrt_base + * @brief Query the binding information of the devpid. + * @param [in] pid: dev pid + * @param [in] chipId chip id + * @param [in] vfId vf id + * @param [in] hostPid host pid + * @param [in] cpType type of custom-process + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtQueryProcessHostPid(int32_t pid, uint32_t *chipId, uint32_t *vfId, uint32_t *hostPid, + uint32_t *cpType); +/** + * @ingroup dvrt_base + * @brief Sets the SSID of the shinared notify. + * @param [in] name share id name to be set + * @param [in] sdid whitelisted sdid + * @param [in] pid whitelisted process + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtSetIpcNotifySuperPodPid(const char *name, uint32_t sdid, int32_t pid); + +/** + * @ingroup dvrt_base + * @brief Setting SSIDs of Shared Memory in Batches. + * @param [in] name Name used for sharing between processes + * @param [in] sdid whitelisted sdid + * @param [in] pid host pid whitelist array + * @param [in] num number of pid arrays + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtSetIpcMemorySuperPodPid(const char *name, uint32_t sdid, int32_t pid[], int32_t num); +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_BASE_H diff --git a/inc/runtime/runtime/config.h b/inc/runtime/runtime/config.h new file mode 100644 index 0000000000000000000000000000000000000000..8c11f48c2e4f56c686ed67f9dba5ac3a337cebcb --- /dev/null +++ b/inc/runtime/runtime/config.h @@ -0,0 +1,313 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: config.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_CONFIG_H +#define CCE_RUNTIME_CONFIG_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#define PLAT_COMBINE(arch, chip, ver) (((arch) << 16U) | ((chip) << 8U) | (ver)) +#define PLAT_GET_ARCH(type) (((type) >> 16U) & 0xffffU) +#define PLAT_GET_CHIP(type) (((type) >> 8U) & 0xffU) +#define PLAT_GET_VER(type) ((type) & 0xffU) + +typedef enum tagRtArchType { + ARCH_BEGIN = 0, + ARCH_V100 = ARCH_BEGIN, + ARCH_V200 = 1, + ARCH_V300 = 2, + ARCH_C100 = 3, /* Ascend910 */ + ARCH_C220 = 4, /* Ascend910B & Ascend910_93 */ + ARCH_M100 = 5, /* Ascend310 */ + ARCH_M200 = 6, /* Ascend310P & Ascend610 */ + ARCH_M201 = 7, /* BS9SX1A */ + ARCH_T300 = 8, /* Tiny */ + ARCH_N350 = 9, /* Nano */ + ARCH_M300 = 10, /* Ascend310B & AS31XM1X */ + ARCH_M310 = 11, /* Ascend610Lite */ + ARCH_S200 = 12, /* Hi3796CV300ES & TsnsE */ + ARCH_S202 = 13, /* Hi3796CV300CS & OPTG & SD3403 &TsnsC */ + ARCH_END, +} rtArchType_t; + +typedef enum tagRtChipType { + CHIP_BEGIN = 0, + CHIP_MINI = CHIP_BEGIN, + CHIP_CLOUD = 1, + CHIP_MDC = 2, + CHIP_LHISI = 3, + CHIP_DC = 4, + CHIP_CLOUD_V2 = 5, + CHIP_NO_DEVICE = 6, + CHIP_MINI_V3 = 7, + CHIP_5612 = 8, /* 1910b tiny */ + CHIP_NANO = 9, + CHIP_1636 = 10, + CHIP_AS31XM1 = 11, + CHIP_610LITE = 12, + CHIP_DAVID = 13, + CHIP_END +} rtChipType_t; + +typedef enum tagRtAicpuScheType { + SCHEDULE_SOFTWARE = 0, /* Software Schedule */ + SCHEDULE_SOFTWARE_OPT, + SCHEDULE_HARDWARE, /* HWTS Schedule */ +} rtAicpuScheType; + +typedef enum tagRtDeviceCapabilityType { + RT_SCHEDULE_SOFTWARE = 0, // Software Schedule + RT_SCHEDULE_SOFTWARE_OPT, + RT_SCHEDULE_HARDWARE, // HWTS Schedule + RT_AICPU_BLOCKING_OP_NOT_SUPPORT, + RT_AICPU_BLOCKING_OP_SUPPORT, // 1910/1980/51 ts support AICPU blocking operation + RT_MODE_NO_FFTS, // no ffts + RT_MODE_FFTS, // 81 get ffts work mode, ffts + RT_MODE_FFTS_PLUS, // 81 get ffts work mode, ffts plus + RT_DEV_CAP_SUPPORT, // Capability Support + RT_DEV_CAP_NOT_SUPPORT, // Capability not support +} rtDeviceCapabilityType; + +typedef enum tagRtVersion { + VER_BEGIN = 0, + VER_NA = VER_BEGIN, + VER_ES = 1, + VER_CS = 2, + VER_SD3403 = 3, + VER_LITE = 4, + VER_310M1 = 5, + VER_END = 6, +} rtVersion_t; + +/* match rtChipType_t */ +typedef enum tagRtPlatformType { + PLATFORM_BEGIN = 0, + PLATFORM_MINI_V1 = PLATFORM_BEGIN, + PLATFORM_CLOUD_V1 = 1, + PLATFORM_MINI_V2 = 2, + PLATFORM_LHISI_ES = 3, + PLATFORM_LHISI_CS = 4, + PLATFORM_DC = 5, + PLATFORM_CLOUD_V2 = 6, + PLATFORM_LHISI_SD3403 = 7, + PLATFORM_MINI_V3 = 8, + PLATFORM_MINI_5612 = 9, + PLATFORM_CLOUD_V2_910B1 = 10, + PLATFORM_CLOUD_V2_910B2 = 11, + PLATFORM_CLOUD_V2_910B3 = 12, + PLATFORM_MDC_BS9SX1A = 16, + PLATFORM_MDC_AS31XM1X = 17, + PLATFORM_MDC_LITE = 18, + PLATFORM_MINI_V3_B1 = 19, + PLATFORM_MINI_V3_B2 = 20, + PLATFORM_MINI_V3_B3 = 21, + PLATFORM_MINI_V3_B4 = 22, + PLATFORM_1636 = 23, + PLATFORM_CLOUD_V2_910B2C = 24, + PLATFORM_DAVID_910_95 = 26, + PLATFORM_CLOUD_V2_910B4_1 = 27, + PLATFORM_END +} rtPlatformType_t; + +typedef enum tagRtCubeFracMKNFp16 { + RT_CUBE_MKN_FP16_2_16_16 = 0, + RT_CUBE_MKN_FP16_4_16_16, + RT_CUBE_MKN_FP16_16_16_16, + RT_CUBE_MKN_FP16_Default, +} rtCubeFracMKNFp16_t; + +typedef enum tagRtCubeFracMKNInt8 { + RT_CUBE_MKN_INT8_2_32_16 = 0, + RT_CUBE_MKN_INT8_4_32_4, + RT_CUBE_MKN_INT8_4_32_16, + RT_CUBE_MKN_INT8_16_32_16, + RT_CUBE_MKN_INT8_Default, +} rtCubeFracMKNInt8_t; + +typedef enum tagRtVecFracVmulMKNFp16 { + RT_VEC_VMUL_MKN_FP16_1_16_16 = 0, + RT_VEC_VMUL_MKN_FP16_Default, +} rtVecFracVmulMKNFp16_t; + +typedef enum tagRtVecFracVmulMKNInt8 { + RT_VEC_VMUL_MKN_INT8_1_32_16 = 0, + RT_VEC_VMUL_MKN_INT8_Default, +} rtVecFracVmulMKNInt8_t; + +typedef struct tagRtAiCoreSpec { + uint32_t cubeFreq; + uint32_t cubeMSize; + uint32_t cubeKSize; + uint32_t cubeNSize; + rtCubeFracMKNFp16_t cubeFracMKNFp16; + rtCubeFracMKNInt8_t cubeFracMKNInt8; + rtVecFracVmulMKNFp16_t vecFracVmulMKNFp16; + rtVecFracVmulMKNInt8_t vecFracVmulMKNInt8; +} rtAiCoreSpec_t; + +typedef struct tagRtAiCoreRatesPara { + uint32_t ddrRate; + uint32_t l2Rate; + uint32_t l2ReadRate; + uint32_t l2WriteRate; + uint32_t l1ToL0ARate; + uint32_t l1ToL0BRate; + uint32_t l0CToUBRate; + uint32_t ubToL2; + uint32_t ubToDDR; + uint32_t ubToL1; +} rtAiCoreMemoryRates_t; + +typedef struct tagRtMemoryConfig { + uint32_t flowtableSize; + uint32_t compilerSize; +} rtMemoryConfig_t; + +typedef struct tagRtPlatformConfig { + uint32_t platformConfig; +} rtPlatformConfig_t; + +typedef enum tagRTTaskTimeoutType { + RT_TIMEOUT_TYPE_OP_WAIT = 0, + RT_TIMEOUT_TYPE_OP_EXECUTE, +} rtTaskTimeoutType_t; + +typedef enum tagRTLastErrLevel { + RT_THREAD_LEVEL = 0, + RT_CONTEXT_LEVEL, +} rtLastErrLevel_t; +/** + * @ingroup + * @brief get AI core count + * @param [in] aiCoreCnt + * @return aiCoreCnt + */ +RTS_API rtError_t rtGetAiCoreCount(uint32_t *aiCoreCnt); + +/** + * @ingroup + * @brief get AI cpu count + * @param [in] aiCpuCnt + * @return aiCpuCnt + */ +RTS_API rtError_t rtGetAiCpuCount(uint32_t *aiCpuCnt); + +/** + * @ingroup + * @brief get AI core frequency + * @param [in] aiCoreSpec + * @return aiCoreSpec + */ +RTS_API rtError_t rtGetAiCoreSpec(rtAiCoreSpec_t *aiCoreSpec); + +/** + * @ingroup + * @brief AI get core band Info + * @param [in] aiCoreMemoryRates + * @return aiCoreMemoryRates + */ +RTS_API rtError_t rtGetAiCoreMemoryRates(rtAiCoreMemoryRates_t *aiCoreMemoryRates); + +/** + * @ingroup + * @brief AI get core buffer Info,FlowTable Size,Compiler Size + * @param [in] memoryConfig + * @return memoryConfig + */ +RTS_API rtError_t rtGetMemoryConfig(rtMemoryConfig_t *memoryConfig); + +/** + * @ingroup + * @brief get l2 buffer Info,virtual baseaddr,Size + * @param [in] stm + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtMemGetL2Info(rtStream_t stm, void **ptr, uint32_t *size); + +/** + * @ingroup + * @brief get runtime version. The version is returned as (1000 major + 10 minor). For example, RUNTIME 9.2 would be + * represented by 9020. + * @param [out] runtimeVersion + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetRuntimeVersion(uint32_t *runtimeVersion); + + +/** + * @ingroup + * @brief get device feature ability by device id, such as task schedule ability. + * @param [in] deviceId + * @param [in] moduleType + * @param [in] featureType + * @param [out] val + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDeviceCapability(int32_t deviceId, int32_t moduleType, int32_t featureType, int32_t *val); + +/** + * @ingroup + * @brief set event wait task timeout time. + * @param [in] timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetOpWaitTimeOut(uint32_t timeout); + +/** + * @ingroup + * @brief set op execute task timeout time. + * @param [in] timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetOpExecuteTimeOut(uint32_t timeout); + +/** + * @ingroup + * @brief get op execute task timeout time. + * @param [out] timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetOpExecuteTimeOut(uint16_t * const timeout); + +/** + * @ingroup + * @brief get is Heterogenous. + * @param [out] heterogenous=1 Heterogenous Mode: read isHeterogenous=1 in ini file. + * @param [out] heterogenous=0 NOT Heterogenous Mode: + * 1:not found ini file, 2:error when reading ini, 3:Heterogenous value is not 1 + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetIsHeterogenous(int32_t *heterogenous); + +/** + * @ingroup + * @brief get latest errcode and clear it. + * @param [in] level error level for this api. + * @return return for error code + */ +RTS_API rtError_t rtGetLastError(rtLastErrLevel_t level); + +/** + * @ingroup + * @brief get latest errcode. + * @param [in] level error level for this api. + * @return return for error code + */ +RTS_API rtError_t rtPeekAtLastError(rtLastErrLevel_t level); +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_CONFIG_H diff --git a/inc/runtime/runtime/context.h b/inc/runtime/runtime/context.h new file mode 100644 index 0000000000000000000000000000000000000000..bb72913b3d10a38ffdd0187c01a85c01b73b8888 --- /dev/null +++ b/inc/runtime/runtime/context.h @@ -0,0 +1,185 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: context.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_CONTEXT_H +#define CCE_RUNTIME_CONTEXT_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * @ingroup rt_context + * @brief runtime context handle. + */ +typedef void *rtContext_t; + +typedef enum tagDryRunFlag { + RT_DRYRUN_FLAG_FALSE = 0, + RT_DRYRUN_FLAG_TRUE = 1, +} rtDryRunFlag_t; + +typedef enum tagCtxMode { + RT_CTX_NORMAL_MODE = 0, + RT_CTX_GEN_MODE = 1, +} rtCtxMode_t; + +typedef struct tagRtGroupInfo { + int32_t groupId; + uint32_t flag; + uint32_t aicoreNum; + uint32_t aicpuNum; + uint32_t aivectorNum; + uint32_t sdmaNum; + uint32_t activeStreamNum; + void *extrPtr; +} rtGroupInfo_t; + +/** + * @ingroup rt_context + * @brief create context and associates it with the calling thread + * @param [out] createCtx created context + * @param [in] flags context creation flag. set to 0. + * @param [in] devId device to create context on + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxCreate(rtContext_t *createCtx, uint32_t flags, int32_t devId); + +/** + * @ingroup rt_context + * @brief create context and associates it with the calling thread + * @param [out] createCtx created context + * @param [in] flags context creation flag. set to 0. + * @param [in] devId device to create context on + * @param [in] deviceMode the device mode + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxCreateV2(rtContext_t *createCtx, uint32_t flags, int32_t devId, rtDeviceMode deviceMode); + +/** + * @ingroup rt_context + * @brief create context and associates it with the calling thread + * @param [out] createCtx created context + * @param [in] flags context creation flag. set to 0. + * @param [in] devId device to create context on + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxCreateEx(rtContext_t *createCtx, uint32_t flags, int32_t devId); + +/** + * @ingroup rt_context + * @brief destroy context instance + * @param [in] destroyCtx context to destroy + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxDestroy(rtContext_t destroyCtx); + +/** + * @ingroup rt_context + * @brief destroy context instance + * @param [in] destroyCtx context to destroy + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxDestroyEx(rtContext_t destroyCtx); + +/** + * @ingroup rt_context + * @brief binds context to the calling CPU thread. + * @param [in] currentCtx context to bind. if NULL, unbind current context. + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxSetCurrent(rtContext_t currentCtx); + +/** + * @ingroup rt_context + * @brief returns the context bound to the calling CPU thread. + * @param [out] currentCtx returned context + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxGetCurrent(rtContext_t *currentCtx); + +/** + * @ingroup rt_context + * @brief returns the primary context of device. + * @param [out] primaryCtx returned context + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPriCtxByDeviceId(int32_t devId, rtContext_t *primaryCtx); + +/** + * @ingroup rt_context + * @brief returns the device ID for the current context + * @param [out] devId returned device id + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxGetDevice(int32_t *devId); + +/** + * @ingroup + * @brief set group id + * @param [in] groupid + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtSetGroup(int32_t groupId); + +/** + * @ingroup + * @brief get group info + * @param [in] groupid count + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetGroupInfo(int32_t groupId, rtGroupInfo_t *groupInfo, uint32_t cnt); + +/** + * @ingroup + * @brief get group count + * @param [in] groupid count + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetGroupCount(uint32_t *cnt); + +/** + * @ingroup rt_context + * @brief set context INF mode + * @param [in] infMode + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetCtxINFMode(bool infMode); + +/** + * @ingroup rt_context + * @brief set current context system param option + * @param [in] configOpt system option to be set + * @param [in] configVal system option's value to be set + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtCtxSetSysParamOpt(const rtSysParamOpt configOpt, const int64_t configVal); + +/** + * @ingroup rt_context + * @brief get current context system param option's value + * @param [in] configOpt system option to be get value + * @param [out] configVal system option's value to be get + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtCtxGetSysParamOpt(const rtSysParamOpt configOpt, int64_t * const configVal); + +/** + * @ingroup rt_context + * @brief get context overflowAddr + * @param [out] overflowAddr current ctx's overflowAddr to be get + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtCtxGetOverflowAddr(void **overflowAddr); + +#if defined(__cplusplus) +} +#endif + + +#endif // CCE_RUNTIME_CONTEXT_H diff --git a/inc/runtime/runtime/dev.h b/inc/runtime/runtime/dev.h new file mode 100644 index 0000000000000000000000000000000000000000..f6a401979ace6eab462c3d91b3bf8da804079eac --- /dev/null +++ b/inc/runtime/runtime/dev.h @@ -0,0 +1,633 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: dev.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_DEVICE_H +#define CCE_RUNTIME_DEVICE_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#define RT_CAPABILITY_SUPPORT (0x1U) +#define RT_CAPABILITY_NOT_SUPPORT (0x0U) +#define MEMORY_INFO_TS_4G_LIMITED (0x0U) // for compatibility + +#define INFO_TYPE_CUBE_NUM (0x775A5A5AU) // for cube core num + +typedef struct tagRTDeviceInfo { + uint8_t env_type; // 0: FPGA 1: EMU 2: ESL + uint32_t ctrl_cpu_ip; + uint32_t ctrl_cpu_id; + uint32_t ctrl_cpu_core_num; + uint32_t ctrl_cpu_endian_little; + uint32_t ts_cpu_core_num; + uint32_t ai_cpu_core_num; + uint32_t ai_core_num; + uint32_t ai_core_freq; + uint32_t ai_cpu_core_id; + uint32_t ai_core_id; + uint32_t aicpu_occupy_bitmap; + uint32_t hardware_version; + uint32_t ts_num; +} rtDeviceInfo_t; + +typedef enum tagRtRunMode { + RT_RUN_MODE_OFFLINE = 0, + RT_RUN_MODE_ONLINE, + RT_RUN_MODE_AICPU_SCHED, + RT_RUN_MODE_RESERVED +} rtRunMode; + +typedef enum tagRtAicpuDeployType { + AICPU_DEPLOY_CROSS_OS = 0x0, + AICPU_DEPLOY_CROSS_PROCESS, + AICPU_DEPLOY_CROSS_THREAD, + AICPU_DEPLOY_RESERVED +} rtAicpuDeployType_t; + +typedef enum tagRtFeatureType { + FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_MEMORY, + FEATURE_TYPE_UPDATE_SQE, + FEATURE_TYPE_RSV +} rtFeatureType_t; + +typedef enum tagRtDeviceFeatureType { + FEATURE_TYPE_SCHE, + FEATURE_TYPE_BLOCKING_OPERATOR, + FEATURE_TYPE_FFTS_MODE, + FEATURE_TYPE_MEMQ_EVENT_CROSS_DEV, + FEATURE_TYPE_MODEL_TASK_UPDATE, + FEATURE_TYPE_END, +} rtDeviceFeatureType_t; + +typedef enum tagMemcpyInfo { + MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, + MEMCPY_INFO_RSV +} rtMemcpyInfo_t; + +typedef enum tagMemoryInfo { + MEMORY_INFO_TS_LIMITED = 0, + MEMORY_INFO_RSV +} rtMemoryInfo_t; + +typedef enum tagUpdateSQEInfo { + UPDATE_SQE_SUPPORT_DSA = 0 +} rtUpdateSQEInfo_t; + +typedef enum tagRtDeviceModuleType { + RT_MODULE_TYPE_SYSTEM = 0, /**< system info*/ + RT_MODULE_TYPE_AICPU, /** < aicpu info*/ + RT_MODULE_TYPE_CCPU, /**< ccpu_info*/ + RT_MODULE_TYPE_DCPU, /**< dcpu info*/ + RT_MODULE_TYPE_AICORE, /**< AI CORE info*/ + RT_MODULE_TYPE_TSCPU, /**< tscpu info*/ + RT_MODULE_TYPE_PCIE, /**< PCIE info*/ + RT_MODULE_TYPE_VECTOR_CORE, /**< VECTOR CORE info*/ + RT_MODULE_TYPE_HOST_AICPU, /**< HOST AICPU info*/ + RT_MODULE_TYPE_QOS, /** */ + RT_MODULE_TYPE_MEMORY /**phyIdSrc. + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEnableP2P(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t flag); + +/** + * @ingroup dvrt_dev + * @brief disable direction:devIdDes---->phyIdSrc. + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); + +/** + * @ingroup dvrt_dev + * @brief get cability of P2P omemry copy betwen device and peeredevic. + * @param [in] devId the logical device id + * @param [in] peerDevice the physical device id + * @param [outv] *canAccessPeer 1:enable 0:disable + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceCanAccessPeer(int32_t *canAccessPeer, uint32_t devId, uint32_t peerDevice); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @param [in|out] status status value + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetP2PStatus(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t *status); + +/** + * @ingroup dvrt_dev + * @brief get value of current thread + * @param [in|out] pid value of pid + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDeviceGetBareTgid(uint32_t *pid); + +/** + * @ingroup dvrt_dev + * @brief get target device of current thread + * @param [in|out] devId the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDevice(int32_t *devId); + +/** + * @ingroup dvrt_dev + * @brief reset all opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceReset(int32_t devId); + +/** + * @ingroup dvrt_dev + * @brief reset opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceResetEx(int32_t devId); + +/** + * @ingroup dvrt_dev + * @brief get total device infomation. + * @param [in] devId the device id + * @param [in] type limit type RT_LIMIT_TYPE_LOW_POWER_TIMEOUT=0 + * @param [in] val limit value + * @param [out] info the device info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceSetLimit(int32_t devId, rtLimitType_t type, uint32_t val); + +/** + * @ingroup dvrt_dev + * @brief Wait for compute device to finish + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceSynchronize(void); + +/** + * @ingroup dvrt_dev + * @brief Wait for compute device to finish and set timeout + * @param [in] timeout timeout value,the unit is milliseconds + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceSynchronizeWithTimeout(int32_t timeout); + +/** + * @ingroup dvrt_dev + * @brief Wait for compute device to finish + * @param [in] devId the device id + * @param [in] timeout the time for hadle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceTaskAbort(int32_t devId, uint32_t timeout); + +/** + * @ingroup dvrt_dev + * @brief get priority range of current device + * @param [in|out] leastPriority least priority + * @param [in|out] greatestPriority greatest priority + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceGetStreamPriorityRange(int32_t *leastPriority, int32_t *greatestPriority); + +/** + * @ingroup dvrt_dev + * @brief Setting Scheduling Type of Graph + * @param [in] tsId the ts id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetTSDevice(uint32_t tsId); + +/** + * @ingroup dvrt_dev + * @brief init aicpu executor + * @param [out] runtime run mode + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for can not get run mode + */ +RTS_API rtError_t rtGetRunMode(rtRunMode *runMode); + +/** + * @ingroup dvrt_dev + * @brief get aicpu deploy + * @param [out] aicpu deploy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for can not get aicpu deploy + */ +RTS_API rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deployType); + +/** + * @ingroup dvrt_dev + * @brief set chipType + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetSocVersion(const char_t *ver); + +/** + * @ingroup dvrt_dev + * @brief get chipType + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetSocVersion(char_t *ver, const uint32_t maxLen); + +/** + * @ingroup dvrt_dev + * @brief check socversion + * @param [in] omSocVersion OM SocVersion + * @param [in] omArchVersion OM ArchVersion + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtModelCheckCompatibility(const char_t *omSocVersion, const char_t *omArchVersion); + +/** + * @ingroup dvrt_dev + * @brief get device status + * @param [in] devId the device id + * @param [out] deviceStatus the device statue + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDeviceStatusQuery(const uint32_t devId, rtDeviceStatus *deviceStatus); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devId the logical device id + * @param [in] otherDevId the other logical device id + * @param [in] infoType info type + * @param [in|out] val pair info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPairDevicesInfo(uint32_t devId, uint32_t otherDevId, int32_t infoType, int64_t *val); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devId the physic device id + * @param [in] otherDevId the other physic device id + * @param [in] infoType info type + * @param [in|out] val pair info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPairPhyDevicesInfo(uint32_t devId, uint32_t otherDevId, int32_t infoType, int64_t *val); + +/** + * @ingroup dvrt_dev + * @brief get capability infomation. + * @param [in] featureType feature type + typedef enum tagRtFeatureType { + FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_RSV, + } rtFeatureType_t; + * @param [in] featureInfo info type + typedef enum tagMemcpyInfo { + MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, + MEMCPY_INFO _RSV, + } rtMemcpyInfo_t; + * @param [out] val the capability info RT_CAPABILITY_SUPPORT or RT_CAPABILITY_NOT_SUPPORT + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetRtCapability(rtFeatureType_t featureType, int32_t featureInfo, int64_t *val); + +/** + * @ingroup dvrt_dev + * @brief set target device for current thread + * @param [int] devId the device id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetDeviceWithoutTsd(int32_t devId); + +/** + * @ingroup dvrt_dev + * @brief reset all opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceResetWithoutTsd(int32_t devId); + +/** + * @ingroup dvrt_dev + * @brief get device message + * @param [in] rtGetDevMsgType_t getMsgType:msg type + * @param [in] GetMsgCallback callback:acl callback function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetDevMsg(rtGetDevMsgType_t getMsgType, rtGetMsgCallback callback); + +/** + * @ingroup dvrt_dev + * @brief get ts mem type + * @param [in] rtMemRequestFeature_t mem request feature type + * @param [in] mem request size + * @return RT_MEMORY_TS, RT_MEMORY_HBM, RT_MEMORY_TS | RT_MEMORY_POLICY_HUGE_PAGE_ONLY + */ +RTS_API uint32_t rtGetTsMemType(rtMemRequestFeature_t featureType, uint32_t memSize); + +/** + * @ingroup + * @brief set saturation mode for current device. + * @param [in] saturation mode. + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetDeviceSatMode(rtFloatOverflowMode_t floatOverflowMode); + +/** + * @ingroup + * @brief get saturation mode for current device. + * @param [out] saturation mode. + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetDeviceSatMode(rtFloatOverflowMode_t *floatOverflowMode); + +/** + * @ingroup + * @brief get saturation mode for target stream. + * @param [in] target stm + * @param [out] saturation mode. + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetDeviceSatModeForStream(rtStream_t stm, rtFloatOverflowMode_t *floatOverflowMode); + +/** + * @ingroup + * @brief get pyh deviceid of current process + * @param [in] logicDeviceId user deviceid + * @param [out] *visibleDeviceId deviceid configured using the environment variable ASCEND_RT_VISIBLE_DEVICES + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetVisibleDeviceIdByLogicDeviceId(const int32_t logicDeviceId, int32_t * const visibleDeviceId); +/** + * @ingroup + * @brief get aicore/aivectoe/aicpu utilizations + * @param [int] devId the device id + * @param [int] kind util type + * @param [out] *util for aicore/aivectoe/aicpu + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetAllUtilizations(const int32_t devId, const rtTypeUtil_t kind, uint8_t * const util); +/** + * @ingroup + * @brief get serverid by sdid + * @param [int] sdid sdid + * @param [out] *srvId serverid + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetServerIDBySDID(uint32_t sdid, uint32_t *srvId); + +/** + * @ingroup dvrt_dev + * @brief set default device id + * @param [int] deviceId deviceId + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetDefaultDeviceId(const int32_t deviceId); + +/** + * @ingroup dvrt_dev + * @brief force reset all opened device + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDeviceResetForce(int32_t devId); + +/** + * @ingroup + * @brief set failure mode for current device + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetDeviceFailureMode(uint64_t failureMode); + +/** + * @ingroup + * @brief get system param option's value + * @param [in] configOpt system option to be get value + * @param [out] configVal system option's value to be get + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetSysParamOpt(const rtSysParamOpt configOpt, int64_t * const configVal); + +/** + * @ingroup + * @brief set system param option + * @param [in] configOpt system option to be set + * @param [in] configVal system option's value to be set + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtSetSysParamOpt(const rtSysParamOpt configOpt, const int64_t configVal); + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_DEVICE_H diff --git a/inc/runtime/runtime/event.h b/inc/runtime/runtime/event.h new file mode 100644 index 0000000000000000000000000000000000000000..426cc59085cedf140f04112ae7640cae47d28168 --- /dev/null +++ b/inc/runtime/runtime/event.h @@ -0,0 +1,501 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: event.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_EVENT_H +#define CCE_RUNTIME_EVENT_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#define RT_IPCINT_MSGLEN_MAX (0x8U) + +typedef enum rtEventWaitStatus { + EVENT_STATUS_COMPLETE = 0, + EVENT_STATUS_NOT_READY = 1, + EVENT_STATUS_MAX = 2, +} rtEventWaitStatus_t; + +typedef enum rtEventStatus { + RT_EVENT_INIT = 0, + RT_EVENT_RECORDED = 1, +} rtEventStatus_t; + +typedef struct tagIpcIntNoticeInfo { + uint32_t ntcPid; + uint16_t ntcGrpId; + uint16_t ntcTid; + uint16_t ntcSubEventId; + uint8_t ntcEventId; + uint8_t msgLen; + uint8_t msg[RT_IPCINT_MSGLEN_MAX]; + uint16_t phyDevId; +} rtIpcIntNoticeInfo_t; + +typedef enum { + RT_STREAM_ID = 0, + RT_EVENT_ID, + RT_MODEL_ID, + RT_NOTIFY_ID, + RT_CMO_ID, + RT_CNT_NOTIFY_ID, /* add start ascend910d */ + RT_INVALID_ID, +} rtIdType_t; + +/** + * @ingroup event_flags + * @brief event op bit flags + */ +#define RT_EVENT_DDSYNC_NS 0x01U +#define RT_EVENT_STREAM_MARK 0x02U +#define RT_EVENT_DDSYNC 0x04U +#define RT_EVENT_TIME_LINE 0x08U +#define RT_EVENT_MC2 0x10U // RT_EVENT_MC2 does not support OR with other flags + +#define RT_EVENT_DEFAULT (RT_EVENT_DDSYNC | RT_EVENT_TIME_LINE | RT_EVENT_STREAM_MARK) +#define RT_EVENT_WITH_FLAG (RT_EVENT_DDSYNC_NS) + +#define RT_NOTIFY_FLAG_DEFAULT (0x00U) +#define RT_NOTIFY_FLAG_DOWNLOAD_TO_DEV (0x01U) // RT_NOTIFY_FLAG_DOWNLOAD_TO_DEV does not support OR with other flags +#define RT_NOTIFY_FLAG_SHR_ID_SHADOW (0x1U << 6) +#define RT_NOTIFY_FLAG_MAX \ + (RT_NOTIFY_FLAG_DOWNLOAD_TO_DEV |RT_NOTIFY_FLAG_SHR_ID_SHADOW) +/** + * @ingroup notify_flags + * @brief notify op bit flags + */ +#define RT_NOTIFY_DEFAULT 0x00U +#define RT_NOTIFY_MC2 0x01U // RT_NOTIFY_MC2 does not support OR with other flags +#define RT_DMS_MAX_EVENT_NAME_LENGTH 256 +#define RT_DMS_MAX_EVENT_DATA_LENGTH 32 +#define RT_DMS_MAX_EVENT_RESV_LENGTH 32 +#define RT_DSM_EVENT_FILTER_FLAG_PID (1UL << 3) +#define RT_MAX_RECORD_PA_NUM_PER_DEV 20U + +typedef struct { + uint64_t ptr; + uint64_t len; +} rtMemRepairAddr; + +typedef struct { + uint32_t devid; + uint32_t count; + rtMemRepairAddr repairAddr[RT_MAX_RECORD_PA_NUM_PER_DEV]; +} rtMemUceInfo; + +typedef struct tagDmsEventFilter { + uint64_t filterFlag; + uint32_t eventId; + unsigned char severity; + unsigned char nodeType; + unsigned char resv[RT_DMS_MAX_EVENT_RESV_LENGTH]; /**< reserve 32byte */ +} rtDmsEventFilter; + +typedef struct tagDmsFaultEvent { + uint64_t alarmRaisedTime; + uint32_t eventId; + int32_t tgid; + int32_t eventSerialNum; + int32_t notifySerialNum; + uint16_t deviceId; + uint16_t nodeType; + uint16_t subNodeType; + unsigned char nodeId; + unsigned char subNodeId; + unsigned char severity; + unsigned char assertion; + char eventName[RT_DMS_MAX_EVENT_NAME_LENGTH]; + char additionalInfo[RT_DMS_MAX_EVENT_DATA_LENGTH]; + unsigned char osId; + unsigned char resv[RT_DMS_MAX_EVENT_RESV_LENGTH]; /* reserve 32byte */ +} rtDmsFaultEvent; + +typedef struct tagNotifyPhyInfo { + uint32_t phyId; /* phy id */ + uint32_t tsId; /* ts id */ + uint32_t idType; /* SHR_ID_NOTIFY_TYPE */ + uint32_t shrId; /* notify id */ + uint32_t flag; /* RT_NOTIFY_FLAG_SHR_ID_SHADOW for remote id or shadow node */ + uint32_t rsv[3]; +} rtNotifyPhyInfo; + +/** + * @ingroup dvrt_event + * @brief create event instance + * @param [in|out] event created event + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventCreate(rtEvent_t *evt); + +/** + * @ingroup dvrt_event + * @brief create event instance with flag + * @param [in|out] event created event flag event op flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventCreateWithFlag(rtEvent_t *evt, uint32_t flag); + +/** + * @ingroup dvrt_event + * @brief create event instance with flag for single mode + * @param [in|out] event created event flag event op flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventCreateExWithFlag(rtEvent_t *evt, uint32_t flag); + +/** + * @ingroup dvrt_event + * @brief destroy event instance + * @param [in] evt event to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventDestroy(rtEvent_t evt); + +/** + * @ingroup dvrt_event + * @brief synchronize destroy event instance + * @param [in] evt event to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventDestroySync(rtEvent_t evt); + +/** + * @ingroup dvrt_event + * @brief get event id + * @param [in] evt event to be get + * @param [in|out] event_id event_id id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetEventID(rtEvent_t evt, uint32_t *evtId); + +/** + * @ingroup dvrt_event + * @brief event record + * @param [int] event event to record + * @param [int] stm stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventRecord(rtEvent_t evt, rtStream_t stm); + +/** + * @ingroup dvrt_event + * @brief event reset + * @param [int] event event to reset + * @param [int] stm stream handle + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtEventReset(rtEvent_t evt, rtStream_t stm); + +/** + * @ingroup dvrt_event + * @brief wait event to be complete + * @param [in] evt event to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEventSynchronize(rtEvent_t evt); + +/** + * @ingroup dvrt_event + * @brief wait event to be complete + * @param [in] evt event to wait + * @param [in] timeout event wait timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_EVENT_SYNC_TIMEOUT for timeout + */ +RTS_API rtError_t rtEventSynchronizeWithTimeout(rtEvent_t evt, const int32_t timeout); + +/** + * @ingroup dvrt_event + * @brief Queries an event's status + * @param [in] evt event to query + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_EVENT_NOT_COMPLETE for not complete + */ +RTS_API rtError_t rtEventQuery(rtEvent_t evt); + +/** + * @ingroup dvrt_event + * @brief Queries an event's wait status + * @param [in] evt event to query + * @param [in out] EVENT_WAIT_STATUS status + * @return EVENT_STATUS_COMPLETE for complete + * @return EVENT_STATUS_NOT_READY for not complete + */ +RTS_API rtError_t rtEventQueryWaitStatus(rtEvent_t evt, rtEventWaitStatus_t *status); + +/** + * @ingroup dvrt_event + * @brief Queries an event's status + * @param [in] evt event to query + * @param [in out] rtEventStatus_t status + * @return RT_EVENT_RECORDED for recorded + * @return RT_EVENT_INIT for not recorded + */ +RTS_API rtError_t rtEventQueryStatus(rtEvent_t evt, rtEventStatus_t *status); + +/** + * @ingroup dvrt_event + * @brief computes the elapsed time between events. + * @param [in] timeInterval time between start and end in ms + * @param [in] startEvent starting event + * @param [in] endEvent ending event + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtEventElapsedTime(float32_t *timeInterval, rtEvent_t startEvent, rtEvent_t endEvent); + +/** + * @ingroup dvrt_event + * @brief get the elapsed time from a event after event recorded. + * @param [in] timeStamp time in ms + * @param [in] evt event handle + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtEventGetTimeStamp(uint64_t *timeStamp, rtEvent_t evt); + +/** + * @ingroup dvrt_event + * @brief name an event + * @param [in] evt event to be named + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input of event, name + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtNameEvent(rtEvent_t evt, const char_t *name); + +/** + * @ingroup + * @brief get fault event . + * @param [in] deviceId device id + * @param [in] filter filter condition:PID + * @param [in] len output length + * @param [out] dmsEvent return dms event struct array + * @param [out] eventCount return event count + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetFaultEvent(const int32_t deviceId, rtDmsEventFilter *filter, rtDmsFaultEvent *dmsEvent, + uint32_t len, uint32_t *eventCount); + +/** +* @ingroup dvrt_mem +* @brief Get memUceInfo. +* @attention Only support ONLINE scene. +* @param [in] deviceId device id. +* @param [out] memUceInfo Returned memUceInfo. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtGetMemUceInfo(const uint32_t deviceId, rtMemUceInfo *memUceInfo); + +/** +* @ingroup dvrt_mem +* @brief Repair Ucemem. +* @attention Only support ONLINE scene. +* @param [in] deviceId device id. +* @param [out] memUceInfo Returned memUceInfo. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemUceRepair(const uint32_t deviceId, rtMemUceInfo *memUceInfo); + +/** + * @ingroup dvrt_event + * @brief Create a notify + * @param [in] device_id device id + * @param [in|out] notify_ notify to be created + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNotifyCreate(int32_t deviceId, rtNotify_t *notify); + + +/** + * @ingroup dvrt_event + * @brief Create a notify + * @param [in] device_id device id + * @param [in|out] notify_ notify to be created + * @param [in] flag flag notify flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNotifyCreateWithFlag(int32_t deviceId, rtNotify_t *notify, uint32_t flag); + +/** + * @ingroup dvrt_event + * @brief Destroy a notify + * @param [in] notify_ notify to be destroyed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtNotifyDestroy(rtNotify_t notify); + +/** + * @ingroup dvrt_event + * @brief Record a notify + * @param [in] notify_ notify to be recorded + * @param [in] stream_ input stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx + */ +RTS_API rtError_t rtNotifyRecord(rtNotify_t notify, rtStream_t stm); + +/** + * @ingroup dvrt_event + * @brief Reset a notify + * @param [in] notify_ notify to be destroyed + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNotifyReset(rtNotify_t notify); + +/** + * @ingroup dvrt_event + * @brief Resource clean + * @param [in] devId deviceId + * @param [in] type typeId + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtResourceClean(int32_t devId, rtIdType_t type); + +/** + * @ingroup dvrt_event + * @brief Wait for a notify + * @param [in] notify notify to be wait + * @param [in] stm input stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx + */ +RTS_API rtError_t rtNotifyWait(rtNotify_t notify, rtStream_t stm); + +/** + * @ingroup dvrt_event + * @brief Wait for a notify with time out + * @param [in] notify notify to be wait + * @param [in] stm input stream + * @param [in] timeOut input timeOut + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_STREAM_CONTEXT for stream is not in current ctx + */ +RTS_API rtError_t rtNotifyWaitWithTimeOut(rtNotify_t notify, rtStream_t stm, uint32_t timeOut); + +/** + * @ingroup dvrt_event + * @brief get notify id + * @param [in] notify_ notify to be get + * @param [in|out] notify_id notify id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetNotifyID(rtNotify_t notify, uint32_t *notifyId); + +/** + * @ingroup dvrt_event + * @brief Get notify phy info + * @param [in] notify the created/opened notify + * @param [out] phyDevId phy device id + * @param [out] tsId ts id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNotifyGetPhyInfo(rtNotify_t notify, uint32_t *phyDevId, uint32_t *tsId); + +/** + * @ingroup dvrt_event + * @brief Get notify phy and pod info + * @param [in] notify the created/opened notify + * @param [out] phyDevId phy device id + * @param [out] tsId ts id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNotifyGetPhyInfoExt(rtNotify_t notify, rtNotifyPhyInfo *notifyInfo); + +/** + * @ingroup dvrt_event + * @brief Set a notify to IPC notify + * @param [in] notify notify to be set to IPC notify + * @param [in] name identification name + * @param [in] len length of name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtIpcSetNotifyName(rtNotify_t notify, char_t *name, uint32_t len); + +/** + * @ingroup dvrt_event + * @brief Open IPC notify + * @param [out] notify the opened notify + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtIpcOpenNotify(rtNotify_t *notify, const char_t *name); + +/** + * @ingroup dvrt_event + * @brief Open IPC notify + * @param [out] notify the opened notify + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtIpcOpenNotifyWithFlag(rtNotify_t *notify, const char_t *name, uint32_t flag); + +/** + * @ingroup dvrt_event + * @brief Get the physical address corresponding to notify + * @param [in] notify notify to be queried + * @param [in] devAddrOffset device physical address offset + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtNotifyGetAddrOffset(rtNotify_t notify, uint64_t *devAddrOffset); + +/** + * @ingroup dvrt_event + * @brief Ipc set notify pid + * @param [in] name name to be queried + * @param [in] pid process id + * @param [in] num length of pid[] + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtSetIpcNotifyPid(const char_t *name, int32_t pid[], int32_t num); + +/** + * @ingroup dvrt_event + * @brief Create an ipc interrupt notice task + * @param [in] ipcIntNoticeInfo ipcIntNotice info to be sent + * @param [in] stm stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtIpcIntNotice(const rtIpcIntNoticeInfo_t * const ipcIntNoticeInfo, rtStream_t stm); + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_EVENT_H diff --git a/inc/runtime/runtime/kernel.h b/inc/runtime/runtime/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4532496b46e49cec83ad4ba929c7ca82f4ea6966 --- /dev/null +++ b/inc/runtime/runtime/kernel.h @@ -0,0 +1,1180 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: kernel.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_KERNEL_H +#define CCE_RUNTIME_KERNEL_H + +#include "base.h" +#include "stream.h" +#include "rt_stars_define.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * @ingroup rt_kernel + * @brief shared memory data control + */ +typedef struct tagRtSmData { + uint64_t L2_mirror_addr; // preload or swap source addr + uint32_t L2_data_section_size; // every data size + uint8_t L2_preload; // 1 - preload from mirrorAddr, 0 - no preload + uint8_t modified; // 1 - data will be modified by kernel, 0 - no modified + uint8_t priority; // data priority + int8_t prev_L2_page_offset_base; // remap source section offset + uint8_t L2_page_offset_base; // remap destination section offset + uint8_t L2_load_to_ddr; // 1 - need load out, 0 - no need + uint8_t reserved[2]; // reserved +} rtSmData_t; + +/** + * @ingroup rt_kernel + * @brief shared memory description + */ +typedef struct tagRtSmCtrl { + rtSmData_t data[8]; // data description + uint64_t size; // max page Num + uint8_t remap[64]; /* just using for static remap mode, default:0xFF + array index: virtual l2 page id, array value: physic l2 page id */ + uint8_t l2_in_main; // 0-DDR, 1-L2, default:0xFF + uint8_t reserved[3]; +} rtSmDesc_t; + +typedef rtSmDesc_t rtL2Ctrl_t; + +/** + * @ingroup rt_kernel + * @brief device binary type + */ +typedef struct tagRtDevBinary { + uint32_t magic; // magic number + uint32_t version; // version of binary + const void *data; // binary data + uint64_t length; // binary length +} rtDevBinary_t; + +/** + * @ingroup rt_kernel + * @brief function mode type + */ +#define ONLINE_PROF_MAX_PMU_NUM (8) + +typedef struct ProfilefDataInfo { + const void *stubFunc; + uint32_t blockDim; + const void *args; + uint32_t argsSize; + rtSmDesc_t *smDesc; + rtStream_t stream; + uint64_t totalcycle; + uint64_t ovcycle; + uint64_t pmu_cnt[ONLINE_PROF_MAX_PMU_NUM]; +} rtProfDataInfo_t; + +/** + * @ingroup rt_kernel + * @brief function mode type + */ +typedef enum { + FUNC_MODE_NORMAL = 0, + FUNC_MODE_PCTRACE_USERPROFILE_RECORDLOOP, + FUNC_MODE_PCTRACE_USERPROFILE_SKIPLOOP, + FUNC_MODE_PCTRACE_CYCLECNT_RECORDLOOP, + FUNC_MODE_PCTRACE_CYCLECNT_SKIPLOOP, + FUNC_MODE_BUTT +} rtFuncModeType_t; + +/** + * @ingroup rt_kernel + * @brief kernel info + */ +typedef struct rtKernelInfo { + uint64_t task_offset; // kernel offset in module + /* flowtable */ + void *arg; // launch kernel arg + uint32_t arg_size; + /* module */ + void *module_addr; // module::baseaddr_ + uint32_t module_size; +} *rtKernelInfo_t; + +/** + * @ingroup rt_kernel + * @brief op name + */ +typedef struct rtKernelLaunchNames { + const char_t *soName; // defined for so name + const char_t *kernelName; // defined for kernel type name + const char_t *opName; // defined for operator name +} rtKernelLaunchNames_t; + +typedef struct rtFunctionInfo { + void *pcAddr; + uint32_t prefetchCnt; + uint8_t mixType; // 0:NO_MIX; 1:MIX_AIC; 2:MIX_AIV; 3:MIX_AIC_AIV + uint8_t reserved[3]; +} rtFunctionInfo_t; + +typedef struct tagRtKernelInfo { + uint8_t functionInfoNum; + uint8_t reserved[3]; + rtFunctionInfo_t functionInfo[2]; +} rtKernelDetailInfo_t; + +/** + * @ingroup rt_kernel + * @brief args struct + */ +typedef struct tagRtArgsWithTiling { + void *args; // args host mem addr + uint32_t argsSize; // input + output + tiling addr size + tiling data size + uint32_t argsSizeWithoutTiling; // input + output + tiling addr size + uint16_t tilingAddrOffset; // tiling addr offset + uint16_t tilingDataOffset; // tiling data offset + uint16_t hostInputAddrOffset; // index of host_memory input in inputs_addrs list + uint16_t hostInputDataOffset; // host_mem input data offset + uint8_t hasHostMemInput; // has host_memory input data in args or not: 0 means no host_memory input data, + // others means has host_memory input data. + uint8_t isNoNeedH2DCopy; // is no need host to device copy: 0 means need H2D copy, + // others means doesn't need H2D copy. + uint8_t reserved[6]; +} rtArgsWithTiling_t; + +/** + * @ingroup rt_kernel + * @brief host memory input struct + */ +typedef struct rtHostInputInfo { + uint32_t addrOffset; + uint32_t dataOffset; +} rtHostInputInfo_t; + +/** + * @ingroup rt_kernel + * @brief args struct + */ +typedef struct tagRtArgsEx { + void *args; // args host mem addr + rtHostInputInfo_t *hostInputInfoPtr; // nullptr means no host mem input + uint32_t argsSize; // input + output + tiling addr size + tiling data size + host mem + uint32_t tilingAddrOffset; // tiling addr offset + uint32_t tilingDataOffset; // tiling data offset + uint16_t hostInputInfoNum; // hostInputInfo num + uint8_t hasTiling; // if has tiling: 0 means no tiling + uint8_t isNoNeedH2DCopy; // is no need host to device copy: 0 means need H2D copy, + // others means doesn't need H2D copy. + uint8_t reserved[4]; +} rtArgsEx_t; + +typedef struct tagRtAicpuArgsEx { + void *args; // args host mem addr + rtHostInputInfo_t *hostInputInfoPtr; // nullptr means no host mem input + rtHostInputInfo_t *kernelOffsetInfoPtr; // KernelOffsetInfo, it is different for CCE Kernel and fwk kernel + uint32_t argsSize; + uint16_t hostInputInfoNum; // hostInputInfo num + uint16_t kernelOffsetInfoNum; // KernelOffsetInfo num + uint32_t soNameAddrOffset; // just for CCE Kernel, default value is 0xffff for FWK kernel + uint32_t kernelNameAddrOffset; // just for CCE Kernel, default value is 0xffff for FWK kernel + bool isNoNeedH2DCopy; // is no need host to device copy: 0 means need H2D copy, + // other means doesn't need H2D copy. + uint16_t timeout; // timeout for aicpu exit + uint8_t reserved[1]; +} rtAicpuArgsEx_t; + +typedef struct tagRtDvppTaskDesc { + rtStarsCommonSqe_t sqe; + uint16_t aicpuTaskPos ; // rtsq max dep is 1024 + uint16_t reserved; +} rtDvppTaskDesc_t; + +typedef struct tagRtAicpuTaskDesc { + rtKernelLaunchNames_t kernelLaunchNames; + uint16_t blockDim; + uint16_t isUnderstudyOp : 1; // dvpp op exist, set 1; otherwise set 0 + uint16_t resverved : 15; + rtArgsEx_t argsInfo; +} rtAicpuTaskDesc_t; + +typedef enum tagRtMultipleTaskType { + RT_MULTIPLE_TASK_TYPE_DVPP = 0, + RT_MULTIPLE_TASK_TYPE_AICPU = 1, + RT_MULTIPLE_TASK_TYPE_MAX +} rtMultipleTaskType_t; + +typedef struct tagRtTaskDesc { + rtMultipleTaskType_t type; // only support AICPU or DVPP, will be checked in runtime api_error. + union { + rtDvppTaskDesc_t dvppTaskDesc; + rtAicpuTaskDesc_t aicpuTaskDesc; + } u; +} rtTaskDesc_t; + +typedef struct tagRtMultipleTaskInfo { + uint32_t taskNum; + rtTaskDesc_t *taskDesc; // must memset0 after new obj +} rtMultipleTaskInfo_t; + +/** + * @ingroup rt_KernelConfigDump + * @brief device dump type + */ +typedef enum tagRtDumpKind { + RT_DATA_DUMP_KIND_INVALID = -1, + RT_DATA_DUMP_KIND_DUMP = 0, + RT_DATA_DUMP_KIND_RESERVED = 1, +} rtDumpKind_t; + +/** + * @ingroup rt_kernel + * @brief rt kernel type + */ +typedef enum rtKernelType { + KERNEL_TYPE_CCE = 0, + KERNEL_TYPE_FWK = 1, + KERNEL_TYPE_AICPU = 2, + KERNEL_TYPE_AICPU_CUSTOM = 4, + KERNEL_TYPE_AICPU_KFC = 5, + KERNEL_TYPE_HWTS = 10, + KERNEL_TYPE_RESERVED = 99, +} rtKernelType_t; + +/** + * @ingroup rt_kernel + * @brief report callback + */ +typedef rtError_t (*rtKernelReportCallback)(rtStream_t stm, rtKernelInfo_t kernelInfo); + +/** + * @ingroup rt_kernel + * @brief stream report callback + */ +typedef void (*rtCallback_t)(void *fnData); + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aicore + */ +#define RT_DEV_BINARY_MAGIC_PLAIN 0xabceed50U + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aicpu + */ +#define RT_DEV_BINARY_MAGIC_PLAIN_AICPU 0xabceed51U + +/** + * @ingroup rt_kernel + * @brief magic number of plain binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_PLAIN_AIVEC 0xabceed52U + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicore + */ +#define RT_DEV_BINARY_MAGIC_ELF 0x43554245U + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicpu + */ +#define RT_DEV_BINARY_MAGIC_ELF_AICPU 0x41415243U + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aivector + */ +#define RT_DEV_BINARY_MAGIC_ELF_AIVEC 0x41415246U + +/** + * @ingroup rt_kernel + * @brief magic number of elf binary for aicube + */ +#define RT_DEV_BINARY_MAGIC_ELF_AICUBE 0x41494343U + +/** + * @ingroup rt_kernel_flags + * @brief kernel op bit flags + */ +#define RT_KERNEL_DEFAULT (0x00U) +#define RT_KERNEL_CONVERT (0x01U) +#define RT_KERNEL_DUMPFLAG (0x02U) +#define RT_FUSION_KERNEL_DUMPFLAG (0x04U) +#define RT_KERNEL_CUSTOM_AICPU (0x08U) +#define RT_KERNEL_FFTSPLUS_DYNAMIC_SHAPE_DUMPFLAG (0x10U) +#define RT_KERNEL_FFTSPLUS_STATIC_SHAPE_DUMPFLAG (0x20U) +// cmdlist does not need to be released by the runtime. +#define RT_KERNEL_CMDLIST_NOT_FREE (0x40U) +#define RT_KERNEL_USE_SPECIAL_TIMEOUT (0x100U) + +// STARS topic scheduler sqe : topic_type +#define RT_KERNEL_DEVICE_FIRST (0x10U) +#define RT_KERNEL_HOST_ONLY (0x20U) +#define RT_KERNEL_HOST_FIRST (0x40U) +#define RT_KERNEL_BIUPERF_FLAG (0x80U) + +/** + * @ingroup rt_kernel + * @brief kernel mode +**/ +#define RT_DEFAULT_KERNEL_MODE (0x00U) +#define RT_NORMAL_KERNEL_MODE (0x01U) +#define RT_ALL_KERNEL_MODE (0x02U) + +/** + * @ingroup rt_kernel + * @brief SHAPE kernel type +**/ +#define RT_STATIC_SHAPE_KERNEL (0x00U) +#define RT_DYNAMIC_SHAPE_KERNEL (0x01U) + +/** + * @ingroup rt_kernel + * @brief kernel L1 Fusion Dump bit flags + */ +#define RT_DDR_ADDR (0x0U) + +/** + * @ingroup rt_kernel + * @brief register device binary + * @param [in] bin device binary description + * @param [out] hdl device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDevBinaryRegister(const rtDevBinary_t *bin, void **hdl); + +RTS_API rtError_t rtGetNotifyAddress(rtNotify_t notify, uint64_t * const notifyAddres); + +/** + * @ingroup rt_kernel + * @brief register device binary with all kernel + * @param [in] bin device binary description + * @param [out] hdl device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtRegisterAllKernel(const rtDevBinary_t *bin, void **hdl); + +/** + * @ingroup rt_kernel + * @brief register fast memeory device binary + * @param [in] hdl device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtBinaryRegisterToFastMemory(void *hdl); + +/** + * @ingroup rt_kernel + * @brief unregister device binary + * @param [in] hdl device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDevBinaryUnRegister(void *hdl); + +/** + * @ingroup rt_kernel + * @brief register device binary metadata + * @param [in] hdl device binary description + * @param [in] metadata device binary metadata + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMetadataRegister(void *hdl, const char_t *metadata); + +/** + * @ingroup rt_kernel + * @brief register device binary dependency + * @param [in] mHandle master device binary description + * @param [in] sHandle slave device binary description + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDependencyRegister(void *mHandle, void *sHandle); + +/** + * @ingroup rt_kernel + * @brief register device function + * @param [in] binHandle device binary handle + * @param [in] stubFunc stub function + * @param [in] stubName stub function name + * @param [in] kernelInfoExt kernel Info extension. device function description or tiling key, + * depending static shape or dynmaic shape. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFunctionRegister(void *binHandle, const void *stubFunc, const char_t *stubName, + const void *kernelInfoExt, uint32_t funcMode); + +/** + * @ingroup rt_kernel + * @brief find stub function by name + * @param [in] stubName stub function name + * @param [out] stubFunc stub function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetFunctionByName(const char_t *stubName, void **stubFunc); + +/** + * @ingroup rt_kernel + * @brief find addr by stub func + * @param [in] stubFunc stub function + * @param [out] addr + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetAddrByFun(const void *stubFunc, void **addr); +/** + * @ingroup rt_kernel + * @brief query registered or not by stubName + * @param [in] stubName stub function name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtQueryFunctionRegistered(const char_t *stubName); + +/** + * @ingroup rt_kernel + * @brief config data dump + * @param [in] dumpSizePerBlock dump size + * @param [in] blockDim block dimentions + * @param [in] dumpBaseAddr dump base addr + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelConfigDump(uint32_t kind, uint32_t dumpSizePerBlock, uint32_t blockDim, void **dumpBaseAddr, + rtStream_t stm); + +/** +* @ingroup rt_kernel +* @brief get kernel address and prefetchCnt +* @param [in] hdl program for dynamic shape +* @param [in] tilingKey tilingKey for dynamic shape +* @param [in] stubFunc stubFunc for static shape +* @param [in] flag flag for distinguishing between dynamic shape and static shape +* @param [out] addr address of kernel function +* @param [out] prefetchCnt prefetchCnt of kernel function +* @return RT_ERROR_NONE for ok +* @return RT_ERROR_INVALID_VALUE for error input +*/ +RTS_API rtError_t rtKernelGetAddrAndPrefCnt(void *hdl, const uint64_t tilingKey, const void * const stubFunc, + const uint32_t flag, void **addr, uint32_t *prefetchCnt); + +/** +* @ingroup rt_kernel +* @brief get kernel address and prefetchCnt +* @param [in] hdl program for dynamic shape +* @param [in] tilingKey tilingKey for dynamic shape +* @param [in] stubFunc stubFunc for static shape +* @param [in] flag flag for distinguishing between dynamic shape and static shape +* @param [out] kernelInfo address & prefetchCnt of kernel function +* @return RT_ERROR_NONE for ok +* @return RT_ERROR_INVALID_VALUE for error input +*/ +RTS_API rtError_t rtKernelGetAddrAndPrefCntV2(void *hdl, const uint64_t tilingKey, const void * const stubFunc, + const uint32_t flag, rtKernelDetailInfo_t *kernelInfo); + +/** +* @ingroup rt_kernel +* @brief set input argments size for exception +* @param [in] sizeInfo argments size info +* @return RT_ERROR_NONE for ok +* @return RT_ERROR_INVALID_VALUE for error input +*/ +RTS_API rtError_t rtSetExceptionExtInfo(const rtArgsSizeInfo_t * const sizeInfo); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunch(const void *stubFunc, uint32_t blockDim, void *args, uint32_t argsSize, + rtSmDesc_t *smDesc, rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief launch kernel with handle to device + * @param [in] hdl program + * @param [in] tilingKey tilingKey + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] kernelInfo kernel info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchWithHandle(void *hdl, const uint64_t tilingKey, uint32_t blockDim, + rtArgsEx_t *argsInfo, rtSmDesc_t *smDesc, rtStream_t stm, + const void *kernelInfo); + +/** + * @ingroup rt_kernel + * @brief launch kernel with handle to device + * @param [in] hdl program + * @param [in] tilingKey tilingKey + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] cfgInfo task config + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchWithHandleV2(void *hdl, const uint64_t tilingKey, uint32_t blockDim, + rtArgsEx_t *argsInfo, rtSmDesc_t *smDesc, rtStream_t stm, const rtTaskCfgInfo_t *cfgInfo); + +/** + * @ingroup rtKernelLaunchWithFlag + * @brief launch kernel to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] flags dump flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchWithFlag(const void *stubFunc, uint32_t blockDim, rtArgsEx_t *argsInfo, + rtSmDesc_t *smDesc, rtStream_t stm, uint32_t flags); + +/** + * @ingroup rtKernelLaunchWithFlag + * @brief launch kernel to device + * @param [in] stubFunc stub function + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] flags dump flag + * @param [in] cfgInfo task config info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchWithFlagV2(const void *stubFunc, uint32_t blockDim, rtArgsEx_t *argsInfo, + rtSmDesc_t *smDesc, rtStream_t stm, uint32_t flags, const rtTaskCfgInfo_t *cfgInfo); + +/** + * @ingroup rt_kernel(abandoned) + * @brief launch kernel to device + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] flags launch flags + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchEx(void *args, uint32_t argsSize, uint32_t flags, rtStream_t stm); + +/** + * @ingroup rt_kernel(in use) + * @brief launch kernel to device + * @param [in] opName opkernel name + * @param [in] args argments address for kernel function + * @param [in] argsSize argements size + * @param [in] flags launch flags + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelLaunchFwk(const char_t *opName, void *args, uint32_t argsSize, uint32_t flags, + rtStream_t rtStream); + +/** + * @ingroup rtCpuKernelLaunchWithFlag(abandoned) + * @brief launch cpu kernel to device with dump identifier + * @param [in] soName so name + * @param [in] kernelName kernel name + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] flag dump flag or others function flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCpuKernelLaunchWithFlag(const void *soName, const void *kernelName, uint32_t blockDim, + const rtArgsEx_t *argsInfo, rtSmDesc_t *smDesc, rtStream_t stm, + uint32_t flags); + +/** + * @ingroup rtAicpuKernelLaunchWithFlag(in use) + * @brief launch cpu kernel to device with dump identifier + * @param [in] launchNames names for kernel launch + * @param [in] blockDim block dimentions + * @param [in] args argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] flags dump flag or others function flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuKernelLaunchWithFlag(const rtKernelLaunchNames_t *launchNames, uint32_t blockDim, + const rtArgsEx_t *argsInfo, rtSmDesc_t *smDesc, rtStream_t stm, + uint32_t flags); + +/** + * @ingroup rtAicpuKernelLaunchExWithArgs + * @brief launch cpu kernel to device with dump identifier and kernelType + * @param [in] kernelType aicpu kernel type + * @param [in] opName address of op name + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] flags dump flag or others function flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuKernelLaunchExWithArgs(const uint32_t kernelType, const char_t * const opName, + const uint32_t blockDim, const rtAicpuArgsEx_t *argsInfo, + rtSmDesc_t * const smDesc, const rtStream_t stm, + const uint32_t flags); + +/** + * @ingroup rt_kernel + * @brief L1 fusion dump addr transfered to device + * @param [in] mdl handle info + * @param [in] addr ddr address of L1 Fusion Dump + * @param [in] dumpSize memory size + * @param [in] flag memory flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDumpAddrSet(rtModel_t mdl, void *addr, uint32_t dumpSize, uint32_t flag); + +/** + * @ingroup rt_kernel + * @brief load dump info to aicpu + * @param [in] dumpInfo dump info + * @param [in] length length of dump info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDatadumpInfoLoad(const void *dumpInfo, uint32_t length); + +/** + * @ingroup rt_kernel + * @brief load aicpu info + * @param [in] aicpuInfo aicpu info + * @param [in] length length of aicpu info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAicpuInfoLoad(const void *aicpuInfo, uint32_t length); + +/** + * @ingroup rt_kernel + * @brief load dump info to aicpu + * @param [in] dumpInfo dump info + * @param [in] length length of dump info + * @param [in] flag RT_KERNEL_DEFAULT or RT_KERNEL_CUSTOM_AICPU + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDatadumpInfoLoadWithFlag(const void *dumpInfo, const uint32_t length, const uint32_t flag); + +/** + * @ingroup rt_kernel + * @brief launch npu get float status task + * @param [in] outputAddrPtr pointer to op output addr + * @param [in] outputSize op output size + * @param [in] checkMode check mode + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNpuGetFloatStatus(void *outputAddrPtr, uint64_t outputSize, uint32_t checkMode, rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief launch npu clear float status task + * @param [in] checkMode check mode + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNpuClearFloatStatus(uint32_t checkMode, rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief launch npu get float status task + * @param [in] outputAddrPtr pointer to op output addr + * @param [in] outputSize op output size + * @param [in] checkMode check mode + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNpuGetFloatDebugStatus(void *outputAddrPtr, uint64_t outputSize, uint32_t checkMode, + rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief launch npu clear float status task + * @param [in] checkMode check mode + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNpuClearFloatDebugStatus(uint32_t checkMode, rtStream_t stm); + +#ifndef __CLANG_CCE_RUNTIME_H__ +#define __CLANG_CCE_RUNTIME_H__ +/** + * @ingroup rt_kernel + * @brief configure call argment for next rtLaunch in current thread + * @param [in] numBlocks block dimentions + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +#ifdef __cplusplus +RTS_API rtError_t rtConfigureCall(uint32_t numBlocks, rtSmDesc_t *smDesc = nullptr, rtStream_t stm = nullptr); +#else +RTS_API rtError_t rtConfigureCall(uint32_t numBlocks, rtSmDesc_t *smDesc, rtStream_t stm); + +#endif +#endif // __CLANG_CCE_RUNTIME_H__ + +/** + * @ingroup rt_kernel + * @brief setup argment for next rtLaunch in current thread + * @param [in] args argment address for kernel function + * @param [in] size argment size + * @param [in] offset argment table offset + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetupArgument(const void *args, uint32_t size, uint32_t offset); + +/** + * @ingroup rt_kernel + * @brief launch kernel to device with previous setting kernel argment + * and call argment + * @param [in] stubFunc stub function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLaunch(const void *stubFunc); + +/** + * @ingroup rt_kernel + * @brief implicitly transfered data to device. + * lifecycle end after next kernel task finish + * @param [in] ptr host memory + * @param [in] size host memory size + * @param [in] flag reserved. set to 0 + * @param [out] args returned arg. used for next kernel's arg. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelConfigTransArg(const void *ptr, uint64_t size, uint32_t flag, void **args); + +/** + * @ingroup rt_kernel + * @brief start fusion kernels. + * @param [in] stm stream for fusion kernels + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelFusionStart(rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief end fusion kernels. + * @param [in] stm stream for fusion kernels + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtKernelFusionEnd(rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief set kernelinfo callback + * @param [in] callback + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetKernelReportCallback(rtKernelReportCallback callBack); + +/** + * @ingroup rt_kernel + * @brief subscribe stream callback report. + * @param [in] threadId thread id for stream + * @param [in] stm stream for subscribe + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSubscribeReport(uint64_t threadId, rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief add callback launch task in stream. + * @param [in] callBackFunc app callback function + * @param [in] fnData user data + * @param [in] stm subscribed stream + * @param [in] isBlock specifies whether the callback function invocation blocks the device, + * 0: non-blocking; 1: blocking + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCallbackLaunch(rtCallback_t callBackFunc, void *fnData, rtStream_t stm, bool isBlock); + +/** + * @ingroup rt_kernel + * @brief process callback report. + * @param [in] timeout if timeout=-1, while(1); else timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtProcessReport(int32_t timeout); + +/** + * @ingroup rt_kernel + * @brief unsubscribe callback report. + * @param [in] threadId thread id for stream + * @param [in] stm stream for subscribe + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtUnSubscribeReport(uint64_t threadId, rtStream_t stm); + +/** + * @ingroup profiling_base + * @brief start online prof. + * @param [in] stm associated stream + * @param [in] sampleNum number of sample + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStartOnlineProf(rtStream_t stm, uint32_t sampleNum); + +/** + * @ingroup profiling_base + * @brief stop online prof. + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStopOnlineProf(rtStream_t stm); + +/** + * @ingroup profiling_base + * @brief get online prof. + * @param [in] stm associated stream + * @param [in] profDataNum number of obtained prof data records, should be in (0, MAX_ONLINEPROF_NUM] + * @param [out] pProfData obtained prof data + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetOnlineProfData(rtStream_t stm, rtProfDataInfo_t *pProfData, uint32_t profDataNum); + +/** + * @ingroup profiling_base + * @brief start mdc profiler. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStartMDCProfiler(void **addr, uint32_t length); + +/** + * @ingroup profiling_base + * @brief stop mdc profiler. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStopMDCProfiler(void *addr); + +/** + * @ingroup rt_kernel + * @brief Calculate ArgsSize. + * @param [in] argsSize args Size + * @param [in] hostInfoTotalSize hostInfoTotal Size + * @param [in] hostInfoNum hostInfo num + * @param [out] launchArgsSize launch Args Size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtCalcLaunchArgsSize(size_t argsSize, size_t hostInfoTotalSize, size_t hostInfoNum, + size_t *launchArgsSize); + +/** + * @ingroup rt_kernel + * @brief Create Args Handle. + * @param [in] argsSize args Size + * @param [in] hostInfoTotalSize hostInfoTotal Size + * @param [in] hostInfoNum hostInfo num + * @param [in] argsData args Data + * @param [out] argsHandle args Handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtCreateLaunchArgs(size_t argsSize, size_t hostInfoTotalSize, size_t hostInfoNum, + void* argsData, rtLaunchArgsHandle* argsHandle); + +/** + * @ingroup rt_kernel + * @brief Destroy Args Handle. + * @param [in] argsHandle args Handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtDestroyLaunchArgs(rtLaunchArgsHandle argsHandle); + +/** + * @ingroup rt_kernel + * @brief Reset Args Handle Info. + * @param [in] argsHandle args Handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtResetLaunchArgs(rtLaunchArgsHandle argsHandle); + +/** + * @ingroup rt_kernel + * @brief Append address info to Args Handle. + * @param [in] argsHandle args Handle + * @param [in] addrInfo address info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtAppendLaunchAddrInfo(rtLaunchArgsHandle argsHandle, void *addrInfo); + +/** + * @ingroup rt_kernel + * @brief Append Host info to args Handle. + * @param [in] argsHandle args Handle + * @param [in] hostInfoSize host Info Size + * @param [out] hostInfo host info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtAppendLaunchHostInfo(rtLaunchArgsHandle argsHandle, size_t hostInfoSize, void **hostInfo); + +/** + * @ingroup rt_kernel + * @brief Registers and parses the bin file and loads it to the device. + * @param [in] bin device binary description + * @param [out] binHandle device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtBinaryLoad(const rtDevBinary_t *bin, rtBinHandle *binHandle); + +/** + * @ingroup rt_kernel + * @brief Find funcHandle based on binHandle and tilingKey. + * @param [in] binHandle funcHandle + * @param [in] tilingKey tilingKey + * @param [out] funcHandle funcHandle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtBinaryGetFunction(const rtBinHandle binHandle, const uint64_t tilingKey, rtFuncHandle *funcHandle); + +/** + * @ingroup rt_kernel + * @brief UnLoad binary + * @param [in] binHandle Binary Handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtBinaryUnLoad(rtBinHandle binHandle); + +/** + * @ingroup rt_kernel + * @brief Kernel Launch to device + * @param [in] funcHandle function Handle + * @param [in] blockDim block dimentions + * @param [in] argsHandle args Handle + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtLaunchKernelByFuncHandle(rtFuncHandle funcHandle, uint32_t blockDim, rtLaunchArgsHandle argsHandle, + rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief Kernel Launch to device + * @param [in] funcHandle function Handle + * @param [in] blockDim block dimentions + * @param [in] argsHandle args Handle + * @param [in] stm associated stream + * @param [in] cfgInfo task config info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtLaunchKernelByFuncHandleV2(rtFuncHandle funcHandle, uint32_t blockDim, rtLaunchArgsHandle argsHandle, + rtStream_t stm, const rtTaskCfgInfo_t *cfgInfo); + +/** + * @ingroup rt_kernel + * @brief get Saturation Status task + * @param [in] outputAddrPtr pointer to op output addr + * @param [in] outputSize op output size + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetDeviceSatStatus(void * const outputAddrPtr, const uint64_t outputSize, rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief clear Saturation Status task + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCleanDeviceSatStatus(rtStream_t stm); + +/** + * @ingroup rt_kernel + * @brief Get Kernel Bin + * @param [in] binFileName binFileName + * @param [out] buffer bin buffer + * @param [out] length buffer length + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetKernelBin(const char_t * const binFileName, char_t **const buffer, uint32_t *length); + +/** + * @ingroup rt_kernel + * @brief Free Kernel Bin + * @param [in] buffer bin buffer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFreeKernelBin(char_t * const buffer); + +/** + * @ingroup dvrt_mem + * @brief HCCL copy ffts args + * @param [in] stm task stream + * @param [in] argsInfo args info + * @param [out] devArgsAddr device mem addr for args + * @param [out] argsHandle copy handler + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtGetDevArgsAddr(rtStream_t stm, rtArgsEx_t *argsInfo, void **devArgsAddr, void **argsHandle); + +/** + * @ingroup rt_kernel + * @brief subscribe stream for hostFunc thread. + * @param [in] threadId thread id for stream + * @param [in] stream stream for subscribe + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSubscribeHostFunc(uint64_t threadId, rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief process hostFunc callback report and hostFunc callback function. + * @param [in] timeout if timeout=-1, while(1); else timeout + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtProcessHostFunc(int32_t timeout); + +/** + * @ingroup rt_kernel + * @brief unsubscribe hostFunc callback report. + * @param [in] threadId thread id for stream + * @param [in] stream stream for subscribe + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtUnSubscribeHostFunc(uint64_t threadId, rtStream_t stream); + +/** + * @ingroup rt_kernel + * @brief Find funcHandle based on binHandle and kernelName. + * @param [in] binHandle funcHandle + * @param [in] kernelName kernelName + * @param [out] funcHandle funcHandle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtBinaryGetFunctionByName(const rtBinHandle binHandle, const char *kernelName, rtFuncHandle *funcHandle); + +/** + * @ingroup rt_kernel + * @brief Registers and parses the bin file and loads it to the device. + * @param [in] data device binary data description + * @param [in] length device binary data length + * @param [out] binHandle device binary handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtBinaryLoadWithoutTilingKey(const void *data, const uint64_t length, rtBinHandle *binHandle); + +/** + * @ingroup rt_kernel + * @brief Kernel Launch to device + * @param [in] funcHandle function Handle + * @param [in] blockDim block dimentions + * @param [in] argsInfo args info + * @param [in] stm associated stream + * @param [in] cfgInfo task config info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtLaunchKernelByFuncHandleV3(rtFuncHandle funcHandle, uint32_t blockDim, const rtArgsEx_t * const argsInfo, + rtStream_t stm, const rtTaskCfgInfo_t * const cfgInfo); + +/** + * @ingroup rt_kernel + * @brief launch vector kernel with handle to device + * @param [in] hdl program + * @param [in] tilingKey tilingKey + * @param [in] blockDim block dimentions + * @param [in] argsInfo argments address for kernel function + * @param [in] smDesc shared memory description + * @param [in] stm associated stream + * @param [in] cfgInfo task config + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtVectorCoreKernelLaunchWithHandle(void *hdl, const uint64_t tilingKey, uint32_t blockDim, + rtArgsEx_t *argsInfo, rtSmDesc_t *smDesc, rtStream_t stm, const rtTaskCfgInfo_t *cfgInfo); + +/** + * @ingroup rt_kernel + * @brief Vector Kernel Launch to device + * @param [in] funcHandle function Handle + * @param [in] blockDim block dimentions + * @param [in] argsHandle args Handle + * @param [in] stm associated stream + * @param [in] cfgInfo task config info + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtVectorCoreKernelLaunch(const void *stubFunc, uint32_t blockDim, rtArgsEx_t *argsInfo, + rtSmDesc_t *smDesc, rtStream_t stm, uint32_t flags, const rtTaskCfgInfo_t *cfgInfo); + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_KERNEL_H + diff --git a/inc/runtime/runtime/mem.h b/inc/runtime/runtime/mem.h new file mode 100644 index 0000000000000000000000000000000000000000..bad8f311d644824e860c0e9b2bf8267bec93786b --- /dev/null +++ b/inc/runtime/runtime/mem.h @@ -0,0 +1,1076 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: mem.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_MEM_H +#define CCE_RUNTIME_MEM_H + +#include +#include "base.h" +#include "config.h" +#include "stream.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * @ingroup dvrt_mem + * @brief memory type + */ +#define RT_MEMORY_DEFAULT (0x0U) // default memory on device +#define RT_MEMORY_HBM (0x2U) // HBM memory on device +#define RT_MEMORY_RDMA_HBM (0x3U) // RDMA-HBM memory on device +#define RT_MEMORY_DDR (0x4U) // DDR memory on device +#define RT_MEMORY_SPM (0x8U) // shared physical memory on device +#define RT_MEMORY_P2P_HBM (0x10U) // HBM memory on other 4P device +#define RT_MEMORY_P2P_DDR (0x11U) // DDR memory on other device +#define RT_MEMORY_DDR_NC (0x20U) // DDR memory of non-cache +#define RT_MEMORY_TS (0x40U) // Used for Ts memory +#define RT_MEMORY_TS_4G (0x40U) // Used for Ts memory(only 51) +#define RT_MEMORY_HOST (0x81U) // Memory on host +#define RT_MEMORY_SVM (0x90U) // Memory for SVM +#define RT_MEMORY_HOST_SVM (0x90U) // Memory for host SVM +#define RT_MEMORY_RESERVED (0x100U) + +#define RT_MEMORY_L1 (0x1U << 16U) +#define RT_MEMORY_L2 (0x1U << 17U) + +/** + * @ingroup dvrt_mem + * @brief memory info type + */ +#define RT_MEM_INFO_TYPE_DDR_SIZE (0x1U) +#define RT_MEM_INFO_TYPE_HBM_SIZE (0x2U) +#define RT_MEM_INFO_TYPE_DDR_P2P_SIZE (0x3U) +#define RT_MEM_INFO_TYPE_HBM_P2P_SIZE (0x4U) + +/** + * @ingroup dvrt_mem + * @brief memory Policy + */ +#define RT_MEMORY_POLICY_NONE (0x0U) // Malloc mem prior huge page, then default page +#define RT_MEMORY_POLICY_HUGE_PAGE_FIRST (0x400U) // Malloc mem prior huge page, then default page, 0x1U << 10U +#define RT_MEMORY_POLICY_HUGE_PAGE_ONLY (0x800U) // Malloc mem only use huge page, 0x1U << 11U +#define RT_MEMORY_POLICY_DEFAULT_PAGE_ONLY (0x1000U) // Malloc mem only use default page, 0x1U << 12U +// Malloc mem prior huge page, then default page, for p2p, 0x1U << 13U +#define RT_MEMORY_POLICY_HUGE_PAGE_FIRST_P2P (0x2000U) +#define RT_MEMORY_POLICY_HUGE_PAGE_ONLY_P2P (0x4000U) // Malloc mem only use huge page, use for p2p, 0x1U << 14U +#define RT_MEMORY_POLICY_DEFAULT_PAGE_ONLY_P2P (0x8000U) // Malloc mem only use default page, use for p2p, 0x1U << 15U + +/** + * @ingroup dvrt_mem + * @brief memory attribute + */ +#define RT_MEMORY_ATTRIBUTE_DEFAULT (0x0U) +// memory read only attribute, now only dvpp memory support. +#define RT_MEMORY_ATTRIBUTE_READONLY (0x100000U) // Malloc readonly, 1<<20. + +#define MEM_ALLOC_TYPE_BIT (0x3FFU) // mem type bit in <0, 9> + +/** + * @ingroup dvrt_mem + * @brief memory type | memory Policy + */ +typedef uint32_t rtMemType_t; + +/** + * @ingroup dvrt_mem + * @brief memory advise type + */ +#define RT_MEMORY_ADVISE_EXE (0x02U) +#define RT_MEMORY_ADVISE_THP (0x04U) +#define RT_MEMORY_ADVISE_PLE (0x08U) +#define RT_MEMORY_ADVISE_PIN (0x16U) + +/** + * @ingroup dvrt_mem + * @brief memory copy type + */ +typedef enum tagRtMemcpyKind { + RT_MEMCPY_HOST_TO_HOST = 0, // host to host + RT_MEMCPY_HOST_TO_DEVICE, // host to device + RT_MEMCPY_DEVICE_TO_HOST, // device to host + RT_MEMCPY_DEVICE_TO_DEVICE, // device to device, 1P && P2P + RT_MEMCPY_MANAGED, // managed memory + RT_MEMCPY_ADDR_DEVICE_TO_DEVICE, + RT_MEMCPY_HOST_TO_DEVICE_EX, // host to device ex (only used for 8 bytes) + RT_MEMCPY_DEVICE_TO_HOST_EX, // device to host ex + RT_MEMCPY_DEFAULT, // auto infer copy dir + RT_MEMCPY_RESERVED, +} rtMemcpyKind_t; + +typedef enum tagRtMemInfoType { + RT_MEMORYINFO_DDR, + RT_MEMORYINFO_HBM, + RT_MEMORYINFO_DDR_HUGE, // Hugepage memory of DDR + RT_MEMORYINFO_DDR_NORMAL, // Normal memory of DDR + RT_MEMORYINFO_HBM_HUGE, // Hugepage memory of HBM + RT_MEMORYINFO_HBM_NORMAL, // Normal memory of HBM + RT_MEMORYINFO_DDR_P2P_HUGE, // Hugepage memory of DDR + RT_MEMORYINFO_DDR_P2P_NORMAL, // Normal memory of DDR + RT_MEMORYINFO_HBM_P2P_HUGE, // Hugepage memory of HBM + RT_MEMORYINFO_HBM_P2P_NORMAL, // Normal memory of HBM +} rtMemInfoType_t; + +typedef enum tagRtRecudeKind { + RT_MEMCPY_SDMA_AUTOMATIC_ADD = 10, // D2D, SDMA inline reduce, include 1P, and P2P + RT_MEMCPY_SDMA_AUTOMATIC_MAX = 11, + RT_MEMCPY_SDMA_AUTOMATIC_MIN = 12, + RT_MEMCPY_SDMA_AUTOMATIC_EQUAL = 13, + RT_RECUDE_KIND_END = 14, +} rtRecudeKind_t; + +typedef enum tagRtDataType { + RT_DATA_TYPE_FP32 = 0, // fp32 + RT_DATA_TYPE_FP16 = 1, // fp16 + RT_DATA_TYPE_INT16 = 2, // int16 + RT_DATA_TYPE_INT4 = 3, // int4 + RT_DATA_TYPE_INT8 = 4, // int8 + RT_DATA_TYPE_INT32 = 5, // int32 + RT_DATA_TYPE_BFP16 = 6, // bfp16 + RT_DATA_TYPE_BFP32 = 7, // bfp32 + RT_DATA_TYPE_UINT8 = 8, // uint8 + RT_DATA_TYPE_UINT16 = 9, // uint16 + RT_DATA_TYPE_UINT32 = 10, // uint32 + RT_DATA_TYPE_END = 11, +} rtDataType_t; + +typedef enum rtMemcpyAttributeId { + RT_MEMCPY_ATTRIBUTE_RSV = 0, + RT_MEMCPY_ATTRIBUTE_CHECK = 1, + RT_MEMCPY_ATTRIBUTE_MAX = 2, +} rtMemcpyAttributeId_t; + +typedef union rtMemcpyAttributeValue_union { + uint32_t rsv[4]; + uint32_t checkBitmap; // bit0:Do not check for matching between address and kind;bit1:check addr is page-lock +} rtMemcpyAttributeValue_t; + +typedef struct rtMemcpyAttribute { + rtMemcpyAttributeId_t id; + rtMemcpyAttributeValue_t value; +} rtMemcpyAttribute_t; + +typedef struct rtMemcpyConfig { + rtMemcpyAttribute_t* attrs; + uint32_t numAttrs; +} rtMemcpyConfig_t; + +/** + * @ingroup dvrt_mem + * @brief memory copy channel type + */ +typedef enum tagRtMemcpyChannelType { + RT_MEMCPY_CHANNEL_TYPE_INNER = 0, // 1P + RT_MEMCPY_CHANNEL_TYPE_PCIe, + RT_MEMCPY_CHANNEL_TYPE_HCCs, // not support now + RT_MEMCPY_CHANNEL_TYPE_RESERVED, +} rtMemcpyChannelType_t; + +/** + * @ingroup rt_kernel + * @brief ai core memory size + */ +typedef struct rtAiCoreMemorySize { + uint32_t l0ASize; + uint32_t l0BSize; + uint32_t l0CSize; + uint32_t l1Size; + uint32_t ubSize; + uint32_t l2Size; + uint32_t l2PageNum; + uint32_t blockSize; + uint64_t bankSize; + uint64_t bankNum; + uint64_t burstInOneBlock; + uint64_t bankGroupNum; +} rtAiCoreMemorySize_t; + +/** + * @ingroup dvrt_mem + * @brief memory type + */ +typedef enum tagRtMemoryType { + RT_MEMORY_TYPE_HOST = 1, + RT_MEMORY_TYPE_DEVICE = 2, + RT_MEMORY_TYPE_SVM = 3, + RT_MEMORY_TYPE_DVPP = 4, + RT_MEMORY_TYPE_USER = 5 // by user malloc, unkown memory +} rtMemoryType_t; + +/** + * @ingroup dvrt_mem + * @brief ipc type + */ +typedef enum tagRtIpcMemAttrType { + RT_IPC_ATTR_SIO = 0, + RT_IPC_ATTR_HCCS = 1, + RT_IPC_ATTR_MAX +}rtIpcMemAttrType; + +/** + * @ingroup dvrt_mem + * @brief memory attribute + */ +typedef struct tagRtPointerAttributes { + rtMemoryType_t memoryType; // host memory or device memory + rtMemoryType_t locationType; + uint32_t deviceID; // device ID + uint32_t pageSize; +} rtPointerAttributes_t; + + +typedef struct { + const char_t *name; + const uint64_t size; + uint32_t flag; +} rtMallocHostSharedMemoryIn; + +typedef struct { + int32_t fd; + void *ptr; + void *devPtr; +} rtMallocHostSharedMemoryOut; + +typedef struct { + const char_t *name; + const uint64_t size; + int32_t fd; + void *ptr; + void *devPtr; +} rtFreeHostSharedMemoryIn; + + +/** + * @ingroup dvrt_mem + * @brief alloc device memory + * @param [in|out] devPtr memory pointer + * @param [in] size memory size + * @param [in] type memory type + * @param [in] moduleid alloc memory module id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMalloc(void **devPtr, uint64_t size, rtMemType_t type, const uint16_t moduleId); + +/** + * @ingroup dvrt_mem + * @brief free device memory + * @param [in|out] devPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFree(void *devPtr); + +/** + * @ingroup dvrt_mem + * @brief alloc device memory for dvpp + * @param [in|out] devPtr memory pointer + * @param [in] size memory size + * @param [in] moduleid alloc memory module id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDvppMalloc(void **devPtr, uint64_t size, const uint16_t moduleId); + +/** + * @ingroup dvrt_mem + * @brief alloc device memory for dvpp, support set flag + * @param [in|out] devPtr memory pointer + * @param [in] size memory size + * @param [in] flag mem flag, can use mem attribute set read only. + * @param [in] moduleid alloc memory module id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return others is error + */ +RTS_API rtError_t rtDvppMallocWithFlag(void **devPtr, uint64_t size, uint32_t flag, const uint16_t moduleId); + +/** + * @ingroup dvrt_mem + * @brief free device memory for dvpp + * @param [in|out] devPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDvppFree(void *devPtr); + +/** + * @ingroup dvrt_mem + * @brief alloc host memory + * @param [in|out] hostPtr memory pointer + * @param [in] size memory size + * @param [in] moduleid alloc memory module id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMallocHost(void **hostPtr, uint64_t size, const uint16_t moduleId); + +/** + * @ingroup dvrt_mem + * @brief free host memory + * @param [in] hostPtr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtFreeHost(void *hostPtr); + +/** + * @ingroup dvrt_mem + * @brief alloc host shared memory + * @param [in] in alloc host shared memory inputPara pointer + * @param [in] out alloc host shared memory outputInfo pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ + +RTS_API rtError_t rtMallocHostSharedMemory(rtMallocHostSharedMemoryIn *in, + rtMallocHostSharedMemoryOut *out); + +/** + * @ingroup dvrt_mem + * @brief free host memory + * @param [in] in free host shared memory inputPara pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ + +RTS_API rtError_t rtFreeHostSharedMemory(rtFreeHostSharedMemoryIn *in); + +/** + * @ingroup dvrt_mem + * @brief alloc managed memory + * @param [in|out] ptr memory pointer + * @param [in] size memory size + * @param [in] flag reserved, set to 0. + * @param [in] moduleid alloc memory module id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag, const uint16_t moduleId); + +/** + * @ingroup dvrt_mem + * @brief free managed memory + * @param [in] ptr memory pointer + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemFreeManaged(void *ptr); + +/** + * @ingroup dvrt_mem + * @brief alloc cached device memory + * @param [in| devPtr memory pointer + * @param [in] size memory size + * @param [in] type memory type + * @param [in] moduleid alloc memory module id + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMallocCached(void **devPtr, uint64_t size, rtMemType_t type, const uint16_t moduleId); + +/** + * @ingroup dvrt_mem + * @brief flush device mempory + * @param [in] base virtal base addr + * @param [in] len memory size + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtFlushCache(void *base, size_t len); + +/** + * @ingroup dvrt_mem + * @brief invalid device mempory + * @param [in] base virtal base addr + * @param [in] len memory size + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtInvalidCache(void *base, size_t len); + +/** + * @ingroup dvrt_mem + * @brief synchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpy(void *dst, uint64_t destMax, const void *src, uint64_t cnt, rtMemcpyKind_t kind); + +/** + * @ingroup dvrt_mem for mbuff + * @brief synchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyEx(void *dst, uint64_t destMax, const void *src, uint64_t cnt, rtMemcpyKind_t kind); + +/** + * @ingroup dvrt_mem + * @brief host task memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stm task stream + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtMemcpyHostTask(void * const dst, const uint64_t destMax, const void * const src, + const uint64_t cnt, rtMemcpyKind_t kind, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyAsync(void *dst, uint64_t destMax, const void *src, uint64_t cnt, rtMemcpyKind_t kind, + rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type, not check + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyAsyncWithoutCheckKind(void *dst, uint64_t destMax, const void *src, uint64_t cnt, + rtMemcpyKind_t kind, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief dsa update memcpy + * @param [in] streamId dsa streamId + * @param [in] taskId dsa + * @param [in] src source device address pointer + * @param [in] cnt the number of byte to copy + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtLaunchSqeUpdateTask(uint32_t streamId, uint32_t taskId, void *src, uint64_t cnt, + rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stm asynchronized task stream + * @param [in] memcpyConfig memory copy config + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyAsyncEx(void *dst, uint64_t destMax, const void *src, uint64_t cnt, + rtMemcpyKind_t kind, rtStream_t stm, rtMemcpyConfig_t *memcpyConfig); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stm asynchronized task stream + * @param [in] qosCfg asynchronized task qosCfg + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyAsyncWithCfg(void *dst, uint64_t destMax, const void *src, uint64_t cnt, + rtMemcpyKind_t kind, rtStream_t stm, uint32_t qosCfg); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @param [in] stm asynchronized task stream + * @param [in] cfgInfo task config + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyAsyncWithCfgV2(void *dst, uint64_t destMax, const void *src, uint64_t cnt, + rtMemcpyKind_t kind, rtStream_t stm, const rtTaskCfgInfo_t *cfgInfo); + +typedef struct { + uint32_t resv0; + uint32_t resv1; + uint32_t resv2; + uint32_t len; + uint64_t src; + uint64_t dst; +} rtMemcpyAddrInfo; + +RTS_API rtError_t rtMemcpyAsyncPtr(void *memcpyAddrInfo, uint64_t destMax, uint64_t count, + rtMemcpyKind_t kind, rtStream_t stream, uint32_t qosCfg); + +RTS_API rtError_t rtMemcpyAsyncPtrV2(void *memcpyAddrInfo, uint64_t destMax, uint64_t count, + rtMemcpyKind_t kind, rtStream_t stream, const rtTaskCfgInfo_t *cfgInfo); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy + * @param [in] dst destination address pointer + * @param [in] dstMax length of destination address memory + * @param [in] dstOffset + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] srcOffset + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpyD2DAddrAsync(void *dst, uint64_t dstMax, uint64_t dstOffset, const void *src, + uint64_t cnt, uint64_t srcOffset, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief asynchronized reduce memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] cnt the number of byte to copy + * @param [in] kind memcpy type + * @param [in] type data type + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtReduceAsync(void *dst, uint64_t destMax, const void *src, uint64_t cnt, rtRecudeKind_t kind, + rtDataType_t type, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief asynchronized reduce memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @param [in] type data type + * @param [in] stm asynchronized task stream + * @param [in] qosCfg asynchronized task qosCfg + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtReduceAsyncWithCfg(void *dst, uint64_t destMax, const void *src, uint64_t cnt, rtRecudeKind_t kind, + rtDataType_t type, rtStream_t stm, uint32_t qosCfg); + +/** + * @ingroup dvrt_mem + * @brief asynchronized reduce memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @param [in] type data type + * @param [in] stm asynchronized task stream + * @param [in] cfgInfo task config + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtReduceAsyncWithCfgV2(void *dst, uint64_t destMax, const void *src, uint64_t cnt, + rtRecudeKind_t kind, rtDataType_t type, rtStream_t stm, const rtTaskCfgInfo_t *cfgInfo); + +/** + * @ingroup dvrt_mem + * @brief asynchronized reduce memcpy + * @param [in] dst destination address pointer + * @param [in] destMax length of destination address memory + * @param [in] src source address pointer + * @param [in] count the number of byte to copy + * @param [in] kind memcpy type + * @param [in] type data type + * @param [in] stm asynchronized task stream + * @param [in] overflowAddr addr of overflow flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtReduceAsyncV2(void *dst, uint64_t destMax, const void *src, uint64_t count, rtRecudeKind_t kind, + rtDataType_t type, rtStream_t stm, void *overflowAddr); + +/** + * @ingroup dvrt_mem + * @brief synchronized memcpy2D + * @param [in] dst destination address pointer + * @param [in] dstPitch pitch of destination memory + * @param [in] src source address pointer + * @param [in] srcPitch pitch of source memory + * @param [in] width width of matrix transfer + * @param [in] height height of matrix transfer + * @param [in] kind memcpy type + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpy2d(void *dst, uint64_t dstPitch, const void *src, uint64_t srcPitch, uint64_t width, + uint64_t height, rtMemcpyKind_t kind); + +/** + * @ingroup dvrt_mem + * @brief asynchronized memcpy2D + * @param [in] dst destination address pointer + * @param [in] dstPitch length of destination address memory + * @param [in] src source address pointer + * @param [in] srcPitch length of destination address memory + * @param [in] width width of matrix transfer + * @param [in] height height of matrix transfer + * @param [in] kind memcpy type + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemcpy2dAsync(void *dst, uint64_t dstPitch, const void *src, uint64_t srcPitch, uint64_t width, + uint64_t height, rtMemcpyKind_t kind, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief query memory size + * @param [in] aiCoreMemorySize + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); + +/** + * @ingroup dvrt_mem + * @brief set memory size, Setting before model reasoning, Bright screen to prevent model can not be fully + integrated network due to memory limitations.Requirement come from JiaMinHu.Only use for Tiny. + * @param [in] aiCoreMemorySize + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetAiCoreMemorySizes(rtAiCoreMemorySize_t *aiCoreMemorySize); + +/** + * @ingroup dvrt_mem + * @brief Specifies how memory is use + * @param [in] devPtr memory pointer + * @param [in] count memory count + * @param [in] advise reserved, set to 1 + * @return RT_ERROR_NONE for ok + * @return others for error + */ +RTS_API rtError_t rtMemAdvise(void *devPtr, uint64_t count, uint32_t advise); +/** + * @ingroup dvrt_mem + * @brief set memory with uint32_t value + * @param [in] devPtr + * @param [in] destMax length of destination address memory + * @param [in] val + * @param [in] cnt byte num + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemset(void *devPtr, uint64_t destMax, uint32_t val, uint64_t cnt); + +/** + * @ingroup dvrt_mem + * @brief set memory with uint32_t value async + * @param [in] devPtr + * @param [in] destMax length of destination address memory + * @param [in] val + * @param [in] cnt byte num + * @param [in] stm + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t val, uint64_t cnt, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief get current device memory total and free + * @param [out] freeSize + * @param [out] totalSize + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemGetInfo(size_t *freeSize, size_t *totalSize); + +/** + * @ingroup dvrt_mem + * @brief get current device memory total and free + * @param [in] memInfoType + * @param [out] freeSize + * @param [out] totalSize + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtMemGetInfoEx(rtMemInfoType_t memInfoType, size_t *freeSize, size_t *totalSize); + +/** + * @ingroup dvrt_mem + * @brief set memory with uint32_t value + * @param [in] devPtr + * @param [in] len + * @param [in] devId + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMemPrefetchToDevice(void *devPtr, uint64_t len, int32_t devId); + +/** + * @ingroup dvrt_mem + * @brief get memory attribute:Host or Device + * @param [in] ptr + * @param [out] attributes + * @return RT_ERROR_NONE for ok, errno for failed + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtPointerGetAttributes(rtPointerAttributes_t *attributes, const void *ptr); + +/** + * @ingroup dvrt_mem + * @brief make memory shared interprocess and assigned a name + * @param [in] ptr device memory address pointer + * @param [in] name identification name + * @param [in] byteCount identification byteCount + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcSetMemoryName(const void *ptr, uint64_t byteCount, char_t *name, uint32_t len); + +/** + * @ingroup dvrt_mem + * @brief set the attribute of shared memory + * @param [in] name identification name + * @param [in] type shared memory mapping type + * @param [in] attr shared memory attribute + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error +*/ +RTS_API rtError_t rtIpcSetMemoryAttr (const char *name, uint32_t type, uint64_t attr); + +/** + * @ingroup dvrt_mem + * @brief destroy a interprocess shared memory + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcDestroyMemoryName(const char_t *name); + +/** + * @ingroup dvrt_mem + * @brief open a interprocess shared memory + * @param [in|out] ptr device memory address pointer + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcOpenMemory(void **ptr, const char_t *name); + +/** + * @ingroup dvrt_mem + * @brief close a interprocess shared memory + * @param [in] ptr device memory address pointer + * @param [in] name identification name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtIpcCloseMemory(const void *ptr); + +/** + * @ingroup dvrt_mem + * @brief HCCL Async memory cpy + * @param [in] sqIndex sq index + * @param [in] wqeIndex moudle index + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtRDMASend(uint32_t sqIndex, uint32_t wqeIndex, rtStream_t stm); + +/** + * @ingroup dvrt_mem + * @brief Ipc set mem pid + * @param [in] name name to be queried + * @param [in] pid process id + * @param [in] num length of pid[] + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtSetIpcMemPid(const char_t *name, int32_t pid[], int32_t num); + +/** + * @ingroup dvrt_mem + * @brief HCCL Async memory cpy + * @param [in] dbindex single device 0 + * @param [in] dbinfo doorbell info + * @param [in] stm asynchronized task stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtRDMADBSend(uint32_t dbIndex, uint64_t dbInfo, rtStream_t stm); + + +typedef struct DrvMemHandle { + int32_t id; + uint32_t side; + uint32_t devid; + uint64_t pg_num; +} rtDrvMemHandle_t; + +typedef void* rtDrvMemHandle; + +typedef struct DrvMemProp { + uint32_t side; + uint32_t devid; + uint32_t module_id; + + uint32_t pg_type; + uint32_t mem_type; + uint64_t reserve; +} rtDrvMemProp_t; + +typedef enum DrvMemHandleType { + RT_MEM_HANDLE_TYPE_NONE = 0x0, +} rtDrvMemHandleType; + +typedef enum DrvMemAttrType { + RT_ATTR_TYPE_MEM_MAP = 0, + RT_ATTR_TYPE_MAX, +}rtDrvMemAttrType; + +typedef enum DrvMemGranularityOptions { + RT_MEM_ALLOC_GRANULARITY_MINIMUM = 0x0, + RT_MEM_ALLOC_GRANULARITY_RECOMMENDED, + RT_MEM_ALLOC_GRANULARITY_INVALID, +} rtDrvMemGranularityOptions; + +/** + * @ingroup dvrt_mem + * @brief This command is used to reserve a virtual address range + * @attention Only support ONLINE scene + * @param [in] devPtr Resulting pointer to start of virtual address range allocated. + * @param [in] size Size of the reserved virtual address range requested. + * @param [in] alignment Alignment of the reserved virtual address range requested, Currently unused, must be zero. + * @param [in] devAddr Expected virtual address space start address Currently, Currently unused, must be zero. + * @param [in] flags currently unused, must be zero. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtReserveMemAddress(void** devPtr, size_t size, size_t alignment, void *devAddr, uint64_t flags); + +/** + * @ingroup dvrt_mem + * @brief This command is used to free a virtual address range reserved by halMemAddressReserve. + * @attention Only support ONLINE scene. + * @param [in] devPtr Starting address of the virtual address range to free. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtReleaseMemAddress(void* devPtr); + +/** + * @ingroup dvrt_mem + * @brief This command is used to alloc physical memory. + * @attention Only support ONLINE scene. + * @param [out] handle Value of handle returned,all operations on this allocation are to be performed using this handle. + * @param [in] size Size of the allocation requested. + * @param [in] prop Properties of the allocation to create. + * @param [in] flags Currently unused, must be zero. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtMallocPhysical(rtDrvMemHandle* handle, size_t size, rtDrvMemProp_t* prop, uint64_t flags); + +/** + * @ingroup dvrt_mem + * @brief This command is used to alloc physical memory. + * @attention Only support ONLINE scene. + * @param [out] handle Value of handle returned,all operations on this allocation are to be performed using this handle. + * @param [in] size Size of the allocation requested. + * @param [in] prop Properties of the allocation to create. + * @param [in] flags Currently unused, must be zero. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtMallocPhysicalV2(rtDrvMemHandle_t** handle, size_t size, rtDrvMemProp_t* prop, uint64_t flags); + +/** + * @ingroup dvrt_mem + * @brief This command is used to free physical memory. + * @attention Only support ONLINE scene. + * @param [in] handle Value of handle which was returned previously by halMemCreate. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtFreePhysical(rtDrvMemHandle handle); + +/** + * @ingroup dvrt_mem + * @brief This command is used to free physical memory. + * @attention Only support ONLINE scene. + * @param [in] handle Value of handle which was returned previously by halMemCreate. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtFreePhysicalV2(rtDrvMemHandle_t* handle); + +/** + * @ingroup dvrt_mem + * @brief This command is used to map an allocation handle to a reserved virtual address range. + * @attention Only support ONLINE scene. + * @param [in] devPtr Address where memory will be mapped. + * @param [in] size Size of the memory mapping. + * @param [in] offset Currently unused, must be zero. + * @param [in] handle Value of handle which was returned previously by halMemCreate. + * @param [in] flag Currently unused, must be zero. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtMapMem(void* devPtr, size_t size, size_t offset, rtDrvMemHandle handle, uint64_t flags); + +/** + * @ingroup dvrt_mem + * @brief This command is used to map an allocation handle to a reserved virtual address range. + * @attention Only support ONLINE scene. + * @param [in] devPtr Address where memory will be mapped. + * @param [in] size Size of the memory mapping. + * @param [in] offset Currently unused, must be zero. + * @param [in] handle Value of handle which was returned previously by halMemCreate. + * @param [in] flag Currently unused, must be zero. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtMapMemV2(void* devPtr, size_t size, size_t offset, rtDrvMemHandle_t* handle, uint64_t flags); + +/** + * @ingroup dvrt_mem + * @brief This command is used to unmap the backing memory of a given address range. + * @attention Only support ONLINE scene. + * @param [in] devPtr Starting address for the virtual address range to unmap. + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_DRV_ERR for driver error + */ +RTS_API rtError_t rtUnmapMem(void* devPtr); + +/** +* @ingroup dvrt_mem +* @brief This command is used to export an allocation to a shareable handle. +* @attention Only support ONLINE scene. Not support compute group. +* @param [in] handle Handle for the memory allocation. +* @param [in] handleType Currently unused, must be MEM_HANDLE_TYPE_NONE. +* @param [in] flags Currently unused, must be zero. +* @param [out] shareableHandle Export a shareable handle. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemExportToShareableHandle(rtDrvMemHandle handle, rtDrvMemHandleType handleType, + uint64_t flags, uint64_t *shareableHandle); + +/** +* @ingroup dvrt_mem +* @brief This command is used to export an allocation to a shareable handle. +* @attention Only support ONLINE scene. Not support compute group. +* @param [in] handle Handle for the memory allocation. +* @param [in] handleType Currently unused, must be MEM_HANDLE_TYPE_NONE. +* @param [in] flags Currently unused, must be zero. +* @param [out] shareableHandle Export a shareable handle. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemExportToShareableHandleV2(rtDrvMemHandle_t* handle, rtDrvMemHandleType handleType, + uint64_t flags, uint64_t *shareableHandle); + +/** +* @ingroup dvrt_mem +* @brief This command is used to import an allocation from a shareable handle. +* @attention Only support ONLINE scene. Not support compute group. +* @param [in] shareableHandle Import a shareable handle. +* @param [in] devId Device id. +* @param [out] handle Value of handle returned, all operations on this allocation are to be performed using this handle. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemImportFromShareableHandle(uint64_t shareableHandle, int32_t devId, rtDrvMemHandle* handle); + +/** +* @ingroup dvrt_mem +* @brief This command is used to import an allocation from a shareable handle. +* @attention Only support ONLINE scene. Not support compute group. +* @param [in] shareableHandle Import a shareable handle. +* @param [in] devId Device id. +* @param [out] handle Value of handle returned, all operations on this allocation are to be performed using this handle. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemImportFromShareableHandleV2(uint64_t shareableHandle, int32_t devId, rtDrvMemHandle_t** handle); + +/** +* @ingroup dvrt_mem +* @brief This command is used to configure the process whitelist which can use shareable handle. +* @attention Only support ONLINE scene. Not support compute group. +* @param [in] shareableHandle A shareable handle. +* @param [in] pid Host pid whitelist array. +* @param [in] pid_num Number of pid arrays. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemSetPidToShareableHandle(uint64_t shareableHandle, int pid[], uint32_t pidNum); + +/** +* @ingroup dvrt_mem +* @brief This command is used to calculate either the minimal or recommended granularity. +* @attention Only support ONLINE scene. +* @param [in] prop Properties of the allocation. +* @param [in] option Determines which granularity to return. +* @param [out] granularity Returned granularity. +* @return DRV_ERROR_NONE : success +* @return DV_ERROR_XXX : fail +*/ +RTS_API rtError_t rtMemGetAllocationGranularity(rtDrvMemProp_t *prop, rtDrvMemGranularityOptions option, + size_t *granularity); + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_MEM_H diff --git a/inc/runtime/runtime/rt.h b/inc/runtime/runtime/rt.h new file mode 100644 index 0000000000000000000000000000000000000000..de1b768b7b5a3bb401af17db30d91366b4a13aec --- /dev/null +++ b/inc/runtime/runtime/rt.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: rt.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_RT_H +#define CCE_RUNTIME_RT_H + +#include "base.h" +#include "config.h" +#include "context.h" +#include "dev.h" +#include "event.h" +#include "kernel.h" +#include "mem.h" +#include "rt_model.h" +#include "stream.h" +#include "rt_stars.h" +#include "rt_ffts.h" +#include "rt_ffts_plus.h" +#include "rt_dfx.h" +#include "rt_mem_queue.h" + +#endif // CCE_RUNTIME_RT_H diff --git a/inc/runtime/runtime/rt_dfx.h b/inc/runtime/runtime/rt_dfx.h new file mode 100644 index 0000000000000000000000000000000000000000..71215f801f31e8f1dfa1207eaaf74f4bab263729 --- /dev/null +++ b/inc/runtime/runtime/rt_dfx.h @@ -0,0 +1,44 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: dfx interface + */ + +#ifndef CCE_RUNTIME_RT_DFX_H +#define CCE_RUNTIME_RT_DFX_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +// max task tag buffer is 1024(include '\0') +#define TASK_TAG_MAX_LEN 1024U + +/** + * @brief set task tag. + * once set is only use by one task and thread local. + * attention: + * 1. it's used for dump current task in active stream now. + * 2. it must be called be for task submit and will be invalid after task submit. + * @param [in] taskTag task tag, usually it's can be node name or task name. + * must end with '\0' and max len is TASK_TAG_MAX_LEN. + * @return RT_ERROR_NONE for ok + * @return other failed + */ +RTS_API rtError_t rtSetTaskTag(const char_t *taskTag); + +/** + * @brief set aicpu device attribute. + * it is used for aicpu device to be aware of enviroment config + * @param [in] key attrubute key. + * @param [in] val attrubute value. + * @return RT_ERROR_NONE for ok + * @return other failed + */ +RTS_API rtError_t rtSetAicpuAttr(const char_t *key, const char_t *val); + +#if defined(__cplusplus) +} +#endif +#endif // CCE_RUNTIME_RT_DFX_H diff --git a/inc/runtime/runtime/rt_ffts.h b/inc/runtime/runtime/rt_ffts.h new file mode 100644 index 0000000000000000000000000000000000000000..5c0ab9714cb996072ed81fae14081f886cf7fcbe --- /dev/null +++ b/inc/runtime/runtime/rt_ffts.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved. + * Description: ffts interface + */ + +#ifndef CCE_RUNTIME_RT_FFTS_H +#define CCE_RUNTIME_RT_FFTS_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#define RT_FFTS_MAX_SUB_TASK_NUM 32U +#define RT_FFTS_MAX_TICKET_CACHE_NUM 64U +#define RT_FFTS_MAX_MANUAL_THREAD_NUM 16U +#define RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK 8U +#define RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN 32U + +typedef enum tagFftsType { + RT_FFTS_TYPE_AUTO_THREAD = 2, // ffts auto thread mode, same as ffts define + RT_FFTS_TYPE_MANUAL_THREAD = 3, // ffts manual thread mode, same as ffts define +} rtFftsType_t; + +typedef enum tagFftsSubTaskType { + RT_FFTS_SUB_TASK_TYPE_AIC = 0, + RT_FFTS_SUB_TASK_TYPE_AIV = 1, + RT_FFTS_SUB_TASK_TYPE_NOP = 2, + RT_FFTS_SUB_TASK_TYPE_NOTIFY_WAIT = 3, + RT_FFTS_SUB_TASK_TYPE_NOTIFY_RECORD = 4, + RT_FFTS_SUB_TASK_TYPE_WRITE_VALUE = 5, + RT_FFTS_SUB_TASK_TYPE_MIX_AIC = 6, + RT_FFTS_SUB_TASK_TYPE_MIX_AIV = 7, + RT_FFTS_SUB_TASK_TYPE_SDMA = 8, + RT_FFTS_SUB_TASK_TYPE_RESERVED = 9, +} rtFftsSubTaskType_t; + +typedef struct tagManualThreadDmuInfo { + uint64_t dataAddr; // device mem + uint16_t numOuter; + uint16_t numInner; + uint32_t strideOuter; + uint32_t lenInner; + uint32_t strideInner; +} rtManualThreadDmuInfo_t; + +typedef struct tagManualThreadDependency { + uint8_t dependency[RT_FFTS_MANUAL_SRC_DEPEND_TBL_LEN]; +} rtManualThreadDependency_t; + +typedef struct tagManualThreadAicAivInfo { + uint64_t taskParamAddr; // device mem + uint16_t taskParamOffset; + // when satMode=1 and FP16 computation with none INF inputs overflows/underflows, results will be +/-INF of FP16 + // when satMode=0 and FP16 computation with none INF inputs overflows/underflows, + // results will be saturated to +/-MAX of FP16 + uint8_t satMode; + uint8_t scheduleMode; // 0:normal mode, 1:batch mode, 2:sync mode 3:reserved + uint8_t iCachePrefetchCnt; // units is 2K + uint8_t prefetchEnableBitmap; // 8 bit bitmap 1 0 1 0 + uint8_t prefetchOnceBitmap; // 8 bit bitmap 1 0 1 0 + uint16_t prefetchOnceDmuNum; // prefetch_once_dmu_descriptor_index in ffts + // num: thread0_prefetch_dmu_descriptor_index – prefetch_once_dmu_descriptor_index + uint16_t threadPrefetchDmuIdx[RT_FFTS_MAX_MANUAL_THREAD_NUM]; // max valid is threadDim + uint16_t threadBlkDim[RT_FFTS_MAX_MANUAL_THREAD_NUM]; + const char_t *threadTaskFuncStub[RT_FFTS_MAX_MANUAL_THREAD_NUM]; + + rtManualThreadDmuInfo_t *prefetchList; // dmu desc 0-64k, length is the last threadPrefetchDmuIdx[threadDim-1] + rtManualThreadDependency_t srcDepTbl[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; +} rtManualThreadAicAivInfo_t; + +typedef struct tagAutoThreadPrefetch { + uint64_t dataAddr; // device mem + uint32_t dataAddrOffset; + uint32_t nonTailDataLen; + uint32_t tailDataLen; +} rtAutoThreadPrefetch_t; + +typedef struct tagAutoThreadAicAivInfo { + uint64_t taskParamAddr; // device mem + uint16_t taskParamOffset; + /* + * when satMode=1 and FP16 computation with none INF inputs overflows/underflows, results will be +/-INF of FP16 + * when satMode=0 and FP16 computation with none INF inputs overflows/underflows, results will be saturated to + * +/-MAX of FP16 + */ + uint8_t satMode; + uint8_t scheduleMode; // 0:normal mode, 1:batch mode, 2:sync mode 3:reserved + uint8_t iCachePrefetchCnt; // units is 2K + uint8_t prefetchEnableBitmap; // 8 bit bitmap + uint8_t prefetchOnceBitmap; // 8 bit bitmap + + uint16_t tailBlkDim; + uint16_t nonTailBlkDim; + + const char_t *nonTailTaskFuncStub; + const char_t *tailTaskFuncStub; + + // for prefetch, valid num is prefetchEnableBitmap bit count. + // if prefetchEnableBitmap='00010011', need prefetch number is 3, srcPrefetch is only 0, 1, 2 is valid + rtAutoThreadPrefetch_t srcPrefetch[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; +} rtAutoThreadAicAivInfo_t; + +typedef struct tagAutoThreadCacheInfo { + uint64_t dataAddr; // device mem + uint32_t dataAddrOffset; + uint32_t nonTailDataLen; + uint32_t tailDataLen; + uint16_t ticketCacheRefCnt; +} rtAutoThreadCacheInfo_t; + +typedef struct tagManualThreadCacheInfo { + rtManualThreadDmuInfo_t *dmuList; // 0-64k + uint16_t dmuNum; + uint16_t sliceDmuIdx[RT_FFTS_MAX_MANUAL_THREAD_NUM]; + uint16_t ticketCacheRefCntTbl[RT_FFTS_MAX_MANUAL_THREAD_NUM]; +} rtManualThreadCacheInfo_t; + +typedef enum tagCacheOp { + RT_CACHE_OP_NONE = 0, + RT_CACHE_OP_FLUSH = 1, + RT_CACHE_OP_INVALIDATE = 2, + RT_CACHE_OP_WRITE_BACK = 3, +} rtCacheOp_t; + +typedef struct tagTicketCache { + rtCacheOp_t cacheOption; + uint8_t ticketCacheWindow; + union { + rtAutoThreadCacheInfo_t autoThreadCache; + rtManualThreadCacheInfo_t manualThreadCache; + } custom; +} rtTicketCache_t; + +typedef struct tagManualThreadNopInfo { + // depend srcTickCacheVldBitmap in rtFftsSubTaskInfo_t + rtManualThreadDependency_t srcDepTbl[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; +} rtManualThreadNopInfo_t; + +typedef struct tagFftsSubTaskInfo { + rtFftsSubTaskType_t subTaskType; + uint16_t threadDim; + uint8_t dstTickCacheVldBitmap; + uint8_t srcTickCacheVldBitmap; + uint8_t srcDataOutOfSubGraphBitmap; + uint8_t dstTickCacheID[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; + uint8_t srcTickCacheID[RT_FFTS_MAX_TICKET_CACHE_PER_SUBTASK]; + union { + rtAutoThreadAicAivInfo_t autoThreadAicAiv; + rtManualThreadAicAivInfo_t manualThreadAicAiv; + rtManualThreadNopInfo_t manualThreadNop; + } custom; +} rtFftsSubTaskInfo_t; + +typedef struct tagFftsDescInfo { + uint8_t tm; // thread subtask kickstart mode, 0:order, 1:disorder + uint8_t di; // discard invalidate + uint8_t dw; // discard write back + uint8_t df; // discard flush + uint8_t dataSplitUnit; // split source or ticket cache by 2^dataSplitUnit MB + uint8_t prefetchOstNum; + uint8_t cacheMaintainOstNum; + uint8_t aicPrefetchUpper; + uint8_t aicPrefetchLower; + uint8_t aivPrefetchUpper; + uint8_t aivPrefetchLower; +} rtFftsDescInfo_t; + +typedef struct tagFftsTaskInfo { + rtFftsType_t fftsType; + uint16_t subTaskNum; + uint16_t tickCacheNum; + rtFftsDescInfo_t fftsDesc; + // sub task desc, real num is subTaskNum + rtFftsSubTaskInfo_t subTask[RT_FFTS_MAX_SUB_TASK_NUM]; + + // ticket cache, real number is tickCacheNum. + rtTicketCache_t ticketCache[RT_FFTS_MAX_TICKET_CACHE_NUM]; +} rtFftsTaskInfo_t; + +RTS_API rtError_t rtFftsTaskLaunch(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stm); +RTS_API rtError_t rtGetC2cCtrlAddr(uint64_t *addr, uint32_t *len); + +RTS_API rtError_t rtFftsTaskLaunchWithFlag(rtFftsTaskInfo_t *fftsTaskInfo, rtStream_t stm, uint32_t flag); + +#if defined(__cplusplus) +} +#endif +#endif // CCE_RUNTIME_RT_FFTS_H diff --git a/inc/runtime/runtime/rt_ffts_plus.h b/inc/runtime/runtime/rt_ffts_plus.h new file mode 100644 index 0000000000000000000000000000000000000000..618bdfe132e48e695bb39617808a5beba7e79061 --- /dev/null +++ b/inc/runtime/runtime/rt_ffts_plus.h @@ -0,0 +1,53 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: ffts plus interface + */ + +#ifndef CCE_RUNTIME_RT_FFTS_PLUS_H +#define CCE_RUNTIME_RT_FFTS_PLUS_H + +#include "base.h" +#include "rt_ffts_plus_define.h" +#include "rt_stars_define.h" +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#pragma pack(push) +#pragma pack (1) + +// context desc addr type +#define RT_FFTS_PLUS_CTX_DESC_ADDR_TYPE_HOST (0x0U) +#define RT_FFTS_PLUS_CTX_DESC_ADDR_TYPE_DEVICE (0x1U) + +typedef struct tagFftsPlusDumpInfo { + const void *loadDumpInfo; + const void *unloadDumpInfo; + uint32_t loadDumpInfolen; + uint32_t unloadDumpInfolen; +} rtFftsPlusDumpInfo_t; + +typedef struct tagFftsPlusTaskInfo { + const rtFftsPlusSqe_t *fftsPlusSqe; + const void *descBuf; // include total context + size_t descBufLen; // the length of descBuf + rtFftsPlusDumpInfo_t fftsPlusDumpInfo; // used only in the dynamic shape + uint32_t descAddrType; // 0:host addr 1:device addr + uint32_t argsHandleInfoNum; + void **argsHandleInfoPtr; +} rtFftsPlusTaskInfo_t; + +#pragma pack(pop) + +RTS_API rtError_t rtGetAddrAndPrefCntWithHandle(void *hdl, const void *kernelInfoExt, void **addr, + uint32_t *prefetchCnt); + +RTS_API rtError_t rtFftsPlusTaskLaunch(rtFftsPlusTaskInfo_t *fftsPlusTaskInfo, rtStream_t stm); + +RTS_API rtError_t rtFftsPlusTaskLaunchWithFlag(rtFftsPlusTaskInfo_t *fftsPlusTaskInfo, rtStream_t stm, + uint32_t flag); + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif +#endif // CCE_RUNTIME_RT_FFTS_PLUS_H diff --git a/inc/runtime/runtime/rt_ffts_plus_define.h b/inc/runtime/runtime/rt_ffts_plus_define.h new file mode 100644 index 0000000000000000000000000000000000000000..0ead3d7134f3338933ab3c08cd4d3bdd579f93ba --- /dev/null +++ b/inc/runtime/runtime/rt_ffts_plus_define.h @@ -0,0 +1,806 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: the definition of ffts plus + */ + +#ifndef CCE_RUNTIME_RT_FFTS_PLUS_DEFINE_H +#define CCE_RUNTIME_RT_FFTS_PLUS_DEFINE_H + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#pragma pack(push) +#pragma pack (1) + +// hardware context type +typedef enum tagFftsPlusHwType { + RT_HW_CTX_TYPE_AIC = 0, + RT_HW_CTX_TYPE_AIV = 1, + RT_HW_CTX_TYPE_NOTIFY_WAIT = 3, + RT_HW_CTX_TYPE_NOTIFY_RECORD = 4, + RT_HW_CTX_TYPE_WRITE_VALUE = 5, + RT_HW_CTX_TYPE_MIX_AIC = 6, + RT_HW_CTX_TYPE_MIX_AIV = 7, + RT_HW_CTX_TYPE_SDMA = 8, + RT_HW_CTX_TYPE_FLUSH_DATA = 9, + RT_HW_CTX_TYPE_INVALIDATE_DATA = 10, + RT_HW_CTX_TYPE_WRITEBACK_DATA = 11, + RT_HW_CTX_TYPE_AICPU = 12, + RT_HW_CTX_TYPE_LOAD = 13, + RT_HW_CTX_TYPE_MAX = 14, +} rtFftsPlusHwType_t; + +// hardware context type +typedef enum tagFftsPlusSoftType { + RT_SOFT_CTX_TYPE_COND_SWITCH = 1, + RT_SOFT_CTX_TYPE_CASE_SWITCH = 2, + RT_SOFT_CTX_TYPE_AT_START = 3, + RT_SOFT_CTX_TYPE_AT_END = 4, + RT_SOFT_CTX_TYPE_LABEL = 5, + RT_SOFT_CTX_PERSISTENT_CACHE = 6, + RT_SOFT_CTX_DSA = 7, + RT_SOFT_CTX_TYPE_MAX = 8, +} rtFftsPlusSoftType_t; + +typedef enum tagFftsPlusContextType { + RT_CTX_TYPE_AICORE = 0x0000, + RT_CTX_TYPE_AIV = 0x0001, + RT_CTX_TYPE_NOTIFY_WAIT = 0x0003, + RT_CTX_TYPE_NOTIFY_RECORD = 0x0004, + RT_CTX_TYPE_WRITE_VALUE = 0x0005, + RT_CTX_TYPE_MIX_AIC = 0x0006, + RT_CTX_TYPE_MIX_AIV = 0x0007, + RT_CTX_TYPE_SDMA = 0x0008, + RT_CTX_TYPE_FLUSH_DATA = 0x0009, + RT_CTX_TYPE_INVALIDATE_DATA = 0x000A, + RT_CTX_TYPE_WRITEBACK_DATA = 0x000B, + RT_CTX_TYPE_AICPU = 0x000C, + RT_CTX_TYPE_COND_SWITCH = 0x010D, + RT_CTX_TYPE_CASE_SWITCH = 0x020D, + RT_CTX_TYPE_AT_START = 0x0300, + RT_CTX_TYPE_AT_END = 0x0400, + RT_CTX_TYPE_LABEL = 0x0500, + RT_CTX_TYPE_PERSISTENT_CACHE = 0x0600, + RT_CTX_TYPE_DSA = 0x0700, +}rtFftsPlusContextType_t; + +// condition type +typedef enum tagFftsPlusCondType { + RT_COND_TYPE_EQUAL = 0, + RT_COND_TYPE_NOTEQUAL = 1, + RT_COND_TYPE_GREATER = 2, + RT_COND_TYPE_GREATER_OR_EQUAL = 3, + RT_COND_TYPE_LESS = 4, + RT_COND_TYPE_LESS_OR_EQUAL = 5, + RT_COND_TYPE_MAX = 6, +} rtFftsPlusCondType_t; + +// the definition of ffts plus context + +#define RT_CTX_SUCCESSOR_NUM 26 + +// ffts plus common context +typedef struct tagFftsPlusComCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t rsv1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t rsv2; + uint8_t rsv3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t rsv4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-71 + uint32_t rsv5[2]; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-127 + uint32_t res6[13]; +} rtFftsPlusComCtx_t; + +// aic/aiv context +typedef struct tagFftsPlusAicAivCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t resv : 6; + uint8_t dumpSwitch : 1; + uint8_t aten : 1; + // 4-7 + uint8_t prefetchConfig; + uint8_t resv1; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint16_t resv2; + uint16_t policyPri; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t resv3 : 1; + uint16_t schem : 2; + uint16_t icachePrefetchCnt : 5; + uint16_t resv4 : 7; + uint16_t atm : 1; + uint16_t prefetchEnableBitmap : 4; + uint16_t res6 : 4; + uint16_t prefetchOnceBitmap : 4; + uint16_t res7 : 4; + // 68-71 + uint16_t pmg : 2; + uint16_t ns : 1; + uint16_t partId : 8; + uint16_t res8 : 1; + uint16_t qos : 4; + uint16_t res9; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t nonTailBlockdim; + uint16_t tailBlockdim; + // 80-83 + uint32_t taskParamPtrBaseL; + // 84-87 + uint16_t taskParamPtrBaseH; + uint16_t taskParamPtrOffset; + // 88-95 + uint32_t res10; + uint32_t res11; + // 96-103 + uint32_t nonTailTaskStartPcL; + uint16_t nonTailTaskStartPcH; + uint16_t res12; + // 104-111 + uint32_t tailTaskStartPcL; + uint16_t tailTaskStartPcH; + uint16_t res13; + // 112-119 + uint32_t res14; + uint32_t res15; + // 120-127 + uint16_t srcSlot[4]; // src_slot0-3(context ID for source data which is out of subgraph) +} rtFftsPlusAicAivCtx_t; + +// mix aic/aiv context +typedef struct tagFftsPlusMixAicAivCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t reserved1 : 6; + uint8_t dumpSwitch : 1; + uint8_t aten : 1; + // 4-7 + uint8_t prefetchConfig; + uint8_t reserved2; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint16_t reserved3; + uint16_t policyPri; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t reserved4 : 1; + uint16_t schem : 2; + uint16_t aicIcachePrefetchCnt : 5; + uint16_t aivIcachePrefetchCnt : 5; + uint16_t reserved5 : 2; + uint16_t atm : 1; + uint16_t prefetchEnableBitmap : 4; + uint16_t reserved6 : 4; + uint16_t prefetchOnceBitmap : 4; + uint16_t reserved7 : 4; + // 68-71 + uint16_t pmg : 2; + uint16_t ns : 1; + uint16_t partId : 8; + uint16_t reserved8 : 1; + uint16_t qos : 4; + uint8_t nonTailBlockRatioN; + uint8_t tailBlockRatioN; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t nonTailBlockdim; + uint16_t tailBlockdim; + // 80-87 + uint32_t aicTaskParamPtrL; + uint16_t aicTaskParamPtrH; + uint16_t aicTaskParamPtrOffset; + // 88-95 + uint32_t aivTaskParamPtrL; + uint16_t aivTaskParamPtrH; + uint16_t aivTaskParamPtrOffset; + // 96-103 + uint32_t nonTailAicTaskStartPcL; + uint16_t nonTailAicTaskStartPcH; + uint16_t tailAicTaskStartPcH; + // 104-111 + uint32_t tailAicTaskStartPcL; + uint32_t nonTailAivTaskStartPcL; + // 112-119 + uint16_t nonTailAivTaskStartPcH; + uint16_t tailAivTaskStartPcH; + uint32_t tailAivTaskStartPcL; + // 120-127 + uint16_t srcSlot[4]; // src_slot0-3(context ID for source data which is out of subgraph) +} rtFftsPlusMixAicAivCtx_t; + +// sdma context +typedef struct tagFftsPlusSdmaCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 6; + uint8_t dumpSwitch : 1; + uint8_t aten : 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint8_t res5; + uint8_t res6 : 7; + uint8_t atm : 1; + uint16_t res7; + // 68-71 + uint16_t pmg : 2; + uint16_t ns : 1; + uint16_t partId : 8; + uint16_t res8 : 1; + uint16_t qos : 4; + uint16_t res9; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint32_t sdmaSqeHeader; // (FORMAT/MPAMNS/PARTID/DRO/SRO/QOS/DNS/SNS/DSSV/SSSV/IE/UPCODE) + // 80-83 + uint16_t sourceStreamId; + uint16_t sourceSubstreamId; + // 84-87 + uint16_t destinationStreamId; + uint16_t destinationSubstreamId; + // 88-127 + uint32_t sourceAddressBaseL; + uint32_t sourceAddressBaseH; + uint32_t sourceAddressOffset; + uint32_t destinationAddressBaseL; + uint32_t destinationAddressBaseH; + uint32_t destinationAddressOffset; + uint32_t nonTailDataLength; + uint32_t tailDataLength; + uint32_t res10[2]; +} rtFftsPlusSdmaCtx_t; + +// ffts plus notify record/wait context +typedef struct tagFftsPlusNotifyCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res : 7; + uint8_t aten : 1; + // 4-7 + uint8_t res1; + uint8_t res2; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res3; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res4 : 14; + uint16_t satm : 1; + uint16_t atm : 1; + uint16_t res6; + // 68-71 + uint32_t res7; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t notifyIdBase; + uint8_t autoWindow; + uint8_t res8; + // 80-127 + uint32_t res9[4]; + uint16_t notifyId[16]; +} rtFftsPlusNotifyCtx_t; + +// write Value context +typedef struct tagFftsPlusWriteValueCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t resv1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t resv2; + uint8_t resv3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t resv4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t resv5 : 15; + uint16_t atm : 1; + uint16_t resv6; + // 68-71 + uint32_t resv7; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint8_t awSize : 3; + uint8_t awSnoop : 1; + uint8_t resv8 : 4; + uint8_t awCache : 4; + uint8_t awProt : 3; + uint8_t awVa : 1; + + uint8_t arSize : 3; + uint8_t arSnoop : 1; + uint8_t resv9 : 4; + uint8_t arCache : 4; + uint8_t arProt : 3; + uint8_t arVa : 1; + // 80-83 + uint32_t writeAddressBaseL; + // 84-87 + uint32_t writeAddressBaseH : 17; + uint32_t res10 : 15; + // 88-91 + uint32_t writeAddressOffset; + // 92-95 + uint32_t res11; + // 96-111 + uint32_t writeValue[4]; // write_value_00 -> write_value_03 + // 112-127 + uint32_t res12[4]; +} rtFftsPlusWriteValueCtx_t; + +// ai cpu context +typedef struct tagFftsPlusAiCpuCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 6; + uint8_t dumpSwitch : 1; + uint8_t aten : 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorContextID[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res5 : 15; + uint16_t atm : 1; + uint16_t res6; + // 68-71 + uint16_t sqeIndex; + uint8_t kernelType : 7; + uint8_t bm : 1; + uint8_t topicType : 4; + uint8_t qos : 3; + uint8_t res7 : 1; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t nonTailBlockdim; + uint16_t tailBlockdim; + // 80-115 + uint32_t usrData[9]; // usr_data0 -> usr_data8 usr_data2(task_param_base_l) usr_data3(task_param_base_h) + // 116--119 + uint32_t res8; + // 120-123 + uint32_t subtopicId : 12; + uint32_t topicId : 6; + uint32_t groupId : 6; + uint32_t usrDataLength : 8; + // 124-127 + uint32_t taskParamOffset; +} rtFftsPlusAiCpuCtx_t; + +// data context +typedef struct tagFftsPlusDataCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t cntInit; // cons_cnt_init / prod_cnt_init + uint8_t cnt; // cons_cnt / prod_cnt + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t res5 : 15; + uint16_t atm : 1; + uint16_t res6; + // 68-71 + uint16_t pmg : 2; + uint16_t ns : 1; + uint16_t partId : 8; + uint16_t res7 : 1; + uint16_t qos : 4; + uint16_t res8; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t origConsumerCounter; + uint16_t runConsumerCounter; + // 80-83 + uint32_t addressBaseL; + // 84-87 + uint32_t addressBaseH; + // 88-91 + uint32_t addressOffset; + // 92-95 + uint32_t res9; + // 96-99 + uint16_t nonTailNumOutter; + uint16_t nonTailNumInner; + // 100-103 + uint32_t nonTailLengthInner; + // 104-107 + uint32_t nonTailStrideOutter; + // 108-111 + uint32_t nonTailStrideInner; + // 112-115 + uint16_t tailNumOutter; + uint16_t tailNumInner; + // 116-119 + uint32_t tailLengthInner; + // 120-123 + uint32_t tailStrideOutter; + // 124-127 + uint32_t tailStrideInner; +} rtFftsPlusDataCtx_t; + +// at start context +typedef struct tagFftsPlusAtStartCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t rs1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t rs2; + uint8_t rs3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t rs4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t rs5; + uint16_t rs6; + // 68-71 + uint16_t rs7; + uint16_t rs8; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint16_t threadIdInit; + uint16_t threadWindowSize; + // 80-127 + uint32_t res9[12]; +} rtFftsPlusAtStartCtx_t; + +// at end context +#define RT_CTX_SUCC_AT_START_SLOT_NUM 12 +#define RT_CTX_SUCC_OUT_LABEL_SLOT_NUM 12 + +typedef struct tagFftsPlusAtEndCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t atStartSlotNumber; + uint8_t outLabelSlotNumber : 7; + uint8_t aten : 1; + // 4-7 + uint8_t res1; + uint8_t res2; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res3; + // 12-59 + uint16_t succAtStartSlot[RT_CTX_SUCC_AT_START_SLOT_NUM]; + uint16_t succOutLabelSlot[RT_CTX_SUCC_OUT_LABEL_SLOT_NUM]; + // 60-63 + uint16_t res4; + uint16_t res5; + // 64-67 + uint16_t res6; + uint16_t res7; + // 68-71 + uint16_t res8; + uint16_t res9; + // 72-75 + uint16_t threadId; + uint16_t res10; + // 76-79 + uint16_t res11; + uint16_t res12; + // 80-127 + uint32_t res13[12]; +} rtFftsPlusAtEndCtx_t; + +// label context +typedef struct tagFftsPlusLabelCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t res2; + uint8_t res3; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res4; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-71 + uint32_t res5[2]; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + + // 76-127 + uint32_t res6[13]; +} rtFftsPlusLabelCtx_t; + +// case switch context +typedef struct tagFftsPlusCaseSwitchCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t resv0 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t startLabelId; + uint8_t labelListLen; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t resv1; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t resv2 : 15; + uint16_t atm : 1; + uint16_t resv3; + // 68-71 + uint32_t resv4; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint8_t arSize : 3; + uint8_t snoop : 1; + uint8_t resv5 : 4; + uint8_t arCache : 4; + uint8_t arProt : 3; + uint8_t va : 1; + uint16_t resv6; + // 80-83 + uint32_t loadAddress0BaseL; + // 84-87 + uint32_t loadAddress0BaseH : 17; + uint32_t resv7 : 14; + uint32_t ld0En : 1; + // 88-91 + uint32_t loadAddress0Offset; + // 92-95 + uint32_t resv8; + // 96-99 + uint32_t loadAddress1BaseL; + // 100-103 + uint32_t loadAddress1BaseH : 17; + uint32_t resv9 : 14; + uint32_t ld1En : 1; + // 104-107 + uint32_t loadAddress1Offset; + // 108-127 + uint32_t resv10[5]; +} rtFftsPlusCaseSwitchCtx_t; + +// case default context +typedef struct tagFftsPlusCaseDefCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t rs0 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t startLabelId; + uint8_t labelListLen; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t rs1; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint16_t rs2; + uint16_t rs3; + // 68-127 + uint32_t rs4[15]; +} rtFftsPlusCaseDefCtx_t; + +// condition switch context +#define RT_CTX_TRUE_SUCCESSOR_NUM 13 +#define RT_CTX_FALSE_SUCCESSOR_NUM 13 + +typedef struct tagFftsPlusCondSwitchCtx { + // 0-3 bytes + uint16_t contextType; + uint8_t trueSuccessorNum; + uint8_t falseSuccessorNum : 7; + uint8_t aten : 1; + // 4-7 + uint8_t condition; + uint8_t res1; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res2; + // 12-63 + uint16_t trueSuccessorList[RT_CTX_TRUE_SUCCESSOR_NUM]; + uint16_t falseSuccessorList[RT_CTX_FALSE_SUCCESSOR_NUM]; + // 64-67 + uint16_t res3 : 15; + uint16_t atm : 1; + uint16_t res4; + // 68-71 + uint32_t res5; + // 72-75 + uint16_t threadId; + uint16_t threadDim; + // 76-79 + uint8_t arSize : 3; + uint8_t snoop : 1; + uint8_t res6 : 4; + uint8_t arCache : 4; + uint8_t arProt : 3; + uint8_t va : 1; + uint16_t res7; + // 80-83 + uint32_t loadAddress0BaseL; + // 84-87 + uint32_t loadAddress0BaseH : 17; + uint32_t res8 : 14; + uint32_t ld0En : 1; + // 88-91 + uint32_t loadAddress0Offset; + // 92-95 + uint32_t res9; + // 96-99 + uint32_t loadAddress1BaseL; + // 100-103 + uint32_t loadAddress1BaseH : 17; + uint32_t res10 : 14; + uint32_t ld1En : 1; + // 104-107 + uint32_t loadAddress1Offset; + // 108-127 + uint32_t res11[3]; + uint32_t cmpValue1; + uint32_t cmpValue2; +} rtFftsPlusCondSwitchCtx_t; + +// ffts plus persistent cache context +typedef struct tagFftsPlusPersistentCacheCtx { + // 0- 3bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 7; + uint8_t aten : 1; + // 4-7 + uint8_t res2[2]; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint8_t res3[4]; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + // 64-67 + uint8_t persistentEnable : 1; + uint8_t res4 : 7; + uint8_t res5; + uint16_t persistentSize; + // 68-71 + uint32_t persistentId; + // 72-127 + uint32_t res6[14]; +} rtFftsPlusPersistentCacheCtx_t; + + +typedef struct tagFftsPlusDsaCtx { + // 0-3bytes + uint16_t contextType; + uint8_t successorNum; + uint8_t res1 : 6; + uint8_t dumpSwitch : 1; + uint8_t aten : 1; + // 4-7 + uint8_t res2[2]; + uint8_t predCntInit; + uint8_t predCnt; + // 8-11 + uint32_t res3; + // 12-63 + uint16_t successorList[RT_CTX_SUCCESSOR_NUM]; + + // bottom half 64B + // 0-3 bytes + uint16_t numberOffset; + uint16_t addressOffset; + // 4-7 bytes + uint32_t res5; + // 8-11 bytes + uint16_t threadId; + uint16_t threadDim; + // 12-15 bytes + uint32_t start : 1; + uint32_t functionType : 3; + uint32_t dataType : 3; + uint32_t algoType : 3; + uint32_t paramVldBitmap : 5; + uint32_t paramAddrValBitmap : 7; + uint32_t res6 : 10; + // 16-31 bytes + uint32_t dsaCfgResultAddrLow; + uint32_t dsaCfgResultAddrHigh; + uint32_t dsaCfgStateAddrLow; + uint32_t dsaCfgStateAddrHigh; + // 32-47 bytes + uint32_t dsaCfgParamAddrLow; + uint32_t dsaCfgParamAddrHigh; + uint32_t dsaCfgSeedLow; + uint32_t dsaCfgSeedHigh; + // 48-63 bytes + uint32_t dsaCfgNumberLow; + uint32_t dsaCfgNumberHigh; + uint32_t res7[2]; +} rtFftsPlusDsaCtx_t; + +#pragma pack(pop) + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif +#endif // CCE_RUNTIME_RT_FFTS_PLUS_DEFINE_H diff --git a/inc/runtime/runtime/rt_mem_queue.h b/inc/runtime/runtime/rt_mem_queue.h new file mode 100644 index 0000000000000000000000000000000000000000..cc653faef6fdbe2350761b66af20fd11783b93b2 --- /dev/null +++ b/inc/runtime/runtime/rt_mem_queue.h @@ -0,0 +1,879 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: mbuf and queue interface + */ + +#ifndef CCE_RUNTIME_RT_MEM_QUEUE_H +#define CCE_RUNTIME_RT_MEM_QUEUE_H + +#include "base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +#define RT_MQ_MAX_NAME_LEN 128 // same as driver's +#define RT_MQ_DEPTH_MIN 2U +#define RT_MQ_MODE_PUSH 1 +#define RT_MQ_MODE_PULL 2 +#define RT_MQ_MODE_DEFAULT RT_MQ_MODE_PUSH +#define RT_EVENT_SUMMARY_RSV 4 +#define RT_EVENT_MAX_MSG_LEN 128 +#define RT_MQ_LOCAL_QUEUE_DEPLOY 1U +#define RT_MQ_CLIENT_QUEUE_DEPLOY 0U + +#define RT_BUFF_SP_HUGEPAGE_ONLY (1 << 1) + +typedef struct tagMemQueueInfo { + int32_t id; + int32_t size; + uint32_t depth; + int32_t status; +} rtMemQueueInfo_t; + +typedef struct tagMemQueueAttr { + char_t name[RT_MQ_MAX_NAME_LEN]; + uint32_t depth; + uint32_t workMode; + uint32_t flowCtrlDropTime; + bool flowCtrlFlag; + bool overWriteFlag; + uint32_t deployType : 1; + uint32_t resv : 31; +} rtMemQueueAttr_t; + +typedef enum tagMemQueueSetCmdType { + RT_MQ_QUEUE_SET_WORK_MODE, + RT_MQ_QUEUE_ENABLE_LOCAL_QUEUE, + RT_MQ_QUEUE_SET_CMD_MAX, +} rtMemQueueSetCmdType; + +typedef struct tagMemQueueSetInputPara { + void *inBuff; + uint32_t inLen; +} rtMemQueueSetInputPara; + +typedef struct tagMemQueueShareAttr { + uint32_t manage : 1; + uint32_t read : 1; + uint32_t write : 1; + uint32_t rsv : 29; +} rtMemQueueShareAttr_t; + +typedef struct tagMemQueueBuffInfo { + void *addr; + size_t len; +} rtMemQueueBuffInfo; + +typedef struct tagMemQueueBuff { + void *contextAddr; + size_t contextLen; + rtMemQueueBuffInfo *buffInfo; + uint32_t buffCount; +} rtMemQueueBuff_t; + + +typedef enum tagMemQueueQueryCmd { + RT_MQ_QUERY_QUE_ATTR_OF_CUR_PROC = 0, // input is qid(4bytes), output is rtMemQueueShareAttr_t + RT_MQ_QUERY_QUES_OF_CUR_PROC = 1, + RT_MQ_QUERY_CMD_MAX = 2 +} rtMemQueueQueryCmd_t; + +#define RT_MQ_EVENT_QS_MSG 27 // same as driver's + +#define RT_MQ_SCHED_PRIORITY_LEVEL0 0 // same as driver's +#define RT_MQ_SCHED_PRIORITY_LEVEL1 1 +#define RT_MQ_SCHED_PRIORITY_LEVEL2 2 +#define RT_MQ_SCHED_PRIORITY_LEVEL3 3 +#define RT_MQ_SCHED_PRIORITY_LEVEL4 4 +#define RT_MQ_SCHED_PRIORITY_LEVEL5 5 +#define RT_MQ_SCHED_PRIORITY_LEVEL6 6 +#define RT_MQ_SCHED_PRIORITY_LEVEL7 7 + +/* Events can be released between different systems. This parameter specifies the destination type of events + to be released. The destination type is defined based on the CPU type of the destination system. */ +#define RT_MQ_DST_ENGINE_ACPU_DEVICE 0 // device AICPU, same as driver's +#define RT_MQ_DST_ENGINE_ACPU_HOST 1 // Host AICPU +#define RT_MQ_DST_ENGINE_CCPU_DEVICE 2 // device CtrlCPU +#define RT_MQ_DST_ENGINE_CCPU_HOST 3 // Host CtrlCPU +#define RT_MQ_DST_ENGINE_DCPU_DEVICE 4 // device DataCPU +#define RT_MQ_DST_ENGINE_TS_CPU 5 // device TS CPU +#define RT_MQ_DST_ENGINE_DVPP_CPU 6 // device DVPP CPU + +#define RT_MQ_SCHED_EVENT_QS_MSG 25 // same as driver's EVENT_QS_MSG +#define RT_MQ_SCHED_EVENT_DRV_CUSTOM_MSG 56 // drvier's custom msg event + +/* When the destination engine is AICPU, select a policy. + ONLY: The command is executed only on the local AICPU. + FIRST: The local AICPU is preferentially executed. If the local AICPU is busy, the remote AICPU can be used. */ +#define RT_SCHEDULE_POLICY_ONLY 0 // same as driver's schedule_policy +#define RT_SCHEDULE_POLICY_FIRST 1 // same as driver's schedule_policy + + +typedef struct tagEschedEventSummary { + int32_t pid; // dst PID + uint32_t grpId; + int32_t eventId; // only RT_MQ_SCHED_EVENT_QS_MSG is supported + uint32_t subeventId; + uint32_t msgLen; + char_t *msg; + uint32_t dstEngine; // dst system cpu type + int32_t policy; // RT_SCHEDULE_POLICY_ONLY or RT_SCHEDULE_POLICY_FIRST +} rtEschedEventSummary_t; + +typedef struct tagEschedEventReply { + char_t *buf; + uint32_t bufLen; + uint32_t replyLen; // output, ack msg len, same with msgLen in halEschedAckEvent +} rtEschedEventReply_t; + +#define RT_DEV_PROCESS_CP1 0 +#define RT_DEV_PROCESS_CP2 1 +#define RT_DEV_PROCESS_DEV_ONLY 2 +#define RT_DEV_PROCESS_QS 3 +#define RT_DEV_PROCESS_SIGN_LENGTH 49 + +typedef struct tagBindHostpidInfo { + int32_t hostPid; + uint32_t vfid; + uint32_t chipId; + int32_t cpType; // type of custom-process, see RT_DEV_PROCESS_XXX +} rtBindHostpidInfo_t; + +#define RT_MEM_BUFF_MAX_CFG_NUM 64 + +typedef struct { + uint32_t cfgId; // cfg id, start from 0 + uint32_t totalSize; // one zone total size + uint32_t blkSize; // blk size, 2^n (0, 2M] + uint32_t maxBufSize; // max size can alloc from zone + uint32_t pageType; // page type, small page / huge page + int32_t elasticEnable; // elastic enable + int32_t elasticRate; + int32_t elasticRateMax; + int32_t elasticHighLevel; + int32_t elasticLowLevel; +} rtMemZoneCfg_t; + +typedef struct { + rtMemZoneCfg_t cfg[RT_MEM_BUFF_MAX_CFG_NUM]; +}rtMemBuffCfg_t; + +typedef enum rt_queue_work_mode { + RT_QUEUE_MODE_PUSH = 1, + RT_QUEUE_MODE_PULL, +} RT_QUEUE_WORK_MODE; + +typedef void *rtMbufPtr_t; + +typedef enum rtEventIdType { + RT_EVENT_RANDOM_KERNEL, /* Random operator event */ + RT_EVENT_DVPP_MSG, /* operator events commited by DVPP */ + RT_EVENT_FR_MSG, /* operator events commited by Feature retrieves */ + RT_EVENT_TS_HWTS_KERNEL, /* operator events commited by ts/hwts */ + RT_EVENT_AICPU_MSG, /* aicpu activates its own stream events */ + RT_EVENT_TS_CTRL_MSG, /* controls message events of TS */ + RT_EVENT_QUEUE_ENQUEUE, /* entry event of Queue(consumer) */ + RT_EVENT_QUEUE_FULL_TO_NOT_FULL, /* full to non-full events of Queue(producers) */ + RT_EVENT_QUEUE_EMPTY_TO_NOT_EMPTY, /* empty to non-empty event of Queue(consumer) */ + RT_EVENT_TDT_ENQUEUE, /* data entry event of TDT */ + RT_EVENT_TIMER, /* ros timer */ + RT_EVENT_HCFI_SCHED_MSG, /* scheduling events of HCFI */ + RT_EVENT_HCFI_EXEC_MSG, /* performs the event of HCFI */ + RT_EVENT_ROS_MSG_LEVEL0, + RT_EVENT_ROS_MSG_LEVEL1, + RT_EVENT_ROS_MSG_LEVEL2, + RT_EVENT_ACPU_MSG_TYPE0, + RT_EVENT_ACPU_MSG_TYPE1, + RT_EVENT_ACPU_MSG_TYPE2, + RT_EVENT_CCPU_CTRL_MSG, + RT_EVENT_SPLIT_KERNEL, + RT_EVENT_DVPP_MPI_MSG, + RT_EVENT_CDQ_MSG, + /* Add a new event here */ + RT_EVENT_TEST, /* Reserve for test */ + RT_EVENT_MAX_NUM +} rtEventIdType_t; + +typedef enum rtGroupType { + /* Bound to a AICPU, multiple threads can be woken up simultaneously within a group */ + RT_GRP_TYPE_BIND_DP_CPU = 1, + RT_GRP_TYPE_BIND_CP_CPU, /* Bind to the control CPU */ + RT_GRP_TYPE_BIND_DP_CPU_EXCLUSIVE /* Bound to a AICPU, intra-group threads are mutex awakened */ +} rtGroupType_t; + +typedef struct tagInitFlowGwInfo { + const char_t *groupName; + uint64_t schedPolicy; + uint64_t reschedInterval; + char_t rsv[128]; +} rtInitFlowGwInfo_t; + +typedef enum tagBuffGetCmdType { + RT_BUFF_GET_MBUF_TIMEOUT_INFO = 0, + RT_BUFF_GET_MBUF_USE_INFO = 1, + RT_BUFF_GET_MBUF_TYPE_INFO = 2, + RT_BUFF_GET_MBUF_BUILD_INFO = 3, + RT_BUFF_GET_MAX +} rtBuffGetCmdType; + +typedef struct tagBuffBuildInfo { + uint32_t status; /* 0: buff unbuild 1: buff build */ +} rtBuffBuildInfo; + +RTS_API rtError_t rtBuffGetInfo(rtBuffGetCmdType type, const void * const inBuff, uint32_t inLen, + void * const outBuff, uint32_t * const outLen); + +/** + * @ingroup rt_mem_queue + * @brief init queue schedule + * @param [in] devId the logical device id + * @param [in] grpName the name of group, can be nullptr + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueInitQS(int32_t devId, const char_t *grpName); + +/** + * @ingroup rt_mem_queue + * @brief init flow gateway + * @param [in] devId the logical device id + * @param [in] initInfo Initialization parameters + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueInitFlowGw(int32_t devId, const rtInitFlowGwInfo_t * const initInfo); + +/** + * @ingroup rt_mem_queue + * @brief create mbuf queue + * @param [in] devId the logical device id + * @param [in] queAttr attribute of queue + * @param [out] qid queue id + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueCreate(int32_t devId, const rtMemQueueAttr_t *queAttr, uint32_t *qid); + +/** + * @ingroup rt_mem_queue + * @brief mbuf queue set + * @param [in] devId the logical device id + * @param [in] cmd cmd type of queue set + * @param [in] input input param of queue set + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueSet(int32_t devId, rtMemQueueSetCmdType cmd, const rtMemQueueSetInputPara *input); + +/** + * @ingroup rt_mem_queue + * @brief destroy mbuf queue + * @param [in] devId the logical device id + * @param [in] qid queue id + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueDestroy(int32_t devId, uint32_t qid); + +/** + * @ingroup rt_mem_queue + * @brief destroy mbuf queue init + * @param [in] devId the logical device id + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueInit(int32_t devId); + +/** +* @ingroup rt_mem_queue +* @brief queue reset +* @attention null +* @param [in] qid: qid +* @param [in] devId: logic devid +* @return 0 for success, others for fail +**/ +RTS_API rtError_t rtMemQueueReset(int32_t devId, uint32_t qid); + +/** + * @ingroup rt_mem_queue + * @brief enqueue memBuf + * @param [in] devId the logical device id + * @param [in] qid queue id + * @param [in] memBuf enqueue memBuf + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueEnQueue(int32_t devId, uint32_t qid, void *memBuf); + + +/** + * @ingroup rt_mem_queue + * @brief dequeue memBuf + * @param [in] devId the logical device id + * @param [in] qid queue id + * @param [out] memBuf dequeue memBuf + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueDeQueue(int32_t devId, uint32_t qid, void **memBuf); + +/** + * @ingroup rt_mem_queue + * @brief enqueu peek + * @param [in] devId the logical device id + * @param [in] qid queue id + * @param [out] bufLen length of mbuf in queue + * @param [in] timeout peek timeout (ms), -1: wait all the time until peeking success + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueuePeek(int32_t devId, uint32_t qid, size_t *bufLen, int32_t timeout); + +/** + * @ingroup rt_mem_queue + * @brief enqueu buff + * @param [in] devId the logical device id + * @param [in] qid queue id + * @param [in] inBuf enqueue buff + * @param [in] timeout enqueue timeout (ms), -1: wait all the time until enqueue success + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueEnQueueBuff(int32_t devId, uint32_t qid, rtMemQueueBuff_t *inBuf, int32_t timeout); + +/** + * @ingroup rt_mem_queue + * @brief enqueu buff + * @param [in] devId the logical device id + * @param [in] qid queue id + * @param [out] outBuf dequeue buff + * @param [in] timeout dequeue timeout (ms), -1: wait all the time until dequeue success + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueDeQueueBuff(int32_t devId, uint32_t qid, rtMemQueueBuff_t *outBuf, int32_t timeout); + +/** + * @ingroup rt_mem_queue + * @brief query current queue info + * @param [in] devId the logical device id + * @param [in] qid queue id + * @param [out] queInfo current queue info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtMemQueueQueryInfo(int32_t devId, uint32_t qid, rtMemQueueInfo_t *queInfo); + +/** +* @ingroup rt_mem_queue +* @brief query queue status +* @param [in] devId: the logical device id +* @param [in] cmd: query cmd +* @param [in] inBuff: input buff +* @param [in] inLen: the length of input +* @param [in|out] outBuff: output buff +* @param [in|out] outLen: the length of output +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMemQueueQuery(int32_t devId, rtMemQueueQueryCmd_t cmd, const void *inBuff, uint32_t inLen, + void *outBuff, uint32_t *outLen); + +/** +* @ingroup rt_mem_queue +* @brief grant queue +* @param [in] devId: logic devid +* @param [in] qid: queue id +* @param [in] pid: pid +* @param [in] attr: queue share attr +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMemQueueGrant(int32_t devId, uint32_t qid, int32_t pid, rtMemQueueShareAttr_t *attr); + +/** +* @ingroup rt_mem_queue +* @brief attach queue +* @param [in] devId: logic devid +* @param [in] qid: queue id +* @param [in] timeOut: timeOut +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMemQueueAttach(int32_t devId, uint32_t qid, int32_t timeOut); + +/** +* @ingroup rt_mem_queue +* @brief Commit the event to a specific process +* @param [in] devId: logic devid +* @param [in] evt: event summary info +* @param [out] ack: event reply info +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtEschedSubmitEventSync(int32_t devId, rtEschedEventSummary_t *evt, + rtEschedEventReply_t *ack); + +/** +* @ingroup rt_mem_queue +* @brief query device proccess id +* @param [in] info: see struct rtBindHostpidInfo_t +* @param [out] devPid: device proccess id +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtQueryDevPid(rtBindHostpidInfo_t *info, int32_t *devPid); + +/** +* @ingroup rt_mem_queue +* @brief device buff init +* @param [in] cfg, init cfg +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufInit(rtMemBuffCfg_t *cfg); + +/** +* @ingroup rt_mem_queue +* @brief alloc buff +* @param [out] buff: The buff id the shared memory pointer applied by calling halBuffAlloc and halBuffAllocByPool +* @param [in] size: The amount of memory space requested +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtBuffAlloc(const uint64_t size, void **buff); + +/** +* @ingroup rt_mem_queue +* @brief determine whether buff id is the shared memory pointer applied by calling halBuffAlloc and halBuffAllocByPool +* @param [in] buff: The buff id the shared memory pointer applied by calling halBuffAlloc and halBuffAllocByPool +* @param [in] size: The amount of memory space requested +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtBuffConfirm(void *buff, const uint64_t size); + +/** +* @ingroup rt_mem_queue +* @brief alloc buff +* @param [out] mbufPtr: buff addr alloced +* @param [in] buff: The buff must be the shared memory pointer applied by calling halBuffAlloc and halBuffAllocByPool +* @param [in] size: The amount of memory space requested +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufBuild(void *buff, const uint64_t size, rtMbufPtr_t *mbufPtr); + +/** +* @ingroup rt_mem_queue +* @brief alloc buff +* @param [out] memBuf: buff addr alloced +* @param [in] size: The amount of memory space requested +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufAlloc(rtMbufPtr_t *memBuf, uint64_t size); + +/** +* @ingroup rt_mem_queue +* @brief alloc buff +* @param [out] memBuf: buff addr alloced +* @param [in] size: The amount of memory space requested +* @param [in] flag: Huge page flag(bit0~31: mem type, bit32~bit35: devid, bit36~63: resv) +* @param [in] grpId: group id +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufAllocEx(rtMbufPtr_t *memBuf, uint64_t size, uint64_t flag, int32_t grpId); + +/** +* @ingroup rt_mem_queue +* @brief free buff +* @param [in] buff: The buff id the shared memory pointer applied by calling halBuffAlloc and halBuffAllocByPool +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtBuffFree(void *buff); + +/** +* @ingroup rt_mem_queue +* @brief free the head of mbufPtr +* @param [in] mbufPtr: buff addr alloced +* @param [out] buff: The buffer of mbuPtr +* @param [out] size: The amount of memory space of buffer +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufUnBuild(const rtMbufPtr_t mbufPtr, void **buff, uint64_t *size); + +/** +* @ingroup rt_mem_queue +* @brief get mbuffer +* @param [in] mbufPtr: buff addr alloced +* @param [out] buff: The buffer of mbuPtr +* @param [in] size: The amount of memory space of buffer +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtBuffGet(const rtMbufPtr_t mbufPtr, void *buff, const uint64_t size); + +/** +* @ingroup rt_mem_queue +* @brief put mbuffer +* @param [in] mbufPtr: buff addr alloced +* @param [out] buff: The buffer of mbuPtr +* @param [out] size: The amount of memory space of buffer +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtBuffPut(const rtMbufPtr_t mbufPtr, void *buff); + + +/** +* @ingroup rt_mem_queue +* @brief free buff +* @param [in] memBuf: buff addr to be freed +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufFree(rtMbufPtr_t memBuf); + +/** +* @ingroup rt_mem_queue +* @brief set Data len of Mbuf +* @param [in] memBuf: Mbuf addr +* @param [in] len: data len +* @return RT_ERROR_NONE for success, others for fail +*/ +RTS_API rtError_t rtMbufSetDataLen(rtMbufPtr_t memBuf, uint64_t len); + +/** +* @ingroup rt_mem_queue +* @brief set Data len of Mbuf +* @param [in] memBuf: Mbuf addr +* @param [out] len: data len +* @return RT_ERROR_NONE for success, others for fail +*/ +RTS_API rtError_t rtMbufGetDataLen(rtMbufPtr_t memBuf, uint64_t *len); + + +/** +* @ingroup rt_mem_queue +* @brief get Data addr of Mbuf +* @param [in] memBuf: Mbuf addr +* @param [out] buf: Mbuf data addr +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufGetBuffAddr(rtMbufPtr_t memBuf, void **buf); + +/** +* @ingroup rt_mem_queue +* @brief get total Buffer size of Mbuf +* @param [in] memBuf: Mbuf addr +* @param [out] totalSize: total buffer size of Mbuf +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufGetBuffSize(rtMbufPtr_t memBuf, uint64_t *totalSize); + +/** +* @ingroup rt_mem_queue +* @brief Get the address and length of its user_data from the specified Mbuf +* @param [in] memBuf: Mbuf addr +* @param [out] priv: address of its user_data +* @param [out] size: length of its user_data +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufGetPrivInfo(rtMbufPtr_t memBuf, void **priv, uint64_t *size); + +/** +* @ingroup rt_mem_queue +* @brief copy buf ref +* @param [in] memBuf: src buff addr +* @param [out] newMemBuf: des buff addr +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufCopyBufRef(rtMbufPtr_t memBuf, rtMbufPtr_t *newMemBuf); + +/** +* @ingroup rt_mem_queue +* @brief append mbuf to mbuf chain +* @param [inout] memBufChainHead, the mbuf chain head +* @param [in] memBuf, the mbuf to append +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufChainAppend(rtMbufPtr_t memBufChainHead, rtMbufPtr_t memBuf); + +/** +* @ingroup rt_mem_queue +* @brief get mbuf num in mbuf chain +* @param [in] memBufChainHead, the mbuf chain head +* @param [out] num, the mbuf chain size +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufChainGetMbufNum(rtMbufPtr_t memBufChainHead, uint32_t *num); + +/** +* @ingroup rt_mem_queue +* @brief get mbuf in mbuf chain +* @param [in] mbufChainHead, the mbuf chain head +* @param [in] index, the mbuf index which to get in chain +* @param [out] mbuf, the mbuf to get +* @return RT_ERROR_NONE for ok +*/ +RTS_API rtError_t rtMbufChainGetMbuf(rtMbufPtr_t memBufChainHead, uint32_t index, rtMbufPtr_t *memBuf); + +#define RT_MEM_GRP_NAME_LEN 32 // it must be same as driver define BUFF_GRP_NAME_LEN +#define RT_MEM_CACHE_MAX_NUM 1024 // it must be same as driver define BUFF_CACHE_MAX_NUM + +// mem group +typedef struct { + uint64_t maxMemSize; // max buf size in grp, in KB. = 0 means no limit + uint32_t cacheAllocFlag; + uint32_t addGrpTimeout; + int32_t rsv[RT_MEM_GRP_NAME_LEN - 2]; +} rtMemGrpConfig_t; + +typedef struct { + uint32_t admin : 1; // admin permission, can add other proc to grp + uint32_t read : 1; // read only permission + uint32_t write : 1; // read and write permission + uint32_t alloc : 1; // alloc permission (have read and write permission) + uint32_t rsv : 28; +} rtMemGrpShareAttr_t; + +typedef enum tagGroupQueryCmdType { + RT_MEM_GRP_QUERY_GROUP, /* query grp info include proc and permission */ + RT_MEM_GRP_QUERY_GROUPS_OF_PROCESS, /* query process all grp */ + RT_MEM_GRP_QUERY_GROUP_ID, /* query grp ID by grp name */ + RT_MEM_GRP_QUERY_GROUP_ADDR_INFO, /* query group addr info */ + RT_MEM_GRP_QUERY_CMD_MAX +} rtGroupQueryCmdType; + +typedef struct { + int32_t pid; +} rtMemGrpQueryByProc_t; // cmd: GRP_QUERY_GROUPS_OF_PROCESS + +typedef struct { + char grpName[RT_MEM_GRP_NAME_LEN]; +} rtMemGrpQueryGroupId_t; // cmd: RT_MEM_GRP_QUERY_GROUP_ID + +typedef struct { + char grpName[RT_MEM_GRP_NAME_LEN]; + uint32_t devId; +} rtMemGrpQueryGroupAddrPara_t; /* cmd: RT_MEM_GRP_QUERY_GROUP_ADDR_PARA */ + +typedef struct { + int32_t cmd; + union { + rtMemGrpQueryByProc_t grpQueryByProc; // cmd: GRP_QUERY_GROUPS_OF_PROCESS + rtMemGrpQueryGroupId_t grpQueryGroupId; // cmd: RT_MEM_GRP_QUERY_GROUP_ID + rtMemGrpQueryGroupAddrPara_t grpQueryGroupAddrPara; // cmd: RT_MEM_GRP_QUERY_GROUP_ADDR_PARA + }; +} rtMemGrpQueryInput_t; + +typedef struct { + char_t groupName[RT_MEM_GRP_NAME_LEN]; // group name + rtMemGrpShareAttr_t attr; // process in group attribute +} rtMemGrpOfProc_t; // cmd: GRP_QUERY_GROUPS_OF_PROCESS + +typedef struct { + int32_t groupId; // group id +} rtMemGrpQueryGroupIdInfo_t; // cmd: RT_MEM_GRP_QUERY_GROUP_ID + +typedef struct { + uint64_t addr; /* cache memory addr */ + uint64_t size; /* cache memory size */ +} rtMemGrpQueryGroupAddrInfo_t; /* cmd: RT_MEM_GRP_QUERY_GROUP_ADDR_PARA */ + +typedef struct { + size_t maxNum; // max number of result + size_t resultNum; // if the number of results exceeds 'maxNum', only 'maxNum' results are filled in buffer + union { + rtMemGrpOfProc_t *groupsOfProc; // cmd: GRP_QUERY_GROUPS_OF_PROCESS + rtMemGrpQueryGroupIdInfo_t *groupIdInfo; // cmd: RT_MEM_GRP_QUERY_GROUP_ID + rtMemGrpQueryGroupAddrInfo_t *groupAddrInfo; // cmd: RT_MEM_GRP_QUERY_GROUP_ADDR_PARA + }; +} rtMemGrpQueryOutput_t; + +typedef struct { + uint64_t memSize; + uint32_t memFlag; + int32_t rsv[RT_MEM_CACHE_MAX_NUM]; +} rtMemGrpCacheAllocPara; + +/** +* @ingroup rt_mem_queue +* @brief create mem group +* @attention null +* @param [in] name, group name +* @param [in] cfg, group cfg +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtMemGrpCreate(const char_t *name, const rtMemGrpConfig_t *cfg); + +/** +* @ingroup rt_mem_queue +* @brief alloc mem group cache +* @attention null +* @param [in] name, group name +* @param [in] devId, device id +* @param [in] para, mem group cache alloc para +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtMemGrpCacheAlloc(const char_t *name, int32_t devId, const rtMemGrpCacheAllocPara *para); + +/** +* @ingroup rt_mem_queue +* @brief add process to group +* @param [in] name, group name +* @param [in] pid, process id +* @param [in] attr, process permission in group +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtMemGrpAddProc(const char_t *name, int32_t pid, const rtMemGrpShareAttr_t *attr); + +/** +* @ingroup rt_mem_queue +* @brief attach proccess to check permission in group +* @param [in] name, group name +* @param [in] timeout, time out ms +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtMemGrpAttach(const char_t *name, int32_t timeout); + +/** +* @ingroup rt_mem_queue +* @brief buff group query +* @param [in] input, query input +* @param [in|out] output, query output +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtMemGrpQuery(const rtMemGrpQueryInput_t *input, rtMemGrpQueryOutput_t *output); + +/** +* @ingroup rt_mem_queue +* @brief buff group query +* @param [in] devId, cdevice id +* @param [in] name, group name +* @param [out] qid, queue id +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtMemQueueGetQidByName(int32_t devId, const char_t *name, uint32_t *qId); + +/** +* @ingroup rt_mem_queue +* @brief esched attach device +* @param [in] devId, device id +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedAttachDevice(int32_t devId); + +/** +* @ingroup rt_mem_queue +* @brief esched dettach device +* @param [in] devId, device id +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedDettachDevice(int32_t devId); + +/** +* @ingroup rt_mem_queue +* @brief esched wait event +* @param [in] devId, device id +* @param [in] grpId, group id +* @param [in] threadId, thread id +* @param [in] timeout +* @param [in] evt +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedWaitEvent(int32_t devId, uint32_t grpId, uint32_t threadId, + int32_t timeout, rtEschedEventSummary_t *evt); + +/** +* @ingroup rt_mem_queue +* @brief esched create group +* @param [in] devId, device id +* @param [in] grpId, group id +* @param [in] type, group type +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedCreateGrp(int32_t devId, uint32_t grpId, rtGroupType_t type); + +/** +* @ingroup rt_mem_queue +* @brief esched submit event +* @param [in] devId, device id +* @param [in] evt +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedSubmitEvent(int32_t devId, rtEschedEventSummary_t *evt); + +/** +* @ingroup rt_mem_queue +* @brief esched submit event +* @param [in] devId, device id +* @param [in] grpId, group id +* @param [in] threadId, thread id +* @param [in] eventBitmap +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedSubscribeEvent(int32_t devId, uint32_t grpId, uint32_t threadId, uint64_t eventBitmap); + +/** +* @ingroup rtEschedAckEvent +* @brief esched ack event +* @param [in] devId, device id +* @param [in] evtId, event type +* @param [in] subEvtId, sub event type +* @param [in] msg, message info +* @param [in] len, message length +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedAckEvent(int32_t devId, rtEventIdType_t evtId, + uint32_t subEvtId, char_t *msg, uint32_t len); + +/** +* @ingroup rtQueueSubF2NFEvent +* @brief full to not full event +* @param [in] devId, device id +* @param [in] qid, queue id +* @param [in] groupId, group id +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtQueueSubF2NFEvent(int32_t devId, uint32_t qId, uint32_t groupId); + +/** +* @ingroup rtQueueSubscribe +* @brief queue subscribe +* @param [in] devId, device id +* @param [in] qid, queue id +* @param [in] groupId, group id +* @param [in] type + +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtQueueSubscribe(int32_t devId, uint32_t qId, uint32_t groupId, int32_t type); + +/** +* @ingroup rtBufEventTrigger +* @brief buf event trigger +* @param [in] name, group name +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtBufEventTrigger(const char_t *name); + +// EschedQueryInfo group +#define EVENT_MAX_GRP_NAME_LEN 16 +typedef enum tagEschedQueryType { + RT_QUERY_TYPE_LOCAL_GRP_ID, + RT_QUERY_TYPE_REMOTE_GRP_ID, + RT_QUERY_TYPE_MAX +} rtEschedQueryType; + +typedef struct tagEschedInputInfo { + void *inBuff; + unsigned int inLen; +} rtEschedInputInfo; + +typedef struct tagEschedOutputInfo { + void *outBuff; + unsigned int outLen; +} rtEschedOutputInfo; + +typedef struct tagEschedQueryGidInput { + int pid; + char grpName[EVENT_MAX_GRP_NAME_LEN]; +} rtEschedQueryGidInput; + +typedef struct tagEschedQueryGidOutput { + unsigned int grpId; +} rtEschedQueryGidOutput; + +/** +* @ingroup rtEschedQueryInfo +* @brief query esched info, such as grpid. +* @param [in] devId: logic devid +* @param [in] type: query info type +* @param [in] inPut: Input the corresponding data structure based on the type. +* @param [out] outPut: OutPut the corresponding data structure based on the type. +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtEschedQueryInfo(const uint32_t devId, const rtEschedQueryType type, + rtEschedInputInfo *inPut, rtEschedOutputInfo *outPut); +#if defined(__cplusplus) +} +#endif +#endif // CCE_RUNTIME_RT_MEM_QUEUE_H diff --git a/inc/runtime/runtime/rt_model.h b/inc/runtime/runtime/rt_model.h new file mode 100644 index 0000000000000000000000000000000000000000..51719f978092c8ab6a024545cb3a0d171e2db6ba --- /dev/null +++ b/inc/runtime/runtime/rt_model.h @@ -0,0 +1,733 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: rt_model.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_RT_MODEL_H +#define CCE_RUNTIME_RT_MODEL_H + +#include "base.h" +#include "rt_ffts_plus.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +typedef enum tagModelTaskType { + RT_MODEL_TASK_KERNEL = 0, + RT_MODEL_TASK_EVENT_RECORD, + RT_MODEL_TASK_EVENT_WAIT, + RT_MODEL_TASK_FUSION_START, + RT_MODEL_TASK_FUSION_END, + RT_MODEL_TASK_KERNEL_EX, + RT_MODEL_TASK_HCCL, + RT_MODEL_TASK_STREAM_SWITCH, + RT_MODEL_TASK_STREAM_ACTIVE, + RT_MODEL_TASK_LABEL_SET, + RT_MODEL_TASK_LABEL_SWITCH, + RT_MODEL_TASK_LABEL_GOTO, + RT_MODEL_TASK_PROFILER_TRACE, + RT_MODEL_TASK_MEMCPY_ASYNC, + RT_MODEL_TASK_NOTIFY_RECORD, + RT_MODEL_TASK_NOTIFY_WAIT, + RT_MODEL_TASK_REDUCE_ASYNC, + RT_MODEL_TASK_RDMA_SEND, + RT_MODEL_TASK_EVENT_RESET, + RT_MODEL_TASK_MODEL_END_GRAPH, + RT_MODEL_TASK_STREAM_SWITCH_N, + RT_MODEL_TASK_RDMA_DB_SEND, + RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, + RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, + RT_MODEL_TASK_STREAM_LABEL_GOTO, + RT_MODEL_TASK_MODEL_EXIT, + RT_MODEL_TASK_ALL_KERNEL, + RT_MODEL_TASK_PROFILER_TRACE_EX, + RT_MODEL_TASK_FFTS_TASK, + RT_MODEL_TASK_FFTS_PLUS_TASK, + RT_MODEL_TASK_DSA_TASK, + RT_MODEL_TASK_CMO, + RT_MODEL_TASK_BARRIER, + RT_MODEL_TASK_NPU_GET_FLOAT_STATUS, + RT_MODEL_TASK_NPU_CLEAR_FLOAT_STATUS, + RT_MODEL_TASK_DVPP, + RT_MODEL_TASK_NPU_GET_DEBUG_FLOAT_STATUS, + RT_MODEL_TASK_NPU_CLEAR_DEBUG_FLOAT_STATUS, + RT_MODEL_TASK_CMO_ADDR, + RT_MODEL_TASK_VECTOR_KERNEL, + RT_MODEL_TASK_VECTOR_ALL_KERNEL, + RT_MODEL_TASK_UPDATE, + RT_MODEL_TASK_NOP, + RT_MODEL_TASK_PREPROCESS_KERNEL, +} rtModelTaskType_t; + +typedef enum tagModelStreamType { + RT_MODEL_HEAD_STREAM = 0, + RT_MODEL_WAIT_ACTIVE_STREAM = 1 +} rtModelStreamType_t; + +typedef enum tagModelQueueFlag { + RT_MODEL_INPUT_QUEUE = 0, + RT_MODEL_OUTPUT_QUEUE = 1 +} rtModelQueueFlag_t; + +#define EXECUTOR_NONE (0x0U) +#define EXECUTOR_TS (0x01U) +#define EXECUTOR_AICPU (0x02U) + +/* + * @ingroup rt_model + * @brief debug flag for kernel exception dump + */ +#define RT_DEBUG_FLAG_AICORE_OVERFLOW (0x1U << 0U) +#define RT_DEBUG_FLAG_ATOMIC_ADD_OVERFLOW (0x1U << 1U) + +/** + * @ingroup + * @brief the type defination of aicpu model task command + */ +typedef enum tagTsAicpuModelCmd { + TS_AICPU_MODEL_LOAD = 1, + TS_AICPU_MODEL_EXECUTE, + TS_AICPU_MODEL_DESTROY, + TS_AICPU_MODEL_ABORT, + TS_AICPU_MODEL_RESERVED, +} tsAicpuModelCmd; + +typedef struct tagAicpuTaskInfo { + uint32_t taskID; + uint32_t streamID; + uint32_t kernelType; + uint64_t kernelName; + uint64_t kernelSo; + uint64_t paraBase; + uint32_t taskFlag; +} rtAicpuTaskInfo_t; + +typedef struct tagModelStreamInfo { + uint32_t streamID; + uint32_t streamFlag; +} rtModelStreamInfo_t; + +typedef struct tagModelQueueInfo { + uint32_t queueID; + uint32_t flag; +} rtModelQueueInfo_t; + +typedef struct tagAicpuModelInfo { + uint32_t moduleID; + uint32_t tsId; + uint16_t streamInfoNum; + uint16_t aicpuTaskNum; + uint64_t streamInfoPtr; + uint64_t aicpuTaskPtr; + uint16_t queueSize; + uint64_t queueInfoPtr; +} rtAicpuModelInfo_t; + +typedef struct tagKernelTaskInfo { + uint16_t blockDim; + uint16_t argsCount; + uint16_t argsSize; + uint16_t reserved; + const char_t *stubFunc; + uint8_t *smDesc; + const uint8_t *args; + uint16_t *argsOffset; +} rtKernelTaskInfo_t; + +typedef struct tagAllKernelTaskInfo { + uint16_t blockDim; + uint16_t argsCount; + uint16_t argsSize; + uint16_t reserved; + uint64_t tilingKey; + void *handle; + uint8_t *smDesc; + const uint8_t *args; + uint16_t *argsOffset; +} rtAllKernelTaskInfo_t; + +typedef struct tagKernelTaskInfoEx { + uint32_t flags; + uint32_t argsSize; + const void *args; + uint32_t reserved[6]; +} rtKernelTaskInfoEx_t; + +typedef struct tagEventTaskInfo { + uint32_t eventID; + uint32_t reserved[9]; +} rtEventTaskInfo_t; + +typedef struct tagStreamSwitchTaskInfo { + int64_t value; + uint64_t pValuePtr; + uint32_t trueStreamID; + uint32_t dataType; + uint32_t reserved[4]; +} rtStreamSwitchTaskInfo_t; + +typedef struct tagStreamSwitchNTaskInfo { + uint64_t pValuePtr; + uint64_t pTrueStreamPtr; + uint32_t size; + uint32_t elementSize; + uint32_t dataType; + uint32_t reserved[3]; +} rtStreamSwitchNTaskInfo_t; + +typedef struct tagStreamActiveTaskInfo { + uint32_t activeStreamID; + uint32_t reserved[9]; +} rtStreamActiveTaskInfo_t; + +typedef struct tagSetTaskInfo { + uint16_t labelId; + uint32_t reserved[9]; +} rtLabelSetTaskInfo_t; + +typedef struct tagSwitchTaskInfo { + uint32_t value; + uint32_t reserved[9]; +} rtLabelSwitchTaskInfo_t; + +typedef struct tagLabelGotoTaskInfo { + uint16_t labelId; + uint32_t reserved[9]; +} rtLabelGotoTaskInfo_t; + +typedef struct tagProfilerTraceTaskInfo { + uint64_t profilerTraceId; + uint32_t notify : 8; + uint32_t reserved_ : 24; + uint32_t flags; + uint32_t reserved[6]; +} rtProfilerTrace_t; + +typedef struct tagProfilerTraceExTaskInfo { + uint64_t profilerTraceId; + uint64_t modelId; + uint16_t tagId; + uint8_t reserved[22]; +} rtProfilerTraceEx_t; + +typedef struct tagrtMemcpyAsyncTaskInfo { + const void *dst; + uint64_t destMax; + const void *src; + uint64_t count; + uint32_t kind; + uint32_t reserved; +} rtMemcpyAsyncTaskInfo_t; + +typedef struct tagrtCmoAddrTaskInfo { + const void *src; + uint32_t len_inner; + uint16_t num_outer; + uint16_t num_inner; + uint32_t stride_outer; + uint32_t stride_inner; + uint8_t cmoOpCode; + uint8_t reserved[8]; +} rtCmoAddrTaskInfo_t; + +typedef struct tagrtNotifyTaskInfo { + uint32_t notifyID; + uint32_t reserved[9]; +} rtNotifyTaskInfo_t; + +typedef struct tagrtReduceAsyncTaskInfo { + const void *dst; + uint64_t destMax; + const void *src; + uint64_t count; + uint32_t kind; + uint32_t type; +} rtReduceAsyncTaskInfo_t; + +typedef struct tagrtRdmaSendTaskInfo { + uint32_t index; + uint32_t wqe_index; + uint32_t reserved[8]; +} rtRdmaSendTaskInfo_t; + +typedef struct tagrtRdmaDbSendTaskInfo { + uint64_t dbInfo; + uint32_t dbIndex; + uint32_t reserved[7]; // offset 7 +} rtRdmaDbSendTaskInfo_t; + +typedef struct tagrtModelEndGraphTaskInfo { + uint32_t modelId; + uint32_t executorFlag; + uint32_t reserved[8]; +} rtModelEndGraphTaskInfo_t; + +typedef struct tagrtModelExitInfo { + uint32_t modelId; + uint32_t streamId; + uint32_t reserved[8]; +} rtModelExitTaskInfo_t; + + +typedef struct tagrtStreamLabelSwitchByIndexTask_t { + uint64_t indexPtr; + uint64_t labelInfoPtr; + uint32_t max; + uint8_t reserved[20]; +} rtStreamLabelSwitchByIndexTask_t; + +typedef struct tagrtStreamLabelGotoTask_t { + uint16_t labelId; + uint16_t modelId; + uint8_t reserved[36]; +} rtStreamLabelGotoTask_t; + +typedef struct tagrtNpuGetFloatStatusTask_t { + uint64_t outputAddr; + uint64_t outputSize; + uint32_t checkMode; + uint8_t reserved[20]; +} rtNpuGetFloatStatusTask_t; + +typedef struct tagrtNpuClearFloatStatusTask_t { + uint32_t checkMode; + uint8_t reserved[36]; +} rtNpuClearFloatStatusTask_t; + +typedef struct tagrtNpuGetFloatDebugStatusTask_t { + uint64_t outputAddr; + uint64_t outputSize; + uint32_t checkMode; + uint8_t reserved[20]; +} rtNpuGetFloatDebugStatusTask_t; + +typedef struct tagrtNpuClearFloatDebugStatusTask_t { + uint32_t checkMode; + uint8_t reserved[36]; +} rtNpuClearFloatDebugStatusTask_t; + +typedef struct tagTaskInfo { + uint32_t type; + uint32_t streamID; + union { + rtKernelTaskInfoEx_t kernelTaskEx; + rtKernelTaskInfo_t kernelTask; + rtAllKernelTaskInfo_t allKernelTask; + rtEventTaskInfo_t eventTask; + rtStreamSwitchTaskInfo_t streamSwitchTask; + rtStreamActiveTaskInfo_t streamActiveTask; + rtLabelSetTaskInfo_t labelSetTask; + rtLabelSwitchTaskInfo_t labelSwitchTask; + rtLabelGotoTaskInfo_t labelGotoTask; + rtProfilerTrace_t profilertraceTask; + rtProfilerTraceEx_t profilertraceExTask; + rtMemcpyAsyncTaskInfo_t memcpyAsyncTask; + rtNotifyTaskInfo_t notifyTask; + rtReduceAsyncTaskInfo_t reduceAsyncTask; + rtRdmaSendTaskInfo_t rdmaSendTask; + rtRdmaDbSendTaskInfo_t rdmaDbSendTask; + rtModelEndGraphTaskInfo_t modelEndGraphTask; + rtModelExitTaskInfo_t modelExitTask; + rtStreamSwitchNTaskInfo_t streamSwitchNTask; + rtStreamLabelSwitchByIndexTask_t streamLabelSwitchIndexTask; + rtStreamLabelGotoTask_t streamLabelGotoTask; + rtNpuGetFloatStatusTask_t npuGetFloatStatusTask; + rtNpuClearFloatStatusTask_t npuClearFloatStatusTask; + rtNpuGetFloatDebugStatusTask_t npuGetFloatDebugStatusTask; + rtNpuClearFloatDebugStatusTask_t npuClearFloatDebugStatusTask; + rtCmoAddrTaskInfo_t cmoAddrTask; + uint32_t reserved[10]; + } u; +} rtTaskInfo_t; + +typedef struct tagNodeInfo_t { + uint32_t nodeIdx; + uint32_t reserved[1]; +} rtNodeInfo; + +typedef struct tagHwtsInfo_t { + uint16_t taskId; + uint16_t sqExeHead; + uint16_t streamExeHead; + uint16_t reserved[2]; +} rtHwtsInfo; + +typedef struct tagLabelDevInfo_t { + uint16_t modelId; + uint16_t streamId; + uint16_t labelId; + union { + rtNodeInfo nodeInfo; + rtHwtsInfo hwtsInfo; + uint16_t reserved[5]; + }u; +}rtLabelDevInfo; + +typedef struct tagMdlLoad { + uint8_t overflow_en; + uint16_t totalTaskNum; + void *taskDescBaseAddr; + void *pcBaseAddr; + void *paramBaseAddr; + void *weightBaseAddr; + uint8_t weightPrefetch; +} rtMdlLoad_t; + +typedef struct tagMdlExecute { + void *ioaSrcAddr; + void *dynamicTaskPtr; + void *workPtr; + bool sync; + uint16_t vld; + uint16_t taskProf; + uint8_t mid; + uint32_t ioaSize; + uint32_t sqid; + uint8_t meType; + uintptr_t cbFn; + void *cbData; + size_t mpamId; + size_t aicQos; + size_t aicOst; + size_t mecTimeThreshHold; +} rtMdlExecute_t; + +typedef struct tagMdlTaskUpdateInfo { + uint64_t *tilingKeyAddr; + uint64_t *blockDimAddr; + void *hdl; + rtFftsPlusTaskInfo_t *fftsPlusTaskInfo; +} rtMdlTaskUpdateInfo_t; + +typedef rtError_t (*rtTaskGenCallback)(rtModel_t mdl, rtTaskInfo_t *taskInfo); + +/** + * @ingroup rt_model + * @brief nano model load + * @param [out] phyModelId drv create model id + * @param [in] modelLoad model load param + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNanoModelLoad(rtMdlLoad_t *modelLoad, uint32_t *phyModelId); + +/** + * @ingroup rt_model + * @brief nano model execute + * @param [in] modelExec model execute param + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNanoModelExecute(rtMdlExecute_t *modelExec); + +/** + * @ingroup rtMsgSend + * @brief nano msg send + * @param [in] tId rcv thread id + * @param [in] sendTid send thread id + * @param [in] timeout time out + * @param [in] sendInfo tlv info + * @param [in] size tlv size + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMsgSend(uint32_t tId, uint32_t sendTid, int32_t timeout, void *sendInfo, uint32_t size); + +/** + * @ingroup rtSetTaskDescDumpFlag + * @brief nano set taskdesc dump flag + * @param [in] taskDescBaseAddr TaskDesc Base Addr + * @param [in] taskDescSize Static TaskDesc Partition size + * @param [in] taskId task id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetTaskDescDumpFlag(void *taskDescBaseAddr, size_t taskDescSize, uint32_t taskId); + +/** + * @ingroup rt_dump_Init + * @brief nano dump init + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDumpInit(void); + +/** + * @ingroup rt_dump_deInit + * @brief nano dump deinit + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDumpDeInit(void); + +/** + * @ingroup rt_model + * @brief nano destroy model instance + * @param [in] phyMdlId model to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNanoModelDestroy(uint32_t phyMdlId); + +/** + * @ingroup rt_model + * @brief set callback for generate model + * @param [in] callBack callback function + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetTaskGenCallback(rtTaskGenCallback callback); + +/** + * @ingroup rt_model + * @brief create model instance + * @param [out] mdl created model + * @param [in] flag reserved + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelCreate(rtModel_t *mdl, uint32_t flag); + +/** + * @ingroup rt_model + * @brief set ge model id to aicpu + * @param [in] model aicpu model + * @param [in] extid ge model id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtModelSetExtId(rtModel_t mdl, uint32_t extId); + +/** + * @ingroup rt_model + * @brief destroy model instance + * @param [in] mdl model to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelDestroy(rtModel_t mdl); + +/** + * @ingroup rt_model + * @brief bind model and stream instance + * @param [in] mdl binded model + * @param [in] stm binded stream + * @param [in] flag reserved + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelBindStream(rtModel_t mdl, rtStream_t stm, uint32_t flag); + +/** + * @ingroup rt_model + * @brief unbind model and stream instance + * @param [in] mdl unbinded model + * @param [in] stm unbinded stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelUnbindStream(rtModel_t mdl, rtStream_t stm); + +/** + * @ingroup rt_model + * @brief tell runtime Model has been Loaded + * @param [in] mdl model to execute + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtModelLoadComplete(rtModel_t mdl); + +/** + * @ingroup rt_model + * @brief execute model instance + * @param [in] mdl model to execute + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExecute(rtModel_t mdl, rtStream_t stm, uint32_t flag); + +/** + * @ingroup rt_model + * @brief execute model instance + * @param [in] mdl model to execute, timeout for sync + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExecuteSync(rtModel_t mdl, rtStream_t stm, uint32_t flag, int32_t timeout); + +/** + * @ingroup rt_model + * @brief get model the last persist task id + * @param [in] mdl model to execute + * @param [out] taskId last task id of the model + * @param [out] streamId last steam id of the model + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelGetTaskId(rtModel_t mdl, uint32_t *taskId, uint32_t *streamId); + +/** + * @ingroup rt_model + * @brief add a end graph task to stream + * @param [in] mdl model to execute + * @param [in] end graph stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEndGraph(rtModel_t mdl, rtStream_t stm); + +/** + * @ingroup rt_model + * @brief add a end graph task with flag to stream + * @param [in] mdl model to execute + * @param [in] end graph stream + * @param [in] flags AICPU datadump + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtEndGraphEx(rtModel_t mdl, rtStream_t stm, uint32_t flags); + +/** + * @ingroup rt_model + * @brief add a end graph task to stream + * @param [in] mdl model to execute + * @param [in] flags EXECUTOR_TS | EXECUTOR_AICPU + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExecutorSet(rtModel_t mdl, uint8_t flags); + +/** + * @ingroup rt_model + * @brief abort model + * @param [in] mdl model to abort + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelAbort(rtModel_t mdl); + +/** + * @ingroup rt_model + * @brief end graph task to model default stream + * @param [in] mdl model to execute + * @param [in] end graph stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelExit(rtModel_t mdl, rtStream_t stm); + +/** + * @ingroup rt_model + * @brief bind queue + * @param [in] mdl model to bind + * @param [in] queueId queueId to bind + * @param [in] flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelBindQueue(rtModel_t mdl, uint32_t queueId, rtModelQueueFlag_t flag); + +/** + * @ingroup rt_model + * @brief get model id + * @param [in] mdl + * @param [out] modelId model id + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelGetId(rtModel_t mdl, uint32_t *modelId); + +/* + * @ingroup rt_model + * @brief enable debug for dump overflow exception + * @param [in] addr: ddr address of kernel exception dumpped + * @param [in] mdl: model handle + * @param [in] flag: debug flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugRegister(rtModel_t mdl, uint32_t flag, const void *addr, + uint32_t *streamId, uint32_t *taskId); + +/* + * @ingroup rt_model + * @brief disable debug for dump overflow exception + * @param [in] mdl: model handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugUnRegister(rtModel_t mdl); + +/** + * @ingroup rt_model + * @brief set model group id + * @param [in] mdl model + * @param [in] schGrpId groupId (0,4) 0:default invalid value 1-4 valid value Maximum support 4 groups + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelSetSchGroupId(rtModel_t mdl, const int16_t schGrpId); + +/** + * @ingroup rt_model + * @brief add stream sq lock task + * @param [in] stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetStreamSqLock(rtStream_t stm); + +/** + * @ingroup rt_model + * @brief add stream sq unlock task + * @param [in] stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetStreamSqUnlock(rtStream_t stm); + +/** + * @ingroup rt_model + * @brief support reuse or not + * @param [out] bSupport + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error + */ +RTS_API rtError_t rtSupportModelStreamReuse(bool *bSupport); + +/** + * @ingroup rt_model + * @brief update model task info + * @param [in] desStm + * @param [in] desTaskId + * @param [in] sinkStm + * @param [in] para + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtModelTaskUpdate(rtStream_t desStm, uint32_t desTaskId, rtStream_t sinkStm, + rtMdlTaskUpdateInfo_t *para); + +/** + * @ingroup rt_model + * @brief no operation task + * @param [in] stm + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNopTask(rtStream_t stm); + +/** + * @ingroup rt_model + * @brief set model name + * @param [in] mdl + * @param [in] mdlName + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetModelName(rtModel_t mdl, const char_t *mdlName); + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_RT_MODEL_H diff --git a/inc/runtime/runtime/rt_preload_task.h b/inc/runtime/runtime/rt_preload_task.h new file mode 100644 index 0000000000000000000000000000000000000000..ee9bf22bafdce5d3d65baa66cd5d12479d9ddc45 --- /dev/null +++ b/inc/runtime/runtime/rt_preload_task.h @@ -0,0 +1,183 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * Description: parma rt_preload_task.h + * Create: 2023-06-02 + */ +#ifndef CCE_RUNTIME_RT_PRELOAD_TASK_H +#define CCE_RUNTIME_RT_PRELOAD_TASK_H + +#include "runtime/base.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +typedef int32_t (*rtKernelLaunchFillFunc)(void* cfgAddr, uint32_t cfglen); + +typedef enum tagRtTaskBuffType { + HWTS_STATIC_TASK_DESC = 0, /**< static task */ + HWTS_DYNAMIC_TASK_DESC = 1, /**< dynamic task */ + PARAM_TASK_INFO_DESC = 2, /**< parma task */ + MAX_TASK, +} rtTaskBuffType_t; + +typedef enum tagRtTaskType { + RT_TASK_TYPE_KERNEL_AICORE = 0, + RT_TASK_TYPE_KERNEL_AICPU = 1, + RT_TASK_TYPE_KERNEL_NANO_AICORE = 2, + RT_TASK_TYPE_KERNEL_NANO_AICPU_HOSTFUNC = 3, + RT_TASK_TYPE_MAX, +} rtTaskType_t; + +typedef struct { + uint64_t kernelBinOffset; + uint64_t argsOffset; // need add rtTaskInput_t.argOffset + uint32_t literalBuffLen; + uint16_t blockDim; + uint8_t kernelFlag; +} rtAicoreTaskParam_t; + +typedef struct { + uint16_t type; + uint16_t preP; + uint16_t posP; + uint16_t dump; + uint16_t conds; + uint16_t uf; + uint16_t sw; + uint16_t prefetchNum; + uint16_t softUser; + uint16_t kernelCredit; + uint32_t taskParamOffset; // need add rtTaskInput_t.argOffset + uint8_t swapOut; +} rtHwtsStaticTaskDesc_t; + +typedef struct { + uint8_t vld; + uint32_t codeSize; + uint32_t dynTaskDescSize; + uint32_t blockDim; + uint32_t taskPcOffset; +} rtHwtsDynamicTaskDesc_t; + +typedef struct { + uint32_t opType; + uint32_t dataSize; + uint32_t dstOffset; + uint32_t srcOffset; +} rtPrefetchBufInfo_t; + +// The size of the Param Buffer is variable and depends on the number of inputs and +// outputs of the operator in actual applications. +// It is estimated that the size of the Param Buffer does not exceed 128. +#define PARMA_PARAM_BUFFER_MAX_N 128U +typedef struct { + rtPrefetchBufInfo_t prefetchBufInfo[PARMA_PARAM_BUFFER_MAX_N]; + uint16_t prefetchBufSize; + uint64_t paramBufInfo[PARMA_PARAM_BUFFER_MAX_N]; + uint16_t paramBufSize; + uint32_t bufSize; + void* bufInfo; +} rtParamBufDesc_t; + +typedef struct { + rtTaskBuffType_t type; + union { + rtHwtsStaticTaskDesc_t hwtsTaskDesc; + rtHwtsDynamicTaskDesc_t hwtsDynamicTaskDesc; + rtParamBufDesc_t paramBufDesc; + }u; +} rtNanoDefaultTaskParam_t; + +typedef struct { + rtTaskType_t taskType; + rtTaskBuffType_t bufType; + uint16_t streamId; + union { + rtAicoreTaskParam_t aicoreTask; + rtNanoDefaultTaskParam_t nanoAicoreTask; + rtNanoDefaultTaskParam_t nanoHostFuncTask; + }u; +} rtCompilerPartinfo_t; + +typedef struct { + void* dataBuffer; // current write addr + uint32_t bufferLen; // the space of dataBuffer left + rtCompilerPartinfo_t compilerInfo; // task info for complie + uint64_t argOffset; // args offset +} rtTaskInput_t; + +/** + * @ingroup rt_preload_task + * @brief exeom task build + * @param [in] type task type + * @param [out] bufferLen the space of dataBuffer left + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetTaskBufferLen(const rtTaskBuffType_t type, uint32_t * const bufferLen); + +/** + * @ingroup rt_preload_task + * @brief exeom task build + * @param [in] taskInput task info for complie + * @param [out] taskLen current tasklen + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtTaskBuild(const rtTaskInput_t * const taskInput, uint32_t* taskLen); + +/** + * @ingroup rt_preload_task + * @brief get elf header offset + * @param [in] elfData kernel bin addr + * @param [in] elfLen + * @param [out] offset elf header offset + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetElfOffset(void * const elfData, const uint32_t elfLen, uint32_t* offset); + +/** + * @ingroup rt_preload_task + * @brief regkernel callback function + * @param [in] symbol symbol_name + * @param [in] kenelLauchFillFunc kenelLauchFillFunc + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtRegKernelLaunchFillFunc(const char* symbol, rtKernelLaunchFillFunc callback); + +/** + * @ingroup rt_preload_task + * @brief unregkernel callback function + * @param [in] symbol symbol_name + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtUnRegKernelLaunchFillFunc(const char* symbol); + +/** +* @ingroup rt_preload_task +* @brief get l2cache offset +* @param [in] deviceId device id +* @param [out] offset l2cache offset +* @return 0 for success, others for fail +*/ +RTS_API rtError_t rtGetL2CacheOffset(uint32_t deviceId, uint64_t *offset); + +/** + * @ingroup rt_preload_task + * @brief not use now, return RT_ERROR_NONE + * @param [in] isHuge + * @param [out] bufferLen + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetStreamBufferLen(const bool isHuge, uint32_t * const bufferLen); + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_RT_PRELOAD_TASK_H diff --git a/inc/runtime/runtime/rt_stars.h b/inc/runtime/runtime/rt_stars.h new file mode 100644 index 0000000000000000000000000000000000000000..bc83dd32200554107fcdb8884523ad6df3267263 --- /dev/null +++ b/inc/runtime/runtime/rt_stars.h @@ -0,0 +1,281 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: the definition of stars + */ + +#ifndef CCE_RUNTIME_RT_STARS_H +#define CCE_RUNTIME_RT_STARS_H + +#include "base.h" +#include "rt_stars_define.h" +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * @ingroup rt_stars + * @brief launch stars task. + * used for send star sqe directly. + * @param [in] taskSqe stars task sqe + * @param [in] sqeLen stars task sqe length + * @param [in] stm associated stream + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtStarsTaskLaunch(const void *taskSqe, uint32_t sqeLen, rtStream_t stm); + +/** + * @ingroup rt_stars + * @brief launch stars task. + * used for send star sqe directly. + * @param [in] taskSqe stars task sqe + * @param [in] sqeLen stars task sqe length + * @param [in] stm associated stream + * @param [in] flag dump flag + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtStarsTaskLaunchWithFlag(const void *taskSqe, uint32_t sqeLen, rtStream_t stm, uint32_t flag); + +/** + * @ingroup rt_stars + * @brief create cdq instance. + * @param [in] batchNum batch number + * @param [in] batchSize batch size + * @param [in] queName cdq name + * @return RT_ERROR_NONE for ok, ACL_ERROR_RT_NO_CDQ_RESOURCE for no cdq resources + */ +RTS_API rtError_t rtCdqCreate(uint32_t batchNum, uint32_t batchSize, const char_t *queName); + +/** + * @ingroup rt_stars + * @brief destroy cdq instance. + * @param [in] queName cdq name + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtCdqDestroy(const char_t *queName); + +/** + * @ingroup rt_stars + * @brief get free batch in the queue. + * @param [in] queName cdq name + * @param [in] timeout batch size + * @param [out] batchId batch index + * @return RT_ERROR_NONE for ok, ACL_ERROR_RT_WAIT_TIMEOUT for timeout + */ +RTS_API rtError_t rtCdqAllocBatch(const char_t *queName, int32_t timeout, uint32_t *batchId); + +/** + * @ingroup rt_stars + * @brief launch a write_cdqm task on the stream. + * When the task is executed, the data information will be inserted into the cdqe index position of the queue. + * @param [in] queName cdq name + * @param [in] cdqeIndex cdqe index + * @param [in] data cdqe infomation + * @param [in] dataSize data size + * @param [in] stm launch task on the stream + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtCdqEnQueue(const char_t *queName, uint32_t cdqeIndex, void *data, uint32_t dataSize, + rtStream_t stm); + +/** + * @ingroup rt_stars + * @brief launch a write_cdqm task on the stream. + * When the task is executed, the data information will be inserted into the cdqe index position of the queue. + * @param [in] queName cdq name + * @param [in] cdqeIndex cdqe index + * @param [in] data cdqe infomation + * @param [in] dataSize data size + * @param [in] stm launch task on the stream + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtCdqEnQueuePtrMode(const char_t *queName, uint32_t cdqeIndex, const void *ptrAddr, + rtStream_t stm); + +/** + * @ingroup rt_stars + * @brief launch common cmo task on the stream. + * @param [in] taskInfo cmo task info + * @param [in] stm launch task on the stream + * @param [in] flag flag + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtCmoTaskLaunch(rtCmoTaskInfo_t *taskInfo, rtStream_t stm, uint32_t flag); + +/** + * @ingroup dvrt_mem + * @brief launch common cmo task on the stream. + * @param [in] cmoAddrInfo cmo task info + * @param [in] destMax destMax + * @param [in] cmoOpCode opcode + * @param [in] stm launch task on the stream + * @param [in] flag flag + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtCmoAddrTaskLaunch(void *cmoAddrInfo, uint64_t destMax, rtCmoOpCode_t cmoOpCode, + rtStream_t stm, uint32_t flag); + +/** + * @ingroup rt_stars + * @brief launch common cmo task on the stream. + * @param [in] srcAddrPtr prefetch addrs + * @param [in] srcLen prefetch addrs load + * @param [in] cmoType opcode + * @param [in] stm stream + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtCmoAsync(void *srcAddrPtr, size_t srcLen, rtCmoOpCode_t cmoType, rtStream_t stm); + +/** + * @ingroup rt_stars + * @brief launch barrier cmo task on the stream. + * @param [in] taskInfo barrier task info + * @param [in] stm launch task on the stream + * @param [in] flag flag + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtBarrierTaskLaunch(rtBarrierTaskInfo_t *taskInfo, rtStream_t stm, uint32_t flag); + +/** + * @ingroup rt_stars + * @brief dvpp group handle. + */ +typedef void *rtDvppGrp_t; + +typedef struct tagDvppGrpRptInfo { + uint32_t deviceId; + uint32_t streamId; + uint32_t taskId; + uint8_t sqeType; + uint8_t cqeErrorCode; + uint8_t reserve[2]; + uint32_t accErrorCode; +} rtDvppGrpRptInfo_t; + +typedef void (*rtDvppGrpCallback)(rtDvppGrpRptInfo_t *rptInfo); + +/** + * @ingroup rt_stars + * @brief create dvpp group. + * @param [in] flags group flag, reserved parameter + * @param [out] grp group handle + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtDvppGroupCreate(rtDvppGrp_t *grp, uint32_t flags); + +/** + * @ingroup rt_stars + * @brief destroy dvpp group. + * @param [in] grp group handle + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtDvppGroupDestory(rtDvppGrp_t grp); + +/** + * @ingroup rt_stars + * @brief create stream with grp handle + * @param [in|out] stm created stream + * @param [in] priority stream priority + * @param [in] flags stream op flags + * @param [in] grp grp handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtStreamCreateByGrp(rtStream_t *stm, int32_t priority, uint32_t flags, rtDvppGrp_t grp); + +/** + * @ingroup rt_stars + * @brief wait report by grp + * @param [in] grp group handle + * @param [in] callBackFunc callback + * @param [in] timeout wait timeout config, ms, -1: wait forever + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtDvppWaitGroupReport(rtDvppGrp_t grp, rtDvppGrpCallback callBackFunc, int32_t timeout); + +/* + * @ingroup dvrt_stream + * @brief set stream geOpTag + * @param [in] stm: stream handle + * @param [in] geOpTag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetStreamTag(rtStream_t stm, uint32_t geOpTag); + +/* + * @ingroup rt_stars + * @brief build multiple task + * @param [in] taskInfo(rtMultipleTaskInfo_t) + * @param [in] stm: stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMultipleTaskInfoLaunch(const void *taskInfo, rtStream_t stm); + + /* + * @ingroup rt_stars + * @brief build multiple task + * @param [in] taskInfo(rtMultipleTaskInfo_t) + * @param [in] stm: stream handle + * @param [in] flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtMultipleTaskInfoLaunchWithFlag(const void *taskInfo, rtStream_t stm, const uint32_t flag); + +// general ctrl type +typedef enum tagGeneralCtrlType { + RT_GNL_CTRL_TYPE_MEMCPY_ASYNC_CFG = 0, + RT_GNL_CTRL_TYPE_REDUCE_ASYNC_CFG = 1, + RT_GNL_CTRL_TYPE_FFTS_PLUS_FLAG = 2, + RT_GNL_CTRL_TYPE_FFTS_PLUS = 3, + RT_GNL_CTRL_TYPE_NPU_GET_FLOAT_STATUS = 4, + RT_GNL_CTRL_TYPE_NPU_CLEAR_FLOAT_STATUS = 5, + RT_GNL_CTRL_TYPE_STARS_TSK = 6, + RT_GNL_CTRL_TYPE_CDQ_EN_QU = 7, + RT_GNL_CTRL_TYPE_CDQ_EN_QU_PTR = 8, + RT_GNL_CTRL_TYPE_CMO_TSK = 9, + RT_GNL_CTRL_TYPE_BARRIER_TSK = 10, + RT_GNL_CTRL_TYPE_STARS_TSK_FLAG = 11, + RT_GNL_CTRL_TYPE_SET_STREAM_TAG = 12, + RT_GNL_CTRL_TYPE_MULTIPLE_TSK = 13, + RT_GNL_CTRL_TYPE_NPU_GET_FLOAT_DEBUG_STATUS = 14, + RT_GNL_CTRL_TYPE_NPU_CLEAR_FLOAT_DEBUG_STATUS = 15, + RT_GNL_CTRL_TYPE_MULTIPLE_TSK_FLAG = 16, // invoke rtMultipleTaskInfoLaunchWithFlag + RT_GNL_CTRL_TYPE_MAX +} rtGeneralCtrlType_t; + +/** + * @ingroup rt_stars + * @brief gerneral ctrl if + * @param [in] ctl ctl input + * @param [in] num ctl input num + * @param [in] type ctl type + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtGeneralCtrl(uintptr_t *ctrl, uint32_t num, uint32_t type); + +/** + * @ingroup rt_stars + * @brief 5612(tiny) need translate addr + * @param [out] needTranslate + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtNeedDevVA2PA(bool *need); + +/** + * @ingroup rt_stars + * @brief translate addr from va to pa + * @param [in] devAddr translate addr + * @param [in] len addr len + * @param [in] stm stream +* @param [in] isAsync async or sync + * @return RT_ERROR_NONE for ok, others failed + */ +RTS_API rtError_t rtDevVA2PA(uint64_t devAddr, uint64_t len, rtStream_t stm, bool isAsync); + +#if defined(__cplusplus) +} +#endif +#endif // CCE_RUNTIME_RT_STARS_H diff --git a/inc/runtime/runtime/rt_stars_define.h b/inc/runtime/runtime/rt_stars_define.h new file mode 100644 index 0000000000000000000000000000000000000000..cb54e5604afc210a5132be16846b1631bcc064f9 --- /dev/null +++ b/inc/runtime/runtime/rt_stars_define.h @@ -0,0 +1,246 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * Description: the definition of stars + */ + +#ifndef CCE_RUNTIME_RT_STARS_DEFINE_H +#define CCE_RUNTIME_RT_STARS_DEFINE_H + +#include "base.h" + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +extern "C" { +#endif + +#pragma pack(push) +#pragma pack (1) + +typedef struct tagStarsSqeHeader { + uint8_t type : 6; + uint8_t l1Lock : 1; + uint8_t l1Unlock : 1; + + uint8_t ie : 2; + uint8_t preP : 2; + uint8_t postP : 2; + uint8_t wrCqe : 1; + uint8_t reserved : 1; + + uint16_t blockDim; + + uint16_t rtStreamId; + uint16_t taskId; +} rtStarsSqeHeader_t; + +typedef struct tagDavidStarsSqeHeader { + /* word0 */ + uint8_t type : 6; + uint8_t lock : 1; + uint8_t unlock : 1; + uint8_t ie : 1; + uint8_t preP : 1; + uint8_t postP : 1; + uint8_t wrCqe : 1; + uint8_t ptrMode : 1; + uint8_t rttMode : 1; + uint8_t headUpdate : 1; + uint8_t res0 : 1; + uint16_t blockDim; + + /* word1 */ + uint16_t rtStreamId; + uint16_t taskId; +} rtDavidStarsSqeHeader_t; + +typedef struct tagRtStarsCommonSqe { + rtStarsSqeHeader_t sqeHeader; // word 0-1 + uint32_t commandCustom[14]; // word 2-15 is custom define by command. +} rtStarsCommonSqe_t; + +typedef struct tagRtDavidStarsCommonSqe { + rtDavidStarsSqeHeader_t sqeHeader; // word 0-1 + uint32_t commandCustom[14]; // word 2-15 is custom define by command. +} rtDavidStarsCommonSqe_t; + +typedef struct tagStarsDsaSqe { + // 0-7 bytes + rtStarsSqeHeader_t sqeHeader; + // 8-11 bytes + uint32_t start : 1; + uint32_t functionType : 3; + uint32_t dataType : 3; + uint32_t algoType : 3; + uint32_t paramVldBitmap : 5; + uint32_t paramAddrValBitmap : 7; + uint32_t reserved0 : 10; + // 12-15 bytes + uint16_t sqeIndex; + uint8_t kernelCredit; + uint8_t reserved1; + // 16-31 bytes + uint32_t dsaCfgResultAddrLow; + uint32_t dsaCfgResultAddrHigh; + uint32_t dsaCfgStateAddrLow; + uint32_t dsaCfgStateAddrHigh; + // 32-47 bytes + uint32_t dsaCfgParamAddrLow; + uint32_t dsaCfgParamAddrHigh; + uint32_t dsaCfgSeedLow; + uint32_t dsaCfgSeedHigh; + // 48-63 bytes + uint32_t dsaCfgNumberLow; + uint32_t dsaCfgNumberHigh; + uint32_t reserved2[2]; +} rtStarsDsaSqe_t; + +// ffts+ type +typedef enum tagFftsPlusType { + RT_FFTS_PLUS_TYPE_RES1 = 2, // Reserved + RT_FFTS_PLUS_TYPE_RES2 = 3, // Reserved + RT_FFTS_PLUS_TYPE = 4, // FFTS+ mode +} rtFftsPlusType_t; + +typedef struct tagStarsFftsPlusHeader { + uint8_t type : 6; + uint8_t l1Lock : 1; + uint8_t l1Unlock : 1; + + uint8_t ie : 2; + uint8_t preP : 2; + uint8_t postP : 2; + uint8_t wrCqe : 1; + /* tell mcu if this subgraph is overflow-enabled and mcu will send this flag to aicpu when aicpu ctx is excuted */ + uint8_t overflowEn : 1; + + uint16_t blockDim; + + uint16_t rtStreamId; + uint16_t taskId; +} rtStarsFftsPlusHeader_t; +// ffts+ sqe +typedef struct tagFftsPlusSqe { + // 0-7 bytes + rtStarsSqeHeader_t sqeHeader; // use rtStarsFftsPlusHeader_t instead + // 8-11 bytes + uint16_t fftsType : 3; + uint16_t cmo : 1; + uint16_t scheduleDfxFlag : 1; + uint16_t reserved1 : 7; + uint16_t wrrRatio : 4; + uint16_t dsaSqId : 11; + uint16_t reserved2 : 5; + // 12-15 bytes + uint16_t sqeIndex; + uint8_t kernelCredit; + uint8_t subType; + // 16-23 bytes + uint32_t stackPhyBaseL; + uint32_t stackPhyBaseH; + // 24-31 bytes + uint16_t totalContextNum; + uint16_t readyContextNum; + uint16_t preloadContextNum; + uint16_t timeout; + // 32-35 bytes + uint16_t reserved6; + uint16_t prefetchOstNum : 5; + uint16_t reserved9 : 3; + uint16_t cmaintOstNum : 5; + uint16_t reserved10 : 3; + // 36-39 bytes + uint16_t aicPrefetchLower : 5; + uint16_t reserved11 : 3; + uint16_t aicPrefetchUpper : 5; + uint16_t reserved12 : 3; + uint16_t aivPrefetchLower : 5; + uint16_t reserved13 : 3; + uint16_t aivPrefetchUpper : 5; + uint16_t reserved14 : 3; + // 40-47 bytes + uint32_t contextAddressBaseL; + uint32_t contextAddressBaseH : 17; + uint32_t reserved15 : 15; + // 48-63 bytes:48-51 use for pid + // use reserved16[1] bit 4 for l2cache + uint32_t reserved16[4]; +} rtFftsPlusSqe_t; + +typedef enum { + RECORD_STORE_MODE = 0x0U, + RECORD_ADD_MODE = 0x1U, + RECORD_WRITE_BIT_MODE = 0x2U, + RECORD_INVALID_MODE = 0x3U, // invalid, cannot set + RECORD_CLEAR_BIT_MODE = 0x4U, + RECORD_MODE_BUT +} rtCntNotifyRecordMode_t; + +typedef enum { + WAIT_LESS_MODE = 0x0U, + WAIT_EQUAL_MODE = 0x1U, + WAIT_BIGGER_MODE = 0x2U, + WAIT_BIGGER_OR_EQUAL_MODE = 0x3U, + WAIT_BITMAP_MODE = 0x4U, + WAIT_MODE_BUT +} rtCntNotifyWaitMode_t; + +typedef enum { + NOTIFY_TABLE_SLICE = 0U, + NOTIFY_CNT_ST_SLICE = 1U, + NOTIFY_CNT_ADD_SLICE = 2U, + NOTIFY_CNT_BIT_WR_SLICE = 3U, + NOTIFY_CNT_BIT_CLR_SLICE = 4U, + NOTIFY_TYPE_BUFF +} rtNotifyType_t; + +typedef struct tagCmoTaskInfo { + uint8_t qos; + uint8_t partId; + uint8_t pmg; + uint8_t reserved; + uint16_t cmoType; + uint16_t opCode; // 6: Preload; 7: Prewriteback; 8: invalid; 9: flush; + uint16_t numInner; + uint16_t numOuter; + uint32_t logicId; + uint32_t lengthInner; + uint64_t sourceAddr; + uint32_t striderOuter; + uint32_t striderInner; +} rtCmoTaskInfo_t; + +typedef struct tagBarrierCmoInfo { + uint16_t cmoType; // 0 is barrier, 1 is invalid, Prefetch is 2, Write_back is 3, FE/GE only use invalid type. + uint32_t logicId; +} rtBarrierCmoInfo_t; + +#define RT_CMO_MAX_BARRIER_NUM 6U // 6U is max support +typedef struct tagBarrierTaskInfo { + uint8_t logicIdNum; + rtBarrierCmoInfo_t cmoInfo[RT_CMO_MAX_BARRIER_NUM]; +} rtBarrierTaskInfo_t; + +typedef enum tagRtCmoType { + RT_CMO_PREFETCH = 6, // Preload + RT_CMO_WRITEBACK, // Prewriteback + RT_CMO_INVALID, // invalid + RT_CMO_FLUSH, // flush + RT_CMO_RESERVED, +} rtCmoOpCode_t; + +typedef struct { + uint32_t resv0; + uint32_t resv1; + uint16_t num_outer; + uint16_t num_inner; + uint32_t len_inner; + uint64_t src; + uint32_t stride_outer; + uint32_t stride_inner; +} rtCmoAddrInfo; + +#pragma pack(pop) + +#if defined(__cplusplus) && !defined(COMPILE_OMG_PACKAGE) +} +#endif +#endif // CCE_RUNTIME_RT_STARS_DEFINE_H diff --git a/inc/runtime/runtime/stream.h b/inc/runtime/runtime/stream.h new file mode 100644 index 0000000000000000000000000000000000000000..45b838421cab6d5f629464b34126d87fdf7b681c --- /dev/null +++ b/inc/runtime/runtime/stream.h @@ -0,0 +1,389 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2020-2021. All rights reserved. + * Description: stream.h + * Create: 2020-01-01 + */ + +#ifndef CCE_RUNTIME_STREAM_H +#define CCE_RUNTIME_STREAM_H + +#include + +#include "base.h" +#include "event.h" + +#if defined(__cplusplus) +extern "C" { +#endif + +/** + * @ingroup stream_flags + * @brief stream op bit flags + */ +#define RT_STREAM_DEFAULT (0x00U) +#define RT_STREAM_PERSISTENT (0x01U) +#define RT_STREAM_FORCE_COPY (0x02U) +#define RT_STREAM_HUGE (0x04U) +#define RT_STREAM_AICPU (0x08U) +#define RT_STREAM_FORBIDDEN_DEFAULT (0x10U) +#define RT_STREAM_HEAD (0x20U) +#define RT_STREAM_PRIMARY_DEFAULT (0x40U) +#define RT_STREAM_PRIMARY_FIRST_DEFAULT (0x80U) +#define RT_STREAM_OVERFLOW (0x100U) +#define RT_STREAM_FAST_LAUNCH (0x200U) +#define RT_STREAM_FAST_SYNC (0x400U) +#define RT_STREAM_CP_PROCESS_USE (0x800U) // RT_STREAM_CP_PROCESS_USE does not support OR with other flags +#define RT_STREAM_VECTOR_CORE_USE (0x1000U) + +/** + * @ingroup stream_config + * @brief stream config params + */ +typedef struct TagStreamConfigHandle { + void *workPtr; + size_t workSize; + size_t flag; + uint32_t priority; +} rtStreamConfigHandle; + +typedef enum tagRtClearStep { + RT_STREAM_STOP = 0, + RT_STREAM_CLEAR, +} rtClearStep_t; +/** + * @ingroup stream_type + * @brief stream type + */ +#define RT_NORMAL_STREAM (0x00U) +#define RT_HUGE_STREAM (0x01U) + +/** + * priority level default value when create a stream + */ +#define RT_STREAM_PRIORITY_DEFAULT (0U) + +#define RT_MAX_MODELS_IN_ONE_STREAM (256) + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in|out] stm created stream + * @param [in] priority stream priority + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamCreate(rtStream_t *stm, int32_t priority); + +/** + * @ingroup dvrt_stream + * @brief create stream instance for external api + * @param [in|out] stm created stream + * @param [in] priority stream priority + * @param [in] flags stream op flags + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamCreateWithFlagsExternal(rtStream_t *stm, int32_t priority, uint32_t flags); + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in|out] stm created stream + * @param [in] priority stream priority + * @param [in] flags stream op flags + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamCreateWithFlags(rtStream_t *stm, int32_t priority, uint32_t flags); + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in|out] stm created stream + * @param [in] handle stream config + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamCreateWithConfig(rtStream_t *stm, rtStreamConfigHandle *handle); + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in] stm stream hadle + * @param [out] sqId stream op sqId + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamGetSqid(const rtStream_t stm, uint32_t *sqId); + +/** + * @ingroup dvrt_stream + * @brief get stream cq info + * @param [in] stm stream hadle + * @param [out] sqId stream op cqId + * @param [out] cqId stream op logic cqId + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamGetCqid(const rtStream_t stm, uint32_t *cqId, uint32_t *logicCqId); + +/** + * @ingroup dvrt_stream + * @brief create stream instance + * @param [in] stm stream hadle + * @param [out] workaddr workaddr on stream + * @param [out] worksize worksize on stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamGetWorkspace(const rtStream_t stm, void **workaddr, size_t *worksize); + +/** + * @ingroup dvrt_stream + * @brief destroy stream instance. + * @param [in] stm the stream to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamDestroy(rtStream_t stm); + +/** + * @ingroup dvrt_stream + * @brief force destroy stream instance. + * @param [in] stm the stream to destroy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamDestroyForce(rtStream_t stm); + +/** + * @ingroup dvrt_stream + * @brief wait an recorded event for stream + * @param [in] stm the wait stream + * @param [in] event the event to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamWaitEvent(rtStream_t stm, rtEvent_t evt); + +/** + * @ingroup dvrt_stream + * @brief wait an recorded event for stream + * @param [in] stm the wait stream + * @param [in] event the event to wait + * @param [in] timeout timeout value + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamWaitEventWithTimeout(rtStream_t stm, rtEvent_t evt, uint32_t timeout); + +/** + * @ingroup dvrt_stream + * @brief wait stream to be complete + * @param [in] stm stream to wait + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamSynchronize(rtStream_t stm); + +/** + * @ingroup dvrt_stream + * @brief wait stream to be complete and set timeout + * @param [in] stm stream to wait + * @param [in] timeout timeout value,the unit is milliseconds + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamSynchronizeWithTimeout(rtStream_t stm, int32_t timeout); + +/** + * @ingroup dvrt_stream + * @brief queries an asynchronous stream for completion status + * @param [in] stm stream to query + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_STREAM_NOT_COMPLETE for not complete + */ +RTS_API rtError_t rtStreamQuery(rtStream_t stm); + +/** + * @ingroup dvrt_stream + * @brief get stream id from a stream handle + * @param [in] stm stream hadle + * @param [in] streamId stream id + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetStreamId(rtStream_t stm, int32_t *streamId); + +/** + * @ingroup dvrt_stream + * @brief inquire max stream count and max task count per stream + * @param [in] streamType Stream Type + * @param [in] MaxStrCount Max stream count + * @param [in] MaxTaskCount max task count per stream + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetMaxStreamAndTask(uint32_t streamType, uint32_t *maxStrCount, uint32_t *maxTaskCount); + +/** + * @ingroup dvrt_stream + * @brief inquire avaliable stream count + * @param [in] streamType Stream Type + * @param [out] streamCount avaliable streamCount + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetAvailStreamNum(const uint32_t streamType, uint32_t * const streamCount); + +/** + * @ingroup dvrt_stream + * @brief inquire avaliable event count + * @param [out] eventCount avaliable event Count + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +rtError_t rtGetAvailEventNum(uint32_t * const eventCount); + +/** + * @ingroup dvrt_stream + * @brief Name a stream + * @param [in] stm stream to be named + * @param [in] name identification name + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtNameStream(rtStream_t stm, const char_t *name); + +/** + * @brief execute extensible stream switch task + * @param [in] ptr pointer of value + * @param [in] condition judge condition + * @param [in] value_ptr pointer of target value + * @param [in] true_stream stream to be activated when value is not zero + * @param [in] stm stream id + * @param [in] dataType data type of target value + * @return RT_ERROR_NONE for complete + */ +RTS_API rtError_t rtStreamSwitchEx(void *ptr, rtCondition_t condition, void *valuePtr, rtStream_t trueStream, + rtStream_t stm, rtSwitchDataType_t dataType); + +/** + * @ingroup dvrt_stream + * @brief Active a stream + * @param [in] activeStream stream to be activated + * @param [in] stm input stream to init task + * @return RT_ERROR_NONE for complete + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamActive(rtStream_t activeStream, rtStream_t stm); + +/** + * @brief execute extensible stream case switch task + * @param [in] ptr pointer of value + * @param [in] size pointer num of value + * @param [in] valuePtr pointer of target value, length = size * elementSize + * @param [in] trueStreamPtr streams to be activated + * @param [in] elementSize size of to be activated true streams + * @param [in] stm input stream to init task + * @param [in] dataType data type of target value + * @return RT_ERROR_NONE for complete + */ +RTS_API rtError_t rtStreamSwitchN(void *ptr, uint32_t size, void *valuePtr, rtStream_t *trueStreamPtr, + uint32_t elementSize, rtStream_t stm, rtSwitchDataType_t dataType); + +/* + * @ingroup dvrt_stream + * @brief enable debug for dump overflow exception with stream + * @param [in] addr: ddr address of kernel exception dumpped + * @param [in] stm: stream handle + * @param [in] flag: debug flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugRegisterForStream(rtStream_t stm, uint32_t flag, const void *addr, + uint32_t *streamId, uint32_t *taskId); + +/* + * @ingroup rt_model + * @brief disable debug for dump overflow exception with stream + * @param [in] stm: stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtDebugUnRegisterForStream(rtStream_t stm); + +/* + * @ingroup dvrt_stream + * @brief enable or disable stream overflow + * @param [in] stm: stream handle + * @param [in] flag: 0:disable others:enable + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtSetStreamOverflowSwitch(rtStream_t stm, uint32_t flags); + +/* + * @ingroup dvrt_stream + * @brief get whether overflow of the stream is enable or disable + * @param [in] stm: stream handle + * @param [out] flag: 0:disable others:enable + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetStreamOverflowSwitch(rtStream_t stm, uint32_t *flags); + +/* + * @ingroup dvrt_stream + * @brief get stream geOpTag + * @param [in] stm: stream handle + * @param [out] geOpTag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtGetStreamTag(rtStream_t stm, uint32_t *geOpTag); + +/** + * @ingroup dvrt_stream + * @brief clearing tasks on stream for mc2 + * @param [in] stm: stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamClear(rtStream_t stm, rtClearStep_t step); + +/** + * @ingroup dvrt_stream + * @brief get current default stream + * @param [int] stm stream + * @param [out] default stream + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtCtxGetCurrentDefaultStream(rtStream_t *stm); + +/** + * @ingroup dvrt_stream + * @brief reg stream state callback func + * @param [int] regName register name + * @param [int] callback callback func + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtRegStreamStateCallback(const char_t *regName, const rtStreamStateCallback callback); + +/** + * @ingroup dvrt_stream + * @brief abort stream + * @param [in] stm: stream handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input + */ +RTS_API rtError_t rtStreamAbort(rtStream_t stm); + + +#if defined(__cplusplus) +} +#endif + +#endif // CCE_RUNTIME_STREAM_H diff --git a/inc/slog/toolchain/log_types.h b/inc/slog/toolchain/log_types.h new file mode 100644 index 0000000000000000000000000000000000000000..36ad66bb490d2db14b9f62ecbf088af66d2601c1 --- /dev/null +++ b/inc/slog/toolchain/log_types.h @@ -0,0 +1,130 @@ +/** + * @log_types.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef LOG_TYPES_H_ +#define LOG_TYPES_H_ + +#ifndef LINUX +#define LINUX 0 +#endif // LINUX + +#ifndef WIN +#define WIN 1 +#endif + +#ifndef OS_TYPE +#define OS_TYPE 0 +#endif // OS_TYPE + +#if (OS_TYPE == LINUX) +#define DLL_EXPORT __attribute__((visibility("default"))) +#else +#define DLL_EXPORT _declspec(dllexport) +#endif + +/** + * @ingroup slog + * + * log level id + */ +#define DLOG_DEBUG 0x0 // debug level id +#define DLOG_INFO 0x1 // info level id +#define DLOG_WARN 0x2 // warning level id +#define DLOG_ERROR 0x3 // error level id + +/** + * @ingroup slog + * + * log type + */ +enum { + DLOG_TYPE_DEBUG = 0, + DLOG_TYPE_RUN = 1, + DLOG_TYPE_MAX +}; + +/** + * @ingroup slog + * + * max log length + */ +#define MSG_LENGTH 1024 + +/** + * @ingroup slog + * + * module id + * if a module needs to be added, add the module at the end and before INVLID_MOUDLE_ID + */ +enum { + SLOG = 0, /* Slog module */ + IDEDD = 1, /* IDE daemon device */ + HCCL = 3, /* HCCL */ + FMK = 4, /* Adapter */ + DVPP = 6, /* DVPP */ + RUNTIME = 7, /* Runtime */ + CCE = 8, /* CCE */ + HDC = 9, /* HDC */ + DRV = 10, /* Driver */ + DEVMM = 22, /* Dlog memory managent */ + KERNEL = 23, /* Kernel */ + LIBMEDIA = 24, /* Libmedia */ + CCECPU = 25, /* aicpu shedule */ + ROS = 27, /* ROS */ + HCCP = 28, + ROCE = 29, + TEFUSION = 30, + PROFILING =31, + DP = 32, /* Data Preprocess */ + APP = 33, /* User Application */ + TS = 34, /* Task Schedule */ + TSDUMP = 35, + AICPU = 36, + LP = 37, /* Low Power */ + TDT = 38, /* tsdaemon or aicpu shedule */ + FE = 39, + MD = 40, + MB = 41, + ME = 42, + IMU = 43, + IMP = 44, + GE = 45, /* Fmk */ + CAMERA = 47, + ASCENDCL = 48, + TEEOS = 49, + ISP = 50, + SIS = 51, + HSM = 52, + DSS = 53, + PROCMGR = 54, /* Process Manager, Base Platform */ + BBOX = 55, + AIVECTOR = 56, + TBE = 57, + FV = 58, + TUNE = 60, + HSS = 61, /* helper */ + FFTS = 62, + OP = 63, + UDF = 64, + HICAID = 65, + TSYNC = 66, + AUDIO = 67, + TPRT = 68, + ASCENDCKERNEL = 69, + ASYS = 70, + ATRACE = 71, + RTC = 72, + SYSMONITOR = 73, + AML = 74, + ADETECT = 75, + INVLID_MOUDLE_ID = 76 /* add new module before INVLID_MOUDLE_ID */ +}; + +#endif // LOG_TYPES_H_ \ No newline at end of file diff --git a/inc/slog/toolchain/plog.h b/inc/slog/toolchain/plog.h new file mode 100644 index 0000000000000000000000000000000000000000..01a05b59107f56d0ea9d88627bb4d934281e3d2f --- /dev/null +++ b/inc/slog/toolchain/plog.h @@ -0,0 +1,70 @@ +/** + * @plog.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2024. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef PLOG_H_ +#define PLOG_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#ifndef LINUX +#define LINUX 0 +#endif // LINUX + +#ifndef WIN +#define WIN 1 +#endif + +#ifndef OS_TYPE +#define OS_TYPE 0 +#endif // OS_TYPE + +#if (OS_TYPE == LINUX) +#define DLL_EXPORT __attribute__((visibility("default"))) +#else +#define DLL_EXPORT _declspec(dllexport) +#endif + +/** + * @ingroup plog + * @brief DlogReportInitialize: init log in service process before all device setting. + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int DlogReportInitialize(void); + +/** + * @ingroup plog + * @brief DlogReportFinalize: release log resource in service process after all device reset. + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int DlogReportFinalize(void); + +/** + * @ingroup : plog + * @brief : create thread to recv log from device + * @param[in] : devId device id + * @param[in] : mode use macro LOG_SAVE_MODE_XXX in slog.h + * @return : 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int DlogReportStart(int devId, int mode); + +/** + * @ingroup : plog + * @brief : stop recv thread + * @param[in] : devId device id + */ +DLL_EXPORT void DlogReportStop(int devId); + + +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // PLOG_H_ diff --git a/inc/slog/toolchain/slog.h b/inc/slog/toolchain/slog.h new file mode 100644 index 0000000000000000000000000000000000000000..7508eb5e56953b3b5648ee0df806b473ef5ed1e9 --- /dev/null +++ b/inc/slog/toolchain/slog.h @@ -0,0 +1,331 @@ +/** + * @slog.h + * + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2024. All rights reserved. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. + */ + +#ifndef D_SYSLOG_H_ +#define D_SYSLOG_H_ + +#include +#include +#include "log_types.h" +static const int32_t TMP_LOG = 0; + +#ifdef __cplusplus +#ifndef LOG_CPP +extern "C" { +#endif +#endif // __cplusplus + +/** + * @ingroup slog + * + * log level id + */ +#define DLOG_NULL 0x4 // don't print log +#define DLOG_EVENT 0x10 // event log print level id + +/** + * @ingroup slog + * + * log mask + */ +#define DEBUG_LOG_MASK (0x00010000U) // print log to directory debug +#define SECURITY_LOG_MASK (0x00100000U) // print log to directory security +#define RUN_LOG_MASK (0x01000000U) // print log to directory run +#define STDOUT_LOG_MASK (0x10000000U) // print log to stdout + +#define LOG_SAVE_MODE_DEF (0x0U) // default +#define LOG_SAVE_MODE_UNI (0xFE756E69U) // unify save mode +#define LOG_SAVE_MODE_SEP (0xFE736570U) // separate save mode + +typedef enum { + APPLICATION = 0, + SYSTEM +} ProcessType; + +typedef struct { + ProcessType type; // process type + unsigned int pid; // pid + unsigned int deviceId; // device id + unsigned int mode; // log save mode + char reserved[48]; // reserve 48 bytes, align to 64 bytes +} LogAttr; + +/** + * @ingroup slog + * + * module id + * if a module needs to be added, add the module at the end and before INVLID_MOUDLE_ID + */ +#define ALL_MODULE (0x0000FFFFU) + +/** + * @ingroup slog + * @brief External log interface, which called by modules + */ +DLL_EXPORT void dlog_init(void); + +/** + * @ingroup slog + * @brief dlog_getlevel: get module debug loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), others: invalid + * @param [out]enableEvent: 1: enable; 0: disable + * @return: module level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + */ +DLL_EXPORT int32_t dlog_getlevel(int32_t moduleId, int32_t *enableEvent); + +/** + * @ingroup slog + * @brief dlog_setlevel: set module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), -1: all modules, others: invalid + * @param [in]level: log level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + * @param [in]enableEvent: 1: enable; 0: disable, others:invalid + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int32_t dlog_setlevel(int32_t moduleId, int32_t level, int32_t enableEvent); + +/** + * @ingroup slog + * @brief CheckLogLevel: check module level enable or not + * users no need to call it because all dlog interface(include inner interface) has already called + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG + * @return: 1:enable, 0:disable + */ +DLL_EXPORT int32_t CheckLogLevel(int32_t moduleId, int32_t logLevel); + +/** + * @ingroup : slog + * @brief : set log attr, default pid is 0, default device id is 0, default process type is APPLICATION + * @param [in] : logAttrInfo attr info, include pid(must be larger than 0), process type and device id(chip ID) + * @return : 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int32_t DlogSetAttr(LogAttr logAttrInfo); + +/** + * @ingroup : slog + * @brief : print log, need va_list variable, exec CheckLogLevel() before call this function + * @param[in] : moduleId module id, eg: CCE + * @param[in] : level (0: debug, 1: info, 2: warning, 3: error, 16: event) + * @param[in] : fmt log content + * @param[in] : list variable list of log content + */ +DLL_EXPORT void DlogVaList(int32_t moduleId, int32_t level, const char *fmt, va_list list); + +/** + * @ingroup slog + * @brief DlogFlush: flush log buffer to file + */ +DLL_EXPORT void DlogFlush(void); + +/** + * @ingroup slog + * @brief dlog_error: print error log + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_error(moduleId, fmt, ...) \ + do { \ + DlogRecord(moduleId, DLOG_ERROR, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief dlog_warn: print warning log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_warn(moduleId, fmt, ...) \ + do { \ + if (CheckLogLevel(moduleId, DLOG_WARN) == 1) { \ + DlogRecord(moduleId, DLOG_WARN, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief dlog_info: print info log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_info(moduleId, fmt, ...) \ + do { \ + if (CheckLogLevel(moduleId, DLOG_INFO) == 1) { \ + DlogRecord(moduleId, DLOG_INFO, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief dlog_debug: print debug log + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_debug(moduleId, fmt, ...) \ + do { \ + if (CheckLogLevel(moduleId, DLOG_DEBUG) == 1) { \ + DlogRecord(moduleId, DLOG_DEBUG, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief dlog_event: print event log + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]fmt: log content + */ +#define dlog_event(moduleId, fmt, ...) \ + do { \ + DlogRecord(moduleId, DLOG_EVENT, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief Dlog: print log, need caller to specify level + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 16: event) + * @param [in]fmt: log content + */ +#define Dlog(moduleId, level, fmt, ...) \ + do { \ + if (CheckLogLevel(moduleId, level) == 1) { \ + DlogRecord(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief DlogSub: print log, need caller to specify level and submodule + * call CheckLogLevel in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]submodule: eg: engine + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 16: event) + * @param [in]fmt: log content + */ +#define DlogSub(moduleId, submodule, level, fmt, ...) \ + do { \ + if (CheckLogLevel(moduleId, level) == 1) { \ + DlogRecord(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +// log interface +DLL_EXPORT void DlogRecord(int32_t moduleId, int32_t level, const char *fmt, ...) __attribute((weak)); + +#ifdef __cplusplus +#ifndef LOG_CPP +} +#endif // LOG_CPP +#endif // __cplusplus + +#ifdef LOG_CPP +#ifdef __cplusplus +extern "C" { +#endif +/** + * @ingroup slog + * @brief DlogGetlevelForC: get module debug loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), others: invalid + * @param [out]enableEvent: 1: enable; 0: disable + * @return: module level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + */ +DLL_EXPORT int DlogGetlevelForC(int moduleId, int *enableEvent); + +/** + * @ingroup slog + * @brief DlogSetlevelForC: set module loglevel and enableEvent + * + * @param [in]moduleId: moudule id(see slog.h, eg: CCE), -1: all modules, others: invalid + * @param [in]level: log level(0: debug, 1: info, 2: warning, 3: error, 4: null output) + * @param [in]enableEvent: 1: enable; 0: disable, others:invalid + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int32_t DlogSetlevelForC(int32_t moduleId, int32_t level, int32_t enableEvent); + +/** + * @ingroup slog + * @brief CheckLogLevelForC: check module level enable or not + * users no need to call it because all dlog interface(include inner interface) has already called + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]logLevel: eg: DLOG_EVENT/DLOG_ERROR/DLOG_WARN/DLOG_INFO/DLOG_DEBUG + * @return: 1:enable, 0:disable + */ +DLL_EXPORT int32_t CheckLogLevelForC(int32_t moduleId, int32_t logLevel); + +/** + * @ingroup slog + * @brief DlogSetAttrForC: set log attr, default pid is 0, default device id is 0, default process type is APPLICATION + * @param [in]logAttrInfo: attr info, include pid(must be larger than 0), process type and device id(chip ID) + * @return: 0: SUCCEED, others: FAILED + */ +DLL_EXPORT int32_t DlogSetAttrForC(LogAttr logAttrInfo); + +/** + * @ingroup slog + * @brief DlogForC: print log, need caller to specify level + * call CheckLogLevelForC in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 16: event) + * @param [in]fmt: log content + */ +#define DlogForC(moduleId, level, fmt, ...) \ + do { \ + if (CheckLogLevelForC(moduleId, level) == 1) { \ + DlogRecordForC(moduleId, level, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief DlogSubForC: print log, need caller to specify level and submodule + * call CheckLogLevelForC in advance to optimize performance, call interface with fmt input take time + * + * @param [in]moduleId: module id, eg: CCE + * @param [in]submodule: eg: engine + * @param [in]level(0: debug, 1: info, 2: warning, 3: error, 16: event) + * @param [in]fmt: log content + */ +#define DlogSubForC(moduleId, submodule, level, fmt, ...) \ + do { \ + if (CheckLogLevelForC(moduleId, level) == 1) { \ + DlogRecordForC(moduleId, level, "[%s:%d][%s]" fmt, __FILE__, __LINE__, submodule, ##__VA_ARGS__); \ + } \ + } while (TMP_LOG != 0) + +/** + * @ingroup slog + * @brief DlogFlushForC: flush log buffer to file + */ +DLL_EXPORT void DlogFlushForC(void); + +// log interface +DLL_EXPORT void DlogRecordForC(int32_t moduleId, int32_t level, const char *fmt, ...) __attribute((weak)); + +#ifdef __cplusplus +} +#endif +#endif // LOG_CPP +#endif // D_SYSLOG_H_ diff --git a/inc/tdt/data_common.h b/inc/tdt/data_common.h new file mode 100644 index 0000000000000000000000000000000000000000..a9b347c4a4e3dc2604b26884eef8294aca4817b5 --- /dev/null +++ b/inc/tdt/data_common.h @@ -0,0 +1,92 @@ +/** +* @file data_common.h +* +* Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved. +* +* This program is used to data structure +*/ + +#ifndef HOST_INNER_INC_DATA_COMMON_H_ +#define HOST_INNER_INC_DATA_COMMON_H_ +#include + +namespace tdt { +#ifndef TDT_DATA_TYPE +#define TDT_DATA_TYPE + +/** + * @ingroup Tdt data. + * + * Tdt data type. + */ +enum TdtDataType { + TDT_IMAGE_LABEL = 0, /**< Image label*/ + TDT_TFRECORD, /**< TF Record*/ + TDT_DATA_LABEL, /**< Data label*/ + TDT_END_OF_SEQUENCE, /**< End of Sequence*/ + TDT_TENSOR, /**< Tensor*/ + TDT_ABNORMAL, /**< ABNORMAL*/ + TDT_DATATYPE_MAX /**< Max*/ +}; +#endif + +/** + * @ingroup Tdt data. + * + * Tdt push data between host and device. + */ +struct TdtDataItem { + TdtDataType dataType_; /**< Input data type*/ + uint64_t label_; /**< Input data label*/ + uint64_t dataLen_; /**< Input data type length*/ + uint64_t realDataLen_; /**< Real Input data type length*/ + std::string tensorShape_; /**< Tensor shape*/ + std::string tensorType_; /**< Tensor type*/ + uint32_t cnt_; /**< Data count*/ + uint32_t currentCnt_; /**< Data current count*/ + uint64_t index_; /**< Data inde*/ + std::string tensorName_; /**< Tensor name*/ + uint64_t md5ValueHead_; /**< Data md5*/ + uint64_t md5ValueTail_; /**< Data md5*/ + std::shared_ptr dataPtr_; /**< Data pointer*/ + std::string headMD5_; /**< MD5 header, 8byte*/ + std::string tailMD5_; /**< MD5 tail, 8byte*/ +}; + +/** + * @ingroup Tdt data. + * + * Tdt push data for queuedataset ort mind-data. + */ +struct DataItem { + TdtDataType dataType_; /**< Input data type*/ + std::string tensorName_; /**< Tensor name*/ + std::string tensorShape_; /**< Tensor shape*/ + std::string tensorType_; /**< Tensor type*/ + uint64_t dataLen_; /**< Input data type length*/ + std::shared_ptr dataPtr_; /**< Data pointer*/ +}; + +/** + * @ingroup Tsdclient. + * + * tsdclient func type; + */ +enum TsdCmdType { + TSDCLOSE = 0, + TSDOPEN = 1 +}; + +/** + * @ingroup Tsdclient. + * + * tsdclient func input value object. + */ +enum InputItem { + OPEN_DEVICEID = 0, + OPEN_RANKSIZE, + CLOSE_DEVICEID +}; + +} // namespace tdt +#endif // HOST_INNER_INC_DATA_COMMON_H_ diff --git a/inc/tdt/tdt_host_interface.h b/inc/tdt/tdt_host_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..ea23211cae5d59ed4f977733164aacd0ee9fa315 --- /dev/null +++ b/inc/tdt/tdt_host_interface.h @@ -0,0 +1,202 @@ +/** +* @file tdt_host_interface.h +* +* Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved. +* +* This program is used to host server +*/ + +#ifndef HOST_INNER_INC_TDT_HOST_INTERFACE_H_ +#define HOST_INNER_INC_TDT_HOST_INTERFACE_H_ + +#include +#include +#include +#include "tdt/data_common.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +namespace tdt { +/** +* @ingroup TdtHostInit +* @brief Initialize the interface, start and initialize various general thread, log and other services +* +* @par Function +* Initialize the interface, start and initialize various general thread, log and other services +* +* @param deviceId [IN] type #unsigned int. Physical device ID +* @retval #0 Success +* @retval #Not 0 Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtHostInit(uint32_t deviceId); + +/** +* @ingroup TdtHostPushData +* @brief Blocking queue. When the queue is full, the Push interface will block. +* +* @par Function +* Blocking queue. When the queue is full, the Push interface will block. +* +* @param channelName [IN] type #String. queue channel name +* @param items [IN] type #vector DataItem is defined in data_common.h. input data +* @retval 0 Success +* @retval OtherValues 0 Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +* @li data_common.h: Header file where 'DataItem' defined +*/ +int32_t TdtHostPushData(const std::string &channelName, const std::vector &item, uint32_t deviceId = 0); + +/** +* @ingroup TdtHostDestroy +* @brief Notify TDT component to close related resources +* +* @par Function +* Notify TDT component to close related resources +* +* @param NA +* @retval 0 Success +* @retval OtherValues Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtHostDestroy(); + +/** +* @ingroup TdtHostPreparePopData +* @brief Prepare pop data from Tdt data storage queue +* +* @par Function +* Prepare pop data from Tdt data storage queue +* +* @param NA +* @retval 0 Success +* @retval OtherValues 0 Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +* @li data_common.h: Header file where 'DataItem' defined +*/ +int32_t TdtHostPreparePopData(); + +/** +* @ingroup TdtHostPopData +* @brief POP data from Tdt data storage queue +* +* @par Function +* POP data from Tdt data storage queue +* +* @param channelName [IN] type #String. queue channel name +* @param items [IN] type #vector DataItem is defined in data_common.h. input data +* @retval 0 Success +* @retval OtherValues 0 Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +* @li data_common.h: Header file where 'DataItem' defined +*/ +int32_t TdtHostPopData(const std::string &channelName, std::vector &item); + +/** +* @ingroup TdtHostStop +* @brief Activate the thread that reads data externally from Tdt and +* send end of sequence data so that the external thread can exit +* +* @par Function +* Activate the thread that reads data externally from Tdt and send +* end of sequence data so that the external thread can exit +* +* @param channelName [IN] type #String. queue channel name +* @retval 0 Success +* @retval OtherValues Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtHostStop(const std::string &channelName); + +/** +* @ingroup TdtInFeedInit +* @brief Initialize the interface, start and initialize various general thread, log and other services +* +* @par Function +* Initialize the interface, start and initialize various general thread, log and other services +* +* @param deviceId [IN] type #unsigned int. logic device ID +* @retval #0 Success +* @retval #Not 0 Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtInFeedInit(uint32_t deviceId); + +/** +* @ingroup TdtOutFeedInit +* @brief Initialize the interface, start and initialize various general thread, log and other services +* +* @par Function +* Initialize the interface, start and initialize various general thread, log and other services +* +* @param deviceId [IN] type #unsigned int. logic device ID +* @retval #0 Success +* @retval #Not 0 Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtOutFeedInit(uint32_t deviceId); + +/** +* @ingroup TdtInFeedDestroy +* @brief Notify TDT component to close related resources +* +* @par Function +* Notify TDT component to close related resources +* +* @param NA +* @retval 0 Success +* @retval OtherValues Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtInFeedDestroy(uint32_t deviceId); + +/** +* @ingroup TdtOutFeedDestroy +* @brief Notify TDT component to close related resources +* +* @par Function +* Notify TDT component to close related resources +* +* @param NA +* @retval 0 Success +* @retval OtherValues Fail +* +* @par Dependency +* @li libtsdclient.so: Library to which the interface belongs. +* @li tdt_host_interface.h: Header file where the interface declaration is located. +*/ +int32_t TdtOutFeedDestroy(); +} // namespace tdt +#ifdef __cplusplus +} +#endif // __cplusplus +#endif // HOST_INNER_INC_TDT_HOST_INTERFACE_H_ diff --git a/tf_adapter_2.x/CI_Build b/tf_adapter_2.x/CI_Build old mode 100644 new mode 100755 diff --git a/tf_adapter_2.x/CMakeLists.txt b/tf_adapter_2.x/CMakeLists.txt index fe80f7db0c5a2fcc4dd9528e120f9f7444a3e25e..b7184a296a2cad6e957f6d039f8c689f2efbbe1c 100644 --- a/tf_adapter_2.x/CMakeLists.txt +++ b/tf_adapter_2.x/CMakeLists.txt @@ -11,6 +11,11 @@ IF (${CMAKE_CXX_COMPILER_ID} STREQUAL GNU) add_definitions(-Wno-builtin-macro-redefined) ENDIF () +if(DEFINED ENV{D_PKG_SERVER}) + set(TF_PKG_SERVER $ENV{D_PKG_SERVER}) + message("Download packages from PKG server") +endif() + if (DEFINED ASCEND_CI_BUILD_DIR) set(CMAKE_C_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 ${CMAKE_C_FLAGS}") set(CMAKE_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=0 ${CMAKE_CXX_FLAGS}") diff --git a/tf_adapter_2.x/build_tf2.sh b/tf_adapter_2.x/build_tf2.sh new file mode 100755 index 0000000000000000000000000000000000000000..1465381be330684efc730c33c57109752ba9feac --- /dev/null +++ b/tf_adapter_2.x/build_tf2.sh @@ -0,0 +1,28 @@ +#!/bin/bash + +CONFIGURE_DIR=$(dirname "$0") +cd "${CONFIGURE_DIR}" + +if [ "$1x" != "x" ];then + python_version=$1 +fi + +if [ -z "$python_version" ]; then + PYTHON_BIN_PATH=$(which python3) +else + PYTHON_BIN_PATH=$(which $python_version) +fi + +export ADAPTER_TARGET_PYTHON_PATH="$PYTHON_BIN_PATH" +export ASCEND_INSTALLED_PATH="$ASCEND_HOME_PATH" + +echo "ADAPTER_TARGET_PYTHON_PATH set to $ADAPTER_TARGET_PYTHON_PATH" + +"$PYTHON_BIN_PATH" "configure.py" "$@" + +echo "Configuration finished" + +mkdir -p build +cd build +cmake .. +make -j8 diff --git a/tf_adapter_2.x/cmake/aoe/module.cmake b/tf_adapter_2.x/cmake/aoe/module.cmake index 513e2c5e56e396f8e4086824b8a689af2e9463fd..7b21ab0a80993be17cb76140bb4f33551a4d6f82 100644 --- a/tf_adapter_2.x/cmake/aoe/module.cmake +++ b/tf_adapter_2.x/cmake/aoe/module.cmake @@ -1,7 +1,7 @@ add_library(aoe_libs INTERFACE) if (DEFINED ASCEND_INSTALLED_PATH) - include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/aoe) + include_directories(${CMAKE_SOURCE_DIR}/../inc/aoe) target_link_libraries(aoe_libs INTERFACE ${ASCEND_INSTALLED_PATH}/tools/aoe/lib64/libaoe_tuning.so) else () @@ -18,4 +18,4 @@ else () add_library(aoe_tuning SHARED ${fake_sources}) target_link_libraries(aoe_libs INTERFACE aoe_tuning) -endif () \ No newline at end of file +endif () diff --git a/tf_adapter_2.x/cmake/graph_engine/module.cmake b/tf_adapter_2.x/cmake/graph_engine/module.cmake index 7f042725edee54f7add11b0aa97ad4adec4b2aaf..e3e5013aac650594ebc9d2fa98d480453a0f71ac 100644 --- a/tf_adapter_2.x/cmake/graph_engine/module.cmake +++ b/tf_adapter_2.x/cmake/graph_engine/module.cmake @@ -1,10 +1,10 @@ add_library(ge_libs INTERFACE) if(DEFINED ASCEND_INSTALLED_PATH) - include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/air) - include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/air/external) - include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/metadef) - include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/metadef/external) + include_directories(${CMAKE_SOURCE_DIR}/../inc/air) + include_directories(${CMAKE_SOURCE_DIR}/../inc/air/external) + include_directories(${CMAKE_SOURCE_DIR}/../inc/metadef) + include_directories(${CMAKE_SOURCE_DIR}/../inc/metadef/external) target_link_libraries(ge_libs INTERFACE ${ASCEND_INSTALLED_PATH}/compiler/lib64/libge_runner.so ${ASCEND_INSTALLED_PATH}/compiler/lib64/libfmk_parser.so) diff --git a/tf_adapter_2.x/cmake/secure_c.cmake b/tf_adapter_2.x/cmake/secure_c.cmake index 0801ef458c6b3f7111fe9835d26487742b6d8d59..14a7df3c7d5d4fd172c28b69ed114883c6dc8983 100644 --- a/tf_adapter_2.x/cmake/secure_c.cmake +++ b/tf_adapter_2.x/cmake/secure_c.cmake @@ -3,7 +3,7 @@ if (DEFINED ASCEND_INSTALLED_PATH) if (TF_PKG_SERVER) FetchContent_Declare( secure_c - URL ${TF_PKG_SERVER}/libs/securec/v1.1.10.tar.gz + URL "${TF_PKG_SERVER}/libs/securec/v1.1.10.tar.gz" ) else () FetchContent_Declare( @@ -19,4 +19,4 @@ if (DEFINED ASCEND_INSTALLED_PATH) else () include_directories(${ASCEND_CI_BUILD_DIR}/libc_sec/include/) include_directories(${ASCEND_CI_BUILD_DIR}/abl/libc_sec/include/) -endif () \ No newline at end of file +endif () diff --git a/tf_adapter_2.x/compat_v1/CMakeLists.txt b/tf_adapter_2.x/compat_v1/CMakeLists.txt index b75925d17a9e02b6bd71624835b92d86226783cc..90652937f93fe5971b037d54d23b46a24252f93f 100644 --- a/tf_adapter_2.x/compat_v1/CMakeLists.txt +++ b/tf_adapter_2.x/compat_v1/CMakeLists.txt @@ -61,11 +61,10 @@ else () ${CMAKE_CURRENT_BINARY_DIR} ${ADAPTER_ROOT}/ ${ADAPTER_ROOT}/inc/ - ${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include - ${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/aoe - ${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/slog - ${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/runtime - ${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/msprof + ${ADAPTER_ROOT}/inc/aoe + ${ADAPTER_ROOT}/inc/slog + ${ADAPTER_ROOT}/inc/runtime + ${ADAPTER_ROOT}/inc/msprof ) endif () diff --git a/tf_adapter_2.x/tests/CMakeLists.txt b/tf_adapter_2.x/tests/CMakeLists.txt index 4ca02cae93dbe72a2ce66acb021f4f029514358e..a0725a09f2d4dbd6d63eeeb8da843c8cb314b526 100644 --- a/tf_adapter_2.x/tests/CMakeLists.txt +++ b/tf_adapter_2.x/tests/CMakeLists.txt @@ -30,8 +30,8 @@ foreach (COMPILE_FLAG ${CUSTOM_COMPILE_FLAGS}) set(CMAKE_CXX_FLAGS "${COMPILE_FLAG} ${CMAKE_CXX_FLAGS}") endforeach (COMPILE_FLAG) -include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/ascendcl/external) # just for acl -include_directories(${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/aoe) +include_directories(${ASCEND_INSTALLED_PATH}/include/acl) # just for acl +include_directories(${CMAKE_SOURCE_DIR}/../inc/aoe) include_directories(${CMAKE_CURRENT_LIST_DIR}/stub/include) include_directories(${CMAKE_CURRENT_LIST_DIR}/../npu_device/core) diff --git a/tf_adapter_2.x/tests/cmake/acl/module.cmake b/tf_adapter_2.x/tests/cmake/acl/module.cmake index 5a05d64d2d0e5d9adb69fb622d10e52b30ffcfe6..b5124ea53352cd1b001a953c2bd9fe0d16e7d758 100644 --- a/tf_adapter_2.x/tests/cmake/acl/module.cmake +++ b/tf_adapter_2.x/tests/cmake/acl/module.cmake @@ -3,7 +3,7 @@ add_library(acl_libs INTERFACE) add_library(acl_stub STATIC ${CMAKE_CURRENT_LIST_DIR}/../../stub/acl_stub.cpp) target_include_directories(acl_stub PRIVATE - ${ASCEND_INSTALLED_PATH}/opensdk/opensdk/include/ascendcl/external + ${ASCEND_INSTALLED_PATH}/include/acl ${CMAKE_CURRENT_LIST_DIR}/../../stub/include) target_link_libraries(acl_libs INTERFACE acl_stub) diff --git a/tf_adapter_2.x/tests/cmake/gtest.cmake b/tf_adapter_2.x/tests/cmake/gtest.cmake index 2419f01e71ed6ff8e86fe5b4bacac051a119b428..81d0f56b68f3e1d8a1d0450b37110e3737607232 100644 --- a/tf_adapter_2.x/tests/cmake/gtest.cmake +++ b/tf_adapter_2.x/tests/cmake/gtest.cmake @@ -2,8 +2,7 @@ include(FetchContent) if(TF_PKG_SERVER) FetchContent_Declare( googletest - URL ${TF_PKG_SERVER}/libs/ge_gtest/release-1.8.1.tar.gz - URL_HASH MD5=2e6fbeb6a91310a16efe181886c59596 + URL "${TF_PKG_SERVER}/libs/ge_gtest/release-1.8.1.tar.gz" ) else() FetchContent_Declare( diff --git a/tf_adapter_2.x/tests/cmake/secure_c.cmake b/tf_adapter_2.x/tests/cmake/secure_c.cmake index 0bcd07c2cc416cab2d6ab66a4234f6c8c4f0df8b..aac116ee78da3ed5ad5fac82ddac9e7984b5dca2 100644 --- a/tf_adapter_2.x/tests/cmake/secure_c.cmake +++ b/tf_adapter_2.x/tests/cmake/secure_c.cmake @@ -2,7 +2,7 @@ include(FetchContent) if (TF_PKG_SERVER) FetchContent_Declare( secure_c - URL ${TF_PKG_SERVER}/libs/securec/v1.1.10.tar.gz + URL "${TF_PKG_SERVER}/libs/securec/v1.1.10.tar.gz" ) else () FetchContent_Declare( @@ -17,4 +17,4 @@ if (NOT secure_c_POPULATED) include_directories(${secure_c_SOURCE_DIR}/include) file(GLOB SEC_C_SRCS ${secure_c_SOURCE_DIR}/src/*.c) add_library(c_sec SHARED ${SEC_C_SRCS}) -endif () \ No newline at end of file +endif ()