diff --git a/frameworks/native/neural_network_core/backend_manager.cpp b/frameworks/native/neural_network_core/backend_manager.cpp index c392795f5a87f44a436cadf11c6ee33811b92cad..078d0665729f17a58e75944a0d1d4db7c124810c 100644 --- a/frameworks/native/neural_network_core/backend_manager.cpp +++ b/frameworks/native/neural_network_core/backend_manager.cpp @@ -23,21 +23,13 @@ namespace NeuralNetworkRuntime { BackendManager::~BackendManager() { m_backends.clear(); + m_backendNames.clear(); m_backendIDs.clear(); } -std::vector BackendManager::GetAllBackendsID() +const std::vector& BackendManager::GetAllBackendsID() { - std::vector tmpBackendIds; - std::shared_ptr backend {nullptr}; - for (auto iter = m_backends.begin(); iter != m_backends.end(); ++iter) { - backend = iter->second; - if (!IsValidBackend(backend)) { - continue; - } - tmpBackendIds.emplace_back(iter->first); - } - return tmpBackendIds; + return m_backendIDs; } std::shared_ptr BackendManager::GetBackend(size_t backendID) const @@ -62,32 +54,27 @@ std::shared_ptr BackendManager::GetBackend(size_t backendID) const return iter->second; } -std::string BackendManager::GetBackendName(size_t backendID) +const std::string& BackendManager::GetBackendName(size_t backendID) { - std::string tmpBackendName; - if (m_backends.empty()) { + std::string emptyName; + if (m_backendNames.empty()) { LOGE("[BackendManager] GetBackendName failed, there is no registered backend can be used."); - return tmpBackendName; + return emptyName; } - auto iter = m_backends.begin(); + auto iter = m_backendNames.begin(); if (backendID == static_cast(0)) { LOGI("[BackendManager] the backendID is 0, default return 1st backend."); } else { - iter = m_backends.find(backendID); + iter = m_backendNames.find(backendID); } - if (iter == m_backends.end()) { + if (iter == m_backendNames.end()) { LOGE("[BackendManager] GetBackendName failed, backendID %{public}zu is not registered.", backendID); - return tmpBackendName; + return emptyName; } - OH_NN_ReturnCode ret = iter->second->GetBackendName(tmpBackendName); - if (ret != OH_NN_SUCCESS) { - LOGE("[BackendManager] GetBackendName failed, fail to get backendName from backend."); - } - - return tmpBackendName; + return iter->second; } OH_NN_ReturnCode BackendManager::RegisterBackend(std::function()> creator) @@ -106,14 +93,22 @@ OH_NN_ReturnCode BackendManager::RegisterBackend(std::functionGetBackendID(); const std::lock_guard lock(m_mtx); - auto setResult = m_backendIDs.emplace(backendID); - if (!setResult.second) { + auto iter = std::find(m_backendIDs.begin(), m_backendIDs.end(), backendID); + if (iter != m_backendIDs.end()) { LOGE("[BackendManager] RegisterBackend failed, backend already exists, cannot register again. " "backendID=%{public}zu", backendID); return OH_NN_FAILED; } + std::string tmpBackendName; + auto ret = regBackend->GetBackendName(tmpBackendName); + if (ret != OH_NN_SUCCESS) { + LOGE("[BackendManager] RegisterBackend failed, fail to get backend name."); + return OH_NN_FAILED; + } m_backends.emplace(backendID, regBackend); + m_backendIDs.emplace_back(backendID); + m_backendNames.emplace(backendID, tmpBackendName); return OH_NN_SUCCESS; } diff --git a/frameworks/native/neural_network_core/backend_manager.h b/frameworks/native/neural_network_core/backend_manager.h index 936f254ff4f429010ac7db7bc273e639c54d2cc4..656d47cad4841b026bda66a14e2e8d39286e3e0f 100644 --- a/frameworks/native/neural_network_core/backend_manager.h +++ b/frameworks/native/neural_network_core/backend_manager.h @@ -32,9 +32,9 @@ namespace OHOS { namespace NeuralNetworkRuntime { class BackendManager { public: - std::vector GetAllBackendsID(); + const std::vector& GetAllBackendsID(); std::shared_ptr GetBackend(size_t backendID) const; - std::string GetBackendName(size_t backendID); + const std::string& GetBackendName(size_t backendID); // Register backend by C++ API OH_NN_ReturnCode RegisterBackend(std::function()> creator); @@ -60,7 +60,8 @@ private: bool IsValidBackend(std::shared_ptr backend) const; private: - std::unordered_set m_backendIDs; + std::vector m_backendIDs; + std::unordered_map m_backendNames; // key is the name of backend. std::unordered_map> m_backends; std::mutex m_mtx;