From 04e779804ebf9d41885487ffd9c6b39ecd80dbae Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:23:22 +0800 Subject: [PATCH 1/7] =?UTF-8?q?CSRF=E4=BD=BF=E7=94=A8=E6=96=B0=E7=89=88Con?= =?UTF-8?q?fig?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/__init__.py | 5 +++-- apps/dependency/csrf.py | 23 ++++++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/apps/dependency/__init__.py b/apps/dependency/__init__.py index 7bb63d13a..824f0b436 100644 --- a/apps/dependency/__init__.py +++ b/apps/dependency/__init__.py @@ -1,6 +1,7 @@ -"""FastAPI 依赖注入模块 +""" +FastAPI 依赖注入模块 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ from apps.dependency.csrf import verify_csrf_token diff --git a/apps/dependency/csrf.py b/apps/dependency/csrf.py index 054d94b3c..fb747b5ed 100644 --- a/apps/dependency/csrf.py +++ b/apps/dependency/csrf.py @@ -1,18 +1,18 @@ -"""CSRF Token校验 +""" +CSRF Token校验 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Optional from fastapi import HTTPException, Request, Response, status -from apps.common.config import config +from apps.common.config import Config from apps.manager.session import SessionManager -async def verify_csrf_token(request: Request, response: Response) -> Optional[Response]: +async def verify_csrf_token(request: Request, response: Response) -> Response | None: """验证CSRF Token""" - if not config["ENABLE_CSRF"]: + if not Config().get_config().fastapi.csrf: return None csrf_token = request.headers["x-csrf-token"].strip('"') @@ -25,10 +25,11 @@ async def verify_csrf_token(request: Request, response: Response) -> Optional[Re if not new_csrf_token: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Renew CSRF token failed.") - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, - domain=config["DOMAIN"]) + if Config().get_config().deploy.cookie == "DEBUG": + response.set_cookie("_csrf_tk", new_csrf_token, max_age=Config().get_config().fastapi.session_ttl * 60, + domain=Config().get_config().fastapi.domain) else: - response.set_cookie("_csrf_tk", new_csrf_token, max_age=config["SESSION_TTL"] * 60, - secure=True, domain=config["DOMAIN"], samesite="strict") + response.set_cookie("_csrf_tk", new_csrf_token, max_age=Config().get_config().fastapi.session_ttl * 60, + secure=True, domain=Config().get_config().fastapi.domain, samesite="strict") return response + -- Gitee From 6c08cf3631d630bef5a9683ff6bbedc408cb295c Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:24:57 +0800 Subject: [PATCH 2/7] =?UTF-8?q?API=20Key=E7=9B=B8=E5=85=B3=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=BB=93=E6=9E=84=E4=BD=9C=E4=B8=BA=E5=8D=95=E7=8B=AC?= =?UTF-8?q?=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 今后将陆续把数据结构按照功能分块 --- apps/entities/api_key.py | 29 ++++++++++++++++++++++++++ apps/entities/response_data.py | 37 +++++++--------------------------- 2 files changed, 36 insertions(+), 30 deletions(-) create mode 100644 apps/entities/api_key.py diff --git a/apps/entities/api_key.py b/apps/entities/api_key.py new file mode 100644 index 000000000..a2f2566da --- /dev/null +++ b/apps/entities/api_key.py @@ -0,0 +1,29 @@ +"""API密钥相关数据结构""" + +from pydantic import BaseModel + +from apps.entities.response_data import ResponseData + + +class _GetAuthKeyMsg(BaseModel): + """GET /api/auth/key Result数据结构""" + + api_key_exists: bool + + +class GetAuthKeyRsp(ResponseData): + """GET /api/auth/key 返回数据结构""" + + result: _GetAuthKeyMsg + + +class PostAuthKeyMsg(BaseModel): + """POST /api/auth/key Result数据结构""" + + api_key: str + + +class PostAuthKeyRsp(ResponseData): + """POST /api/auth/key 返回数据结构""" + + result: PostAuthKeyMsg diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py index 6ea6f5f77..decb55dc5 100644 --- a/apps/entities/response_data.py +++ b/apps/entities/response_data.py @@ -1,9 +1,10 @@ -"""FastAPI 返回数据结构 +""" +FastAPI 返回数据结构 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -28,35 +29,11 @@ class ResponseData(BaseModel): result: Any -class _GetAuthKeyMsg(BaseModel): - """GET /api/auth/key Result数据结构""" - - api_key_exists: bool - - -class GetAuthKeyRsp(ResponseData): - """GET /api/auth/key 返回数据结构""" - - result: _GetAuthKeyMsg - - -class PostAuthKeyMsg(BaseModel): - """POST /api/auth/key Result数据结构""" - - api_key: str - - -class PostAuthKeyRsp(ResponseData): - """POST /api/auth/key 返回数据结构""" - - result: PostAuthKeyMsg - - class PostClientSessionMsg(BaseModel): """POST /api/client/session Result数据结构""" session_id: str - user_sub: Optional[str] = None + user_sub: str | None = None class PostClientSessionRsp(ResponseData): @@ -378,8 +355,8 @@ class GetServiceDetailMsg(BaseModel): service_id: str = Field(..., alias="serviceId", description="服务ID") name: str = Field(..., description="服务名称") - apis: Optional[list[ServiceApiData]] = Field(default=None, description="解析后的接口列表") - data: Optional[dict[str, Any]] = Field(default=None, description="YAML 内容数据对象") + apis: list[ServiceApiData] | None = Field(default=None, description="解析后的接口列表") + data: dict[str, Any] | None = Field(default=None, description="YAML 内容数据对象") class GetServiceDetailRsp(ResponseData): -- Gitee From 11561c61728779ea12a17e17dde6dee04379a1c7 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:25:27 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E5=A2=9E=E5=8A=A0RequestDataApp=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/request_data.py | 38 +++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index e8ae5fbfc..de9ab62a9 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -1,16 +1,24 @@ -"""FastAPI 请求体 +""" +FastAPI 请求体 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field -from apps.common.config import config +from apps.common.config import Config from apps.entities.appcenter import AppData -from apps.entities.flow_topology import FlowItem, PositionItem -from apps.entities.task import RequestDataApp +from apps.entities.flow_topology import FlowItem + + +class RequestDataApp(BaseModel): + """模型对话中包含的app信息""" + + app_id: str = Field(description="应用ID", alias="appId") + flow_id: str = Field(description="Flow ID", alias="flowId") + params: dict[str, Any] = Field(description="插件参数") class MockRequestData(BaseModel): @@ -25,7 +33,7 @@ class MockRequestData(BaseModel): class RequestDataFeatures(BaseModel): """POST /api/chat的features字段数据""" - max_tokens: int = Field(default=config["LLM_MAX_TOKENS"], description="最大生成token数", ge=0) + max_tokens: int = Field(default=Config().get_config().llm.max_tokens, description="最大生成token数", ge=0) context_num: int = Field(default=5, description="上下文消息数量", le=10, ge=0) @@ -33,11 +41,11 @@ class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" question: str = Field(max_length=2000, description="用户输入") - conversation_id: str = Field(default=None, alias="conversationId", description="聊天ID") - group_id: Optional[str] = Field(default=None, alias="groupId", description="问答组ID") + conversation_id: str = Field(default="", alias="conversationId", description="聊天ID") + group_id: str | None = Field(default=None, alias="groupId", description="问答组ID") language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[], description="文件列表") - app: Optional[RequestDataApp] = Field(default=None, description="应用") + app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") @@ -75,7 +83,7 @@ class AbuseProcessRequest(BaseModel): class CreateAppRequest(AppData): """POST /api/app 请求数据结构""" - app_id: Optional[str] = Field(None, alias="appId", description="应用ID") + app_id: str | None = Field(None, alias="appId", description="应用ID") class ModFavAppRequest(BaseModel): @@ -87,7 +95,7 @@ class ModFavAppRequest(BaseModel): class UpdateServiceRequest(BaseModel): """POST /api/service 请求数据结构""" - service_id: Optional[str] = Field(None, alias="serviceId", description="服务ID(更新时传递)") + service_id: str | None = Field(None, alias="serviceId", description="服务ID(更新时传递)") data: dict[str, Any] = Field(..., description="对应 YAML 内容的数据对象") @@ -100,7 +108,7 @@ class ModFavServiceRequest(BaseModel): class ClientSessionData(BaseModel): """客户端Session信息""" - session_id: Optional[str] = Field(default=None) + session_id: str | None = Field(default=None) class ModifyConversationData(BaseModel): @@ -122,8 +130,8 @@ class AddCommentData(BaseModel): group_id: str is_like: bool = Field(...) dislike_reason: list[str] = Field(default=[], max_length=10) - reason_link: str = Field(default=None, max_length=200) - reason_description: str = Field(default=None, max_length=500) + reason_link: str = Field(default="", max_length=200) + reason_description: str = Field(default="", max_length=500) class PostDomainData(BaseModel): -- Gitee From d50e2c1fe87768df30cc003497a3fcaa955e6c74 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:26:10 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E5=8E=BB=E9=99=A4FlowStopContent=EF=BC=8Cf?= =?UTF-8?q?low.stop=E5=8F=98=E4=B8=BA=E6=8F=90=E7=A4=BA=E6=80=A7Message?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/message.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/apps/entities/message.py b/apps/entities/message.py index 75bc490f1..796c5905c 100644 --- a/apps/entities/message.py +++ b/apps/entities/message.py @@ -1,13 +1,15 @@ -"""队列中的消息结构 +""" +队列中的消息结构 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Any, Optional + +from typing import Any from pydantic import BaseModel, Field -from apps.entities.collection import RecordMetadata -from apps.entities.enum_var import EventType, FlowOutputType, StepStatus +from apps.entities.enum_var import EventType, StepStatus +from apps.entities.record import RecordMetadata class HeartbeatData(BaseModel): @@ -81,13 +83,6 @@ class FlowStartContent(BaseModel): params: dict[str, Any] = Field(description="预先提供的参数") -class FlowStopContent(BaseModel): - """flow.stop消息的content""" - - type: Optional[FlowOutputType] = Field(description="Flow输出的类型", default=None) - data: Optional[dict[str, Any]] = Field(description="Flow输出的数据", default=None) - - class MessageBase(HeartbeatData): """基础消息事件结构""" @@ -95,6 +90,6 @@ class MessageBase(HeartbeatData): group_id: str = Field(min_length=36, max_length=36, alias="groupId") conversation_id: str = Field(min_length=36, max_length=36, alias="conversationId") task_id: str = Field(min_length=36, max_length=36, alias="taskId") - flow: Optional[MessageFlow] = None + flow: MessageFlow | None = None content: dict[str, Any] = {} metadata: MessageMetadata -- Gitee From 1aa6fd822cb0bbdf5e6b5d6b7537c64eddaca1b0 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:26:58 +0800 Subject: [PATCH 5/7] =?UTF-8?q?Scheduler=E7=9A=84Conversation=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E5=8E=9F=E5=A7=8B=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/scheduler.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py index 8f0cf69a5..119891350 100644 --- a/apps/entities/scheduler.py +++ b/apps/entities/scheduler.py @@ -1,33 +1,34 @@ -"""插件、工作流、步骤相关数据结构定义 +""" +插件、工作流、步骤相关数据结构定义 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + from typing import Any from pydantic import BaseModel, Field +from apps.entities.enum_var import CallOutputType from apps.entities.task import FlowStepHistory class CallVars(BaseModel): - """所有Call都需要接受的参数。包含用户输入、上下文信息、Step的输出记录等 - - 这一部分的参数由Executor填充,用户无法修改 - """ + """由Executor填充的变量,即“系统变量”""" summary: str = Field(description="上下文信息") user_sub: str = Field(description="用户ID") question: str = Field(description="改写后的用户输入") - history: dict[str, FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default=[]) + history: dict[str, FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default={}) task_id: str = Field(description="任务ID") flow_id: str = Field(description="Flow ID") session_id: str = Field(description="当前用户的Session ID") + service_id: str = Field(description="语义接口ID") class ExecutorBackground(BaseModel): """Executor的背景信息""" - conversation: str = Field(description="当前Executor的背景信息") + conversation: list[dict[str, str]] = Field(description="对话记录") facts: list[str] = Field(description="当前Executor的背景信息") @@ -38,3 +39,10 @@ class CallError(Exception): """获取Call错误中的数据""" self.message = message self.data = data + + +class CallOutputChunk(BaseModel): + """Call的输出""" + + type: CallOutputType = Field(description="输出类型") + content: str | dict[str, Any] = Field(description="输出内容") -- Gitee From 704f8587b79449ebec48001e6a9320b6149d662a Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:28:07 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E8=8E=B7=E5=8F=96=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E6=97=B6=E6=A3=80=E6=9F=A5=E6=98=AF=E5=90=A6=E5=AD=98=E5=9C=A8?= =?UTF-8?q?OIDC=20Token=EF=BC=8C=E5=AD=98=E5=9C=A8=E5=B0=B1=E8=B7=B3?= =?UTF-8?q?=E8=BF=87=E7=99=BB=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/user.py | 110 +++++++++++++++++++++++++++++++++------- 1 file changed, 93 insertions(+), 17 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 9be898dd6..4699edb47 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -1,32 +1,101 @@ -"""用户鉴权 +""" +用户鉴权 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from fastapi import Depends + +import logging + +from fastapi import Depends, Response from fastapi.security import OAuth2PasswordBearer from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config +from apps.common.oidc import oidc_provider from apps.manager.api_key import ApiKeyManager from apps.manager.session import SessionManager +logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -async def verify_user(request: HTTPConnection) -> None: - """验证Session是否已鉴权;未鉴权则抛出HTTP 401;接口级dependence +async def _verify_oidc_auth(request: HTTPConnection, response: Response) -> str: + """ + 验证OIDC认证状态并获取用户信息 :param request: HTTP请求 - :return: + :return: 用户信息字典 + :raises: HTTPException 当OIDC验证失败时 + """ + try: + tokens = await oidc_provider.get_login_status(request.cookies) + except Exception as err: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 检查OIDC登录状态失败") from err + + if not tokens: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 检查OIDC登录状态失败") + + try: + user_info = await oidc_provider.get_oidc_user(tokens["access_token"]) + except Exception as err: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取用户信息失败") from err + + if not user_info: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取用户信息失败") + + # 创建新的session + if request.client is None: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="[OIDC] 获取登录IP失败") + + user_sub = user_info["user_sub"] + user_host = request.client.host + try: + current_session = request.cookies["ECSESSION"] + await SessionManager.delete_session(current_session) + except Exception: + logger.exception("[VerifySessionMiddleware] 删除session失败") + + current_session = await SessionManager.create_session(user_host, user_sub) + + # 设置cookie + if Config().get_config().deploy.cookie == "DEBUG": + response.set_cookie( + "ECSESSION", + current_session, + ) + else: + response.set_cookie( + "ECSESSION", + current_session, + max_age=Config().get_config().fastapi.session_ttl * 60, + secure=True, + domain=Config().get_config().fastapi.domain, + httponly=True, + samesite="strict", + ) + + return user_sub + + +async def verify_user(request: HTTPConnection, response: Response) -> None: + """ + 验证Session是否已鉴权;未鉴权则抛出HTTP 401;接口级dependence + + :param request: HTTP请求 + :return: None """ session_id = request.cookies["ECSESSION"] - if not await SessionManager.verify_user(session_id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - # pass + if await SessionManager.verify_user(session_id): + return + + await _verify_oidc_auth(request, response) + async def get_session(request: HTTPConnection) -> str: - """验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence + """ + 验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence :param request: HTTP请求 :return: Session ID @@ -36,20 +105,26 @@ async def get_session(request: HTTPConnection) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") return session_id -async def get_user(request: HTTPConnection) -> str: - """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence + +async def get_user(request: HTTPConnection, response: Response) -> str: + """ + 验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence :param request: HTTP请求体 :return: 用户sub """ session_id = request.cookies["ECSESSION"] + user = await SessionManager.get_user(session_id) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - return user + if user: + return user + + return await _verify_oidc_auth(request, response) + async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: - """验证API Key是否有效;无效则抛出HTTP 401;接口级dependence + """ + 验证API Key是否有效;无效则抛出HTTP 401;接口级dependence :param api_key: API Key :return: @@ -59,7 +134,8 @@ async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: async def get_user_by_api_key(api_key: str = Depends(oauth2_scheme)) -> str: - """验证API Key是否有效;若有效,返回对应的user_sub;若无效,抛出HTTP 401;参数级dependence + """ + 验证API Key是否有效;若有效,返回对应的user_sub;若无效,抛出HTTP 401;参数级dependence :param api_key: API Key :return: 用户sub -- Gitee From f86327ecb168b0d7ed590d8bf490ef36880ada33 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 11:28:38 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=E6=8B=86=E5=88=86Session=E9=80=BB=E8=BE=91?= =?UTF-8?q?=E6=88=90=E5=A4=9A=E4=B8=AA=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/session.py | 107 ++++++++++++++++++++++--------------- 1 file changed, 65 insertions(+), 42 deletions(-) diff --git a/apps/dependency/session.py b/apps/dependency/session.py index 05b1bf9af..e83a6a97e 100644 --- a/apps/dependency/session.py +++ b/apps/dependency/session.py @@ -1,12 +1,16 @@ -"""浏览器Session校验 +""" +浏览器Session校验 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + +from typing import Any + from fastapi import Response from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.requests import Request -from apps.common.config import config +from apps.common.config import Config from apps.manager.session import SessionManager BYPASS_LIST = [ @@ -19,49 +23,68 @@ BYPASS_LIST = [ class VerifySessionMiddleware(BaseHTTPMiddleware): """浏览器Session校验中间件""" - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # noqa: C901, PLR0912 + def _check_bypass_list(self, path: str) -> bool: + """检查请求路径是否需要跳过验证""" + return path in BYPASS_LIST + + def _validate_client(self, request: Request) -> str: + """验证客户端信息并返回主机地址""" + if request.client is None or request.client.host is None: + err = "[VerifySessionMiddleware] 无法检测请求来源IP!" + raise ValueError(err) + return request.client.host + + def _update_cookie_header(self, request: Request, session_id: str) -> None: + """更新请求头中的cookie信息""" + cookie_str = "" + for item in request.scope["headers"]: + if item[0] == b"cookie": + cookie_str = item[1].decode() + request.scope["headers"].remove(item) + break + + all_cookies = "" + if cookie_str: + other_headers = cookie_str.split(";") + all_cookies = "; ".join(item for item in other_headers if "ECSESSION" not in item) + + all_cookies = f"{all_cookies}; ECSESSION={session_id}" if all_cookies else f"ECSESSION={session_id}" + request.scope["headers"].append((b"cookie", all_cookies.encode())) + + def _set_response_cookie(self, response: Response, session_id: str) -> None: + """设置响应cookie""" + # 检查 是否其他dependence 设置过cookie + if "ECSESSION" in response.headers.get("set-cookie", ""): + return + + cookie_params: dict[str, Any] = { + "key": "ECSESSION", + "value": session_id, + "domain": Config().get_config().fastapi.domain, + } + + if Config().get_config().deploy.cookie != "DEBUG": + cookie_params.update({ + "httponly": True, + "secure": True, + "samesite": "strict", + "max_age": Config().get_config().fastapi.session_ttl * 60, + }) + + response.set_cookie(**cookie_params) + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: """浏览器Session校验中间件""" - if request.url.path in BYPASS_LIST: + if self._check_bypass_list(request.url.path): return await call_next(request) - # TODO: 加入apikey校验 + host = self._validate_client(request) cookie = request.cookies.get("ECSESSION", "") - if request.client is None or request.client.host is None: - err = "无法检测请求来源IP!" - raise ValueError(err) - host = request.client.host session_id = await SessionManager.get_session(cookie, host) - if session_id != request.cookies.get("ECSESSION", ""): - cookie_str = "" - - for item in request.scope["headers"]: - if item[0] == b"cookie": - cookie_str = item[1].decode() - request.scope["headers"].remove(item) - break - - all_cookies = "" - if cookie_str != "": - other_headers = cookie_str.split(";") - for item in other_headers: - if "ECSESSION" not in item: - all_cookies += f"{item}; " - - all_cookies += f"ECSESSION={session_id}" - request.scope["headers"].append((b"cookie", all_cookies.encode())) - - response = await call_next(request) - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("ECSESSION", session_id, domain=config["DOMAIN"]) - else: - response.set_cookie("ECSESSION", session_id, httponly=True, secure=True, samesite="strict", - max_age=config["SESSION_TTL"] * 60, domain=config["DOMAIN"]) - else: - response = await call_next(request) - if config["COOKIE_MODE"] == "DEBUG": - response.set_cookie("ECSESSION", session_id, domain=config["DOMAIN"]) - else: - response.set_cookie("ECSESSION", session_id, httponly=True, secure=True, samesite="strict", - max_age=config["SESSION_TTL"] * 60, domain=config["DOMAIN"]) + if session_id != cookie: + self._update_cookie_header(request, session_id) + + response = await call_next(request) + self._set_response_cookie(response, session_id) return response -- Gitee