diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 4867df3a9346189c2ce64c1b0463b294643822ca..97692a2eb20652174eff207ae965256c37c77eef 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -46,11 +46,10 @@ async def verify_session(request: HTTPConnection) -> None: return # 作为Session ID校验 - session_id = token - request.state.session_id = session_id - user_id = await SessionManager.get_user(session_id) + request.state.session_id = token + user_id = await SessionManager.get_user(token) if not user_id: - logger.warning("Session ID鉴权失败:无效的session_id=%s", session_id) + logger.warning("Session ID鉴权失败:无效的session_id=%s", token) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session ID 鉴权失败", @@ -89,6 +88,7 @@ async def verify_personal_token(request: HTTPConnection) -> None: # 验证是否为合法的Personal Token user_id = await PersonalTokenManager.get_user_by_personal_token(token) if user_id is not None: + request.state.personal_token = token request.state.user_id = user_id else: # Personal Token无效,抛出401 diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 213afc55c82fb761391c90430f852380509f4a80..530c9c137af374c6d42dd9c3f786489b10a21681 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -62,7 +62,7 @@ async def chat_generator(post_body: RequestData, user_id: str, session_id: str) # 获取最终答案 task = scheduler.task - if task.state.executorStatus == ExecutorStatus.ERROR: + if task.state and task.state.executorStatus == ExecutorStatus.ERROR: _logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" await Activity.remove_active(user_id) diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 1076a0569dafd25e0316397e5d1270309b7188be..e1340e01ac8b7715ef20539cc7740871f509addf 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -137,9 +137,10 @@ async def delete_conversation( conversation_id, ) # 删除对话对应的文件 + auth_header = getattr(request.session, "session_id", None) or request.state.personal_token await DocumentManager.delete_document_by_conversation_id( - request.state.user_id, conversation_id, + auth_header, ) deleted_conversation.append(conversation_id) diff --git a/apps/routers/document.py b/apps/routers/document.py index aaae234b38eec168ca1d01de63ebe0fb66582f1c..3ca496294bd82d3a673447475849a709bba7ef5a 100644 --- a/apps/routers/document.py +++ b/apps/routers/document.py @@ -44,7 +44,8 @@ async def document_upload( ) -> JSONResponse: """POST /document/{conversation_id}: 上传文档到指定对话""" result = await DocumentManager.storage_docs(request.state.user_id, conversation_id, documents) - await KnowledgeBaseService.send_file_to_rag(request.state.session_id, result) + auth_header = getattr(request.session, "session_id", None) or request.state.personal_token + await KnowledgeBaseService.send_file_to_rag(auth_header, result) # 返回所有Framework已知的文档 succeed_document: list[BaseDocumentItem] = [ @@ -107,13 +108,13 @@ async def _process_used_documents(conversation_id: uuid.UUID) -> list[Conversati async def _process_unused_documents( conversation_id: uuid.UUID, - session_id: str, + auth_header: str, ) -> list[ConversationDocumentItem]: """处理未使用的文档列表""" result = [] unused_docs = await DocumentManager.get_unused_docs(conversation_id) doc_status = await KnowledgeBaseService.get_doc_status_from_rag( - session_id, [item.id for item in unused_docs], + auth_header, [item.id for item in unused_docs], ) for current_doc in unused_docs: @@ -172,7 +173,8 @@ async def get_document_list( result.extend(await _process_used_documents(conversation_id)) if unused: - result.extend(await _process_unused_documents(conversation_id, request.state.session_id)) + auth_header = getattr(request.session, "session_id", None) or request.state.personal_token + result.extend(await _process_unused_documents(conversation_id, auth_header)) # 对外展示的时候用id,不用alias return JSONResponse( @@ -202,7 +204,8 @@ async def delete_single_document( ).model_dump(exclude_none=True, by_alias=False), ) # 在RAG侧删除 - result = await KnowledgeBaseService.delete_doc_from_rag(request.state.session_id, [document_id]) + auth_header = getattr(request.session, "session_id", None) or request.state.personal_token + result = await KnowledgeBaseService.delete_doc_from_rag(auth_header, [document_id]) if not result: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index 1eed0288f5a9ab48da402787c8df6b98796492a7..743cb40a6332cdbfa0ab199f4cd3e1d041f1dfc9 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -27,7 +27,7 @@ from apps.services.token import TokenManager from .schema import APIInput, APIOutput -logger = logging.getLogger(__name__) +_logger = logging.getLogger(__name__) SUCCESS_HTTP_CODES = [ status.HTTP_200_OK, status.HTTP_201_CREATED, @@ -75,7 +75,7 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): async def _init(self, call_vars: CallVars) -> APIInput: """初始化API调用工具""" self._service_id = None - self._session_id = call_vars.ids.session_id + self._user_id = call_vars.ids.user_id self._auth = None if not self.node: @@ -172,11 +172,6 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): async def _apply_auth(self) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: """应用认证信息到请求参数中""" - if not self._session_id: - err = "[API] 未设置Session ID" - logger.error(err) - raise CallError(message=err, data={}) - # self._auth可能是None或ServiceApiAuth类型 # ServiceApiAuth类型包含header、cookie、query和oidc属性 req_header = {} @@ -198,9 +193,14 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): req_params[item.name] = item.value # 如果oidc配置存在 if self._service_id and self._auth.oidc: + if not self._user_id: + err = "[API] 未设置User ID" + _logger.error(err) + raise CallError(message=err, data={}) + token = await TokenManager.get_plugin_token( self._service_id, - self._session_id, + self._user_id, await oidc_provider.get_access_token_url(), 30, ) @@ -211,14 +211,14 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): async def _call_api(self, final_data: APIInput) -> APIOutput: """实际调用API,并处理返回值""" # 获取必要参数 - logger.info("[API] 调用接口 %s,请求数据为 %s", self.url, final_data) + _logger.info("[API] 调用接口 %s,请求数据为 %s", self.url, final_data) files = {} # httpx需要使用字典格式的files参数 response = await self._make_api_call(final_data, files) if response.status_code not in SUCCESS_HTTP_CODES: text = f"API发生错误:API返回状态码{response.status_code}, 原因为{response.reason_phrase}。" - logger.error(text) + _logger.error(text) raise CallError( message=text, data={"api_response_data": response.text}, @@ -227,7 +227,7 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): response_status = response.status_code response_data = response.text - logger.info("[API] 调用接口 %s,结果为 %s", self.url, response_data) + _logger.info("[API] 调用接口 %s,结果为 %s", self.url, response_data) # 如果没有返回结果 if not response_data: diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 6d1dc396395b01b1c85db24ab59585d2776df8d1..fca0d4ab344b89dced98c7aab740e49279705590 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -100,7 +100,7 @@ class CoreCall(BaseModel): ids=CallIds( task_id=executor.task.metadata.id, executor_id=executor.task.state.executorId, - session_id=executor.task.runtime.sessionId, + auth_header=executor.task.runtime.sessionId, user_id=executor.task.metadata.userId, app_id=executor.task.state.appId, conversation_id=executor.task.metadata.conversationId, diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index 4058a3e2b26b54676ef52ebd3016b07641cdbff4..73e88cb936fbddf54d4f65cc98e2136458af05d6 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -70,8 +70,8 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): async def _init(self, call_vars: CallVars) -> RAGInput: """初始化RAG工具""" - if not call_vars.ids.session_id: - err = "[RAG] 未设置Session ID" + if not call_vars.ids.auth_header: + err = "[RAG] 未设置Auth Header" _logger.error(err) raise CallError(message=err, data={}) @@ -94,7 +94,7 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): url = config.rag.rag_service.rstrip("/") + "/chunk/search" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self._sys_vars.ids.session_id}", + "Authorization": f"Bearer {self._sys_vars.ids.auth_header}", } doc_chunk_list = [] diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 8e1d6012a77008ea9cf9bf8dd8a8952014540659..2f782f2b53ea6c8f55e1d9e826056c383232abb1 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -345,12 +345,8 @@ class MCPAgentExecutor(BaseExecutor): self._current_input, state.errorMessage, ) - error_msg_str = self._get_error_message_str(state.errorMessage) - error_message = await self._planner.change_err_message_to_description( - error_message=error_msg_str, - tool=current_tool, - input_params=self._current_input, - ) + # TODO + error_msg = self._get_error_message_str(state.errorMessage) # 先更新状态 state.executorStatus = ExecutorStatus.WAITING @@ -358,13 +354,13 @@ class MCPAgentExecutor(BaseExecutor): self.task.state = state await self._push_message( - EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null}, + EventType.STEP_WAITING_FOR_PARAM, data={"message": error_msg, "params": params_with_null}, ) self.task.context.append( self._create_executor_history( step_status=state.stepStatus, extra_data={ - "message": error_message, + "message": error_msg, "params": params_with_null, }, ), diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index f0996726688378700d14b8aea628262484c0ae66..3c16e0c41427ba1a51a0e6ab97475c46f0a3458d 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -52,6 +52,7 @@ class MCPPlanner: plan = await json_generator.generate( function=function_def, conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], language=self._language, diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index 85d0fb7e033d2d1ca16786ca0bad9ea01619e06e..5bf1ed24bf8bacd511b6272fcac94156c7716e4b 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -420,36 +420,111 @@ MEMORY_TEMPLATE: dict[str, str] = { MCP_FUNCTION_SELECT: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" - 你是一个专业的MCP (Model Context Protocol) 选择器。 - 你的任务是分析推理结果并选择最合适的MCP服务器ID。 + ## 任务描述 - --- 推理结果 --- - {reasoning_result} - --- 推理结果结束 --- + 你是一个专业的 MCP (Model Context Protocol) 选择器。 + 你的任务是分析推理结果并选择最合适的 MCP 服务器 ID。 - 可用的MCP服务器ID: {mcp_ids} + ## 推理结果 - 请仔细分析推理结果,选择最符合需求的MCP服务器。 - 重点关注推理中描述的能力和需求,将其与正确的MCP服务器进行匹配。 + 以下是需要分析的推理结果: + + ``` + {{ reasoning_result }} + ``` + + ## 可用的 MCP 服务器 + + 可选的 MCP 服务器 ID 列表: + + {% for mcp_id in mcp_ids %} + - `{{ mcp_id }}` + {% endfor %} + + ## 选择要求 + + 请仔细分析推理结果,选择最符合需求的 MCP 服务器。重点关注: + + 1. **能力匹配**:推理中描述的能力需求与 MCP 服务器功能的对应关系 + 2. **场景适配**:MCP 服务器是否适合当前应用场景 + 3. **最优选择**:从可用列表中选择最合适的服务器 ID + + ## 输出格式 + + 请使用 `select_mcp` 工具进行选择,提供选中的 MCP 服务器 ID。 """, ), LanguageType.ENGLISH: dedent( r""" + ## Task Description + You are an expert MCP (Model Context Protocol) selector. Your task is to analyze the reasoning result and select the most appropriate MCP server ID. - --- Reasoning Result --- - {reasoning_result} - --- End of Reasoning --- + ## Reasoning Result + + Below is the reasoning result that needs to be analyzed: + + ``` + {{ reasoning_result }} + ``` + + ## Available MCP Servers + + List of available MCP server IDs: + + {% for mcp_id in mcp_ids %} + - `{{ mcp_id }}` + {% endfor %} + + ## Selection Requirements + + Please analyze the reasoning result carefully and select the MCP server that best matches the \ +requirements. Focus on: + + 1. **Capability Matching**: Align the capability requirements described in the reasoning with the MCP \ +server functions + 2. **Scenario Fit**: Ensure the MCP server is suitable for the current application scenario + 3. **Optimal Choice**: Select the most appropriate server ID from the available list - Available MCP Server IDs: {mcp_ids} + ## Output Format - Please analyze the reasoning carefully and select the MCP server that best matches the requirements. - Focus on matching the capabilities and requirements described in the reasoning with the correct MCP server. + Please use the `select_mcp` tool to make your selection by providing the chosen MCP server ID. """, ), } +SELECT_MCP_FUNCTION: dict[LanguageType, dict] = { + LanguageType.CHINESE: { + "name": "select_mcp", + "description": "根据推理结果选择最合适的MCP服务器", + "parameters": { + "type": "object", + "properties": { + "mcp_id": { + "type": "string", + "description": "MCP Server的ID", + }, + }, + "required": ["mcp_id"], + }, + }, + LanguageType.ENGLISH: { + "name": "select_mcp", + "description": "Select the most appropriate MCP server based on the reasoning result", + "parameters": { + "type": "object", + "properties": { + "mcp_id": { + "type": "string", + "description": "The ID of the MCP Server", + }, + }, + "required": ["mcp_id"], + }, + }, +} + FILL_PARAMS_QUERY: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 65c5f0d12840f329fc724c51cf5d08b4711db62c..372ac399acb20f66fe116f60c4064d57db813b19 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -1,8 +1,11 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """选择MCP Server及其工具""" +import copy import logging +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment from sqlalchemy import select from apps.common.postgres import postgres @@ -11,7 +14,7 @@ from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPSelectResult from apps.services.mcp_service import MCPServiceManager -from .prompt import MCP_FUNCTION_SELECT +from .prompt import MCP_FUNCTION_SELECT, SELECT_MCP_FUNCTION logger = logging.getLogger(__name__) @@ -23,6 +26,12 @@ class MCPSelector: """初始化MCP选择器""" self._llm = llm self._language = language + self._env = SandboxedEnvironment( + loader=BaseLoader, + autoescape=True, + trim_blocks=True, + lstrip_blocks=True, + ) async def _call_reasoning(self, prompt: str) -> str: """调用大模型进行推理""" @@ -39,27 +48,21 @@ class MCPSelector: async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: """调用结构化输出小模型提取JSON""" - schema = MCPSelectResult.model_json_schema() - # schema中加入选项 - schema["properties"]["mcp_id"]["enum"] = mcp_ids - - # 组装OpenAI FunctionCall格式的dict - function = { - "name": "select_mcp", - "description": "Select the most appropriate MCP server based on the reasoning result", - "parameters": schema, - } - - # 构建优化的提示词 - user_prompt = MCP_FUNCTION_SELECT[self._language].format( + function = copy.deepcopy(SELECT_MCP_FUNCTION[self._language]) + function["parameters"]["properties"]["mcp_id"]["enum"] = mcp_ids + + # 使用jinja2格式化模板 + template = self._env.from_string(MCP_FUNCTION_SELECT[self._language]) + user_prompt = template.render( reasoning_result=reasoning_result, - mcp_ids=", ".join(mcp_ids), + mcp_ids=mcp_ids, ) # 使用json_generator生成JSON result = await json_generator.generate( function=function, conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": user_prompt}, ], language=self._language, diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 12d58f09c58fd164b450f7198e6ddbf26157b418..432c6a63634d18f5a87445ae6de4dab761f199be 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -2,9 +2,8 @@ """MCP基类""" import logging -from typing import Any -from apps.llm import LLM, json_generator +from apps.llm import LLM from apps.models import LanguageType from apps.schemas.task import TaskData @@ -41,17 +40,3 @@ class MCPBase: result += chunk.content or "" return result - - async def get_json_result(self, result: str, function: dict[str, Any]) -> dict[str, Any]: - """解析推理结果;function使用OpenAI标准Function格式""" - return await json_generator.generate( - function=function, - conversation=[ - {"role": "user", "content": result}, - { - "role": "user", - "content": "Please provide a JSON response based on the above information and schema.", - }, - ], - language=self._language, - ) diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index ef2f446ccbb2ed2bda1d528cd468910af343c002..3fbe09ac1b70c8e1867a4790c7ed897652481c19 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -65,9 +65,13 @@ class MCPHost(MCPBase): "description": mcp_tool.description, "parameters": mcp_tool.inputSchema, } - return await self.get_json_result( - prompt, - function, + return await json_generator.generate( + function=function, + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + language=self._language, ) async def fill_params( # noqa: D102, PLR0913 diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 572e7cf2701bb758fae0bee997cb8feb5595fd49..bb6e798202b952212a882c685f90371bd8fab4fc 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -3,20 +3,26 @@ import logging from collections.abc import AsyncGenerator +from copy import deepcopy from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment +from apps.llm import json_generator from apps.models import MCPTools from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import ( - CHANGE_ERROR_MESSAGE_TO_DESCRIPTION, + CREATE_NEXT_STEP_FUNCTION, + EVALUATE_TOOL_RISK_FUNCTION, FINAL_ANSWER, GEN_STEP, GENERATE_FLOW_NAME, + GET_FLOW_NAME_FUNCTION, GET_MISSING_PARAMS, + GET_MISSING_PARAMS_FUNCTION, IS_PARAM_ERROR, + IS_PARAM_ERROR_FUNCTION, RISK_EVALUATE, ) from apps.scheduler.slot.slot import Slot @@ -45,14 +51,14 @@ class MCPPlanner(MCPBase): template = _env.from_string(GENERATE_FLOW_NAME[self._language]) prompt = template.render(goal=self._goal) - # 组装OpenAI标准Function格式 - function = { - "name": "get_flow_name", - "description": "Generate a descriptive name for the current workflow based on the user's goal", - "parameters": FlowName.model_json_schema(), - } - - result = await self.get_json_result(prompt, function) + result = await json_generator.generate( + function=GET_FLOW_NAME_FUNCTION[self._language], + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + language=self._language, + ) return FlowName.model_validate(result) async def create_next_step(self, history: str, tools: list[MCPTools]) -> Step: @@ -61,24 +67,18 @@ class MCPPlanner(MCPBase): template = _env.from_string(GEN_STEP[self._language]) prompt = template.render(goal=self._goal, history=history, tools=tools) - # 解析为结构化数据 - schema = Step.model_json_schema() - if "enum" not in schema["properties"]["tool_id"]: - schema["properties"]["tool_id"]["enum"] = [] - for tool in tools: - schema["properties"]["tool_id"]["enum"].append(tool.id) - - # 组装OpenAI标准Function格式 - function = { - "name": "create_next_step", - "description": ( - "Create the next execution step in the workflow by selecting " - "an appropriate tool and defining its parameters" - ), - "parameters": schema, - } - - step = await self.get_json_result(prompt, function) + # 获取函数定义并动态设置tool_id的enum + function = deepcopy(CREATE_NEXT_STEP_FUNCTION[self._language]) + function["parameters"]["properties"]["tool_id"]["enum"] = [tool.id for tool in tools] + + step = await json_generator.generate( + function=function, + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + language=self._language, + ) logger.info("[MCPPlanner] 创建下一步的执行步骤: %s", step) # 使用Step模型解析结果 return Step.model_validate(step) @@ -99,17 +99,14 @@ class MCPPlanner(MCPBase): additional_info=additional_info, ) - # 组装OpenAI标准Function格式 - function = { - "name": "evaluate_tool_risk", - "description": ( - "Evaluate the risk level and safety concerns of executing " - "a specific tool with given parameters" - ), - "parameters": ToolRisk.model_json_schema(), - } - - risk = await self.get_json_result(prompt, function) + risk = await json_generator.generate( + function=EVALUATE_TOOL_RISK_FUNCTION[self._language], + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + language=self._language, + ) # 返回风险评估结果 return ToolRisk.model_validate(risk) @@ -134,31 +131,17 @@ class MCPPlanner(MCPBase): error_message=error_message, ) - # 组装OpenAI标准Function格式 - function = { - "name": "check_parameter_error", - "description": "Determine whether an error message indicates a parameter-related error", - "parameters": IsParamError.model_json_schema(), - } - - is_param_error = await self.get_json_result(prompt, function) + is_param_error = await json_generator.generate( + function=IS_PARAM_ERROR_FUNCTION[self._language], + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + language=self._language, + ) # 使用IsParamError模型解析结果 return IsParamError.model_validate(is_param_error) - async def change_err_message_to_description( - self, error_message: str, tool: MCPTools, input_params: dict[str, Any], - ) -> str: - """将错误信息转换为工具描述""" - template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION[self._language]) - prompt = template.render( - error_message=error_message, - tool_name=tool.toolName, - tool_description=tool.description, - input_schema=tool.inputSchema, - input_params=input_params, - ) - return await self.get_reasoning_result(prompt) - async def get_missing_param( self, tool: MCPTools, input_param: dict[str, Any], error_message: dict[str, Any], ) -> dict[str, Any]: @@ -174,15 +157,19 @@ class MCPPlanner(MCPBase): error_message=error_message, ) - # 组装OpenAI标准Function格式 - function = { - "name": "get_missing_parameters", - "description": "Extract and provide the missing or incorrect parameters based on error feedback", - "parameters": schema_with_null, - } + # 获取函数定义并设置parameters为schema_with_null + function = deepcopy(GET_MISSING_PARAMS_FUNCTION[self._language]) + function["parameters"] = schema_with_null # 解析为结构化数据 - return await self.get_json_result(prompt, function) + return await json_generator.generate( + function=function, + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + language=self._language, + ) async def generate_answer(self, memory: str) -> AsyncGenerator[LLMChunk, None]: """生成最终回答,返回LLMChunk""" diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index 07d924ecf7cbbded08a6dd0905ffa7bcf8681d4a..c4fb46b71f313b71236ee532927a262dba188cd6 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -8,656 +8,449 @@ from apps.models import LanguageType GENERATE_FLOW_NAME: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" - 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。 - - # 生成流程名称时的注意事项: - 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。 - 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 - 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。 - 4. 流程名称应该尽量简短,小于20个字或者单词。 - 5. 只输出流程名称,不要输出其他内容。 - # 样例 - # 目标 - 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 - # 输出 - { - "flow_name": "扫描MySQL数据库并分析性能瓶颈,进行调优" - } - # 现在开始生成流程名称: - # 目标 - {{goal}} - # 输出 + ## 任务 + 根据用户目标生成描述性的工作流程名称。 + + ## 要求 + - **简洁明了**:准确表达达成目标的过程,长度控制在20字以内 + - **包含关键操作**:体现核心步骤(如"扫描"、"分析"、"调优"等) + - **通俗易懂**:避免过于专业的术语 + + ## 示例 + **用户目标**:我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + + **输出**:扫描MySQL数据库并分析性能瓶颈,进行调优 + + --- + **用户目标**:{{goal}} """, ).strip("\n"), LanguageType.ENGLISH: dedent( r""" - You are an intelligent assistant, your task is to generate a suitable flow name based on the user's goal. - - # Notes when generating flow names: - 1. The flow name should be concise and clear, accurately expressing the process of achieving the \ -user's goal. - 2. The flow name should include key operations or steps, such as "scan", "analyze", "tune", etc. - 3. The flow name should avoid using overly complex or professional terms, so that users can understand. - 4. The flow name should be as short as possible, less than 20 characters or words. - 5. Only output the flow name, do not output other content. - # Example - # Goal - I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. - # Output - { - "flow_name": "Scan MySQL database and analyze performance bottlenecks, and optimize it." - } - # Now start generating the flow name: - # Goal - {{goal}} - # Output + ## Task + Generate a descriptive workflow name based on the user's goal. + + ## Requirements + - **Concise and clear**: Accurately express the process, keep under 20 words + - **Include key operations**: Reflect core steps (e.g., "scan", "analyze", "optimize") + - **Easy to understand**: Avoid overly technical terminology + + ## Example + **User Goal**: I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. + + **Output**: Scan MySQL database, analyze performance bottlenecks, and optimize + + --- + **User Goal**: {{goal}} """, - ), + ).strip("\n"), } -GEN_STEP: dict[LanguageType, str] = { - LanguageType.CHINESE: dedent( - r""" - 你是一个计划生成器。 - 请根据用户的目标、当前计划和历史,生成一个新的步骤。 - - # 一个好的计划步骤应该: - 1.使用最适合的工具来完成当前步骤。 - 2.能够基于当前的计划和历史,完成阶段性的任务。 - 3.不要选择不存在的工具。 - 4.如果你认为当前已经达成了用户的目标,可以直接返回Final工具,表示计划执行结束。 - 5.tool_id中的工具ID必须是当前工具集合中存在的工具ID,而不是工具的名称。 - 6.工具在 XML标签中给出,工具的id在 下的 XML标签中给出。 - - # 样例 1 - # 目标 - 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优,我的ip是192.168.1.1,数据库端口是3306,\ -用户名是root,密码是password - # 历史记录 - 第1步:生成端口扫描命令 - - 调用工具 `command_generator`,并提供参数 `帮我生成一个mysql端口扫描命令` - - 执行状态:成功 - - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` - 第2步:执行端口扫描命令 - - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` - - 执行状态:成功 - - 得到数据:`{"result": "success"}` - # 工具 - - - mcp_tool_1 mysql分析工具,用于分析数据库性能/description> - - mcp_tool_2 文件存储工具,用于存储文件 - - mcp_tool_3 mongoDB工具,用于操作MongoDB数据库 - - Final 结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终\ -结果。 - - # 输出 - ```json +GET_FLOW_NAME_FUNCTION: dict[LanguageType, dict] = { + LanguageType.CHINESE: { + "name": "get_flow_name", + "description": "根据用户目标生成当前工作流程的描述性名称", + "parameters": { + "type": "object", + "properties": { + "flow_name": { + "type": "string", + "description": "MCP 流程名称", + "default": "", + }, + }, + "required": ["flow_name"], + }, + "examples": [ + { + "flow_name": "扫描MySQL数据库并分析性能瓶颈,进行调优", + }, + ], + }, + LanguageType.ENGLISH: { + "name": "get_flow_name", + "description": "Generate a descriptive name for the current workflow based on the user's goal", + "parameters": { + "type": "object", + "properties": { + "flow_name": { + "type": "string", + "description": "MCP workflow name", + "default": "", + }, + }, + "required": ["flow_name"], + }, + "examples": [ + { + "flow_name": "Scan MySQL database, analyze performance bottlenecks, and optimize", + }, + ], + }, +} + +CREATE_NEXT_STEP_FUNCTION: dict[LanguageType, dict] = { + LanguageType.CHINESE: { + "name": "create_next_step", + "description": "根据当前的目标、计划和历史,选择合适的工具并定义其参数来创建下一个执行步骤", + "parameters": { + "type": "object", + "properties": { + "tool_id": { + "type": "string", + "description": "工具ID", + "enum": [], + }, + "description": { + "type": "string", + "description": "步骤描述", + }, + }, + "required": ["tool_id", "description"], + }, + "examples": [ { "tool_id": "mcp_tool_1", - "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能", - } - ``` - # 样例二 - # 目标 - 计划从杭州到北京的旅游计划 - # 历史记录 - 第1步:将杭州转换为经纬度坐标 - - 调用工具 `经纬度工具`,并提供参数 `{"city_from": "杭州", "address": "西湖"}` - - 执行状态:成功 - - 得到数据:`{"location": "123.456, 78.901"}` - 第2步:查询杭州的天气 - - 调用工具 `天气查询工具`,并提供参数 `{"location": "123.456, 78.901"}` - - 执行状态:成功 - - 得到数据:`{"weather": "晴", "temperature": "25°C"}` - 第3步:将北京转换为经纬度坐标 - - 调用工具 `经纬度工具`,并提供参数 `{"city_from": "北京", "address": "天安门"}` - - 执行状态:成功 - - 得到数据:`{"location": "123.456, 78.901"}` - 第4步:查询北京的天气 - - 调用工具 `天气查询工具`,并提供参数 `{"location": "123.456, 78.901"}` - - 执行状态:成功 - - 得到数据:`{"weather": "晴", "temperature": "25°C"}` - # 工具 - - - mcp_tool_4 maps_geo_planner;将详细的结构化地址转换为经纬度坐标。\ -支持对地标性名胜景区、建筑物名称解析为经纬度坐标 - - mcp_tool_5 weather_query;天气查询,用于查询天气信息 - - mcp_tool_6 maps_direction_transit_integrated;根据用户起终点经纬度坐标规划\ -综合各类公共(火车、公交、地铁)交通方式的通勤方案,并且返回通勤方案的数据,跨城场景下必须传起点城市与终点城市 - - Final Final;结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将\ -作为最终结果。 - - # 输出 - ```json + "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能", + }, + ], + }, + LanguageType.ENGLISH: { + "name": "create_next_step", + "description": ( + "Create the next execution step in the workflow by selecting " + "an appropriate tool and defining its parameters" + ), + "parameters": { + "type": "object", + "properties": { + "tool_id": { + "type": "string", + "description": "Tool ID", + "enum": [], + }, + "description": { + "type": "string", + "description": "Step description", + }, + }, + "required": ["tool_id", "description"], + }, + "examples": [ { - "tool_id": "mcp_tool_6", - "description": "规划从杭州到北京的综合公共交通方式的通勤方案" - } - ``` - # 现在开始生成步骤: - # 目标 - {{goal}} - # 历史记录 + "tool_id": "mcp_tool_1", + "description": "Scan MySQL database performance at 192.168.1.1:3306 with user root", + }, + ], + }, +} + +GEN_STEP: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 根据用户目标、执行历史和可用工具,生成下一个执行步骤。 + + ## 任务要求 + + 作为计划生成器,你需要: + - **选择最合适的工具**:从可用工具集中选择当前阶段最适合的工具 + - **推进目标完成**:基于历史记录,制定能完成阶段性任务的步骤 + - **严格使用工具ID**:工具ID必须精确匹配可用工具列表中的ID + - **判断完成状态**:若目标已达成,选择`Final`工具结束流程 + + ## 示例 + + 假设用户需要扫描一个MySQL数据库(地址192.168.1.1,端口3306,使用root账户和password密码),分析其性能瓶颈并进行调优。 + 之前已经完成了端口扫描,确认了3306端口开放。现在需要选择下一步操作。 + 查看可用工具列表后,发现有MySQL性能分析工具(mcp_tool_1)、文件存储工具(mcp_tool_2)和结束工具(Final)。 + 此时应选择MySQL性能分析工具,并描述这一步要做什么:使用提供的数据库连接信息(192.168.1.1:3306,root/password)来扫描和分析数据库性能。 + + --- + + ## 当前任务 + + **目标**:{{goal}} + + **历史记录**: {{history}} - # 工具 - + + **可用工具**: {% for tool in tools %} - - {{tool.id}} {{tool.description}} + - **{{tool.id}}**:{{tool.description}} {% endfor %} - """, ), LanguageType.ENGLISH: dedent( r""" - You are a plan generator. - Please generate a new step based on the user's goal, current plan, and history. - - # A good plan step should: - 1. Use the most appropriate tool for the current step. - 2. Complete the tasks at each stage based on the current plan and history. - 3. Do not select a tool that does not exist. - 4. If you believe the user's goal has been achieved, return to the Final tool to complete the \ -plan execution. - - # Example 1 - # Objective - I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. My IP \ -address is 192.168.1.1, the database port is 3306, my username is root, and my password is password. - # History - Step 1: Generate a port scan command - - Call the `command_generator` tool and provide the `help me generate a MySQL port scan command` \ -parameter. - - Execution status: Success. - - Result: `{"command": "nmap -sS -p --open 192.168.1.1"}` - Step 2: Execute the port scan command - - Call the `command_executor` tool and provide the `{"command": "nmap -sS -p --open 192.168.1.1"}` \ -parameter. - - Execution status: Success. - - Result: `{"result": "success"}` - # Tools - - - mcp_tool_1 mysql_analyzer; used for analyzing database performance. - - mcp_tool_2 File storage tool; used for storing files. - - mcp_tool_3 MongoDB tool; used for operating MongoDB databases. - - Final This step completes the plan execution and the result is used as the \ -final result. - - # Output - ```json - { - "tool_id": "mcp_tool_1", - "description": "Scan the database performance of the MySQL database with IP address 192.168.1.1, \ -port 3306, username root, and password password", - } - ``` - # Example 2 - # Objective - Plan a trip from Hangzhou to Beijing - # History - Step 1: Convert Hangzhou to latitude and longitude coordinates - - Call the `maps_geo_planner` tool and provide `{"city_from": "Hangzhou", "address": "West Lake"}` - - Execution status: Success - - Result: `{"location": "123.456, 78.901"}` - Step 2: Query the weather in Hangzhou - - Call the `weather_query` tool and provide `{"location": "123.456, 78.901"}` - - Execution Status: Success - - Result: `{"weather": "Sunny", "temperature": "25°C"}` - Step 3: Convert Beijing to latitude and longitude coordinates - - Call the `maps_geo_planner` tool and provide `{"city_from": "Beijing", "address": "Tiananmen"}` - - Execution Status: Success - - Result: `{"location": "123.456, 78.901"}` - Step 4: Query the weather in Beijing - - Call the `weather_query` tool and provide `{"location": "123.456, 78.901"}` - - Execution Status: Success - - Result: `{"weather": "Sunny", "temperature": "25°C"}` - # Tools - - - mcp_tool_4 maps_geo_planner; Converts a detailed structured address \ -into longitude and latitude coordinates. Supports parsing landmarks, scenic spots, and building names into \ -longitude and latitude coordinates. - - mcp_tool_5 weather_query; Weather query, used to query weather \ -information. - - mcp_tool_6 maps_direction_transit_integrated; Plans a commuting plan \ -based on the user's starting and ending longitude and latitude coordinates, integrating various public \ -transportation modes (train, bus, subway), and returns the commuting plan data. For cross-city scenarios, \ -both the starting and ending cities must be provided. - - Final Final; Final step. When this step is reached, plan execution \ -is complete, and the resulting result is used as the final result. - - # Output - ```json - { - "tool_id": "mcp_tool_6", - "description": "Plan a comprehensive public transportation commute from Hangzhou to Beijing" - } - ``` - # Now start generating steps: - # Goal - {{goal}} - # History + Generate the next execution step based on user goals, execution history, and available tools. + + ## Task Requirements + + As a plan generator, you need to: + - **Select the most appropriate tool**: Choose the best tool from available tools for the current stage + - **Advance goal completion**: Based on execution history, formulate steps to complete phased tasks + - **Strictly use tool IDs**: Tool ID must exactly match the ID in the available tools list + - **Determine completion status**: If goal is achieved, select `Final` tool to end the workflow + + ## Example + + Suppose the user needs to scan a MySQL database (at 192.168.1.1:3306, using root account with password), + analyze its performance bottlenecks, and optimize it. + Previously, a port scan was completed and confirmed that port 3306 is open. Now we need to select the + next action. + Looking at the available tools list, there is a MySQL performance analysis tool (mcp_tool_1), a file + storage tool (mcp_tool_2), and a final tool (Final). + At this point, we should select the MySQL performance analysis tool and describe what this step will do: + use the provided database connection information (192.168.1.1:3306, root/password) to scan and analyze + database performance. + + --- + + ## Current Task + + **Goal**: {{goal}} + + **Execution History**: {{history}} - # Tools - + + **Available Tools**: {% for tool in tools %} - - {{tool.id}} {{tool.description}} + - **{{tool.id}}**: {{tool.description}} {% endfor %} - """, ), } -RISK_EVALUATE: dict[LanguageType, str] = { - LanguageType.CHINESE: dedent( - r""" - 你是一个工具执行计划评估器。 - 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。 - ```json - { - "risk": "low/medium/high", - "reason": "提示信息" - } - ``` - # 样例 - ## 工具 - - mysql_analyzer - 分析MySQL数据库性能 - - ## 工具入参 - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } - ## 附加信息 - 1. 当前MySQL数据库的版本是8.0.26 - 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 - ```ini - [mysqld] - innodb_buffer_pool_size=1G - innodb_log_file_size=256M - ``` - ## 输出 - ```json - { - "risk": "中", - "reason": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。\ -请确保在非生产环境中执行此操作。" - } - ``` - - # 现在开始评估工具执行风险: - ## 工具 - - {{tool_name}} - {{tool_description}} - - ## 工具入参 - {{input_param}} - ## 附加信息 - {{additional_info}} - ## 输出 - """, - ), - LanguageType.ENGLISH: dedent( - r""" - You are a tool execution plan evaluator. - Your task is to determine the risk of executing the current tool based on its name, description, \ -input parameters, and additional information, and output a warning. - ```json - { - "risk": "low/medium/high", - "reason": "prompt message" - } - ``` - # Example - ## Tool - - mysql_analyzer - Analyzes MySQL database performance - - ## Tool input +EVALUATE_TOOL_RISK_FUNCTION: dict[LanguageType, dict] = { + LanguageType.CHINESE: { + "name": "evaluate_tool_risk", + "description": "评估工具执行的风险等级和安全问题", + "parameters": { + "type": "object", + "properties": { + "risk": { + "type": "string", + "enum": ["low", "medium", "high"], + "description": "风险等级:低(low)、中(medium)、高(high)", + }, + "reason": { + "type": "string", + "description": "风险评估的原因说明", + "default": "", + }, + }, + "required": ["risk", "reason"], + }, + "examples": [ { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } - ## Additional information - 1. The current MySQL database version is 8.0.26 - 2. The current MySQL database configuration file path is /etc/my.cnf and contains the following \ -configuration items - ```ini - [mysqld] - innodb_buffer_pool_size=1G - innodb_log_file_size=256M - ``` - ## Output - ```json + "risk": "medium", + "reason": ( + "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。" + "请确保在非生产环境中执行此操作。" + ), + }, + ], + }, + LanguageType.ENGLISH: { + "name": "evaluate_tool_risk", + "description": ( + "Evaluate the risk level and safety concerns of executing " + "a specific tool with given parameters" + ), + "parameters": { + "type": "object", + "properties": { + "risk": { + "type": "string", + "enum": ["low", "medium", "high"], + "description": "Risk level: low, medium, or high", + }, + "reason": { + "type": "string", + "description": "Explanation of the risk assessment", + "default": "", + }, + }, + "required": ["risk", "reason"], + }, + "examples": [ { "risk": "medium", - "reason": "This tool will connect to a MySQL database and analyze performance, which may impact \ -database performance. This operation should only be performed in a non-production environment." - } - ``` - - # Now start evaluating the tool execution risk: - ## Tool - - {{tool_name}} - {{tool_description}} - - ## Tool Input Parameters - {{input_param}} - ## Additional Information - {{additional_info}} - ## Output - """, - ), + "reason": ( + "This tool will connect to a MySQL database and analyze performance, " + "which may impact database performance. This operation should only be " + "performed in a non-production environment." + ), + }, + ], + }, } -IS_PARAM_ERROR: dict[LanguageType, str] = { +RISK_EVALUATE: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" - 你是一个计划执行专家,你的任务是判断当前的步骤执行失败是否是因为参数错误导致的, - 如果是,请返回`true`,否则返回`false`。 - 必须按照以下格式回答: - ```json - { - "is_param_error": true/false, - } - ``` + 评估以下工具执行的安全风险等级,并说明风险原因。 - # 样例 - ## 用户目标 - 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优 + ## 评估要求 + - **风险等级**: 从 `low` (低)、`medium` (中)、`high` (高) 中选择 + - **风险分析**: 说明可能的安全隐患、性能影响或操作风险 + - **建议**: 如有必要,提供安全执行建议 - ## 历史 - 第1步:生成端口扫描命令 - - 调用工具 `command_generator`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` - - 执行状态:成功 - - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` - 第2步:执行端口扫描命令 - - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` - - 执行状态:成功 - - 得到数据:`{"result": "success"}` - - ## 当前步骤 - - step_3 - mysql_analyzer - 分析MySQL数据库性能 - - - ## 工具入参 - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } + ### 示例 + 假设评估MySQL性能分析工具(`mysql_analyzer`),该工具将连接到生产环境数据库(MySQL 8.0.26)并执行性能分析。 + 由于工具会执行查询和统计操作,可能短暂增加数据库负载,因此风险等级为**中等**。建议在业务低峰期执行此操作,避免影响生产服务的正常运行。 - ## 工具运行报错 - 执行MySQL性能分析命令时,出现了错误:`host is not correct`。 + --- - ## 输出 - ```json - { - "is_param_error": true - } - ``` - - # 现在开始判断工具执行失败是否是因为参数错误导致的: - ## 用户目标 - {{goal}} - - ## 历史 - {{history}} - - ## 当前步骤 - - {{step_id}} - {{step_name}} - {{step_instruction}} - + ## 工具信息 + **名称**: {{tool_name}} + **描述**: {{tool_description}} - ## 工具入参 + ## 输入参数 + ```json {{input_param}} + ``` - ## 工具运行报错 - {{error_message}} - - ## 输出 + {% if additional_info %} + ## 上下文信息 + {{additional_info}} + {% endif %} """, ), LanguageType.ENGLISH: dedent( r""" - You are a plan execution expert. Your task is to determine whether the current step execution failure is \ -due to parameter errors. - If so, return `true`; otherwise, return `false`. - The answer must be in the following format: - ```json - { - "is_param_error": true/false, - } - ``` + Evaluate the security risk level for executing the following tool and explain the reasons. - # Example - ## User Goal - I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. - - ## History - Step 1: Generate a port scan command - - Call the `command_generator` tool and provide `{"command": "nmap -sS -p--open 192.168.1.1"}` - - Execution Status: Success - - Result: `{"command": "nmap -sS -p--open 192.168.1.1"}` - Step 2: Execute the port scan command - - Call the `command_executor` tool and provide `{"command": "nmap -sS -p--open 192.168.1.1"}` - - Execution Status: Success - - Result: `{"result": "success"}` - - ## Current step - - step_3 - mysql_analyzer - Analyze MySQL database performance - - - ## Tool input parameters - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } + ## Assessment Requirements + - **Risk Level**: Choose from `low`, `medium`, or `high` + - **Risk Analysis**: Explain potential security risks, performance impacts, or operational concerns + - **Recommendations**: Provide safe execution guidance if necessary - ## Tool execution error - When executing the MySQL performance analysis command, an error occurred: `host is not correct`. + ### Example + Consider evaluating a MySQL performance analysis tool (`mysql_analyzer`) that will connect to a production \ +database (MySQL 8.0.26) and perform performance analysis. + Since the tool will execute queries and statistical operations, it may temporarily increase database load, \ +resulting in a **medium** risk level. It is recommended to execute this operation during off-peak business hours to \ +avoid impacting normal production services. - ## Output - ```json - { - "is_param_error": true - } - ``` - - # Now start judging whether the tool execution failure is due to parameter errors: - ## User goal - {{goal}} - - ## History - {{history}} + --- - ## Current step - - {{step_id}} - {{step_name}} - {{step_instruction}} - + ## Tool Information + **Name**: {{tool_name}} + **Description**: {{tool_description}} - ## Tool input parameters + ## Input Parameters + ```json {{input_param}} + ``` - ## Tool error - {{error_message}} - - ## Output + {% if additional_info %} + ## Context Information + {{additional_info}} + {% endif %} """, ), } -# 将当前程序运行的报错转换为自然语言 -CHANGE_ERROR_MESSAGE_TO_DESCRIPTION: dict[LanguageType, str] = { +IS_PARAM_ERROR_FUNCTION: dict[LanguageType, dict] = { + LanguageType.CHINESE: { + "name": "check_parameter_error", + "description": "判断错误信息是否是参数相关的错误", + "parameters": { + "type": "object", + "properties": { + "is_param_error": { + "type": "boolean", + "description": "是否是参数错误", + "default": False, + }, + }, + "required": ["is_param_error"], + }, + "examples": [ + { + "is_param_error": True, + }, + ], + }, + LanguageType.ENGLISH: { + "name": "check_parameter_error", + "description": "Determine whether an error message indicates a parameter-related error", + "parameters": { + "type": "object", + "properties": { + "is_param_error": { + "type": "boolean", + "description": "Whether it is a parameter error", + "default": False, + }, + }, + "required": ["is_param_error"], + }, + "examples": [ + { + "is_param_error": True, + }, + ], + }, +} + +IS_PARAM_ERROR: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" - 你是一个智能助手,你的任务是将当前程序运行的报错转换为自然语言描述。 - 请根据以下规则进行转换: - 1. 将报错信息转换为自然语言描述,描述应该简洁明了,能够让人理解报错的原因和影响。 - 2. 描述应该包含报错的具体内容和可能的解决方案。 - 3. 描述应该避免使用过于专业的术语,以便用户能够理解。 - 4. 描述应该尽量简短,控制在50字以内。 - 5. 只输出自然语言描述,不要输出其他内容。 - - # 样例 - ## 工具信息 - - port_scanner - 扫描主机端口 - - { - "type": "object", - "properties": { - "host": { - "type": "string", - "description": "主机地址" - }, - "port": { - "type": "integer", - "description": "端口号" - }, - "username": { - "type": "string", - "description": "用户名" - }, - "password": { - "type": "string", - "description": "密码" - } - }, - "required": ["host", "port", "username", "password"] - } - - + 判断以下工具执行失败是否由参数错误导致,并调用 check_parameter_error 工具返回结果。 - ## 工具入参 - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } - - ## 报错信息 - 执行端口扫描命令时,出现了错误:`password is not correct`。 + **判断标准**: + - **参数错误**:缺失必需参数、参数值不正确、参数格式/类型错误等 + - **非参数错误**:权限问题、网络故障、系统异常、业务逻辑错误等 - ## 输出 - 扫描端口时发生错误:密码不正确。请检查输入的密码是否正确,并重试。 + **示例**:当mysql_analyzer工具入参为 {"host": "192.0.0.1", ...},报错"host is not correct"时, + 错误明确指出host参数值不正确,属于**参数错误**,应调用工具返回 {"is_param_error": true} - # 现在开始转换报错信息: - ## 工具信息 - - {{tool_name}} - {{tool_description}} - - {{input_schema}} - - + --- - # 工具入参 - {{input_params}} + **背景信息**: + - 用户目标:{{goal}} + - 执行历史:{{history}} - # 报错信息 - {{error_message}} + **当前失败步骤**(步骤{{step_id}}): + - 工具:{{step_name}} + - 说明:{{step_instruction}} + - 入参:{{input_param}} + - 报错:{{error_message}} - ## 输出 + 请基于报错信息和上下文综合判断,调用 check_parameter_error 工具返回判断结果。 """, ), LanguageType.ENGLISH: dedent( r""" - You are an intelligent assistant. Your task is to convert the error message generated by the current \ -program into a natural language description. - Please follow the following rules for conversion: - 1. Convert the error message into a natural language description. The description should be concise \ -and clear, allowing users to understand the cause and impact of the error. - 2. The description should include the specific content of the error and possible solutions. - 3. The description should avoid using overly technical terms so that users can understand it. - 4. The description should be as brief as possible, within 50 words. - 5. Only output the natural language description, do not output other content. - - # Example - ## Tool Information - - port_scanner - Scan host ports - - { - "type": "object", - "properties": { - "host": { - "type": "string", - "description": "Host address" - }, - "port": { - "type": "integer", - "description": "Port number" - }, - "username": { - "type": "string", - "description": "Username" - }, - "password": { - "type": "string", - "description": "Password" - } - }, - "required": ["host", "port", "username", "password"] - } - - - - ## Tool input - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } + Determine whether the following tool execution failure is caused by parameter errors, and call the \ +check_parameter_error tool to return the result. - ## Error message - An error occurred while executing the port scan command: `password is not correct`. + **Judgment Criteria**: + - **Parameter errors**: missing required parameters, incorrect parameter values, parameter format/type \ +errors, etc. + - **Non-parameter errors**: permission issues, network failures, system exceptions, business logic \ +errors, etc. - ## Output - An error occurred while scanning the port: The password is incorrect. Please check that the password \ -you entered is correct and try again. + **Example**: When the mysql_analyzer tool has input {"host": "192.0.0.1", ...} and error "host is not \ +correct", the error explicitly indicates the host parameter value is incorrect, which is a **parameter error**, \ +should call the tool to return {"is_param_error": true} - # Now start converting the error message: - ## Tool information - - {{tool_name}} - {{tool_description}} - - {{input_schema}} - - + --- - ## Tool input parameters - {{input_params}} + **Background**: + - User goal: {{goal}} + - Execution history: {{history}} - ## Error message - {{error_message}} + **Current Failed Step** (Step {{step_id}}): + - Tool: {{step_name}} + - Instruction: {{step_instruction}} + - Input: {{input_param}} + - Error: {{error_message}} - ## Output + Please make a comprehensive judgment based on the error message and context, and call the \ +check_parameter_error tool to return the result. """, ), } @@ -666,199 +459,130 @@ you entered is correct and try again. GET_MISSING_PARAMS: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" - 你是一个工具参数获取器。 - 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。 - ```json - { - "host": "请补充主机地址", - "port": "请补充端口号", - "username": "请补充用户名", - "password": "请补充密码" - } - ``` + 根据工具执行报错,识别缺失或错误的参数,并将其设置为null,保留正确参数的值。 + **请调用 `get_missing_parameters` 工具返回结果**。 - # 样例 - ## 工具 - - mysql_analyzer - 分析MySQL数据库性能 - + ## 任务要求 + - **分析报错信息**:定位哪些参数导致了错误 + - **标记问题参数**:将缺失或错误的参数值设为`null` + - **保留有效值**:未出错的参数保持原值 + - **调用工具返回**:必须调用`get_missing_parameters`工具,将参数JSON作为工具入参返回 - ## 工具入参 - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } + ## 示例 - ## 工具入参schema - { - "type": "object", - "properties": { - "host": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL数据库的主机地址(可以为字符串或null)" - }, - "port": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL数据库的端口号(可以是数字、字符串或null)" - }, - "username": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL数据库的用户名(可以为字符串或null)" - }, - "password": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL数据库的密码(可以为字符串或null)" - } - }, - "required": ["host", "port", "username", "password"] - } + **工具**: `mysql_analyzer` - 分析MySQL数据库性能 - #$ 运行报错 - {"err_msg": "执行端口扫描命令时,出现了错误:`password is not correct`。", "data": {} } + **当前入参**: + ```json + {"host": "192.0.0.1", "port": 3306, "username": "root", "password": "password"} + ``` - ## 输出 + **参数Schema**: ```json { - "host": "192.0.0.1", - "port": 3306, - "username": null, - "password": null + "properties": { + "host": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "主机地址"}, + "port": {"anyOf": [{"type": "integer"}, {"type": "null"}], "description": "端口号"}, + "username": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "用户名"}, + "password": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "密码"} + }, + "required": ["host", "port", "username", "password"] } ``` - # 现在开始获取缺失的参数: - ## 工具 - - {{tool_name}} - {{tool_description}} - + **运行报错**: `password is not correct` + + **应调用工具**: + ``` + get_missing_parameters({"host": "192.0.0.1", "port": 3306, "username": null, "password": null}) + ``` + + > 分析:报错提示密码错误,因此将`password`设为`null`;同时将`username`也设为`null`以便用户重新提供凭证 + + --- + + ## 当前任务 - ## 工具入参 + **工具**: `{{tool_name}}` - {{tool_description}} + + **当前入参**: + ```json {{input_param}} + ``` - ## 工具入参schema(部分字段允许为null) + **参数Schema**: + ```json {{input_schema}} + ``` - ## 运行报错 - {{error_message}} + **运行报错**: {{error_message}} - ## 输出 + **请调用 `get_missing_parameters` 工具返回分析结果**。 """, ), LanguageType.ENGLISH: dedent( r""" - You are a tool parameter getter. - Your task is to set missing parameters to null based on the current tool's name, description, \ -input parameters, input parameter schema, and runtime errors, and output a JSON-formatted string. - ```json - { - "host": "Please provide the host address", - "port": "Please provide the port number", - "username": "Please provide the username", - "password": "Please provide the password" - } - ``` + Based on tool execution errors, identify missing or incorrect parameters, set them to null, and retain \ +correct parameter values. + **Please call the `get_missing_parameters` tool to return the result**. - # Example - ## Tool - - mysql_analyzer - Analyze MySQL database performance - + ## Task Requirements + - **Analyze error messages**: Identify which parameters caused the error + - **Mark problematic parameters**: Set missing or incorrect parameter values to `null` + - **Retain valid values**: Keep original values for parameters that didn't cause errors + - **Call tool to return**: Must call the `get_missing_parameters` tool with the parameter JSON as input - ## Tool Input Parameters + ## Example + + **Tool**: `mysql_analyzer` - Analyze MySQL database performance + + **Current Input**: ```json - { - "host": "192.0.0.1", - "port": 3306, - "username": "root", - "password": "password" - } + {"host": "192.0.0.1", "port": 3306, "username": "root", "password": "password"} ``` - ## Tool Input Parameter Schema + **Parameter Schema**: ```json { - "type": "object", - "properties": { - "host": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL database host address (can be a string or null)" - }, - "port": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL database port number (can be a number, a string, or null)" - }, - "username": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL database username (can be a string or null)" - }, - "password": { - "anyOf": [ - {"type": "string"}, - {"type": "null"} - ], - "description": "MySQL database password (can be a string or null)" - } - }, - "required": ["host", "port", "username", "password"] + "properties": { + "host": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Host address"}, + "port": {"anyOf": [{"type": "integer"}, {"type": "null"}], "description": "Port number"}, + "username": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Username"}, + "password": {"anyOf": [{"type": "string"}, {"type": "null"}], "description": "Password"} + }, + "required": ["host", "port", "username", "password"] } + ``` - ## Error info - {"err_msg": "When executing the port scan command, an error occurred: `password is not correct`.", \ -"data": {} } + **Error**: `password is not correct` - ## Output - ```json - { - "host": "192.0.0.1", - "port": 3306, - "username": null, - "password": null - } + **Should call tool**: + ``` + get_missing_parameters({"host": "192.0.0.1", "port": 3306, "username": null, "password": null}) ``` - # Now start getting the missing parameters: - ## Tool - - {{tool_name}} - {{tool_description}} - + > Analysis: Error indicates incorrect password, so set `password` to `null`; also set `username` to `null` \ +for user to provide credentials again + + --- + + ## Current Task + + **Tool**: `{{tool_name}}` - {{tool_description}} - ## Tool input parameters + **Current Input**: + ```json {{input_param}} + ``` - ## Tool input parameter schema (some fields can be null) + **Parameter Schema**: + ```json {{input_schema}} + ``` - ## Error info - {{error_message}} + **Error**: {{error_message}} - ## Output + **Please call the `get_missing_parameters` tool to return the analysis result**. """, ), } @@ -866,187 +590,153 @@ input parameters, input parameter schema, and runtime errors, and output a JSON- GEN_PARAMS: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" - 你是一个工具参数生成器。 - 你的任务是根据总的目标、阶段性的目标、工具信息、工具入参的schema和背景信息生成工具的入参。 - 注意: - 1.生成的参数在格式上必须符合工具入参的schema。 - 2.总的目标、阶段性的目标和背景信息必须被充分理解,利用其中的信息来生成工具入参。 - 3.生成的参数必须符合阶段性目标。 + 根据总体目标、阶段目标、工具信息和背景信息,生成符合schema的工具参数。 - # 样例 - ## 工具信息 - - mysql_analyzer - 分析MySQL数据库性能 - + ## 任务要求 + - **严格遵循Schema**:生成的参数必须完全符合工具入参schema的类型和格式规范 + - **充分理解上下文**:全面分析总体目标、阶段目标和背景信息,提取所有可用的参数值 + - **匹配阶段目标**:确保生成的参数能够完成当前阶段目标 - ## 总目标 - 我需要扫描当前mysql数据库,分析性能瓶颈, 并调优,ip地址是192.168.1.1,端口是3306,用户名是root,\ -密码是password。 + ## 示例 - ## 当前阶段目标 - 我要连接MySQL数据库,分析性能瓶颈,并调优。 + **工具**: + - **名称**:`mysql_analyzer` + - **描述**:分析MySQL数据库性能 - ## 工具入参的schema + **总体目标**:扫描MySQL数据库(IP: 192.168.1.1,端口: 3306,用户名: root,密码: password),\ +分析性能瓶颈并调优 + + **阶段目标**:连接MySQL数据库并分析性能 + + **参数Schema**: + ```json { - "type": "object", - "properties": { - "host": { - "type": "string", - "description": "MySQL数据库的主机地址" - }, - "port": { - "type": "integer", - "description": "MySQL数据库的端口号" - }, - "username": { - "type": "string", - "description": "MySQL数据库的用户名" - }, - "password": { - "type": "string", - "description": "MySQL数据库的密码" - } - }, - "required": ["host", "port", "username", "password"] + "type": "object", + "properties": { + "host": {"type": "string", "description": "主机地址"}, + "port": {"type": "integer", "description": "端口号"}, + "username": {"type": "string", "description": "用户名"}, + "password": {"type": "string", "description": "密码"} + }, + "required": ["host", "port", "username", "password"] } + ``` - ## 背景信息 - 第1步:生成端口扫描命令 - - 调用工具 `command_generator`,并提供参数 `帮我生成一个mysql端口扫描命令` - - 执行状态:成功 - - 得到数据:`{"command": "nmap -sS -p--open 192.168.1.1"}` - 第2步:执行端口扫描命令 - - 调用工具 `command_executor`,并提供参数 `{"command": "nmap -sS -p--open 192.168.1.1"}` - - 执行状态:成功 - - 得到数据:`{"result": "success"}` + **背景信息**: + - **步骤1**:生成端口扫描命令 → 成功,输出:`{"command": "nmap -sS -p --open 192.168.1.1"}` + - **步骤2**:执行端口扫描 → 成功,确认3306端口开放 - ## 输出 + **输出**: ```json { - "host": "192.168.1.1", - "port": 3306, - "username": "root", - "password": "password" + "host": "192.168.1.1", + "port": 3306, + "username": "root", + "password": "password" } ``` - # 现在开始生成工具入参: - ## 工具 - - {{tool_name}} - {{tool_description}} - + --- + + ## 当前任务 + + **工具**: + - **名称**:`{{tool_name}}` + - **描述**:{{tool_description}} - # 总目标 + **总体目标**: {{goal}} - # 当前阶段目标 + **阶段目标**: {{current_goal}} - # 工具入参scheme + **参数Schema**: + ```json {{input_schema}} + ``` - # 背景信息 + **背景信息**: {{background_info}} - # 输出 + **输出**: """, ).strip("\n"), LanguageType.ENGLISH: dedent( r""" - You are a tool parameter generator. - Your task is to generate tool input parameters based on the overall goal, phased goals, tool information, \ -tool input parameter schema, and background information. - Note: - 1. The generated parameters must conform to the tool input parameter schema. - 2. The overall goal, phased goals, and background information must be fully understood and used to \ -generate tool input parameters. - 3. The generated parameters must conform to the phased goals. + Generate tool parameters that conform to the schema based on the overall goal, phase goal, tool \ +information, and background context. - # Example - ## Tool Information - - mysql_analyzer - Analyze MySQL Database Performance - + ## Task Requirements + - **Strictly Follow Schema**: Generated parameters must fully conform to the tool input schema's type and \ +format specifications + - **Fully Understand Context**: Comprehensively analyze the overall goal, phase goal, and background \ +information to extract all available parameter values + - **Match Phase Goal**: Ensure the generated parameters can accomplish the current phase goal - ## Overall Goal - I need to scan the current MySQL database, analyze performance bottlenecks, and optimize it. The IP \ -address is 192.168.1.1, the port is 3306, the username is root, and the password is password. + ## Example - ## Current Phase Goal - I need to connect to the MySQL database, analyze performance bottlenecks, and optimize it. + **Tool**: + - **Name**: `mysql_analyzer` + - **Description**: Analyze MySQL database performance - ## Tool input schema + **Overall Goal**: Scan MySQL database (IP: 192.168.1.1, Port: 3306, Username: root, Password: password), \ +analyze performance bottlenecks, and optimize + + **Phase Goal**: Connect to MySQL database and analyze performance + + **Parameter Schema**: + ```json { - "type": "object", - "properties": { - "host": { - "type": "string", - "description": "MySQL database host address" - }, - "port": { - "type": "integer", - "description": "MySQL database port number" - }, - "username": { - "type": "string", - "description": "MySQL database username" - }, - "password": { - "type": "string", - "description": "MySQL database password" - } - }, - "required": ["host", "port", "username", "password"] + "type": "object", + "properties": { + "host": {"type": "string", "description": "Host address"}, + "port": {"type": "integer", "description": "Port number"}, + "username": {"type": "string", "description": "Username"}, + "password": {"type": "string", "description": "Password"} + }, + "required": ["host", "port", "username", "password"] } + ``` - ## Background information - Step 1: Generate a port scan command - - Call the `command_generator` tool and provide the `Help me generate a MySQL port scan \ -command` parameter - - Execution status: Success - - Received data: `{"command": "nmap -sS -p --open 192.168.1.1"}` - - Step 2: Execute the port scan command - - Call the `command_executor` tool and provide the parameters `{"command": "nmap -sS -p --open \ -192.168.1.1"}` - - Execution status: Success - - Received data: `{"result": "success"}` + **Background Information**: + - **Step 1**: Generate port scan command → Success, output: `{"command": "nmap -sS -p --open 192.168.1.1"}` + - **Step 2**: Execute port scan → Success, confirmed port 3306 is open - ## Output + **Output**: ```json { - "host": "192.168.1.1", - "port": 3306, - "username": "root", - "password": "password" + "host": "192.168.1.1", + "port": 3306, + "username": "root", + "password": "password" } ``` - # Now start generating tool input parameters: - ## Tool - - {{tool_name}} - {{tool_description}} - + --- - ## Overall goal + ## Current Task + + **Tool**: + - **Name**: `{{tool_name}}` + - **Description**: {{tool_description}} + + **Overall Goal**: {{goal}} - ## Current stage goal + **Phase Goal**: {{current_goal}} - ## Tool input scheme + **Parameter Schema**: + ```json {{input_schema}} + ``` - ## Background information + **Background Information**: {{background_info}} - ## Output + **Output**: """, - ), + ).strip("\n"), } REPAIR_PARAMS: dict[LanguageType, str] = { @@ -1251,6 +941,19 @@ parameter descriptions. ), } +GET_MISSING_PARAMS_FUNCTION: dict[LanguageType, dict] = { + LanguageType.CHINESE: { + "name": "get_missing_parameters", + "description": "根据错误反馈提取并提供缺失或错误的参数", + "parameters": None, + }, + LanguageType.ENGLISH: { + "name": "get_missing_parameters", + "description": "Extract and provide the missing or incorrect parameters based on error feedback", + "parameters": None, + }, +} + FINAL_ANSWER: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index ab1d5ca946e895483d25a9eab5403b43f6c2dc89..12e87acf2f559a2b08f6310f73781385c0d42a1e 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -2,6 +2,7 @@ """Flow相关的Mixin类""" import logging +from copy import deepcopy from jinja2.sandbox import SandboxedEnvironment @@ -11,7 +12,7 @@ from apps.schemas.request_data import RequestData from apps.schemas.scheduler import TopFlow from apps.schemas.task import TaskData -from .prompt import FLOW_SELECT +from .prompt import FLOW_SELECT, FLOW_SELECT_FUNCTION _logger = logging.getLogger(__name__) @@ -48,19 +49,13 @@ class FlowMixin: question=self.post_body.question, choice_list=choices, ) - schema = TopFlow.model_json_schema() - schema["properties"]["choice"]["enum"] = [choice["name"] for choice in choices] - function = { - "name": "select_flow", - "description": "Select the appropriate flow", - "parameters": schema, - } + function = deepcopy(FLOW_SELECT_FUNCTION) + function["parameters"]["properties"]["choice"]["enum"] = [choice["name"] for choice in choices] result_str = await json_generator.generate( function=function, conversation=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, - {"role": "user", "content": self.post_body.question}, ], language=self.task.runtime.language, ) diff --git a/apps/scheduler/scheduler/prompt.py b/apps/scheduler/scheduler/prompt.py index 12eb4e08075dfe53351d49a6250a5fb23833949b..395ee2d1919cfe6274bbb4bf03ce34dddb030297 100644 --- a/apps/scheduler/scheduler/prompt.py +++ b/apps/scheduler/scheduler/prompt.py @@ -5,107 +5,73 @@ from apps.models import LanguageType FLOW_SELECT: dict[LanguageType, str] = { LanguageType.CHINESE: r""" - - - 根据历史对话(包括工具调用结果)和用户问题,从给出的选项列表中,选出最符合要求的那一项。 - 在输出之前,请先思考,并使用“”标签给出思考过程。 - 结果需要使用JSON格式输出,输出格式为:{{ "choice": "选项名称" }} - - - - - 使用天气API,查询明天杭州的天气信息 - - - - API - HTTP请求,获得返回的JSON数据 - - - SQL - 查询数据库,获得数据库表中的数据 - - - - - - API 工具可以通过 API 来获取外部数据,而天气信息可能就存储在外部数据中,由于用户说明中明确 \ -提到了天气 API 的使用,因此应该优先使用 API 工具。SQL 工具用于从数据库中获取信息,考虑到天气数据的可变性和动态性\ -,不太可能存储在数据库中,因此 SQL 工具的优先级相对较低,最佳选择似乎是“API:请求特定 API,获取返回的 JSON 数据”。 - - - - {{ "choice": "API" }} - - - - - - - {{question}} - - - - {{choice_list}} - - - - - 让我们一步一步思考。 + ## 任务说明 + + 根据对话历史和用户查询,从可用选项中选择最匹配的一项。 + + ## 示例 + + **用户查询:** + > 使用天气API,查询明天杭州的天气信息 + + **可用选项:** + - **API**:HTTP请求,获取返回的JSON数据 + - **SQL**:查询数据库,获取数据库表中的数据 + + **回答:** + 用户明确提到使用天气API,天气数据通常通过外部API获取而非数据库存储,因此选择 API 选项。 + + --- + + ## 当前任务 + + **用户查询:** + {{question}} + + **可用选项:** + {{choice_list}} """, LanguageType.ENGLISH: r""" - - - Based on the historical dialogue (including tool call results) and user question, select the \ -most suitable option from the given option list. - Before outputting, please think carefully and use the "" tag to give the thinking \ -process. - The output needs to be in JSON format, the output format is: {{ "choice": "option name" }} - - - - - Use the weather API to query the weather information of Hangzhou \ -tomorrow - - - - API - HTTP request, get the returned JSON data - - - SQL - Query the database, get the data in the database table - - - - - - The API tool can get external data through API, and the weather information may be stored \ -in external data. Since the user clearly mentioned the use of weather API, it should be given priority to the API \ -tool. The SQL tool is used to get information from the database, considering the variability and dynamism of weather \ -data, it is unlikely to be stored in the database, so the priority of the SQL tool is relatively low, \ -The best choice seems to be "API: request a specific API, get the returned JSON data". - - - - {{ "choice": "API" }} - - - - - - {{question}} - - - - {{choice_list}} - - - - - Let's think step by step. - - + ## Task Description + + Based on conversation history and user query, select the most matching option from available choices. + + ## Example + + **User Query:** + > Use the weather API to query the weather information of Hangzhou tomorrow + + **Available Options:** + - **API**: HTTP request, get the returned JSON data + - **SQL**: Query the database, get the data in the database table + + **Answer:** + The user explicitly mentioned using weather API. Weather data is typically accessed via external APIs rather \ +than database storage, so the API option is selected. + + --- + + ## Current Task + + **User Query:** + {{question}} + + **Available Options:** + {{choice_list}} """, } + +FLOW_SELECT_FUNCTION = { + "name": "select_flow", + "description": "Select the appropriate flow", + "parameters": { + "type": "object", + "properties": { + "choice": { + "type": "string", + "description": "最匹配用户输入的Flow的名称", + }, + }, + "required": ["choice"], + }, +} diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 3f44772331885fb79c85997397d07ecadc64b583..0d8e4748dc473aa2fcdd82e3db05e58b6a36bece 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -23,7 +23,7 @@ class CallIds(BaseModel): task_id: uuid.UUID = Field(description="任务ID") executor_id: str = Field(description="Flow ID") - session_id: str | None = Field(description="当前用户的Session ID") + auth_header: str | None = Field(description="当前用户的Authorization Header") app_id: uuid.UUID | None = Field(description="当前应用的ID") user_id: str = Field(description="当前用户的用户ID") conversation_id: uuid.UUID | None = Field(description="当前对话的ID") diff --git a/apps/services/document.py b/apps/services/document.py index 94b8818a28197d435329bbed67ede1af2b99ae02..09c85a8f4d129dd9312f1e2d4eedb902311baf3c 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -21,7 +21,6 @@ from apps.models import ( ) from .knowledge_service import KnowledgeBaseService -from .session import SessionManager logger = logging.getLogger(__name__) @@ -213,7 +212,9 @@ class DocumentManager: @staticmethod - async def delete_document_by_conversation_id(user_id: str, conversation_id: uuid.UUID) -> list[str]: + async def delete_document_by_conversation_id( + conversation_id: uuid.UUID, auth_header: str, + ) -> list[str]: """通过ConversationID删除文件""" doc_ids = [] @@ -227,11 +228,7 @@ class DocumentManager: await session.delete(doc) await session.commit() - session_id = await SessionManager.get_session_by_user(user_id) - if not session_id: - logger.error("[DocumentManager] Session不存在: %s", user_id) - return [] - await KnowledgeBaseService.delete_doc_from_rag(session_id, doc_ids) + await KnowledgeBaseService.delete_doc_from_rag(auth_header, doc_ids) return doc_ids diff --git a/apps/services/knowledge_service.py b/apps/services/knowledge_service.py index 7e47f1b50b521e6df74373bc7b4eb61c34ef407c..df63235d04183a10bccfe05f3758a87bc993543d 100644 --- a/apps/services/knowledge_service.py +++ b/apps/services/knowledge_service.py @@ -26,11 +26,11 @@ class KnowledgeBaseService: """知识库服务""" @staticmethod - async def send_file_to_rag(session_id: str, docs: list[Document]) -> list[str]: + async def send_file_to_rag(auth_header: str, docs: list[Document]) -> list[str]: """上传文件给RAG,进行处理和向量化""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {session_id}", + "Authorization": f"Bearer {auth_header}", } rag_docs = [RAGFileParseReqItem( id=str(doc.id), @@ -51,11 +51,11 @@ class KnowledgeBaseService: return resp_data["result"] @staticmethod - async def delete_doc_from_rag(session_id: str, doc_ids: list[str]) -> list[str]: + async def delete_doc_from_rag(auth_header: str, doc_ids: list[str]) -> list[str]: """删除文件""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {session_id}", + "Authorization": f"Bearer {auth_header}", } delete_data = {"ids": doc_ids} async with httpx.AsyncClient() as client: @@ -66,11 +66,11 @@ class KnowledgeBaseService: return resp_data["result"] @staticmethod - async def get_doc_status_from_rag(session_id: str, doc_ids: list[uuid.UUID]) -> list[RAGFileStatusRspItem]: + async def get_doc_status_from_rag(auth_header: str, doc_ids: list[uuid.UUID]) -> list[RAGFileStatusRspItem]: """获取文件状态""" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {session_id}", + "Authorization": f"Bearer {auth_header}", } post_data = {"ids": doc_ids} async with httpx.AsyncClient() as client: diff --git a/apps/services/token.py b/apps/services/token.py index a5d0eadbcb8cd8199a40df5a2869fe0512bbbe9d..a4c335d69e5482189c8ef7e52356ef94e76eab8c 100644 --- a/apps/services/token.py +++ b/apps/services/token.py @@ -16,8 +16,6 @@ from apps.constants import OIDC_ACCESS_TOKEN_EXPIRE_TIME from apps.models import Session, SessionType from apps.schemas.config import OIDCConfig -from .session import SessionManager - logger = logging.getLogger(__name__) @@ -27,12 +25,11 @@ class TokenManager: @staticmethod async def get_plugin_token( plugin_id: uuid.UUID, - session_id: str, + user_id: str, access_token_url: str, expire_time: int, ) -> str: """获取插件Token""" - user_id = await SessionManager.get_user(session_id=session_id) if not user_id: err = "用户不存在!" raise ValueError(err) @@ -56,7 +53,6 @@ class TokenManager: token = await TokenManager.generate_plugin_token( plugin_id, - session_id, user_id, access_token_url, expire_time, @@ -88,7 +84,6 @@ class TokenManager: @staticmethod async def generate_plugin_token( plugin_name: uuid.UUID, - session_id: str, user_id: str, access_token_url: str, expire_time: int, @@ -125,7 +120,6 @@ class TokenManager: ).one_or_none() if not refresh_token or not refresh_token.token: - await SessionManager.delete_session(session_id) err = "Refresh token均过期,需要重新登录" raise RuntimeError(err) diff --git a/design/services/conversation.md b/design/services/conversation.md index e64409cdf100fbde6da9bd8751d46bfbc5eddbda..7bceef4b86dcf582af4133716c7b542c66c43069 100644 --- a/design/services/conversation.md +++ b/design/services/conversation.md @@ -218,7 +218,8 @@ sequenceDiagram CM-->>R: 删除完成 - R->>DM: delete_document_by_conversation_id(user_id, id) + R->>R: auth_header = request.session.session_id || request.state.personal_token + R->>DM: delete_document_by_conversation_id(id, auth_header) DM-->>R: 文档删除完成 end diff --git a/design/services/document.md b/design/services/document.md index d882539a53d9e13da6020533b775db9fbb12e7e5..e1477b9b2d411b08dc7f117c397dedab63c1cd78 100644 --- a/design/services/document.md +++ b/design/services/document.md @@ -510,7 +510,7 @@ sequenceDiagram API->>Auth: verify_session() Auth-->>API: ✓ user_id API->>Auth: verify_personal_token() - Auth-->>API: ✓ session_id + Auth-->>API: ✓ personal_token API->>DM: storage_docs(user_id, conv_id, files) @@ -524,7 +524,8 @@ sequenceDiagram DM-->>API: 返回文档列表 - API->>RAG: send_file_to_rag(session_id, docs) + API->>API: auth_header = request.session.session_id || request.state.personal_token + API->>RAG: send_file_to_rag(auth_header, docs) RAG-->>API: ✓ 接收成功 API->>User: 200 OK + 文档列表 @@ -569,7 +570,8 @@ sequenceDiagram DB-->>DM: 文档详情 DM-->>API: 未使用文档列表 - API->>RAG: get_doc_status_from_rag(session_id, doc_ids) + API->>API: auth_header = request.session.session_id || request.state.personal_token + API->>RAG: get_doc_status_from_rag(auth_header, doc_ids) RAG-->>API: 文档处理状态列表 loop 每个未使用文档 @@ -620,7 +622,8 @@ sequenceDiagram MinIO-->>DM: ✓ DM-->>API: 删除成功 - API->>RAG: delete_doc_from_rag(session_id, [doc_id]) + API->>API: auth_header = request.session.session_id || request.state.personal_token + API->>RAG: delete_doc_from_rag(auth_header, [doc_id]) RAG-->>API: 删除结果 alt RAG删除失败 @@ -793,19 +796,24 @@ erDiagram ### 9.1 与 RAG 系统集成 +**认证方式**: + +所有 RAG 系统调用使用 `auth_header` 参数进行认证。 +`auth_header` 的值优先使用 `request.session.session_id`,若不存在则使用 `request.state.personal_token`。 + **文档上传后发送到 RAG**: -调用 KnowledgeBaseService.send_file_to_rag 方法, +调用 `KnowledgeBaseService.send_file_to_rag(auth_header, docs)` 方法, 将文档发送到 RAG 系统进行异步处理(解析、向量化、索引)。 **查询文档处理状态**: -调用 KnowledgeBaseService.get_doc_status_from_rag 方法, +调用 `KnowledgeBaseService.get_doc_status_from_rag(auth_header, doc_ids)` 方法, 获取文档在 RAG 系统中的处理状态。 **从 RAG 删除文档**: -调用 KnowledgeBaseService.delete_doc_from_rag 方法, +调用 `KnowledgeBaseService.delete_doc_from_rag(auth_header, doc_ids)` 方法, 从 RAG 系统中删除文档的索引数据。 ### 9.2 RAG 状态映射