diff --git a/apps/dependency/user.py b/apps/dependency/user.py index da6d2bfc8bf68d08985c6449c6db8ea67a41c9ab..545b5a9593932fa251e2e78af1b2ec84946a2d82 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -56,7 +56,6 @@ async def verify_session(request: HTTPConnection) -> None: detail="Session ID 鉴权失败", ) request.state.user_sub = user - logger.info("Session鉴权成功,user_sub=%s", user) async def verify_personal_token(request: HTTPConnection) -> None: @@ -91,7 +90,6 @@ async def verify_personal_token(request: HTTPConnection) -> None: user_sub = await PersonalTokenManager.get_user_by_personal_token(token) if user_sub is not None: request.state.user_sub = user_sub - logger.info("Personal Token鉴权成功,user_sub=%s", user_sub) else: # Personal Token无效,抛出401 logger.warning("Personal Token鉴权失败:无效的token") diff --git a/apps/llm/__init__.py b/apps/llm/__init__.py index 08bfd4543031282d7f27be8580cd2d456da8c055..4938632956ed55b0c295cc487712f49a674af25d 100644 --- a/apps/llm/__init__.py +++ b/apps/llm/__init__.py @@ -2,14 +2,13 @@ """模型调用模块""" from .embedding import Embedding -from .function import FunctionLLM, JsonGenerator -from .reasoning import ReasoningLLM +from .generator import JsonGenerator +from .llm import LLM from .token import token_calculator __all__ = [ + "LLM", "Embedding", - "FunctionLLM", "JsonGenerator", - "ReasoningLLM", "token_calculator", ] diff --git a/apps/llm/function.py b/apps/llm/function.py deleted file mode 100644 index 63e9ec1db59b872235e7d68b3ddf80a5f01da59e..0000000000000000000000000000000000000000 --- a/apps/llm/function.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""用于FunctionCall的大模型""" - -import json -import logging -import re -from textwrap import dedent -from typing import Any - -from jinja2 import BaseLoader -from jinja2.sandbox import SandboxedEnvironment -from jsonschema import Draft7Validator - -from apps.constants import JSON_GEN_MAX_TRIAL -from apps.models import LLMData, LLMProvider - -from .prompt import JSON_GEN_BASIC -from .providers import ( - BaseProvider, - OllamaProvider, - OpenAIProvider, -) - -_logger = logging.getLogger(__name__) -_CLASS_DICT: dict[LLMProvider, type[BaseProvider]] = { - LLMProvider.OLLAMA: OllamaProvider, - LLMProvider.OPENAI: OpenAIProvider, -} - - -class FunctionLLM: - """用于FunctionCall的模型""" - - def __init__(self, llm_config: LLMData | None = None) -> None: - """初始化大模型客户端""" - if not llm_config: - err = "[FunctionLLM] 未设置LLM配置" - _logger.error(err) - raise RuntimeError(err) - - if llm_config.provider not in _CLASS_DICT: - err = "[FunctionLLM] 未支持的LLM类型: %s", llm_config.provider - _logger.error(err) - raise RuntimeError(err) - - self._provider = _CLASS_DICT[llm_config.provider](llm_config) - - @staticmethod - async def process_response(response: str) -> str: - """处理大模型的输出""" - # 尝试解析JSON - response = dedent(response).strip() - error_flag = False - try: - json.loads(response) - except Exception: # noqa: BLE001 - error_flag = True - - if not error_flag: - return response - - # 尝试提取```json中的JSON - _logger.warning("[FunctionCall] 直接解析失败!尝试提取```json中的JSON") - try: - json_str = re.findall(r"```json(.*)```", response, re.DOTALL)[-1] - json_str = dedent(json_str).strip() - json.loads(json_str) - except Exception: # noqa: BLE001 - # 尝试直接通过括号匹配JSON - _logger.warning("[FunctionCall] 提取失败!尝试正则提取JSON") - try: - json_str = re.findall(r"\{.*\}", response, re.DOTALL)[-1] - json_str = dedent(json_str).strip() - json.loads(json_str) - except Exception: # noqa: BLE001 - json_str = "{}" - - return json_str - - - async def call( - self, - messages: list[dict[str, Any]], - schema: dict[str, Any], - max_tokens: int | None = None, - temperature: float | None = None, - ) -> dict[str, Any]: - """ - 调用FunctionCall小模型 - - 不开放流式输出 - """ - try: - return json.loads(json_str) - except Exception: # noqa: BLE001 - _logger.error("[FunctionCall] 大模型JSON解析失败:%s", json_str) # noqa: TRY400 - return {} - - -class JsonGenerator: - """JSON生成器""" - - def xml_parser(self, xml: str) -> dict[str, Any]: - """XML解析器""" - return {} - - def __init__( - self, llm: FunctionLLM, query: str, conversation: list[dict[str, str]], schema: dict[str, Any], - ) -> None: - """初始化JSON生成器""" - self._query = query - self._conversation = conversation - self._schema = schema - self._llm = llm - - self._trial = {} - self._count = 0 - self._env = SandboxedEnvironment( - loader=BaseLoader(), - autoescape=False, - trim_blocks=True, - lstrip_blocks=True, - extensions=["jinja2.ext.loopcontrols"], - ) - self._err_info = "" - - async def _single_trial(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]: - """单次尝试""" - # 检查类型 - - # 渲染模板 - template = self._env.from_string(JSON_GEN_BASIC) - return template.render( - query=self._query, - conversation=self._conversation, - previous_trial=self._trial, - schema=self._schema, - function_call=function_call, - err_info=self._err_info, - ) - - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ] - return await self._llm.call(messages, self._schema, max_tokens, temperature) - - - async def generate(self) -> dict[str, Any]: - """生成JSON""" - Draft7Validator.check_schema(self._schema) - validator = Draft7Validator(self._schema) - - while self._count < JSON_GEN_MAX_TRIAL: - self._count += 1 - result = await self._single_trial() - try: - validator.validate(result) - except Exception as err: # noqa: BLE001 - err_info = str(err) - err_info = err_info.split("\n\n")[0] - self._err_info = err_info - _logger.info("[JSONGenerator] 验证失败:%s", self._err_info) - continue - return result - - return {} diff --git a/apps/llm/generator.py b/apps/llm/generator.py new file mode 100644 index 0000000000000000000000000000000000000000..045076fcfcc3246053e62a253ffb1c007bdaf615 --- /dev/null +++ b/apps/llm/generator.py @@ -0,0 +1,165 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""综合Json生成器""" + +import json +import logging +import re +from typing import Any + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment +from jsonschema import Draft7Validator + +from apps.models import LLMType +from apps.schemas.llm import LLMFunctions +from apps.schemas.scheduler import LLMConfig + +from .llm import LLM +from .prompt import JSON_GEN_BASIC, JSON_NO_FUNCTION_CALL + +_logger = logging.getLogger(__name__) + +JSON_GEN_MAX_TRIAL = 3 # 最大尝试次数 + +class JsonGenerator: + """综合Json生成器""" + + def __init__( + self, llm_config: LLMConfig, query: str, conversation: list[dict[str, str]], function: dict[str, Any], + ) -> None: + """初始化JSON生成器;function使用OpenAI标准Function格式""" + self._query = query + self._function = function + + # 选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型 + self._llm, self._support_function_call = self._select_llm(llm_config) + + self._context = [ + { + "role": "system", + "content": "You are a helpful assistant that can use tools to help answer user queries.", + }, + ] + if conversation: + self._context.extend(conversation) + + self._count = 0 + self._env = SandboxedEnvironment( + loader=BaseLoader(), + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + extensions=["jinja2.ext.loopcontrols"], + ) + + def _select_llm(self, llm_config: LLMConfig) -> tuple[LLM, bool]: + """选择LLM:优先使用Function模型(如果存在且支持FunctionCall),否则回退到Reasoning模型""" + if llm_config.function is not None and LLMType.FUNCTION in llm_config.function.config.llmType: + _logger.info("[JSONGenerator] 使用Function模型,支持FunctionCall") + return llm_config.function, True + + _logger.info("[JSONGenerator] Function模型不可用或不支持FunctionCall,回退到Reasoning模型") + return llm_config.reasoning, False + + async def _single_trial(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]: + """单次尝试,包含校验逻辑""" + # 获取schema并创建验证器 + schema = self._function["parameters"] + validator = Draft7Validator(schema) + + # 执行生成 + if self._support_function_call: + # 如果支持FunctionCall,使用provider的function调用逻辑 + result = await self._call_with_function() + else: + # 如果不支持FunctionCall,使用JSON_GEN_BASIC和provider的call进行调用 + result = await self._call_without_function(max_tokens, temperature) + + # 校验结果 + try: + validator.validate(result) + except Exception as err: + # 捕获校验异常信息 + err_info = str(err) + err_info = err_info.split("\n\n")[0] + _logger.info("[JSONGenerator] 验证失败:%s", err_info) + + # 将错误信息添加到上下文中 + self._context.append({ + "role": "assistant", + "content": f"Attempted to use tool but validation failed: {err_info}", + }) + raise + else: + return result + + async def _call_with_function(self) -> dict[str, Any]: + """使用FunctionCall方式调用""" + # 直接使用传入的function构建工具定义 + tool = LLMFunctions( + name=self._function["name"], + description=self._function["description"], + param_schema=self._function["parameters"], + ) + + messages = self._context.copy() + messages.append({"role": "user", "content": self._query}) + + # 调用LLM的call方法,传入tools + tool_call_result = {} + async for chunk in self._llm.call(messages, include_thinking=False, streaming=True, tools=[tool]): + if chunk.tool_call: + tool_call_result.update(chunk.tool_call) + + # 从tool_call结果中提取JSON,使用function中的函数名 + function_name = self._function["name"] + if tool_call_result and function_name in tool_call_result: + return json.loads(tool_call_result[function_name]) + + return {} + + async def _call_without_function(self, max_tokens: int | None = None, temperature: float | None = None) -> dict[str, Any]: # noqa: E501 + """不使用FunctionCall方式调用""" + # 渲染模板 + template = self._env.from_string(JSON_GEN_BASIC + "\n\n" + JSON_NO_FUNCTION_CALL) + prompt = template.render( + query=self._query, + conversation=self._context[1:] if self._context else [], + schema=self._function["parameters"], + ) + + messages = [ + self._context[0], + {"role": "user", "content": prompt}, + ] + + # 使用LLM的call方法获取响应 + full_response = "" + async for chunk in self._llm.call(messages, include_thinking=False, streaming=True): + if chunk.content: + full_response += chunk.content + + # 从响应中提取JSON + # 查找第一个 { 和最后一个 } + json_match = re.search(r"\{.*\}", full_response, re.DOTALL) + if json_match: + return json.loads(json_match.group(0)) + + return {} + + + async def generate(self) -> dict[str, Any]: + """生成JSON""" + # 检查schema格式是否正确 + schema = self._function["parameters"] + Draft7Validator.check_schema(schema) + + while self._count < JSON_GEN_MAX_TRIAL: + self._count += 1 + try: + return await self._single_trial() + except Exception: # noqa: BLE001 + # 校验失败,_single_trial已经将错误信息记录到self._err_info中并记录日志 + continue + + return {} diff --git a/apps/llm/reasoning.py b/apps/llm/llm.py similarity index 81% rename from apps/llm/reasoning.py rename to apps/llm/llm.py index efa8d23b0df1d00d5d18cfa8429a0e70903479d1..fc005cd7fe73cdc5c81e3e757fac794a542f7da9 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/llm.py @@ -5,7 +5,7 @@ import logging from collections.abc import AsyncGenerator from apps.models import LLMData, LLMProvider -from apps.schemas.llm import LLMChunk +from apps.schemas.llm import LLMChunk, LLMFunctions from .providers import ( BaseProvider, @@ -20,7 +20,7 @@ _CLASS_DICT: dict[LLMProvider, type[BaseProvider]] = { } -class ReasoningLLM: +class LLM: """调用用于问答的大模型""" def __init__(self, llm_config: LLMData | None) -> None: @@ -43,14 +43,23 @@ class ReasoningLLM: *, include_thinking: bool = True, streaming: bool = True, + tools: list[LLMFunctions] | None = None, ) -> AsyncGenerator[LLMChunk, None]: """调用大模型,分为流式和非流式两种""" if streaming: - async for chunk in await self._provider.chat(messages, include_thinking=include_thinking): + async for chunk in await self._provider.chat( + messages, + include_thinking=include_thinking, + tools=tools, + ): yield chunk else: # 非流式模式下,需要收集所有内容 - async for _ in await self._provider.chat(messages, include_thinking=include_thinking): + async for _ in await self._provider.chat( + messages, + include_thinking=include_thinking, + tools=tools, + ): continue # 直接yield最终结果 yield LLMChunk( diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index ebe44c0ee2bd70a9d393727a4b45554e7aeb4737..9851b282292f16a4fdbc556ff9428e61e1a4bf26 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -4,51 +4,88 @@ from textwrap import dedent JSON_GEN_BASIC = dedent(r""" - Respond to the query according to the background information provided. - - # User query - - User query is given in XML tags. + + + You are an intelligent assistant who can use tools to help answer user queries. + Your task is to respond to the query according to the background information and available tools. + + Note: + - You have access to a set of tools that can help you gather information. + - You can use one tool at a time and will receive the result in the user's response. + - Use tools step-by-step to respond to the user's query, with each tool use informed by the \ +result of the previous tool use. + - The user's query is provided in the tags. + {% if previous_trial %}- Review the previous trial information in \ +tags to avoid repeating mistakes.{% endif %} + + - {{ query }} + {{ query }} - {% if previous_trial %} - # Previous Trial - You tried to answer the query with one function, but the arguments are incorrect. - - The arguments you provided are: - - ```json - {{ previous_trial }} - ``` - And the error information is: - - ``` - {{ err_info }} - ``` + + + You previously attempted to answer the query by calling a tool, but the arguments were incorrect. + + + {{ previous_trial }} + + + {{ err_info }} + + {% endif %} - # Tools - You have access to a set of tools. You can use one tool and will receive the result of \ -that tool use in the user's response. You use tools step-by-step to respond to the user's \ -query, with each tool use informed by the result of the previous tool use. + + You have access to a set of tools. You can use one tool and will receive the result of that tool \ +use in the user's response. + """) - JSON_NO_FUNCTION_CALL = dedent(r""" **Tool Use Formatting:** - Tool uses are formatted using XML-style tags. The tool name itself becomes the XML tag name. Each \ -parameter is enclosed within its own set of tags. Here's the structure: - - - value1 - value2 - ... - - - Always use the actual tool name as the XML tag name for proper parsing and execution. + Tool uses are formatted using XML-style tags. The tool name itself becomes the root XML tag name. \ +Each parameter is enclosed within its own set of tags according to the parameter schema provided below. + + **Basic Structure:** + + value + + + **Parameter Schema:** + The available tools and their parameter schemas are provided in the following format: + - Tool name: The name to use as the root XML tag + - Parameters: Each parameter has a name, type, and description + - Required parameters must be included + - Optional parameters can be omitted + + **XML Generation Rules:** + 1. Use the exact tool name as the root XML tag + 2. For each parameter, create a nested tag with the parameter name + 3. Place the parameter value inside the corresponding tag + 4. For string values: text value + 5. For numeric values: 123 + 6. For boolean values: true or false + 7. For array values: wrap each item in the parameter tag + item1 + item2 + 8. For object values: nest the object properties as sub-tags + + value1 + value2 + + + **Example:** + If you need to use a tool named "search" with parameters query (string) and limit (number): + + + your search text + 10 + + + Always use the actual tool name as the root XML tag and match parameter names exactly as specified \ +in the schema for proper parsing and execution. """) diff --git a/apps/llm/providers/openai.py b/apps/llm/providers/openai.py index 49e9890982e0056a7e2a588d583bb58788bdce02..0c2d3d4584b5cdfefde6a1c2f59fb3ce08d1902c 100644 --- a/apps/llm/providers/openai.py +++ b/apps/llm/providers/openai.py @@ -3,15 +3,12 @@ import logging from collections.abc import AsyncGenerator +from typing import cast from openai import AsyncOpenAI, AsyncStream from openai.types.chat import ( - ChatCompletionAssistantMessageParam, ChatCompletionChunk, ChatCompletionMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionToolMessageParam, - ChatCompletionUserMessageParam, ) from typing_extensions import override @@ -187,35 +184,23 @@ class OpenAIProvider(BaseProvider): def _convert_messages(self, messages: list[dict[str, str]]) -> list[ChatCompletionMessageParam]: - """将list[dict]形式的消息转换为list[ChatCompletionMessageParam]形式""" - converted_messages = [] + """确保消息格式符合OpenAI API要求,特别是tool消息需要tool_call_id字段""" + result: list[dict[str, str]] = [] for msg in messages: role = msg.get("role", "user") - content = msg.get("content", "") - - if role == "system": - converted_messages.append( - ChatCompletionSystemMessageParam(role="system", content=content), - ) - elif role == "user": - converted_messages.append( - ChatCompletionUserMessageParam(role="user", content=content), - ) - elif role == "assistant": - converted_messages.append( - ChatCompletionAssistantMessageParam(role="assistant", content=content), - ) - elif role == "tool": - tool_call_id = msg.get("tool_call_id", "") - converted_messages.append( - ChatCompletionToolMessageParam( - role="tool", - content=content, - tool_call_id=tool_call_id, - ), - ) - else: + + # 验证角色的有效性 + if role not in ("system", "user", "assistant", "tool"): err = f"[OpenAIProvider] 未知角色: {role}" _logger.error(err) raise ValueError(err) - return converted_messages + + # 对于tool角色,确保有tool_call_id字段 + if role == "tool" and "tool_call_id" not in msg: + result.append({**msg, "tool_call_id": ""}) + else: + result.append(msg) + + # 使用cast告诉类型检查器这是ChatCompletionMessageParam类型 + # OpenAI库在运行时接受dict格式,无需显式构造TypedDict对象 + return cast("list[ChatCompletionMessageParam]", result) diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index ee0d28995dc169eeb6244f41ec8fb08d4559b9ee..53c4894fdbfcc6c79c97a5dc7a6d9e3a15b9721b 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -9,7 +9,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field -from apps.llm import FunctionLLM +from apps.llm import JsonGenerator from apps.models import LanguageType, NodeInfo from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot as SlotProcessor @@ -75,7 +75,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): answer = "" async for chunk in self._llm(messages=conversation, streaming=True): answer += chunk - answer = await FunctionLLM.process_response(answer) + answer = await JsonGenerator.process_response(answer) try: data = json.loads(answer) except Exception: # noqa: BLE001 diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 18603aaa38e38fc23bd50bf1529f3282b86f5e72..bb3aa7d5c1d78d0860fc6197a9eeaaaae3ae7420 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -114,16 +114,20 @@ class MCPHost: {query} """ + function_definition = { + "name": tool.toolName, + "description": tool.description, + "parameters": tool.inputSchema, + } # 进行生成 json_generator = JsonGenerator( - self._llm.function, + self._llm, llm_query, [ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": await self.assemble_memory()}, ], - tool.inputSchema, + function_definition, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp/plan.py b/apps/scheduler/mcp/plan.py index f1fbc3818ffab293e492caeadaf6a1aac8d471c3..e00648a14ecd289d8c9751a7e5ff1bf2640fca1e 100644 --- a/apps/scheduler/mcp/plan.py +++ b/apps/scheduler/mcp/plan.py @@ -71,19 +71,27 @@ class MCPPlanner: _logger.error(err) raise RuntimeError(err) - # 格式化Prompt + # 构造标准 OpenAI Function 格式 schema = MCPPlan.model_json_schema() schema["properties"]["plans"]["maxItems"] = max_steps + function_def = { + "name": "parse_mcp_plan", + "description": ( + "Parse the reasoning result into a structured MCP plan with multiple steps. " + "Each step should include a step ID, content description, tool name, and instruction." + ), + "parameters": schema, + } + # 使用Function模型解析结果 json_generator = JsonGenerator( - self._llm.function, + self._llm, result, [ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": result}, ], - schema, + function_def, ) plan = await json_generator.generate() return MCPPlan.model_validate(plan) diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index 2d9a67ac362676a4450106096cb5a07dbe5a2a10..bd6246919b69904d5872f0a73e42a989e7d2ad6c 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -513,3 +513,36 @@ MEMORY_TEMPLATE: dict[str, str] = { """, ), } + +MCP_FUNCTION_SELECT: dict[LanguageType, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个专业的MCP (Model Context Protocol) 选择器。 + 你的任务是分析推理结果并选择最合适的MCP服务器ID。 + + --- 推理结果 --- + {reasoning_result} + --- 推理结果结束 --- + + 可用的MCP服务器ID: {mcp_ids} + + 请仔细分析推理结果,选择最符合需求的MCP服务器。 + 重点关注推理中描述的能力和需求,将其与正确的MCP服务器进行匹配。 + """, + ), + LanguageType.ENGLISH: dedent( + r""" + You are an expert MCP (Model Context Protocol) selector. + Your task is to analyze the reasoning result and select the most appropriate MCP server ID. + + --- Reasoning Result --- + {reasoning_result} + --- End of Reasoning --- + + Available MCP Server IDs: {mcp_ids} + + Please analyze the reasoning carefully and select the MCP server that best matches the requirements. + Focus on matching the capabilities and requirements described in the reasoning with the correct MCP server. + """, + ), +} diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 0714393149a7cd0f59379b392ab4a8dab6192394..17dd699f3e8132fd5ac0a1fe94dbca03113dac34 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -7,20 +7,23 @@ from sqlalchemy import select from apps.common.postgres import postgres from apps.llm import JsonGenerator -from apps.models import MCPTools +from apps.models import LanguageType, MCPTools from apps.schemas.mcp import MCPSelectResult from apps.schemas.scheduler import LLMConfig from apps.services.mcp_service import MCPServiceManager +from .prompt import MCP_FUNCTION_SELECT + logger = logging.getLogger(__name__) class MCPSelector: """MCP选择器""" - def __init__(self, llm: LLMConfig) -> None: + def __init__(self, llm: LLMConfig, language: LanguageType = LanguageType.CHINESE) -> None: """初始化MCP选择器""" self._llm = llm + self._language = language async def _call_reasoning(self, prompt: str) -> str: """调用大模型进行推理""" @@ -47,16 +50,25 @@ class MCPSelector: # schema中加入选项 schema["properties"]["mcp_id"]["enum"] = mcp_ids + # 组装OpenAI FunctionCall格式的dict + function = { + "name": "select_mcp", + "description": "Select the most appropriate MCP server based on the reasoning result", + "parameters": schema, + } + + # 构建优化的提示词 + user_prompt = MCP_FUNCTION_SELECT[self._language].format( + reasoning_result=reasoning_result, + mcp_ids=", ".join(mcp_ids), + ) + # 使用JsonGenerator生成JSON - conversation = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": reasoning_result}, - ] generator = JsonGenerator( - llm=self._llm.function, - query=reasoning_result, - conversation=conversation, - schema=schema, + llm_config=self._llm, + query=user_prompt, + conversation=[], + function=function, ) result = await generator.generate() diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 04326edf341a4536fe7f015473f6ac64f4f141c9..1a4d168db2f959277cc14f0199abb246a9560e70 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -39,27 +39,18 @@ class MCPBase: message, streaming=False, ): - result += chunk + result += chunk.content or "" return result - async def _parse_result(self, result: str, schema: dict[str, Any]) -> dict[str, Any]: - """解析推理结果""" - if not self._llm.function: - err = "[MCPBase] 未找到函数调用模型" - _logger.error(err) - raise RuntimeError(err) - - json_result = await JsonGenerator.parse_result_by_stack(result, schema) - if json_result is not None: - return json_result + async def get_json_result(self, result: str, function: dict[str, Any]) -> dict[str, Any]: + """解析推理结果;function使用OpenAI标准Function格式""" json_generator = JsonGenerator( - self._llm.function, + self._llm, "Please provide a JSON response based on the above information and schema.\n\n", [ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": result}, ], - schema, + function, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index 5179cc66ce020b57dc8afac978a7fe30557f685f..a6f850c705a28625d2311136957220c92bbfe7c8 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -59,10 +59,9 @@ class MCPHost(MCPBase): background_info=await self.assemble_memory(runtime, context), ) _logger.info("[MCPHost] 填充工具参数: %s", prompt) - result = await self.get_reasoning_result(prompt) # 使用JsonGenerator解析结果 - return await self._parse_result( - result, + return await self.get_json_result( + prompt, mcp_tool.inputSchema, ) @@ -94,13 +93,19 @@ class MCPHost(MCPBase): params_description=params_description, ) + # 组装OpenAI Function标准的Function结构 + function = { + "name": mcp_tool.toolName, + "description": mcp_tool.description, + "parameters": mcp_tool.inputSchema, + } + json_generator = JsonGenerator( - self._llm.function, + self._llm, llm_query, [ - {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], - mcp_tool.inputSchema, + function, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 82b098484bad35a1d230c56e1abeaf023bf4e1d9..e11b7286ff49efe0ca670382379f38232e9e29ce 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -46,16 +46,22 @@ class MCPPlanner(MCPBase): """获取当前流程的名称""" template = _env.from_string(GENERATE_FLOW_NAME[self._language]) prompt = template.render(goal=self._goal) - result = await self.get_reasoning_result(prompt) - result = await self._parse_result(result, FlowName.model_json_schema()) + + # 组装OpenAI标准Function格式 + function = { + "name": "get_flow_name", + "description": "Generate a descriptive name for the current workflow based on the user's goal", + "parameters": FlowName.model_json_schema(), + } + + result = await self.get_json_result(prompt, function) return FlowName.model_validate(result) async def create_next_step(self, history: str, tools: list[MCPTools]) -> Step: """创建下一步的执行步骤""" - # 获取推理结果 + # 构建提示词 template = _env.from_string(GEN_STEP[self._language]) prompt = template.render(goal=self._goal, history=history, tools=tools) - result = await self.get_reasoning_result(prompt) # 解析为结构化数据 schema = Step.model_json_schema() @@ -63,7 +69,18 @@ class MCPPlanner(MCPBase): schema["properties"]["tool_id"]["enum"] = [] for tool in tools: schema["properties"]["tool_id"]["enum"].append(tool.id) - step = await self._parse_result(result, schema) + + # 组装OpenAI标准Function格式 + function = { + "name": "create_next_step", + "description": ( + "Create the next execution step in the workflow by selecting " + "an appropriate tool and defining its parameters" + ), + "parameters": schema, + } + + step = await self.get_json_result(prompt, function) logger.info("[MCPPlanner] 创建下一步的执行步骤: %s", step) # 使用Step模型解析结果 return Step.model_validate(step) @@ -75,8 +92,15 @@ class MCPPlanner(MCPBase): goal=self._goal, tools=tools, ) - result = await self.get_reasoning_result(prompt) - result = await self._parse_result(result, FlowRisk.model_json_schema()) + + # 组装OpenAI标准Function格式 + function = { + "name": "evaluate_flow_execution_risk", + "description": "Evaluate the potential risks and safety concerns of executing the entire workflow", + "parameters": FlowRisk.model_json_schema(), + } + + result = await self.get_json_result(prompt, function) # 使用FlowRisk模型解析结果 return FlowRisk.model_validate(result) @@ -87,7 +111,7 @@ class MCPPlanner(MCPBase): additional_info: str = "", ) -> ToolRisk: """获取MCP工具的风险评估结果""" - # 获取推理结果 + # 构建提示词 template = _env.from_string(RISK_EVALUATE[self._language]) prompt = template.render( tool_name=tool.toolName, @@ -95,10 +119,18 @@ class MCPPlanner(MCPBase): input_param=input_param, additional_info=additional_info, ) - result = await self.get_reasoning_result(prompt) - schema = ToolRisk.model_json_schema() - risk = await self._parse_result(result, schema) + # 组装OpenAI标准Function格式 + function = { + "name": "evaluate_tool_risk", + "description": ( + "Evaluate the risk level and safety concerns of executing " + "a specific tool with given parameters" + ), + "parameters": ToolRisk.model_json_schema(), + } + + risk = await self.get_json_result(prompt, function) # 返回风险评估结果 return ToolRisk.model_validate(risk) @@ -122,10 +154,15 @@ class MCPPlanner(MCPBase): input_params=input_params, error_message=error_message, ) - result = await self.get_reasoning_result(prompt) - # 解析为结构化数据 - schema = IsParamError.model_json_schema() - is_param_error = await self._parse_result(result, schema) + + # 组装OpenAI标准Function格式 + function = { + "name": "check_parameter_error", + "description": "Determine whether an error message indicates a parameter-related error", + "parameters": IsParamError.model_json_schema(), + } + + is_param_error = await self.get_json_result(prompt, function) # 使用IsParamError模型解析结果 return IsParamError.model_validate(is_param_error) @@ -157,9 +194,16 @@ class MCPPlanner(MCPBase): schema=schema_with_null, error_message=error_message, ) - result = await self.get_reasoning_result(prompt) + + # 组装OpenAI标准Function格式 + function = { + "name": "get_missing_parameters", + "description": "Extract and provide the missing or incorrect parameters based on error feedback", + "parameters": schema_with_null, + } + # 解析为结构化数据 - return await self._parse_result(result, schema_with_null) + return await self.get_json_result(prompt, function) async def generate_answer(self, memory: str) -> AsyncGenerator[LLMChunk, None]: """生成最终回答""" diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 8295eb44c2a31a2022525325db005d2eae945a94..a08cb4547f0ec795f3bdde6eceb5ffd67158be80 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -11,7 +11,7 @@ from jinja2.sandbox import SandboxedEnvironment from apps.common.queue import MessageQueue from apps.common.security import Security -from apps.llm import Embedding, FunctionLLM, JsonGenerator, ReasoningLLM +from apps.llm import LLM, Embedding, JsonGenerator from apps.models import ( AppType, Conversation, @@ -275,7 +275,7 @@ class Scheduler: err = "[Scheduler] 获取问答用大模型ID失败" _logger.error(err) raise ValueError(err) - reasoning_llm = ReasoningLLM(reasoning_llm) + reasoning_llm = LLM(reasoning_llm) # 获取功能性的大模型信息 function_llm = None @@ -289,7 +289,7 @@ class Scheduler: self.user.userSub, self.user.functionLLM, ) else: - function_llm = FunctionLLM(function_llm) + function_llm = LLM(function_llm) embedding_llm = None if not self.user.embeddingLLM: diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 72ad05e54822d2d670ca9f2ac0c3ad8b52058937..9c7595ea3599dfbdce55c4a5be483a8794d9ef39 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.llm import Embedding, FunctionLLM, ReasoningLLM +from apps.llm import LLM, Embedding from apps.models import ExecutorHistory, LanguageType from .enum_var import CallOutputType @@ -15,8 +15,8 @@ from .enum_var import CallOutputType class LLMConfig(BaseModel): """LLM配置""" - reasoning: ReasoningLLM = Field(description="推理LLM") - function: FunctionLLM | None = Field(description="函数LLM") + reasoning: LLM = Field(description="推理LLM") + function: LLM | None = Field(description="函数LLM") embedding: Embedding | None = Field(description="Embedding") class Config: diff --git a/apps/services/task.py b/apps/services/task.py index 9a543e01bba9d6d6ce086c88897ed198babc94ec..6f9560511dc31596e3ef3ad689ec7941607b8061 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -154,14 +154,86 @@ class TaskManager: await session.commit() - @classmethod - async def save_task(cls, task_id: uuid.UUID, task: Task) -> None: - """保存任务块""" - task_collection = MongoDB().get_collection("task") - - # 更新已有的Task记录 - await task_collection.update_one( - {"_id": task_id}, - {"$set": task.model_dump(by_alias=True, exclude_none=True)}, - upsert=True, - ) + @staticmethod + async def save_task(task_id: uuid.UUID, task: Task) -> None: + """保存Task数据到PostgreSQL""" + async with postgres.session() as session: + # 查询是否存在该Task + existing_task = (await session.scalars( + select(Task).where(Task.id == task_id), + )).one_or_none() + + if existing_task: + # 更新现有Task + for key, value in task.__dict__.items(): + if not key.startswith("_"): + setattr(existing_task, key, value) + else: + # 插入新Task + session.add(task) + + await session.commit() + + + @staticmethod + async def save_task_runtime(task_runtime: TaskRuntime) -> None: + """保存TaskRuntime数据到PostgreSQL""" + async with postgres.session() as session: + # 查询是否存在该TaskRuntime + existing_runtime = (await session.scalars( + select(TaskRuntime).where(TaskRuntime.taskId == task_runtime.taskId), + )).one_or_none() + + if existing_runtime: + # 更新现有TaskRuntime + for key, value in task_runtime.__dict__.items(): + if not key.startswith("_"): + setattr(existing_runtime, key, value) + else: + # 插入新TaskRuntime + session.add(task_runtime) + + await session.commit() + + + @staticmethod + async def save_executor_checkpoint(checkpoint: ExecutorCheckpoint) -> None: + """保存ExecutorCheckpoint数据到PostgreSQL""" + async with postgres.session() as session: + # 查询是否存在该Checkpoint + existing_checkpoint = (await session.scalars( + select(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == checkpoint.taskId), + )).one_or_none() + + if existing_checkpoint: + # 更新现有Checkpoint + for key, value in checkpoint.__dict__.items(): + if not key.startswith("_"): + setattr(existing_checkpoint, key, value) + else: + # 插入新Checkpoint + session.add(checkpoint) + + await session.commit() + + + @staticmethod + async def save_flow_context(context: list[ExecutorHistory]) -> None: + """保存Flow上下文信息到PostgreSQL""" + async with postgres.session() as session: + for ctx in context: + # 查询是否存在该ExecutorHistory + existing_history = (await session.scalars( + select(ExecutorHistory).where(ExecutorHistory.id == ctx.id), + )).one_or_none() + + if existing_history: + # 更新现有History + for key, value in ctx.__dict__.items(): + if not key.startswith("_"): + setattr(existing_history, key, value) + else: + # 插入新History + session.add(ctx) + + await session.commit() diff --git a/design/call/slot.md b/design/call/slot.md index 0a48e6bdabbf6b28788dd0820c05b12e698f1fc6..251f14f0a474c6a13deea554715c9afc5258399f 100644 --- a/design/call/slot.md +++ b/design/call/slot.md @@ -2,7 +2,7 @@ ## 概述 -Slot工具是一个智能参数自动填充工具,它通过分析历史步骤数据、背景信息和用户问题,自动生成符合JSON Schema要求的参数对象。该工具使用FunctionCall进行精确的剩余参数填充。 +Slot工具是一个智能参数自动填充工具,它通过分析历史步骤数据、背景信息和用户问题,自动生成符合JSON Schema要求的参数对象。该工具使用JsonGenerator进行精确的剩余参数填充。 ## 功能特性 @@ -34,7 +34,7 @@ Slot工具的核心实现类,继承自`CoreCall`基类,负责参数自动填 - `_init()`: 初始化工具输入,处理历史数据和Schema - `_exec()`: 执行参数填充逻辑 - `_llm_slot_fill()`: 使用大语言模型填充参数 -- `_function_slot_fill()`: 使用FunctionCall填充剩余参数 +- `_function_slot_fill()`: 使用JsonGenerator填充剩余参数 ## 数据结构 @@ -100,7 +100,7 @@ graph LR B --> C[SlotProcessor] C --> D[SlotInput] D --> E[LLM填充] - E --> F[FunctionCall填充] + E --> F[JsonGenerator填充] F --> G[SlotOutput] subgraph "输入数据" @@ -115,7 +115,7 @@ graph LR B1[历史数据处理] B2[Schema验证] B3[LLM推理填充] - B4[FunctionCall精确填充] + B4[JsonGenerator精确填充] B5[数据转换和验证] end @@ -207,7 +207,7 @@ graph TD I --> J[检查剩余Schema] J --> K{仍有剩余Schema?} K -->|否| L[输出最终结果] - K -->|是| M[第二阶段:FunctionCall填充] + K -->|是| M[第二阶段:JsonGenerator填充] M --> N[数据转换] N --> O[最终Schema检查] O --> L @@ -220,9 +220,9 @@ graph TD G4[解析JSON响应] end - subgraph "FunctionCall填充阶段" + subgraph "JsonGenerator填充阶段" M1[构建对话上下文] - M2[调用JSON函数] + M2[调用JSON生成器] M3[获取结构化数据] end @@ -247,7 +247,7 @@ sequenceDiagram participant SP as SlotProcessor participant T as Jinja2Template participant L as LLM - participant F as FunctionLLM + participant J as JsonGenerator participant V as JSONValidator E->>S: instance(executor, node) @@ -268,8 +268,8 @@ sequenceDiagram S->>L: 调用LLM生成参数 L-->>S: 流式返回JSON响应 - S->>F: process_response(answer) - F-->>S: 处理后的JSON字符串 + S->>J: process_response(answer) + J-->>S: 处理后的JSON字符串 S->>V: json.loads(answer) V-->>S: 解析的JSON数据 @@ -279,8 +279,8 @@ sequenceDiagram SP-->>S: 剩余Schema alt 仍有剩余Schema - S->>F: _json(messages, schema) - F-->>S: 结构化JSON数据 + S->>J: _json(messages, schema) + J-->>S: 结构化JSON数据 S->>SP: convert_json(slot_data) SP-->>S: 转换后的数据 S->>SP: check_json(slot_data) @@ -298,15 +298,15 @@ sequenceDiagram ```mermaid graph LR A[原始Schema] --> B[LLM填充阶段] - B --> C[FunctionCall填充阶段] + B --> C[JsonGenerator填充阶段] C --> D[最终结果] - + B --> B1[生成自然语言提示词] B --> B2[LLM推理生成JSON] B --> B3[解析和验证JSON] - + C --> C1[构建结构化对话] - C --> C2[调用JSON生成函数] + C --> C2[调用JSON生成器] C --> C3[获取精确JSON数据] D --> D1[slot_data: 填充数据]