From 78e087e256824b30d73e2835342b37444d92ab72 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 12:16:10 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E9=80=82=E9=85=8Drag=E7=9A=84=E4=BD=9C?= =?UTF-8?q?=E8=80=85=E5=92=8C=E5=88=9B=E5=BB=BA=E6=97=B6=E9=97=B4=E7=9A=84?= =?UTF-8?q?=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/context.py | 3 +++ apps/scheduler/scheduler/message.py | 9 +-------- apps/schemas/record.py | 6 +++++- apps/services/document.py | 4 ++++ apps/services/rag.py | 3 +++ 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index b7088d8da..4c2c4cf0b 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -114,11 +114,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: used_docs.append( RecordGroupDocument( _id=docs["id"], + author=docs.get("author", ""), + order=docs.get("order", 0), name=docs["name"], abstract=docs.get("abstract", ""), extension=docs.get("extension", ""), size=docs.get("size", 0), associated="answer", + created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)), ) ) if docs.get("order") is not None: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index c89fdd101..5997b48fb 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -71,14 +71,7 @@ async def push_rag_message( # 如果是文本消息,直接拼接到答案中 full_answer += content_obj.content elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append({ - "id": content_obj.content.get("id", ""), - "order": content_obj.content.get("order", 0), - "name": content_obj.content.get("name", ""), - "abstract": content_obj.content.get("abstract", ""), - "extension": content_obj.content.get("extension", ""), - "size": content_obj.content.get("size", 0), - }) + task.runtime.documents.append(content_obj.content) # 保存答案 task.runtime.answer = full_answer await TaskManager.save_task(task.id, task) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d7acd3682..e1a995f8e 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -17,8 +17,9 @@ class RecordDocument(Document): """GET /api/record/{conversation_id} Result中的document数据结构""" id: str = Field(alias="_id", default="") + order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None = None + author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] class Config: @@ -103,11 +104,14 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + order: int = Field(default=0, description="文档顺序") + author: str = Field(default="", description="文档作者") name: str = Field(default="", description="文档名称") abstract: str = Field(default="", description="文档摘要") extension: str = Field(default="", description="文档扩展名") size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] + created_at: float = Field(default=0.0, description="文档创建时间") class Record(RecordData): diff --git a/apps/services/document.py b/apps/services/document.py index 203162da1..451423a9d 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -2,6 +2,7 @@ """文件Manager""" import base64 +from datetime import UTC, datetime import logging import uuid @@ -131,12 +132,15 @@ class DocumentManager: return [ RecordDocument( _id=doc.id, + order=doc.order, + author=doc.author, abstract=doc.abstract, name=doc.name, type=doc.extension, size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, + created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] diff --git a/apps/services/rag.py b/apps/services/rag.py index 6b6c843dc..efbdfe940 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """对接Euler Copilot RAG""" +from datetime import UTC, datetime import json import logging from collections.abc import AsyncGenerator @@ -156,9 +157,11 @@ class RAG: "id": doc_chunk["docId"], "order": doc_cnt, "name": doc_chunk.get("docName", ""), + "author": doc_chunk.get("docAuthor", ""), "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), + "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] -- Gitee From c930151094eef40aa9f4e01e1ec59bd56ef75a21 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:40:30 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=AF=B9rag=E7=9A=84team=E5=92=8C=E5=9B=A2?= =?UTF-8?q?=E9=98=9F=E8=8E=B7=E5=8F=96=E5=A2=9E=E5=8A=A0=E5=BC=82=E5=B8=B8?= =?UTF-8?q?=E6=8D=95=E8=8E=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/services/conversation.py | 6 +++++- apps/services/knowledge.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 4bcade45c..bac964db1 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,7 +59,11 @@ class ConversationManager: model_name=llm.model_name, ) kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except: + logger.error("[ConversationManager] 获取团队知识库列表失败") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 9b4077f92..bd8dfc9e8 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -138,7 +138,11 @@ class KnowledgeBaseManager: return [] kb_ids_update_success = [] kb_item_dict_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except Exception as e: + logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: -- Gitee From 059bf35ef464b8165ea1c816566c9c35c4bc742b Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:53:07 +0800 Subject: [PATCH 3/8] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index e1a995f8e..6e9617a33 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,6 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") + user_sub: None | None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From eefe717400bca7606c0c722d5e506fd9d9704c14 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:56:50 +0800 Subject: [PATCH 4/8] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 6e9617a33..b5e1b0c55 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,7 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None | None + user_sub: None = None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From 60ea73002a36aa3452997caa649c9b80993ea2bd Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 16:21:07 +0800 Subject: [PATCH 5/8] =?UTF-8?q?=E5=AE=8C=E5=96=84DocumentAddContent?= =?UTF-8?q?=E7=9A=84=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/message.py | 2 ++ apps/schemas/message.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 5997b48fb..a2a45e414 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -108,10 +108,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl data=DocumentAddContent( documentId=content_obj.content.get("id", ""), documentOrder=content_obj.content.get("order", 0), + documentAuthor=content_obj.content.get("author", ""), documentName=content_obj.content.get("name", ""), documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), + createdAt=round(datetime.now(tz=UTC).timestamp(), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d0661224e..5f0aee8ab 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -2,7 +2,7 @@ """队列中的消息结构""" from typing import Any - +from datetime import UTC, datetime from pydantic import BaseModel, Field from apps.schemas.enum_var import EventType, StepStatus @@ -60,10 +60,14 @@ class DocumentAddContent(BaseModel): document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") + document_author: str = Field(description="文档作者", alias="documentAuthor", default="") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) + created_at: float = Field( + description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + ) class FlowStartContent(BaseModel): -- Gitee From 7b57aa351af2a882d5059e5f5b3a618d2238deec Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 19:22:36 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=97=A0=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/user.py | 8 +++++++ apps/routers/api_key.py | 1 + apps/routers/auth.py | 3 +-- apps/scheduler/executor/step.py | 38 ++++++++++++++++----------------- apps/schemas/config.py | 9 +++++++- apps/schemas/message.py | 2 ++ apps/schemas/request_data.py | 4 ++-- 7 files changed, 41 insertions(+), 24 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51e..87cbd2902 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -5,10 +5,12 @@ import logging from fastapi import Depends from fastapi.security import OAuth2PasswordBearer +import secrets from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config from apps.services.api_key import ApiKeyManager from apps.services.session import SessionManager @@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return secrets.token_hex(16) session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( @@ -69,6 +74,9 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return Config().get_config().no_auth.user_sub session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py index 158cfc13a..51366a219 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/api_key.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse + from apps.dependency.user import get_user, verify_user from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp from apps.schemas.response_data import ResponseData diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed62..4a3f82939 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates +from apps.common.config import Config from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user from apps.schemas.collection import Audit @@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: user_info = await oidc_provider.get_oidc_user(token["access_token"]) user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) except Exception as e: logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa9..506f3bb10 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -86,8 +86,8 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.task.state.step_id = self.step.step_id # type: ignore[arg-type] + self.task.state.step_name = self.step.step.name # type: ignore[arg-type] # 获取并验证Call类 node_id = self.step.step.node @@ -127,13 +127,13 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.task.state.step_id # type: ignore[arg-type] + current_step_name = self.task.state.step_name # type: ignore[arg-type] # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] + self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,17 +156,17 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] + self.task.state.step_id = current_step_id # type: ignore[arg-type] + self.task.state.step_name = current_step_name # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens @@ -212,7 +212,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,22 +224,22 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": str(e), "data": {}, } return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -253,12 +253,12 @@ class StepExecutor(BaseExecutor): # 更新context history = FlowStepHistory( task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + flow_name=self.task.state.flow_name, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + status=self.task.state.status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data, ) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1a..99bcccde8 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -6,6 +6,13 @@ from typing import Literal from pydantic import BaseModel, Field +class NoauthConfig(BaseModel): + """无认证配置""" + + enable: bool = Field(description="是否启用无认证访问", default=False) + user_sub: str = Field(description="调试用户的sub", default="admin") + + class DeployConfig(BaseModel): """部署配置""" @@ -122,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - + no_auth: NoauthConfig deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 5f0aee8ab..cf70a82b8 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -24,6 +24,8 @@ class MessageFlow(BaseModel): flow_id: str = Field(description="Flow ID", alias="flowId") step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") + sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) + sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2305dd93d..5d6dc7fcc 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -16,8 +16,8 @@ class RequestDataApp(BaseModel): """模型对话中包含的app信息""" app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") + params: dict[str, Any] | None = Field(default=None, description="插件参数") class MockRequestData(BaseModel): -- Gitee From f02748f597d63b3b072b0749681bf1827283987e Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 21:13:04 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E5=AE=8C=E5=96=84Agent=E7=9A=84=E5=BC=80?= =?UTF-8?q?=E5=8F=91&=E4=BF=AE=E5=A4=8Dmcp=E6=B3=A8=E5=86=8C=E6=97=B6?= =?UTF-8?q?=E5=80=99=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 11 ++++++++--- apps/scheduler/call/mcp/mcp.py | 7 ------- apps/scheduler/executor/agent.py | 28 +++++++++------------------ apps/scheduler/pool/loader/mcp.py | 1 - apps/scheduler/scheduler/scheduler.py | 2 +- apps/schemas/enum_var.py | 3 +++ apps/schemas/mcp.py | 3 ++- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 3 +++ apps/services/task.py | 9 --------- 10 files changed, 27 insertions(+), 41 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7fe5162c0..589000be3 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -36,11 +36,16 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - # 创建或还原Task + if post_body.new_task: + # 创建或还原Task + task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + if task: + await TaskManager.delete_task_by_task_id(task.id) task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) # 更改信息并刷新数据库 - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + if post_body.new_task: + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id return task diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 661e9ada7..4e6a1bb73 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -35,7 +35,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) - @classmethod def info(cls) -> CallInfo: """ @@ -46,7 +45,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """ return CallInfo(name="MCP", description="调用MCP Server,执行工具") - async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" # 获取MCP交互类 @@ -63,7 +61,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 @@ -80,7 +77,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): async for chunk in self._generate_answer(): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 @@ -103,7 +99,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final @@ -141,7 +136,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 @@ -163,7 +157,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): ).model_dump(), ) - def _create_output( self, text: str, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f6814dd36..2ff4f3d39 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -7,6 +7,8 @@ from pydantic import Field from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.schemas.task import ExecutorState, StepQueueItem +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -15,26 +17,14 @@ class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" question: str = Field(description="用户输入") - max_steps: int = Field(default=10, description="最大步数") + max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - async def run(self) -> None: - """运行MCP Agent""" - agent = await MCPAgent.create( - servers_id=self.servers_id, - max_steps=self.max_steps, - task=self.task, - msg_queue=self.msg_queue, - question=self.question, - agent_id=self.agent_id, - description=self.agent_description, - ) - - try: - answer = await agent.run(self.question) - self.task = agent.task - self.task.runtime.answer = answer - except Exception as e: - logger.error(f"Error: {str(e)}") + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 66a516e77..1463d0a1a 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -153,7 +153,6 @@ class MCPLoader(metaclass=SingletonMeta): # 检查目录 template_path = MCP_PATH / "template" / mcp_id await Path.mkdir(template_path, parents=True, exist_ok=True) - ProcessHandler.clear_finished_tasks() # 安装MCP模板 if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config): err = f"安装任务无法执行,请稍后重试: {mcp_id}" diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ed73638ce..417f93d28 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -206,7 +206,7 @@ class Scheduler: task=self.task, msg_queue=queue, question=post_body.question, - max_steps=app_metadata.history_len, + history_len=app_metadata.history_len, servers_id=servers_id, background=background, agent_id=app_info.app_id, diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84d..a84dc3a35 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" @@ -38,6 +39,8 @@ class EventType(str, Enum): TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" + STEP_WAITING_FOR_START = "step.waiting_for_start" + STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 44021b0ed..60c8f17b4 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" +import uuid from enum import Enum from typing import Any @@ -117,7 +118,7 @@ class MCPToolSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - + id: str = Field(default_factory=lambda: str(uuid.uuid4())) content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 5d6dc7fcc..a3a8848c3 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -46,6 +46,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + new_task: bool = Field(default=True, description="是否新建任务") class QuestionBlacklistRequest(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 8efcb5991..37fdebbfb 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from apps.schemas.enum_var import StepStatus from apps.schemas.flow import Step +from apps.schemas.mcp import MCPPlan class FlowStepHistory(BaseModel): @@ -42,6 +43,7 @@ class ExecutorState(BaseModel): # 附加信息 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -75,6 +77,7 @@ class TaskRuntime(BaseModel): summary: str = Field(description="摘要", default="") filled: dict[str, Any] = Field(description="填充的槽位", default={}) documents: list[dict[str, Any]] = Field(description="文档列表", default=[]) + temporary_plans: MCPPlan | None = Field(description="临时计划列表", default=None) class Task(BaseModel): diff --git a/apps/services/task.py b/apps/services/task.py index 1e672be69..2456d96bf 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -45,7 +45,6 @@ class TaskManager: return Task.model_validate(task) - @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" @@ -58,7 +57,6 @@ class TaskManager: task = await task_collection.find_one({"_id": record_group_obj.task_id}) return Task.model_validate(task) - @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" @@ -68,7 +66,6 @@ class TaskManager: return None return Task.model_validate(task) - @staticmethod async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: """根据record_group_id获取flow信息""" @@ -95,7 +92,6 @@ class TaskManager: else: return flow_context_list - @staticmethod async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: """根据task_id获取flow信息""" @@ -115,7 +111,6 @@ class TaskManager: else: return flow_context - @staticmethod async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: """保存flow信息到flow_context""" @@ -137,7 +132,6 @@ class TaskManager: except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") - @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" @@ -148,7 +142,6 @@ class TaskManager: if task: await task_collection.delete_one({"_id": task_id}) - @staticmethod async def delete_tasks_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" @@ -167,7 +160,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod async def get_task( cls, @@ -212,7 +204,6 @@ class TaskManager: runtime=TaskRuntime(), ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee From e4dfb413f4147e9cee2fb570a89b05f9fdc57331 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 25 Jul 2025 19:09:10 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E5=AE=8C=E5=96=84mcp=20agent=E7=9A=84?= =?UTF-8?q?=E5=BC=80=E5=8F=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 2 +- apps/routers/record.py | 21 +- apps/scheduler/call/mcp/mcp.py | 2 +- apps/scheduler/executor/agent.py | 2 +- apps/scheduler/executor/flow.py | 28 +- apps/scheduler/executor/step.py | 18 +- apps/scheduler/mcp/host.py | 9 +- apps/scheduler/mcp_agent/__init__.py | 8 + apps/scheduler/mcp_agent/agent/base.py | 196 -------------- apps/scheduler/mcp_agent/agent/mcp.py | 81 ------ apps/scheduler/mcp_agent/agent/react.py | 35 --- apps/scheduler/mcp_agent/agent/toolcall.py | 238 ----------------- apps/scheduler/mcp_agent/host.py | 190 ++++++++++++++ apps/scheduler/mcp_agent/plan.py | 110 ++++++++ apps/scheduler/mcp_agent/prompt.py | 240 ++++++++++++++++++ apps/scheduler/mcp_agent/schema.py | 148 ----------- apps/scheduler/mcp_agent/select.py | 185 ++++++++++++++ apps/scheduler/mcp_agent/tool/__init__.py | 9 - apps/scheduler/mcp_agent/tool/base.py | 73 ------ apps/scheduler/mcp_agent/tool/terminate.py | 25 -- .../mcp_agent/tool/tool_collection.py | 55 ---- apps/scheduler/scheduler/context.py | 2 +- apps/schemas/enum_var.py | 11 + apps/schemas/record.py | 1 + apps/schemas/task.py | 12 +- tests/common/test_queue.py | 4 +- 26 files changed, 792 insertions(+), 913 deletions(-) create mode 100644 apps/scheduler/mcp_agent/__init__.py delete mode 100644 apps/scheduler/mcp_agent/agent/base.py delete mode 100644 apps/scheduler/mcp_agent/agent/mcp.py delete mode 100644 apps/scheduler/mcp_agent/agent/react.py delete mode 100644 apps/scheduler/mcp_agent/agent/toolcall.py create mode 100644 apps/scheduler/mcp_agent/host.py create mode 100644 apps/scheduler/mcp_agent/plan.py create mode 100644 apps/scheduler/mcp_agent/prompt.py delete mode 100644 apps/scheduler/mcp_agent/schema.py create mode 100644 apps/scheduler/mcp_agent/select.py delete mode 100644 apps/scheduler/mcp_agent/tool/__init__.py delete mode 100644 apps/scheduler/mcp_agent/tool/base.py delete mode 100644 apps/scheduler/mcp_agent/tool/terminate.py delete mode 100644 apps/scheduler/mcp_agent/tool/tool_collection.py diff --git a/apps/common/queue.py b/apps/common/queue.py index 5601c93aa..911485b31 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -58,7 +58,7 @@ class MessageQueue: flowId=task.state.flow_id, stepId=task.state.step_id, stepName=task.state.step_name, - stepStatus=task.state.status, + stepStatus=task.state.step_status ) else: flow = None diff --git a/apps/routers/record.py b/apps/routers/record.py index 7384793b4..f357f0de9 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -83,22 +83,23 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) # 获得Record关联的flow数据 - flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) - if flow_list: - first_flow = FlowStepHistory.model_validate(flow_list[0]) + flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) + if flow_step_list: + first_step_history = FlowStepHistory.model_validate(flow_step_list[0]) tmp_record.flow = RecordFlow( - id=first_flow.flow_name, #TODO: 此处前端应该用name + id=first_step_history.flow_name, # TODO: 此处前端应该用name recordId=record.id, - flowId=first_flow.id, - stepNum=len(flow_list), + flowStatus=first_step_history.flow_status, + flowId=first_step_history.id, + stepNum=len(flow_step_list), steps=[], ) - for flow in flow_list: - flow_step = FlowStepHistory.model_validate(flow) + for flow_step in flow_step_list: + flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( RecordFlowStep( - stepId=flow_step.step_name, #TODO: 此处前端应该用name - stepStatus=flow_step.status, + stepId=flow_step.step_name, # TODO: 此处前端应该用name + stepStatus=flow_step.step_status, input=flow_step.input_data, output=flow_step.output_data, ), diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 4e6a1bb73..bd0257b49 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -31,7 +31,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """MCP工具""" mcp_list: list[str] = Field(description="MCP Server ID列表", max_length=5, min_length=1) - max_steps: int = Field(description="最大步骤数", default=6) + max_steps: int = Field(description="最大步骤数", default=20) text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 2ff4f3d39..cb8e183e9 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -6,7 +6,7 @@ import logging from pydantic import Field from apps.scheduler.executor.base import BaseExecutor -from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.scheduler.mcp_agent import host, plan, select from apps.schemas.task import ExecutorState, StepQueueItem from apps.services.task import TaskManager diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a70d0d707..e5381a4ac 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -11,7 +11,7 @@ from pydantic import Field from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.executor.step import StepExecutor -from apps.schemas.enum_var import EventType, SpecialCallType, StepStatus +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp from apps.schemas.task import ExecutorState, StepQueueItem @@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -59,8 +58,9 @@ class FlowExecutor(BaseExecutor): self.task.state = ExecutorState( flow_id=str(self.flow_id), flow_name=self.flow.name, + flow_status=FlowStatus.RUNNING, description=str(self.flow.description), - status=StepStatus.RUNNING, + step_status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", step_name="开始", @@ -70,7 +70,6 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" # 创建步骤Runner @@ -90,7 +89,6 @@ class FlowExecutor(BaseExecutor): # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: @@ -102,7 +100,6 @@ class FlowExecutor(BaseExecutor): # 执行Step await self._invoke_runner(queue_item) - async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" next_ids = [] @@ -111,14 +108,13 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -137,7 +133,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -150,8 +145,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -166,11 +161,11 @@ class FlowExecutor(BaseExecutor): # 插入首个步骤 self.step_queue.append(first_step) - + self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -183,13 +178,14 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), enable_filling=False, to_user=False, )) + self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] # 错误处理后结束 self._reached_end = True @@ -216,3 +212,5 @@ class FlowExecutor(BaseExecutor): self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 await self.push_message(EventType.FLOW_STOP.value) + # 更新Task状态 + self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 506f3bb10..377a4c6e5 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -133,7 +132,7 @@ class StepExecutor(BaseExecutor): # 更新State self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,9 +155,9 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.step_status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 @@ -170,7 +169,6 @@ class StepExecutor(BaseExecutor): self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor): return content - async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) @@ -212,7 +209,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,7 +221,7 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.step_status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): self.task.state.error_info = { # type: ignore[arg-type] @@ -239,7 +236,7 @@ class StepExecutor(BaseExecutor): return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -255,10 +252,11 @@ class StepExecutor(BaseExecutor): task_id=self.task.id, flow_id=self.task.state.flow_id, # type: ignore[arg-type] flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_status=self.task.state.flow_status, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + step_status=self.task.state.step_status, input_data=self.obj.input, output_data=output_data, ) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 78aa7bc3e..acdd4871e 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -40,7 +40,6 @@ class MCPHost: lstrip_blocks=True, ) - async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() @@ -59,7 +58,6 @@ class MCPHost: logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) return None - async def assemble_memory(self) -> str: """组装记忆""" task = await TaskManager.get_task_by_task_id(self._task_id) @@ -78,7 +76,6 @@ class MCPHost: context_list=context_list, ) - async def _save_memory( self, tool: MCPTool, @@ -105,11 +102,12 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, + flow_status=StepStatus.SUCCESS, step_id=tool.name, step_name=tool.name, # description是规划的实际内容 step_description=plan_item.content, - status=StepStatus.SUCCESS, + step_status=StepStatus.SUCCESS, input_data=input_data, output_data=output_data, ) @@ -125,7 +123,6 @@ class MCPHost: return output_data - async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate @@ -146,7 +143,6 @@ class MCPHost: ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client @@ -170,7 +166,6 @@ class MCPHost: return processed_result - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: """获取工具列表""" mongo = MongoDB() diff --git a/apps/scheduler/mcp_agent/__init__.py b/apps/scheduler/mcp_agent/__init__.py new file mode 100644 index 000000000..12f5cb68c --- /dev/null +++ b/apps/scheduler/mcp_agent/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Scheduler MCP 模块""" + +from apps.scheduler.mcp.host import MCPHost +from apps.scheduler.mcp.plan import MCPPlanner +from apps.scheduler.mcp.select import MCPSelector + +__all__ = ["MCPHost", "MCPPlanner", "MCPSelector"] diff --git a/apps/scheduler/mcp_agent/agent/base.py b/apps/scheduler/mcp_agent/agent/base.py deleted file mode 100644 index eccb58a9c..000000000 --- a/apps/scheduler/mcp_agent/agent/base.py +++ /dev/null @@ -1,196 +0,0 @@ -"""MCP Agent基类""" -import logging -from abc import ABC, abstractmethod -from contextlib import asynccontextmanager - -from pydantic import BaseModel, Field, model_validator - -from apps.common.queue import MessageQueue -from apps.schemas.enum_var import AgentState -from apps.schemas.task import Task -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.schema import Memory, Message, Role -from apps.services.activity import Activity - -logger = logging.getLogger(__name__) - - -class BaseAgent(BaseModel, ABC): - """ - 用于管理代理状态和执行的抽象基类。 - - 为状态转换、内存管理、 - 以及分步执行循环。子类必须实现`step`方法。 - """ - - msg_queue: MessageQueue - task: Task - name: str = Field(..., description="Agent名称") - agent_id: str = Field(default="", description="Agent ID") - description: str = Field(default="", description="Agent描述") - question: str - # Prompts - next_step_prompt: str | None = Field( - None, description="判断下一步动作的提示" - ) - - # Dependencies - llm: ReasoningLLM = Field(default_factory=ReasoningLLM, description="大模型实例") - memory: Memory = Field(default_factory=Memory, description="Agent记忆库") - state: AgentState = Field( - default=AgentState.IDLE, description="Agent状态" - ) - servers_id: list[str] = Field(default_factory=list, description="MCP server id") - - # Execution control - max_steps: int = Field(default=10, description="终止前的最大步长") - current_step: int = Field(default=0, description="执行中的当前步骤") - - duplicate_threshold: int = 2 - - user_prompt: str = r""" - 当前步骤:{step} 工具输出结果:{result} - 请总结当前正在执行的步骤和对应的工具输出结果,内容包括当前步骤是多少,执行的工具是什么,输出是什么。 - 最终以报告的形式展示。 - 如果工具输出结果中执行的工具为terminate,请按照状态输出本次交互过程最终结果并完成对整个报告的总结,不需要输出你的分析过程。 - """ - """用户提示词""" - - class Config: - arbitrary_types_allowed = True - extra = "allow" # Allow extra fields for flexibility in subclasses - - @model_validator(mode="after") - def initialize_agent(self) -> "BaseAgent": - """初始化Agent""" - if self.llm is None or not isinstance(self.llm, ReasoningLLM): - self.llm = ReasoningLLM() - if not isinstance(self.memory, Memory): - self.memory = Memory() - return self - - @asynccontextmanager - async def state_context(self, new_state: AgentState): - """ - Agent状态转换上下文管理器 - - Args: - new_state: 要转变的状态 - - :return: None - :raise ValueError: 如果new_state无效 - """ - if not isinstance(new_state, AgentState): - raise ValueError(f"无效状态: {new_state}") - - previous_state = self.state - self.state = new_state - try: - yield - except Exception as e: - self.state = AgentState.ERROR # Transition to ERROR on failure - raise e - finally: - self.state = previous_state # Revert to previous state - - def update_memory( - self, - role: Role, - content: str, - **kwargs, - ) -> None: - """添加信息到Agent的memory中""" - message_map = { - "user": Message.user_message, - "system": Message.system_message, - "assistant": Message.assistant_message, - "tool": lambda content, **kw: Message.tool_message(content, **kw), - } - - if role not in message_map: - raise ValueError(f"不支持的消息角色: {role}") - - # Create message with appropriate parameters based on role - kwargs = {**(kwargs if role == "tool" else {})} - self.memory.add_message(message_map[role](content, **kwargs)) - - async def run(self, request: str | None = None) -> str: - """异步执行Agent的主循环""" - self.task.runtime.question = request - if self.state != AgentState.IDLE: - raise RuntimeError(f"无法从以下状态运行智能体: {self.state}") - - if request: - self.update_memory("user", request) - - results: list[str] = [] - async with self.state_context(AgentState.RUNNING): - while ( - self.current_step < self.max_steps and self.state != AgentState.FINISHED - ): - if not await Activity.is_active(self.task.ids.user_sub): - logger.info("用户终止会话,任务停止!") - return "" - self.current_step += 1 - logger.info(f"执行步骤{self.current_step}/{self.max_steps}") - step_result = await self.step() - - # Check for stuck state - if self.is_stuck(): - self.handle_stuck_state() - result = f"Step {self.current_step}: {step_result}" - results.append(result) - - if self.current_step >= self.max_steps: - self.current_step = 0 - self.state = AgentState.IDLE - result = f"任务终止: 已达到最大步数 ({self.max_steps})" - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": result}, # type: ignore[arg-type] - ) - results.append(result) - return "\n".join(results) if results else "未执行任何步骤" - - @abstractmethod - async def step(self) -> str: - """ - 执行代理工作流程中的单个步骤。 - - 必须由子类实现,以定义具体的行为。 - """ - - def handle_stuck_state(self): - """通过添加更改策略的提示来处理卡住状态""" - stuck_prompt = "\ - 观察到重复响应。考虑新策略,避免重复已经尝试过的无效路径" - self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" - logger.warning(f"检测到智能体处于卡住状态。新增提示:{stuck_prompt}") - - def is_stuck(self) -> bool: - """通过检测重复内容来检查代理是否卡在循环中""" - if len(self.memory.messages) < 2: - return False - - last_message = self.memory.messages[-1] - if not last_message.content: - return False - - duplicate_count = sum( - 1 - for msg in reversed(self.memory.messages[:-1]) - if msg.role == "assistant" and msg.content == last_message.content - ) - - return duplicate_count >= self.duplicate_threshold - - @property - def messages(self) -> list[Message]: - """从Agent memory中检索消息列表""" - return self.memory.messages - - @messages.setter - def messages(self, value: list[Message]) -> None: - """设置Agent memory的消息列表""" - self.memory.messages = value diff --git a/apps/scheduler/mcp_agent/agent/mcp.py b/apps/scheduler/mcp_agent/agent/mcp.py deleted file mode 100644 index 378da368a..000000000 --- a/apps/scheduler/mcp_agent/agent/mcp.py +++ /dev/null @@ -1,81 +0,0 @@ -"""MCP Agent""" -import logging - -from pydantic import Field - -from apps.scheduler.mcp.host import MCPHost -from apps.scheduler.mcp_agent.agent.toolcall import ToolCallAgent -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class MCPAgent(ToolCallAgent): - """ - 用于与MCP(模型上下文协议)服务器交互。 - - 使用SSE或stdio传输连接到MCP服务器 - 并使服务器的工具 - """ - - name: str = "MCPAgent" - description: str = "一个多功能的智能体,能够使用多种工具(包括基于MCP的工具)解决各种任务" - - # Add general-purpose tools to the tool collection - available_tools: ToolCollection = Field( - default_factory=lambda: ToolCollection( - Terminate(), - ), - ) - - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - _initialized: bool = False - - @classmethod - async def create(cls, **kwargs) -> "MCPAgent": # noqa: ANN003 - """创建并初始化MCP Agent实例""" - instance = cls(**kwargs) - await instance.initialize_mcp_servers() - instance._initialized = True - return instance - - async def initialize_mcp_servers(self) -> None: - """初始化与已配置的MCP服务器的连接""" - mcp_host = MCPHost( - self.task.ids.user_sub, - self.task.id, - self.agent_id, - self.description, - ) - mcps = {} - for mcp_id in self.servers_id: - client = await mcp_host.get_client(mcp_id) - if client: - mcps[mcp_id] = client - - for mcp_id, mcp_client in mcps.items(): - new_tools = [] - for tool in mcp_client.tools: - original_name = tool.name - # Always prefix with server_id to ensure uniqueness - tool_name = f"mcp_{mcp_id}_{original_name}" - - server_tool = MCPClientTool( - name=tool_name, - description=tool.description, - parameters=tool.inputSchema, - session=mcp_client.session, - server_id=mcp_id, - original_name=original_name, - ) - new_tools.append(server_tool) - self.available_tools.add_tools(*new_tools) - - async def think(self) -> bool: - """使用适当的上下文处理当前状态并决定下一步操作""" - if not self._initialized: - await self.initialize_mcp_servers() - self._initialized = True - - return await super().think() diff --git a/apps/scheduler/mcp_agent/agent/react.py b/apps/scheduler/mcp_agent/agent/react.py deleted file mode 100644 index b56efd8b1..000000000 --- a/apps/scheduler/mcp_agent/agent/react.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC, abstractmethod - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.agent.base import BaseAgent -from apps.scheduler.mcp_agent.schema import Memory - - -class ReActAgent(BaseAgent, ABC): - name: str - description: str | None = None - - system_prompt: str | None = None - next_step_prompt: str | None = None - - llm: ReasoningLLM | None = Field(default_factory=ReasoningLLM) - memory: Memory = Field(default_factory=Memory) - state: AgentState = AgentState.IDLE - - @abstractmethod - async def think(self) -> bool: - """处理当前状态并决定下一步操作""" - - @abstractmethod - async def act(self) -> str: - """执行已决定的行动""" - - async def step(self) -> str: - """执行一个步骤:思考和行动""" - should_act = await self.think() - if not should_act: - return "思考完成-无需采取任何行动" - return await self.act() diff --git a/apps/scheduler/mcp_agent/agent/toolcall.py b/apps/scheduler/mcp_agent/agent/toolcall.py deleted file mode 100644 index 1e22099ce..000000000 --- a/apps/scheduler/mcp_agent/agent/toolcall.py +++ /dev/null @@ -1,238 +0,0 @@ -import asyncio -import json -import logging -from typing import Any, Optional - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.function import JsonGenerator -from apps.llm.patterns import Select -from apps.scheduler.mcp_agent.agent.react import ReActAgent -from apps.scheduler.mcp_agent.schema import Function, Message, ToolCall -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class ToolCallAgent(ReActAgent): - """用于处理工具/函数调用的基本Agent类""" - - name: str = "toolcall" - description: str = "可以执行工具调用的智能体" - - available_tools: ToolCollection = ToolCollection( - Terminate(), - ) - tool_choices: str = "auto" - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - tool_calls: list[ToolCall] = Field(default_factory=list) - _current_base64_image: str | None = None - - max_observe: int | bool | None = None - - async def think(self) -> bool: - """使用工具处理当前状态并决定下一步行动""" - messages = [] - for message in self.messages: - if isinstance(message, Message): - message = message.to_dict() - messages.append(message) - try: - # 通过工具获得响应 - select_obj = Select() - choices = [] - for available_tool in self.available_tools.to_params(): - choices.append(available_tool.get("function")) - - tool = await select_obj.generate(question=self.question, choices=choices) - if tool in self.available_tools.tool_map: - schema = self.available_tools.tool_map[tool].parameters - json_generator = JsonGenerator( - query="根据跟定的信息,获取工具参数", - conversation=messages, - schema=schema, - ) # JsonGenerator - parameters = await json_generator.generate() - - else: - raise ValueError(f"尝试调用不存在的工具: {tool}") - except Exception as e: - raise - self.tool_calls = tool_calls = [ToolCall(id=tool, function=Function(name=tool, arguments=parameters))] - content = f"选择的执行工具为:{tool}, 参数为{parameters}" - - logger.info( - f"{self.name} 选择 {len(tool_calls) if tool_calls else 0}个工具执行" - ) - if tool_calls: - logger.info( - f"准备使用的工具: {[call.function.name for call in tool_calls]}" - ) - logger.info(f"工具参数: {tool_calls[0].function.arguments}") - - try: - - assistant_msg = ( - Message.from_tool_calls(content=content, tool_calls=self.tool_calls) - if self.tool_calls - else Message.assistant_message(content) - ) - self.memory.add_message(assistant_msg) - - if not self.tool_calls: - return bool(content) - - return bool(self.tool_calls) - except Exception as e: - logger.error(f"{self.name}的思考过程遇到了问题:: {e}") - self.memory.add_message( - Message.assistant_message( - f"处理时遇到错误: {str(e)}" - ) - ) - return False - - async def act(self) -> str: - """执行工具调用并处理其结果""" - if not self.tool_calls: - # 如果没有工具调用,则返回最后的消息内容 - return self.messages[-1].content or "没有要执行的内容或命令" - - results = [] - for command in self.tool_calls: - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"正在执行工具{command.function.name}"} - ) - - self._current_base64_image = None - - result = await self.execute_tool(command) - - if self.max_observe: - result = result[: self.max_observe] - - push_result = "" - async for chunk in self.llm.call( - messages=[{"role": "system", "content": "You are a helpful asistant."}, - {"role": "user", "content": self.user_prompt.format( - step=self.current_step, - result=result, - )}, ], streaming=False - ): - push_result += chunk - self.task.tokens.input_tokens += self.llm.input_tokens - self.task.tokens.output_tokens += self.llm.output_tokens - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": push_result}, # type: ignore[arg-type] - ) - - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"工具{command.function.name}执行完成"}, # type: ignore[arg-type] - ) - - logger.info( - f"工具'{command.function.name}'执行完成! 执行结果为: {result}" - ) - - # 将工具响应添加到内存 - tool_msg = Message.tool_message( - content=result, - tool_call_id=command.id, - name=command.function.name, - ) - self.memory.add_message(tool_msg) - results.append(result) - self.question += ( - f"\n已执行工具{command.function.name}, " - f"作用是{self.available_tools.tool_map[command.function.name].description},结果为{result}" - ) - - return "\n\n".join(results) - - async def execute_tool(self, command: ToolCall) -> str: - """执行单个工具调用""" - if not command or not command.function or not command.function.name: - return "错误:无效的命令格式" - - name = command.function.name - if name not in self.available_tools.tool_map: - return f"错误:未知工具 '{name}'" - - try: - # 解析参数 - args = command.function.arguments - # 执行工具 - logger.info(f"激活工具:'{name}'...") - result = await self.available_tools.execute(name=name, tool_input=args) - - # 执行特殊工具 - await self._handle_special_tool(name=name, result=result) - - # 格式化结果 - observation = ( - f"观察到执行的工具 `{name}`的输出:\n{str(result)}" - if result - else f"工具 `{name}` 已完成,无输出" - ) - - return observation - except json.JSONDecodeError: - error_msg = f"解析{name}的参数时出错:JSON格式无效" - logger.error( - f"{name}”的参数没有意义-无效的JSON,参数:{command.function.arguments}" - ) - return f"错误: {error_msg}" - except Exception as e: - error_msg = f"工具 '{name}' 遇到问题: {str(e)}" - logger.exception(error_msg) - return f"错误: {error_msg}" - - async def _handle_special_tool(self, name: str, result: Any, **kwargs): - """处理特殊工具的执行和状态变化""" - if not self._is_special_tool(name): - return - - if self._should_finish_execution(name=name, result=result, **kwargs): - # 将智能体状态设为finished - logger.info(f"特殊工具'{name}'已完成任务!") - self.state = AgentState.FINISHED - - @staticmethod - def _should_finish_execution(**kwargs) -> bool: - """确定工具执行是否应完成""" - return True - - def _is_special_tool(self, name: str) -> bool: - """检查工具名称是否在特殊工具列表中""" - return name.lower() in [n.lower() for n in self.special_tool_names] - - async def cleanup(self): - """清理Agent工具使用的资源。""" - logger.info(f"正在清理智能体的资源'{self.name}'...") - for tool_name, tool_instance in self.available_tools.tool_map.items(): - if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( - tool_instance.cleanup - ): - try: - logger.debug(f"清理工具: {tool_name}") - await tool_instance.cleanup() - except Exception as e: - logger.error( - f"清理工具时发生错误'{tool_name}': {e}", exc_info=True - ) - logger.info(f"智能体清理完成'{self.name}'.") - - async def run(self, request: Optional[str] = None) -> str: - """运行Agent""" - try: - return await super().run(request) - finally: - await self.cleanup() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py new file mode 100644 index 000000000..acdd4871e --- /dev/null +++ b/apps/scheduler/mcp_agent/host.py @@ -0,0 +1,190 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP宿主""" + +import json +import logging +from typing import Any + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment +from mcp.types import TextContent + +from apps.common.mongo import MongoDB +from apps.llm.function import JsonGenerator +from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE +from apps.scheduler.pool.mcp.client import MCPClient +from apps.scheduler.pool.mcp.pool import MCPPool +from apps.schemas.enum_var import StepStatus +from apps.schemas.mcp import MCPPlanItem, MCPTool +from apps.schemas.task import FlowStepHistory +from apps.services.task import TaskManager + +logger = logging.getLogger(__name__) + + +class MCPHost: + """MCP宿主服务""" + + def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None: + """初始化MCP宿主""" + self._user_sub = user_sub + self._task_id = task_id + # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description + self._runtime_id = runtime_id + self._runtime_name = runtime_name + self._context_list = [] + self._env = SandboxedEnvironment( + loader=BaseLoader(), + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + ) + + async def get_client(self, mcp_id: str) -> MCPClient | None: + """获取MCP客户端""" + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") + + # 检查用户是否启用了这个mcp + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + if not mcp_db_result: + logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + return None + + # 获取MCP配置 + try: + return await MCPPool().get(mcp_id, self._user_sub) + except KeyError: + logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) + return None + + async def assemble_memory(self) -> str: + """组装记忆""" + task = await TaskManager.get_task_by_task_id(self._task_id) + if not task: + logger.error("任务 %s 不存在", self._task_id) + return "" + + context_list = [] + for ctx_id in self._context_list: + context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + if not context: + continue + context_list.append(context) + + return self._env.from_string(MEMORY_TEMPLATE).render( + context_list=context_list, + ) + + async def _save_memory( + self, + tool: MCPTool, + plan_item: MCPPlanItem, + input_data: dict[str, Any], + result: str, + ) -> dict[str, Any]: + """保存记忆""" + try: + output_data = json.loads(result) + except Exception: # noqa: BLE001 + logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str") + output_data = { + "message": result, + } + + if not isinstance(output_data, dict): + output_data = { + "message": result, + } + + # 创建context;注意用法 + context = FlowStepHistory( + task_id=self._task_id, + flow_id=self._runtime_id, + flow_name=self._runtime_name, + flow_status=StepStatus.SUCCESS, + step_id=tool.name, + step_name=tool.name, + # description是规划的实际内容 + step_description=plan_item.content, + step_status=StepStatus.SUCCESS, + input_data=input_data, + output_data=output_data, + ) + + # 保存到task + task = await TaskManager.get_task_by_task_id(self._task_id) + if not task: + logger.error("任务 %s 不存在", self._task_id) + return {} + self._context_list.append(context.id) + task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + await TaskManager.save_task(self._task_id, task) + + return output_data + + async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: + """填充工具参数""" + # 更清晰的输入·指令,这样可以调用generate + llm_query = rf""" + 请使用参数生成工具,生成满足以下目标的工具参数: + + {query} + """ + + # 进行生成 + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": await self.assemble_memory()}, + ], + tool.input_schema, + ) + return await json_generator.generate() + + async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: + """调用工具""" + # 拿到Client + client = await MCPPool().get(tool.mcp_id, self._user_sub) + if client is None: + err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}" + logger.error(err) + raise ValueError(err) + + # 填充参数 + params = await self._fill_params(tool, plan_item.instruction) + # 调用工具 + result = await client.call_tool(tool.name, params) + # 保存记忆 + processed_result = [] + for item in result.content: + if not isinstance(item, TextContent): + logger.error("MCP结果类型不支持: %s", item) + continue + processed_result.append(await self._save_memory(tool, plan_item, params, item.text)) + + return processed_result + + async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: + """获取工具列表""" + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") + + # 获取工具列表 + tool_list = [] + for mcp_id in mcp_id_list: + # 检查用户是否启用了这个mcp + mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) + if not mcp_db_result: + logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + continue + # 获取MCP工具配置 + try: + for tool in mcp_db_result["tools"]: + tool_list.extend([MCPTool.model_validate(tool)]) + except KeyError: + logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) + continue + + return tool_list diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py new file mode 100644 index 000000000..cd4f5975e --- /dev/null +++ b/apps/scheduler/mcp_agent/plan.py @@ -0,0 +1,110 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP 用户目标拆解与规划""" + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm.function import JsonGenerator +from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER +from apps.schemas.mcp import MCPPlan, MCPTool + + +class MCPPlanner: + """MCP 用户目标拆解与规划""" + + def __init__(self, user_goal: str) -> None: + """初始化MCP规划器""" + self.user_goal = user_goal + self._env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, + ) + self.input_tokens = 0 + self.output_tokens = 0 + + + async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: + """规划下一步的执行流程,并输出""" + # 获取推理结果 + result = await self._get_reasoning_plan(tool_list, max_steps) + + # 解析为结构化数据 + return await self._parse_plan_result(result, max_steps) + + + async def _get_reasoning_plan(self, tool_list: list[MCPTool], max_steps: int) -> str: + """获取推理大模型的结果""" + # 格式化Prompt + template = self._env.from_string(CREATE_PLAN) + prompt = template.render( + goal=self.user_goal, + tools=tool_list, + max_num=max_steps, + ) + + # 调用推理大模型 + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + reasoning_llm = ReasoningLLM() + result = "" + async for chunk in reasoning_llm.call( + message, + streaming=False, + temperature=0.07, + result_only=True, + ): + result += chunk + + # 保存token用量 + self.input_tokens = reasoning_llm.input_tokens + self.output_tokens = reasoning_llm.output_tokens + + return result + + + async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: + """将推理结果解析为结构化数据""" + # 格式化Prompt + schema = MCPPlan.model_json_schema() + schema["properties"]["plans"]["maxItems"] = max_steps + + # 使用Function模型解析结果 + json_generator = JsonGenerator( + result, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": result}, + ], + schema, + ) + plan = await json_generator.generate() + return MCPPlan.model_validate(plan) + + + async def generate_answer(self, plan: MCPPlan, memory: str) -> str: + """生成最终回答""" + template = self._env.from_string(FINAL_ANSWER) + prompt = template.render( + plan=plan, + memory=memory, + goal=self.user_goal, + ) + + llm = ReasoningLLM() + result = "" + async for chunk in llm.call( + [{"role": "user", "content": prompt}], + streaming=False, + temperature=0.07, + ): + result += chunk + + self.input_tokens = llm.input_tokens + self.output_tokens = llm.output_tokens + + return result diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py new file mode 100644 index 000000000..b322fb088 --- /dev/null +++ b/apps/scheduler/mcp_agent/prompt.py @@ -0,0 +1,240 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP相关的大模型Prompt""" + +from textwrap import dedent + +MCP_SELECT = dedent(r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,选择最合适的MCP Server。 + + ## 选择MCP Server时的注意事项: + + 1. 确保充分理解当前目标,选择最合适的MCP Server。 + 2. 请在给定的MCP Server列表中选择,不要自己生成MCP Server。 + 3. 请先给出你选择的理由,再给出你的选择。 + 4. 当前目标将在下面给出,MCP Server列表也会在下面给出。 + 请将你的思考过程放在"思考过程"部分,将你的选择放在"选择结果"部分。 + 5. 选择必须是JSON格式,严格按照下面的模板,不要输出任何其他内容: + + ```json + { + "mcp": "你选择的MCP Server的名称" + } + ``` + + 6. 下面的示例仅供参考,不要将示例中的内容作为选择MCP Server的依据。 + + ## 示例 + + ### 目标 + + 我需要一个MCP Server来完成一个任务。 + + ### MCP Server列表 + + - **mcp_1**: "MCP Server 1";MCP Server 1的描述 + - **mcp_2**: "MCP Server 2";MCP Server 2的描述 + + ### 请一步一步思考: + + 因为当前目标需要一个MCP Server来完成一个任务,所以选择mcp_1。 + + ### 选择结果 + + ```json + { + "mcp": "mcp_1" + } + ``` + + ## 现在开始! + + ### 目标 + + {{goal}} + + ### MCP Server列表 + + {% for mcp in mcp_list %} + - **{{mcp.id}}**: "{{mcp.name}}";{{mcp.description}} + {% endfor %} + + ### 请一步一步思考: + +""") +CREATE_PLAN = dedent(r""" + 你是一个计划生成器。 + 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + - Final结束步骤,当执行到这一步时,\ +表示计划执行结束,所得到的结果将作为最终结果。 + + + # 样例 + + ## 目标 + + 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + + ## 计划 + + + 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server + 2. 目标可以拆解为以下几个部分: + - 运行alpine:latest容器 + - 挂载主机目录 + - 在后台运行 + - 执行top命令 + 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令 + + + ```json + { + "plans": [ + { + "content": "选择一个支持Docker的MCP Server", + "tool": "mcp_selector", + "instruction": "需要一个支持Docker容器运行的MCP Server" + }, + { + "content": "使用Result[0]中选择的MCP Server,生成Docker命令", + "tool": "command_generator", + "instruction": "生成Docker命令:在后台运行alpine:latest容器,挂载/root到/data,执行top命令" + }, + { + "content": "在Result[0]的MCP Server上执行Result[1]生成的命令", + "tool": "command_executor", + "instruction": "执行Docker命令" + }, + { + "content": "任务执行完成,容器已在后台运行,结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始生成计划: + + ## 目标 + + {{goal}} + + # 计划 +""") +EVALUATE_PLAN = dedent(r""" + 你是一个计划评估器。 + 请根据给定的计划,和当前计划执行的实际情况,分析当前计划是否合理和完整,并生成改进后的计划。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + + # 你此前的计划是: + + {{ plan }} + + # 这个计划的执行情况是: + + 计划的执行情况将放置在 XML标签中。 + + + {{ memory }} + + + # 进行评估时的注意事项: + + - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。 + - 评估结果分为两个部分: + - 计划评估的结论 + - 改进后的计划 + - 请按照以下JSON格式输出评估结果: + + ```json + { + "evaluation": "评估结果", + "plans": [ + { + "content": "改进后的计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + # 现在开始评估计划: + +""") +FINAL_ANSWER = dedent(r""" + 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 + + # 用户目标 + + {{ goal }} + + # 计划执行情况 + + 为了完成上述目标,你实施了以下计划: + + {{ memory }} + + # 其他背景信息: + + {{ status }} + + # 现在,请根据以上信息,向用户报告目标的完成情况: + +""") +MEMORY_TEMPLATE = dedent(r""" + {% for ctx in context_list %} + - 第{{ loop.index }}步:{{ ctx.step_description }} + - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}` + - 执行状态:{{ ctx.status }} + - 得到数据:`{{ ctx.output_data }}` + {% endfor %} +""") diff --git a/apps/scheduler/mcp_agent/schema.py b/apps/scheduler/mcp_agent/schema.py deleted file mode 100644 index 614139074..000000000 --- a/apps/scheduler/mcp_agent/schema.py +++ /dev/null @@ -1,148 +0,0 @@ -"""MCP Agent执行数据结构""" -from typing import Any, Self - -from pydantic import BaseModel, Field - -from apps.schemas.enum_var import Role - - -class Function(BaseModel): - """工具函数""" - - name: str - arguments: dict[str, Any] - - -class ToolCall(BaseModel): - """Represents a tool/function call in a message""" - - id: str - type: str = "function" - function: Function - - -class Message(BaseModel): - """Represents a chat message in the conversation""" - - role: Role = Field(...) - content: str | None = Field(default=None) - tool_calls: list[ToolCall] | None = Field(default=None) - name: str | None = Field(default=None) - tool_call_id: str | None = Field(default=None) - - def __add__(self, other) -> list["Message"]: - """支持 Message + list 或 Message + Message 的操作""" - if isinstance(other, list): - return [self] + other - elif isinstance(other, Message): - return [self, other] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" - ) - - def __radd__(self, other) -> list["Message"]: - """支持 list + Message 的操作""" - if isinstance(other, list): - return other + [self] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" - ) - - def to_dict(self) -> dict: - """Convert message to dictionary format""" - message = {"role": self.role} - if self.content is not None: - message["content"] = self.content - if self.tool_calls is not None: - message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] - if self.name is not None: - message["name"] = self.name - if self.tool_call_id is not None: - message["tool_call_id"] = self.tool_call_id - return message - - @classmethod - def user_message(cls, content: str) -> Self: - """Create a user message""" - return cls(role=Role.USER, content=content) - - @classmethod - def system_message(cls, content: str) -> Self: - """Create a system message""" - return cls(role=Role.SYSTEM, content=content) - - @classmethod - def assistant_message( - cls, content: str | None = None, - ) -> Self: - """Create an assistant message""" - return cls(role=Role.ASSISTANT, content=content) - - @classmethod - def tool_message( - cls, content: str, name: str, tool_call_id: str, - ) -> Self: - """Create a tool message""" - return cls( - role=Role.TOOL, - content=content, - name=name, - tool_call_id=tool_call_id, - ) - - @classmethod - def from_tool_calls( - cls, - tool_calls: list[Any], - content: str | list[str] = "", - **kwargs, # noqa: ANN003 - ) -> Self: - """Create ToolCallsMessage from raw tool calls. - - Args: - tool_calls: Raw tool calls from LLM - content: Optional message content - """ - formatted_calls = [ - {"id": call.id, "function": call.function.model_dump(), "type": "function"} - for call in tool_calls - ] - return cls( - role=Role.ASSISTANT, - content=content, - tool_calls=formatted_calls, - **kwargs, - ) - - -class Memory(BaseModel): - messages: list[Message] = Field(default_factory=list) - max_messages: int = Field(default=100) - - def add_message(self, message: Message) -> None: - """Add a message to memory""" - self.messages.append(message) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def add_messages(self, messages: list[Message]) -> None: - """Add multiple messages to memory""" - self.messages.extend(messages) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def clear(self) -> None: - """Clear all messages""" - self.messages.clear() - - def get_recent_messages(self, n: int) -> list[Message]: - """Get n most recent messages""" - return self.messages[-n:] - - def to_dict_list(self) -> list[dict]: - """Convert messages to list of dicts""" - return [msg.to_dict() for msg in self.messages] diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py new file mode 100644 index 000000000..2ff503447 --- /dev/null +++ b/apps/scheduler/mcp_agent/select.py @@ -0,0 +1,185 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""选择MCP Server及其工具""" + +import logging + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.common.lance import LanceDB +from apps.common.mongo import MongoDB +from apps.llm.embedding import Embedding +from apps.llm.function import FunctionLLM +from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.mcp.prompt import ( + MCP_SELECT, +) +from apps.schemas.mcp import ( + MCPCollection, + MCPSelectResult, + MCPTool, +) + +logger = logging.getLogger(__name__) + + +class MCPSelector: + """MCP选择器""" + + def __init__(self) -> None: + """初始化助手类""" + self.input_tokens = 0 + self.output_tokens = 0 + + @staticmethod + def _assemble_sql(mcp_list: list[str]) -> str: + """组装SQL""" + sql = "(" + for mcp_id in mcp_list: + sql += f"'{mcp_id}', " + return sql.rstrip(", ") + ")" + + + async def _get_top_mcp_by_embedding( + self, + query: str, + mcp_list: list[str], + ) -> list[dict[str, str]]: + """通过向量检索获取Top5 MCP Server""" + logger.info("[MCPHelper] 查询MCP Server向量: %s, %s", query, mcp_list) + mcp_table = await LanceDB().get_table("mcp") + query_embedding = await Embedding.get_embedding([query]) + mcp_vecs = await (await mcp_table.search( + query=query_embedding, + vector_column_name="embedding", + )).where(f"id IN {MCPSelector._assemble_sql(mcp_list)}").limit(5).to_list() + + # 拿到名称和description + logger.info("[MCPHelper] 查询MCP Server名称和描述: %s", mcp_vecs) + mcp_collection = MongoDB().get_collection("mcp") + llm_mcp_list: list[dict[str, str]] = [] + for mcp_vec in mcp_vecs: + mcp_id = mcp_vec["id"] + mcp_data = await mcp_collection.find_one({"_id": mcp_id}) + if not mcp_data: + logger.warning("[MCPHelper] 查询MCP Server名称和描述失败: %s", mcp_id) + continue + mcp_data = MCPCollection.model_validate(mcp_data) + llm_mcp_list.extend([{ + "id": mcp_id, + "name": mcp_data.name, + "description": mcp_data.description, + }]) + return llm_mcp_list + + + async def _get_mcp_by_llm( + self, + query: str, + mcp_list: list[dict[str, str]], + mcp_ids: list[str], + ) -> MCPSelectResult: + """通过LLM选择最合适的MCP Server""" + # 初始化jinja2环境 + env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, + ) + template = env.from_string(MCP_SELECT) + # 渲染模板 + mcp_prompt = template.render( + mcp_list=mcp_list, + goal=query, + ) + + # 调用大模型进行推理 + result = await self._call_reasoning(mcp_prompt) + + # 使用小模型提取JSON + return await self._call_function_mcp(result, mcp_ids) + + + async def _call_reasoning(self, prompt: str) -> str: + """调用大模型进行推理""" + logger.info("[MCPHelper] 调用推理大模型") + llm = ReasoningLLM() + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + result = "" + async for chunk in llm.call(message): + result += chunk + self.input_tokens += llm.input_tokens + self.output_tokens += llm.output_tokens + return result + + + async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: + """调用结构化输出小模型提取JSON""" + logger.info("[MCPHelper] 调用结构化输出小模型") + llm = FunctionLLM() + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": reasoning_result}, + ] + schema = MCPSelectResult.model_json_schema() + # schema中加入选项 + schema["properties"]["mcp_id"]["enum"] = mcp_ids + result = await llm.call(messages=message, schema=schema) + try: + result = MCPSelectResult.model_validate(result) + except Exception: + logger.exception("[MCPHelper] 解析MCP Select Result失败") + raise + return result + + + async def select_top_mcp( + self, + query: str, + mcp_list: list[str], + ) -> MCPSelectResult: + """ + 选择最合适的MCP Server + + 先通过Embedding选择Top5,然后通过LLM选择Top 1 + """ + # 通过向量检索获取Top5 + llm_mcp_list = await self._get_top_mcp_by_embedding(query, mcp_list) + + # 通过LLM选择最合适的 + return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) + + + @staticmethod + async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: + """选择最合适的工具""" + tool_vector = await LanceDB().get_table("mcp_tool") + query_embedding = await Embedding.get_embedding([query]) + tool_vecs = await (await tool_vector.search( + query=query_embedding, + vector_column_name="embedding", + )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list() + + # 拿到工具 + tool_collection = MongoDB().get_collection("mcp") + llm_tool_list = [] + + for tool_vec in tool_vecs: + # 到MongoDB里找对应的工具 + logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) + tool_data = await tool_collection.aggregate([ + {"$match": {"_id": tool_vec["mcp_id"]}}, + {"$unwind": "$tools"}, + {"$match": {"tools.id": tool_vec["id"]}}, + {"$project": {"_id": 0, "tools": 1}}, + {"$replaceRoot": {"newRoot": "$tools"}}, + ]) + async for tool in tool_data: + tool_obj = MCPTool.model_validate(tool) + llm_tool_list.append(tool_obj) + + return llm_tool_list diff --git a/apps/scheduler/mcp_agent/tool/__init__.py b/apps/scheduler/mcp_agent/tool/__init__.py deleted file mode 100644 index 4593f3174..000000000 --- a/apps/scheduler/mcp_agent/tool/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool -from apps.scheduler.mcp_agent.tool.terminate import Terminate -from apps.scheduler.mcp_agent.tool.tool_collection import ToolCollection - -__all__ = [ - "BaseTool", - "Terminate", - "ToolCollection", -] diff --git a/apps/scheduler/mcp_agent/tool/base.py b/apps/scheduler/mcp_agent/tool/base.py deleted file mode 100644 index 04ad45c47..000000000 --- a/apps/scheduler/mcp_agent/tool/base.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field - - -class BaseTool(ABC, BaseModel): - name: str - description: str - parameters: Optional[dict] = None - - class Config: - arbitrary_types_allowed = True - - async def __call__(self, **kwargs) -> Any: - return await self.execute(**kwargs) - - @abstractmethod - async def execute(self, **kwargs) -> Any: - """使用给定的参数执行工具""" - - def to_param(self) -> Dict: - """将工具转换为函数调用格式""" - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, - } - - -class ToolResult(BaseModel): - """表示工具执行的结果""" - - output: Any = Field(default=None) - error: Optional[str] = Field(default=None) - system: Optional[str] = Field(default=None) - - class Config: - arbitrary_types_allowed = True - - def __bool__(self): - return any(getattr(self, field) for field in self.__fields__) - - def __add__(self, other: "ToolResult"): - def combine_fields( - field: Optional[str], other_field: Optional[str], concatenate: bool = True - ): - if field and other_field: - if concatenate: - return field + other_field - raise ValueError("Cannot combine tool results") - return field or other_field - - return ToolResult( - output=combine_fields(self.output, other.output), - error=combine_fields(self.error, other.error), - system=combine_fields(self.system, other.system), - ) - - def __str__(self): - return f"Error: {self.error}" if self.error else self.output - - def replace(self, **kwargs): - """返回一个新的ToolResult,其中替换了给定的字段""" - # return self.copy(update=kwargs) - return type(self)(**{**self.dict(), **kwargs}) - - -class ToolFailure(ToolResult): - """表示失败的ToolResult""" diff --git a/apps/scheduler/mcp_agent/tool/terminate.py b/apps/scheduler/mcp_agent/tool/terminate.py deleted file mode 100644 index 84aa12031..000000000 --- a/apps/scheduler/mcp_agent/tool/terminate.py +++ /dev/null @@ -1,25 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool - - -_TERMINATE_DESCRIPTION = """当请求得到满足或助理无法继续处理任务时,终止交互。 -当您完成所有任务后,调用此工具结束工作。""" - - -class Terminate(BaseTool): - name: str = "terminate" - description: str = _TERMINATE_DESCRIPTION - parameters: dict = { - "type": "object", - "properties": { - "status": { - "type": "string", - "description": "交互的完成状态", - "enum": ["success", "failure"], - } - }, - "required": ["status"], - } - - async def execute(self, status: str) -> str: - """Finish the current execution""" - return f"交互已完成,状态为: {status}" diff --git a/apps/scheduler/mcp_agent/tool/tool_collection.py b/apps/scheduler/mcp_agent/tool/tool_collection.py deleted file mode 100644 index 95bda3178..000000000 --- a/apps/scheduler/mcp_agent/tool/tool_collection.py +++ /dev/null @@ -1,55 +0,0 @@ -"""用于管理多个工具的集合类""" -import logging -from typing import Any - -from apps.scheduler.mcp_agent.tool.base import BaseTool, ToolFailure, ToolResult - -logger = logging.getLogger(__name__) - - -class ToolCollection: - """定义工具的集合""" - - class Config: - arbitrary_types_allowed = True - - def __init__(self, *tools: BaseTool): - self.tools = tools - self.tool_map = {tool.name: tool for tool in tools} - - def __iter__(self): - return iter(self.tools) - - def to_params(self) -> list[dict[str, Any]]: - return [tool.to_param() for tool in self.tools] - - async def execute( - self, *, name: str, tool_input: dict[str, Any] = None - ) -> ToolResult: - tool = self.tool_map.get(name) - if not tool: - return ToolFailure(error=f"Tool {name} is invalid") - try: - result = await tool(**tool_input) - return result - except Exception as e: - return ToolFailure(error=f"Failed to execute tool {name}: {e}") - - def add_tool(self, tool: BaseTool): - """ - 将单个工具添加到集合中。 - - 如果已存在同名工具,则将跳过该工具并记录警告。 - """ - if tool.name in self.tool_map: - logger.warning(f"Tool {tool.name} already exists in collection, skipping") - return self - - self.tools += (tool,) - self.tool_map[tool.name] = tool - return self - - def add_tools(self, *tools: BaseTool): - for tool in tools: - self.add_tool(tool) - return self diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 4c2c4cf0b..32331cf3d 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -214,7 +214,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) # 若状态为成功,删除Task - if not task.state or task.state.status == StepStatus.SUCCESS: + if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED: await TaskManager.delete_task_by_task_id(task.id) else: # 更新Task diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index a84dc3a35..20e9c0f9e 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -20,6 +20,17 @@ class StepStatus(str, Enum): SUCCESS = "success" ERROR = "error" PARAM = "param" + CANCELLED = "cancelled" + + +class FlowStatus(str, Enum): + """Flow状态""" + + WAITING = "waiting" + RUNNING = "running" + SUCCESS = "success" + ERROR = "error" + CANCELLED = "cancelled" class DocumentStatus(str, Enum): diff --git a/apps/schemas/record.py b/apps/schemas/record.py index b5e1b0c55..0c3d71850 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -44,6 +44,7 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 37fdebbfb..98d8c6b3b 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import StepStatus +from apps.schemas.enum_var import FlowStatus, StepStatus from apps.schemas.flow import Step from apps.schemas.mcp import MCPPlan @@ -23,10 +23,11 @@ class FlowStepHistory(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="FlowID") flow_name: str = Field(description="Flow名称") + flow_status: FlowStatus = Field(description="Flow状态") step_id: str = Field(description="当前步骤名称") step_name: str = Field(description="当前步骤名称") - step_description: str = Field(description="当前步骤描述") - status: StepStatus = Field(description="当前步骤状态") + step_description: str = Field(description="当前步骤描述", default="") + step_status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) @@ -39,10 +40,11 @@ class ExecutorState(BaseModel): flow_id: str = Field(description="Flow ID") flow_name: str = Field(description="Flow名称") description: str = Field(description="Flow描述") - status: StepStatus = Field(description="Flow执行状态") - # 附加信息 + flow_status: FlowStatus = Field(description="Flow状态") + # 任务级数据 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_status: StepStatus = Field(description="当前步骤状态") step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py index 5375180a3..db1f5ead7 100644 --- a/tests/common/test_queue.py +++ b/tests/common/test_queue.py @@ -74,8 +74,8 @@ async def test_push_output_with_flow(message_queue, mock_task): mock_task.state.flow_id = "flow_id" mock_task.state.step_id = "step_id" mock_task.state.step_name = "step_name" - mock_task.state.status = "running" - + mock_task.state.step_status = "running" + await message_queue.init("test_task") await message_queue.push_output(mock_task, EventType.TEXT_ADD, {}) -- Gitee