diff --git a/frameworks/native/neural_network_runtime/neural_network_runtime.cpp b/frameworks/native/neural_network_runtime/neural_network_runtime.cpp index 2cb6e1dbf92c1cac699b52bf70844558798db973..39f79d236c32e024918f511e310bded47169dfa2 100644 --- a/frameworks/native/neural_network_runtime/neural_network_runtime.cpp +++ b/frameworks/native/neural_network_runtime/neural_network_runtime.cpp @@ -46,7 +46,8 @@ const std::string EXTENSION_KEY_FM_SHARED = "NPU_FM_SHARED"; const std::string EXTENSION_KEY_IS_EXCEED_RAMLIMIT = "isExceedRamLimit"; const std::string NULL_HARDWARE_NAME = "default"; -const std::string HARDWARE_NAME = "const.ai.nnrt_deivce"; +const std::string NNRT_DEVICE_NAME = "const.ai.nnrt_deivce"; +const std::string HARDWARE_NAME = "ohos.boot.hardware"; const std::string HARDWARE_VERSION = "v5_0"; constexpr size_t HARDWARE_NAME_MAX_LENGTH = 128; constexpr size_t FILE_NUMBER_MAX = 100; // 限制cache文件数量最大为100 @@ -565,7 +566,7 @@ NNRT_API OH_NN_ReturnCode OH_NNModel_BuildFromLiteGraph(OH_NNModel *model, const } namespace { -OH_NN_ReturnCode CheckCacheFileExtension(const std::string& content, int64_t& fileNumber, int64_t& cacheVersion) +OH_NN_ReturnCode CheckCacheFileExtension(const std::string& content, int64_t& fileNumber, int64_t& cacheVersion, int64_t& deviceId) { if (!nlohmann::json::accept(content)) { LOGE("OH_NNModel_HasCache CheckCacheFile JSON parse error"); @@ -578,6 +579,12 @@ OH_NN_ReturnCode CheckCacheFileExtension(const std::string& content, int64_t& fi return OH_NN_INVALID_FILE; } + if (j["data"].find("deviceId") == j["data"].end()) { + LOGE("OH_NNModel_HasCache read deviceId from cache info file failed."); + return OH_NN_INVALID_FILE; + } + deviceId = j["data"]["deviceId"].get(); + if (j["data"].find("fileNumber") == j["data"].end()) { LOGE("OH_NNModel_HasCache read fileNumber from cache info file failed."); return OH_NN_INVALID_FILE; @@ -610,7 +617,7 @@ OH_NN_ReturnCode CheckCacheFileExtension(const std::string& content, int64_t& fi return OH_NN_SUCCESS; } -OH_NN_ReturnCode CheckCacheFile(const std::string& cacheInfoPath, int64_t& fileNumber, int64_t& cacheVersion) +OH_NN_ReturnCode CheckCacheFile(const std::string& cacheInfoPath, int64_t& fileNumber, int64_t& cacheVersion, int64_t& deviceId) { char path[PATH_MAX]; if (realpath(cacheInfoPath.c_str(), path) == nullptr) { @@ -633,7 +640,26 @@ OH_NN_ReturnCode CheckCacheFile(const std::string& cacheInfoPath, int64_t& fileN std::string content((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); ifs.close(); - return CheckCacheFileExtension(content, fileNumber, cacheVersion); + return CheckCacheFileExtension(content, fileNumber, cacheVersion, deviceId); +} + +OH_NN_ReturnCode CheckDeviceId(int64_t& deviceId) +{ + std::string deviceName; + char cName[HARDWARE_NAME_MAX_LENGTH]; + int ret = GetParameter(HARDWARE_NAME.c_str(), NULL_HARDWARE_NAME.c_str(), cName, HARDWARE_NAME_MAX_LENGTH); + if (ret <= 0) { + LOGE("OH_NNModel_HasCache failed to get parameter."); + return OH_NN_FAILED; + } + + deviceName = HARDWARE_NAME + "." + cName; + if (deviceId != std::hash{}(deviceName)) { + LOGE("OH_NNModel_HasCache the deviceId in the cache files is different from current deviceId."); + return OH_NN_FAILED; + } + + return OH_NN_SUCCESS; } } @@ -657,15 +683,23 @@ NNRT_API bool OH_NNModel_HasCache(const char *cacheDir, const char *modelName, u return false; } + int64_t deviceId{0}; int64_t fileNumber{0}; int64_t cacheVersion{0}; - OH_NN_ReturnCode returnCode = CheckCacheFile(cacheInfoPath, fileNumber, cacheVersion); + OH_NN_ReturnCode returnCode = CheckCacheFile(cacheInfoPath, fileNumber, cacheVersion, deviceId); if (returnCode != OH_NN_SUCCESS) { LOGE("OH_NNModel_HasCache get fileNumber or cacheVersion fail."); std::filesystem::remove_all(cacheInfoPath); return false; } + returnCode = CheckDeviceId(deviceId); + if (returnCode != OH_NN_SUCCESS) { + LOGE("OH_NNModel_HasCache check deviceId fail."); + std::filesystem::remove_all(cacheInfoPath); + return false; + } + if (fileNumber <= 0 || fileNumber > FILE_NUMBER_MAX) { LOGE("OH_NNModel_HasCache fileNumber is invalid or more than 100"); std::filesystem::remove_all(cacheInfoPath); @@ -804,7 +838,7 @@ NNRT_API OH_NN_ReturnCode OH_NN_GetDeviceID(char *nnrtDevice, size_t len) } char cName[HARDWARE_NAME_MAX_LENGTH] = {0}; - int ret = GetParameter(HARDWARE_NAME.c_str(), NULL_HARDWARE_NAME.c_str(), cName, HARDWARE_NAME_MAX_LENGTH); + int ret = GetParameter(NNRT_DEVICE_NAME.c_str(), NULL_HARDWARE_NAME.c_str(), cName, HARDWARE_NAME_MAX_LENGTH); // 如果成功获取返回值为硬件名称的字节数 if (ret <= 0) { LOGE("GetNNRtDeviceName failed, failed to get parameter.");