diff --git a/apps/common/queue.py b/apps/common/queue.py index 5601c93aa3ef72011f2e2d4b126f782d83e85b78..6e82af915e2a70f743c629902c149b7ebc0f7a0e 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -56,9 +56,12 @@ class MessageQueue: flow = MessageFlow( appId=task.state.app_id, flowId=task.state.flow_id, + flowName=task.state.flow_name, + flowStatus=task.state.flow_status, stepId=task.state.step_id, stepName=task.state.step_name, - stepStatus=task.state.status, + stepDescription=task.state.step_description, + stepStatus=task.state.step_status ) else: flow = None diff --git a/apps/constants.py b/apps/constants.py index 58158b33c3aa4ea5aa3595f5642bf6444e7d76cb..20cb79b54ac8db5450fdbb296cb400b47a29adf0 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -11,7 +11,7 @@ from apps.common.config import Config # 新对话默认标题 NEW_CHAT = "新对话" # 滑动窗口限流 默认窗口期 -SLIDE_WINDOW_TIME = 60 +SLIDE_WINDOW_TIME = 15 # OIDC 访问Token 过期时间(分钟) OIDC_ACCESS_TOKEN_EXPIRE_TIME = 30 # OIDC 刷新Token 过期时间(分钟) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51e00dd76b35bec2f7787dfe8039111589..8f5848a3d2961ef21b242d2eee931c623c6481c7 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -1,14 +1,16 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """用户鉴权""" - +import os import logging from fastapi import Depends from fastapi.security import OAuth2PasswordBearer +import secrets from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config from apps.services.api_key import ApiKeyManager from apps.services.session import SessionManager @@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return secrets.token_hex(16) session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( @@ -69,6 +74,12 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回当前操作系统用户的名称 + username = os.environ.get('USERNAME') # 适用于 Windows 系统 + if not username: + username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 + return username or "admin" session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/llm/embedding.py b/apps/llm/embedding.py index 28ab86b49a96d19cc6e2db83930d5aff48f82260..df13044a1c9ce0e11968d5dd7c3f0f8ab2345394 100644 --- a/apps/llm/embedding.py +++ b/apps/llm/embedding.py @@ -1,8 +1,9 @@ """Embedding模型""" import httpx - +import logging from apps.common.config import Config +logger = logging.getLogger(__name__) class Embedding: @@ -15,7 +16,6 @@ class Embedding: embedding = await cls.get_embedding(["测试文本"]) return len(embedding[0]) - @classmethod async def _get_openai_embedding(cls, text: list[str]) -> list[list[float]]: """访问OpenAI兼容的Embedding API,获得向量化数据""" @@ -75,10 +75,18 @@ class Embedding: :param text: 待向量化文本(多条文本组成List) :return: 文本对应的向量(顺序与text一致,也为List) """ - if Config().get_config().embedding.type == "openai": - return await cls._get_openai_embedding(text) - if Config().get_config().embedding.type == "mindie": - return await cls._get_tei_embedding(text) + try: + if Config().get_config().embedding.type == "openai": + return await cls._get_openai_embedding(text) + if Config().get_config().embedding.type == "mindie": + return await cls._get_tei_embedding(text) - err = f"不支持的Embedding API类型: {Config().get_config().embedding.type}" - raise ValueError(err) + err = f"不支持的Embedding API类型: {Config().get_config().embedding.type}" + raise ValueError(err) + except Exception as e: + err = f"获取Embedding失败: {e}" + logger.error(err) + rt = [] + for i in range(len(text)): + rt.append([0.0]*1024) + return rt diff --git a/apps/llm/function.py b/apps/llm/function.py index 1f995fe7ba187cead03aa6fc62a4cbce1ec05a65..6165dc455aae038aeb1ad9bcd0840b418c3dc8e8 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -68,7 +68,6 @@ class FunctionLLM: api_key=self._config.api_key, ) - async def _call_openai( self, messages: list[dict[str, str]], @@ -123,7 +122,7 @@ class FunctionLLM: }, ] - response = await self._client.chat.completions.create(**self._params) # type: ignore[arg-type] + response = await self._client.chat.completions.create(**self._params) # type: ignore[arg-type] try: logger.info("[FunctionCall] 大模型输出:%s", response.choices[0].message.tool_calls[0].function.arguments) return response.choices[0].message.tool_calls[0].function.arguments @@ -132,7 +131,6 @@ class FunctionLLM: logger.info("[FunctionCall] 大模型输出:%s", ans) return await FunctionLLM.process_response(ans) - @staticmethod async def process_response(response: str) -> str: """处理大模型的输出""" @@ -169,7 +167,6 @@ class FunctionLLM: return json_str - async def _call_ollama( self, messages: list[dict[str, str]], @@ -196,10 +193,9 @@ class FunctionLLM: "format": schema, }) - response = await self._client.chat(**self._params) # type: ignore[arg-type] + response = await self._client.chat(**self._params) # type: ignore[arg-type] return await self.process_response(response.message.content or "") - async def call( self, messages: list[dict[str, Any]], @@ -254,7 +250,6 @@ class JsonGenerator: ) self._err_info = "" - async def _assemble_message(self) -> str: """组装消息""" # 检查类型 @@ -275,13 +270,12 @@ class JsonGenerator: """单次尝试""" prompt = await self._assemble_message() messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, + {"role": "system", "content": prompt}, + {"role": "user", "content": "please generate a JSON response based on the above information and schema."}, ] function = FunctionLLM() return await function.call(messages, self._schema, max_tokens, temperature) - async def generate(self) -> dict[str, Any]: """生成JSON""" Draft7Validator.check_schema(self._schema) diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index 7cf555be390bc72301323cd103c1a7ec0ab7a085..7702f6d0912947612b64145402398f222e01a148 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -19,7 +19,7 @@ JSON_GEN_BASIC = dedent(r""" Background information is given in XML tags. - Here are the conversations between you and the user: + Here are the background information between you and the user: {% if conversation|length > 0 %} {% for message in conversation %} diff --git a/apps/main.py b/apps/main.py index c4ca2bfb116db11624f4e4fbbe4e876a86e5f1eb..99b7de0cd927f937abc95971cb878dd87e878abb 100644 --- a/apps/main.py +++ b/apps/main.py @@ -36,9 +36,10 @@ from apps.routers import ( record, service, user, + parameter ) from apps.scheduler.pool.pool import Pool - +logger = logging.getLogger(__name__) # 定义FastAPI app app = FastAPI(redoc_url=None) # 定义FastAPI全局中间件 @@ -66,6 +67,7 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(parameter.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" @@ -81,13 +83,49 @@ logging.basicConfig( ) +async def add_no_auth_user() -> None: + """ + 添加无认证用户 + """ + from apps.common.mongo import MongoDB + from apps.schemas.collection import User + import os + mongo = MongoDB() + user_collection = mongo.get_collection("user") + username = os.environ.get('USERNAME') # 适用于 Windows 系统 + if not username: + username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统 + if not username: + username = "admin" + try: + await user_collection.insert_one(User( + _id=username, + is_admin=True, + auto_execute=False + ).model_dump(by_alias=True)) + except Exception as e: + logger.error(f"[add_no_auth_user] 默认用户 {username} 已存在") + + +async def clear_user_activity() -> None: + """清除所有用户的活跃状态""" + from apps.services.activity import Activity + from apps.common.mongo import MongoDB + mongo = MongoDB() + activity_collection = mongo.get_collection("activity") + await activity_collection.delete_many({}) + logging.info("清除所有用户活跃状态完成") + + async def init_resources() -> None: """初始化必要资源""" WordsCheck() await LanceDB().init() await Pool.init() TokenCalculator() - + if Config().get_config().no_auth.enable: + await add_no_auth_user() + await clear_user_activity() # 运行 if __name__ == "__main__": # 初始化必要资源 diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py index 158cfc13a5be17f16e2d73ec5fa44c4a18e4f392..51366a2192eaa429a3f1e1d40672b9e915bec424 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/api_key.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse + from apps.dependency.user import get_user, verify_user from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp from apps.schemas.response_data import ResponseData diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 0ec4db9155a06fe16cbeddff223ccd9a764d8b3d..36815743349c2436564e30e4cb319caa8843dd99 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -13,6 +13,8 @@ from apps.schemas.appcenter import AppFlowInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import CreateAppRequest, ModFavAppRequest from apps.schemas.response_data import ( + AppMcpServiceInfo, + LLMIteam, BaseAppOperationMsg, BaseAppOperationRsp, GetAppListMsg, @@ -25,7 +27,8 @@ from apps.schemas.response_data import ( ResponseData, ) from apps.services.appcenter import AppCenterManager - +from apps.services.llm import LLMManager +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/app", @@ -214,6 +217,24 @@ async def get_application( ) for flow in app_data.flows ] + mcp_service = [] + if app_data.mcp_service: + for service in app_data.mcp_service: + mcp_collection = await MCPServiceManager.get_mcp_service(service) + mcp_service.append(AppMcpServiceInfo( + id=mcp_collection.id, + name=mcp_collection.name, + description=mcp_collection.description, + )) + if app_data.llm_id == "empty": + llm_item = LLMIteam() + else: + llm_collection = await LLMManager.get_llm_by_id(app_data.user_sub, app_data.llm_id) + llm_item = LLMIteam( + llmId=llm_collection.id, + modelName=llm_collection.model_name, + icon=llm_collection.icon + ) return JSONResponse( status_code=status.HTTP_200_OK, content=GetAppPropertyRsp( @@ -234,7 +255,8 @@ async def get_application( authorizedUsers=app_data.permission.users, ), workflows=workflows, - mcpService=app_data.mcp_service, + mcpService=mcp_service, + llm=llm_item, ), ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed629d90776b59ae3bf652381318dcabf0e..72416fa35b1c23d2975497427da35a4912695fa4 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates +from apps.common.config import Config from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user from apps.schemas.collection import Audit @@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: user_info = await oidc_provider.get_oidc_user(token["access_token"]) user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) except Exception as e: logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN @@ -75,7 +74,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: status_code=status.HTTP_403_FORBIDDEN, ) - await UserManager.update_userinfo_by_user_sub(user_sub) + await UserManager.update_refresh_revision_by_user_sub(user_sub) current_session = await SessionManager.create_session(user_host, user_sub) @@ -178,6 +177,7 @@ async def userinfo( user_sub=user_sub, revision=user.is_active, is_admin=user.is_admin, + auto_execute=user.auto_execute, ), ).model_dump(exclude_none=True, by_alias=True), ) @@ -193,7 +193,7 @@ async def userinfo( ) async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001 """更新用户协议信息""" - ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True) + ret: bool = await UserManager.update_refresh_revision_by_user_sub(user_sub, refresh_revision=True) if not ret: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7fe5162c0f7db804c6d723207ffa327d9394e3eb..888fbd04f67c0ad2ad9482db8d37016e0d71447d 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -7,20 +7,23 @@ import uuid from collections.abc import AsyncGenerator from typing import Annotated -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue from apps.common.wordscheck import WordsCheck from apps.dependency import get_session, get_user +from apps.schemas.enum_var import FlowStatus from apps.scheduler.scheduler import Scheduler from apps.scheduler.scheduler.context import save_data -from apps.schemas.request_data import RequestData +from apps.schemas.request_data import RequestData, RequestDataApp from apps.schemas.response_data import ResponseData from apps.schemas.task import Task from apps.services.activity import Activity from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager from apps.services.flow import FlowManager +from apps.services.conversation import ConversationManager +from apps.services.record import RecordManager from apps.services.task import TaskManager RECOMMEND_TRES = 5 @@ -36,28 +39,49 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + # 更改信息并刷新数据库 - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + if post_body.task_id is None: + conversation = await ConversationManager.get_conversation_by_conversation_id( + user_sub=user_sub, + conversation_id=post_body.conversation_id, + ) + if not conversation: + err = "[Chat] 用户没有权限访问该对话!" + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err) + task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id) + await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids) + task = await TaskManager.init_new_task(user_sub=user_sub, session_id=session_id, post_body=post_body) + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id + task.state.app_id = post_body.app.app_id if post_body.app else "" + else: + if not post_body.task_id: + err = "[Chat] task_id 不可为空!" + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty") + task = await TaskManager.get_task_by_task_id(post_body.task_id) + post_body.app = RequestDataApp(appId=task.state.app_id) + post_body.group_id = task.ids.group_id + post_body.conversation_id = task.ids.conversation_id + post_body.language = task.language + post_body.question = task.runtime.question return task async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: - await Activity.set_active(user_sub) + active_id = await Activity.set_active(user_sub) # 敏感词检查 if await WordsCheck().check(post_body.question) != 1: yield "data: [SENSITIVE]\n\n" logger.info("[Chat] 问题包含敏感词!") - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) return task = await init_task(post_body, user_sub, session_id) - + task.ids.active_id = active_id # 创建queue;由Scheduler进行关闭 queue = MessageQueue() await queue.init() @@ -77,19 +101,18 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 获取最终答案 task = scheduler.task - if not task.runtime.answer: - logger.error("[Chat] 答案为空") + if task.state.flow_status == FlowStatus.ERROR: + logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) return # 对结果进行敏感词检查 if await WordsCheck().check(task.runtime.answer) != 1: yield "data: [SENSITIVE]\n\n" logger.info("[Chat] 答案包含敏感词!") - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) return - # 创建新Record,存入数据库 await save_data(task, user_sub, post_body) @@ -107,7 +130,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) yield "data: [ERROR]\n\n" finally: - await Activity.remove_active(user_sub) + await Activity.remove_active(active_id) @router.post("/chat") @@ -118,15 +141,11 @@ async def chat( ) -> StreamingResponse: """LLM流式对话接口""" # 问题黑名单检测 - if not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): + if post_body.question is not None and not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question): # 用户扣分 await UserBlacklistManager.change_blacklisted_users(user_sub, -10) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") - # 限流检查 - if await Activity.is_active(user_sub): - raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, @@ -138,9 +157,12 @@ async def chat( @router.post("/stop", response_model=ResponseData) -async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 +async def stop_generation(user_sub: Annotated[str, Depends(get_user)], + task_id: Annotated[str, Query(..., alias="taskId")] = "") -> JSONResponse: """停止生成""" - await Activity.remove_active(user_sub) + task = await TaskManager.get_task_by_task_id(task_id) + if task: + await Activity.remove_active(task.activity_id) return JSONResponse( status_code=status.HTTP_200_OK, content=ResponseData( diff --git a/apps/routers/flow.py b/apps/routers/flow.py index fc38c1bfb9879c1d846d700e91b812c83866c2c4..68c0c9f3bd2fbdb434164d64ac4c3d1cf31a376b 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -2,6 +2,7 @@ """FastAPI Flow拓扑结构展示API""" from typing import Annotated +import logging from fastapi import APIRouter, Body, Depends, Query, status from fastapi.responses import JSONResponse @@ -25,6 +26,8 @@ from apps.services.application import AppManager from apps.services.flow import FlowManager from apps.services.flow_validate import FlowService +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/api/flow", tags=["flow"], @@ -130,8 +133,11 @@ async def put_flow( ).model_dump(exclude_none=True, by_alias=True), ) put_body.flow = await FlowService.remove_excess_structure_from_flow(put_body.flow) + logger.error(f'{put_body.flow}') await FlowService.validate_flow_illegal(put_body.flow) + logger.error(f'{put_body.flow}') put_body.flow.connectivity = await FlowService.validate_flow_connectivity(put_body.flow) + logger.error(f'{put_body.flow}') result = await FlowManager.put_flow_by_app_and_flow_id(app_id, flow_id, put_body.flow) if result is None: return JSONResponse( diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index a845a3761ea2ab179a1b71d4674d0d5fd3b5a760..99cc6dfdc312607e6bcf42274e7ad748d9dd928b 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -36,6 +36,7 @@ router = APIRouter( dependencies=[Depends(verify_user)], ) + async def _check_user_admin(user_sub: str) -> None: user = await UserManager.get_userinfo_by_user_sub(user_sub) if not user: @@ -52,6 +53,8 @@ async def get_mcpservice_list( ] = SearchType.ALL, keyword: Annotated[str | None, Query(..., alias="keyword", description="搜索关键字")] = None, page: Annotated[int, Query(..., alias="page", ge=1, description="页码")] = 1, + is_install: Annotated[bool | None, Query(..., alias="isInstall", description="是否已安装")] = None, + is_active: Annotated[bool | None, Query(..., alias="isActive", description="是否激活")] = None, ) -> JSONResponse: """获取服务列表""" try: @@ -60,6 +63,8 @@ async def get_mcpservice_list( user_sub, keyword, page, + is_install, + is_active ) except Exception as e: err = f"[MCPServiceCenter] 获取MCP服务列表失败: {e}" @@ -89,7 +94,7 @@ async def get_mcpservice_list( @router.post("", response_model=UpdateMCPServiceRsp) async def create_or_update_mcpservice( user_sub: Annotated[str, Depends(get_user)], # TODO: get_user直接获取所有用户信息 - data: UpdateMCPServiceRequest, + data: UpdateMCPServiceRequest ) -> JSONResponse: """新建或更新MCP服务""" await _check_user_admin(user_sub) @@ -130,6 +135,35 @@ async def create_or_update_mcpservice( ).model_dump(exclude_none=True, by_alias=True)) +@router.post("/{serviceId}/install") +async def install_mcp_service( + user_sub: Annotated[str, Depends(get_user)], + service_id: Annotated[str, Path(..., alias="serviceId", description="服务ID")], + install: Annotated[bool, Query(..., description="是否安装")] = True, +) -> JSONResponse: + try: + await MCPServiceManager.install_mcpservice(user_sub, service_id, install) + except Exception as e: + err = f"[MCPService] 安装mcp服务失败: {e!s}" if install else f"[MCPService] 卸载mcp服务失败: {e!s}" + logger.exception(err) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=err, + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="OK", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + + @router.get("/{serviceId}", response_model=GetMCPServiceDetailRsp) async def get_service_detail( user_sub: Annotated[str, Depends(get_user)], @@ -282,7 +316,7 @@ async def active_or_deactivate_mcp_service( """激活/取消激活mcp""" try: if data.active: - await MCPServiceManager.active_mcpservice(user_sub, service_id) + await MCPServiceManager.active_mcpservice(user_sub, service_id, data.mcp_env) else: await MCPServiceManager.deactive_mcpservice(user_sub, service_id) except Exception as e: diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..6edbe2e142cb6589be8947c28ae2eb4a7287baa1 --- /dev/null +++ b/apps/routers/parameter.py @@ -0,0 +1,77 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.services.parameter import ParameterManager +from apps.schemas.response_data import ( + GetOperaRsp, + GetParamsRsp +) +from apps.services.application import AppManager +from apps.services.flow import FlowManager + +router = APIRouter( + prefix="/api/parameter", + tags=["parameter"], + dependencies=[ + Depends(verify_user), + ], +) + + +@router.get("", response_model=GetParamsRsp) +async def get_parameters( + user_sub: Annotated[str, Depends(get_user)], + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], + step_id: Annotated[str, Query(alias="stepId")], +) -> JSONResponse: + """Get parameters for node choice.""" + if not await AppManager.validate_user_app_access(user_sub, app_id): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content=GetParamsRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + if not flow: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=GetParamsRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetParamsRsp( + code=status.HTTP_200_OK, + message="获取参数成功", + result=result + ).model_dump(exclude_none=True, by_alias=True), + ) + + +@router.get("/operate", response_model=GetOperaRsp) +async def get_operate_parameters( + user_sub: Annotated[str, Depends(get_user)], + param_type: Annotated[str, Query(alias="ParamType")], +) -> JSONResponse: + """Get parameters for node choice.""" + result = await ParameterManager.get_operate_and_bind_type(param_type) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetOperaRsp( + code=status.HTTP_200_OK, + message="获取操作成功", + result=result + ).model_dump(exclude_none=True, by_alias=True), + ) diff --git a/apps/routers/record.py b/apps/routers/record.py index 7384793b4b2abea5117462cb630c5b13c8f76071..ed994f6fe2e611562fa092db39265d1d57dfb9bc 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -65,7 +65,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record = RecordData( id=record.id, groupId=record_group.id, - taskId=record_group.task_id, + taskId=record.task_id, conversationId=conversation_id, content=record_data, metadata=record.metadata @@ -81,26 +81,26 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ # 获得Record关联的文档 tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) - # 获得Record关联的flow数据 - flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) - if flow_list: - first_flow = FlowStepHistory.model_validate(flow_list[0]) + flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) + if flow_step_list: tmp_record.flow = RecordFlow( - id=first_flow.flow_name, #TODO: 此处前端应该用name + id=record.flow.flow_id, # TODO: 此处前端应该用name recordId=record.id, - flowId=first_flow.id, - stepNum=len(flow_list), + flowId=record.flow.flow_id, + flowName=record.flow.flow_name, + flowStatus=record.flow.flow_staus, + stepNum=len(flow_step_list), steps=[], ) - for flow in flow_list: - flow_step = FlowStepHistory.model_validate(flow) + for flow_step in flow_step_list: tmp_record.flow.steps.append( RecordFlowStep( - stepId=flow_step.step_name, #TODO: 此处前端应该用name - stepStatus=flow_step.status, + stepId=flow_step.step_name, # TODO: 此处前端应该用name + stepStatus=flow_step.step_status, input=flow_step.input_data, output=flow_step.output_data, + exData=flow_step.ex_data ), ) diff --git a/apps/routers/user.py b/apps/routers/user.py index 54e12f444181b56e408f7face5c4ed37be76008b..537f1bf3adb95e8b48ef1b3d384376bae86ac802 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -3,10 +3,11 @@ from typing import Annotated -from fastapi import APIRouter, Depends, status +from fastapi import APIRouter, Body, Depends, status, Query from fastapi.responses import JSONResponse from apps.dependency import get_user +from apps.schemas.request_data import UserUpdateRequest from apps.schemas.response_data import UserGetMsp, UserGetRsp from apps.schemas.user import UserInfo from apps.services.user import UserManager @@ -17,12 +18,14 @@ router = APIRouter( ) -@router.get("") -async def chat( +@router.get("", response_model=UserGetRsp) +async def get_user_sub( user_sub: Annotated[str, Depends(get_user)], + page_size: Annotated[int, Query(description="每页用户数量")] = 20, + page_cnt: Annotated[int, Query(description="当前页码")] = 1, ) -> JSONResponse: """查询所有用户接口""" - user_list = await UserManager.get_all_user_sub() + user_list, total = await UserManager.get_all_user_sub(page_cnt=page_cnt, page_size=page_size, filter_user_subs=[user_sub]) user_info_list = [] for user in user_list: # user_info = await UserManager.get_userinfo_by_user_sub(user) 暂时不需要查询user_name @@ -39,6 +42,22 @@ async def chat( content=UserGetRsp( code=status.HTTP_200_OK, message="用户数据详细信息获取成功", - result=UserGetMsp(userInfoList=user_info_list), + result=UserGetMsp(userInfoList=user_info_list, total=total), ).model_dump(exclude_none=True, by_alias=True), ) + + +@router.post("") +async def update_user_info( + user_sub: Annotated[str, Depends(get_user)], + *, + data: Annotated[UserUpdateRequest, Body(..., description="用户更新信息")], +) -> JSONResponse: + """更新用户信息接口""" + # 更新用户信息 + + await UserManager.update_userinfo_by_user_sub(user_sub, data) + return JSONResponse( + status_code=status.HTTP_200_OK, + content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, + ) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2ee8b862885b0d88e519bf00a1678a49863df2bd..c5a6f0549c6891bcdf74052ddb8152fb550653b2 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -8,7 +8,7 @@ from apps.scheduler.call.mcp.mcp import MCP from apps.scheduler.call.rag.rag import RAG from apps.scheduler.call.sql.sql import SQL from apps.scheduler.call.suggest.suggest import Suggestion - +from apps.scheduler.call.choice.choice import Choice # 只包含需要在编排界面展示的工具 __all__ = [ "API", @@ -18,4 +18,5 @@ __all__ = [ "SQL", "Graph", "Suggestion", + "Choice" ] diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index a5edf21afb2eeb308dd909a40696500751c9a086..01ac7106fbdae673d12488a49d831f1d7e7fc13d 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,19 +1,150 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" -from enum import Enum - -from apps.scheduler.call.choice.schema import ChoiceInput, ChoiceOutput -from apps.scheduler.call.core import CoreCall +import ast +import copy +import logging +from collections.abc import AsyncGenerator +from typing import Any +from pydantic import Field -class Operator(str, Enum): - """Choice工具支持的运算符""" +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + Condition, + ChoiceBranch, + ChoiceInput, + ChoiceOutput, + Logic, +) +from apps.schemas.parameters import Type +from apps.scheduler.call.core import CoreCall +from apps.schemas.enum_var import CallOutputType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) - pass +logger = logging.getLogger(__name__) class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" - pass + to_user: bool = Field(default=False) + choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(), + ChoiceBranch(conditions=[Condition()], is_default=False)]) + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="选择器", description="使用大模型或使用程序做出判断") + + async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: + """替换choices中的系统变量""" + valid_choices = [] + + for choice in self.choices: + try: + # 验证逻辑运算符 + if choice.logic not in [Logic.AND, Logic.OR]: + msg = f"无效的逻辑运算符: {choice.logic}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + + valid_conditions = [] + for i in range(len(choice.conditions)): + condition = copy.deepcopy(choice.conditions[i]) + # 处理左值 + if condition.left.step_id is not None: + condition.left.value = self._extract_history_variables( + condition.left.step_id+'/'+condition.left.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.left.value, condition.left.type): + msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + msg = "左侧变量缺少step_id" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + # 处理右值 + if condition.right.step_id is not None: + condition.right.value = self._extract_history_variables( + condition.right.step_id+'/'+condition.right.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.right.value is None: + msg = f"步骤 {condition.right.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + # 如果右值没有step_id,尝试从call_vars中获取 + right_value_type = await ConditionHandler.get_value_type_from_operate( + condition.operate) + if right_value_type is None: + msg = f"不支持的运算符: {condition.operate}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if condition.right.type != right_value_type: + msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if right_value_type == Type.STRING: + condition.right.value = str(condition.right.value) + else: + condition.right.value = ast.literal_eval(condition.right.value) + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + valid_conditions.append(condition) + + # 如果所有条件都无效,抛出异常 + if not valid_conditions and not choice.is_default: + msg = "分支没有有效条件" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + + # 更新有效条件 + choice.conditions = valid_conditions + valid_choices.append(choice) + + except ValueError as e: + logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) + continue + + return valid_choices + + async def _init(self, call_vars: CallVars) -> ChoiceInput: + """初始化Choice工具""" + return ChoiceInput( + choices=await self._prepare_message(call_vars), + ) + + async def _exec( + self, input_data: dict[str, Any] + ) -> AsyncGenerator[CallOutputChunk, None]: + """执行Choice工具""" + # 解析输入数据 + data = ChoiceInput(**input_data) + try: + branch_id = ConditionHandler.handler(data.choices) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True), + ) + except Exception as e: + raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py new file mode 100644 index 0000000000000000000000000000000000000000..261809ad57fdd6d8c5267c59a7b87404dd7befa5 --- /dev/null +++ b/apps/scheduler/call/choice/condition_handler.py @@ -0,0 +1,309 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""处理条件分支的工具""" + + +import logging + +from pydantic import BaseModel + +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) + +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + Condition, + Logic, + Value +) + +logger = logging.getLogger(__name__) + + +class ConditionHandler(BaseModel): + """条件分支处理器""" + @staticmethod + async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | + BoolOperate | DictOperate) -> Type: + """获取右值的类型""" + if isinstance(operate, NumberOperate): + return Type.NUMBER + if operate in [ + StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS, + StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]: + return Type.STRING + if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN, + StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN, + StringOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]: + return Type.LIST + if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]: + return Type.STRING + if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN, + ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN, + ListOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]: + return Type.BOOL + if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]: + return Type.DICT + if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]: + return Type.STRING + return None + + @staticmethod + def check_value_type(value: Value, expected_type: Type) -> bool: + """检查值的类型是否符合预期""" + if expected_type == Type.STRING and isinstance(value.value, str): + return True + if expected_type == Type.NUMBER and isinstance(value.value, (int, float)): + return True + if expected_type == Type.LIST and isinstance(value.value, list): + return True + if expected_type == Type.DICT and isinstance(value.value, dict): + return True + if expected_type == Type.BOOL and isinstance(value.value, bool): + return True + return False + + @staticmethod + def handler(choices: list[ChoiceBranch]) -> str: + """处理条件""" + + for block_judgement in choices[::-1]: + results = [] + if block_judgement.is_default: + return block_judgement.branch_id + for condition in block_judgement.conditions: + result = ConditionHandler._judge_condition(condition) + if result is not None: + results.append(result) + if not results: + logger.warning(f"[Choice] 分支 {block_judgement.branch_id} 条件处理失败: 没有有效的条件") + continue + if block_judgement.logic == Logic.AND: + final_result = all(results) + elif block_judgement.logic == Logic.OR: + final_result = any(results) + + if final_result: + return block_judgement.branch_id + + return "" + + @staticmethod + def _judge_condition(condition: Condition) -> bool: + """ + 判断条件是否成立。 + + Args: + condition (Condition): 'left', 'operate', 'right', 'type' + + Returns: + bool + + """ + left = condition.left + operate = condition.operate + right = condition.right + value_type = condition.type + + result = None + if value_type == Type.STRING: + result = ConditionHandler._judge_string_condition(left, operate, right) + elif value_type == Type.NUMBER: + result = ConditionHandler._judge_number_condition(left, operate, right) + elif value_type == Type.BOOL: + result = ConditionHandler._judge_bool_condition(left, operate, right) + elif value_type == Type.LIST: + result = ConditionHandler._judge_list_condition(left, operate, right) + elif value_type == Type.DICT: + result = ConditionHandler._judge_dict_condition(left, operate, right) + else: + msg = f"不支持的数据类型: {value_type}" + logger.error(f"[Choice] 条件处理失败: {msg}") + return None + return result + + @staticmethod + def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool: + """ + 判断字符串类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, str): + msg = f"左值必须是字符串类型 ({left_value})" + logger.warning(msg) + return None + right_value = right.value + if operate == StringOperate.EQUAL: + return left_value == right_value + elif operate == StringOperate.NOT_EQUAL: + return left_value != right_value + elif operate == StringOperate.CONTAINS: + return right_value in left_value + elif operate == StringOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == StringOperate.STARTS_WITH: + return left_value.startswith(right_value) + elif operate == StringOperate.ENDS_WITH: + return left_value.endswith(right_value) + elif operate == StringOperate.REGEX_MATCH: + import re + return bool(re.match(right_value, left_value)) + elif operate == StringOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == StringOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == StringOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911 + """ + 判断数字类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, (int, float)): + msg = f"左值必须是数字类型 ({left_value})" + logger.warning(msg) + return None + right_value = right.value + if operate == NumberOperate.EQUAL: + return left_value == right_value + elif operate == NumberOperate.NOT_EQUAL: + return left_value != right_value + elif operate == NumberOperate.GREATER_THAN: + return left_value > right_value + elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004 + return left_value < right_value + elif operate == NumberOperate.GREATER_THAN_OR_EQUAL: + return left_value >= right_value + elif operate == NumberOperate.LESS_THAN_OR_EQUAL: + return left_value <= right_value + return False + + @staticmethod + def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool: + """ + 判断布尔类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, bool): + msg = "左值必须是布尔类型" + logger.warning(msg) + return None + right_value = right.value + if operate == BoolOperate.EQUAL: + return left_value == right_value + elif operate == BoolOperate.NOT_EQUAL: + return left_value != right_value + return False + + @staticmethod + def _judge_list_condition(left: Value, operate: ListOperate, right: Value): + """ + 判断列表类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, list): + msg = f"左值必须是列表类型 ({left_value})" + logger.warning(msg) + return None + right_value = right.value + if operate == ListOperate.EQUAL: + return left_value == right_value + elif operate == ListOperate.NOT_EQUAL: + return left_value != right_value + elif operate == ListOperate.CONTAINS: + return right_value in left_value + elif operate == ListOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == ListOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == ListOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == ListOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_dict_condition(left: Value, operate: DictOperate, right: Value): + """ + 判断字典类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, dict): + msg = f"左值必须是字典类型 ({left_value})" + logger.warning(msg) + return None + right_value = right.value + if operate == DictOperate.EQUAL: + return left_value == right_value + elif operate == DictOperate.NOT_EQUAL: + return left_value != right_value + elif operate == DictOperate.CONTAINS_KEY: + return right_value in left_value + elif operate == DictOperate.NOT_CONTAINS_KEY: + return right_value not in left_value + return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index 60b62d09fd66adbf32295f44ec86398a537f38d5..d97a0c8d7e9efcef255ced76c8d1d6b39dfed241 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,12 +1,64 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +import uuid +from enum import Enum + +from pydantic import BaseModel, Field + +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.scheduler.call.core import DataBase +class Logic(str, Enum): + """Choice 工具支持的逻辑运算符""" + + AND = "and" + OR = "or" + + +class Value(DataBase): + """值的结构""" + + step_id: str | None = Field(description="步骤id", default=None) + type: Type | None = Field(description="值的类型", default=None) + name: str | None = Field(description="值的名称", default=None) + value: str | float | int | bool | list | dict | None = Field(description="值", default=None) + + +class Condition(DataBase): + """单个条件""" + + left: Value = Field(description="左值", default=Value()) + right: Value = Field(description="右值", default=Value()) + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field( + description="运算符", default=None) + id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4())) + + +class ChoiceBranch(DataBase): + """子分支""" + + branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4())) + logic: Logic = Field(description="逻辑运算符", default=Logic.AND) + conditions: list[Condition] = Field(description="条件列表", default=[]) + is_default: bool = Field(description="是否为默认分支", default=True) + + class ChoiceInput(DataBase): """Choice Call的输入""" + choices: list[ChoiceBranch] = Field(description="分支", default=[]) + class ChoiceOutput(DataBase): """Choice Call的输出""" + + branch_id: str = Field(description="分支ID", default="") diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 2b1cbba83b9345e8df84e8f23f893137cb46532b..5bed80309fbdb993aacc378ab0319ed291f1ffb3 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -76,21 +76,18 @@ class CoreCall(BaseModel): extra="allow", ) - def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: """初始化子类""" super().__init_subclass__(**kwargs) cls.input_model = input_model cls.output_model = output_model - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" err = "[CoreCall] 必须手动实现info方法" raise NotImplementedError(err) - @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" @@ -120,7 +117,6 @@ class CoreCall(BaseModel): summary=executor.task.runtime.summary, ) - @staticmethod def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any: """ @@ -131,16 +127,14 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") + if len(split_path) < 1: + err = f"[CoreCall] 路径格式错误: {path}" + logger.error(err) + return None if split_path[0] not in history: err = f"[CoreCall] 步骤{split_path[0]}不存在" logger.error(err) - raise CallError( - message=err, - data={ - "step_id": split_path[0], - }, - ) - + return None data = history[split_path[0]].output_data for key in split_path[1:]: if key not in data: @@ -156,7 +150,6 @@ class CoreCall(BaseModel): data = data[key] return data - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """实例化Call类""" @@ -170,36 +163,30 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj - async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) - async def _init(self, call_vars: CallVars) -> DataBase: """初始化Call类,并返回Call的输入""" err = "[CoreCall] 初始化方法必须手动实现" raise NotImplementedError(err) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的流式输出方法""" yield CallOutputChunk(type=CallOutputType.TEXT, content="") - async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" async for chunk in self._exec(input_data): yield chunk await self._after_exec(input_data) - async def _llm(self, messages: list[dict[str, Any]]) -> str: """Call可直接使用的LLM非流式调用""" result = "" @@ -210,7 +197,6 @@ class CoreCall(BaseModel): self.output_tokens = llm.output_tokens return result - async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel: """Call可直接使用的JSON生成""" json = FunctionLLM() diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index f8aebcd748d92de4109553d0f68fecac43363f58..2b9df0c6f8f1002e76bda24eb099f47a6084b2af 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -30,13 +30,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): answer: str = Field(description="用户输入") - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" return CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。") - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """初始化工具""" @@ -51,7 +49,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): await obj._set_input(executor) return obj - async def _init(self, call_vars: CallVars) -> FactsInput: """初始化工具""" # 组装必要变量 @@ -65,7 +62,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): message=message, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" data = FactsInput(**input_data) @@ -83,7 +79,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): facts_obj: FactsGen = await self._json([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": facts_prompt}, - ], FactsGen) # type: ignore[arg-type] + ], FactsGen) # type: ignore[arg-type] # 更新用户画像 domain_tpl = env.from_string(DOMAIN_PROMPT) @@ -91,7 +87,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): domain_list: DomainGen = await self._json([ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": domain_prompt}, - ], DomainGen) # type: ignore[arg-type] + ], DomainGen) # type: ignore[arg-type] for domain in domain_list.keywords: await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) @@ -104,7 +100,6 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ).model_dump(by_alias=True, exclude_none=True), ) - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行工具""" async for chunk in self._exec(input_data): diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 661e9ada74da76e6d78a0f6cba42cef08c562335..bd0257b49fb6ea24fbbc429cb11f7b8e6be364b9 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -31,11 +31,10 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """MCP工具""" mcp_list: list[str] = Field(description="MCP Server ID列表", max_length=5, min_length=1) - max_steps: int = Field(description="最大步骤数", default=6) + max_steps: int = Field(description="最大步骤数", default=20) text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) - @classmethod def info(cls) -> CallInfo: """ @@ -46,7 +45,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """ return CallInfo(name="MCP", description="调用MCP Server,执行工具") - async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" # 获取MCP交互类 @@ -63,7 +61,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 @@ -80,7 +77,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): async for chunk in self._generate_answer(): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 @@ -103,7 +99,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final @@ -141,7 +136,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 @@ -163,7 +157,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): ).model_dump(), ) - def _create_output( self, text: str, diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 4f8e1010cc0bd88f22e050778bce236d7d3515e0..d24e1661e9dd7facc4bf590b24409052b6ec9104 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -32,13 +32,11 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): facts: list[str] = Field(description="事实信息", default=[]) step_num: int = Field(description="历史步骤数", default=1) - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" return CallInfo(name="参数自动填充", description="根据步骤历史,自动填充参数") - async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]: """使用大模型填充参数;若大模型解析度足够,则直接返回结果""" env = SandboxedEnvironment( @@ -106,7 +104,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): await obj._set_input(executor) return obj - async def _init(self, call_vars: CallVars) -> SlotInput: """初始化""" self._flow_history = [] @@ -126,7 +123,6 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): remaining_schema=remaining_schema, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行参数填充""" data = SlotInput(**input_data) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f6814dd364636d304382dc258f87fdd9f06eb9a8..3997d89ad039a056ec38ce1cbb5cebfa8fb39671 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -1,40 +1,510 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP Agent执行器""" +import anyio import logging - +import uuid from pydantic import Field - +from typing import Any +from mcp.types import TextContent +from apps.llm.patterns.rewrite import QuestionRewrite +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.executor.base import BaseExecutor -from apps.scheduler.mcp_agent.agent.mcp import MCPAgent - +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus +from apps.scheduler.mcp_agent.host import MCPHost +from apps.scheduler.mcp_agent.plan import MCPPlanner +from apps.scheduler.mcp_agent.select import FINAL_TOOL_ID, MCPSelector +from apps.scheduler.pool.mcp.client import MCPClient +from apps.schemas.mcp import ( + GoalEvaluationResult, + RestartStepIndex, + ToolRisk, + ErrorType, + ToolExcutionErrorType, + MCPPlan, + MCPCollection, + MCPTool, + Step +) +from apps.scheduler.pool.mcp.pool import MCPPool +from apps.schemas.task import ExecutorState, FlowStepHistory, StepQueueItem +from apps.schemas.message import param +from apps.services.task import TaskManager +from apps.services.appcenter import AppCenterManager +from apps.services.mcp_service import MCPServiceManager +from apps.services.user import UserManager logger = logging.getLogger(__name__) class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" - question: str = Field(description="用户输入") - max_steps: int = Field(default=10, description="最大步数") + max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") + mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[]) + mcp_pool: MCPPool = Field(description="MCP池", default=MCPPool()) + tools: dict[str, MCPTool] = Field( + description="MCP工具列表,key为tool_id", default={} + ) + tool_list: list[MCPTool] = Field( + description="MCP工具列表,包含所有MCP工具", default=[] + ) + params: param | bool | None = Field( + default=None, description="流执行过程中的参数补充", alias="params" + ) + resoning_llm: ReasoningLLM = Field( + default=ReasoningLLM(), + description="推理大模型", + ) - async def run(self) -> None: - """运行MCP Agent""" - agent = await MCPAgent.create( - servers_id=self.servers_id, - max_steps=self.max_steps, - task=self.task, - msg_queue=self.msg_queue, - question=self.question, - agent_id=self.agent_id, - description=self.agent_description, + async def update_tokens(self) -> None: + """更新令牌数""" + self.task.tokens.input_tokens = self.resoning_llm.input_tokens + self.task.tokens.output_tokens = self.resoning_llm.output_tokens + await TaskManager.save_task(self.task.id, self.task) + + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + + async def load_mcp(self) -> None: + """加载MCP服务器列表""" + logger.info("[MCPAgentExecutor] 加载MCP服务器列表") + # 获取MCP服务器列表 + app = await AppCenterManager.fetch_app_data_by_id(self.agent_id) + mcp_ids = app.mcp_service + for mcp_id in mcp_ids: + mcp_service = await MCPServiceManager.get_mcp_service(mcp_id) + if self.task.ids.user_sub not in mcp_service.activated: + logger.warning( + "[MCPAgentExecutor] 用户 %s 未启用MCP %s", + self.task.ids.user_sub, + mcp_id, + ) + continue + + self.mcp_list.append(mcp_service) + await self.mcp_pool._init_mcp(mcp_id, self.task.ids.user_sub) + for tool in mcp_service.tools: + self.tools[tool.id] = tool + self.tool_list.extend(mcp_service.tools) + self.tools[FINAL_TOOL_ID] = MCPTool( + id=FINAL_TOOL_ID, + name="Final Tool", + description="结束流程的工具", + mcp_id="", + input_schema={} + ) + self.tool_list.append(MCPTool(id=FINAL_TOOL_ID, name="Final Tool", + description="结束流程的工具", mcp_id="", input_schema={})) + + async def get_tool_input_param(self, is_first: bool) -> None: + if is_first: + # 获取第一个输入参数 + mcp_tool = self.tools[self.task.state.tool_id] + self.task.state.current_input = await MCPHost._get_first_input_params(mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task) + else: + # 获取后续输入参数 + if isinstance(self.params, param): + params = self.params.content + params_description = self.params.description + else: + params = {} + params_description = "" + mcp_tool = self.tools[self.task.state.tool_id] + self.task.state.current_input = await MCPHost._fill_params(mcp_tool, self.task.runtime.question, self.task.state.step_description, self.task.state.current_input, self.task.state.error_message, params, params_description) + + async def confirm_before_step(self) -> None: + # 发送确认消息 + mcp_tool = self.tools[self.task.state.tool_id] + confirm_message = await MCPPlanner.get_tool_risk(mcp_tool, self.task.state.current_input, "", self.resoning_llm) + await self.update_tokens() + await self.push_message(EventType.STEP_WAITING_FOR_START, confirm_message.model_dump( + exclude_none=True, by_alias=True)) + await self.push_message(EventType.FLOW_STOP, {}) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.WAITING + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data=confirm_message.model_dump(exclude_none=True, by_alias=True), + ) ) + async def run_step(self): + """执行步骤""" + self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.step_status = StepStatus.RUNNING + mcp_tool = self.tools[self.task.state.tool_id] + mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub)) + try: + output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) + except anyio.ClosedResourceError as e: + import traceback + logger.error("[MCPAgentExecutor] MCP客户端连接已关闭: %s, 错误: %s", mcp_tool.mcp_id, traceback.format_exc()) + await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.ids.user_sub) + await self.mcp_pool._init_mcp(mcp_tool.mcp_id, self.task.ids.user_sub) + logger.error("[MCPAgentExecutor] MCP客户端连接已关闭: %s, 错误: %s", mcp_tool.mcp_id, str(e)) + self.task.state.step_status = StepStatus.ERROR + return + except Exception as e: + import traceback + logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误: %s", mcp_tool.name, traceback.format_exc()) + self.task.state.step_status = StepStatus.ERROR + self.task.state.error_message = str(e) + return + if output_params.isError: + err = "" + for output in output_params.content: + if isinstance(output, TextContent): + err += output.text + self.task.state.step_status = StepStatus.ERROR + self.task.state.error_message = err + return + message = "" + for output in output_params.content: + if isinstance(output, TextContent): + message += output.text + output_params = { + "message": message, + } + + await self.update_tokens() + await self.push_message( + EventType.STEP_INPUT, + self.task.state.current_input + ) + await self.push_message( + EventType.STEP_OUTPUT, + output_params + ) + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=StepStatus.SUCCESS, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data=self.task.state.current_input, + output_data=output_params, + ) + ) + self.task.state.step_status = StepStatus.SUCCESS + + async def generate_params_with_null(self) -> None: + """生成参数补充""" + mcp_tool = self.tools[self.task.state.tool_id] + params_with_null = await MCPPlanner.get_missing_param( + mcp_tool, + self.task.state.current_input, + self.task.state.error_message, + self.resoning_llm + ) + await self.update_tokens() + error_message = await MCPPlanner.change_err_message_to_description( + error_message=self.task.state.error_message, + tool=mcp_tool, + input_params=self.task.state.current_input, + reasoning_llm=self.resoning_llm + ) + await self.push_message( + EventType.STEP_WAITING_FOR_PARAM, + data={ + "message": error_message, + "params": params_with_null + } + ) + await self.push_message( + EventType.FLOW_STOP, + data={} + ) + self.task.state.flow_status = FlowStatus.WAITING + self.task.state.step_status = StepStatus.PARAM + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ex_data={ + "message": error_message, + "params": params_with_null + } + ) + ) + + async def get_next_step(self) -> None: + if self.task.state.step_cnt < self.max_steps: + self.task.state.step_cnt += 1 + history = await MCPHost.assemble_memory(self.task) + max_retry = 3 + step = None + for i in range(max_retry): + step = await MCPPlanner.create_next_step(self.task.runtime.question, history, self.tool_list) + if step.tool_id in self.tools.keys(): + break + if step is None or step.tool_id not in self.tools.keys(): + step = Step( + tool_id=FINAL_TOOL_ID, + description=FINAL_TOOL_ID + ) + tool_id = step.tool_id + if tool_id == FINAL_TOOL_ID: + step_name = FINAL_TOOL_ID + else: + step_name = self.tools[tool_id].name + step_description = step.description + self.task.state.step_id = str(uuid.uuid4()) + self.task.state.tool_id = tool_id + self.task.state.step_name = step_name + self.task.state.step_description = step_description + self.task.state.step_status = StepStatus.INIT + self.task.state.current_input = {} + else: + # 没有下一步了,结束流程 + self.task.state.tool_id = FINAL_TOOL_ID + return + + async def error_handle_after_step(self) -> None: + """步骤执行失败后的错误处理""" + self.task.state.step_status = StepStatus.ERROR + self.task.state.flow_status = FlowStatus.ERROR + await self.push_message( + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) + + async def work(self) -> None: + """执行当前步骤""" + if self.task.state.step_status == StepStatus.INIT: + await self.push_message( + EventType.STEP_INIT, + data={} + ) + await self.get_tool_input_param(is_first=True) + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if not user_info.auto_execute: + # 等待用户确认 + await self.confirm_before_step() + return + self.task.state.step_status = StepStatus.RUNNING + elif self.task.state.step_status in [StepStatus.PARAM, StepStatus.WAITING, StepStatus.RUNNING]: + if self.task.state.step_status == StepStatus.PARAM: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + elif self.task.state.step_status == StepStatus.WAITING: + if self.params: + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + else: + self.task.state.flow_status = FlowStatus.CANCELLED + self.task.state.step_status = StepStatus.CANCELLED + await self.push_message( + EventType.STEP_CANCEL, + data={} + ) + await self.push_message( + EventType.FLOW_CANCEL, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.CANCELLED + if self.task.state.step_status == StepStatus.PARAM: + await self.get_tool_input_param(is_first=False) + max_retry = 5 + for i in range(max_retry): + if i != 0: + await self.get_tool_input_param(is_first=True) + await self.run_step() + if self.task.state.step_status == StepStatus.SUCCESS: + break + elif self.task.state.step_status == StepStatus.ERROR: + # 错误处理 + if self.task.state.retry_times >= 3: + await self.error_handle_after_step() + else: + user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) + if user_info.auto_execute: + self.push_message( + EventType.STEP_ERROR, + data={ + "message": self.task.state.error_message, + } + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.ERROR + self.task.context[-1].output_data = { + "message": self.task.state.error_message, + } + await self.get_next_step() + else: + mcp_tool = self.tools[self.task.state.tool_id] + is_param_error = await MCPPlanner.is_param_error( + self.task.runtime.question, + await MCPHost.assemble_memory(self.task), + self.task.state.error_message, + mcp_tool, + self.task.state.step_description, + self.task.state.current_input, + ) + if is_param_error.is_param_error: + # 如果是参数错误,生成参数补充 + await self.generate_params_with_null() + else: + self.push_message( + EventType.STEP_ERROR, + data={ + "message": self.task.state.error_message, + } + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + self.task.context[-1].step_status = StepStatus.ERROR + self.task.context[-1].output_data = { + "message": self.task.state.error_message, + } + await self.get_next_step() + elif self.task.state.step_status == StepStatus.SUCCESS: + await self.get_next_step() + + async def summarize(self) -> None: + async for chunk in MCPPlanner.generate_answer( + self.task.runtime.question, + (await MCPHost.assemble_memory(self.task)), + self.resoning_llm + ): + await self.push_message( + EventType.TEXT_ADD, + data=chunk + ) + self.task.runtime.answer += chunk + + async def run(self) -> None: + """执行MCP Agent的主逻辑""" + # 初始化MCP服务 + await self.load_state() + await self.load_mcp() + if self.task.state.flow_status == FlowStatus.INIT: + # 初始化状态 + try: + self.task.state.flow_id = str(uuid.uuid4()) + self.task.state.flow_name = await MCPPlanner.get_flow_name(self.task.runtime.question, self.resoning_llm) + await TaskManager.save_task(self.task.id, self.task) + await self.get_next_step() + except Exception as e: + import traceback + logger.error("[MCPAgentExecutor] 初始化失败: %s", traceback.format_exc()) + logger.error("[MCPAgentExecutor] 初始化失败: %s", str(e)) + self.task.state.flow_status = FlowStatus.ERROR + self.task.state.error_message = str(e) + await self.push_message( + EventType.FLOW_FAILED, + data={} + ) + return + self.task.state.flow_status = FlowStatus.RUNNING + await self.push_message( + EventType.FLOW_START, + data={} + ) + if self.task.state.tool_id == FINAL_TOOL_ID: + # 如果已经是最后一步,直接结束 + self.task.state.flow_status = FlowStatus.SUCCESS + await self.push_message( + EventType.FLOW_SUCCESS, + data={} + ) + await self.summarize() + return try: - answer = await agent.run(self.question) - self.task = agent.task - self.task.runtime.answer = answer + while self.task.state.flow_status == FlowStatus.RUNNING: + if self.task.state.tool_id == FINAL_TOOL_ID: + break + await self.work() + await TaskManager.save_task(self.task.id, self.task) + tool_id = self.task.state.tool_id + if tool_id == FINAL_TOOL_ID: + # 如果已经是最后一步,直接结束 + self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.step_status = StepStatus.SUCCESS + await self.push_message( + EventType.FLOW_SUCCESS, + data={} + ) + await self.summarize() except Exception as e: - logger.error(f"Error: {str(e)}") + import traceback + logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", traceback.format_exc()) + logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e)) + self.task.state.flow_status = FlowStatus.ERROR + self.task.state.error_message = str(e) + self.task.state.step_status = StepStatus.ERROR + await self.push_message( + EventType.STEP_ERROR, + data={} + ) + await self.push_message( + EventType.FLOW_FAILED, + data={} + ) + if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: + del self.task.context[-1] + self.task.context.append( + FlowStepHistory( + task_id=self.task.id, + step_id=self.task.state.step_id, + step_name=self.task.state.step_name, + step_description=self.task.state.step_description, + step_status=self.task.state.step_status, + flow_id=self.task.state.flow_id, + flow_name=self.task.state.flow_name, + flow_status=self.task.state.flow_status, + input_data={}, + output_data={}, + ) + ) + finally: + for mcp_service in self.mcp_list: + try: + await self.mcp_pool.stop(mcp_service.id, self.task.ids.user_sub) + except Exception as e: + import traceback + logger.error("[MCPAgentExecutor] 停止MCP客户端时发生错误: %s", traceback.format_exc()) diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index cf2f4e6838f5d8d5b6578db2750543ddfd22ee75..877341f08e776c168b618ff5cbea8e17c9cfc956 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -44,15 +44,8 @@ class BaseExecutor(BaseModel, ABC): :param event_type: 事件类型 :param data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent """ - if event_type == EventType.FLOW_START.value and isinstance(data, dict): - data = FlowStartContent( - question=self.question, - params=self.task.runtime.filled, - ).model_dump(exclude_none=True, by_alias=True) - elif event_type == EventType.FLOW_STOP.value: - data = {} - elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): - data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) + if event_type == EventType.TEXT_ADD.value and isinstance(data, str): + data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) if data is None: data = {} @@ -62,7 +55,7 @@ class BaseExecutor(BaseModel, ABC): await self.msg_queue.push_output( self.task, event_type=event_type, - data=data, # type: ignore[arg-type] + data=data, # type: ignore[arg-type] ) @abstractmethod diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a70d0d7073c50ec2290c3fd4c797bbb3efcd4501..ebd4da8ecb6e55c5220a9cd4e0f63c72842c9828 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -11,7 +11,7 @@ from pydantic import Field from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.executor.step import StepExecutor -from apps.schemas.enum_var import EventType, SpecialCallType, StepStatus +from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp from apps.schemas.task import ExecutorState, StepQueueItem @@ -46,21 +46,25 @@ class FlowExecutor(BaseExecutor): flow_id: str = Field(description="Flow ID") question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - + current_step: StepQueueItem | None = Field( + description="当前执行的步骤", + default=None + ) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state: + if self.task.state and self.task.state.flow_status != FlowStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState self.task.state = ExecutorState( flow_id=str(self.flow_id), flow_name=self.flow.name, + flow_status=FlowStatus.RUNNING, description=str(self.flow.description), - status=StepStatus.RUNNING, + step_status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", step_name="开始", @@ -70,14 +74,13 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: + async def _invoke_runner(self) -> None: """单一Step执行""" # 创建步骤Runner step_runner = StepExecutor( msg_queue=self.msg_queue, task=self.task, - step=queue_item, + step=self.current_step, background=self.background, question=self.question, ) @@ -85,23 +88,21 @@ class FlowExecutor(BaseExecutor): # 初始化步骤 await step_runner.init() # 运行Step - await step_runner.run() + await step_runner.run() # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: try: - queue_item = self.step_queue.pop() + self.current_step = self.step_queue.pop() except IndexError: break # 执行Step - await self._invoke_runner(queue_item) - + await self._invoke_runner() async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" @@ -111,14 +112,22 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + if self.current_step.step.type == SpecialCallType.CHOICE.value: + # 如果是choice节点,获取分支ID + branch_id = self.task.context[-1].output_data["branch_id"] + if branch_id: + next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) + logger.info("[FlowExecutor] 分支ID:%s", branch_id) + else: + logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") + return [] + else: + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -137,7 +146,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -150,8 +158,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -166,11 +174,12 @@ class FlowExecutor(BaseExecutor): # 插入首个步骤 self.step_queue.append(first_step) - + self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] # 运行Flow(未达终点) + is_error = False while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -183,13 +192,14 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), enable_filling=False, to_user=False, )) + is_error = True # 错误处理后结束 self._reached_end = True @@ -204,6 +214,12 @@ class FlowExecutor(BaseExecutor): for step in next_step: self.step_queue.append(step) + # 更新Task状态 + if is_error: + self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + else: + self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: self.step_queue.append(StepQueueItem( @@ -215,4 +231,7 @@ class FlowExecutor(BaseExecutor): # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 - await self.push_message(EventType.FLOW_STOP.value) + if is_error: + await self.push_message(EventType.FLOW_FAILED.value) + else: + await self.push_message(EventType.FLOW_SUCCESS.value) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa9ccd92b8f9b2d497f3983a64ff3a6981..5a95e407989f4dfc9c26c2157ccbec5d3ecdad21 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -86,8 +86,8 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.task.state.step_id = self.step.step_id # type: ignore[arg-type] + self.task.state.step_name = self.step.step.name # type: ignore[arg-type] # 获取并验证Call类 node_id = self.step.step.node @@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -127,13 +126,13 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.task.state.step_id # type: ignore[arg-type] + current_step_name = self.task.state.step_name # type: ignore[arg-type] # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] + self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,21 +155,20 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.step_status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] + self.task.state.step_id = current_step_id # type: ignore[arg-type] + self.task.state.step_name = current_step_name # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor): return content - async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) @@ -212,7 +209,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,22 +221,22 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.step_status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": str(e), "data": {}, } return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -253,16 +250,17 @@ class StepExecutor(BaseExecutor): # 更新context history = FlowStepHistory( task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_status=self.task.state.flow_status, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + step_status=self.task.state.step_status, input_data=self.obj.input, output_data=output_data, ) - self.task.context.append(history.model_dump(exclude_none=True, by_alias=True)) + self.task.context.append(history) # 推送输出 await self.push_message(EventType.STEP_OUTPUT.value, output_data) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 78aa7bc3ee869e8710e1fb02a2d9fb438d04be34..8e11083900c20c29fcaf71b6535f3b442fc2313a 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -40,7 +40,6 @@ class MCPHost: lstrip_blocks=True, ) - async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" mongo = MongoDB() @@ -59,7 +58,6 @@ class MCPHost: logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) return None - async def assemble_memory(self) -> str: """组装记忆""" task = await TaskManager.get_task_by_task_id(self._task_id) @@ -69,7 +67,7 @@ class MCPHost: context_list = [] for ctx_id in self._context_list: - context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None) + context = next((ctx for ctx in task.context if ctx.id == ctx_id), None) if not context: continue context_list.append(context) @@ -78,7 +76,6 @@ class MCPHost: context_list=context_list, ) - async def _save_memory( self, tool: MCPTool, @@ -105,11 +102,12 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, + flow_status=StepStatus.RUNNING, step_id=tool.name, step_name=tool.name, # description是规划的实际内容 step_description=plan_item.content, - status=StepStatus.SUCCESS, + step_status=StepStatus.SUCCESS, input_data=input_data, output_data=output_data, ) @@ -120,12 +118,11 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(by_alias=True, exclude_none=True)) + task.context.append(context.model_dump(exclude_none=True, by_alias=True)) await TaskManager.save_task(self._task_id, task) return output_data - async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate @@ -146,7 +143,6 @@ class MCPHost: ) return await json_generator.generate() - async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client @@ -170,7 +166,6 @@ class MCPHost: return processed_result - async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: """获取工具列表""" mongo = MongoDB() diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index b322fb0883e8ed935243389cb86066845a549631..b906fb8f75dfafb2546aaef6eb36370645c57598 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -72,7 +72,8 @@ CREATE_PLAN = dedent(r""" 2. 计划中的每一个步骤必须且只能使用一个工具。 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 - + 5.生成的计划必须要覆盖用户的目标,不能遗漏任何用户目标中的内容。 + # 生成计划时的注意事项: - 每一条计划包含3个部分: @@ -233,8 +234,8 @@ FINAL_ANSWER = dedent(r""" MEMORY_TEMPLATE = dedent(r""" {% for ctx in context_list %} - 第{{ loop.index }}步:{{ ctx.step_description }} - - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}` - - 执行状态:{{ ctx.status }} - - 得到数据:`{{ ctx.output_data }}` + - 调用工具 `{{ ctx.step_name }}`,并提供参数 `{{ ctx.input_data|tojson }}`。 + - 执行状态:{{ ctx.step_status }} + - 得到数据:`{{ ctx.output_data|tojson }}` {% endfor %} """) diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 2ff5034471c5e9c38f166c6187b76dfb4596f734..f3d6e0d4ae3aea9508d48f33499ac2bcf3de98a9 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -39,7 +39,6 @@ class MCPSelector: sql += f"'{mcp_id}', " return sql.rstrip(", ") + ")" - async def _get_top_mcp_by_embedding( self, query: str, @@ -72,7 +71,6 @@ class MCPSelector: }]) return llm_mcp_list - async def _get_mcp_by_llm( self, query: str, @@ -100,7 +98,6 @@ class MCPSelector: # 使用小模型提取JSON return await self._call_function_mcp(result, mcp_ids) - async def _call_reasoning(self, prompt: str) -> str: """调用大模型进行推理""" logger.info("[MCPHelper] 调用推理大模型") @@ -116,7 +113,6 @@ class MCPSelector: self.output_tokens += llm.output_tokens return result - async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: """调用结构化输出小模型提取JSON""" logger.info("[MCPHelper] 调用结构化输出小模型") @@ -136,7 +132,6 @@ class MCPSelector: raise return result - async def select_top_mcp( self, query: str, @@ -153,7 +148,6 @@ class MCPSelector: # 通过LLM选择最合适的 return await self._get_mcp_by_llm(query, llm_mcp_list, mcp_list) - @staticmethod async def select_top_tool(query: str, mcp_list: list[str], top_n: int = 10) -> list[MCPTool]: """选择最合适的工具""" diff --git a/apps/scheduler/mcp_agent/__init__.py b/apps/scheduler/mcp_agent/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..12f5cb68c12e4d19d830a8155eaeb0851fce897d --- /dev/null +++ b/apps/scheduler/mcp_agent/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Scheduler MCP 模块""" + +from apps.scheduler.mcp.host import MCPHost +from apps.scheduler.mcp.plan import MCPPlanner +from apps.scheduler.mcp.select import MCPSelector + +__all__ = ["MCPHost", "MCPPlanner", "MCPSelector"] diff --git a/apps/scheduler/mcp_agent/agent/base.py b/apps/scheduler/mcp_agent/agent/base.py deleted file mode 100644 index eccb58a9ce8466f67ac76040a53469edac455109..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/base.py +++ /dev/null @@ -1,196 +0,0 @@ -"""MCP Agent基类""" -import logging -from abc import ABC, abstractmethod -from contextlib import asynccontextmanager - -from pydantic import BaseModel, Field, model_validator - -from apps.common.queue import MessageQueue -from apps.schemas.enum_var import AgentState -from apps.schemas.task import Task -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.schema import Memory, Message, Role -from apps.services.activity import Activity - -logger = logging.getLogger(__name__) - - -class BaseAgent(BaseModel, ABC): - """ - 用于管理代理状态和执行的抽象基类。 - - 为状态转换、内存管理、 - 以及分步执行循环。子类必须实现`step`方法。 - """ - - msg_queue: MessageQueue - task: Task - name: str = Field(..., description="Agent名称") - agent_id: str = Field(default="", description="Agent ID") - description: str = Field(default="", description="Agent描述") - question: str - # Prompts - next_step_prompt: str | None = Field( - None, description="判断下一步动作的提示" - ) - - # Dependencies - llm: ReasoningLLM = Field(default_factory=ReasoningLLM, description="大模型实例") - memory: Memory = Field(default_factory=Memory, description="Agent记忆库") - state: AgentState = Field( - default=AgentState.IDLE, description="Agent状态" - ) - servers_id: list[str] = Field(default_factory=list, description="MCP server id") - - # Execution control - max_steps: int = Field(default=10, description="终止前的最大步长") - current_step: int = Field(default=0, description="执行中的当前步骤") - - duplicate_threshold: int = 2 - - user_prompt: str = r""" - 当前步骤:{step} 工具输出结果:{result} - 请总结当前正在执行的步骤和对应的工具输出结果,内容包括当前步骤是多少,执行的工具是什么,输出是什么。 - 最终以报告的形式展示。 - 如果工具输出结果中执行的工具为terminate,请按照状态输出本次交互过程最终结果并完成对整个报告的总结,不需要输出你的分析过程。 - """ - """用户提示词""" - - class Config: - arbitrary_types_allowed = True - extra = "allow" # Allow extra fields for flexibility in subclasses - - @model_validator(mode="after") - def initialize_agent(self) -> "BaseAgent": - """初始化Agent""" - if self.llm is None or not isinstance(self.llm, ReasoningLLM): - self.llm = ReasoningLLM() - if not isinstance(self.memory, Memory): - self.memory = Memory() - return self - - @asynccontextmanager - async def state_context(self, new_state: AgentState): - """ - Agent状态转换上下文管理器 - - Args: - new_state: 要转变的状态 - - :return: None - :raise ValueError: 如果new_state无效 - """ - if not isinstance(new_state, AgentState): - raise ValueError(f"无效状态: {new_state}") - - previous_state = self.state - self.state = new_state - try: - yield - except Exception as e: - self.state = AgentState.ERROR # Transition to ERROR on failure - raise e - finally: - self.state = previous_state # Revert to previous state - - def update_memory( - self, - role: Role, - content: str, - **kwargs, - ) -> None: - """添加信息到Agent的memory中""" - message_map = { - "user": Message.user_message, - "system": Message.system_message, - "assistant": Message.assistant_message, - "tool": lambda content, **kw: Message.tool_message(content, **kw), - } - - if role not in message_map: - raise ValueError(f"不支持的消息角色: {role}") - - # Create message with appropriate parameters based on role - kwargs = {**(kwargs if role == "tool" else {})} - self.memory.add_message(message_map[role](content, **kwargs)) - - async def run(self, request: str | None = None) -> str: - """异步执行Agent的主循环""" - self.task.runtime.question = request - if self.state != AgentState.IDLE: - raise RuntimeError(f"无法从以下状态运行智能体: {self.state}") - - if request: - self.update_memory("user", request) - - results: list[str] = [] - async with self.state_context(AgentState.RUNNING): - while ( - self.current_step < self.max_steps and self.state != AgentState.FINISHED - ): - if not await Activity.is_active(self.task.ids.user_sub): - logger.info("用户终止会话,任务停止!") - return "" - self.current_step += 1 - logger.info(f"执行步骤{self.current_step}/{self.max_steps}") - step_result = await self.step() - - # Check for stuck state - if self.is_stuck(): - self.handle_stuck_state() - result = f"Step {self.current_step}: {step_result}" - results.append(result) - - if self.current_step >= self.max_steps: - self.current_step = 0 - self.state = AgentState.IDLE - result = f"任务终止: 已达到最大步数 ({self.max_steps})" - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": result}, # type: ignore[arg-type] - ) - results.append(result) - return "\n".join(results) if results else "未执行任何步骤" - - @abstractmethod - async def step(self) -> str: - """ - 执行代理工作流程中的单个步骤。 - - 必须由子类实现,以定义具体的行为。 - """ - - def handle_stuck_state(self): - """通过添加更改策略的提示来处理卡住状态""" - stuck_prompt = "\ - 观察到重复响应。考虑新策略,避免重复已经尝试过的无效路径" - self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" - logger.warning(f"检测到智能体处于卡住状态。新增提示:{stuck_prompt}") - - def is_stuck(self) -> bool: - """通过检测重复内容来检查代理是否卡在循环中""" - if len(self.memory.messages) < 2: - return False - - last_message = self.memory.messages[-1] - if not last_message.content: - return False - - duplicate_count = sum( - 1 - for msg in reversed(self.memory.messages[:-1]) - if msg.role == "assistant" and msg.content == last_message.content - ) - - return duplicate_count >= self.duplicate_threshold - - @property - def messages(self) -> list[Message]: - """从Agent memory中检索消息列表""" - return self.memory.messages - - @messages.setter - def messages(self, value: list[Message]) -> None: - """设置Agent memory的消息列表""" - self.memory.messages = value diff --git a/apps/scheduler/mcp_agent/agent/mcp.py b/apps/scheduler/mcp_agent/agent/mcp.py deleted file mode 100644 index 378da368aca02d0d628352fcc4816b98b2921d01..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/mcp.py +++ /dev/null @@ -1,81 +0,0 @@ -"""MCP Agent""" -import logging - -from pydantic import Field - -from apps.scheduler.mcp.host import MCPHost -from apps.scheduler.mcp_agent.agent.toolcall import ToolCallAgent -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class MCPAgent(ToolCallAgent): - """ - 用于与MCP(模型上下文协议)服务器交互。 - - 使用SSE或stdio传输连接到MCP服务器 - 并使服务器的工具 - """ - - name: str = "MCPAgent" - description: str = "一个多功能的智能体,能够使用多种工具(包括基于MCP的工具)解决各种任务" - - # Add general-purpose tools to the tool collection - available_tools: ToolCollection = Field( - default_factory=lambda: ToolCollection( - Terminate(), - ), - ) - - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - _initialized: bool = False - - @classmethod - async def create(cls, **kwargs) -> "MCPAgent": # noqa: ANN003 - """创建并初始化MCP Agent实例""" - instance = cls(**kwargs) - await instance.initialize_mcp_servers() - instance._initialized = True - return instance - - async def initialize_mcp_servers(self) -> None: - """初始化与已配置的MCP服务器的连接""" - mcp_host = MCPHost( - self.task.ids.user_sub, - self.task.id, - self.agent_id, - self.description, - ) - mcps = {} - for mcp_id in self.servers_id: - client = await mcp_host.get_client(mcp_id) - if client: - mcps[mcp_id] = client - - for mcp_id, mcp_client in mcps.items(): - new_tools = [] - for tool in mcp_client.tools: - original_name = tool.name - # Always prefix with server_id to ensure uniqueness - tool_name = f"mcp_{mcp_id}_{original_name}" - - server_tool = MCPClientTool( - name=tool_name, - description=tool.description, - parameters=tool.inputSchema, - session=mcp_client.session, - server_id=mcp_id, - original_name=original_name, - ) - new_tools.append(server_tool) - self.available_tools.add_tools(*new_tools) - - async def think(self) -> bool: - """使用适当的上下文处理当前状态并决定下一步操作""" - if not self._initialized: - await self.initialize_mcp_servers() - self._initialized = True - - return await super().think() diff --git a/apps/scheduler/mcp_agent/agent/react.py b/apps/scheduler/mcp_agent/agent/react.py deleted file mode 100644 index b56efd8b195eb36c4d5718711cc5f07b5a49812f..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/react.py +++ /dev/null @@ -1,35 +0,0 @@ -from abc import ABC, abstractmethod - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp_agent.agent.base import BaseAgent -from apps.scheduler.mcp_agent.schema import Memory - - -class ReActAgent(BaseAgent, ABC): - name: str - description: str | None = None - - system_prompt: str | None = None - next_step_prompt: str | None = None - - llm: ReasoningLLM | None = Field(default_factory=ReasoningLLM) - memory: Memory = Field(default_factory=Memory) - state: AgentState = AgentState.IDLE - - @abstractmethod - async def think(self) -> bool: - """处理当前状态并决定下一步操作""" - - @abstractmethod - async def act(self) -> str: - """执行已决定的行动""" - - async def step(self) -> str: - """执行一个步骤:思考和行动""" - should_act = await self.think() - if not should_act: - return "思考完成-无需采取任何行动" - return await self.act() diff --git a/apps/scheduler/mcp_agent/agent/toolcall.py b/apps/scheduler/mcp_agent/agent/toolcall.py deleted file mode 100644 index 1e22099ce1d2e2f2f54a3bc018511acf887a91a1..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/agent/toolcall.py +++ /dev/null @@ -1,238 +0,0 @@ -import asyncio -import json -import logging -from typing import Any, Optional - -from pydantic import Field - -from apps.schemas.enum_var import AgentState -from apps.llm.function import JsonGenerator -from apps.llm.patterns import Select -from apps.scheduler.mcp_agent.agent.react import ReActAgent -from apps.scheduler.mcp_agent.schema import Function, Message, ToolCall -from apps.scheduler.mcp_agent.tool import Terminate, ToolCollection - -logger = logging.getLogger(__name__) - - -class ToolCallAgent(ReActAgent): - """用于处理工具/函数调用的基本Agent类""" - - name: str = "toolcall" - description: str = "可以执行工具调用的智能体" - - available_tools: ToolCollection = ToolCollection( - Terminate(), - ) - tool_choices: str = "auto" - special_tool_names: list[str] = Field(default_factory=lambda: [Terminate().name]) - - tool_calls: list[ToolCall] = Field(default_factory=list) - _current_base64_image: str | None = None - - max_observe: int | bool | None = None - - async def think(self) -> bool: - """使用工具处理当前状态并决定下一步行动""" - messages = [] - for message in self.messages: - if isinstance(message, Message): - message = message.to_dict() - messages.append(message) - try: - # 通过工具获得响应 - select_obj = Select() - choices = [] - for available_tool in self.available_tools.to_params(): - choices.append(available_tool.get("function")) - - tool = await select_obj.generate(question=self.question, choices=choices) - if tool in self.available_tools.tool_map: - schema = self.available_tools.tool_map[tool].parameters - json_generator = JsonGenerator( - query="根据跟定的信息,获取工具参数", - conversation=messages, - schema=schema, - ) # JsonGenerator - parameters = await json_generator.generate() - - else: - raise ValueError(f"尝试调用不存在的工具: {tool}") - except Exception as e: - raise - self.tool_calls = tool_calls = [ToolCall(id=tool, function=Function(name=tool, arguments=parameters))] - content = f"选择的执行工具为:{tool}, 参数为{parameters}" - - logger.info( - f"{self.name} 选择 {len(tool_calls) if tool_calls else 0}个工具执行" - ) - if tool_calls: - logger.info( - f"准备使用的工具: {[call.function.name for call in tool_calls]}" - ) - logger.info(f"工具参数: {tool_calls[0].function.arguments}") - - try: - - assistant_msg = ( - Message.from_tool_calls(content=content, tool_calls=self.tool_calls) - if self.tool_calls - else Message.assistant_message(content) - ) - self.memory.add_message(assistant_msg) - - if not self.tool_calls: - return bool(content) - - return bool(self.tool_calls) - except Exception as e: - logger.error(f"{self.name}的思考过程遇到了问题:: {e}") - self.memory.add_message( - Message.assistant_message( - f"处理时遇到错误: {str(e)}" - ) - ) - return False - - async def act(self) -> str: - """执行工具调用并处理其结果""" - if not self.tool_calls: - # 如果没有工具调用,则返回最后的消息内容 - return self.messages[-1].content or "没有要执行的内容或命令" - - results = [] - for command in self.tool_calls: - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"正在执行工具{command.function.name}"} - ) - - self._current_base64_image = None - - result = await self.execute_tool(command) - - if self.max_observe: - result = result[: self.max_observe] - - push_result = "" - async for chunk in self.llm.call( - messages=[{"role": "system", "content": "You are a helpful asistant."}, - {"role": "user", "content": self.user_prompt.format( - step=self.current_step, - result=result, - )}, ], streaming=False - ): - push_result += chunk - self.task.tokens.input_tokens += self.llm.input_tokens - self.task.tokens.output_tokens += self.llm.output_tokens - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": push_result}, # type: ignore[arg-type] - ) - - await self.msg_queue.push_output( - self.task, - event_type="text.add", - data={"text": f"工具{command.function.name}执行完成"}, # type: ignore[arg-type] - ) - - logger.info( - f"工具'{command.function.name}'执行完成! 执行结果为: {result}" - ) - - # 将工具响应添加到内存 - tool_msg = Message.tool_message( - content=result, - tool_call_id=command.id, - name=command.function.name, - ) - self.memory.add_message(tool_msg) - results.append(result) - self.question += ( - f"\n已执行工具{command.function.name}, " - f"作用是{self.available_tools.tool_map[command.function.name].description},结果为{result}" - ) - - return "\n\n".join(results) - - async def execute_tool(self, command: ToolCall) -> str: - """执行单个工具调用""" - if not command or not command.function or not command.function.name: - return "错误:无效的命令格式" - - name = command.function.name - if name not in self.available_tools.tool_map: - return f"错误:未知工具 '{name}'" - - try: - # 解析参数 - args = command.function.arguments - # 执行工具 - logger.info(f"激活工具:'{name}'...") - result = await self.available_tools.execute(name=name, tool_input=args) - - # 执行特殊工具 - await self._handle_special_tool(name=name, result=result) - - # 格式化结果 - observation = ( - f"观察到执行的工具 `{name}`的输出:\n{str(result)}" - if result - else f"工具 `{name}` 已完成,无输出" - ) - - return observation - except json.JSONDecodeError: - error_msg = f"解析{name}的参数时出错:JSON格式无效" - logger.error( - f"{name}”的参数没有意义-无效的JSON,参数:{command.function.arguments}" - ) - return f"错误: {error_msg}" - except Exception as e: - error_msg = f"工具 '{name}' 遇到问题: {str(e)}" - logger.exception(error_msg) - return f"错误: {error_msg}" - - async def _handle_special_tool(self, name: str, result: Any, **kwargs): - """处理特殊工具的执行和状态变化""" - if not self._is_special_tool(name): - return - - if self._should_finish_execution(name=name, result=result, **kwargs): - # 将智能体状态设为finished - logger.info(f"特殊工具'{name}'已完成任务!") - self.state = AgentState.FINISHED - - @staticmethod - def _should_finish_execution(**kwargs) -> bool: - """确定工具执行是否应完成""" - return True - - def _is_special_tool(self, name: str) -> bool: - """检查工具名称是否在特殊工具列表中""" - return name.lower() in [n.lower() for n in self.special_tool_names] - - async def cleanup(self): - """清理Agent工具使用的资源。""" - logger.info(f"正在清理智能体的资源'{self.name}'...") - for tool_name, tool_instance in self.available_tools.tool_map.items(): - if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( - tool_instance.cleanup - ): - try: - logger.debug(f"清理工具: {tool_name}") - await tool_instance.cleanup() - except Exception as e: - logger.error( - f"清理工具时发生错误'{tool_name}': {e}", exc_info=True - ) - logger.info(f"智能体清理完成'{self.name}'.") - - async def run(self, request: Optional[str] = None) -> str: - """运行Agent""" - try: - return await super().run(request) - finally: - await self.cleanup() diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py new file mode 100644 index 0000000000000000000000000000000000000000..ac3829b418be54bfc5df6ae3d61db82fe0fc129e --- /dev/null +++ b/apps/scheduler/mcp_agent/base.py @@ -0,0 +1,61 @@ +from typing import Any +import json +from jsonschema import validate +import logging +from apps.llm.function import JsonGenerator +from apps.llm.reasoning import ReasoningLLM + +logger = logging.getLogger(__name__) + + +class McpBase: + @staticmethod + async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理结果""" + # 调用推理大模型 + message = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + result = "" + async for chunk in resoning_llm.call( + message, + streaming=False, + temperature=0.07, + result_only=True, + ): + result += chunk + + return result + + @staticmethod + async def _parse_result(result: str, schema: dict[str, Any], left_str: str = '{', right_str: str = '}') -> str: + """解析推理结果""" + left_index = result.find(left_str) + right_index = result.rfind(right_str) + flag = True + if left_str == -1 or right_str == -1: + flag = False + + if left_index > right_index: + flag = False + if flag: + try: + tmp_js = json.loads(result[left_index:right_index + 1]) + validate(instance=tmp_js, schema=schema) + except Exception as e: + logger.error("[McpBase] 解析结果失败: %s", e) + flag = False + if not flag: + json_generator = JsonGenerator( + "请提取下面内容中的json\n\n", + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "请提取下面内容中的json\n\n"+result}, + ], + schema, + ) + json_result = await json_generator.generate() + else: + json_result = json.loads(result[left_index:right_index + 1]) + return json_result diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py new file mode 100644 index 0000000000000000000000000000000000000000..44b04d1fff502037e9c35f1232dfce6959a57009 --- /dev/null +++ b/apps/scheduler/mcp_agent/host.py @@ -0,0 +1,104 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP宿主""" + +import json +import logging +from typing import Any + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment +from mcp.types import TextContent + +from apps.common.mongo import MongoDB +from apps.llm.function import JsonGenerator +from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE +from apps.scheduler.pool.mcp.client import MCPClient +from apps.scheduler.pool.mcp.pool import MCPPool +from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS +from apps.schemas.enum_var import StepStatus +from apps.schemas.mcp import MCPPlanItem, MCPTool +from apps.schemas.task import Task, FlowStepHistory +from apps.services.task import TaskManager + +logger = logging.getLogger(__name__) + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, +) + + +def tojson_filter(value): + return json.dumps(value, ensure_ascii=False, separators=(',', ':')) + + +_env.filters['tojson'] = tojson_filter + + +class MCPHost: + """MCP宿主服务""" + + @staticmethod + async def assemble_memory(task: Task) -> str: + """组装记忆""" + + return _env.from_string(MEMORY_TEMPLATE).render( + context_list=task.context, + ) + + async def _get_first_input_params(mcp_tool: MCPTool, goal: str, query: str, task: Task) -> dict[str, Any]: + """填充工具参数""" + # 更清晰的输入·指令,这样可以调用generate + llm_query = rf""" + 你是一个MCP工具参数生成器,你的任务是根据用户的目标和查询,生成MCP工具的输入参数。 + 请充分利用背景中的信息,生成满足以下目标的工具参数: + {query} + 注意: + 1. 你需要根据工具的输入schema来生成参数。 + 2. 如果工具的输入schema中有枚举类型的字段,请确保生成的参数符合这些枚举值。 + 下面是工具的简介: + 工具名称:{mcp_tool.name} + 工具描述:{mcp_tool.description} + """ + + # 进行生成 + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "用户的总目标是:\n"+goal+(await MCPHost.assemble_memory(task))}, + ], + mcp_tool.input_schema, + ) + return await json_generator.generate() + + async def _fill_params(mcp_tool: MCPTool, + goal: str, + current_goal: str, + current_input: dict[str, Any], + error_message: str = "", params: dict[str, Any] = {}, + params_description: str = "") -> dict[str, Any]: + llm_query = "请生成修复之后的工具参数" + prompt = _env.from_string(REPAIR_PARAMS).render( + tool_name=mcp_tool.name, + goal=goal, + current_goal=current_goal, + tool_description=mcp_tool.description, + input_schema=mcp_tool.input_schema, + current_input=current_input, + error_message=error_message, + params=params, + params_description=params_description, + ) + + json_generator = JsonGenerator( + llm_query, + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + mcp_tool.input_schema, + ) + return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py new file mode 100644 index 0000000000000000000000000000000000000000..84c08d60227147849f8b9f4b3d91f79dc43297b0 --- /dev/null +++ b/apps/scheduler/mcp_agent/plan.py @@ -0,0 +1,368 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP 用户目标拆解与规划""" +from typing import Any, AsyncGenerator +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm.reasoning import ReasoningLLM +from apps.llm.function import JsonGenerator +from apps.scheduler.mcp_agent.base import McpBase +from apps.scheduler.mcp_agent.prompt import ( + EVALUATE_GOAL, + GENERATE_FLOW_NAME, + GET_REPLAN_START_STEP_INDEX, + CREATE_PLAN, + RECREATE_PLAN, + GEN_STEP, + TOOL_SKIP, + RISK_EVALUATE, + TOOL_EXECUTE_ERROR_TYPE_ANALYSIS, + IS_PARAM_ERROR, + CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, + GET_MISSING_PARAMS, + FINAL_ANSWER +) +from apps.schemas.task import Task +from apps.schemas.mcp import ( + GoalEvaluationResult, + RestartStepIndex, + ToolSkip, + ToolRisk, + IsParamError, + ToolExcutionErrorType, + MCPPlan, + Step, + MCPPlanItem, + MCPTool +) +from apps.scheduler.slot.slot import Slot + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, +) + + +class MCPPlanner(McpBase): + """MCP 用户目标拆解与规划""" + + @staticmethod + async def evaluate_goal( + goal: str, + tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM()) -> GoalEvaluationResult: + """评估用户目标的可行性""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_evaluation(goal, tool_list, resoning_llm) + + # 解析为结构化数据 + evaluation = await MCPPlanner._parse_evaluation_result(result) + + # 返回评估结果 + return evaluation + + @staticmethod + async def _get_reasoning_evaluation( + goal, tool_list: list[MCPTool], + resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理大模型的评估结果""" + template = _env.from_string(EVALUATE_GOAL) + prompt = template.render( + goal=goal, + tools=tool_list, + ) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) + return result + + @staticmethod + async def _parse_evaluation_result(result: str) -> GoalEvaluationResult: + """将推理结果解析为结构化数据""" + schema = GoalEvaluationResult.model_json_schema() + evaluation = await MCPPlanner._parse_result(result, schema) + # 使用GoalEvaluationResult模型解析结果 + return GoalEvaluationResult.model_validate(evaluation) + + async def get_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取当前流程的名称""" + result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm) + return result + + @staticmethod + async def _get_reasoning_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理大模型的流程名称""" + template = _env.from_string(GENERATE_FLOW_NAME) + prompt = template.render(goal=user_goal) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) + return result + + @staticmethod + async def get_replan_start_step_index( + user_goal: str, error_message: str, current_plan: MCPPlan | None = None, + history: str = "", + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> RestartStepIndex: + """获取重新规划的步骤索引""" + # 获取推理结果 + template = _env.from_string(GET_REPLAN_START_STEP_INDEX) + prompt = template.render( + goal=user_goal, + error_message=error_message, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + history=history, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + schema = RestartStepIndex.model_json_schema() + schema["properties"]["start_index"]["maximum"] = len(current_plan.plans) - 1 + schema["properties"]["start_index"]["minimum"] = 0 + restart_index = await MCPPlanner._parse_result(result, schema) + # 使用RestartStepIndex模型解析结果 + return RestartStepIndex.model_validate(restart_index) + + @staticmethod + async def create_plan( + user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 6, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan: + """规划下一步的执行流程,并输出""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_plan(user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm) + + # 解析为结构化数据 + return await MCPPlanner._parse_plan_result(result, max_steps) + + @staticmethod + async def _get_reasoning_plan( + user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 10, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理大模型的结果""" + # 格式化Prompt + tool_ids = [tool.id for tool in tool_list] + if is_replan: + template = _env.from_string(RECREATE_PLAN) + prompt = template.render( + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + error_message=error_message, + goal=user_goal, + tools=tool_list, + max_num=max_steps, + ) + else: + template = _env.from_string(CREATE_PLAN) + prompt = template.render( + goal=user_goal, + tools=tool_list, + max_num=max_steps, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return result + + @staticmethod + async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan: + """将推理结果解析为结构化数据""" + # 格式化Prompt + schema = MCPPlan.model_json_schema() + schema["properties"]["plans"]["maxItems"] = max_steps + plan = await MCPPlanner._parse_result(result, schema) + # 使用Function模型解析结果 + return MCPPlan.model_validate(plan) + + @staticmethod + async def create_next_step( + goal: str, history: str, tools: list[MCPTool], + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> Step: + """创建下一步的执行步骤""" + # 获取推理结果 + template = _env.from_string(GEN_STEP) + prompt = template.render(goal=goal, history=history, tools=tools) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + + # 解析为结构化数据 + schema = Step.model_json_schema() + if "enum" not in schema["properties"]["tool_id"]: + schema["properties"]["tool_id"]["enum"] = [] + for tool in tools: + schema["properties"]["tool_id"]["enum"].append(tool.id) + step = await MCPPlanner._parse_result(result, schema) + # 使用Step模型解析结果 + return Step.model_validate(step) + + @staticmethod + async def tool_skip( + task: Task, step_id: str, step_name: str, step_instruction: str, step_content: str, + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolSkip: + """判断当前步骤是否需要跳过""" + # 获取推理结果 + template = _env.from_string(TOOL_SKIP) + from apps.scheduler.mcp_agent.host import MCPHost + history = await MCPHost.assemble_memory(task) + prompt = template.render( + step_id=step_id, + step_name=step_name, + step_instruction=step_instruction, + step_content=step_content, + history=history, + goal=task.runtime.question + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + + # 解析为结构化数据 + schema = ToolSkip.model_json_schema() + skip_result = await MCPPlanner._parse_result(result, schema) + # 使用ToolSkip模型解析结果 + return ToolSkip.model_validate(skip_result) + + @staticmethod + async def get_tool_risk( + tool: MCPTool, input_parm: dict[str, Any], + additional_info: str = "", resoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolRisk: + """获取MCP工具的风险评估结果""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_risk(tool, input_parm, additional_info, resoning_llm) + + # 解析为结构化数据 + risk = await MCPPlanner._parse_risk_result(result) + + # 返回风险评估结果 + return risk + + @staticmethod + async def _get_reasoning_risk( + tool: MCPTool, input_param: dict[str, Any], + additional_info: str, resoning_llm: ReasoningLLM) -> str: + """获取推理大模型的风险评估结果""" + template = _env.from_string(RISK_EVALUATE) + prompt = template.render( + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + additional_info=additional_info, + ) + result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) + return result + + @staticmethod + async def _parse_risk_result(result: str) -> ToolRisk: + """将推理结果解析为结构化数据""" + schema = ToolRisk.model_json_schema() + risk = await MCPPlanner._parse_result(result, schema) + # 使用ToolRisk模型解析结果 + return ToolRisk.model_validate(risk) + + @staticmethod + async def _get_reasoning_tool_execute_error_type( + user_goal: str, current_plan: MCPPlan, + tool: MCPTool, input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """获取推理大模型的工具执行错误类型""" + template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS) + prompt = template.render( + goal=user_goal, + current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return result + + @staticmethod + async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType: + """将推理结果解析为工具执行错误类型""" + schema = ToolExcutionErrorType.model_json_schema() + error_type = await MCPPlanner._parse_result(result, schema) + # 使用ToolExcutionErrorType模型解析结果 + return ToolExcutionErrorType.model_validate(error_type) + + @staticmethod + async def get_tool_execute_error_type( + user_goal: str, current_plan: MCPPlan, + tool: MCPTool, input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolExcutionErrorType: + """获取MCP工具执行错误类型""" + # 获取推理结果 + result = await MCPPlanner._get_reasoning_tool_execute_error_type( + user_goal, current_plan, tool, input_param, error_message, reasoning_llm) + error_type = await MCPPlanner._parse_tool_execute_error_type_result(result) + # 返回工具执行错误类型 + return error_type + + @staticmethod + async def is_param_error( + goal: str, history: str, error_message: str, tool: MCPTool, step_description: str, input_params: dict + [str, Any], + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> IsParamError: + """判断错误信息是否是参数错误""" + tmplate = _env.from_string(IS_PARAM_ERROR) + prompt = tmplate.render( + goal=goal, + history=history, + step_id=tool.id, + step_name=tool.name, + step_description=step_description, + input_params=input_params, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + schema = IsParamError.model_json_schema() + is_param_error = await MCPPlanner._parse_result(result, schema) + # 使用IsParamError模型解析结果 + return IsParamError.model_validate(is_param_error) + + @staticmethod + async def change_err_message_to_description( + error_message: str, tool: MCPTool, input_params: dict[str, Any], + reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + """将错误信息转换为工具描述""" + template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION) + prompt = template.render( + error_message=error_message, + tool_name=tool.name, + tool_description=tool.description, + input_schema=tool.input_schema, + input_params=input_params, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + return result + + @staticmethod + async def get_missing_param( + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> list[str]: + """获取缺失的参数""" + slot = Slot(schema=tool.input_schema) + template = _env.from_string(GET_MISSING_PARAMS) + schema_with_null = slot.add_null_to_basic_types() + prompt = template.render( + tool_name=tool.name, + tool_description=tool.description, + input_param=input_param, + schema=schema_with_null, + error_message=error_message, + ) + result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) + # 解析为结构化数据 + input_param_with_null = await MCPPlanner._parse_result(result, schema_with_null) + return input_param_with_null + + @staticmethod + async def generate_answer( + user_goal: str, memory: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> AsyncGenerator[ + str, None]: + """生成最终回答""" + template = _env.from_string(FINAL_ANSWER) + prompt = template.render( + memory=memory, + goal=user_goal, + ) + async for chunk in resoning_llm.call( + [{"role": "user", "content": prompt}], + streaming=True, + temperature=0.07, + ): + yield chunk diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..31e8ce5415dad9bca6ef8e5069a125975070d658 --- /dev/null +++ b/apps/scheduler/mcp_agent/prompt.py @@ -0,0 +1,1083 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""MCP相关的大模型Prompt""" + +from textwrap import dedent + +MCP_SELECT = dedent(r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,选择最合适的MCP Server。 + + ## 选择MCP Server时的注意事项: + + 1. 确保充分理解当前目标,选择最合适的MCP Server。 + 2. 请在给定的MCP Server列表中选择,不要自己生成MCP Server。 + 3. 请先给出你选择的理由,再给出你的选择。 + 4. 当前目标将在下面给出,MCP Server列表也会在下面给出。 + 请将你的思考过程放在"思考过程"部分,将你的选择放在"选择结果"部分。 + 5. 选择必须是JSON格式,严格按照下面的模板,不要输出任何其他内容: + + ```json + { + "mcp": "你选择的MCP Server的名称" + } + ``` + + 6. 下面的示例仅供参考,不要将示例中的内容作为选择MCP Server的依据。 + + ## 示例 + + ### 目标 + + 我需要一个MCP Server来完成一个任务。 + + ### MCP Server列表 + + - **mcp_1**: "MCP Server 1";MCP Server 1的描述 + - **mcp_2**: "MCP Server 2";MCP Server 2的描述 + + ### 请一步一步思考: + + 因为当前目标需要一个MCP Server来完成一个任务,所以选择mcp_1。 + + ### 选择结果 + + ```json + { + "mcp": "mcp_1" + } + ``` + + ## 现在开始! + + ### 目标 + + {{goal}} + + ### MCP Server列表 + + {% for mcp in mcp_list %} + - **{{mcp.id}}**: "{{mcp.name}}";{{mcp.description}} + {% endfor %} + + ### 请一步一步思考: + +""") +TOOL_SELECT = dedent(r""" + 你是一个乐于助人的智能助手。 + 你的任务是:根据当前目标,附加信息,选择最合适的MCP工具。 + ## 选择MCP工具时的注意事项: + 1. 确保充分理解当前目标,选择实现目标所需的MCP工具。 + 2. 不要选择不存在的工具。 + 3. 可以选择一些辅助工具,但必须确保这些工具与当前目标相关。 + 4. 注意,返回的工具ID必须是MCP工具的ID,而不是名称。 + 5. 可以多选择一些工具,用于应对不同的情况。 + 必须按照以下格式生成选择结果,不要输出任何其他内容: + ```json + { + "tool_ids": ["工具ID1", "工具ID2", ...] + } + ``` + + # 示例 + ## 目标 + 调优mysql性能 + ## MCP工具列表 + + - mcp_tool_1 MySQL链接池工具;用于优化MySQL链接池 + - mcp_tool_2 MySQL性能调优工具;用于分析MySQL性能瓶颈 + - mcp_tool_3 MySQL查询优化工具;用于优化MySQL查询语句 + - mcp_tool_4 MySQL索引优化工具;用于优化MySQL索引 + - mcp_tool_5 文件存储工具;用于存储文件 + - mcp_tool_6 mongoDB工具;用于操作MongoDB数据库 + + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```json + { + "max_connections": 1000, + "innodb_buffer_pool_size": "1G", + "query_cache_size": "64M" + } + ##输出 + ```json + { + "tool_ids": ["mcp_tool_1", "mcp_tool_2", "mcp_tool_3", "mcp_tool_4"] + } + ``` + # 现在开始! + ## 目标 + {{goal}} + ## MCP工具列表 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + ## 附加信息 + {{additional_info}} + # 输出 + """ + ) + +EVALUATE_GOAL = dedent(r""" + 你是一个计划评估器。 + 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 + 如果能够完成,请返回`true`,否则返回`false`。 + 推理过程必须清晰明了,能够让人理解你的判断依据。 + 必须按照以下格式回答: + ```json + { + "can_complete": true/false, + "resoning": "你的推理过程" + } + ``` + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + + # 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - mysql_analyzer 分析MySQL数据库性能 + - performance_tuner 调优数据库性能 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + + # 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf + + ## + ```json + { + "can_complete": true, + "resoning": "当前的工具集合中包含mysql_analyzer和performance_tuner,能够完成对MySQL数据库的性能分析和调优,因此可以完成用户的目标。" + } + ``` + + # 目标 + {{goal}} + + # 工具集合 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # 附加信息 + {{additional_info}} + +""") +GENERATE_FLOW_NAME = dedent(r""" + 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。 + + # 生成流程名称时的注意事项: + 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。 + 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 + 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。 + 4. 流程名称应该尽量简短,小于20个字或者单词。 + 5. 只输出流程名称,不要输出其他内容。 + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 输出 + 扫描MySQL数据库并分析性能瓶颈,进行调优 + # 现在开始生成流程名称: + # 目标 + {{goal}} + # 输出 + """) +GET_REPLAN_START_STEP_INDEX = dedent(r""" + 你是一个智能助手,你的任务是根据用户的目标、报错信息和当前计划和历史,获取重新规划的步骤起始索引。 + + # 样例 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 报错信息 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 当前计划 + ```json + { + "plans": [ + { + "step_id": "step_1", + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描 + }, + { + "step_id": "step_2", + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + } + ] + } + # 历史 + [ + { + id: "0", + task_id: "task_1", + flow_id: "flow_1", + flow_name: "MYSQL性能调优", + flow_status: "RUNNING", + step_id: "step_1", + step_name: "生成端口扫描命令", + step_description: "生成端口扫描命令:扫描当前MySQL数据库的端口", + step_status: "FAILED", + input_data: { + "command": "nmap -p 3306 + "target": "localhost" + }, + output_data: { + "error": "- bash: curl: command not found" + } + } + ] + # 输出 + { + "start_index": 0, + "reasoning": "当前计划的第一步就失败了,报错信息显示curl命令未找到,可能是因为没有安装curl工具,因此需要从第一步重新规划。" + } + # 现在开始获取重新规划的步骤起始索引: + # 目标 + {{goal}} + # 报错信息 + {{error_message}} + # 当前计划 + {{current_plan}} + # 历史 + {{history}} + # 输出 + """) + +CREATE_PLAN = dedent(r""" + 你是一个计划生成器。 + 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 不要选择不存在的工具。 + 5. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + 6. 生成的计划必须要覆盖用户的目标,当然需要考虑一些意外情况,可以有一定的冗余步骤。 + + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # 样例 + + # 目标 + + 在后台运行一个新的alpine: latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + + # 计划 + + + 1. 这个目标需要使用Docker来完成, 首先需要选择合适的MCP Server + 2. 目标可以拆解为以下几个部分: + - 运行alpine: latest容器 + - 挂载主机目录 + - 在后台运行 + - 执行top命令 + 3. 需要先选择MCP Server, 然后生成Docker命令, 最后执行命令 + ```json + { + "plans": [ + { + "content": "选择一个支持Docker的MCP Server", + "tool": "mcp_selector", + "instruction": "需要一个支持Docker容器运行的MCP Server" + }, + { + "content": "使用第一步选择的MCP Server,生成Docker命令", + "tool": "command_generator", + "instruction": "生成Docker命令:在后台运行alpine:latest容器,挂载/root到/data,执行top命令" + }, + { + "content": "执行第二步生成的Docker命令", + "tool": "command_executor", + "instruction": "执行Docker命令" + }, + { + "content": "任务执行完成,容器已在后台运行", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始生成计划: + + # 目标 + + {{goal}} + + # 计划 +""") +RECREATE_PLAN = dedent(r""" + 你是一个计划重建器。 + 请根据用户的目标、当前计划和运行报错,重新生成一个计划。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 你的计划必须避免之前的错误,并且能够成功执行。 + 5. 不要选择不存在的工具。 + 6. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + 7. 生成的计划必须要覆盖用户的目标,当然需要考虑一些意外情况,可以有一定的冗余步骤。 + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。 +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{max_num}}条,且每条计划内容应少于150字。 + + # 样例 + + # 目标 + + 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。 + # 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - command_generator 生成命令行指令 + - tool_selector 选择合适的工具 + - command_executor 执行命令行指令 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + # 当前计划 + ```json + { + "plans": [ + { + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行第一步生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + # 运行报错 + 执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + # 重新生成的计划 + + + 1. 这个目标需要使用网络扫描工具来完成, 首先需要选择合适的网络扫描工具 + 2. 目标可以拆解为以下几个部分: + - 生成端口扫描命令 + - 执行端口扫描命令 + 3.但是在执行端口扫描命令时,出现了错误:`- bash: curl: command not found`。 + 4.我将计划调整为: + - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具 + - 执行这个命令,查看当前机器支持哪些网络扫描工具 + - 然后从中选择一个网络扫描工具 + - 基于选择的网络扫描工具,生成端口扫描命令 + - 执行端口扫描命令 + ```json + { + "plans": [ + { + "content": "需要生成一条命令查看当前机器支持哪些网络扫描工具", + "tool": "command_generator", + "instruction": "选择一个前机器支持哪些网络扫描工具" + }, + { + "content": "执行第一步中生成的命令,查看当前机器支持哪些网络扫描工具", + "tool": "command_executor", + "instruction": "执行第一步中生成的命令" + }, + { + "content": "从第二步执行结果中选择一个网络扫描工具,生成端口扫描命令", + "tool": "tool_selector", + "instruction": "选择一个网络扫描工具,生成端口扫描命令" + }, + { + "content": "基于第三步中选择的网络扫描工具,生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "执行第四步中生成的端口扫描命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始重新生成计划: + + # 目标 + + {{goal}} + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + + + # 当前计划 + {{current_plan}} + + # 运行报错 + {{error_message}} + + # 重新生成的计划 +""") +GEN_STEP = dedent(r""" + 你是一个计划生成器。 + 请根据用户的目标、当前计划和历史,生成一个新的步骤。 + + # 一个好的计划步骤应该: + 1.使用最适合的工具来完成当前步骤。 + 2.能够基于当前的计划和历史,完成阶段性的任务。 + 3.不要选择不存在的工具。 + 4.如果你认为当前已经达成了用户的目标,可以直接返回Final工具,表示计划执行结束。 + + # 样例 1 + # 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优,我的ip是192.168.1.1,数据库端口是3306,用户名是root,密码是password + # 历史记录 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `帮我生成一个mysql端口扫描命令` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + # 工具 + + - mcp_tool_1 mysql_analyzer;用于分析数据库性能/description> + - mcp_tool_2 文件存储工具;用于存储文件 + - mcp_tool_3 mongoDB工具;用于操作MongoDB数据库 + - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + # 输出 + ```json + { + "tool_id": "mcp_tool_1", // 选择的工具ID + "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能", + } + ``` + # 样例二 + # 目标 + 计划从杭州到北京的旅游计划 + # 历史记录 + 第1步:将杭州转换为经纬度坐标 + - 调用工具 `maps_geo_planner`,并提供参数 `{"city_from": "杭州", "address": "西湖"}` + - 执行状态:成功 + - 得到数据:`{"location": "123.456, 78.901"}` + 第2步:查询杭州的天气 + - 调用工具 `weather_query`,并提供参数 `{"location": "123.456, 78.901"}` + - 执行状态:成功 + - 得到数据:`{"weather": "晴", "temperature": "25°C"}` + 第3步:将北京转换为经纬度坐标 + - 调用工具 `maps_geo_planner`,并提供参数 `{"city_from": "北京", "address": "天安门"}` + - 执行状态:成功 + - 得到数据:`{"location": "123.456, 78.901"}` + 第4步:查询北京的天气 + - 调用工具 `weather_query`,并提供参数 `{"location": "123.456, 78.901"}` + - 执行状态:成功 + - 得到数据:`{"weather": "晴", "temperature": "25°C"}` + # 工具 + + - mcp_tool_4 maps_geo_planner;将详细的结构化地址转换为经纬度坐标。支持对地标性名胜景区、建筑物名称解析为经纬度坐标 + - mcp_tool_5 weather_query;天气查询,用于查询天气信息 + - mcp_tool_6 maps_direction_transit_integrated;根据用户起终点经纬度坐标规划综合各类公共(火车、公交、地铁)交通方式的通勤方案,并且返回通勤方案的数据,跨城场景下必须传起点城市与终点城市 + - Final Final;结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + # 输出 + ```json + { + "tool_id": "mcp_tool_6", // 选择的工具ID + "description": "规划从杭州到北京的综合公共交通方式的通勤方案" + } + ``` + # 现在开始生成步骤: + # 目标 + {{goal}} + # 历史记录 + {{history}} + # 工具 + + {% for tool in tools %} + - {{tool.id}} {{tool.name}};{{tool.description}} + {% endfor %} + +""") + +TOOL_SKIP = dedent(r""" + 你是一个计划执行器。 + 你的任务是根据当前的计划和用户目标,判断当前步骤是否需要跳过。 + 如果需要跳过,请返回`true`,否则返回`false`。 + 必须按照以下格式回答: + ```json + { + "skip": true/false, + } + ``` + 注意: + 1.你的判断要谨慎,在历史消息中有足够的上下文信息时,才可以判断是否跳过当前步骤。 + # 样例 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 历史 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + 第3步:分析端口扫描结果 + - 调用工具 `mysql_analyzer`,并提供参数 `{"host": "192.168.1.1", "port": 3306, "username": "root", "password": "password"}` + - 执行状态:成功 + - 得到数据:`{"performance": "good", "bottleneck": "none"}` + # 当前步骤 + + step_4 + command_generator + 生成MySQL性能调优命令 + 生成MySQL性能调优命令:调优MySQL数据库性能 + + # 输出 + ```json + { + "skip": true + } + ``` + # 用户目标 + {{goal}} + # 历史 + {{history}} + # 当前步骤 + + {{step_id}} + {{step_name}} + {{step_instruction}} + {{step_content}} + + # 输出 + """ + ) +RISK_EVALUATE = dedent(r""" + 你是一个工具执行计划评估器。 + 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。 + ```json + { + "risk": "low/medium/high", + "reason": "提示信息" + } + ``` + # 样例 + # 工具名称 + mysql_analyzer + # 工具描述 + 分析MySQL数据库性能 + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```ini + [mysqld] + innodb_buffer_pool_size=1G + innodb_log_file_size=256M + ``` + # 输出 + ```json + { + "risk": "中", + "reason": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" + } + ``` + # 工具 + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < / tool > + # 工具入参 + {{input_param}} + # 附加信息 + {{additional_info}} + # 输出 + """ + ) +# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划 +TOOL_EXECUTE_ERROR_TYPE_ANALYSIS = dedent(r""" + 你是一个计划决策器。 + 你的任务是根据用户目标、当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 + 请根据以下规则进行判断: + 1. 仅通过补充工具入参来解决问题的,返回 missing_param; + 2. 需要重计划当前步骤的,返回 decorrect_plan + 3.推理过程必须清晰明了,能够让人理解你的判断依据,并且不超过100字。 + 你的输出要以json格式返回,格式如下: + ```json + { + "error_type": "missing_param/decorrect_plan, + "reason": "你的推理过程" + } + ``` + # 样例 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 当前计划 + {"plans": [ + { + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ]} + # 当前使用的工具 + < tool > + < name > command_executor < /name > + < description > 执行命令行指令 < /description > + < / tool > + # 工具入参 + { + "command": "nmap -sS -p--open 192.168.1.1" + } + # 工具运行报错 + 执行端口扫描命令时,出现了错误:`- bash: nmap: command not found`。 + # 输出 + ```json + { + "error_type": "decorrect_plan", + "reason": "当前计划的第二步执行失败,报错信息显示nmap命令未找到,可能是因为没有安装nmap工具,因此需要重计划当前步骤。" + } + ``` + # 用户目标 + {{goal}} + # 当前计划 + {{current_plan}} + # 当前使用的工具 + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < / tool > + # 工具入参 + {{input_param}} + # 工具运行报错 + {{error_message}} + # 输出 + """ + ) +IS_PARAM_ERROR = dedent(r""" + 你是一个计划执行专家,你的任务是判断当前的步骤执行失败是否是因为参数错误导致的, + 如果是,请返回`true`,否则返回`false`。 + 必须按照以下格式回答: + ```json + { + "is_param_error": true/false, + } + ``` + # 样例 + # 用户目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 历史 + 第1步:生成端口扫描命令 + - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` + 第2步:执行端口扫描命令 + - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` + - 执行状态:成功 + - 得到数据:`{"result": "success"}` + # 当前步骤 + + step_3 + mysql_analyzer + 分析MySQL数据库性能 + + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具运行报错 + 执行MySQL性能分析命令时,出现了错误:`host is not correct`。 + + # 输出 + ```json + { + "is_param_error": true + } + ``` + # 用户目标 + {{goal}} + # 历史 + {{history}} + # 当前步骤 + + {{step_id}} + {{step_name}} + {{step_instruction}} + + # 工具入参 + {{input_param}} + # 工具运行报错 + {{error_message}} + # 输出 + """ + ) +# 将当前程序运行的报错转换为自然语言 +CHANGE_ERROR_MESSAGE_TO_DESCRIPTION = dedent(r""" + 你是一个智能助手,你的任务是将当前程序运行的报错转换为自然语言描述。 + 请根据以下规则进行转换: + 1. 将报错信息转换为自然语言描述,描述应该简洁明了,能够让人理解报错的原因和影响。 + 2. 描述应该包含报错的具体内容和可能的解决方案。 + 3. 描述应该避免使用过于专业的术语,以便用户能够理解。 + 4. 描述应该尽量简短,控制在50字以内。 + 5. 只输出自然语言描述,不要输出其他内容。 + # 样例 + # 工具信息 + < tool > + < name > port_scanner < /name > + < description > 扫描主机端口 < /description > + < input_schema > + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "主机地址" + }, + "port": { + "type": "integer", + "description": "端口号" + }, + "username": { + "type": "string", + "description": "用户名" + }, + "password": { + "type": "string", + "description": "密码" + } + }, + "required": ["host", "port", "username", "password"] + } + < /input_schema > + < / tool > + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 报错信息 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 输出 + 扫描端口时发生错误:密码不正确。请检查输入的密码是否正确,并重试。 + # 现在开始转换报错信息: + # 工具信息 + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < input_schema > + {{input_schema}} + < /input_schema > + < / tool > + # 工具入参 + {{input_params}} + # 报错信息 + {{error_message}} + # 输出 + """) +# 获取缺失的参数的json结构体 +GET_MISSING_PARAMS = dedent(r""" + 你是一个工具参数获取器。 + 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。 + ```json + { + "host": "请补充主机地址", + "port": "请补充端口号", + "username": "请补充用户名", + "password": "请补充密码" + } + ``` + # 样例 + # 工具名称 + mysql_analyzer + # 工具描述 + 分析MySQL数据库性能 + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具入参schema + { + "type": "object", + "properties": { + "host": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的主机地址(可以为字符串或null)" + }, + "port": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的端口号(可以是数字、字符串或null)" + }, + "username": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的用户名(可以为字符串或null)" + }, + "password": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的密码(可以为字符串或null)" + } + }, + "required": ["host", "port", "username", "password"] + } + # 运行报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": null, + "password": null + } + ``` + # 工具 + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < / tool > + # 工具入参 + {{input_param}} + # 工具入参schema(部分字段允许为null) + {{input_schema}} + # 运行报错 + {{error_message}} + # 输出 + """ + ) +REPAIR_PARAMS = dedent(r""" + 你是一个工具参数修复器。 + 你的任务是根据当前的工具信息、目标、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。 + + 注意: + 1.最终修复的参数要符合目标和工具入参的schema。 + + # 样例 + # 工具信息 + < tool > + < name > mysql_analyzer < /name > + < description > 分析MySQL数据库性能 < /description > + < / tool > + # 总目标 + 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + # 当前阶段目标 + 我要连接MySQL数据库,分析性能瓶颈,并调优。 + # 工具入参的schema + { + "type": "object", + "properties": { + "host": { + "type": "string", + "description": "MySQL数据库的主机地址" + }, + "port": { + "type": "integer", + "description": "MySQL数据库的端口号" + }, + "username": { + "type": "string", + "description": "MySQL数据库的用户名" + }, + "password": { + "type": "string", + "description": "MySQL数据库的密码" + } + }, + "required": ["host", "port", "username", "password"] + } + # 工具当前的入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具的报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 补充的参数 + { + "username": "admin", + "password": "admin123" + } + # 补充的参数描述 + 用户希望使用admin用户和admin123密码来连接MySQL数据库。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": "admin", + "password": "admin123" + } + ``` + # 工具 + < tool > + < name > {{tool_name}} < /name > + < description > {{tool_description}} < /description > + < / tool > + # 总目标 + {{goal}} + # 当前阶段目标 + {{current_goal}} + # 工具入参scheme + {{input_schema}} + # 工具入参 + {{input_param}} + # 工具描述 + {{tool_description}} + # 运行报错 + {{error_message}} + # 补充的参数 + {{params}} + # 补充的参数描述 + {{params_description}} + # 输出 + """ + ) +FINAL_ANSWER = dedent(r""" + 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 + + # 用户目标 + + {{goal}} + + # 计划执行情况 + + 为了完成上述目标,你实施了以下计划: + + {{memory}} + + + # 现在,请根据以上信息,向用户报告目标的完成情况: + +""") +MEMORY_TEMPLATE = dedent(r""" + { % for ctx in context_list % } + - 第{{loop.index}}步:{{ctx.step_description}} + - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` + - 执行状态:{{ctx.step_status}} + - 得到数据:`{{ctx.output_data}}` + { % endfor % } +""") diff --git a/apps/scheduler/mcp_agent/schema.py b/apps/scheduler/mcp_agent/schema.py deleted file mode 100644 index 614139074382daf128a19a320c949e5e46803c4d..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/schema.py +++ /dev/null @@ -1,148 +0,0 @@ -"""MCP Agent执行数据结构""" -from typing import Any, Self - -from pydantic import BaseModel, Field - -from apps.schemas.enum_var import Role - - -class Function(BaseModel): - """工具函数""" - - name: str - arguments: dict[str, Any] - - -class ToolCall(BaseModel): - """Represents a tool/function call in a message""" - - id: str - type: str = "function" - function: Function - - -class Message(BaseModel): - """Represents a chat message in the conversation""" - - role: Role = Field(...) - content: str | None = Field(default=None) - tool_calls: list[ToolCall] | None = Field(default=None) - name: str | None = Field(default=None) - tool_call_id: str | None = Field(default=None) - - def __add__(self, other) -> list["Message"]: - """支持 Message + list 或 Message + Message 的操作""" - if isinstance(other, list): - return [self] + other - elif isinstance(other, Message): - return [self, other] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(self).__name__}' and '{type(other).__name__}'" - ) - - def __radd__(self, other) -> list["Message"]: - """支持 list + Message 的操作""" - if isinstance(other, list): - return other + [self] - else: - raise TypeError( - f"unsupported operand type(s) for +: '{type(other).__name__}' and '{type(self).__name__}'" - ) - - def to_dict(self) -> dict: - """Convert message to dictionary format""" - message = {"role": self.role} - if self.content is not None: - message["content"] = self.content - if self.tool_calls is not None: - message["tool_calls"] = [tool_call.dict() for tool_call in self.tool_calls] - if self.name is not None: - message["name"] = self.name - if self.tool_call_id is not None: - message["tool_call_id"] = self.tool_call_id - return message - - @classmethod - def user_message(cls, content: str) -> Self: - """Create a user message""" - return cls(role=Role.USER, content=content) - - @classmethod - def system_message(cls, content: str) -> Self: - """Create a system message""" - return cls(role=Role.SYSTEM, content=content) - - @classmethod - def assistant_message( - cls, content: str | None = None, - ) -> Self: - """Create an assistant message""" - return cls(role=Role.ASSISTANT, content=content) - - @classmethod - def tool_message( - cls, content: str, name: str, tool_call_id: str, - ) -> Self: - """Create a tool message""" - return cls( - role=Role.TOOL, - content=content, - name=name, - tool_call_id=tool_call_id, - ) - - @classmethod - def from_tool_calls( - cls, - tool_calls: list[Any], - content: str | list[str] = "", - **kwargs, # noqa: ANN003 - ) -> Self: - """Create ToolCallsMessage from raw tool calls. - - Args: - tool_calls: Raw tool calls from LLM - content: Optional message content - """ - formatted_calls = [ - {"id": call.id, "function": call.function.model_dump(), "type": "function"} - for call in tool_calls - ] - return cls( - role=Role.ASSISTANT, - content=content, - tool_calls=formatted_calls, - **kwargs, - ) - - -class Memory(BaseModel): - messages: list[Message] = Field(default_factory=list) - max_messages: int = Field(default=100) - - def add_message(self, message: Message) -> None: - """Add a message to memory""" - self.messages.append(message) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def add_messages(self, messages: list[Message]) -> None: - """Add multiple messages to memory""" - self.messages.extend(messages) - # Optional: Implement message limit - if len(self.messages) > self.max_messages: - self.messages = self.messages[-self.max_messages:] - - def clear(self) -> None: - """Clear all messages""" - self.messages.clear() - - def get_recent_messages(self, n: int) -> list[Message]: - """Get n most recent messages""" - return self.messages[-n:] - - def to_dict_list(self) -> list[dict]: - """Convert messages to list of dicts""" - return [msg.to_dict() for msg in self.messages] diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py new file mode 100644 index 0000000000000000000000000000000000000000..075e08f00cb2b5a2976d5372c623193b04c42412 --- /dev/null +++ b/apps/scheduler/mcp_agent/select.py @@ -0,0 +1,109 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""选择MCP Server及其工具""" + +import logging +import random +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment +from typing import AsyncGenerator, Any + +from apps.llm.function import JsonGenerator +from apps.llm.reasoning import ReasoningLLM +from apps.common.lance import LanceDB +from apps.common.mongo import MongoDB +from apps.llm.embedding import Embedding +from apps.llm.function import FunctionLLM +from apps.llm.reasoning import ReasoningLLM +from apps.llm.token import TokenCalculator +from apps.scheduler.mcp_agent.base import McpBase +from apps.scheduler.mcp_agent.prompt import TOOL_SELECT +from apps.schemas.mcp import ( + BaseModel, + MCPCollection, + MCPSelectResult, + MCPTool, + MCPToolIdsSelectResult +) +from apps.common.config import Config +logger = logging.getLogger(__name__) + +_env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, +) + +FINAL_TOOL_ID = "FIANL" +SUMMARIZE_TOOL_ID = "SUMMARIZE" + + +class MCPSelector(McpBase): + """MCP选择器""" + + @staticmethod + async def select_top_tool( + goal: str, tool_list: list[MCPTool], + additional_info: str | None = None, top_n: int | None = None, + reasoning_llm: ReasoningLLM | None = None) -> list[MCPTool]: + """选择最合适的工具""" + random.shuffle(tool_list) + max_tokens = reasoning_llm._config.max_tokens + template = _env.from_string(TOOL_SELECT) + token_calculator = TokenCalculator() + if token_calculator.calculate_token_length( + messages=[{"role": "user", "content": template.render( + goal=goal, tools=[], additional_info=additional_info + )}], + pure_text=True) > max_tokens: + logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择") + return [] + current_index = 0 + tool_ids = [] + while current_index < len(tool_list): + index = current_index + sub_tools = [] + while index < len(tool_list): + tool = tool_list[index] + tokens = token_calculator.calculate_token_length( + messages=[{"role": "user", "content": template.render( + goal=goal, tools=[tool], + additional_info=additional_info + )}], + pure_text=True + ) + if tokens > max_tokens: + continue + sub_tools.append(tool) + + tokens = token_calculator.calculate_token_length(messages=[{"role": "user", "content": template.render( + goal=goal, tools=sub_tools, additional_info=additional_info)}, ], pure_text=True) + if tokens > max_tokens: + del sub_tools[-1] + break + else: + index += 1 + current_index = index + if sub_tools: + schema = MCPToolIdsSelectResult.model_json_schema() + if "items" not in schema["properties"]["tool_ids"]: + schema["properties"]["tool_ids"]["items"] = {} + # 将enum添加到items中,限制数组元素的可选值 + schema["properties"]["tool_ids"]["items"]["enum"] = [tool.id for tool in sub_tools] + result = await MCPSelector.get_resoning_result(template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"), reasoning_llm) + result = await MCPSelector._parse_result(result, schema) + try: + result = MCPToolIdsSelectResult.model_validate(result) + tool_ids.extend(result.tool_ids) + except Exception: + logger.exception("[MCPSelector] 解析MCP工具ID选择结果失败") + continue + mcp_tools = [tool for tool in tool_list if tool.id in tool_ids] + + if top_n is not None: + mcp_tools = mcp_tools[:top_n] + mcp_tools.append(MCPTool(id=FINAL_TOOL_ID, name="Final", + description="终止", mcp_id=FINAL_TOOL_ID, input_schema={})) + # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize", + # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={})) + return mcp_tools diff --git a/apps/scheduler/mcp_agent/tool/__init__.py b/apps/scheduler/mcp_agent/tool/__init__.py deleted file mode 100644 index 4593f31742fee21b2e3ec1c7c18ff8e3cfea2110..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool -from apps.scheduler.mcp_agent.tool.terminate import Terminate -from apps.scheduler.mcp_agent.tool.tool_collection import ToolCollection - -__all__ = [ - "BaseTool", - "Terminate", - "ToolCollection", -] diff --git a/apps/scheduler/mcp_agent/tool/base.py b/apps/scheduler/mcp_agent/tool/base.py deleted file mode 100644 index 04ad45c47a3eecb25efdf5b2ce52beb6965b2fbd..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/base.py +++ /dev/null @@ -1,73 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - -from pydantic import BaseModel, Field - - -class BaseTool(ABC, BaseModel): - name: str - description: str - parameters: Optional[dict] = None - - class Config: - arbitrary_types_allowed = True - - async def __call__(self, **kwargs) -> Any: - return await self.execute(**kwargs) - - @abstractmethod - async def execute(self, **kwargs) -> Any: - """使用给定的参数执行工具""" - - def to_param(self) -> Dict: - """将工具转换为函数调用格式""" - return { - "type": "function", - "function": { - "name": self.name, - "description": self.description, - "parameters": self.parameters, - }, - } - - -class ToolResult(BaseModel): - """表示工具执行的结果""" - - output: Any = Field(default=None) - error: Optional[str] = Field(default=None) - system: Optional[str] = Field(default=None) - - class Config: - arbitrary_types_allowed = True - - def __bool__(self): - return any(getattr(self, field) for field in self.__fields__) - - def __add__(self, other: "ToolResult"): - def combine_fields( - field: Optional[str], other_field: Optional[str], concatenate: bool = True - ): - if field and other_field: - if concatenate: - return field + other_field - raise ValueError("Cannot combine tool results") - return field or other_field - - return ToolResult( - output=combine_fields(self.output, other.output), - error=combine_fields(self.error, other.error), - system=combine_fields(self.system, other.system), - ) - - def __str__(self): - return f"Error: {self.error}" if self.error else self.output - - def replace(self, **kwargs): - """返回一个新的ToolResult,其中替换了给定的字段""" - # return self.copy(update=kwargs) - return type(self)(**{**self.dict(), **kwargs}) - - -class ToolFailure(ToolResult): - """表示失败的ToolResult""" diff --git a/apps/scheduler/mcp_agent/tool/terminate.py b/apps/scheduler/mcp_agent/tool/terminate.py deleted file mode 100644 index 84aa120316de1123f985eebfd82c471e8ceec990..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/terminate.py +++ /dev/null @@ -1,25 +0,0 @@ -from apps.scheduler.mcp_agent.tool.base import BaseTool - - -_TERMINATE_DESCRIPTION = """当请求得到满足或助理无法继续处理任务时,终止交互。 -当您完成所有任务后,调用此工具结束工作。""" - - -class Terminate(BaseTool): - name: str = "terminate" - description: str = _TERMINATE_DESCRIPTION - parameters: dict = { - "type": "object", - "properties": { - "status": { - "type": "string", - "description": "交互的完成状态", - "enum": ["success", "failure"], - } - }, - "required": ["status"], - } - - async def execute(self, status: str) -> str: - """Finish the current execution""" - return f"交互已完成,状态为: {status}" diff --git a/apps/scheduler/mcp_agent/tool/tool_collection.py b/apps/scheduler/mcp_agent/tool/tool_collection.py deleted file mode 100644 index 95bda317805abdecc256af0737091b46adf77b1a..0000000000000000000000000000000000000000 --- a/apps/scheduler/mcp_agent/tool/tool_collection.py +++ /dev/null @@ -1,55 +0,0 @@ -"""用于管理多个工具的集合类""" -import logging -from typing import Any - -from apps.scheduler.mcp_agent.tool.base import BaseTool, ToolFailure, ToolResult - -logger = logging.getLogger(__name__) - - -class ToolCollection: - """定义工具的集合""" - - class Config: - arbitrary_types_allowed = True - - def __init__(self, *tools: BaseTool): - self.tools = tools - self.tool_map = {tool.name: tool for tool in tools} - - def __iter__(self): - return iter(self.tools) - - def to_params(self) -> list[dict[str, Any]]: - return [tool.to_param() for tool in self.tools] - - async def execute( - self, *, name: str, tool_input: dict[str, Any] = None - ) -> ToolResult: - tool = self.tool_map.get(name) - if not tool: - return ToolFailure(error=f"Tool {name} is invalid") - try: - result = await tool(**tool_input) - return result - except Exception as e: - return ToolFailure(error=f"Failed to execute tool {name}: {e}") - - def add_tool(self, tool: BaseTool): - """ - 将单个工具添加到集合中。 - - 如果已存在同名工具,则将跳过该工具并记录警告。 - """ - if tool.name in self.tool_map: - logger.warning(f"Tool {tool.name} already exists in collection, skipping") - return self - - self.tools += (tool,) - self.tool_map[tool.name] = tool - return self - - def add_tools(self, *tools: BaseTool): - for tool in tools: - self.add_tool(tool) - return self diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index a85395bfc7723af9d009a1945fe4617cf390b938..49407ba95358d4442c714087fd48b1aaca219f8d 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -78,6 +78,7 @@ class AppLoader: # 加载模型 try: metadata = AgentAppMetadata.model_validate(metadata) + logger.info(f"[AppLoader] Agent应用元数据验证成功: {metadata}") except Exception as e: err = "[AppLoader] Agent应用元数据验证失败" logger.exception(err) @@ -102,7 +103,6 @@ class AppLoader: await file_checker.diff_one(app_path) await self.load(app_id, file_checker.hashes[f"app/{app_id}"]) - @staticmethod async def delete(app_id: str, *, is_reload: bool = False) -> None: """ @@ -157,5 +157,7 @@ class AppLoader: }, upsert=True, ) + app_pool = await app_collection.find_one({"_id": metadata.id}) + logger.error(f"[AppLoader] 更新 MongoDB 成功: {app_pool}") except Exception: logger.exception("[AppLoader] 更新 MongoDB 失败") diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 1463d0a1a9141bca271e29d7387759748aa8bd43..1da06b53630d3188ca55c16969823f3d1e063516 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -11,6 +11,7 @@ import shutil import asyncer from anyio import Path from sqids.sqids import Sqids +from typing import Any from apps.common.lance import LanceDB from apps.common.mongo import MongoDB @@ -91,48 +92,47 @@ class MCPLoader(metaclass=SingletonMeta): :param MCPServerConfig config: MCP配置 :return: 无 """ - if not config.config.auto_install: - print(f"[Installer] MCP模板无需安装: {mcp_id}") # noqa: T201 - - elif isinstance(config.config, MCPServerStdioConfig): - print(f"[Installer] Stdio方式的MCP模板,开始自动安装: {mcp_id}") # noqa: T201 - if "uv" in config.config.command: - new_config = await install_uvx(mcp_id, config.config) - elif "npx" in config.config.command: - new_config = await install_npx(mcp_id, config.config) - - if new_config is None: - logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id) - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) - return - - config.config = new_config - - # 重新保存config - template_config = MCP_PATH / "template" / mcp_id / "config.json" - f = await template_config.open("w+", encoding="utf-8") - config_data = config.model_dump(by_alias=True, exclude_none=True) - await f.write(json.dumps(config_data, indent=4, ensure_ascii=False)) - await f.aclose() - - else: - print(f"[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: {mcp_id}") # noqa: T201 - config.config.auto_install = False - - print(f"[Installer] MCP模板安装成功: {mcp_id}") # noqa: T201 - await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY) - await MCPLoader._insert_template_tool(mcp_id, config) + try: + if not config.config.auto_install: + print(f"[Installer] MCP模板无需安装: {mcp_id}") # noqa: T201 + elif isinstance(config.config, MCPServerStdioConfig): + print(f"[Installer] Stdio方式的MCP模板,开始自动安装: {mcp_id}") # noqa: T201 + if "uv" in config.config.command: + new_config = await install_uvx(mcp_id, config.config) + elif "npx" in config.config.command: + new_config = await install_npx(mcp_id, config.config) + + if new_config is None: + logger.error("[MCPLoader] MCP模板安装失败: %s", mcp_id) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) + return + + config.config = new_config + + # 重新保存config + template_config = MCP_PATH / "template" / mcp_id / "config.json" + f = await template_config.open("w+", encoding="utf-8") + config_data = config.model_dump(by_alias=True, exclude_none=True) + await f.write(json.dumps(config_data, indent=4, ensure_ascii=False)) + await f.aclose() + + else: + logger.info(f"[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: {mcp_id}") # noqa: T201 + config.config.auto_install = False + + await MCPLoader._insert_template_tool(mcp_id, config) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.READY) + logger.info(f"[Installer] MCP模板安装成功: {mcp_id}") # noqa: T201 + except Exception as e: + logger.error("[MCPLoader] MCP模板安装失败: %s, 错误: %s", mcp_id, e) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED) + raise @staticmethod - async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None: + async def clear_ready_or_failed_mcp_installation() -> None: """ - 初始化单个MCP模板 - - :param str mcp_id: MCP模板ID - :param MCPServerConfig config: MCP配置 - :return: 无 + 清除状态为ready或failed的MCP安装任务 """ - # 删除完成或者失败的MCP安装任务 mcp_collection = MongoDB().get_collection("mcp") mcp_ids = ProcessHandler.get_all_task_ids() # 检索_id在mcp_ids且状态为ready或者failed的MCP的内容 @@ -147,8 +147,17 @@ class MCPLoader(metaclass=SingletonMeta): continue ProcessHandler.remove_task(item.id) logger.info("[MCPLoader] 删除已完成或失败的MCP安装进程: %s", item.id) - # 插入数据库;这里用旧的config就可以 - await MCPLoader._insert_template_db(mcp_id, config) + + @staticmethod + async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None: + """ + 初始化单个MCP模板 + + :param str mcp_id: MCP模板ID + :param MCPServerConfig config: MCP配置 + :return: 无 + """ + await MCPLoader.clear_ready_or_failed_mcp_installation() # 检查目录 template_path = MCP_PATH / "template" / mcp_id @@ -158,37 +167,31 @@ class MCPLoader(metaclass=SingletonMeta): err = f"安装任务无法执行,请稍后重试: {mcp_id}" logger.error(err) raise RuntimeError(err) + # 将installing状态的安装任务的状态变为cancelled @staticmethod - async def _init_all_template() -> None: + async def cancel_all_installing_task() -> None: """ - 初始化所有MCP模板 - - 遍历 ``template`` 目录下的所有MCP模板,并初始化。在Framework启动时进行此流程,确保所有MCP均可正常使用。 - 这一过程会与数据库内的条目进行对比,若发生修改,则重新创建数据库条目。 + 取消正在安装的MCP模板任务 """ template_path = MCP_PATH / "template" logger.info("[MCPLoader] 初始化所有MCP模板: %s", template_path) - + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") # 遍历所有模板 + mcp_ids = [] async for mcp_dir in template_path.iterdir(): # 不是目录 if not await mcp_dir.is_dir(): logger.warning("[MCPLoader] 跳过非目录: %s", mcp_dir.as_posix()) continue - # 检查配置文件是否存在 - config_path = mcp_dir / "config.json" - if not await config_path.exists(): - logger.warning("[MCPLoader] 跳过没有配置文件的MCP模板: %s", mcp_dir.as_posix()) - continue - - # 读取配置并加载 - config = await MCPLoader._load_config(config_path) - - # 初始化第一个MCP Server - logger.info("[MCPLoader] 初始化MCP模板: %s", mcp_dir.as_posix()) - await MCPLoader.init_one_template(mcp_dir.name, config) + mcp_ids.append(mcp_dir.name) + # 更新数据库状态 + await mcp_collection.update_many( + {"_id": {"$in": mcp_ids}, "status": MCPInstallStatus.INSTALLING}, + {"$set": {"status": MCPInstallStatus.CANCELLED}}, + ) @staticmethod async def _get_template_tool( @@ -384,6 +387,7 @@ class MCPLoader(metaclass=SingletonMeta): # 更新数据库 mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") + logger.info("[MCPLoader] 更新MCP模板状态: %s -> %s", mcp_id, status) await mcp_collection.update_one( {"_id": mcp_id}, {"$set": {"status": status}}, @@ -391,7 +395,7 @@ class MCPLoader(metaclass=SingletonMeta): ) @staticmethod - async def user_active_template(user_sub: str, mcp_id: str) -> None: + async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any] | None = None) -> None: """ 用户激活MCP模板 @@ -409,7 +413,7 @@ class MCPLoader(metaclass=SingletonMeta): if await user_path.exists(): err = f"MCP模板“{mcp_id}”已存在或有同名文件,无法激活" raise FileExistsError(err) - + mcp_config = await MCPLoader.get_config(mcp_id) # 拷贝文件 await asyncer.asyncify(shutil.copytree)( template_path.as_posix(), @@ -417,7 +421,32 @@ class MCPLoader(metaclass=SingletonMeta): dirs_exist_ok=True, symlinks=True, ) - + if mcp_env is not None: + mcp_config.config.env.update(mcp_env) + if mcp_config.type == MCPType.STDIO: + index = None + for i in range(len(mcp_config.config.args)): + if mcp_config.config.args[i] == "--directory": + index = i + 1 + break + if index is not None: + if index < len(mcp_config.config.args): + mcp_config.config.args[index] = str(user_path)+'/project' + else: + mcp_config.config.args.append(str(user_path)+'/project') + else: + mcp_config.config.args = ["--directory", str(user_path)+'/project'] + mcp_config.config.args + user_config_path = user_path / "config.json" + # 更新用户配置 + f = await user_config_path.open("w", encoding="utf-8", errors="ignore") + await f.write( + json.dumps( + mcp_config.model_dump(by_alias=True, exclude_none=True), + indent=4, + ensure_ascii=False, + ) + ) + await f.aclose() # 更新数据库 mongo = MongoDB() mcp_collection = mongo.get_collection("mcp") @@ -468,6 +497,26 @@ class MCPLoader(metaclass=SingletonMeta): logger.info("[MCPLoader] 这些MCP在文件系统中被删除: %s", deleted_mcp_list) return deleted_mcp_list + @staticmethod + async def cancel_installing_task(cancel_mcp_list: list[str]) -> None: + """ + 取消正在安装的MCP模板任务 + + :param list[str] cancel_mcp_list: 需要取消的MCP列表 + :return: 无 + """ + mongo = MongoDB() + mcp_collection = mongo.get_collection("mcp") + # 更新数据库状态 + cancel_mcp_list = await mcp_collection.distinct("_id", {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING}) + await mcp_collection.update_many( + {"_id": {"$in": cancel_mcp_list}, "status": MCPInstallStatus.INSTALLING}, + {"$set": {"status": MCPInstallStatus.CANCELLED}}, + ) + for mcp_id in cancel_mcp_list: + ProcessHandler.remove_task(mcp_id) + logger.info("[MCPLoader] 取消这些正在安装的MCP模板任务: %s", cancel_mcp_list) + @staticmethod async def remove_deleted_mcp(deleted_mcp_list: list[str]) -> None: """ @@ -573,8 +622,8 @@ class MCPLoader(metaclass=SingletonMeta): # 检查目录 await MCPLoader._check_dir() - # 初始化所有模板 - await MCPLoader._init_all_template() + # 暂停所有安装任务 + await MCPLoader.cancel_all_installing_task() # 加载用户MCP await MCPLoader._load_user_mcp() diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 2b9060461fc0f3baaece19e88c71776449e9752a..2d84069c434de99e45156452a1e5156ba03b2b3e 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -3,6 +3,7 @@ import asyncio import logging +import os import shutil from anyio import Path @@ -30,6 +31,9 @@ class ServiceLoader: """加载单个Service""" service_path = BASE_PATH / service_id # 载入元数据 + if not os.path.exists(service_path / "metadata.yaml"): + logger.error("[ServiceLoader] Service %s 的元数据不存在", service_id) + return metadata = await MetadataLoader().load_one(service_path / "metadata.yaml") if not isinstance(metadata, ServiceMetadata): err = f"[ServiceLoader] 元数据类型错误: {service_path}/metadata.yaml" @@ -48,7 +52,6 @@ class ServiceLoader: # 更新数据库 await self._update_db(nodes, metadata) - async def save(self, service_id: str, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" service_path = BASE_PATH / service_id @@ -67,7 +70,6 @@ class ServiceLoader: await file_checker.diff_one(service_path) await self.load(service_id, file_checker.hashes[f"service/{service_id}"]) - async def delete(self, service_id: str, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" mongo = MongoDB() @@ -95,7 +97,6 @@ class ServiceLoader: if await path.exists(): shutil.rmtree(path) - async def _update_db(self, nodes: list[NodePool], metadata: ServiceMetadata) -> None: # noqa: C901, PLR0912, PLR0915 """更新数据库""" if not metadata.hashes: @@ -197,4 +198,3 @@ class ServiceLoader: await asyncio.sleep(0.01) else: raise - diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py index 092bac8909635a5e0c846dddef3456d8ad3be43c..0ced05e8cc68bb773cfe61eb9268a4f8baab8fa5 100644 --- a/apps/scheduler/pool/mcp/client.py +++ b/apps/scheduler/pool/mcp/client.py @@ -29,6 +29,7 @@ class MCPClient: mcp_id: str task: asyncio.Task ready_sign: asyncio.Event + error_sign: asyncio.Event stop_sign: asyncio.Event client: ClientSession status: MCPStatus @@ -54,9 +55,10 @@ class MCPClient: """ # 创建Client if isinstance(config, MCPServerSSEConfig): + env = config.env or {} client = sse_client( url=config.url, - headers=config.env, + headers=env, ) elif isinstance(config, MCPServerStdioConfig): if user_sub: @@ -64,7 +66,6 @@ class MCPClient: else: cwd = MCP_PATH / "template" / mcp_id / "project" await cwd.mkdir(parents=True, exist_ok=True) - client = stdio_client(server=StdioServerParameters( command=config.command, args=config.args, @@ -72,6 +73,7 @@ class MCPClient: cwd=cwd.as_posix(), )) else: + self.error_sign.set() err = f"[MCPClient] MCP {mcp_id}:未知的MCP服务类型“{config.type}”" logger.error(err) raise TypeError(err) @@ -85,23 +87,24 @@ class MCPClient: # 初始化Client await session.initialize() except Exception: + self.error_sign.set() + self.status = MCPStatus.STOPPED logger.exception("[MCPClient] MCP %s:初始化失败", mcp_id) raise self.ready_sign.set() self.status = MCPStatus.RUNNING - # 等待关闭信号 await self.stop_sign.wait() + logger.error("[MCPClient] MCP %s:收到停止信号,正在关闭", mcp_id) # 关闭Client try: - await exit_stack.aclose() # type: ignore[attr-defined] + await exit_stack.aclose() # type: ignore[attr-defined] self.status = MCPStatus.STOPPED except Exception: logger.exception("[MCPClient] MCP %s:关闭失败", mcp_id) - async def init(self, user_sub: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig) -> None: """ 初始化 MCP Client类 @@ -116,27 +119,34 @@ class MCPClient: # 初始化变量 self.mcp_id = mcp_id self.ready_sign = asyncio.Event() + self.error_sign = asyncio.Event() self.stop_sign = asyncio.Event() # 创建协程 self.task = asyncio.create_task(self._main_loop(user_sub, mcp_id, config)) # 等待初始化完成 - await self.ready_sign.wait() - - # 获取工具列表 + done, pending = await asyncio.wait( + [asyncio.create_task(self.ready_sign.wait()), + asyncio.create_task(self.error_sign.wait())], + return_when=asyncio.FIRST_COMPLETED + ) + if self.error_sign.is_set(): + self.status = MCPStatus.ERROR + logger.error("[MCPClient] MCP %s:初始化失败", mcp_id) + raise Exception(f"MCP {mcp_id} 初始化失败") + + # 获取工具列表 self.tools = (await self.client.list_tools()).tools - async def call_tool(self, tool_name: str, params: dict) -> "CallToolResult": """调用MCP Server的工具""" return await self.client.call_tool(tool_name, params) - async def stop(self) -> None: """停止MCP Client""" self.stop_sign.set() try: await self.task - except Exception: - logger.exception("[MCPClient] MCP %s:停止失败", self.mcp_id) + except Exception as e: + logger.warning("[MCPClient] MCP %s:停止时发生异常:%s", self.mcp_id, e) diff --git a/apps/scheduler/pool/mcp/install.py b/apps/scheduler/pool/mcp/install.py index 1b6c3edeb3a042b7716d0d7748e9c3e50b01af74..2b15cd690bfd8b1939326728647e27ab517f89ff 100644 --- a/apps/scheduler/pool/mcp/install.py +++ b/apps/scheduler/pool/mcp/install.py @@ -3,12 +3,16 @@ from asyncio import subprocess from typing import TYPE_CHECKING - +import logging +import os +import shutil from apps.constants import MCP_PATH if TYPE_CHECKING: from apps.schemas.mcp import MCPServerStdioConfig +logger = logging.getLogger(__name__) + async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServerStdioConfig | None": """ @@ -23,27 +27,35 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer :rtype: MCPServerStdioConfig :raises ValueError: 未找到MCP Server对应的Python包 """ - # 创建文件夹 - mcp_path = MCP_PATH / "template" / mcp_id / "project" - await mcp_path.mkdir(parents=True, exist_ok=True) - + uv_path = shutil.which('uv') + if uv_path is None: + error = "[Installer] 未找到uv命令,请先安装uv包管理器: pip install uv" + logging.error(error) + raise Exception(error) # 找到包名 - package = "" + package = None for arg in config.args: if not arg.startswith("-"): package = arg break - + logger.error(f"[Installer] MCP包名: {package}") if not package: print("[Installer] 未找到包名") # noqa: T201 return None - + # 创建文件夹 + mcp_path = MCP_PATH / "template" / mcp_id / "project" + logger.error(f"[Installer] MCP安装路径: {mcp_path}") + await mcp_path.mkdir(parents=True, exist_ok=True) # 如果有pyproject.toml文件,则使用sync + flag = await (mcp_path / "pyproject.toml").exists() + logger.error(f"[Installer] MCP安装标志: {flag}") if await (mcp_path / "pyproject.toml").exists(): + shell_command = f"{uv_path} venv; {uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active --no-install-project --no-cache" + logger.error(f"[Installer] MCP安装命令: {shell_command}") pipe = await subprocess.create_subprocess_shell( ( - "uv venv; " - "uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " + f"{uv_path} venv; " + f"{uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " "--no-install-project --no-cache" ), stdout=subprocess.PIPE, @@ -57,19 +69,20 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer return None print(f"[Installer] 检查依赖成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201 - config.command = "uv" - config.args = ["run", *config.args] + config.command = uv_path + if "run" not in config.args: + config.args = ["run", *config.args] config.auto_install = False - + logger.error(f"[Installer] MCP安装配置更新成功: {config}") return config # 否则,初始化uv项目 pipe = await subprocess.create_subprocess_shell( ( - f"uv init; " - f"uv venv; " - f"uv add --index-url https://pypi.tuna.tsinghua.edu.cn/simple {package}; " - f"uv sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " + f"{uv_path} init; " + f"{uv_path} venv; " + f"{uv_path} add --index-url https://pypi.tuna.tsinghua.edu.cn/simple {package}; " + f"{uv_path} sync --index-url https://pypi.tuna.tsinghua.edu.cn/simple --active " f"--no-install-project --no-cache" ), stdout=subprocess.PIPE, @@ -84,8 +97,9 @@ async def install_uvx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer print(f"[Installer] 安装 {package} 成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201 # 更新配置 - config.command = "uv" - config.args = ["run", *config.args] + config.command = uv_path + if "run" not in config.args: + config.args = ["run", *config.args] config.auto_install = False return config @@ -103,17 +117,13 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer :rtype: MCPServerStdioConfig :raises ValueError: 未找到MCP Server对应的npm包 """ - mcp_path = MCP_PATH / "template" / mcp_id / "project" - await mcp_path.mkdir(parents=True, exist_ok=True) - - # 如果有node_modules文件夹,则认为已安装 - if await (mcp_path / "node_modules").exists(): - config.command = "npm" - config.args = ["exec", *config.args] - return config - + npm_path = shutil.which('npm') + if npm_path is None: + error = "[Installer] 未找到npm命令,请先安装Node.js和npm" + logging.error(error) + raise Exception(error) # 查找package name - package = "" + package = None for arg in config.args: if not arg.startswith("-"): package = arg @@ -122,10 +132,18 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer if not package: print("[Installer] 未找到包名") # noqa: T201 return None + mcp_path = MCP_PATH / "template" / mcp_id / "project" + await mcp_path.mkdir(parents=True, exist_ok=True) + # 如果有node_modules文件夹,则认为已安装 + if await (mcp_path / "node_modules").exists(): + config.command = npm_path + if "exec" not in config.args: + config.args = ["exec", *config.args] + return config # 安装NPM包 pipe = await subprocess.create_subprocess_shell( - f"npm install {package}", + f"{npm_path} install {package}", stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=mcp_path, @@ -137,8 +155,9 @@ async def install_npx(mcp_id: str, config: "MCPServerStdioConfig") -> "MCPServer print(f"[Installer] 安装 {package} 成功: {mcp_path}; {stdout.decode() if stdout else '(无输出信息)'}") # noqa: T201 # 更新配置 - config.command = "npm" - config.args = ["exec", *config.args] + config.command = npm_path + if "exec" not in config.args: + config.args = ["exec", *config.args] config.auto_install = False return config diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 91cde4d9d6c83fd5088328e12cfbb6bf06d3c7cb..bf0320f429a9ef45864ba6548b1bb28e3d874b59 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -21,16 +21,13 @@ class MCPPool(metaclass=SingletonMeta): """初始化MCP池""" self.pool = {} - async def _init_mcp(self, mcp_id: str, user_sub: str) -> MCPClient | None: """初始化MCP池""" - mcp_math = MCP_USER_PATH / user_sub / mcp_id / "project" config_path = MCP_USER_PATH / user_sub / mcp_id / "config.json" - - if not await mcp_math.exists() or not await mcp_math.is_dir(): - logger.warning("[MCPPool] 用户 %s 的MCP %s 未激活", user_sub, mcp_id) + flag = (await config_path.exists()) + if not flag: + logger.warning("[MCPPool] 用户 %s 的MCP %s 配置文件不存在", user_sub, mcp_id) return None - config = MCPServerConfig.model_validate_json(await config_path.read_text()) if config.type in (MCPType.SSE, MCPType.STDIO): @@ -40,9 +37,11 @@ class MCPPool(metaclass=SingletonMeta): return None await client.init(user_sub, mcp_id, config.config) + if user_sub not in self.pool: + self.pool[user_sub] = {} + self.pool[user_sub][mcp_id] = client return client - async def _get_from_dict(self, mcp_id: str, user_sub: str) -> MCPClient | None: """从字典中获取MCP客户端""" if user_sub not in self.pool: @@ -53,7 +52,6 @@ class MCPPool(metaclass=SingletonMeta): return self.pool[user_sub][mcp_id] - async def _validate_user(self, mcp_id: str, user_sub: str) -> bool: """验证用户是否已激活""" mongo = MongoDB() @@ -61,7 +59,6 @@ class MCPPool(metaclass=SingletonMeta): mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) return mcp_db_result is not None - async def get(self, mcp_id: str, user_sub: str) -> MCPClient | None: """获取MCP客户端""" item = await self._get_from_dict(mcp_id, user_sub) @@ -83,7 +80,6 @@ class MCPPool(metaclass=SingletonMeta): return item - async def stop(self, mcp_id: str, user_sub: str) -> None: """停止MCP客户端""" await self.pool[user_sub][mcp_id].stop() diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 7710d24dc102fe02c06d8cbc14f126bd088db2d4..ead552fc2994150137121b103ff013d8ea4d66a5 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -60,7 +60,6 @@ class Pool: await Path(root_dir + "mcp").unlink(missing_ok=True) await Path(root_dir + "mcp").mkdir(parents=True, exist_ok=True) - @staticmethod async def init() -> None: """ @@ -121,13 +120,15 @@ class Pool: for app in changed_app: hash_key = Path("app/" + app).as_posix() if hash_key in checker.hashes: - await app_loader.load(app, checker.hashes[hash_key]) - + try: + await app_loader.load(app, checker.hashes[hash_key]) + except Exception as e: + await app_loader.delete(app, is_reload=True) + logger.warning("[Pool] 加载App %s 失败: %s", app, e) # 载入MCP logger.info("[Pool] 载入MCP") await MCPLoader.init() - async def get_flow_metadata(self, app_id: str) -> list[AppFlow]: """从数据库中获取特定App的全部Flow的元数据""" mongo = MongoDB() @@ -145,14 +146,12 @@ class Pool: else: return flow_metadata_list - async def get_flow(self, app_id: str, flow_id: str) -> Flow | None: """从文件系统中获取单个Flow的全部数据""" logger.info("[Pool] 获取工作流 %s", flow_id) flow_loader = FlowLoader() return await flow_loader.load(app_id, flow_id) - async def get_call(self, call_id: str) -> Any: """[Exception] 拿到Call的信息""" # 从MongoDB里拿到数据 diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index b7088d8da45ad01e94e0cf7208db25c5f56c3b22..6e737314822c7d9c5a9671c8670bd5a38db40e3d 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -10,6 +10,7 @@ from apps.llm.patterns.facts import Facts from apps.schemas.collection import Document from apps.schemas.enum_var import StepStatus from apps.schemas.record import ( + FlowHistory, Record, RecordContent, RecordDocument, @@ -114,11 +115,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: used_docs.append( RecordGroupDocument( _id=docs["id"], + author=docs.get("author", ""), + order=docs.get("order", 0), name=docs["name"], abstract=docs.get("abstract", ""), extension=docs.get("extension", ""), size=docs.get("size", 0), associated="answer", + created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)), ) ) if docs.get("order") is not None: @@ -154,7 +158,6 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: facts=task.runtime.facts, data={}, ) - try: # 加密Record数据 encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) @@ -185,34 +188,37 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow=[i["_id"] for i in task.context], + flow=FlowHistory( + flow_id=task.state.flow_id, + flow_name=task.state.flow_name, + flow_status=task.state.flow_status, + history_ids=[context.id for context in task.context], + ) ) # 检查是否存在group_id if not await RecordManager.check_group_id(task.ids.group_id, user_sub): - record_group = await RecordManager.create_record_group( - task.ids.group_id, user_sub, post_body.conversation_id, task.id, + record_group_id = await RecordManager.create_record_group( + task.ids.group_id, user_sub, post_body.conversation_id ) - if not record_group: + if not record_group_id: logger.error("[Scheduler] 创建问答组失败") return else: - record_group = task.ids.group_id + record_group_id = task.ids.group_id # 修改文件状态 - await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) + await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group_id) # 保存Record - await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record) + await RecordManager.insert_record_data_into_record_group(user_sub, record_group_id, record) # 保存与答案关联的文件 - await DocumentManager.save_answer_doc(user_sub, record_group, used_docs) + await DocumentManager.save_answer_doc(user_sub, record_group_id, used_docs) if post_body.app and post_body.app.app_id: # 更新最近使用的应用 await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) - # 若状态为成功,删除Task - if not task.state or task.state.status == StepStatus.SUCCESS: + if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED: await TaskManager.delete_task_by_task_id(task.id) else: - # 更新Task await TaskManager.save_task(task.id, task) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index c89fdd1014ba4714d0777742412e98c0802ca68c..d43ba7fb891d2b1ed7ecdb3378a8ad8e0e512286 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -15,6 +15,7 @@ from apps.schemas.message import ( InitContentFeature, TextAddContent, ) +from apps.schemas.enum_var import FlowStatus from apps.schemas.rag_data import RAGEventData, RAGQueryReq from apps.schemas.record import RecordDocument from apps.schemas.task import Task @@ -61,28 +62,25 @@ async def push_init_message( async def push_rag_message( task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], - rag_data: RAGQueryReq,) -> Task: + rag_data: RAGQueryReq,) -> None: """推送RAG消息""" full_answer = "" - - async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): - task, content_obj = await _push_rag_chunk(task, queue, chunk) - if content_obj.event_type == EventType.TEXT_ADD.value: - # 如果是文本消息,直接拼接到答案中 - full_answer += content_obj.content - elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append({ - "id": content_obj.content.get("id", ""), - "order": content_obj.content.get("order", 0), - "name": content_obj.content.get("name", ""), - "abstract": content_obj.content.get("abstract", ""), - "extension": content_obj.content.get("extension", ""), - "size": content_obj.content.get("size", 0), - }) + try: + async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): + task, content_obj = await _push_rag_chunk(task, queue, chunk) + if content_obj.event_type == EventType.TEXT_ADD.value: + # 如果是文本消息,直接拼接到答案中 + full_answer += content_obj.content + elif content_obj.event_type == EventType.DOCUMENT_ADD.value: + task.runtime.documents.append(content_obj.content) + task.state.flow_status = FlowStatus.SUCCESS + except Exception as e: + logger.error(f"[Scheduler] RAG服务发生错误: {e}") + task.state.flow_status = FlowStatus.ERROR # 保存答案 task.runtime.answer = full_answer + task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time await TaskManager.save_task(task.id, task) - return task async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]: @@ -115,10 +113,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl data=DocumentAddContent( documentId=content_obj.content.get("id", ""), documentOrder=content_obj.content.get("order", 0), + documentAuthor=content_obj.content.get("author", ""), documentName=content_obj.content.get("name", ""), documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), + createdAt=round(datetime.now(tz=UTC).timestamp(), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ed73638ced241909f55989597f8bb747b50f1945..b524d249a6f55caf779a473ef39db35a463525b9 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -1,9 +1,12 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Scheduler模块""" +import asyncio import logging from datetime import UTC, datetime - +from apps.llm.reasoning import ReasoningLLM +from apps.schemas.config import LLMConfig +from apps.llm.patterns.rewrite import QuestionRewrite from apps.common.config import Config from apps.common.mongo import MongoDB from apps.common.queue import MessageQueue @@ -17,11 +20,12 @@ from apps.scheduler.scheduler.message import ( push_rag_message, ) from apps.schemas.collection import LLM -from apps.schemas.enum_var import AppType, EventType +from apps.schemas.enum_var import FlowStatus, AppType, EventType from apps.schemas.pool import AppPool from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground +from apps.services.activity import Activity from apps.schemas.task import Task from apps.services.appcenter import AppCenterManager from apps.services.knowledge import KnowledgeBaseManager @@ -41,12 +45,31 @@ class Scheduler: """初始化""" self.used_docs = [] self.task = task - self.queue = queue self.post_body = post_body - async def run(self) -> None: # noqa: PLR0911 - """运行调度器""" + async def _monitor_activity(self, kill_event): + """监控用户活动状态,不活跃时终止工作流""" + try: + check_interval = 0.5 # 每0.5秒检查一次 + + while not kill_event.is_set(): + # 检查用户活动状态 + is_active = await Activity.is_active(self.task.ids.active_id) + if not is_active: + logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", self.task.ids.user_sub) + kill_event.set() + break + + # 控制检查频率 + await asyncio.sleep(check_interval) + except asyncio.CancelledError: + logger.info("[Scheduler] 活动监控任务已取消") + except Exception as e: + logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") + + async def get_llm_use_in_chat_with_rag(self) -> LLM: + """获取RAG大模型""" try: # 获取当前会话使用的大模型 llm_id = await LLMManager.get_llm_id_by_conversation_id( @@ -54,8 +77,7 @@ class Scheduler: ) if not llm_id: logger.error("[Scheduler] 获取大模型ID失败") - await self.queue.close() - return + return None if llm_id == "empty": llm = LLM( _id="empty", @@ -65,24 +87,31 @@ class Scheduler: model_name=Config().get_config().llm.model, max_tokens=Config().get_config().llm.max_tokens, ) + return llm else: llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id) if not llm: logger.error("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None + return llm except Exception: logger.exception("[Scheduler] 获取大模型失败") - await self.queue.close() - return + return None + + async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]: + """获取知识库ID列表""" try: - # 获取当前会话使用的知识库 kb_ids = await KnowledgeBaseManager.get_kb_ids_by_conversation_id( - self.task.ids.user_sub, self.task.ids.conversation_id) + self.task.ids.user_sub, self.task.ids.conversation_id, + ) + return kb_ids except Exception: logger.exception("[Scheduler] 获取知识库ID失败") await self.queue.close() - return + return [] + + async def run(self) -> None: # noqa: PLR0911 + """运行调度器""" try: # 获取当前问答可供关联的文档 docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body) @@ -92,18 +121,30 @@ class Scheduler: return history, _ = await get_context(self.task.ids.user_sub, self.post_body, 3) # 已使用文档 - # 如果是智能问答,直接执行 logger.info("[Scheduler] 开始执行") - if not self.post_body.app or self.post_body.app.app_id == "": + # 创建用于通信的事件 + kill_event = asyncio.Event() + monitor = asyncio.create_task(self._monitor_activity(kill_event)) + rag_method = True + if self.post_body.app and self.post_body.app.app_id: + rag_method = False + if self.task.state.app_id: + rag_method = False + if rag_method: + llm = await self.get_llm_use_in_chat_with_rag() + kb_ids = await self.get_kb_ids_use_in_chat_with_rag() self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( kbIds=kb_ids, query=self.post_body.question, tokensLimit=llm.max_tokens, ) - self.task = await push_rag_message(self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data) - self.task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time + + # 启动监控任务和主任务 + main_task = asyncio.create_task(push_rag_message( + self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data)) + else: # 查找对应的App元数据 app_data = await AppCenterManager.fetch_app_data_by_id(self.post_body.app.app_id) @@ -127,8 +168,27 @@ class Scheduler: conversation=context, facts=facts, ) - await self.run_executor(self.queue, self.post_body, executor_background) + # 启动监控任务和主任务 + main_task = asyncio.create_task(self.run_executor(self.queue, self.post_body, executor_background)) + # 等待任一任务完成 + done, pending = await asyncio.wait( + [main_task, monitor], + return_when=asyncio.FIRST_COMPLETED + ) + + # 如果是监控任务触发,终止主任务 + if kill_event.is_set(): + logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...") + main_task.cancel() + need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING] + if self.task.state.flow_status in need_change_cancel_flow_state: + self.task.state.flow_status = FlowStatus.CANCELLED + try: + await main_task + logger.info("[Scheduler] 工作流执行已被终止") + except Exception as e: + logger.error(f"[Scheduler] 终止工作流时发生错误: {e}") # 更新Task,发送结束消息 logger.info("[Scheduler] 发送结束消息") await self.queue.push_output(self.task, event_type=EventType.DONE.value, data={}) @@ -152,6 +212,37 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return + if app_metadata.llm_id == "empty": + llm = LLM( + _id="empty", + user_sub=self.task.ids.user_sub, + openai_base_url=Config().get_config().llm.endpoint, + openai_api_key=Config().get_config().llm.key, + model_name=Config().get_config().llm.model, + max_tokens=Config().get_config().llm.max_tokens, + ) + else: + llm = await LLMManager.get_llm_by_id( + self.task.ids.user_sub, app_metadata.llm_id, + ) + if not llm: + logger.error("[Scheduler] 获取大模型失败") + await self.queue.close() + return + reasion_llm = ReasoningLLM( + LLMConfig( + endpoint=llm.openai_base_url, + key=llm.openai_api_key, + model=llm.model_name, + max_tokens=llm.max_tokens, + ) + ) + if background.conversation and self.task.state.flow_status == FlowStatus.INIT: + try: + question_obj = QuestionRewrite() + post_body.question = await question_obj.generate(history=background.conversation, question=post_body.question, llm=reasion_llm) + except Exception: + logger.exception("[Scheduler] 问题重写失败") if app_metadata.app_type == AppType.FLOW.value: logger.info("[Scheduler] 获取工作流元数据") flow_info = await Pool().get_flow_metadata(app_info.app_id) @@ -182,7 +273,6 @@ class Scheduler: # 初始化Executor logger.info("[Scheduler] 初始化Executor") - flow_exec = FlowExecutor( flow_id=flow_id, flow=flow_data, @@ -206,10 +296,11 @@ class Scheduler: task=self.task, msg_queue=queue, question=post_body.question, - max_steps=app_metadata.history_len, + history_len=app_metadata.history_len, servers_id=servers_id, background=background, agent_id=app_info.app_id, + params=post_body.params ) # 开始运行 logger.info("[Scheduler] 运行Executor") diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 4caab4d2dc945ec5ac14b7769770f94c39980a9f..7a6313a1c14c3122fab87905c07ce1b4bc6a0d84 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """参数槽位管理""" +import copy import json import logging import traceback @@ -12,6 +13,8 @@ from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema.validators import extend +from apps.schemas.response_data import ParamsNode +from apps.scheduler.call.choice.schema import Type from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -126,7 +129,7 @@ class Slot: # Schema标准 return [_process_json_value(item, spec_data["items"]) for item in json_value] if spec_data["type"] == "object" and isinstance(json_value, dict): - # 若Schema不标准,则不进行处理 + # 若Schema不标准,则不进行处理F if "properties" not in spec_data: return json_value # Schema标准 @@ -154,35 +157,60 @@ class Slot: @staticmethod def _generate_example(schema_node: dict) -> Any: # noqa: PLR0911 """根据schema生成示例值""" + if "anyOf" in schema_node or "oneOf" in schema_node: + # 如果有anyOf,随机返回一个示例 + for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]: + example = Slot._generate_example(item) + if example is not None: + return example + + if "allOf" in schema_node: + # 如果有allOf,返回所有示例的合并 + example = None + for item in schema_node["allOf"]: + if example is None: + example = Slot._generate_example(item) + else: + other_example = Slot._generate_example(item) + if isinstance(example, dict) and isinstance(other_example, dict): + example.update(other_example) + else: + example = None + break + return example + if "default" in schema_node: return schema_node["default"] if "type" not in schema_node: return None - + type_value = schema_node["type"] + if isinstance(type_value, list): + # 如果是多类型,随机返回一个示例 + if len(type_value) > 1: + type_value = type_value[0] # 处理类型为 object 的节点 - if schema_node["type"] == "object": + if type_value == "object": data = {} properties = schema_node.get("properties", {}) for name, schema in properties.items(): data[name] = Slot._generate_example(schema) return data - # 处理类型为 array 的节点 - if schema_node["type"] == "array": + elif type_value == "array": items_schema = schema_node.get("items", {}) return [Slot._generate_example(items_schema)] # 处理类型为 string 的节点 - if schema_node["type"] == "string": + elif type_value == "string": return "" # 处理类型为 number 或 integer 的节点 - if schema_node["type"] in ["number", "integer"]: + elif type_value in ["number", "integer"]: return 0 # 处理类型为 boolean 的节点 - if schema_node["type"] == "boolean": + elif type_value == "boolean": return False # 处理其他类型或未定义类型 @@ -196,31 +224,114 @@ class Slot: """从JSON Schema中提取类型描述""" def _extract_type_desc(schema_node: dict[str, Any]) -> dict[str, Any]: - if "type" not in schema_node and "anyOf" not in schema_node: - return {} - data = {"type": schema_node.get("type", ""), "description": schema_node.get("description", "")} - if "anyOf" in schema_node: - data["type"] = "anyOf" - # 处理类型为 object 的节点 - if "anyOf" in schema_node: - data["items"] = {} - type_index = 0 - for type_index, sub_schema in enumerate(schema_node["anyOf"]): - sub_result = _extract_type_desc(sub_schema) - if sub_result: - data["items"]["type_"+str(type_index)] = sub_result - if schema_node.get("type", "") == "object": - data["items"] = {} + # 处理组合关键字 + special_keys = ["anyOf", "allOf", "oneOf"] + for key in special_keys: + if key in schema_node: + data = { + "type": key, + "description": schema_node.get("description", ""), + "items": {}, + } + type_index = 0 + for item in schema_node[key]: + if isinstance(item, dict): + data["items"][f"item_{type_index}"] = _extract_type_desc(item) + else: + data["items"][f"item_{type_index}"] = {"type": item, "description": ""} + type_index += 1 + return data + # 处理基本类型 + type_val = schema_node.get("type", "") + description = schema_node.get("description", "") + + # 处理多类型数组 + if isinstance(type_val, list): + if len(type_val) > 1: + data = {"type": "union", "description": description, "items": {}} + type_index = 0 + for t in type_val: + if t == "object": + tmp_dict = {} + for key, val in schema_node.get("properties", {}).items(): + tmp_dict[key] = _extract_type_desc(val) + data["items"][f"item_{type_index}"] = tmp_dict + elif t == "array": + items_schema = schema_node.get("items", {}) + data["items"][f"item_{type_index}"] = _extract_type_desc(items_schema) + else: + data["items"][f"item_{type_index}"] = {"type": t, "description": description} + type_index += 1 + return data + elif len(type_val) == 1: + type_val = type_val[0] + else: + type_val = "" + + data = {"type": type_val, "description": description, "items": {}} + + # 递归处理对象和数组 + if type_val == "object": for key, val in schema_node.get("properties", {}).items(): data["items"][key] = _extract_type_desc(val) - - # 处理类型为 array 的节点 - if schema_node.get("type", "") == "array": + elif type_val == "array": items_schema = schema_node.get("items", {}) - data["items"] = _extract_type_desc(items_schema) + if isinstance(items_schema, list): + item_index = 0 + for item in items_schema: + data["items"][f"item_{item_index}"] = _extract_type_desc(item) + item_index += 1 + else: + data["items"]["item"] = _extract_type_desc(items_schema) + if data["items"] == {}: + del data["items"] return data + return _extract_type_desc(self._schema) + def get_params_node_from_schema(self, root: str = "") -> ParamsNode: + """从JSON Schema中提取ParamsNode""" + def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode: + """递归提取ParamsNode""" + if "type" not in schema_node: + return None + + param_type = schema_node["type"] + if isinstance(param_type, list): + return None # 不支持多类型 + if param_type == "object": + param_type = Type.DICT + elif param_type == "array": + param_type = Type.LIST + elif param_type == "string": + param_type = Type.STRING + elif param_type in ["number", "integer"]: + param_type = Type.NUMBER + elif param_type == "boolean": + param_type = Type.BOOL + else: + logger.warning(f"[Slot] 不支持的参数类型: {param_type}") + return None + sub_params = [] + + if param_type == Type.DICT and "properties" in schema_node: + for key, value in schema_node["properties"].items(): + sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}") + if sub_param: + sub_params.append(sub_param) + else: + # 对于非对象类型,直接返回空子参数 + sub_params = None + return ParamsNode(paramName=name, + paramPath=path, + paramType=param_type, + subParams=sub_params) + try: + return _extract_params_node(self._schema, name=root, path=root) + except Exception as e: + logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") + return None + def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """将JSON Schema扁平化""" result = {} @@ -276,7 +387,6 @@ class Slot: logger.exception("[Slot] 错误schema不合法: %s", error.schema) return {}, [] - def _assemble_patch( self, key: str, @@ -329,7 +439,6 @@ class Slot: logger.info("[Slot] 组装patch: %s", patch_list) return patch_list - def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]: """将用户手动填充的参数专为真实JSON""" json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data @@ -373,3 +482,54 @@ class Slot: return schema_template return {} + + def add_null_to_basic_types(self) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + """ + def add_null_to_basic_types(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + + 参数: + schema (dict): 原始 JSON Schema + + 返回: + dict: 修改后的 JSON Schema + """ + # 如果不是字典类型(schema),直接返回 + if not isinstance(schema, dict): + return schema + + # 处理当前节点的 type 字段 + if 'type' in schema: + # 处理单一类型字符串 + if isinstance(schema['type'], str): + if schema['type'] in ['boolean', 'number', 'string', 'integer']: + schema['type'] = [schema['type'], 'null'] + + # 处理类型数组 + elif isinstance(schema['type'], list): + for i, t in enumerate(schema['type']): + if isinstance(t, str) and t in ['boolean', 'number', 'string', 'integer']: + if 'null' not in schema['type']: + schema['type'].append('null') + break + + # 递归处理 properties 字段(对象类型) + if 'properties' in schema: + for prop, prop_schema in schema['properties'].items(): + schema['properties'][prop] = add_null_to_basic_types(prop_schema) + + # 递归处理 items 字段(数组类型) + if 'items' in schema: + schema['items'] = add_null_to_basic_types(schema['items']) + + # 递归处理 anyOf, oneOf, allOf 字段 + for keyword in ['anyOf', 'oneOf', 'allOf']: + if keyword in schema: + schema[keyword] = [add_null_to_basic_types(sub_schema) for sub_schema in schema[keyword]] + + return schema + schema_copy = copy.deepcopy(self._schema) + return add_null_to_basic_types(schema_copy) diff --git a/apps/schemas/agent.py b/apps/schemas/agent.py index b52f5e1c3315fb873acddb2bcc4e27937498d561..16e818e4d588d4f25ea4854a7979e2dc5fe90ecb 100644 --- a/apps/schemas/agent.py +++ b/apps/schemas/agent.py @@ -17,6 +17,7 @@ class AgentAppMetadata(MetadataBase): app_type: AppType = Field(default=AppType.AGENT, description="应用类型", frozen=True) published: bool = Field(description="是否发布", default=False) history_len: int = Field(description="对话轮次", default=3, le=10) - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + mcp_service: list[str] = Field(default=[], description="MCP服务id列表") + llm_id: str = Field(default="empty", description="大模型ID") permission: Permission | None = Field(description="应用权限配置", default=None) version: str = Field(description="元数据版本") diff --git a/apps/schemas/appcenter.py b/apps/schemas/appcenter.py index a89f39df18083d90b988cc764c0e88f90500d1f3..e3bb896eba2361e0c28ee914b84f387d7da76b87 100644 --- a/apps/schemas/appcenter.py +++ b/apps/schemas/appcenter.py @@ -45,9 +45,9 @@ class AppFlowInfo(BaseModel): """应用工作流数据结构""" id: str = Field(..., description="工作流ID") - name: str = Field(..., description="工作流名称") - description: str = Field(..., description="工作流简介") - debug: bool = Field(..., description="是否经过调试") + name: str = Field(default="", description="工作流名称") + description: str = Field(default="", description="工作流简介") + debug: bool = Field(default=False, description="是否经过调试") class AppData(BaseModel): @@ -61,6 +61,7 @@ class AppData(BaseModel): first_questions: list[str] = Field( default=[], alias="recommendedQuestions", description="推荐问题", max_length=3) history_len: int = Field(3, alias="dialogRounds", ge=1, le=10, description="对话轮次(1~10)") + llm: str = Field(default="empty", description="大模型ID") permission: AppPermissionData = Field( default_factory=lambda: AppPermissionData(authorizedUsers=None), description="权限配置") workflows: list[AppFlowInfo] = Field(default=[], description="工作流信息列表") diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index 0ff66c72bbe30b7cbb55517952c8b14744d69cf2..a2991f851a85c8d60afb6d60e76cbb766c9da94e 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -61,6 +61,7 @@ class User(BaseModel): fav_apps: list[str] = [] fav_services: list[str] = [] is_admin: bool = Field(default=False, description="是否为管理员") + auto_execute: bool = Field(default=True, description="是否自动执行任务") class LLM(BaseModel): diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1afb018663713a3bf4c0b9b62436e59d4..675a9ba7f26531206a5ea5f8a7598d4344ae2124 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -6,6 +6,12 @@ from typing import Literal from pydantic import BaseModel, Field +class NoauthConfig(BaseModel): + """无认证配置""" + + enable: bool = Field(description="是否启用无认证访问", default=False) + + class DeployConfig(BaseModel): """部署配置""" @@ -122,7 +128,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - + no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig()) deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84d4805bcd502d9d4ca0f9ba3c49a7bb4c..3ae7c42528d27c3e0d17970ba5e959db8e1bed44 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -14,11 +14,26 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" - + UNKNOWN = "unknown" + INIT = "init" + WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" PARAM = "param" + CANCELLED = "cancelled" + + +class FlowStatus(str, Enum): + """Flow状态""" + + UNKNOWN = "unknown" + INIT = "init" + WAITING = "waiting" + RUNNING = "running" + SUCCESS = "success" + ERROR = "error" + CANCELLED = "cancelled" class DocumentStatus(str, Enum): @@ -34,14 +49,22 @@ class EventType(str, Enum): """事件类型""" HEARTBEAT = "heartbeat" - INIT = "init", + INIT = "init" TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" + STEP_WAITING_FOR_START = "step.waiting_for_start" + STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" + STEP_INIT = "step.init" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" + STEP_CANCEL = "step.cancel" + STEP_ERROR = "step.error" FLOW_STOP = "flow.stop" + FLOW_FAILED = "flow.failed" + FLOW_SUCCESS = "flow.success" + FLOW_CANCEL = "flow.cancel" DONE = "done" @@ -81,7 +104,7 @@ class NodeType(str, Enum): START = "start" END = "end" NORMAL = "normal" - CHOICE = "choice" + CHOICE = "Choice" class SaveType(str, Enum): @@ -144,7 +167,7 @@ class SpecialCallType(str, Enum): LLM = "LLM" START = "start" END = "end" - CHOICE = "choice" + CHOICE = "Choice" class CommentType(str, Enum): diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index 2646d04390099fd92988d78c00d1b1471780d6a9..dfffd1f14c779952d9bb761d2e2bccd85528bc0f 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -136,6 +136,7 @@ class AppMetadata(MetadataBase): published: bool = Field(description="是否发布", default=False) links: list[AppLink] = Field(description="相关链接", default=[]) first_questions: list[str] = Field(description="首次提问", default=[]) + llm_id: str = Field(description="大模型ID", default="empty") history_len: int = Field(description="对话轮次", default=3, le=10) permission: Permission | None = Field(description="应用权限配置", default=None) flows: list[AppFlow] = Field(description="Flow列表", default=[]) diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index d0ab666a9902ec7e0d81bb922f973d655f29eaf2..fecf21606634982ad6480069a7adc6714a612db2 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -5,6 +5,7 @@ from typing import Any from pydantic import BaseModel, Field +from apps.schemas.enum_var import SpecialCallType from apps.schemas.enum_var import EdgeType @@ -51,7 +52,7 @@ class NodeItem(BaseModel): service_id: str = Field(alias="serviceId", default="") node_id: str = Field(alias="nodeId", default="") name: str = Field(default="") - call_id: str = Field(alias="callId", default="Empty") + call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value) description: str = Field(default="") enable: bool = Field(default=True) parameters: dict[str, Any] = Field(default={}) @@ -81,6 +82,6 @@ class FlowItem(BaseModel): nodes: list[NodeItem] = Field(default=[]) edges: list[EdgeItem] = Field(default=[]) created_at: float | None = Field(alias="createdAt", default=0) - connectivity: bool = Field(default=False,description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") + connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem()) debug: bool = Field(default=False) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 44021b0ed5f5f6c372ac0a472bb3ac28ff5bbddb..d27a5b469877f24639d1566f62f8b216636d3fd7 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" +import uuid from enum import Enum from typing import Any @@ -10,8 +11,9 @@ from pydantic import BaseModel, Field class MCPInstallStatus(str, Enum): """MCP 服务状态""" - + INIT = "init" INSTALLING = "installing" + CANCELLED = "cancelled" READY = "ready" FAILED = "failed" @@ -22,6 +24,7 @@ class MCPStatus(str, Enum): UNINITIALIZED = "uninitialized" RUNNING = "running" STOPPED = "stopped" + ERROR = "error" class MCPType(str, Enum): @@ -38,7 +41,9 @@ class MCPBasicConfig(BaseModel): env: dict[str, str] = Field(description="MCP 服务器环境变量", default={}) auto_approve: list[str] = Field(description="自动批准的MCP工具ID列表", default=[], alias="autoApprove") disabled: bool = Field(description="MCP 服务器是否禁用", default=False) - auto_install: bool = Field(description="是否自动安装MCP服务器", default=True, alias="autoInstall") + auto_install: bool = Field(description="是否自动安装MCP服务器", default=True) + timeout: int = Field(description="MCP 服务器超时时间(秒)", default=60, alias="timeout") + description: str = Field(description="MCP 服务器自然语言描述", default="") class MCPServerStdioConfig(MCPBasicConfig): @@ -84,7 +89,7 @@ class MCPCollection(BaseModel): type: MCPType = Field(description="MCP 类型", default=MCPType.SSE) activated: list[str] = Field(description="激活该MCP的用户ID列表", default=[]) tools: list[MCPTool] = Field(description="MCP工具列表", default=[]) - status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INSTALLING) + status: MCPInstallStatus = Field(description="MCP服务状态", default=MCPInstallStatus.INIT) author: str = Field(description="MCP作者", default="") @@ -103,6 +108,60 @@ class MCPToolVector(LanceModel): embedding: Vector(dim=1024) = Field(description="MCP工具描述的向量信息") # type: ignore[call-arg] +class GoalEvaluationResult(BaseModel): + """MCP 目标评估结果""" + + can_complete: bool = Field(description="是否可以完成目标") + reason: str = Field(description="评估原因") + + +class RestartStepIndex(BaseModel): + """MCP重新规划的步骤索引""" + + start_index: int = Field(description="重新规划的起始步骤索引") + reasoning: str = Field(description="重新规划的原因") + + +class Risk(str, Enum): + """MCP工具风险类型""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class ToolSkip(BaseModel): + """MCP工具跳过执行结果""" + + skip: bool = Field(description="是否跳过当前步骤", default=False) + + +class ToolRisk(BaseModel): + """MCP工具风险评估结果""" + + risk: Risk = Field(description="风险类型", default=Risk.LOW) + reason: str = Field(description="风险原因", default="") + + +class ErrorType(str, Enum): + """MCP工具错误类型""" + + MISSING_PARAM = "missing_param" + DECORRECT_PLAN = "decorrect_plan" + + +class ToolExcutionErrorType(BaseModel): + """MCP工具执行错误""" + + type: ErrorType = Field(description="错误类型", default=ErrorType.MISSING_PARAM) + reason: str = Field(description="错误原因", default="") + + +class IsParamError(BaseModel): + """MCP工具参数错误""" + + is_param_error: bool = Field(description="是否是参数错误", default=False) + class MCPSelectResult(BaseModel): """MCP选择结果""" @@ -115,9 +174,15 @@ class MCPToolSelectResult(BaseModel): name: str = Field(description="工具名称") +class MCPToolIdsSelectResult(BaseModel): + """MCP工具ID选择结果""" + + tool_ids: list[str] = Field(description="工具ID列表") + + class MCPPlanItem(BaseModel): """MCP 计划""" - + step_id: str = Field(description="步骤的ID", default="") content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") @@ -126,4 +191,11 @@ class MCPPlanItem(BaseModel): class MCPPlan(BaseModel): """MCP 计划""" - plans: list[MCPPlanItem] = Field(description="计划列表") + plans: list[MCPPlanItem] = Field(description="计划列表", default=[]) + + +class Step(BaseModel): + """MCP步骤""" + + tool_id: str = Field(description="工具ID") + description: str = Field(description="步骤描述") diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d0661224e0817ff6f25e341d65c88903020d18e8..5b465ee54321bbd4c649753911025bff41840186 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -2,13 +2,19 @@ """队列中的消息结构""" from typing import Any - +from datetime import UTC, datetime from pydantic import BaseModel, Field -from apps.schemas.enum_var import EventType, StepStatus +from apps.schemas.enum_var import EventType, FlowStatus, StepStatus from apps.schemas.record import RecordMetadata +class param(BaseModel): + """流执行过程中的参数补充""" + content: dict[str, Any] = Field(default={}, description="流执行过程中的参数补充内容") + description: str = Field(default="", description="流执行过程中的参数补充描述") + + class HeartbeatData(BaseModel): """心跳事件的数据结构""" @@ -22,8 +28,17 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") + flow_name: str = Field(description="Flow名称", alias="flowName") + flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN) step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") + sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) + sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) + step_description: str | None = Field( + description="当前步骤描述", + alias="stepDescription", + default=None, + ) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") @@ -60,17 +75,21 @@ class DocumentAddContent(BaseModel): document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") + document_author: str = Field(description="文档作者", alias="documentAuthor", default="") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) + created_at: float = Field( + description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + ) class FlowStartContent(BaseModel): """flow.start消息的content""" question: str = Field(description="用户问题") - params: dict[str, Any] = Field(description="预先提供的参数") + params: dict[str, Any] | None = Field(description="预先提供的参数", default=None) class MessageBase(HeartbeatData): @@ -81,5 +100,5 @@ class MessageBase(HeartbeatData): conversation_id: str = Field(min_length=36, max_length=36, alias="conversationId") task_id: str = Field(min_length=36, max_length=36, alias="taskId") flow: MessageFlow | None = None - content: dict[str, Any] = {} + content: Any | None = Field(default=None, description="消息内容") metadata: MessageMetadata diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..bd908d2375415c798f4804c4eabde797f6a4f7a0 --- /dev/null +++ b/apps/schemas/parameters.py @@ -0,0 +1,69 @@ +from enum import Enum + + +class NumberOperate(str, Enum): + """Choice 工具支持的数字运算符""" + + EQUAL = "number_equal" + NOT_EQUAL = "number_not_equal" + GREATER_THAN = "number_greater_than" + LESS_THAN = "number_less_than" + GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal" + LESS_THAN_OR_EQUAL = "number_less_than_or_equal" + + +class StringOperate(str, Enum): + """Choice 工具支持的字符串运算符""" + + EQUAL = "string_equal" + NOT_EQUAL = "string_not_equal" + CONTAINS = "string_contains" + NOT_CONTAINS = "string_not_contains" + STARTS_WITH = "string_starts_with" + ENDS_WITH = "string_ends_with" + LENGTH_EQUAL = "string_length_equal" + LENGTH_GREATER_THAN = "string_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal" + LENGTH_LESS_THAN = "string_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal" + REGEX_MATCH = "string_regex_match" + + +class ListOperate(str, Enum): + """Choice 工具支持的列表运算符""" + + EQUAL = "list_equal" + NOT_EQUAL = "list_not_equal" + CONTAINS = "list_contains" + NOT_CONTAINS = "list_not_contains" + LENGTH_EQUAL = "list_length_equal" + LENGTH_GREATER_THAN = "list_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal" + LENGTH_LESS_THAN = "list_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal" + + +class BoolOperate(str, Enum): + """Choice 工具支持的布尔运算符""" + + EQUAL = "bool_equal" + NOT_EQUAL = "bool_not_equal" + + +class DictOperate(str, Enum): + """Choice 工具支持的字典运算符""" + + EQUAL = "dict_equal" + NOT_EQUAL = "dict_not_equal" + CONTAINS_KEY = "dict_contains_key" + NOT_CONTAINS_KEY = "dict_not_contains_key" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + NUMBER = "number" + LIST = "list" + DICT = "dict" + BOOL = "bool" diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py index 27e16b370ec83acc11e1f435ac1da296fe2a9560..3532b6e2ffd31739e25e53414245b5d54fd395d0 100644 --- a/apps/schemas/pool.py +++ b/apps/schemas/pool.py @@ -109,4 +109,7 @@ class AppPool(BaseData): permission: Permission = Field(description="应用权限配置", default=Permission()) flows: list[AppFlow] = Field(description="Flow列表", default=[]) hashes: dict[str, str] = Field(description="关联文件的hash值", default={}) - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") + mcp_service: list[str] = Field(default=[], description="MCP服务id列表") + llm_id: str = Field( + default="empty", description="应用使用的大模型ID(如果有的话)" + ) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d7acd368205d0aa5a2b368edf4af3a31ed4d1ff6..b6b70dd6fbd5b89aafb06af5146b8feb1fb84862 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -10,15 +10,17 @@ from pydantic import BaseModel, Field from apps.schemas.collection import ( Document, ) -from apps.schemas.enum_var import CommentType, StepStatus +from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus class RecordDocument(Document): """GET /api/record/{conversation_id} Result中的document数据结构""" id: str = Field(alias="_id", default="") + order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") user_sub: None = None + author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] class Config: @@ -34,6 +36,7 @@ class RecordFlowStep(BaseModel): step_status: StepStatus = Field(alias="stepStatus") input: dict[str, Any] output: dict[str, Any] + ex_data: dict[str, Any] | None = Field(default=None, alias="exData") class RecordFlow(BaseModel): @@ -42,6 +45,8 @@ class RecordFlow(BaseModel): id: str record_id: str = Field(alias="recordId") flow_id: str = Field(alias="flowId") + flow_name: str = Field(alias="flowName", default="") + flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS) step_num: int = Field(alias="stepNum") steps: list[RecordFlowStep] @@ -90,7 +95,7 @@ class RecordData(BaseModel): id: str group_id: str = Field(alias="groupId") conversation_id: str = Field(alias="conversationId") - task_id: str = Field(alias="taskId") + task_id: str | None = Field(default=None, alias="taskId") document: list[RecordDocument] = [] flow: RecordFlow | None = None content: RecordContent @@ -103,11 +108,22 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + order: int = Field(default=0, description="文档顺序") + author: str = Field(default="", description="文档作者") name: str = Field(default="", description="文档名称") abstract: str = Field(default="", description="文档摘要") extension: str = Field(default="", description="文档扩展名") size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] + created_at: float = Field(default=0.0, description="文档创建时间") + + +class FlowHistory(BaseModel): + """Flow执行历史""" + flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + flow_name: str = Field(default="", description="Flow名称") + flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态") + history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表") class Record(RecordData): @@ -115,9 +131,11 @@ class Record(RecordData): user_sub: str key: dict[str, Any] = {} - content: str + task_id: str | None = Field(default=None, description="任务ID") + content: str = Field(default="", description="Record内容,已加密") comment: RecordComment = Field(default=RecordComment()) - flow: list[str] = Field(default=[]) + flow: FlowHistory = Field( + default=FlowHistory(), description="Flow执行历史信息") class RecordGroup(BaseModel): @@ -134,5 +152,4 @@ class RecordGroup(BaseModel): records: list[Record] = [] docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变 conversation_id: str - task_id: str created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2305dd93dfe33969691dcc42928e578e7f1c4207..3fd5a67fa29af6e1685d55285ee1ebd458cabce2 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -10,14 +10,14 @@ from apps.schemas.appcenter import AppData from apps.schemas.enum_var import CommentType from apps.schemas.flow_topology import FlowItem from apps.schemas.mcp import MCPType +from apps.schemas.message import param class RequestDataApp(BaseModel): """模型对话中包含的app信息""" app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") class MockRequestData(BaseModel): @@ -39,13 +39,15 @@ class RequestDataFeatures(BaseModel): class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" - question: str = Field(max_length=2000, description="用户输入") - conversation_id: str = Field(default="", alias="conversationId", description="聊天ID") + question: str | None = Field(default=None, max_length=2000, description="用户输入") + conversation_id: str | None = Field(default=None, alias="conversationId", description="聊天ID") group_id: str | None = Field(default=None, alias="groupId", description="问答组ID") language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + task_id: str | None = Field(default=None, alias="taskId", description="任务ID") + params: param | bool | None = Field(default=None, description="流执行过程中的参数补充", alias="params") class QuestionBlacklistRequest(BaseModel): @@ -98,7 +100,7 @@ class UpdateMCPServiceRequest(BaseModel): name: str = Field(..., description="MCP服务名称") description: str = Field(..., description="MCP服务描述") overview: str = Field(..., description="MCP服务概述") - config: str = Field(..., description="MCP服务配置") + config: dict[str, Any] = Field(..., description="MCP服务配置") mcp_type: MCPType = Field(description="MCP传输协议(Stdio/SSE/Streamable)", default=MCPType.STDIO, alias="mcpType") @@ -106,6 +108,7 @@ class ActiveMCPServiceRequest(BaseModel): """POST /api/mcp/{serviceId} 请求数据结构""" active: bool = Field(description="是否激活mcp服务") + mcp_env: dict[str, Any] | None = Field(default=None, description="MCP服务环境变量", alias="mcpEnv") class UpdateServiceRequest(BaseModel): @@ -183,3 +186,9 @@ class UpdateKbReq(BaseModel): """更新知识库请求体""" kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[]) + + +class UserUpdateRequest(BaseModel): + """更新用户信息请求体""" + + auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index b2a1872918638b153e1e94ff4db85ef00cfc0502..542dac88bb134ca79aa7934d8781c1dbd8a93cc3 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -14,10 +14,19 @@ from apps.schemas.flow_topology import ( NodeServiceItem, PositionItem, ) +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType from apps.schemas.record import RecordData from apps.schemas.user import UserInfo from apps.templates.generate_llm_operator_config import llm_provider_dict +from apps.common.config import Config class ResponseData(BaseModel): @@ -47,6 +56,7 @@ class AuthUserMsg(BaseModel): user_sub: str revision: bool is_admin: bool + auto_execute: bool class AuthUserRsp(ResponseData): @@ -90,7 +100,7 @@ class LLMIteam(BaseModel): icon: str = Field(default=llm_provider_dict["ollama"]["icon"]) llm_id: str = Field(alias="llmId", default="empty") - model_name: str = Field(alias="modelName", default="Ollama LLM") + model_name: str = Field(alias="modelName", default=Config().get_config().llm.model) class KbIteam(BaseModel): @@ -272,11 +282,21 @@ class BaseAppOperationRsp(ResponseData): result: BaseAppOperationMsg +class AppMcpServiceInfo(BaseModel): + """应用关联的MCP服务信息""" + + id: str = Field(..., description="MCP服务ID") + name: str = Field(default="", description="MCP服务名称") + description: str = Field(default="", description="MCP服务简介") + + class GetAppPropertyMsg(AppData): """GET /api/app/{appId} Result数据结构""" app_id: str = Field(..., alias="appId", description="应用ID") published: bool = Field(..., description="是否已发布") + mcp_service: list[AppMcpServiceInfo] = Field(default=[], alias="mcpService", description="MCP服务信息列表") + llm: LLMIteam | None = Field(alias="llm", default=None) class GetAppPropertyRsp(ResponseData): @@ -573,7 +593,7 @@ class FlowStructureDeleteRsp(ResponseData): class UserGetMsp(BaseModel): """GET /api/user result""" - + total: int = Field(default=0) user_info_list: list[UserInfo] = Field(alias="userInfoList", default=[]) @@ -628,3 +648,42 @@ class ListLLMRsp(ResponseData): """GET /api/llm 返回数据结构""" result: list[LLMProviderInfo] = Field(default=[], title="Result") + + +class ParamsNode(BaseModel): + """参数数据结构""" + param_name: str = Field(..., description="参数名称", alias="paramName") + param_path: str = Field(..., description="参数路径", alias="paramPath") + param_type: Type = Field(..., description="参数类型", alias="paramType") + sub_params: list["ParamsNode"] | None = Field( + default=None, description="子参数列表", alias="subParams" + ) + + +class StepParams(BaseModel): + """参数数据结构""" + step_id: str = Field(..., description="步骤ID", alias="stepId") + name: str = Field(..., description="Step名称") + params_node: ParamsNode | None = Field( + default=None, description="参数节点", alias="paramsNode") + + +class GetParamsRsp(ResponseData): + """GET /api/params 返回数据结构""" + + result: list[StepParams] = Field( + default=[], description="参数列表", alias="result" + ) + + +class OperateAndBindType(BaseModel): + """操作和绑定类型数据结构""" + + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型") + bind_type: Type = Field(description="绑定类型") + + +class GetOperaRsp(ResponseData): + """GET /api/operate 返回数据结构""" + + result: list[OperateAndBindType] = Field(..., title="Result") diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 8efcb59914d568478c65d611670b251d019b38ed..d3e7b036c35f8ecf53a204b8b8a28e4600738184 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -7,8 +7,9 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import StepStatus +from apps.schemas.enum_var import FlowStatus, StepStatus from apps.schemas.flow import Step +from apps.schemas.mcp import MCPPlan class FlowStepHistory(BaseModel): @@ -22,12 +23,14 @@ class FlowStepHistory(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="FlowID") flow_name: str = Field(description="Flow名称") + flow_status: FlowStatus = Field(description="Flow状态") step_id: str = Field(description="当前步骤名称") step_name: str = Field(description="当前步骤名称") - step_description: str = Field(description="当前步骤描述") - status: StepStatus = Field(description="当前步骤状态") + step_description: str = Field(description="当前步骤描述", default="") + step_status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) + ex_data: dict[str, Any] | None = Field(description="额外数据", default=None) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) @@ -35,16 +38,21 @@ class ExecutorState(BaseModel): """FlowExecutor状态""" # 执行器级数据 - flow_id: str = Field(description="Flow ID") - flow_name: str = Field(description="Flow名称") - description: str = Field(description="Flow描述") - status: StepStatus = Field(description="Flow执行状态") - # 附加信息 - step_id: str = Field(description="当前步骤ID") - step_name: str = Field(description="当前步骤名称") - app_id: str = Field(description="应用ID") - slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) - error_info: dict[str, Any] = Field(description="错误信息", default={}) + flow_id: str = Field(description="Flow ID", default="") + flow_name: str = Field(description="Flow名称", default="") + description: str = Field(description="Flow描述", default="") + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT) + # 任务级数据 + step_cnt: int = Field(description="当前步骤数量", default=0) + step_id: str = Field(description="当前步骤ID", default="") + tool_id: str = Field(description="当前工具ID", default="") + step_name: str = Field(description="当前步骤名称", default="") + step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) + step_description: str = Field(description="当前步骤描述", default="") + app_id: str = Field(description="应用ID", default="") + current_input: dict[str, Any] = Field(description="当前输入数据", default={}) + error_message: str = Field(description="错误信息", default="") + retry_times: int = Field(description="当前步骤重试次数", default=0) class TaskIds(BaseModel): @@ -55,6 +63,7 @@ class TaskIds(BaseModel): conversation_id: str = Field(description="对话ID") record_id: str = Field(description="记录ID", default_factory=lambda: str(uuid.uuid4())) user_sub: str = Field(description="用户ID") + active_id: str = Field(description="活动ID", default_factory=lambda: str(uuid.uuid4())) class TaskTokens(BaseModel): @@ -86,10 +95,11 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) - state: ExecutorState | None = Field(description="Flow的状态", default=None) + context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) + state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") + language: str = Field(description="语言", default="zh") created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/user.py b/apps/schemas/user.py index 61aa2587b8ae6255dc3885b04a96f12310f70773..ebea66924c062f4e680ad2e414768b307b056f4a 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -9,3 +9,4 @@ class UserInfo(BaseModel): user_sub: str = Field(alias="userSub", default="") user_name: str = Field(alias="userName", default="") + auto_execute: bool | None = Field(alias="autoExecute", default=False) diff --git a/apps/services/activity.py b/apps/services/activity.py index 299a49a640664b29a3e4724c5842f12bd289e3bf..7ba20cbb4300a5fa5cb60e37cdb887f38d54232f 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -3,70 +3,65 @@ import uuid from datetime import UTC, datetime - +import logging from apps.common.mongo import MongoDB from apps.constants import SLIDE_WINDOW_QUESTION_COUNT, SLIDE_WINDOW_TIME from apps.exceptions import ActivityError +logger = logging.getLogger(__name__) + class Activity: - """用户活动控制,限制单用户同一时间只能提问一个问题""" + """用户活动控制,限制单用户同一时间只能有SLIDE_WINDOW_QUESTION_COUNT个请求""" @staticmethod - async def is_active(user_sub: str) -> bool: + async def is_active(active_id: str) -> bool: """ 判断当前用户是否正在提问(占用GPU资源) :param user_sub: 用户实体ID :return: 判断结果,正在提问则返回True """ - time = round(datetime.now(UTC).timestamp(), 3) - - # 检查窗口内总请求数 - count = await MongoDB().get_collection("activity").count_documents( - {"timestamp": {"$gte": time - SLIDE_WINDOW_TIME, "$lte": time}}, - ) - if count >= SLIDE_WINDOW_QUESTION_COUNT: - return True # 检查用户是否正在提问 active = await MongoDB().get_collection("activity").find_one( - {"user_sub": user_sub}, + {"_id": active_id}, ) return bool(active) @staticmethod - async def set_active(user_sub: str) -> None: + async def set_active(user_sub: str) -> str: """设置用户的活跃标识""" time = round(datetime.now(UTC).timestamp(), 3) # 设置用户活跃状态 collection = MongoDB().get_collection("activity") - active = await collection.find_one({"user_sub": user_sub}) - if active: - err = "用户正在提问" + # 查看用户活跃标识是否在滑动窗口内 + + if await collection.count_documents({"user_sub": user_sub, "timestamp": {"$gt": time - SLIDE_WINDOW_TIME}}) >= SLIDE_WINDOW_QUESTION_COUNT: + err = "[Activity] 用户在滑动窗口内提问次数超过限制,请稍后再试。" raise ActivityError(err) + await collection.delete_many( + {"user_sub": user_sub, "timestamp": {"$lte": time - SLIDE_WINDOW_TIME}}, + ) + # 插入新的活跃记录 + tmp_record = { + "_id": str(uuid.uuid4()), + "user_sub": user_sub, + "timestamp": time, + } await collection.insert_one( - { - "_id": str(uuid.uuid4()), - "user_sub": user_sub, - "timestamp": time, - }, + tmp_record ) + return tmp_record["_id"] @staticmethod - async def remove_active(user_sub: str) -> None: + async def remove_active(active_id: str) -> None: """ 清除用户的活跃标识,释放GPU资源 :param user_sub: 用户实体ID """ - time = round(datetime.now(UTC).timestamp(), 3) # 清除用户当前活动标识 await MongoDB().get_collection("activity").delete_one( - {"user_sub": user_sub}, - ) - - # 清除超出窗口范围的请求记录 - await MongoDB().get_collection("activity").delete_many( - {"timestamp": {"$lte": time - SLIDE_WINDOW_TIME}}, + {"_id": active_id}, ) diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index e256ab55ab13bf8c5b9e40cb36a3c616a1d39596..4e0d00d06254973d284b149faaa632d969700981 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -59,7 +59,6 @@ class AppCenterManager: } user_favorite_app_ids = await AppCenterManager._get_favorite_app_ids_by_user(user_sub) - if filter_type == AppFilterType.ALL: # 获取所有已发布的应用 filters["published"] = True @@ -72,7 +71,6 @@ class AppCenterManager: "_id": {"$in": user_favorite_app_ids}, "published": True, } - # 添加关键字搜索条件 if keyword: filters["$or"] = [ @@ -84,8 +82,8 @@ class AppCenterManager: # 添加应用类型过滤条件 if app_type is not None: filters["app_type"] = app_type.value - # 获取应用列表 + logger.error(f"[AppCenterManager] 搜索条件: {filters}, 页码: {page}, 每页大小: {SERVICE_PAGE_SIZE}") apps, total_apps = await AppCenterManager._search_apps_by_filter(filters, page, SERVICE_PAGE_SIZE) # 构建返回的应用卡片列表 @@ -484,7 +482,19 @@ class AppCenterManager: else: # 在预期的条件下,如果在 data 或 app_data 中找不到 mcp_service,则默认回退为空列表。 metadata.mcp_service = [] - + # 处理llm_id字段 + if data is not None and hasattr(data, "llm") and data.llm: + # 创建应用场景,验证传入的 llm_id 状态 (create_app) + metadata.llm_id = data.llm if data.llm else "empty" + elif data is not None and hasattr(data, "llm_id"): + # 更新应用场景,使用 data 中的 llm_id (update_app) + metadata.llm_id = data.llm if data.llm else "empty" + elif app_data is not None and hasattr(app_data, "llm_id"): + # 更新应用发布状态场景,使用 app_data 中的 llm_id (update_app_publish_status) + metadata.llm_id = app_data.llm_id if app_data.llm_id else "empty" + else: + # 在预期的条件下,如果在 data 或 app_data 中找不到 llm_id,则默认回退为 "empty"。 + metadata.llm_id = "empty" # Agent 应用的发布状态逻辑 if published is not None: # 从 update_app_publish_status 调用,'published' 参数已提供 metadata.published = published diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 4bcade45c757ea2b3dca6c58863b2852af98c15b..6bacb72720dbe43ff6835d4714c720f2778d85b1 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,7 +59,11 @@ class ConversationManager: model_name=llm.model_name, ) kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except: + logger.error("[ConversationManager] 获取团队知识库列表失败") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: @@ -136,4 +140,4 @@ class ConversationManager: await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session) await session.commit_transaction() - await TaskManager.delete_tasks_by_conversation_id(conversation_id) + await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id) diff --git a/apps/services/document.py b/apps/services/document.py index 203162da1137cdd69d900ceea65a3029d2f4fd8e..451423a9d8f752c2abbf97f5b2a1df38e6d7fefe 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -2,6 +2,7 @@ """文件Manager""" import base64 +from datetime import UTC, datetime import logging import uuid @@ -131,12 +132,15 @@ class DocumentManager: return [ RecordDocument( _id=doc.id, + order=doc.order, + author=doc.author, abstract=doc.abstract, name=doc.name, type=doc.extension, size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, + created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] diff --git a/apps/services/flow.py b/apps/services/flow.py index 9275fc60c75199e8e17ebcba8f164083190b7211..4d682e5a85677b3a526ec9d6d2db521c238357b4 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,6 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager - logger = logging.getLogger(__name__) @@ -258,10 +257,7 @@ class FlowManager: ) for node_id, node_config in flow_config.steps.items(): input_parameters = node_config.params - if node_config.node not in ("Empty"): - _, output_parameters = await NodeManager.get_node_params(node_config.node) - else: - output_parameters = {} + _, output_parameters = await NodeManager.get_node_params(node_config.node) parameters = { "input_parameters": input_parameters, "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(), @@ -414,6 +410,7 @@ class FlowManager: flow_config.debug = await FlowManager.is_flow_config_equal(old_flow_config, flow_config) else: flow_config.debug = False + logger.error(f'{flow_config}') await flow_loader.save(app_id, flow_id, flow_config) except Exception: logger.exception("[FlowManager] 存储/更新流失败") diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index 78e8d340e955115a51e24ac8fc9081fd8ff00e7e..16c23053705c158c4d4f7e0faa438fe3927c0d63 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -4,6 +4,7 @@ import collections import logging +from apps.schemas.enum_var import SpecialCallType from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError from apps.schemas.enum_var import NodeType from apps.schemas.flow_topology import EdgeItem, FlowItem, NodeItem @@ -38,32 +39,40 @@ class FlowService: for node in flow_item.nodes: from apps.scheduler.pool.pool import Pool from pydantic import BaseModel - if node.node_id != 'start' and node.node_id != 'end' and node.node_id != 'Empty': + if node.node_id != 'start' and node.node_id != 'end' and node.node_id != SpecialCallType.EMPTY.value: try: call_class: type[BaseModel] = await Pool().get_call(node.call_id) if not call_class: - node.node_id = 'Empty' + node.node_id = SpecialCallType.EMPTY.value node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description except Exception as e: - node.node_id = 'Empty' + node.node_id = SpecialCallType.EMPTY.value node.description = '【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n'+node.description logger.error(f"[FlowService] 获取步骤的call_id失败{node.call_id}由于:{e}") node_branch_map[node.step_id] = set() if node.call_id == NodeType.CHOICE.value: - node.parameters = node.parameters["input_parameters"] - if "choices" not in node.parameters: - node.parameters["choices"] = [] - for choice in node.parameters["choices"]: - if choice["branchId"] in node_branch_map[node.step_id]: - err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}重复" + input_parameters = node.parameters["input_parameters"] + if "choices" not in input_parameters: + logger.error(f"[FlowService] 节点{node.name}的分支字段缺失") + raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段缺失") + if not input_parameters["choices"]: + logger.error(f"[FlowService] 节点{node.name}的分支字段为空") + raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段为空") + for choice in input_parameters["choices"]: + if "branch_id" not in choice: + err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段" + logger.error(err) + raise FlowBranchValidationError(err) + if choice["branch_id"] in node_branch_map[node.step_id]: + err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}重复" logger.error(err) raise Exception(err) for illegal_char in branch_illegal_chars: - if illegal_char in choice["branchId"]: - err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}名称中含有非法字符" + if illegal_char in choice["branch_id"]: + err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}名称中含有非法字符" logger.error(err) raise Exception(err) - node_branch_map[node.step_id].add(choice["branchId"]) + node_branch_map[node.step_id].add(choice["branch_id"]) else: node_branch_map[node.step_id].add("") valid_edges = [] @@ -133,7 +142,6 @@ class FlowService: branches = {} in_deg = {} out_deg = {} - for e in edges: if e.edge_id in ids: err = f"[FlowService] 边{e.edge_id}的id重复" diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 9b4077f92cf44bca42d321cc20975d7ea2e5cadd..bd8dfc9e808f642a8eecf493a96f762dad8ae7b9 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -138,7 +138,11 @@ class KnowledgeBaseManager: return [] kb_ids_update_success = [] kb_item_dict_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except Exception as e: + logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 7cb880c0f08b1cc7727bde528be39ac748f07ae2..3c7ab3c5b42615a6f4005d5c032dd81481933511 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -2,6 +2,7 @@ """MCP服务管理器""" import logging +from logging import config import random import re from typing import Any @@ -28,8 +29,10 @@ from apps.schemas.mcp import ( MCPTool, MCPType, ) +from apps.services.user import UserManager from apps.schemas.request_data import UpdateMCPServiceRequest from apps.schemas.response_data import MCPServiceCardItem +from apps.constants import MCP_PATH logger = logging.getLogger(__name__) sqids = Sqids(min_length=6) @@ -66,10 +69,8 @@ class MCPServiceManager: mcp_list = await mcp_collection.find({"_id": mcp_id}, {"status": True}).to_list(None) for db_item in mcp_list: status = db_item.get("status") - if MCPInstallStatus.READY.value == status: - return MCPInstallStatus.READY - if MCPInstallStatus.INSTALLING.value == status: - return MCPInstallStatus.INSTALLING + if status in MCPInstallStatus.__members__.values(): + return status return MCPInstallStatus.FAILED @staticmethod @@ -78,6 +79,8 @@ class MCPServiceManager: user_sub: str, keyword: str | None, page: int, + is_install: bool | None = None, + is_active: bool | None = None, ) -> list[MCPServiceCardItem]: """ 获取所有MCP服务列表 @@ -89,6 +92,17 @@ class MCPServiceManager: :return: MCP服务列表 """ filters = MCPServiceManager._build_filters(search_type, keyword) + if is_active is not None: + if is_active: + filters["activated"] = {"$in": [user_sub]} + else: + filters["activated"] = {"$nin": [user_sub]} + if not is_install: + user_info = await UserManager.get_userinfo_by_user_sub(user_sub) + if not user_info.is_admin: + filters["status"] = MCPInstallStatus.READY.value + else: + filters["status"] = MCPInstallStatus.READY.value mcpservice_pools = await MCPServiceManager._search_mcpservice(filters, page) return [ MCPServiceCardItem( @@ -198,7 +212,6 @@ class MCPServiceManager: base_filters = {"author": {"$regex": keyword, "$options": "i"}} return base_filters - @staticmethod async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: """ @@ -209,9 +222,9 @@ class MCPServiceManager: """ # 检查config if data.mcp_type == MCPType.SSE: - config = MCPServerSSEConfig.model_validate_json(data.config) + config = MCPServerSSEConfig.model_validate(data.config) else: - config = MCPServerStdioConfig.model_validate_json(data.config) + config = MCPServerStdioConfig.model_validate(data.config) # 构造Server mcp_server = MCPServerConfig( @@ -233,8 +246,24 @@ class MCPServiceManager: # 保存并载入配置 logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name) + mcp_path = MCP_PATH / "template" / mcp_id / "project" + if isinstance(config, MCPServerStdioConfig): + index = None + for i in range(len(config.args)): + if not config.args[i] == "--directory": + continue + index = i + 1 + break + if index is not None: + if index >= len(config.args): + config.args.append(str(mcp_path)) + else: + config.args[index] = str(mcp_path) + else: + config.args += ["--directory", str(mcp_path)] + await MCPLoader._insert_template_db(mcp_id=mcp_id, config=mcp_server) await MCPLoader.save_one(mcp_id, mcp_server) - await MCPLoader.init_one_template(mcp_id=mcp_id, config=mcp_server) + await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.INIT) return mcp_id @staticmethod @@ -258,19 +287,21 @@ class MCPServiceManager: db_service = MCPCollection.model_validate(db_service) for user_id in db_service.activated: await MCPServiceManager.deactive_mcpservice(user_sub=user_id, service_id=data.service_id) - - await MCPLoader.init_one_template(mcp_id=data.service_id, config=MCPServerConfig( + mcp_config = MCPServerConfig( name=data.name, overview=data.overview, description=data.description, - config=MCPServerStdioConfig.model_validate_json( + config=MCPServerStdioConfig.model_validate( data.config, - ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate_json( + ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate( data.config, ), type=data.mcp_type, author=user_sub, - )) + ) + await MCPLoader._insert_template_db(mcp_id=data.service_id, config=mcp_config) + await MCPLoader.save_one(mcp_id=data.service_id, config=mcp_config) + await MCPLoader.update_template_status(data.service_id, MCPInstallStatus.INIT) # 返回服务ID return data.service_id @@ -297,6 +328,7 @@ class MCPServiceManager: async def active_mcpservice( user_sub: str, service_id: str, + mcp_env: dict[str, Any] | None = None, ) -> None: """ 激活MCP服务 @@ -310,7 +342,7 @@ class MCPServiceManager: for item in status: mcp_status = item.get("status", MCPInstallStatus.INSTALLING) if mcp_status == MCPInstallStatus.READY: - await MCPLoader.user_active_template(user_sub, service_id) + await MCPLoader.user_active_template(user_sub, service_id, mcp_env) else: err = "[MCPServiceManager] MCP服务未准备就绪" raise RuntimeError(err) @@ -371,3 +403,28 @@ class MCPServiceManager: image.save(MCP_ICON_PATH / f"{service_id}.png", format="PNG", optimize=True, compress_level=9) return f"/static/mcp/{service_id}.png" + + @staticmethod + async def install_mcpservice(user_sub: str, service_id: str, install: bool) -> None: + """ + 安装或卸载MCP服务 + + :param user_sub: str: 用户ID + :param service_id: str: MCP服务ID + :param install: bool: 是否安装 + :return: 无 + """ + service_collection = MongoDB().get_collection("mcp") + db_service = await service_collection.find_one({"_id": service_id, "author": user_sub}) + db_service = MCPCollection.model_validate(db_service) + if install: + if db_service.status == MCPInstallStatus.INSTALLING or db_service.status == MCPInstallStatus.READY: + err = "[MCPServiceManager] MCP服务已处于安装中或已准备就绪" + raise Exception(err) + mcp_config = await MCPLoader.get_config(service_id) + await MCPLoader.init_one_template(mcp_id=service_id, config=mcp_config) + else: + if db_service.status != MCPInstallStatus.INSTALLING: + err = "[MCPServiceManager] 只能卸载处于安装中的MCP服务" + raise Exception(err) + await MCPLoader.cancel_installing_task([service_id]) diff --git a/apps/services/node.py b/apps/services/node.py index 6f0d492edacdd7611324a38953417e0fa769249b..3fb311bf67b14de543567bf559540ced7aa8b11e 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -4,6 +4,7 @@ import logging from typing import TYPE_CHECKING, Any +from apps.schemas.enum_var import SpecialCallType from apps.common.mongo import MongoDB from apps.schemas.node import APINode from apps.schemas.pool import NodePool @@ -16,6 +17,7 @@ NODE_TYPE_MAP = { "API": APINode, } + class NodeManager: """Node管理器""" @@ -29,7 +31,6 @@ class NodeManager: raise ValueError(err) return node["call_id"] - @staticmethod async def get_node(node_id: str) -> NodePool: """获取Node的类型""" @@ -40,7 +41,6 @@ class NodeManager: raise ValueError(err) return NodePool.model_validate(node) - @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" @@ -52,7 +52,6 @@ class NodeManager: return "" return node_doc["name"] - @staticmethod def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: """递归合并参数Schema,将known_params中的值填充到params_schema的对应位置""" @@ -75,12 +74,13 @@ class NodeManager: return params_schema - @staticmethod async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" from apps.scheduler.pool.pool import Pool - + if node_id == SpecialCallType.EMPTY.value: + # 如果是空节点,返回空Schema + return {}, {} # 查找Node信息 logger.info("[NodeManager] 获取节点 %s", node_id) node_collection = MongoDB().get_collection("node") @@ -100,7 +100,6 @@ class NodeManager: err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) raise ValueError(err) - # 返回参数Schema return ( NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), diff --git a/apps/services/parameter.py b/apps/services/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..2b85c62e341ebaedb3af025ca59d869e82d9a46c --- /dev/null +++ b/apps/services/parameter.py @@ -0,0 +1,89 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""flow Manager""" + +import logging + +from pymongo import ASCENDING + +from apps.services.node import NodeManager +from apps.schemas.flow_topology import FlowItem +from apps.scheduler.slot.slot import Slot +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, + Type +) +from apps.schemas.response_data import ( + OperateAndBindType, + ParamsNode, + StepParams, +) +from apps.services.node import NodeManager +logger = logging.getLogger(__name__) + + +class ParameterManager: + """Parameter Manager""" + @staticmethod + async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]: + """Get operate and bind type""" + result = [] + operate = None + if param_type == Type.NUMBER: + operate = NumberOperate + elif param_type == Type.STRING: + operate = StringOperate + elif param_type == Type.LIST: + operate = ListOperate + elif param_type == Type.BOOL: + operate = BoolOperate + elif param_type == Type.DICT: + operate = DictOperate + if operate: + for item in operate: + result.append(OperateAndBindType( + operate=item, + bind_type=(await ConditionHandler.get_value_type_from_operate(item)))) + return result + + @staticmethod + async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]: + """Get pre params by flow and step id""" + index = 0 + q = [step_id] + in_edges = {} + step_id_to_node_id = {} + step_id_to_node_name = {} + for step in flow.nodes: + step_id_to_node_id[step.step_id] = step.node_id + step_id_to_node_name[step.step_id] = step.name + for edge in flow.edges: + if edge.target_node not in in_edges: + in_edges[edge.target_node] = [] + in_edges[edge.target_node].append(edge.source_node) + while index < len(q): + tmp_step_id = q[index] + index += 1 + for i in range(len(in_edges.get(tmp_step_id, []))): + pre_node_id = in_edges[tmp_step_id][i] + if pre_node_id not in q: + q.append(pre_node_id) + pre_step_params = [] + for i in range(1, len(q)): + step_id = q[i] + node_id = step_id_to_node_id.get(step_id) + _, output_schema = await NodeManager.get_node_params(node_id) + slot = Slot(output_schema) + params_node = slot.get_params_node_from_schema() + pre_step_params.append( + StepParams( + stepId=step_id, + name=step_id_to_node_name.get(step_id), + paramsNode=params_node + ) + ) + return pre_step_params diff --git a/apps/services/rag.py b/apps/services/rag.py index 6b6c843dc6fc09a14809ef0a11bc44b2da434f29..8b95d2b1b3c9fd20df54290b1b2534f8114a2e56 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """对接Euler Copilot RAG""" +from datetime import UTC, datetime import json import logging from collections.abc import AsyncGenerator @@ -17,7 +18,6 @@ from apps.schemas.collection import LLM from apps.schemas.config import LLMConfig from apps.schemas.enum_var import EventType from apps.schemas.rag_data import RAGQueryReq -from apps.services.activity import Activity from apps.services.session import SessionManager logger = logging.getLogger(__name__) @@ -156,9 +156,11 @@ class RAG: "id": doc_chunk["docId"], "order": doc_cnt, "name": doc_chunk.get("docName", ""), + "author": doc_chunk.get("docAuthor", ""), "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), + "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] @@ -254,8 +256,6 @@ class RAG: result_only=False, model=llm.model_name, ): - if not await Activity.is_active(user_sub): - return chunk = buffer + chunk # 防止脚注被截断 if len(chunk) >= 2 and chunk[-2:] != "]]": diff --git a/apps/services/record.py b/apps/services/record.py index 5c8a89dfd224166ac822e2eeb47fb79f1fefe7e4..cf8373b042ab409662dcf1d2e08ebc8ac3666d1b 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -9,7 +9,7 @@ from apps.schemas.record import ( Record, RecordGroup, ) - +from apps.schemas.enum_var import FlowStatus logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None: + async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None: """创建问答组""" mongo = MongoDB() record_group_collection = mongo.get_collection("record_group") @@ -26,7 +26,6 @@ class RecordManager: _id=group_id, user_sub=user_sub, conversation_id=conversation_id, - task_id=task_id, ) try: @@ -49,6 +48,10 @@ class RecordManager: mongo = MongoDB() group_collection = mongo.get_collection("record_group") try: + await group_collection.update_one( + {"_id": group_id, "user_sub": user_sub}, + {"$pull": {"records": {"id": record.id}}} + ) await group_collection.update_one( {"_id": group_id, "user_sub": user_sub}, {"$push": {"records": record.model_dump(by_alias=True)}}, @@ -133,6 +136,19 @@ class RecordManager: logger.exception("[RecordManager] 查询问答组失败") return [] + @staticmethod + async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: + """更新Record关联的Flow状态""" + record_group_collection = MongoDB().get_collection("record_group") + try: + await record_group_collection.update_many( + {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, + {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], + ) + except Exception: + logger.exception("[RecordManager] 更新Record关联的Flow状态失败") + @staticmethod async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool: """ @@ -151,7 +167,6 @@ class RecordManager: logger.exception("[RecordManager] 验证记录是否在组中失败") return False - @staticmethod async def check_group_id(group_id: str, user_sub: str) -> bool: """检查group_id是否存在""" diff --git a/apps/services/task.py b/apps/services/task.py index 1e672be690a6f17896ab01bc7149aa353987ecf7..49c9f4b037cb95408588ce304d0f483af88b4ed8 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -13,6 +13,7 @@ from apps.schemas.task import ( TaskIds, TaskRuntime, TaskTokens, + FlowStepHistory ) from apps.services.record import RecordManager @@ -45,7 +46,6 @@ class TaskManager: return Task.model_validate(task) - @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" @@ -58,7 +58,6 @@ class TaskManager: task = await task_collection.find_one({"_id": record_group_obj.task_id}) return Task.model_validate(task) - @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" @@ -68,9 +67,8 @@ class TaskManager: return None return Task.model_validate(task) - @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") @@ -85,59 +83,72 @@ class TaskManager: return [] flow_context_list = [] - for flow_context_id in records[0]["records"]["flow"]: + for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context_list.append(flow_context) + flow_context_list.append(FlowStepHistory.model_validate(flow_context)) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") return [] else: return flow_context_list - @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: + async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") flow_context = [] try: - async for history in flow_context_collection.find( - {"task_id": task_id}, - ).sort( - "created_at", -1, - ).limit(length): - flow_context += [history] + if length is None: + async for context in flow_context_collection.find({"task_id": task_id}): + flow_context.append(FlowStepHistory.model_validate(context)) + else: + async for context in flow_context_collection.find({"task_id": task_id}).limit(length): + flow_context.append(FlowStepHistory.model_validate(context)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] else: return flow_context + @staticmethod + async def init_new_task( + user_sub: str, + session_id: str | None = None, + post_body: RequestData | None = None, + ) -> Task: + """获取任务块""" + return Task( + _id=str(uuid.uuid4()), + ids=TaskIds( + user_sub=user_sub if user_sub else "", + session_id=session_id if session_id else "", + conversation_id=post_body.conversation_id, + group_id=post_body.group_id if post_body.group_id else "", + ), + question=post_body.question if post_body else "", + group_id=post_body.group_id if post_body else "", + tokens=TaskTokens(), + runtime=TaskRuntime(), + ) @staticmethod - async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: + async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: - for history in flow_context: - # 查找是否存在 - current_context = await flow_context_collection.find_one({ - "task_id": task_id, - "_id": history["_id"], - }) - if current_context: - await flow_context_collection.update_one( - {"_id": current_context["_id"]}, - {"$set": history}, - ) - else: - await flow_context_collection.insert_one(history) + # 删除旧的flow_context + await flow_context_collection.delete_many({"task_id": task_id}) + if not flow_context: + return + await flow_context_collection.insert_many( + [history.model_dump(exclude_none=True, by_alias=True) for history in flow_context], + ordered=False, + ) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") - @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" @@ -148,9 +159,27 @@ class TaskManager: if task: await task_collection.delete_one({"_id": task_id}) + @staticmethod + async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]: + """通过ConversationID删除Task信息""" + mongo = MongoDB() + task_collection = mongo.get_collection("task") + task_ids = [] + try: + async for task in task_collection.find( + {"conversation_id": conversation_id}, + {"_id": 1}, + ): + task_ids.append(task["_id"]) + if task_ids: + await task_collection.delete_many({"conversation_id": conversation_id}) + return task_ids + except Exception: + logger.exception("[TaskManager] 删除ConversationID的Task信息失败") + return [] @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: str) -> None: + async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") @@ -167,52 +196,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - - @classmethod - async def get_task( - cls, - task_id: str | None = None, - session_id: str | None = None, - post_body: RequestData | None = None, - user_sub: str | None = None, - ) -> Task: - """获取任务块""" - if task_id: - try: - task = await cls.get_task_by_task_id(task_id) - if task: - return task - except Exception: - logger.exception("[TaskManager] 通过task_id获取任务失败") - - logger.info("[TaskManager] 未提供task_id,通过session_id获取任务") - if not session_id or not post_body: - err = ( - "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。" - ) - raise ValueError(err) - - if post_body.group_id: - task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id) - else: - task = await cls.get_task_by_conversation_id(post_body.conversation_id) - - if task: - return task - return Task( - _id=str(uuid.uuid4()), - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - group_id=post_body.group_id if post_body.group_id else "", - ), - state=None, - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" diff --git a/apps/services/user.py b/apps/services/user.py index 2721d3773bd43e2a8fda5b371d6336fcf0b9b7f3..1b96df18143f5a87b102f68e2bb4e2df62753f05 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -4,6 +4,7 @@ import logging from datetime import UTC, datetime +from apps.schemas.request_data import UserUpdateRequest from apps.common.mongo import MongoDB from apps.schemas.collection import User from apps.services.conversation import ConversationManager @@ -28,7 +29,7 @@ class UserManager: ).model_dump(by_alias=True)) @staticmethod - async def get_all_user_sub() -> list[str]: + async def get_all_user_sub(page_size: int = 20, page_cnt: int = 1, filter_user_subs: list[str] = []) -> tuple[list[str], int]: """ 获取所有用户的sub @@ -36,7 +37,13 @@ class UserManager: """ mongo = MongoDB() user_collection = mongo.get_collection("user") - return [user["_id"] async for user in user_collection.find({}, {"_id": 1})] + total = await user_collection.count_documents({}) - len(filter_user_subs) + + users = await user_collection.find( + {"_id": {"$nin": filter_user_subs}}, + {"_id": 1}, + ).skip((page_cnt - 1) * page_size).limit(page_size).to_list(length=page_size) + return [user["_id"] for user in users], total @staticmethod async def get_userinfo_by_user_sub(user_sub: str) -> User | None: @@ -52,7 +59,25 @@ class UserManager: return User(**user_data) if user_data else None @staticmethod - async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: + async def update_userinfo_by_user_sub(user_sub: str, data: UserUpdateRequest) -> None: + """ + 根据用户sub更新用户信息 + + :param user_sub: 用户sub + :param data: 用户更新信息 + :return: 是否更新成功 + """ + mongo = MongoDB() + user_collection = mongo.get_collection("user") + update_dict = { + "$set": { + "auto_execute": data.auto_execute, + } + } + await user_collection.update_one({"_id": user_sub}, update_dict) + + @staticmethod + async def update_refresh_revision_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: """ 根据用户sub更新用户信息 diff --git a/deploy/chart/euler_copilot/configs/rag/.env b/deploy/chart/euler_copilot/configs/rag/.env index 5e83f3f415091b156644b33bef9291a2d99d1931..0708fd0762d1f8c789ee5a4a873fb06880ce1463 100644 --- a/deploy/chart/euler_copilot/configs/rag/.env +++ b/deploy/chart/euler_copilot/configs/rag/.env @@ -52,7 +52,7 @@ HALF_KEY3=${halfKey3} #LLM config MODEL_NAME={{ .Values.models.answer.name }} -OPENAI_API_BASE={{ .Values.models.answer.endpoint }}/v1 +OPENAI_API_BASE={{ .Values.models.answer.endpoint }} OPENAI_API_KEY={{ default "" .Values.models.answer.key }} MAX_TOKENS={{ default 2048 .Values.models.answer.maxTokens }} diff --git a/perf.data b/perf.data new file mode 100644 index 0000000000000000000000000000000000000000..f319e212c75a402bbf0ff2e488552bad93ff5b8b Binary files /dev/null and b/perf.data differ diff --git a/pyproject.toml b/pyproject.toml index 681ea9fb4552650da9148a22dd55d2063dca2410..22b07eacd07d19ac7478b6af20fb0f972da742db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,4 +50,4 @@ dev = [ "ruff==0.11.2", "sphinx==8.2.3", "sphinx-rtd-theme==3.0.2", -] +] \ No newline at end of file diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py index 5375180a3f453a9d9d18d17b95b0ed9c39d8c5c4..db1f5ead746f35dd510b30698746e1a25f08d02d 100644 --- a/tests/common/test_queue.py +++ b/tests/common/test_queue.py @@ -74,8 +74,8 @@ async def test_push_output_with_flow(message_queue, mock_task): mock_task.state.flow_id = "flow_id" mock_task.state.step_id = "step_id" mock_task.state.step_name = "step_name" - mock_task.state.status = "running" - + mock_task.state.step_status = "running" + await message_queue.init("test_task") await message_queue.push_output(mock_task, EventType.TEXT_ADD, {}) diff --git a/tests/manager/test_user.py b/tests/manager/test_user.py index d350ca307a9f0aa2d6532cae614b8633b1ae0f81..ab6eeab4c498a2619c4f08c5b548a6096611aeaa 100644 --- a/tests/manager/test_user.py +++ b/tests/manager/test_user.py @@ -73,7 +73,7 @@ class TestUserManager(unittest.TestCase): mock_mysql_db_instance.get_session.return_value = mock_session # 调用被测方法 - updated_userinfo = UserManager.update_userinfo_by_user_sub(userinfo, refresh_revision=True) + updated_userinfo = UserManager.update_refresh_revision_by_user_sub(userinfo, refresh_revision=True) # 断言返回的用户信息的 revision_number 是否与原始用户信息一致 self.assertEqual(updated_userinfo.revision_number, userinfo.revision_number)