From 7b0d5c13dbd9557f560fcd1667f2c8e47d366711 Mon Sep 17 00:00:00 2001 From: wangchuanxia Date: Mon, 5 Feb 2024 23:08:15 +0800 Subject: [PATCH] fix backend id bug Signed-off-by: wangchuanxia --- .../neural_network_core/backend_manager.cpp | 49 +++++++++---------- .../neural_network_core/backend_manager.h | 7 +-- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/frameworks/native/neural_network_core/backend_manager.cpp b/frameworks/native/neural_network_core/backend_manager.cpp index c392795..078d066 100644 --- a/frameworks/native/neural_network_core/backend_manager.cpp +++ b/frameworks/native/neural_network_core/backend_manager.cpp @@ -23,21 +23,13 @@ namespace NeuralNetworkRuntime { BackendManager::~BackendManager() { m_backends.clear(); + m_backendNames.clear(); m_backendIDs.clear(); } -std::vector BackendManager::GetAllBackendsID() +const std::vector& BackendManager::GetAllBackendsID() { - std::vector tmpBackendIds; - std::shared_ptr backend {nullptr}; - for (auto iter = m_backends.begin(); iter != m_backends.end(); ++iter) { - backend = iter->second; - if (!IsValidBackend(backend)) { - continue; - } - tmpBackendIds.emplace_back(iter->first); - } - return tmpBackendIds; + return m_backendIDs; } std::shared_ptr BackendManager::GetBackend(size_t backendID) const @@ -62,32 +54,27 @@ std::shared_ptr BackendManager::GetBackend(size_t backendID) const return iter->second; } -std::string BackendManager::GetBackendName(size_t backendID) +const std::string& BackendManager::GetBackendName(size_t backendID) { - std::string tmpBackendName; - if (m_backends.empty()) { + std::string emptyName; + if (m_backendNames.empty()) { LOGE("[BackendManager] GetBackendName failed, there is no registered backend can be used."); - return tmpBackendName; + return emptyName; } - auto iter = m_backends.begin(); + auto iter = m_backendNames.begin(); if (backendID == static_cast(0)) { LOGI("[BackendManager] the backendID is 0, default return 1st backend."); } else { - iter = m_backends.find(backendID); + iter = m_backendNames.find(backendID); } - if (iter == m_backends.end()) { + if (iter == m_backendNames.end()) { LOGE("[BackendManager] GetBackendName failed, backendID %{public}zu is not registered.", backendID); - return tmpBackendName; + return emptyName; } - OH_NN_ReturnCode ret = iter->second->GetBackendName(tmpBackendName); - if (ret != OH_NN_SUCCESS) { - LOGE("[BackendManager] GetBackendName failed, fail to get backendName from backend."); - } - - return tmpBackendName; + return iter->second; } OH_NN_ReturnCode BackendManager::RegisterBackend(std::function()> creator) @@ -106,14 +93,22 @@ OH_NN_ReturnCode BackendManager::RegisterBackend(std::functionGetBackendID(); const std::lock_guard lock(m_mtx); - auto setResult = m_backendIDs.emplace(backendID); - if (!setResult.second) { + auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID); + if (iter != m_backendIDs.end()) { LOGE("[BackendManager] RegisterBackend failed, backend already exists, cannot register again. " "backendID=%{public}zu", backendID); return OH_NN_FAILED; } + std::string tmpBackendName; + auto ret = regBackend->GetBackendName(tmpBackendName); + if (ret != OH_NN_SUCCESS) { + LOGE("[BackendManager] RegisterBackend failed, fail to get backend name."); + return OH_NN_FAILED; + } m_backends.emplace(backendID, regBackend); + m_backendIDs.emplace_back(backendID); + m_backendNames.emplace(backendID, tmpBackendName); return OH_NN_SUCCESS; } diff --git a/frameworks/native/neural_network_core/backend_manager.h b/frameworks/native/neural_network_core/backend_manager.h index 936f254..656d47c 100644 --- a/frameworks/native/neural_network_core/backend_manager.h +++ b/frameworks/native/neural_network_core/backend_manager.h @@ -32,9 +32,9 @@ namespace OHOS { namespace NeuralNetworkRuntime { class BackendManager { public: - std::vector GetAllBackendsID(); + const std::vector& GetAllBackendsID(); std::shared_ptr GetBackend(size_t backendID) const; - std::string GetBackendName(size_t backendID); + const std::string& GetBackendName(size_t backendID); // Register backend by C++ API OH_NN_ReturnCode RegisterBackend(std::function()> creator); @@ -60,7 +60,8 @@ private: bool IsValidBackend(std::shared_ptr backend) const; private: - std::unordered_set m_backendIDs; + std::vector m_backendIDs; + std::unordered_map m_backendNames; // key is the name of backend. std::unordered_map> m_backends; std::mutex m_mtx; -- Gitee