From 06d9ee4b47891713f9648ffad3b82cf6a997bcba Mon Sep 17 00:00:00 2001 From: edwardcaoyue Date: Wed, 22 Nov 2023 12:02:53 +0800 Subject: [PATCH] Rectification of auth mode Signed-off-by: edwardcaoyue --- bundle.json | 2 +- frameworks/concurrent_task_client/BUILD.gn | 1 - .../include/concurrent_task_type.h | 1 - services/BUILD.gn | 2 +- services/include/concurrent_task_controller.h | 10 ++- services/src/concurrent_task_controller.cpp | 69 +++++++++---------- test/BUILD.gn | 1 - .../phone/concurrent_task_controller_test.cpp | 27 ++------ 8 files changed, 45 insertions(+), 68 deletions(-) diff --git a/bundle.json b/bundle.json index 4962b7f..22658b9 100644 --- a/bundle.json +++ b/bundle.json @@ -23,8 +23,8 @@ "components": [ "ability_base", "ability_runtime", + "access_token", "c_utils", - "eventhandler", "frame_aware_sched", "hilog", "hitrace", diff --git a/frameworks/concurrent_task_client/BUILD.gn b/frameworks/concurrent_task_client/BUILD.gn index b4e8e32..e40fea6 100644 --- a/frameworks/concurrent_task_client/BUILD.gn +++ b/frameworks/concurrent_task_client/BUILD.gn @@ -53,7 +53,6 @@ ohos_shared_library("concurrent_task_client") { external_deps = [ "c_utils:utils", - "eventhandler:libeventhandler", "hilog:libhilog", "ipc:ipc_single", "samgr:samgr_proxy", diff --git a/frameworks/concurrent_task_client/include/concurrent_task_type.h b/frameworks/concurrent_task_client/include/concurrent_task_type.h index 520c13e..c9194e0 100644 --- a/frameworks/concurrent_task_client/include/concurrent_task_type.h +++ b/frameworks/concurrent_task_client/include/concurrent_task_type.h @@ -28,7 +28,6 @@ enum MsgType { MSG_CONTINUOUS_TASK_END, MSG_GET_FOCUS, MSG_LOSE_FOCUS, - MSG_AUTH_REQUEST, MSG_SYSTEM_MAX, MSG_APP_START_TYPE = 100, MSG_REG_RENDER = MSG_APP_START_TYPE, diff --git a/services/BUILD.gn b/services/BUILD.gn index 779ba63..8155dde 100644 --- a/services/BUILD.gn +++ b/services/BUILD.gn @@ -60,8 +60,8 @@ ohos_shared_library("concurrentsvc") { } external_deps = [ + "access_token:libaccesstoken_sdk", "c_utils:utils", - "eventhandler:libeventhandler", "frame_aware_sched:rtg_interface", "hilog:libhilog", "hitrace:hitrace_meter", diff --git a/services/include/concurrent_task_controller.h b/services/include/concurrent_task_controller.h index 13cbd77..74d813a 100644 --- a/services/include/concurrent_task_controller.h +++ b/services/include/concurrent_task_controller.h @@ -42,13 +42,12 @@ public: int CreateNewRtgGrp(int prioType, int rtNum); private: - bool CheckUid(pid_t uid); void TypeMapInit(); void QosApplyInit(); void TryCreateRsGroup(); void QueryUi(pid_t uid, IntervalReply& queryRs); void QueryRender(pid_t uid, IntervalReply& queryRs); - void QueryRenderService(pid_t uid, IntervalReply& queryRs); + void QueryRenderService(pid_t uid, pid_t pid, IntervalReply& queryRs); void QueryHwc(pid_t uid, IntervalReply& queryRs); int GetRequestType(std::string strRequstType); void DealSystemRequest(int requestType, const Json::Value& payload); @@ -58,7 +57,7 @@ private: void AppKilled(int uid, int pid); void ContinuousTaskProcess(int uid, int pid, int status); void FocusStatusProcess(int uid, int pid, int status); - void AuthRequestProcess(int uid, int pid); + int AuthSystemProcess(int pid); bool ModifySystemRate(const Json::Value& payload); void SetAppRate(const Json::Value& payload); int FindRateFromInfo(int uiTid, const Json::Value& payload); @@ -68,6 +67,7 @@ private: std::list::iterator GetRecordOfPid(int pid); void PrintInfo(); bool ParsePayload(const Json::Value& payload, int& uid, int& pid); + std::string GetProcessNameByToken(); std::mutex appInfoLock_; std::list foregroundApp_ = {}; @@ -78,6 +78,10 @@ private: int rsTid_ = -1; int systemRate_ = 0; bool rtgEnabled_ = false; + bool rsAuthed_ = false; + + const std::string RENDER_SERVICE_PROCESS_NAME = "render_service"; + const std::string RESOURCE_SCHEDULE_PROCESS_NAME = "resource_schedule_service"; }; class ForegroundAppRecord { diff --git a/services/src/concurrent_task_controller.cpp b/services/src/concurrent_task_controller.cpp index 7eb9617..b4d778e 100644 --- a/services/src/concurrent_task_controller.cpp +++ b/services/src/concurrent_task_controller.cpp @@ -20,15 +20,15 @@ #include #include #include - +#include "accesstoken_kit.h" #include "concurrent_task_log.h" #include "rtg_interface.h" #include "ipc_skeleton.h" #include "parameters.h" #include "concurrent_task_controller.h" -constexpr int RS_UID = 1003; using namespace OHOS::RME; +using namespace OHOS::Security::AccessToken; namespace OHOS { namespace ConcurrentTask { @@ -54,8 +54,8 @@ TaskController& TaskController::GetInstance() void TaskController::ReportData(uint32_t resType, int64_t value, const Json::Value& payload) { pid_t uid = IPCSkeleton::GetInstance().GetCallingUid(); - if (!CheckUid(uid)) { - CONCUR_LOGE("only system call can be allowed"); + if (GetProcessNameByToken() != RESOURCE_SCHEDULE_PROCESS_NAME) { + CONCUR_LOGE("Invalid uid %{public}d, only RSS can call ReportData", uid); return; } if (!CheckJsonValid(payload)) { @@ -80,10 +80,7 @@ void TaskController::ReportData(uint32_t resType, int64_t value, const Json::Val void TaskController::QueryInterval(int queryItem, IntervalReply& queryRs) { pid_t uid = IPCSkeleton::GetInstance().GetCallingUid(); - if (uid == 0) { - CONCUR_LOGE("Uid is 0, error query"); - return; - } + pid_t pid = IPCSkeleton::GetInstance().GetCallingPid(); switch (queryItem) { case QUERY_UI: QueryUi(uid, queryRs); @@ -92,7 +89,7 @@ void TaskController::QueryInterval(int queryItem, IntervalReply& queryRs) QueryRender(uid, queryRs); break; case QUERY_RENDER_SERVICE: - QueryRenderService(uid, queryRs); + QueryRenderService(uid, pid, queryRs); break; case QUERY_COMPOSER: QueryHwc(uid, queryRs); @@ -102,11 +99,18 @@ void TaskController::QueryInterval(int queryItem, IntervalReply& queryRs) } } -void TaskController::QueryUi(int uid, IntervalReply& queryRs) +std::string TaskController::GetProcessNameByToken() { - if (uid == SYSTEM_UID) { - return; + AccessTokenID tokenID = IPCSkeleton::GetInstance().GetCallingTokenID(); + NativeTokenInfo tokenInfo; + if (AccessTokenKit::GetNativeTokenInfo(tokenID, tokenInfo) != AccessTokenKitRet::RET_SUCCESS) { + return ""; } + return tokenInfo.processName; +} + +void TaskController::QueryUi(int uid, IntervalReply& queryRs) +{ pid_t pid = IPCSkeleton::GetInstance().GetCallingPid(); auto iter = GetRecordOfPid(pid); if (iter == foregroundApp_.end()) { @@ -124,9 +128,6 @@ void TaskController::QueryUi(int uid, IntervalReply& queryRs) void TaskController::QueryRender(int uid, IntervalReply& queryRs) { - if (uid == SYSTEM_UID) { - return; - } pid_t pid = IPCSkeleton::GetInstance().GetCallingPid(); auto iter = GetRecordOfPid(pid); if (iter == foregroundApp_.end()) { @@ -142,8 +143,18 @@ void TaskController::QueryRender(int uid, IntervalReply& queryRs) } } -void TaskController::QueryRenderService(int uid, IntervalReply& queryRs) +void TaskController::QueryRenderService(int uid, int pid, IntervalReply& queryRs) { + if (GetProcessNameByToken() != RENDER_SERVICE_PROCESS_NAME) { + return; + } + + if (!rsAuthed_) { + if (AuthSystemProcess(pid) != 0) { + return; + } + rsAuthed_ = true; + } if (renderServiceGrpId_ <= 0) { TryCreateRsGroup(); CONCUR_LOGI("uid %{public}d query rs group failed and create %{public}d.", uid, renderServiceGrpId_); @@ -165,9 +176,6 @@ void TaskController::QueryRenderService(int uid, IntervalReply& queryRs) void TaskController::QueryHwc(int uid, IntervalReply& queryRs) { - if (uid == SYSTEM_UID) { - return; - } pid_t pid = IPCSkeleton::GetInstance().GetCallingPid(); auto iter = GetRecordOfPid(pid); if (iter == foregroundApp_.end()) { @@ -209,7 +217,6 @@ void TaskController::TypeMapInit() msgType_.insert(pair("appKilled", MSG_APP_KILLED)); msgType_.insert(pair("continuousStart", MSG_CONTINUOUS_TASK_START)); msgType_.insert(pair("continuousEnd", MSG_CONTINUOUS_TASK_END)); - msgType_.insert(pair("authRequest", MSG_AUTH_REQUEST)); msgType_.insert(pair("getFocus", MSG_GET_FOCUS)); msgType_.insert(pair("loseFocus", MSG_LOSE_FOCUS)); } @@ -244,14 +251,6 @@ int TaskController::GetRequestType(std::string strRequstType) return msgType_[strRequstType]; } -bool TaskController::CheckUid(pid_t uid) -{ - if ((uid != SYSTEM_UID) && (uid != 0) && (uid != RS_UID)) { - return false; - } - return true; -} - bool TaskController::ParsePayload(const Json::Value& payload, int& uid, int& pid) { try { @@ -295,9 +294,6 @@ void TaskController::DealSystemRequest(int requestType, const Json::Value& paylo case MSG_LOSE_FOCUS: FocusStatusProcess(uid, pid, requestType); break; - case MSG_AUTH_REQUEST: - AuthRequestProcess(uid, pid); - break; default: CONCUR_LOGE("Unknown system request"); break; @@ -428,12 +424,8 @@ void TaskController::AppKilled(int uid, int pid) } } -void TaskController::AuthRequestProcess(int uid, int pid) +int TaskController::AuthSystemProcess(int pid) { - if (uid != RS_UID) { - CONCUR_LOGE("uid %{public}d cannot request auth", uid); - return; - } unsigned int uaFlag = AF_RTG_ALL; unsigned int status = static_cast(AuthStatus::AUTH_STATUS_SYSTEM_SERVER); int ret = AuthEnable(pid, uaFlag, status); @@ -442,6 +434,7 @@ void TaskController::AuthRequestProcess(int uid, int pid) } else { CONCUR_LOGI("auth process %{public}d failed, ret %{public}d", pid, ret); } + return ret; } void TaskController::ContinuousTaskProcess(int uid, int pid, int status) @@ -477,8 +470,8 @@ void TaskController::FocusStatusProcess(int uid, int pid, int status) void TaskController::QueryDeadline(int queryItem, DeadlineReply& ddlReply, const Json::Value& payload) { pid_t uid = IPCSkeleton::GetInstance().GetCallingUid(); - if ((uid != RS_UID) && (uid != ROOT_UID)) { - CONCUR_LOGE("only render service call can be allowed, but uid is %{public}d", uid); + if (GetProcessNameByToken() != RENDER_SERVICE_PROCESS_NAME) { + CONCUR_LOGE("Invalid uid %{public}d, only render service can call QueryDeadline", uid); return; } switch (queryItem) { diff --git a/test/BUILD.gn b/test/BUILD.gn index 47ab634..5049499 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -48,7 +48,6 @@ ohos_unittest("concurrent_svc_intf_test") { deps = [ "../frameworks/concurrent_task_client:concurrent_task_client" ] external_deps = [ "c_utils:utils", - "eventhandler:libeventhandler", "hilog:libhilog", "ipc:ipc_single", "safwk:system_ability_fwk", diff --git a/test/unittest/phone/concurrent_task_controller_test.cpp b/test/unittest/phone/concurrent_task_controller_test.cpp index a6d4b7e..bf994ba 100644 --- a/test/unittest/phone/concurrent_task_controller_test.cpp +++ b/test/unittest/phone/concurrent_task_controller_test.cpp @@ -103,24 +103,6 @@ HWTEST_F(ConcurrentTaskControllerTest, InitTest, TestSize.Level1) TaskController::GetInstance().Init(); } -/** - * @tc.name: PushTaskTest - * @tc.desc: Test whether the PushTask interface are normal. - * @tc.type: FUNC - */ -HWTEST_F(ConcurrentTaskControllerTest, CheckUidTest, TestSize.Level1) -{ - int uid = SYSTEM_UID; - bool ret = TaskController::GetInstance().CheckUid(uid); - EXPECT_EQ(ret, true); - uid = 0; - ret = TaskController::GetInstance().CheckUid(uid); - EXPECT_EQ(ret, true); - uid = 100; - ret = TaskController::GetInstance().CheckUid(uid); - EXPECT_EQ(ret, false); -} - /** * @tc.name: PushTaskTest * @tc.desc: Test whether the PushTask interface are normal. @@ -152,15 +134,16 @@ HWTEST_F(ConcurrentTaskControllerTest, TryCreateRsGroupTest, TestSize.Level1) HWTEST_F(ConcurrentTaskControllerTest, QueryRenderServiceTest, TestSize.Level1) { int uid = SYSTEM_UID; + int pid = getpid(); IntervalReply queryRs = {87, 657, 357, 214}; - TaskController::GetInstance().QueryRenderService(uid, queryRs); + TaskController::GetInstance().QueryRenderService(uid, pid, queryRs); int flag = TaskController::GetInstance().renderServiceGrpId_; TaskController::GetInstance().renderServiceGrpId_ = 1; - TaskController::GetInstance().QueryRenderService(uid, queryRs); + TaskController::GetInstance().QueryRenderService(uid, pid, queryRs); TaskController::GetInstance().renderServiceGrpId_ = -1; - TaskController::GetInstance().QueryRenderService(uid, queryRs); + TaskController::GetInstance().QueryRenderService(uid, pid, queryRs); TaskController::GetInstance().renderServiceGrpId_ = flag; - TaskController::GetInstance().QueryRenderService(uid, queryRs); + TaskController::GetInstance().QueryRenderService(uid, pid, queryRs); } /** -- Gitee