diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py index 35600c8c6326e4549f984b47431496b5068f8e8d..9b9167fb0bc3d2e89172acc3a6cfccee5a6eb5be 100644 --- a/apps/scheduler/pool/mcp/client.py +++ b/apps/scheduler/pool/mcp/client.py @@ -6,11 +6,13 @@ import logging from contextlib import AsyncExitStack from typing import TYPE_CHECKING +from datetime import timedelta import httpx from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from apps.common.config import Config from apps.constants import MCP_PATH from apps.schemas.mcp import ( MCPServerSSEConfig, @@ -59,56 +61,69 @@ class MCPClient: headers = config.headers or {} # 添加超时配置 timeout = getattr(config, 'timeout', 30) # 默认30秒超时 - logger.info("[MCPClient] MCP %s:尝试连接SSE端点 %s,超时时间: %s秒", mcp_id, config.url, timeout) - + logger.info("[MCPClient] MCP %s:尝试连接SSE端点 %s,超时时间: %s秒", + mcp_id, config.url, timeout) + try: # 先测试端点可达性 - 对于SSE端点,我们只检查连接性,不读取内容 async with httpx.AsyncClient(timeout=httpx.Timeout(connect=5.0, read=3.0)) as test_client: try: # 首先尝试HEAD请求 response = await test_client.head(config.url, headers=headers) - logger.info("[MCPClient] MCP %s:端点预检查响应状态 %s", mcp_id, response.status_code) - + logger.info("[MCPClient] MCP %s:端点预检查响应状态 %s", + mcp_id, response.status_code) + # 如果HEAD请求返回404,尝试流式GET请求验证连接性 if response.status_code == 404: - logger.info("[MCPClient] MCP %s:HEAD请求返回404,尝试流式连接验证", mcp_id) + logger.info( + "[MCPClient] MCP %s:HEAD请求返回404,尝试流式连接验证", mcp_id) try: # 使用stream=True避免读取完整响应,只验证连接 async with test_client.stream('GET', config.url, headers=headers) as stream_response: if stream_response.status_code == 200: - logger.info("[MCPClient] MCP %s:流式连接成功,端点可用", mcp_id) + logger.info( + "[MCPClient] MCP %s:流式连接成功,端点可用", mcp_id) # 立即关闭流,不读取内容 else: - logger.warning("[MCPClient] MCP %s:流式连接返回状态 %s", mcp_id, stream_response.status_code) + logger.warning( + "[MCPClient] MCP %s:流式连接返回状态 %s", mcp_id, stream_response.status_code) except httpx.ReadTimeout: # 对于SSE端点,读取超时是正常的,说明连接成功但在等待流数据 - logger.info("[MCPClient] MCP %s:连接成功但读取超时(SSE端点正常行为)", mcp_id) + logger.info( + "[MCPClient] MCP %s:连接成功但读取超时(SSE端点正常行为)", mcp_id) except Exception as get_e: - logger.error("[MCPClient] MCP %s:流式连接失败: %s", mcp_id, get_e) - raise ConnectionError(f"MCP端点不可用: {config.url}") - + logger.error( + "[MCPClient] MCP %s:流式连接失败: %s", mcp_id, get_e) + raise ConnectionError( + f"MCP端点不可用: {config.url}") + except httpx.ConnectTimeout: logger.error("[MCPClient] MCP %s:连接超时", mcp_id) raise ConnectionError(f"无法连接到MCP端点 {config.url}: 连接超时") except httpx.RequestError as e: - logger.error("[MCPClient] MCP %s:端点预检查失败: %s", mcp_id, e) + logger.error( + "[MCPClient] MCP %s:端点预检查失败: %s", mcp_id, e) raise ConnectionError(f"无法连接到MCP端点 {config.url}: {e}") except httpx.HTTPStatusError as e: - logger.warning("[MCPClient] MCP %s:端点返回HTTP错误 %s", mcp_id, e.response.status_code) + logger.warning( + "[MCPClient] MCP %s:端点返回HTTP错误 %s", mcp_id, e.response.status_code) # 对于SSE端点,某些HTTP错误是可以接受的 - + except ConnectionError: # 重新抛出连接错误 self.error_sign.set() self.status = MCPStatus.ERROR raise except Exception as e: - logger.warning("[MCPClient] MCP %s:连接预检查遇到异常,但继续尝试连接: %s", mcp_id, e) + logger.warning( + "[MCPClient] MCP %s:连接预检查遇到异常,但继续尝试连接: %s", mcp_id, e) # 对于其他异常,记录警告但不阻止连接尝试 - + client = sse_client( url=config.url, - headers=headers + headers=headers, + timeout=Config().get_config().mcp.sse_client_read_timeout, + sse_read_timeout=Config().get_config().mcp.sse_client_read_timeout ) elif isinstance(config, MCPServerStdioConfig): if user_sub: @@ -116,7 +131,8 @@ class MCPClient: else: cwd = MCP_PATH / "template" / mcp_id / "project" await cwd.mkdir(parents=True, exist_ok=True) - logger.info("[MCPClient] MCP %s:创建Stdio客户端,工作目录: %s", mcp_id, cwd.as_posix()) + logger.info("[MCPClient] MCP %s:创建Stdio客户端,工作目录: %s", + mcp_id, cwd.as_posix()) client = stdio_client(server=StdioServerParameters( command=config.command, args=config.args, @@ -133,17 +149,17 @@ class MCPClient: exit_stack = AsyncExitStack() try: logger.info("[MCPClient] MCP %s:开始建立连接", mcp_id) - + # 设置超时时间 timeout_duration = getattr(config, 'timeout', 30) read, write = await asyncio.wait_for( exit_stack.enter_async_context(client), timeout=timeout_duration ) - + self.client = ClientSession(read, write) session = await exit_stack.enter_async_context(self.client) - + # 初始化Client logger.info("[MCPClient] MCP %s:开始初始化会话", mcp_id) await asyncio.wait_for( @@ -151,7 +167,7 @@ class MCPClient: timeout=timeout_duration ) logger.info("[MCPClient] MCP %s:初始化成功", mcp_id) - + except asyncio.TimeoutError: self.error_sign.set() self.status = MCPStatus.ERROR @@ -175,7 +191,7 @@ class MCPClient: self.ready_sign.set() self.status = MCPStatus.RUNNING - + try: # 等待关闭信号 await self.stop_sign.wait() @@ -197,10 +213,12 @@ class MCPClient: if "cancel scope" in str(e).lower() or "different task" in str(e).lower(): # 这是已知的TaskGroup问题,记录警告但不影响功能 self.status = MCPStatus.STOPPED - logger.warning("[MCPClient] MCP %s:关闭时遇到TaskGroup问题(已知问题,忽略)", mcp_id) + logger.warning( + "[MCPClient] MCP %s:关闭时遇到TaskGroup问题(已知问题,忽略)", mcp_id) else: self.status = MCPStatus.ERROR - logger.warning("[MCPClient] MCP %s:关闭时发生运行时错误: %s", mcp_id, e) + logger.warning( + "[MCPClient] MCP %s:关闭时发生运行时错误: %s", mcp_id, e) except Exception as e: self.status = MCPStatus.ERROR logger.warning("[MCPClient] MCP %s:关闭时发生异常: %s", mcp_id, e) @@ -223,7 +241,8 @@ class MCPClient: self.stop_sign = asyncio.Event() # 创建协程 - self.task = asyncio.create_task(self._main_loop(user_sub, mcp_id, config)) + self.task = asyncio.create_task( + self._main_loop(user_sub, mcp_id, config)) # 等待初始化完成 try: @@ -232,7 +251,7 @@ class MCPClient: asyncio.create_task(self.error_sign.wait())], return_when=asyncio.FIRST_COMPLETED ) - + # 取消未完成的任务 for task in pending: task.cancel() @@ -240,7 +259,7 @@ class MCPClient: await task except asyncio.CancelledError: pass - + if self.error_sign.is_set(): self.status = MCPStatus.ERROR logger.error("[MCPClient] MCP %s:初始化失败", mcp_id) @@ -249,7 +268,8 @@ class MCPClient: try: self.task.result() # 这会重新抛出任务中的异常 except Exception as task_exc: - logger.error("[MCPClient] MCP %s:主任务异常: %s", mcp_id, task_exc) + logger.error( + "[MCPClient] MCP %s:主任务异常: %s", mcp_id, task_exc) raise task_exc raise Exception(f"MCP {mcp_id} 初始化失败") except Exception as e: @@ -267,17 +287,17 @@ class MCPClient: async def call_tool(self, tool_name: str, params: dict) -> "CallToolResult": """调用MCP Server的工具""" - return await self.client.call_tool(tool_name, params) + return await self.client.call_tool(tool_name, params, timeout=timedelta(seconds=3600)) async def stop(self) -> None: """停止MCP Client""" if not hasattr(self, 'stop_sign') or not hasattr(self, 'task'): logger.warning("[MCPClient] 客户端未初始化,无需停止") return - + logger.info("[MCPClient] MCP %s:开始停止客户端", self.mcp_id) self.stop_sign.set() - + try: # 等待任务完成,不设置超时以避免取消作用域问题 await self.task diff --git a/apps/schemas/config.py b/apps/schemas/config.py index c9b0c507d43971b20345b2a9c595c528e13ac52a..a4302baa0b48bd8bfd6f809fc10f315836d23f3f 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -152,6 +152,14 @@ class FunctionCallConfig(BaseModel): description="Function Call 温度", default=None) +class McpConfig(BaseModel): + """MCP配置""" + sse_client_init_timeout: int = Field( + description="MCP SSE连接超时时间,单位秒", default=60) + sse_client_read_timeout: int = Field( + description="MCP SSE读取超时时间,单位秒", default=3600) + + class SecurityConfig(BaseModel): """安全配置""" @@ -195,6 +203,7 @@ class ConfigModel(BaseModel): redis: RedisConfig llm: LLMConfig function_call: FunctionCallConfig + mcp_config: McpConfig security: SecurityConfig check: CheckConfig sandbox: SandboxConfig