From cf24629104bc696bec491ddfb771d4dc7c8906d9 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Mon, 25 Aug 2025 17:03:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E5=90=88Agent=E5=88=86=E6=94=AF?= =?UTF-8?q?=E6=94=B9=E5=8A=A8=20&=20=E8=BF=81=E7=A7=BB=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/function.py | 18 ++- apps/llm/reasoning.py | 3 + apps/models/service.py | 10 +- apps/models/user.py | 2 +- apps/scheduler/executor/agent.py | 127 ++++++++-------- apps/scheduler/mcp/host.py | 24 +-- apps/scheduler/mcp/prompt.py | 219 ++++++++++++++++++++++++++++ apps/scheduler/mcp_agent/base.py | 18 +++ apps/scheduler/mcp_agent/host.py | 26 ++-- apps/scheduler/mcp_agent/plan.py | 16 +- apps/scheduler/pool/mcp/default.py | 66 +++++---- apps/scheduler/pool/mcp/pool.py | 21 ++- apps/scheduler/scheduler/message.py | 22 ++- apps/schemas/flow_topology.py | 6 +- apps/schemas/request_data.py | 1 + apps/schemas/response_data.py | 2 +- apps/services/flow.py | 8 +- apps/services/llm.py | 14 +- apps/services/rag.py | 54 +++---- apps/services/record.py | 9 +- 20 files changed, 477 insertions(+), 189 deletions(-) diff --git a/apps/llm/function.py b/apps/llm/function.py index 869dbb43..74d5543a 100644 --- a/apps/llm/function.py +++ b/apps/llm/function.py @@ -24,6 +24,8 @@ logger = logging.getLogger(__name__) class FunctionLLM: """用于FunctionCall的模型""" + timeout: float = 30.0 + def __init__(self) -> None: """ 初始化用于FunctionCall的模型 @@ -48,20 +50,28 @@ class FunctionLLM: } if self._config.backend == "ollama" and not self._config.api_key: - self._client = ollama.AsyncClient(host=self._config.endpoint) + self._client = ollama.AsyncClient( + host=self._config.endpoint, + timeout=self.timeout, + ) elif self._config.backend == "ollama" and self._config.api_key: self._client = ollama.AsyncClient( host=self._config.endpoint, headers={ "Authorization": f"Bearer {self._config.api_key}", }, + timeout=self.timeout, ) elif self._config.backend != "ollama" and not self._config.api_key: - self._client = openai.AsyncOpenAI(base_url=self._config.endpoint) + self._client = openai.AsyncOpenAI( + base_url=self._config.endpoint, + timeout=self.timeout, + ) elif self._config.backend != "ollama" and self._config.api_key: self._client = openai.AsyncOpenAI( base_url=self._config.endpoint, api_key=self._config.api_key, + timeout=self.timeout, ) @@ -235,7 +245,7 @@ class JsonGenerator: """JSON生成器""" @staticmethod - async def parse_result_by_stack(result: str, schema: dict[str, Any]) -> str: # noqa: C901, PLR0912 + async def parse_result_by_stack(result: str, schema: dict[str, Any]) -> dict[str, Any]: # noqa: C901, PLR0912 """解析推理结果""" validator = Draft7Validator(schema) left_index = result.find("{") @@ -289,7 +299,7 @@ class JsonGenerator: else: return tmp_js - return "" + return {} def __init__(self, query: str, conversation: list[dict[str, str]], schema: dict[str, Any]) -> None: """初始化JSON生成器""" diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index b765988c..f5d396a3 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -93,6 +93,7 @@ class ReasoningLLM: input_tokens: int = 0 output_tokens: int = 0 + timeout: float = 30.0 def __init__(self, llm_config: LLMConfig | None = None) -> None: """判断配置文件里用了哪种大模型;初始化大模型客户端""" @@ -108,12 +109,14 @@ class ReasoningLLM: if not self._config.key: self._client = AsyncOpenAI( base_url=self._config.endpoint, + timeout=self.timeout, ) return self._client = AsyncOpenAI( api_key=self._config.key, base_url=self._config.endpoint, + timeout=self.timeout, ) @staticmethod diff --git a/apps/models/service.py b/apps/models/service.py index 2341a5a2..6a8080e2 100644 --- a/apps/models/service.py +++ b/apps/models/service.py @@ -1,8 +1,10 @@ """插件 数据库表""" +import uuid from datetime import UTC, datetime from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column from apps.schemas.enum_var import PermissionType @@ -14,14 +16,14 @@ class Service(Base): """插件""" __tablename__ = "framework_service" - id: Mapped[str] = mapped_column(String(255), primary_key=True) - """插件ID""" name: Mapped[str] = mapped_column(String(255), nullable=False) """插件名称""" description: Mapped[str] = mapped_column(String(2000), nullable=False) """插件描述""" author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) """插件作者""" + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + """插件ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), @@ -41,7 +43,7 @@ class ServiceACL(Base): """插件权限""" __tablename__ = "framework_service_acl" - serviceId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 + serviceId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 """关联的插件ID""" userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 """用户名""" @@ -57,7 +59,7 @@ class ServiceHashes(Base): __tablename__ = "framework_service_hashes" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" - serviceId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 + serviceId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 """关联的插件ID""" filePath: Mapped[str] = mapped_column(String(255), nullable=False) # noqa: N815 """文件路径""" diff --git a/apps/models/user.py b/apps/models/user.py index 22e5fcba..e18918c9 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -36,7 +36,7 @@ class User(Base): """用户个人令牌""" selectedKB: Mapped[list[uuid.UUID]] = mapped_column(ARRAY(UUID), default=[], nullable=False) # noqa: N815 """用户选择的知识库的ID""" - defaultLLM: Mapped[uuid.UUID | None] = mapped_column(UUID(as_uuid=True), default=None, nullable=True) # noqa: N815 + defaultLLM: Mapped[str | None] = mapped_column(String(255), default=None, nullable=True) # noqa: N815 """用户选择的大模型ID""" autoExecute: Mapped[bool | None] = mapped_column(Boolean, default=False, nullable=True) # noqa: N815 """Agent是否自动执行""" diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 2fe3b88a..d2730c4d 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -57,6 +57,7 @@ class MCPAgentExecutor(BaseExecutor): async def init(self) -> None: """初始化MCP Agent""" self.planner = MCPPlanner(self.runtime.userInput, self.resoning_llm, self.runtime.language) + self.host = MCPHost(self.task.userSub, self.resoning_llm) async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" @@ -99,7 +100,7 @@ class MCPAgentExecutor(BaseExecutor): """获取工具输入参数""" if is_first: # 获取第一个输入参数 - mcp_tool = self.tools[self.task.state.tool_id] + mcp_tool = self.tools[self.state.toolId] self.state.currentInput = await MCPHost._get_first_input_params( mcp_tool, self.runtime.userInput, self.state.stepDescription, self.task, ) @@ -135,9 +136,9 @@ class MCPAgentExecutor(BaseExecutor): EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(exclude_none=True, by_alias=True), ) await self.push_message(EventType.FLOW_STOP, {}) - self.task.state.flow_status = FlowStatus.WAITING - self.task.state.step_status = StepStatus.WAITING - self.task.context.append( + self.state.executorStatus = ExecutorStatus.WAITING + self.state.stepStatus = StepStatus.WAITING + self.context.append( ExecutorHistory( task_id=self.task.id, step_id=self.task.state.step_id, @@ -155,22 +156,22 @@ class MCPAgentExecutor(BaseExecutor): async def run_step(self) -> None: """执行步骤""" - self.task.state.flow_status = FlowStatus.RUNNING - self.task.state.step_status = StepStatus.RUNNING - mcp_tool = self.tools[self.task.state.tool_id] + self.state.executorStatus = ExecutorStatus.RUNNING + self.state.stepStatus = StepStatus.RUNNING + mcp_tool = self.tools[self.state.toolId] mcp_client = (await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub)) try: - output_params = await mcp_client.call_tool(mcp_tool.name, self.task.state.current_input) + output_params = await mcp_client.call_tool(mcp_tool.name, self.state.currentInput) except anyio.ClosedResourceError: logger.exception("[MCPAgentExecutor] MCP客户端连接已关闭: %s", mcp_tool.mcp_id) await self.mcp_pool.stop(mcp_tool.mcp_id, self.task.ids.user_sub) await self.mcp_pool.init_mcp(mcp_tool.mcp_id, self.task.ids.user_sub) - self.task.state.step_status = StepStatus.ERROR + self.state.stepStatus = StepStatus.ERROR return except Exception as e: logger.exception("[MCPAgentExecutor] 执行步骤 %s 时发生错误", mcp_tool.name) - self.task.state.step_status = StepStatus.ERROR - self.task.state.error_message = str(e) + self.state.stepStatus = StepStatus.ERROR + self.state.errorMessage = str(e) return logger.error(f"当前工具名称: {mcp_tool.name}, 输出参数: {output_params}") if output_params.isError: @@ -223,7 +224,7 @@ class MCPAgentExecutor(BaseExecutor): error_message = await self.planner.change_err_message_to_description( error_message=self.state.errorMessage, tool=mcp_tool, - input_params=self.task.state.current_input, + input_params=self.state.currentInput, ) await self.push_message( EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null}, @@ -242,8 +243,8 @@ class MCPAgentExecutor(BaseExecutor): executorName=self.state.executorName, executorStatus=self.state.executorStatus, inputData={}, - output_data={}, - ex_data={ + outputData={}, + extraData={ "message": error_message, "params": params_with_null, }, @@ -292,8 +293,8 @@ class MCPAgentExecutor(BaseExecutor): ) if len(self.context) and self.context[-1].step_id == self.state.step_id: del self.context[-1] - self.task.context.append( - FlowStepHistory( + self.context.append( + ExecutorHistory( task_id=self.task.id, step_id=self.task.state.step_id, step_name=self.task.state.step_name, @@ -330,8 +331,8 @@ class MCPAgentExecutor(BaseExecutor): if len(self.context) and self.context[-1].stepId == self.state.stepId: del self.context[-1] else: - self.task.state.flow_status = FlowStatus.CANCELLED - self.task.state.step_status = StepStatus.CANCELLED + self.state.executorStatus = ExecutorStatus.CANCELLED + self.state.stepStatus = StepStatus.CANCELLED await self.push_message( EventType.STEP_CANCEL, data={}, @@ -340,17 +341,17 @@ class MCPAgentExecutor(BaseExecutor): EventType.FLOW_CANCEL, data={}, ) - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - self.task.context[-1].step_status = StepStatus.CANCELLED + if len(self.context) and self.context[-1].stepId == self.state.stepId: + self.context[-1].stepStatus = StepStatus.CANCELLED return max_retry = 5 for i in range(max_retry): if i != 0: await self.get_tool_input_param(is_first=True) await self.run_step() - if self.task.state.step_status == StepStatus.SUCCESS: + if self.state.stepStatus == StepStatus.SUCCESS: break - elif self.task.state.step_status == StepStatus.ERROR: + elif self.state.stepStatus == StepStatus.ERROR: # 错误处理 if self.task.state.retry_times >= 3: await self.error_handle_after_step() @@ -360,17 +361,17 @@ class MCPAgentExecutor(BaseExecutor): await self.push_message( EventType.STEP_ERROR, data={ - "message": self.task.state.error_message, + "message": self.state.errorMessage, } ) - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - self.task.context[-1].step_status = StepStatus.ERROR - self.task.context[-1].output_data = { - "message": self.task.state.error_message, + if len(self.context) and self.context[-1].stepId == self.state.stepId: + self.context[-1].stepStatus = StepStatus.ERROR + self.context[-1].outputData = { + "message": self.state.errorMessage, } else: - self.task.context.append( - FlowStepHistory( + self.context.append( + ExecutorHistory( task_id=self.task.id, step_id=self.task.state.step_id, step_name=self.task.state.step_name, @@ -387,15 +388,15 @@ class MCPAgentExecutor(BaseExecutor): ) await self.get_next_step() else: - mcp_tool = self.tools[self.task.state.tool_id] + mcp_tool = self.tools[self.state.toolId] is_param_error = await MCPPlanner.is_param_error( - self.task.runtime.question, + self.runtime.question, await MCPHost.assemble_memory(self.task), - self.task.state.error_message, + self.state.errorMessage, mcp_tool, - self.task.state.step_description, - self.task.state.current_input, - language=self.task.language, + self.state.stepDescription, + self.state.currentInput, + language=self.runtime.language, ) if is_param_error.is_param_error: # 如果是参数错误,生成参数补充 @@ -404,13 +405,13 @@ class MCPAgentExecutor(BaseExecutor): await self.push_message( EventType.STEP_ERROR, data={ - "message": self.task.state.error_message, + "message": self.state.errorMessage, } ) - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - self.task.context[-1].step_status = StepStatus.ERROR - self.task.context[-1].output_data = { - "message": self.task.state.error_message, + if len(self.context) and self.context[-1].stepId == self.state.stepId: + self.context[-1].stepStatus = StepStatus.ERROR + self.context[-1].outputData = { + "message": self.state.errorMessage, } else: self.task.context.append( @@ -430,33 +431,33 @@ class MCPAgentExecutor(BaseExecutor): ), ) await self.get_next_step() - elif self.task.state.step_status == StepStatus.SUCCESS: + elif self.state.stepStatus == StepStatus.SUCCESS: await self.get_next_step() async def summarize(self) -> None: """总结""" async for chunk in MCPPlanner.generate_answer( - self.task.runtime.question, + self.runtime.userInput, (await MCPHost.assemble_memory(self.task)), self.resoning_llm, - self.task.language, + self.runtime.language, ): await self.push_message( EventType.TEXT_ADD, data=chunk, ) - self.task.runtime.answer += chunk + self.runtime.fullAnswer += chunk async def run(self) -> None: """执行MCP Agent的主逻辑""" # 初始化MCP服务 await self.load_state() await self.load_mcp() - if self.task.state.flow_status == FlowStatus.INIT: + if self.state.executorStatus == ExecutorStatus.INIT: # 初始化状态 try: - self.task.state.flow_id = str(uuid.uuid4()) - self.task.state.flow_name = (await MCPPlanner.get_flow_name( + self.state.executorId = str(uuid.uuid4()) + self.state.executorName = (await MCPPlanner.get_flow_name( self.task.runtime.question, self.resoning_llm, self.task.language )).flow_name await TaskManager.save_task(self.task.id, self.task) @@ -470,14 +471,14 @@ class MCPAgentExecutor(BaseExecutor): data={}, ) return - self.task.state.flow_status = FlowStatus.RUNNING + self.state.executorStatus = ExecutorStatus.RUNNING await self.push_message( EventType.FLOW_START, data={}, ) - if self.task.state.tool_id == FINAL_TOOL_ID: + if self.state.toolId == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 - self.task.state.flow_status = FlowStatus.SUCCESS + self.state.executorStatus = ExecutorStatus.SUCCESS await self.push_message( EventType.FLOW_SUCCESS, data={}, @@ -485,16 +486,15 @@ class MCPAgentExecutor(BaseExecutor): await self.summarize() return try: - while self.task.state.flow_status == FlowStatus.RUNNING: - if self.task.state.tool_id == FINAL_TOOL_ID: + while self.state.executorStatus == ExecutorStatus.RUNNING: + if self.state.toolId == FINAL_TOOL_ID: break await self.work() await TaskManager.save_task(self.task.id, self.task) - tool_id = self.task.state.tool_id - if tool_id == FINAL_TOOL_ID: + if self.state.toolId == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 - self.task.state.flow_status = FlowStatus.SUCCESS - self.task.state.step_status = StepStatus.SUCCESS + self.state.executorStatus = ExecutorStatus.SUCCESS + self.state.stepStatus = StepStatus.SUCCESS await self.push_message( EventType.FLOW_SUCCESS, data={}, @@ -502,9 +502,9 @@ class MCPAgentExecutor(BaseExecutor): await self.summarize() except Exception as e: logger.exception("[MCPAgentExecutor] 执行过程中发生错误") - self.task.state.flow_status = FlowStatus.ERROR - self.task.state.error_message = str(e) - self.task.state.step_status = StepStatus.ERROR + self.state.executorName = ExecutorStatus.ERROR + self.state.errorMessage = str(e) + self.state.stepStatus = StepStatus.ERROR await self.push_message( EventType.STEP_ERROR, data={}, @@ -513,10 +513,10 @@ class MCPAgentExecutor(BaseExecutor): EventType.FLOW_FAILED, data={}, ) - if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: - del self.task.context[-1] - self.task.context.append( - FlowStepHistory( + if len(self.context) and self.context[-1].stepId == self.state.stepId: + del self.context[-1] + self.context.append( + ExecutorHistory( task_id=self.task.id, step_id=self.task.state.step_id, step_name=self.task.state.step_name, @@ -534,5 +534,4 @@ class MCPAgentExecutor(BaseExecutor): try: await self.mcp_pool.stop(mcp_service.id, self.task.ids.user_sub) except Exception as e: - import traceback logger.error("[MCPAgentExecutor] 停止MCP客户端时发生错误: %s", traceback.format_exc()) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index c1dccbaa..d28dc684 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -102,18 +102,18 @@ class MCPHost: } # 创建context;注意用法 - context = FlowStepHistory( - task_id=self._task_id, - flow_id=self._runtime_id, - flow_name=self._runtime_name, - flow_status=ExecutorStatus.RUNNING, - step_id=tool.id, - step_name=tool.toolName, + context = ExecutorHistory( + taskId=self._task_id, + executorId=self._runtime_id, + executorName=self._runtime_name, + executorStatus=ExecutorStatus.RUNNING, + stepId=tool.id, + stepName=tool.toolName, # description是规划的实际内容 - step_description=plan_item.content, - step_status=StepStatus.SUCCESS, - input_data=input_data, - output_data=output_data, + stepDescription=plan_item.content, + stepStatus=StepStatus.SUCCESS, + inputData=input_data, + outputData=output_data, ) # 保存到task @@ -122,7 +122,7 @@ class MCPHost: logger.error("任务 %s 不存在", self._task_id) return {} self._context_list.append(context.id) - task.context.append(context.model_dump(exclude_none=True, by_alias=True)) + context.append(context.model_dump(exclude_none=True, by_alias=True)) await TaskManager.save_task(self._task_id, task) return output_data diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index 100571c5..78b3c772 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -175,6 +175,225 @@ additional content: ), } +CREATE_PLAN: dict[str, str] = { + LanguageType.CHINESE: dedent( + r""" + 你是一个计划生成器。 + 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 + + # 一个好的计划应该: + + 1. 能够成功完成用户的目标 + 2. 计划中的每一个步骤必须且只能使用一个工具。 + 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 + 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + 5.生成的计划必须要覆盖用户的目标,不能遗漏任何用户目标中的内容。 + + # 生成计划时的注意事项: + + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: + + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` + + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应按步骤顺序放置在\ + XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终\ +结果。 + + + # 样例 + + ## 目标 + + 在后台运行一个新的alpine:latest容器,将主机/root文件夹挂载至/data,并执行top命令。 + + ## 计划 + + + 1. 这个目标需要使用Docker来完成,首先需要选择合适的MCP Server + 2. 目标可以拆解为以下几个部分: + - 运行alpine:latest容器 + - 挂载主机目录 + - 在后台运行 + - 执行top命令 + 3. 需要先选择MCP Server,然后生成Docker命令,最后执行命令 + + + ```json + { + "plans": [ + { + "content": "选择一个支持Docker的MCP Server", + "tool": "mcp_selector", + "instruction": "需要一个支持Docker容器运行的MCP Server" + }, + { + "content": "使用Result[0]中选择的MCP Server,生成Docker命令", + "tool": "command_generator", + "instruction": "生成Docker命令:在后台运行alpine:latest容器,挂载/root到/data,执行top命令" + }, + { + "content": "在Result[0]的MCP Server上执行Result[1]生成的命令", + "tool": "command_executor", + "instruction": "执行Docker命令" + }, + { + "content": "任务执行完成,容器已在后台运行,结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始生成计划: + + ## 目标 + + {{goal}} + + # 计划 + """, + ), + LanguageType.ENGLISH: dedent( + r""" + You are a plan generator. + Please analyze the user's goal and generate a plan. You will then follow this plan to achieve the user's \ +goal step by step. + + # A good plan should: + + 1. Be able to successfully achieve the user's goal. + 2. Each step in the plan must use only one tool. + 3. The steps in the plan must have clear and logical steps, without redundant or unnecessary steps. + 4. The last step in the plan must be a Final tool to ensure that the plan is executed. + + # Things to note when generating plans: + + - Each plan contains three parts: + - Plan content: Describes the general content of a single plan step + - Tool ID: Must be selected from the tool list below + - Tool instructions: Rewrite the user's goal to make it more consistent with the tool's input \ +requirements + - Plans must be generated in the following format. Do not output any additional data: + + ```json + { + "plans": [ + { + "content":"Plan content", + "tool":"Tool ID", + "instruction":"Tool instructions" + } + ] + } + ``` + + - Before generating a plan, please think step by step, analyze the user's goal, and guide your next \ +steps. The thought process should be placed in sequential steps within XML tags. + - In the plan content, you can use "Result[]" to reference the results of the previous plan steps. \ +For example: "Result[3]" refers to the result after the third plan is executed. + - The plan should not have more than {{ max_num }} items, and each plan content should be less than \ +150 words. + + # Tools + + You can access and use some tools, which will be given in the XML tags. + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}}; {{ tool.description }} + {% endfor %} + - FinalEnd step. When this step is executed, \ + Indicates that the plan execution is completed and the result obtained will be used as the final \ +result. + + + # Example + + ## Target + + Run a new alpine:latest container in the background, mount the host/root folder to /data, and execute the \ +top command. + + ## Plan + + + 1. This goal needs to be completed using Docker. First, you need to select a suitable MCP Server. + 2. The goal can be broken down into the following parts: + - Run the alpine:latest container + - Mount the host directory + - Run in the background + - Execute the top command + 3. You need to select an MCP Server first, then generate the Docker command, and finally execute the \ +command. + + + ```json + { + "plans": [ + { + "content": "Select an MCP Server that supports Docker", + "tool": "mcp_selector", + "instruction": "You need an MCP Server that supports running Docker containers" + }, + { + "content": "Use the MCP Server selected in Result[0] to generate Docker commands", + "tool": "command_generator", + "instruction": "Generate Docker command: Run the alpine:latest container in the background, \ +mount /root to /data, and execute the top command" + }, + { + "content": "Execute the command generated by Result[1] on the MCP Server of Result[0]", + "tool": "command_executor", + "instruction": "Execute Docker command" + }, + { + "content": "Task execution completed, the container is running in the background, the result \ +is Result[2]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # Now start generating the plan: + + ## Goal + + {{goal}} + + # Plan + """, + ), +} + EVALUATE_PLAN: dict[LanguageType, str] = { LanguageType.CHINESE: dedent( r""" diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 4b9807ad..1b273ba8 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -2,7 +2,9 @@ """MCP基类""" import logging +from typing import Any +from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM logger = logging.getLogger(__name__) @@ -30,3 +32,19 @@ class MCPBase: result += chunk return result + + @staticmethod + async def _parse_result(result: str, schema: dict[str, Any]) -> dict[str, Any]: + """解析推理结果""" + json_result = await JsonGenerator.parse_result_by_stack(result, schema) + if json_result is not None: + return json_result + json_generator = JsonGenerator( + "Please provide a JSON response based on the above information and schema.\n\n", + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": result}, + ], + schema, + ) + return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index 88c0cbb0..3ba5a866 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -10,12 +10,12 @@ from jinja2.sandbox import SandboxedEnvironment from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM -from apps.models.task import TaskRuntime +from apps.models.mcp import MCPTools +from apps.models.task import ExecutorHistory, TaskRuntime from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import GEN_PARAMS, REPAIR_PARAMS from apps.schemas.enum_var import LanguageType -from apps.schemas.mcp import MCPTool logger = logging.getLogger(__name__) _env = SandboxedEnvironment( @@ -48,36 +48,36 @@ class MCPHost(MCPBase): self.llm = llm @staticmethod - async def assemble_memory(runtime: TaskRuntime) -> str: + async def assemble_memory(runtime: TaskRuntime, context: list[ExecutorHistory]) -> str: """组装记忆""" return _env.from_string(MEMORY_TEMPLATE[runtime.language]).render( - context_list=runtime.context, + context_list=context, ) async def get_first_input_params( - self, mcp_tool: MCPTool, current_goal: str, runtime: TaskRuntime, + self, mcp_tool: MCPTools, current_goal: str, runtime: TaskRuntime, context: list[ExecutorHistory], ) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate prompt = _env.from_string(GEN_PARAMS[runtime.language]).render( - tool_name=mcp_tool.name, + tool_name=mcp_tool.toolName, tool_description=mcp_tool.description, goal=self.goal, current_goal=current_goal, - input_schema=mcp_tool.input_schema, - background_info=await MCPHost.assemble_memory(runtime), + input_schema=mcp_tool.inputSchema, + background_info=await self.assemble_memory(runtime, context), ) logger.info("[MCPHost] 填充工具参数: %s", prompt) result = await self.get_resoning_result(prompt) # 使用JsonGenerator解析结果 return await MCPHost._parse_result( result, - mcp_tool.input_schema, + mcp_tool.inputSchema, ) async def fill_params( # noqa: D102, PLR0913 self, - mcp_tool: MCPTool, + mcp_tool: MCPTools, current_goal: str, current_input: dict[str, Any], language: LanguageType, @@ -87,11 +87,11 @@ class MCPHost(MCPBase): ) -> dict[str, Any]: llm_query = LLM_QUERY_FIX[language] prompt = _env.from_string(REPAIR_PARAMS[language]).render( - tool_name=mcp_tool.name, + tool_name=mcp_tool.toolName, goal=self.goal, current_goal=current_goal, tool_description=mcp_tool.description, - input_schema=mcp_tool.input_schema, + input_schema=mcp_tool.inputSchema, current_input=current_input, error_message=error_message, params=params, @@ -104,6 +104,6 @@ class MCPHost(MCPBase): {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ], - mcp_tool.input_schema, + mcp_tool.inputSchema, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index 98b76946..c658aa7f 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -87,7 +87,7 @@ class MCPPlanner(MCPBase): # 获取推理结果 template = _env.from_string(RISK_EVALUATE[self.language]) prompt = template.render( - tool_name=tool.name, + tool_name=tool.toolName, tool_description=tool.description, input_param=input_param, additional_info=additional_info, @@ -114,7 +114,7 @@ class MCPPlanner(MCPBase): goal=self.goal, history=history, step_id=tool.id, - step_name=tool.name, + step_name=tool.toolName, step_description=step_description, input_params=input_params, error_message=error_message, @@ -133,20 +133,22 @@ class MCPPlanner(MCPBase): template = _env.from_string(CHANGE_ERROR_MESSAGE_TO_DESCRIPTION[self.language]) prompt = template.render( error_message=error_message, - tool_name=tool.name, + tool_name=tool.toolName, tool_description=tool.description, - input_schema=tool.input_schema, + input_schema=tool.inputSchema, input_params=input_params, ) return await self.get_resoning_result(prompt) - async def get_missing_param(self, tool: MCPTools, input_param: dict[str, Any], error_message: str) -> dict[str, Any]: + async def get_missing_param( + self, tool: MCPTools, input_param: dict[str, Any], error_message: str, + ) -> dict[str, Any]: """获取缺失的参数""" - slot = Slot(schema=tool.input_schema) + slot = Slot(schema=tool.inputSchema) template = _env.from_string(GET_MISSING_PARAMS[self.language]) schema_with_null = slot.add_null_to_basic_types() prompt = template.render( - tool_name=tool.name, + tool_name=tool.toolName, tool_description=tool.description, input_param=input_param, schema=schema_with_null, diff --git a/apps/scheduler/pool/mcp/default.py b/apps/scheduler/pool/mcp/default.py index bf9b7d94..e97b6b46 100644 --- a/apps/scheduler/pool/mcp/default.py +++ b/apps/scheduler/pool/mcp/default.py @@ -12,34 +12,42 @@ from apps.schemas.mcp import ( ) logger = logging.getLogger(__name__) -DEFAULT_STDIO = MCPServerConfig( - name="MCP服务_" + uuid.uuid4().hex[:6], - description="MCP服务描述", - mcpType=MCPType.STDIO, - config=MCPServerStdioConfig( - command="uvx", - args=[ - "your_package", - ], - env={ - "EXAMPLE_ENV": "example_value", - }, - ), + +def _default_stdio_config(hex_id: str) -> MCPServerConfig: + """默认的Stdio协议MCP Server配置""" + return MCPServerConfig( + name="MCP服务_" + hex_id, + description="MCP服务描述", + mcpType=MCPType.STDIO, + mcpServers={ + f"MCP_{hex_id}": MCPServerStdioConfig( + command="uvx", + args=[ + "your_package", + ], + env={ + "EXAMPLE_ENV": "example_value", + }, + ), + }, ) -"""默认的Stdio协议MCP Server配置""" - -DEFAULT_SSE = MCPServerConfig( - name="MCP服务_" + uuid.uuid4().hex[:6], - description="MCP服务描述", - mcpType=MCPType.SSE, - config=MCPServerSSEConfig( - url="http://test.domain/sse", - env={ - "EXAMPLE_HEADER": "example_value", + + +def _default_sse_config(hex_id: str) -> MCPServerConfig: + """默认的SSE协议MCP Server配置""" + return MCPServerConfig( + name="MCP服务_" + hex_id, + description="MCP服务描述", + mcpType=MCPType.SSE, + mcpServers={ + f"MCP_{hex_id}": MCPServerSSEConfig( + url="http://test.domain/sse", + env={ + "EXAMPLE_HEADER": "example_value", + }, + ), }, - ), -) -"""默认的SSE协议MCP Server配置""" + ) async def get_default(mcp_type: MCPType) -> MCPServerConfig: @@ -51,10 +59,12 @@ async def get_default(mcp_type: MCPType) -> MCPServerConfig: :rtype: MCPConfig :raises ValueError: 未找到默认的 MCP 配置 """ + random_id = uuid.uuid4().hex[:6] + if mcp_type == MCPType.STDIO: - return DEFAULT_STDIO + return _default_stdio_config(random_id) if mcp_type == MCPType.SSE: - return DEFAULT_SSE + return _default_sse_config(random_id) err = f"未找到默认的 MCP 配置: {mcp_type}" logger.error(err) raise ValueError(err) diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index cea8d6b5..fdd3e53f 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -3,9 +3,12 @@ import logging +from sqlalchemy import and_, select + +from apps.common.postgres import postgres from apps.common.singleton import SingletonMeta from apps.constants import MCP_PATH -from apps.models.mcp import MCPType +from apps.models.mcp import MCPActivated, MCPType from apps.schemas.mcp import MCPServerConfig from .client import MCPClient @@ -39,7 +42,7 @@ class MCPPool(metaclass=SingletonMeta): logger.warning("[MCPPool] 用户 %s 的MCP %s 类型错误", user_sub, mcp_id) return None - await client.init(user_sub, mcp_id, config.config) + await client.init(user_sub, mcp_id, config.mcpServers[mcp_id]) if user_sub not in self.pool: self.pool[user_sub] = {} self.pool[user_sub][mcp_id] = client @@ -59,10 +62,16 @@ class MCPPool(metaclass=SingletonMeta): async def _validate_user(self, mcp_id: str, user_sub: str) -> bool: """验证用户是否已激活""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub}) - return mcp_db_result is not None + async with postgres.session() as session: + result = (await session.scalars( + select(MCPActivated).where( + and_( + MCPActivated.mcpId == mcp_id, + MCPActivated.userSub == user_sub, + ), + ).limit(1), + )).one_or_none() + return result is not None async def get(self, mcp_id: str, user_sub: str) -> MCPClient | None: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 9d21433b..25f2f6d4 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -61,13 +61,15 @@ async def push_init_message( async def push_rag_message( task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], - rag_data: RAGQueryReq,) -> None: + rag_data: RAGQueryReq) -> None: """推送RAG消息""" full_answer = "" try: async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): task, content_obj = await _push_rag_chunk(task, queue, chunk) + if not content_obj: + continue if content_obj.event_type == EventType.TEXT_ADD.value: # 如果是文本消息,直接拼接到答案中 full_answer += content_obj.content @@ -83,20 +85,24 @@ async def push_rag_message( await TaskManager.save_task(task.id, task) -async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]: +async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData | None]: """推送RAG单个消息块""" # 如果是换行 if not content or not content.rstrip().rstrip("\n"): - return task, "" + return task, None try: content_obj = RAGEventData.model_validate_json(dedent(content[6:]).rstrip("\n")) # 如果是空消息 if not content_obj.content: - return task, "" + return task, None - task.tokens.input_tokens = content_obj.input_tokens - task.tokens.output_tokens = content_obj.output_tokens + await TaskManager.update_task_token( + task.id, + content_obj.input_tokens, + content_obj.output_tokens, + override=True, + ) await TaskManager.save_task(task.id, task) # 推送消息 @@ -118,11 +124,11 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), - createdAt=round(datetime.now(tz=UTC).timestamp(), 3), + createdAt=round(content_obj.content.get("created_at", datetime.now(tz=UTC).timestamp()), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: logger.exception("[Scheduler] RAG服务返回错误数据") - return task, "" + return task, None else: return task, content_obj diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index 768a1c4d..da3f0b07 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -46,8 +46,8 @@ class PositionItem(BaseModel): class NodeItem(BaseModel): """请求/响应中的节点变量类""" - step_id: str = Field(alias="stepId", default="") - service_id: str = Field(alias="serviceId", default="") + step_id: uuid.UUID = Field(alias="stepId", default=uuid.uuid4()) + service_id: uuid.UUID = Field(alias="serviceId", default=uuid.uuid4()) node_id: str = Field(alias="nodeId", default="") name: str = Field(default="") call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value) @@ -69,7 +69,7 @@ class EdgeItem(BaseModel): class FlowItem(BaseModel): """请求/响应中的流变量类""" - flow_id: uuid.UUID = Field(alias="flowId", default=uuid.UUID("00000000-0000-0000-0000-000000000000")) + flow_id: str = Field(alias="flowId", default="") name: str = Field(default="工作流名称") description: str = Field(default="工作流描述") enable: bool = Field(default=True) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 4aa17731..c03e63ea 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -147,6 +147,7 @@ class PutFlowReq(BaseModel): class UpdateLLMReq(BaseModel): """更新大模型请求体""" + llm_id: str | None = Field(default=None, description="大模型ID", alias="id") icon: str = Field(description="图标", default="") openai_base_url: str = Field(default="", description="OpenAI API Base URL", alias="openaiBaseUrl") openai_api_key: str = Field(default="", description="OpenAI API Key", alias="openaiApiKey") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 2ba8f82f..d672b831 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -570,7 +570,7 @@ class ListLLMProviderRsp(ResponseData): class LLMProviderInfo(BaseModel): """LLM数据结构""" - llm_id: uuid.UUID = Field(alias="llmId", description="LLM ID") + llm_id: str = Field(alias="llmId", description="LLM ID") icon: str = Field(default="", description="LLM图标", max_length=25536) openai_base_url: str = Field( default="https://api.openai.com/v1", diff --git a/apps/services/flow.py b/apps/services/flow.py index 92052a92..5a6cc0f6 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -161,7 +161,7 @@ class FlowManager: @staticmethod - async def get_flow_by_app_and_flow_id(app_id: uuid.UUID, flow_id: uuid.UUID) -> FlowItem | None: + async def get_flow_by_app_and_flow_id(app_id: uuid.UUID, flow_id: str) -> FlowItem | None: """ 通过appId flowId获取flow config的路径和focus,并通过flow config的路径获取flow config,并将其转换为flow item。 @@ -282,7 +282,7 @@ class FlowManager: @staticmethod async def put_flow_by_app_and_flow_id( - app_id: uuid.UUID, flow_id: uuid.UUID, flow_item: FlowItem, + app_id: uuid.UUID, flow_id: str, flow_item: FlowItem, ) -> None: """ 存储/更新flow的数据库数据和配置文件 @@ -345,7 +345,7 @@ class FlowManager: @staticmethod - async def delete_flow_by_app_and_flow_id(app_id: uuid.UUID, flow_id: uuid.UUID) -> None: + async def delete_flow_by_app_and_flow_id(app_id: uuid.UUID, flow_id: str) -> None: """ 删除flow的数据库数据和配置文件 @@ -375,7 +375,7 @@ class FlowManager: @staticmethod - async def update_flow_debug_by_app_and_flow_id(app_id: uuid.UUID, flow_id: uuid.UUID, *, debug: bool) -> bool: + async def update_flow_debug_by_app_and_flow_id(app_id: uuid.UUID, flow_id: str, *, debug: bool) -> bool: """ 更新flow的debug状态 diff --git a/apps/services/llm.py b/apps/services/llm.py index 256a87a1..dd095fe0 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -2,7 +2,6 @@ """大模型管理""" import logging -import uuid from sqlalchemy import select @@ -41,7 +40,7 @@ class LLMManager: @staticmethod - async def get_user_default_llm(user_sub: str) -> uuid.UUID | None: + async def get_user_default_llm(user_sub: str) -> str | None: """ 通过用户ID获取大模型ID @@ -60,7 +59,7 @@ class LLMManager: @staticmethod - async def get_llm(llm_id: uuid.UUID) -> LLMData | None: + async def get_llm(llm_id: str) -> LLMData | None: """ 通过ID获取大模型 @@ -81,7 +80,7 @@ class LLMManager: @staticmethod - async def list_llm(llm_id: uuid.UUID | None) -> list[LLMProviderInfo]: + async def list_llm(llm_id: str | None) -> list[LLMProviderInfo]: """ 获取大模型列表 @@ -119,7 +118,7 @@ class LLMManager: @staticmethod - async def update_llm(llm_id: uuid.UUID | None, req: UpdateLLMReq) -> uuid.UUID: + async def update_llm(llm_id: str, req: UpdateLLMReq) -> str: """ 创建大模型 @@ -144,6 +143,7 @@ class LLMManager: await session.commit() else: llm = LLMData( + id=llm_id, icon=req.icon, openaiBaseUrl=req.openai_base_url, openaiAPIKey=req.openai_api_key, @@ -156,7 +156,7 @@ class LLMManager: @staticmethod - async def delete_llm(user_sub: str, llm_id: uuid.UUID | None) -> None: + async def delete_llm(user_sub: str, llm_id: str | None) -> None: """ 删除大模型 @@ -192,7 +192,7 @@ class LLMManager: @staticmethod async def update_user_default_llm( user_sub: str, - llm_id: uuid.UUID, + llm_id: str, ) -> None: """更新用户的默认LLM""" async with postgres.session() as session: diff --git a/apps/services/rag.py b/apps/services/rag.py index f73ce3fd..caf2930c 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -70,7 +70,7 @@ class RAG: openEuler社区的目标是为用户提供一个稳定、安全、高效的操作系统平台,并且支持多种硬件架构。[[1]] - + {bac_info} @@ -173,7 +173,10 @@ class RAG: return doc_chunk_list @staticmethod - async def assemble_doc_info(doc_chunk_list: list[dict[str, Any]], max_tokens: int) -> str: + async def assemble_doc_info( + doc_chunk_list: list[dict[str, Any]], + max_tokens: int, + ) -> tuple[str, list[dict[str, Any]]]: """组装文档信息""" bac_info = "" doc_info_list = [] @@ -251,9 +254,9 @@ class RAG: return bac_info, doc_info_list @staticmethod - async def chat_with_llm_base_on_rag( + async def chat_with_llm_base_on_rag( # noqa: C901, PLR0913 user_sub: str, - llm: LLM, + llm: LLMData, history: list[dict[str, str]], doc_ids: list[str], data: RAGQueryReq, @@ -262,10 +265,10 @@ class RAG: """获取RAG服务的结果""" reasion_llm = ReasoningLLM( LLMConfig( - endpoint=llm.openai_base_url, - key=llm.openai_api_key, - model=llm.model_name, - max_tokens=llm.max_tokens, + endpoint=llm.openaiBaseUrl, + key=llm.openaiAPIKey, + model=llm.modelName, + max_tokens=llm.maxToken, ), ) if history: @@ -277,9 +280,9 @@ class RAG: except Exception: logger.exception("[RAG] 问题重写失败") doc_chunk_list = await RAG.get_doc_info_from_rag( - user_sub=user_sub, max_tokens=llm.max_tokens, doc_ids=doc_ids, data=data) + user_sub=user_sub, max_tokens=llm.maxToken, doc_ids=doc_ids, data=data) bac_info, doc_info_list = await RAG.assemble_doc_info( - doc_chunk_list=doc_chunk_list, max_tokens=llm.max_tokens) + doc_chunk_list=doc_chunk_list, max_tokens=llm.maxToken) messages = [ *history, { @@ -296,7 +299,7 @@ class RAG: ] input_tokens = TokenCalculator().calculate_token_length(messages=messages) output_tokens = 0 - doc_cnt = 0 + doc_cnt: int = 0 for doc_info in doc_info_list: doc_cnt = max(doc_cnt, doc_info["order"]) yield ( @@ -313,40 +316,41 @@ class RAG: + "\n\n" ) max_footnote_length = 4 - while doc_cnt > 0: - doc_cnt //= 10 + tmp_doc_cnt = doc_cnt + while tmp_doc_cnt > 0: + tmp_doc_cnt //= 10 max_footnote_length += 1 buffer = "" async for chunk in reasion_llm.call( messages, - max_tokens=llm.max_tokens, + max_tokens=llm.maxToken, streaming=True, temperature=0.7, result_only=False, - model=llm.model_name, + model=llm.modelName, ): - chunk = buffer + chunk + tmp_chunk = buffer + chunk # 防止脚注被截断 - if len(chunk) >= 2 and chunk[-2:] != "]]": - index = len(chunk) - 1 - while index >= max(0, len(chunk) - max_footnote_length) and chunk[index] != "]": + if len(tmp_chunk) >= 2 and tmp_chunk[-2:] != "]]": + index = len(tmp_chunk) - 1 + while index >= max(0, len(tmp_chunk) - max_footnote_length) and tmp_chunk[index] != "]": index -= 1 if index >= 0: - buffer = chunk[index + 1:] - chunk = chunk[:index + 1] + buffer = tmp_chunk[index + 1:] + tmp_chunk = tmp_chunk[:index + 1] else: buffer = "" # 匹配脚注 - footnotes = re.findall(r"\[\[\d+\]\]", chunk) + footnotes = re.findall(r"\[\[\d+\]\]", tmp_chunk) # 去除编号大于doc_cnt的脚注 footnotes = [fn for fn in footnotes if int(fn[2:-2]) > doc_cnt] footnotes = list(set(footnotes)) # 去重 if footnotes: for fn in footnotes: - chunk = chunk.replace(fn, "") + tmp_chunk = tmp_chunk.replace(fn, "") output_tokens += TokenCalculator().calculate_token_length( messages=[ - {"role": "assistant", "content": chunk}, + {"role": "assistant", "content": tmp_chunk}, ], pure_text=True, ) @@ -355,7 +359,7 @@ class RAG: + json.dumps( { "event_type": EventType.TEXT_ADD.value, - "content": chunk, + "content": tmp_chunk, "input_tokens": input_tokens, "output_tokens": output_tokens, }, diff --git a/apps/services/record.py b/apps/services/record.py index ae788518..c38ffac9 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -80,9 +80,14 @@ class RecordManager: order: Literal["desc", "asc"] = "desc", ) -> list[Record]: """查询ConversationID的最后n条问答对""" - sort_order = -1 if order == "desc" else 1 - async with postgres.session() as session: + sql = select(PgRecord).where( + PgRecord.conversationId == conversation_id, + PgRecord.userSub == user_sub, + ).order_by(PgRecord.createdAt.desc() if order == "desc" else PgRecord.createdAt.asc()) + if total_pairs is not None: + sql = sql.limit(total_pairs) + result = await session.scalars(sql) # 得到conversation的全部record_group id record_groups = await record_group_collection.aggregate( [ -- Gitee