diff --git a/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.cpp b/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.cpp index 40d0403a46da7c34c55a1d748319d066ce14be69..bcce7f808625bd1fcd5cd7eebc24680adedb523f 100644 --- a/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.cpp +++ b/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.cpp @@ -76,7 +76,8 @@ void RouteHeadHandlerImpl::Init() } } SessionPoint localPoint { DmAdapter::GetInstance().GetLocalDevice().uuid, - static_cast(atoi(userId_.c_str())), appId_, storeId_ }; + static_cast(atoi(userId_.c_str())), appId_, storeId_, + AccountDelegate::GetInstance()->GetCurrentAccountId() }; session_ = SessionManager::GetInstance().GetSession(localPoint, deviceId_); ZLOGD("valid session:appId:%{public}s, srcDevId:%{public}s, srcUser:%{public}u, trgDevId:%{public}s,", session_.appId.c_str(), Anonymous::Change(session_.sourceDeviceId).c_str(), @@ -112,9 +113,9 @@ DistributedDB::DBStatus RouteHeadHandlerImpl::GetHeadDataSize(uint32_t &headSize ZLOGI("no valid session to peer device"); return DistributedDB::DB_ERROR; } - size_t expectSize = sizeof(RouteHead) + sizeof(SessionDevicePair) + sizeof(SessionUserPair) - + session_.targetUserIds.size() * sizeof(int) + sizeof(SessionAppId) + session_.appId.size() - + sizeof(SessionStoreId) + session_.storeId.size(); + size_t expectSize = sizeof(RouteHead) + sizeof(SessionDevicePair) + sizeof(SessionUserPair) + + session_.targetUserIds.size() * sizeof(int) + sizeof(SessionAppId) + session_.appId.size() + + sizeof(SessionStoreId) + session_.storeId.size() + sizeof(SessionAccountId) + session_.accountId.size(); // align message uint width headSize = GET_ALIGNED_SIZE(expectSize, ALIGN_WIDTH); @@ -194,28 +195,57 @@ bool RouteHeadHandlerImpl::PackDataBody(uint8_t *data, uint32_t totalLen) } ptr += (sizeof(SessionUserPair) + session_.targetUserIds.size() * sizeof(int)); - SessionAppId *appPair = reinterpret_cast(ptr); + const uint8_t *end = data + totalLen; + if (!PackAppId(&ptr, end) || !PackStoreId(&ptr, end) || !PackAccountId(&ptr, end)) { + return false; + } + return true; +} + +bool RouteHeadHandlerImpl::PackAppId(uint8_t **data, const uint8_t *end) +{ + SessionAppId *appPair = reinterpret_cast(*data); uint32_t appIdSize = session_.appId.size(); appPair->len = HostToNet(appIdSize); - uint8_t *end = data + totalLen; - ptr += sizeof(SessionAppId); - ret = memcpy_s(appPair->appId, end - ptr, session_.appId.c_str(), appIdSize); - if (ret != 0) { + *data += sizeof(SessionAppId); + auto ret = memcpy_s(appPair->appId, end - *data, session_.appId.c_str(), appIdSize); + if (ret != EOK) { ZLOGE("memcpy for app id failed, ret is %{public}d, leftSize is %{public}u, appIdSize is %{public}u", - ret, static_cast(end - ptr), appIdSize); + ret, static_cast(end - *data), appIdSize); return false; } - ptr += appIdSize; + *data += appIdSize; + return true; +} - SessionStoreId *storePair = reinterpret_cast(ptr); +bool RouteHeadHandlerImpl::PackStoreId(uint8_t **data, const uint8_t *end) +{ + SessionStoreId *storePair = reinterpret_cast(*data); uint32_t storeIdSize = session_.storeId.size(); - ret = memcpy_s(storePair->storeId, end - ptr, session_.storeId.data(), storeIdSize); - if (ret != 0) { + storePair->len = HostToNet(storeIdSize); + *data += sizeof(SessionStoreId); + auto ret = memcpy_s(storePair->storeId, end - *data, session_.storeId.data(), storeIdSize); + if (ret != EOK) { ZLOGE("memcpy for store id failed, ret is %{public}d, leftSize is %{public}u, storeIdSize is %{public}u", - ret, static_cast(end - ptr), storeIdSize); + ret, static_cast(end - *data), storeIdSize); + return false; + } + *data += storeIdSize; + return true; +} + +bool RouteHeadHandlerImpl::PackAccountId(uint8_t **data, const uint8_t *end) +{ + SessionAccountId *accountPair = reinterpret_cast(*data); + uint32_t accountIdSize = session_.accountId.size(); + accountPair->len = HostToNet(accountIdSize); + *data += sizeof(SessionAccountId); + auto ret = memcpy_s(accountPair->accountId, end - *data, session_.accountId.data(), accountIdSize); + if (ret != EOK) { + ZLOGE("memcpy for account id failed, ret is %{public}d, leftSize is %{public}u, storeIdSize is %{public}u", + ret, static_cast(end - *data), accountIdSize); return false; } - storePair->len = HostToNet(storeIdSize); return true; } @@ -283,13 +313,21 @@ bool RouteHeadHandlerImpl::ParseHeadDataUser(const uint8_t *data, uint32_t total // flip the local and peer ends SessionPoint local { .deviceId = session_.targetDeviceId, .appId = session_.appId }; - SessionPoint peer { .deviceId = session_.sourceDeviceId, .userId = session_.sourceUserId, .appId = session_.appId }; + SessionPoint peer { .deviceId = session_.sourceDeviceId, .userId = session_.sourceUserId, .appId = session_.appId, + .accountId = session_.accountId }; ZLOGD("valid session:appId:%{public}s, srcDevId:%{public}s, srcUser:%{public}u, trgDevId:%{public}s,", session_.appId.c_str(), Anonymous::Change(session_.sourceDeviceId).c_str(), session_.sourceUserId, Anonymous::Change(session_.targetDeviceId).c_str()); + bool flag = false; + auto peerCap = UpgradeManager::GetInstance().GetCapability(session_.sourceDeviceId, flag); + if (!flag) { + ZLOGI("get peer cap failed, peer deviceId:%{public}s", Anonymous::Change(session_.sourceDeviceId).c_str()); + return false; + } + bool accountFlag = peerCap.version >= CapMetaData::ACCOUNT_VERSION; for (const auto &item : session_.targetUserIds) { local.userId = item; - if (SessionManager::GetInstance().CheckSession(local, peer)) { + if (SessionManager::GetInstance().CheckSession(local, peer, accountFlag)) { UserInfo userInfo = { .receiveUser = std::to_string(item) }; userInfos.emplace_back(userInfo); } @@ -340,7 +378,7 @@ bool RouteHeadHandlerImpl::UnPackDataHead(const uint8_t *data, uint32_t totalLen bool RouteHeadHandlerImpl::UnPackDataBody(const uint8_t *data, uint32_t totalLen) { - const uint8_t *ptr = data; + uint8_t *ptr = const_cast(data); uint32_t leftSize = totalLen; if (leftSize < sizeof(SessionDevicePair)) { @@ -373,39 +411,61 @@ bool RouteHeadHandlerImpl::UnPackDataBody(const uint8_t *data, uint32_t totalLen ptr += userPairSize; leftSize -= userPairSize; + if (!UnPackAppId(&ptr, leftSize) || !UnPackStoreId(&ptr, leftSize) || !UnPackAccountId(&ptr, leftSize)) { + return false; + } + return true; +} + +bool RouteHeadHandlerImpl::UnPackAppId(uint8_t **data, uint32_t leftSize) +{ if (leftSize < sizeof(SessionAppId)) { - ZLOGE("failed to parse app id, leftSize : %{public}d", leftSize); + ZLOGE("failed to parse app id, leftSize:%{public}d.", leftSize); return false; } - const SessionAppId *appId = reinterpret_cast(ptr); + const SessionAppId *appId = reinterpret_cast(*data); auto appIdLen = NetToHost(appId->len); if (leftSize - sizeof(SessionAppId) < appIdLen) { ZLOGE("failed to parse app id, appIdLen:%{public}d, leftSize:%{public}d.", appIdLen, leftSize); return false; } - session_.appId.append(appId->appId, appIdLen); + session_.appId = std::string(appId->appId, appIdLen); leftSize -= (sizeof(SessionAppId) + appIdLen); - if (leftSize > 0) { - ptr += (sizeof(SessionAppId) + appIdLen); - return UnPackStoreId(ptr, leftSize); - } + *data += (sizeof(SessionAppId) + appIdLen); return true; } -bool RouteHeadHandlerImpl::UnPackStoreId(const uint8_t *data, uint32_t leftSize) +bool RouteHeadHandlerImpl::UnPackStoreId(uint8_t **data, uint32_t leftSize) { if (leftSize < sizeof(SessionStoreId)) { ZLOGE("failed to parse store id, leftSize:%{public}d.", leftSize); return false; } - const uint8_t *ptr = data; - const SessionStoreId *storeId = reinterpret_cast(ptr); + const SessionStoreId *storeId = reinterpret_cast(*data); auto storeIdLen = NetToHost(storeId->len); if (leftSize - sizeof(SessionStoreId) < storeIdLen) { ZLOGE("failed to parse store id, storeIdLen:%{public}d, leftSize:%{public}d.", storeIdLen, leftSize); return false; } session_.storeId = std::string(storeId->storeId, storeIdLen); + leftSize -= (sizeof(SessionAppId) + storeIdLen); + *data += (sizeof(SessionAppId) + storeIdLen); + return true; +} + +bool RouteHeadHandlerImpl::UnPackAccountId(uint8_t **data, uint32_t leftSize) +{ + if (leftSize < sizeof(SessionAccountId)) { + ZLOGE("failed to parse account id, leftSize:%{public}d.", leftSize); + return false; + } + const SessionAccountId *accountId = reinterpret_cast(*data); + auto accountIdLen = NetToHost(accountId->len); + if (leftSize - sizeof(SessionAccountId) < accountIdLen) { + ZLOGE("failed to parse account id, accountIdLen:%{public}d, leftSize:%{public}d.", accountIdLen, leftSize); + return false; + } + session_.accountId = std::string(accountId->accountId, accountIdLen); return true; } } // namespace OHOS::DistributedData \ No newline at end of file diff --git a/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.h b/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.h index 11583b7b1d4c712c8129b23f60f7d29dffa456a9..bd3ff4261929bb97747f397b2aa94ad51f5f82b5 100644 --- a/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.h +++ b/services/distributeddataservice/app/src/session_manager/route_head_handler_impl.h @@ -58,6 +58,11 @@ struct SessionStoreId { uint32_t len; char storeId[0]; }; + +struct SessionAccountId { + uint32_t len; + char accountId[0]; +}; #pragma pack() class RouteHeadHandlerImpl : public DistributedData::RouteHeadHandler { @@ -75,11 +80,16 @@ private: bool PackData(uint8_t *data, uint32_t totalLen); bool PackDataHead(uint8_t *data, uint32_t totalLen); bool PackDataBody(uint8_t *data, uint32_t totalLen); + bool PackAppId(uint8_t **data, const uint8_t *end); + bool PackStoreId(uint8_t **data, const uint8_t *end); + bool PackAccountId(uint8_t **data, const uint8_t *end); bool UnPackData(const uint8_t *data, uint32_t totalLen, uint32_t &unpackedSize); bool UnPackDataHead(const uint8_t *data, uint32_t totalLen, RouteHead &routeHead); bool UnPackDataBody(const uint8_t *data, uint32_t totalLen); + bool UnPackAppId(uint8_t **data, uint32_t leftSize); + bool UnPackStoreId(uint8_t **data, uint32_t leftSize); + bool UnPackAccountId(uint8_t **data, uint32_t leftSize); std::string ParseStoreId(const std::string &deviceId, const std::string &label); - bool UnPackStoreId(const uint8_t *data, uint32_t leftSize); std::string userId_; std::string appId_; diff --git a/services/distributeddataservice/app/src/session_manager/session_manager.cpp b/services/distributeddataservice/app/src/session_manager/session_manager.cpp index 1ef5c834ab26b19bc412b6335aa028cf9890edcd..cb435f6fb7e4e66e16f64ddff865e6b03bec3c9e 100644 --- a/services/distributeddataservice/app/src/session_manager/session_manager.cpp +++ b/services/distributeddataservice/app/src/session_manager/session_manager.cpp @@ -48,6 +48,7 @@ Session SessionManager::GetSession(const SessionPoint &local, const std::string session.sourceUserId = local.userId; session.sourceDeviceId = local.deviceId; session.targetDeviceId = targetDeviceId; + session.accountId = local.accountId; auto users = UserDelegate::GetInstance().GetRemoteUserStatus(targetDeviceId); // system service if (local.userId == UserDelegate::SYSTEM_USER) { @@ -94,7 +95,7 @@ bool SessionManager::GetSendAuthParams(const SessionPoint &local, const std::str for (const auto &storeMeta : metaData) { if (storeMeta.appId == local.appId && storeMeta.storeId == local.storeId) { aclParams.accCaller.bundleName = storeMeta.bundleName; - aclParams.accCaller.accountId = AccountDelegate::GetInstance()->GetCurrentAccountId(); + aclParams.accCaller.accountId = local.accountId; aclParams.accCaller.userId = local.userId; aclParams.accCaller.networkId = DmAdapter::GetInstance().ToNetworkID(local.deviceId); @@ -109,13 +110,13 @@ bool SessionManager::GetSendAuthParams(const SessionPoint &local, const std::str return false; } -bool SessionManager::GetRecvAuthParams(const SessionPoint &local, const std::string &targetDeviceId, - AclParams &aclParams, int32_t peerUser) const +bool SessionManager::GetRecvAuthParams(const SessionPoint &local, const SessionPoint &peer, bool accountFlag, + AclParams &aclParams) const { std::vector metaData; - if (!MetaDataManager::GetInstance().LoadMeta(StoreMetaData::GetPrefix({ targetDeviceId }), metaData)) { - ZLOGE("load meta failed, deviceId:%{public}s, user:%{public}d", Anonymous::Change(targetDeviceId).c_str(), - peerUser); + if (!MetaDataManager::GetInstance().LoadMeta(StoreMetaData::GetPrefix({ peer.deviceId }), metaData)) { + ZLOGE("load meta failed, deviceId:%{public}s, user:%{public}d", Anonymous::Change(peer.deviceId).c_str(), + peer.userId); return false; } for (const auto &storeMeta : metaData) { @@ -126,23 +127,23 @@ bool SessionManager::GetRecvAuthParams(const SessionPoint &local, const std::str aclParams.accCaller.userId = local.userId; aclParams.accCaller.networkId = DmAdapter::GetInstance().ToNetworkID(local.deviceId); - aclParams.accCallee.accountId = accountId; - aclParams.accCallee.userId = peerUser; - aclParams.accCallee.networkId = DmAdapter::GetInstance().ToNetworkID(targetDeviceId); + aclParams.accCallee.accountId = accountFlag ? peer.accountId : accountId; + aclParams.accCallee.userId = peer.userId; + aclParams.accCallee.networkId = DmAdapter::GetInstance().ToNetworkID(peer.deviceId); aclParams.authType = storeMeta.authType; return true; } } ZLOGE("get params failed,appId:%{public}s,tarDevid:%{public}s,user:%{public}d,peer:%{public}d", - local.appId.c_str(), Anonymous::Change(targetDeviceId).c_str(), local.userId, peerUser); + local.appId.c_str(), Anonymous::Change(peer.deviceId).c_str(), local.userId, peer.userId); return false; } -bool SessionManager::CheckSession(const SessionPoint &local, const SessionPoint &peer) const +bool SessionManager::CheckSession(const SessionPoint &local, const SessionPoint &peer, bool accountFlag) const { AclParams aclParams; - if (!GetRecvAuthParams(local, peer.deviceId, aclParams, peer.userId)) { + if (!GetRecvAuthParams(local, peer, accountFlag, aclParams)) { ZLOGE("get recv auth params failed:%{public}s", Anonymous::Change(peer.deviceId).c_str()); return false; } @@ -165,6 +166,7 @@ bool Session::Marshal(json &node) const ret = SetValue(node[GET_NAME(targetUserIds)], targetUserIds) && ret; ret = SetValue(node[GET_NAME(appId)], appId) && ret; ret = SetValue(node[GET_NAME(storeId)], storeId) && ret; + ret = SetValue(node[GET_NAME(accountId)], accountId) && ret; return ret; } @@ -177,6 +179,7 @@ bool Session::Unmarshal(const json &node) ret = GetValue(node, GET_NAME(targetUserIds), targetUserIds) && ret; ret = GetValue(node, GET_NAME(appId), appId) && ret; ret = GetValue(node, GET_NAME(storeId), storeId) && ret; + ret = GetValue(node, GET_NAME(accountId), accountId) && ret; return ret; } } // namespace OHOS::DistributedData diff --git a/services/distributeddataservice/app/src/session_manager/session_manager.h b/services/distributeddataservice/app/src/session_manager/session_manager.h index c72c1bef195f8b24c351fae585b0e6e4754c97a8..c41e30c2381316ba24f4d9d574b47d2245e6a143 100644 --- a/services/distributeddataservice/app/src/session_manager/session_manager.h +++ b/services/distributeddataservice/app/src/session_manager/session_manager.h @@ -30,6 +30,7 @@ struct SessionPoint { uint32_t userId; std::string appId; std::string storeId; + std::string accountId; }; class Session : public Serializable { @@ -40,6 +41,7 @@ public: std::vector targetUserIds; std::string appId; std::string storeId; + std::string accountId; bool Marshal(json &node) const override; bool Unmarshal(const json &node) override; inline bool IsValid() @@ -52,12 +54,12 @@ class SessionManager { public: static SessionManager &GetInstance(); Session GetSession(const SessionPoint &local, const std::string &targetDeviceId) const; - bool CheckSession(const SessionPoint &local, const SessionPoint &peer) const; + bool CheckSession(const SessionPoint &local, const SessionPoint &peer, bool accountFlag) const; private: bool GetSendAuthParams(const SessionPoint &local, const std::string &targetDeviceId, AclParams &aclParams) const; - bool GetRecvAuthParams(const SessionPoint &local, const std::string &targetDeviceId, - AclParams &aclParams, int peerUser) const; + bool GetRecvAuthParams(const SessionPoint &local, const SessionPoint &peer, bool accountFlag, + AclParams &aclParams) const; }; } // namespace OHOS::DistributedData diff --git a/services/distributeddataservice/app/test/unittest/session_manager_test.cpp b/services/distributeddataservice/app/test/unittest/session_manager_test.cpp index e009b52ad7ba4d4f7c0f678354eb313d01ae9fac..9b0706a085d31196a3179fcb24baca45cf6c1789 100644 --- a/services/distributeddataservice/app/test/unittest/session_manager_test.cpp +++ b/services/distributeddataservice/app/test/unittest/session_manager_test.cpp @@ -628,6 +628,7 @@ HWTEST_F(SessionManagerTest, CheckSession, TestSize.Level1) localNormal.appId = "test_app"; localNormal.deviceId = "local_device"; localNormal.storeId = "test_store"; + localNormal.accountId = "test_account"; std::vector datas; CreateStoreMetaData(datas, localSys); EXPECT_CALL(*metaDataMock, LoadMeta(_, _, _)) @@ -636,11 +637,11 @@ HWTEST_F(SessionManagerTest, CheckSession, TestSize.Level1) EXPECT_CALL(AuthHandlerMock::GetInstance(), CheckAccess(_, _, _, _)) .WillOnce(Return(std::pair(false, true))) .WillOnce(Return(std::pair(true, false))); - bool result = SessionManager::GetInstance().CheckSession(localSys, localNormal); + bool result = SessionManager::GetInstance().CheckSession(localSys, localNormal, true); EXPECT_FALSE(result); - result = SessionManager::GetInstance().CheckSession(localSys, localNormal); + result = SessionManager::GetInstance().CheckSession(localSys, localNormal, true); EXPECT_FALSE(result); - result = SessionManager::GetInstance().CheckSession(localNormal, localSys); + result = SessionManager::GetInstance().CheckSession(localNormal, localSys, true); EXPECT_TRUE(result); } } // namespace \ No newline at end of file diff --git a/services/distributeddataservice/framework/include/metadata/capability_meta_data.h b/services/distributeddataservice/framework/include/metadata/capability_meta_data.h index 702a362a110eca69b46c64fe9806de63836425a6..50b6f25eae785417433c4854c34c9c1cc742475d 100644 --- a/services/distributeddataservice/framework/include/metadata/capability_meta_data.h +++ b/services/distributeddataservice/framework/include/metadata/capability_meta_data.h @@ -20,8 +20,10 @@ namespace OHOS::DistributedData { class API_EXPORT CapMetaData final : public Serializable { public: - static constexpr int32_t CURRENT_VERSION = 1; + // 1->2 add accountId to session + static constexpr int32_t CURRENT_VERSION = 2; static constexpr int32_t INVALID_VERSION = -1; + static constexpr int32_t ACCOUNT_VERSION = 2; int32_t version = INVALID_VERSION; API_EXPORT bool Marshal(json &node) const override;