From e8155964f14717f5aa75dbea9dd431a4e6887d21 Mon Sep 17 00:00:00 2001 From: Loshawn <2428333123@qq.com> Date: Fri, 15 Aug 2025 20:03:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=BD=E9=99=85=E5=8C=96-Agent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/mcp_agent/base.py | 4 +- apps/scheduler/mcp_agent/host.py | 62 +++---- apps/scheduler/mcp_agent/plan.py | 261 ++++++++++++++++++----------- apps/scheduler/mcp_agent/select.py | 89 ++++++---- 4 files changed, 252 insertions(+), 164 deletions(-) diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index ac3829b41..6cd60a934 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -8,7 +8,9 @@ from apps.llm.reasoning import ReasoningLLM logger = logging.getLogger(__name__) -class McpBase: +class MCPBase: + """MCP基类""" + @staticmethod async def get_resoning_result(prompt: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: """获取推理结果""" diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index a06104f3b..b78272785 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -7,20 +7,15 @@ from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from mcp.types import TextContent -from apps.common.mongo import MongoDB -from apps.llm.reasoning import ReasoningLLM from apps.llm.function import JsonGenerator -from apps.scheduler.mcp_agent.base import McpBase +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE -from apps.scheduler.pool.mcp.client import MCPClient -from apps.scheduler.pool.mcp.pool import MCPPool +from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import GEN_PARAMS, REPAIR_PARAMS -from apps.schemas.enum_var import StepStatus -from apps.schemas.mcp import MCPPlanItem, MCPTool -from apps.schemas.task import Task, FlowStepHistory -from apps.services.task import TaskManager +from apps.schemas.mcp import MCPTool +from apps.schemas.task import Task +from apps.schemas.enum_var import LanguageType logger = logging.getLogger(__name__) @@ -36,26 +31,36 @@ def tojson_filter(value): return json.dumps(value, ensure_ascii=False, separators=(',', ':')) -_env.filters['tojson'] = tojson_filter +_env.filters["tojson"] = tojson_filter +LLM_QUERY_FIX = { + LanguageType.CHINESE: "请生成修复之后的工具参数", + LanguageType.ENGLISH: "Please generate the tool parameters after repair", +} -class MCPHost(McpBase): + +class MCPHost(MCPBase): """MCP宿主服务""" @staticmethod async def assemble_memory(task: Task) -> str: """组装记忆""" - return _env.from_string(MEMORY_TEMPLATE).render( + return _env.from_string(MEMORY_TEMPLATE[task.language]).render( context_list=task.context, ) @staticmethod - async def _get_first_input_params(mcp_tool: MCPTool, goal: str, current_goal: str, task: Task, - resoning_llm: ReasoningLLM = ReasoningLLM()) -> dict[str, Any]: + async def _get_first_input_params( + mcp_tool: MCPTool, + goal: str, + current_goal: str, + task: Task, + resoning_llm: ReasoningLLM = ReasoningLLM(), + ) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate - prompt = _env.from_string(GEN_PARAMS).render( + prompt = _env.from_string(GEN_PARAMS[task.language]).render( tool_name=mcp_tool.name, tool_description=mcp_tool.description, goal=goal, @@ -63,10 +68,7 @@ class MCPHost(McpBase): input_schema=mcp_tool.input_schema, background_info=await MCPHost.assemble_memory(task), ) - result = await MCPHost.get_resoning_result( - prompt, - resoning_llm - ) + result = await MCPHost.get_resoning_result(prompt, resoning_llm) # 使用JsonGenerator解析结果 result = await MCPHost._parse_result( result, @@ -75,14 +77,18 @@ class MCPHost(McpBase): return result @staticmethod - async def _fill_params(mcp_tool: MCPTool, - goal: str, - current_goal: str, - current_input: dict[str, Any], - error_message: str = "", params: dict[str, Any] = {}, - params_description: str = "") -> dict[str, Any]: - llm_query = "请生成修复之后的工具参数" - prompt = _env.from_string(REPAIR_PARAMS).render( + async def _fill_params( + mcp_tool: MCPTool, + goal: str, + current_goal: str, + current_input: dict[str, Any], + error_message: str = "", + params: dict[str, Any] = {}, + params_description: str = "", + language: LanguageType = LanguageType.CHINESE, + ) -> dict[str, Any]: + llm_query = LLM_QUERY_FIX[language] + prompt = _env.from_string(REPAIR_PARAMS[language]).render( tool_name=mcp_tool.name, goal=goal, current_goal=current_goal, diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index aaacc8442..31b3c7c05 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -1,41 +1,44 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" -from typing import Any, AsyncGenerator + +import logging +from collections.abc import AsyncGenerator +from typing import Any + from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -import logging + from apps.llm.reasoning import ReasoningLLM -from apps.llm.function import JsonGenerator -from apps.scheduler.mcp_agent.base import McpBase +from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import ( + CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, + CREATE_PLAN, EVALUATE_GOAL, + FINAL_ANSWER, + GEN_STEP, GENERATE_FLOW_NAME, + GET_MISSING_PARAMS, GET_REPLAN_START_STEP_INDEX, - CREATE_PLAN, + IS_PARAM_ERROR, RECREATE_PLAN, - GEN_STEP, - TOOL_SKIP, RISK_EVALUATE, TOOL_EXECUTE_ERROR_TYPE_ANALYSIS, - IS_PARAM_ERROR, - CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, - GET_MISSING_PARAMS, - FINAL_ANSWER + TOOL_SKIP, ) -from apps.schemas.task import Task +from apps.schemas.enum_var import LanguageType +from apps.scheduler.slot.slot import Slot from apps.schemas.mcp import ( GoalEvaluationResult, - RestartStepIndex, - ToolSkip, - ToolRisk, IsParamError, - ToolExcutionErrorType, MCPPlan, + MCPTool, + RestartStepIndex, Step, - MCPPlanItem, - MCPTool + ToolExcutionErrorType, + ToolRisk, + ToolSkip, ) -from apps.scheduler.slot.slot import Slot +from apps.schemas.task import Task _env = SandboxedEnvironment( loader=BaseLoader, @@ -46,36 +49,31 @@ _env = SandboxedEnvironment( logger = logging.getLogger(__name__) -class MCPPlanner(McpBase): +class MCPPlanner(MCPBase): """MCP 用户目标拆解与规划""" @staticmethod async def evaluate_goal( - goal: str, - tool_list: list[MCPTool], - resoning_llm: ReasoningLLM = ReasoningLLM()) -> GoalEvaluationResult: + goal: str, tool_list: list[MCPTool], resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, + ) -> GoalEvaluationResult: """评估用户目标的可行性""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_evaluation(goal, tool_list, resoning_llm) - - # 解析为结构化数据 - evaluation = await MCPPlanner._parse_evaluation_result(result) + result = await MCPPlanner._get_reasoning_evaluation(goal, tool_list, resoning_llm, language) # 返回评估结果 - return evaluation + return await MCPPlanner._parse_evaluation_result(result) @staticmethod async def _get_reasoning_evaluation( - goal, tool_list: list[MCPTool], - resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + goal, tool_list: list[MCPTool], resoning_llm: ReasoningLLM = ReasoningLLM(), language: LanguageType = LanguageType.CHINESE, + ) -> str: """获取推理大模型的评估结果""" - template = _env.from_string(EVALUATE_GOAL) + template = _env.from_string(EVALUATE_GOAL[language]) prompt = template.render( goal=goal, tools=tool_list, ) - result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) - return result + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) @staticmethod async def _parse_evaluation_result(result: str) -> GoalEvaluationResult: @@ -85,27 +83,38 @@ class MCPPlanner(McpBase): # 使用GoalEvaluationResult模型解析结果 return GoalEvaluationResult.model_validate(evaluation) - async def get_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + async def get_flow_name( + user_goal: str, + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: """获取当前流程的名称""" - result = await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm) - return result + + return await MCPPlanner._get_reasoning_flow_name(user_goal, resoning_llm, language) @staticmethod - async def _get_reasoning_flow_name(user_goal: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + async def _get_reasoning_flow_name( + user_goal: str, + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: """获取推理大模型的流程名称""" - template = _env.from_string(GENERATE_FLOW_NAME) + template = _env.from_string(GENERATE_FLOW_NAME[language]) prompt = template.render(goal=user_goal) - result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) - return result + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) @staticmethod async def get_replan_start_step_index( - user_goal: str, error_message: str, current_plan: MCPPlan | None = None, - history: str = "", - reasoning_llm: ReasoningLLM = ReasoningLLM()) -> RestartStepIndex: + user_goal: str, + error_message: str, + current_plan: MCPPlan | None = None, + history: str = "", + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> RestartStepIndex: """获取重新规划的步骤索引""" # 获取推理结果 - template = _env.from_string(GET_REPLAN_START_STEP_INDEX) + template = _env.from_string(GET_REPLAN_START_STEP_INDEX[language]) prompt = template.render( goal=user_goal, error_message=error_message, @@ -123,26 +132,40 @@ class MCPPlanner(McpBase): @staticmethod async def create_plan( - user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, - tool_list: list[MCPTool] = [], - max_steps: int = 6, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> MCPPlan: + user_goal: str, + is_replan: bool = False, + error_message: str = "", + current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 6, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> MCPPlan: """规划下一步的执行流程,并输出""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_plan(user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm) + result = await MCPPlanner._get_reasoning_plan( + user_goal, is_replan, error_message, current_plan, tool_list, max_steps, reasoning_llm, language + ) # 解析为结构化数据 return await MCPPlanner._parse_plan_result(result, max_steps) @staticmethod async def _get_reasoning_plan( - user_goal: str, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan | None = None, - tool_list: list[MCPTool] = [], - max_steps: int = 10, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + user_goal: str, + is_replan: bool = False, + error_message: str = "", + current_plan: MCPPlan | None = None, + tool_list: list[MCPTool] = [], + max_steps: int = 10, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: """获取推理大模型的结果""" # 格式化Prompt tool_ids = [tool.id for tool in tool_list] if is_replan: - template = _env.from_string(RECREATE_PLAN) + template = _env.from_string(RECREATE_PLAN[language]) prompt = template.render( current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), error_message=error_message, @@ -151,14 +174,13 @@ class MCPPlanner(McpBase): max_num=max_steps, ) else: - template = _env.from_string(CREATE_PLAN) + template = _env.from_string(CREATE_PLAN[language]) prompt = template.render( goal=user_goal, tools=tool_list, max_num=max_steps, ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) - return result + return await MCPPlanner.get_resoning_result(prompt, reasoning_llm) @staticmethod async def _parse_plan_result(result: str, max_steps: int) -> MCPPlan: @@ -172,11 +194,15 @@ class MCPPlanner(McpBase): @staticmethod async def create_next_step( - goal: str, history: str, tools: list[MCPTool], - reasoning_llm: ReasoningLLM = ReasoningLLM()) -> Step: + goal: str, + history: str, + tools: list[MCPTool], + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> Step: """创建下一步的执行步骤""" # 获取推理结果 - template = _env.from_string(GEN_STEP) + template = _env.from_string(GEN_STEP[language]) prompt = template.render(goal=goal, history=history, tools=tools) result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) @@ -198,11 +224,17 @@ class MCPPlanner(McpBase): @staticmethod async def tool_skip( - task: Task, step_id: str, step_name: str, step_instruction: str, step_content: str, - reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolSkip: + task: Task, + step_id: str, + step_name: str, + step_instruction: str, + step_content: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> ToolSkip: """判断当前步骤是否需要跳过""" # 获取推理结果 - template = _env.from_string(TOOL_SKIP) + template = _env.from_string(TOOL_SKIP[language]) from apps.scheduler.mcp_agent.host import MCPHost history = await MCPHost.assemble_memory(task) prompt = template.render( @@ -223,32 +255,38 @@ class MCPPlanner(McpBase): @staticmethod async def get_tool_risk( - tool: MCPTool, input_parm: dict[str, Any], - additional_info: str = "", resoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolRisk: + tool: MCPTool, + input_parm: dict[str, Any], + additional_info: str = "", + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> ToolRisk: """获取MCP工具的风险评估结果""" # 获取推理结果 - result = await MCPPlanner._get_reasoning_risk(tool, input_parm, additional_info, resoning_llm) - - # 解析为结构化数据 - risk = await MCPPlanner._parse_risk_result(result) + result = await MCPPlanner._get_reasoning_risk( + tool, input_parm, additional_info, resoning_llm, language + ) # 返回风险评估结果 - return risk + return await MCPPlanner._parse_risk_result(result) @staticmethod async def _get_reasoning_risk( - tool: MCPTool, input_param: dict[str, Any], - additional_info: str, resoning_llm: ReasoningLLM) -> str: + tool: MCPTool, + input_param: dict[str, Any], + additional_info: str, + resoning_llm: ReasoningLLM, + language: LanguageType = LanguageType.CHINESE, + ) -> str: """获取推理大模型的风险评估结果""" - template = _env.from_string(RISK_EVALUATE) + template = _env.from_string(RISK_EVALUATE[language]) prompt = template.render( tool_name=tool.name, tool_description=tool.description, input_param=input_param, additional_info=additional_info, ) - result = await MCPPlanner.get_resoning_result(prompt, resoning_llm) - return result + return await MCPPlanner.get_resoning_result(prompt, resoning_llm) @staticmethod async def _parse_risk_result(result: str) -> ToolRisk: @@ -260,11 +298,16 @@ class MCPPlanner(McpBase): @staticmethod async def _get_reasoning_tool_execute_error_type( - user_goal: str, current_plan: MCPPlan, - tool: MCPTool, input_param: dict[str, Any], - error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + user_goal: str, + current_plan: MCPPlan, + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: """获取推理大模型的工具执行错误类型""" - template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS) + template = _env.from_string(TOOL_EXECUTE_ERROR_TYPE_ANALYSIS[language]) prompt = template.render( goal=user_goal, current_plan=current_plan.model_dump(exclude_none=True, by_alias=True), @@ -273,8 +316,7 @@ class MCPPlanner(McpBase): input_param=input_param, error_message=error_message, ) - result = await MCPPlanner.get_resoning_result(prompt, reasoning_llm) - return result + return await MCPPlanner.get_resoning_result(prompt, reasoning_llm) @staticmethod async def _parse_tool_execute_error_type_result(result: str) -> ToolExcutionErrorType: @@ -286,24 +328,35 @@ class MCPPlanner(McpBase): @staticmethod async def get_tool_execute_error_type( - user_goal: str, current_plan: MCPPlan, - tool: MCPTool, input_param: dict[str, Any], - error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> ToolExcutionErrorType: + user_goal: str, + current_plan: MCPPlan, + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> ToolExcutionErrorType: """获取MCP工具执行错误类型""" # 获取推理结果 result = await MCPPlanner._get_reasoning_tool_execute_error_type( - user_goal, current_plan, tool, input_param, error_message, reasoning_llm) - error_type = await MCPPlanner._parse_tool_execute_error_type_result(result) + user_goal, current_plan, tool, input_param, error_message, reasoning_llm, language + ) # 返回工具执行错误类型 - return error_type + return await MCPPlanner._parse_tool_execute_error_type_result(result) @staticmethod async def is_param_error( - goal: str, history: str, error_message: str, tool: MCPTool, step_description: str, input_params: dict - [str, Any], - reasoning_llm: ReasoningLLM = ReasoningLLM()) -> IsParamError: + goal: str, + history: str, + error_message: str, + tool: MCPTool, + step_description: str, + input_params: dict[str, Any], + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> IsParamError: """判断错误信息是否是参数错误""" - tmplate = _env.from_string(IS_PARAM_ERROR) + tmplate = _env.from_string(IS_PARAM_ERROR[language]) prompt = tmplate.render( goal=goal, history=history, @@ -322,10 +375,14 @@ class MCPPlanner(McpBase): @staticmethod async def change_err_message_to_description( - error_message: str, tool: MCPTool, input_params: dict[str, Any], - reasoning_llm: ReasoningLLM = ReasoningLLM()) -> str: + error_message: str, + tool: MCPTool, + input_params: dict[str, Any], + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> str: """将错误信息转换为工具描述""" - template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION) + template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION[language]) prompt = template.render( error_message=error_message, tool_name=tool.name, @@ -338,12 +395,15 @@ class MCPPlanner(McpBase): @staticmethod async def get_missing_param( - tool: MCPTool, - input_param: dict[str, Any], - error_message: str, reasoning_llm: ReasoningLLM = ReasoningLLM()) -> list[str]: + tool: MCPTool, + input_param: dict[str, Any], + error_message: str, + reasoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> list[str]: """获取缺失的参数""" slot = Slot(schema=tool.input_schema) - template = _env.from_string(GET_MISSING_PARAMS) + template = _env.from_string(GET_MISSING_PARAMS[language]) schema_with_null = slot.add_null_to_basic_types() prompt = template.render( tool_name=tool.name, @@ -359,10 +419,13 @@ class MCPPlanner(McpBase): @staticmethod async def generate_answer( - user_goal: str, memory: str, resoning_llm: ReasoningLLM = ReasoningLLM()) -> AsyncGenerator[ - str, None]: + user_goal: str, + memory: str, + resoning_llm: ReasoningLLM = ReasoningLLM(), + language: LanguageType = LanguageType.CHINESE, + ) -> AsyncGenerator[str, None]: """生成最终回答""" - template = _env.from_string(FINAL_ANSWER) + template = _env.from_string(FINAL_ANSWER[language]) prompt = template.render( memory=memory, goal=user_goal, diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index 075e08f00..aea6cd11f 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -3,28 +3,17 @@ import logging import random + from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from typing import AsyncGenerator, Any -from apps.llm.function import JsonGenerator -from apps.llm.reasoning import ReasoningLLM -from apps.common.lance import LanceDB -from apps.common.mongo import MongoDB -from apps.llm.embedding import Embedding -from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM from apps.llm.token import TokenCalculator -from apps.scheduler.mcp_agent.base import McpBase +from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import TOOL_SELECT -from apps.schemas.mcp import ( - BaseModel, - MCPCollection, - MCPSelectResult, - MCPTool, - MCPToolIdsSelectResult -) -from apps.common.config import Config +from apps.schemas.mcp import MCPTool, MCPToolIdsSelectResult +from apps.schemas.enum_var import LanguageType + logger = logging.getLogger(__name__) _env = SandboxedEnvironment( @@ -38,24 +27,35 @@ FINAL_TOOL_ID = "FIANL" SUMMARIZE_TOOL_ID = "SUMMARIZE" -class MCPSelector(McpBase): +class MCPSelector(MCPBase): """MCP选择器""" @staticmethod async def select_top_tool( - goal: str, tool_list: list[MCPTool], - additional_info: str | None = None, top_n: int | None = None, - reasoning_llm: ReasoningLLM | None = None) -> list[MCPTool]: + goal: str, + tool_list: list[MCPTool], + additional_info: str | None = None, + top_n: int | None = None, + reasoning_llm: ReasoningLLM | None = None, + language: LanguageType = LanguageType.CHINESE, + ) -> list[MCPTool]: """选择最合适的工具""" random.shuffle(tool_list) max_tokens = reasoning_llm._config.max_tokens - template = _env.from_string(TOOL_SELECT) + template = _env.from_string(TOOL_SELECT[language]) token_calculator = TokenCalculator() - if token_calculator.calculate_token_length( - messages=[{"role": "user", "content": template.render( - goal=goal, tools=[], additional_info=additional_info - )}], - pure_text=True) > max_tokens: + if ( + token_calculator.calculate_token_length( + messages=[ + { + "role": "user", + "content": template.render(goal=goal, tools=[], additional_info=additional_info), + } + ], + pure_text=True, + ) + > max_tokens + ): logger.warning("[MCPSelector] 工具选择模板长度超过最大令牌数,无法进行选择") return [] current_index = 0 @@ -66,18 +66,31 @@ class MCPSelector(McpBase): while index < len(tool_list): tool = tool_list[index] tokens = token_calculator.calculate_token_length( - messages=[{"role": "user", "content": template.render( - goal=goal, tools=[tool], - additional_info=additional_info - )}], - pure_text=True + messages=[ + { + "role": "user", + "content": template.render( + goal=goal, tools=[tool], additional_info=additional_info + ), + } + ], + pure_text=True, ) if tokens > max_tokens: continue sub_tools.append(tool) - tokens = token_calculator.calculate_token_length(messages=[{"role": "user", "content": template.render( - goal=goal, tools=sub_tools, additional_info=additional_info)}, ], pure_text=True) + tokens = token_calculator.calculate_token_length( + messages=[ + { + "role": "user", + "content": template.render( + goal=goal, tools=sub_tools, additional_info=additional_info + ), + }, + ], + pure_text=True, + ) if tokens > max_tokens: del sub_tools[-1] break @@ -90,7 +103,10 @@ class MCPSelector(McpBase): schema["properties"]["tool_ids"]["items"] = {} # 将enum添加到items中,限制数组元素的可选值 schema["properties"]["tool_ids"]["items"]["enum"] = [tool.id for tool in sub_tools] - result = await MCPSelector.get_resoning_result(template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"), reasoning_llm) + result = await MCPSelector.get_resoning_result( + template.render(goal=goal, tools=sub_tools, additional_info="请根据目标选择对应的工具"), + reasoning_llm, + ) result = await MCPSelector._parse_result(result, schema) try: result = MCPToolIdsSelectResult.model_validate(result) @@ -102,8 +118,9 @@ class MCPSelector(McpBase): if top_n is not None: mcp_tools = mcp_tools[:top_n] - mcp_tools.append(MCPTool(id=FINAL_TOOL_ID, name="Final", - description="终止", mcp_id=FINAL_TOOL_ID, input_schema={})) + mcp_tools.append( + MCPTool(id=FINAL_TOOL_ID, name="Final", description="终止", mcp_id=FINAL_TOOL_ID, input_schema={}) + ) # mcp_tools.append(MCPTool(id=SUMMARIZE_TOOL_ID, name="Summarize", # description="总结工具", mcp_id=SUMMARIZE_TOOL_ID, input_schema={})) return mcp_tools -- Gitee