From 4fb484240da47956c4f9e1572da4dac8b099f4e0 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 6 Aug 2025 16:50:23 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E5=90=88Agent=E6=94=B9=E5=8A=A8?= =?UTF-8?q?=E8=87=B3=20!630=20(=E9=99=A4=20!624=20=E5=A4=96)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 2 + apps/models/user.py | 2 + apps/routers/appcenter.py | 14 +- apps/routers/auth.py | 1 + apps/routers/chat.py | 29 +- apps/routers/mcp_service.py | 3 + apps/routers/record.py | 6 +- apps/routers/user.py | 17 + apps/scheduler/executor/agent.py | 431 +++++++++++++++++++++++++- apps/scheduler/executor/base.py | 4 +- apps/scheduler/executor/flow.py | 4 +- apps/scheduler/executor/step.py | 2 +- apps/scheduler/mcp/host.py | 4 +- apps/scheduler/mcp_agent/host.py | 167 ++++------ apps/scheduler/mcp_agent/plan.py | 234 +++++++++----- apps/scheduler/mcp_agent/prompt.py | 397 ++++++++++++++++++------ apps/scheduler/mcp_agent/select.py | 253 +++++---------- apps/scheduler/scheduler/context.py | 2 +- apps/scheduler/scheduler/scheduler.py | 76 +++-- apps/schemas/appcenter.py | 10 +- apps/schemas/enum_var.py | 7 +- apps/schemas/mcp.py | 33 ++ apps/schemas/message.py | 16 +- apps/schemas/record.py | 6 +- apps/schemas/request_data.py | 11 +- apps/schemas/response_data.py | 1 + apps/schemas/task.py | 13 +- apps/services/conversation.py | 2 +- apps/services/mcp_service.py | 6 +- apps/services/record.py | 15 + apps/services/task.py | 105 +++---- 31 files changed, 1286 insertions(+), 587 deletions(-) diff --git a/apps/common/queue.py b/apps/common/queue.py index 96a87073..cba88fe5 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -56,6 +56,8 @@ class MessageQueue: flow = MessageFlow( appId=task.state.app_id, flowId=task.state.flow_id, + flowName=task.state.flow_name, + flowStatus=task.state.flow_status, stepId=task.state.step_id, stepName=task.state.step_name, stepStatus=task.state.step_status, diff --git a/apps/models/user.py b/apps/models/user.py index 185d1184..56b352ac 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -38,6 +38,8 @@ class User(Base): """用户选择的知识库的ID""" defaultLLM: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), default=None, nullable=True) # noqa: N815 """用户选择的大模型ID""" + autoExecute: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # noqa: N815 + """Agent是否自动执行""" class UserFavoriteType(str, enum.Enum): diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 192e3eb0..d0893a08 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -10,7 +10,7 @@ from fastapi.responses import JSONResponse from apps.dependency.user import verify_personal_token, verify_session from apps.exceptions import InstancePermissionError -from apps.schemas.appcenter import AppFlowInfo, AppPermissionData +from apps.schemas.appcenter import AppFlowInfo, AppMcpServiceInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import ChangeFavouriteAppRequest, CreateAppRequest from apps.schemas.response_data import ( @@ -26,6 +26,7 @@ from apps.schemas.response_data import ( ResponseData, ) from apps.services.appcenter import AppCenterManager +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) router = APIRouter( @@ -218,6 +219,15 @@ async def get_application(appId: Annotated[uuid.UUID, Path()]) -> JSONResponse: ) for flow in app_data.flows ] + mcp_service = [] + if app_data.mcp_service: + for service in app_data.mcp_service: + mcp_collection = await MCPServiceManager.get_mcp_service(service) + mcp_service.append(AppMcpServiceInfo( + id=mcp_collection.id, + name=mcp_collection.name, + description=mcp_collection.description, + )) return JSONResponse( status_code=status.HTTP_200_OK, content=GetAppPropertyRsp( @@ -238,7 +248,7 @@ async def get_application(appId: Annotated[uuid.UUID, Path()]) -> JSONResponse: authorizedUsers=app_data.permission.users, ), workflows=workflows, - mcpService=app_data.mcp_service, + mcpService=mcp_service, ), ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1d5cc516..10b895b9 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -164,6 +164,7 @@ async def userinfo(request: Request) -> JSONResponse: user_sub=request.state.user_sub, revision=user.isActive, is_admin=is_admin, + auto_execute=user.autoExecute, ), ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 840c2fb9..860e71f2 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -19,7 +19,9 @@ from apps.schemas.response_data import ResponseData from apps.schemas.task import Task from apps.services.activity import Activity from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager +from apps.services.conversation import ConversationManager from apps.services.flow import FlowManager +from apps.services.record import RecordManager from apps.services.task import TaskManager RECOMMEND_TRES = 5 @@ -36,17 +38,25 @@ router = APIRouter( async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task: """初始化Task""" - if post_body.new_task: - # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) - if task: - await TaskManager.delete_task_by_task_id(task.id) - # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) # 更改信息并刷新数据库 - if post_body.new_task: + if post_body.task_id is None: + conversation = await ConversationManager.get_conversation_by_conversation_id( + user_sub=user_sub, + conversation_id=post_body.conversation_id, + ) + if not conversation: + err = "[Chat] 用户没有权限访问该对话!" + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err) + task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) + await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) + task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body) task.runtime.question = post_body.question task.ids.group_id = post_body.group_id + else: + if not post_body.task_id: + err = "[Chat] task_id 不可为空!" + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty") + task = await TaskManager.get_task_by_conversation_id(post_body.task_id) return task @@ -70,6 +80,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 在单独Task中运行Scheduler,拉齐queue.get的时机 scheduler = Scheduler(task, queue, post_body) + logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_sub)}") scheduler_task = asyncio.create_task(scheduler.run()) # 处理每一条消息 @@ -119,8 +130,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) @router.post("/chat") async def chat(request: Request, post_body: RequestData) -> StreamingResponse: """LLM流式对话接口""" - user_sub = request.state.user_sub session_id = request.state.session_id + user_sub = request.state.user_sub # 问题黑名单检测 if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): # 用户扣分 diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index dd0c0765..3007d560 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -54,6 +54,8 @@ async def get_mcpservice_list( searchType: SearchType = SearchType.ALL, # noqa: N803 keyword: str | None = None, page: Annotated[int, Query(ge=1)] = 1, + *, + is_active: bool | None = None, ) -> JSONResponse: """获取服务列表""" user_sub = request.state.user_sub @@ -63,6 +65,7 @@ async def get_mcpservice_list( user_sub, keyword, page, + is_active=is_active, ) except Exception as e: err = f"[MCPServiceCenter] 获取MCP服务列表失败: {e}" diff --git a/apps/routers/record.py b/apps/routers/record.py index 02c629d0..81fa3003 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -68,7 +68,7 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path tmp_record = RecordData( id=record.id, groupId=record_group.id, - taskId=record_group.task_id, + taskId=record.task_id, conversationId=conversationId, content=record_data, metadata=record.metadata @@ -90,13 +90,13 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path tmp_record.flow = RecordFlow( id=record.flow.flow_id, # TODO: 此处前端应该用name recordId=record.id, - flowStatus=record.flow.flow_staus, flowId=record.flow.flow_id, + flowName=record.flow.flow_name, + flowStatus=record.flow.flow_staus, stepNum=len(flow_step_list), steps=[], ) for flow_step in flow_step_list: - flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( RecordFlowStep( stepId=flow_step.step_name, # TODO: 此处前端应该用name diff --git a/apps/routers/user.py b/apps/routers/user.py index da3efce9..6451c4b6 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -40,3 +40,20 @@ async def list_user( result=UserGetMsp(userInfoList=user_info_list), ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.post("") +async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONResponse: + """更新用户信息接口""" + # 更新用户信息 + + result = await UserManager.update_userinfo_by_user_sub(request.state.user_sub, data) + if not result: + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, + ) + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"code": status.HTTP_404_NOT_FOUND, "message": "用户信息更新失败"}, + ) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index cb8e183e..4a88ecb6 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -2,13 +2,35 @@ """MCP Agent执行器""" import logging +import uuid +from typing import Any from pydantic import Field +from apps.llm.patterns.rewrite import QuestionRewrite +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.executor.base import BaseExecutor -from apps.scheduler.mcp_agent import host, plan, select -from apps.schemas.task import ExecutorState, StepQueueItem +from apps.scheduler.mcp_agent.host import MCPHost +from apps.scheduler.mcp_agent.plan import MCPPlanner +from apps.scheduler.mcp_agent.select import FINAL_TOOL_ID, MCPSelector +from apps.scheduler.pool.mcp.client import MCPClient +from apps.schemas.enum_var import EventType, FlowStatus, SpecialCallType, StepStatus +from apps.schemas.mcp import ( + ErrorType, + GoalEvaluationResult, + MCPCollection, + MCPPlan, + MCPTool, + RestartStepIndex, + ToolExcutionErrorType, + ToolRisk, +) +from apps.schemas.message import param +from apps.schemas.task import ExecutorState, FlowStepHistory, StepQueueItem +from apps.services.appcenter import AppCenterManager +from apps.services.mcp_service import MCPServiceManager from apps.services.task import TaskManager +from apps.services.user import UserManager logger = logging.getLogger(__name__) @@ -16,15 +38,416 @@ logger = logging.getLogger(__name__) class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" - question: str = Field(description="用户输入") max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") + mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[]) + mcp_client: dict[str, MCPClient] = Field( + description="MCP客户端列表,key为mcp_id", default={}, + ) + tools: dict[str, MCPTool] = Field( + description="MCP工具列表,key为tool_id", default={}, + ) + params: param | None = Field( + default=None, description="流执行过程中的参数补充", alias="params", + ) + resoning_llm: ReasoningLLM = Field( + default=ReasoningLLM(), + description="推理大模型", + ) + + async def update_tokens(self) -> None: + """更新令牌数""" + self.task.tokens.input_tokens = self.resoning_llm.input_tokens + self.task.tokens.output_tokens = self.resoning_llm.output_tokens + await TaskManager.save_task(self.task.id, self.task) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + + async def load_mcp(self) -> None: + """加载MCP服务器列表""" + logger.info("[MCPAgentExecutor] 加载MCP服务器列表") + # 获取MCP服务器列表 + app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) + mcp_ids = app.mcp_service + for mcp_id in mcp_ids: + mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) + if self.task.ids.user_sub not in mcp_service.activated: + logger.warning( + "[MCPAgentExecutor] 用户 %s 未启用MCP %s", + self.task.ids.user_sub, + mcp_id, + ) + continue + + self.mcp_list.append(mcp_service) + self.mcp_client[mcp_id] = await MCPHost.get_client(self.task.ids.user_sub, mcp_id) + for tool in mcp_service.tools: + self.tools[tool.id] = tool + + async def plan(self, is_replan: bool = False, start_index: int | None = None) -> None: + if is_replan: + error_message = "之前的计划遇到以下报错\n\n"+self.task.state.error_message + else: + error_message = "初始化计划" + tools = MCPSelector.select_top_tool( + self.task.runtime.question, list(self.tools.values()), + additional_info=error_message, top_n=40) + if is_replan: + logger.info("[MCPAgentExecutor] 重新规划流程") + if not start_index: + start_index = await MCPPlanner.get_replan_start_step_index(self.task.runtime.question, + self.task.state.error_message, + self.task.runtime.temporary_plans, + self.resoning_llm) + current_plan = self.task.runtime.temporary_plans.plans[start_index:] + error_message = self.task.state.error_message + temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question, + is_replan=is_replan, + error_message=error_message, + current_plan=current_plan, + tool_list=tools, + max_steps=self.max_steps-start_index-1, + reasoning_llm=self.resoning_llm + ) + self.update_tokens() + self.push_message( + EventType.STEP_CANCEL, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.CANCELLED + self.task.runtime.temporary_plans = self.task.runtime.temporary_plans.plans[:start_index] + temporary_plans.plans + self.task.state.step_index = start_index + else: + start_index = 0 + self.task.runtime.temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question, tool_list=tools, max_steps=self.max_steps, reasoning_llm=self.resoning_llm) + for i in range(start_index, len(self.task.runtime.temporary_plans.plans)): + self.task.runtime.temporary_plans.plans[i].step_id = str(uuid.uuid4()) + + async def get_tool_input_param(self, is_first: bool) -> None: + if is_first: + # 获取第一个输入参数 + self.task.state.current_input = await MCPHost._get_first_input_params(self.tools[self.task.state.step_id], self.task.runtime.question, self.task) + else: + # 获取后续输入参数 + if isinstance(self.params, param): + params = self.params.content + params_description = self.params.description + else: + params = {} + params_description = "" + self.task.state.current_input = await MCPHost._fill_params(self.tools[self.task.state.step_id], self.task.state.current_input, self.task.state.error_message, params, params_description) + + async def reset_step_to_index(self, start_index: int) -> None: + """重置步骤到开始""" + logger.info("[MCPAgentExecutor] 重置步骤到索引 %d", start_index) + if self.task.runtime.temporary_plans: + self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.step_id = self.task.runtime.temporary_plans.plans[start_index].step_id + self.task.state.step_index = 0 + self.task.state.step_name = self.task.runtime.temporary_plans.plans[start_index].tool + self.task.state.step_description = self.task.runtime.temporary_plans.plans[start_index].content + self.task.state.step_status = StepStatus.RUNNING + self.task.state.retry_times = 0 + else: + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_id = FINAL_TOOL_ID + + async def confirm_before_step(self) -> None: + logger.info("[MCPAgentExecutor] 等待用户确认步骤 %d", self.task.state.step_index) + # 发送确认消息 + confirm_message = await MCPPlanner.get_tool_risk(self.tools[self.task.state.step_id], self.task.state.current_input, "", self.resoning_llm) + self.update_tokens() + self.push_message(EventType.STEP_WAITING_FOR_START, confirm_message.model_dump( + exclude_none=True, by_alias=True)) + self.push_message(EventType.FLOW_STOP, {}) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.WAITING + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True), + ) + ) + + async def run_step(self): + """执行步骤""" + self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.step_status = StepStatus.RUNNING + logger.info("[MCPAgentExecutor] 执行步骤 %d", self.task.state.step_index) + # 获取MCP客户端 + mcp_tool = self.tools[self.task.state.step_id] + mcp_client = self.mcp_client[mcp_tool.mcp_id] + if not mcp_client: + logger.error("[MCPAgentExecutor] MCP客户端未找到: %s", mcp_tool.mcp_id) + self.task.state.flow_status = FlowStatus.ERROR + error = "[MCPAgentExecutor] MCP客户端未找到: {}".format(mcp_tool.mcp_id) + self.task.state.error_message = error + try: + output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) + self.update_tokens() + self.push_message( + EventType.STEP_INPUT, + self.task.state.current_input + ) + self.push_message( + EventType.STEP_OUTPUT, + output_params + ) + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=StepStatus.SUCCESS, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data=self.task.state.current_input, + output_data=output_params, + ) + ) + self.task.state.step_status = StepStatus.SUCCESS + except Exception as e: + logger.warning("[MCPAgentExecutor] 执行步骤 %s 失败: %s", mcp_tool.name, str(e)) + import traceback + self.task.state.error_message = traceback.format_exc() + self.task.state.step_status = StepStatus.ERROR + + async def generate_params_with_null(self) -> None: + """生成参数补充""" + mcp_tool = self.tools[self.task.state.step_id] + params_with_null = await MCPPlanner.get_missing_param( + mcp_tool, + self.task.state.current_input, + self.task.state.error_message, + self.resoning_llm + ) + self.update_tokens() + self.push_message( + EventType.STEP_WAITING_FOR_PARAM, + data={ + "message": "当运行产生如下报错:\n" + self.task.state.error_message, + "params": params_with_null + } + ) + self.push_message( + EventType.FLOW_STOP, + data={} + ) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.PARAM + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data={ + "message": "当运行产生如下报错:\n" + self.task.state.error_message, + "params": params_with_null + } + ) + ) + + async def get_next_step(self) -> None: + self.task.state.step_index += 1 + if self.task.state.step_index < len(self.task.runtime.temporary_plans): + if self.task.runtime.temporary_plans.plans[self.task.state.step_index].step_id == FINAL_TOOL_ID: + # 最后一步 + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_status = StepStatus.SUCCESS + self.push_message( + EventType.FLOW_SUCCESS, + data={} + ) + return + self.task.state.step_id = self.task.runtime.temporary_plans.plans[self.task.state.step_index].step_id + self.task.state.step_name = self.task.runtime.temporary_plans.plans[self.task.state.step_index].tool + self.task.state.step_description = self.task.runtime.temporary_plans.plans[self.task.state.step_index].content + self.task.state.step_status = StepStatus.INIT + self.task.state.current_input = {} + self.push_message( + EventType.STEP_INIT, + data={} + ) + else: + # 没有下一步了,结束流程 + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_status = StepStatus.SUCCESS + self.push_message( + EventType.FLOW_SUCCESS, + data={} + ) + return + + async def error_handle_after_step(self) -> None: + """步骤执行失败后的错误处理""" + self.task.state.step_status = StepStatus.ERROR + self.task.state.flow_status = FlowStatus.ERROR + self.push_message( + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) + + async def work(self) -> None: + """执行当前步骤""" + if self.task.state.step_status == StepStatus.INIT: + await self.get_tool_input_param(is_first=True) + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if not user_info.auto_execute: + # 等待用户确认 + await self.confirm_before_step() + return + self.task.state.step_status = StepStatus.RUNNING + elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: + if self.task.context[-1].step_status == StepStatus.PARAM: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + elif self.task.state.step_status == StepStatus.WAITING: + if self.params.content: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + else: + self.task.state.flow_status = FlowStatus.CANCELLED + self.task.state.step_status = StepStatus.CANCELLED + self.push_message( + EventType.STEP_CANCEL, + data={} + ) + self.push_message( + EventType.FLOW_CANCEL, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.CANCELLED + if self.task.state.step_status == StepStatus.PARAM: + await self.get_tool_input_param(is_first=False) + max_retry = 5 + for i in range(max_retry): + if i != 0: + await self.get_tool_input_param(is_first=False) + await self.run_step() + if self.task.state.step_status == StepStatus.SUCCESS: + break + elif self.task.state.step_status == StepStatus.ERROR: + # 错误处理 + if self.task.state.retry_times >= 3: + await self.error_handle_after_step() + else: + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + mcp_tool = self.tools[self.task.state.step_id] + error_type = await MCPPlanner.get_tool_execute_error_type( + self.task.runtime.question, + self.task.runtime.temporary_plans, + mcp_tool, + self.task.state.current_input, + self.task.state.error_message, + self.resoning_llm + ) + if error_type.type == ErrorType.DECORRECT_PLAN or user_info.auto_execute: + await self.plan(is_replan=True) + self.reset_step_to_index(self.task.state.step_index) + elif error_type.type == ErrorType.MISSING_PARAM: + await self.generate_params_with_null() + elif self.task.state.step_status == StepStatus.SUCCESS: + await self.get_next_step() + + async def summarize(self) -> None: + async for chunk in MCPPlanner.generate_answer( + self.task.runtime.question, + self.task.runtime.temporary_plans, + (await MCPHost.assemble_memory(self.task)), + self.resoning_llm + ): + self.push_message( + EventType.TEXT_ADD, + data=chunk + ) + self.task.runtime.answer += chunk + + async def run(self) -> None: + """执行MCP Agent的主逻辑""" + # 初始化MCP服务 + self.load_state() + self.load_mcp() + if self.task.state.flow_status == FlowStatus.INIT: + # 初始化状态 + self.task.state.flow_id = str(uuid.uuid4()) + self.task.state.flow_name = await MCPPlanner.get_flow_name(self.task.runtime.question, self.resoning_llm) + await self.plan(is_replan=False) + self.reset_step_to_index(0) + TaskManager.save_task(self.task.id, self.task) + self.task.state.flow_status = FlowStatus.RUNNING + self.push_message( + EventType.FLOW_START, + data={} + ) + try: + while self.task.state.step_index < len(self.task.runtime.temporary_plans) and \ + self.task.state.flow_status == FlowStatus.RUNNING: + await self.work() + TaskManager.save_task(self.task.id, self.task) + except Exception as e: + logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e)) + self.task.state.flow_status = FlowStatus.ERROR + self.task.state.error_message = str(e) + self.task.state.step_status = StepStatus.ERROR + self.push_message( + EventType.STEP_ERROR, + data={} + ) + self.push_message( + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index cf2f4e68..86e72483 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -49,10 +49,8 @@ class BaseExecutor(BaseModel, ABC): question=self.question, params=self.task.runtime.filled, ).model_dump(exclude_none=True, by_alias=True) - elif event_type == EventType.FLOW_STOP.value: - data = {} elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): - data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) + data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) if data is None: data = {} diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 0f4978a1..08b34ecc 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -57,7 +57,7 @@ class FlowExecutor(BaseExecutor): """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState @@ -125,7 +125,7 @@ class FlowExecutor(BaseExecutor): return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"]["branch_id"] + branch_id = self.task.context[-1].output_data["branch_id"] if branch_id: next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index c9d30e54..233c8999 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -263,7 +263,7 @@ class StepExecutor(BaseExecutor): input_data=self.obj.input, output_data=output_data, ) - self.task.context.append(history.model_dump(exclude_none=True, by_alias=True)) + self.task.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 6b78e65d..443fd29d 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -65,7 +65,7 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue context_list.append(context) @@ -117,7 +117,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context.model_dump(exclude_none=True, by_alias=True)) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index a8ebec7b..a56ce03b 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -12,118 +12,55 @@ from mcp.types import TextContent from apps.common.mongo import MongoDB from apps.llm.function import JsonGenerator from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE +from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import MCPPool from apps.schemas.enum_var import StepStatus from apps.schemas.mcp import MCPPlanItem, MCPTool -from apps.schemas.task import FlowStepHistory +from apps.schemas.task import FlowStepHistory, Task from apps.services.task import TaskManager logger = logging.getLogger(__name__) +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) + class MCPHost: """MCP宿主服务""" - def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None: - """初始化MCP宿主""" - self._user_sub = user_sub - self._task_id = task_id - # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description - self._runtime_id = runtime_id - self._runtime_name = runtime_name - self._context_list = [] - self._env = SandboxedEnvironment( - loader=BaseLoader(), - autoescape=False, - trim_blocks=True, - lstrip_blocks=True, - ) - - async def get_client(self, mcp_id: str) -> MCPClient | None: + @staticmethod + async def get_client(user_sub, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) if not mcp_db_result: - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + logger.warning("用户 %s 未启用MCP %s", user_sub, mcp_id) return None # 获取MCP配置 try: - return await MCPPool().get(mcp_id, self._user_sub) + return await MCPPool().get(mcp_id, user_sub) except KeyError: - logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) + logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", user_sub, mcp_id) return None - async def assemble_memory(self) -> str: + @staticmethod + async def assemble_memory(task: Task) -> str: """组装记忆""" - task = await TaskManager.get_task_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return "" - - context_list = [] - for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) - if not context: - continue - context_list.append(context) - return self._env.from_string(MEMORY_TEMPLATE).render( - context_list=context_list, - ) - - async def _save_memory( - self, - tool: MCPTool, - plan_item: MCPPlanItem, - input_data: dict[str, Any], - result: str, - ) -> dict[str, Any]: - """保存记忆""" - try: - output_data = json.loads(result) - except Exception: # noqa: BLE001 - logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str") - output_data = { - "message": result, - } - - if not isinstance(output_data, dict): - output_data = { - "message": result, - } - - # 创建context;注意用法 - context = FlowStepHistory( - task_id=self._task_id, - flow_id=self._runtime_id, - flow_name=self._runtime_name, - flow_status=StepStatus.SUCCESS, - step_id=tool.name, - step_name=tool.name, - # description是规划的实际内容 - step_description=plan_item.content, - step_status=StepStatus.SUCCESS, - input_data=input_data, - output_data=output_data, + return _env.from_string(MEMORY_TEMPLATE).render( + context_list=task.context, ) - # 保存到task - task = await TaskManager.get_task_by_task_id(self._task_id) - if not task: - logger.error("任务 %s 不存在", self._task_id) - return {} - self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) - await TaskManager.save_task(self._task_id, task) - - return output_data - - async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]: + async def _get_first_input_params(mcp_tool: MCPTool, query: str, task: Task) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" @@ -137,23 +74,48 @@ class MCPHost: llm_query, [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": await self.assemble_memory()}, + {"role": "user", "content": await MCPHost.assemble_memory(task)}, ], - schema, + mcp_tool.input_schema, ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: + async def _fill_params(mcp_tool: MCPTool, + current_input: dict[str, Any], + error_message: str = "", params: dict[str, Any] = {}, + params_description: str = "") -> dict[str, Any]: + llm_query = "请生成修复之后的工具参数" + prompt = _env.from_string(REPAIR_PARAMS).render( + tool_name=mcp_tool.name, + tool_description=mcp_tool.description, + input_schema=mcp_tool.input_schema, + current_input=current_input, + error_message=error_message, + params=params, + params_description=params_description, + ) + + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + mcp_tool.input_schema, + ) + return await json_generator.generate() + + async def call_tool(user_sub: str, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client - client = await MCPPool().get(tool.mcp_id, self._user_sub) + client = await MCPPool().get(tool.mcp_id, user_sub) if client is None: err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}" logger.error(err) raise ValueError(err) # 填充参数 - params = await self._fill_params(tool, plan_item.instruction) + params = await MCPHost._fill_params(tool, plan_item.instruction) # 调用工具 result = await client.call_tool(tool.name, params) # 保存记忆 @@ -162,29 +124,12 @@ class MCPHost: if not isinstance(item, TextContent): logger.error("MCP结果类型不支持: %s", item) continue - processed_result.append(await self._save_memory(tool, plan_item, params, item.text)) - - return processed_result - - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: - """获取工具列表""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - - # 获取工具列表 - tool_list = [] - for mcp_id in mcp_id_list: - # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) - if not mcp_db_result: - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) - continue - # 获取MCP工具配置 + result = item.text try: - for tool in mcp_db_result["tools"]: - tool_list.extend([MCPTool.model_validate(tool)]) - except KeyError: - logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) + json_result = json.loads(result) + except Exception as e: + logger.error("MCP结果解析失败: %s, 错误: %s", result, e) continue + processed_result.append(json_result) - return tool_list + return processed_result \ No newline at end of file diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index ec90f57b..474a7010 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -1,6 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" -from typing import Any +from typing import Any, AsyncGenerator from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -13,30 +13,26 @@ from apps.scheduler.mcp_agent.prompt import ( FINAL_ANSWER, GENERATE_FLOW_NAME, GET_MISSING_PARAMS, + GET_REPLAN_START_STEP_INDEX, RECREATE_PLAN, RISK_EVALUATE, + TOOL_EXECUTE_ERROR_TYPE_ANALYSIS, ) from apps.scheduler.slot.slot import Slot -from apps.schemas.mcp import GoalEvaluationResult, MCPPlan, MCPTool, ToolRisk +from apps.schemas.mcp import GoalEvaluationResult, MCPPlan, MCPTool, RestartStepIndex, ToolExcutionErrorType, ToolRisk + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) class MCPPlanner: """MCP 用户目标拆解与规划""" - - def __init__(self, user_goal: str, resoning_llm: ReasoningLLM = None) -> None: - """初始化MCP规划器""" - self.user_goal = user_goal - self._env = SandboxedEnvironment( - loader=BaseLoader, - autoescape=True, - trim_blocks=True, - lstrip_blocks=True, - ) - self.resoning_llm = resoning_llm or ReasoningLLM() - self.input_tokens = 0 - self.output_tokens = 0 - - async def get_resoning_result(self, prompt: str) -> str: + @staticmethod + async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理结果""" # 调用推理大模型 message = [ @@ -44,7 +40,7 @@ class MCPPlanner: {"role": "user", "content": prompt}, ] result = "" - async for chunk in self.resoning_llm.call( + async for chunk in resoning_llm.call( message, streaming=False, temperature=0.07, @@ -52,12 +48,10 @@ class MCPPlanner: ): result += chunk - # 保存token用量 - self.input_tokens = self.resoning_llm.input_tokens - self.output_tokens = self.resoning_llm.output_tokens return result - async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: + @staticmethod + async def _parse_result(result: str, schema: dict[str, Any]) -> str: """解析推理结果""" json_generator = JsonGenerator( result, @@ -70,155 +64,237 @@ class MCPPlanner: json_result = await json_generator.generate() return json_result - async def evaluate_goal(self, tool_list: list[MCPTool]) -> GoalEvaluationResult: + @staticmethod + async def evaluate_goal( + tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM()) -> GoalEvaluationResult: """评估用户目标的可行性""" # 获取推理结果 - result = await self._get_reasoning_evaluation(tool_list) + result = await MCPPlanner._get_reasoning_evaluation(tool_list, resoning_llm) # 解析为结构化数据 - evaluation = await self._parse_evaluation_result(result) + evaluation = await MCPPlanner._parse_evaluation_result(result) # 返回评估结果 return evaluation - async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str: + @staticmethod + async def _get_reasoning_evaluation( + goal, tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理大模型的评估结果""" - template = self._env.from_string(EVALUATE_GOAL) + template = _env.from_string(EVALUATE_GOAL) prompt = template.render( - goal=self.user_goal, + goal=goal, tools=tool_list, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) return result - async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult: + @staticmethod + async def _parse_evaluation_result(result: str) -> GoalEvaluationResult: """将推理结果解析为结构化数据""" schema = GoalEvaluationResult.model_json_schema() - evaluation = await self._parse_result(result, schema) + evaluation = await MCPPlanner._parse_result(result, schema) # 使用GoalEvaluationResult模型解析结果 return GoalEvaluationResult.model_validate(evaluation) - async def get_flow_name(self) -> str: + async def get_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取当前流程的名称""" - result = await self._get_reasoning_flow_name() + result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm) return result - async def _get_reasoning_flow_name(self) -> str: + @staticmethod + async def _get_reasoning_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理大模型的流程名称""" - template = self._env.from_string(GENERATE_FLOW_NAME) - prompt = template.render(goal=self.user_goal) - result = await self.get_resoning_result(prompt) + template = _env.from_string(GENERATE_FLOW_NAME) + prompt = template.render(goal=user_goal) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) return result - async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: + @staticmethod + async def get_replan_start_step_index( + user_goal: str, error_message: str, current_plan: MCPPlan | None = None, + history: str = "", + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan: + """获取重新规划的步骤索引""" + # 获取推理结果 + template = _env.from_string(GET_REPLAN_START_STEP_INDEX) + prompt = template.render( + goal=user_goal, + error_message=error_message, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + history=history, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + schema = RestartStepIndex.model_json_schema() + schema["properties"]["start_index"]["maximum"] = len(current_plan.plans) - 1 + schema["properties"]["start_index"]["minimum"] = 0 + restart_index = await MCPPlanner._parse_result(result, schema) + # 使用RestartStepIndex模型解析结果 + return RestartStepIndex.model_validate(restart_index) + + @staticmethod + async def create_plan( + user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 6, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 - result = await self._get_reasoning_plan(tool_list, max_steps) + result = await MCPPlanner._get_reasoning_plan(user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm) # 解析为结构化数据 - return await self._parse_plan_result(result, max_steps) + return await MCPPlanner._parse_plan_result(result, max_steps) + @staticmethod async def _get_reasoning_plan( - self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(), + user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, tool_list: list[MCPTool] = [], - max_steps: int = 10) -> str: + max_steps: int = 10, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理大模型的结果""" # 格式化Prompt if is_replan: - template = self._env.from_string(RECREATE_PLAN) + template = _env.from_string(RECREATE_PLAN) prompt = template.render( - current_plan=current_plan, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), error_message=error_message, - goal=self.user_goal, + goal=user_goal, tools=tool_list, max_num=max_steps, ) else: - template = self._env.from_string(CREATE_PLAN) + template = _env.from_string(CREATE_PLAN) prompt = template.render( - goal=self.user_goal, + goal=user_goal, tools=tool_list, max_num=max_steps, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) return result - async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: + @staticmethod + async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan: """将推理结果解析为结构化数据""" # 格式化Prompt schema = MCPPlan.model_json_schema() schema["properties"]["plans"]["maxItems"] = max_steps - plan = await self._parse_result(result, schema) + plan = await MCPPlanner._parse_result(result, schema) # 使用Function模型解析结果 return MCPPlan.model_validate(plan) - async def get_tool_risk(self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "") -> ToolRisk: + @staticmethod + async def get_tool_risk( + tool: MCPTool, input_parm: dict[str, Any], + additional_info: str = "", resoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolRisk: """获取MCP工具的风险评估结果""" # 获取推理结果 - result = await self._get_reasoning_risk(tool, input_parm, additional_info) + result = await MCPPlanner._get_reasoning_risk(tool, input_parm, additional_info, resoning_llm) # 解析为结构化数据 - risk = await self._parse_risk_result(result) + risk = await MCPPlanner._parse_risk_result(result) # 返回风险评估结果 return risk - async def _get_reasoning_risk(self, tool: MCPTool, input_param: dict[str, Any], additional_info: str) -> str: + @staticmethod + async def _get_reasoning_risk( + tool: MCPTool, input_param: dict[str, Any], + additional_info: str, resoning_llm: ReasoningLLM) -> str: """获取推理大模型的风险评估结果""" - template = self._env.from_string(RISK_EVALUATE) + template = _env.from_string(RISK_EVALUATE) prompt = template.render( - tool=tool, + tool_name=tool.name, + tool_description=tool.description, input_param=input_param, additional_info=additional_info, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) return result - async def _parse_risk_result(self, result: str) -> ToolRisk: + @staticmethod + async def _parse_risk_result(result: str) -> ToolRisk: """将推理结果解析为结构化数据""" schema = ToolRisk.model_json_schema() - risk = await self._parse_result(result, schema) + risk = await MCPPlanner._parse_result(result, schema) # 使用ToolRisk模型解析结果 return ToolRisk.model_validate(risk) + @staticmethod + async def _get_reasoning_tool_execute_error_type( + user_goal: str, current_plan: MCPPlan, + tool: MCPTool, input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理大模型的工具执行错误类型""" + template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS) + prompt = template.render( + goal=user_goal, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return result + + @staticmethod + async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType: + """将推理结果解析为工具执行错误类型""" + schema = ToolExcutionErrorType.model_json_schema() + error_type = await MCPPlanner._parse_result(result, schema) + # 使用ToolExcutionErrorType模型解析结果 + return ToolExcutionErrorType.model_validate(error_type) + + @staticmethod + async def get_tool_execute_error_type( + user_goal: str, current_plan: MCPPlan, + tool: MCPTool, input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolExcutionErrorType: + """获取MCP工具执行错误类型""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_tool_execute_error_type( + user_goal, current_plan, tool, input_param, error_message, reasoning_llm) + error_type = await MCPPlanner._parse_tool_execute_error_type_result(result) + # 返回工具执行错误类型 + return error_type + + @staticmethod async def get_missing_param( - self, tool: MCPTool, schema: dict[str, Any], + tool: MCPTool, input_param: dict[str, Any], - error_message: str) -> list[str]: + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> list[str]: """获取缺失的参数""" - slot = Slot(schema=schema) + slot = Slot(schema=tool.input_schema) + template = _env.from_string(GET_MISSING_PARAMS) schema_with_null = slot.add_null_to_basic_types() - template = self._env.from_string(GET_MISSING_PARAMS) prompt = template.render( - tool=tool, + tool_name=tool.name, + tool_description=tool.description, input_param=input_param, schema=schema_with_null, error_message=error_message, ) - result = await self.get_resoning_result(prompt) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) # 解析为结构化数据 - input_param_with_null = await self._parse_result(result, schema_with_null) + input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null) return input_param_with_null - async def generate_answer(self, plan: MCPPlan, memory: str) -> str: + @staticmethod + async def generate_answer( + user_goal: str, plan: MCPPlan, memory: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> AsyncGenerator[ + str, None]: """生成最终回答""" - template = self._env.from_string(FINAL_ANSWER) + template = _env.from_string(FINAL_ANSWER) prompt = template.render( - plan=plan, + plan=plan.model_dump(exclude_none=True, by_alias=True), memory=memory, - goal=self.user_goal, + goal=user_goal, ) - llm = ReasoningLLM() - result = "" - async for chunk in llm.call( + async for chunk in resoning_llm.call( [{"role": "user", "content": prompt}], streaming=False, temperature=0.07, ): - result += chunk - - self.input_tokens = llm.input_tokens - self.output_tokens = llm.output_tokens - - return result + yield chunk \ No newline at end of file diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index cf05afc8..b5bc085c 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -62,6 +62,62 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: """) +TOOL_SELECT = dedent(r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,附加信息,选择最合适的MCP工具。 + ## 选择MCP工具时的注意事项: + 1. 确保充分理解当前目标,选择实现目标所需的MCP工具。 + 2. 请在给定的MCP工具列表中选择,不要自己生成MCP工具。 + 3. 可以选择一些辅助工具,但必须确保这些工具与当前目标相关。 + 必须按照以下格式生成选择结果,不要输出任何其他内容: + ```json + { + "tool_ids": ["工具ID1", "工具ID2", ...] + } + ``` + + # 示例 + ## 目标 + 调优mysql性能 + ## MCP工具列表 + + - mcp_tool_1 MySQL链接池工具;用于优化MySQL链接池 + - mcp_tool_2 MySQL性能调优工具;用于分析MySQL性能瓶颈 + - mcp_tool_3 MySQL查询优化工具;用于优化MySQL查询语句 + - mcp_tool_4 MySQL索引优化工具;用于优化MySQL索引 + - mcp_tool_5 文件存储工具;用于存储文件 + - mcp_tool_6 mongoDB工具;用于操作MongoDB数据库 + + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```json + { + "max_connections": 1000, + "innodb_buffer_pool_size": "1G", + "query_cache_size": "64M" + } + ##输出 + ```json + { + "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"] + } + ``` + # 现在开始! + ## 目标 + {{goal}} + ## MCP工具列表 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + ## 附加信息 + {{additional_info}} + # 输出 + """ + ) + EVALUATE_GOAL = dedent(r""" 你是一个计划评估器。 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 @@ -76,18 +132,18 @@ EVALUATE_GOAL = dedent(r""" ``` # 样例 - ## 目标 - 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 - ## 工具集合 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + # 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - - mysql_analyzer分析MySQL数据库性能 - - performance_tuner调优数据库性能 - - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + - mysql_analyzer 分析MySQL数据库性能 + - performance_tuner 调优数据库性能 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 - ## 附加信息 + # 附加信息 1. 当前MySQL数据库的版本是8.0.26 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf @@ -100,17 +156,17 @@ EVALUATE_GOAL = dedent(r""" ``` # 目标 - {{ goal }} + {{goal}} # 工具集合 - {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} - {% endfor %} + { % for tool in tools % } + - {{tool.id}} {{tool.name}};{{tool.description}} + { % endfor % } # 附加信息 - {{ additional_info }} + {{additional_info}} """) GENERATE_FLOW_NAME = dedent(r""" @@ -123,15 +179,79 @@ GENERATE_FLOW_NAME = dedent(r""" 4. 流程名称应该尽量简短,小于20个字或者单词。 5. 只输出流程名称,不要输出其他内容。 # 样例 - ## 目标 - 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 - ## 输出 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 输出 扫描MySQL数据库并分析性能瓶颈,进行调优 # 现在开始生成流程名称: # 目标 - {{ goal }} + {{goal}} # 输出 """) +GET_REPLAN_START_STEP_INDEX = dedent(r""" + 你是一个智能助手,你的任务是根据用户的目标、报错信息和当前计划和历史,获取重新规划的步骤起始索引。 + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 报错信息 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 当前计划 + ```json + { + "plans": [ + { + "step_id": "step_1", + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描 + }, + { + "step_id": "step_2", + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + } + ] + } + # 历史 + [ + { + id: "0", + task_id: "task_1", + flow_id: "flow_1", + flow_name: "MYSQL性能调优", + flow_status: "RUNNING", + step_id: "step_1", + step_name: "生成端口扫描命令", + step_description: "生成端口扫描命令:扫描当前MySQL数据库的端口", + step_status: "FAILED", + input_data: { + "command": "nmap -p 3306 + "target": "localhost" + }, + output_data: { + "error": "- bash: curl: command not found" + } + } + ] + # 输出 + { + "start_index": 0, + "reasoning": "当前计划的第一步就失败了,报错信息显示curl命令未找到,可能是因为没有安装curl工具,因此需要从第一步重新规划。" + } + # 现在开始获取重新规划的步骤起始索引: + # 目标 + {{goal}} + # 报错信息 + {{error_message}} + # 当前计划 + {{current_plan}} + # 历史 + {{history}} + # 输出 + """) + CREATE_PLAN = dedent(r""" 你是一个计划生成器。 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 @@ -163,40 +283,38 @@ CREATE_PLAN = dedent(r""" } ``` - - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ -思考过程应放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 - - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 # 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} - {% endfor %} + { % for tool in tools % } + - {{tool.id}} {{tool.name}};{{tool.description}} + { % endfor % } # 样例 - ## 目标 + # 目标 - 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + 在后台运行一个新的alpine: latest容器,将主机/root文件夹挂载至/data,并执行top命令。 - ## 计划 + # 计划 - 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server + 1. 这个目标需要使用Docker来完成, 首先需要选择合适的MCP Server 2. 目标可以拆解为以下几个部分: - - 运行alpine:latest容器 + - 运行alpine: latest容器 - 挂载主机目录 - 在后台运行 - 执行top命令 - 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令 - - - ```json + 3. 需要先选择MCP Server, 然后生成Docker命令, 最后执行命令 + ```json { "plans": [ { @@ -225,7 +343,7 @@ CREATE_PLAN = dedent(r""" # 现在开始生成计划: - ## 目标 + # 目标 {{goal}} @@ -263,26 +381,24 @@ RECREATE_PLAN = dedent(r""" } ``` - - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ -思考过程应放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 - - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 # 样例 - ## 目标 + # 目标 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。 - ## 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + # 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - - command_generator生成命令行指令 - - tool_selector选择合适的工具 - - command_executor执行命令行指令 - - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 - - - ## 当前计划 + - command_generator 生成命令行指令 + - tool_selector 选择合适的工具 + - command_executor 执行命令行指令 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + # 当前计划 ```json { "plans": [ @@ -304,25 +420,23 @@ RECREATE_PLAN = dedent(r""" ] } ``` - ## 运行报错 - 执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 - ## 重新生成的计划 + # 运行报错 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 重新生成的计划 - 1. 这个目标需要使用网络扫描工具来完成,首先需要选择合适的网络扫描工具 + 1. 这个目标需要使用网络扫描工具来完成, 首先需要选择合适的网络扫描工具 2. 目标可以拆解为以下几个部分: - 生成端口扫描命令 - 执行端口扫描命令 - 3.但是在执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 + 3.但是在执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 4.我将计划调整为: - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具 - 执行这个命令,查看当前机器支持哪些网络扫描工具 - 然后从中选择一个网络扫描工具 - 基于选择的网络扫描工具,生成端口扫描命令 - 执行端口扫描命令 - - - ```json + ```json { "plans": [ { @@ -367,19 +481,19 @@ RECREATE_PLAN = dedent(r""" # 工具 - 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 - {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} - {% endfor %} + { % for tool in tools % } + - {{tool.id}} {{tool.name}};{{tool.description}} + { % endfor % } # 当前计划 - {{ current_plan }} + {{current_plan}} # 运行报错 - {{ error_message }} + {{error_message}} # 重新生成的计划 """) @@ -393,18 +507,18 @@ RISK_EVALUATE = dedent(r""" } ``` # 样例 - ## 工具名称 + # 工具名称 mysql_analyzer - ## 工具描述 + # 工具描述 分析MySQL数据库性能 - ## 工具入参 + # 工具入参 { "host": "192.0.0.1", "port": 3306, "username": "root", "password": "password" } - ## 附加信息 + # 附加信息 1. 当前MySQL数据库的版本是8.0.26 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 ```ini @@ -412,7 +526,7 @@ RISK_EVALUATE = dedent(r""" innodb_buffer_pool_size=1G innodb_log_file_size=256M ``` - ## 输出 + # 输出 ```json { "risk": "中", @@ -421,35 +535,35 @@ RISK_EVALUATE = dedent(r""" ``` # 工具 - {{ tool.name }} - {{ tool.description }} + {{tool_name}} + {{tool_description}} # 工具入参 - {{ input_param }} + {{input_param}} # 附加信息 - {{ additional_info }} + {{additional_info}} # 输出 """ ) # 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划 -JUDGE_NEXT_STEP = dedent(r""" +TOOL_EXECUTE_ERROR_TYPE_ANALYSIS = dedent(r""" 你是一个计划决策器。 - 你的任务是根据当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 + 你的任务是根据用户目标、当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 请根据以下规则进行判断: - 1. 仅通过补充工具入参来解决问题的,返回 fill_params; - 2. 需要重计划当前步骤的,返回 replan_current_step; - 3. 需要重计划接下来的所有计划的,返回 replan_all_steps; + 1. 仅通过补充工具入参来解决问题的,返回 missing_param; + 2. 需要重计划当前步骤的,返回 decorrect_plan + 3.推理过程必须清晰明了,能够让人理解你的判断依据,并且不超过100字。 你的输出要以json格式返回,格式如下: ```json { - "next_step": "fill_params/replan_current_step/replan_all_steps", - "reason": "你的判断依据" + "error_type": "missing_param/decorrect_plan, + "reason": "你的推理过程" } ``` - 注意: - reason字段必须清晰明了,能够让人理解你的判断依据,并且不超过50个中文字或者100个英文单词。 # 样例 - ## 当前计划 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 当前计划 {"plans": [ { "content": "生成端口扫描命令", @@ -464,41 +578,43 @@ JUDGE_NEXT_STEP = dedent(r""" { "content": "任务执行完成,端口扫描结果为Result[2]", "tool": "Final", - "instruction": "" + "instruction": "" } ]} - ## 当前使用的工具 + # 当前使用的工具 - command_executor - 执行命令行指令 + command_executor + 执行命令行指令 - ## 工具入参 + # 工具入参 { "command": "nmap -sS -p--open 192.168.1.1" } - ## 工具运行报错 - 执行端口扫描命令时,出现了错误:`-bash: nmap: command not found`。 - ## 输出 + # 工具运行报错 + 执行端口扫描命令时,出现了错误:`- bash: nmap: command not found`。 + # 输出 ```json { - "next_step": "replan_all_steps", - "reason": "当前工具执行报错,提示nmap命令未找到,需要增加command_generator和command_executor的步骤,生成nmap安装命令并执行,之后再生成端口扫描命令并执行。" + "error_type": "decorrect_plan", + "reason": "当前计划的第二步执行失败,报错信息显示nmap命令未找到,可能是因为没有安装nmap工具,因此需要重计划当前步骤。" } ``` + # 用户目标 + {{goal}} # 当前计划 - {{ current_plan }} + {{current_plan}} # 当前使用的工具 - {{ tool.name }} - {{ tool.description }} + {{tool_name}} + {{tool_description}} # 工具入参 - {{ input_param }} + {{input_param}} # 工具运行报错 - {{ error_message }} + {{error_message}} # 输出 """ - ) + ) # 获取缺失的参数的json结构体 GET_MISSING_PARAMS = dedent(r""" 你是一个工具参数获取器。 @@ -570,10 +686,10 @@ GET_MISSING_PARAMS = dedent(r""" } ``` # 工具 - < tool > - < name > {{tool.name}} < /name > - < description > {{tool.description}} < /description > - < / tool > + + {{tool_name}} + {{tool_description}} + # 工具入参 {{input_param}} # 工具入参schema(部分字段允许为null) @@ -583,6 +699,82 @@ GET_MISSING_PARAMS = dedent(r""" # 输出 """ ) +REPAIR_PARAMS = dedent(r""" + 你是一个工具参数修复器。 + 你的任务是根据当前的工具信息、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。 + + # 样例 + # 工具信息 + + mysql_analyzer + 分析MySQL数据库性能 + + # 工具入参的schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL数据库的主机地址" + }, + "port": { + "type": "integer", + "description": "MySQL数据库的端口号" + }, + "username": { + "type": "string", + "description": "MySQL数据库的用户名" + }, + "password": { + "type": "string", + "description": "MySQL数据库的密码" + } + }, + "required": ["host", "port", "username", "password"] + } + # 工具当前的入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具的报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 补充的参数 + { + "username": "admin", + "password": "admin123" + } + # 补充的参数描述 + 用户希望使用admin用户和admin123密码来连接MySQL数据库。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "admin", + "password": "admin123" + } + ``` + # 工具 + + {{tool_name}} + {{tool_description}} + + # 工具入参scheme + {{input_schema}} + # 工具入参 + {{input_param}} + # 运行报错 + {{error_message}} + # 补充的参数 + {{params}} + # 补充的参数描述 + {{params_description}} + # 输出 + """ + ) FINAL_ANSWER = dedent(r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 @@ -603,12 +795,11 @@ FINAL_ANSWER = dedent(r""" # 现在,请根据以上信息,向用户报告目标的完成情况: """) - MEMORY_TEMPLATE = dedent(r""" - { % for ctx in context_list % } + {% for ctx in context_list % } - 第{{loop.index}}步:{{ctx.step_description}} - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` - 执行状态:{{ctx.status}} - 得到数据:`{{ctx.output_data}}` - { % endfor % } + {% endfor % } """) diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index 7198940b..f917a38a 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -2,195 +2,100 @@ """选择MCP Server及其工具""" import logging -import uuid -from collections.abc import AsyncGenerator +import random +from typing import AsyncGenerator from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment +from apps.common.config import Config from apps.common.lance import LanceDB from apps.common.mongo import MongoDB from apps.llm.embedding import Embedding from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp.prompt import ( - MCP_SELECT, -) -from apps.schemas.mcp import ( - MCPCollection, - MCPSelectResult, - MCPTool, -) +from apps.llm.token import TokenCalculator +from apps.scheduler.mcp_agent.prompt import TOOL_SELECT +from apps.schemas.mcp import BaseModel, MCPCollection, MCPSelectResult, MCPTool, MCPToolIdsSelectResult logger = logging.getLogger(__name__) +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) -class MCPSelector: - """MCP选择器""" - - def __init__(self, resoning_llm: ReasoningLLM = None) -> None: - """初始化助手类""" - self.resoning_llm = resoning_llm or ReasoningLLM() - self.input_tokens = 0 - self.output_tokens = 0 - - @staticmethod - def _assemble_sql(mcp_list: list[str]) -> str: - """组装SQL""" - sql = "(" - for mcp_id in mcp_list: - sql += f"'{mcp_id}', " - return sql.rstrip(", ") + ")" - - - async def _get_top_mcp_by_embedding( - self, - query: str, - mcp_list: list[str], - ) -> list[dict[str, str]]: - """通过向量检索获取Top5 MCP Server""" - logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) - mcp_table = await LanceDB().get_table("mcp") - query_embedding = await Embedding.get_embedding([query]) - mcp_vecs = await (await mcp_table.search( - query=query_embedding, - vector_column_name="embedding", - )).where(f"id IN {MCPSelector._assemble_sql(mcp_list)}").limit(5).to_list() - - # 拿到名称和description - logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs) - mcp_collection = MongoDB().get_collection("mcp") - llm_mcp_list: list[dict[str, str]] = [] - for mcp_vec in mcp_vecs: - mcp_id = mcp_vec["id"] - mcp_data = await mcp_collection.find_one({"_id": mcp_id}) - if not mcp_data: - logger.warning("[MCPHelper] 查询MCP Server名称和描述失败: %s", mcp_id) - continue - mcp_data = MCPCollection.model_validate(mcp_data) - llm_mcp_list.extend([{ - "id": mcp_id, - "name": mcp_data.name, - "description": mcp_data.description, - }]) - return llm_mcp_list - - - async def _get_mcp_by_llm( - self, - query: str, - mcp_list: list[dict[str, str]], - mcp_ids: list[str], - ) -> MCPSelectResult: - """通过LLM选择最合适的MCP Server""" - # 初始化jinja2环境 - env = SandboxedEnvironment( - loader=BaseLoader, - autoescape=True, - trim_blocks=True, - lstrip_blocks=True, - ) - template = env.from_string(MCP_SELECT) - # 渲染模板 - mcp_prompt = template.render( - mcp_list=mcp_list, - goal=query, - ) - - # 调用大模型进行推理 - result = await self._call_reasoning(mcp_prompt) - - # 使用小模型提取JSON - return await self._call_function_mcp(result, mcp_ids) - - - async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]: - """调用大模型进行推理""" - logger.info("[MCPHelper] 调用推理大模型") - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ] - async for chunk in self.resoning_llm.call(message): - yield chunk - - - async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: - """调用结构化输出小模型提取JSON""" - logger.info("[MCPHelper] 调用结构化输出小模型") - llm = FunctionLLM() - message = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": reasoning_result}, - ] - schema = MCPSelectResult.model_json_schema() - # schema中加入选项 - schema["properties"]["mcp_id"]["enum"] = mcp_ids - result = await llm.call(messages=message, schema=schema) - try: - result = MCPSelectResult.model_validate(result) - except Exception: - logger.exception("[MCPHelper] 解析MCP Select Result失败") - raise - return result - - - async def select_top_mcp( - self, - query: str, - mcp_list: list[str], - ) -> MCPSelectResult: - """ - 选择最合适的MCP Server - - 先通过Embedding选择Top5,然后通过LLM选择Top 1 - """ - # 通过向量检索获取Top5 - llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list) +FINAL_TOOL_ID = "FIANL" +SUMMARIZE_TOOL_ID = "SUMMARIZE" - # 通过LLM选择最合适的 - return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) +class MCPSelector: + """MCP选择器""" @staticmethod - async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: + async def select_top_tool( + goal: str, tool_list: list[MCPTool], + additional_info: str | None = None, top_n: int | None = None) -> list[MCPTool]: """选择最合适的工具""" - tool_vector = await LanceDB().get_table("mcp_tool") - query_embedding = await Embedding.get_embedding([query]) - tool_vecs = await (await tool_vector.search( - query=query_embedding, - vector_column_name="embedding", - )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list() - - # 拿到工具 - tool_collection = MongoDB().get_collection("mcp") - llm_tool_list = [] - - for tool_vec in tool_vecs: - # 到MongoDB里找对应的工具 - logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) - tool_data = await tool_collection.aggregate([ - {"$match": {"_id": tool_vec["mcp_id"]}}, - {"$unwind": "$tools"}, - {"$match": {"tools.id": tool_vec["id"]}}, - {"$project": {"_id": 0, "tools": 1}}, - {"$replaceRoot": {"newRoot": "$tools"}}, - ]) - async for tool in tool_data: - tool_obj = MCPTool.model_validate(tool) - llm_tool_list.append(tool_obj) - - llm_tool_list.append( - MCPTool( - id="00000000-0000-0000-0000-000000000000", - name="Final", - description="It is the final step, indicating the end of the plan execution."), - ) - llm_tool_list.append( - MCPTool( - id="00000000-0000-0000-0000-000000000001", - name="Chat", - description="It is a chat tool to communicate with the user."), - ) - - return llm_tool_list + random.shuffle(tool_list) + max_tokens = Config().get_config().function_call.max_tokens + template = _env.from_string(TOOL_SELECT) + if TokenCalculator.calculate_token_length( + messages=[{"role": "user", "content": template.render( + goal=goal, tools=[], additional_info=additional_info + )}], + pure_text=True) > max_tokens: + logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择") + return [] + llm = FunctionLLM() + current_index = 0 + tool_ids = [] + while current_index < len(tool_list): + index = current_index + sub_tools = [] + while index < len(tool_list): + tool = tool_list[index] + tokens = TokenCalculator.calculate_token_length( + messages=[{"role": "user", "content": template.render( + goal=goal, tools=[tool], + additional_info=additional_info + )}], + pure_text=True + ) + if tokens > max_tokens: + continue + sub_tools.append(tool) + + tokens = TokenCalculator.calculate_token_length(messages=[{"role": "user", "content": template.render( + goal=goal, tools=sub_tools, additional_info=additional_info)}, ], pure_text=True) + if tokens > max_tokens: + del sub_tools[-1] + break + else: + index += 1 + current_index = index + if sub_tools: + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": template.render(tools=sub_tools)}, + ] + schema = MCPToolIdsSelectResult.model_json_schema() + schema["properties"]["tool_ids"]["enum"] = [tool.id for tool in sub_tools] + result = await llm.call(messages=message, schema=schema) + try: + result = MCPToolIdsSelectResult.model_validate(result) + tool_ids.extend(result.tool_ids) + except Exception: + logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败") + continue + mcp_tools = [tool for tool in tool_list if tool.id in tool_ids] + + if top_n is not None: + mcp_tools = mcp_tools[:top_n] + mcp_tools.append(MCPTool(id=FINAL_TOOL_ID, name="Final", + description="终止", mcp_id=FINAL_TOOL_ID, input_schema={})) + # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize", + # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={})) + return mcp_tools \ No newline at end of file diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 272a8c60..3371cbdc 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -190,7 +190,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: flow_id=task.state.flow_id, flow_name=task.state.flow_name, flow_status=task.state.flow_status, - history_ids=[context["_id"] for context in task.context], + history_ids=[context.id for context in task.context], ), ) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index d35bbb21..0c58a410 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -7,6 +7,8 @@ from datetime import UTC, datetime from apps.common.config import config from apps.common.queue import MessageQueue +from apps.llm.patterns.rewrite import QuestionRewrite +from apps.llm.reasoning import ReasoningLLM from apps.models.app import App from apps.scheduler.executor.agent import MCPAgentExecutor from apps.scheduler.executor.flow import FlowExecutor @@ -17,6 +19,7 @@ from apps.scheduler.scheduler.message import ( push_init_message, push_rag_message, ) +from apps.schemas.config import LLMConfig from apps.schemas.enum_var import AppType, EventType, FlowStatus from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData @@ -68,44 +71,49 @@ class Scheduler: logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") - async def run(self) -> None: # noqa: PLR0911 - """运行调度器""" + async def get_llm_use_in_chat_with_rag(self) -> LLM: + """获取RAG大模型""" try: # 获取当前会话使用的大模型 - llm_id = await LLMManager.get_user_default_llm( + llm_id = await LLMManager.get_llm_id_by_conversation_id( self.task.ids.user_sub, self.task.ids.conversation_id, ) if not llm_id: logger.error("[Scheduler] 获取大模型ID失败") - await self.queue.close() - return + return None if llm_id == "empty": - llm = LLM( + return LLM( _id="empty", user_sub=self.task.ids.user_sub, - openai_base_url=config.llm.endpoint, - openai_api_key=config.llm.key, - model_name=config.llm.model, - max_tokens=config.llm.max_tokens, + openai_base_url=Config().get_config().llm.endpoint, + openai_api_key=Config().get_config().llm.key, + model_name=Config().get_config().llm.model, + max_tokens=Config().get_config().llm.max_tokens, ) - else: - llm = await LLMManager.get_llm(self.task.ids.user_sub, llm_id) - if not llm: - logger.error("[Scheduler] 获取大模型失败") - await self.queue.close() - return + llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id) + if not llm: + logger.error("[Scheduler] 获取大模型失败") + return None + return llm except Exception: logger.exception("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None + + + async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]: + """获取知识库ID列表""" try: - # 获取当前会话使用的知识库 - kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id( - self.task.ids.user_sub, self.task.ids.conversation_id) + return await KnowledgeBaseManager.get_kb_ids_by_conversation_id( + self.task.ids.user_sub, self.task.ids.conversation_id, + ) except Exception: logger.exception("[Scheduler] 获取知识库ID失败") await self.queue.close() - return + return [] + + + async def run(self) -> None: + """运行调度器""" try: # 获取当前问答可供关联的文档 docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body) @@ -121,6 +129,8 @@ class Scheduler: kill_event = asyncio.Event() monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub)) if not self.post_body.app or self.post_body.app.app_id == "": + llm = await self.get_llm_use_in_chat_with_rag() + kb_ids = await self.get_kb_ids_use_in_chat_with_rag() self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( kbIds=kb_ids, @@ -197,6 +207,27 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return + llm = await LLMManager.get_llm_by_id( + self.task.ids.user_sub, app_metadata.llm_id, + ) + if not llm: + logger.error("[Scheduler] 获取大模型失败") + await self.queue.close() + return + reasion_llm = ReasoningLLM( + LLMConfig( + endpoint=llm.openai_base_url, + key=llm.openai_api_key, + model=llm.model_name, + max_tokens=llm.max_tokens, + ), + ) + if background.conversation: + try: + question_obj = QuestionRewrite() + post_body.question = await question_obj.generate(history=background.conversation, question=post_body.question, llm=reasion_llm) + except Exception: + logger.exception("[Scheduler] 问题重写失败") if app_metadata.app_type == AppType.FLOW.value: logger.info("[Scheduler] 获取工作流元数据") flow_info = await Pool().get_flow_metadata(app_info.app_id) @@ -255,6 +286,7 @@ class Scheduler: servers_id=servers_id, background=background, agent_id=app_info.app_id, + params=post_body.app.params, ) # 开始运行 logger.info("[Scheduler] 运行Executor") diff --git a/apps/schemas/appcenter.py b/apps/schemas/appcenter.py index b37716f3..f6ccd609 100644 --- a/apps/schemas/appcenter.py +++ b/apps/schemas/appcenter.py @@ -52,6 +52,14 @@ class AppFlowInfo(BaseModel): debug: bool = Field(..., description="是否经过调试") +class AppMcpServiceInfo(BaseModel): + """应用关联的MCP服务信息""" + + id: str = Field(..., description="MCP服务ID") + name: str = Field(..., description="MCP服务名称") + description: str = Field(..., description="MCP服务简介") + + class AppData(BaseModel): """应用信息数据结构""" @@ -66,4 +74,4 @@ class AppData(BaseModel): permission: AppPermissionData = Field( default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置") workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表") - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务id列表") diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index d46f5289..f80a0c4a 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -16,6 +16,7 @@ class StepStatus(str, Enum): """步骤状态""" UNKNOWN = "unknown" + INIT = "init" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" @@ -28,6 +29,7 @@ class FlowStatus(str, Enum): """Flow状态""" UNKNOWN = "unknown" + INIT = "init" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" @@ -55,12 +57,15 @@ class EventType(str, Enum): STEP_WAITING_FOR_START = "step.waiting_for_start" STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" + STEP_INIT = "step.init" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" + STEP_CANCEL = "step.cancel" + STEP_ERROR = "step.error" FLOW_STOP = "flow.stop" FLOW_FAILED = "flow.failed" FLOW_SUCCESS = "flow.success" - FLOW_CANCELLED = "flow.cancelled" + FLOW_CANCEL = "flow.cancel" DONE = "done" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index c31542bc..226d6849 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -64,6 +64,13 @@ class GoalEvaluationResult(BaseModel): reason: str = Field(description="评估原因") +class RestartStepIndex(BaseModel): + """MCP重新规划的步骤索引""" + + start_index: int = Field(description="重新规划的起始步骤索引") + reasoning: str = Field(description="重新规划的原因") + + class Risk(str, Enum): """MCP工具风险类型""" @@ -79,12 +86,38 @@ class ToolRisk(BaseModel): reason: str = Field(description="风险原因", default="") +class ErrorType(str, Enum): + """MCP工具错误类型""" + + MISSING_PARAM = "missing_param" + DECORRECT_PLAN = "decorrect_plan" + + +class ToolExcutionErrorType(BaseModel): + """MCP工具执行错误""" + + type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM) + reason: str = Field(description="错误原因", default="") + + class MCPSelectResult(BaseModel): """MCP选择结果""" mcp_id: str = Field(description="MCP Server的ID") +class MCPToolSelectResult(BaseModel): + """MCP工具选择结果""" + + name: str = Field(description="工具名称") + + +class MCPToolIdsSelectResult(BaseModel): + """MCP工具ID选择结果""" + + tool_ids: list[str] = Field(description="工具ID列表") + + class MCPPlanItem(BaseModel): """MCP 计划""" diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 0a441dd2..0b9ef518 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -7,10 +7,18 @@ from typing import Any from pydantic import BaseModel, Field -from .enum_var import EventType, StepStatus +from apps.schemas.enum_var import EventType, FlowStatus, StepStatus + from .record import RecordMetadata +class param(BaseModel): + """流执行过程中的参数补充""" + + content: dict[str, Any] | bool = Field(default={}, description="流执行过程中的参数补充内容") + description: str = Field(default="", description="流执行过程中的参数补充描述") + + class HeartbeatData(BaseModel): """心跳事件的数据结构""" @@ -24,6 +32,8 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") + flow_name: str = Field(description="Flow名称", alias="flowName") + flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN) step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) @@ -78,7 +88,7 @@ class FlowStartContent(BaseModel): """flow.start消息的content""" question: str = Field(description="用户问题") - params: dict[str, Any] = Field(description="预先提供的参数") + params: dict[str, Any] | None = Field(description="预先提供的参数", default=None) class MessageBase(HeartbeatData): @@ -89,5 +99,5 @@ class MessageBase(HeartbeatData): conversation_id: uuid.UUID = Field(min_length=36, max_length=36, alias="conversationId") task_id: str = Field(min_length=36, max_length=36, alias="taskId") flow: MessageFlow | None = None - content: dict[str, Any] = {} + content: Any | None = Field(default=None, description="消息内容") metadata: MessageMetadata diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 578f567c..119a8082 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -41,6 +41,7 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_name: str = Field(alias="flowName", default="") flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] @@ -81,7 +82,7 @@ class RecordData(BaseModel): id: uuid.UUID conversation_id: uuid.UUID = Field(alias="conversationId") - task_id: uuid.UUID = Field(alias="taskId") + task_id: uuid.UUID | None = Field(alias="taskId", default=None) document: list[RecordDocument] = [] flow: RecordFlow | None = None content: RecordContent @@ -118,7 +119,8 @@ class Record(RecordData): user_sub: str key: dict[str, Any] = {} - content: str + task_id: uuid.UUID | None = Field(default=None, description="任务ID") + content: str = Field(default="", description="Record内容,已加密") comment: RecordComment = Field(default=RecordComment()) flow: FlowHistory = Field( default=FlowHistory(), description="Flow执行历史信息") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index b369ddf1..f271f6ee 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -10,6 +10,7 @@ from .appcenter import AppData from .enum_var import CommentType from .flow_topology import FlowItem from .mcp import MCPType +from .message import param class RequestDataApp(BaseModel): @@ -17,7 +18,7 @@ class RequestDataApp(BaseModel): app_id: uuid.UUID = Field(description="应用ID", alias="appId") flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") - params: dict[str, Any] | None = Field(default=None, description="插件参数") + params: param | None = Field(default=None, description="流执行过程中的参数补充", alias="params") class RequestData(BaseModel): @@ -31,7 +32,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") - new_task: bool = Field(default=True, description="是否新建任务") + task_id: str | None = Field(default=None, alias="taskId", description="任务ID") class QuestionBlacklistRequest(BaseModel): @@ -163,3 +164,9 @@ class UpdateUserKnowledgebaseReq(BaseModel): """更新知识库请求体""" kb_ids: list[uuid.UUID] = Field(description="知识库ID列表", alias="kbIds", default=[]) + + +class UserUpdateRequest(BaseModel): + """更新用户信息请求体""" + + auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index f41fdb02..3373fda3 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -43,6 +43,7 @@ class AuthUserMsg(BaseModel): user_sub: str revision: bool is_admin: bool + auto_execute: bool class AuthUserRsp(ResponseData): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index e0a2b39e..2e0a65d2 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -31,6 +31,7 @@ class FlowStepHistory(BaseModel): step_status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) + ex_data: dict[str, Any] | None = Field(description="额外数据", default=None) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) @@ -41,17 +42,17 @@ class ExecutorState(BaseModel): flow_id: str = Field(description="Flow ID", default="") flow_name: str = Field(description="Flow名称", default="") description: str = Field(description="Flow描述", default="") - flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN) + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) # 任务级数据 step_id: str = Field(description="当前步骤ID", default="") + step_index: int = Field(description="当前步骤索引", default=0) step_name: str = Field(description="当前步骤名称", default="") step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") - retry_times: int = Field(description="当前步骤重试次数", default=0) - error_message: str = Field(description="当前步骤错误信息", default="") app_id: str = Field(description="应用ID", default="") - slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) - error_info: dict[str, Any] = Field(description="错误信息", default={}) + current_input: dict[str, Any] = Field(description="当前输入数据", default={}) + error_message: str = Field(description="错误信息", default="") + retry_times: int = Field(description="当前步骤重试次数", default=0) class TaskIds(BaseModel): @@ -93,7 +94,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) + context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") diff --git a/apps/services/conversation.py b/apps/services/conversation.py index ade7fb6b..47fa28d2 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -153,7 +153,7 @@ class ConversationManager: await session.delete(conv) await session.commit() - await TaskManager.delete_tasks_by_conversation_id(conversation_id) + await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id) @staticmethod diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index a08d5455..dc6619ab 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -95,6 +95,8 @@ class MCPServiceManager: user_sub: str, keyword: str | None, page: int, + *, + is_active: bool | None = None, ) -> list[MCPServiceCardItem]: """ 获取所有MCP服务列表 @@ -105,7 +107,7 @@ class MCPServiceManager: :param page: int: 页码 :return: MCP服务列表 """ - mcpservice_pools = await MCPServiceManager._search_mcpservice(search_type, keyword, page) + mcpservice_pools = await MCPServiceManager._search_mcpservice(search_type, keyword, page, is_active) return [ MCPServiceCardItem( mcpserviceId=item.id, @@ -162,6 +164,8 @@ class MCPServiceManager: search_type: SearchType, keyword: str | None, page: int, + *, + is_active: bool | None = None, ) -> list[MCPInfo]: """ 基于输入条件搜索MCP服务 diff --git a/apps/services/record.py b/apps/services/record.py index 3b515108..d62603b5 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -11,6 +11,7 @@ from sqlalchemy import and_, select from apps.common.postgres import postgres from apps.models.conversation import Conversation from apps.models.record import Record as PgRecord +from apps.schemas.enum_var import FlowStatus from apps.schemas.record import Record logger = logging.getLogger(__name__) @@ -116,3 +117,17 @@ class RecordManager: return [] else: return records + + + @staticmethod + async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: + """更新Record关联的Flow状态""" + record_group_collection = MongoDB().get_collection("record_group") + try: + await record_group_collection.update_many( + {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], + ) + except Exception: + logger.exception("[RecordManager] 更新Record关联的Flow状态失败") diff --git a/apps/services/task.py b/apps/services/task.py index ec8436b1..34e6bf75 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -3,11 +3,11 @@ import logging import uuid -from typing import Any from apps.common.postgres import postgres from apps.schemas.request_data import RequestData from apps.schemas.task import ( + FlowStepHistory, Task, TaskIds, TaskRuntime, @@ -70,7 +70,7 @@ class TaskManager: @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") @@ -88,7 +88,7 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context_list.append(flow_context) + flow_context_list.append(FlowStepHistory.model_validate(flow_context)) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") return [] @@ -97,7 +97,7 @@ class TaskManager: @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") @@ -108,7 +108,8 @@ class TaskManager: ).sort( "created_at", -1, ).limit(length): - flow_context += [history] + for i in range(len(flow_context)): + flow_context.append(FlowStepHistory.model_validate(history)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] @@ -117,7 +118,28 @@ class TaskManager: @staticmethod - async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: + async def init_new_task( + user_sub: str, + session_id: str | None = None, + post_body: RequestData | None = None, + ) -> Task: + """获取任务块""" + return Task( + _id=str(uuid.uuid4()), + ids=TaskIds( + user_sub=user_sub if user_sub else "", + session_id=session_id if session_id else "", + conversation_id=post_body.conversation_id, + group_id=post_body.group_id if post_body.group_id else "", + ), + question=post_body.question if post_body else "", + group_id=post_body.group_id if post_body else "", + tokens=TaskTokens(), + runtime=TaskRuntime(), + ) + + @staticmethod + async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: @@ -125,15 +147,15 @@ class TaskManager: # 查找是否存在 current_context = await flow_context_collection.find_one({ "task_id": task_id, - "_id": history["_id"], + "_id": history.id, }) if current_context: await flow_context_collection.update_one( {"_id": current_context["_id"]}, - {"$set": history}, + {"$set": history.model_dump(exclude_none=True, by_alias=True)}, ) else: - await flow_context_collection.insert_one(history) + await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True)) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") @@ -150,7 +172,26 @@ class TaskManager: @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> None: + async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]: + """通过ConversationID删除Task信息""" + mongo = MongoDB() + task_collection = mongo.get_collection("task") + task_ids = [] + try: + async for task in task_collection.find( + {"conversation_id": conversation_id}, + {"_id": 1}, + ): + task_ids.append(task["_id"]) + if task_ids: + await task_collection.delete_many({"conversation_id": conversation_id}) + return task_ids + except Exception: + logger.exception("[TaskManager] 删除ConversationID的Task信息失败") + return [] + + @staticmethod + async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") @@ -168,50 +209,6 @@ class TaskManager: await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod - async def get_task( - cls, - task_id: str | None = None, - session_id: str | None = None, - post_body: RequestData | None = None, - user_sub: str | None = None, - ) -> Task: - """获取任务块""" - if task_id: - try: - task = await cls.get_task_by_task_id(task_id) - if task: - return task - except Exception: - logger.exception("[TaskManager] 通过task_id获取任务失败") - - logger.info("[TaskManager] 未提供task_id,通过session_id获取任务") - if not session_id or not post_body: - err = ( - "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。" - ) - raise ValueError(err) - - if post_body.group_id: - task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id) - else: - task = await cls.get_task_by_conversation_id(post_body.conversation_id) - - if task: - return task - return Task( - _id=str(uuid.uuid4()), - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - group_id=post_body.group_id if post_body.group_id else "", - ), - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee