From 1c7b5f8c976736c00618bc56999cd2f8ebb09240 Mon Sep 17 00:00:00 2001 From: w30052974 Date: Mon, 12 Aug 2024 17:52:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9model=20id?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: w30052974 --- .../neural_network_core.cpp | 61 +++++++++++++------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/frameworks/native/neural_network_core/neural_network_core.cpp b/frameworks/native/neural_network_core/neural_network_core.cpp index 9674e13..c35a71d 100644 --- a/frameworks/native/neural_network_core/neural_network_core.cpp +++ b/frameworks/native/neural_network_core/neural_network_core.cpp @@ -618,6 +618,44 @@ OH_NN_ReturnCode Authentication(Compilation** compilation) return OH_NN_SUCCESS; } +namespace { +OH_NN_ReturnCode GetNnrtModelId(Compilation* compilationImpl, NNRtServiceApi& nnrtService) +{ + std::string modelName; + OH_NN_ReturnCode retCode = compilationImpl->compiler->GetModelName(modelName); + if (retCode != OH_NN_SUCCESS) { + LOGE("GetModelName is failed."); + return retCode; + } + + if (compilationImpl->nnModel != nullptr) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, + modelName.c_str()); + if (compilationImpl->nnrtModelID == 0) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromModel(compilationImpl->nnModel); + } + } else if (compilationImpl->offlineModelPath != nullptr) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromPath(compilationImpl->offlineModelPath); + } else if (compilationImpl->cachePath != nullptr) { + compilationImpl->nnrtModelID = + nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, modelName.c_str()); + } else if ((compilationImpl->offlineModelBuffer.first != nullptr) && \ + (compilationImpl->offlineModelBuffer.second != size_t(0))) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( + compilationImpl->offlineModelBuffer.first, compilationImpl->offlineModelBuffer.second); + } else if ((compilationImpl->cacheBuffer.first != nullptr) && \ + (compilationImpl->cacheBuffer.second != size_t(0))) { + compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( + compilationImpl->cacheBuffer.first, compilationImpl->cacheBuffer.second); + } else { + LOGE("GetModelId failed, no available model to set modelId, please check."); + return OH_NN_INVALID_PARAMETER; + } + + return OH_NN_SUCCESS; +} +} + OH_NN_ReturnCode GetModelId(Compilation** compilation) { if (compilation == nullptr) { @@ -652,26 +690,9 @@ OH_NN_ReturnCode GetModelId(Compilation** compilation) return OH_NN_INVALID_PARAMETER; } - if (compilationImpl->nnModel != nullptr) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromModel(compilationImpl->nnModel); - } else if (compilationImpl->offlineModelPath != nullptr) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromPath(compilationImpl->offlineModelPath); - } else if (compilationImpl->cachePath != nullptr) { - std::string modelName; - compilationImpl->compiler->GetModelName(modelName); - compilationImpl->nnrtModelID = - nnrtService.GetNNRtModelIDFromCache(compilationImpl->cachePath, modelName.c_str()); - } else if ((compilationImpl->offlineModelBuffer.first != nullptr) && \ - (compilationImpl->offlineModelBuffer.second != size_t(0))) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( - compilationImpl->offlineModelBuffer.first, compilationImpl->offlineModelBuffer.second); - } else if ((compilationImpl->cacheBuffer.first != nullptr) && \ - (compilationImpl->cacheBuffer.second != size_t(0))) { - compilationImpl->nnrtModelID = nnrtService.GetNNRtModelIDFromBuffer( - compilationImpl->cacheBuffer.first, compilationImpl->cacheBuffer.second); - } else { - LOGE("GetModelId failed, no available model to set modelId, please check."); - return OH_NN_INVALID_PARAMETER; + auto ret = GetNnrtModelId(compilationImpl, nnrtService); + if (ret != OH_NN_SUCCESS) { + LOGE("GetNnrtModelId failed."); } return OH_NN_SUCCESS; -- Gitee