diff --git a/apps/dependency/user.py b/apps/dependency/user.py
index da6d2bfc8bf68d08985c6449c6db8ea67a41c9ab..545b5a9593932fa251e2e78af1b2ec84946a2d82 100644
--- a/apps/dependency/user.py
+++ b/apps/dependency/user.py
@@ -56,7 +56,6 @@ async def verify_session(request: HTTPConnection) -> None:
detail="Session ID 鉴权失败",
)
request.state.user_sub = user
- logger.info("Session鉴权成功,user_sub=%s", user)
async def verify_personal_token(request: HTTPConnection) -> None:
@@ -91,7 +90,6 @@ async def verify_personal_token(request: HTTPConnection) -> None:
user_sub = await PersonalTokenManager.get_user_by_personal_token(token)
if user_sub is not None:
request.state.user_sub = user_sub
- logger.info("Personal Token鉴权成功,user_sub=%s", user_sub)
else:
# Personal Token无效,抛出401
logger.warning("Personal Token鉴权失败:无效的token")
diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py
index 08bfd4543031282d7f27be8580cd2d456da8c055..4938632956ed55b0c295cc487712f49a674af25d 100644
--- a/apps/llm/__init__.py
+++ b/apps/llm/__init__.py
@@ -2,14 +2,13 @@
"""模型调用模块"""
from .embedding import Embedding
-from .function import FunctionLLM, JsonGenerator
-from .reasoning import ReasoningLLM
+from .generator import JsonGenerator
+from .llm import LLM
from .token import token_calculator
__all__ = [
+ "LLM",
"Embedding",
- "FunctionLLM",
"JsonGenerator",
- "ReasoningLLM",
"token_calculator",
]
diff --git a/apps/llm/function.py b/apps/llm/function.py
deleted file mode 100644
index 63e9ec1db59b872235e7d68b3ddf80a5f01da59e..0000000000000000000000000000000000000000
--- a/apps/llm/function.py
+++ /dev/null
@@ -1,167 +0,0 @@
-# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
-"""用于FunctionCall的大模型"""
-
-import json
-import logging
-import re
-from textwrap import dedent
-from typing import Any
-
-from jinja2 import BaseLoader
-from jinja2.sandbox import SandboxedEnvironment
-from jsonschema import Draft7Validator
-
-from apps.constants import JSON_GEN_MAX_TRIAL
-from apps.models import LLMData, LLMProvider
-
-from .prompt import JSON_GEN_BASIC
-from .providers import (
- BaseProvider,
- OllamaProvider,
- OpenAIProvider,
-)
-
-_logger = logging.getLogger(__name__)
-_CLASS_DICT: dict[LLMProvider, type[BaseProvider]] = {
- LLMProvider.OLLAMA: OllamaProvider,
- LLMProvider.OPENAI: OpenAIProvider,
-}
-
-
-class FunctionLLM:
- """用于FunctionCall的模型"""
-
- def __init__(self, llm_config: LLMData | None = None) -> None:
- """初始化大模型客户端"""
- if not llm_config:
- err = "[FunctionLLM] 未设置LLM配置"
- _logger.error(err)
- raise RuntimeError(err)
-
- if llm_config.provider not in _CLASS_DICT:
- err = "[FunctionLLM] 未支持的LLM类型: %s", llm_config.provider
- _logger.error(err)
- raise RuntimeError(err)
-
- self._provider = _CLASS_DICT[llm_config.provider](llm_config)
-
- @staticmethod
- async def process_response(response: str) -> str:
- """处理大模型的输出"""
- # 尝试解析JSON
- response = dedent(response).strip()
- error_flag = False
- try:
- json.loads(response)
- except Exception: # noqa: BLE001
- error_flag = True
-
- if not error_flag:
- return response
-
- # 尝试提取```json中的JSON
- _logger.warning("[FunctionCall] 直接解析失败!尝试提取```json中的JSON")
- try:
- json_str = re.findall(r"```json(.*)```", response, re.DOTALL)[-1]
- json_str = dedent(json_str).strip()
- json.loads(json_str)
- except Exception: # noqa: BLE001
- # 尝试直接通过括号匹配JSON
- _logger.warning("[FunctionCall] 提取失败!尝试正则提取JSON")
- try:
- json_str = re.findall(r"\{.*\}", response, re.DOTALL)[-1]
- json_str = dedent(json_str).strip()
- json.loads(json_str)
- except Exception: # noqa: BLE001
- json_str = "{}"
-
- return json_str
-
-
- async def call(
- self,
- messages: list[dict[str, Any]],
- schema: dict[str, Any],
- max_tokens: int | None = None,
- temperature: float | None = None,
- ) -> dict[str, Any]:
- """
- 调用FunctionCall小模型
-
- 不开放流式输出
- """
- try:
- return json.loads(json_str)
- except Exception: # noqa: BLE001
- _logger.error("[FunctionCall] 大模型JSON解析失败:%s", json_str) # noqa: TRY400
- return {}
-
-
-class JsonGenerator:
- """JSON生成器"""
-
- def xml_parser(self, xml: str) -> dict[str, Any]:
- """XML解析器"""
- return {}
-
- def __init__(
- self, llm: FunctionLLM, query: str, conversation: list[dict[str, str]], schema: dict[str, Any],
- ) -> None:
- """初始化JSON生成器"""
- self._query = query
- self._conversation = conversation
- self._schema = schema
- self._llm = llm
-
- self._trial = {}
- self._count = 0
- self._env = SandboxedEnvironment(
- loader=BaseLoader(),
- autoescape=False,
- trim_blocks=True,
- lstrip_blocks=True,
- extensions=["jinja2.ext.loopcontrols"],
- )
- self._err_info = ""
-
- async def _single_trial(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]:
- """单次尝试"""
- # 检查类型
-
- # 渲染模板
- template = self._env.from_string(JSON_GEN_BASIC)
- return template.render(
- query=self._query,
- conversation=self._conversation,
- previous_trial=self._trial,
- schema=self._schema,
- function_call=function_call,
- err_info=self._err_info,
- )
-
- messages = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": prompt},
- ]
- return await self._llm.call(messages, self._schema, max_tokens, temperature)
-
-
- async def generate(self) -> dict[str, Any]:
- """生成JSON"""
- Draft7Validator.check_schema(self._schema)
- validator = Draft7Validator(self._schema)
-
- while self._count < JSON_GEN_MAX_TRIAL:
- self._count += 1
- result = await self._single_trial()
- try:
- validator.validate(result)
- except Exception as err: # noqa: BLE001
- err_info = str(err)
- err_info = err_info.split("\n\n")[0]
- self._err_info = err_info
- _logger.info("[JSONGenerator] 验证失败:%s", self._err_info)
- continue
- return result
-
- return {}
diff --git a/apps/llm/generator.py b/apps/llm/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..045076fcfcc3246053e62a253ffb1c007bdaf615
--- /dev/null
+++ b/apps/llm/generator.py
@@ -0,0 +1,165 @@
+# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""综合Json生成器"""
+
+import json
+import logging
+import re
+from typing import Any
+
+from jinja2 import BaseLoader
+from jinja2.sandbox import SandboxedEnvironment
+from jsonschema import Draft7Validator
+
+from apps.models import LLMType
+from apps.schemas.llm import LLMFunctions
+from apps.schemas.scheduler import LLMConfig
+
+from .llm import LLM
+from .prompt import JSON_GEN_BASIC, JSON_NO_FUNCTION_CALL
+
+_logger = logging.getLogger(__name__)
+
+JSON_GEN_MAX_TRIAL = 3 # 最大尝试次数
+
+class JsonGenerator:
+ """综合Json生成器"""
+
+ def __init__(
+ self, llm_config: LLMConfig, query: str, conversation: list[dict[str, str]], function: dict[str, Any],
+ ) -> None:
+ """初始化JSON生成器;function使用OpenAI标准Function格式"""
+ self._query = query
+ self._function = function
+
+ # 选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型
+ self._llm, self._support_function_call = self._select_llm(llm_config)
+
+ self._context = [
+ {
+ "role": "system",
+ "content": "You are a helpful assistant that can use tools to help answer user queries.",
+ },
+ ]
+ if conversation:
+ self._context.extend(conversation)
+
+ self._count = 0
+ self._env = SandboxedEnvironment(
+ loader=BaseLoader(),
+ autoescape=False,
+ trim_blocks=True,
+ lstrip_blocks=True,
+ extensions=["jinja2.ext.loopcontrols"],
+ )
+
+ def _select_llm(self, llm_config: LLMConfig) -> tuple[LLM, bool]:
+ """选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型"""
+ if llm_config.function is not None and LLMType.FUNCTION in llm_config.function.config.llmType:
+ _logger.info("[JSONGenerator] 使用Function模型,支持FunctionCall")
+ return llm_config.function, True
+
+ _logger.info("[JSONGenerator] Function模型不可用或不支持FunctionCall,回退到Reasoning模型")
+ return llm_config.reasoning, False
+
+ async def _single_trial(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]:
+ """单次尝试,包含校验逻辑"""
+ # 获取schema并创建验证器
+ schema = self._function["parameters"]
+ validator = Draft7Validator(schema)
+
+ # 执行生成
+ if self._support_function_call:
+ # 如果支持FunctionCall,使用provider的function调用逻辑
+ result = await self._call_with_function()
+ else:
+ # 如果不支持FunctionCall,使用JSON_GEN_BASIC和provider的call进行调用
+ result = await self._call_without_function(max_tokens, temperature)
+
+ # 校验结果
+ try:
+ validator.validate(result)
+ except Exception as err:
+ # 捕获校验异常信息
+ err_info = str(err)
+ err_info = err_info.split("\n\n")[0]
+ _logger.info("[JSONGenerator] 验证失败:%s", err_info)
+
+ # 将错误信息添加到上下文中
+ self._context.append({
+ "role": "assistant",
+ "content": f"Attempted to use tool but validation failed: {err_info}",
+ })
+ raise
+ else:
+ return result
+
+ async def _call_with_function(self) -> dict[str, Any]:
+ """使用FunctionCall方式调用"""
+ # 直接使用传入的function构建工具定义
+ tool = LLMFunctions(
+ name=self._function["name"],
+ description=self._function["description"],
+ param_schema=self._function["parameters"],
+ )
+
+ messages = self._context.copy()
+ messages.append({"role": "user", "content": self._query})
+
+ # 调用LLM的call方法,传入tools
+ tool_call_result = {}
+ async for chunk in self._llm.call(messages, include_thinking=False, streaming=True, tools=[tool]):
+ if chunk.tool_call:
+ tool_call_result.update(chunk.tool_call)
+
+ # 从tool_call结果中提取JSON,使用function中的函数名
+ function_name = self._function["name"]
+ if tool_call_result and function_name in tool_call_result:
+ return json.loads(tool_call_result[function_name])
+
+ return {}
+
+ async def _call_without_function(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]: # noqa: E501
+ """不使用FunctionCall方式调用"""
+ # 渲染模板
+ template = self._env.from_string(JSON_GEN_BASIC + "\n\n" + JSON_NO_FUNCTION_CALL)
+ prompt = template.render(
+ query=self._query,
+ conversation=self._context[1:] if self._context else [],
+ schema=self._function["parameters"],
+ )
+
+ messages = [
+ self._context[0],
+ {"role": "user", "content": prompt},
+ ]
+
+ # 使用LLM的call方法获取响应
+ full_response = ""
+ async for chunk in self._llm.call(messages, include_thinking=False, streaming=True):
+ if chunk.content:
+ full_response += chunk.content
+
+ # 从响应中提取JSON
+ # 查找第一个 { 和最后一个 }
+ json_match = re.search(r"\{.*\}", full_response, re.DOTALL)
+ if json_match:
+ return json.loads(json_match.group(0))
+
+ return {}
+
+
+ async def generate(self) -> dict[str, Any]:
+ """生成JSON"""
+ # 检查schema格式是否正确
+ schema = self._function["parameters"]
+ Draft7Validator.check_schema(schema)
+
+ while self._count < JSON_GEN_MAX_TRIAL:
+ self._count += 1
+ try:
+ return await self._single_trial()
+ except Exception: # noqa: BLE001
+ # 校验失败,_single_trial已经将错误信息记录到self._err_info中并记录日志
+ continue
+
+ return {}
diff --git a/apps/llm/reasoning.py b/apps/llm/llm.py
similarity index 81%
rename from apps/llm/reasoning.py
rename to apps/llm/llm.py
index efa8d23b0df1d00d5d18cfa8429a0e70903479d1..fc005cd7fe73cdc5c81e3e757fac794a542f7da9 100644
--- a/apps/llm/reasoning.py
+++ b/apps/llm/llm.py
@@ -5,7 +5,7 @@ import logging
from collections.abc import AsyncGenerator
from apps.models import LLMData, LLMProvider
-from apps.schemas.llm import LLMChunk
+from apps.schemas.llm import LLMChunk, LLMFunctions
from .providers import (
BaseProvider,
@@ -20,7 +20,7 @@ _CLASS_DICT: dict[LLMProvider, type[BaseProvider]] = {
}
-class ReasoningLLM:
+class LLM:
"""调用用于问答的大模型"""
def __init__(self, llm_config: LLMData | None) -> None:
@@ -43,14 +43,23 @@ class ReasoningLLM:
*,
include_thinking: bool = True,
streaming: bool = True,
+ tools: list[LLMFunctions] | None = None,
) -> AsyncGenerator[LLMChunk, None]:
"""调用大模型,分为流式和非流式两种"""
if streaming:
- async for chunk in await self._provider.chat(messages, include_thinking=include_thinking):
+ async for chunk in await self._provider.chat(
+ messages,
+ include_thinking=include_thinking,
+ tools=tools,
+ ):
yield chunk
else:
# 非流式模式下,需要收集所有内容
- async for _ in await self._provider.chat(messages, include_thinking=include_thinking):
+ async for _ in await self._provider.chat(
+ messages,
+ include_thinking=include_thinking,
+ tools=tools,
+ ):
continue
# 直接yield最终结果
yield LLMChunk(
diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py
index ebe44c0ee2bd70a9d393727a4b45554e7aeb4737..9851b282292f16a4fdbc556ff9428e61e1a4bf26 100644
--- a/apps/llm/prompt.py
+++ b/apps/llm/prompt.py
@@ -4,51 +4,88 @@
from textwrap import dedent
JSON_GEN_BASIC = dedent(r"""
- Respond to the query according to the background information provided.
-
- # User query
-
- User query is given in XML tags.
+
+
+ You are an intelligent assistant who can use tools to help answer user queries.
+ Your task is to respond to the query according to the background information and available tools.
+
+ Note:
+ - You have access to a set of tools that can help you gather information.
+ - You can use one tool at a time and will receive the result in the user's response.
+ - Use tools step-by-step to respond to the user's query, with each tool use informed by the \
+result of the previous tool use.
+ - The user's query is provided in the tags.
+ {% if previous_trial %}- Review the previous trial information in \
+tags to avoid repeating mistakes.{% endif %}
+
+
- {{ query }}
+ {{ query }}
-
{% if previous_trial %}
- # Previous Trial
- You tried to answer the query with one function, but the arguments are incorrect.
-
- The arguments you provided are:
-
- ```json
- {{ previous_trial }}
- ```
- And the error information is:
-
- ```
- {{ err_info }}
- ```
+
+
+ You previously attempted to answer the query by calling a tool, but the arguments were incorrect.
+
+
+ {{ previous_trial }}
+
+
+ {{ err_info }}
+
+
{% endif %}
- # Tools
- You have access to a set of tools. You can use one tool and will receive the result of \
-that tool use in the user's response. You use tools step-by-step to respond to the user's \
-query, with each tool use informed by the result of the previous tool use.
+
+ You have access to a set of tools. You can use one tool and will receive the result of that tool \
+use in the user's response.
+
""")
-
JSON_NO_FUNCTION_CALL = dedent(r"""
**Tool Use Formatting:**
- Tool uses are formatted using XML-style tags. The tool name itself becomes the XML tag name. Each \
-parameter is enclosed within its own set of tags. Here's the structure:
-
-
- value1
- value2
- ...
-
-
- Always use the actual tool name as the XML tag name for proper parsing and execution.
+ Tool uses are formatted using XML-style tags. The tool name itself becomes the root XML tag name. \
+Each parameter is enclosed within its own set of tags according to the parameter schema provided below.
+
+ **Basic Structure:**
+
+ value
+
+
+ **Parameter Schema:**
+ The available tools and their parameter schemas are provided in the following format:
+ - Tool name: The name to use as the root XML tag
+ - Parameters: Each parameter has a name, type, and description
+ - Required parameters must be included
+ - Optional parameters can be omitted
+
+ **XML Generation Rules:**
+ 1. Use the exact tool name as the root XML tag
+ 2. For each parameter, create a nested tag with the parameter name
+ 3. Place the parameter value inside the corresponding tag
+ 4. For string values: text value
+ 5. For numeric values: 123
+ 6. For boolean values: true or false
+ 7. For array values: wrap each item in the parameter tag
+ item1
+ item2
+ 8. For object values: nest the object properties as sub-tags
+
+ value1
+ value2
+
+
+ **Example:**
+ If you need to use a tool named "search" with parameters query (string) and limit (number):
+
+
+ your search text
+ 10
+
+
+ Always use the actual tool name as the root XML tag and match parameter names exactly as specified \
+in the schema for proper parsing and execution.
""")
diff --git a/apps/llm/providers/openai.py b/apps/llm/providers/openai.py
index 49e9890982e0056a7e2a588d583bb58788bdce02..0c2d3d4584b5cdfefde6a1c2f59fb3ce08d1902c 100644
--- a/apps/llm/providers/openai.py
+++ b/apps/llm/providers/openai.py
@@ -3,15 +3,12 @@
import logging
from collections.abc import AsyncGenerator
+from typing import cast
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
- ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionMessageParam,
- ChatCompletionSystemMessageParam,
- ChatCompletionToolMessageParam,
- ChatCompletionUserMessageParam,
)
from typing_extensions import override
@@ -187,35 +184,23 @@ class OpenAIProvider(BaseProvider):
def _convert_messages(self, messages: list[dict[str, str]]) -> list[ChatCompletionMessageParam]:
- """将list[dict]形式的消息转换为list[ChatCompletionMessageParam]形式"""
- converted_messages = []
+ """确保消息格式符合OpenAI API要求,特别是tool消息需要tool_call_id字段"""
+ result: list[dict[str, str]] = []
for msg in messages:
role = msg.get("role", "user")
- content = msg.get("content", "")
-
- if role == "system":
- converted_messages.append(
- ChatCompletionSystemMessageParam(role="system", content=content),
- )
- elif role == "user":
- converted_messages.append(
- ChatCompletionUserMessageParam(role="user", content=content),
- )
- elif role == "assistant":
- converted_messages.append(
- ChatCompletionAssistantMessageParam(role="assistant", content=content),
- )
- elif role == "tool":
- tool_call_id = msg.get("tool_call_id", "")
- converted_messages.append(
- ChatCompletionToolMessageParam(
- role="tool",
- content=content,
- tool_call_id=tool_call_id,
- ),
- )
- else:
+
+ # 验证角色的有效性
+ if role not in ("system", "user", "assistant", "tool"):
err = f"[OpenAIProvider] 未知角色: {role}"
_logger.error(err)
raise ValueError(err)
- return converted_messages
+
+ # 对于tool角色,确保有tool_call_id字段
+ if role == "tool" and "tool_call_id" not in msg:
+ result.append({**msg, "tool_call_id": ""})
+ else:
+ result.append(msg)
+
+ # 使用cast告诉类型检查器这是ChatCompletionMessageParam类型
+ # OpenAI库在运行时接受dict格式,无需显式构造TypedDict对象
+ return cast("list[ChatCompletionMessageParam]", result)
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
index ee0d28995dc169eeb6244f41ec8fb08d4559b9ee..53c4894fdbfcc6c79c97a5dc7a6d9e3a15b9721b 100644
--- a/apps/scheduler/call/slot/slot.py
+++ b/apps/scheduler/call/slot/slot.py
@@ -9,7 +9,7 @@ from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from pydantic import Field
-from apps.llm import FunctionLLM
+from apps.llm import JsonGenerator
from apps.models import LanguageType, NodeInfo
from apps.scheduler.call.core import CoreCall
from apps.scheduler.slot.slot import Slot as SlotProcessor
@@ -75,7 +75,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput):
answer = ""
async for chunk in self._llm(messages=conversation, streaming=True):
answer += chunk
- answer = await FunctionLLM.process_response(answer)
+ answer = await JsonGenerator.process_response(answer)
try:
data = json.loads(answer)
except Exception: # noqa: BLE001
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index 18603aaa38e38fc23bd50bf1529f3282b86f5e72..bb3aa7d5c1d78d0860fc6197a9eeaaaae3ae7420 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -114,16 +114,20 @@ class MCPHost:
{query}
"""
+ function_definition = {
+ "name": tool.toolName,
+ "description": tool.description,
+ "parameters": tool.inputSchema,
+ }
# 进行生成
json_generator = JsonGenerator(
- self._llm.function,
+ self._llm,
llm_query,
[
- {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": await self.assemble_memory()},
],
- tool.inputSchema,
+ function_definition,
)
return await json_generator.generate()
diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py
index f1fbc3818ffab293e492caeadaf6a1aac8d471c3..e00648a14ecd289d8c9751a7e5ff1bf2640fca1e 100644
--- a/apps/scheduler/mcp/plan.py
+++ b/apps/scheduler/mcp/plan.py
@@ -71,19 +71,27 @@ class MCPPlanner:
_logger.error(err)
raise RuntimeError(err)
- # 格式化Prompt
+ # 构造标准 OpenAI Function 格式
schema = MCPPlan.model_json_schema()
schema["properties"]["plans"]["maxItems"] = max_steps
+ function_def = {
+ "name": "parse_mcp_plan",
+ "description": (
+ "Parse the reasoning result into a structured MCP plan with multiple steps. "
+ "Each step should include a step ID, content description, tool name, and instruction."
+ ),
+ "parameters": schema,
+ }
+
# 使用Function模型解析结果
json_generator = JsonGenerator(
- self._llm.function,
+ self._llm,
result,
[
- {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": result},
],
- schema,
+ function_def,
)
plan = await json_generator.generate()
return MCPPlan.model_validate(plan)
diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py
index 2d9a67ac362676a4450106096cb5a07dbe5a2a10..bd6246919b69904d5872f0a73e42a989e7d2ad6c 100644
--- a/apps/scheduler/mcp/prompt.py
+++ b/apps/scheduler/mcp/prompt.py
@@ -513,3 +513,36 @@ MEMORY_TEMPLATE: dict[str, str] = {
""",
),
}
+
+MCP_FUNCTION_SELECT: dict[LanguageType, str] = {
+ LanguageType.CHINESE: dedent(
+ r"""
+ 你是一个专业的MCP (Model Context Protocol) 选择器。
+ 你的任务是分析推理结果并选择最合适的MCP服务器ID。
+
+ --- 推理结果 ---
+ {reasoning_result}
+ --- 推理结果结束 ---
+
+ 可用的MCP服务器ID: {mcp_ids}
+
+ 请仔细分析推理结果,选择最符合需求的MCP服务器。
+ 重点关注推理中描述的能力和需求,将其与正确的MCP服务器进行匹配。
+ """,
+ ),
+ LanguageType.ENGLISH: dedent(
+ r"""
+ You are an expert MCP (Model Context Protocol) selector.
+ Your task is to analyze the reasoning result and select the most appropriate MCP server ID.
+
+ --- Reasoning Result ---
+ {reasoning_result}
+ --- End of Reasoning ---
+
+ Available MCP Server IDs: {mcp_ids}
+
+ Please analyze the reasoning carefully and select the MCP server that best matches the requirements.
+ Focus on matching the capabilities and requirements described in the reasoning with the correct MCP server.
+ """,
+ ),
+}
diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py
index 0714393149a7cd0f59379b392ab4a8dab6192394..17dd699f3e8132fd5ac0a1fe94dbca03113dac34 100644
--- a/apps/scheduler/mcp/select.py
+++ b/apps/scheduler/mcp/select.py
@@ -7,20 +7,23 @@ from sqlalchemy import select
from apps.common.postgres import postgres
from apps.llm import JsonGenerator
-from apps.models import MCPTools
+from apps.models import LanguageType, MCPTools
from apps.schemas.mcp import MCPSelectResult
from apps.schemas.scheduler import LLMConfig
from apps.services.mcp_service import MCPServiceManager
+from .prompt import MCP_FUNCTION_SELECT
+
logger = logging.getLogger(__name__)
class MCPSelector:
"""MCP选择器"""
- def __init__(self, llm: LLMConfig) -> None:
+ def __init__(self, llm: LLMConfig, language: LanguageType = LanguageType.CHINESE) -> None:
"""初始化MCP选择器"""
self._llm = llm
+ self._language = language
async def _call_reasoning(self, prompt: str) -> str:
"""调用大模型进行推理"""
@@ -47,16 +50,25 @@ class MCPSelector:
# schema中加入选项
schema["properties"]["mcp_id"]["enum"] = mcp_ids
+ # 组装OpenAI FunctionCall格式的dict
+ function = {
+ "name": "select_mcp",
+ "description": "Select the most appropriate MCP server based on the reasoning result",
+ "parameters": schema,
+ }
+
+ # 构建优化的提示词
+ user_prompt = MCP_FUNCTION_SELECT[self._language].format(
+ reasoning_result=reasoning_result,
+ mcp_ids=", ".join(mcp_ids),
+ )
+
# 使用JsonGenerator生成JSON
- conversation = [
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": reasoning_result},
- ]
generator = JsonGenerator(
- llm=self._llm.function,
- query=reasoning_result,
- conversation=conversation,
- schema=schema,
+ llm_config=self._llm,
+ query=user_prompt,
+ conversation=[],
+ function=function,
)
result = await generator.generate()
diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py
index 04326edf341a4536fe7f015473f6ac64f4f141c9..1a4d168db2f959277cc14f0199abb246a9560e70 100644
--- a/apps/scheduler/mcp_agent/base.py
+++ b/apps/scheduler/mcp_agent/base.py
@@ -39,27 +39,18 @@ class MCPBase:
message,
streaming=False,
):
- result += chunk
+ result += chunk.content or ""
return result
- async def _parse_result(self, result: str, schema: dict[str, Any]) -> dict[str, Any]:
- """解析推理结果"""
- if not self._llm.function:
- err = "[MCPBase] 未找到函数调用模型"
- _logger.error(err)
- raise RuntimeError(err)
-
- json_result = await JsonGenerator.parse_result_by_stack(result, schema)
- if json_result is not None:
- return json_result
+ async def get_json_result(self, result: str, function: dict[str, Any]) -> dict[str, Any]:
+ """解析推理结果;function使用OpenAI标准Function格式"""
json_generator = JsonGenerator(
- self._llm.function,
+ self._llm,
"Please provide a JSON response based on the above information and schema.\n\n",
[
- {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": result},
],
- schema,
+ function,
)
return await json_generator.generate()
diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py
index 5179cc66ce020b57dc8afac978a7fe30557f685f..a6f850c705a28625d2311136957220c92bbfe7c8 100644
--- a/apps/scheduler/mcp_agent/host.py
+++ b/apps/scheduler/mcp_agent/host.py
@@ -59,10 +59,9 @@ class MCPHost(MCPBase):
background_info=await self.assemble_memory(runtime, context),
)
_logger.info("[MCPHost] 填充工具参数: %s", prompt)
- result = await self.get_reasoning_result(prompt)
# 使用JsonGenerator解析结果
- return await self._parse_result(
- result,
+ return await self.get_json_result(
+ prompt,
mcp_tool.inputSchema,
)
@@ -94,13 +93,19 @@ class MCPHost(MCPBase):
params_description=params_description,
)
+ # 组装OpenAI Function标准的Function结构
+ function = {
+ "name": mcp_tool.toolName,
+ "description": mcp_tool.description,
+ "parameters": mcp_tool.inputSchema,
+ }
+
json_generator = JsonGenerator(
- self._llm.function,
+ self._llm,
llm_query,
[
- {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
- mcp_tool.inputSchema,
+ function,
)
return await json_generator.generate()
diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py
index 82b098484bad35a1d230c56e1abeaf023bf4e1d9..e11b7286ff49efe0ca670382379f38232e9e29ce 100644
--- a/apps/scheduler/mcp_agent/plan.py
+++ b/apps/scheduler/mcp_agent/plan.py
@@ -46,16 +46,22 @@ class MCPPlanner(MCPBase):
"""获取当前流程的名称"""
template = _env.from_string(GENERATE_FLOW_NAME[self._language])
prompt = template.render(goal=self._goal)
- result = await self.get_reasoning_result(prompt)
- result = await self._parse_result(result, FlowName.model_json_schema())
+
+ # 组装OpenAI标准Function格式
+ function = {
+ "name": "get_flow_name",
+ "description": "Generate a descriptive name for the current workflow based on the user's goal",
+ "parameters": FlowName.model_json_schema(),
+ }
+
+ result = await self.get_json_result(prompt, function)
return FlowName.model_validate(result)
async def create_next_step(self, history: str, tools: list[MCPTools]) -> Step:
"""创建下一步的执行步骤"""
- # 获取推理结果
+ # 构建提示词
template = _env.from_string(GEN_STEP[self._language])
prompt = template.render(goal=self._goal, history=history, tools=tools)
- result = await self.get_reasoning_result(prompt)
# 解析为结构化数据
schema = Step.model_json_schema()
@@ -63,7 +69,18 @@ class MCPPlanner(MCPBase):
schema["properties"]["tool_id"]["enum"] = []
for tool in tools:
schema["properties"]["tool_id"]["enum"].append(tool.id)
- step = await self._parse_result(result, schema)
+
+ # 组装OpenAI标准Function格式
+ function = {
+ "name": "create_next_step",
+ "description": (
+ "Create the next execution step in the workflow by selecting "
+ "an appropriate tool and defining its parameters"
+ ),
+ "parameters": schema,
+ }
+
+ step = await self.get_json_result(prompt, function)
logger.info("[MCPPlanner] 创建下一步的执行步骤: %s", step)
# 使用Step模型解析结果
return Step.model_validate(step)
@@ -75,8 +92,15 @@ class MCPPlanner(MCPBase):
goal=self._goal,
tools=tools,
)
- result = await self.get_reasoning_result(prompt)
- result = await self._parse_result(result, FlowRisk.model_json_schema())
+
+ # 组装OpenAI标准Function格式
+ function = {
+ "name": "evaluate_flow_execution_risk",
+ "description": "Evaluate the potential risks and safety concerns of executing the entire workflow",
+ "parameters": FlowRisk.model_json_schema(),
+ }
+
+ result = await self.get_json_result(prompt, function)
# 使用FlowRisk模型解析结果
return FlowRisk.model_validate(result)
@@ -87,7 +111,7 @@ class MCPPlanner(MCPBase):
additional_info: str = "",
) -> ToolRisk:
"""获取MCP工具的风险评估结果"""
- # 获取推理结果
+ # 构建提示词
template = _env.from_string(RISK_EVALUATE[self._language])
prompt = template.render(
tool_name=tool.toolName,
@@ -95,10 +119,18 @@ class MCPPlanner(MCPBase):
input_param=input_param,
additional_info=additional_info,
)
- result = await self.get_reasoning_result(prompt)
- schema = ToolRisk.model_json_schema()
- risk = await self._parse_result(result, schema)
+ # 组装OpenAI标准Function格式
+ function = {
+ "name": "evaluate_tool_risk",
+ "description": (
+ "Evaluate the risk level and safety concerns of executing "
+ "a specific tool with given parameters"
+ ),
+ "parameters": ToolRisk.model_json_schema(),
+ }
+
+ risk = await self.get_json_result(prompt, function)
# 返回风险评估结果
return ToolRisk.model_validate(risk)
@@ -122,10 +154,15 @@ class MCPPlanner(MCPBase):
input_params=input_params,
error_message=error_message,
)
- result = await self.get_reasoning_result(prompt)
- # 解析为结构化数据
- schema = IsParamError.model_json_schema()
- is_param_error = await self._parse_result(result, schema)
+
+ # 组装OpenAI标准Function格式
+ function = {
+ "name": "check_parameter_error",
+ "description": "Determine whether an error message indicates a parameter-related error",
+ "parameters": IsParamError.model_json_schema(),
+ }
+
+ is_param_error = await self.get_json_result(prompt, function)
# 使用IsParamError模型解析结果
return IsParamError.model_validate(is_param_error)
@@ -157,9 +194,16 @@ class MCPPlanner(MCPBase):
schema=schema_with_null,
error_message=error_message,
)
- result = await self.get_reasoning_result(prompt)
+
+ # 组装OpenAI标准Function格式
+ function = {
+ "name": "get_missing_parameters",
+ "description": "Extract and provide the missing or incorrect parameters based on error feedback",
+ "parameters": schema_with_null,
+ }
+
# 解析为结构化数据
- return await self._parse_result(result, schema_with_null)
+ return await self.get_json_result(prompt, function)
async def generate_answer(self, memory: str) -> AsyncGenerator[LLMChunk, None]:
"""生成最终回答"""
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index 8295eb44c2a31a2022525325db005d2eae945a94..a08cb4547f0ec795f3bdde6eceb5ffd67158be80 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -11,7 +11,7 @@ from jinja2.sandbox import SandboxedEnvironment
from apps.common.queue import MessageQueue
from apps.common.security import Security
-from apps.llm import Embedding, FunctionLLM, JsonGenerator, ReasoningLLM
+from apps.llm import LLM, Embedding, JsonGenerator
from apps.models import (
AppType,
Conversation,
@@ -275,7 +275,7 @@ class Scheduler:
err = "[Scheduler] 获取问答用大模型ID失败"
_logger.error(err)
raise ValueError(err)
- reasoning_llm = ReasoningLLM(reasoning_llm)
+ reasoning_llm = LLM(reasoning_llm)
# 获取功能性的大模型信息
function_llm = None
@@ -289,7 +289,7 @@ class Scheduler:
self.user.userSub, self.user.functionLLM,
)
else:
- function_llm = FunctionLLM(function_llm)
+ function_llm = LLM(function_llm)
embedding_llm = None
if not self.user.embeddingLLM:
diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py
index 72ad05e54822d2d670ca9f2ac0c3ad8b52058937..9c7595ea3599dfbdce55c4a5be483a8794d9ef39 100644
--- a/apps/schemas/scheduler.py
+++ b/apps/schemas/scheduler.py
@@ -6,7 +6,7 @@ from typing import Any
from pydantic import BaseModel, Field
-from apps.llm import Embedding, FunctionLLM, ReasoningLLM
+from apps.llm import LLM, Embedding
from apps.models import ExecutorHistory, LanguageType
from .enum_var import CallOutputType
@@ -15,8 +15,8 @@ from .enum_var import CallOutputType
class LLMConfig(BaseModel):
"""LLM配置"""
- reasoning: ReasoningLLM = Field(description="推理LLM")
- function: FunctionLLM | None = Field(description="函数LLM")
+ reasoning: LLM = Field(description="推理LLM")
+ function: LLM | None = Field(description="函数LLM")
embedding: Embedding | None = Field(description="Embedding")
class Config:
diff --git a/apps/services/task.py b/apps/services/task.py
index 9a543e01bba9d6d6ce086c88897ed198babc94ec..6f9560511dc31596e3ef3ad689ec7941607b8061 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -154,14 +154,86 @@ class TaskManager:
await session.commit()
- @classmethod
- async def save_task(cls, task_id: uuid.UUID, task: Task) -> None:
- """保存任务块"""
- task_collection = MongoDB().get_collection("task")
-
- # 更新已有的Task记录
- await task_collection.update_one(
- {"_id": task_id},
- {"$set": task.model_dump(by_alias=True, exclude_none=True)},
- upsert=True,
- )
+ @staticmethod
+ async def save_task(task_id: uuid.UUID, task: Task) -> None:
+ """保存Task数据到PostgreSQL"""
+ async with postgres.session() as session:
+ # 查询是否存在该Task
+ existing_task = (await session.scalars(
+ select(Task).where(Task.id == task_id),
+ )).one_or_none()
+
+ if existing_task:
+ # 更新现有Task
+ for key, value in task.__dict__.items():
+ if not key.startswith("_"):
+ setattr(existing_task, key, value)
+ else:
+ # 插入新Task
+ session.add(task)
+
+ await session.commit()
+
+
+ @staticmethod
+ async def save_task_runtime(task_runtime: TaskRuntime) -> None:
+ """保存TaskRuntime数据到PostgreSQL"""
+ async with postgres.session() as session:
+ # 查询是否存在该TaskRuntime
+ existing_runtime = (await session.scalars(
+ select(TaskRuntime).where(TaskRuntime.taskId == task_runtime.taskId),
+ )).one_or_none()
+
+ if existing_runtime:
+ # 更新现有TaskRuntime
+ for key, value in task_runtime.__dict__.items():
+ if not key.startswith("_"):
+ setattr(existing_runtime, key, value)
+ else:
+ # 插入新TaskRuntime
+ session.add(task_runtime)
+
+ await session.commit()
+
+
+ @staticmethod
+ async def save_executor_checkpoint(checkpoint: ExecutorCheckpoint) -> None:
+ """保存ExecutorCheckpoint数据到PostgreSQL"""
+ async with postgres.session() as session:
+ # 查询是否存在该Checkpoint
+ existing_checkpoint = (await session.scalars(
+ select(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == checkpoint.taskId),
+ )).one_or_none()
+
+ if existing_checkpoint:
+ # 更新现有Checkpoint
+ for key, value in checkpoint.__dict__.items():
+ if not key.startswith("_"):
+ setattr(existing_checkpoint, key, value)
+ else:
+ # 插入新Checkpoint
+ session.add(checkpoint)
+
+ await session.commit()
+
+
+ @staticmethod
+ async def save_flow_context(context: list[ExecutorHistory]) -> None:
+ """保存Flow上下文信息到PostgreSQL"""
+ async with postgres.session() as session:
+ for ctx in context:
+ # 查询是否存在该ExecutorHistory
+ existing_history = (await session.scalars(
+ select(ExecutorHistory).where(ExecutorHistory.id == ctx.id),
+ )).one_or_none()
+
+ if existing_history:
+ # 更新现有History
+ for key, value in ctx.__dict__.items():
+ if not key.startswith("_"):
+ setattr(existing_history, key, value)
+ else:
+ # 插入新History
+ session.add(ctx)
+
+ await session.commit()
diff --git a/design/call/slot.md b/design/call/slot.md
index 0a48e6bdabbf6b28788dd0820c05b12e698f1fc6..251f14f0a474c6a13deea554715c9afc5258399f 100644
--- a/design/call/slot.md
+++ b/design/call/slot.md
@@ -2,7 +2,7 @@
## 概述
-Slot工具是一个智能参数自动填充工具,它通过分析历史步骤数据、背景信息和用户问题,自动生成符合JSON Schema要求的参数对象。该工具使用FunctionCall进行精确的剩余参数填充。
+Slot工具是一个智能参数自动填充工具,它通过分析历史步骤数据、背景信息和用户问题,自动生成符合JSON Schema要求的参数对象。该工具使用JsonGenerator进行精确的剩余参数填充。
## 功能特性
@@ -34,7 +34,7 @@ Slot工具的核心实现类,继承自`CoreCall`基类,负责参数自动填
- `_init()`: 初始化工具输入,处理历史数据和Schema
- `_exec()`: 执行参数填充逻辑
- `_llm_slot_fill()`: 使用大语言模型填充参数
-- `_function_slot_fill()`: 使用FunctionCall填充剩余参数
+- `_function_slot_fill()`: 使用JsonGenerator填充剩余参数
## 数据结构
@@ -100,7 +100,7 @@ graph LR
B --> C[SlotProcessor]
C --> D[SlotInput]
D --> E[LLM填充]
- E --> F[FunctionCall填充]
+ E --> F[JsonGenerator填充]
F --> G[SlotOutput]
subgraph "输入数据"
@@ -115,7 +115,7 @@ graph LR
B1[历史数据处理]
B2[Schema验证]
B3[LLM推理填充]
- B4[FunctionCall精确填充]
+ B4[JsonGenerator精确填充]
B5[数据转换和验证]
end
@@ -207,7 +207,7 @@ graph TD
I --> J[检查剩余Schema]
J --> K{仍有剩余Schema?}
K -->|否| L[输出最终结果]
- K -->|是| M[第二阶段:FunctionCall填充]
+ K -->|是| M[第二阶段:JsonGenerator填充]
M --> N[数据转换]
N --> O[最终Schema检查]
O --> L
@@ -220,9 +220,9 @@ graph TD
G4[解析JSON响应]
end
- subgraph "FunctionCall填充阶段"
+ subgraph "JsonGenerator填充阶段"
M1[构建对话上下文]
- M2[调用JSON函数]
+ M2[调用JSON生成器]
M3[获取结构化数据]
end
@@ -247,7 +247,7 @@ sequenceDiagram
participant SP as SlotProcessor
participant T as Jinja2Template
participant L as LLM
- participant F as FunctionLLM
+ participant J as JsonGenerator
participant V as JSONValidator
E->>S: instance(executor, node)
@@ -268,8 +268,8 @@ sequenceDiagram
S->>L: 调用LLM生成参数
L-->>S: 流式返回JSON响应
- S->>F: process_response(answer)
- F-->>S: 处理后的JSON字符串
+ S->>J: process_response(answer)
+ J-->>S: 处理后的JSON字符串
S->>V: json.loads(answer)
V-->>S: 解析的JSON数据
@@ -279,8 +279,8 @@ sequenceDiagram
SP-->>S: 剩余Schema
alt 仍有剩余Schema
- S->>F: _json(messages, schema)
- F-->>S: 结构化JSON数据
+ S->>J: _json(messages, schema)
+ J-->>S: 结构化JSON数据
S->>SP: convert_json(slot_data)
SP-->>S: 转换后的数据
S->>SP: check_json(slot_data)
@@ -298,15 +298,15 @@ sequenceDiagram
```mermaid
graph LR
A[原始Schema] --> B[LLM填充阶段]
- B --> C[FunctionCall填充阶段]
+ B --> C[JsonGenerator填充阶段]
C --> D[最终结果]
-
+
B --> B1[生成自然语言提示词]
B --> B2[LLM推理生成JSON]
B --> B3[解析和验证JSON]
-
+
C --> C1[构建结构化对话]
- C --> C2[调用JSON生成函数]
+ C --> C2[调用JSON生成器]
C --> C3[获取精确JSON数据]
D --> D1[slot_data: 填充数据]