diff --git a/frameworks/native/neural_network_core/neural_network_core.cpp b/frameworks/native/neural_network_core/neural_network_core.cpp index 9674e13b01b9341b3cc15f46f92a2b9b7d2e1f18..c35a71d15887193cbcc9b39a1dc01ac30948521d 100644 --- a/frameworks/native/neural_network_core/neural_network_core.cpp +++ b/frameworks/native/neural_network_core/neural_network_core.cpp @@ -618,6 +618,44 @@ OH_NN_ReturnCode Authentication(Compilation** compilation) return OH_NN_SUCCESS; } +namespace { +OH_NN_ReturnCode GetNnrtModelId(Compilation* compilationImpl, NNRtServiceApi& nnrtService) +{ + std::string modelName; + OH_NN_ReturnCode retCode = compilationImpl->compiler->GetModelName(modelName); + if (retCode != OH_NN_SUCCESS) { + LOGE("GetModelName is failed."); + return retCode; + } + + if (compilationImpl->nnModel != nullptr) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, + modelName.c_str()); + if (compilationImpl->nnrtModelID == 0) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromModel(compilationImpl->nnModel); + } + } else if (compilationImpl->offlineModelPath != nullptr) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromPath(compilationImpl->offlineModelPath); + } else if (compilationImpl->cachePath != nullptr) { + compilationImpl->nnrtModelID = + nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, modelName.c_str()); + } else if ((compilationImpl->offlineModelBuffer.first != nullptr) && \ + (compilationImpl->offlineModelBuffer.second != size_t(0))) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( + compilationImpl->offlineModelBuffer.first, compilationImpl->offlineModelBuffer.second); + } else if ((compilationImpl->cacheBuffer.first != nullptr) && \ + (compilationImpl->cacheBuffer.second != size_t(0))) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( + compilationImpl->cacheBuffer.first, compilationImpl->cacheBuffer.second); + } else { + LOGE("GetModelId failed, no available model to set modelId, please check."); + return OH_NN_INVALID_PARAMETER; + } + + return OH_NN_SUCCESS; +} +} + OH_NN_ReturnCode GetModelId(Compilation** compilation) { if (compilation == nullptr) { @@ -652,26 +690,9 @@ OH_NN_ReturnCode GetModelId(Compilation** compilation) return OH_NN_INVALID_PARAMETER; } - if (compilationImpl->nnModel != nullptr) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromModel(compilationImpl->nnModel); - } else if (compilationImpl->offlineModelPath != nullptr) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromPath(compilationImpl->offlineModelPath); - } else if (compilationImpl->cachePath != nullptr) { - std::string modelName; - compilationImpl->compiler->GetModelName(modelName); - compilationImpl->nnrtModelID = - nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, modelName.c_str()); - } else if ((compilationImpl->offlineModelBuffer.first != nullptr) && \ - (compilationImpl->offlineModelBuffer.second != size_t(0))) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( - compilationImpl->offlineModelBuffer.first, compilationImpl->offlineModelBuffer.second); - } else if ((compilationImpl->cacheBuffer.first != nullptr) && \ - (compilationImpl->cacheBuffer.second != size_t(0))) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( - compilationImpl->cacheBuffer.first, compilationImpl->cacheBuffer.second); - } else { - LOGE("GetModelId failed, no available model to set modelId, please check."); - return OH_NN_INVALID_PARAMETER; + auto ret = GetNnrtModelId(compilationImpl, nnrtService); + if (ret != OH_NN_SUCCESS) { + LOGE("GetNnrtModelId failed."); } return OH_NN_SUCCESS;