diff --git a/apps/dependency/__init__.py b/apps/dependency/__init__.py index 7bb63d13a6a4595056598a0b16a7b8a9f0dc68ae..824f0b4369a2b95697b53f786ec821362b00473d 100644 --- a/apps/dependency/__init__.py +++ b/apps/dependency/__init__.py @@ -1,6 +1,7 @@ -"""FastAPI 依赖注入模块 +""" +FastAPI 依赖注入模块 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ from apps.dependency.csrf import verify_csrf_token diff --git a/apps/dependency/csrf.py b/apps/dependency/csrf.py index 054d94b3cb95daa10282724dfd1f0b38dab214d8..fb747b5ede895863f2625d4ebb17cf50b1854a70 100644 --- a/apps/dependency/csrf.py +++ b/apps/dependency/csrf.py @@ -1,18 +1,18 @@ -"""CSRF Token校验 +""" +CSRF Token校验 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Optional from fastapi import HTTPException, Request, Response, status -from apps.common.config import config +from apps.common.config import Config from apps.manager.session import SessionManager -async def verify_csrf_token(request: Request, response: Response) -> Optional[Response]: +async def verify_csrf_token(request: Request, response: Response) -> Response | None: """验证CSRF Token""" - if not config["ENABLE_CSRF"]: + if not Config().get_config().fastapi.csrf: return None csrf_token = request.headers["x-csrf-token"].strip('"') @@ -25,10 +25,11 @@ async def verify_csrf_token(request: Request, response: Response) -> Optional[Re if not new_csrf_token: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Renew CSRF token failed.") - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, - domain=config["DOMAIN"]) + if Config().get_config().deploy.cookie == "DEBUG": + response.set_cookie("_csrf_tk", new_csrf_token, max_age=Config().get_config().fastapi.session_ttl * 60, + domain=Config().get_config().fastapi.domain) else: - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, - secure=True, domain=config["DOMAIN"], samesite="strict") + response.set_cookie("_csrf_tk", new_csrf_token, max_age=Config().get_config().fastapi.session_ttl * 60, + secure=True, domain=Config().get_config().fastapi.domain, samesite="strict") return response + diff --git a/apps/dependency/session.py b/apps/dependency/session.py index 05b1bf9af553fdffeb5169f242a80d40f38c0212..e83a6a97ebedc03cd5c3f1fb729ebcbadd5ce502 100644 --- a/apps/dependency/session.py +++ b/apps/dependency/session.py @@ -1,12 +1,16 @@ -"""浏览器Session校验 +""" +浏览器Session校验 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + +from typing import Any + from fastapi import Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from apps.common.config import config +from apps.common.config import Config from apps.manager.session import SessionManager BYPASS_LIST = [ @@ -19,49 +23,68 @@ BYPASS_LIST = [ class VerifySessionMiddleware(BaseHTTPMiddleware): """浏览器Session校验中间件""" - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: C901, PLR0912 + def _check_bypass_list(self, path: str) -> bool: + """检查请求路径是否需要跳过验证""" + return path in BYPASS_LIST + + def _validate_client(self, request: Request) -> str: + """验证客户端信息并返回主机地址""" + if request.client is None or request.client.host is None: + err = "[VerifySessionMiddleware] 无法检测请求来源IP!" + raise ValueError(err) + return request.client.host + + def _update_cookie_header(self, request: Request, session_id: str) -> None: + """更新请求头中的cookie信息""" + cookie_str = "" + for item in request.scope["headers"]: + if item[0] == b"cookie": + cookie_str = item[1].decode() + request.scope["headers"].remove(item) + break + + all_cookies = "" + if cookie_str: + other_headers = cookie_str.split(";") + all_cookies = "; ".join(item for item in other_headers if "ECSESSION" not in item) + + all_cookies = f"{all_cookies}; ECSESSION={session_id}" if all_cookies else f"ECSESSION={session_id}" + request.scope["headers"].append((b"cookie", all_cookies.encode())) + + def _set_response_cookie(self, response: Response, session_id: str) -> None: + """设置响应cookie""" + # 检查 是否其他dependence 设置过cookie + if "ECSESSION" in response.headers.get("set-cookie", ""): + return + + cookie_params: dict[str, Any] = { + "key": "ECSESSION", + "value": session_id, + "domain": Config().get_config().fastapi.domain, + } + + if Config().get_config().deploy.cookie != "DEBUG": + cookie_params.update({ + "httponly": True, + "secure": True, + "samesite": "strict", + "max_age": Config().get_config().fastapi.session_ttl * 60, + }) + + response.set_cookie(**cookie_params) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: """浏览器Session校验中间件""" - if request.url.path in BYPASS_LIST: + if self._check_bypass_list(request.url.path): return await call_next(request) - # TODO: 加入apikey校验 + host = self._validate_client(request) cookie = request.cookies.get("ECSESSION", "") - if request.client is None or request.client.host is None: - err = "无法检测请求来源IP!" - raise ValueError(err) - host = request.client.host session_id = await SessionManager.get_session(cookie, host) - if session_id != request.cookies.get("ECSESSION", ""): - cookie_str = "" - - for item in request.scope["headers"]: - if item[0] == b"cookie": - cookie_str = item[1].decode() - request.scope["headers"].remove(item) - break - - all_cookies = "" - if cookie_str != "": - other_headers = cookie_str.split(";") - for item in other_headers: - if "ECSESSION" not in item: - all_cookies += f"{item}; " - - all_cookies += f"ECSESSION={session_id}" - request.scope["headers"].append((b"cookie", all_cookies.encode())) - - response = await call_next(request) - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("ECSESSION", session_id, domain=config["DOMAIN"]) - else: - response.set_cookie("ECSESSION", session_id, httponly=True, secure=True, samesite="strict", - max_age=config["SESSION_TTL"] * 60, domain=config["DOMAIN"]) - else: - response = await call_next(request) - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("ECSESSION", session_id, domain=config["DOMAIN"]) - else: - response.set_cookie("ECSESSION", session_id, httponly=True, secure=True, samesite="strict", - max_age=config["SESSION_TTL"] * 60, domain=config["DOMAIN"]) + if session_id != cookie: + self._update_cookie_header(request, session_id) + + response = await call_next(request) + self._set_response_cookie(response, session_id) return response diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 9be898dd6931513d1c517a620e78abf645c89a45..4699edb47ad19e968262075c8df56139de31fc11 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -1,32 +1,101 @@ -"""用户鉴权 +""" +用户鉴权 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from fastapi import Depends + +import logging + +from fastapi import Depends, Response from fastapi.security import OAuth2PasswordBearer from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config +from apps.common.oidc import oidc_provider from apps.manager.api_key import ApiKeyManager from apps.manager.session import SessionManager +logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -async def verify_user(request: HTTPConnection) -> None: - """验证Session是否已鉴权;未鉴权则抛出HTTP 401;接口级dependence +async def _verify_oidc_auth(request: HTTPConnection, response: Response) -> str: + """ + 验证OIDC认证状态并获取用户信息 :param request: HTTP请求 - :return: + :return: 用户信息字典 + :raises: HTTPException 当OIDC验证失败时 + """ + try: + tokens = await oidc_provider.get_login_status(request.cookies) + except Exception as err: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 检查OIDC登录状态失败") from err + + if not tokens: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 检查OIDC登录状态失败") + + try: + user_info = await oidc_provider.get_oidc_user(tokens["access_token"]) + except Exception as err: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取用户信息失败") from err + + if not user_info: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取用户信息失败") + + # 创建新的session + if request.client is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取登录IP失败") + + user_sub = user_info["user_sub"] + user_host = request.client.host + try: + current_session = request.cookies["ECSESSION"] + await SessionManager.delete_session(current_session) + except Exception: + logger.exception("[VerifySessionMiddleware] 删除session失败") + + current_session = await SessionManager.create_session(user_host, user_sub) + + # 设置cookie + if Config().get_config().deploy.cookie == "DEBUG": + response.set_cookie( + "ECSESSION", + current_session, + ) + else: + response.set_cookie( + "ECSESSION", + current_session, + max_age=Config().get_config().fastapi.session_ttl * 60, + secure=True, + domain=Config().get_config().fastapi.domain, + httponly=True, + samesite="strict", + ) + + return user_sub + + +async def verify_user(request: HTTPConnection, response: Response) -> None: + """ + 验证Session是否已鉴权;未鉴权则抛出HTTP 401;接口级dependence + + :param request: HTTP请求 + :return: None """ session_id = request.cookies["ECSESSION"] - if not await SessionManager.verify_user(session_id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - # pass + if await SessionManager.verify_user(session_id): + return + + await _verify_oidc_auth(request, response) + async def get_session(request: HTTPConnection) -> str: - """验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence + """ + 验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence :param request: HTTP请求 :return: Session ID @@ -36,20 +105,26 @@ async def get_session(request: HTTPConnection) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") return session_id -async def get_user(request: HTTPConnection) -> str: - """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence + +async def get_user(request: HTTPConnection, response: Response) -> str: + """ + 验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence :param request: HTTP请求体 :return: 用户sub """ session_id = request.cookies["ECSESSION"] + user = await SessionManager.get_user(session_id) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - return user + if user: + return user + + return await _verify_oidc_auth(request, response) + async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: - """验证API Key是否有效;无效则抛出HTTP 401;接口级dependence + """ + 验证API Key是否有效;无效则抛出HTTP 401;接口级dependence :param api_key: API Key :return: @@ -59,7 +134,8 @@ async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: async def get_user_by_api_key(api_key: str = Depends(oauth2_scheme)) -> str: - """验证API Key是否有效;若有效,返回对应的user_sub;若无效,抛出HTTP 401;参数级dependence + """ + 验证API Key是否有效;若有效,返回对应的user_sub;若无效,抛出HTTP 401;参数级dependence :param api_key: API Key :return: 用户sub diff --git a/apps/entities/api_key.py b/apps/entities/api_key.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f2566da4895f7e56705e4f4d15c7dbc33b14c8 --- /dev/null +++ b/apps/entities/api_key.py @@ -0,0 +1,29 @@ +"""API密钥相关数据结构""" + +from pydantic import BaseModel + +from apps.entities.response_data import ResponseData + + +class _GetAuthKeyMsg(BaseModel): + """GET /api/auth/key Result数据结构""" + + api_key_exists: bool + + +class GetAuthKeyRsp(ResponseData): + """GET /api/auth/key 返回数据结构""" + + result: _GetAuthKeyMsg + + +class PostAuthKeyMsg(BaseModel): + """POST /api/auth/key Result数据结构""" + + api_key: str + + +class PostAuthKeyRsp(ResponseData): + """POST /api/auth/key 返回数据结构""" + + result: PostAuthKeyMsg diff --git a/apps/entities/message.py b/apps/entities/message.py index 75bc490f17e943fbc4a8455b72b83b525102e5e9..796c5905c183543af62b8afd0e42e778e68d5b36 100644 --- a/apps/entities/message.py +++ b/apps/entities/message.py @@ -1,13 +1,15 @@ -"""队列中的消息结构 +""" +队列中的消息结构 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Any, Optional + +from typing import Any from pydantic import BaseModel, Field -from apps.entities.collection import RecordMetadata -from apps.entities.enum_var import EventType, FlowOutputType, StepStatus +from apps.entities.enum_var import EventType, StepStatus +from apps.entities.record import RecordMetadata class HeartbeatData(BaseModel): @@ -81,13 +83,6 @@ class FlowStartContent(BaseModel): params: dict[str, Any] = Field(description="预先提供的参数") -class FlowStopContent(BaseModel): - """flow.stop消息的content""" - - type: Optional[FlowOutputType] = Field(description="Flow输出的类型", default=None) - data: Optional[dict[str, Any]] = Field(description="Flow输出的数据", default=None) - - class MessageBase(HeartbeatData): """基础消息事件结构""" @@ -95,6 +90,6 @@ class MessageBase(HeartbeatData): group_id: str = Field(min_length=36, max_length=36, alias="groupId") conversation_id: str = Field(min_length=36, max_length=36, alias="conversationId") task_id: str = Field(min_length=36, max_length=36, alias="taskId") - flow: Optional[MessageFlow] = None + flow: MessageFlow | None = None content: dict[str, Any] = {} metadata: MessageMetadata diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index e8ae5fbfcd3847fd686e2a5a770d4a7ce7fcfdb5..de9ab62a950b732076e17f4fb30f0c4ffcdae555 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -1,16 +1,24 @@ -"""FastAPI 请求体 +""" +FastAPI 请求体 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field -from apps.common.config import config +from apps.common.config import Config from apps.entities.appcenter import AppData -from apps.entities.flow_topology import FlowItem, PositionItem -from apps.entities.task import RequestDataApp +from apps.entities.flow_topology import FlowItem + + +class RequestDataApp(BaseModel): + """模型对话中包含的app信息""" + + app_id: str = Field(description="应用ID", alias="appId") + flow_id: str = Field(description="Flow ID", alias="flowId") + params: dict[str, Any] = Field(description="插件参数") class MockRequestData(BaseModel): @@ -25,7 +33,7 @@ class MockRequestData(BaseModel): class RequestDataFeatures(BaseModel): """POST /api/chat的features字段数据""" - max_tokens: int = Field(default=config["LLM_MAX_TOKENS"], description="最大生成token数", ge=0) + max_tokens: int = Field(default=Config().get_config().llm.max_tokens, description="最大生成token数", ge=0) context_num: int = Field(default=5, description="上下文消息数量", le=10, ge=0) @@ -33,11 +41,11 @@ class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" question: str = Field(max_length=2000, description="用户输入") - conversation_id: str = Field(default=None, alias="conversationId", description="聊天ID") - group_id: Optional[str] = Field(default=None, alias="groupId", description="问答组ID") + conversation_id: str = Field(default="", alias="conversationId", description="聊天ID") + group_id: str | None = Field(default=None, alias="groupId", description="问答组ID") language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[], description="文件列表") - app: Optional[RequestDataApp] = Field(default=None, description="应用") + app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") @@ -75,7 +83,7 @@ class AbuseProcessRequest(BaseModel): class CreateAppRequest(AppData): """POST /api/app 请求数据结构""" - app_id: Optional[str] = Field(None, alias="appId", description="应用ID") + app_id: str | None = Field(None, alias="appId", description="应用ID") class ModFavAppRequest(BaseModel): @@ -87,7 +95,7 @@ class ModFavAppRequest(BaseModel): class UpdateServiceRequest(BaseModel): """POST /api/service 请求数据结构""" - service_id: Optional[str] = Field(None, alias="serviceId", description="服务ID(更新时传递)") + service_id: str | None = Field(None, alias="serviceId", description="服务ID(更新时传递)") data: dict[str, Any] = Field(..., description="对应 YAML 内容的数据对象") @@ -100,7 +108,7 @@ class ModFavServiceRequest(BaseModel): class ClientSessionData(BaseModel): """客户端Session信息""" - session_id: Optional[str] = Field(default=None) + session_id: str | None = Field(default=None) class ModifyConversationData(BaseModel): @@ -122,8 +130,8 @@ class AddCommentData(BaseModel): group_id: str is_like: bool = Field(...) dislike_reason: list[str] = Field(default=[], max_length=10) - reason_link: str = Field(default=None, max_length=200) - reason_description: str = Field(default=None, max_length=500) + reason_link: str = Field(default="", max_length=200) + reason_description: str = Field(default="", max_length=500) class PostDomainData(BaseModel): diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py index 6ea6f5f778e04a727b0748e17659f91e180773c3..decb55dc5502a6676e684724f477e53e238b9be3 100644 --- a/apps/entities/response_data.py +++ b/apps/entities/response_data.py @@ -1,9 +1,10 @@ -"""FastAPI 返回数据结构 +""" +FastAPI 返回数据结构 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -28,35 +29,11 @@ class ResponseData(BaseModel): result: Any -class _GetAuthKeyMsg(BaseModel): - """GET /api/auth/key Result数据结构""" - - api_key_exists: bool - - -class GetAuthKeyRsp(ResponseData): - """GET /api/auth/key 返回数据结构""" - - result: _GetAuthKeyMsg - - -class PostAuthKeyMsg(BaseModel): - """POST /api/auth/key Result数据结构""" - - api_key: str - - -class PostAuthKeyRsp(ResponseData): - """POST /api/auth/key 返回数据结构""" - - result: PostAuthKeyMsg - - class PostClientSessionMsg(BaseModel): """POST /api/client/session Result数据结构""" session_id: str - user_sub: Optional[str] = None + user_sub: str | None = None class PostClientSessionRsp(ResponseData): @@ -378,8 +355,8 @@ class GetServiceDetailMsg(BaseModel): service_id: str = Field(..., alias="serviceId", description="服务ID") name: str = Field(..., description="服务名称") - apis: Optional[list[ServiceApiData]] = Field(default=None, description="解析后的接口列表") - data: Optional[dict[str, Any]] = Field(default=None, description="YAML 内容数据对象") + apis: list[ServiceApiData] | None = Field(default=None, description="解析后的接口列表") + data: dict[str, Any] | None = Field(default=None, description="YAML 内容数据对象") class GetServiceDetailRsp(ResponseData): diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py index 8f0cf69a50d3b77d3d2bdfe31e26fab3f1d48b9e..119891350e954838403c57458d376bec32865e7b 100644 --- a/apps/entities/scheduler.py +++ b/apps/entities/scheduler.py @@ -1,33 +1,34 @@ -"""插件、工作流、步骤相关数据结构定义 +""" +插件、工作流、步骤相关数据结构定义 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + from typing import Any from pydantic import BaseModel, Field +from apps.entities.enum_var import CallOutputType from apps.entities.task import FlowStepHistory class CallVars(BaseModel): - """所有Call都需要接受的参数。包含用户输入、上下文信息、Step的输出记录等 - - 这一部分的参数由Executor填充,用户无法修改 - """ + """由Executor填充的变量,即“系统变量”""" summary: str = Field(description="上下文信息") user_sub: str = Field(description="用户ID") question: str = Field(description="改写后的用户输入") - history: dict[str, FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default=[]) + history: dict[str, FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default={}) task_id: str = Field(description="任务ID") flow_id: str = Field(description="Flow ID") session_id: str = Field(description="当前用户的Session ID") + service_id: str = Field(description="语义接口ID") class ExecutorBackground(BaseModel): """Executor的背景信息""" - conversation: str = Field(description="当前Executor的背景信息") + conversation: list[dict[str, str]] = Field(description="对话记录") facts: list[str] = Field(description="当前Executor的背景信息") @@ -38,3 +39,10 @@ class CallError(Exception): """获取Call错误中的数据""" self.message = message self.data = data + + +class CallOutputChunk(BaseModel): + """Call的输出""" + + type: CallOutputType = Field(description="输出类型") + content: str | dict[str, Any] = Field(description="输出内容")