From bc2560e247376c7108bf4b7d026b9e788e936600 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 14:53:48 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=9F=BA=E6=9C=ACReWOO?= =?UTF-8?q?=E6=A8=A1=E6=9D=BF=E5=B9=B6=E8=BD=AC=E4=B8=BA=E4=B8=AD=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/patterns/rewoo.py | 147 ++++++++++++++++++++++++++----------- 1 file changed, 106 insertions(+), 41 deletions(-) diff --git a/apps/llm/patterns/rewoo.py b/apps/llm/patterns/rewoo.py index 9898d471..ff126257 100644 --- a/apps/llm/patterns/rewoo.py +++ b/apps/llm/patterns/rewoo.py @@ -1,8 +1,8 @@ -"""规划生成命令行 +""" +规划生成命令行 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -from typing import Optional from apps.llm.patterns.core import CorePattern from apps.llm.reasoning import ReasoningLLM @@ -12,32 +12,28 @@ class InitPlan(CorePattern): """规划生成命令行""" system_prompt: str = r""" - You are a plan generator. For the given objective, **make a simple plan** that can generate \ - proper commandline arguments and flags step by step. + 你是一个计划生成器。对于给定的目标,**制定一个简单的计划**,该计划可以逐步生成合适的命令行参数和标志。 - You will be given a "command prefix", which is the part of the command that has been determined and generated. \ - You need to complete the command based on this prefix using flags and arguments. + 你会收到一个"命令前缀",这是已经确定和生成的命令部分。你需要基于这个前缀使用标志和参数来完成命令。 - In each step, indicate which external tool together with tool input to retrieve evidences. + 在每一步中,指明使用哪个外部工具以及工具输入来获取证据。 - Tools can be one of the following: - (1) Option["directive"]: Query the most similar command-line flag. Takes only one input parameter, \ - and "directive" must be the search string. The search string should be in detail and contain essential data. - (2) Argument[name]: Place the data in the task to a specific position in the command-line. \ - Takes exactly two input parameters. + 工具可以是以下之一: + (1) Option["指令"]:查询最相似的命令行标志。只接受一个输入参数,"指令"必须是搜索字符串。搜索字符串应该详细且包含必要的数据。 + (2) Argument[名称]<值>:将任务中的数据放置到命令行的特定位置。接受两个输入参数。 - All steps must begin with "Plan: ", and less than 150 words. - Do not add any superfluous steps. - Make sure that each step has all the information needed - do not skip steps. - Do not add any extra data behind the evidence. + 所有步骤必须以"Plan: "开头,且少于150个单词。 + 不要添加任何多余的步骤。 + 确保每个步骤都包含所需的所有信息 - 不要跳过步骤。 + 不要在证据后面添加任何额外数据。 - BEGIN EXAMPLE + 开始示例 - Task: 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 - Prefix: `docker run` - Usage: `docker run ${OPTS} ${image} ${command}`. 这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 \ + 任务:在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + 前缀:`docker run` + 用法:`docker run ${OPTS} ${image} ${command}`。这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 \ ["image", "command"] 其中之一。 - Prefix Description: 二进制程序`docker`的描述为“Docker容器平台”,`run`子命令的描述为“从镜像创建并运行一个新的容器”。 + 前缀描述:二进制程序`docker`的描述为"Docker容器平台",`run`子命令的描述为"从镜像创建并运行一个新的容器"。 Plan: 我需要一个标志使容器在后台运行。 #E1 = Option[在后台运行单个容器] Plan: 我需要一个标志,将主机/root目录挂载至容器内/data目录。 #E2 = Option[挂载主机/root目录至/data目录] @@ -45,24 +41,24 @@ class InitPlan(CorePattern): Plan: 我需要指定容器中运行的命令。 #E4 = Argument[command] Final: 组装上述线索,生成最终命令。 #F - END OF EXAMPLE + 示例结束 - Let's Begin! + 让我们开始! """ """系统提示词""" user_prompt: str = r""" - Task: {instruction} - Prefix: `{binary_name} {subcmd_name}` - Usage: `{subcmd_usage}`. 这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 {argument_list} 其中之一。 - Prefix Description: 二进制程序`{binary_name}`的描述为“{binary_description}”,`{subcmd_name}`子命令的描述为\ - “{subcmd_description}”。 + 任务:{instruction} + 前缀:`{binary_name} {subcmd_name}` + 用法:`{subcmd_usage}`。这是一个Python模板字符串。OPTS是所有标志的占位符。参数必须是 {argument_list} 其中之一。 + 前缀描述:二进制程序`{binary_name}`的描述为"{binary_description}",`{subcmd_name}`子命令的描述为\ + "{subcmd_description}"。 - Generate your plan accordingly. + 请生成相应的计划。 """ """用户提示词""" - def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: """处理Prompt""" super().__init__(system_prompt, user_prompt) @@ -103,32 +99,101 @@ class InitPlan(CorePattern): class PlanEvaluator(CorePattern): """计划评估器""" - system_prompt = "" + system_prompt: str = r""" + 你是一个计划评估器。你的任务是评估给定的计划是否合理和完整。 + + 一个好的计划应该: + 1. 涵盖原始任务的所有要求 + 2. 使用适当的工具收集必要的信息 + 3. 具有清晰和逻辑的步骤 + 4. 没有冗余或不必要的步骤 + + 对于计划中的每个步骤,评估: + 1. 工具选择是否适当 + 2. 输入参数是否清晰和充分 + 3. 该步骤是否有助于实现最终目标 + + 请回复: + "VALID" - 如果计划良好且完整 + "INVALID: <原因>" - 如果计划有问题,请解释原因 + """ """系统提示词""" - user_prompt = """""" + + user_prompt: str = r""" + 任务:{instruction} + 计划:{plan} + + 评估计划并回复"VALID"或"INVALID: <原因>"。 + """ """用户提示词""" - def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: """初始化Prompt""" super().__init__(system_prompt, user_prompt) - async def generate(self, **_kwargs) -> str: + async def generate(self, **kwargs) -> str: """生成计划评估结果""" - pass + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + instruction=kwargs["instruction"], + plan=kwargs["plan"], + )}, + ] + + result = "" + async for chunk in ReasoningLLM().call(kwargs["task_id"], messages, streaming=False): + result += chunk + + return result class RePlanner(CorePattern): """重新规划器""" - system_prompt = "" + system_prompt: str = r""" + 你是一个计划重新规划器。当计划被评估为无效时,你需要生成一个新的、改进的计划。 + + 新计划应该: + 1. 解决评估中提到的所有问题 + 2. 保持与原始计划相同的格式 + 3. 更加精确和完整 + 4. 为每个步骤使用适当的工具 + + 遵循与原始计划相同的格式: + - 每个步骤应以"Plan: "开头 + - 包含带有适当参数的工具使用 + - 保持步骤简洁和重点突出 + - 以"Final"步骤结束 + """ """系统提示词""" - user_prompt = """""" + + user_prompt: str = r""" + 任务:{instruction} + 原始计划:{plan} + 评估:{evaluation} + + 生成一个新的、改进的计划,解决评估中提到的所有问题。 + """ """用户提示词""" - def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None: + def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None: """初始化Prompt""" super().__init__(system_prompt, user_prompt) - async def generate(self, **_kwargs) -> str: + async def generate(self, **kwargs) -> str: """生成重新规划结果""" - pass + messages = [ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": self.user_prompt.format( + instruction=kwargs["instruction"], + plan=kwargs["plan"], + evaluation=kwargs["evaluation"], + )}, + ] + + result = "" + async for chunk in ReasoningLLM().call(kwargs["task_id"], messages, streaming=False): + result += chunk + + return result -- Gitee From 0e66d855299f8897644ed20dfa6c660d3740f8a5 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 14:54:20 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E8=B0=83=E6=95=B4Task=E7=BB=93=E6=9E=84?= =?UTF-8?q?=EF=BC=8C=E5=8E=BB=E9=99=A4=E5=86=97=E4=BD=99=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/task.py | 82 +++++++++++++++++++++++-------------------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/apps/entities/task.py b/apps/entities/task.py index 4d61e32c..8289cdb9 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -1,19 +1,21 @@ -"""Task相关数据结构定义 +""" +Task相关数据结构定义 -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + import uuid -from datetime import datetime, timezone -from typing import Any, Optional +from datetime import UTC, datetime +from typing import Any from pydantic import BaseModel, Field from apps.entities.enum_var import StepStatus -from apps.entities.record import RecordData class FlowStepHistory(BaseModel): - """任务执行历史;每个Executor每个步骤执行后都会创建 + """ + 任务执行历史;每个Executor每个步骤执行后都会创建 Collection: flow_history """ @@ -25,7 +27,7 @@ class FlowStepHistory(BaseModel): status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=timezone.utc).timestamp(), 3)) + created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) class ExecutorState(BaseModel): @@ -39,47 +41,51 @@ class ExecutorState(BaseModel): step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") app_id: str = Field(description="应用ID") - # 运行时数据 - ai_summary: str = Field(description="大模型的思考内容", default="") - filled_data: dict[str, Any] = Field(description="待使用的参数", default={}) - remaining_schema: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) + slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) -class TaskBlock(BaseModel): - """内存中的Task块,不存储在数据库中""" +class TaskIds(BaseModel): + """任务涉及的各种ID""" - session_id: str = Field(description="浏览器会话ID") - user_sub: str = Field(description="用户ID", default="") - record: RecordData = Field(description="当前任务执行过程关联的Record") - flow_state: Optional[ExecutorState] = Field(description="Flow的状态", default=None) - flow_context: dict[str, FlowStepHistory] = Field(description="Flow的执行信息", default={}) - new_context: list[str] = Field(description="Flow的执行信息(增量ID)", default=[]) + session_id: str = Field(description="会话ID") + group_id: str = Field(description="组ID") + conversation_id: str = Field(description="对话ID") + record_id: str = Field(description="记录ID", default_factory=lambda: str(uuid.uuid4())) + user_sub: str = Field(description="用户ID") -class RequestDataApp(BaseModel): - """模型对话中包含的app信息""" +class TaskTokens(BaseModel): + """任务Token""" - app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + input_tokens: int = Field(description="输入Token", default=0) + input_delta: int = Field(description="输入Token增量", default=0) + output_tokens: int = Field(description="输出Token", default=0) + output_delta: int = Field(description="输出Token增量", default=0) + time: float = Field(description="时间成本", default=0.0) + time_delta: float = Field(description="时间成本增量", default=0.0) -class TaskData(BaseModel): - """任务信息 +class TaskRuntime(BaseModel): + """任务运行时数据""" - Collection: task - 外键:task - record_group - """ + question: str = Field(description="用户问题", default="") + answer: str = Field(description="模型回答", default="") + facts: list[str] = Field(description="记忆", default=[]) + summary: str = Field(description="摘要", default="") + filled: dict[str, Any] = Field(description="填充的槽位", default={}) - id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - conversation_id: str - record_groups: list[str] = [] - state: Optional[ExecutorState] = Field(description="Flow的状态", default=None) - ended: bool = False - updated_at: float = Field(default_factory=lambda: round(datetime.now(tz=timezone.utc).timestamp(), 3)) +class Task(BaseModel): + """ + 任务信息 -class SchedulerResult(BaseModel): - """调度器返回结果""" + Collection: task + """ - used_docs: list[str] = Field(description="已使用的文档ID列表") + id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + ids: TaskIds = Field(description="任务涉及的各种ID") + context: dict[str, FlowStepHistory] = Field(description="Flow的步骤执行信息", default={}) + state: ExecutorState | None = Field(description="Flow的状态", default=None) + tokens: TaskTokens = Field(description="Token信息") + runtime: TaskRuntime = Field(description="任务运行时数据") + created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) -- Gitee From 6c54e822bff1cf8548cd3e3d91f44f7148b62f9b Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 14:54:51 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E8=B0=83=E6=95=B4SessionManager=EF=BC=8CRe?= =?UTF-8?q?dis=E8=BD=AC=E4=B8=BAMongoDB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/manager/session.py | 199 +++++++++++++++++++++------------------- 1 file changed, 105 insertions(+), 94 deletions(-) diff --git a/apps/manager/session.py b/apps/manager/session.py index 28551298..b143b478 100644 --- a/apps/manager/session.py +++ b/apps/manager/session.py @@ -1,68 +1,76 @@ -"""浏览器Session Manager +""" +浏览器Session Manager -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + import base64 import hashlib import hmac import logging import secrets -from typing import Any, Optional +from datetime import UTC, datetime, timedelta -from apps.common.config import config +from apps.common.config import Config +from apps.entities.config import FixedUserConfig +from apps.entities.session import Session +from apps.exceptions import LoginSettingsError, SessionError from apps.manager.blacklist import UserBlacklistManager -from apps.models.redis import RedisConnectionPool +from apps.models.mongo import MongoDB -logger = logging.getLogger("ray") +logger = logging.getLogger(__name__) class SessionManager: """浏览器Session管理""" @staticmethod - async def create_session(ip: Optional[str] = None, extra_keys: Optional[dict[str, Any]] = None) -> str: + async def create_session(ip: str | None = None, user_sub: str | None = None) -> str: """创建浏览器Session""" if not ip: err = "用户IP错误!" raise ValueError(err) session_id = secrets.token_hex(16) - data = { - "ip": ip, - } - if config["DISABLE_LOGIN"]: - data.update({ - "user_sub": config["DEFAULT_USER"], - }) - - if extra_keys is not None: - data.update(extra_keys) - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hmset(session_id, data) - pipe.expire(session_id, config["SESSION_TTL"] * 60) - await pipe.execute() - except Exception: - logger.exception("[SessionManager] 创建浏览器Session失败") + data = Session( + _id=session_id, + ip=ip, + expired_at=datetime.now(UTC) + timedelta(minutes=Config().get_config().fastapi.session_ttl), + ) + if Config().get_config().login.provider == "disable": + login_settings = Config().get_config().login.settings + if not isinstance(login_settings, FixedUserConfig): + err = "固定用户配置错误!" + raise LoginSettingsError(err) + data.user_sub = login_settings.user_id + + if user_sub is not None: + data.user_sub = user_sub + + try: + collection = MongoDB().get_collection("session") + await collection.insert_one(data.model_dump(exclude_none=True, by_alias=True)) + await collection.create_index( + "expired_at", expireAfterSeconds=0, + ) + except Exception as e: + err = "创建浏览器Session失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e return session_id @staticmethod - async def delete_session(session_id: str) -> bool: + async def delete_session(session_id: str) -> None: """删除浏览器Session""" if not session_id: - return True - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.exists(session_id) - result = await pipe.execute() - if not result[0]: - return True - pipe.delete(session_id) - result = await pipe.execute() - return result[0] != 1 - except Exception: - logger.exception("[SessionManager] 删除浏览器Session失败") - return False + return + try: + collection = MongoDB().get_collection("session") + await collection.delete_one({"_id": session_id}) + except Exception as e: + err = "删除浏览器Session失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e @staticmethod async def get_session(session_id: str, session_ip: str) -> str: @@ -71,14 +79,16 @@ class SessionManager: return await SessionManager.create_session(session_ip) ip = None - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hget(session_id, "ip") - pipe.expire(session_id, config["SESSION_TTL"] * 60) - result = await pipe.execute() - ip = result[0].decode() - except Exception: - logger.exception("[SessionManager] 读取浏览器Session失败") + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return await SessionManager.create_session(session_ip) + ip = Session(**data).ip + except Exception as e: + err = "读取浏览器Session失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e if not ip or ip != session_ip: return await SessionManager.create_session(session_ip) @@ -87,41 +97,41 @@ class SessionManager: @staticmethod async def verify_user(session_id: str) -> bool: """验证用户是否在Session中""" - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hexists(session_id, "user_sub") - pipe.expire(session_id, config["SESSION_TTL"] * 60) - result = await pipe.execute() - return result[0] - except Exception: - logger.exception("[SessionManager] 用户不在Session中") + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: return False + return Session(**data).user_sub is not None + except Exception as e: + err = "用户不在Session中" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e @staticmethod - async def get_user(session_id: str) -> Optional[str]: + async def get_user(session_id: str) -> str | None: """从Session中获取用户""" - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hget(session_id, "user_sub") - pipe.expire(session_id, config["SESSION_TTL"] * 60) - result = await pipe.execute() - user_sub = result[0].decode() - except Exception: - logger.exception("[SessionManager] 从Session中获取用户失败") + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: return None + user_sub = Session(**data).user_sub + except Exception as e: + err = "从Session中获取用户失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e # 查询黑名单 - if await UserBlacklistManager.check_blacklisted_users(user_sub): + if user_sub and await UserBlacklistManager.check_blacklisted_users(user_sub): logger.error("用户在Session黑名单中") - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hdel(session_id, "user_sub") - pipe.expire(session_id, config["SESSION_TTL"] * 60) - await pipe.execute() - return None - except Exception: - logger.exception("[SessionManager] 从Session中删除用户失败") - return None + try: + await collection.delete_one({"_id": session_id}) + except Exception as e: + err = "从Session中删除用户失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e + return None return user_sub @@ -130,19 +140,18 @@ class SessionManager: """创建CSRF Token""" rand = secrets.token_hex(8) - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hset(session_id, "nonce", rand) - pipe.expire(session_id, config["SESSION_TTL"] * 60) - await pipe.execute() - except Exception as e: - err = f"Create csrf token from session error: {e}" - raise RuntimeError(err) from e + try: + collection = MongoDB().get_collection("session") + await collection.update_one({"_id": session_id}, {"$set": {"nonce": rand}}) + except Exception as e: + err = "创建CSRF Token失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e csrf_value = f"{session_id}{rand}" csrf_b64 = base64.b64encode(bytes.fromhex(csrf_value)) - jwt_key = base64.b64decode(config["JWT_KEY"]) + jwt_key = base64.b64decode(Config().get_config().security.jwt_key) hmac_processor = hmac.new(key=jwt_key, msg=csrf_b64, digestmod=hashlib.sha256) signature = base64.b64encode(hmac_processor.digest()) @@ -167,18 +176,20 @@ class SessionManager: return False current_nonce = first_part[32:] - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - try: - pipe.hget(current_session_id, "nonce") - pipe.expire(current_session_id, config["SESSION_TTL"] * 60) - result = await pipe.execute() - nonce = result[0].decode() - if nonce != current_nonce: - return False - except Exception: - logger.exception("[SessionManager] 从Session中获取CSRF Token失败") - - jwt_key = base64.b64decode(config["JWT_KEY"]) + try: + collection = MongoDB().get_collection("session") + data = await collection.find_one({"_id": session_id}) + if not data: + return False + nonce = Session(**data).nonce + if nonce != current_nonce: + return False + except Exception as e: + err = "从Session中获取CSRF Token失败" + logger.exception("[SessionManager] %s", err) + raise SessionError(err) from e + + jwt_key = base64.b64decode(Config().get_config().security.jwt_key) hmac_obj = hmac.new(key=jwt_key, msg=token_msg[0].encode("utf-8"), digestmod=hashlib.sha256) signature = hmac_obj.digest() current_signature = base64.b64decode(token_msg[1]) -- Gitee From d6b21b5166d03cc1e02e42e3e6495d7144c1c818 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 14:55:13 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E8=B0=83=E6=95=B4=E6=8F=92=E4=BB=B6?= =?UTF-8?q?=E7=9A=84TokenManager?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/manager/token.py | 146 +++++++++++++++++++++++++++--------------- 1 file changed, 94 insertions(+), 52 deletions(-) diff --git a/apps/manager/token.py b/apps/manager/token.py index b0bf9885..eb329783 100644 --- a/apps/manager/token.py +++ b/apps/manager/token.py @@ -1,18 +1,23 @@ -"""Token Manager +""" +Token Manager -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + import logging -from typing import Optional +from datetime import UTC, datetime, timedelta import aiohttp from fastapi import status -from apps.common.config import config +from apps.common.config import Config +from apps.common.oidc import oidc_provider +from apps.constants import OIDC_ACCESS_TOKEN_EXPIRE_TIME +from apps.entities.config import OIDCConfig from apps.manager.session import SessionManager -from apps.models.redis import RedisConnectionPool +from apps.models.mongo import MongoDB -logger = logging.getLogger("ray") +logger = logging.getLogger(__name__) class TokenManager: @@ -25,11 +30,14 @@ class TokenManager: if not user_sub: err = "用户不存在!" raise ValueError(err) - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - pipe.get(f"{user_sub}_token_{plugin_name}") - result = await pipe.execute() - if result[0] is not None: - return result[0].decode() + + collection = MongoDB().get_collection("session") + token_data = await collection.find_one({ + "_id": f"{plugin_name}_token_{user_sub}", + }) + + if token_data: + return token_data["token"] token = await TokenManager.generate_plugin_token( plugin_name, @@ -43,6 +51,15 @@ class TokenManager: raise RuntimeError(err) return token + @staticmethod + def _get_login_config() -> OIDCConfig: + """获取并验证登录配置""" + login_config = Config().get_config().login.settings + if not isinstance(login_config, OIDCConfig): + err = "Authhub OIDC配置错误" + raise TypeError(err) + return login_config + @staticmethod async def generate_plugin_token( plugin_name: str, @@ -50,61 +67,86 @@ class TokenManager: user_sub: str, access_token_url: str, expire_time: int, - ) -> Optional[str]: + ) -> str | None: """生成插件Token""" - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - pipe.get(f"{user_sub}_oidc_access_token") - pipe.get(f"{user_sub}_oidc_refresh_token") - result = await pipe.execute() - if result[0]: - oidc_access_token = result[0] - elif result[1]: - # access token 过期的时候,重新获取 - url = config["OIDC_REFRESH_TOKEN_URL"] - async with aiohttp.ClientSession() as session: - response = await session.post( - url=url, - json={ - "refresh_token": result[1].decode(), - "client_id": config["OIDC_APP_ID"], - }, - ) - ret = await response.json() - if response.status != status.HTTP_200_OK: - logger.error("[TokenManager] 获取OIDC Access token 失败: %s", ret) - return None - oidc_access_token = ret["data"]["access_token"] - pipe.set(f"{user_sub}_oidc_access_token", oidc_access_token, int(config["OIDC_ACCESS_TOKEN_EXPIRE_TIME"]) * 60) - await pipe.execute() + collection = MongoDB().get_collection("session") + + # 获取OIDC token + oidc_token = await collection.find_one({ + "_id": f"access_token_{user_sub}", + }) + + if oidc_token: + oidc_access_token = oidc_token["token"] else: - await SessionManager.delete_session(session_id) - err = "Refresh token均过期,需要重新登录" - raise RuntimeError(err) + # 检查是否有refresh token + refresh_token = await collection.find_one({ + "_id": f"refresh_token_{user_sub}", + }) + + if refresh_token: + # access token 过期的时候,重新获取 + oidc_config = TokenManager._get_login_config() + try: + token_info = await oidc_provider.get_oidc_token(refresh_token["token_value"]) + oidc_access_token = token_info["access_token"] + + # 更新OIDC token + await collection.update_one( + {"_id": f"access_token_{user_sub}"}, + {"$set": { + "token": oidc_access_token, + "expired_at": datetime.now(UTC) + timedelta(minutes=OIDC_ACCESS_TOKEN_EXPIRE_TIME), + }}, + upsert=True, + ) + await collection.create_index("expired_at", expireAfterSeconds=0) + except Exception: + logger.exception("[TokenManager] 获取OIDC Access token 失败") + return None + else: + await SessionManager.delete_session(session_id) + err = "Refresh token均过期,需要重新登录" + raise RuntimeError(err) async with aiohttp.ClientSession() as session: response = await session.post( url=access_token_url, json={ - "client_id": config["OIDC_APP_ID"], - "access_token": oidc_access_token.decode(), + "client_id": oidc_config.app_id, + "access_token": oidc_access_token, }, ) ret = await response.json() if response.status != status.HTTP_200_OK: logger.error("[TokenManager] 获取 %s 插件所需的token失败", plugin_name) return None - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - pipe.set(f"{plugin_name}_{user_sub}_token", ret["data"]["access_token"], int(expire_time) * 60) - await pipe.execute() - return ret["data"]["access_token"] + + # 保存插件token + await collection.update_one( + { + "_id": f"{plugin_name}_token_{user_sub}", + }, + {"$set": { + "token": ret["access_token"], + "expired_at": datetime.now(UTC) + timedelta(minutes=expire_time), + }}, + upsert=True, + ) + + # 创建TTL索引 + await collection.create_index("expired_at", expireAfterSeconds=0) + return ret["access_token"] @staticmethod async def delete_plugin_token(user_sub: str) -> None: """删除插件token""" - async with RedisConnectionPool.get_redis_connection().pipeline(transaction=True) as pipe: - # 删除 oidc related token - pipe.delete(f"{user_sub}_oidc_access_token") - pipe.delete(f"{user_sub}_oidc_refresh_token") - pipe.delete(f"aops_{user_sub}_token") - await pipe.execute() - + collection = MongoDB().get_collection("token") + await collection.delete_many({ + "user_sub": user_sub, + "$or": [ + {"token_type": "oidc"}, + {"token_type": "oidc_refresh"}, + {"token_type": "plugin"}, + ], + }) -- Gitee From 7102a401a0a39f716b794f38706a950752eb5104 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 7 Apr 2025 14:56:12 +0800 Subject: [PATCH 5/5] =?UTF-8?q?CleanCode=20&=20=E4=BF=AE=E6=AD=A3Record?= =?UTF-8?q?=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/manager/conversation.py | 43 +++++++++++++-------- apps/manager/record.py | 75 ++++++++++++++++++++++-------------- 2 files changed, 74 insertions(+), 44 deletions(-) diff --git a/apps/manager/conversation.py b/apps/manager/conversation.py index b5b3a8d6..ba3e09cb 100644 --- a/apps/manager/conversation.py +++ b/apps/manager/conversation.py @@ -1,17 +1,19 @@ -"""对话 Manager +""" +对话 Manager -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + import logging import uuid -from datetime import datetime, timezone -from typing import Any, Optional +from datetime import UTC, datetime +from typing import Any from apps.entities.collection import Conversation from apps.manager.task import TaskManager from apps.models.mongo import MongoDB -logger = logging.getLogger("ray") +logger = logging.getLogger(__name__) class ConversationManager: @@ -22,13 +24,16 @@ class ConversationManager: """根据用户ID获取对话列表,按时间由近到远排序""" try: conv_collection = MongoDB.get_collection("conversation") - return [Conversation(**conv) async for conv in conv_collection.find({"user_sub": user_sub,"debug": False}).sort({"created_at": 1})] + return [ + Conversation(**conv) + async for conv in conv_collection.find({"user_sub": user_sub, "debug": False}).sort({"created_at": 1}) + ] except Exception: logger.exception("[ConversationManager] 通过用户ID获取对话失败") return [] @staticmethod - async def get_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> Optional[Conversation]: + async def get_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> Conversation | None: """通过ConversationID查询对话信息""" try: conv_collection = MongoDB.get_collection("conversation") @@ -41,7 +46,7 @@ class ConversationManager: return None @staticmethod - async def add_conversation_by_user_sub(user_sub: str, app_id: str, *, debug: bool) -> Optional[Conversation]: + async def add_conversation_by_user_sub(user_sub: str, app_id: str, *, debug: bool) -> Conversation | None: """通过用户ID新建对话""" conversation_id = str(uuid.uuid4()) conv = Conversation( @@ -59,9 +64,11 @@ class ConversationManager: "$push": {"conversations": conversation_id}, } if app_id: - # 非调试模式下更新应用使用情况 + # 非调试模式下更新应用使用情况 if not debug: - update_data["$set"] = {f"app_usage.{app_id}.last_used": round(datetime.now(timezone.utc).timestamp(), 3)} + update_data["$set"] = { + f"app_usage.{app_id}.last_used": round(datetime.now(UTC).timestamp(), 3), + } update_data["$inc"] = {f"app_usage.{app_id}.count": 1} await user_collection.update_one( {"_id": user_sub}, @@ -83,10 +90,11 @@ class ConversationManager: {"_id": conversation_id, "user_sub": user_sub}, {"$set": data}, ) - return result.modified_count > 0 except Exception: logger.exception("[ConversationManager] 更新对话失败") return False + else: + return result.modified_count > 0 @staticmethod async def delete_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> bool: @@ -96,15 +104,20 @@ class ConversationManager: record_group_collection = MongoDB.get_collection("record_group") try: async with MongoDB.get_session() as session, await session.start_transaction(): - conversation_data = await conv_collection.find_one_and_delete({"_id": conversation_id, "user_sub": user_sub}, session=session) + conversation_data = await conv_collection.find_one_and_delete( + {"_id": conversation_id, "user_sub": user_sub}, session=session, + ) if not conversation_data: return False - await user_collection.update_one({"_id": user_sub}, {"$pull": {"conversations": conversation_id}}, session=session) + await user_collection.update_one( + {"_id": user_sub}, {"$pull": {"conversations": conversation_id}}, session=session, + ) await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session) await session.commit_transaction() - await TaskManager.delete_tasks_by_conversation_id(conversation_id) - return True except Exception: logger.exception("[ConversationManager] 删除对话失败") return False + else: + await TaskManager.delete_tasks_by_conversation_id(conversation_id) + return True diff --git a/apps/manager/record.py b/apps/manager/record.py index 5aab177f..ae4f7703 100644 --- a/apps/manager/record.py +++ b/apps/manager/record.py @@ -1,25 +1,27 @@ -"""问答对Manager +""" +问答对Manager -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ + import logging import uuid -from typing import Literal, Optional +from typing import Literal -from apps.entities.collection import ( +from apps.entities.record import ( Record, RecordGroup, ) from apps.models.mongo import MongoDB -logger = logging.getLogger("ray") +logger = logging.getLogger(__name__) class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(user_sub: str, conversation_id: str, task_id: str) -> Optional[str]: + async def create_record_group(user_sub: str, conversation_id: str, task_id: str) -> str | None: """创建问答组""" group_id = str(uuid.uuid4()) record_group_collection = MongoDB.get_collection("record_group") @@ -36,16 +38,17 @@ class RecordManager: # RecordGroup里面加一条记录 await record_group_collection.insert_one(record_group.model_dump(by_alias=True), session=session) # Conversation里面加一个ID - await conversation_collection.update_one({"_id": conversation_id}, {"$push": {"record_groups": group_id}}, session=session) + await conversation_collection.update_one( + {"_id": conversation_id}, {"$push": {"record_groups": group_id}}, session=session, + ) except Exception: logger.exception("[RecordManager] 创建问答组失败") return None return group_id - @staticmethod - async def insert_record_data_into_record_group(user_sub: str, group_id: str, record: Record) -> Optional[str]: + async def insert_record_data_into_record_group(user_sub: str, group_id: str, record: Record) -> str | None: """加密问答对,并插入MongoDB中的特定问答组""" group_collection = MongoDB.get_collection("record_group") try: @@ -57,13 +60,17 @@ class RecordManager: logger.exception("[RecordManager] 插入加密问答对失败") return None else: - return record.record_id + return record.id @staticmethod async def query_record_by_conversation_id( - user_sub: str, conversation_id: str, total_pairs: Optional[int] = None, order: Literal["desc", "asc"] = "desc", + user_sub: str, + conversation_id: str, + total_pairs: int | None = None, + order: Literal["desc", "asc"] = "desc", ) -> list[Record]: - """查询ConversationID的最后n条问答对 + """ + 查询ConversationID的最后n条问答对 每个record_group只取最后一条record """ @@ -72,22 +79,26 @@ class RecordManager: record_group_collection = MongoDB.get_collection("record_group") try: # 得到conversation的全部record_group id - record_groups = await record_group_collection.aggregate([ - {"$match": {"conversation_id": conversation_id, "user_sub": user_sub}}, - {"$sort": {"created_at": sort_order}}, - {"$project": {"_id": 1}}, - {"$limit": total_pairs} if total_pairs is not None else {}, - ]) + record_groups = await record_group_collection.aggregate( + [ + {"$match": {"conversation_id": conversation_id, "user_sub": user_sub}}, + {"$sort": {"created_at": sort_order}}, + {"$project": {"_id": 1}}, + {"$limit": total_pairs} if total_pairs is not None else {}, + ], + ) records = [] async for record_group_id in record_groups: - record = await record_group_collection.aggregate([ - {"$match": {"_id": record_group_id["_id"]}}, - {"$project": {"records": 1}}, - {"$unwind": "$records"}, - {"$sort": {"records.created_at": -1}}, - {"$limit": 1}, - ]) + record = await record_group_collection.aggregate( + [ + {"$match": {"_id": record_group_id["_id"]}}, + {"$project": {"records": 1}}, + {"$unwind": "$records"}, + {"$sort": {"records.created_at": -1}}, + {"$limit": 1}, + ], + ) record = await record.to_list(length=1) if not record: logger.info("[RecordManager] 问答组 %s 没有问答对", record_group_id) @@ -101,8 +112,11 @@ class RecordManager: return records @staticmethod - async def query_record_group_by_conversation_id(conversation_id: str, total_pairs: Optional[int] = None) -> list[RecordGroup]: - """查询对话ID的最后n条问答组 + async def query_record_group_by_conversation_id( + conversation_id: str, total_pairs: int | None = None, + ) -> list[RecordGroup]: + """ + 查询对话ID的最后n条问答组 包含全部record_group及其关联的record """ @@ -123,14 +137,17 @@ class RecordManager: @staticmethod async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool: - """验证记录是否在组中 + """ + 验证记录是否在组中 :param record_id: 记录ID,设置了则会去查询指定记录ID的记录 :return: 记录是否存在 """ try: record_group_collection = MongoDB.get_collection("record_group") - record_data = await record_group_collection.find_one({"_id": group_id, "user_sub": user_sub, "records._id": record_id}) + record_data = await record_group_collection.find_one( + {"_id": group_id, "user_sub": user_sub, "records._id": record_id}, + ) return bool(record_data) except Exception: logger.exception("[RecordManager] 验证记录是否在组中失败") -- Gitee