diff --git a/apps/common/queue.py b/apps/common/queue.py
index 5601c93aa3ef72011f2e2d4b126f782d83e85b78..6e82af915e2a70f743c629902c149b7ebc0f7a0e 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -56,9 +56,12 @@ class MessageQueue:
flow = MessageFlow(
appId=task.state.app_id,
flowId=task.state.flow_id,
+ flowName=task.state.flow_name,
+ flowStatus=task.state.flow_status,
stepId=task.state.step_id,
stepName=task.state.step_name,
- stepStatus=task.state.status,
+ stepDescription=task.state.step_description,
+ stepStatus=task.state.step_status
)
else:
flow = None
diff --git a/apps/constants.py b/apps/constants.py
index 58158b33c3aa4ea5aa3595f5642bf6444e7d76cb..20cb79b54ac8db5450fdbb296cb400b47a29adf0 100644
--- a/apps/constants.py
+++ b/apps/constants.py
@@ -11,7 +11,7 @@ from apps.common.config import Config
# 新对话默认标题
NEW_CHAT = "新对话"
# 滑动窗口限流 默认窗口期
-SLIDE_WINDOW_TIME = 60
+SLIDE_WINDOW_TIME = 15
# OIDC 访问Token 过期时间(分钟)
OIDC_ACCESS_TOKEN_EXPIRE_TIME = 30
# OIDC 刷新Token 过期时间(分钟)
diff --git a/apps/dependency/user.py b/apps/dependency/user.py
index fce67e51e00dd76b35bec2f7787dfe8039111589..8f5848a3d2961ef21b242d2eee931c623c6481c7 100644
--- a/apps/dependency/user.py
+++ b/apps/dependency/user.py
@@ -1,14 +1,16 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""用户鉴权"""
-
+import os
import logging
from fastapi import Depends
from fastapi.security import OAuth2PasswordBearer
+import secrets
from starlette import status
from starlette.exceptions import HTTPException
from starlette.requests import HTTPConnection
+from apps.common.config import Config
from apps.services.api_key import ApiKeyManager
from apps.services.session import SessionManager
@@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str:
:param request: HTTP请求
:return: Session ID
"""
+ if Config().get_config().no_auth.enable:
+ # 如果启用了无认证访问,直接返回调试用户
+ return secrets.token_hex(16)
session_id = await _get_session_id_from_request(request)
if not session_id:
raise HTTPException(
@@ -69,6 +74,12 @@ async def get_user(request: HTTPConnection) -> str:
:param request: HTTP请求体
:return: 用户sub
"""
+ if Config().get_config().no_auth.enable:
+ # 如果启用了无认证访问,直接返回当前操作系统用户的名称
+ username = os.environ.get('USERNAME') # 适用于 Windows 系统
+ if not username:
+ username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统
+ return username or "admin"
session_id = await _get_session_id_from_request(request)
if not session_id:
raise HTTPException(
diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py
index 28ab86b49a96d19cc6e2db83930d5aff48f82260..df13044a1c9ce0e11968d5dd7c3f0f8ab2345394 100644
--- a/apps/llm/embedding.py
+++ b/apps/llm/embedding.py
@@ -1,8 +1,9 @@
"""Embedding模型"""
import httpx
-
+import logging
from apps.common.config import Config
+logger = logging.getLogger(__name__)
class Embedding:
@@ -15,7 +16,6 @@ class Embedding:
embedding = await cls.get_embedding(["测试文本"])
return len(embedding[0])
-
@classmethod
async def _get_openai_embedding(cls, text: list[str]) -> list[list[float]]:
"""访问OpenAI兼容的Embedding API,获得向量化数据"""
@@ -75,10 +75,18 @@ class Embedding:
:param text: 待向量化文本(多条文本组成List)
:return: 文本对应的向量(顺序与text一致,也为List)
"""
- if Config().get_config().embedding.type == "openai":
- return await cls._get_openai_embedding(text)
- if Config().get_config().embedding.type == "mindie":
- return await cls._get_tei_embedding(text)
+ try:
+ if Config().get_config().embedding.type == "openai":
+ return await cls._get_openai_embedding(text)
+ if Config().get_config().embedding.type == "mindie":
+ return await cls._get_tei_embedding(text)
- err = f"不支持的Embedding API类型: {Config().get_config().embedding.type}"
- raise ValueError(err)
+ err = f"不支持的Embedding API类型: {Config().get_config().embedding.type}"
+ raise ValueError(err)
+ except Exception as e:
+ err = f"获取Embedding失败: {e}"
+ logger.error(err)
+ rt = []
+ for i in range(len(text)):
+ rt.append([0.0]*1024)
+ return rt
diff --git a/apps/llm/function.py b/apps/llm/function.py
index 1f995fe7ba187cead03aa6fc62a4cbce1ec05a65..6165dc455aae038aeb1ad9bcd0840b418c3dc8e8 100644
--- a/apps/llm/function.py
+++ b/apps/llm/function.py
@@ -68,7 +68,6 @@ class FunctionLLM:
api_key=self._config.api_key,
)
-
async def _call_openai(
self,
messages: list[dict[str, str]],
@@ -123,7 +122,7 @@ class FunctionLLM:
},
]
- response = await self._client.chat.completions.create(**self._params) # type: ignore[arg-type]
+ response = await self._client.chat.completions.create(**self._params) # type: ignore[arg-type]
try:
logger.info("[FunctionCall] 大模型输出:%s", response.choices[0].message.tool_calls[0].function.arguments)
return response.choices[0].message.tool_calls[0].function.arguments
@@ -132,7 +131,6 @@ class FunctionLLM:
logger.info("[FunctionCall] 大模型输出:%s", ans)
return await FunctionLLM.process_response(ans)
-
@staticmethod
async def process_response(response: str) -> str:
"""处理大模型的输出"""
@@ -169,7 +167,6 @@ class FunctionLLM:
return json_str
-
async def _call_ollama(
self,
messages: list[dict[str, str]],
@@ -196,10 +193,9 @@ class FunctionLLM:
"format": schema,
})
- response = await self._client.chat(**self._params) # type: ignore[arg-type]
+ response = await self._client.chat(**self._params) # type: ignore[arg-type]
return await self.process_response(response.message.content or "")
-
async def call(
self,
messages: list[dict[str, Any]],
@@ -254,7 +250,6 @@ class JsonGenerator:
)
self._err_info = ""
-
async def _assemble_message(self) -> str:
"""组装消息"""
# 检查类型
@@ -275,13 +270,12 @@ class JsonGenerator:
"""单次尝试"""
prompt = await self._assemble_message()
messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": prompt},
+ {"role": "system", "content": prompt},
+ {"role": "user", "content": "please generate a JSON response based on the above information and schema."},
]
function = FunctionLLM()
return await function.call(messages, self._schema, max_tokens, temperature)
-
async def generate(self) -> dict[str, Any]:
"""生成JSON"""
Draft7Validator.check_schema(self._schema)
diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py
index 7cf555be390bc72301323cd103c1a7ec0ab7a085..7702f6d0912947612b64145402398f222e01a148 100644
--- a/apps/llm/prompt.py
+++ b/apps/llm/prompt.py
@@ -19,7 +19,7 @@ JSON_GEN_BASIC = dedent(r"""
Background information is given in XML tags.
- Here are the conversations between you and the user:
+ Here are the background information between you and the user:
{% if conversation|length > 0 %}
{% for message in conversation %}
diff --git a/apps/main.py b/apps/main.py
index c4ca2bfb116db11624f4e4fbbe4e876a86e5f1eb..99b7de0cd927f937abc95971cb878dd87e878abb 100644
--- a/apps/main.py
+++ b/apps/main.py
@@ -36,9 +36,10 @@ from apps.routers import (
record,
service,
user,
+ parameter
)
from apps.scheduler.pool.pool import Pool
-
+logger = logging.getLogger(__name__)
# 定义FastAPI app
app = FastAPI(redoc_url=None)
# 定义FastAPI全局中间件
@@ -66,6 +67,7 @@ app.include_router(llm.router)
app.include_router(mcp_service.router)
app.include_router(flow.router)
app.include_router(user.router)
+app.include_router(parameter.router)
# logger配置
LOGGER_FORMAT = "%(funcName)s() - %(message)s"
@@ -81,13 +83,49 @@ logging.basicConfig(
)
+async def add_no_auth_user() -> None:
+ """
+ 添加无认证用户
+ """
+ from apps.common.mongo import MongoDB
+ from apps.schemas.collection import User
+ import os
+ mongo = MongoDB()
+ user_collection = mongo.get_collection("user")
+ username = os.environ.get('USERNAME') # 适用于 Windows 系统
+ if not username:
+ username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统
+ if not username:
+ username = "admin"
+ try:
+ await user_collection.insert_one(User(
+ _id=username,
+ is_admin=True,
+ auto_execute=False
+ ).model_dump(by_alias=True))
+ except Exception as e:
+ logger.error(f"[add_no_auth_user] 默认用户 {username} 已存在")
+
+
+async def clear_user_activity() -> None:
+ """清除所有用户的活跃状态"""
+ from apps.services.activity import Activity
+ from apps.common.mongo import MongoDB
+ mongo = MongoDB()
+ activity_collection = mongo.get_collection("activity")
+ await activity_collection.delete_many({})
+ logging.info("清除所有用户活跃状态完成")
+
+
async def init_resources() -> None:
"""初始化必要资源"""
WordsCheck()
await LanceDB().init()
await Pool.init()
TokenCalculator()
-
+ if Config().get_config().no_auth.enable:
+ await add_no_auth_user()
+ await clear_user_activity()
# 运行
if __name__ == "__main__":
# 初始化必要资源
diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py
index 158cfc13a5be17f16e2d73ec5fa44c4a18e4f392..51366a2192eaa429a3f1e1d40672b9e915bec424 100644
--- a/apps/routers/api_key.py
+++ b/apps/routers/api_key.py
@@ -6,6 +6,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
+
from apps.dependency.user import get_user, verify_user
from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp
from apps.schemas.response_data import ResponseData
diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py
index 0ec4db9155a06fe16cbeddff223ccd9a764d8b3d..36815743349c2436564e30e4cb319caa8843dd99 100644
--- a/apps/routers/appcenter.py
+++ b/apps/routers/appcenter.py
@@ -13,6 +13,8 @@ from apps.schemas.appcenter import AppFlowInfo, AppPermissionData
from apps.schemas.enum_var import AppFilterType, AppType
from apps.schemas.request_data import CreateAppRequest, ModFavAppRequest
from apps.schemas.response_data import (
+ AppMcpServiceInfo,
+ LLMIteam,
BaseAppOperationMsg,
BaseAppOperationRsp,
GetAppListMsg,
@@ -25,7 +27,8 @@ from apps.schemas.response_data import (
ResponseData,
)
from apps.services.appcenter import AppCenterManager
-
+from apps.services.llm import LLMManager
+from apps.services.mcp_service import MCPServiceManager
logger = logging.getLogger(__name__)
router = APIRouter(
prefix="/api/app",
@@ -214,6 +217,24 @@ async def get_application(
)
for flow in app_data.flows
]
+ mcp_service = []
+ if app_data.mcp_service:
+ for service in app_data.mcp_service:
+ mcp_collection = await MCPServiceManager.get_mcp_service(service)
+ mcp_service.append(AppMcpServiceInfo(
+ id=mcp_collection.id,
+ name=mcp_collection.name,
+ description=mcp_collection.description,
+ ))
+ if app_data.llm_id == "empty":
+ llm_item = LLMIteam()
+ else:
+ llm_collection = await LLMManager.get_llm_by_id(app_data.user_sub, app_data.llm_id)
+ llm_item = LLMIteam(
+ llmId=llm_collection.id,
+ modelName=llm_collection.model_name,
+ icon=llm_collection.icon
+ )
return JSONResponse(
status_code=status.HTTP_200_OK,
content=GetAppPropertyRsp(
@@ -234,7 +255,8 @@ async def get_application(
authorizedUsers=app_data.permission.users,
),
workflows=workflows,
- mcpService=app_data.mcp_service,
+ mcpService=mcp_service,
+ llm=llm_item,
),
).model_dump(exclude_none=True, by_alias=True),
)
diff --git a/apps/routers/auth.py b/apps/routers/auth.py
index 1cba5ed629d90776b59ae3bf652381318dcabf0e..72416fa35b1c23d2975497427da35a4912695fa4 100644
--- a/apps/routers/auth.py
+++ b/apps/routers/auth.py
@@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
+from apps.common.config import Config
from apps.common.oidc import oidc_provider
from apps.dependency import get_session, get_user, verify_user
from apps.schemas.collection import Audit
@@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
user_info = await oidc_provider.get_oidc_user(token["access_token"])
user_sub: str | None = user_info.get("user_sub", None)
- if user_sub:
- await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"])
except Exception as e:
logger.exception("User login failed")
status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN
@@ -75,7 +74,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
status_code=status.HTTP_403_FORBIDDEN,
)
- await UserManager.update_userinfo_by_user_sub(user_sub)
+ await UserManager.update_refresh_revision_by_user_sub(user_sub)
current_session = await SessionManager.create_session(user_host, user_sub)
@@ -178,6 +177,7 @@ async def userinfo(
user_sub=user_sub,
revision=user.is_active,
is_admin=user.is_admin,
+ auto_execute=user.auto_execute,
),
).model_dump(exclude_none=True, by_alias=True),
)
@@ -193,7 +193,7 @@ async def userinfo(
)
async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001
"""更新用户协议信息"""
- ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True)
+ ret: bool = await UserManager.update_refresh_revision_by_user_sub(user_sub, refresh_revision=True)
if not ret:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 7fe5162c0f7db804c6d723207ffa327d9394e3eb..888fbd04f67c0ad2ad9482db8d37016e0d71447d 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -7,20 +7,23 @@ import uuid
from collections.abc import AsyncGenerator
from typing import Annotated
-from fastapi import APIRouter, Depends, HTTPException, status
+from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.responses import JSONResponse, StreamingResponse
from apps.common.queue import MessageQueue
from apps.common.wordscheck import WordsCheck
from apps.dependency import get_session, get_user
+from apps.schemas.enum_var import FlowStatus
from apps.scheduler.scheduler import Scheduler
from apps.scheduler.scheduler.context import save_data
-from apps.schemas.request_data import RequestData
+from apps.schemas.request_data import RequestData, RequestDataApp
from apps.schemas.response_data import ResponseData
from apps.schemas.task import Task
from apps.services.activity import Activity
from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager
from apps.services.flow import FlowManager
+from apps.services.conversation import ConversationManager
+from apps.services.record import RecordManager
from apps.services.task import TaskManager
RECOMMEND_TRES = 5
@@ -36,28 +39,49 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T
# 生成group_id
if not post_body.group_id:
post_body.group_id = str(uuid.uuid4())
- # 创建或还原Task
- task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
+
# 更改信息并刷新数据库
- task.runtime.question = post_body.question
- task.ids.group_id = post_body.group_id
+ if post_body.task_id is None:
+ conversation = await ConversationManager.get_conversation_by_conversation_id(
+ user_sub=user_sub,
+ conversation_id=post_body.conversation_id,
+ )
+ if not conversation:
+ err = "[Chat] 用户没有权限访问该对话!"
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err)
+ task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id)
+ await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids)
+ task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body)
+ task.runtime.question = post_body.question
+ task.ids.group_id = post_body.group_id
+ task.state.app_id = post_body.app.app_id if post_body.app else ""
+ else:
+ if not post_body.task_id:
+ err = "[Chat] task_id 不可为空!"
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty")
+ task = await TaskManager.get_task_by_task_id(post_body.task_id)
+ post_body.app = RequestDataApp(appId=task.state.app_id)
+ post_body.group_id = task.ids.group_id
+ post_body.conversation_id = task.ids.conversation_id
+ post_body.language = task.language
+ post_body.question = task.runtime.question
return task
async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]:
"""进行实际问答,并从MQ中获取消息"""
try:
- await Activity.set_active(user_sub)
+ active_id = await Activity.set_active(user_sub)
# 敏感词检查
if await WordsCheck().check(post_body.question) != 1:
yield "data: [SENSITIVE]\n\n"
logger.info("[Chat] 问题包含敏感词!")
- await Activity.remove_active(user_sub)
+ await Activity.remove_active(active_id)
return
task = await init_task(post_body, user_sub, session_id)
-
+ task.ids.active_id = active_id
# 创建queue;由Scheduler进行关闭
queue = MessageQueue()
await queue.init()
@@ -77,19 +101,18 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
# 获取最终答案
task = scheduler.task
- if not task.runtime.answer:
- logger.error("[Chat] 答案为空")
+ if task.state.flow_status == FlowStatus.ERROR:
+ logger.error("[Chat] 生成答案失败")
yield "data: [ERROR]\n\n"
- await Activity.remove_active(user_sub)
+ await Activity.remove_active(active_id)
return
# 对结果进行敏感词检查
if await WordsCheck().check(task.runtime.answer) != 1:
yield "data: [SENSITIVE]\n\n"
logger.info("[Chat] 答案包含敏感词!")
- await Activity.remove_active(user_sub)
+ await Activity.remove_active(active_id)
return
-
# 创建新Record,存入数据库
await save_data(task, user_sub, post_body)
@@ -107,7 +130,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
yield "data: [ERROR]\n\n"
finally:
- await Activity.remove_active(user_sub)
+ await Activity.remove_active(active_id)
@router.post("/chat")
@@ -118,15 +141,11 @@ async def chat(
) -> StreamingResponse:
"""LLM流式对话接口"""
# 问题黑名单检测
- if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question):
+ if post_body.question is not None and not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question):
# 用户扣分
await UserBlacklistManager.change_blacklisted_users(user_sub, -10)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted")
- # 限流检查
- if await Activity.is_active(user_sub):
- raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests")
-
res = chat_generator(post_body, user_sub, session_id)
return StreamingResponse(
content=res,
@@ -138,9 +157,12 @@ async def chat(
@router.post("/stop", response_model=ResponseData)
-async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201
+async def stop_generation(user_sub: Annotated[str, Depends(get_user)],
+ task_id: Annotated[str, Query(..., alias="taskId")] = "") -> JSONResponse:
"""停止生成"""
- await Activity.remove_active(user_sub)
+ task = await TaskManager.get_task_by_task_id(task_id)
+ if task:
+ await Activity.remove_active(task.activity_id)
return JSONResponse(
status_code=status.HTTP_200_OK,
content=ResponseData(
diff --git a/apps/routers/flow.py b/apps/routers/flow.py
index fc38c1bfb9879c1d846d700e91b812c83866c2c4..68c0c9f3bd2fbdb434164d64ac4c3d1cf31a376b 100644
--- a/apps/routers/flow.py
+++ b/apps/routers/flow.py
@@ -2,6 +2,7 @@
"""FastAPI Flow拓扑结构展示API"""
from typing import Annotated
+import logging
from fastapi import APIRouter, Body, Depends, Query, status
from fastapi.responses import JSONResponse
@@ -25,6 +26,8 @@ from apps.services.application import AppManager
from apps.services.flow import FlowManager
from apps.services.flow_validate import FlowService
+logger = logging.getLogger(__name__)
+
router = APIRouter(
prefix="/api/flow",
tags=["flow"],
@@ -130,8 +133,11 @@ async def put_flow(
).model_dump(exclude_none=True, by_alias=True),
)
put_body.flow = await FlowService.remove_excess_structure_from_flow(put_body.flow)
+ logger.error(f'{put_body.flow}')
await FlowService.validate_flow_illegal(put_body.flow)
+ logger.error(f'{put_body.flow}')
put_body.flow.connectivity = await FlowService.validate_flow_connectivity(put_body.flow)
+ logger.error(f'{put_body.flow}')
result = await FlowManager.put_flow_by_app_and_flow_id(app_id, flow_id, put_body.flow)
if result is None:
return JSONResponse(
diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py
index a845a3761ea2ab179a1b71d4674d0d5fd3b5a760..99cc6dfdc312607e6bcf42274e7ad748d9dd928b 100644
--- a/apps/routers/mcp_service.py
+++ b/apps/routers/mcp_service.py
@@ -36,6 +36,7 @@ router = APIRouter(
dependencies=[Depends(verify_user)],
)
+
async def _check_user_admin(user_sub: str) -> None:
user = await UserManager.get_userinfo_by_user_sub(user_sub)
if not user:
@@ -52,6 +53,8 @@ async def get_mcpservice_list(
] = SearchType.ALL,
keyword: Annotated[str | None, Query(..., alias="keyword", description="搜索关键字")] = None,
page: Annotated[int, Query(..., alias="page", ge=1, description="页码")] = 1,
+ is_install: Annotated[bool | None, Query(..., alias="isInstall", description="是否已安装")] = None,
+ is_active: Annotated[bool | None, Query(..., alias="isActive", description="是否激活")] = None,
) -> JSONResponse:
"""获取服务列表"""
try:
@@ -60,6 +63,8 @@ async def get_mcpservice_list(
user_sub,
keyword,
page,
+ is_install,
+ is_active
)
except Exception as e:
err = f"[MCPServiceCenter] 获取MCP服务列表失败: {e}"
@@ -89,7 +94,7 @@ async def get_mcpservice_list(
@router.post("", response_model=UpdateMCPServiceRsp)
async def create_or_update_mcpservice(
user_sub: Annotated[str, Depends(get_user)], # TODO: get_user直接获取所有用户信息
- data: UpdateMCPServiceRequest,
+ data: UpdateMCPServiceRequest
) -> JSONResponse:
"""新建或更新MCP服务"""
await _check_user_admin(user_sub)
@@ -130,6 +135,35 @@ async def create_or_update_mcpservice(
).model_dump(exclude_none=True, by_alias=True))
+@router.post("/{serviceId}/install")
+async def install_mcp_service(
+ user_sub: Annotated[str, Depends(get_user)],
+ service_id: Annotated[str, Path(..., alias="serviceId", description="服务ID")],
+ install: Annotated[bool, Query(..., description="是否安装")] = True,
+) -> JSONResponse:
+ try:
+ await MCPServiceManager.install_mcpservice(user_sub, service_id, install)
+ except Exception as e:
+ err = f"[MCPService] 安装mcp服务失败: {e!s}" if install else f"[MCPService] 卸载mcp服务失败: {e!s}"
+ logger.exception(err)
+ return JSONResponse(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ content=ResponseData(
+ code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ message=err,
+ result={},
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content=ResponseData(
+ code=status.HTTP_200_OK,
+ message="OK",
+ result={},
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+
+
@router.get("/{serviceId}", response_model=GetMCPServiceDetailRsp)
async def get_service_detail(
user_sub: Annotated[str, Depends(get_user)],
@@ -282,7 +316,7 @@ async def active_or_deactivate_mcp_service(
"""激活/取消激活mcp"""
try:
if data.active:
- await MCPServiceManager.active_mcpservice(user_sub, service_id)
+ await MCPServiceManager.active_mcpservice(user_sub, service_id, data.mcp_env)
else:
await MCPServiceManager.deactive_mcpservice(user_sub, service_id)
except Exception as e:
diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..6edbe2e142cb6589be8947c28ae2eb4a7287baa1
--- /dev/null
+++ b/apps/routers/parameter.py
@@ -0,0 +1,77 @@
+from typing import Annotated
+
+from fastapi import APIRouter, Depends, Query, status
+from fastapi.responses import JSONResponse
+
+from apps.dependency import get_user
+from apps.dependency.user import verify_user
+from apps.services.parameter import ParameterManager
+from apps.schemas.response_data import (
+ GetOperaRsp,
+ GetParamsRsp
+)
+from apps.services.application import AppManager
+from apps.services.flow import FlowManager
+
+router = APIRouter(
+ prefix="/api/parameter",
+ tags=["parameter"],
+ dependencies=[
+ Depends(verify_user),
+ ],
+)
+
+
+@router.get("", response_model=GetParamsRsp)
+async def get_parameters(
+ user_sub: Annotated[str, Depends(get_user)],
+ app_id: Annotated[str, Query(alias="appId")],
+ flow_id: Annotated[str, Query(alias="flowId")],
+ step_id: Annotated[str, Query(alias="stepId")],
+) -> JSONResponse:
+ """Get parameters for node choice."""
+ if not await AppManager.validate_user_app_access(user_sub, app_id):
+ return JSONResponse(
+ status_code=status.HTTP_403_FORBIDDEN,
+ content=GetParamsRsp(
+ code=status.HTTP_403_FORBIDDEN,
+ message="用户没有权限访问该流",
+ result=[],
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id)
+ if not flow:
+ return JSONResponse(
+ status_code=status.HTTP_404_NOT_FOUND,
+ content=GetParamsRsp(
+ code=status.HTTP_404_NOT_FOUND,
+ message="未找到该流",
+ result=[],
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id)
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content=GetParamsRsp(
+ code=status.HTTP_200_OK,
+ message="获取参数成功",
+ result=result
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+
+
+@router.get("/operate", response_model=GetOperaRsp)
+async def get_operate_parameters(
+ user_sub: Annotated[str, Depends(get_user)],
+ param_type: Annotated[str, Query(alias="ParamType")],
+) -> JSONResponse:
+ """Get parameters for node choice."""
+ result = await ParameterManager.get_operate_and_bind_type(param_type)
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content=GetOperaRsp(
+ code=status.HTTP_200_OK,
+ message="获取操作成功",
+ result=result
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
diff --git a/apps/routers/record.py b/apps/routers/record.py
index 7384793b4b2abea5117462cb630c5b13c8f76071..ed994f6fe2e611562fa092db39265d1d57dfb9bc 100644
--- a/apps/routers/record.py
+++ b/apps/routers/record.py
@@ -65,7 +65,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_
tmp_record = RecordData(
id=record.id,
groupId=record_group.id,
- taskId=record_group.task_id,
+ taskId=record.task_id,
conversationId=conversation_id,
content=record_data,
metadata=record.metadata
@@ -81,26 +81,26 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_
# 获得Record关联的文档
tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id)
-
# 获得Record关联的flow数据
- flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id)
- if flow_list:
- first_flow = FlowStepHistory.model_validate(flow_list[0])
+ flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id)
+ if flow_step_list:
tmp_record.flow = RecordFlow(
- id=first_flow.flow_name, #TODO: 此处前端应该用name
+ id=record.flow.flow_id, # TODO: 此处前端应该用name
recordId=record.id,
- flowId=first_flow.id,
- stepNum=len(flow_list),
+ flowId=record.flow.flow_id,
+ flowName=record.flow.flow_name,
+ flowStatus=record.flow.flow_staus,
+ stepNum=len(flow_step_list),
steps=[],
)
- for flow in flow_list:
- flow_step = FlowStepHistory.model_validate(flow)
+ for flow_step in flow_step_list:
tmp_record.flow.steps.append(
RecordFlowStep(
- stepId=flow_step.step_name, #TODO: 此处前端应该用name
- stepStatus=flow_step.status,
+ stepId=flow_step.step_name, # TODO: 此处前端应该用name
+ stepStatus=flow_step.step_status,
input=flow_step.input_data,
output=flow_step.output_data,
+ exData=flow_step.ex_data
),
)
diff --git a/apps/routers/user.py b/apps/routers/user.py
index 54e12f444181b56e408f7face5c4ed37be76008b..537f1bf3adb95e8b48ef1b3d384376bae86ac802 100644
--- a/apps/routers/user.py
+++ b/apps/routers/user.py
@@ -3,10 +3,11 @@
from typing import Annotated
-from fastapi import APIRouter, Depends, status
+from fastapi import APIRouter, Body, Depends, status, Query
from fastapi.responses import JSONResponse
from apps.dependency import get_user
+from apps.schemas.request_data import UserUpdateRequest
from apps.schemas.response_data import UserGetMsp, UserGetRsp
from apps.schemas.user import UserInfo
from apps.services.user import UserManager
@@ -17,12 +18,14 @@ router = APIRouter(
)
-@router.get("")
-async def chat(
+@router.get("", response_model=UserGetRsp)
+async def get_user_sub(
user_sub: Annotated[str, Depends(get_user)],
+ page_size: Annotated[int, Query(description="每页用户数量")] = 20,
+ page_cnt: Annotated[int, Query(description="当前页码")] = 1,
) -> JSONResponse:
"""查询所有用户接口"""
- user_list = await UserManager.get_all_user_sub()
+ user_list, total = await UserManager.get_all_user_sub(page_cnt=page_cnt, page_size=page_size, filter_user_subs=[user_sub])
user_info_list = []
for user in user_list:
# user_info = await UserManager.get_userinfo_by_user_sub(user) 暂时不需要查询user_name
@@ -39,6 +42,22 @@ async def chat(
content=UserGetRsp(
code=status.HTTP_200_OK,
message="用户数据详细信息获取成功",
- result=UserGetMsp(userInfoList=user_info_list),
+ result=UserGetMsp(userInfoList=user_info_list, total=total),
).model_dump(exclude_none=True, by_alias=True),
)
+
+
+@router.post("")
+async def update_user_info(
+ user_sub: Annotated[str, Depends(get_user)],
+ *,
+ data: Annotated[UserUpdateRequest, Body(..., description="用户更新信息")],
+) -> JSONResponse:
+ """更新用户信息接口"""
+ # 更新用户信息
+
+ await UserManager.update_userinfo_by_user_sub(user_sub, data)
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"},
+ )
diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py
index 2ee8b862885b0d88e519bf00a1678a49863df2bd..c5a6f0549c6891bcdf74052ddb8152fb550653b2 100644
--- a/apps/scheduler/call/__init__.py
+++ b/apps/scheduler/call/__init__.py
@@ -8,7 +8,7 @@ from apps.scheduler.call.mcp.mcp import MCP
from apps.scheduler.call.rag.rag import RAG
from apps.scheduler.call.sql.sql import SQL
from apps.scheduler.call.suggest.suggest import Suggestion
-
+from apps.scheduler.call.choice.choice import Choice
# 只包含需要在编排界面展示的工具
__all__ = [
"API",
@@ -18,4 +18,5 @@ __all__ = [
"SQL",
"Graph",
"Suggestion",
+ "Choice"
]
diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py
index a5edf21afb2eeb308dd909a40696500751c9a086..01ac7106fbdae673d12488a49d831f1d7e7fc13d 100644
--- a/apps/scheduler/call/choice/choice.py
+++ b/apps/scheduler/call/choice/choice.py
@@ -1,19 +1,150 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""使用大模型或使用程序做出判断"""
-from enum import Enum
-
-from apps.scheduler.call.choice.schema import ChoiceInput, ChoiceOutput
-from apps.scheduler.call.core import CoreCall
+import ast
+import copy
+import logging
+from collections.abc import AsyncGenerator
+from typing import Any
+from pydantic import Field
-class Operator(str, Enum):
- """Choice工具支持的运算符"""
+from apps.scheduler.call.choice.condition_handler import ConditionHandler
+from apps.scheduler.call.choice.schema import (
+ Condition,
+ ChoiceBranch,
+ ChoiceInput,
+ ChoiceOutput,
+ Logic,
+)
+from apps.schemas.parameters import Type
+from apps.scheduler.call.core import CoreCall
+from apps.schemas.enum_var import CallOutputType
+from apps.schemas.scheduler import (
+ CallError,
+ CallInfo,
+ CallOutputChunk,
+ CallVars,
+)
- pass
+logger = logging.getLogger(__name__)
class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput):
"""Choice工具"""
- pass
+ to_user: bool = Field(default=False)
+ choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(),
+ ChoiceBranch(conditions=[Condition()], is_default=False)])
+
+ @classmethod
+ def info(cls) -> CallInfo:
+ """返回Call的名称和描述"""
+ return CallInfo(name="选择器", description="使用大模型或使用程序做出判断")
+
+ async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]:
+ """替换choices中的系统变量"""
+ valid_choices = []
+
+ for choice in self.choices:
+ try:
+ # 验证逻辑运算符
+ if choice.logic not in [Logic.AND, Logic.OR]:
+ msg = f"无效的逻辑运算符: {choice.logic}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+
+ valid_conditions = []
+ for i in range(len(choice.conditions)):
+ condition = copy.deepcopy(choice.conditions[i])
+ # 处理左值
+ if condition.left.step_id is not None:
+ condition.left.value = self._extract_history_variables(
+ condition.left.step_id+'/'+condition.left.value, call_vars.history)
+ # 检查历史变量是否成功提取
+ if condition.left.value is None:
+ msg = f"步骤 {condition.left.step_id} 的历史变量不存在"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if not ConditionHandler.check_value_type(
+ condition.left.value, condition.left.type):
+ msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ else:
+ msg = "左侧变量缺少step_id"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ # 处理右值
+ if condition.right.step_id is not None:
+ condition.right.value = self._extract_history_variables(
+ condition.right.step_id+'/'+condition.right.value, call_vars.history)
+ # 检查历史变量是否成功提取
+ if condition.right.value is None:
+ msg = f"步骤 {condition.right.step_id} 的历史变量不存在"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if not ConditionHandler.check_value_type(
+ condition.right.value, condition.right.type):
+ msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ else:
+ # 如果右值没有step_id,尝试从call_vars中获取
+ right_value_type = await ConditionHandler.get_value_type_from_operate(
+ condition.operate)
+ if right_value_type is None:
+ msg = f"不支持的运算符: {condition.operate}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if condition.right.type != right_value_type:
+ msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ if right_value_type == Type.STRING:
+ condition.right.value = str(condition.right.value)
+ else:
+ condition.right.value = ast.literal_eval(condition.right.value)
+ if not ConditionHandler.check_value_type(
+ condition.right.value, condition.right.type):
+ msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+ valid_conditions.append(condition)
+
+ # 如果所有条件都无效,抛出异常
+ if not valid_conditions and not choice.is_default:
+ msg = "分支没有有效条件"
+ logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
+ continue
+
+ # 更新有效条件
+ choice.conditions = valid_conditions
+ valid_choices.append(choice)
+
+ except ValueError as e:
+ logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e))
+ continue
+
+ return valid_choices
+
+ async def _init(self, call_vars: CallVars) -> ChoiceInput:
+ """初始化Choice工具"""
+ return ChoiceInput(
+ choices=await self._prepare_message(call_vars),
+ )
+
+ async def _exec(
+ self, input_data: dict[str, Any]
+ ) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行Choice工具"""
+ # 解析输入数据
+ data = ChoiceInput(**input_data)
+ try:
+ branch_id = ConditionHandler.handler(data.choices)
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True),
+ )
+ except Exception as e:
+ raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e
diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..261809ad57fdd6d8c5267c59a7b87404dd7befa5
--- /dev/null
+++ b/apps/scheduler/call/choice/condition_handler.py
@@ -0,0 +1,309 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""处理条件分支的工具"""
+
+
+import logging
+
+from pydantic import BaseModel
+
+from apps.schemas.parameters import (
+ Type,
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+)
+
+from apps.scheduler.call.choice.schema import (
+ ChoiceBranch,
+ Condition,
+ Logic,
+ Value
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ConditionHandler(BaseModel):
+ """条件分支处理器"""
+ @staticmethod
+ async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate |
+ BoolOperate | DictOperate) -> Type:
+ """获取右值的类型"""
+ if isinstance(operate, NumberOperate):
+ return Type.NUMBER
+ if operate in [
+ StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS,
+ StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]:
+ return Type.STRING
+ if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN,
+ StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN,
+ StringOperate.LENGTH_LESS_THAN_OR_EQUAL]:
+ return Type.NUMBER
+ if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]:
+ return Type.LIST
+ if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]:
+ return Type.STRING
+ if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN,
+ ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN,
+ ListOperate.LENGTH_LESS_THAN_OR_EQUAL]:
+ return Type.NUMBER
+ if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]:
+ return Type.BOOL
+ if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]:
+ return Type.DICT
+ if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]:
+ return Type.STRING
+ return None
+
+ @staticmethod
+ def check_value_type(value: Value, expected_type: Type) -> bool:
+ """检查值的类型是否符合预期"""
+ if expected_type == Type.STRING and isinstance(value.value, str):
+ return True
+ if expected_type == Type.NUMBER and isinstance(value.value, (int, float)):
+ return True
+ if expected_type == Type.LIST and isinstance(value.value, list):
+ return True
+ if expected_type == Type.DICT and isinstance(value.value, dict):
+ return True
+ if expected_type == Type.BOOL and isinstance(value.value, bool):
+ return True
+ return False
+
+ @staticmethod
+ def handler(choices: list[ChoiceBranch]) -> str:
+ """处理条件"""
+
+ for block_judgement in choices[::-1]:
+ results = []
+ if block_judgement.is_default:
+ return block_judgement.branch_id
+ for condition in block_judgement.conditions:
+ result = ConditionHandler._judge_condition(condition)
+ if result is not None:
+ results.append(result)
+ if not results:
+ logger.warning(f"[Choice] 分支 {block_judgement.branch_id} 条件处理失败: 没有有效的条件")
+ continue
+ if block_judgement.logic == Logic.AND:
+ final_result = all(results)
+ elif block_judgement.logic == Logic.OR:
+ final_result = any(results)
+
+ if final_result:
+ return block_judgement.branch_id
+
+ return ""
+
+ @staticmethod
+ def _judge_condition(condition: Condition) -> bool:
+ """
+ 判断条件是否成立。
+
+ Args:
+ condition (Condition): 'left', 'operate', 'right', 'type'
+
+ Returns:
+ bool
+
+ """
+ left = condition.left
+ operate = condition.operate
+ right = condition.right
+ value_type = condition.type
+
+ result = None
+ if value_type == Type.STRING:
+ result = ConditionHandler._judge_string_condition(left, operate, right)
+ elif value_type == Type.NUMBER:
+ result = ConditionHandler._judge_number_condition(left, operate, right)
+ elif value_type == Type.BOOL:
+ result = ConditionHandler._judge_bool_condition(left, operate, right)
+ elif value_type == Type.LIST:
+ result = ConditionHandler._judge_list_condition(left, operate, right)
+ elif value_type == Type.DICT:
+ result = ConditionHandler._judge_dict_condition(left, operate, right)
+ else:
+ msg = f"不支持的数据类型: {value_type}"
+ logger.error(f"[Choice] 条件处理失败: {msg}")
+ return None
+ return result
+
+ @staticmethod
+ def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool:
+ """
+ 判断字符串类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, str):
+ msg = f"左值必须是字符串类型 ({left_value})"
+ logger.warning(msg)
+ return None
+ right_value = right.value
+ if operate == StringOperate.EQUAL:
+ return left_value == right_value
+ elif operate == StringOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == StringOperate.CONTAINS:
+ return right_value in left_value
+ elif operate == StringOperate.NOT_CONTAINS:
+ return right_value not in left_value
+ elif operate == StringOperate.STARTS_WITH:
+ return left_value.startswith(right_value)
+ elif operate == StringOperate.ENDS_WITH:
+ return left_value.endswith(right_value)
+ elif operate == StringOperate.REGEX_MATCH:
+ import re
+ return bool(re.match(right_value, left_value))
+ elif operate == StringOperate.LENGTH_EQUAL:
+ return len(left_value) == right_value
+ elif operate == StringOperate.LENGTH_GREATER_THAN:
+ return len(left_value) > right_value
+ elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL:
+ return len(left_value) >= right_value
+ elif operate == StringOperate.LENGTH_LESS_THAN:
+ return len(left_value) < right_value
+ elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL:
+ return len(left_value) <= right_value
+ return False
+
+ @staticmethod
+ def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911
+ """
+ 判断数字类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, (int, float)):
+ msg = f"左值必须是数字类型 ({left_value})"
+ logger.warning(msg)
+ return None
+ right_value = right.value
+ if operate == NumberOperate.EQUAL:
+ return left_value == right_value
+ elif operate == NumberOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == NumberOperate.GREATER_THAN:
+ return left_value > right_value
+ elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004
+ return left_value < right_value
+ elif operate == NumberOperate.GREATER_THAN_OR_EQUAL:
+ return left_value >= right_value
+ elif operate == NumberOperate.LESS_THAN_OR_EQUAL:
+ return left_value <= right_value
+ return False
+
+ @staticmethod
+ def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool:
+ """
+ 判断布尔类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, bool):
+ msg = "左值必须是布尔类型"
+ logger.warning(msg)
+ return None
+ right_value = right.value
+ if operate == BoolOperate.EQUAL:
+ return left_value == right_value
+ elif operate == BoolOperate.NOT_EQUAL:
+ return left_value != right_value
+ return False
+
+ @staticmethod
+ def _judge_list_condition(left: Value, operate: ListOperate, right: Value):
+ """
+ 判断列表类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, list):
+ msg = f"左值必须是列表类型 ({left_value})"
+ logger.warning(msg)
+ return None
+ right_value = right.value
+ if operate == ListOperate.EQUAL:
+ return left_value == right_value
+ elif operate == ListOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == ListOperate.CONTAINS:
+ return right_value in left_value
+ elif operate == ListOperate.NOT_CONTAINS:
+ return right_value not in left_value
+ elif operate == ListOperate.LENGTH_EQUAL:
+ return len(left_value) == right_value
+ elif operate == ListOperate.LENGTH_GREATER_THAN:
+ return len(left_value) > right_value
+ elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL:
+ return len(left_value) >= right_value
+ elif operate == ListOperate.LENGTH_LESS_THAN:
+ return len(left_value) < right_value
+ elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL:
+ return len(left_value) <= right_value
+ return False
+
+ @staticmethod
+ def _judge_dict_condition(left: Value, operate: DictOperate, right: Value):
+ """
+ 判断字典类型的条件。
+
+ Args:
+ left (Value): 左值,包含 'value' 键。
+ operate (Operate): 操作符
+ right (Value): 右值,包含 'value' 键。
+
+ Returns:
+ bool
+
+ """
+ left_value = left.value
+ if not isinstance(left_value, dict):
+ msg = f"左值必须是字典类型 ({left_value})"
+ logger.warning(msg)
+ return None
+ right_value = right.value
+ if operate == DictOperate.EQUAL:
+ return left_value == right_value
+ elif operate == DictOperate.NOT_EQUAL:
+ return left_value != right_value
+ elif operate == DictOperate.CONTAINS_KEY:
+ return right_value in left_value
+ elif operate == DictOperate.NOT_CONTAINS_KEY:
+ return right_value not in left_value
+ return False
diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py
index 60b62d09fd66adbf32295f44ec86398a537f38d5..d97a0c8d7e9efcef255ced76c8d1d6b39dfed241 100644
--- a/apps/scheduler/call/choice/schema.py
+++ b/apps/scheduler/call/choice/schema.py
@@ -1,12 +1,64 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Choice Call的输入和输出"""
+import uuid
+from enum import Enum
+
+from pydantic import BaseModel, Field
+
+from apps.schemas.parameters import (
+ Type,
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+)
from apps.scheduler.call.core import DataBase
+class Logic(str, Enum):
+ """Choice 工具支持的逻辑运算符"""
+
+ AND = "and"
+ OR = "or"
+
+
+class Value(DataBase):
+ """值的结构"""
+
+ step_id: str | None = Field(description="步骤id", default=None)
+ type: Type | None = Field(description="值的类型", default=None)
+ name: str | None = Field(description="值的名称", default=None)
+ value: str | float | int | bool | list | dict | None = Field(description="值", default=None)
+
+
+class Condition(DataBase):
+ """单个条件"""
+
+ left: Value = Field(description="左值", default=Value())
+ right: Value = Field(description="右值", default=Value())
+ operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field(
+ description="运算符", default=None)
+ id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4()))
+
+
+class ChoiceBranch(DataBase):
+ """子分支"""
+
+ branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4()))
+ logic: Logic = Field(description="逻辑运算符", default=Logic.AND)
+ conditions: list[Condition] = Field(description="条件列表", default=[])
+ is_default: bool = Field(description="是否为默认分支", default=True)
+
+
class ChoiceInput(DataBase):
"""Choice Call的输入"""
+ choices: list[ChoiceBranch] = Field(description="分支", default=[])
+
class ChoiceOutput(DataBase):
"""Choice Call的输出"""
+
+ branch_id: str = Field(description="分支ID", default="")
diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py
index 2b1cbba83b9345e8df84e8f23f893137cb46532b..5bed80309fbdb993aacc378ab0319ed291f1ffb3 100644
--- a/apps/scheduler/call/core.py
+++ b/apps/scheduler/call/core.py
@@ -76,21 +76,18 @@ class CoreCall(BaseModel):
extra="allow",
)
-
def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None:
"""初始化子类"""
super().__init_subclass__(**kwargs)
cls.input_model = input_model
cls.output_model = output_model
-
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
err = "[CoreCall] 必须手动实现info方法"
raise NotImplementedError(err)
-
@staticmethod
def _assemble_call_vars(executor: "StepExecutor") -> CallVars:
"""组装CallVars"""
@@ -120,7 +117,6 @@ class CoreCall(BaseModel):
summary=executor.task.runtime.summary,
)
-
@staticmethod
def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any:
"""
@@ -131,16 +127,14 @@ class CoreCall(BaseModel):
:return: 变量
"""
split_path = path.split("/")
+ if len(split_path) < 1:
+ err = f"[CoreCall] 路径格式错误: {path}"
+ logger.error(err)
+ return None
if split_path[0] not in history:
err = f"[CoreCall] 步骤{split_path[0]}不存在"
logger.error(err)
- raise CallError(
- message=err,
- data={
- "step_id": split_path[0],
- },
- )
-
+ return None
data = history[split_path[0]].output_data
for key in split_path[1:]:
if key not in data:
@@ -156,7 +150,6 @@ class CoreCall(BaseModel):
data = data[key]
return data
-
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
"""实例化Call类"""
@@ -170,36 +163,30 @@ class CoreCall(BaseModel):
await obj._set_input(executor)
return obj
-
async def _set_input(self, executor: "StepExecutor") -> None:
"""获取Call的输入"""
self._sys_vars = self._assemble_call_vars(executor)
input_data = await self._init(self._sys_vars)
self.input = input_data.model_dump(by_alias=True, exclude_none=True)
-
async def _init(self, call_vars: CallVars) -> DataBase:
"""初始化Call类,并返回Call的输入"""
err = "[CoreCall] 初始化方法必须手动实现"
raise NotImplementedError(err)
-
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的流式输出方法"""
yield CallOutputChunk(type=CallOutputType.TEXT, content="")
-
async def _after_exec(self, input_data: dict[str, Any]) -> None:
"""Call类实例的执行后方法"""
-
async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的执行方法"""
async for chunk in self._exec(input_data):
yield chunk
await self._after_exec(input_data)
-
async def _llm(self, messages: list[dict[str, Any]]) -> str:
"""Call可直接使用的LLM非流式调用"""
result = ""
@@ -210,7 +197,6 @@ class CoreCall(BaseModel):
self.output_tokens = llm.output_tokens
return result
-
async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel:
"""Call可直接使用的JSON生成"""
json = FunctionLLM()
diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py
index f8aebcd748d92de4109553d0f68fecac43363f58..2b9df0c6f8f1002e76bda24eb099f47a6084b2af 100644
--- a/apps/scheduler/call/facts/facts.py
+++ b/apps/scheduler/call/facts/facts.py
@@ -30,13 +30,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
answer: str = Field(description="用户输入")
-
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
return CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。")
-
@classmethod
async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self:
"""初始化工具"""
@@ -51,7 +49,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
await obj._set_input(executor)
return obj
-
async def _init(self, call_vars: CallVars) -> FactsInput:
"""初始化工具"""
# 组装必要变量
@@ -65,7 +62,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
message=message,
)
-
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行工具"""
data = FactsInput(**input_data)
@@ -83,7 +79,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
facts_obj: FactsGen = await self._json([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": facts_prompt},
- ], FactsGen) # type: ignore[arg-type]
+ ], FactsGen) # type: ignore[arg-type]
# 更新用户画像
domain_tpl = env.from_string(DOMAIN_PROMPT)
@@ -91,7 +87,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
domain_list: DomainGen = await self._json([
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": domain_prompt},
- ], DomainGen) # type: ignore[arg-type]
+ ], DomainGen) # type: ignore[arg-type]
for domain in domain_list.keywords:
await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain)
@@ -104,7 +100,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput):
).model_dump(by_alias=True, exclude_none=True),
)
-
async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行工具"""
async for chunk in self._exec(input_data):
diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py
index 661e9ada74da76e6d78a0f6cba42cef08c562335..bd0257b49fb6ea24fbbc429cb11f7b8e6be364b9 100644
--- a/apps/scheduler/call/mcp/mcp.py
+++ b/apps/scheduler/call/mcp/mcp.py
@@ -31,11 +31,10 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
"""MCP工具"""
mcp_list: list[str] = Field(description="MCP Server ID列表", max_length=5, min_length=1)
- max_steps: int = Field(description="最大步骤数", default=6)
+ max_steps: int = Field(description="最大步骤数", default=20)
text_output: bool = Field(description="是否将结果以文本形式返回", default=True)
to_user: bool = Field(description="是否将结果返回给用户", default=True)
-
@classmethod
def info(cls) -> CallInfo:
"""
@@ -46,7 +45,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
"""
return CallInfo(name="MCP", description="调用MCP Server,执行工具")
-
async def _init(self, call_vars: CallVars) -> MCPInput:
"""初始化MCP"""
# 获取MCP交互类
@@ -63,7 +61,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps)
-
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行MCP"""
# 生成计划
@@ -80,7 +77,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
async for chunk in self._generate_answer():
yield chunk
-
async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]:
"""生成执行计划"""
# 开始提示
@@ -103,7 +99,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
data=self._plan.model_dump(),
)
-
async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]:
"""执行单个计划项"""
# 判断是否为Final
@@ -141,7 +136,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
},
)
-
async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]:
"""生成总结"""
# 提示开始总结
@@ -163,7 +157,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput):
).model_dump(),
)
-
def _create_output(
self,
text: str,
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
index 4f8e1010cc0bd88f22e050778bce236d7d3515e0..d24e1661e9dd7facc4bf590b24409052b6ec9104 100644
--- a/apps/scheduler/call/slot/slot.py
+++ b/apps/scheduler/call/slot/slot.py
@@ -32,13 +32,11 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
facts: list[str] = Field(description="事实信息", default=[])
step_num: int = Field(description="历史步骤数", default=1)
-
@classmethod
def info(cls) -> CallInfo:
"""返回Call的名称和描述"""
return CallInfo(name="参数自动填充", description="根据步骤历史,自动填充参数")
-
async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]:
"""使用大模型填充参数;若大模型解析度足够,则直接返回结果"""
env = SandboxedEnvironment(
@@ -106,7 +104,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
await obj._set_input(executor)
return obj
-
async def _init(self, call_vars: CallVars) -> SlotInput:
"""初始化"""
self._flow_history = []
@@ -126,7 +123,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
remaining_schema=remaining_schema,
)
-
async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行参数填充"""
data = SlotInput(**input_data)
diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py
index f6814dd364636d304382dc258f87fdd9f06eb9a8..3997d89ad039a056ec38ce1cbb5cebfa8fb39671 100644
--- a/apps/scheduler/executor/agent.py
+++ b/apps/scheduler/executor/agent.py
@@ -1,40 +1,510 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP Agent执行器"""
+import anyio
import logging
-
+import uuid
from pydantic import Field
-
+from typing import Any
+from mcp.types import TextContent
+from apps.llm.patterns.rewrite import QuestionRewrite
+from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.executor.base import BaseExecutor
-from apps.scheduler.mcp_agent.agent.mcp import MCPAgent
-
+from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus
+from apps.scheduler.mcp_agent.host import MCPHost
+from apps.scheduler.mcp_agent.plan import MCPPlanner
+from apps.scheduler.mcp_agent.select import FINAL_TOOL_ID, MCPSelector
+from apps.scheduler.pool.mcp.client import MCPClient
+from apps.schemas.mcp import (
+ GoalEvaluationResult,
+ RestartStepIndex,
+ ToolRisk,
+ ErrorType,
+ ToolExcutionErrorType,
+ MCPPlan,
+ MCPCollection,
+ MCPTool,
+ Step
+)
+from apps.scheduler.pool.mcp.pool import MCPPool
+from apps.schemas.task import ExecutorState, FlowStepHistory, StepQueueItem
+from apps.schemas.message import param
+from apps.services.task import TaskManager
+from apps.services.appcenter import AppCenterManager
+from apps.services.mcp_service import MCPServiceManager
+from apps.services.user import UserManager
logger = logging.getLogger(__name__)
class MCPAgentExecutor(BaseExecutor):
"""MCP Agent执行器"""
- question: str = Field(description="用户输入")
- max_steps: int = Field(default=10, description="最大步数")
+ max_steps: int = Field(default=20, description="最大步数")
servers_id: list[str] = Field(description="MCP server id")
agent_id: str = Field(default="", description="Agent ID")
agent_description: str = Field(default="", description="Agent描述")
+ mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[])
+ mcp_pool: MCPPool = Field(description="MCP池", default=MCPPool())
+ tools: dict[str, MCPTool] = Field(
+ description="MCP工具列表,key为tool_id", default={}
+ )
+ tool_list: list[MCPTool] = Field(
+ description="MCP工具列表,包含所有MCP工具", default=[]
+ )
+ params: param | bool | None = Field(
+ default=None, description="流执行过程中的参数补充", alias="params"
+ )
+ resoning_llm: ReasoningLLM = Field(
+ default=ReasoningLLM(),
+ description="推理大模型",
+ )
- async def run(self) -> None:
- """运行MCP Agent"""
- agent = await MCPAgent.create(
- servers_id=self.servers_id,
- max_steps=self.max_steps,
- task=self.task,
- msg_queue=self.msg_queue,
- question=self.question,
- agent_id=self.agent_id,
- description=self.agent_description,
+ async def update_tokens(self) -> None:
+ """更新令牌数"""
+ self.task.tokens.input_tokens = self.resoning_llm.input_tokens
+ self.task.tokens.output_tokens = self.resoning_llm.output_tokens
+ await TaskManager.save_task(self.task.id, self.task)
+
+ async def load_state(self) -> None:
+ """从数据库中加载FlowExecutor的状态"""
+ logger.info("[FlowExecutor] 加载Executor状态")
+ # 尝试恢复State
+ if self.task.state and self.task.state.flow_status != FlowStatus.INIT:
+ self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
+
+ async def load_mcp(self) -> None:
+ """加载MCP服务器列表"""
+ logger.info("[MCPAgentExecutor] 加载MCP服务器列表")
+ # 获取MCP服务器列表
+ app = await AppCenterManager.fetch_app_data_by_id(self.agent_id)
+ mcp_ids = app.mcp_service
+ for mcp_id in mcp_ids:
+ mcp_service = await MCPServiceManager.get_mcp_service(mcp_id)
+ if self.task.ids.user_sub not in mcp_service.activated:
+ logger.warning(
+ "[MCPAgentExecutor] 用户 %s 未启用MCP %s",
+ self.task.ids.user_sub,
+ mcp_id,
+ )
+ continue
+
+ self.mcp_list.append(mcp_service)
+ await self.mcp_pool._init_mcp(mcp_id, self.task.ids.user_sub)
+ for tool in mcp_service.tools:
+ self.tools[tool.id] = tool
+ self.tool_list.extend(mcp_service.tools)
+ self.tools[FINAL_TOOL_ID] = MCPTool(
+ id=FINAL_TOOL_ID,
+ name="Final Tool",
+ description="结束流程的工具",
+ mcp_id="",
+ input_schema={}
+ )
+ self.tool_list.append(MCPTool(id=FINAL_TOOL_ID, name="Final Tool",
+ description="结束流程的工具", mcp_id="", input_schema={}))
+
+ async def get_tool_input_param(self, is_first: bool) -> None:
+ if is_first:
+ # 获取第一个输入参数
+ mcp_tool = self.tools[self.task.state.tool_id]
+ self.task.state.current_input = await MCPHost._get_first_input_params(mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task)
+ else:
+ # 获取后续输入参数
+ if isinstance(self.params, param):
+ params = self.params.content
+ params_description = self.params.description
+ else:
+ params = {}
+ params_description = ""
+ mcp_tool = self.tools[self.task.state.tool_id]
+ self.task.state.current_input = await MCPHost._fill_params(mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task.state.current_input, self.task.state.error_message, params, params_description)
+
+ async def confirm_before_step(self) -> None:
+ # 发送确认消息
+ mcp_tool = self.tools[self.task.state.tool_id]
+ confirm_message = await MCPPlanner.get_tool_risk(mcp_tool, self.task.state.current_input, "", self.resoning_llm)
+ await self.update_tokens()
+ await self.push_message(EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(
+ exclude_none=True, by_alias=True))
+ await self.push_message(EventType.FLOW_STOP, {})
+ self.task.state.flow_status = FlowStatus.WAITING
+ self.task.state.step_status = StepStatus.WAITING
+ self.task.context.append(
+ FlowStepHistory(
+ task_id=self.task.id,
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True),
+ )
)
+ async def run_step(self):
+ """执行步骤"""
+ self.task.state.flow_status = FlowStatus.RUNNING
+ self.task.state.step_status = StepStatus.RUNNING
+ mcp_tool = self.tools[self.task.state.tool_id]
+ mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub))
+ try:
+ output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input)
+ except anyio.ClosedResourceError as e:
+ import traceback
+ logger.error("[MCPAgentExecutor] MCP客户端连接已关闭: %s, 错误: %s", mcp_tool.mcp_id, traceback.format_exc())
+ await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.ids.user_sub)
+ await self.mcp_pool._init_mcp(mcp_tool.mcp_id, self.task.ids.user_sub)
+ logger.error("[MCPAgentExecutor] MCP客户端连接已关闭: %s, 错误: %s", mcp_tool.mcp_id, str(e))
+ self.task.state.step_status = StepStatus.ERROR
+ return
+ except Exception as e:
+ import traceback
+ logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误: %s", mcp_tool.name, traceback.format_exc())
+ self.task.state.step_status = StepStatus.ERROR
+ self.task.state.error_message = str(e)
+ return
+ if output_params.isError:
+ err = ""
+ for output in output_params.content:
+ if isinstance(output, TextContent):
+ err += output.text
+ self.task.state.step_status = StepStatus.ERROR
+ self.task.state.error_message = err
+ return
+ message = ""
+ for output in output_params.content:
+ if isinstance(output, TextContent):
+ message += output.text
+ output_params = {
+ "message": message,
+ }
+
+ await self.update_tokens()
+ await self.push_message(
+ EventType.STEP_INPUT,
+ self.task.state.current_input
+ )
+ await self.push_message(
+ EventType.STEP_OUTPUT,
+ output_params
+ )
+ self.task.context.append(
+ FlowStepHistory(
+ task_id=self.task.id,
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=StepStatus.SUCCESS,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data=self.task.state.current_input,
+ output_data=output_params,
+ )
+ )
+ self.task.state.step_status = StepStatus.SUCCESS
+
+ async def generate_params_with_null(self) -> None:
+ """生成参数补充"""
+ mcp_tool = self.tools[self.task.state.tool_id]
+ params_with_null = await MCPPlanner.get_missing_param(
+ mcp_tool,
+ self.task.state.current_input,
+ self.task.state.error_message,
+ self.resoning_llm
+ )
+ await self.update_tokens()
+ error_message = await MCPPlanner.change_err_message_to_description(
+ error_message=self.task.state.error_message,
+ tool=mcp_tool,
+ input_params=self.task.state.current_input,
+ reasoning_llm=self.resoning_llm
+ )
+ await self.push_message(
+ EventType.STEP_WAITING_FOR_PARAM,
+ data={
+ "message": error_message,
+ "params": params_with_null
+ }
+ )
+ await self.push_message(
+ EventType.FLOW_STOP,
+ data={}
+ )
+ self.task.state.flow_status = FlowStatus.WAITING
+ self.task.state.step_status = StepStatus.PARAM
+ self.task.context.append(
+ FlowStepHistory(
+ task_id=self.task.id,
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ ex_data={
+ "message": error_message,
+ "params": params_with_null
+ }
+ )
+ )
+
+ async def get_next_step(self) -> None:
+ if self.task.state.step_cnt < self.max_steps:
+ self.task.state.step_cnt += 1
+ history = await MCPHost.assemble_memory(self.task)
+ max_retry = 3
+ step = None
+ for i in range(max_retry):
+ step = await MCPPlanner.create_next_step(self.task.runtime.question, history, self.tool_list)
+ if step.tool_id in self.tools.keys():
+ break
+ if step is None or step.tool_id not in self.tools.keys():
+ step = Step(
+ tool_id=FINAL_TOOL_ID,
+ description=FINAL_TOOL_ID
+ )
+ tool_id = step.tool_id
+ if tool_id == FINAL_TOOL_ID:
+ step_name = FINAL_TOOL_ID
+ else:
+ step_name = self.tools[tool_id].name
+ step_description = step.description
+ self.task.state.step_id = str(uuid.uuid4())
+ self.task.state.tool_id = tool_id
+ self.task.state.step_name = step_name
+ self.task.state.step_description = step_description
+ self.task.state.step_status = StepStatus.INIT
+ self.task.state.current_input = {}
+ else:
+ # 没有下一步了,结束流程
+ self.task.state.tool_id = FINAL_TOOL_ID
+ return
+
+ async def error_handle_after_step(self) -> None:
+ """步骤执行失败后的错误处理"""
+ self.task.state.step_status = StepStatus.ERROR
+ self.task.state.flow_status = FlowStatus.ERROR
+ await self.push_message(
+ EventType.FLOW_FAILED,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ self.task.context.append(
+ FlowStepHistory(
+ task_id=self.task.id,
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ )
+ )
+
+ async def work(self) -> None:
+ """执行当前步骤"""
+ if self.task.state.step_status == StepStatus.INIT:
+ await self.push_message(
+ EventType.STEP_INIT,
+ data={}
+ )
+ await self.get_tool_input_param(is_first=True)
+ user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub)
+ if not user_info.auto_execute:
+ # 等待用户确认
+ await self.confirm_before_step()
+ return
+ self.task.state.step_status = StepStatus.RUNNING
+ elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]:
+ if self.task.state.step_status == StepStatus.PARAM:
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ elif self.task.state.step_status == StepStatus.WAITING:
+ if self.params:
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ else:
+ self.task.state.flow_status = FlowStatus.CANCELLED
+ self.task.state.step_status = StepStatus.CANCELLED
+ await self.push_message(
+ EventType.STEP_CANCEL,
+ data={}
+ )
+ await self.push_message(
+ EventType.FLOW_CANCEL,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ self.task.context[-1].step_status = StepStatus.CANCELLED
+ if self.task.state.step_status == StepStatus.PARAM:
+ await self.get_tool_input_param(is_first=False)
+ max_retry = 5
+ for i in range(max_retry):
+ if i != 0:
+ await self.get_tool_input_param(is_first=True)
+ await self.run_step()
+ if self.task.state.step_status == StepStatus.SUCCESS:
+ break
+ elif self.task.state.step_status == StepStatus.ERROR:
+ # 错误处理
+ if self.task.state.retry_times >= 3:
+ await self.error_handle_after_step()
+ else:
+ user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub)
+ if user_info.auto_execute:
+ self.push_message(
+ EventType.STEP_ERROR,
+ data={
+ "message": self.task.state.error_message,
+ }
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ self.task.context[-1].step_status = StepStatus.ERROR
+ self.task.context[-1].output_data = {
+ "message": self.task.state.error_message,
+ }
+ await self.get_next_step()
+ else:
+ mcp_tool = self.tools[self.task.state.tool_id]
+ is_param_error = await MCPPlanner.is_param_error(
+ self.task.runtime.question,
+ await MCPHost.assemble_memory(self.task),
+ self.task.state.error_message,
+ mcp_tool,
+ self.task.state.step_description,
+ self.task.state.current_input,
+ )
+ if is_param_error.is_param_error:
+ # 如果是参数错误,生成参数补充
+ await self.generate_params_with_null()
+ else:
+ self.push_message(
+ EventType.STEP_ERROR,
+ data={
+ "message": self.task.state.error_message,
+ }
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ self.task.context[-1].step_status = StepStatus.ERROR
+ self.task.context[-1].output_data = {
+ "message": self.task.state.error_message,
+ }
+ await self.get_next_step()
+ elif self.task.state.step_status == StepStatus.SUCCESS:
+ await self.get_next_step()
+
+ async def summarize(self) -> None:
+ async for chunk in MCPPlanner.generate_answer(
+ self.task.runtime.question,
+ (await MCPHost.assemble_memory(self.task)),
+ self.resoning_llm
+ ):
+ await self.push_message(
+ EventType.TEXT_ADD,
+ data=chunk
+ )
+ self.task.runtime.answer += chunk
+
+ async def run(self) -> None:
+ """执行MCP Agent的主逻辑"""
+ # 初始化MCP服务
+ await self.load_state()
+ await self.load_mcp()
+ if self.task.state.flow_status == FlowStatus.INIT:
+ # 初始化状态
+ try:
+ self.task.state.flow_id = str(uuid.uuid4())
+ self.task.state.flow_name = await MCPPlanner.get_flow_name(self.task.runtime.question, self.resoning_llm)
+ await TaskManager.save_task(self.task.id, self.task)
+ await self.get_next_step()
+ except Exception as e:
+ import traceback
+ logger.error("[MCPAgentExecutor] 初始化失败: %s", traceback.format_exc())
+ logger.error("[MCPAgentExecutor] 初始化失败: %s", str(e))
+ self.task.state.flow_status = FlowStatus.ERROR
+ self.task.state.error_message = str(e)
+ await self.push_message(
+ EventType.FLOW_FAILED,
+ data={}
+ )
+ return
+ self.task.state.flow_status = FlowStatus.RUNNING
+ await self.push_message(
+ EventType.FLOW_START,
+ data={}
+ )
+ if self.task.state.tool_id == FINAL_TOOL_ID:
+ # 如果已经是最后一步,直接结束
+ self.task.state.flow_status = FlowStatus.SUCCESS
+ await self.push_message(
+ EventType.FLOW_SUCCESS,
+ data={}
+ )
+ await self.summarize()
+ return
try:
- answer = await agent.run(self.question)
- self.task = agent.task
- self.task.runtime.answer = answer
+ while self.task.state.flow_status == FlowStatus.RUNNING:
+ if self.task.state.tool_id == FINAL_TOOL_ID:
+ break
+ await self.work()
+ await TaskManager.save_task(self.task.id, self.task)
+ tool_id = self.task.state.tool_id
+ if tool_id == FINAL_TOOL_ID:
+ # 如果已经是最后一步,直接结束
+ self.task.state.flow_status = FlowStatus.SUCCESS
+ self.task.state.step_status = StepStatus.SUCCESS
+ await self.push_message(
+ EventType.FLOW_SUCCESS,
+ data={}
+ )
+ await self.summarize()
except Exception as e:
- logger.error(f"Error: {str(e)}")
+ import traceback
+ logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", traceback.format_exc())
+ logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e))
+ self.task.state.flow_status = FlowStatus.ERROR
+ self.task.state.error_message = str(e)
+ self.task.state.step_status = StepStatus.ERROR
+ await self.push_message(
+ EventType.STEP_ERROR,
+ data={}
+ )
+ await self.push_message(
+ EventType.FLOW_FAILED,
+ data={}
+ )
+ if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id:
+ del self.task.context[-1]
+ self.task.context.append(
+ FlowStepHistory(
+ task_id=self.task.id,
+ step_id=self.task.state.step_id,
+ step_name=self.task.state.step_name,
+ step_description=self.task.state.step_description,
+ step_status=self.task.state.step_status,
+ flow_id=self.task.state.flow_id,
+ flow_name=self.task.state.flow_name,
+ flow_status=self.task.state.flow_status,
+ input_data={},
+ output_data={},
+ )
+ )
+ finally:
+ for mcp_service in self.mcp_list:
+ try:
+ await self.mcp_pool.stop(mcp_service.id, self.task.ids.user_sub)
+ except Exception as e:
+ import traceback
+ logger.error("[MCPAgentExecutor] 停止MCP客户端时发生错误: %s", traceback.format_exc())
diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py
index cf2f4e6838f5d8d5b6578db2750543ddfd22ee75..877341f08e776c168b618ff5cbea8e17c9cfc956 100644
--- a/apps/scheduler/executor/base.py
+++ b/apps/scheduler/executor/base.py
@@ -44,15 +44,8 @@ class BaseExecutor(BaseModel, ABC):
:param event_type: 事件类型
:param data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent
"""
- if event_type == EventType.FLOW_START.value and isinstance(data, dict):
- data = FlowStartContent(
- question=self.question,
- params=self.task.runtime.filled,
- ).model_dump(exclude_none=True, by_alias=True)
- elif event_type == EventType.FLOW_STOP.value:
- data = {}
- elif event_type == EventType.TEXT_ADD.value and isinstance(data, str):
- data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True)
+ if event_type == EventType.TEXT_ADD.value and isinstance(data, str):
+ data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True)
if data is None:
data = {}
@@ -62,7 +55,7 @@ class BaseExecutor(BaseModel, ABC):
await self.msg_queue.push_output(
self.task,
event_type=event_type,
- data=data, # type: ignore[arg-type]
+ data=data, # type: ignore[arg-type]
)
@abstractmethod
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index a70d0d7073c50ec2290c3fd4c797bbb3efcd4501..ebd4da8ecb6e55c5220a9cd4e0f63c72842c9828 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -11,7 +11,7 @@ from pydantic import Field
from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT
from apps.scheduler.executor.base import BaseExecutor
from apps.scheduler.executor.step import StepExecutor
-from apps.schemas.enum_var import EventType, SpecialCallType, StepStatus
+from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus
from apps.schemas.flow import Flow, Step
from apps.schemas.request_data import RequestDataApp
from apps.schemas.task import ExecutorState, StepQueueItem
@@ -46,21 +46,25 @@ class FlowExecutor(BaseExecutor):
flow_id: str = Field(description="Flow ID")
question: str = Field(description="用户输入")
post_body_app: RequestDataApp = Field(description="请求体中的app信息")
-
+ current_step: StepQueueItem | None = Field(
+ description="当前执行的步骤",
+ default=None
+ )
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
- if self.task.state:
+ if self.task.state and self.task.state.flow_status != FlowStatus.INIT:
self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
else:
# 创建ExecutorState
self.task.state = ExecutorState(
flow_id=str(self.flow_id),
flow_name=self.flow.name,
+ flow_status=FlowStatus.RUNNING,
description=str(self.flow.description),
- status=StepStatus.RUNNING,
+ step_status=StepStatus.RUNNING,
app_id=str(self.post_body_app.app_id),
step_id="start",
step_name="开始",
@@ -70,14 +74,13 @@ class FlowExecutor(BaseExecutor):
self._reached_end: bool = False
self.step_queue: deque[StepQueueItem] = deque()
-
- async def _invoke_runner(self, queue_item: StepQueueItem) -> None:
+ async def _invoke_runner(self) -> None:
"""单一Step执行"""
# 创建步骤Runner
step_runner = StepExecutor(
msg_queue=self.msg_queue,
task=self.task,
- step=queue_item,
+ step=self.current_step,
background=self.background,
question=self.question,
)
@@ -85,23 +88,21 @@ class FlowExecutor(BaseExecutor):
# 初始化步骤
await step_runner.init()
# 运行Step
- await step_runner.run()
+ await step_runner.run()
# 更新Task(已存过库)
self.task = step_runner.task
-
async def _step_process(self) -> None:
"""执行当前queue里面的所有步骤(在用户看来是单一Step)"""
while True:
try:
- queue_item = self.step_queue.pop()
+ self.current_step = self.step_queue.pop()
except IndexError:
break
# 执行Step
- await self._invoke_runner(queue_item)
-
+ await self._invoke_runner()
async def _find_next_id(self, step_id: str) -> list[str]:
"""查找下一个节点"""
@@ -111,14 +112,22 @@ class FlowExecutor(BaseExecutor):
next_ids += [edge.edge_to]
return next_ids
-
async def _find_flow_next(self) -> list[StepQueueItem]:
"""在当前步骤执行前,尝试获取下一步"""
# 如果当前步骤为结束,则直接返回
- if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type]
+ if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type]
return []
-
- next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type]
+ if self.current_step.step.type == SpecialCallType.CHOICE.value:
+ # 如果是choice节点,获取分支ID
+ branch_id = self.task.context[-1].output_data["branch_id"]
+ if branch_id:
+ next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id)
+ logger.info("[FlowExecutor] 分支ID:%s", branch_id)
+ else:
+ logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表")
+ return []
+ else:
+ next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type]
# 如果step没有任何出边,直接跳到end
if not next_steps:
return [
@@ -137,7 +146,6 @@ class FlowExecutor(BaseExecutor):
for next_step in next_steps
]
-
async def run(self) -> None:
"""
运行流,返回各步骤结果,直到无法继续执行
@@ -150,8 +158,8 @@ class FlowExecutor(BaseExecutor):
# 获取首个步骤
first_step = StepQueueItem(
- step_id=self.task.state.step_id, # type: ignore[arg-type]
- step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type]
+ step_id=self.task.state.step_id, # type: ignore[arg-type]
+ step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type]
)
# 头插开始前的系统步骤,并执行
@@ -166,11 +174,12 @@ class FlowExecutor(BaseExecutor):
# 插入首个步骤
self.step_queue.append(first_step)
-
+ self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type]
# 运行Flow(未达终点)
+ is_error = False
while not self._reached_end:
# 如果当前步骤出错,执行错误处理步骤
- if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type]
+ if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type]
logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤")
self.step_queue.clear()
self.step_queue.appendleft(StepQueueItem(
@@ -183,13 +192,14 @@ class FlowExecutor(BaseExecutor):
params={
"user_prompt": LLM_ERROR_PROMPT.replace(
"{{ error_info }}",
- self.task.state.error_info["err_msg"], # type: ignore[arg-type]
+ self.task.state.error_info["err_msg"], # type: ignore[arg-type]
),
},
),
enable_filling=False,
to_user=False,
))
+ is_error = True
# 错误处理后结束
self._reached_end = True
@@ -204,6 +214,12 @@ class FlowExecutor(BaseExecutor):
for step in next_step:
self.step_queue.append(step)
+ # 更新Task状态
+ if is_error:
+ self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type]
+ else:
+ self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type]
+
# 尾插运行结束后的系统步骤
for step in FIXED_STEPS_AFTER_END:
self.step_queue.append(StepQueueItem(
@@ -215,4 +231,7 @@ class FlowExecutor(BaseExecutor):
# FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间)
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time
# 推送Flow停止消息
- await self.push_message(EventType.FLOW_STOP.value)
+ if is_error:
+ await self.push_message(EventType.FLOW_FAILED.value)
+ else:
+ await self.push_message(EventType.FLOW_SUCCESS.value)
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index 6b3451fa9ccd92b8f9b2d497f3983a64ff3a6981..5a95e407989f4dfc9c26c2157ccbec5d3ecdad21 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -86,8 +86,8 @@ class StepExecutor(BaseExecutor):
logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name)
# State写入ID和运行状态
- self.task.state.step_id = self.step.step_id # type: ignore[arg-type]
- self.task.state.step_name = self.step.step.name # type: ignore[arg-type]
+ self.task.state.step_id = self.step.step_id # type: ignore[arg-type]
+ self.task.state.step_name = self.step.step.name # type: ignore[arg-type]
# 获取并验证Call类
node_id = self.step.step.node
@@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor):
logger.exception("[StepExecutor] 初始化Call失败")
raise
-
async def _run_slot_filling(self) -> None:
"""运行自动参数填充;相当于特殊Step,但是不存库"""
# 判断是否需要进行自动参数填充
@@ -127,13 +126,13 @@ class StepExecutor(BaseExecutor):
return
# 暂存旧数据
- current_step_id = self.task.state.step_id # type: ignore[arg-type]
- current_step_name = self.task.state.step_name # type: ignore[arg-type]
+ current_step_id = self.task.state.step_id # type: ignore[arg-type]
+ current_step_name = self.task.state.step_name # type: ignore[arg-type]
# 更新State
- self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type]
- self.task.state.step_name = "自动参数填充" # type: ignore[arg-type]
- self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type]
+ self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type]
+ self.task.state.step_name = "自动参数填充" # type: ignore[arg-type]
+ self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type]
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2)
# 初始化填参
@@ -156,21 +155,20 @@ class StepExecutor(BaseExecutor):
# 如果没有填全,则状态设置为待填参
if result.remaining_schema:
- self.task.state.status = StepStatus.PARAM # type: ignore[arg-type]
+ self.task.state.step_status = StepStatus.PARAM # type: ignore[arg-type]
else:
- self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type]
+ self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type]
await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True))
# 更新输入
self.obj.input.update(result.slot_data)
# 恢复State
- self.task.state.step_id = current_step_id # type: ignore[arg-type]
- self.task.state.step_name = current_step_name # type: ignore[arg-type]
+ self.task.state.step_id = current_step_id # type: ignore[arg-type]
+ self.task.state.step_name = current_step_name # type: ignore[arg-type]
self.task.tokens.input_tokens += self.obj.tokens.input_tokens
self.task.tokens.output_tokens += self.obj.tokens.output_tokens
-
async def _process_chunk(
self,
iterator: AsyncGenerator[CallOutputChunk, None],
@@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor):
return content
-
async def run(self) -> None:
"""运行单个步骤"""
self.validate_flow_state(self.task)
@@ -212,7 +209,7 @@ class StepExecutor(BaseExecutor):
await self._run_slot_filling()
# 更新状态
- self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type]
+ self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type]
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2)
# 推送输入
await self.push_message(EventType.STEP_INPUT.value, self.obj.input)
@@ -224,22 +221,22 @@ class StepExecutor(BaseExecutor):
content = await self._process_chunk(iterator, to_user=self.obj.to_user)
except Exception as e:
logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤")
- self.task.state.status = StepStatus.ERROR # type: ignore[arg-type]
+ self.task.state.step_status = StepStatus.ERROR # type: ignore[arg-type]
await self.push_message(EventType.STEP_OUTPUT.value, {})
if isinstance(e, CallError):
- self.task.state.error_info = { # type: ignore[arg-type]
+ self.task.state.error_info = { # type: ignore[arg-type]
"err_msg": e.message,
"data": e.data,
}
else:
- self.task.state.error_info = { # type: ignore[arg-type]
+ self.task.state.error_info = { # type: ignore[arg-type]
"err_msg": str(e),
"data": {},
}
return
# 更新执行状态
- self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type]
+ self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type]
self.task.tokens.input_tokens += self.obj.tokens.input_tokens
self.task.tokens.output_tokens += self.obj.tokens.output_tokens
self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time
@@ -253,16 +250,17 @@ class StepExecutor(BaseExecutor):
# 更新context
history = FlowStepHistory(
task_id=self.task.id,
- flow_id=self.task.state.flow_id, # type: ignore[arg-type]
- flow_name=self.task.state.flow_name, # type: ignore[arg-type]
+ flow_id=self.task.state.flow_id, # type: ignore[arg-type]
+ flow_name=self.task.state.flow_name, # type: ignore[arg-type]
+ flow_status=self.task.state.flow_status, # type: ignore[arg-type]
step_id=self.step.step_id,
step_name=self.step.step.name,
step_description=self.step.step.description,
- status=self.task.state.status, # type: ignore[arg-type]
+ step_status=self.task.state.step_status,
input_data=self.obj.input,
output_data=output_data,
)
- self.task.context.append(history.model_dump(exclude_none=True, by_alias=True))
+ self.task.context.append(history)
# 推送输出
await self.push_message(EventType.STEP_OUTPUT.value, output_data)
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index 78aa7bc3ee869e8710e1fb02a2d9fb438d04be34..8e11083900c20c29fcaf71b6535f3b442fc2313a 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -40,7 +40,6 @@ class MCPHost:
lstrip_blocks=True,
)
-
async def get_client(self, mcp_id: str) -> MCPClient | None:
"""获取MCP客户端"""
mongo = MongoDB()
@@ -59,7 +58,6 @@ class MCPHost:
logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id)
return None
-
async def assemble_memory(self) -> str:
"""组装记忆"""
task = await TaskManager.get_task_by_task_id(self._task_id)
@@ -69,7 +67,7 @@ class MCPHost:
context_list = []
for ctx_id in self._context_list:
- context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None)
+ context = next((ctx for ctx in task.context if ctx.id == ctx_id), None)
if not context:
continue
context_list.append(context)
@@ -78,7 +76,6 @@ class MCPHost:
context_list=context_list,
)
-
async def _save_memory(
self,
tool: MCPTool,
@@ -105,11 +102,12 @@ class MCPHost:
task_id=self._task_id,
flow_id=self._runtime_id,
flow_name=self._runtime_name,
+ flow_status=StepStatus.RUNNING,
step_id=tool.name,
step_name=tool.name,
# description是规划的实际内容
step_description=plan_item.content,
- status=StepStatus.SUCCESS,
+ step_status=StepStatus.SUCCESS,
input_data=input_data,
output_data=output_data,
)
@@ -120,12 +118,11 @@ class MCPHost:
logger.error("任务 %s 不存在", self._task_id)
return {}
self._context_list.append(context.id)
- task.context.append(context.model_dump(by_alias=True, exclude_none=True))
+ task.context.append(context.model_dump(exclude_none=True, by_alias=True))
await TaskManager.save_task(self._task_id, task)
return output_data
-
async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]:
"""填充工具参数"""
# 更清晰的输入·指令,这样可以调用generate
@@ -146,7 +143,6 @@ class MCPHost:
)
return await json_generator.generate()
-
async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]:
"""调用工具"""
# 拿到Client
@@ -170,7 +166,6 @@ class MCPHost:
return processed_result
-
async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]:
"""获取工具列表"""
mongo = MongoDB()
diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py
index b322fb0883e8ed935243389cb86066845a549631..b906fb8f75dfafb2546aaef6eb36370645c57598 100644
--- a/apps/scheduler/mcp/prompt.py
+++ b/apps/scheduler/mcp/prompt.py
@@ -72,7 +72,8 @@ CREATE_PLAN = dedent(r"""
2. 计划中的每一个步骤必须且只能使用一个工具。
3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。
4. 计划中的最后一步必须是Final工具,以确保计划执行结束。
-
+ 5.生成的计划必须要覆盖用户的目标,不能遗漏任何用户目标中的内容。
+
# 生成计划时的注意事项:
- 每一条计划包含3个部分:
@@ -233,8 +234,8 @@ FINAL_ANSWER = dedent(r"""
MEMORY_TEMPLATE = dedent(r"""
{% for ctx in context_list %}
- 第{{ loop.index }}步:{{ ctx.step_description }}
- - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}`
- - 执行状态:{{ ctx.status }}
- - 得到数据:`{{ ctx.output_data }}`
+ - 调用工具 `{{ ctx.step_name }}`,并提供参数 `{{ ctx.input_data|tojson }}`。
+ - 执行状态:{{ ctx.step_status }}
+ - 得到数据:`{{ ctx.output_data|tojson }}`
{% endfor %}
""")
diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py
index 2ff5034471c5e9c38f166c6187b76dfb4596f734..f3d6e0d4ae3aea9508d48f33499ac2bcf3de98a9 100644
--- a/apps/scheduler/mcp/select.py
+++ b/apps/scheduler/mcp/select.py
@@ -39,7 +39,6 @@ class MCPSelector:
sql += f"'{mcp_id}', "
return sql.rstrip(", ") + ")"
-
async def _get_top_mcp_by_embedding(
self,
query: str,
@@ -72,7 +71,6 @@ class MCPSelector:
}])
return llm_mcp_list
-
async def _get_mcp_by_llm(
self,
query: str,
@@ -100,7 +98,6 @@ class MCPSelector:
# 使用小模型提取JSON
return await self._call_function_mcp(result, mcp_ids)
-
async def _call_reasoning(self, prompt: str) -> str:
"""调用大模型进行推理"""
logger.info("[MCPHelper] 调用推理大模型")
@@ -116,7 +113,6 @@ class MCPSelector:
self.output_tokens += llm.output_tokens
return result
-
async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult:
"""调用结构化输出小模型提取JSON"""
logger.info("[MCPHelper] 调用结构化输出小模型")
@@ -136,7 +132,6 @@ class MCPSelector:
raise
return result
-
async def select_top_mcp(
self,
query: str,
@@ -153,7 +148,6 @@ class MCPSelector:
# 通过LLM选择最合适的
return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list)
-
@staticmethod
async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]:
"""选择最合适的工具"""
diff --git a/apps/scheduler/mcp_agent/__init__.py b/apps/scheduler/mcp_agent/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..12f5cb68c12e4d19d830a8155eaeb0851fce897d
--- /dev/null
+++ b/apps/scheduler/mcp_agent/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""Scheduler MCP 模块"""
+
+from apps.scheduler.mcp.host import MCPHost
+from apps.scheduler.mcp.plan import MCPPlanner
+from apps.scheduler.mcp.select import MCPSelector
+
+__all__ = ["MCPHost", "MCPPlanner", "MCPSelector"]
diff --git a/apps/scheduler/mcp_agent/agent/base.py b/apps/scheduler/mcp_agent/agent/base.py
deleted file mode 100644
index eccb58a9ce8466f67ac76040a53469edac455109..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/agent/base.py
+++ /dev/null
@@ -1,196 +0,0 @@
-"""MCP Agent基类"""
-import logging
-from abc import ABC, abstractmethod
-from contextlib import asynccontextmanager
-
-from pydantic import BaseModel, Field, model_validator
-
-from apps.common.queue import MessageQueue
-from apps.schemas.enum_var import AgentState
-from apps.schemas.task import Task
-from apps.llm.reasoning import ReasoningLLM
-from apps.scheduler.mcp_agent.schema import Memory, Message, Role
-from apps.services.activity import Activity
-
-logger = logging.getLogger(__name__)
-
-
-class BaseAgent(BaseModel, ABC):
- """
- 用于管理代理状态和执行的抽象基类。
-
- 为状态转换、内存管理、
- 以及分步执行循环。子类必须实现`step`方法。
- """
-
- msg_queue: MessageQueue
- task: Task
- name: str = Field(..., description="Agent名称")
- agent_id: str = Field(default="", description="Agent ID")
- description: str = Field(default="", description="Agent描述")
- question: str
- # Prompts
- next_step_prompt: str | None = Field(
- None, description="判断下一步动作的提示"
- )
-
- # Dependencies
- llm: ReasoningLLM = Field(default_factory=ReasoningLLM, description="大模型实例")
- memory: Memory = Field(default_factory=Memory, description="Agent记忆库")
- state: AgentState = Field(
- default=AgentState.IDLE, description="Agent状态"
- )
- servers_id: list[str] = Field(default_factory=list, description="MCP server id")
-
- # Execution control
- max_steps: int = Field(default=10, description="终止前的最大步长")
- current_step: int = Field(default=0, description="执行中的当前步骤")
-
- duplicate_threshold: int = 2
-
- user_prompt: str = r"""
- 当前步骤:{step} 工具输出结果:{result}
- 请总结当前正在执行的步骤和对应的工具输出结果,内容包括当前步骤是多少,执行的工具是什么,输出是什么。
- 最终以报告的形式展示。
- 如果工具输出结果中执行的工具为terminate,请按照状态输出本次交互过程最终结果并完成对整个报告的总结,不需要输出你的分析过程。
- """
- """用户提示词"""
-
- class Config:
- arbitrary_types_allowed = True
- extra = "allow" # Allow extra fields for flexibility in subclasses
-
- @model_validator(mode="after")
- def initialize_agent(self) -> "BaseAgent":
- """初始化Agent"""
- if self.llm is None or not isinstance(self.llm, ReasoningLLM):
- self.llm = ReasoningLLM()
- if not isinstance(self.memory, Memory):
- self.memory = Memory()
- return self
-
- @asynccontextmanager
- async def state_context(self, new_state: AgentState):
- """
- Agent状态转换上下文管理器
-
- Args:
- new_state: 要转变的状态
-
- :return: None
- :raise ValueError: 如果new_state无效
- """
- if not isinstance(new_state, AgentState):
- raise ValueError(f"无效状态: {new_state}")
-
- previous_state = self.state
- self.state = new_state
- try:
- yield
- except Exception as e:
- self.state = AgentState.ERROR # Transition to ERROR on failure
- raise e
- finally:
- self.state = previous_state # Revert to previous state
-
- def update_memory(
- self,
- role: Role,
- content: str,
- **kwargs,
- ) -> None:
- """添加信息到Agent的memory中"""
- message_map = {
- "user": Message.user_message,
- "system": Message.system_message,
- "assistant": Message.assistant_message,
- "tool": lambda content, **kw: Message.tool_message(content, **kw),
- }
-
- if role not in message_map:
- raise ValueError(f"不支持的消息角色: {role}")
-
- # Create message with appropriate parameters based on role
- kwargs = {**(kwargs if role == "tool" else {})}
- self.memory.add_message(message_map[role](content, **kwargs))
-
- async def run(self, request: str | None = None) -> str:
- """异步执行Agent的主循环"""
- self.task.runtime.question = request
- if self.state != AgentState.IDLE:
- raise RuntimeError(f"无法从以下状态运行智能体: {self.state}")
-
- if request:
- self.update_memory("user", request)
-
- results: list[str] = []
- async with self.state_context(AgentState.RUNNING):
- while (
- self.current_step < self.max_steps and self.state != AgentState.FINISHED
- ):
- if not await Activity.is_active(self.task.ids.user_sub):
- logger.info("用户终止会话,任务停止!")
- return ""
- self.current_step += 1
- logger.info(f"执行步骤{self.current_step}/{self.max_steps}")
- step_result = await self.step()
-
- # Check for stuck state
- if self.is_stuck():
- self.handle_stuck_state()
- result = f"Step {self.current_step}: {step_result}"
- results.append(result)
-
- if self.current_step >= self.max_steps:
- self.current_step = 0
- self.state = AgentState.IDLE
- result = f"任务终止: 已达到最大步数 ({self.max_steps})"
- await self.msg_queue.push_output(
- self.task,
- event_type="text.add",
- data={"text": result}, # type: ignore[arg-type]
- )
- results.append(result)
- return "\n".join(results) if results else "未执行任何步骤"
-
- @abstractmethod
- async def step(self) -> str:
- """
- 执行代理工作流程中的单个步骤。
-
- 必须由子类实现,以定义具体的行为。
- """
-
- def handle_stuck_state(self):
- """通过添加更改策略的提示来处理卡住状态"""
- stuck_prompt = "\
- 观察到重复响应。考虑新策略,避免重复已经尝试过的无效路径"
- self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}"
- logger.warning(f"检测到智能体处于卡住状态。新增提示:{stuck_prompt}")
-
- def is_stuck(self) -> bool:
- """通过检测重复内容来检查代理是否卡在循环中"""
- if len(self.memory.messages) < 2:
- return False
-
- last_message = self.memory.messages[-1]
- if not last_message.content:
- return False
-
- duplicate_count = sum(
- 1
- for msg in reversed(self.memory.messages[:-1])
- if msg.role == "assistant" and msg.content == last_message.content
- )
-
- return duplicate_count >= self.duplicate_threshold
-
- @property
- def messages(self) -> list[Message]:
- """从Agent memory中检索消息列表"""
- return self.memory.messages
-
- @messages.setter
- def messages(self, value: list[Message]) -> None:
- """设置Agent memory的消息列表"""
- self.memory.messages = value
diff --git a/apps/scheduler/mcp_agent/agent/mcp.py b/apps/scheduler/mcp_agent/agent/mcp.py
deleted file mode 100644
index 378da368aca02d0d628352fcc4816b98b2921d01..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/agent/mcp.py
+++ /dev/null
@@ -1,81 +0,0 @@
-"""MCP Agent"""
-import logging
-
-from pydantic import Field
-
-from apps.scheduler.mcp.host import MCPHost
-from apps.scheduler.mcp_agent.agent.toolcall import ToolCallAgent
-from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection
-
-logger = logging.getLogger(__name__)
-
-
-class MCPAgent(ToolCallAgent):
- """
- 用于与MCP(模型上下文协议)服务器交互。
-
- 使用SSE或stdio传输连接到MCP服务器
- 并使服务器的工具
- """
-
- name: str = "MCPAgent"
- description: str = "一个多功能的智能体,能够使用多种工具(包括基于MCP的工具)解决各种任务"
-
- # Add general-purpose tools to the tool collection
- available_tools: ToolCollection = Field(
- default_factory=lambda: ToolCollection(
- Terminate(),
- ),
- )
-
- special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name])
-
- _initialized: bool = False
-
- @classmethod
- async def create(cls, **kwargs) -> "MCPAgent": # noqa: ANN003
- """创建并初始化MCP Agent实例"""
- instance = cls(**kwargs)
- await instance.initialize_mcp_servers()
- instance._initialized = True
- return instance
-
- async def initialize_mcp_servers(self) -> None:
- """初始化与已配置的MCP服务器的连接"""
- mcp_host = MCPHost(
- self.task.ids.user_sub,
- self.task.id,
- self.agent_id,
- self.description,
- )
- mcps = {}
- for mcp_id in self.servers_id:
- client = await mcp_host.get_client(mcp_id)
- if client:
- mcps[mcp_id] = client
-
- for mcp_id, mcp_client in mcps.items():
- new_tools = []
- for tool in mcp_client.tools:
- original_name = tool.name
- # Always prefix with server_id to ensure uniqueness
- tool_name = f"mcp_{mcp_id}_{original_name}"
-
- server_tool = MCPClientTool(
- name=tool_name,
- description=tool.description,
- parameters=tool.inputSchema,
- session=mcp_client.session,
- server_id=mcp_id,
- original_name=original_name,
- )
- new_tools.append(server_tool)
- self.available_tools.add_tools(*new_tools)
-
- async def think(self) -> bool:
- """使用适当的上下文处理当前状态并决定下一步操作"""
- if not self._initialized:
- await self.initialize_mcp_servers()
- self._initialized = True
-
- return await super().think()
diff --git a/apps/scheduler/mcp_agent/agent/react.py b/apps/scheduler/mcp_agent/agent/react.py
deleted file mode 100644
index b56efd8b195eb36c4d5718711cc5f07b5a49812f..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/agent/react.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from abc import ABC, abstractmethod
-
-from pydantic import Field
-
-from apps.schemas.enum_var import AgentState
-from apps.llm.reasoning import ReasoningLLM
-from apps.scheduler.mcp_agent.agent.base import BaseAgent
-from apps.scheduler.mcp_agent.schema import Memory
-
-
-class ReActAgent(BaseAgent, ABC):
- name: str
- description: str | None = None
-
- system_prompt: str | None = None
- next_step_prompt: str | None = None
-
- llm: ReasoningLLM | None = Field(default_factory=ReasoningLLM)
- memory: Memory = Field(default_factory=Memory)
- state: AgentState = AgentState.IDLE
-
- @abstractmethod
- async def think(self) -> bool:
- """处理当前状态并决定下一步操作"""
-
- @abstractmethod
- async def act(self) -> str:
- """执行已决定的行动"""
-
- async def step(self) -> str:
- """执行一个步骤:思考和行动"""
- should_act = await self.think()
- if not should_act:
- return "思考完成-无需采取任何行动"
- return await self.act()
diff --git a/apps/scheduler/mcp_agent/agent/toolcall.py b/apps/scheduler/mcp_agent/agent/toolcall.py
deleted file mode 100644
index 1e22099ce1d2e2f2f54a3bc018511acf887a91a1..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/agent/toolcall.py
+++ /dev/null
@@ -1,238 +0,0 @@
-import asyncio
-import json
-import logging
-from typing import Any, Optional
-
-from pydantic import Field
-
-from apps.schemas.enum_var import AgentState
-from apps.llm.function import JsonGenerator
-from apps.llm.patterns import Select
-from apps.scheduler.mcp_agent.agent.react import ReActAgent
-from apps.scheduler.mcp_agent.schema import Function, Message, ToolCall
-from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection
-
-logger = logging.getLogger(__name__)
-
-
-class ToolCallAgent(ReActAgent):
- """用于处理工具/函数调用的基本Agent类"""
-
- name: str = "toolcall"
- description: str = "可以执行工具调用的智能体"
-
- available_tools: ToolCollection = ToolCollection(
- Terminate(),
- )
- tool_choices: str = "auto"
- special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name])
-
- tool_calls: list[ToolCall] = Field(default_factory=list)
- _current_base64_image: str | None = None
-
- max_observe: int | bool | None = None
-
- async def think(self) -> bool:
- """使用工具处理当前状态并决定下一步行动"""
- messages = []
- for message in self.messages:
- if isinstance(message, Message):
- message = message.to_dict()
- messages.append(message)
- try:
- # 通过工具获得响应
- select_obj = Select()
- choices = []
- for available_tool in self.available_tools.to_params():
- choices.append(available_tool.get("function"))
-
- tool = await select_obj.generate(question=self.question, choices=choices)
- if tool in self.available_tools.tool_map:
- schema = self.available_tools.tool_map[tool].parameters
- json_generator = JsonGenerator(
- query="根据跟定的信息,获取工具参数",
- conversation=messages,
- schema=schema,
- ) # JsonGenerator
- parameters = await json_generator.generate()
-
- else:
- raise ValueError(f"尝试调用不存在的工具: {tool}")
- except Exception as e:
- raise
- self.tool_calls = tool_calls = [ToolCall(id=tool, function=Function(name=tool, arguments=parameters))]
- content = f"选择的执行工具为:{tool}, 参数为{parameters}"
-
- logger.info(
- f"{self.name} 选择 {len(tool_calls) if tool_calls else 0}个工具执行"
- )
- if tool_calls:
- logger.info(
- f"准备使用的工具: {[call.function.name for call in tool_calls]}"
- )
- logger.info(f"工具参数: {tool_calls[0].function.arguments}")
-
- try:
-
- assistant_msg = (
- Message.from_tool_calls(content=content, tool_calls=self.tool_calls)
- if self.tool_calls
- else Message.assistant_message(content)
- )
- self.memory.add_message(assistant_msg)
-
- if not self.tool_calls:
- return bool(content)
-
- return bool(self.tool_calls)
- except Exception as e:
- logger.error(f"{self.name}的思考过程遇到了问题:: {e}")
- self.memory.add_message(
- Message.assistant_message(
- f"处理时遇到错误: {str(e)}"
- )
- )
- return False
-
- async def act(self) -> str:
- """执行工具调用并处理其结果"""
- if not self.tool_calls:
- # 如果没有工具调用,则返回最后的消息内容
- return self.messages[-1].content or "没有要执行的内容或命令"
-
- results = []
- for command in self.tool_calls:
- await self.msg_queue.push_output(
- self.task,
- event_type="text.add",
- data={"text": f"正在执行工具{command.function.name}"}
- )
-
- self._current_base64_image = None
-
- result = await self.execute_tool(command)
-
- if self.max_observe:
- result = result[: self.max_observe]
-
- push_result = ""
- async for chunk in self.llm.call(
- messages=[{"role": "system", "content": "You are a helpful asistant."},
- {"role": "user", "content": self.user_prompt.format(
- step=self.current_step,
- result=result,
- )}, ], streaming=False
- ):
- push_result += chunk
- self.task.tokens.input_tokens += self.llm.input_tokens
- self.task.tokens.output_tokens += self.llm.output_tokens
- await self.msg_queue.push_output(
- self.task,
- event_type="text.add",
- data={"text": push_result}, # type: ignore[arg-type]
- )
-
- await self.msg_queue.push_output(
- self.task,
- event_type="text.add",
- data={"text": f"工具{command.function.name}执行完成"}, # type: ignore[arg-type]
- )
-
- logger.info(
- f"工具'{command.function.name}'执行完成! 执行结果为: {result}"
- )
-
- # 将工具响应添加到内存
- tool_msg = Message.tool_message(
- content=result,
- tool_call_id=command.id,
- name=command.function.name,
- )
- self.memory.add_message(tool_msg)
- results.append(result)
- self.question += (
- f"\n已执行工具{command.function.name}, "
- f"作用是{self.available_tools.tool_map[command.function.name].description},结果为{result}"
- )
-
- return "\n\n".join(results)
-
- async def execute_tool(self, command: ToolCall) -> str:
- """执行单个工具调用"""
- if not command or not command.function or not command.function.name:
- return "错误:无效的命令格式"
-
- name = command.function.name
- if name not in self.available_tools.tool_map:
- return f"错误:未知工具 '{name}'"
-
- try:
- # 解析参数
- args = command.function.arguments
- # 执行工具
- logger.info(f"激活工具:'{name}'...")
- result = await self.available_tools.execute(name=name, tool_input=args)
-
- # 执行特殊工具
- await self._handle_special_tool(name=name, result=result)
-
- # 格式化结果
- observation = (
- f"观察到执行的工具 `{name}`的输出:\n{str(result)}"
- if result
- else f"工具 `{name}` 已完成,无输出"
- )
-
- return observation
- except json.JSONDecodeError:
- error_msg = f"解析{name}的参数时出错:JSON格式无效"
- logger.error(
- f"{name}”的参数没有意义-无效的JSON,参数:{command.function.arguments}"
- )
- return f"错误: {error_msg}"
- except Exception as e:
- error_msg = f"工具 '{name}' 遇到问题: {str(e)}"
- logger.exception(error_msg)
- return f"错误: {error_msg}"
-
- async def _handle_special_tool(self, name: str, result: Any, **kwargs):
- """处理特殊工具的执行和状态变化"""
- if not self._is_special_tool(name):
- return
-
- if self._should_finish_execution(name=name, result=result, **kwargs):
- # 将智能体状态设为finished
- logger.info(f"特殊工具'{name}'已完成任务!")
- self.state = AgentState.FINISHED
-
- @staticmethod
- def _should_finish_execution(**kwargs) -> bool:
- """确定工具执行是否应完成"""
- return True
-
- def _is_special_tool(self, name: str) -> bool:
- """检查工具名称是否在特殊工具列表中"""
- return name.lower() in [n.lower() for n in self.special_tool_names]
-
- async def cleanup(self):
- """清理Agent工具使用的资源。"""
- logger.info(f"正在清理智能体的资源'{self.name}'...")
- for tool_name, tool_instance in self.available_tools.tool_map.items():
- if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction(
- tool_instance.cleanup
- ):
- try:
- logger.debug(f"清理工具: {tool_name}")
- await tool_instance.cleanup()
- except Exception as e:
- logger.error(
- f"清理工具时发生错误'{tool_name}': {e}", exc_info=True
- )
- logger.info(f"智能体清理完成'{self.name}'.")
-
- async def run(self, request: Optional[str] = None) -> str:
- """运行Agent"""
- try:
- return await super().run(request)
- finally:
- await self.cleanup()
diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac3829b418be54bfc5df6ae3d61db82fe0fc129e
--- /dev/null
+++ b/apps/scheduler/mcp_agent/base.py
@@ -0,0 +1,61 @@
+from typing import Any
+import json
+from jsonschema import validate
+import logging
+from apps.llm.function import JsonGenerator
+from apps.llm.reasoning import ReasoningLLM
+
+logger = logging.getLogger(__name__)
+
+
+class McpBase:
+ @staticmethod
+ async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取推理结果"""
+ # 调用推理大模型
+ message = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ]
+ result = ""
+ async for chunk in resoning_llm.call(
+ message,
+ streaming=False,
+ temperature=0.07,
+ result_only=True,
+ ):
+ result += chunk
+
+ return result
+
+ @staticmethod
+ async def _parse_result(result: str, schema: dict[str, Any], left_str: str = '{', right_str: str = '}') -> str:
+ """解析推理结果"""
+ left_index = result.find(left_str)
+ right_index = result.rfind(right_str)
+ flag = True
+ if left_str == -1 or right_str == -1:
+ flag = False
+
+ if left_index > right_index:
+ flag = False
+ if flag:
+ try:
+ tmp_js = json.loads(result[left_index:right_index + 1])
+ validate(instance=tmp_js, schema=schema)
+ except Exception as e:
+ logger.error("[McpBase] 解析结果失败: %s", e)
+ flag = False
+ if not flag:
+ json_generator = JsonGenerator(
+ "请提取下面内容中的json\n\n",
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "请提取下面内容中的json\n\n"+result},
+ ],
+ schema,
+ )
+ json_result = await json_generator.generate()
+ else:
+ json_result = json.loads(result[left_index:right_index + 1])
+ return json_result
diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py
new file mode 100644
index 0000000000000000000000000000000000000000..44b04d1fff502037e9c35f1232dfce6959a57009
--- /dev/null
+++ b/apps/scheduler/mcp_agent/host.py
@@ -0,0 +1,104 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""MCP宿主"""
+
+import json
+import logging
+from typing import Any
+
+from jinja2 import BaseLoader
+from jinja2.sandbox import SandboxedEnvironment
+from mcp.types import TextContent
+
+from apps.common.mongo import MongoDB
+from apps.llm.function import JsonGenerator
+from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE
+from apps.scheduler.pool.mcp.client import MCPClient
+from apps.scheduler.pool.mcp.pool import MCPPool
+from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS
+from apps.schemas.enum_var import StepStatus
+from apps.schemas.mcp import MCPPlanItem, MCPTool
+from apps.schemas.task import Task, FlowStepHistory
+from apps.services.task import TaskManager
+
+logger = logging.getLogger(__name__)
+
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=False,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
+
+
+def tojson_filter(value):
+ return json.dumps(value, ensure_ascii=False, separators=(',', ':'))
+
+
+_env.filters['tojson'] = tojson_filter
+
+
+class MCPHost:
+ """MCP宿主服务"""
+
+ @staticmethod
+ async def assemble_memory(task: Task) -> str:
+ """组装记忆"""
+
+ return _env.from_string(MEMORY_TEMPLATE).render(
+ context_list=task.context,
+ )
+
+ async def _get_first_input_params(mcp_tool: MCPTool, goal: str, query: str, task: Task) -> dict[str, Any]:
+ """填充工具参数"""
+ # 更清晰的输入·指令,这样可以调用generate
+ llm_query = rf"""
+ 你是一个MCP工具参数生成器,你的任务是根据用户的目标和查询,生成MCP工具的输入参数。
+ 请充分利用背景中的信息,生成满足以下目标的工具参数:
+ {query}
+ 注意:
+ 1. 你需要根据工具的输入schema来生成参数。
+ 2. 如果工具的输入schema中有枚举类型的字段,请确保生成的参数符合这些枚举值。
+ 下面是工具的简介:
+ 工具名称:{mcp_tool.name}
+ 工具描述:{mcp_tool.description}
+ """
+
+ # 进行生成
+ json_generator = JsonGenerator(
+ llm_query,
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "用户的总目标是:\n"+goal+(await MCPHost.assemble_memory(task))},
+ ],
+ mcp_tool.input_schema,
+ )
+ return await json_generator.generate()
+
+ async def _fill_params(mcp_tool: MCPTool,
+ goal: str,
+ current_goal: str,
+ current_input: dict[str, Any],
+ error_message: str = "", params: dict[str, Any] = {},
+ params_description: str = "") -> dict[str, Any]:
+ llm_query = "请生成修复之后的工具参数"
+ prompt = _env.from_string(REPAIR_PARAMS).render(
+ tool_name=mcp_tool.name,
+ goal=goal,
+ current_goal=current_goal,
+ tool_description=mcp_tool.description,
+ input_schema=mcp_tool.input_schema,
+ current_input=current_input,
+ error_message=error_message,
+ params=params,
+ params_description=params_description,
+ )
+
+ json_generator = JsonGenerator(
+ llm_query,
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
+ ],
+ mcp_tool.input_schema,
+ )
+ return await json_generator.generate()
diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py
new file mode 100644
index 0000000000000000000000000000000000000000..84c08d60227147849f8b9f4b3d91f79dc43297b0
--- /dev/null
+++ b/apps/scheduler/mcp_agent/plan.py
@@ -0,0 +1,368 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""MCP 用户目标拆解与规划"""
+from typing import Any, AsyncGenerator
+from jinja2 import BaseLoader
+from jinja2.sandbox import SandboxedEnvironment
+
+from apps.llm.reasoning import ReasoningLLM
+from apps.llm.function import JsonGenerator
+from apps.scheduler.mcp_agent.base import McpBase
+from apps.scheduler.mcp_agent.prompt import (
+ EVALUATE_GOAL,
+ GENERATE_FLOW_NAME,
+ GET_REPLAN_START_STEP_INDEX,
+ CREATE_PLAN,
+ RECREATE_PLAN,
+ GEN_STEP,
+ TOOL_SKIP,
+ RISK_EVALUATE,
+ TOOL_EXECUTE_ERROR_TYPE_ANALYSIS,
+ IS_PARAM_ERROR,
+ CHANGE_ERROR_MESSAGE_TO_DESCRIPTION,
+ GET_MISSING_PARAMS,
+ FINAL_ANSWER
+)
+from apps.schemas.task import Task
+from apps.schemas.mcp import (
+ GoalEvaluationResult,
+ RestartStepIndex,
+ ToolSkip,
+ ToolRisk,
+ IsParamError,
+ ToolExcutionErrorType,
+ MCPPlan,
+ Step,
+ MCPPlanItem,
+ MCPTool
+)
+from apps.scheduler.slot.slot import Slot
+
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=False,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
+
+
+class MCPPlanner(McpBase):
+ """MCP 用户目标拆解与规划"""
+
+ @staticmethod
+ async def evaluate_goal(
+ goal: str,
+ tool_list: list[MCPTool],
+ resoning_llm: ReasoningLLM = ReasoningLLM()) -> GoalEvaluationResult:
+ """评估用户目标的可行性"""
+ # 获取推理结果
+ result = await MCPPlanner._get_reasoning_evaluation(goal, tool_list, resoning_llm)
+
+ # 解析为结构化数据
+ evaluation = await MCPPlanner._parse_evaluation_result(result)
+
+ # 返回评估结果
+ return evaluation
+
+ @staticmethod
+ async def _get_reasoning_evaluation(
+ goal, tool_list: list[MCPTool],
+ resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取推理大模型的评估结果"""
+ template = _env.from_string(EVALUATE_GOAL)
+ prompt = template.render(
+ goal=goal,
+ tools=tool_list,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, resoning_llm)
+ return result
+
+ @staticmethod
+ async def _parse_evaluation_result(result: str) -> GoalEvaluationResult:
+ """将推理结果解析为结构化数据"""
+ schema = GoalEvaluationResult.model_json_schema()
+ evaluation = await MCPPlanner._parse_result(result, schema)
+ # 使用GoalEvaluationResult模型解析结果
+ return GoalEvaluationResult.model_validate(evaluation)
+
+ async def get_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取当前流程的名称"""
+ result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm)
+ return result
+
+ @staticmethod
+ async def _get_reasoning_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取推理大模型的流程名称"""
+ template = _env.from_string(GENERATE_FLOW_NAME)
+ prompt = template.render(goal=user_goal)
+ result = await MCPPlanner.get_resoning_result(prompt, resoning_llm)
+ return result
+
+ @staticmethod
+ async def get_replan_start_step_index(
+ user_goal: str, error_message: str, current_plan: MCPPlan | None = None,
+ history: str = "",
+ reasoning_llm: ReasoningLLM = ReasoningLLM()) -> RestartStepIndex:
+ """获取重新规划的步骤索引"""
+ # 获取推理结果
+ template = _env.from_string(GET_REPLAN_START_STEP_INDEX)
+ prompt = template.render(
+ goal=user_goal,
+ error_message=error_message,
+ current_plan=current_plan.model_dump(exclude_none=True, by_alias=True),
+ history=history,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ # 解析为结构化数据
+ schema = RestartStepIndex.model_json_schema()
+ schema["properties"]["start_index"]["maximum"] = len(current_plan.plans) - 1
+ schema["properties"]["start_index"]["minimum"] = 0
+ restart_index = await MCPPlanner._parse_result(result, schema)
+ # 使用RestartStepIndex模型解析结果
+ return RestartStepIndex.model_validate(restart_index)
+
+ @staticmethod
+ async def create_plan(
+ user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None,
+ tool_list: list[MCPTool] = [],
+ max_steps: int = 6, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan:
+ """规划下一步的执行流程,并输出"""
+ # 获取推理结果
+ result = await MCPPlanner._get_reasoning_plan(user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm)
+
+ # 解析为结构化数据
+ return await MCPPlanner._parse_plan_result(result, max_steps)
+
+ @staticmethod
+ async def _get_reasoning_plan(
+ user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None,
+ tool_list: list[MCPTool] = [],
+ max_steps: int = 10, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取推理大模型的结果"""
+ # 格式化Prompt
+ tool_ids = [tool.id for tool in tool_list]
+ if is_replan:
+ template = _env.from_string(RECREATE_PLAN)
+ prompt = template.render(
+ current_plan=current_plan.model_dump(exclude_none=True, by_alias=True),
+ error_message=error_message,
+ goal=user_goal,
+ tools=tool_list,
+ max_num=max_steps,
+ )
+ else:
+ template = _env.from_string(CREATE_PLAN)
+ prompt = template.render(
+ goal=user_goal,
+ tools=tool_list,
+ max_num=max_steps,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ return result
+
+ @staticmethod
+ async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan:
+ """将推理结果解析为结构化数据"""
+ # 格式化Prompt
+ schema = MCPPlan.model_json_schema()
+ schema["properties"]["plans"]["maxItems"] = max_steps
+ plan = await MCPPlanner._parse_result(result, schema)
+ # 使用Function模型解析结果
+ return MCPPlan.model_validate(plan)
+
+ @staticmethod
+ async def create_next_step(
+ goal: str, history: str, tools: list[MCPTool],
+ reasoning_llm: ReasoningLLM = ReasoningLLM()) -> Step:
+ """创建下一步的执行步骤"""
+ # 获取推理结果
+ template = _env.from_string(GEN_STEP)
+ prompt = template.render(goal=goal, history=history, tools=tools)
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+
+ # 解析为结构化数据
+ schema = Step.model_json_schema()
+ if "enum" not in schema["properties"]["tool_id"]:
+ schema["properties"]["tool_id"]["enum"] = []
+ for tool in tools:
+ schema["properties"]["tool_id"]["enum"].append(tool.id)
+ step = await MCPPlanner._parse_result(result, schema)
+ # 使用Step模型解析结果
+ return Step.model_validate(step)
+
+ @staticmethod
+ async def tool_skip(
+ task: Task, step_id: str, step_name: str, step_instruction: str, step_content: str,
+ reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolSkip:
+ """判断当前步骤是否需要跳过"""
+ # 获取推理结果
+ template = _env.from_string(TOOL_SKIP)
+ from apps.scheduler.mcp_agent.host import MCPHost
+ history = await MCPHost.assemble_memory(task)
+ prompt = template.render(
+ step_id=step_id,
+ step_name=step_name,
+ step_instruction=step_instruction,
+ step_content=step_content,
+ history=history,
+ goal=task.runtime.question
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+
+ # 解析为结构化数据
+ schema = ToolSkip.model_json_schema()
+ skip_result = await MCPPlanner._parse_result(result, schema)
+ # 使用ToolSkip模型解析结果
+ return ToolSkip.model_validate(skip_result)
+
+ @staticmethod
+ async def get_tool_risk(
+ tool: MCPTool, input_parm: dict[str, Any],
+ additional_info: str = "", resoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolRisk:
+ """获取MCP工具的风险评估结果"""
+ # 获取推理结果
+ result = await MCPPlanner._get_reasoning_risk(tool, input_parm, additional_info, resoning_llm)
+
+ # 解析为结构化数据
+ risk = await MCPPlanner._parse_risk_result(result)
+
+ # 返回风险评估结果
+ return risk
+
+ @staticmethod
+ async def _get_reasoning_risk(
+ tool: MCPTool, input_param: dict[str, Any],
+ additional_info: str, resoning_llm: ReasoningLLM) -> str:
+ """获取推理大模型的风险评估结果"""
+ template = _env.from_string(RISK_EVALUATE)
+ prompt = template.render(
+ tool_name=tool.name,
+ tool_description=tool.description,
+ input_param=input_param,
+ additional_info=additional_info,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, resoning_llm)
+ return result
+
+ @staticmethod
+ async def _parse_risk_result(result: str) -> ToolRisk:
+ """将推理结果解析为结构化数据"""
+ schema = ToolRisk.model_json_schema()
+ risk = await MCPPlanner._parse_result(result, schema)
+ # 使用ToolRisk模型解析结果
+ return ToolRisk.model_validate(risk)
+
+ @staticmethod
+ async def _get_reasoning_tool_execute_error_type(
+ user_goal: str, current_plan: MCPPlan,
+ tool: MCPTool, input_param: dict[str, Any],
+ error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """获取推理大模型的工具执行错误类型"""
+ template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS)
+ prompt = template.render(
+ goal=user_goal,
+ current_plan=current_plan.model_dump(exclude_none=True, by_alias=True),
+ tool_name=tool.name,
+ tool_description=tool.description,
+ input_param=input_param,
+ error_message=error_message,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ return result
+
+ @staticmethod
+ async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType:
+ """将推理结果解析为工具执行错误类型"""
+ schema = ToolExcutionErrorType.model_json_schema()
+ error_type = await MCPPlanner._parse_result(result, schema)
+ # 使用ToolExcutionErrorType模型解析结果
+ return ToolExcutionErrorType.model_validate(error_type)
+
+ @staticmethod
+ async def get_tool_execute_error_type(
+ user_goal: str, current_plan: MCPPlan,
+ tool: MCPTool, input_param: dict[str, Any],
+ error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolExcutionErrorType:
+ """获取MCP工具执行错误类型"""
+ # 获取推理结果
+ result = await MCPPlanner._get_reasoning_tool_execute_error_type(
+ user_goal, current_plan, tool, input_param, error_message, reasoning_llm)
+ error_type = await MCPPlanner._parse_tool_execute_error_type_result(result)
+ # 返回工具执行错误类型
+ return error_type
+
+ @staticmethod
+ async def is_param_error(
+ goal: str, history: str, error_message: str, tool: MCPTool, step_description: str, input_params: dict
+ [str, Any],
+ reasoning_llm: ReasoningLLM = ReasoningLLM()) -> IsParamError:
+ """判断错误信息是否是参数错误"""
+ tmplate = _env.from_string(IS_PARAM_ERROR)
+ prompt = tmplate.render(
+ goal=goal,
+ history=history,
+ step_id=tool.id,
+ step_name=tool.name,
+ step_description=step_description,
+ input_params=input_params,
+ error_message=error_message,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ # 解析为结构化数据
+ schema = IsParamError.model_json_schema()
+ is_param_error = await MCPPlanner._parse_result(result, schema)
+ # 使用IsParamError模型解析结果
+ return IsParamError.model_validate(is_param_error)
+
+ @staticmethod
+ async def change_err_message_to_description(
+ error_message: str, tool: MCPTool, input_params: dict[str, Any],
+ reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str:
+ """将错误信息转换为工具描述"""
+ template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION)
+ prompt = template.render(
+ error_message=error_message,
+ tool_name=tool.name,
+ tool_description=tool.description,
+ input_schema=tool.input_schema,
+ input_params=input_params,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ return result
+
+ @staticmethod
+ async def get_missing_param(
+ tool: MCPTool,
+ input_param: dict[str, Any],
+ error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> list[str]:
+ """获取缺失的参数"""
+ slot = Slot(schema=tool.input_schema)
+ template = _env.from_string(GET_MISSING_PARAMS)
+ schema_with_null = slot.add_null_to_basic_types()
+ prompt = template.render(
+ tool_name=tool.name,
+ tool_description=tool.description,
+ input_param=input_param,
+ schema=schema_with_null,
+ error_message=error_message,
+ )
+ result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm)
+ # 解析为结构化数据
+ input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null)
+ return input_param_with_null
+
+ @staticmethod
+ async def generate_answer(
+ user_goal: str, memory: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> AsyncGenerator[
+ str, None]:
+ """生成最终回答"""
+ template = _env.from_string(FINAL_ANSWER)
+ prompt = template.render(
+ memory=memory,
+ goal=user_goal,
+ )
+ async for chunk in resoning_llm.call(
+ [{"role": "user", "content": prompt}],
+ streaming=True,
+ temperature=0.07,
+ ):
+ yield chunk
diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e8ce5415dad9bca6ef8e5069a125975070d658
--- /dev/null
+++ b/apps/scheduler/mcp_agent/prompt.py
@@ -0,0 +1,1083 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""MCP相关的大模型Prompt"""
+
+from textwrap import dedent
+
+MCP_SELECT = dedent(r"""
+ 你是一个乐于助人的智能助手。
+ 你的任务是:根据当前目标,选择最合适的MCP Server。
+
+ ## 选择MCP Server时的注意事项:
+
+ 1. 确保充分理解当前目标,选择最合适的MCP Server。
+ 2. 请在给定的MCP Server列表中选择,不要自己生成MCP Server。
+ 3. 请先给出你选择的理由,再给出你的选择。
+ 4. 当前目标将在下面给出,MCP Server列表也会在下面给出。
+ 请将你的思考过程放在"思考过程"部分,将你的选择放在"选择结果"部分。
+ 5. 选择必须是JSON格式,严格按照下面的模板,不要输出任何其他内容:
+
+ ```json
+ {
+ "mcp": "你选择的MCP Server的名称"
+ }
+ ```
+
+ 6. 下面的示例仅供参考,不要将示例中的内容作为选择MCP Server的依据。
+
+ ## 示例
+
+ ### 目标
+
+ 我需要一个MCP Server来完成一个任务。
+
+ ### MCP Server列表
+
+ - **mcp_1**: "MCP Server 1";MCP Server 1的描述
+ - **mcp_2**: "MCP Server 2";MCP Server 2的描述
+
+ ### 请一步一步思考:
+
+ 因为当前目标需要一个MCP Server来完成一个任务,所以选择mcp_1。
+
+ ### 选择结果
+
+ ```json
+ {
+ "mcp": "mcp_1"
+ }
+ ```
+
+ ## 现在开始!
+
+ ### 目标
+
+ {{goal}}
+
+ ### MCP Server列表
+
+ {% for mcp in mcp_list %}
+ - **{{mcp.id}}**: "{{mcp.name}}";{{mcp.description}}
+ {% endfor %}
+
+ ### 请一步一步思考:
+
+""")
+TOOL_SELECT = dedent(r"""
+ 你是一个乐于助人的智能助手。
+ 你的任务是:根据当前目标,附加信息,选择最合适的MCP工具。
+ ## 选择MCP工具时的注意事项:
+ 1. 确保充分理解当前目标,选择实现目标所需的MCP工具。
+ 2. 不要选择不存在的工具。
+ 3. 可以选择一些辅助工具,但必须确保这些工具与当前目标相关。
+ 4. 注意,返回的工具ID必须是MCP工具的ID,而不是名称。
+ 5. 可以多选择一些工具,用于应对不同的情况。
+ 必须按照以下格式生成选择结果,不要输出任何其他内容:
+ ```json
+ {
+ "tool_ids": ["工具ID1", "工具ID2", ...]
+ }
+ ```
+
+ # 示例
+ ## 目标
+ 调优mysql性能
+ ## MCP工具列表
+
+ - mcp_tool_1 MySQL链接池工具;用于优化MySQL链接池
+ - mcp_tool_2 MySQL性能调优工具;用于分析MySQL性能瓶颈
+ - mcp_tool_3 MySQL查询优化工具;用于优化MySQL查询语句
+ - mcp_tool_4 MySQL索引优化工具;用于优化MySQL索引
+ - mcp_tool_5 文件存储工具;用于存储文件
+ - mcp_tool_6 mongoDB工具;用于操作MongoDB数据库
+
+ ## 附加信息
+ 1. 当前MySQL数据库的版本是8.0.26
+ 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项
+ ```json
+ {
+ "max_connections": 1000,
+ "innodb_buffer_pool_size": "1G",
+ "query_cache_size": "64M"
+ }
+ ##输出
+ ```json
+ {
+ "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"]
+ }
+ ```
+ # 现在开始!
+ ## 目标
+ {{goal}}
+ ## MCP工具列表
+
+ {% for tool in tools %}
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ {% endfor %}
+
+ ## 附加信息
+ {{additional_info}}
+ # 输出
+ """
+ )
+
+EVALUATE_GOAL = dedent(r"""
+ 你是一个计划评估器。
+ 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。
+ 如果能够完成,请返回`true`,否则返回`false`。
+ 推理过程必须清晰明了,能够让人理解你的判断依据。
+ 必须按照以下格式回答:
+ ```json
+ {
+ "can_complete": true/false,
+ "resoning": "你的推理过程"
+ }
+ ```
+
+ # 样例
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+
+ # 工具集合
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+ - mysql_analyzer 分析MySQL数据库性能
+ - performance_tuner 调优数据库性能
+ - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+
+
+ # 附加信息
+ 1. 当前MySQL数据库的版本是8.0.26
+ 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf
+
+ ##
+ ```json
+ {
+ "can_complete": true,
+ "resoning": "当前的工具集合中包含mysql_analyzer和performance_tuner,能够完成对MySQL数据库的性能分析和调优,因此可以完成用户的目标。"
+ }
+ ```
+
+ # 目标
+ {{goal}}
+
+ # 工具集合
+
+ {% for tool in tools %}
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ {% endfor %}
+
+
+ # 附加信息
+ {{additional_info}}
+
+""")
+GENERATE_FLOW_NAME = dedent(r"""
+ 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。
+
+ # 生成流程名称时的注意事项:
+ 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。
+ 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。
+ 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。
+ 4. 流程名称应该尽量简短,小于20个字或者单词。
+ 5. 只输出流程名称,不要输出其他内容。
+ # 样例
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 输出
+ 扫描MySQL数据库并分析性能瓶颈,进行调优
+ # 现在开始生成流程名称:
+ # 目标
+ {{goal}}
+ # 输出
+ """)
+GET_REPLAN_START_STEP_INDEX = dedent(r"""
+ 你是一个智能助手,你的任务是根据用户的目标、报错信息和当前计划和历史,获取重新规划的步骤起始索引。
+
+ # 样例
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 报错信息
+ 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。
+ # 当前计划
+ ```json
+ {
+ "plans": [
+ {
+ "step_id": "step_1",
+ "content": "生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描
+ },
+ {
+ "step_id": "step_2",
+ "content": "在执行Result[0]生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ }
+ ]
+ }
+ # 历史
+ [
+ {
+ id: "0",
+ task_id: "task_1",
+ flow_id: "flow_1",
+ flow_name: "MYSQL性能调优",
+ flow_status: "RUNNING",
+ step_id: "step_1",
+ step_name: "生成端口扫描命令",
+ step_description: "生成端口扫描命令:扫描当前MySQL数据库的端口",
+ step_status: "FAILED",
+ input_data: {
+ "command": "nmap -p 3306
+ "target": "localhost"
+ },
+ output_data: {
+ "error": "- bash: curl: command not found"
+ }
+ }
+ ]
+ # 输出
+ {
+ "start_index": 0,
+ "reasoning": "当前计划的第一步就失败了,报错信息显示curl命令未找到,可能是因为没有安装curl工具,因此需要从第一步重新规划。"
+ }
+ # 现在开始获取重新规划的步骤起始索引:
+ # 目标
+ {{goal}}
+ # 报错信息
+ {{error_message}}
+ # 当前计划
+ {{current_plan}}
+ # 历史
+ {{history}}
+ # 输出
+ """)
+
+CREATE_PLAN = dedent(r"""
+ 你是一个计划生成器。
+ 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。
+
+ # 一个好的计划应该:
+
+ 1. 能够成功完成用户的目标
+ 2. 计划中的每一个步骤必须且只能使用一个工具。
+ 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。
+ 4. 不要选择不存在的工具。
+ 5. 计划中的最后一步必须是Final工具,以确保计划执行结束。
+ 6. 生成的计划必须要覆盖用户的目标,当然需要考虑一些意外情况,可以有一定的冗余步骤。
+
+ # 生成计划时的注意事项:
+
+ - 每一条计划包含3个部分:
+ - 计划内容:描述单个计划步骤的大致内容
+ - 工具ID:必须从下文的工具列表中选择
+ - 工具指令:改写用户的目标,使其更符合工具的输入要求
+ - 必须按照如下格式生成计划,不要输出任何额外数据:
+
+ ```json
+ {
+ "plans": [
+ {
+ "content": "计划内容",
+ "tool": "工具ID",
+ "instruction": "工具指令"
+ }
+ ]
+ }
+ ```
+
+ - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。
+思考过程应放置在 XML标签中。
+ - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。
+ - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。
+
+ # 工具
+
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+
+ {% for tool in tools %}
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ {% endfor %}
+
+
+ # 样例
+
+ # 目标
+
+ 在后台运行一个新的alpine: latest容器,将主机/root文件夹挂载至/data,并执行top命令。
+
+ # 计划
+
+
+ 1. 这个目标需要使用Docker来完成, 首先需要选择合适的MCP Server
+ 2. 目标可以拆解为以下几个部分:
+ - 运行alpine: latest容器
+ - 挂载主机目录
+ - 在后台运行
+ - 执行top命令
+ 3. 需要先选择MCP Server, 然后生成Docker命令, 最后执行命令
+ ```json
+ {
+ "plans": [
+ {
+ "content": "选择一个支持Docker的MCP Server",
+ "tool": "mcp_selector",
+ "instruction": "需要一个支持Docker容器运行的MCP Server"
+ },
+ {
+ "content": "使用第一步选择的MCP Server,生成Docker命令",
+ "tool": "command_generator",
+ "instruction": "生成Docker命令:在后台运行alpine:latest容器,挂载/root到/data,执行top命令"
+ },
+ {
+ "content": "执行第二步生成的Docker命令",
+ "tool": "command_executor",
+ "instruction": "执行Docker命令"
+ },
+ {
+ "content": "任务执行完成,容器已在后台运行",
+ "tool": "Final",
+ "instruction": ""
+ }
+ ]
+ }
+ ```
+
+ # 现在开始生成计划:
+
+ # 目标
+
+ {{goal}}
+
+ # 计划
+""")
+RECREATE_PLAN = dedent(r"""
+ 你是一个计划重建器。
+ 请根据用户的目标、当前计划和运行报错,重新生成一个计划。
+
+ # 一个好的计划应该:
+
+ 1. 能够成功完成用户的目标
+ 2. 计划中的每一个步骤必须且只能使用一个工具。
+ 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。
+ 4. 你的计划必须避免之前的错误,并且能够成功执行。
+ 5. 不要选择不存在的工具。
+ 6. 计划中的最后一步必须是Final工具,以确保计划执行结束。
+ 7. 生成的计划必须要覆盖用户的目标,当然需要考虑一些意外情况,可以有一定的冗余步骤。
+ # 生成计划时的注意事项:
+
+ - 每一条计划包含3个部分:
+ - 计划内容:描述单个计划步骤的大致内容
+ - 工具ID:必须从下文的工具列表中选择
+ - 工具指令:改写用户的目标,使其更符合工具的输入要求
+ - 必须按照如下格式生成计划,不要输出任何额外数据:
+
+ ```json
+ {
+ "plans": [
+ {
+ "content": "计划内容",
+ "tool": "工具ID",
+ "instruction": "工具指令"
+ }
+ ]
+ }
+ ```
+
+ - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。
+思考过程应放置在 XML标签中。
+ - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。
+ - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。
+
+ # 样例
+
+ # 目标
+
+ 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。
+ # 工具
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+ - command_generator 生成命令行指令
+ - tool_selector 选择合适的工具
+ - command_executor 执行命令行指令
+ - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+ # 当前计划
+ ```json
+ {
+ "plans": [
+ {
+ "content": "生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口"
+ },
+ {
+ "content": "在执行第一步生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ },
+ {
+ "content": "任务执行完成",
+ "tool": "Final",
+ "instruction": ""
+ }
+ ]
+ }
+ ```
+ # 运行报错
+ 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。
+ # 重新生成的计划
+
+
+ 1. 这个目标需要使用网络扫描工具来完成, 首先需要选择合适的网络扫描工具
+ 2. 目标可以拆解为以下几个部分:
+ - 生成端口扫描命令
+ - 执行端口扫描命令
+ 3.但是在执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。
+ 4.我将计划调整为:
+ - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具
+ - 执行这个命令,查看当前机器支持哪些网络扫描工具
+ - 然后从中选择一个网络扫描工具
+ - 基于选择的网络扫描工具,生成端口扫描命令
+ - 执行端口扫描命令
+ ```json
+ {
+ "plans": [
+ {
+ "content": "需要生成一条命令查看当前机器支持哪些网络扫描工具",
+ "tool": "command_generator",
+ "instruction": "选择一个前机器支持哪些网络扫描工具"
+ },
+ {
+ "content": "执行第一步中生成的命令,查看当前机器支持哪些网络扫描工具",
+ "tool": "command_executor",
+ "instruction": "执行第一步中生成的命令"
+ },
+ {
+ "content": "从第二步执行结果中选择一个网络扫描工具,生成端口扫描命令",
+ "tool": "tool_selector",
+ "instruction": "选择一个网络扫描工具,生成端口扫描命令"
+ },
+ {
+ "content": "基于第三步中选择的网络扫描工具,生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口"
+ },
+ {
+ "content": "执行第四步中生成的端口扫描命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ },
+ {
+ "content": "任务执行完成",
+ "tool": "Final",
+ "instruction": ""
+ }
+ ]
+ }
+ ```
+
+ # 现在开始重新生成计划:
+
+ # 目标
+
+ {{goal}}
+
+ # 工具
+
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+
+ {% for tool in tools %}
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ {% endfor %}
+
+
+ # 当前计划
+ {{current_plan}}
+
+ # 运行报错
+ {{error_message}}
+
+ # 重新生成的计划
+""")
+GEN_STEP = dedent(r"""
+ 你是一个计划生成器。
+ 请根据用户的目标、当前计划和历史,生成一个新的步骤。
+
+ # 一个好的计划步骤应该:
+ 1.使用最适合的工具来完成当前步骤。
+ 2.能够基于当前的计划和历史,完成阶段性的任务。
+ 3.不要选择不存在的工具。
+ 4.如果你认为当前已经达成了用户的目标,可以直接返回Final工具,表示计划执行结束。
+
+ # 样例 1
+ # 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优,我的ip是192.168.1.1,数据库端口是3306,用户名是root,密码是password
+ # 历史记录
+ 第1步:生成端口扫描命令
+ - 调用工具 `command_generator`,并提供参数 `帮我生成一个mysql端口扫描命令`
+ - 执行状态:成功
+ - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}`
+ 第2步:执行端口扫描命令
+ - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}`
+ - 执行状态:成功
+ - 得到数据:`{"result": "success"}`
+ # 工具
+
+ - mcp_tool_1 mysql_analyzer;用于分析数据库性能/description>
+ - mcp_tool_2 文件存储工具;用于存储文件
+ - mcp_tool_3 mongoDB工具;用于操作MongoDB数据库
+ - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+
+ # 输出
+ ```json
+ {
+ "tool_id": "mcp_tool_1", // 选择的工具ID
+ "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能",
+ }
+ ```
+ # 样例二
+ # 目标
+ 计划从杭州到北京的旅游计划
+ # 历史记录
+ 第1步:将杭州转换为经纬度坐标
+ - 调用工具 `maps_geo_planner`,并提供参数 `{"city_from": "杭州", "address": "西湖"}`
+ - 执行状态:成功
+ - 得到数据:`{"location": "123.456, 78.901"}`
+ 第2步:查询杭州的天气
+ - 调用工具 `weather_query`,并提供参数 `{"location": "123.456, 78.901"}`
+ - 执行状态:成功
+ - 得到数据:`{"weather": "晴", "temperature": "25°C"}`
+ 第3步:将北京转换为经纬度坐标
+ - 调用工具 `maps_geo_planner`,并提供参数 `{"city_from": "北京", "address": "天安门"}`
+ - 执行状态:成功
+ - 得到数据:`{"location": "123.456, 78.901"}`
+ 第4步:查询北京的天气
+ - 调用工具 `weather_query`,并提供参数 `{"location": "123.456, 78.901"}`
+ - 执行状态:成功
+ - 得到数据:`{"weather": "晴", "temperature": "25°C"}`
+ # 工具
+
+ - mcp_tool_4 maps_geo_planner;将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标
+ - mcp_tool_5 weather_query;天气查询,用于查询天气信息
+ - mcp_tool_6 maps_direction_transit_integrated;根据用户起终点经纬度坐标规划综合各类公共(火车、公交、地铁)交通方式的通勤方案,并且返回通勤方案的数据,跨城场景下必须传起点城市与终点城市
+ - Final Final;结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+
+ # 输出
+ ```json
+ {
+ "tool_id": "mcp_tool_6", // 选择的工具ID
+ "description": "规划从杭州到北京的综合公共交通方式的通勤方案"
+ }
+ ```
+ # 现在开始生成步骤:
+ # 目标
+ {{goal}}
+ # 历史记录
+ {{history}}
+ # 工具
+
+ {% for tool in tools %}
+ - {{tool.id}} {{tool.name}};{{tool.description}}
+ {% endfor %}
+
+""")
+
+TOOL_SKIP = dedent(r"""
+ 你是一个计划执行器。
+ 你的任务是根据当前的计划和用户目标,判断当前步骤是否需要跳过。
+ 如果需要跳过,请返回`true`,否则返回`false`。
+ 必须按照以下格式回答:
+ ```json
+ {
+ "skip": true/false,
+ }
+ ```
+ 注意:
+ 1.你的判断要谨慎,在历史消息中有足够的上下文信息时,才可以判断是否跳过当前步骤。
+ # 样例
+ # 用户目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 历史
+ 第1步:生成端口扫描命令
+ - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}`
+ - 执行状态:成功
+ - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}`
+ 第2步:执行端口扫描命令
+ - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}`
+ - 执行状态:成功
+ - 得到数据:`{"result": "success"}`
+ 第3步:分析端口扫描结果
+ - 调用工具 `mysql_analyzer`,并提供参数 `{"host": "192.168.1.1", "port": 3306, "username": "root", "password": "password"}`
+ - 执行状态:成功
+ - 得到数据:`{"performance": "good", "bottleneck": "none"}`
+ # 当前步骤
+
+ step_4
+ command_generator
+ 生成MySQL性能调优命令
+ 生成MySQL性能调优命令:调优MySQL数据库性能
+
+ # 输出
+ ```json
+ {
+ "skip": true
+ }
+ ```
+ # 用户目标
+ {{goal}}
+ # 历史
+ {{history}}
+ # 当前步骤
+
+ {{step_id}}
+ {{step_name}}
+ {{step_instruction}}
+ {{step_content}}
+
+ # 输出
+ """
+ )
+RISK_EVALUATE = dedent(r"""
+ 你是一个工具执行计划评估器。
+ 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。
+ ```json
+ {
+ "risk": "low/medium/high",
+ "reason": "提示信息"
+ }
+ ```
+ # 样例
+ # 工具名称
+ mysql_analyzer
+ # 工具描述
+ 分析MySQL数据库性能
+ # 工具入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 附加信息
+ 1. 当前MySQL数据库的版本是8.0.26
+ 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项
+ ```ini
+ [mysqld]
+ innodb_buffer_pool_size=1G
+ innodb_log_file_size=256M
+ ```
+ # 输出
+ ```json
+ {
+ "risk": "中",
+ "reason": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。"
+ }
+ ```
+ # 工具
+ < tool >
+ < name > {{tool_name}} < /name >
+ < description > {{tool_description}} < /description >
+ < / tool >
+ # 工具入参
+ {{input_param}}
+ # 附加信息
+ {{additional_info}}
+ # 输出
+ """
+ )
+# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划
+TOOL_EXECUTE_ERROR_TYPE_ANALYSIS = dedent(r"""
+ 你是一个计划决策器。
+ 你的任务是根据用户目标、当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。
+ 请根据以下规则进行判断:
+ 1. 仅通过补充工具入参来解决问题的,返回 missing_param;
+ 2. 需要重计划当前步骤的,返回 decorrect_plan
+ 3.推理过程必须清晰明了,能够让人理解你的判断依据,并且不超过100字。
+ 你的输出要以json格式返回,格式如下:
+ ```json
+ {
+ "error_type": "missing_param/decorrect_plan,
+ "reason": "你的推理过程"
+ }
+ ```
+ # 样例
+ # 用户目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 当前计划
+ {"plans": [
+ {
+ "content": "生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口"
+ },
+ {
+ "content": "在执行Result[0]生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ },
+ {
+ "content": "任务执行完成,端口扫描结果为Result[2]",
+ "tool": "Final",
+ "instruction": ""
+ }
+ ]}
+ # 当前使用的工具
+ < tool >
+ < name > command_executor < /name >
+ < description > 执行命令行指令 < /description >
+ < / tool >
+ # 工具入参
+ {
+ "command": "nmap -sS -p--open 192.168.1.1"
+ }
+ # 工具运行报错
+ 执行端口扫描命令时,出现了错误:`- bash: nmap: command not found`。
+ # 输出
+ ```json
+ {
+ "error_type": "decorrect_plan",
+ "reason": "当前计划的第二步执行失败,报错信息显示nmap命令未找到,可能是因为没有安装nmap工具,因此需要重计划当前步骤。"
+ }
+ ```
+ # 用户目标
+ {{goal}}
+ # 当前计划
+ {{current_plan}}
+ # 当前使用的工具
+ < tool >
+ < name > {{tool_name}} < /name >
+ < description > {{tool_description}} < /description >
+ < / tool >
+ # 工具入参
+ {{input_param}}
+ # 工具运行报错
+ {{error_message}}
+ # 输出
+ """
+ )
+IS_PARAM_ERROR = dedent(r"""
+ 你是一个计划执行专家,你的任务是判断当前的步骤执行失败是否是因为参数错误导致的,
+ 如果是,请返回`true`,否则返回`false`。
+ 必须按照以下格式回答:
+ ```json
+ {
+ "is_param_error": true/false,
+ }
+ ```
+ # 样例
+ # 用户目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 历史
+ 第1步:生成端口扫描命令
+ - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}`
+ - 执行状态:成功
+ - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}`
+ 第2步:执行端口扫描命令
+ - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}`
+ - 执行状态:成功
+ - 得到数据:`{"result": "success"}`
+ # 当前步骤
+
+ step_3
+ mysql_analyzer
+ 分析MySQL数据库性能
+
+ # 工具入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 工具运行报错
+ 执行MySQL性能分析命令时,出现了错误:`host is not correct`。
+
+ # 输出
+ ```json
+ {
+ "is_param_error": true
+ }
+ ```
+ # 用户目标
+ {{goal}}
+ # 历史
+ {{history}}
+ # 当前步骤
+
+ {{step_id}}
+ {{step_name}}
+ {{step_instruction}}
+
+ # 工具入参
+ {{input_param}}
+ # 工具运行报错
+ {{error_message}}
+ # 输出
+ """
+ )
+# 将当前程序运行的报错转换为自然语言
+CHANGE_ERROR_MESSAGE_TO_DESCRIPTION = dedent(r"""
+ 你是一个智能助手,你的任务是将当前程序运行的报错转换为自然语言描述。
+ 请根据以下规则进行转换:
+ 1. 将报错信息转换为自然语言描述,描述应该简洁明了,能够让人理解报错的原因和影响。
+ 2. 描述应该包含报错的具体内容和可能的解决方案。
+ 3. 描述应该避免使用过于专业的术语,以便用户能够理解。
+ 4. 描述应该尽量简短,控制在50字以内。
+ 5. 只输出自然语言描述,不要输出其他内容。
+ # 样例
+ # 工具信息
+ < tool >
+ < name > port_scanner < /name >
+ < description > 扫描主机端口 < /description >
+ < input_schema >
+ {
+ "type": "object",
+ "properties": {
+ "host": {
+ "type": "string",
+ "description": "主机地址"
+ },
+ "port": {
+ "type": "integer",
+ "description": "端口号"
+ },
+ "username": {
+ "type": "string",
+ "description": "用户名"
+ },
+ "password": {
+ "type": "string",
+ "description": "密码"
+ }
+ },
+ "required": ["host", "port", "username", "password"]
+ }
+ < /input_schema >
+ < / tool >
+ # 工具入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 报错信息
+ 执行端口扫描命令时,出现了错误:`password is not correct`。
+ # 输出
+ 扫描端口时发生错误:密码不正确。请检查输入的密码是否正确,并重试。
+ # 现在开始转换报错信息:
+ # 工具信息
+ < tool >
+ < name > {{tool_name}} < /name >
+ < description > {{tool_description}} < /description >
+ < input_schema >
+ {{input_schema}}
+ < /input_schema >
+ < / tool >
+ # 工具入参
+ {{input_params}}
+ # 报错信息
+ {{error_message}}
+ # 输出
+ """)
+# 获取缺失的参数的json结构体
+GET_MISSING_PARAMS = dedent(r"""
+ 你是一个工具参数获取器。
+ 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。
+ ```json
+ {
+ "host": "请补充主机地址",
+ "port": "请补充端口号",
+ "username": "请补充用户名",
+ "password": "请补充密码"
+ }
+ ```
+ # 样例
+ # 工具名称
+ mysql_analyzer
+ # 工具描述
+ 分析MySQL数据库性能
+ # 工具入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 工具入参schema
+ {
+ "type": "object",
+ "properties": {
+ "host": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的主机地址(可以为字符串或null)"
+ },
+ "port": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的端口号(可以是数字、字符串或null)"
+ },
+ "username": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的用户名(可以为字符串或null)"
+ },
+ "password": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的密码(可以为字符串或null)"
+ }
+ },
+ "required": ["host", "port", "username", "password"]
+ }
+ # 运行报错
+ 执行端口扫描命令时,出现了错误:`password is not correct`。
+ # 输出
+ ```json
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": null,
+ "password": null
+ }
+ ```
+ # 工具
+ < tool >
+ < name > {{tool_name}} < /name >
+ < description > {{tool_description}} < /description >
+ < / tool >
+ # 工具入参
+ {{input_param}}
+ # 工具入参schema(部分字段允许为null)
+ {{input_schema}}
+ # 运行报错
+ {{error_message}}
+ # 输出
+ """
+ )
+REPAIR_PARAMS = dedent(r"""
+ 你是一个工具参数修复器。
+ 你的任务是根据当前的工具信息、目标、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。
+
+ 注意:
+ 1.最终修复的参数要符合目标和工具入参的schema。
+
+ # 样例
+ # 工具信息
+ < tool >
+ < name > mysql_analyzer < /name >
+ < description > 分析MySQL数据库性能 < /description >
+ < / tool >
+ # 总目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优
+ # 当前阶段目标
+ 我要连接MySQL数据库,分析性能瓶颈,并调优。
+ # 工具入参的schema
+ {
+ "type": "object",
+ "properties": {
+ "host": {
+ "type": "string",
+ "description": "MySQL数据库的主机地址"
+ },
+ "port": {
+ "type": "integer",
+ "description": "MySQL数据库的端口号"
+ },
+ "username": {
+ "type": "string",
+ "description": "MySQL数据库的用户名"
+ },
+ "password": {
+ "type": "string",
+ "description": "MySQL数据库的密码"
+ }
+ },
+ "required": ["host", "port", "username", "password"]
+ }
+ # 工具当前的入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 工具的报错
+ 执行端口扫描命令时,出现了错误:`password is not correct`。
+ # 补充的参数
+ {
+ "username": "admin",
+ "password": "admin123"
+ }
+ # 补充的参数描述
+ 用户希望使用admin用户和admin123密码来连接MySQL数据库。
+ # 输出
+ ```json
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "admin",
+ "password": "admin123"
+ }
+ ```
+ # 工具
+ < tool >
+ < name > {{tool_name}} < /name >
+ < description > {{tool_description}} < /description >
+ < / tool >
+ # 总目标
+ {{goal}}
+ # 当前阶段目标
+ {{current_goal}}
+ # 工具入参scheme
+ {{input_schema}}
+ # 工具入参
+ {{input_param}}
+ # 工具描述
+ {{tool_description}}
+ # 运行报错
+ {{error_message}}
+ # 补充的参数
+ {{params}}
+ # 补充的参数描述
+ {{params_description}}
+ # 输出
+ """
+ )
+FINAL_ANSWER = dedent(r"""
+ 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。
+
+ # 用户目标
+
+ {{goal}}
+
+ # 计划执行情况
+
+ 为了完成上述目标,你实施了以下计划:
+
+ {{memory}}
+
+
+ # 现在,请根据以上信息,向用户报告目标的完成情况:
+
+""")
+MEMORY_TEMPLATE = dedent(r"""
+ { % for ctx in context_list % }
+ - 第{{loop.index}}步:{{ctx.step_description}}
+ - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}`
+ - 执行状态:{{ctx.step_status}}
+ - 得到数据:`{{ctx.output_data}}`
+ { % endfor % }
+""")
diff --git a/apps/scheduler/mcp_agent/schema.py b/apps/scheduler/mcp_agent/schema.py
deleted file mode 100644
index 614139074382daf128a19a320c949e5e46803c4d..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/schema.py
+++ /dev/null
@@ -1,148 +0,0 @@
-"""MCP Agent执行数据结构"""
-from typing import Any, Self
-
-from pydantic import BaseModel, Field
-
-from apps.schemas.enum_var import Role
-
-
-class Function(BaseModel):
- """工具函数"""
-
- name: str
- arguments: dict[str, Any]
-
-
-class ToolCall(BaseModel):
- """Represents a tool/function call in a message"""
-
- id: str
- type: str = "function"
- function: Function
-
-
-class Message(BaseModel):
- """Represents a chat message in the conversation"""
-
- role: Role = Field(...)
- content: str | None = Field(default=None)
- tool_calls: list[ToolCall] | None = Field(default=None)
- name: str | None = Field(default=None)
- tool_call_id: str | None = Field(default=None)
-
- def __add__(self, other) -> list["Message"]:
- """支持 Message + list 或 Message + Message 的操作"""
- if isinstance(other, list):
- return [self] + other
- elif isinstance(other, Message):
- return [self, other]
- else:
- raise TypeError(
- f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'"
- )
-
- def __radd__(self, other) -> list["Message"]:
- """支持 list + Message 的操作"""
- if isinstance(other, list):
- return other + [self]
- else:
- raise TypeError(
- f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'"
- )
-
- def to_dict(self) -> dict:
- """Convert message to dictionary format"""
- message = {"role": self.role}
- if self.content is not None:
- message["content"] = self.content
- if self.tool_calls is not None:
- message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls]
- if self.name is not None:
- message["name"] = self.name
- if self.tool_call_id is not None:
- message["tool_call_id"] = self.tool_call_id
- return message
-
- @classmethod
- def user_message(cls, content: str) -> Self:
- """Create a user message"""
- return cls(role=Role.USER, content=content)
-
- @classmethod
- def system_message(cls, content: str) -> Self:
- """Create a system message"""
- return cls(role=Role.SYSTEM, content=content)
-
- @classmethod
- def assistant_message(
- cls, content: str | None = None,
- ) -> Self:
- """Create an assistant message"""
- return cls(role=Role.ASSISTANT, content=content)
-
- @classmethod
- def tool_message(
- cls, content: str, name: str, tool_call_id: str,
- ) -> Self:
- """Create a tool message"""
- return cls(
- role=Role.TOOL,
- content=content,
- name=name,
- tool_call_id=tool_call_id,
- )
-
- @classmethod
- def from_tool_calls(
- cls,
- tool_calls: list[Any],
- content: str | list[str] = "",
- **kwargs, # noqa: ANN003
- ) -> Self:
- """Create ToolCallsMessage from raw tool calls.
-
- Args:
- tool_calls: Raw tool calls from LLM
- content: Optional message content
- """
- formatted_calls = [
- {"id": call.id, "function": call.function.model_dump(), "type": "function"}
- for call in tool_calls
- ]
- return cls(
- role=Role.ASSISTANT,
- content=content,
- tool_calls=formatted_calls,
- **kwargs,
- )
-
-
-class Memory(BaseModel):
- messages: list[Message] = Field(default_factory=list)
- max_messages: int = Field(default=100)
-
- def add_message(self, message: Message) -> None:
- """Add a message to memory"""
- self.messages.append(message)
- # Optional: Implement message limit
- if len(self.messages) > self.max_messages:
- self.messages = self.messages[-self.max_messages:]
-
- def add_messages(self, messages: list[Message]) -> None:
- """Add multiple messages to memory"""
- self.messages.extend(messages)
- # Optional: Implement message limit
- if len(self.messages) > self.max_messages:
- self.messages = self.messages[-self.max_messages:]
-
- def clear(self) -> None:
- """Clear all messages"""
- self.messages.clear()
-
- def get_recent_messages(self, n: int) -> list[Message]:
- """Get n most recent messages"""
- return self.messages[-n:]
-
- def to_dict_list(self) -> list[dict]:
- """Convert messages to list of dicts"""
- return [msg.to_dict() for msg in self.messages]
diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py
new file mode 100644
index 0000000000000000000000000000000000000000..075e08f00cb2b5a2976d5372c623193b04c42412
--- /dev/null
+++ b/apps/scheduler/mcp_agent/select.py
@@ -0,0 +1,109 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""选择MCP Server及其工具"""
+
+import logging
+import random
+from jinja2 import BaseLoader
+from jinja2.sandbox import SandboxedEnvironment
+from typing import AsyncGenerator, Any
+
+from apps.llm.function import JsonGenerator
+from apps.llm.reasoning import ReasoningLLM
+from apps.common.lance import LanceDB
+from apps.common.mongo import MongoDB
+from apps.llm.embedding import Embedding
+from apps.llm.function import FunctionLLM
+from apps.llm.reasoning import ReasoningLLM
+from apps.llm.token import TokenCalculator
+from apps.scheduler.mcp_agent.base import McpBase
+from apps.scheduler.mcp_agent.prompt import TOOL_SELECT
+from apps.schemas.mcp import (
+ BaseModel,
+ MCPCollection,
+ MCPSelectResult,
+ MCPTool,
+ MCPToolIdsSelectResult
+)
+from apps.common.config import Config
+logger = logging.getLogger(__name__)
+
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=True,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
+
+FINAL_TOOL_ID = "FIANL"
+SUMMARIZE_TOOL_ID = "SUMMARIZE"
+
+
+class MCPSelector(McpBase):
+ """MCP选择器"""
+
+ @staticmethod
+ async def select_top_tool(
+ goal: str, tool_list: list[MCPTool],
+ additional_info: str | None = None, top_n: int | None = None,
+ reasoning_llm: ReasoningLLM | None = None) -> list[MCPTool]:
+ """选择最合适的工具"""
+ random.shuffle(tool_list)
+ max_tokens = reasoning_llm._config.max_tokens
+ template = _env.from_string(TOOL_SELECT)
+ token_calculator = TokenCalculator()
+ if token_calculator.calculate_token_length(
+ messages=[{"role": "user", "content": template.render(
+ goal=goal, tools=[], additional_info=additional_info
+ )}],
+ pure_text=True) > max_tokens:
+ logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择")
+ return []
+ current_index = 0
+ tool_ids = []
+ while current_index < len(tool_list):
+ index = current_index
+ sub_tools = []
+ while index < len(tool_list):
+ tool = tool_list[index]
+ tokens = token_calculator.calculate_token_length(
+ messages=[{"role": "user", "content": template.render(
+ goal=goal, tools=[tool],
+ additional_info=additional_info
+ )}],
+ pure_text=True
+ )
+ if tokens > max_tokens:
+ continue
+ sub_tools.append(tool)
+
+ tokens = token_calculator.calculate_token_length(messages=[{"role": "user", "content": template.render(
+ goal=goal, tools=sub_tools, additional_info=additional_info)}, ], pure_text=True)
+ if tokens > max_tokens:
+ del sub_tools[-1]
+ break
+ else:
+ index += 1
+ current_index = index
+ if sub_tools:
+ schema = MCPToolIdsSelectResult.model_json_schema()
+ if "items" not in schema["properties"]["tool_ids"]:
+ schema["properties"]["tool_ids"]["items"] = {}
+ # 将enum添加到items中,限制数组元素的可选值
+ schema["properties"]["tool_ids"]["items"]["enum"] = [tool.id for tool in sub_tools]
+ result = await MCPSelector.get_resoning_result(template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"), reasoning_llm)
+ result = await MCPSelector._parse_result(result, schema)
+ try:
+ result = MCPToolIdsSelectResult.model_validate(result)
+ tool_ids.extend(result.tool_ids)
+ except Exception:
+ logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败")
+ continue
+ mcp_tools = [tool for tool in tool_list if tool.id in tool_ids]
+
+ if top_n is not None:
+ mcp_tools = mcp_tools[:top_n]
+ mcp_tools.append(MCPTool(id=FINAL_TOOL_ID, name="Final",
+ description="终止", mcp_id=FINAL_TOOL_ID, input_schema={}))
+ # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize",
+ # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={}))
+ return mcp_tools
diff --git a/apps/scheduler/mcp_agent/tool/__init__.py b/apps/scheduler/mcp_agent/tool/__init__.py
deleted file mode 100644
index 4593f31742fee21b2e3ec1c7c18ff8e3cfea2110..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/tool/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from apps.scheduler.mcp_agent.tool.base import BaseTool
-from apps.scheduler.mcp_agent.tool.terminate import Terminate
-from apps.scheduler.mcp_agent.tool.tool_collection import ToolCollection
-
-__all__ = [
- "BaseTool",
- "Terminate",
- "ToolCollection",
-]
diff --git a/apps/scheduler/mcp_agent/tool/base.py b/apps/scheduler/mcp_agent/tool/base.py
deleted file mode 100644
index 04ad45c47a3eecb25efdf5b2ce52beb6965b2fbd..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/tool/base.py
+++ /dev/null
@@ -1,73 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Any, Dict, Optional
-
-from pydantic import BaseModel, Field
-
-
-class BaseTool(ABC, BaseModel):
- name: str
- description: str
- parameters: Optional[dict] = None
-
- class Config:
- arbitrary_types_allowed = True
-
- async def __call__(self, **kwargs) -> Any:
- return await self.execute(**kwargs)
-
- @abstractmethod
- async def execute(self, **kwargs) -> Any:
- """使用给定的参数执行工具"""
-
- def to_param(self) -> Dict:
- """将工具转换为函数调用格式"""
- return {
- "type": "function",
- "function": {
- "name": self.name,
- "description": self.description,
- "parameters": self.parameters,
- },
- }
-
-
-class ToolResult(BaseModel):
- """表示工具执行的结果"""
-
- output: Any = Field(default=None)
- error: Optional[str] = Field(default=None)
- system: Optional[str] = Field(default=None)
-
- class Config:
- arbitrary_types_allowed = True
-
- def __bool__(self):
- return any(getattr(self, field) for field in self.__fields__)
-
- def __add__(self, other: "ToolResult"):
- def combine_fields(
- field: Optional[str], other_field: Optional[str], concatenate: bool = True
- ):
- if field and other_field:
- if concatenate:
- return field + other_field
- raise ValueError("Cannot combine tool results")
- return field or other_field
-
- return ToolResult(
- output=combine_fields(self.output, other.output),
- error=combine_fields(self.error, other.error),
- system=combine_fields(self.system, other.system),
- )
-
- def __str__(self):
- return f"Error: {self.error}" if self.error else self.output
-
- def replace(self, **kwargs):
- """返回一个新的ToolResult,其中替换了给定的字段"""
- # return self.copy(update=kwargs)
- return type(self)(**{**self.dict(), **kwargs})
-
-
-class ToolFailure(ToolResult):
- """表示失败的ToolResult"""
diff --git a/apps/scheduler/mcp_agent/tool/terminate.py b/apps/scheduler/mcp_agent/tool/terminate.py
deleted file mode 100644
index 84aa120316de1123f985eebfd82c471e8ceec990..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/tool/terminate.py
+++ /dev/null
@@ -1,25 +0,0 @@
-from apps.scheduler.mcp_agent.tool.base import BaseTool
-
-
-_TERMINATE_DESCRIPTION = """当请求得到满足或助理无法继续处理任务时,终止交互。
-当您完成所有任务后,调用此工具结束工作。"""
-
-
-class Terminate(BaseTool):
- name: str = "terminate"
- description: str = _TERMINATE_DESCRIPTION
- parameters: dict = {
- "type": "object",
- "properties": {
- "status": {
- "type": "string",
- "description": "交互的完成状态",
- "enum": ["success", "failure"],
- }
- },
- "required": ["status"],
- }
-
- async def execute(self, status: str) -> str:
- """Finish the current execution"""
- return f"交互已完成,状态为: {status}"
diff --git a/apps/scheduler/mcp_agent/tool/tool_collection.py b/apps/scheduler/mcp_agent/tool/tool_collection.py
deleted file mode 100644
index 95bda317805abdecc256af0737091b46adf77b1a..0000000000000000000000000000000000000000
--- a/apps/scheduler/mcp_agent/tool/tool_collection.py
+++ /dev/null
@@ -1,55 +0,0 @@
-"""用于管理多个工具的集合类"""
-import logging
-from typing import Any
-
-from apps.scheduler.mcp_agent.tool.base import BaseTool, ToolFailure, ToolResult
-
-logger = logging.getLogger(__name__)
-
-
-class ToolCollection:
- """定义工具的集合"""
-
- class Config:
- arbitrary_types_allowed = True
-
- def __init__(self, *tools: BaseTool):
- self.tools = tools
- self.tool_map = {tool.name: tool for tool in tools}
-
- def __iter__(self):
- return iter(self.tools)
-
- def to_params(self) -> list[dict[str, Any]]:
- return [tool.to_param() for tool in self.tools]
-
- async def execute(
- self, *, name: str, tool_input: dict[str, Any] = None
- ) -> ToolResult:
- tool = self.tool_map.get(name)
- if not tool:
- return ToolFailure(error=f"Tool {name} is invalid")
- try:
- result = await tool(**tool_input)
- return result
- except Exception as e:
- return ToolFailure(error=f"Failed to execute tool {name}: {e}")
-
- def add_tool(self, tool: BaseTool):
- """
- 将单个工具添加到集合中。
-
- 如果已存在同名工具,则将跳过该工具并记录警告。
- """
- if tool.name in self.tool_map:
- logger.warning(f"Tool {tool.name} already exists in collection, skipping")
- return self
-
- self.tools += (tool,)
- self.tool_map[tool.name] = tool
- return self
-
- def add_tools(self, *tools: BaseTool):
- for tool in tools:
- self.add_tool(tool)
- return self
diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py
index a85395bfc7723af9d009a1945fe4617cf390b938..49407ba95358d4442c714087fd48b1aaca219f8d 100644
--- a/apps/scheduler/pool/loader/app.py
+++ b/apps/scheduler/pool/loader/app.py
@@ -78,6 +78,7 @@ class AppLoader:
# 加载模型
try:
metadata = AgentAppMetadata.model_validate(metadata)
+ logger.info(f"[AppLoader] Agent应用元数据验证成功: {metadata}")
except Exception as e:
err = "[AppLoader] Agent应用元数据验证失败"
logger.exception(err)
@@ -102,7 +103,6 @@ class AppLoader:
await file_checker.diff_one(app_path)
await self.load(app_id, file_checker.hashes[f"app/{app_id}"])
-
@staticmethod
async def delete(app_id: str, *, is_reload: bool = False) -> None:
"""
@@ -157,5 +157,7 @@ class AppLoader:
},
upsert=True,
)
+ app_pool = await app_collection.find_one({"_id": metadata.id})
+ logger.error(f"[AppLoader] 更新 MongoDB 成功: {app_pool}")
except Exception:
logger.exception("[AppLoader] 更新 MongoDB 失败")
diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py
index 1463d0a1a9141bca271e29d7387759748aa8bd43..1da06b53630d3188ca55c16969823f3d1e063516 100644
--- a/apps/scheduler/pool/loader/mcp.py
+++ b/apps/scheduler/pool/loader/mcp.py
@@ -11,6 +11,7 @@ import shutil
import asyncer
from anyio import Path
from sqids.sqids import Sqids
+from typing import Any
from apps.common.lance import LanceDB
from apps.common.mongo import MongoDB
@@ -91,48 +92,47 @@ class MCPLoader(metaclass=SingletonMeta):
:param MCPServerConfig config: MCP配置
:return: 无
"""
- if not config.config.auto_install:
- print(f"[Installer] MCP模板无需安装: {mcp_id}") # noqa: T201
-
- elif isinstance(config.config, MCPServerStdioConfig):
- print(f"[Installer] Stdio方式的MCP模板,开始自动安装: {mcp_id}") # noqa: T201
- if "uv" in config.config.command:
- new_config = await install_uvx(mcp_id, config.config)
- elif "npx" in config.config.command:
- new_config = await install_npx(mcp_id, config.config)
-
- if new_config is None:
- logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id)
- await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED)
- return
-
- config.config = new_config
-
- # 重新保存config
- template_config = MCP_PATH / "template" / mcp_id / "config.json"
- f = await template_config.open("w+", encoding="utf-8")
- config_data = config.model_dump(by_alias=True, exclude_none=True)
- await f.write(json.dumps(config_data, indent=4, ensure_ascii=False))
- await f.aclose()
-
- else:
- print(f"[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: {mcp_id}") # noqa: T201
- config.config.auto_install = False
-
- print(f"[Installer] MCP模板安装成功: {mcp_id}") # noqa: T201
- await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY)
- await MCPLoader._insert_template_tool(mcp_id, config)
+ try:
+ if not config.config.auto_install:
+ print(f"[Installer] MCP模板无需安装: {mcp_id}") # noqa: T201
+ elif isinstance(config.config, MCPServerStdioConfig):
+ print(f"[Installer] Stdio方式的MCP模板,开始自动安装: {mcp_id}") # noqa: T201
+ if "uv" in config.config.command:
+ new_config = await install_uvx(mcp_id, config.config)
+ elif "npx" in config.config.command:
+ new_config = await install_npx(mcp_id, config.config)
+
+ if new_config is None:
+ logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id)
+ await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED)
+ return
+
+ config.config = new_config
+
+ # 重新保存config
+ template_config = MCP_PATH / "template" / mcp_id / "config.json"
+ f = await template_config.open("w+", encoding="utf-8")
+ config_data = config.model_dump(by_alias=True, exclude_none=True)
+ await f.write(json.dumps(config_data, indent=4, ensure_ascii=False))
+ await f.aclose()
+
+ else:
+ logger.info(f"[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: {mcp_id}") # noqa: T201
+ config.config.auto_install = False
+
+ await MCPLoader._insert_template_tool(mcp_id, config)
+ await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY)
+ logger.info(f"[Installer] MCP模板安装成功: {mcp_id}") # noqa: T201
+ except Exception as e:
+ logger.error("[MCPLoader] MCP模板安装失败: %s, 错误: %s", mcp_id, e)
+ await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED)
+ raise
@staticmethod
- async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None:
+ async def clear_ready_or_failed_mcp_installation() -> None:
"""
- 初始化单个MCP模板
-
- :param str mcp_id: MCP模板ID
- :param MCPServerConfig config: MCP配置
- :return: 无
+ 清除状态为ready或failed的MCP安装任务
"""
- # 删除完成或者失败的MCP安装任务
mcp_collection = MongoDB().get_collection("mcp")
mcp_ids = ProcessHandler.get_all_task_ids()
# 检索_id在mcp_ids且状态为ready或者failed的MCP的内容
@@ -147,8 +147,17 @@ class MCPLoader(metaclass=SingletonMeta):
continue
ProcessHandler.remove_task(item.id)
logger.info("[MCPLoader] 删除已完成或失败的MCP安装进程: %s", item.id)
- # 插入数据库;这里用旧的config就可以
- await MCPLoader._insert_template_db(mcp_id, config)
+
+ @staticmethod
+ async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None:
+ """
+ 初始化单个MCP模板
+
+ :param str mcp_id: MCP模板ID
+ :param MCPServerConfig config: MCP配置
+ :return: 无
+ """
+ await MCPLoader.clear_ready_or_failed_mcp_installation()
# 检查目录
template_path = MCP_PATH / "template" / mcp_id
@@ -158,37 +167,31 @@ class MCPLoader(metaclass=SingletonMeta):
err = f"安装任务无法执行,请稍后重试: {mcp_id}"
logger.error(err)
raise RuntimeError(err)
+ # 将installing状态的安装任务的状态变为cancelled
@staticmethod
- async def _init_all_template() -> None:
+ async def cancel_all_installing_task() -> None:
"""
- 初始化所有MCP模板
-
- 遍历 ``template`` 目录下的所有MCP模板,并初始化。在Framework启动时进行此流程,确保所有MCP均可正常使用。
- 这一过程会与数据库内的条目进行对比,若发生修改,则重新创建数据库条目。
+ 取消正在安装的MCP模板任务
"""
template_path = MCP_PATH / "template"
logger.info("[MCPLoader] 初始化所有MCP模板: %s", template_path)
-
+ mongo = MongoDB()
+ mcp_collection = mongo.get_collection("mcp")
# 遍历所有模板
+ mcp_ids = []
async for mcp_dir in template_path.iterdir():
# 不是目录
if not await mcp_dir.is_dir():
logger.warning("[MCPLoader] 跳过非目录: %s", mcp_dir.as_posix())
continue
- # 检查配置文件是否存在
- config_path = mcp_dir / "config.json"
- if not await config_path.exists():
- logger.warning("[MCPLoader] 跳过没有配置文件的MCP模板: %s", mcp_dir.as_posix())
- continue
-
- # 读取配置并加载
- config = await MCPLoader._load_config(config_path)
-
- # 初始化第一个MCP Server
- logger.info("[MCPLoader] 初始化MCP模板: %s", mcp_dir.as_posix())
- await MCPLoader.init_one_template(mcp_dir.name, config)
+ mcp_ids.append(mcp_dir.name)
+ # 更新数据库状态
+ await mcp_collection.update_many(
+ {"_id": {"$in": mcp_ids}, "status": MCPInstallStatus.INSTALLING},
+ {"$set": {"status": MCPInstallStatus.CANCELLED}},
+ )
@staticmethod
async def _get_template_tool(
@@ -384,6 +387,7 @@ class MCPLoader(metaclass=SingletonMeta):
# 更新数据库
mongo = MongoDB()
mcp_collection = mongo.get_collection("mcp")
+ logger.info("[MCPLoader] 更新MCP模板状态: %s -> %s", mcp_id, status)
await mcp_collection.update_one(
{"_id": mcp_id},
{"$set": {"status": status}},
@@ -391,7 +395,7 @@ class MCPLoader(metaclass=SingletonMeta):
)
@staticmethod
- async def user_active_template(user_sub: str, mcp_id: str) -> None:
+ async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any] | None = None) -> None:
"""
用户激活MCP模板
@@ -409,7 +413,7 @@ class MCPLoader(metaclass=SingletonMeta):
if await user_path.exists():
err = f"MCP模板“{mcp_id}”已存在或有同名文件,无法激活"
raise FileExistsError(err)
-
+ mcp_config = await MCPLoader.get_config(mcp_id)
# 拷贝文件
await asyncer.asyncify(shutil.copytree)(
template_path.as_posix(),
@@ -417,7 +421,32 @@ class MCPLoader(metaclass=SingletonMeta):
dirs_exist_ok=True,
symlinks=True,
)
-
+ if mcp_env is not None:
+ mcp_config.config.env.update(mcp_env)
+ if mcp_config.type == MCPType.STDIO:
+ index = None
+ for i in range(len(mcp_config.config.args)):
+ if mcp_config.config.args[i] == "--directory":
+ index = i + 1
+ break
+ if index is not None:
+ if index < len(mcp_config.config.args):
+ mcp_config.config.args[index] = str(user_path)+'/project'
+ else:
+ mcp_config.config.args.append(str(user_path)+'/project')
+ else:
+ mcp_config.config.args = ["--directory", str(user_path)+'/project'] + mcp_config.config.args
+ user_config_path = user_path / "config.json"
+ # 更新用户配置
+ f = await user_config_path.open("w", encoding="utf-8", errors="ignore")
+ await f.write(
+ json.dumps(
+ mcp_config.model_dump(by_alias=True, exclude_none=True),
+ indent=4,
+ ensure_ascii=False,
+ )
+ )
+ await f.aclose()
# 更新数据库
mongo = MongoDB()
mcp_collection = mongo.get_collection("mcp")
@@ -468,6 +497,26 @@ class MCPLoader(metaclass=SingletonMeta):
logger.info("[MCPLoader] 这些MCP在文件系统中被删除: %s", deleted_mcp_list)
return deleted_mcp_list
+ @staticmethod
+ async def cancel_installing_task(cancel_mcp_list: list[str]) -> None:
+ """
+ 取消正在安装的MCP模板任务
+
+ :param list[str] cancel_mcp_list: 需要取消的MCP列表
+ :return: 无
+ """
+ mongo = MongoDB()
+ mcp_collection = mongo.get_collection("mcp")
+ # 更新数据库状态
+ cancel_mcp_list = await mcp_collection.distinct("_id", {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING})
+ await mcp_collection.update_many(
+ {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING},
+ {"$set": {"status": MCPInstallStatus.CANCELLED}},
+ )
+ for mcp_id in cancel_mcp_list:
+ ProcessHandler.remove_task(mcp_id)
+ logger.info("[MCPLoader] 取消这些正在安装的MCP模板任务: %s", cancel_mcp_list)
+
@staticmethod
async def remove_deleted_mcp(deleted_mcp_list: list[str]) -> None:
"""
@@ -573,8 +622,8 @@ class MCPLoader(metaclass=SingletonMeta):
# 检查目录
await MCPLoader._check_dir()
- # 初始化所有模板
- await MCPLoader._init_all_template()
+ # 暂停所有安装任务
+ await MCPLoader.cancel_all_installing_task()
# 加载用户MCP
await MCPLoader._load_user_mcp()
diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py
index 2b9060461fc0f3baaece19e88c71776449e9752a..2d84069c434de99e45156452a1e5156ba03b2b3e 100644
--- a/apps/scheduler/pool/loader/service.py
+++ b/apps/scheduler/pool/loader/service.py
@@ -3,6 +3,7 @@
import asyncio
import logging
+import os
import shutil
from anyio import Path
@@ -30,6 +31,9 @@ class ServiceLoader:
"""加载单个Service"""
service_path = BASE_PATH / service_id
# 载入元数据
+ if not os.path.exists(service_path / "metadata.yaml"):
+ logger.error("[ServiceLoader] Service %s 的元数据不存在", service_id)
+ return
metadata = await MetadataLoader().load_one(service_path / "metadata.yaml")
if not isinstance(metadata, ServiceMetadata):
err = f"[ServiceLoader] 元数据类型错误: {service_path}/metadata.yaml"
@@ -48,7 +52,6 @@ class ServiceLoader:
# 更新数据库
await self._update_db(nodes, metadata)
-
async def save(self, service_id: str, metadata: ServiceMetadata, data: dict) -> None:
"""在文件系统上保存Service,并更新数据库"""
service_path = BASE_PATH / service_id
@@ -67,7 +70,6 @@ class ServiceLoader:
await file_checker.diff_one(service_path)
await self.load(service_id, file_checker.hashes[f"service/{service_id}"])
-
async def delete(self, service_id: str, *, is_reload: bool = False) -> None:
"""删除Service,并更新数据库"""
mongo = MongoDB()
@@ -95,7 +97,6 @@ class ServiceLoader:
if await path.exists():
shutil.rmtree(path)
-
async def _update_db(self, nodes: list[NodePool], metadata: ServiceMetadata) -> None: # noqa: C901, PLR0912, PLR0915
"""更新数据库"""
if not metadata.hashes:
@@ -197,4 +198,3 @@ class ServiceLoader:
await asyncio.sleep(0.01)
else:
raise
-
diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py
index 092bac8909635a5e0c846dddef3456d8ad3be43c..0ced05e8cc68bb773cfe61eb9268a4f8baab8fa5 100644
--- a/apps/scheduler/pool/mcp/client.py
+++ b/apps/scheduler/pool/mcp/client.py
@@ -29,6 +29,7 @@ class MCPClient:
mcp_id: str
task: asyncio.Task
ready_sign: asyncio.Event
+ error_sign: asyncio.Event
stop_sign: asyncio.Event
client: ClientSession
status: MCPStatus
@@ -54,9 +55,10 @@ class MCPClient:
"""
# 创建Client
if isinstance(config, MCPServerSSEConfig):
+ env = config.env or {}
client = sse_client(
url=config.url,
- headers=config.env,
+ headers=env,
)
elif isinstance(config, MCPServerStdioConfig):
if user_sub:
@@ -64,7 +66,6 @@ class MCPClient:
else:
cwd = MCP_PATH / "template" / mcp_id / "project"
await cwd.mkdir(parents=True, exist_ok=True)
-
client = stdio_client(server=StdioServerParameters(
command=config.command,
args=config.args,
@@ -72,6 +73,7 @@ class MCPClient:
cwd=cwd.as_posix(),
))
else:
+ self.error_sign.set()
err = f"[MCPClient] MCP {mcp_id}:未知的MCP服务类型“{config.type}”"
logger.error(err)
raise TypeError(err)
@@ -85,23 +87,24 @@ class MCPClient:
# 初始化Client
await session.initialize()
except Exception:
+ self.error_sign.set()
+ self.status = MCPStatus.STOPPED
logger.exception("[MCPClient] MCP %s:初始化失败", mcp_id)
raise
self.ready_sign.set()
self.status = MCPStatus.RUNNING
-
# 等待关闭信号
await self.stop_sign.wait()
+ logger.error("[MCPClient] MCP %s:收到停止信号,正在关闭", mcp_id)
# 关闭Client
try:
- await exit_stack.aclose() # type: ignore[attr-defined]
+ await exit_stack.aclose() # type: ignore[attr-defined]
self.status = MCPStatus.STOPPED
except Exception:
logger.exception("[MCPClient] MCP %s:关闭失败", mcp_id)
-
async def init(self, user_sub: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig) -> None:
"""
初始化 MCP Client类
@@ -116,27 +119,34 @@ class MCPClient:
# 初始化变量
self.mcp_id = mcp_id
self.ready_sign = asyncio.Event()
+ self.error_sign = asyncio.Event()
self.stop_sign = asyncio.Event()
# 创建协程
self.task = asyncio.create_task(self._main_loop(user_sub, mcp_id, config))
# 等待初始化完成
- await self.ready_sign.wait()
-
- # 获取工具列表
+ done, pending = await asyncio.wait(
+ [asyncio.create_task(self.ready_sign.wait()),
+ asyncio.create_task(self.error_sign.wait())],
+ return_when=asyncio.FIRST_COMPLETED
+ )
+ if self.error_sign.is_set():
+ self.status = MCPStatus.ERROR
+ logger.error("[MCPClient] MCP %s:初始化失败", mcp_id)
+ raise Exception(f"MCP {mcp_id} 初始化失败")
+
+ # 获取工具列表
self.tools = (await self.client.list_tools()).tools
-
async def call_tool(self, tool_name: str, params: dict) -> "CallToolResult":
"""调用MCP Server的工具"""
return await self.client.call_tool(tool_name, params)
-
async def stop(self) -> None:
"""停止MCP Client"""
self.stop_sign.set()
try:
await self.task
- except Exception:
- logger.exception("[MCPClient] MCP %s:停止失败", self.mcp_id)
+ except Exception as e:
+ logger.warning("[MCPClient] MCP %s:停止时发生异常:%s", self.mcp_id, e)
diff --git a/apps/scheduler/pool/mcp/install.py b/apps/scheduler/pool/mcp/install.py
index 1b6c3edeb3a042b7716d0d7748e9c3e50b01af74..2b15cd690bfd8b1939326728647e27ab517f89ff 100644
--- a/apps/scheduler/pool/mcp/install.py
+++ b/apps/scheduler/pool/mcp/install.py
@@ -3,12 +3,16 @@
from asyncio import subprocess
from typing import TYPE_CHECKING
-
+import logging
+import os
+import shutil
from apps.constants import MCP_PATH
if TYPE_CHECKING:
from apps.schemas.mcp import MCPServerStdioConfig
+logger = logging.getLogger(__name__)
+
async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServerStdioConfig | None":
"""
@@ -23,27 +27,35 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer
:rtype: MCPServerStdioConfig
:raises ValueError: 未找到MCP Server对应的Python包
"""
- # 创建文件夹
- mcp_path = MCP_PATH / "template" / mcp_id / "project"
- await mcp_path.mkdir(parents=True, exist_ok=True)
-
+ uv_path = shutil.which('uv')
+ if uv_path is None:
+ error = "[Installer] 未找到uv命令,请先安装uv包管理器: pip install uv"
+ logging.error(error)
+ raise Exception(error)
# 找到包名
- package = ""
+ package = None
for arg in config.args:
if not arg.startswith("-"):
package = arg
break
-
+ logger.error(f"[Installer] MCP包名: {package}")
if not package:
print("[Installer] 未找到包名") # noqa: T201
return None
-
+ # 创建文件夹
+ mcp_path = MCP_PATH / "template" / mcp_id / "project"
+ logger.error(f"[Installer] MCP安装路径: {mcp_path}")
+ await mcp_path.mkdir(parents=True, exist_ok=True)
# 如果有pyproject.toml文件,则使用sync
+ flag = await (mcp_path / "pyproject.toml").exists()
+ logger.error(f"[Installer] MCP安装标志: {flag}")
if await (mcp_path / "pyproject.toml").exists():
+ shell_command = f"{uv_path} venv; {uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active --no-install-project --no-cache"
+ logger.error(f"[Installer] MCP安装命令: {shell_command}")
pipe = await subprocess.create_subprocess_shell(
(
- "uv venv; "
- "uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active "
+ f"{uv_path} venv; "
+ f"{uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active "
"--no-install-project --no-cache"
),
stdout=subprocess.PIPE,
@@ -57,19 +69,20 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer
return None
print(f"[Installer] 检查依赖成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201
- config.command = "uv"
- config.args = ["run", *config.args]
+ config.command = uv_path
+ if "run" not in config.args:
+ config.args = ["run", *config.args]
config.auto_install = False
-
+ logger.error(f"[Installer] MCP安装配置更新成功: {config}")
return config
# 否则,初始化uv项目
pipe = await subprocess.create_subprocess_shell(
(
- f"uv init; "
- f"uv venv; "
- f"uv add --index-url https://pypi.tuna.tsinghua.edu.cn/simple {package}; "
- f"uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active "
+ f"{uv_path} init; "
+ f"{uv_path} venv; "
+ f"{uv_path} add --index-url https://pypi.tuna.tsinghua.edu.cn/simple {package}; "
+ f"{uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active "
f"--no-install-project --no-cache"
),
stdout=subprocess.PIPE,
@@ -84,8 +97,9 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer
print(f"[Installer] 安装 {package} 成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201
# 更新配置
- config.command = "uv"
- config.args = ["run", *config.args]
+ config.command = uv_path
+ if "run" not in config.args:
+ config.args = ["run", *config.args]
config.auto_install = False
return config
@@ -103,17 +117,13 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer
:rtype: MCPServerStdioConfig
:raises ValueError: 未找到MCP Server对应的npm包
"""
- mcp_path = MCP_PATH / "template" / mcp_id / "project"
- await mcp_path.mkdir(parents=True, exist_ok=True)
-
- # 如果有node_modules文件夹,则认为已安装
- if await (mcp_path / "node_modules").exists():
- config.command = "npm"
- config.args = ["exec", *config.args]
- return config
-
+ npm_path = shutil.which('npm')
+ if npm_path is None:
+ error = "[Installer] 未找到npm命令,请先安装Node.js和npm"
+ logging.error(error)
+ raise Exception(error)
# 查找package name
- package = ""
+ package = None
for arg in config.args:
if not arg.startswith("-"):
package = arg
@@ -122,10 +132,18 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer
if not package:
print("[Installer] 未找到包名") # noqa: T201
return None
+ mcp_path = MCP_PATH / "template" / mcp_id / "project"
+ await mcp_path.mkdir(parents=True, exist_ok=True)
+ # 如果有node_modules文件夹,则认为已安装
+ if await (mcp_path / "node_modules").exists():
+ config.command = npm_path
+ if "exec" not in config.args:
+ config.args = ["exec", *config.args]
+ return config
# 安装NPM包
pipe = await subprocess.create_subprocess_shell(
- f"npm install {package}",
+ f"{npm_path} install {package}",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
cwd=mcp_path,
@@ -137,8 +155,9 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer
print(f"[Installer] 安装 {package} 成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201
# 更新配置
- config.command = "npm"
- config.args = ["exec", *config.args]
+ config.command = npm_path
+ if "exec" not in config.args:
+ config.args = ["exec", *config.args]
config.auto_install = False
return config
diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py
index 91cde4d9d6c83fd5088328e12cfbb6bf06d3c7cb..bf0320f429a9ef45864ba6548b1bb28e3d874b59 100644
--- a/apps/scheduler/pool/mcp/pool.py
+++ b/apps/scheduler/pool/mcp/pool.py
@@ -21,16 +21,13 @@ class MCPPool(metaclass=SingletonMeta):
"""初始化MCP池"""
self.pool = {}
-
async def _init_mcp(self, mcp_id: str, user_sub: str) -> MCPClient | None:
"""初始化MCP池"""
- mcp_math = MCP_USER_PATH / user_sub / mcp_id / "project"
config_path = MCP_USER_PATH / user_sub / mcp_id / "config.json"
-
- if not await mcp_math.exists() or not await mcp_math.is_dir():
- logger.warning("[MCPPool] 用户 %s 的MCP %s 未激活", user_sub, mcp_id)
+ flag = (await config_path.exists())
+ if not flag:
+ logger.warning("[MCPPool] 用户 %s 的MCP %s 配置文件不存在", user_sub, mcp_id)
return None
-
config = MCPServerConfig.model_validate_json(await config_path.read_text())
if config.type in (MCPType.SSE, MCPType.STDIO):
@@ -40,9 +37,11 @@ class MCPPool(metaclass=SingletonMeta):
return None
await client.init(user_sub, mcp_id, config.config)
+ if user_sub not in self.pool:
+ self.pool[user_sub] = {}
+ self.pool[user_sub][mcp_id] = client
return client
-
async def _get_from_dict(self, mcp_id: str, user_sub: str) -> MCPClient | None:
"""从字典中获取MCP客户端"""
if user_sub not in self.pool:
@@ -53,7 +52,6 @@ class MCPPool(metaclass=SingletonMeta):
return self.pool[user_sub][mcp_id]
-
async def _validate_user(self, mcp_id: str, user_sub: str) -> bool:
"""验证用户是否已激活"""
mongo = MongoDB()
@@ -61,7 +59,6 @@ class MCPPool(metaclass=SingletonMeta):
mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub})
return mcp_db_result is not None
-
async def get(self, mcp_id: str, user_sub: str) -> MCPClient | None:
"""获取MCP客户端"""
item = await self._get_from_dict(mcp_id, user_sub)
@@ -83,7 +80,6 @@ class MCPPool(metaclass=SingletonMeta):
return item
-
async def stop(self, mcp_id: str, user_sub: str) -> None:
"""停止MCP客户端"""
await self.pool[user_sub][mcp_id].stop()
diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py
index 7710d24dc102fe02c06d8cbc14f126bd088db2d4..ead552fc2994150137121b103ff013d8ea4d66a5 100644
--- a/apps/scheduler/pool/pool.py
+++ b/apps/scheduler/pool/pool.py
@@ -60,7 +60,6 @@ class Pool:
await Path(root_dir + "mcp").unlink(missing_ok=True)
await Path(root_dir + "mcp").mkdir(parents=True, exist_ok=True)
-
@staticmethod
async def init() -> None:
"""
@@ -121,13 +120,15 @@ class Pool:
for app in changed_app:
hash_key = Path("app/" + app).as_posix()
if hash_key in checker.hashes:
- await app_loader.load(app, checker.hashes[hash_key])
-
+ try:
+ await app_loader.load(app, checker.hashes[hash_key])
+ except Exception as e:
+ await app_loader.delete(app, is_reload=True)
+ logger.warning("[Pool] 加载App %s 失败: %s", app, e)
# 载入MCP
logger.info("[Pool] 载入MCP")
await MCPLoader.init()
-
async def get_flow_metadata(self, app_id: str) -> list[AppFlow]:
"""从数据库中获取特定App的全部Flow的元数据"""
mongo = MongoDB()
@@ -145,14 +146,12 @@ class Pool:
else:
return flow_metadata_list
-
async def get_flow(self, app_id: str, flow_id: str) -> Flow | None:
"""从文件系统中获取单个Flow的全部数据"""
logger.info("[Pool] 获取工作流 %s", flow_id)
flow_loader = FlowLoader()
return await flow_loader.load(app_id, flow_id)
-
async def get_call(self, call_id: str) -> Any:
"""[Exception] 拿到Call的信息"""
# 从MongoDB里拿到数据
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index b7088d8da45ad01e94e0cf7208db25c5f56c3b22..6e737314822c7d9c5a9671c8670bd5a38db40e3d 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -10,6 +10,7 @@ from apps.llm.patterns.facts import Facts
from apps.schemas.collection import Document
from apps.schemas.enum_var import StepStatus
from apps.schemas.record import (
+ FlowHistory,
Record,
RecordContent,
RecordDocument,
@@ -114,11 +115,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
used_docs.append(
RecordGroupDocument(
_id=docs["id"],
+ author=docs.get("author", ""),
+ order=docs.get("order", 0),
name=docs["name"],
abstract=docs.get("abstract", ""),
extension=docs.get("extension", ""),
size=docs.get("size", 0),
associated="answer",
+ created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)),
)
)
if docs.get("order") is not None:
@@ -154,7 +158,6 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
facts=task.runtime.facts,
data={},
)
-
try:
# 加密Record数据
encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True))
@@ -185,34 +188,37 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
feature={},
),
createdAt=current_time,
- flow=[i["_id"] for i in task.context],
+ flow=FlowHistory(
+ flow_id=task.state.flow_id,
+ flow_name=task.state.flow_name,
+ flow_status=task.state.flow_status,
+ history_ids=[context.id for context in task.context],
+ )
)
# 检查是否存在group_id
if not await RecordManager.check_group_id(task.ids.group_id, user_sub):
- record_group = await RecordManager.create_record_group(
- task.ids.group_id, user_sub, post_body.conversation_id, task.id,
+ record_group_id = await RecordManager.create_record_group(
+ task.ids.group_id, user_sub, post_body.conversation_id
)
- if not record_group:
+ if not record_group_id:
logger.error("[Scheduler] 创建问答组失败")
return
else:
- record_group = task.ids.group_id
+ record_group_id = task.ids.group_id
# 修改文件状态
- await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group)
+ await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group_id)
# 保存Record
- await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record)
+ await RecordManager.insert_record_data_into_record_group(user_sub, record_group_id, record)
# 保存与答案关联的文件
- await DocumentManager.save_answer_doc(user_sub, record_group, used_docs)
+ await DocumentManager.save_answer_doc(user_sub, record_group_id, used_docs)
if post_body.app and post_body.app.app_id:
# 更新最近使用的应用
await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id)
-
# 若状态为成功,删除Task
- if not task.state or task.state.status == StepStatus.SUCCESS:
+ if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED:
await TaskManager.delete_task_by_task_id(task.id)
else:
- # 更新Task
await TaskManager.save_task(task.id, task)
diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py
index c89fdd1014ba4714d0777742412e98c0802ca68c..d43ba7fb891d2b1ed7ecdb3378a8ad8e0e512286 100644
--- a/apps/scheduler/scheduler/message.py
+++ b/apps/scheduler/scheduler/message.py
@@ -15,6 +15,7 @@ from apps.schemas.message import (
InitContentFeature,
TextAddContent,
)
+from apps.schemas.enum_var import FlowStatus
from apps.schemas.rag_data import RAGEventData, RAGQueryReq
from apps.schemas.record import RecordDocument
from apps.schemas.task import Task
@@ -61,28 +62,25 @@ async def push_init_message(
async def push_rag_message(
task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]],
doc_ids: list[str],
- rag_data: RAGQueryReq,) -> Task:
+ rag_data: RAGQueryReq,) -> None:
"""推送RAG消息"""
full_answer = ""
-
- async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data):
- task, content_obj = await _push_rag_chunk(task, queue, chunk)
- if content_obj.event_type == EventType.TEXT_ADD.value:
- # 如果是文本消息,直接拼接到答案中
- full_answer += content_obj.content
- elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
- task.runtime.documents.append({
- "id": content_obj.content.get("id", ""),
- "order": content_obj.content.get("order", 0),
- "name": content_obj.content.get("name", ""),
- "abstract": content_obj.content.get("abstract", ""),
- "extension": content_obj.content.get("extension", ""),
- "size": content_obj.content.get("size", 0),
- })
+ try:
+ async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data):
+ task, content_obj = await _push_rag_chunk(task, queue, chunk)
+ if content_obj.event_type == EventType.TEXT_ADD.value:
+ # 如果是文本消息,直接拼接到答案中
+ full_answer += content_obj.content
+ elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
+ task.runtime.documents.append(content_obj.content)
+ task.state.flow_status = FlowStatus.SUCCESS
+ except Exception as e:
+ logger.error(f"[Scheduler] RAG服务发生错误: {e}")
+ task.state.flow_status = FlowStatus.ERROR
# 保存答案
task.runtime.answer = full_answer
+ task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time
await TaskManager.save_task(task.id, task)
- return task
async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]:
@@ -115,10 +113,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl
data=DocumentAddContent(
documentId=content_obj.content.get("id", ""),
documentOrder=content_obj.content.get("order", 0),
+ documentAuthor=content_obj.content.get("author", ""),
documentName=content_obj.content.get("name", ""),
documentAbstract=content_obj.content.get("abstract", ""),
documentType=content_obj.content.get("extension", ""),
documentSize=content_obj.content.get("size", 0),
+ createdAt=round(datetime.now(tz=UTC).timestamp(), 3),
).model_dump(exclude_none=True, by_alias=True),
)
except Exception:
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index ed73638ced241909f55989597f8bb747b50f1945..b524d249a6f55caf779a473ef39db35a463525b9 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -1,9 +1,12 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Scheduler模块"""
+import asyncio
import logging
from datetime import UTC, datetime
-
+from apps.llm.reasoning import ReasoningLLM
+from apps.schemas.config import LLMConfig
+from apps.llm.patterns.rewrite import QuestionRewrite
from apps.common.config import Config
from apps.common.mongo import MongoDB
from apps.common.queue import MessageQueue
@@ -17,11 +20,12 @@ from apps.scheduler.scheduler.message import (
push_rag_message,
)
from apps.schemas.collection import LLM
-from apps.schemas.enum_var import AppType, EventType
+from apps.schemas.enum_var import FlowStatus, AppType, EventType
from apps.schemas.pool import AppPool
from apps.schemas.rag_data import RAGQueryReq
from apps.schemas.request_data import RequestData
from apps.schemas.scheduler import ExecutorBackground
+from apps.services.activity import Activity
from apps.schemas.task import Task
from apps.services.appcenter import AppCenterManager
from apps.services.knowledge import KnowledgeBaseManager
@@ -41,12 +45,31 @@ class Scheduler:
"""初始化"""
self.used_docs = []
self.task = task
-
self.queue = queue
self.post_body = post_body
- async def run(self) -> None: # noqa: PLR0911
- """运行调度器"""
+ async def _monitor_activity(self, kill_event):
+ """监控用户活动状态,不活跃时终止工作流"""
+ try:
+ check_interval = 0.5 # 每0.5秒检查一次
+
+ while not kill_event.is_set():
+ # 检查用户活动状态
+ is_active = await Activity.is_active(self.task.ids.active_id)
+ if not is_active:
+ logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", self.task.ids.user_sub)
+ kill_event.set()
+ break
+
+ # 控制检查频率
+ await asyncio.sleep(check_interval)
+ except asyncio.CancelledError:
+ logger.info("[Scheduler] 活动监控任务已取消")
+ except Exception as e:
+ logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}")
+
+ async def get_llm_use_in_chat_with_rag(self) -> LLM:
+ """获取RAG大模型"""
try:
# 获取当前会话使用的大模型
llm_id = await LLMManager.get_llm_id_by_conversation_id(
@@ -54,8 +77,7 @@ class Scheduler:
)
if not llm_id:
logger.error("[Scheduler] 获取大模型ID失败")
- await self.queue.close()
- return
+ return None
if llm_id == "empty":
llm = LLM(
_id="empty",
@@ -65,24 +87,31 @@ class Scheduler:
model_name=Config().get_config().llm.model,
max_tokens=Config().get_config().llm.max_tokens,
)
+ return llm
else:
llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id)
if not llm:
logger.error("[Scheduler] 获取大模型失败")
- await self.queue.close()
- return
+ return None
+ return llm
except Exception:
logger.exception("[Scheduler] 获取大模型失败")
- await self.queue.close()
- return
+ return None
+
+ async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]:
+ """获取知识库ID列表"""
try:
- # 获取当前会话使用的知识库
kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id(
- self.task.ids.user_sub, self.task.ids.conversation_id)
+ self.task.ids.user_sub, self.task.ids.conversation_id,
+ )
+ return kb_ids
except Exception:
logger.exception("[Scheduler] 获取知识库ID失败")
await self.queue.close()
- return
+ return []
+
+ async def run(self) -> None: # noqa: PLR0911
+ """运行调度器"""
try:
# 获取当前问答可供关联的文档
docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body)
@@ -92,18 +121,30 @@ class Scheduler:
return
history, _ = await get_context(self.task.ids.user_sub, self.post_body, 3)
# 已使用文档
-
# 如果是智能问答,直接执行
logger.info("[Scheduler] 开始执行")
- if not self.post_body.app or self.post_body.app.app_id == "":
+ # 创建用于通信的事件
+ kill_event = asyncio.Event()
+ monitor = asyncio.create_task(self._monitor_activity(kill_event))
+ rag_method = True
+ if self.post_body.app and self.post_body.app.app_id:
+ rag_method = False
+ if self.task.state.app_id:
+ rag_method = False
+ if rag_method:
+ llm = await self.get_llm_use_in_chat_with_rag()
+ kb_ids = await self.get_kb_ids_use_in_chat_with_rag()
self.task = await push_init_message(self.task, self.queue, 3, is_flow=False)
rag_data = RAGQueryReq(
kbIds=kb_ids,
query=self.post_body.question,
tokensLimit=llm.max_tokens,
)
- self.task = await push_rag_message(self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data)
- self.task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time
+
+ # 启动监控任务和主任务
+ main_task = asyncio.create_task(push_rag_message(
+ self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data))
+
else:
# 查找对应的App元数据
app_data = await AppCenterManager.fetch_app_data_by_id(self.post_body.app.app_id)
@@ -127,8 +168,27 @@ class Scheduler:
conversation=context,
facts=facts,
)
- await self.run_executor(self.queue, self.post_body, executor_background)
+ # 启动监控任务和主任务
+ main_task = asyncio.create_task(self.run_executor(self.queue, self.post_body, executor_background))
+ # 等待任一任务完成
+ done, pending = await asyncio.wait(
+ [main_task, monitor],
+ return_when=asyncio.FIRST_COMPLETED
+ )
+
+ # 如果是监控任务触发,终止主任务
+ if kill_event.is_set():
+ logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...")
+ main_task.cancel()
+ need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING]
+ if self.task.state.flow_status in need_change_cancel_flow_state:
+ self.task.state.flow_status = FlowStatus.CANCELLED
+ try:
+ await main_task
+ logger.info("[Scheduler] 工作流执行已被终止")
+ except Exception as e:
+ logger.error(f"[Scheduler] 终止工作流时发生错误: {e}")
# 更新Task,发送结束消息
logger.info("[Scheduler] 发送结束消息")
await self.queue.push_output(self.task, event_type=EventType.DONE.value, data={})
@@ -152,6 +212,37 @@ class Scheduler:
if not app_metadata:
logger.error("[Scheduler] 未找到Agent应用")
return
+ if app_metadata.llm_id == "empty":
+ llm = LLM(
+ _id="empty",
+ user_sub=self.task.ids.user_sub,
+ openai_base_url=Config().get_config().llm.endpoint,
+ openai_api_key=Config().get_config().llm.key,
+ model_name=Config().get_config().llm.model,
+ max_tokens=Config().get_config().llm.max_tokens,
+ )
+ else:
+ llm = await LLMManager.get_llm_by_id(
+ self.task.ids.user_sub, app_metadata.llm_id,
+ )
+ if not llm:
+ logger.error("[Scheduler] 获取大模型失败")
+ await self.queue.close()
+ return
+ reasion_llm = ReasoningLLM(
+ LLMConfig(
+ endpoint=llm.openai_base_url,
+ key=llm.openai_api_key,
+ model=llm.model_name,
+ max_tokens=llm.max_tokens,
+ )
+ )
+ if background.conversation and self.task.state.flow_status == FlowStatus.INIT:
+ try:
+ question_obj = QuestionRewrite()
+ post_body.question = await question_obj.generate(history=background.conversation, question=post_body.question, llm=reasion_llm)
+ except Exception:
+ logger.exception("[Scheduler] 问题重写失败")
if app_metadata.app_type == AppType.FLOW.value:
logger.info("[Scheduler] 获取工作流元数据")
flow_info = await Pool().get_flow_metadata(app_info.app_id)
@@ -182,7 +273,6 @@ class Scheduler:
# 初始化Executor
logger.info("[Scheduler] 初始化Executor")
-
flow_exec = FlowExecutor(
flow_id=flow_id,
flow=flow_data,
@@ -206,10 +296,11 @@ class Scheduler:
task=self.task,
msg_queue=queue,
question=post_body.question,
- max_steps=app_metadata.history_len,
+ history_len=app_metadata.history_len,
servers_id=servers_id,
background=background,
agent_id=app_info.app_id,
+ params=post_body.params
)
# 开始运行
logger.info("[Scheduler] 运行Executor")
diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py
index 4caab4d2dc945ec5ac14b7769770f94c39980a9f..7a6313a1c14c3122fab87905c07ce1b4bc6a0d84 100644
--- a/apps/scheduler/slot/slot.py
+++ b/apps/scheduler/slot/slot.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""参数槽位管理"""
+import copy
import json
import logging
import traceback
@@ -12,6 +13,8 @@ from jsonschema.exceptions import ValidationError
from jsonschema.protocols import Validator
from jsonschema.validators import extend
+from apps.schemas.response_data import ParamsNode
+from apps.scheduler.call.choice.schema import Type
from apps.scheduler.slot.parser import (
SlotConstParser,
SlotDateParser,
@@ -126,7 +129,7 @@ class Slot:
# Schema标准
return [_process_json_value(item, spec_data["items"]) for item in json_value]
if spec_data["type"] == "object" and isinstance(json_value, dict):
- # 若Schema不标准,则不进行处理
+ # 若Schema不标准,则不进行处理F
if "properties" not in spec_data:
return json_value
# Schema标准
@@ -154,35 +157,60 @@ class Slot:
@staticmethod
def _generate_example(schema_node: dict) -> Any: # noqa: PLR0911
"""根据schema生成示例值"""
+ if "anyOf" in schema_node or "oneOf" in schema_node:
+ # 如果有anyOf,随机返回一个示例
+ for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]:
+ example = Slot._generate_example(item)
+ if example is not None:
+ return example
+
+ if "allOf" in schema_node:
+ # 如果有allOf,返回所有示例的合并
+ example = None
+ for item in schema_node["allOf"]:
+ if example is None:
+ example = Slot._generate_example(item)
+ else:
+ other_example = Slot._generate_example(item)
+ if isinstance(example, dict) and isinstance(other_example, dict):
+ example.update(other_example)
+ else:
+ example = None
+ break
+ return example
+
if "default" in schema_node:
return schema_node["default"]
if "type" not in schema_node:
return None
-
+ type_value = schema_node["type"]
+ if isinstance(type_value, list):
+ # 如果是多类型,随机返回一个示例
+ if len(type_value) > 1:
+ type_value = type_value[0]
# 处理类型为 object 的节点
- if schema_node["type"] == "object":
+ if type_value == "object":
data = {}
properties = schema_node.get("properties", {})
for name, schema in properties.items():
data[name] = Slot._generate_example(schema)
return data
-
# 处理类型为 array 的节点
- if schema_node["type"] == "array":
+ elif type_value == "array":
items_schema = schema_node.get("items", {})
return [Slot._generate_example(items_schema)]
# 处理类型为 string 的节点
- if schema_node["type"] == "string":
+ elif type_value == "string":
return ""
# 处理类型为 number 或 integer 的节点
- if schema_node["type"] in ["number", "integer"]:
+ elif type_value in ["number", "integer"]:
return 0
# 处理类型为 boolean 的节点
- if schema_node["type"] == "boolean":
+ elif type_value == "boolean":
return False
# 处理其他类型或未定义类型
@@ -196,31 +224,114 @@ class Slot:
"""从JSON Schema中提取类型描述"""
def _extract_type_desc(schema_node: dict[str, Any]) -> dict[str, Any]:
- if "type" not in schema_node and "anyOf" not in schema_node:
- return {}
- data = {"type": schema_node.get("type", ""), "description": schema_node.get("description", "")}
- if "anyOf" in schema_node:
- data["type"] = "anyOf"
- # 处理类型为 object 的节点
- if "anyOf" in schema_node:
- data["items"] = {}
- type_index = 0
- for type_index, sub_schema in enumerate(schema_node["anyOf"]):
- sub_result = _extract_type_desc(sub_schema)
- if sub_result:
- data["items"]["type_"+str(type_index)] = sub_result
- if schema_node.get("type", "") == "object":
- data["items"] = {}
+ # 处理组合关键字
+ special_keys = ["anyOf", "allOf", "oneOf"]
+ for key in special_keys:
+ if key in schema_node:
+ data = {
+ "type": key,
+ "description": schema_node.get("description", ""),
+ "items": {},
+ }
+ type_index = 0
+ for item in schema_node[key]:
+ if isinstance(item, dict):
+ data["items"][f"item_{type_index}"] = _extract_type_desc(item)
+ else:
+ data["items"][f"item_{type_index}"] = {"type": item, "description": ""}
+ type_index += 1
+ return data
+ # 处理基本类型
+ type_val = schema_node.get("type", "")
+ description = schema_node.get("description", "")
+
+ # 处理多类型数组
+ if isinstance(type_val, list):
+ if len(type_val) > 1:
+ data = {"type": "union", "description": description, "items": {}}
+ type_index = 0
+ for t in type_val:
+ if t == "object":
+ tmp_dict = {}
+ for key, val in schema_node.get("properties", {}).items():
+ tmp_dict[key] = _extract_type_desc(val)
+ data["items"][f"item_{type_index}"] = tmp_dict
+ elif t == "array":
+ items_schema = schema_node.get("items", {})
+ data["items"][f"item_{type_index}"] = _extract_type_desc(items_schema)
+ else:
+ data["items"][f"item_{type_index}"] = {"type": t, "description": description}
+ type_index += 1
+ return data
+ elif len(type_val) == 1:
+ type_val = type_val[0]
+ else:
+ type_val = ""
+
+ data = {"type": type_val, "description": description, "items": {}}
+
+ # 递归处理对象和数组
+ if type_val == "object":
for key, val in schema_node.get("properties", {}).items():
data["items"][key] = _extract_type_desc(val)
-
- # 处理类型为 array 的节点
- if schema_node.get("type", "") == "array":
+ elif type_val == "array":
items_schema = schema_node.get("items", {})
- data["items"] = _extract_type_desc(items_schema)
+ if isinstance(items_schema, list):
+ item_index = 0
+ for item in items_schema:
+ data["items"][f"item_{item_index}"] = _extract_type_desc(item)
+ item_index += 1
+ else:
+ data["items"]["item"] = _extract_type_desc(items_schema)
+ if data["items"] == {}:
+ del data["items"]
return data
+
return _extract_type_desc(self._schema)
+ def get_params_node_from_schema(self, root: str = "") -> ParamsNode:
+ """从JSON Schema中提取ParamsNode"""
+ def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode:
+ """递归提取ParamsNode"""
+ if "type" not in schema_node:
+ return None
+
+ param_type = schema_node["type"]
+ if isinstance(param_type, list):
+ return None # 不支持多类型
+ if param_type == "object":
+ param_type = Type.DICT
+ elif param_type == "array":
+ param_type = Type.LIST
+ elif param_type == "string":
+ param_type = Type.STRING
+ elif param_type in ["number", "integer"]:
+ param_type = Type.NUMBER
+ elif param_type == "boolean":
+ param_type = Type.BOOL
+ else:
+ logger.warning(f"[Slot] 不支持的参数类型: {param_type}")
+ return None
+ sub_params = []
+
+ if param_type == Type.DICT and "properties" in schema_node:
+ for key, value in schema_node["properties"].items():
+ sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}")
+ if sub_param:
+ sub_params.append(sub_param)
+ else:
+ # 对于非对象类型,直接返回空子参数
+ sub_params = None
+ return ParamsNode(paramName=name,
+ paramPath=path,
+ paramType=param_type,
+ subParams=sub_params)
+ try:
+ return _extract_params_node(self._schema, name=root, path=root)
+ except Exception as e:
+ logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}")
+ return None
+
def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]:
"""将JSON Schema扁平化"""
result = {}
@@ -276,7 +387,6 @@ class Slot:
logger.exception("[Slot] 错误schema不合法: %s", error.schema)
return {}, []
-
def _assemble_patch(
self,
key: str,
@@ -329,7 +439,6 @@ class Slot:
logger.info("[Slot] 组装patch: %s", patch_list)
return patch_list
-
def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]:
"""将用户手动填充的参数专为真实JSON"""
json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data
@@ -373,3 +482,54 @@ class Slot:
return schema_template
return {}
+
+ def add_null_to_basic_types(self) -> dict[str, Any]:
+ """
+ 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项
+ """
+ def add_null_to_basic_types(schema: dict[str, Any]) -> dict[str, Any]:
+ """
+ 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项
+
+ 参数:
+ schema (dict): 原始 JSON Schema
+
+ 返回:
+ dict: 修改后的 JSON Schema
+ """
+ # 如果不是字典类型(schema),直接返回
+ if not isinstance(schema, dict):
+ return schema
+
+ # 处理当前节点的 type 字段
+ if 'type' in schema:
+ # 处理单一类型字符串
+ if isinstance(schema['type'], str):
+ if schema['type'] in ['boolean', 'number', 'string', 'integer']:
+ schema['type'] = [schema['type'], 'null']
+
+ # 处理类型数组
+ elif isinstance(schema['type'], list):
+ for i, t in enumerate(schema['type']):
+ if isinstance(t, str) and t in ['boolean', 'number', 'string', 'integer']:
+ if 'null' not in schema['type']:
+ schema['type'].append('null')
+ break
+
+ # 递归处理 properties 字段(对象类型)
+ if 'properties' in schema:
+ for prop, prop_schema in schema['properties'].items():
+ schema['properties'][prop] = add_null_to_basic_types(prop_schema)
+
+ # 递归处理 items 字段(数组类型)
+ if 'items' in schema:
+ schema['items'] = add_null_to_basic_types(schema['items'])
+
+ # 递归处理 anyOf, oneOf, allOf 字段
+ for keyword in ['anyOf', 'oneOf', 'allOf']:
+ if keyword in schema:
+ schema[keyword] = [add_null_to_basic_types(sub_schema) for sub_schema in schema[keyword]]
+
+ return schema
+ schema_copy = copy.deepcopy(self._schema)
+ return add_null_to_basic_types(schema_copy)
diff --git a/apps/schemas/agent.py b/apps/schemas/agent.py
index b52f5e1c3315fb873acddb2bcc4e27937498d561..16e818e4d588d4f25ea4854a7979e2dc5fe90ecb 100644
--- a/apps/schemas/agent.py
+++ b/apps/schemas/agent.py
@@ -17,6 +17,7 @@ class AgentAppMetadata(MetadataBase):
app_type: AppType = Field(default=AppType.AGENT, description="应用类型", frozen=True)
published: bool = Field(description="是否发布", default=False)
history_len: int = Field(description="对话轮次", default=3, le=10)
- mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表")
+ mcp_service: list[str] = Field(default=[], description="MCP服务id列表")
+ llm_id: str = Field(default="empty", description="大模型ID")
permission: Permission | None = Field(description="应用权限配置", default=None)
version: str = Field(description="元数据版本")
diff --git a/apps/schemas/appcenter.py b/apps/schemas/appcenter.py
index a89f39df18083d90b988cc764c0e88f90500d1f3..e3bb896eba2361e0c28ee914b84f387d7da76b87 100644
--- a/apps/schemas/appcenter.py
+++ b/apps/schemas/appcenter.py
@@ -45,9 +45,9 @@ class AppFlowInfo(BaseModel):
"""应用工作流数据结构"""
id: str = Field(..., description="工作流ID")
- name: str = Field(..., description="工作流名称")
- description: str = Field(..., description="工作流简介")
- debug: bool = Field(..., description="是否经过调试")
+ name: str = Field(default="", description="工作流名称")
+ description: str = Field(default="", description="工作流简介")
+ debug: bool = Field(default=False, description="是否经过调试")
class AppData(BaseModel):
@@ -61,6 +61,7 @@ class AppData(BaseModel):
first_questions: list[str] = Field(
default=[], alias="recommendedQuestions", description="推荐问题", max_length=3)
history_len: int = Field(3, alias="dialogRounds", ge=1, le=10, description="对话轮次(1~10)")
+ llm: str = Field(default="empty", description="大模型ID")
permission: AppPermissionData = Field(
default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置")
workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表")
diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py
index 0ff66c72bbe30b7cbb55517952c8b14744d69cf2..a2991f851a85c8d60afb6d60e76cbb766c9da94e 100644
--- a/apps/schemas/collection.py
+++ b/apps/schemas/collection.py
@@ -61,6 +61,7 @@ class User(BaseModel):
fav_apps: list[str] = []
fav_services: list[str] = []
is_admin: bool = Field(default=False, description="是否为管理员")
+ auto_execute: bool = Field(default=True, description="是否自动执行任务")
class LLM(BaseModel):
diff --git a/apps/schemas/config.py b/apps/schemas/config.py
index b88a81f1afb018663713a3bf4c0b9b62436e59d4..675a9ba7f26531206a5ea5f8a7598d4344ae2124 100644
--- a/apps/schemas/config.py
+++ b/apps/schemas/config.py
@@ -6,6 +6,12 @@ from typing import Literal
from pydantic import BaseModel, Field
+class NoauthConfig(BaseModel):
+ """无认证配置"""
+
+ enable: bool = Field(description="是否启用无认证访问", default=False)
+
+
class DeployConfig(BaseModel):
"""部署配置"""
@@ -122,7 +128,7 @@ class ExtraConfig(BaseModel):
class ConfigModel(BaseModel):
"""配置文件的校验Class"""
-
+ no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig())
deploy: DeployConfig
login: LoginConfig
embedding: EmbeddingConfig
diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py
index 9a20ba84d4805bcd502d9d4ca0f9ba3c49a7bb4c..3ae7c42528d27c3e0d17970ba5e959db8e1bed44 100644
--- a/apps/schemas/enum_var.py
+++ b/apps/schemas/enum_var.py
@@ -14,11 +14,26 @@ class SlotType(str, Enum):
class StepStatus(str, Enum):
"""步骤状态"""
-
+ UNKNOWN = "unknown"
+ INIT = "init"
+ WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
ERROR = "error"
PARAM = "param"
+ CANCELLED = "cancelled"
+
+
+class FlowStatus(str, Enum):
+ """Flow状态"""
+
+ UNKNOWN = "unknown"
+ INIT = "init"
+ WAITING = "waiting"
+ RUNNING = "running"
+ SUCCESS = "success"
+ ERROR = "error"
+ CANCELLED = "cancelled"
class DocumentStatus(str, Enum):
@@ -34,14 +49,22 @@ class EventType(str, Enum):
"""事件类型"""
HEARTBEAT = "heartbeat"
- INIT = "init",
+ INIT = "init"
TEXT_ADD = "text.add"
GRAPH = "graph"
DOCUMENT_ADD = "document.add"
+ STEP_WAITING_FOR_START = "step.waiting_for_start"
+ STEP_WAITING_FOR_PARAM = "step.waiting_for_param"
FLOW_START = "flow.start"
+ STEP_INIT = "step.init"
STEP_INPUT = "step.input"
STEP_OUTPUT = "step.output"
+ STEP_CANCEL = "step.cancel"
+ STEP_ERROR = "step.error"
FLOW_STOP = "flow.stop"
+ FLOW_FAILED = "flow.failed"
+ FLOW_SUCCESS = "flow.success"
+ FLOW_CANCEL = "flow.cancel"
DONE = "done"
@@ -81,7 +104,7 @@ class NodeType(str, Enum):
START = "start"
END = "end"
NORMAL = "normal"
- CHOICE = "choice"
+ CHOICE = "Choice"
class SaveType(str, Enum):
@@ -144,7 +167,7 @@ class SpecialCallType(str, Enum):
LLM = "LLM"
START = "start"
END = "end"
- CHOICE = "choice"
+ CHOICE = "Choice"
class CommentType(str, Enum):
diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py
index 2646d04390099fd92988d78c00d1b1471780d6a9..dfffd1f14c779952d9bb761d2e2bccd85528bc0f 100644
--- a/apps/schemas/flow.py
+++ b/apps/schemas/flow.py
@@ -136,6 +136,7 @@ class AppMetadata(MetadataBase):
published: bool = Field(description="是否发布", default=False)
links: list[AppLink] = Field(description="相关链接", default=[])
first_questions: list[str] = Field(description="首次提问", default=[])
+ llm_id: str = Field(description="大模型ID", default="empty")
history_len: int = Field(description="对话轮次", default=3, le=10)
permission: Permission | None = Field(description="应用权限配置", default=None)
flows: list[AppFlow] = Field(description="Flow列表", default=[])
diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py
index d0ab666a9902ec7e0d81bb922f973d655f29eaf2..fecf21606634982ad6480069a7adc6714a612db2 100644
--- a/apps/schemas/flow_topology.py
+++ b/apps/schemas/flow_topology.py
@@ -5,6 +5,7 @@ from typing import Any
from pydantic import BaseModel, Field
+from apps.schemas.enum_var import SpecialCallType
from apps.schemas.enum_var import EdgeType
@@ -51,7 +52,7 @@ class NodeItem(BaseModel):
service_id: str = Field(alias="serviceId", default="")
node_id: str = Field(alias="nodeId", default="")
name: str = Field(default="")
- call_id: str = Field(alias="callId", default="Empty")
+ call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value)
description: str = Field(default="")
enable: bool = Field(default=True)
parameters: dict[str, Any] = Field(default={})
@@ -81,6 +82,6 @@ class FlowItem(BaseModel):
nodes: list[NodeItem] = Field(default=[])
edges: list[EdgeItem] = Field(default=[])
created_at: float | None = Field(alias="createdAt", default=0)
- connectivity: bool = Field(default=False,description="图的开始节点和结束节点是否联通,并且除结束节点都有出边")
+ connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边")
focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem())
debug: bool = Field(default=False)
diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py
index 44021b0ed5f5f6c372ac0a472bb3ac28ff5bbddb..d27a5b469877f24639d1566f62f8b216636d3fd7 100644
--- a/apps/schemas/mcp.py
+++ b/apps/schemas/mcp.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 相关数据结构"""
+import uuid
from enum import Enum
from typing import Any
@@ -10,8 +11,9 @@ from pydantic import BaseModel, Field
class MCPInstallStatus(str, Enum):
"""MCP 服务状态"""
-
+ INIT = "init"
INSTALLING = "installing"
+ CANCELLED = "cancelled"
READY = "ready"
FAILED = "failed"
@@ -22,6 +24,7 @@ class MCPStatus(str, Enum):
UNINITIALIZED = "uninitialized"
RUNNING = "running"
STOPPED = "stopped"
+ ERROR = "error"
class MCPType(str, Enum):
@@ -38,7 +41,9 @@ class MCPBasicConfig(BaseModel):
env: dict[str, str] = Field(description="MCP 服务器环境变量", default={})
auto_approve: list[str] = Field(description="自动批准的MCP工具ID列表", default=[], alias="autoApprove")
disabled: bool = Field(description="MCP 服务器是否禁用", default=False)
- auto_install: bool = Field(description="是否自动安装MCP服务器", default=True, alias="autoInstall")
+ auto_install: bool = Field(description="是否自动安装MCP服务器", default=True)
+ timeout: int = Field(description="MCP 服务器超时时间(秒)", default=60, alias="timeout")
+ description: str = Field(description="MCP 服务器自然语言描述", default="")
class MCPServerStdioConfig(MCPBasicConfig):
@@ -84,7 +89,7 @@ class MCPCollection(BaseModel):
type: MCPType = Field(description="MCP 类型", default=MCPType.SSE)
activated: list[str] = Field(description="激活该MCP的用户ID列表", default=[])
tools: list[MCPTool] = Field(description="MCP工具列表", default=[])
- status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INSTALLING)
+ status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INIT)
author: str = Field(description="MCP作者", default="")
@@ -103,6 +108,60 @@ class MCPToolVector(LanceModel):
embedding: Vector(dim=1024) = Field(description="MCP工具描述的向量信息") # type: ignore[call-arg]
+class GoalEvaluationResult(BaseModel):
+ """MCP 目标评估结果"""
+
+ can_complete: bool = Field(description="是否可以完成目标")
+ reason: str = Field(description="评估原因")
+
+
+class RestartStepIndex(BaseModel):
+ """MCP重新规划的步骤索引"""
+
+ start_index: int = Field(description="重新规划的起始步骤索引")
+ reasoning: str = Field(description="重新规划的原因")
+
+
+class Risk(str, Enum):
+ """MCP工具风险类型"""
+
+ LOW = "low"
+ MEDIUM = "medium"
+ HIGH = "high"
+
+
+class ToolSkip(BaseModel):
+ """MCP工具跳过执行结果"""
+
+ skip: bool = Field(description="是否跳过当前步骤", default=False)
+
+
+class ToolRisk(BaseModel):
+ """MCP工具风险评估结果"""
+
+ risk: Risk = Field(description="风险类型", default=Risk.LOW)
+ reason: str = Field(description="风险原因", default="")
+
+
+class ErrorType(str, Enum):
+ """MCP工具错误类型"""
+
+ MISSING_PARAM = "missing_param"
+ DECORRECT_PLAN = "decorrect_plan"
+
+
+class ToolExcutionErrorType(BaseModel):
+ """MCP工具执行错误"""
+
+ type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM)
+ reason: str = Field(description="错误原因", default="")
+
+
+class IsParamError(BaseModel):
+ """MCP工具参数错误"""
+
+ is_param_error: bool = Field(description="是否是参数错误", default=False)
+
class MCPSelectResult(BaseModel):
"""MCP选择结果"""
@@ -115,9 +174,15 @@ class MCPToolSelectResult(BaseModel):
name: str = Field(description="工具名称")
+class MCPToolIdsSelectResult(BaseModel):
+ """MCP工具ID选择结果"""
+
+ tool_ids: list[str] = Field(description="工具ID列表")
+
+
class MCPPlanItem(BaseModel):
"""MCP 计划"""
-
+ step_id: str = Field(description="步骤的ID", default="")
content: str = Field(description="计划内容")
tool: str = Field(description="工具名称")
instruction: str = Field(description="工具指令")
@@ -126,4 +191,11 @@ class MCPPlanItem(BaseModel):
class MCPPlan(BaseModel):
"""MCP 计划"""
- plans: list[MCPPlanItem] = Field(description="计划列表")
+ plans: list[MCPPlanItem] = Field(description="计划列表", default=[])
+
+
+class Step(BaseModel):
+ """MCP步骤"""
+
+ tool_id: str = Field(description="工具ID")
+ description: str = Field(description="步骤描述")
diff --git a/apps/schemas/message.py b/apps/schemas/message.py
index d0661224e0817ff6f25e341d65c88903020d18e8..5b465ee54321bbd4c649753911025bff41840186 100644
--- a/apps/schemas/message.py
+++ b/apps/schemas/message.py
@@ -2,13 +2,19 @@
"""队列中的消息结构"""
from typing import Any
-
+from datetime import UTC, datetime
from pydantic import BaseModel, Field
-from apps.schemas.enum_var import EventType, StepStatus
+from apps.schemas.enum_var import EventType, FlowStatus, StepStatus
from apps.schemas.record import RecordMetadata
+class param(BaseModel):
+ """流执行过程中的参数补充"""
+ content: dict[str, Any] = Field(default={}, description="流执行过程中的参数补充内容")
+ description: str = Field(default="", description="流执行过程中的参数补充描述")
+
+
class HeartbeatData(BaseModel):
"""心跳事件的数据结构"""
@@ -22,8 +28,17 @@ class MessageFlow(BaseModel):
app_id: str = Field(description="插件ID", alias="appId")
flow_id: str = Field(description="Flow ID", alias="flowId")
+ flow_name: str = Field(description="Flow名称", alias="flowName")
+ flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN)
step_id: str = Field(description="当前步骤ID", alias="stepId")
step_name: str = Field(description="当前步骤名称", alias="stepName")
+ sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None)
+ sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None)
+ step_description: str | None = Field(
+ description="当前步骤描述",
+ alias="stepDescription",
+ default=None,
+ )
step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus")
@@ -60,17 +75,21 @@ class DocumentAddContent(BaseModel):
document_id: str = Field(description="文档UUID", alias="documentId")
document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder")
+ document_author: str = Field(description="文档作者", alias="documentAuthor", default="")
document_name: str = Field(description="文档名称", alias="documentName")
document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="")
document_type: str = Field(description="文档MIME类型", alias="documentType", default="")
document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0)
+ created_at: float = Field(
+ description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)
+ )
class FlowStartContent(BaseModel):
"""flow.start消息的content"""
question: str = Field(description="用户问题")
- params: dict[str, Any] = Field(description="预先提供的参数")
+ params: dict[str, Any] | None = Field(description="预先提供的参数", default=None)
class MessageBase(HeartbeatData):
@@ -81,5 +100,5 @@ class MessageBase(HeartbeatData):
conversation_id: str = Field(min_length=36, max_length=36, alias="conversationId")
task_id: str = Field(min_length=36, max_length=36, alias="taskId")
flow: MessageFlow | None = None
- content: dict[str, Any] = {}
+ content: Any | None = Field(default=None, description="消息内容")
metadata: MessageMetadata
diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd908d2375415c798f4804c4eabde797f6a4f7a0
--- /dev/null
+++ b/apps/schemas/parameters.py
@@ -0,0 +1,69 @@
+from enum import Enum
+
+
+class NumberOperate(str, Enum):
+ """Choice 工具支持的数字运算符"""
+
+ EQUAL = "number_equal"
+ NOT_EQUAL = "number_not_equal"
+ GREATER_THAN = "number_greater_than"
+ LESS_THAN = "number_less_than"
+ GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal"
+ LESS_THAN_OR_EQUAL = "number_less_than_or_equal"
+
+
+class StringOperate(str, Enum):
+ """Choice 工具支持的字符串运算符"""
+
+ EQUAL = "string_equal"
+ NOT_EQUAL = "string_not_equal"
+ CONTAINS = "string_contains"
+ NOT_CONTAINS = "string_not_contains"
+ STARTS_WITH = "string_starts_with"
+ ENDS_WITH = "string_ends_with"
+ LENGTH_EQUAL = "string_length_equal"
+ LENGTH_GREATER_THAN = "string_length_greater_than"
+ LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal"
+ LENGTH_LESS_THAN = "string_length_less_than"
+ LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal"
+ REGEX_MATCH = "string_regex_match"
+
+
+class ListOperate(str, Enum):
+ """Choice 工具支持的列表运算符"""
+
+ EQUAL = "list_equal"
+ NOT_EQUAL = "list_not_equal"
+ CONTAINS = "list_contains"
+ NOT_CONTAINS = "list_not_contains"
+ LENGTH_EQUAL = "list_length_equal"
+ LENGTH_GREATER_THAN = "list_length_greater_than"
+ LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal"
+ LENGTH_LESS_THAN = "list_length_less_than"
+ LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal"
+
+
+class BoolOperate(str, Enum):
+ """Choice 工具支持的布尔运算符"""
+
+ EQUAL = "bool_equal"
+ NOT_EQUAL = "bool_not_equal"
+
+
+class DictOperate(str, Enum):
+ """Choice 工具支持的字典运算符"""
+
+ EQUAL = "dict_equal"
+ NOT_EQUAL = "dict_not_equal"
+ CONTAINS_KEY = "dict_contains_key"
+ NOT_CONTAINS_KEY = "dict_not_contains_key"
+
+
+class Type(str, Enum):
+ """Choice 工具支持的类型"""
+
+ STRING = "string"
+ NUMBER = "number"
+ LIST = "list"
+ DICT = "dict"
+ BOOL = "bool"
diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py
index 27e16b370ec83acc11e1f435ac1da296fe2a9560..3532b6e2ffd31739e25e53414245b5d54fd395d0 100644
--- a/apps/schemas/pool.py
+++ b/apps/schemas/pool.py
@@ -109,4 +109,7 @@ class AppPool(BaseData):
permission: Permission = Field(description="应用权限配置", default=Permission())
flows: list[AppFlow] = Field(description="Flow列表", default=[])
hashes: dict[str, str] = Field(description="关联文件的hash值", default={})
- mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表")
+ mcp_service: list[str] = Field(default=[], description="MCP服务id列表")
+ llm_id: str = Field(
+ default="empty", description="应用使用的大模型ID(如果有的话)"
+ )
diff --git a/apps/schemas/record.py b/apps/schemas/record.py
index d7acd368205d0aa5a2b368edf4af3a31ed4d1ff6..b6b70dd6fbd5b89aafb06af5146b8feb1fb84862 100644
--- a/apps/schemas/record.py
+++ b/apps/schemas/record.py
@@ -10,15 +10,17 @@ from pydantic import BaseModel, Field
from apps.schemas.collection import (
Document,
)
-from apps.schemas.enum_var import CommentType, StepStatus
+from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus
class RecordDocument(Document):
"""GET /api/record/{conversation_id} Result中的document数据结构"""
id: str = Field(alias="_id", default="")
+ order: int = Field(default=0, description="文档顺序")
abstract: str = Field(default="", description="文档摘要")
user_sub: None = None
+ author: str = Field(default="", description="文档作者")
associated: Literal["question", "answer"]
class Config:
@@ -34,6 +36,7 @@ class RecordFlowStep(BaseModel):
step_status: StepStatus = Field(alias="stepStatus")
input: dict[str, Any]
output: dict[str, Any]
+ ex_data: dict[str, Any] | None = Field(default=None, alias="exData")
class RecordFlow(BaseModel):
@@ -42,6 +45,8 @@ class RecordFlow(BaseModel):
id: str
record_id: str = Field(alias="recordId")
flow_id: str = Field(alias="flowId")
+ flow_name: str = Field(alias="flowName", default="")
+ flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS)
step_num: int = Field(alias="stepNum")
steps: list[RecordFlowStep]
@@ -90,7 +95,7 @@ class RecordData(BaseModel):
id: str
group_id: str = Field(alias="groupId")
conversation_id: str = Field(alias="conversationId")
- task_id: str = Field(alias="taskId")
+ task_id: str | None = Field(default=None, alias="taskId")
document: list[RecordDocument] = []
flow: RecordFlow | None = None
content: RecordContent
@@ -103,11 +108,22 @@ class RecordGroupDocument(BaseModel):
"""RecordGroup关联的文件"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
+ order: int = Field(default=0, description="文档顺序")
+ author: str = Field(default="", description="文档作者")
name: str = Field(default="", description="文档名称")
abstract: str = Field(default="", description="文档摘要")
extension: str = Field(default="", description="文档扩展名")
size: int = Field(default=0, description="文档大小,单位是KB")
associated: Literal["question", "answer"]
+ created_at: float = Field(default=0.0, description="文档创建时间")
+
+
+class FlowHistory(BaseModel):
+ """Flow执行历史"""
+ flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
+ flow_name: str = Field(default="", description="Flow名称")
+ flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态")
+ history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表")
class Record(RecordData):
@@ -115,9 +131,11 @@ class Record(RecordData):
user_sub: str
key: dict[str, Any] = {}
- content: str
+ task_id: str | None = Field(default=None, description="任务ID")
+ content: str = Field(default="", description="Record内容,已加密")
comment: RecordComment = Field(default=RecordComment())
- flow: list[str] = Field(default=[])
+ flow: FlowHistory = Field(
+ default=FlowHistory(), description="Flow执行历史信息")
class RecordGroup(BaseModel):
@@ -134,5 +152,4 @@ class RecordGroup(BaseModel):
records: list[Record] = []
docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变
conversation_id: str
- task_id: str
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py
index 2305dd93dfe33969691dcc42928e578e7f1c4207..3fd5a67fa29af6e1685d55285ee1ebd458cabce2 100644
--- a/apps/schemas/request_data.py
+++ b/apps/schemas/request_data.py
@@ -10,14 +10,14 @@ from apps.schemas.appcenter import AppData
from apps.schemas.enum_var import CommentType
from apps.schemas.flow_topology import FlowItem
from apps.schemas.mcp import MCPType
+from apps.schemas.message import param
class RequestDataApp(BaseModel):
"""模型对话中包含的app信息"""
app_id: str = Field(description="应用ID", alias="appId")
- flow_id: str = Field(description="Flow ID", alias="flowId")
- params: dict[str, Any] = Field(description="插件参数")
+ flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId")
class MockRequestData(BaseModel):
@@ -39,13 +39,15 @@ class RequestDataFeatures(BaseModel):
class RequestData(BaseModel):
"""POST /api/chat 请求的总的数据结构"""
- question: str = Field(max_length=2000, description="用户输入")
- conversation_id: str = Field(default="", alias="conversationId", description="聊天ID")
+ question: str | None = Field(default=None, max_length=2000, description="用户输入")
+ conversation_id: str | None = Field(default=None, alias="conversationId", description="聊天ID")
group_id: str | None = Field(default=None, alias="groupId", description="问答组ID")
language: str = Field(default="zh", description="语言")
files: list[str] = Field(default=[], description="文件列表")
app: RequestDataApp | None = Field(default=None, description="应用")
debug: bool = Field(default=False, description="是否调试")
+ task_id: str | None = Field(default=None, alias="taskId", description="任务ID")
+ params: param | bool | None = Field(default=None, description="流执行过程中的参数补充", alias="params")
class QuestionBlacklistRequest(BaseModel):
@@ -98,7 +100,7 @@ class UpdateMCPServiceRequest(BaseModel):
name: str = Field(..., description="MCP服务名称")
description: str = Field(..., description="MCP服务描述")
overview: str = Field(..., description="MCP服务概述")
- config: str = Field(..., description="MCP服务配置")
+ config: dict[str, Any] = Field(..., description="MCP服务配置")
mcp_type: MCPType = Field(description="MCP传输协议(Stdio/SSE/Streamable)", default=MCPType.STDIO, alias="mcpType")
@@ -106,6 +108,7 @@ class ActiveMCPServiceRequest(BaseModel):
"""POST /api/mcp/{serviceId} 请求数据结构"""
active: bool = Field(description="是否激活mcp服务")
+ mcp_env: dict[str, Any] | None = Field(default=None, description="MCP服务环境变量", alias="mcpEnv")
class UpdateServiceRequest(BaseModel):
@@ -183,3 +186,9 @@ class UpdateKbReq(BaseModel):
"""更新知识库请求体"""
kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[])
+
+
+class UserUpdateRequest(BaseModel):
+ """更新用户信息请求体"""
+
+ auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute")
diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py
index b2a1872918638b153e1e94ff4db85ef00cfc0502..542dac88bb134ca79aa7934d8781c1dbd8a93cc3 100644
--- a/apps/schemas/response_data.py
+++ b/apps/schemas/response_data.py
@@ -14,10 +14,19 @@ from apps.schemas.flow_topology import (
NodeServiceItem,
PositionItem,
)
+from apps.schemas.parameters import (
+ Type,
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+)
from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType
from apps.schemas.record import RecordData
from apps.schemas.user import UserInfo
from apps.templates.generate_llm_operator_config import llm_provider_dict
+from apps.common.config import Config
class ResponseData(BaseModel):
@@ -47,6 +56,7 @@ class AuthUserMsg(BaseModel):
user_sub: str
revision: bool
is_admin: bool
+ auto_execute: bool
class AuthUserRsp(ResponseData):
@@ -90,7 +100,7 @@ class LLMIteam(BaseModel):
icon: str = Field(default=llm_provider_dict["ollama"]["icon"])
llm_id: str = Field(alias="llmId", default="empty")
- model_name: str = Field(alias="modelName", default="Ollama LLM")
+ model_name: str = Field(alias="modelName", default=Config().get_config().llm.model)
class KbIteam(BaseModel):
@@ -272,11 +282,21 @@ class BaseAppOperationRsp(ResponseData):
result: BaseAppOperationMsg
+class AppMcpServiceInfo(BaseModel):
+ """应用关联的MCP服务信息"""
+
+ id: str = Field(..., description="MCP服务ID")
+ name: str = Field(default="", description="MCP服务名称")
+ description: str = Field(default="", description="MCP服务简介")
+
+
class GetAppPropertyMsg(AppData):
"""GET /api/app/{appId} Result数据结构"""
app_id: str = Field(..., alias="appId", description="应用ID")
published: bool = Field(..., description="是否已发布")
+ mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务信息列表")
+ llm: LLMIteam | None = Field(alias="llm", default=None)
class GetAppPropertyRsp(ResponseData):
@@ -573,7 +593,7 @@ class FlowStructureDeleteRsp(ResponseData):
class UserGetMsp(BaseModel):
"""GET /api/user result"""
-
+ total: int = Field(default=0)
user_info_list: list[UserInfo] = Field(alias="userInfoList", default=[])
@@ -628,3 +648,42 @@ class ListLLMRsp(ResponseData):
"""GET /api/llm 返回数据结构"""
result: list[LLMProviderInfo] = Field(default=[], title="Result")
+
+
+class ParamsNode(BaseModel):
+ """参数数据结构"""
+ param_name: str = Field(..., description="参数名称", alias="paramName")
+ param_path: str = Field(..., description="参数路径", alias="paramPath")
+ param_type: Type = Field(..., description="参数类型", alias="paramType")
+ sub_params: list["ParamsNode"] | None = Field(
+ default=None, description="子参数列表", alias="subParams"
+ )
+
+
+class StepParams(BaseModel):
+ """参数数据结构"""
+ step_id: str = Field(..., description="步骤ID", alias="stepId")
+ name: str = Field(..., description="Step名称")
+ params_node: ParamsNode | None = Field(
+ default=None, description="参数节点", alias="paramsNode")
+
+
+class GetParamsRsp(ResponseData):
+ """GET /api/params 返回数据结构"""
+
+ result: list[StepParams] = Field(
+ default=[], description="参数列表", alias="result"
+ )
+
+
+class OperateAndBindType(BaseModel):
+ """操作和绑定类型数据结构"""
+
+ operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型")
+ bind_type: Type = Field(description="绑定类型")
+
+
+class GetOperaRsp(ResponseData):
+ """GET /api/operate 返回数据结构"""
+
+ result: list[OperateAndBindType] = Field(..., title="Result")
diff --git a/apps/schemas/task.py b/apps/schemas/task.py
index 8efcb59914d568478c65d611670b251d019b38ed..d3e7b036c35f8ecf53a204b8b8a28e4600738184 100644
--- a/apps/schemas/task.py
+++ b/apps/schemas/task.py
@@ -7,8 +7,9 @@ from typing import Any
from pydantic import BaseModel, Field
-from apps.schemas.enum_var import StepStatus
+from apps.schemas.enum_var import FlowStatus, StepStatus
from apps.schemas.flow import Step
+from apps.schemas.mcp import MCPPlan
class FlowStepHistory(BaseModel):
@@ -22,12 +23,14 @@ class FlowStepHistory(BaseModel):
task_id: str = Field(description="任务ID")
flow_id: str = Field(description="FlowID")
flow_name: str = Field(description="Flow名称")
+ flow_status: FlowStatus = Field(description="Flow状态")
step_id: str = Field(description="当前步骤名称")
step_name: str = Field(description="当前步骤名称")
- step_description: str = Field(description="当前步骤描述")
- status: StepStatus = Field(description="当前步骤状态")
+ step_description: str = Field(description="当前步骤描述", default="")
+ step_status: StepStatus = Field(description="当前步骤状态")
input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={})
output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={})
+ ex_data: dict[str, Any] | None = Field(description="额外数据", default=None)
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
@@ -35,16 +38,21 @@ class ExecutorState(BaseModel):
"""FlowExecutor状态"""
# 执行器级数据
- flow_id: str = Field(description="Flow ID")
- flow_name: str = Field(description="Flow名称")
- description: str = Field(description="Flow描述")
- status: StepStatus = Field(description="Flow执行状态")
- # 附加信息
- step_id: str = Field(description="当前步骤ID")
- step_name: str = Field(description="当前步骤名称")
- app_id: str = Field(description="应用ID")
- slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
- error_info: dict[str, Any] = Field(description="错误信息", default={})
+ flow_id: str = Field(description="Flow ID", default="")
+ flow_name: str = Field(description="Flow名称", default="")
+ description: str = Field(description="Flow描述", default="")
+ flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT)
+ # 任务级数据
+ step_cnt: int = Field(description="当前步骤数量", default=0)
+ step_id: str = Field(description="当前步骤ID", default="")
+ tool_id: str = Field(description="当前工具ID", default="")
+ step_name: str = Field(description="当前步骤名称", default="")
+ step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN)
+ step_description: str = Field(description="当前步骤描述", default="")
+ app_id: str = Field(description="应用ID", default="")
+ current_input: dict[str, Any] = Field(description="当前输入数据", default={})
+ error_message: str = Field(description="错误信息", default="")
+ retry_times: int = Field(description="当前步骤重试次数", default=0)
class TaskIds(BaseModel):
@@ -55,6 +63,7 @@ class TaskIds(BaseModel):
conversation_id: str = Field(description="对话ID")
record_id: str = Field(description="记录ID", default_factory=lambda: str(uuid.uuid4()))
user_sub: str = Field(description="用户ID")
+ active_id: str = Field(description="活动ID", default_factory=lambda: str(uuid.uuid4()))
class TaskTokens(BaseModel):
@@ -86,10 +95,11 @@ class Task(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
ids: TaskIds = Field(description="任务涉及的各种ID")
- context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[])
- state: ExecutorState | None = Field(description="Flow的状态", default=None)
+ context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[])
+ state: ExecutorState = Field(description="Flow的状态", default=ExecutorState())
tokens: TaskTokens = Field(description="Token信息")
runtime: TaskRuntime = Field(description="任务运行时数据")
+ language: str = Field(description="语言", default="zh")
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
diff --git a/apps/schemas/user.py b/apps/schemas/user.py
index 61aa2587b8ae6255dc3885b04a96f12310f70773..ebea66924c062f4e680ad2e414768b307b056f4a 100644
--- a/apps/schemas/user.py
+++ b/apps/schemas/user.py
@@ -9,3 +9,4 @@ class UserInfo(BaseModel):
user_sub: str = Field(alias="userSub", default="")
user_name: str = Field(alias="userName", default="")
+ auto_execute: bool | None = Field(alias="autoExecute", default=False)
diff --git a/apps/services/activity.py b/apps/services/activity.py
index 299a49a640664b29a3e4724c5842f12bd289e3bf..7ba20cbb4300a5fa5cb60e37cdb887f38d54232f 100644
--- a/apps/services/activity.py
+++ b/apps/services/activity.py
@@ -3,70 +3,65 @@
import uuid
from datetime import UTC, datetime
-
+import logging
from apps.common.mongo import MongoDB
from apps.constants import SLIDE_WINDOW_QUESTION_COUNT, SLIDE_WINDOW_TIME
from apps.exceptions import ActivityError
+logger = logging.getLogger(__name__)
+
class Activity:
- """用户活动控制,限制单用户同一时间只能提问一个问题"""
+ """用户活动控制,限制单用户同一时间只能有SLIDE_WINDOW_QUESTION_COUNT个请求"""
@staticmethod
- async def is_active(user_sub: str) -> bool:
+ async def is_active(active_id: str) -> bool:
"""
判断当前用户是否正在提问(占用GPU资源)
:param user_sub: 用户实体ID
:return: 判断结果,正在提问则返回True
"""
- time = round(datetime.now(UTC).timestamp(), 3)
-
- # 检查窗口内总请求数
- count = await MongoDB().get_collection("activity").count_documents(
- {"timestamp": {"$gte": time - SLIDE_WINDOW_TIME, "$lte": time}},
- )
- if count >= SLIDE_WINDOW_QUESTION_COUNT:
- return True
# 检查用户是否正在提问
active = await MongoDB().get_collection("activity").find_one(
- {"user_sub": user_sub},
+ {"_id": active_id},
)
return bool(active)
@staticmethod
- async def set_active(user_sub: str) -> None:
+ async def set_active(user_sub: str) -> str:
"""设置用户的活跃标识"""
time = round(datetime.now(UTC).timestamp(), 3)
# 设置用户活跃状态
collection = MongoDB().get_collection("activity")
- active = await collection.find_one({"user_sub": user_sub})
- if active:
- err = "用户正在提问"
+ # 查看用户活跃标识是否在滑动窗口内
+
+ if await collection.count_documents({"user_sub": user_sub, "timestamp": {"$gt": time - SLIDE_WINDOW_TIME}}) >= SLIDE_WINDOW_QUESTION_COUNT:
+ err = "[Activity] 用户在滑动窗口内提问次数超过限制,请稍后再试。"
raise ActivityError(err)
+ await collection.delete_many(
+ {"user_sub": user_sub, "timestamp": {"$lte": time - SLIDE_WINDOW_TIME}},
+ )
+ # 插入新的活跃记录
+ tmp_record = {
+ "_id": str(uuid.uuid4()),
+ "user_sub": user_sub,
+ "timestamp": time,
+ }
await collection.insert_one(
- {
- "_id": str(uuid.uuid4()),
- "user_sub": user_sub,
- "timestamp": time,
- },
+ tmp_record
)
+ return tmp_record["_id"]
@staticmethod
- async def remove_active(user_sub: str) -> None:
+ async def remove_active(active_id: str) -> None:
"""
清除用户的活跃标识,释放GPU资源
:param user_sub: 用户实体ID
"""
- time = round(datetime.now(UTC).timestamp(), 3)
# 清除用户当前活动标识
await MongoDB().get_collection("activity").delete_one(
- {"user_sub": user_sub},
- )
-
- # 清除超出窗口范围的请求记录
- await MongoDB().get_collection("activity").delete_many(
- {"timestamp": {"$lte": time - SLIDE_WINDOW_TIME}},
+ {"_id": active_id},
)
diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py
index e256ab55ab13bf8c5b9e40cb36a3c616a1d39596..4e0d00d06254973d284b149faaa632d969700981 100644
--- a/apps/services/appcenter.py
+++ b/apps/services/appcenter.py
@@ -59,7 +59,6 @@ class AppCenterManager:
}
user_favorite_app_ids = await AppCenterManager._get_favorite_app_ids_by_user(user_sub)
-
if filter_type == AppFilterType.ALL:
# 获取所有已发布的应用
filters["published"] = True
@@ -72,7 +71,6 @@ class AppCenterManager:
"_id": {"$in": user_favorite_app_ids},
"published": True,
}
-
# 添加关键字搜索条件
if keyword:
filters["$or"] = [
@@ -84,8 +82,8 @@ class AppCenterManager:
# 添加应用类型过滤条件
if app_type is not None:
filters["app_type"] = app_type.value
-
# 获取应用列表
+ logger.error(f"[AppCenterManager] 搜索条件: {filters}, 页码: {page}, 每页大小: {SERVICE_PAGE_SIZE}")
apps, total_apps = await AppCenterManager._search_apps_by_filter(filters, page, SERVICE_PAGE_SIZE)
# 构建返回的应用卡片列表
@@ -484,7 +482,19 @@ class AppCenterManager:
else:
# 在预期的条件下,如果在 data 或 app_data 中找不到 mcp_service,则默认回退为空列表。
metadata.mcp_service = []
-
+ # 处理llm_id字段
+ if data is not None and hasattr(data, "llm") and data.llm:
+ # 创建应用场景,验证传入的 llm_id 状态 (create_app)
+ metadata.llm_id = data.llm if data.llm else "empty"
+ elif data is not None and hasattr(data, "llm_id"):
+ # 更新应用场景,使用 data 中的 llm_id (update_app)
+ metadata.llm_id = data.llm if data.llm else "empty"
+ elif app_data is not None and hasattr(app_data, "llm_id"):
+ # 更新应用发布状态场景,使用 app_data 中的 llm_id (update_app_publish_status)
+ metadata.llm_id = app_data.llm_id if app_data.llm_id else "empty"
+ else:
+ # 在预期的条件下,如果在 data 或 app_data 中找不到 llm_id,则默认回退为 "empty"。
+ metadata.llm_id = "empty"
# Agent 应用的发布状态逻辑
if published is not None: # 从 update_app_publish_status 调用,'published' 参数已提供
metadata.published = published
diff --git a/apps/services/conversation.py b/apps/services/conversation.py
index 4bcade45c757ea2b3dca6c58863b2852af98c15b..6bacb72720dbe43ff6835d4714c720f2778d85b1 100644
--- a/apps/services/conversation.py
+++ b/apps/services/conversation.py
@@ -59,7 +59,11 @@ class ConversationManager:
model_name=llm.model_name,
)
kb_item_list = []
- team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ try:
+ team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ except:
+ logger.error("[ConversationManager] 获取团队知识库列表失败")
+ team_kb_list = []
for team_kb in team_kb_list:
for kb in team_kb["kbList"]:
if str(kb["kbId"]) in kb_ids:
@@ -136,4 +140,4 @@ class ConversationManager:
await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session)
await session.commit_transaction()
- await TaskManager.delete_tasks_by_conversation_id(conversation_id)
+ await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id)
diff --git a/apps/services/document.py b/apps/services/document.py
index 203162da1137cdd69d900ceea65a3029d2f4fd8e..451423a9d8f752c2abbf97f5b2a1df38e6d7fefe 100644
--- a/apps/services/document.py
+++ b/apps/services/document.py
@@ -2,6 +2,7 @@
"""文件Manager"""
import base64
+from datetime import UTC, datetime
import logging
import uuid
@@ -131,12 +132,15 @@ class DocumentManager:
return [
RecordDocument(
_id=doc.id,
+ order=doc.order,
+ author=doc.author,
abstract=doc.abstract,
name=doc.name,
type=doc.extension,
size=doc.size,
conversation_id=record_group.get("conversation_id", ""),
associated=doc.associated,
+ created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3)
)
for doc in docs if type is None or doc.associated == type
]
diff --git a/apps/services/flow.py b/apps/services/flow.py
index 9275fc60c75199e8e17ebcba8f164083190b7211..4d682e5a85677b3a526ec9d6d2db521c238357b4 100644
--- a/apps/services/flow.py
+++ b/apps/services/flow.py
@@ -20,7 +20,6 @@ from apps.schemas.flow_topology import (
PositionItem,
)
from apps.services.node import NodeManager
-
logger = logging.getLogger(__name__)
@@ -258,10 +257,7 @@ class FlowManager:
)
for node_id, node_config in flow_config.steps.items():
input_parameters = node_config.params
- if node_config.node not in ("Empty"):
- _, output_parameters = await NodeManager.get_node_params(node_config.node)
- else:
- output_parameters = {}
+ _, output_parameters = await NodeManager.get_node_params(node_config.node)
parameters = {
"input_parameters": input_parameters,
"output_parameters": Slot(output_parameters).extract_type_desc_from_schema(),
@@ -414,6 +410,7 @@ class FlowManager:
flow_config.debug = await FlowManager.is_flow_config_equal(old_flow_config, flow_config)
else:
flow_config.debug = False
+ logger.error(f'{flow_config}')
await flow_loader.save(app_id, flow_id, flow_config)
except Exception:
logger.exception("[FlowManager] 存储/更新流失败")
diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py
index 78e8d340e955115a51e24ac8fc9081fd8ff00e7e..16c23053705c158c4d4f7e0faa438fe3927c0d63 100644
--- a/apps/services/flow_validate.py
+++ b/apps/services/flow_validate.py
@@ -4,6 +4,7 @@
import collections
import logging
+from apps.schemas.enum_var import SpecialCallType
from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError
from apps.schemas.enum_var import NodeType
from apps.schemas.flow_topology import EdgeItem, FlowItem, NodeItem
@@ -38,32 +39,40 @@ class FlowService:
for node in flow_item.nodes:
from apps.scheduler.pool.pool import Pool
from pydantic import BaseModel
- if node.node_id != 'start' and node.node_id != 'end' and node.node_id != 'Empty':
+ if node.node_id != 'start' and node.node_id != 'end' and node.node_id != SpecialCallType.EMPTY.value:
try:
call_class: type[BaseModel] = await Pool().get_call(node.call_id)
if not call_class:
- node.node_id = 'Empty'
+ node.node_id = SpecialCallType.EMPTY.value
node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description
except Exception as e:
- node.node_id = 'Empty'
+ node.node_id = SpecialCallType.EMPTY.value
node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description
logger.error(f"[FlowService] 获取步骤的call_id失败{node.call_id}由于:{e}")
node_branch_map[node.step_id] = set()
if node.call_id == NodeType.CHOICE.value:
- node.parameters = node.parameters["input_parameters"]
- if "choices" not in node.parameters:
- node.parameters["choices"] = []
- for choice in node.parameters["choices"]:
- if choice["branchId"] in node_branch_map[node.step_id]:
- err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}重复"
+ input_parameters = node.parameters["input_parameters"]
+ if "choices" not in input_parameters:
+ logger.error(f"[FlowService] 节点{node.name}的分支字段缺失")
+ raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段缺失")
+ if not input_parameters["choices"]:
+ logger.error(f"[FlowService] 节点{node.name}的分支字段为空")
+ raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段为空")
+ for choice in input_parameters["choices"]:
+ if "branch_id" not in choice:
+ err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段"
+ logger.error(err)
+ raise FlowBranchValidationError(err)
+ if choice["branch_id"] in node_branch_map[node.step_id]:
+ err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}重复"
logger.error(err)
raise Exception(err)
for illegal_char in branch_illegal_chars:
- if illegal_char in choice["branchId"]:
- err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}名称中含有非法字符"
+ if illegal_char in choice["branch_id"]:
+ err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}名称中含有非法字符"
logger.error(err)
raise Exception(err)
- node_branch_map[node.step_id].add(choice["branchId"])
+ node_branch_map[node.step_id].add(choice["branch_id"])
else:
node_branch_map[node.step_id].add("")
valid_edges = []
@@ -133,7 +142,6 @@ class FlowService:
branches = {}
in_deg = {}
out_deg = {}
-
for e in edges:
if e.edge_id in ids:
err = f"[FlowService] 边{e.edge_id}的id重复"
diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py
index 9b4077f92cf44bca42d321cc20975d7ea2e5cadd..bd8dfc9e808f642a8eecf493a96f762dad8ae7b9 100644
--- a/apps/services/knowledge.py
+++ b/apps/services/knowledge.py
@@ -138,7 +138,11 @@ class KnowledgeBaseManager:
return []
kb_ids_update_success = []
kb_item_dict_list = []
- team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ try:
+ team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None)
+ except Exception as e:
+ logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}")
+ team_kb_list = []
for team_kb in team_kb_list:
for kb in team_kb["kbList"]:
if str(kb["kbId"]) in kb_ids:
diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py
index 7cb880c0f08b1cc7727bde528be39ac748f07ae2..3c7ab3c5b42615a6f4005d5c032dd81481933511 100644
--- a/apps/services/mcp_service.py
+++ b/apps/services/mcp_service.py
@@ -2,6 +2,7 @@
"""MCP服务管理器"""
import logging
+from logging import config
import random
import re
from typing import Any
@@ -28,8 +29,10 @@ from apps.schemas.mcp import (
MCPTool,
MCPType,
)
+from apps.services.user import UserManager
from apps.schemas.request_data import UpdateMCPServiceRequest
from apps.schemas.response_data import MCPServiceCardItem
+from apps.constants import MCP_PATH
logger = logging.getLogger(__name__)
sqids = Sqids(min_length=6)
@@ -66,10 +69,8 @@ class MCPServiceManager:
mcp_list = await mcp_collection.find({"_id": mcp_id}, {"status": True}).to_list(None)
for db_item in mcp_list:
status = db_item.get("status")
- if MCPInstallStatus.READY.value == status:
- return MCPInstallStatus.READY
- if MCPInstallStatus.INSTALLING.value == status:
- return MCPInstallStatus.INSTALLING
+ if status in MCPInstallStatus.__members__.values():
+ return status
return MCPInstallStatus.FAILED
@staticmethod
@@ -78,6 +79,8 @@ class MCPServiceManager:
user_sub: str,
keyword: str | None,
page: int,
+ is_install: bool | None = None,
+ is_active: bool | None = None,
) -> list[MCPServiceCardItem]:
"""
获取所有MCP服务列表
@@ -89,6 +92,17 @@ class MCPServiceManager:
:return: MCP服务列表
"""
filters = MCPServiceManager._build_filters(search_type, keyword)
+ if is_active is not None:
+ if is_active:
+ filters["activated"] = {"$in": [user_sub]}
+ else:
+ filters["activated"] = {"$nin": [user_sub]}
+ if not is_install:
+ user_info = await UserManager.get_userinfo_by_user_sub(user_sub)
+ if not user_info.is_admin:
+ filters["status"] = MCPInstallStatus.READY.value
+ else:
+ filters["status"] = MCPInstallStatus.READY.value
mcpservice_pools = await MCPServiceManager._search_mcpservice(filters, page)
return [
MCPServiceCardItem(
@@ -198,7 +212,6 @@ class MCPServiceManager:
base_filters = {"author": {"$regex": keyword, "$options": "i"}}
return base_filters
-
@staticmethod
async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str:
"""
@@ -209,9 +222,9 @@ class MCPServiceManager:
"""
# 检查config
if data.mcp_type == MCPType.SSE:
- config = MCPServerSSEConfig.model_validate_json(data.config)
+ config = MCPServerSSEConfig.model_validate(data.config)
else:
- config = MCPServerStdioConfig.model_validate_json(data.config)
+ config = MCPServerStdioConfig.model_validate(data.config)
# 构造Server
mcp_server = MCPServerConfig(
@@ -233,8 +246,24 @@ class MCPServiceManager:
# 保存并载入配置
logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name)
+ mcp_path = MCP_PATH / "template" / mcp_id / "project"
+ if isinstance(config, MCPServerStdioConfig):
+ index = None
+ for i in range(len(config.args)):
+ if not config.args[i] == "--directory":
+ continue
+ index = i + 1
+ break
+ if index is not None:
+ if index >= len(config.args):
+ config.args.append(str(mcp_path))
+ else:
+ config.args[index] = str(mcp_path)
+ else:
+ config.args += ["--directory", str(mcp_path)]
+ await MCPLoader._insert_template_db(mcp_id=mcp_id, config=mcp_server)
await MCPLoader.save_one(mcp_id, mcp_server)
- await MCPLoader.init_one_template(mcp_id=mcp_id, config=mcp_server)
+ await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.INIT)
return mcp_id
@staticmethod
@@ -258,19 +287,21 @@ class MCPServiceManager:
db_service = MCPCollection.model_validate(db_service)
for user_id in db_service.activated:
await MCPServiceManager.deactive_mcpservice(user_sub=user_id, service_id=data.service_id)
-
- await MCPLoader.init_one_template(mcp_id=data.service_id, config=MCPServerConfig(
+ mcp_config = MCPServerConfig(
name=data.name,
overview=data.overview,
description=data.description,
- config=MCPServerStdioConfig.model_validate_json(
+ config=MCPServerStdioConfig.model_validate(
data.config,
- ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate_json(
+ ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate(
data.config,
),
type=data.mcp_type,
author=user_sub,
- ))
+ )
+ await MCPLoader._insert_template_db(mcp_id=data.service_id, config=mcp_config)
+ await MCPLoader.save_one(mcp_id=data.service_id, config=mcp_config)
+ await MCPLoader.update_template_status(data.service_id, MCPInstallStatus.INIT)
# 返回服务ID
return data.service_id
@@ -297,6 +328,7 @@ class MCPServiceManager:
async def active_mcpservice(
user_sub: str,
service_id: str,
+ mcp_env: dict[str, Any] | None = None,
) -> None:
"""
激活MCP服务
@@ -310,7 +342,7 @@ class MCPServiceManager:
for item in status:
mcp_status = item.get("status", MCPInstallStatus.INSTALLING)
if mcp_status == MCPInstallStatus.READY:
- await MCPLoader.user_active_template(user_sub, service_id)
+ await MCPLoader.user_active_template(user_sub, service_id, mcp_env)
else:
err = "[MCPServiceManager] MCP服务未准备就绪"
raise RuntimeError(err)
@@ -371,3 +403,28 @@ class MCPServiceManager:
image.save(MCP_ICON_PATH / f"{service_id}.png", format="PNG", optimize=True, compress_level=9)
return f"/static/mcp/{service_id}.png"
+
+ @staticmethod
+ async def install_mcpservice(user_sub: str, service_id: str, install: bool) -> None:
+ """
+ 安装或卸载MCP服务
+
+ :param user_sub: str: 用户ID
+ :param service_id: str: MCP服务ID
+ :param install: bool: 是否安装
+ :return: 无
+ """
+ service_collection = MongoDB().get_collection("mcp")
+ db_service = await service_collection.find_one({"_id": service_id, "author": user_sub})
+ db_service = MCPCollection.model_validate(db_service)
+ if install:
+ if db_service.status == MCPInstallStatus.INSTALLING or db_service.status == MCPInstallStatus.READY:
+ err = "[MCPServiceManager] MCP服务已处于安装中或已准备就绪"
+ raise Exception(err)
+ mcp_config = await MCPLoader.get_config(service_id)
+ await MCPLoader.init_one_template(mcp_id=service_id, config=mcp_config)
+ else:
+ if db_service.status != MCPInstallStatus.INSTALLING:
+ err = "[MCPServiceManager] 只能卸载处于安装中的MCP服务"
+ raise Exception(err)
+ await MCPLoader.cancel_installing_task([service_id])
diff --git a/apps/services/node.py b/apps/services/node.py
index 6f0d492edacdd7611324a38953417e0fa769249b..3fb311bf67b14de543567bf559540ced7aa8b11e 100644
--- a/apps/services/node.py
+++ b/apps/services/node.py
@@ -4,6 +4,7 @@
import logging
from typing import TYPE_CHECKING, Any
+from apps.schemas.enum_var import SpecialCallType
from apps.common.mongo import MongoDB
from apps.schemas.node import APINode
from apps.schemas.pool import NodePool
@@ -16,6 +17,7 @@ NODE_TYPE_MAP = {
"API": APINode,
}
+
class NodeManager:
"""Node管理器"""
@@ -29,7 +31,6 @@ class NodeManager:
raise ValueError(err)
return node["call_id"]
-
@staticmethod
async def get_node(node_id: str) -> NodePool:
"""获取Node的类型"""
@@ -40,7 +41,6 @@ class NodeManager:
raise ValueError(err)
return NodePool.model_validate(node)
-
@staticmethod
async def get_node_name(node_id: str) -> str:
"""获取node的名称"""
@@ -52,7 +52,6 @@ class NodeManager:
return ""
return node_doc["name"]
-
@staticmethod
def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]:
"""递归合并参数Schema,将known_params中的值填充到params_schema的对应位置"""
@@ -75,12 +74,13 @@ class NodeManager:
return params_schema
-
@staticmethod
async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]:
"""获取Node数据"""
from apps.scheduler.pool.pool import Pool
-
+ if node_id == SpecialCallType.EMPTY.value:
+ # 如果是空节点,返回空Schema
+ return {}, {}
# 查找Node信息
logger.info("[NodeManager] 获取节点 %s", node_id)
node_collection = MongoDB().get_collection("node")
@@ -100,7 +100,6 @@ class NodeManager:
err = f"[NodeManager] Call {call_id} 不存在"
logger.error(err)
raise ValueError(err)
-
# 返回参数Schema
return (
NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}),
diff --git a/apps/services/parameter.py b/apps/services/parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b85c62e341ebaedb3af025ca59d869e82d9a46c
--- /dev/null
+++ b/apps/services/parameter.py
@@ -0,0 +1,89 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""flow Manager"""
+
+import logging
+
+from pymongo import ASCENDING
+
+from apps.services.node import NodeManager
+from apps.schemas.flow_topology import FlowItem
+from apps.scheduler.slot.slot import Slot
+from apps.scheduler.call.choice.condition_handler import ConditionHandler
+from apps.scheduler.call.choice.schema import (
+ NumberOperate,
+ StringOperate,
+ ListOperate,
+ BoolOperate,
+ DictOperate,
+ Type
+)
+from apps.schemas.response_data import (
+ OperateAndBindType,
+ ParamsNode,
+ StepParams,
+)
+from apps.services.node import NodeManager
+logger = logging.getLogger(__name__)
+
+
+class ParameterManager:
+ """Parameter Manager"""
+ @staticmethod
+ async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]:
+ """Get operate and bind type"""
+ result = []
+ operate = None
+ if param_type == Type.NUMBER:
+ operate = NumberOperate
+ elif param_type == Type.STRING:
+ operate = StringOperate
+ elif param_type == Type.LIST:
+ operate = ListOperate
+ elif param_type == Type.BOOL:
+ operate = BoolOperate
+ elif param_type == Type.DICT:
+ operate = DictOperate
+ if operate:
+ for item in operate:
+ result.append(OperateAndBindType(
+ operate=item,
+ bind_type=(await ConditionHandler.get_value_type_from_operate(item))))
+ return result
+
+ @staticmethod
+ async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]:
+ """Get pre params by flow and step id"""
+ index = 0
+ q = [step_id]
+ in_edges = {}
+ step_id_to_node_id = {}
+ step_id_to_node_name = {}
+ for step in flow.nodes:
+ step_id_to_node_id[step.step_id] = step.node_id
+ step_id_to_node_name[step.step_id] = step.name
+ for edge in flow.edges:
+ if edge.target_node not in in_edges:
+ in_edges[edge.target_node] = []
+ in_edges[edge.target_node].append(edge.source_node)
+ while index < len(q):
+ tmp_step_id = q[index]
+ index += 1
+ for i in range(len(in_edges.get(tmp_step_id, []))):
+ pre_node_id = in_edges[tmp_step_id][i]
+ if pre_node_id not in q:
+ q.append(pre_node_id)
+ pre_step_params = []
+ for i in range(1, len(q)):
+ step_id = q[i]
+ node_id = step_id_to_node_id.get(step_id)
+ _, output_schema = await NodeManager.get_node_params(node_id)
+ slot = Slot(output_schema)
+ params_node = slot.get_params_node_from_schema()
+ pre_step_params.append(
+ StepParams(
+ stepId=step_id,
+ name=step_id_to_node_name.get(step_id),
+ paramsNode=params_node
+ )
+ )
+ return pre_step_params
diff --git a/apps/services/rag.py b/apps/services/rag.py
index 6b6c843dc6fc09a14809ef0a11bc44b2da434f29..8b95d2b1b3c9fd20df54290b1b2534f8114a2e56 100644
--- a/apps/services/rag.py
+++ b/apps/services/rag.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""对接Euler Copilot RAG"""
+from datetime import UTC, datetime
import json
import logging
from collections.abc import AsyncGenerator
@@ -17,7 +18,6 @@ from apps.schemas.collection import LLM
from apps.schemas.config import LLMConfig
from apps.schemas.enum_var import EventType
from apps.schemas.rag_data import RAGQueryReq
-from apps.services.activity import Activity
from apps.services.session import SessionManager
logger = logging.getLogger(__name__)
@@ -156,9 +156,11 @@ class RAG:
"id": doc_chunk["docId"],
"order": doc_cnt,
"name": doc_chunk.get("docName", ""),
+ "author": doc_chunk.get("docAuthor", ""),
"extension": doc_chunk.get("docExtension", ""),
"abstract": doc_chunk.get("docAbstract", ""),
"size": doc_chunk.get("docSize", 0),
+ "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)),
})
doc_id_map[doc_chunk["docId"]] = doc_cnt
doc_index = doc_id_map[doc_chunk["docId"]]
@@ -254,8 +256,6 @@ class RAG:
result_only=False,
model=llm.model_name,
):
- if not await Activity.is_active(user_sub):
- return
chunk = buffer + chunk
# 防止脚注被截断
if len(chunk) >= 2 and chunk[-2:] != "]]":
diff --git a/apps/services/record.py b/apps/services/record.py
index 5c8a89dfd224166ac822e2eeb47fb79f1fefe7e4..cf8373b042ab409662dcf1d2e08ebc8ac3666d1b 100644
--- a/apps/services/record.py
+++ b/apps/services/record.py
@@ -9,7 +9,7 @@ from apps.schemas.record import (
Record,
RecordGroup,
)
-
+from apps.schemas.enum_var import FlowStatus
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@ class RecordManager:
"""问答对相关操作"""
@staticmethod
- async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None:
+ async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None:
"""创建问答组"""
mongo = MongoDB()
record_group_collection = mongo.get_collection("record_group")
@@ -26,7 +26,6 @@ class RecordManager:
_id=group_id,
user_sub=user_sub,
conversation_id=conversation_id,
- task_id=task_id,
)
try:
@@ -49,6 +48,10 @@ class RecordManager:
mongo = MongoDB()
group_collection = mongo.get_collection("record_group")
try:
+ await group_collection.update_one(
+ {"_id": group_id, "user_sub": user_sub},
+ {"$pull": {"records": {"id": record.id}}}
+ )
await group_collection.update_one(
{"_id": group_id, "user_sub": user_sub},
{"$push": {"records": record.model_dump(by_alias=True)}},
@@ -133,6 +136,19 @@ class RecordManager:
logger.exception("[RecordManager] 查询问答组失败")
return []
+ @staticmethod
+ async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None:
+ """更新Record关联的Flow状态"""
+ record_group_collection = MongoDB().get_collection("record_group")
+ try:
+ await record_group_collection.update_many(
+ {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}},
+ {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}},
+ array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}],
+ )
+ except Exception:
+ logger.exception("[RecordManager] 更新Record关联的Flow状态失败")
+
@staticmethod
async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool:
"""
@@ -151,7 +167,6 @@ class RecordManager:
logger.exception("[RecordManager] 验证记录是否在组中失败")
return False
-
@staticmethod
async def check_group_id(group_id: str, user_sub: str) -> bool:
"""检查group_id是否存在"""
diff --git a/apps/services/task.py b/apps/services/task.py
index 1e672be690a6f17896ab01bc7149aa353987ecf7..49c9f4b037cb95408588ce304d0f483af88b4ed8 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -13,6 +13,7 @@ from apps.schemas.task import (
TaskIds,
TaskRuntime,
TaskTokens,
+ FlowStepHistory
)
from apps.services.record import RecordManager
@@ -45,7 +46,6 @@ class TaskManager:
return Task.model_validate(task)
-
@staticmethod
async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None:
"""获取组ID的最后一条问答组关联的任务"""
@@ -58,7 +58,6 @@ class TaskManager:
task = await task_collection.find_one({"_id": record_group_obj.task_id})
return Task.model_validate(task)
-
@staticmethod
async def get_task_by_task_id(task_id: str) -> Task | None:
"""根据task_id获取任务"""
@@ -68,9 +67,8 @@ class TaskManager:
return None
return Task.model_validate(task)
-
@staticmethod
- async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]:
+ async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]:
"""根据record_group_id获取flow信息"""
record_group_collection = MongoDB().get_collection("record_group")
flow_context_collection = MongoDB().get_collection("flow_context")
@@ -85,59 +83,72 @@ class TaskManager:
return []
flow_context_list = []
- for flow_context_id in records[0]["records"]["flow"]:
+ for flow_context_id in records[0]["records"]["flow"]["history_ids"]:
flow_context = await flow_context_collection.find_one({"_id": flow_context_id})
if flow_context:
- flow_context_list.append(flow_context)
+ flow_context_list.append(FlowStepHistory.model_validate(flow_context))
except Exception:
logger.exception("[TaskManager] 获取record_id的flow信息失败")
return []
else:
return flow_context_list
-
@staticmethod
- async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]:
+ async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[FlowStepHistory]:
"""根据task_id获取flow信息"""
flow_context_collection = MongoDB().get_collection("flow_context")
flow_context = []
try:
- async for history in flow_context_collection.find(
- {"task_id": task_id},
- ).sort(
- "created_at", -1,
- ).limit(length):
- flow_context += [history]
+ if length is None:
+ async for context in flow_context_collection.find({"task_id": task_id}):
+ flow_context.append(FlowStepHistory.model_validate(context))
+ else:
+ async for context in flow_context_collection.find({"task_id": task_id}).limit(length):
+ flow_context.append(FlowStepHistory.model_validate(context))
except Exception:
logger.exception("[TaskManager] 获取task_id的flow信息失败")
return []
else:
return flow_context
+ @staticmethod
+ async def init_new_task(
+ user_sub: str,
+ session_id: str | None = None,
+ post_body: RequestData | None = None,
+ ) -> Task:
+ """获取任务块"""
+ return Task(
+ _id=str(uuid.uuid4()),
+ ids=TaskIds(
+ user_sub=user_sub if user_sub else "",
+ session_id=session_id if session_id else "",
+ conversation_id=post_body.conversation_id,
+ group_id=post_body.group_id if post_body.group_id else "",
+ ),
+ question=post_body.question if post_body else "",
+ group_id=post_body.group_id if post_body else "",
+ tokens=TaskTokens(),
+ runtime=TaskRuntime(),
+ )
@staticmethod
- async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None:
+ async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None:
"""保存flow信息到flow_context"""
flow_context_collection = MongoDB().get_collection("flow_context")
try:
- for history in flow_context:
- # 查找是否存在
- current_context = await flow_context_collection.find_one({
- "task_id": task_id,
- "_id": history["_id"],
- })
- if current_context:
- await flow_context_collection.update_one(
- {"_id": current_context["_id"]},
- {"$set": history},
- )
- else:
- await flow_context_collection.insert_one(history)
+ # 删除旧的flow_context
+ await flow_context_collection.delete_many({"task_id": task_id})
+ if not flow_context:
+ return
+ await flow_context_collection.insert_many(
+ [history.model_dump(exclude_none=True, by_alias=True) for history in flow_context],
+ ordered=False,
+ )
except Exception:
logger.exception("[TaskManager] 保存flow执行记录失败")
-
@staticmethod
async def delete_task_by_task_id(task_id: str) -> None:
"""通过task_id删除Task信息"""
@@ -148,9 +159,27 @@ class TaskManager:
if task:
await task_collection.delete_one({"_id": task_id})
+ @staticmethod
+ async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]:
+ """通过ConversationID删除Task信息"""
+ mongo = MongoDB()
+ task_collection = mongo.get_collection("task")
+ task_ids = []
+ try:
+ async for task in task_collection.find(
+ {"conversation_id": conversation_id},
+ {"_id": 1},
+ ):
+ task_ids.append(task["_id"])
+ if task_ids:
+ await task_collection.delete_many({"conversation_id": conversation_id})
+ return task_ids
+ except Exception:
+ logger.exception("[TaskManager] 删除ConversationID的Task信息失败")
+ return []
@staticmethod
- async def delete_tasks_by_conversation_id(conversation_id: str) -> None:
+ async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None:
"""通过ConversationID删除Task信息"""
mongo = MongoDB()
task_collection = mongo.get_collection("task")
@@ -167,52 +196,6 @@ class TaskManager:
await task_collection.delete_many({"conversation_id": conversation_id}, session=session)
await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session)
-
- @classmethod
- async def get_task(
- cls,
- task_id: str | None = None,
- session_id: str | None = None,
- post_body: RequestData | None = None,
- user_sub: str | None = None,
- ) -> Task:
- """获取任务块"""
- if task_id:
- try:
- task = await cls.get_task_by_task_id(task_id)
- if task:
- return task
- except Exception:
- logger.exception("[TaskManager] 通过task_id获取任务失败")
-
- logger.info("[TaskManager] 未提供task_id,通过session_id获取任务")
- if not session_id or not post_body:
- err = (
- "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。"
- )
- raise ValueError(err)
-
- if post_body.group_id:
- task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id)
- else:
- task = await cls.get_task_by_conversation_id(post_body.conversation_id)
-
- if task:
- return task
- return Task(
- _id=str(uuid.uuid4()),
- ids=TaskIds(
- user_sub=user_sub if user_sub else "",
- session_id=session_id if session_id else "",
- conversation_id=post_body.conversation_id,
- group_id=post_body.group_id if post_body.group_id else "",
- ),
- state=None,
- tokens=TaskTokens(),
- runtime=TaskRuntime(),
- )
-
-
@classmethod
async def save_task(cls, task_id: str, task: Task) -> None:
"""保存任务块"""
diff --git a/apps/services/user.py b/apps/services/user.py
index 2721d3773bd43e2a8fda5b371d6336fcf0b9b7f3..1b96df18143f5a87b102f68e2bb4e2df62753f05 100644
--- a/apps/services/user.py
+++ b/apps/services/user.py
@@ -4,6 +4,7 @@
import logging
from datetime import UTC, datetime
+from apps.schemas.request_data import UserUpdateRequest
from apps.common.mongo import MongoDB
from apps.schemas.collection import User
from apps.services.conversation import ConversationManager
@@ -28,7 +29,7 @@ class UserManager:
).model_dump(by_alias=True))
@staticmethod
- async def get_all_user_sub() -> list[str]:
+ async def get_all_user_sub(page_size: int = 20, page_cnt: int = 1, filter_user_subs: list[str] = []) -> tuple[list[str], int]:
"""
获取所有用户的sub
@@ -36,7 +37,13 @@ class UserManager:
"""
mongo = MongoDB()
user_collection = mongo.get_collection("user")
- return [user["_id"] async for user in user_collection.find({}, {"_id": 1})]
+ total = await user_collection.count_documents({}) - len(filter_user_subs)
+
+ users = await user_collection.find(
+ {"_id": {"$nin": filter_user_subs}},
+ {"_id": 1},
+ ).skip((page_cnt - 1) * page_size).limit(page_size).to_list(length=page_size)
+ return [user["_id"] for user in users], total
@staticmethod
async def get_userinfo_by_user_sub(user_sub: str) -> User | None:
@@ -52,7 +59,25 @@ class UserManager:
return User(**user_data) if user_data else None
@staticmethod
- async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool:
+ async def update_userinfo_by_user_sub(user_sub: str, data: UserUpdateRequest) -> None:
+ """
+ 根据用户sub更新用户信息
+
+ :param user_sub: 用户sub
+ :param data: 用户更新信息
+ :return: 是否更新成功
+ """
+ mongo = MongoDB()
+ user_collection = mongo.get_collection("user")
+ update_dict = {
+ "$set": {
+ "auto_execute": data.auto_execute,
+ }
+ }
+ await user_collection.update_one({"_id": user_sub}, update_dict)
+
+ @staticmethod
+ async def update_refresh_revision_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool:
"""
根据用户sub更新用户信息
diff --git a/deploy/chart/euler_copilot/configs/rag/.env b/deploy/chart/euler_copilot/configs/rag/.env
index 5e83f3f415091b156644b33bef9291a2d99d1931..0708fd0762d1f8c789ee5a4a873fb06880ce1463 100644
--- a/deploy/chart/euler_copilot/configs/rag/.env
+++ b/deploy/chart/euler_copilot/configs/rag/.env
@@ -52,7 +52,7 @@ HALF_KEY3=${halfKey3}
#LLM config
MODEL_NAME={{ .Values.models.answer.name }}
-OPENAI_API_BASE={{ .Values.models.answer.endpoint }}/v1
+OPENAI_API_BASE={{ .Values.models.answer.endpoint }}
OPENAI_API_KEY={{ default "" .Values.models.answer.key }}
MAX_TOKENS={{ default 2048 .Values.models.answer.maxTokens }}
diff --git a/perf.data b/perf.data
new file mode 100644
index 0000000000000000000000000000000000000000..f319e212c75a402bbf0ff2e488552bad93ff5b8b
Binary files /dev/null and b/perf.data differ
diff --git a/pyproject.toml b/pyproject.toml
index 681ea9fb4552650da9148a22dd55d2063dca2410..22b07eacd07d19ac7478b6af20fb0f972da742db 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -50,4 +50,4 @@ dev = [
"ruff==0.11.2",
"sphinx==8.2.3",
"sphinx-rtd-theme==3.0.2",
-]
+]
\ No newline at end of file
diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py
index 5375180a3f453a9d9d18d17b95b0ed9c39d8c5c4..db1f5ead746f35dd510b30698746e1a25f08d02d 100644
--- a/tests/common/test_queue.py
+++ b/tests/common/test_queue.py
@@ -74,8 +74,8 @@ async def test_push_output_with_flow(message_queue, mock_task):
mock_task.state.flow_id = "flow_id"
mock_task.state.step_id = "step_id"
mock_task.state.step_name = "step_name"
- mock_task.state.status = "running"
-
+ mock_task.state.step_status = "running"
+
await message_queue.init("test_task")
await message_queue.push_output(mock_task, EventType.TEXT_ADD, {})
diff --git a/tests/manager/test_user.py b/tests/manager/test_user.py
index d350ca307a9f0aa2d6532cae614b8633b1ae0f81..ab6eeab4c498a2619c4f08c5b548a6096611aeaa 100644
--- a/tests/manager/test_user.py
+++ b/tests/manager/test_user.py
@@ -73,7 +73,7 @@ class TestUserManager(unittest.TestCase):
mock_mysql_db_instance.get_session.return_value = mock_session
# 调用被测方法
- updated_userinfo = UserManager.update_userinfo_by_user_sub(userinfo, refresh_revision=True)
+ updated_userinfo = UserManager.update_refresh_revision_by_user_sub(userinfo, refresh_revision=True)
# 断言返回的用户信息的 revision_number 是否与原始用户信息一致
self.assertEqual(updated_userinfo.revision_number, userinfo.revision_number)