diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py index b672690536bc1f03923de51bef6b1ed88c785a04..fe9fe5a2de0cc2764d5158b35dcfc4515f7d5623 100644 --- a/apps/scheduler/pool/mcp/client.py +++ b/apps/scheduler/pool/mcp/client.py @@ -5,11 +5,12 @@ import asyncio import logging from contextlib import AsyncExitStack from typing import TYPE_CHECKING - +from datetime import timedelta from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from apps.common.config import Config from apps.constants import MCP_PATH from apps.schemas.mcp import ( MCPServerSSEConfig, @@ -58,7 +59,9 @@ class MCPClient: headers = config.headers or {} client = sse_client( url=config.url, - headers=headers + headers=headers, + timeout=Config().get_config().mcp_config.sse_client_init_timeout, + sse_read_timeout=Config().get_config().mcp_config.sse_client_read_timeout, ) elif isinstance(config, MCPServerStdioConfig): if user_sub: @@ -123,7 +126,8 @@ class MCPClient: self.stop_sign = asyncio.Event() # 创建协程 - self.task = asyncio.create_task(self._main_loop(user_sub, mcp_id, config)) + self.task = asyncio.create_task( + self._main_loop(user_sub, mcp_id, config)) # 等待初始化完成 done, pending = await asyncio.wait( @@ -141,7 +145,7 @@ class MCPClient: async def call_tool(self, tool_name: str, params: dict) -> "CallToolResult": """调用MCP Server的工具""" - return await self.client.call_tool(tool_name, params) + return await self.client.call_tool(tool_name, params, read_timeout_seconds=timedelta(seconds=3600)) async def stop(self) -> None: """停止MCP Client""" diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 675a9ba7f26531206a5ea5f8a7598d4344ae2124..9f0815e6eb528fbcc3b7e7a00107909bc129706e 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -102,8 +102,10 @@ class FunctionCallConfig(BaseModel): api_key: str = Field(description="Function Call API密钥") max_tokens: int | None = Field(description="Function Call 最大Token数", default=None) temperature: float | None = Field(description="Function Call 温度", default=None) - - +class McpConfig(BaseModel): + """MCP配置""" + sse_client_init_timeout: int = Field(description="MCP SSE连接超时时间,单位秒", default=60) + sse_client_read_timeout: int = Field(description="MCP SSE读取超时时间,单位秒", default=3600) class SecurityConfig(BaseModel): """安全配置""" @@ -138,6 +140,7 @@ class ConfigModel(BaseModel): mongodb: MongoDBConfig llm: LLMConfig function_call: FunctionCallConfig + mcp_config: McpConfig security: SecurityConfig check: CheckConfig extra: ExtraConfig