diff --git a/apps/common/queue.py b/apps/common/queue.py
index 96a8707302ac3641356d234758cdf53fedba6db1..cba88fe55f05a8392362ef6e8cd4abfe9e0b973c 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -56,6 +56,8 @@ class MessageQueue:
flow = MessageFlow(
appId=task.state.app_id,
flowId=task.state.flow_id,
+ flowName=task.state.flow_name,
+ flowStatus=task.state.flow_status,
stepId=task.state.step_id,
stepName=task.state.step_name,
stepStatus=task.state.step_status,
diff --git a/apps/models/user.py b/apps/models/user.py
index 185d1184cb68791f2b34dd4e3a210d7140b765f6..56b352ac0b68a7a60d2343d5c4486173729b1bf7 100644
--- a/apps/models/user.py
+++ b/apps/models/user.py
@@ -38,6 +38,8 @@ class User(Base):
"""用户选择的知识库的ID"""
defaultLLM: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), default=None, nullable=True) # noqa: N815
"""用户选择的大模型ID"""
+ autoExecute: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) # noqa: N815
+ """Agent是否自动执行"""
class UserFavoriteType(str, enum.Enum):
diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py
index 192e3eb04a1fd73032e82899a28494eed09fe437..d0893a0843a66f015bfa60f8c8c487b5c24d4e33 100644
--- a/apps/routers/appcenter.py
+++ b/apps/routers/appcenter.py
@@ -10,7 +10,7 @@ from fastapi.responses import JSONResponse
from apps.dependency.user import verify_personal_token, verify_session
from apps.exceptions import InstancePermissionError
-from apps.schemas.appcenter import AppFlowInfo, AppPermissionData
+from apps.schemas.appcenter import AppFlowInfo, AppMcpServiceInfo, AppPermissionData
from apps.schemas.enum_var import AppFilterType, AppType
from apps.schemas.request_data import ChangeFavouriteAppRequest, CreateAppRequest
from apps.schemas.response_data import (
@@ -26,6 +26,7 @@ from apps.schemas.response_data import (
ResponseData,
)
from apps.services.appcenter import AppCenterManager
+from apps.services.mcp_service import MCPServiceManager
logger = logging.getLogger(__name__)
router = APIRouter(
@@ -218,6 +219,15 @@ async def get_application(appId: Annotated[uuid.UUID, Path()]) -> JSONResponse:
)
for flow in app_data.flows
]
+ mcp_service = []
+ if app_data.mcp_service:
+ for service in app_data.mcp_service:
+ mcp_collection = await MCPServiceManager.get_mcp_service(service)
+ mcp_service.append(AppMcpServiceInfo(
+ id=mcp_collection.id,
+ name=mcp_collection.name,
+ description=mcp_collection.description,
+ ))
return JSONResponse(
status_code=status.HTTP_200_OK,
content=GetAppPropertyRsp(
@@ -238,7 +248,7 @@ async def get_application(appId: Annotated[uuid.UUID, Path()]) -> JSONResponse:
authorizedUsers=app_data.permission.users,
),
workflows=workflows,
- mcpService=app_data.mcp_service,
+ mcpService=mcp_service,
),
).model_dump(exclude_none=True, by_alias=True),
)
diff --git a/apps/routers/auth.py b/apps/routers/auth.py
index 1d5cc5160a52f90881648d43729216d1ae5ee56b..10b895b973eb9a3e5f3888d99747867c27398e5c 100644
--- a/apps/routers/auth.py
+++ b/apps/routers/auth.py
@@ -164,6 +164,7 @@ async def userinfo(request: Request) -> JSONResponse:
user_sub=request.state.user_sub,
revision=user.isActive,
is_admin=is_admin,
+ auto_execute=user.autoExecute,
),
).model_dump(exclude_none=True, by_alias=True),
)
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 840c2fb9aa386b47a19e609f5aa17592695a40d4..860e71f2b61713e6f296051c82959a15f0333571 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -19,7 +19,9 @@ from apps.schemas.response_data import ResponseData
from apps.schemas.task import Task
from apps.services.activity import Activity
from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager
+from apps.services.conversation import ConversationManager
from apps.services.flow import FlowManager
+from apps.services.record import RecordManager
from apps.services.task import TaskManager
RECOMMEND_TRES = 5
@@ -36,17 +38,25 @@ router = APIRouter(
async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task:
"""初始化Task"""
- if post_body.new_task:
- # 创建或还原Task
- task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
- if task:
- await TaskManager.delete_task_by_task_id(task.id)
- # 创建或还原Task
- task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
# 更改信息并刷新数据库
- if post_body.new_task:
+ if post_body.task_id is None:
+ conversation = await ConversationManager.get_conversation_by_conversation_id(
+ user_sub=user_sub,
+ conversation_id=post_body.conversation_id,
+ )
+ if not conversation:
+ err = "[Chat] 用户没有权限访问该对话!"
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err)
+ task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id)
+ await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids)
+ task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body)
task.runtime.question = post_body.question
task.ids.group_id = post_body.group_id
+ else:
+ if not post_body.task_id:
+ err = "[Chat] task_id 不可为空!"
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty")
+ task = await TaskManager.get_task_by_conversation_id(post_body.task_id)
return task
@@ -70,6 +80,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
# 在单独Task中运行Scheduler,拉齐queue.get的时机
scheduler = Scheduler(task, queue, post_body)
+ logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_sub)}")
scheduler_task = asyncio.create_task(scheduler.run())
# 处理每一条消息
@@ -119,8 +130,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
@router.post("/chat")
async def chat(request: Request, post_body: RequestData) -> StreamingResponse:
"""LLM流式对话接口"""
- user_sub = request.state.user_sub
session_id = request.state.session_id
+ user_sub = request.state.user_sub
# 问题黑名单检测
if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question):
# 用户扣分
diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py
index dd0c0765c1a122c7b0d7b49cd218fc9002765d2b..3007d5609e69e9e6f9c43756c366cdb2bc9a6015 100644
--- a/apps/routers/mcp_service.py
+++ b/apps/routers/mcp_service.py
@@ -54,6 +54,8 @@ async def get_mcpservice_list(
searchType: SearchType = SearchType.ALL, # noqa: N803
keyword: str | None = None,
page: Annotated[int, Query(ge=1)] = 1,
+ *,
+ is_active: bool | None = None,
) -> JSONResponse:
"""获取服务列表"""
user_sub = request.state.user_sub
@@ -63,6 +65,7 @@ async def get_mcpservice_list(
user_sub,
keyword,
page,
+ is_active=is_active,
)
except Exception as e:
err = f"[MCPServiceCenter] 获取MCP服务列表失败: {e}"
diff --git a/apps/routers/record.py b/apps/routers/record.py
index 02c629d02623681cdeedcfaa33ed9562f5c5300b..81fa3003dd6b2efe384f7f540819f220e5be09ed 100644
--- a/apps/routers/record.py
+++ b/apps/routers/record.py
@@ -68,7 +68,7 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path
tmp_record = RecordData(
id=record.id,
groupId=record_group.id,
- taskId=record_group.task_id,
+ taskId=record.task_id,
conversationId=conversationId,
content=record_data,
metadata=record.metadata
@@ -90,13 +90,13 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path
tmp_record.flow = RecordFlow(
id=record.flow.flow_id, # TODO: 此处前端应该用name
recordId=record.id,
- flowStatus=record.flow.flow_staus,
flowId=record.flow.flow_id,
+ flowName=record.flow.flow_name,
+ flowStatus=record.flow.flow_staus,
stepNum=len(flow_step_list),
steps=[],
)
for flow_step in flow_step_list:
- flow_step = FlowStepHistory.model_validate(flow_step)
tmp_record.flow.steps.append(
RecordFlowStep(
stepId=flow_step.step_name, # TODO: 此处前端应该用name
diff --git a/apps/routers/user.py b/apps/routers/user.py
index da3efce9ea686919c40384e6d4f7d711b7fc9333..6451c4b6ad7268d3e24ef8f4933b1c017e8c0ccd 100644
--- a/apps/routers/user.py
+++ b/apps/routers/user.py
@@ -40,3 +40,20 @@ async def list_user(
result=UserGetMsp(userInfoList=user_info_list),
).model_dump(exclude_none=True, by_alias=True),
)
+
+
+@router.post("")
+async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONResponse:
+ """更新用户信息接口"""
+ # 更新用户信息
+
+ result = await UserManager.update_userinfo_by_user_sub(request.state.user_sub, data)
+ if not result:
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"},
+ )
+ return JSONResponse(
+ status_code=status.HTTP_404_NOT_FOUND,
+ content={"code": status.HTTP_404_NOT_FOUND, "message": "用户信息更新失败"},
+ )
diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py
index cb8e183e9ac6036ac1a170d99b0da52b6e2a2d53..4a88ecb6cccf617f90761671b799c3faa7ade55a 100644
--- a/apps/scheduler/executor/agent.py
+++ b/apps/scheduler/executor/agent.py
@@ -2,13 +2,35 @@
"""MCP Agent执行器"""
import logging
+import uuid
+from typing import Any
from pydantic import Field
+from apps.llm.patterns.rewrite import QuestionRewrite
+from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.executor.base import BaseExecutor
-from apps.scheduler.mcp_agent import host, plan, select
-from apps.schemas.task import ExecutorState, StepQueueItem
+from apps.scheduler.mcp_agent.host import MCPHost
+from apps.scheduler.mcp_agent.plan import MCPPlanner
+from apps.scheduler.mcp_agent.select import FINAL_TOOL_ID, MCPSelector
+from apps.scheduler.pool.mcp.client import MCPClient
+from apps.schemas.enum_var import EventType, FlowStatus, SpecialCallType, StepStatus
+from apps.schemas.mcp import (
+ ErrorType,
+ GoalEvaluationResult,
+ MCPCollection,
+ MCPPlan,
+ MCPTool,
+ RestartStepIndex,
+ ToolExcutionErrorType,
+ ToolRisk,
+)
+from apps.schemas.message import param
+from apps.schemas.task import ExecutorState, FlowStepHistory, StepQueueItem
+from apps.services.appcenter import AppCenterManager
+from apps.services.mcp_service import MCPServiceManager
from apps.services.task import TaskManager
+from apps.services.user import UserManager
logger = logging.getLogger(__name__)
@@ -16,15 +38,416 @@ logger = logging.getLogger(__name__)
class MCPAgentExecutor(BaseExecutor):
"""MCP Agent执行器"""
- question: str = Field(description="用户输入")
max_steps: int = Field(default=20, description="最大步数")
servers_id: list[str] = Field(description="MCP server id")
agent_id: str = Field(default="", description="Agent ID")
agent_description: str = Field(default="", description="Agent描述")
+ mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[])
+ mcp_client: dict[str, MCPClient] = Field(
+ description="MCP客户端列表,key为mcp_id", default={},
+ )
+ tools: dict[str, MCPTool] = Field(
+ description="MCP工具列表,key为tool_id", default={},
+ )
+ params: param | None = Field(
+ default=None, description="流执行过程中的参数补充", alias="params",
+ )
+ resoning_llm: ReasoningLLM = Field(
+ default=ReasoningLLM(),
+ description="推理大模型",
+ )
+
+ async def update_tokens(self) -> None:
+ """更新令牌数"""
+ self.task.tokens.input_tokens = self.resoning_llm.input_tokens
+ self.task.tokens.output_tokens = self.resoning_llm.output_tokens
+ await TaskManager.save_task(self.task.id, self.task)
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
- if self.task.state:
+ if self.task.state and self.task.state.flow_status != FlowStatus.INIT:
self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
+
+ async def load_mcp(self) -> None:
+ """加载MCP服务器列表"""
+ logger.info("[MCPAgentExecutor] 加载MCP服务器列表")
+ # 获取MCP服务器列表
+ app = await AppCenterManager.fetch_app_data_by_id(self.agent_id)
+ mcp_ids = app.mcp_service
+ for mcp_id in mcp_ids:
+ mcp_service = await MCPServiceManager.get_mcp_service(mcp_id)
+ if self.task.ids.user_sub not in mcp_service.activated:
+ logger.warning(
+ "[MCPAgentExecutor] 用户 %s 未启用MCP %s",
+ self.task.ids.user_sub,
+ mcp_id,
+ )
+ continue
+
+ self.mcp_list.append(mcp_service)
+ self.mcp_client[mcp_id] = await MCPHost.get_client(self.task.ids.user_sub, mcp_id)
+ for tool in mcp_service.tools:
+ self.tools[tool.id] = tool
+
+ async def plan(self, is_replan: bool = False, start_index: int | None = None) -> None:
+ if is_replan:
+ error_message = "之前的计划遇到以下报错\n\n"+self.task.state.error_message
+ else:
+ error_message = "初始化计划"
+ tools = MCPSelector.select_top_tool(
+ self.task.runtime.question, list(self.tools.values()),
+ additional_info=error_message, top_n=40)
+ if is_replan:
+ logger.info("[MCPAgentExecutor] 重新规划流程")
+ if not start_index:
+ start_index = await MCPPlanner.get_replan_start_step_index(self.task.runtime.question,
+ self.task.state.error_message,
+ self.task.runtime.temporary_plans,
+ self.resoning_llm)
+ current_plan = self.task.runtime.temporary_plans.plans[start_index:]
+ error_message = self.task.state.error_message
+ temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question,
+ is_replan=is_replan,
+ error_message=error_message,
+ current_plan=current_plan,
+ tool_list=tools,
+ max_steps=self.max_steps-start_index-1,
+ reasoning_llm=self.resoning_llm
+ )
+ self.update_tokens()
+ self.push_message(
+ EventType.STEP_CANCEL,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ self.task.context[-1].step_status = StepStatus.CANCELLED
+ self.task.runtime.temporary_plans = self.task.runtime.temporary_plans.plans[:start_index] + temporary_plans.plans
+ self.task.state.step_index = start_index
+ else:
+ start_index = 0
+ self.task.runtime.temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question, tool_list=tools, max_steps=self.max_steps, reasoning_llm=self.resoning_llm)
+ for i in range(start_index, len(self.task.runtime.temporary_plans.plans)):
+ self.task.runtime.temporary_plans.plans[i].step_id = str(uuid.uuid4())
+
+ async def get_tool_input_param(self, is_first: bool) -> None:
+ if is_first:
+ # 获取第一个输入参数
+ self.task.state.current_input = await MCPHost._get_first_input_params(self.tools[self.task.state.step_id], self.task.runtime.question, self.task)
+ else:
+ # 获取后续输入参数
+ if isinstance(self.params, param):
+ params = self.params.content
+ params_description = self.params.description
+ else:
+ params = {}
+ params_description = ""
+ self.task.state.current_input = await MCPHost._fill_params(self.tools[self.task.state.step_id], self.task.state.current_input, self.task.state.error_message, params, params_description)
+
+ async def reset_step_to_index(self, start_index: int) -> None:
+ """重置步骤到开始"""
+ logger.info("[MCPAgentExecutor] 重置步骤到索引 %d", start_index)
+ if self.task.runtime.temporary_plans:
+ self.task.state.flow_status = FlowStatus.RUNNING
+ self.task.state.step_id = self.task.runtime.temporary_plans.plans[start_index].step_id
+ self.task.state.step_index = 0
+ self.task.state.step_name = self.task.runtime.temporary_plans.plans[start_index].tool
+ self.task.state.step_description = self.task.runtime.temporary_plans.plans[start_index].content
+ self.task.state.step_status = StepStatus.RUNNING
+ self.task.state.retry_times = 0
+ else:
+ self.task.state.flow_status = FlowStatus.SUCCESS
+ self.task.state.step_id = FINAL_TOOL_ID
+
+ async def confirm_before_step(self) -> None:
+ logger.info("[MCPAgentExecutor] 等待用户确认步骤 %d", self.task.state.step_index)
+ # 发送确认消息
+ confirm_message = await MCPPlanner.get_tool_risk(self.tools[self.task.state.step_id], self.task.state.current_input, "", self.resoning_llm)
+ self.update_tokens()
+ self.push_message(EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(
+ exclude_none=True, by_alias=True))
+ self.push_message(EventType.FLOW_STOP, {})
+ self.task.state.flow_status = FlowStatus.WAITING
+ self.task.state.step_status = StepStatus.WAITING
+ self.task.context.append(
+ FlowStepHistory(
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True),
+ )
+ )
+
+ async def run_step(self):
+ """执行步骤"""
+ self.task.state.flow_status = FlowStatus.RUNNING
+ self.task.state.step_status = StepStatus.RUNNING
+ logger.info("[MCPAgentExecutor] 执行步骤 %d", self.task.state.step_index)
+ # 获取MCP客户端
+ mcp_tool = self.tools[self.task.state.step_id]
+ mcp_client = self.mcp_client[mcp_tool.mcp_id]
+ if not mcp_client:
+ logger.error("[MCPAgentExecutor] MCP客户端未找到: %s", mcp_tool.mcp_id)
+ self.task.state.flow_status = FlowStatus.ERROR
+ error = "[MCPAgentExecutor] MCP客户端未找到: {}".format(mcp_tool.mcp_id)
+ self.task.state.error_message = error
+ try:
+ output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input)
+ self.update_tokens()
+ self.push_message(
+ EventType.STEP_INPUT,
+ self.task.state.current_input
+ )
+ self.push_message(
+ EventType.STEP_OUTPUT,
+ output_params
+ )
+ self.task.context.append(
+ FlowStepHistory(
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=StepStatus.SUCCESS,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data=self.task.state.current_input,
+ output_data=output_params,
+ )
+ )
+ self.task.state.step_status = StepStatus.SUCCESS
+ except Exception as e:
+ logger.warning("[MCPAgentExecutor] 执行步骤 %s 失败: %s", mcp_tool.name, str(e))
+ import traceback
+ self.task.state.error_message = traceback.format_exc()
+ self.task.state.step_status = StepStatus.ERROR
+
+ async def generate_params_with_null(self) -> None:
+ """生成参数补充"""
+ mcp_tool = self.tools[self.task.state.step_id]
+ params_with_null = await MCPPlanner.get_missing_param(
+ mcp_tool,
+ self.task.state.current_input,
+ self.task.state.error_message,
+ self.resoning_llm
+ )
+ self.update_tokens()
+ self.push_message(
+ EventType.STEP_WAITING_FOR_PARAM,
+ data={
+ "message": "当运行产生如下报错:\n" + self.task.state.error_message,
+ "params": params_with_null
+ }
+ )
+ self.push_message(
+ EventType.FLOW_STOP,
+ data={}
+ )
+ self.task.state.flow_status = FlowStatus.WAITING
+ self.task.state.step_status = StepStatus.PARAM
+ self.task.context.append(
+ FlowStepHistory(
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ ex_data={
+ "message": "当运行产生如下报错:\n" + self.task.state.error_message,
+ "params": params_with_null
+ }
+ )
+ )
+
+ async def get_next_step(self) -> None:
+ self.task.state.step_index += 1
+ if self.task.state.step_index < len(self.task.runtime.temporary_plans):
+ if self.task.runtime.temporary_plans.plans[self.task.state.step_index].step_id == FINAL_TOOL_ID:
+ # 最后一步
+ self.task.state.flow_status = FlowStatus.SUCCESS
+ self.task.state.step_status = StepStatus.SUCCESS
+ self.push_message(
+ EventType.FLOW_SUCCESS,
+ data={}
+ )
+ return
+ self.task.state.step_id = self.task.runtime.temporary_plans.plans[self.task.state.step_index].step_id
+ self.task.state.step_name = self.task.runtime.temporary_plans.plans[self.task.state.step_index].tool
+ self.task.state.step_description = self.task.runtime.temporary_plans.plans[self.task.state.step_index].content
+ self.task.state.step_status = StepStatus.INIT
+ self.task.state.current_input = {}
+ self.push_message(
+ EventType.STEP_INIT,
+ data={}
+ )
+ else:
+ # 没有下一步了,结束流程
+ self.task.state.flow_status = FlowStatus.SUCCESS
+ self.task.state.step_status = StepStatus.SUCCESS
+ self.push_message(
+ EventType.FLOW_SUCCESS,
+ data={}
+ )
+ return
+
+ async def error_handle_after_step(self) -> None:
+ """步骤执行失败后的错误处理"""
+ self.task.state.step_status = StepStatus.ERROR
+ self.task.state.flow_status = FlowStatus.ERROR
+ self.push_message(
+ EventType.FLOW_FAILED,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ self.task.context.append(
+ FlowStepHistory(
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ )
+ )
+
+ async def work(self) -> None:
+ """执行当前步骤"""
+ if self.task.state.step_status == StepStatus.INIT:
+ await self.get_tool_input_param(is_first=True)
+ user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub)
+ if not user_info.auto_execute:
+ # 等待用户确认
+ await self.confirm_before_step()
+ return
+ self.task.state.step_status = StepStatus.RUNNING
+ elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]:
+ if self.task.context[-1].step_status == StepStatus.PARAM:
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ elif self.task.state.step_status == StepStatus.WAITING:
+ if self.params.content:
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ else:
+ self.task.state.flow_status = FlowStatus.CANCELLED
+ self.task.state.step_status = StepStatus.CANCELLED
+ self.push_message(
+ EventType.STEP_CANCEL,
+ data={}
+ )
+ self.push_message(
+ EventType.FLOW_CANCEL,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ self.task.context[-1].step_status = StepStatus.CANCELLED
+ if self.task.state.step_status == StepStatus.PARAM:
+ await self.get_tool_input_param(is_first=False)
+ max_retry = 5
+ for i in range(max_retry):
+ if i != 0:
+ await self.get_tool_input_param(is_first=False)
+ await self.run_step()
+ if self.task.state.step_status == StepStatus.SUCCESS:
+ break
+ elif self.task.state.step_status == StepStatus.ERROR:
+ # 错误处理
+ if self.task.state.retry_times >= 3:
+ await self.error_handle_after_step()
+ else:
+ user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub)
+ mcp_tool = self.tools[self.task.state.step_id]
+ error_type = await MCPPlanner.get_tool_execute_error_type(
+ self.task.runtime.question,
+ self.task.runtime.temporary_plans,
+ mcp_tool,
+ self.task.state.current_input,
+ self.task.state.error_message,
+ self.resoning_llm
+ )
+ if error_type.type == ErrorType.DECORRECT_PLAN or user_info.auto_execute:
+ await self.plan(is_replan=True)
+ self.reset_step_to_index(self.task.state.step_index)
+ elif error_type.type == ErrorType.MISSING_PARAM:
+ await self.generate_params_with_null()
+ elif self.task.state.step_status == StepStatus.SUCCESS:
+ await self.get_next_step()
+
+ async def summarize(self) -> None:
+ async for chunk in MCPPlanner.generate_answer(
+ self.task.runtime.question,
+ self.task.runtime.temporary_plans,
+ (await MCPHost.assemble_memory(self.task)),
+ self.resoning_llm
+ ):
+ self.push_message(
+ EventType.TEXT_ADD,
+ data=chunk
+ )
+ self.task.runtime.answer += chunk
+
+ async def run(self) -> None:
+ """执行MCP Agent的主逻辑"""
+ # 初始化MCP服务
+ self.load_state()
+ self.load_mcp()
+ if self.task.state.flow_status == FlowStatus.INIT:
+ # 初始化状态
+ self.task.state.flow_id = str(uuid.uuid4())
+ self.task.state.flow_name = await MCPPlanner.get_flow_name(self.task.runtime.question, self.resoning_llm)
+ await self.plan(is_replan=False)
+ self.reset_step_to_index(0)
+ TaskManager.save_task(self.task.id, self.task)
+ self.task.state.flow_status = FlowStatus.RUNNING
+ self.push_message(
+ EventType.FLOW_START,
+ data={}
+ )
+ try:
+ while self.task.state.step_index < len(self.task.runtime.temporary_plans) and \
+ self.task.state.flow_status == FlowStatus.RUNNING:
+ await self.work()
+ TaskManager.save_task(self.task.id, self.task)
+ except Exception as e:
+ logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e))
+ self.task.state.flow_status = FlowStatus.ERROR
+ self.task.state.error_message = str(e)
+ self.task.state.step_status = StepStatus.ERROR
+ self.push_message(
+ EventType.STEP_ERROR,
+ data={}
+ )
+ self.push_message(
+ EventType.FLOW_FAILED,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ self.task.context.append(
+ FlowStepHistory(
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ )
+ )
diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py
index cf2f4e6838f5d8d5b6578db2750543ddfd22ee75..86e72483521794242b730a58af48e1dedd7fbebb 100644
--- a/apps/scheduler/executor/base.py
+++ b/apps/scheduler/executor/base.py
@@ -49,10 +49,8 @@ class BaseExecutor(BaseModel, ABC):
question=self.question,
params=self.task.runtime.filled,
).model_dump(exclude_none=True, by_alias=True)
- elif event_type == EventType.FLOW_STOP.value:
- data = {}
elif event_type == EventType.TEXT_ADD.value and isinstance(data, str):
- data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True)
+ data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True)
if data is None:
data = {}
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index 0f4978a1eec3f1a05d0e829defb455d4088e82ab..08b34eccae6f1412ad0ef860621bcd332fb91919 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -57,7 +57,7 @@ class FlowExecutor(BaseExecutor):
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
- if self.task.state:
+ if self.task.state and self.task.state.flow_status != FlowStatus.INIT:
self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
else:
# 创建ExecutorState
@@ -125,7 +125,7 @@ class FlowExecutor(BaseExecutor):
return []
if self.current_step.step.type == SpecialCallType.CHOICE.value:
# 如果是choice节点,获取分支ID
- branch_id = self.task.context[-1]["output_data"]["branch_id"]
+ branch_id = self.task.context[-1].output_data["branch_id"]
if branch_id:
next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id)
logger.info("[FlowExecutor] 分支ID:%s", branch_id)
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index c9d30e543e131b7d382a00171e9278f0ad57ab85..233c8999af5aa882b23b3f962a20b6a57a473f11 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -263,7 +263,7 @@ class StepExecutor(BaseExecutor):
input_data=self.obj.input,
output_data=output_data,
)
- self.task.context.append(history.model_dump(exclude_none=True, by_alias=True))
+ self.task.context.append(history)
# 推送输出
await self.push_message(EventType.STEP_OUTPUT.value, output_data)
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index 6b78e65dbb66d2e2ce3782543f2a0698327df373..443fd29d6e9d126bd71c960599cbc68864952e50 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -65,7 +65,7 @@ class MCPHost:
context_list = []
for ctx_id in self._context_list:
- context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None)
+ context = next((ctx for ctx in task.context if ctx.id == ctx_id), None)
if not context:
continue
context_list.append(context)
@@ -117,7 +117,7 @@ class MCPHost:
logger.error("任务 %s 不存在", self._task_id)
return {}
self._context_list.append(context.id)
- task.context.append(context.model_dump(by_alias=True, exclude_none=True))
+ task.context.append(context.model_dump(exclude_none=True, by_alias=True))
await TaskManager.save_task(self._task_id, task)
return output_data
diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py
index a8ebec7b4d1182684a02e5b9f9ce68189af6c2e4..a56ce03bb5331bb3d4c7cf9fc20bcb9dda564128 100644
--- a/apps/scheduler/mcp_agent/host.py
+++ b/apps/scheduler/mcp_agent/host.py
@@ -12,118 +12,55 @@ from mcp.types import TextContent
from apps.common.mongo import MongoDB
from apps.llm.function import JsonGenerator
from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE
+from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS
from apps.scheduler.pool.mcp.client import MCPClient
from apps.scheduler.pool.mcp.pool import MCPPool
from apps.schemas.enum_var import StepStatus
from apps.schemas.mcp import MCPPlanItem, MCPTool
-from apps.schemas.task import FlowStepHistory
+from apps.schemas.task import FlowStepHistory, Task
from apps.services.task import TaskManager
logger = logging.getLogger(__name__)
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=True,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
+
class MCPHost:
"""MCP宿主服务"""
- def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None:
- """初始化MCP宿主"""
- self._user_sub = user_sub
- self._task_id = task_id
- # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description
- self._runtime_id = runtime_id
- self._runtime_name = runtime_name
- self._context_list = []
- self._env = SandboxedEnvironment(
- loader=BaseLoader(),
- autoescape=False,
- trim_blocks=True,
- lstrip_blocks=True,
- )
-
- async def get_client(self, mcp_id: str) -> MCPClient | None:
+ @staticmethod
+ async def get_client(user_sub, mcp_id: str) -> MCPClient | None:
"""获取MCP客户端"""
mongo = MongoDB()
mcp_collection = mongo.get_collection("mcp")
# 检查用户是否启用了这个mcp
- mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub})
+ mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub})
if not mcp_db_result:
- logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id)
+ logger.warning("用户 %s 未启用MCP %s", user_sub, mcp_id)
return None
# 获取MCP配置
try:
- return await MCPPool().get(mcp_id, self._user_sub)
+ return await MCPPool().get(mcp_id, user_sub)
except KeyError:
- logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id)
+ logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", user_sub, mcp_id)
return None
- async def assemble_memory(self) -> str:
+ @staticmethod
+ async def assemble_memory(task: Task) -> str:
"""组装记忆"""
- task = await TaskManager.get_task_by_task_id(self._task_id)
- if not task:
- logger.error("任务 %s 不存在", self._task_id)
- return ""
-
- context_list = []
- for ctx_id in self._context_list:
- context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None)
- if not context:
- continue
- context_list.append(context)
- return self._env.from_string(MEMORY_TEMPLATE).render(
- context_list=context_list,
- )
-
- async def _save_memory(
- self,
- tool: MCPTool,
- plan_item: MCPPlanItem,
- input_data: dict[str, Any],
- result: str,
- ) -> dict[str, Any]:
- """保存记忆"""
- try:
- output_data = json.loads(result)
- except Exception: # noqa: BLE001
- logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str")
- output_data = {
- "message": result,
- }
-
- if not isinstance(output_data, dict):
- output_data = {
- "message": result,
- }
-
- # 创建context;注意用法
- context = FlowStepHistory(
- task_id=self._task_id,
- flow_id=self._runtime_id,
- flow_name=self._runtime_name,
- flow_status=StepStatus.SUCCESS,
- step_id=tool.name,
- step_name=tool.name,
- # description是规划的实际内容
- step_description=plan_item.content,
- step_status=StepStatus.SUCCESS,
- input_data=input_data,
- output_data=output_data,
+ return _env.from_string(MEMORY_TEMPLATE).render(
+ context_list=task.context,
)
- # 保存到task
- task = await TaskManager.get_task_by_task_id(self._task_id)
- if not task:
- logger.error("任务 %s 不存在", self._task_id)
- return {}
- self._context_list.append(context.id)
- task.context.append(context.model_dump(by_alias=True, exclude_none=True))
- await TaskManager.save_task(self._task_id, task)
-
- return output_data
-
- async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]:
+ async def _get_first_input_params(mcp_tool: MCPTool, query: str, task: Task) -> dict[str, Any]:
"""填充工具参数"""
# 更清晰的输入·指令,这样可以调用generate
llm_query = rf"""
@@ -137,23 +74,48 @@ class MCPHost:
llm_query,
[
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": await self.assemble_memory()},
+ {"role": "user", "content": await MCPHost.assemble_memory(task)},
],
- schema,
+ mcp_tool.input_schema,
)
return await json_generator.generate()
- async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]:
+ async def _fill_params(mcp_tool: MCPTool,
+ current_input: dict[str, Any],
+ error_message: str = "", params: dict[str, Any] = {},
+ params_description: str = "") -> dict[str, Any]:
+ llm_query = "请生成修复之后的工具参数"
+ prompt = _env.from_string(REPAIR_PARAMS).render(
+ tool_name=mcp_tool.name,
+ tool_description=mcp_tool.description,
+ input_schema=mcp_tool.input_schema,
+ current_input=current_input,
+ error_message=error_message,
+ params=params,
+ params_description=params_description,
+ )
+
+ json_generator = JsonGenerator(
+ llm_query,
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ],
+ mcp_tool.input_schema,
+ )
+ return await json_generator.generate()
+
+ async def call_tool(user_sub: str, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]:
"""调用工具"""
# 拿到Client
- client = await MCPPool().get(tool.mcp_id, self._user_sub)
+ client = await MCPPool().get(tool.mcp_id, user_sub)
if client is None:
err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}"
logger.error(err)
raise ValueError(err)
# 填充参数
- params = await self._fill_params(tool, plan_item.instruction)
+ params = await MCPHost._fill_params(tool, plan_item.instruction)
# 调用工具
result = await client.call_tool(tool.name, params)
# 保存记忆
@@ -162,29 +124,12 @@ class MCPHost:
if not isinstance(item, TextContent):
logger.error("MCP结果类型不支持: %s", item)
continue
- processed_result.append(await self._save_memory(tool, plan_item, params, item.text))
-
- return processed_result
-
- async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]:
- """获取工具列表"""
- mongo = MongoDB()
- mcp_collection = mongo.get_collection("mcp")
-
- # 获取工具列表
- tool_list = []
- for mcp_id in mcp_id_list:
- # 检查用户是否启用了这个mcp
- mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub})
- if not mcp_db_result:
- logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id)
- continue
- # 获取MCP工具配置
+ result = item.text
try:
- for tool in mcp_db_result["tools"]:
- tool_list.extend([MCPTool.model_validate(tool)])
- except KeyError:
- logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id)
+ json_result = json.loads(result)
+ except Exception as e:
+ logger.error("MCP结果解析失败: %s, 错误: %s", result, e)
continue
+ processed_result.append(json_result)
- return tool_list
+ return processed_result
\ No newline at end of file
diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py
index ec90f57b8bb490db64e3a4652933ad3b369ad4d9..474a7010c3b3739516c747a3520d43d27fbb460f 100644
--- a/apps/scheduler/mcp_agent/plan.py
+++ b/apps/scheduler/mcp_agent/plan.py
@@ -1,6 +1,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 用户目标拆解与规划"""
-from typing import Any
+from typing import Any, AsyncGenerator
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
@@ -13,30 +13,26 @@ from apps.scheduler.mcp_agent.prompt import (
FINAL_ANSWER,
GENERATE_FLOW_NAME,
GET_MISSING_PARAMS,
+ GET_REPLAN_START_STEP_INDEX,
RECREATE_PLAN,
RISK_EVALUATE,
+ TOOL_EXECUTE_ERROR_TYPE_ANALYSIS,
)
from apps.scheduler.slot.slot import Slot
-from apps.schemas.mcp import GoalEvaluationResult, MCPPlan, MCPTool, ToolRisk
+from apps.schemas.mcp import GoalEvaluationResult, MCPPlan, MCPTool, RestartStepIndex, ToolExcutionErrorType, ToolRisk
+
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=True,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
class MCPPlanner:
"""MCP 用户目标拆解与规划"""
-
- def __init__(self, user_goal: str, resoning_llm: ReasoningLLM = None) -> None:
- """初始化MCP规划器"""
- self.user_goal = user_goal
- self._env = SandboxedEnvironment(
- loader=BaseLoader,
- autoescape=True,
- trim_blocks=True,
- lstrip_blocks=True,
- )
- self.resoning_llm = resoning_llm or ReasoningLLM()
- self.input_tokens = 0
- self.output_tokens = 0
-
- async def get_resoning_result(self, prompt: str) -> str:
+ @staticmethod
+ async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
"""获取推理结果"""
# 调用推理大模型
message = [
@@ -44,7 +40,7 @@ class MCPPlanner:
{"role": "user", "content": prompt},
]
result = ""
- async for chunk in self.resoning_llm.call(
+ async for chunk in resoning_llm.call(
message,
streaming=False,
temperature=0.07,
@@ -52,12 +48,10 @@ class MCPPlanner:
):
result += chunk
- # 保存token用量
- self.input_tokens = self.resoning_llm.input_tokens
- self.output_tokens = self.resoning_llm.output_tokens
return result
- async def _parse_result(self, result: str, schema: dict[str, Any]) -> str:
+ @staticmethod
+ async def _parse_result(result: str, schema: dict[str, Any]) -> str:
"""解析推理结果"""
json_generator = JsonGenerator(
result,
@@ -70,155 +64,237 @@ class MCPPlanner:
json_result = await json_generator.generate()
return json_result
- async def evaluate_goal(self, tool_list: list[MCPTool]) -> GoalEvaluationResult:
+ @staticmethod
+ async def evaluate_goal(
+ tool_list: list[MCPTool],
+ resoning_llm: ReasoningLLM = ReasoningLLM()) -> GoalEvaluationResult:
"""评估用户目标的可行性"""
# 获取推理结果
- result = await self._get_reasoning_evaluation(tool_list)
+ result = await MCPPlanner._get_reasoning_evaluation(tool_list, resoning_llm)
# 解析为结构化数据
- evaluation = await self._parse_evaluation_result(result)
+ evaluation = await MCPPlanner._parse_evaluation_result(result)
# 返回评估结果
return evaluation
- async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str:
+ @staticmethod
+ async def _get_reasoning_evaluation(
+ goal, tool_list: list[MCPTool],
+ resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
"""获取推理大模型的评估结果"""
- template = self._env.from_string(EVALUATE_GOAL)
+ template = _env.from_string(EVALUATE_GOAL)
prompt = template.render(
- goal=self.user_goal,
+ goal=goal,
tools=tool_list,
)
- result = await self.get_resoning_result(prompt)
+ result = await MCPPlanner.get_resoning_result(prompt, resoning_llm)
return result
- async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult:
+ @staticmethod
+ async def _parse_evaluation_result(result: str) -> GoalEvaluationResult:
"""将推理结果解析为结构化数据"""
schema = GoalEvaluationResult.model_json_schema()
- evaluation = await self._parse_result(result, schema)
+ evaluation = await MCPPlanner._parse_result(result, schema)
# 使用GoalEvaluationResult模型解析结果
return GoalEvaluationResult.model_validate(evaluation)
- async def get_flow_name(self) -> str:
+ async def get_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
"""获取当前流程的名称"""
- result = await self._get_reasoning_flow_name()
+ result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm)
return result
- async def _get_reasoning_flow_name(self) -> str:
+ @staticmethod
+ async def _get_reasoning_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
"""获取推理大模型的流程名称"""
- template = self._env.from_string(GENERATE_FLOW_NAME)
- prompt = template.render(goal=self.user_goal)
- result = await self.get_resoning_result(prompt)
+ template = _env.from_string(GENERATE_FLOW_NAME)
+ prompt = template.render(goal=user_goal)
+ result = await MCPPlanner.get_resoning_result(prompt, resoning_llm)
return result
- async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan:
+ @staticmethod
+ async def get_replan_start_step_index(
+ user_goal: str, error_message: str, current_plan: MCPPlan | None = None,
+ history: str = "",
+ reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan:
+ """获取重新规划的步骤索引"""
+ # 获取推理结果
+ template = _env.from_string(GET_REPLAN_START_STEP_INDEX)
+ prompt = template.render(
+ goal=user_goal,
+ error_message=error_message,
+ current_plan=current_plan.model_dump(exclude_none=True, by_alias=True),
+ history=history,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ # 解析为结构化数据
+ schema = RestartStepIndex.model_json_schema()
+ schema["properties"]["start_index"]["maximum"] = len(current_plan.plans) - 1
+ schema["properties"]["start_index"]["minimum"] = 0
+ restart_index = await MCPPlanner._parse_result(result, schema)
+ # 使用RestartStepIndex模型解析结果
+ return RestartStepIndex.model_validate(restart_index)
+
+ @staticmethod
+ async def create_plan(
+ user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None,
+ tool_list: list[MCPTool] = [],
+ max_steps: int = 6, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan:
"""规划下一步的执行流程,并输出"""
# 获取推理结果
- result = await self._get_reasoning_plan(tool_list, max_steps)
+ result = await MCPPlanner._get_reasoning_plan(user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm)
# 解析为结构化数据
- return await self._parse_plan_result(result, max_steps)
+ return await MCPPlanner._parse_plan_result(result, max_steps)
+ @staticmethod
async def _get_reasoning_plan(
- self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(),
+ user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None,
tool_list: list[MCPTool] = [],
- max_steps: int = 10) -> str:
+ max_steps: int = 10, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
"""获取推理大模型的结果"""
# 格式化Prompt
if is_replan:
- template = self._env.from_string(RECREATE_PLAN)
+ template = _env.from_string(RECREATE_PLAN)
prompt = template.render(
- current_plan=current_plan,
+ current_plan=current_plan.model_dump(exclude_none=True, by_alias=True),
error_message=error_message,
- goal=self.user_goal,
+ goal=user_goal,
tools=tool_list,
max_num=max_steps,
)
else:
- template = self._env.from_string(CREATE_PLAN)
+ template = _env.from_string(CREATE_PLAN)
prompt = template.render(
- goal=self.user_goal,
+ goal=user_goal,
tools=tool_list,
max_num=max_steps,
)
- result = await self.get_resoning_result(prompt)
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
return result
- async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan:
+ @staticmethod
+ async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan:
"""将推理结果解析为结构化数据"""
# 格式化Prompt
schema = MCPPlan.model_json_schema()
schema["properties"]["plans"]["maxItems"] = max_steps
- plan = await self._parse_result(result, schema)
+ plan = await MCPPlanner._parse_result(result, schema)
# 使用Function模型解析结果
return MCPPlan.model_validate(plan)
- async def get_tool_risk(self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "") -> ToolRisk:
+ @staticmethod
+ async def get_tool_risk(
+ tool: MCPTool, input_parm: dict[str, Any],
+ additional_info: str = "", resoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolRisk:
"""获取MCP工具的风险评估结果"""
# 获取推理结果
- result = await self._get_reasoning_risk(tool, input_parm, additional_info)
+ result = await MCPPlanner._get_reasoning_risk(tool, input_parm, additional_info, resoning_llm)
# 解析为结构化数据
- risk = await self._parse_risk_result(result)
+ risk = await MCPPlanner._parse_risk_result(result)
# 返回风险评估结果
return risk
- async def _get_reasoning_risk(self, tool: MCPTool, input_param: dict[str, Any], additional_info: str) -> str:
+ @staticmethod
+ async def _get_reasoning_risk(
+ tool: MCPTool, input_param: dict[str, Any],
+ additional_info: str, resoning_llm: ReasoningLLM) -> str:
"""获取推理大模型的风险评估结果"""
- template = self._env.from_string(RISK_EVALUATE)
+ template = _env.from_string(RISK_EVALUATE)
prompt = template.render(
- tool=tool,
+ tool_name=tool.name,
+ tool_description=tool.description,
input_param=input_param,
additional_info=additional_info,
)
- result = await self.get_resoning_result(prompt)
+ result = await MCPPlanner.get_resoning_result(prompt, resoning_llm)
return result
- async def _parse_risk_result(self, result: str) -> ToolRisk:
+ @staticmethod
+ async def _parse_risk_result(result: str) -> ToolRisk:
"""将推理结果解析为结构化数据"""
schema = ToolRisk.model_json_schema()
- risk = await self._parse_result(result, schema)
+ risk = await MCPPlanner._parse_result(result, schema)
# 使用ToolRisk模型解析结果
return ToolRisk.model_validate(risk)
+ @staticmethod
+ async def _get_reasoning_tool_execute_error_type(
+ user_goal: str, current_plan: MCPPlan,
+ tool: MCPTool, input_param: dict[str, Any],
+ error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取推理大模型的工具执行错误类型"""
+ template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS)
+ prompt = template.render(
+ goal=user_goal,
+ current_plan=current_plan.model_dump(exclude_none=True, by_alias=True),
+ tool_name=tool.name,
+ tool_description=tool.description,
+ input_param=input_param,
+ error_message=error_message,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ return result
+
+ @staticmethod
+ async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType:
+ """将推理结果解析为工具执行错误类型"""
+ schema = ToolExcutionErrorType.model_json_schema()
+ error_type = await MCPPlanner._parse_result(result, schema)
+ # 使用ToolExcutionErrorType模型解析结果
+ return ToolExcutionErrorType.model_validate(error_type)
+
+ @staticmethod
+ async def get_tool_execute_error_type(
+ user_goal: str, current_plan: MCPPlan,
+ tool: MCPTool, input_param: dict[str, Any],
+ error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolExcutionErrorType:
+ """获取MCP工具执行错误类型"""
+ # 获取推理结果
+ result = await MCPPlanner._get_reasoning_tool_execute_error_type(
+ user_goal, current_plan, tool, input_param, error_message, reasoning_llm)
+ error_type = await MCPPlanner._parse_tool_execute_error_type_result(result)
+ # 返回工具执行错误类型
+ return error_type
+
+ @staticmethod
async def get_missing_param(
- self, tool: MCPTool, schema: dict[str, Any],
+ tool: MCPTool,
input_param: dict[str, Any],
- error_message: str) -> list[str]:
+ error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> list[str]:
"""获取缺失的参数"""
- slot = Slot(schema=schema)
+ slot = Slot(schema=tool.input_schema)
+ template = _env.from_string(GET_MISSING_PARAMS)
schema_with_null = slot.add_null_to_basic_types()
- template = self._env.from_string(GET_MISSING_PARAMS)
prompt = template.render(
- tool=tool,
+ tool_name=tool.name,
+ tool_description=tool.description,
input_param=input_param,
schema=schema_with_null,
error_message=error_message,
)
- result = await self.get_resoning_result(prompt)
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
# 解析为结构化数据
- input_param_with_null = await self._parse_result(result, schema_with_null)
+ input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null)
return input_param_with_null
- async def generate_answer(self, plan: MCPPlan, memory: str) -> str:
+ @staticmethod
+ async def generate_answer(
+ user_goal: str, plan: MCPPlan, memory: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> AsyncGenerator[
+ str, None]:
"""生成最终回答"""
- template = self._env.from_string(FINAL_ANSWER)
+ template = _env.from_string(FINAL_ANSWER)
prompt = template.render(
- plan=plan,
+ plan=plan.model_dump(exclude_none=True, by_alias=True),
memory=memory,
- goal=self.user_goal,
+ goal=user_goal,
)
- llm = ReasoningLLM()
- result = ""
- async for chunk in llm.call(
+ async for chunk in resoning_llm.call(
[{"role": "user", "content": prompt}],
streaming=False,
temperature=0.07,
):
- result += chunk
-
- self.input_tokens = llm.input_tokens
- self.output_tokens = llm.output_tokens
-
- return result
+ yield chunk
\ No newline at end of file
diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py
index cf05afc877e9ce08d790799d08e45c1f321d2d21..b5bc085c642eb1cf800a3a37684e594e041a535b 100644
--- a/apps/scheduler/mcp_agent/prompt.py
+++ b/apps/scheduler/mcp_agent/prompt.py
@@ -62,6 +62,62 @@ MCP_SELECT = dedent(r"""
### 请一步一步思考:
""")
+TOOL_SELECT = dedent(r"""
+ 你是一个乐于助人的智能助手。
+ 你的任务是:根据当前目标,附加信息,选择最合适的MCP工具。
+ ## 选择MCP工具时的注意事项:
+ 1. 确保充分理解当前目标,选择实现目标所需的MCP工具。
+ 2. 请在给定的MCP工具列表中选择,不要自己生成MCP工具。
+ 3. 可以选择一些辅助工具,但必须确保这些工具与当前目标相关。
+ 必须按照以下格式生成选择结果,不要输出任何其他内容:
+ ```json
+ {
+ "tool_ids": ["工具ID1", "工具ID2", ...]
+ }
+ ```
+
+ # 示例
+ ## 目标
+ 调优mysql性能
+ ## MCP工具列表
+
+ - mcp_tool_1 MySQL链接池工具;用于优化MySQL链接池
+ - mcp_tool_2 MySQL性能调优工具;用于分析MySQL性能瓶颈
+ - mcp_tool_3 MySQL查询优化工具;用于优化MySQL查询语句
+ - mcp_tool_4 MySQL索引优化工具;用于优化MySQL索引
+ - mcp_tool_5 文件存储工具;用于存储文件
+ - mcp_tool_6 mongoDB工具;用于操作MongoDB数据库
+
+ ## 附加信息
+ 1. 当前MySQL数据库的版本是8.0.26
+ 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项
+ ```json
+ {
+ "max_connections": 1000,
+ "innodb_buffer_pool_size": "1G",
+ "query_cache_size": "64M"
+ }
+ ##输出
+ ```json
+ {
+ "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"]
+ }
+ ```
+ # 现在开始!
+ ## 目标
+ {{goal}}
+ ## MCP工具列表
+
+ {% for tool in tools %}
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ {% endfor %}
+
+ ## 附加信息
+ {{additional_info}}
+ # 输出
+ """
+ )
+
EVALUATE_GOAL = dedent(r"""
你是一个计划评估器。
请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。
@@ -76,18 +132,18 @@ EVALUATE_GOAL = dedent(r"""
```
# 样例
- ## 目标
- 我需要扫描当前mysql数据库,分析性能瓶颈,并调优
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
- ## 工具集合
- 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+ # 工具集合
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
- - mysql_analyzer分析MySQL数据库性能
- - performance_tuner调优数据库性能
- - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+ - mysql_analyzer 分析MySQL数据库性能
+ - performance_tuner 调优数据库性能
+ - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
- ## 附加信息
+ # 附加信息
1. 当前MySQL数据库的版本是8.0.26
2. 当前MySQL数据库的配置文件路径是/etc/my.cnf
@@ -100,17 +156,17 @@ EVALUATE_GOAL = dedent(r"""
```
# 目标
- {{ goal }}
+ {{goal}}
# 工具集合
- {% for tool in tools %}
- - {{ tool.id }}{{tool.name}};{{ tool.description }}
- {% endfor %}
+ { % for tool in tools % }
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ { % endfor % }
# 附加信息
- {{ additional_info }}
+ {{additional_info}}
""")
GENERATE_FLOW_NAME = dedent(r"""
@@ -123,15 +179,79 @@ GENERATE_FLOW_NAME = dedent(r"""
4. 流程名称应该尽量简短,小于20个字或者单词。
5. 只输出流程名称,不要输出其他内容。
# 样例
- ## 目标
- 我需要扫描当前mysql数据库,分析性能瓶颈,并调优
- ## 输出
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 输出
扫描MySQL数据库并分析性能瓶颈,进行调优
# 现在开始生成流程名称:
# 目标
- {{ goal }}
+ {{goal}}
# 输出
""")
+GET_REPLAN_START_STEP_INDEX = dedent(r"""
+ 你是一个智能助手,你的任务是根据用户的目标、报错信息和当前计划和历史,获取重新规划的步骤起始索引。
+
+ # 样例
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 报错信息
+ 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。
+ # 当前计划
+ ```json
+ {
+ "plans": [
+ {
+ "step_id": "step_1",
+ "content": "生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描
+ },
+ {
+ "step_id": "step_2",
+ "content": "在执行Result[0]生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ }
+ ]
+ }
+ # 历史
+ [
+ {
+ id: "0",
+ task_id: "task_1",
+ flow_id: "flow_1",
+ flow_name: "MYSQL性能调优",
+ flow_status: "RUNNING",
+ step_id: "step_1",
+ step_name: "生成端口扫描命令",
+ step_description: "生成端口扫描命令:扫描当前MySQL数据库的端口",
+ step_status: "FAILED",
+ input_data: {
+ "command": "nmap -p 3306
+ "target": "localhost"
+ },
+ output_data: {
+ "error": "- bash: curl: command not found"
+ }
+ }
+ ]
+ # 输出
+ {
+ "start_index": 0,
+ "reasoning": "当前计划的第一步就失败了,报错信息显示curl命令未找到,可能是因为没有安装curl工具,因此需要从第一步重新规划。"
+ }
+ # 现在开始获取重新规划的步骤起始索引:
+ # 目标
+ {{goal}}
+ # 报错信息
+ {{error_message}}
+ # 当前计划
+ {{current_plan}}
+ # 历史
+ {{history}}
+ # 输出
+ """)
+
CREATE_PLAN = dedent(r"""
你是一个计划生成器。
请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。
@@ -163,40 +283,38 @@ CREATE_PLAN = dedent(r"""
}
```
- - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\
-思考过程应放置在 XML标签中。
+ - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。
+思考过程应放置在 XML标签中。
- 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。
- - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。
+ - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。
# 工具
- 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
- {% for tool in tools %}
- - {{ tool.id }}{{tool.name}};{{ tool.description }}
- {% endfor %}
+ { % for tool in tools % }
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ { % endfor % }
# 样例
- ## 目标
+ # 目标
- 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。
+ 在后台运行一个新的alpine: latest容器,将主机/root文件夹挂载至/data,并执行top命令。
- ## 计划
+ # 计划
- 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server
+ 1. 这个目标需要使用Docker来完成, 首先需要选择合适的MCP Server
2. 目标可以拆解为以下几个部分:
- - 运行alpine:latest容器
+ - 运行alpine: latest容器
- 挂载主机目录
- 在后台运行
- 执行top命令
- 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令
-
-
- ```json
+ 3. 需要先选择MCP Server, 然后生成Docker命令, 最后执行命令
+ ```json
{
"plans": [
{
@@ -225,7 +343,7 @@ CREATE_PLAN = dedent(r"""
# 现在开始生成计划:
- ## 目标
+ # 目标
{{goal}}
@@ -263,26 +381,24 @@ RECREATE_PLAN = dedent(r"""
}
```
- - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\
-思考过程应放置在 XML标签中。
+ - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。
+思考过程应放置在 XML标签中。
- 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。
- - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。
+ - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。
# 样例
- ## 目标
+ # 目标
请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。
- ## 工具
- 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+ # 工具
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
- - command_generator生成命令行指令
- - tool_selector选择合适的工具
- - command_executor执行命令行指令
- - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
-
-
- ## 当前计划
+ - command_generator 生成命令行指令
+ - tool_selector 选择合适的工具
+ - command_executor 执行命令行指令
+ - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+ # 当前计划
```json
{
"plans": [
@@ -304,25 +420,23 @@ RECREATE_PLAN = dedent(r"""
]
}
```
- ## 运行报错
- 执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。
- ## 重新生成的计划
+ # 运行报错
+ 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。
+ # 重新生成的计划
- 1. 这个目标需要使用网络扫描工具来完成,首先需要选择合适的网络扫描工具
+ 1. 这个目标需要使用网络扫描工具来完成, 首先需要选择合适的网络扫描工具
2. 目标可以拆解为以下几个部分:
- 生成端口扫描命令
- 执行端口扫描命令
- 3.但是在执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。
+ 3.但是在执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。
4.我将计划调整为:
- 需要先生成一个命令,查看当前机器支持哪些网络扫描工具
- 执行这个命令,查看当前机器支持哪些网络扫描工具
- 然后从中选择一个网络扫描工具
- 基于选择的网络扫描工具,生成端口扫描命令
- 执行端口扫描命令
-
-
- ```json
+ ```json
{
"plans": [
{
@@ -367,19 +481,19 @@ RECREATE_PLAN = dedent(r"""
# 工具
- 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
- {% for tool in tools %}
- - {{ tool.id }}{{tool.name}};{{ tool.description }}
- {% endfor %}
+ { % for tool in tools % }
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ { % endfor % }
# 当前计划
- {{ current_plan }}
+ {{current_plan}}
# 运行报错
- {{ error_message }}
+ {{error_message}}
# 重新生成的计划
""")
@@ -393,18 +507,18 @@ RISK_EVALUATE = dedent(r"""
}
```
# 样例
- ## 工具名称
+ # 工具名称
mysql_analyzer
- ## 工具描述
+ # 工具描述
分析MySQL数据库性能
- ## 工具入参
+ # 工具入参
{
"host": "192.0.0.1",
"port": 3306,
"username": "root",
"password": "password"
}
- ## 附加信息
+ # 附加信息
1. 当前MySQL数据库的版本是8.0.26
2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项
```ini
@@ -412,7 +526,7 @@ RISK_EVALUATE = dedent(r"""
innodb_buffer_pool_size=1G
innodb_log_file_size=256M
```
- ## 输出
+ # 输出
```json
{
"risk": "中",
@@ -421,35 +535,35 @@ RISK_EVALUATE = dedent(r"""
```
# 工具
- {{ tool.name }}
- {{ tool.description }}
+ {{tool_name}}
+ {{tool_description}}
# 工具入参
- {{ input_param }}
+ {{input_param}}
# 附加信息
- {{ additional_info }}
+ {{additional_info}}
# 输出
"""
)
# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划
-JUDGE_NEXT_STEP = dedent(r"""
+TOOL_EXECUTE_ERROR_TYPE_ANALYSIS = dedent(r"""
你是一个计划决策器。
- 你的任务是根据当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。
+ 你的任务是根据用户目标、当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。
请根据以下规则进行判断:
- 1. 仅通过补充工具入参来解决问题的,返回 fill_params;
- 2. 需要重计划当前步骤的,返回 replan_current_step;
- 3. 需要重计划接下来的所有计划的,返回 replan_all_steps;
+ 1. 仅通过补充工具入参来解决问题的,返回 missing_param;
+ 2. 需要重计划当前步骤的,返回 decorrect_plan
+ 3.推理过程必须清晰明了,能够让人理解你的判断依据,并且不超过100字。
你的输出要以json格式返回,格式如下:
```json
{
- "next_step": "fill_params/replan_current_step/replan_all_steps",
- "reason": "你的判断依据"
+ "error_type": "missing_param/decorrect_plan,
+ "reason": "你的推理过程"
}
```
- 注意:
- reason字段必须清晰明了,能够让人理解你的判断依据,并且不超过50个中文字或者100个英文单词。
# 样例
- ## 当前计划
+ # 用户目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 当前计划
{"plans": [
{
"content": "生成端口扫描命令",
@@ -464,41 +578,43 @@ JUDGE_NEXT_STEP = dedent(r"""
{
"content": "任务执行完成,端口扫描结果为Result[2]",
"tool": "Final",
- "instruction": ""
+ "instruction": ""
}
]}
- ## 当前使用的工具
+ # 当前使用的工具
- command_executor
- 执行命令行指令
+ command_executor
+ 执行命令行指令
- ## 工具入参
+ # 工具入参
{
"command": "nmap -sS -p--open 192.168.1.1"
}
- ## 工具运行报错
- 执行端口扫描命令时,出现了错误:`-bash: nmap: command not found`。
- ## 输出
+ # 工具运行报错
+ 执行端口扫描命令时,出现了错误:`- bash: nmap: command not found`。
+ # 输出
```json
{
- "next_step": "replan_all_steps",
- "reason": "当前工具执行报错,提示nmap命令未找到,需要增加command_generator和command_executor的步骤,生成nmap安装命令并执行,之后再生成端口扫描命令并执行。"
+ "error_type": "decorrect_plan",
+ "reason": "当前计划的第二步执行失败,报错信息显示nmap命令未找到,可能是因为没有安装nmap工具,因此需要重计划当前步骤。"
}
```
+ # 用户目标
+ {{goal}}
# 当前计划
- {{ current_plan }}
+ {{current_plan}}
# 当前使用的工具
- {{ tool.name }}
- {{ tool.description }}
+ {{tool_name}}
+ {{tool_description}}
# 工具入参
- {{ input_param }}
+ {{input_param}}
# 工具运行报错
- {{ error_message }}
+ {{error_message}}
# 输出
"""
- )
+ )
# 获取缺失的参数的json结构体
GET_MISSING_PARAMS = dedent(r"""
你是一个工具参数获取器。
@@ -570,10 +686,10 @@ GET_MISSING_PARAMS = dedent(r"""
}
```
# 工具
- < tool >
- < name > {{tool.name}} < /name >
- < description > {{tool.description}} < /description >
- < / tool >
+
+ {{tool_name}}
+ {{tool_description}}
+
# 工具入参
{{input_param}}
# 工具入参schema(部分字段允许为null)
@@ -583,6 +699,82 @@ GET_MISSING_PARAMS = dedent(r"""
# 输出
"""
)
+REPAIR_PARAMS = dedent(r"""
+ 你是一个工具参数修复器。
+ 你的任务是根据当前的工具信息、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。
+
+ # 样例
+ # 工具信息
+
+ mysql_analyzer
+ 分析MySQL数据库性能
+
+ # 工具入参的schema
+ {
+ "type": "object",
+ "properties": {
+ "host": {
+ "type": "string",
+ "description": "MySQL数据库的主机地址"
+ },
+ "port": {
+ "type": "integer",
+ "description": "MySQL数据库的端口号"
+ },
+ "username": {
+ "type": "string",
+ "description": "MySQL数据库的用户名"
+ },
+ "password": {
+ "type": "string",
+ "description": "MySQL数据库的密码"
+ }
+ },
+ "required": ["host", "port", "username", "password"]
+ }
+ # 工具当前的入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 工具的报错
+ 执行端口扫描命令时,出现了错误:`password is not correct`。
+ # 补充的参数
+ {
+ "username": "admin",
+ "password": "admin123"
+ }
+ # 补充的参数描述
+ 用户希望使用admin用户和admin123密码来连接MySQL数据库。
+ # 输出
+ ```json
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "admin",
+ "password": "admin123"
+ }
+ ```
+ # 工具
+
+ {{tool_name}}
+ {{tool_description}}
+
+ # 工具入参scheme
+ {{input_schema}}
+ # 工具入参
+ {{input_param}}
+ # 运行报错
+ {{error_message}}
+ # 补充的参数
+ {{params}}
+ # 补充的参数描述
+ {{params_description}}
+ # 输出
+ """
+ )
FINAL_ANSWER = dedent(r"""
综合理解计划执行结果和背景信息,向用户报告目标的完成情况。
@@ -603,12 +795,11 @@ FINAL_ANSWER = dedent(r"""
# 现在,请根据以上信息,向用户报告目标的完成情况:
""")
-
MEMORY_TEMPLATE = dedent(r"""
- { % for ctx in context_list % }
+ {% for ctx in context_list % }
- 第{{loop.index}}步:{{ctx.step_description}}
- 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}`
- 执行状态:{{ctx.status}}
- 得到数据:`{{ctx.output_data}}`
- { % endfor % }
+ {% endfor % }
""")
diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py
index 7198940bc52b31e9dc5d760084d5e80f3af435dd..f917a38ae78b0603c9b09c2d3526af637343a0ce 100644
--- a/apps/scheduler/mcp_agent/select.py
+++ b/apps/scheduler/mcp_agent/select.py
@@ -2,195 +2,100 @@
"""选择MCP Server及其工具"""
import logging
-import uuid
-from collections.abc import AsyncGenerator
+import random
+from typing import AsyncGenerator
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
+from apps.common.config import Config
from apps.common.lance import LanceDB
from apps.common.mongo import MongoDB
from apps.llm.embedding import Embedding
from apps.llm.function import FunctionLLM
from apps.llm.reasoning import ReasoningLLM
-from apps.scheduler.mcp.prompt import (
- MCP_SELECT,
-)
-from apps.schemas.mcp import (
- MCPCollection,
- MCPSelectResult,
- MCPTool,
-)
+from apps.llm.token import TokenCalculator
+from apps.scheduler.mcp_agent.prompt import TOOL_SELECT
+from apps.schemas.mcp import BaseModel, MCPCollection, MCPSelectResult, MCPTool, MCPToolIdsSelectResult
logger = logging.getLogger(__name__)
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=True,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
-class MCPSelector:
- """MCP选择器"""
-
- def __init__(self, resoning_llm: ReasoningLLM = None) -> None:
- """初始化助手类"""
- self.resoning_llm = resoning_llm or ReasoningLLM()
- self.input_tokens = 0
- self.output_tokens = 0
-
- @staticmethod
- def _assemble_sql(mcp_list: list[str]) -> str:
- """组装SQL"""
- sql = "("
- for mcp_id in mcp_list:
- sql += f"'{mcp_id}', "
- return sql.rstrip(", ") + ")"
-
-
- async def _get_top_mcp_by_embedding(
- self,
- query: str,
- mcp_list: list[str],
- ) -> list[dict[str, str]]:
- """通过向量检索获取Top5 MCP Server"""
- logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list)
- mcp_table = await LanceDB().get_table("mcp")
- query_embedding = await Embedding.get_embedding([query])
- mcp_vecs = await (await mcp_table.search(
- query=query_embedding,
- vector_column_name="embedding",
- )).where(f"id IN {MCPSelector._assemble_sql(mcp_list)}").limit(5).to_list()
-
- # 拿到名称和description
- logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs)
- mcp_collection = MongoDB().get_collection("mcp")
- llm_mcp_list: list[dict[str, str]] = []
- for mcp_vec in mcp_vecs:
- mcp_id = mcp_vec["id"]
- mcp_data = await mcp_collection.find_one({"_id": mcp_id})
- if not mcp_data:
- logger.warning("[MCPHelper] 查询MCP Server名称和描述失败: %s", mcp_id)
- continue
- mcp_data = MCPCollection.model_validate(mcp_data)
- llm_mcp_list.extend([{
- "id": mcp_id,
- "name": mcp_data.name,
- "description": mcp_data.description,
- }])
- return llm_mcp_list
-
-
- async def _get_mcp_by_llm(
- self,
- query: str,
- mcp_list: list[dict[str, str]],
- mcp_ids: list[str],
- ) -> MCPSelectResult:
- """通过LLM选择最合适的MCP Server"""
- # 初始化jinja2环境
- env = SandboxedEnvironment(
- loader=BaseLoader,
- autoescape=True,
- trim_blocks=True,
- lstrip_blocks=True,
- )
- template = env.from_string(MCP_SELECT)
- # 渲染模板
- mcp_prompt = template.render(
- mcp_list=mcp_list,
- goal=query,
- )
-
- # 调用大模型进行推理
- result = await self._call_reasoning(mcp_prompt)
-
- # 使用小模型提取JSON
- return await self._call_function_mcp(result, mcp_ids)
-
-
- async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]:
- """调用大模型进行推理"""
- logger.info("[MCPHelper] 调用推理大模型")
- message = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": prompt},
- ]
- async for chunk in self.resoning_llm.call(message):
- yield chunk
-
-
- async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult:
- """调用结构化输出小模型提取JSON"""
- logger.info("[MCPHelper] 调用结构化输出小模型")
- llm = FunctionLLM()
- message = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": reasoning_result},
- ]
- schema = MCPSelectResult.model_json_schema()
- # schema中加入选项
- schema["properties"]["mcp_id"]["enum"] = mcp_ids
- result = await llm.call(messages=message, schema=schema)
- try:
- result = MCPSelectResult.model_validate(result)
- except Exception:
- logger.exception("[MCPHelper] 解析MCP Select Result失败")
- raise
- return result
-
-
- async def select_top_mcp(
- self,
- query: str,
- mcp_list: list[str],
- ) -> MCPSelectResult:
- """
- 选择最合适的MCP Server
-
- 先通过Embedding选择Top5,然后通过LLM选择Top 1
- """
- # 通过向量检索获取Top5
- llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list)
+FINAL_TOOL_ID = "FIANL"
+SUMMARIZE_TOOL_ID = "SUMMARIZE"
- # 通过LLM选择最合适的
- return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list)
+class MCPSelector:
+ """MCP选择器"""
@staticmethod
- async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]:
+ async def select_top_tool(
+ goal: str, tool_list: list[MCPTool],
+ additional_info: str | None = None, top_n: int | None = None) -> list[MCPTool]:
"""选择最合适的工具"""
- tool_vector = await LanceDB().get_table("mcp_tool")
- query_embedding = await Embedding.get_embedding([query])
- tool_vecs = await (await tool_vector.search(
- query=query_embedding,
- vector_column_name="embedding",
- )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list()
-
- # 拿到工具
- tool_collection = MongoDB().get_collection("mcp")
- llm_tool_list = []
-
- for tool_vec in tool_vecs:
- # 到MongoDB里找对应的工具
- logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"])
- tool_data = await tool_collection.aggregate([
- {"$match": {"_id": tool_vec["mcp_id"]}},
- {"$unwind": "$tools"},
- {"$match": {"tools.id": tool_vec["id"]}},
- {"$project": {"_id": 0, "tools": 1}},
- {"$replaceRoot": {"newRoot": "$tools"}},
- ])
- async for tool in tool_data:
- tool_obj = MCPTool.model_validate(tool)
- llm_tool_list.append(tool_obj)
-
- llm_tool_list.append(
- MCPTool(
- id="00000000-0000-0000-0000-000000000000",
- name="Final",
- description="It is the final step, indicating the end of the plan execution."),
- )
- llm_tool_list.append(
- MCPTool(
- id="00000000-0000-0000-0000-000000000001",
- name="Chat",
- description="It is a chat tool to communicate with the user."),
- )
-
- return llm_tool_list
+ random.shuffle(tool_list)
+ max_tokens = Config().get_config().function_call.max_tokens
+ template = _env.from_string(TOOL_SELECT)
+ if TokenCalculator.calculate_token_length(
+ messages=[{"role": "user", "content": template.render(
+ goal=goal, tools=[], additional_info=additional_info
+ )}],
+ pure_text=True) > max_tokens:
+ logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择")
+ return []
+ llm = FunctionLLM()
+ current_index = 0
+ tool_ids = []
+ while current_index < len(tool_list):
+ index = current_index
+ sub_tools = []
+ while index < len(tool_list):
+ tool = tool_list[index]
+ tokens = TokenCalculator.calculate_token_length(
+ messages=[{"role": "user", "content": template.render(
+ goal=goal, tools=[tool],
+ additional_info=additional_info
+ )}],
+ pure_text=True
+ )
+ if tokens > max_tokens:
+ continue
+ sub_tools.append(tool)
+
+ tokens = TokenCalculator.calculate_token_length(messages=[{"role": "user", "content": template.render(
+ goal=goal, tools=sub_tools, additional_info=additional_info)}, ], pure_text=True)
+ if tokens > max_tokens:
+ del sub_tools[-1]
+ break
+ else:
+ index += 1
+ current_index = index
+ if sub_tools:
+ message = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": template.render(tools=sub_tools)},
+ ]
+ schema = MCPToolIdsSelectResult.model_json_schema()
+ schema["properties"]["tool_ids"]["enum"] = [tool.id for tool in sub_tools]
+ result = await llm.call(messages=message, schema=schema)
+ try:
+ result = MCPToolIdsSelectResult.model_validate(result)
+ tool_ids.extend(result.tool_ids)
+ except Exception:
+ logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败")
+ continue
+ mcp_tools = [tool for tool in tool_list if tool.id in tool_ids]
+
+ if top_n is not None:
+ mcp_tools = mcp_tools[:top_n]
+ mcp_tools.append(MCPTool(id=FINAL_TOOL_ID, name="Final",
+ description="终止", mcp_id=FINAL_TOOL_ID, input_schema={}))
+ # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize",
+ # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={}))
+ return mcp_tools
\ No newline at end of file
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index 272a8c60e451437494ada813433ad4a58578030a..3371cbdcfb4bbafaaf3bde0582e88a5cf03b34d4 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -190,7 +190,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
flow_id=task.state.flow_id,
flow_name=task.state.flow_name,
flow_status=task.state.flow_status,
- history_ids=[context["_id"] for context in task.context],
+ history_ids=[context.id for context in task.context],
),
)
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index d35bbb21828cf849746e4c3d1bf935fe544e3da5..0c58a4107b81775f8eeefc98019a57683847a256 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -7,6 +7,8 @@ from datetime import UTC, datetime
from apps.common.config import config
from apps.common.queue import MessageQueue
+from apps.llm.patterns.rewrite import QuestionRewrite
+from apps.llm.reasoning import ReasoningLLM
from apps.models.app import App
from apps.scheduler.executor.agent import MCPAgentExecutor
from apps.scheduler.executor.flow import FlowExecutor
@@ -17,6 +19,7 @@ from apps.scheduler.scheduler.message import (
push_init_message,
push_rag_message,
)
+from apps.schemas.config import LLMConfig
from apps.schemas.enum_var import AppType, EventType, FlowStatus
from apps.schemas.rag_data import RAGQueryReq
from apps.schemas.request_data import RequestData
@@ -68,44 +71,49 @@ class Scheduler:
logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}")
- async def run(self) -> None: # noqa: PLR0911
- """运行调度器"""
+ async def get_llm_use_in_chat_with_rag(self) -> LLM:
+ """获取RAG大模型"""
try:
# 获取当前会话使用的大模型
- llm_id = await LLMManager.get_user_default_llm(
+ llm_id = await LLMManager.get_llm_id_by_conversation_id(
self.task.ids.user_sub, self.task.ids.conversation_id,
)
if not llm_id:
logger.error("[Scheduler] 获取大模型ID失败")
- await self.queue.close()
- return
+ return None
if llm_id == "empty":
- llm = LLM(
+ return LLM(
_id="empty",
user_sub=self.task.ids.user_sub,
- openai_base_url=config.llm.endpoint,
- openai_api_key=config.llm.key,
- model_name=config.llm.model,
- max_tokens=config.llm.max_tokens,
+ openai_base_url=Config().get_config().llm.endpoint,
+ openai_api_key=Config().get_config().llm.key,
+ model_name=Config().get_config().llm.model,
+ max_tokens=Config().get_config().llm.max_tokens,
)
- else:
- llm = await LLMManager.get_llm(self.task.ids.user_sub, llm_id)
- if not llm:
- logger.error("[Scheduler] 获取大模型失败")
- await self.queue.close()
- return
+ llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id)
+ if not llm:
+ logger.error("[Scheduler] 获取大模型失败")
+ return None
+ return llm
except Exception:
logger.exception("[Scheduler] 获取大模型失败")
- await self.queue.close()
- return
+ return None
+
+
+ async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]:
+ """获取知识库ID列表"""
try:
- # 获取当前会话使用的知识库
- kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id(
- self.task.ids.user_sub, self.task.ids.conversation_id)
+ return await KnowledgeBaseManager.get_kb_ids_by_conversation_id(
+ self.task.ids.user_sub, self.task.ids.conversation_id,
+ )
except Exception:
logger.exception("[Scheduler] 获取知识库ID失败")
await self.queue.close()
- return
+ return []
+
+
+ async def run(self) -> None:
+ """运行调度器"""
try:
# 获取当前问答可供关联的文档
docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body)
@@ -121,6 +129,8 @@ class Scheduler:
kill_event = asyncio.Event()
monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub))
if not self.post_body.app or self.post_body.app.app_id == "":
+ llm = await self.get_llm_use_in_chat_with_rag()
+ kb_ids = await self.get_kb_ids_use_in_chat_with_rag()
self.task = await push_init_message(self.task, self.queue, 3, is_flow=False)
rag_data = RAGQueryReq(
kbIds=kb_ids,
@@ -197,6 +207,27 @@ class Scheduler:
if not app_metadata:
logger.error("[Scheduler] 未找到Agent应用")
return
+ llm = await LLMManager.get_llm_by_id(
+ self.task.ids.user_sub, app_metadata.llm_id,
+ )
+ if not llm:
+ logger.error("[Scheduler] 获取大模型失败")
+ await self.queue.close()
+ return
+ reasion_llm = ReasoningLLM(
+ LLMConfig(
+ endpoint=llm.openai_base_url,
+ key=llm.openai_api_key,
+ model=llm.model_name,
+ max_tokens=llm.max_tokens,
+ ),
+ )
+ if background.conversation:
+ try:
+ question_obj = QuestionRewrite()
+ post_body.question = await question_obj.generate(history=background.conversation, question=post_body.question, llm=reasion_llm)
+ except Exception:
+ logger.exception("[Scheduler] 问题重写失败")
if app_metadata.app_type == AppType.FLOW.value:
logger.info("[Scheduler] 获取工作流元数据")
flow_info = await Pool().get_flow_metadata(app_info.app_id)
@@ -255,6 +286,7 @@ class Scheduler:
servers_id=servers_id,
background=background,
agent_id=app_info.app_id,
+ params=post_body.app.params,
)
# 开始运行
logger.info("[Scheduler] 运行Executor")
diff --git a/apps/schemas/appcenter.py b/apps/schemas/appcenter.py
index b37716f3fb39556d94aff4053877309846dc1c92..f6ccd6099dde873fcda90d78f29551d3b6df9143 100644
--- a/apps/schemas/appcenter.py
+++ b/apps/schemas/appcenter.py
@@ -52,6 +52,14 @@ class AppFlowInfo(BaseModel):
debug: bool = Field(..., description="是否经过调试")
+class AppMcpServiceInfo(BaseModel):
+ """应用关联的MCP服务信息"""
+
+ id: str = Field(..., description="MCP服务ID")
+ name: str = Field(..., description="MCP服务名称")
+ description: str = Field(..., description="MCP服务简介")
+
+
class AppData(BaseModel):
"""应用信息数据结构"""
@@ -66,4 +74,4 @@ class AppData(BaseModel):
permission: AppPermissionData = Field(
default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置")
workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表")
- mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表")
+ mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务id列表")
diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py
index d46f5289061b002f64d2df37d4059112c4b3d3fc..f80a0c4a7c49df005d8a9dda24beb7a25ca834b1 100644
--- a/apps/schemas/enum_var.py
+++ b/apps/schemas/enum_var.py
@@ -16,6 +16,7 @@ class StepStatus(str, Enum):
"""步骤状态"""
UNKNOWN = "unknown"
+ INIT = "init"
WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
@@ -28,6 +29,7 @@ class FlowStatus(str, Enum):
"""Flow状态"""
UNKNOWN = "unknown"
+ INIT = "init"
WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
@@ -55,12 +57,15 @@ class EventType(str, Enum):
STEP_WAITING_FOR_START = "step.waiting_for_start"
STEP_WAITING_FOR_PARAM = "step.waiting_for_param"
FLOW_START = "flow.start"
+ STEP_INIT = "step.init"
STEP_INPUT = "step.input"
STEP_OUTPUT = "step.output"
+ STEP_CANCEL = "step.cancel"
+ STEP_ERROR = "step.error"
FLOW_STOP = "flow.stop"
FLOW_FAILED = "flow.failed"
FLOW_SUCCESS = "flow.success"
- FLOW_CANCELLED = "flow.cancelled"
+ FLOW_CANCEL = "flow.cancel"
DONE = "done"
diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py
index c31542bc614d99148a1fff8cf4d4e0c1e0cc6735..226d6849b7b15a051d25c34825e9d77b6535ddba 100644
--- a/apps/schemas/mcp.py
+++ b/apps/schemas/mcp.py
@@ -64,6 +64,13 @@ class GoalEvaluationResult(BaseModel):
reason: str = Field(description="评估原因")
+class RestartStepIndex(BaseModel):
+ """MCP重新规划的步骤索引"""
+
+ start_index: int = Field(description="重新规划的起始步骤索引")
+ reasoning: str = Field(description="重新规划的原因")
+
+
class Risk(str, Enum):
"""MCP工具风险类型"""
@@ -79,12 +86,38 @@ class ToolRisk(BaseModel):
reason: str = Field(description="风险原因", default="")
+class ErrorType(str, Enum):
+ """MCP工具错误类型"""
+
+ MISSING_PARAM = "missing_param"
+ DECORRECT_PLAN = "decorrect_plan"
+
+
+class ToolExcutionErrorType(BaseModel):
+ """MCP工具执行错误"""
+
+ type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM)
+ reason: str = Field(description="错误原因", default="")
+
+
class MCPSelectResult(BaseModel):
"""MCP选择结果"""
mcp_id: str = Field(description="MCP Server的ID")
+class MCPToolSelectResult(BaseModel):
+ """MCP工具选择结果"""
+
+ name: str = Field(description="工具名称")
+
+
+class MCPToolIdsSelectResult(BaseModel):
+ """MCP工具ID选择结果"""
+
+ tool_ids: list[str] = Field(description="工具ID列表")
+
+
class MCPPlanItem(BaseModel):
"""MCP 计划"""
diff --git a/apps/schemas/message.py b/apps/schemas/message.py
index 0a441dd273676d93d7620c7c7c27f779b23e5945..0b9ef51862ecc41362885b232d50479f91047e8a 100644
--- a/apps/schemas/message.py
+++ b/apps/schemas/message.py
@@ -7,10 +7,18 @@ from typing import Any
from pydantic import BaseModel, Field
-from .enum_var import EventType, StepStatus
+from apps.schemas.enum_var import EventType, FlowStatus, StepStatus
+
from .record import RecordMetadata
+class param(BaseModel):
+ """流执行过程中的参数补充"""
+
+ content: dict[str, Any] | bool = Field(default={}, description="流执行过程中的参数补充内容")
+ description: str = Field(default="", description="流执行过程中的参数补充描述")
+
+
class HeartbeatData(BaseModel):
"""心跳事件的数据结构"""
@@ -24,6 +32,8 @@ class MessageFlow(BaseModel):
app_id: str = Field(description="插件ID", alias="appId")
flow_id: str = Field(description="Flow ID", alias="flowId")
+ flow_name: str = Field(description="Flow名称", alias="flowName")
+ flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN)
step_id: str = Field(description="当前步骤ID", alias="stepId")
step_name: str = Field(description="当前步骤名称", alias="stepName")
sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None)
@@ -78,7 +88,7 @@ class FlowStartContent(BaseModel):
"""flow.start消息的content"""
question: str = Field(description="用户问题")
- params: dict[str, Any] = Field(description="预先提供的参数")
+ params: dict[str, Any] | None = Field(description="预先提供的参数", default=None)
class MessageBase(HeartbeatData):
@@ -89,5 +99,5 @@ class MessageBase(HeartbeatData):
conversation_id: uuid.UUID = Field(min_length=36, max_length=36, alias="conversationId")
task_id: str = Field(min_length=36, max_length=36, alias="taskId")
flow: MessageFlow | None = None
- content: dict[str, Any] = {}
+ content: Any | None = Field(default=None, description="消息内容")
metadata: MessageMetadata
diff --git a/apps/schemas/record.py b/apps/schemas/record.py
index 578f567ce78657fe1bb86af44629d405d8bb8a5f..119a80823f5daab98c8617a5fcd305b4a38dc3b5 100644
--- a/apps/schemas/record.py
+++ b/apps/schemas/record.py
@@ -41,6 +41,7 @@ class RecordFlow(BaseModel):
id: str
record_id: str = Field(alias="recordId")
flow_id: str = Field(alias="flowId")
+ flow_name: str = Field(alias="flowName", default="")
flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS)
step_num: int = Field(alias="stepNum")
steps: list[RecordFlowStep]
@@ -81,7 +82,7 @@ class RecordData(BaseModel):
id: uuid.UUID
conversation_id: uuid.UUID = Field(alias="conversationId")
- task_id: uuid.UUID = Field(alias="taskId")
+ task_id: uuid.UUID | None = Field(alias="taskId", default=None)
document: list[RecordDocument] = []
flow: RecordFlow | None = None
content: RecordContent
@@ -118,7 +119,8 @@ class Record(RecordData):
user_sub: str
key: dict[str, Any] = {}
- content: str
+ task_id: uuid.UUID | None = Field(default=None, description="任务ID")
+ content: str = Field(default="", description="Record内容,已加密")
comment: RecordComment = Field(default=RecordComment())
flow: FlowHistory = Field(
default=FlowHistory(), description="Flow执行历史信息")
diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py
index b369ddf1d3258534c941d47a4d9634da0f2ba7a8..f271f6ee80032007477be67f38ecf3bc5629931f 100644
--- a/apps/schemas/request_data.py
+++ b/apps/schemas/request_data.py
@@ -10,6 +10,7 @@ from .appcenter import AppData
from .enum_var import CommentType
from .flow_topology import FlowItem
from .mcp import MCPType
+from .message import param
class RequestDataApp(BaseModel):
@@ -17,7 +18,7 @@ class RequestDataApp(BaseModel):
app_id: uuid.UUID = Field(description="应用ID", alias="appId")
flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId")
- params: dict[str, Any] | None = Field(default=None, description="插件参数")
+ params: param | None = Field(default=None, description="流执行过程中的参数补充", alias="params")
class RequestData(BaseModel):
@@ -31,7 +32,7 @@ class RequestData(BaseModel):
files: list[str] = Field(default=[], description="文件列表")
app: RequestDataApp | None = Field(default=None, description="应用")
debug: bool = Field(default=False, description="是否调试")
- new_task: bool = Field(default=True, description="是否新建任务")
+ task_id: str | None = Field(default=None, alias="taskId", description="任务ID")
class QuestionBlacklistRequest(BaseModel):
@@ -163,3 +164,9 @@ class UpdateUserKnowledgebaseReq(BaseModel):
"""更新知识库请求体"""
kb_ids: list[uuid.UUID] = Field(description="知识库ID列表", alias="kbIds", default=[])
+
+
+class UserUpdateRequest(BaseModel):
+ """更新用户信息请求体"""
+
+ auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute")
diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py
index f41fdb02742e1d3f6189a935161a930919c678ba..3373fda330f456acb3ef3282a400759aebe78282 100644
--- a/apps/schemas/response_data.py
+++ b/apps/schemas/response_data.py
@@ -43,6 +43,7 @@ class AuthUserMsg(BaseModel):
user_sub: str
revision: bool
is_admin: bool
+ auto_execute: bool
class AuthUserRsp(ResponseData):
diff --git a/apps/schemas/task.py b/apps/schemas/task.py
index e0a2b39ef1342665cf8beda40ed9682113c86f9b..2e0a65d2390a3ed045b7e4aa122e7816830a2a66 100644
--- a/apps/schemas/task.py
+++ b/apps/schemas/task.py
@@ -31,6 +31,7 @@ class FlowStepHistory(BaseModel):
step_status: StepStatus = Field(description="当前步骤状态")
input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={})
output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={})
+ ex_data: dict[str, Any] | None = Field(description="额外数据", default=None)
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
@@ -41,17 +42,17 @@ class ExecutorState(BaseModel):
flow_id: str = Field(description="Flow ID", default="")
flow_name: str = Field(description="Flow名称", default="")
description: str = Field(description="Flow描述", default="")
- flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN)
+ flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT)
# 任务级数据
step_id: str = Field(description="当前步骤ID", default="")
+ step_index: int = Field(description="当前步骤索引", default=0)
step_name: str = Field(description="当前步骤名称", default="")
step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN)
step_description: str = Field(description="当前步骤描述", default="")
- retry_times: int = Field(description="当前步骤重试次数", default=0)
- error_message: str = Field(description="当前步骤错误信息", default="")
app_id: str = Field(description="应用ID", default="")
- slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
- error_info: dict[str, Any] = Field(description="错误信息", default={})
+ current_input: dict[str, Any] = Field(description="当前输入数据", default={})
+ error_message: str = Field(description="错误信息", default="")
+ retry_times: int = Field(description="当前步骤重试次数", default=0)
class TaskIds(BaseModel):
@@ -93,7 +94,7 @@ class Task(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
ids: TaskIds = Field(description="任务涉及的各种ID")
- context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[])
+ context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[])
state: ExecutorState = Field(description="Flow的状态", default=ExecutorState())
tokens: TaskTokens = Field(description="Token信息")
runtime: TaskRuntime = Field(description="任务运行时数据")
diff --git a/apps/services/conversation.py b/apps/services/conversation.py
index ade7fb6bf2cc227e148947064153e68f7c3b289b..47fa28d2a05f3983b41a1317b7528789e17856aa 100644
--- a/apps/services/conversation.py
+++ b/apps/services/conversation.py
@@ -153,7 +153,7 @@ class ConversationManager:
await session.delete(conv)
await session.commit()
- await TaskManager.delete_tasks_by_conversation_id(conversation_id)
+ await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id)
@staticmethod
diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py
index a08d54551ee9806d92cb3e451c207f25c3a97d9e..dc6619abac40594f093110ea4342cd43d05de6fc 100644
--- a/apps/services/mcp_service.py
+++ b/apps/services/mcp_service.py
@@ -95,6 +95,8 @@ class MCPServiceManager:
user_sub: str,
keyword: str | None,
page: int,
+ *,
+ is_active: bool | None = None,
) -> list[MCPServiceCardItem]:
"""
获取所有MCP服务列表
@@ -105,7 +107,7 @@ class MCPServiceManager:
:param page: int: 页码
:return: MCP服务列表
"""
- mcpservice_pools = await MCPServiceManager._search_mcpservice(search_type, keyword, page)
+ mcpservice_pools = await MCPServiceManager._search_mcpservice(search_type, keyword, page, is_active)
return [
MCPServiceCardItem(
mcpserviceId=item.id,
@@ -162,6 +164,8 @@ class MCPServiceManager:
search_type: SearchType,
keyword: str | None,
page: int,
+ *,
+ is_active: bool | None = None,
) -> list[MCPInfo]:
"""
基于输入条件搜索MCP服务
diff --git a/apps/services/record.py b/apps/services/record.py
index 3b5151087f71a9e4f6b529a6e245cc66a68851df..d62603b5ae4a6444dba3e26c913a5bbc9c1c910e 100644
--- a/apps/services/record.py
+++ b/apps/services/record.py
@@ -11,6 +11,7 @@ from sqlalchemy import and_, select
from apps.common.postgres import postgres
from apps.models.conversation import Conversation
from apps.models.record import Record as PgRecord
+from apps.schemas.enum_var import FlowStatus
from apps.schemas.record import Record
logger = logging.getLogger(__name__)
@@ -116,3 +117,17 @@ class RecordManager:
return []
else:
return records
+
+
+ @staticmethod
+ async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None:
+ """更新Record关联的Flow状态"""
+ record_group_collection = MongoDB().get_collection("record_group")
+ try:
+ await record_group_collection.update_many(
+ {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}},
+ {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}},
+ array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}],
+ )
+ except Exception:
+ logger.exception("[RecordManager] 更新Record关联的Flow状态失败")
diff --git a/apps/services/task.py b/apps/services/task.py
index ec8436b18cd2a8ec6b9162387dcc8c3635a3cedd..34e6bf75cc0b654af67c77d189bdc772d0292ab6 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -3,11 +3,11 @@
import logging
import uuid
-from typing import Any
from apps.common.postgres import postgres
from apps.schemas.request_data import RequestData
from apps.schemas.task import (
+ FlowStepHistory,
Task,
TaskIds,
TaskRuntime,
@@ -70,7 +70,7 @@ class TaskManager:
@staticmethod
- async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]:
+ async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]:
"""根据record_group_id获取flow信息"""
record_group_collection = MongoDB().get_collection("record_group")
flow_context_collection = MongoDB().get_collection("flow_context")
@@ -88,7 +88,7 @@ class TaskManager:
for flow_context_id in records[0]["records"]["flow"]["history_ids"]:
flow_context = await flow_context_collection.find_one({"_id": flow_context_id})
if flow_context:
- flow_context_list.append(flow_context)
+ flow_context_list.append(FlowStepHistory.model_validate(flow_context))
except Exception:
logger.exception("[TaskManager] 获取record_id的flow信息失败")
return []
@@ -97,7 +97,7 @@ class TaskManager:
@staticmethod
- async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]:
+ async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]:
"""根据task_id获取flow信息"""
flow_context_collection = MongoDB().get_collection("flow_context")
@@ -108,7 +108,8 @@ class TaskManager:
).sort(
"created_at", -1,
).limit(length):
- flow_context += [history]
+ for i in range(len(flow_context)):
+ flow_context.append(FlowStepHistory.model_validate(history))
except Exception:
logger.exception("[TaskManager] 获取task_id的flow信息失败")
return []
@@ -117,7 +118,28 @@ class TaskManager:
@staticmethod
- async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None:
+ async def init_new_task(
+ user_sub: str,
+ session_id: str | None = None,
+ post_body: RequestData | None = None,
+ ) -> Task:
+ """获取任务块"""
+ return Task(
+ _id=str(uuid.uuid4()),
+ ids=TaskIds(
+ user_sub=user_sub if user_sub else "",
+ session_id=session_id if session_id else "",
+ conversation_id=post_body.conversation_id,
+ group_id=post_body.group_id if post_body.group_id else "",
+ ),
+ question=post_body.question if post_body else "",
+ group_id=post_body.group_id if post_body else "",
+ tokens=TaskTokens(),
+ runtime=TaskRuntime(),
+ )
+
+ @staticmethod
+ async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None:
"""保存flow信息到flow_context"""
flow_context_collection = MongoDB().get_collection("flow_context")
try:
@@ -125,15 +147,15 @@ class TaskManager:
# 查找是否存在
current_context = await flow_context_collection.find_one({
"task_id": task_id,
- "_id": history["_id"],
+ "_id": history.id,
})
if current_context:
await flow_context_collection.update_one(
{"_id": current_context["_id"]},
- {"$set": history},
+ {"$set": history.model_dump(exclude_none=True, by_alias=True)},
)
else:
- await flow_context_collection.insert_one(history)
+ await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True))
except Exception:
logger.exception("[TaskManager] 保存flow执行记录失败")
@@ -150,7 +172,26 @@ class TaskManager:
@staticmethod
- async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> None:
+ async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]:
+ """通过ConversationID删除Task信息"""
+ mongo = MongoDB()
+ task_collection = mongo.get_collection("task")
+ task_ids = []
+ try:
+ async for task in task_collection.find(
+ {"conversation_id": conversation_id},
+ {"_id": 1},
+ ):
+ task_ids.append(task["_id"])
+ if task_ids:
+ await task_collection.delete_many({"conversation_id": conversation_id})
+ return task_ids
+ except Exception:
+ logger.exception("[TaskManager] 删除ConversationID的Task信息失败")
+ return []
+
+ @staticmethod
+ async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None:
"""通过ConversationID删除Task信息"""
mongo = MongoDB()
task_collection = mongo.get_collection("task")
@@ -168,50 +209,6 @@ class TaskManager:
await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session)
- @classmethod
- async def get_task(
- cls,
- task_id: str | None = None,
- session_id: str | None = None,
- post_body: RequestData | None = None,
- user_sub: str | None = None,
- ) -> Task:
- """获取任务块"""
- if task_id:
- try:
- task = await cls.get_task_by_task_id(task_id)
- if task:
- return task
- except Exception:
- logger.exception("[TaskManager] 通过task_id获取任务失败")
-
- logger.info("[TaskManager] 未提供task_id,通过session_id获取任务")
- if not session_id or not post_body:
- err = (
- "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。"
- )
- raise ValueError(err)
-
- if post_body.group_id:
- task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id)
- else:
- task = await cls.get_task_by_conversation_id(post_body.conversation_id)
-
- if task:
- return task
- return Task(
- _id=str(uuid.uuid4()),
- ids=TaskIds(
- user_sub=user_sub if user_sub else "",
- session_id=session_id if session_id else "",
- conversation_id=post_body.conversation_id,
- group_id=post_body.group_id if post_body.group_id else "",
- ),
- tokens=TaskTokens(),
- runtime=TaskRuntime(),
- )
-
-
@classmethod
async def save_task(cls, task_id: str, task: Task) -> None:
"""保存任务块"""