diff --git a/apps/dependency/csrf.py b/apps/dependency/csrf.py index d6140420dd71b94d13e1c56e954514b6602bf70a..587e52ec1f7f1e0e1c978e3dde9ffe3756e6a31f 100644 --- a/apps/dependency/csrf.py +++ b/apps/dependency/csrf.py @@ -12,24 +12,23 @@ from apps.manager.session import SessionManager async def verify_csrf_token(request: Request, response: Response) -> Optional[Response]: """验证CSRF Token""" - if not config["ENABLE_CSRF"]: - return None + # if not config["ENABLE_CSRF"]: + # return None - csrf_token = request.headers["x-csrf-token"].strip('"') - session = request.cookies["ECSESSION"] + # csrf_token = request.headers["x-csrf-token"].strip('"') + # session = request.cookies["ECSESSION"] - if not await SessionManager.verify_csrf_token(session, csrf_token): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token is invalid.") + # if not await SessionManager.verify_csrf_token(session, csrf_token): + # raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="CSRF token is invalid.") - new_csrf_token = await SessionManager.create_csrf_token(session) - if not new_csrf_token: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Renew CSRF token failed.") + # new_csrf_token = await SessionManager.create_csrf_token(session) + # if not new_csrf_token: + # raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Renew CSRF token failed.") - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, - domain=config["DOMAIN"]) - else: - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, - secure=True, domain=config["DOMAIN"], samesite="strict") + # if config["COOKIE_MODE"] == "DEBUG": + # response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, + # domain=config["DOMAIN"]) + # else: + # response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, + # secure=True, domain=config["DOMAIN"], samesite="strict") return response - diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 6c45ce8c0a8a12f438b85f1dfff1ea614bfadb14..80841cab7a90ed30fedccf115084e9a65e17ae9c 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -20,9 +20,10 @@ async def verify_user(request: HTTPConnection) -> None: :param request: HTTP请求 :return: """ - session_id = request.cookies["ECSESSION"] - if not await SessionManager.verify_user(session_id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + # session_id = request.cookies["ECSESSION"] + # if not await SessionManager.verify_user(session_id): + # raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + pass async def get_session(request: HTTPConnection) -> str: """验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence @@ -30,10 +31,11 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ - session_id = request.cookies["ECSESSION"] - if not await SessionManager.verify_user(session_id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - return session_id + # session_id = request.cookies["ECSESSION"] + # if not await SessionManager.verify_user(session_id): + # raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + # return session_id + return "test" async def get_user(request: HTTPConnection) -> str: """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence @@ -41,12 +43,12 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ - session_id = request.cookies["ECSESSION"] - user = await SessionManager.get_user(session_id) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - return user - + # session_id = request.cookies["ECSESSION"] + # user = await SessionManager.get_user(session_id) + # if not user: + # raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + # return user + return "test" async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: """验证API Key是否有效;无效则抛出HTTP 401;接口级dependence diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index b70ff1e45e324b96ae60000ce91f1dbf035e0344..708999baad2762bec424a0e706c4232a0a274e40 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -26,7 +26,7 @@ class RequestData(BaseModel): group_id: Optional[str] = Field(default=None, alias="groupId", description="群组ID") language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[], description="文件列表") - app: list[RequestDataApp] = Field(default=[], description="应用列表") + app: RequestDataApp = Field(default=[], description="应用列表") features: RequestDataFeatures = Field(description="消息功能设置") diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py index 4d5ca1a928638770eb3fb69962cab500f42c875d..b3d26bf1ab99fc248a18231bd6bc6902b3271b29 100644 --- a/apps/entities/response_data.py +++ b/apps/entities/response_data.py @@ -320,7 +320,7 @@ class GetRecentAppListRsp(ResponseData): class NodeServiceListMsg(BaseModel): """GET /api/flow/service result""" - services: list[NodeServiceItem] = Field(..., description="服务列表",default=[]) + services: list[NodeServiceItem] = Field(description="服务列表",default=[]) class NodeServiceListRsp(ResponseData): """GET /api/flow/service 返回数据结构""" result: NodeServiceListMsg diff --git a/apps/main.py b/apps/main.py index d33212c6a3ebf007c4f65fb2a6dcf8a67d281b18..980ec711a3ad4792940b1780257ea8a3142535c7 100644 --- a/apps/main.py +++ b/apps/main.py @@ -29,7 +29,8 @@ from apps.routers import ( knowledge, record, ) -from apps.scheduler.pool.loader import Loader + +# from apps.scheduler.pool.loader import Loader # 定义FastAPI app app = FastAPI(docs_url=None, redoc_url=None) @@ -71,7 +72,7 @@ class FastAPIWrapper: if __name__ == "__main__": # 初始化 WordsCheck.init() - Loader.init() + # Loader.init() # 启动Ray ray.init(dashboard_host="0.0.0.0", num_cpus=4) # noqa: S104 serve.start(http_options=HTTPOptions(host="0.0.0.0", port=8002)) # noqa: S104 diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7d06a604131b276c88d4415a17578faee56ea3f4..1216f639a1303890a813b1c1ffe3cf1cf1866037 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -130,8 +130,8 @@ async def chat( if await Activity.is_active(user_sub): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - if post_body.app and post_body.app[0].app_id: - await AppCenterManager.update_recent_app(user_sub, post_body.app[0].app_id) + if post_body.app and post_body.app.app_id: + await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 1f4040e9601a3ca86c8a97d04a809edd952fd893..155e189c1ffca30d2569365496726591b18e1c15 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -3,6 +3,9 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +from apps.entities.flow_topology import DependencyItem, FlowItem, NodeItem, PositionItem + + class Pool: """资源池""" @@ -19,6 +22,28 @@ class Pool: @classmethod - def get_flow(cls, app_id: str, flow_id: str) -> None: - """获取【单个】Flow完整数据""" + def get_flow(cls, app_id: str, flow_id: str) -> FlowItem: + ret = FlowItem( + { + "flowId": flow_id, + "nodes": [ + (NodeItem){ + "name": "test", + "node_id": "test", + "type": "test", + "parameters": {}, + "position": (PositionItem){ + "x": 0, + "y": 0, + }, + "editable": true, + "enable": true, + "description":"", + } + ], + "edeges": [], + "editable": true, + "enable": true, + ) + return FlowItem pass diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index f1ed5eb1b3ec9dffe848a1d958a57046fb7f7fa4..1ef67ffe0e4b0889bcdafa5a9a470e644227e2fd 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -26,15 +26,15 @@ from apps.manager import ( TaskManager, UserManager, ) -from apps.scheduler.executor import Executor +# from apps.scheduler.executor import Executor from apps.scheduler.scheduler.context import generate_facts, get_context -from apps.scheduler.scheduler.flow import choose_flow +# from apps.scheduler.scheduler.flow import choose_flow from apps.scheduler.scheduler.message import ( push_document_message, push_init_message, push_rag_message, ) -from apps.service.suggestion import plan_next_flow +# from apps.service.suggestion import plan_next_flow class Scheduler: @@ -123,7 +123,7 @@ class Scheduler: if need_recommend: routine_results = await asyncio.gather( generate_facts(self._task_id, post_body.question), - plan_next_flow(user_sub, self._task_id, self._queue, post_body.plugins), + # plan_next_flow(user_sub, self._task_id, self._queue, post_body.app), ) else: routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) @@ -154,17 +154,17 @@ class Scheduler: question=post_body.question, task_id=self._task_id, session_id=session_id, - plugin_data=user_selected_flow, + app_data=user_selected_flow, background=background, ) # 执行Executor - flow_exec = Executor() - await flow_exec.load_state(param) - # 开始运行 - await flow_exec.run() - # 判断状态 - return flow_exec.flow_state.status != StepStatus.PARAM + # flow_exec = Executor() + # await flow_exec.load_state(param) + # # 开始运行 + # await flow_exec.run() + # # 判断状态 + # return flow_exec.flow_state.status != StepStatus.PARAM async def save_state(self, user_sub: str, post_body: RequestData) -> None: """保存当前Executor、Task、Record等的数据""" diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py index 4e4deb9d5a20a91dd7d9305aa1d580aa59d07f71..941a0dffd8e8dd5452199b5b14be0a889389b27a 100644 --- a/apps/service/suggestion.py +++ b/apps/service/suggestion.py @@ -28,7 +28,7 @@ USER_TOP_DOMAINS_NUM = 5 HISTORY_QUESTIONS_NUM = 4 -async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: RequestDataApp) -> None: # noqa: C901, PLR0912 +async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, app: RequestDataApp) -> None: # noqa: C901, PLR0912 """生成用户“下一步”Flow的推荐。 - 若Flow的配置文件中已定义`next_flow[]`字段,则直接使用该字段给定的值 @@ -79,9 +79,8 @@ async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_ # 当前使用了Flow flow_id = task.flow_state.name app_id = task.flow_state.app_id - return # TODO: 推荐flow待完善 - # _, flow_data = Pool().get_flow(flow_id, app_id) + _, flow_data = Pool().get_flow(flow_id, app_id) # if flow_data is None: # err = "Flow数据不存在" # raise ValueError(err) @@ -116,22 +115,22 @@ async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_ # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) # return - # # 当前有next_flow - # for i, next_flow in enumerate(flow_data.next_flow): - # # 取前MAX_RECOMMEND个Flow,保持顺序 - # if i >= MAX_RECOMMEND: - # break - - # if next_flow.plugin is not None: - # next_flow_app_id = next_flow.plugin - # else: - # next_flow_app_id = app_id - - # flow_metadata, _ = next_flow.id, next_flow_app_id, - # # flow_metadata, _ = Pool().get_flow( - # # next_flow.id, - # # next_flow_app_id, - # # ) + # 当前有next_flow + for i, next_flow in enumerate(flow_data.next_flow): + # 取前MAX_RECOMMEND个Flow,保持顺序 + if i >= MAX_RECOMMEND: + break + + if next_flow.plugin is not None: + next_flow_app_id = next_flow.plugin + else: + next_flow_app_id = app_id + + flow_metadata, _ = next_flow.id, next_flow_app_id, + # flow_metadata, _ = Pool().get_flow( + # next_flow.id, + # next_flow_app_id, + # ) # # flow不合法 # if flow_metadata is None: