From 9fe4239661c94c832840330b53fc5ac9f9c4ee7f Mon Sep 17 00:00:00 2001 From: kai-ma Date: Fri, 4 Jul 2025 14:09:01 +0800 Subject: [PATCH] add csrc/acl --- .../msprobe/csrc/acl/core/AclApi.cpp | 411 ++++++++++++++++++ .../msprobe/csrc/acl/include/AclApi.h | 152 +++++++ 2 files changed, 563 insertions(+) create mode 100644 accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp create mode 100644 accuracy_tools/msprobe/csrc/acl/include/AclApi.h diff --git a/accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp b/accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp new file mode 100644 index 00000000000..06ad54738ef --- /dev/null +++ b/accuracy_tools/msprobe/csrc/acl/core/AclApi.cpp @@ -0,0 +1,411 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "acl/include/AclApi.h" + +#include "utils/Exception.h" +#include "utils/Log.h" + +namespace Ascendcl { + constexpr const char *ascendAclName = "libascendcl.so"; + + using aclInitFuncType = aclError (*)(const char *); + using aclrtSetDeviceFuncType = aclError (*)(int32_t); + using aclrtCreateContextFuncType = aclError (*)(aclrtContext *, int32_t); + using aclmdlLoadFromFileFuncType = aclError (*)(const char *, uint32_t *); + using aclmdlCreateDescFuncType = aclmdlDesc *(*)(); + using aclmdlGetDescFuncType = aclError (*)(aclmdlDesc *, uint32_t); + using aclmdlGetInputSizeByIndexFuncType = size_t (*)(aclmdlDesc *, size_t); + using aclmdlGetOutputSizeByIndexFuncType = size_t (*)(aclmdlDesc *, size_t); + using aclmdlGetNumInputsFuncType = size_t (*)(aclmdlDesc *); + using aclmdlGetNumOutputsFuncType = size_t (*)(aclmdlDesc *); + using aclmdlGetInputNameByIndexFuncType = const char *(*)(const aclmdlDesc *, size_t); + using aclmdlGetInputDataTypeFuncType = aclDataType (*)(const aclmdlDesc *, size_t); + using aclmdlGetInputDimsFuncType = aclError (*)(const aclmdlDesc *, size_t, aclmdlIODims *); + using aclrtMallocFuncType = aclError (*)(void **, size_t, aclrtMemMallocPolicy); + using aclmdlCreateDatasetFuncType = aclmdlDataset *(*)(); + using aclCreateDataBufferFuncType = aclDataBuffer *(*)(void *, size_t); + using aclmdlAddDatasetBufferFuncType = aclError (*)(aclmdlDataset *, aclDataBuffer *); + using aclmdlExecuteFuncType = aclError (*)(uint32_t, const aclmdlDataset *, aclmdlDataset *); + using aclrtMemcpyFuncType = aclError (*)(void *, size_t, const void *, size_t, aclrtMemcpyKind); + using aclmdlGetDatasetNumBuffersFuncType = size_t (*)(const aclmdlDataset *); + using aclmdlGetDatasetBufferFuncType = aclDataBuffer *(*)(const aclmdlDataset *, size_t); + using aclDestroyDataBufferFuncType = aclError (*)(const aclDataBuffer *); + using aclmdlDestroyDatasetFuncType = aclError (*)(const aclmdlDataset *); + using aclrtFreeFuncType = aclError (*)(void *); + using aclFinalizeFuncType = aclError (*)(); + using aclmdlUnloadFuncType = aclError (*)(uint32_t); + using aclmdlDestroyDescFuncType = aclError (*)(aclmdlDesc *); + using aclrtDestroyContextFuncType = aclError (*)(aclrtContext); + using aclrtResetDeviceFuncType = aclError (*)(int32_t); + using aclmdlInitDumpFuncType = aclError (*)(); + using aclmdlSetDumpFuncType = aclError (*)(const char *); + using aclmdlFinalizeDumpFuncType = aclError (*)(); + using acldumpRegCallbackFuncType = aclError (*)(AclDumpCallbackFuncType, int32_t); + using acldumpUnregCallbackFuncType = void (*)(); + + static aclInitFuncType aclInitFunc = nullptr; + static aclrtSetDeviceFuncType aclrtSetDeviceFunc = nullptr; + static aclrtCreateContextFuncType aclrtCreateContextFunc = nullptr; + static aclmdlLoadFromFileFuncType aclmdlLoadFromFileFunc = nullptr; + static aclmdlCreateDescFuncType aclmdlCreateDescFunc = nullptr; + static aclmdlGetDescFuncType aclmdlGetDescFunc = nullptr; + static aclmdlGetInputSizeByIndexFuncType aclmdlGetInputSizeByIndexFunc = nullptr; + static aclmdlGetOutputSizeByIndexFuncType aclmdlGetOutputSizeByIndexFunc = nullptr; + static aclmdlGetNumInputsFuncType aclmdlGetNumInputsFunc = nullptr; + static aclmdlGetNumOutputsFuncType aclmdlGetNumOutputsFunc = nullptr; + static aclmdlGetInputNameByIndexFuncType aclmdlGetInputNameByIndexFunc = nullptr; + static aclmdlGetInputDataTypeFuncType aclmdlGetInputDataTypeFunc = nullptr; + static aclmdlGetInputDimsFuncType aclmdlGetInputDimsFunc = nullptr; + static aclrtMallocFuncType aclrtMallocFunc = nullptr; + static aclmdlCreateDatasetFuncType aclmdlCreateDatasetFunc = nullptr; + static aclCreateDataBufferFuncType aclCreateDataBufferFunc = nullptr; + static aclmdlAddDatasetBufferFuncType aclmdlAddDatasetBufferFunc = nullptr; + static aclmdlExecuteFuncType aclmdlExecuteFunc = nullptr; + static aclrtMemcpyFuncType aclrtMemcpyFunc = nullptr; + static aclmdlGetDatasetNumBuffersFuncType aclmdlGetDatasetNumBuffersFunc = nullptr; + static aclmdlGetDatasetBufferFuncType aclmdlGetDatasetBufferFunc = nullptr; + static aclDestroyDataBufferFuncType aclDestroyDataBufferFunc = nullptr; + static aclmdlDestroyDatasetFuncType aclmdlDestroyDatasetFunc = nullptr; + static aclrtFreeFuncType aclrtFreeFunc = nullptr; + static aclFinalizeFuncType aclFinalizeFunc = nullptr; + static aclmdlUnloadFuncType aclmdlUnloadFunc = nullptr; + static aclmdlDestroyDescFuncType aclmdlDestroyDescFunc = nullptr; + static aclrtDestroyContextFuncType aclrtDestroyContextFunc = nullptr; + static aclrtResetDeviceFuncType aclrtResetDeviceFunc = nullptr; + static aclmdlInitDumpFuncType aclmdlInitDumpFunc = nullptr; + static aclmdlSetDumpFuncType aclmdlSetDumpFunc = nullptr; + static aclmdlFinalizeDumpFuncType aclmdlFinalizeDumpFunc = nullptr; + static acldumpRegCallbackFuncType acldumpRegCallbackFunc = nullptr; + static acldumpUnregCallbackFuncType acldumpUnregCallbackFunc = nullptr; + + const std::map functionMap = { + {"aclInit", reinterpret_cast(&aclInitFunc)}, + {"aclrtSetDevice", reinterpret_cast(&aclrtSetDeviceFunc)}, + {"aclrtCreateContext", reinterpret_cast(&aclrtCreateContextFunc)}, + {"aclmdlLoadFromFile", reinterpret_cast(&aclmdlLoadFromFileFunc)}, + {"aclmdlCreateDesc", reinterpret_cast(&aclmdlCreateDescFunc)}, + {"aclmdlGetDesc", reinterpret_cast(&aclmdlGetDescFunc)}, + {"aclmdlGetNumInputs", reinterpret_cast(&aclmdlGetNumInputsFunc)}, + {"aclmdlGetNumOutputs", reinterpret_cast(&aclmdlGetNumOutputsFunc)}, + {"aclmdlGetInputNameByIndex", reinterpret_cast(&aclmdlGetInputNameByIndexFunc)}, + {"aclmdlGetInputSizeByIndex", reinterpret_cast(&aclmdlGetInputSizeByIndexFunc)}, + {"aclmdlGetInputDataType", reinterpret_cast(&aclmdlGetInputDataTypeFunc)}, + {"aclmdlGetInputDims", reinterpret_cast(&aclmdlGetInputDimsFunc)}, + {"aclmdlGetOutputSizeByIndex", reinterpret_cast(&aclmdlGetOutputSizeByIndexFunc)}, + {"aclrtMalloc", reinterpret_cast(&aclrtMallocFunc)}, + {"aclmdlCreateDataset", reinterpret_cast(&aclmdlCreateDatasetFunc)}, + {"aclCreateDataBuffer", reinterpret_cast(&aclCreateDataBufferFunc)}, + {"aclmdlAddDatasetBuffer", reinterpret_cast(&aclmdlAddDatasetBufferFunc)}, + {"aclmdlExecute", reinterpret_cast(&aclmdlExecuteFunc)}, + {"aclrtMemcpy", reinterpret_cast(&aclrtMemcpyFunc)}, + {"aclmdlGetDatasetNumBuffers", reinterpret_cast(&aclmdlGetDatasetNumBuffersFunc)}, + {"aclmdlGetDatasetBuffer", reinterpret_cast(&aclmdlGetDatasetBufferFunc)}, + {"aclDestroyDataBuffer", reinterpret_cast(&aclDestroyDataBufferFunc)}, + {"aclmdlDestroyDataset", reinterpret_cast(&aclmdlDestroyDatasetFunc)}, + {"aclrtFree", reinterpret_cast(&aclrtFreeFunc)}, + {"aclFinalize", reinterpret_cast(&aclFinalizeFunc)}, + {"aclmdlUnload", reinterpret_cast(&aclmdlUnloadFunc)}, + {"aclmdlDestroyDesc", reinterpret_cast(&aclmdlDestroyDescFunc)}, + {"aclrtDestroyContext", reinterpret_cast(&aclrtDestroyContextFunc)}, + {"aclrtResetDevice", reinterpret_cast(&aclrtResetDeviceFunc)}, + {"aclmdlInitDump", reinterpret_cast(&aclmdlInitDumpFunc)}, + {"aclmdlSetDump", reinterpret_cast(&aclmdlSetDumpFunc)}, + {"aclmdlFinalizeDump", reinterpret_cast(&aclmdlFinalizeDumpFunc)}, + {"acldumpRegCallback", reinterpret_cast(&acldumpRegCallbackFunc)}, + {"acldumpUnregCallback", reinterpret_cast(&acldumpUnregCallbackFunc)}, + }; + + AclApi &AclApi::GetInstance() { + static AclApi instance; + return instance; + } + + AclApi::AclApi() { + LoadAclApi(); + } + + void AclApi::LoadAclApi() { + static void *libAscendcl = nullptr; + + if (libAscendcl != nullptr) { + LOG_ERROR << "No need to load acl api again."; + return; + } + libAscendcl = dlopen(ascendAclName, RTLD_LAZY); + if (libAscendcl == nullptr) { + LOG_ERROR << "Failed to search libascendcl.so. " << dlerror(); + return; + } + for (auto &iter : functionMap) { + if (*(iter.second) != nullptr) { + continue; + } + *(iter.second) = dlsym(libAscendcl, iter.first); + if (*(iter.second) == nullptr) { + LOG_ERROR << "Failed to load function " << iter.first << " from libascendcl.so. " << dlerror(); + dlclose(libAscendcl); + libAscendcl = nullptr; + return; + } + LOG_DEBUG << "Load function " << iter.first << " from libascendcl.so."; + } + } + + aclError AclApi::ACLAPI_AclInit(const char *cfg) { + if (aclInitFunc == nullptr) { + throw Utility::MsprobeException("API aclInit does not have a definition."); + } + return aclInitFunc(cfg); + } + + aclError AclApi::ACLAPI_AclRtSetDevice(int32_t deviceId) { + if (aclrtSetDeviceFunc == nullptr) { + throw Utility::MsprobeException("API aclrtSetDevice does not have a definition."); + } + return aclrtSetDeviceFunc(deviceId); + } + + aclError AclApi::ACLAPI_AclRtCreateContext(aclrtContext *context, int32_t deviceId) { + if (aclrtCreateContextFunc == nullptr) { + throw Utility::MsprobeException("API aclrtCreateContext does not have a definition."); + } + return aclrtCreateContextFunc(context, deviceId); + } + + LoadFileResult AclApi::ACLAPI_AclMdlLoadFromFile(const char *modelPath) { + uint32_t modelId; + if (aclmdlLoadFromFileFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlLoadFromFile does not have a definition."); + } + aclError ret = aclmdlLoadFromFileFunc(modelPath, &modelId); + LoadFileResult loadFileResult; + loadFileResult.modelId = modelId; + loadFileResult.ret = ret; + return loadFileResult; + } + + aclmdlDesc *AclApi::ACLAPI_AclMdlCreateDesc() { + if (aclmdlCreateDescFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlCreateDesc does not have a definition."); + } + return aclmdlCreateDescFunc(); + } + + aclError AclApi::ACLAPI_AclMdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId) { + if (aclmdlGetDescFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetDesc does not have a definition."); + } + return aclmdlGetDescFunc(modelDesc, modelId); + } + + size_t AclApi::ACLAPI_AclMdlGetNumInputs(aclmdlDesc *modelDesc) { + if (aclmdlGetNumInputsFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetNumInputs does not have a definition."); + } + return aclmdlGetNumInputsFunc(modelDesc); + } + + size_t AclApi::ACLAPI_AclMdlGetNumOutputs(aclmdlDesc *modelDesc) { + if (aclmdlGetNumOutputsFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetNumInputs does not have a definition."); + } + return aclmdlGetNumOutputsFunc(modelDesc); + } + + const char *AclApi::ACLAPI_AclMdlGetInputNameByIndex(const aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetInputNameByIndexFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputNameByIndex does not have a definition."); + } + return aclmdlGetInputNameByIndexFunc(modelDesc, index); + } + + size_t AclApi::ACLAPI_AclMdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetInputSizeByIndexFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputSizeByIndex does not have a definition."); + } + return aclmdlGetInputSizeByIndexFunc(modelDesc, index); + } + + size_t AclApi::ACLAPI_AclMdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetOutputSizeByIndexFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputSizeByIndex does not have a definition."); + } + return aclmdlGetOutputSizeByIndexFunc(modelDesc, index); + } + + aclDataType AclApi::ACLAPI_AclMdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index) { + if (aclmdlGetInputDataTypeFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputDataType does not have a definition."); + } + return aclmdlGetInputDataTypeFunc(modelDesc, index); + } + + aclError AclApi::ACLAPI_AclMdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims) { + if (aclmdlGetInputDimsFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetInputDims does not have a definition."); + } + return aclmdlGetInputDimsFunc(modelDesc, index, dims); + } + + aclError AclApi::ACLAPI_AclRtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy) { + if (aclrtMallocFunc == nullptr) { + throw Utility::MsprobeException("API aclrtMalloc does not have a definition."); + } + return aclrtMallocFunc(devPtr, size, policy); + } + + aclmdlDataset *AclApi::ACLAPI_AclMdlCreateDataset() { + if (aclmdlCreateDatasetFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlCreateDataset does not have a definition."); + } + return aclmdlCreateDatasetFunc(); + } + + aclDataBuffer *AclApi::ACLAPI_AclCreateDataBuffer(void *data, size_t size) { + if (aclCreateDataBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclCreateDataBuffer does not have a definition."); + } + return aclCreateDataBufferFunc(data, size); + } + + aclError AclApi::ACLAPI_AclMdlAddDatasetBuffer(aclmdlDataset *dataset, aclDataBuffer *databuffer) { + if (aclmdlAddDatasetBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlAddDatasetBuffer does not have a definition."); + } + return aclmdlAddDatasetBufferFunc(dataset, databuffer); + } + + aclError AclApi::ACLAPI_AclMdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output) { + if (aclmdlExecuteFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlExecute does not have a definition."); + } + return aclmdlExecuteFunc(modelId, input, output); + } + + aclError + AclApi::ACLAPI_AclRtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind) { + if (aclrtMemcpyFunc == nullptr) { + throw Utility::MsprobeException("API aclrtMemcpy does not have a definition."); + } + return aclrtMemcpyFunc(dst, destMax, src, count, kind); + } + + size_t AclApi::ACLAPI_AclMdlGetDatasetNumBuffers(const aclmdlDataset *dataset) { + if (aclmdlGetDatasetNumBuffersFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetDatasetNumBuffers does not have a definition."); + } + return aclmdlGetDatasetNumBuffersFunc(dataset); + } + + aclDataBuffer *AclApi::ACLAPI_AclMdlGetDatasetBuffer(const aclmdlDataset *dataset, size_t index) { + if (aclmdlGetDatasetBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlGetDatasetBuffer does not have a definition."); + } + return aclmdlGetDatasetBufferFunc(dataset, index); + } + + aclError AclApi::ACLAPI_AclDestroyDataBuffer(const aclDataBuffer *dataBuffer) { + if (aclDestroyDataBufferFunc == nullptr) { + throw Utility::MsprobeException("API aclDestroyDataBuffer does not have a definition."); + } + return aclDestroyDataBufferFunc(dataBuffer); + } + + aclError AclApi::ACLAPI_AclMdlDestroyDataset(const aclmdlDataset *dataset) { + if (aclmdlDestroyDatasetFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlDestroyDataset does not have a definition."); + } + return aclmdlDestroyDatasetFunc(dataset); + } + + aclError AclApi::ACLAPI_AclRtFree(void *devPtr) { + if (aclrtFreeFunc == nullptr) { + throw Utility::MsprobeException("API aclrtFree does not have a definition."); + } + return aclrtFreeFunc(devPtr); + } + + aclError AclApi::ACLAPI_AclFinalize() { + if (aclFinalizeFunc == nullptr) { + throw Utility::MsprobeException("API aclFinalize does not have a definition."); + } + return aclFinalizeFunc(); + } + + aclError AclApi::ACLAPI_AclMdlUnload(uint32_t modelId) { + if (aclmdlUnloadFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlUnload does not have a definition."); + } + return aclmdlUnloadFunc(modelId); + } + + aclError AclApi::ACLAPI_AclMdlDestroyDesc(aclmdlDesc *modelDesc) { + if (aclmdlDestroyDescFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlDestroyDesc does not have a definition."); + } + return aclmdlDestroyDescFunc(modelDesc); + } + + aclError AclApi::ACLAPI_AclRtDestroyContext(aclrtContext context) { + if (aclrtDestroyContextFunc == nullptr) { + throw Utility::MsprobeException("API aclrtDestroyContext does not have a definition."); + } + return aclrtDestroyContextFunc(context); + } + + aclError AclApi::ACLAPI_AclRtResetDevice(int32_t deviceId) { + if (aclrtResetDeviceFunc == nullptr) { + throw Utility::MsprobeException("API aclrtResetDevice does not have a definition."); + } + return aclrtResetDeviceFunc(deviceId); + } + + aclError AclApi::ACLAPI_AclInitDump() { + if (aclmdlInitDumpFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlInitDump does not have a definition."); + } + return aclmdlInitDumpFunc(); + } + + aclError AclApi::ACLAPI_AclSetDump(const char *dumpCfgPath) { + if (aclmdlSetDumpFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlSetDump does not have a definition."); + } + return aclmdlSetDumpFunc(dumpCfgPath); + } + + aclError AclApi::ACLAPI_AclFinalizeDump() { + if (aclmdlFinalizeDumpFunc == nullptr) { + throw Utility::MsprobeException("API aclmdlFinalizeDump does not have a definition."); + } + return aclmdlFinalizeDumpFunc(); + } + + aclError AclApi::ACLAPI_AclDumpRegCallBack(AclDumpCallbackFuncType messageCallback, int32_t flag) { + if (acldumpRegCallbackFunc == nullptr) { + throw Utility::MsprobeException("API acldumpRegCallback does not have a definition."); + } + return acldumpRegCallbackFunc(messageCallback, flag); + } + + void AclApi::ACLAPI_AclDumpUnregCallBack() { + if (acldumpUnregCallbackFunc == nullptr) { + throw Utility::MsprobeException("API acldumpUnregCallback does not have a definition."); + } + acldumpUnregCallbackFunc(); + } +} // namespace Ascendcl diff --git a/accuracy_tools/msprobe/csrc/acl/include/AclApi.h b/accuracy_tools/msprobe/csrc/acl/include/AclApi.h new file mode 100644 index 00000000000..967733ac322 --- /dev/null +++ b/accuracy_tools/msprobe/csrc/acl/include/AclApi.h @@ -0,0 +1,152 @@ +/* + * Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACLDUMP_ACLAPI_H +#define ACLDUMP_ACLAPI_H + +#include +#include +#include +#include +#include + +#include "utils/Path.h" + +#define ACL_MAX_DIM_CNT 128 +#define ACL_MAX_TENSOR_NAME_LEN 128 + +extern "C" { +using aclError = int; +using aclrtContext = void *; +using aclmdlDesc = struct aclmdlDesc; +using aclmdlDataset = struct aclmdlDataset; +using aclDataBuffer = struct aclDataBuffer; + +typedef enum { + ACL_DT_UNDEFINED = -1, // 未知数据类型,默认值 + ACL_FLOAT = 0, + ACL_FLOAT16 = 1, + ACL_INT8 = 2, + ACL_INT32 = 3, + ACL_UINT8 = 4, + ACL_INT16 = 6, + ACL_UINT16 = 7, + ACL_UINT32 = 8, + ACL_INT64 = 9, + ACL_UINT64 = 10, + ACL_DOUBLE = 11, + ACL_BOOL = 12, + ACL_STRING = 13, + ACL_COMPLEX64 = 16, + ACL_COMPLEX128 = 17, + ACL_BF16 = 27, + ACL_INT4 = 29, + ACL_UINT1 = 30, + ACL_COMPLEX32 = 33, +} aclDataType; + +typedef enum aclrtMemMallocPolicy { + ACL_MEM_MALLOC_HUGE_FIRST, + ACL_MEM_MALLOC_HUGE_ONLY, + ACL_MEM_MALLOC_NORMAL_ONLY, + ACL_MEM_MALLOC_HUGE_FIRST_P2P, + ACL_MEM_MALLOC_HUGE_ONLY_P2P, + ACL_MEM_MALLOC_NORMAL_ONLY_P2P, + ACL_MEM_TYPE_LOW_BAND_WIDTH = 0x0100, + ACL_MEM_TYPE_HIGH_BAND_WIDTH = 0x1000 +} aclrtMemMallocPolicy; + +typedef enum aclrtMemcpyKind { + ACL_MEMCPY_HOST_TO_HOST, // Host内的内存复制 + ACL_MEMCPY_HOST_TO_DEVICE, // Host到Device的内存复制 + ACL_MEMCPY_DEVICE_TO_HOST, // Device到Host的内存复制 + ACL_MEMCPY_DEVICE_TO_DEVICE, // Device内或Device间的内存复制 + ACL_MEMCPY_DEFAULT, // 由系统根据源、目的内存地址自行判断拷贝方向 +} aclrtMemcpyKind; + +typedef struct aclmdlIODims { + char name[ACL_MAX_TENSOR_NAME_LEN]; /**< tensor name */ + size_t dimCount; /**Shape中的维度个数,如果为标量,则维度个数为0*/ + int64_t dims[ACL_MAX_DIM_CNT]; /**< 维度信息 */ +} aclmdlIODims; + +typedef struct acldumpChunk { + char fileName[Utility::SafePath::MAX_PATH_LENGTH]; // 待落盘的Dump数据文件名 + uint32_t bufLen; // dataBuf数据长度,单位Byte + uint32_t isLastChunk; // 标识Dump数据是否为最后一个分片,0表示不是最后一个分片,1表示最后一个分片 + int64_t offset; // Dump数据文件内容的偏移,其中-1表示文件追加内容 + int32_t flag; // 预留Dump数据标识,当前数据无标识 + uint8_t dataBuf[0]; // Dump数据的内存地址 +} acldumpChunk; +} + +using AclDumpCallbackFuncType = int32_t (*)(const acldumpChunk *, int32_t); + +namespace Ascendcl { + typedef struct LoadFileResult { + uint32_t modelId; + aclError ret; + } LoadFileResult; + + class AclApi { + public: + static AclApi &GetInstance(); + AclApi(); + aclError ACLAPI_AclInit(const char *cfg = nullptr); + aclError ACLAPI_AclRtSetDevice(int32_t deviceId); + aclError ACLAPI_AclRtCreateContext(aclrtContext *context, int32_t deviceId); + LoadFileResult ACLAPI_AclMdlLoadFromFile(const char *modelPath); + aclmdlDesc *ACLAPI_AclMdlCreateDesc(); + aclError ACLAPI_AclMdlGetDesc(aclmdlDesc *modelDesc, uint32_t modelId); + aclDataBuffer *ACLAPI_AclCreateDataBuffer(void *data, size_t size); + size_t ACLAPI_AclMdlGetInputSizeByIndex(aclmdlDesc *modelDesc, size_t index); + size_t ACLAPI_AclMdlGetOutputSizeByIndex(aclmdlDesc *modelDesc, size_t index); + size_t ACLAPI_AclMdlGetNumInputs(aclmdlDesc *modelDesc); + size_t ACLAPI_AclMdlGetNumOutputs(aclmdlDesc *modelDesc); + const char *ACLAPI_AclMdlGetInputNameByIndex(const aclmdlDesc *modelDesc, size_t index); + aclDataType ACLAPI_AclMdlGetInputDataType(const aclmdlDesc *modelDesc, size_t index); + aclError ACLAPI_AclMdlGetInputDims(const aclmdlDesc *modelDesc, size_t index, aclmdlIODims *dims); + aclError + ACLAPI_AclRtMalloc(void **devPtr, size_t size, aclrtMemMallocPolicy policy = ACL_MEM_MALLOC_HUGE_FIRST); + aclmdlDataset *ACLAPI_AclMdlCreateDataset(); + aclError ACLAPI_AclMdlAddDatasetBuffer(aclmdlDataset *dataset, aclDataBuffer *databuffer); + aclError ACLAPI_AclMdlExecute(uint32_t modelId, const aclmdlDataset *input, aclmdlDataset *output); + aclError ACLAPI_AclRtMemcpy(void *dst, size_t destMax, const void *src, size_t count, aclrtMemcpyKind kind); + size_t ACLAPI_AclMdlGetDatasetNumBuffers(const aclmdlDataset *dataset); + aclDataBuffer *ACLAPI_AclMdlGetDatasetBuffer(const aclmdlDataset *dataset, size_t index); + aclError ACLAPI_AclDestroyDataBuffer(const aclDataBuffer *dataBuffer); + aclError ACLAPI_AclMdlDestroyDataset(const aclmdlDataset *dataset); + aclError ACLAPI_AclRtFree(void *devPtr); + aclError ACLAPI_AclFinalize(); + aclError ACLAPI_AclMdlUnload(uint32_t modelId); + aclError ACLAPI_AclMdlDestroyDesc(aclmdlDesc *modelDesc); + aclError ACLAPI_AclRtDestroyContext(aclrtContext context); + aclError ACLAPI_AclRtResetDevice(int32_t deviceId); + aclError ACLAPI_AclInitDump(); + aclError ACLAPI_AclSetDump(const char *dumpCfgPath); + aclError ACLAPI_AclFinalizeDump(); + aclError ACLAPI_AclDumpRegCallBack(AclDumpCallbackFuncType messageCallback, int32_t flag); + void ACLAPI_AclDumpUnregCallBack(); + + private: + void LoadAclApi(); + }; + +#define CALL_ACL_API(func, ...) Ascendcl::AclApi::GetInstance().ACLAPI_##func(__VA_ARGS__) + +} // namespace Ascendcl + +#endif -- Gitee