From 15dc14679a106bef9ffa66b9f31cc1d9b4aaf502 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 22:26:56 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E5=AE=8C=E5=96=84=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E5=90=AF=E5=8A=A8=E6=B7=BB=E5=8A=A0=E6=9C=AC=E5=9C=B0=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E5=A4=B1=E8=B4=A5=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/main.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/apps/main.py b/apps/main.py index 1646f671a..58e2bd40b 100644 --- a/apps/main.py +++ b/apps/main.py @@ -97,10 +97,13 @@ async def add_no_auth_user() -> None: username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 if not username: username = "admin" - await user_collection.insert_one(User( - _id=username, - is_admin=True, - ).model_dump(by_alias=True)) + try: + await user_collection.insert_one(User( + _id=username, + is_admin=True, + ).model_dump(by_alias=True)) + except Exception as e: + logging.warning(f"添加无认证用户失败: {e}") async def init_resources() -> None: -- Gitee From 65953620cfa12b8153732a0c2bb05b78610be1a4 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:13:52 +0800 Subject: [PATCH 2/8] fix buh --- apps/scheduler/scheduler/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 91930f8cd..f63253693 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -229,7 +229,8 @@ class Scheduler: # 初始化Executor logger.info("[Scheduler] 初始化Executor") - + logger.error(f"{flow_data}") + logger.error(f"{self.task}") flow_exec = FlowExecutor( flow_id=flow_id, flow=flow_data, -- Gitee From 3b1fb85f828acd35c1e3b4c68afb092455b04868 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:38:56 +0800 Subject: [PATCH 3/8] fix bug --- apps/scheduler/executor/flow.py | 2 +- apps/schemas/enum_var.py | 1 + apps/schemas/task.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index d8400fdf3..ebd4da8ec 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -55,7 +55,7 @@ class FlowExecutor(BaseExecutor): """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 5ff3381c7..3fb650287 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -27,6 +27,7 @@ class FlowStatus(str, Enum): """Flow状态""" UNKNOWN = "unknown" + INIT = "init" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 9e9f15317..602123e71 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -40,7 +40,7 @@ class ExecutorState(BaseModel): flow_id: str = Field(description="Flow ID", default="") flow_name: str = Field(description="Flow名称", default="") description: str = Field(description="Flow描述", default="") - flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN) + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) # 任务级数据 step_id: str = Field(description="当前步骤ID", default="") step_name: str = Field(description="当前步骤名称", default="") -- Gitee From 6168945d3c707522480dbbecc9c9eda18e844871 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:42:59 +0800 Subject: [PATCH 4/8] fix bug --- apps/scheduler/executor/step.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 377a4c6e5..5a95e4079 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -260,7 +260,7 @@ class StepExecutor(BaseExecutor): input_data=self.obj.input, output_data=output_data, ) - self.task.context.append(history.model_dump(exclude_none=True, by_alias=True)) + self.task.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) -- Gitee From 9a0c015e79f3dc5302f1669cfed1b46347b3fcf7 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:48:26 +0800 Subject: [PATCH 5/8] fix bug --- apps/services/task.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/services/task.py b/apps/services/task.py index 93604a933..d133cbfba 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -94,7 +94,7 @@ class TaskManager: return flow_context_list @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") @@ -105,7 +105,8 @@ class TaskManager: ).sort( "created_at", -1, ).limit(length): - flow_context += [history] + for i in range(len(flow_context)): + flow_context.append(FlowStepHistory.model_validate(history)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] @@ -113,7 +114,7 @@ class TaskManager: return flow_context @staticmethod - async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: + async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: @@ -121,12 +122,12 @@ class TaskManager: # 查找是否存在 current_context = await flow_context_collection.find_one({ "task_id": task_id, - "_id": history["_id"], + "_id": history.id, }) if current_context: await flow_context_collection.update_one( {"_id": current_context["_id"]}, - {"$set": history}, + {"$set": history.model_dump(exclude_none=True, by_alias=True)}, ) else: await flow_context_collection.insert_one(history) -- Gitee From c83aced55db60209d6c8d5cad3e33088e93121b6 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 11:53:33 +0800 Subject: [PATCH 6/8] fix bug --- apps/services/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/services/task.py b/apps/services/task.py index d133cbfba..a8cb18482 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -130,7 +130,7 @@ class TaskManager: {"$set": history.model_dump(exclude_none=True, by_alias=True)}, ) else: - await flow_context_collection.insert_one(history) + await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True)) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") -- Gitee From 75f1670dcdeb04676b27e0bb7ee68ffeb52c8d91 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 1 Aug 2025 17:43:45 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E5=AE=8C=E5=96=84mcp=20agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/auth.py | 5 ++-- apps/routers/user.py | 26 ++++++++++++++++++-- apps/scheduler/executor/agent.py | 42 +++++++++++++++++++++++++++++--- apps/scheduler/mcp/host.py | 4 +-- apps/scheduler/mcp_agent/host.py | 6 ++--- apps/schemas/collection.py | 1 + apps/schemas/message.py | 6 +++++ apps/schemas/request_data.py | 9 ++++++- apps/schemas/response_data.py | 1 + apps/schemas/task.py | 2 +- apps/schemas/user.py | 1 + apps/services/record.py | 5 +++- apps/services/user.py | 22 ++++++++++++++++- tests/manager/test_user.py | 2 +- 14 files changed, 115 insertions(+), 17 deletions(-) diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 4a3f82939..72416fa35 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -74,7 +74,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: status_code=status.HTTP_403_FORBIDDEN, ) - await UserManager.update_userinfo_by_user_sub(user_sub) + await UserManager.update_refresh_revision_by_user_sub(user_sub) current_session = await SessionManager.create_session(user_host, user_sub) @@ -177,6 +177,7 @@ async def userinfo( user_sub=user_sub, revision=user.is_active, is_admin=user.is_admin, + auto_execute=user.auto_execute, ), ).model_dump(exclude_none=True, by_alias=True), ) @@ -192,7 +193,7 @@ async def userinfo( ) async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001 """更新用户协议信息""" - ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True) + ret: bool = await UserManager.update_refresh_revision_by_user_sub(user_sub, refresh_revision=True) if not ret: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/user.py b/apps/routers/user.py index 54e12f444..8c197204e 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -3,10 +3,11 @@ from typing import Annotated -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Body, Depends, status from fastapi.responses import JSONResponse from apps.dependency import get_user +from apps.schemas.request_data import UserUpdateRequest from apps.schemas.response_data import UserGetMsp, UserGetRsp from apps.schemas.user import UserInfo from apps.services.user import UserManager @@ -18,7 +19,7 @@ router = APIRouter( @router.get("") -async def chat( +async def get_user_sub( user_sub: Annotated[str, Depends(get_user)], ) -> JSONResponse: """查询所有用户接口""" @@ -42,3 +43,24 @@ async def chat( result=UserGetMsp(userInfoList=user_info_list), ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.post("") +async def update_user_info( + user_sub: Annotated[str, Depends(get_user)], + *, + data: Annotated[UserUpdateRequest, Body(..., description="用户更新信息")], +) -> JSONResponse: + """更新用户信息接口""" + # 更新用户信息 + + result = await UserManager.update_userinfo_by_user_sub(user_sub, data) + if not result: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, + ) + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"code": status.HTTP_404_NOT_FOUND, "message": "用户信息更新失败"}, + ) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index cb8e183e9..d6ed5917b 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -6,25 +6,61 @@ import logging from pydantic import Field from apps.scheduler.executor.base import BaseExecutor +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus from apps.scheduler.mcp_agent import host, plan, select +from apps.schemas.mcp import MCPServerConfig, MCPTool from apps.schemas.task import ExecutorState, StepQueueItem +from apps.schemas.message import param from apps.services.task import TaskManager - +from apps.services.appcenter import AppCenterManager +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" - question: str = Field(description="用户输入") max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") + mcp_list: list[MCPServerConfig] = Field(description="MCP服务器列表", default=[]) + tool_list: list[MCPTool] = Field(description="MCP工具列表", default=[]) + params: param | None = Field( + default=None, description="流执行过程中的参数补充", alias="params" + ) + + async def load_mcp_list(self) -> None: + """加载MCP服务器列表""" + logger.info("[MCPAgentExecutor] 加载MCP服务器列表") + # 获取MCP服务器列表 + app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) + mcp_ids = app.mcp_service + for mcp_id in mcp_ids: + self.mcp_list.append( + await MCPServiceManager.get_mcp_service(mcp_id) + ) + + async def load_tools(self) -> None: + """加载MCP工具列表""" + logger.info("[MCPAgentExecutor] 加载MCP工具列表") + # 获取工具列表 + mcp_ids = [mcp.id for mcp in self.mcp_list] + for mcp_id in mcp_ids: + if not await MCPServiceManager.is_mcp_enabled(mcp_id, self.agent_id): + logger.warning("MCP %s 未启用,跳过工具加载", mcp_id) + continue + # 获取MCP工具 + tools = await MCPServiceManager.get_mcp_tools(mcp_id) + self.tool_list.extend(tools) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + + async def run(self) -> None: + """执行MCP Agent的主逻辑""" + # 初始化MCP服务 diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index aa1961127..8e1108390 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -67,7 +67,7 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue context_list.append(context) @@ -118,7 +118,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context.model_dump(exclude_none=True, by_alias=True)) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index a8ebec7b4..f05d8957e 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -67,10 +67,10 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue - context_list.append(context) + context_list.append(context.model_dump(exclude_none=True, by_alias=True)) return self._env.from_string(MEMORY_TEMPLATE).render( context_list=context_list, @@ -118,7 +118,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index 0ff66c72b..a2991f851 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -61,6 +61,7 @@ class User(BaseModel): fav_apps: list[str] = [] fav_services: list[str] = [] is_admin: bool = Field(default=False, description="是否为管理员") + auto_execute: bool = Field(default=True, description="是否自动执行任务") class LLM(BaseModel): diff --git a/apps/schemas/message.py b/apps/schemas/message.py index cf70a82b8..1e58fb1b7 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -9,6 +9,12 @@ from apps.schemas.enum_var import EventType, StepStatus from apps.schemas.record import RecordMetadata +class param(BaseModel): + """流执行过程中的参数补充""" + content: dict[str, Any] | bool = Field(default={}, description="流执行过程中的参数补充内容") + description: str = Field(default="", description="流执行过程中的参数补充描述") + + class HeartbeatData(BaseModel): """心跳事件的数据结构""" diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 793ff4563..7bb5ac767 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -10,6 +10,7 @@ from apps.schemas.appcenter import AppData from apps.schemas.enum_var import CommentType from apps.schemas.flow_topology import FlowItem from apps.schemas.mcp import MCPType +from apps.schemas.message import param class RequestDataApp(BaseModel): @@ -17,7 +18,7 @@ class RequestDataApp(BaseModel): app_id: str = Field(description="应用ID", alias="appId") flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") - params: dict[str, Any] | None = Field(default=None, description="插件参数") + params: param | None = Field(default=None, description="流执行过程中的参数补充", alias="params") class MockRequestData(BaseModel): @@ -185,3 +186,9 @@ class UpdateKbReq(BaseModel): """更新知识库请求体""" kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[]) + + +class UserUpdateRequest(BaseModel): + """更新用户信息请求体""" + + auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index b1dc77b75..7f162326d 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -55,6 +55,7 @@ class AuthUserMsg(BaseModel): user_sub: str revision: bool is_admin: bool + auto_execute: bool class AuthUserRsp(ResponseData): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 602123e71..b08da0a8d 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -50,7 +50,7 @@ class ExecutorState(BaseModel): error_message: str = Field(description="当前步骤错误信息", default="") app_id: str = Field(description="应用ID", default="") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) - error_info: dict[str, Any] = Field(description="错误信息", default={}) + error_info: str = Field(description="错误信息", default="") class TaskIds(BaseModel): diff --git a/apps/schemas/user.py b/apps/schemas/user.py index 61aa2587b..debb3446e 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -9,3 +9,4 @@ class UserInfo(BaseModel): user_sub: str = Field(alias="userSub", default="") user_name: str = Field(alias="userName", default="") + auto_execute: bool = Field(alias="autoExecute", default=False) diff --git a/apps/services/record.py b/apps/services/record.py index 5c8a89dfd..e1925ef6c 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -49,6 +49,10 @@ class RecordManager: mongo = MongoDB() group_collection = mongo.get_collection("record_group") try: + await group_collection.update_one( + {"_id": group_id, "user_sub": user_sub}, + {"$pull": {"records": {"id": record.id}}} + ) await group_collection.update_one( {"_id": group_id, "user_sub": user_sub}, {"$push": {"records": record.model_dump(by_alias=True)}}, @@ -151,7 +155,6 @@ class RecordManager: logger.exception("[RecordManager] 验证记录是否在组中失败") return False - @staticmethod async def check_group_id(group_id: str, user_sub: str) -> bool: """检查group_id是否存在""" diff --git a/apps/services/user.py b/apps/services/user.py index 2721d3773..476f1ef43 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -4,6 +4,7 @@ import logging from datetime import UTC, datetime +from apps.schemas.request_data import UserUpdateRequest from apps.common.mongo import MongoDB from apps.schemas.collection import User from apps.services.conversation import ConversationManager @@ -52,7 +53,26 @@ class UserManager: return User(**user_data) if user_data else None @staticmethod - async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: + async def update_userinfo_by_user_sub(user_sub: str, data: UserUpdateRequest) -> bool: + """ + 根据用户sub更新用户信息 + + :param user_sub: 用户sub + :param data: 用户更新信息 + :return: 是否更新成功 + """ + mongo = MongoDB() + user_collection = mongo.get_collection("user") + update_dict = { + "$set": { + "auto_execute": data.auto_execute, + } + } + result = await user_collection.update_one({"_id": user_sub}, update_dict) + return result.modified_count > 0 + + @staticmethod + async def update_refresh_revision_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: """ 根据用户sub更新用户信息 diff --git a/tests/manager/test_user.py b/tests/manager/test_user.py index d350ca307..ab6eeab4c 100644 --- a/tests/manager/test_user.py +++ b/tests/manager/test_user.py @@ -73,7 +73,7 @@ class TestUserManager(unittest.TestCase): mock_mysql_db_instance.get_session.return_value = mock_session # 调用被测方法 - updated_userinfo = UserManager.update_userinfo_by_user_sub(userinfo, refresh_revision=True) + updated_userinfo = UserManager.update_refresh_revision_by_user_sub(userinfo, refresh_revision=True) # 断言返回的用户信息的 revision_number 是否与原始用户信息一致 self.assertEqual(updated_userinfo.revision_number, userinfo.revision_number) -- Gitee From dc80695ae21a3fc59058fb03a4cc49d3c587c73c Mon Sep 17 00:00:00 2001 From: zxstty Date: Sat, 2 Aug 2025 17:20:11 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E5=AE=8C=E5=96=84task=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=E5=92=8Cmcp=20agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 2 + apps/routers/chat.py | 35 +++--- apps/routers/record.py | 5 +- apps/scheduler/executor/agent.py | 58 +++++----- apps/scheduler/mcp_agent/host.py | 165 ++++++++++------------------ apps/scheduler/mcp_agent/plan.py | 26 ++--- apps/scheduler/mcp_agent/prompt.py | 101 ++++++++++++++--- apps/scheduler/scheduler/context.py | 15 ++- apps/schemas/message.py | 4 +- apps/schemas/record.py | 3 +- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 7 +- apps/services/conversation.py | 2 +- apps/services/record.py | 18 ++- apps/services/task.py | 85 +++++++------- 15 files changed, 289 insertions(+), 238 deletions(-) diff --git a/apps/common/queue.py b/apps/common/queue.py index 911485b31..089d475e6 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -56,6 +56,8 @@ class MessageQueue: flow = MessageFlow( appId=task.state.app_id, flowId=task.state.flow_id, + flowName=task.state.flow_name, + flowStatus=task.state.flow_status, stepId=task.state.step_id, stepName=task.state.step_name, stepStatus=task.state.step_status diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 26a874816..06bc2dd78 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -22,6 +22,8 @@ from apps.schemas.task import Task from apps.services.activity import Activity from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager from apps.services.flow import FlowManager +from apps.services.conversation import ConversationManager +from apps.services.record import RecordManager from apps.services.task import TaskManager RECOMMEND_TRES = 5 @@ -32,25 +34,33 @@ router = APIRouter( ) -async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task: +async def init_task(post_body: RequestData, user_sub: str) -> Task: """初始化Task""" # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - if post_body.new_task: - # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) - if task: - await TaskManager.delete_task_by_task_id(task.id) - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + # 更改信息并刷新数据库 if post_body.new_task: - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + conversation = await ConversationManager.get_conversation_by_conversation_id( + user_sub=user_sub, + conversation_id=post_body.conversation_id, + ) + if not conversation: + err = "[Chat] 用户没有权限访问该对话!" + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err) + task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) + await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) + task = await TaskManager.init_new_task(user_sub=user_sub, conversation_id=post_body.conversation_id, post_body=post_body) + else: + if not post_body.task_id: + err = "[Chat] task_id 不可为空!" + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty") + task = await TaskManager.get_task_by_conversation_id(post_body.task_id) return task -async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: +async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: await Activity.set_active(user_sub) @@ -62,7 +72,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await Activity.remove_active(user_sub) return - task = await init_task(post_body, user_sub, session_id) + task = await init_task(post_body, user_sub) # 创建queue;由Scheduler进行关闭 queue = MessageQueue() @@ -120,7 +130,6 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) async def chat( post_body: RequestData, user_sub: Annotated[str, Depends(get_user)], - session_id: Annotated[str, Depends(get_session)], ) -> StreamingResponse: """LLM流式对话接口""" # 问题黑名单检测 @@ -133,7 +142,7 @@ async def chat( if await Activity.is_active(user_sub): raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - res = chat_generator(post_body, user_sub, session_id) + res = chat_generator(post_body, user_sub) return StreamingResponse( content=res, media_type="text/event-stream", diff --git a/apps/routers/record.py b/apps/routers/record.py index f73e6c983..8d0105ff5 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -65,7 +65,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record = RecordData( id=record.id, groupId=record_group.id, - taskId=record_group.task_id, + taskId=record.task_id, conversationId=conversation_id, content=record_data, metadata=record.metadata @@ -87,8 +87,9 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record.flow = RecordFlow( id=record.flow.flow_id, # TODO: 此处前端应该用name recordId=record.id, - flowStatus=record.flow.flow_staus, flowId=record.flow.flow_id, + flowName=record.flow.flow_name, + flowStatus=record.flow.flow_staus, stepNum=len(flow_step_list), steps=[], ) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index d6ed5917b..603ea65b6 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -5,10 +5,13 @@ import logging from pydantic import Field +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.executor.base import BaseExecutor from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus -from apps.scheduler.mcp_agent import host, plan, select -from apps.schemas.mcp import MCPServerConfig, MCPTool +from apps.scheduler.mcp_agent.host import MCPHost +from apps.scheduler.mcp_agent.plan import MCPPlanner +from apps.scheduler.pool.mcp.client import MCPClient +from apps.schemas.mcp import MCPCollection, MCPTool from apps.schemas.task import ExecutorState, StepQueueItem from apps.schemas.message import param from apps.services.task import TaskManager @@ -24,43 +27,48 @@ class MCPAgentExecutor(BaseExecutor): servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - mcp_list: list[MCPServerConfig] = Field(description="MCP服务器列表", default=[]) + mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[]) + mcp_client: dict[str, MCPClient] = Field( + description="MCP客户端列表,key为mcp_id", default={} + ) tool_list: list[MCPTool] = Field(description="MCP工具列表", default=[]) params: param | None = Field( default=None, description="流执行过程中的参数补充", alias="params" ) + resoning_llm: ReasoningLLM = Field( + default=ReasoningLLM(), + description="推理大模型", + ) + + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) - async def load_mcp_list(self) -> None: + async def load_mcp(self) -> None: """加载MCP服务器列表""" logger.info("[MCPAgentExecutor] 加载MCP服务器列表") # 获取MCP服务器列表 app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) mcp_ids = app.mcp_service for mcp_id in mcp_ids: - self.mcp_list.append( - await MCPServiceManager.get_mcp_service(mcp_id) - ) - - async def load_tools(self) -> None: - """加载MCP工具列表""" - logger.info("[MCPAgentExecutor] 加载MCP工具列表") - # 获取工具列表 - mcp_ids = [mcp.id for mcp in self.mcp_list] - for mcp_id in mcp_ids: - if not await MCPServiceManager.is_mcp_enabled(mcp_id, self.agent_id): - logger.warning("MCP %s 未启用,跳过工具加载", mcp_id) + mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) + if self.task.ids.user_sub not in mcp_service.activated: + logger.warning( + "[MCPAgentExecutor] 用户 %s 未启用MCP %s", + self.task.ids.user_sub, + mcp_id, + ) continue - # 获取MCP工具 - tools = await MCPServiceManager.get_mcp_tools(mcp_id) - self.tool_list.extend(tools) - async def load_state(self) -> None: - """从数据库中加载FlowExecutor的状态""" - logger.info("[FlowExecutor] 加载Executor状态") - # 尝试恢复State - if self.task.state and self.task.state.flow_status != FlowStatus.INIT: - self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + self.mcp_list.append(mcp_service) + self.mcp_client[mcp_id] = await MCPHost.get_client(self.task.ids.user_sub, mcp_id) + self.tool_list.extend(mcp_service.tools) async def run(self) -> None: """执行MCP Agent的主逻辑""" # 初始化MCP服务 + self.load_state() + self.load_mcp() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index f05d8957e..3217f5393 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -14,116 +14,53 @@ from apps.llm.function import JsonGenerator from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import MCPPool +from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS from apps.schemas.enum_var import StepStatus from apps.schemas.mcp import MCPPlanItem, MCPTool -from apps.schemas.task import FlowStepHistory +from apps.schemas.task import Task, FlowStepHistory from apps.services.task import TaskManager logger = logging.getLogger(__name__) +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) + class MCPHost: """MCP宿主服务""" - def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None: - """初始化MCP宿主""" - self._user_sub = user_sub - self._task_id = task_id - # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description - self._runtime_id = runtime_id - self._runtime_name = runtime_name - self._context_list = [] - self._env = SandboxedEnvironment( - loader=BaseLoader(), - autoescape=False, - trim_blocks=True, - lstrip_blocks=True, - ) - - async def get_client(self, mcp_id: str) -> MCPClient | None: + @staticmethod + async def get_client(user_sub, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) if not mcp_db_result: - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + logger.warning("用户 %s 未启用MCP %s", user_sub, mcp_id) return None # 获取MCP配置 try: - return await MCPPool().get(mcp_id, self._user_sub) + return await MCPPool().get(mcp_id, user_sub) except KeyError: - logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) + logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", user_sub, mcp_id) return None - async def assemble_memory(self) -> str: + @staticmethod + async def assemble_memory(task: Task) -> str: """组装记忆""" - task = await TaskManager.get_task_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return "" - - context_list = [] - for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) - if not context: - continue - context_list.append(context.model_dump(exclude_none=True, by_alias=True)) - return self._env.from_string(MEMORY_TEMPLATE).render( - context_list=context_list, - ) - - async def _save_memory( - self, - tool: MCPTool, - plan_item: MCPPlanItem, - input_data: dict[str, Any], - result: str, - ) -> dict[str, Any]: - """保存记忆""" - try: - output_data = json.loads(result) - except Exception: # noqa: BLE001 - logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str") - output_data = { - "message": result, - } - - if not isinstance(output_data, dict): - output_data = { - "message": result, - } - - # 创建context;注意用法 - context = FlowStepHistory( - task_id=self._task_id, - flow_id=self._runtime_id, - flow_name=self._runtime_name, - flow_status=StepStatus.SUCCESS, - step_id=tool.name, - step_name=tool.name, - # description是规划的实际内容 - step_description=plan_item.content, - step_status=StepStatus.SUCCESS, - input_data=input_data, - output_data=output_data, + return _env.from_string(MEMORY_TEMPLATE).render( + context_list=task.context, ) - # 保存到task - task = await TaskManager.get_task_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return {} - self._context_list.append(context.id) - task.context.append(context) - await TaskManager.save_task(self._task_id, task) - - return output_data - - async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]: + async def _get_first_input_params(schema: dict[str, Any], query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" @@ -137,23 +74,48 @@ class MCPHost: llm_query, [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": await self.assemble_memory()}, + {"role": "user", "content": await MCPHost.assemble_memory()}, + ], + schema, + ) + return await json_generator.generate() + + async def _fill_params(mcp_tool: MCPTool, schema: dict[str, Any], + current_input: dict[str, Any], + error_message: str = "", params: dict[str, Any] = {}, + params_description: str = "") -> dict[str, Any]: + llm_query = "请生成修复之后的工具参数" + prompt = _env.from_string(REPAIR_PARAMS).render( + tool_name=mcp_tool.name, + tool_description=mcp_tool.description, + input_schema=schema, + current_input=current_input, + error_message=error_message, + params=params, + params_description=params_description, + ) + + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, ], schema, ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: + async def call_tool(user_sub: str, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client - client = await MCPPool().get(tool.mcp_id, self._user_sub) + client = await MCPPool().get(tool.mcp_id, user_sub) if client is None: err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}" logger.error(err) raise ValueError(err) # 填充参数 - params = await self._fill_params(tool, plan_item.instruction) + params = await MCPHost._fill_params(tool, plan_item.instruction) # 调用工具 result = await client.call_tool(tool.name, params) # 保存记忆 @@ -162,29 +124,12 @@ class MCPHost: if not isinstance(item, TextContent): logger.error("MCP结果类型不支持: %s", item) continue - processed_result.append(await self._save_memory(tool, plan_item, params, item.text)) - - return processed_result - - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: - """获取工具列表""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - - # 获取工具列表 - tool_list = [] - for mcp_id in mcp_id_list: - # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) - if not mcp_db_result: - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) - continue - # 获取MCP工具配置 + result = item.text try: - for tool in mcp_db_result["tools"]: - tool_list.extend([MCPTool.model_validate(tool)]) - except KeyError: - logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) + json_result = json.loads(result) + except Exception as e: + logger.error("MCP结果解析失败: %s, 错误: %s", result, e) continue + processed_result.append(json_result) - return tool_list + return processed_result diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index a7c1f132d..13e7a98dc 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -1,6 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" -from typing import Any +from typing import Any, AsyncGenerator from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -57,8 +57,8 @@ class MCPPlanner: result += chunk # 保存token用量 - self.input_tokens = self.resoning_llm.input_tokens - self.output_tokens = self.resoning_llm.output_tokens + self.input_tokens += self.resoning_llm.input_tokens + self.output_tokens += self.resoning_llm.output_tokens return result async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: @@ -171,7 +171,8 @@ class MCPPlanner: """获取推理大模型的风险评估结果""" template = self._env.from_string(RISK_EVALUATE) prompt = template.render( - tool=tool, + tool_name=tool.name, + tool_description=tool.description, input_param=input_param, additional_info=additional_info, ) @@ -194,7 +195,8 @@ class MCPPlanner: schema_with_null = slot.add_null_to_basic_types() template = self._env.from_string(GET_MISSING_PARAMS) prompt = template.render( - tool=tool, + tool_name=tool.name, + tool_description=tool.description, input_param=input_param, schema=schema_with_null, error_message=error_message, @@ -204,7 +206,7 @@ class MCPPlanner: input_param_with_null = await self._parse_result(result, schema_with_null) return input_param_with_null - async def generate_answer(self, plan: MCPPlan, memory: str) -> str: + async def generate_answer(self, plan: MCPPlan, memory: str) -> AsyncGenerator[str, None]: """生成最终回答""" template = self._env.from_string(FINAL_ANSWER) prompt = template.render( @@ -213,16 +215,12 @@ class MCPPlanner: goal=self.user_goal, ) - llm = ReasoningLLM() - result = "" - async for chunk in llm.call( + async for chunk in self.resoning_llm.call( [{"role": "user", "content": prompt}], streaming=False, temperature=0.07, ): - result += chunk + yield chunk - self.input_tokens = llm.input_tokens - self.output_tokens = llm.output_tokens - - return result + self.input_tokens = self.resoning_llm.input_tokens + self.output_tokens = self.resoning_llm.output_tokens diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index cf05afc87..9cbc2f5bc 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -421,8 +421,8 @@ RISK_EVALUATE = dedent(r""" ``` # 工具 - {{ tool.name }} - {{ tool.description }} + {{ tool_name }} + {{ tool_description }} # 工具入参 {{ input_param }} @@ -464,7 +464,7 @@ JUDGE_NEXT_STEP = dedent(r""" { "content": "任务执行完成,端口扫描结果为Result[2]", "tool": "Final", - "instruction": "" + "instruction": "" } ]} ## 当前使用的工具 @@ -489,8 +489,8 @@ JUDGE_NEXT_STEP = dedent(r""" {{ current_plan }} # 当前使用的工具 - {{ tool.name }} - {{ tool.description }} + {{ tool_name }} + {{ tool_description }} # 工具入参 {{ input_param }} @@ -571,8 +571,8 @@ GET_MISSING_PARAMS = dedent(r""" ``` # 工具 < tool > - < name > {{tool.name}} < /name > - < description > {{tool.description}} < /description > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > < / tool > # 工具入参 {{input_param}} @@ -583,32 +583,107 @@ GET_MISSING_PARAMS = dedent(r""" # 输出 """ ) +REPAIR_PARAMS = dedent(r""" + 你是一个工具参数修复器。 + 你的任务是根据当前的工具信息、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。 + + # 样例 + ## 工具信息 + + mysql_analyzer + 分析MySQL数据库性能 + + ## 工具入参的schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL数据库的主机地址" + }, + "port": { + "type": "integer", + "description": "MySQL数据库的端口号" + }, + "username": { + "type": "string", + "description": "MySQL数据库的用户名" + }, + "password": { + "type": "string", + "description": "MySQL数据库的密码" + } + }, + "required": ["host", "port", "username", "password"] + } + ## 工具当前的入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + ## 工具的报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + ## 补充的参数 + { + "username": "admin", + "password": "admin123" + } + ## 补充的参数描述 + 用户希望使用admin用户和admin123密码来连接MySQL数据库。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "admin", + "password": "admin123" + } + ``` + # 工具 + + {{tool_name}} + {{tool_description}} + + # 工具入参scheme + {{input_schema}} + # 工具入参 + {{input_param}} + # 运行报错 + {{error_message}} + # 补充的参数 + {{params}} + # 补充的参数描述 + {{params_description}} + # 输出 + """ + ) FINAL_ANSWER = dedent(r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 # 用户目标 - {{goal}} + {{ goal }} # 计划执行情况 为了完成上述目标,你实施了以下计划: - {{memory}} + {{ memory }} # 其他背景信息: - {{status}} + {{ status }} # 现在,请根据以上信息,向用户报告目标的完成情况: """) - MEMORY_TEMPLATE = dedent(r""" - { % for ctx in context_list % } + {% for ctx in context_list % } - 第{{loop.index}}步:{{ctx.step_description}} - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` - 执行状态:{{ctx.status}} - 得到数据:`{{ctx.output_data}}` - { % endfor % } + {% endfor % } """) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 72f64a710..3b26f42f2 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -199,21 +199,21 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: # 检查是否存在group_id if not await RecordManager.check_group_id(task.ids.group_id, user_sub): - record_group = await RecordManager.create_record_group( - task.ids.group_id, user_sub, post_body.conversation_id, task.id, + record_group_id = await RecordManager.create_record_group( + task.ids.group_id, user_sub, post_body.conversation_id ) - if not record_group: + if not record_group_id: logger.error("[Scheduler] 创建问答组失败") return else: - record_group = task.ids.group_id + record_group_id = task.ids.group_id # 修改文件状态 - await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) + await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group_id) # 保存Record - await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record) + await RecordManager.insert_record_data_into_record_group(user_sub, record_group_id, record) # 保存与答案关联的文件 - await DocumentManager.save_answer_doc(user_sub, record_group, used_docs) + await DocumentManager.save_answer_doc(user_sub, record_group_id, used_docs) if post_body.app and post_body.app.app_id: # 更新最近使用的应用 @@ -223,5 +223,4 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED: await TaskManager.delete_task_by_task_id(task.id) else: - # 更新Task await TaskManager.save_task(task.id, task) diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 1e58fb1b7..e73413242 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -5,7 +5,7 @@ from typing import Any from datetime import UTC, datetime from pydantic import BaseModel, Field -from apps.schemas.enum_var import EventType, StepStatus +from apps.schemas.enum_var import EventType, FlowStatus, StepStatus from apps.schemas.record import RecordMetadata @@ -28,6 +28,8 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") + flow_name: str = Field(description="Flow名称", alias="flowName") + flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN) step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d32227625..6a394375b 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -44,6 +44,7 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_name: str = Field(alias="flowName", default="") flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] @@ -129,6 +130,7 @@ class Record(RecordData): user_sub: str key: dict[str, Any] = {} + task_id: str content: str comment: RecordComment = Field(default=RecordComment()) flow: FlowHistory = Field( @@ -149,5 +151,4 @@ class RecordGroup(BaseModel): records: list[Record] = [] docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变 conversation_id: str - task_id: str created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 7bb5ac767..8719c2e98 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -47,6 +47,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + task_id: str | None = Field(default=None, alias="taskId", description="任务ID") new_task: bool = Field(default=True, description="是否新建任务") diff --git a/apps/schemas/task.py b/apps/schemas/task.py index b08da0a8d..eccc95a56 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -46,11 +46,12 @@ class ExecutorState(BaseModel): step_name: str = Field(description="当前步骤名称", default="") step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") - retry_times: int = Field(description="当前步骤重试次数", default=0) - error_message: str = Field(description="当前步骤错误信息", default="") app_id: str = Field(description="应用ID", default="") - slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) + current_input: dict[str, Any] = Field(description="当前输入数据", default={}) + params: dict[str, Any] = Field(description="补充的参数", default={}) + params_description: str = Field(description="补充的参数描述", default="") error_info: str = Field(description="错误信息", default="") + retry_times: int = Field(description="当前步骤重试次数", default=0) class TaskIds(BaseModel): diff --git a/apps/services/conversation.py b/apps/services/conversation.py index bac964db1..6bacb7272 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -140,4 +140,4 @@ class ConversationManager: await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session) await session.commit_transaction() - await TaskManager.delete_tasks_by_conversation_id(conversation_id) + await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id) diff --git a/apps/services/record.py b/apps/services/record.py index e1925ef6c..6b61f91ec 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -9,7 +9,7 @@ from apps.schemas.record import ( Record, RecordGroup, ) - +from apps.schemas.enum_var import FlowStatus logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None: + async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None: """创建问答组""" mongo = MongoDB() record_group_collection = mongo.get_collection("record_group") @@ -26,7 +26,6 @@ class RecordManager: _id=group_id, user_sub=user_sub, conversation_id=conversation_id, - task_id=task_id, ) try: @@ -137,6 +136,19 @@ class RecordManager: logger.exception("[RecordManager] 查询问答组失败") return [] + @staticmethod + async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: + """更新Record关联的Flow状态""" + record_group_collection = MongoDB().get_collection("record_group") + try: + await record_group_collection.update_many( + {"records.flow.flow_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], + ) + except Exception: + logger.exception("[RecordManager] 更新Record关联的Flow状态失败") + @staticmethod async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool: """ diff --git a/apps/services/task.py b/apps/services/task.py index a8cb18482..2f75a8c3a 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -113,6 +113,28 @@ class TaskManager: else: return flow_context + @staticmethod + async def init_new_task( + cls, + user_sub: str, + session_id: str | None = None, + post_body: RequestData | None = None, + ) -> Task: + """获取任务块""" + return Task( + _id=str(uuid.uuid4()), + ids=TaskIds( + user_sub=user_sub if user_sub else "", + session_id=session_id if session_id else "", + conversation_id=post_body.conversation_id, + group_id=post_body.group_id if post_body.group_id else "", + ), + question=post_body.question if post_body else "", + group_id=post_body.group_id if post_body else "", + tokens=TaskTokens(), + runtime=TaskRuntime(), + ) + @staticmethod async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" @@ -145,7 +167,25 @@ class TaskManager: await task_collection.delete_one({"_id": task_id}) @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: str) -> None: + async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]: + """通过ConversationID删除Task信息""" + mongo = MongoDB() + task_collection = mongo.get_collection("task") + task_ids = [] + try: + async for task in task_collection.find( + {"conversation_id": conversation_id}, + {"_id": 1}, + ): + task_ids.append(task["_id"]) + if task_ids: + await task_collection.delete_many({"conversation_id": conversation_id}) + except Exception: + logger.exception("[TaskManager] 删除ConversationID的Task信息失败") + return [] + + @staticmethod + async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") @@ -162,49 +202,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod - async def get_task( - cls, - task_id: str | None = None, - session_id: str | None = None, - post_body: RequestData | None = None, - user_sub: str | None = None, - ) -> Task: - """获取任务块""" - if task_id: - try: - task = await cls.get_task_by_task_id(task_id) - if task: - return task - except Exception: - logger.exception("[TaskManager] 通过task_id获取任务失败") - - logger.info("[TaskManager] 未提供task_id,通过session_id获取任务") - if not session_id or not post_body: - err = ( - "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。" - ) - raise ValueError(err) - - if post_body.group_id: - task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id) - else: - task = await cls.get_task_by_conversation_id(post_body.conversation_id) - - if task: - return task - return Task( - _id=str(uuid.uuid4()), - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - group_id=post_body.group_id if post_body.group_id else "", - ), - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee