From 680cc0384fba8f08dcda242bde6855666839ba7b Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 6 Aug 2025 11:37:25 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=88=E5=B9=B6Agent=E5=88=86=E6=94=AF?= =?UTF-8?q?=E6=94=B9=E5=8A=A8=E8=87=B3=20!614?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 5 +- apps/routers/mcp_service.py | 2 +- apps/routers/record.py | 8 +- apps/scheduler/call/choice/choice.py | 2 +- .../call/choice/condition_handler.py | 59 +-- apps/scheduler/call/choice/schema.py | 1 + apps/scheduler/call/core.py | 4 +- apps/scheduler/executor/flow.py | 36 +- apps/scheduler/mcp/host.py | 2 +- apps/scheduler/mcp_agent/host.py | 4 +- apps/scheduler/mcp_agent/plan.py | 188 ++++++-- apps/scheduler/mcp_agent/prompt.py | 440 ++++++++++++++++-- apps/scheduler/mcp_agent/select.py | 29 +- apps/scheduler/pool/loader/mcp.py | 16 +- apps/scheduler/scheduler/context.py | 10 +- apps/scheduler/scheduler/message.py | 26 +- apps/scheduler/scheduler/scheduler.py | 56 ++- apps/scheduler/slot/slot.py | 190 ++++++-- apps/schemas/enum_var.py | 7 +- apps/schemas/flow_topology.py | 6 +- apps/schemas/mcp.py | 27 +- apps/schemas/record.py | 14 +- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 20 +- apps/services/flow_validate.py | 34 +- apps/services/mcp_service.py | 4 +- apps/services/node.py | 4 + apps/services/parameter.py | 15 +- apps/services/rag.py | 3 - apps/services/task.py | 3 +- 30 files changed, 985 insertions(+), 231 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index b1e92b8aa..840c2fb9a 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -13,6 +13,7 @@ from apps.common.wordscheck import WordsCheck from apps.dependency import verify_personal_token, verify_session from apps.scheduler.scheduler import Scheduler from apps.scheduler.scheduler.context import save_data +from apps.schemas.enum_var import FlowStatus from apps.schemas.request_data import RequestData from apps.schemas.response_data import ResponseData from apps.schemas.task import Task @@ -82,8 +83,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 获取最终答案 task = scheduler.task - if not task.runtime.answer: - logger.error("[Chat] 答案为空") + if task.state.flow_status == FlowStatus.ERROR: + logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" await Activity.remove_active(user_sub) return diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index ab01d52c5..dd0c0765c 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -281,7 +281,7 @@ async def active_or_deactivate_mcp_service( """激活/取消激活mcp""" try: if data.active: - await MCPServiceManager.active_mcpservice(request.state.user_sub, mcpId) + await MCPServiceManager.active_mcpservice(request.state.user_sub, service_id, data.mcp_env) else: await MCPServiceManager.deactive_mcpservice(request.state.user_sub, mcpId) except Exception as e: diff --git a/apps/routers/record.py b/apps/routers/record.py index 4f424e814..02c629d02 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -84,16 +84,14 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path # 获得Record关联的文档 tmp_record.document = await DocumentManager.get_used_docs_by_record(user_sub, record_group.id) - # 获得Record关联的flow数据 flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) if flow_step_list: - first_step_history = FlowStepHistory.model_validate(flow_step_list[0]) tmp_record.flow = RecordFlow( - id=first_step_history.flow_name, # TODO: 此处前端应该用name + id=record.flow.flow_id, # TODO: 此处前端应该用name recordId=record.id, - flowStatus=first_step_history.flow_status, - flowId=first_step_history.id, + flowStatus=record.flow.flow_staus, + flowId=record.flow.flow_id, stepNum=len(flow_step_list), steps=[], ) diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 0d9ace600..c78cd4ab5 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -113,7 +113,7 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): valid_conditions.append(condition) # 如果所有条件都无效,抛出异常 - if not valid_conditions: + if not valid_conditions and not choice.is_default: msg = "分支没有有效条件" logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") continue diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index f8a9e40e3..3ee048252 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -1,12 +1,16 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """处理条件分支的工具""" - import logging from pydantic import BaseModel -from apps.scheduler.call.choice.schema import ChoiceBranch, Condition, Logic, Value +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + Condition, + Logic, + Value, +) from apps.schemas.parameters import ( BoolOperate, DictOperate, @@ -69,15 +73,17 @@ class ConditionHandler(BaseModel): @staticmethod def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" - default_branch = [c for c in choices if c.is_default] - - for block_judgement in choices: + for block_judgement in choices[::-1]: results = [] if block_judgement.is_default: - continue + return block_judgement.branch_id for condition in block_judgement.conditions: result = ConditionHandler._judge_condition(condition) - results.append(result) + if result is not None: + results.append(result) + if not results: + logger.warning(f"[Choice] 分支 {block_judgement.branch_id} 条件处理失败: 没有有效的条件") + continue if block_judgement.logic == Logic.AND: final_result = all(results) elif block_judgement.logic == Logic.OR: @@ -85,10 +91,6 @@ class ConditionHandler(BaseModel): if final_result: return block_judgement.branch_id - - # 如果没有匹配的分支,选择默认分支 - if default_branch: - return default_branch[0].branch_id return "" @staticmethod @@ -112,7 +114,7 @@ class ConditionHandler(BaseModel): if value_type == Type.STRING: result = ConditionHandler._judge_string_condition(left, operate, right) elif value_type == Type.NUMBER: - result = ConditionHandler._judge_int_condition(left, operate, right) + result = ConditionHandler._judge_number_condition(left, operate, right) elif value_type == Type.BOOL: result = ConditionHandler._judge_bool_condition(left, operate, right) elif value_type == Type.LIST: @@ -120,9 +122,9 @@ class ConditionHandler(BaseModel): elif value_type == Type.DICT: result = ConditionHandler._judge_dict_condition(left, operate, right) else: - logger.error("不支持的数据类型: %s", value_type) msg = f"不支持的数据类型: {value_type}" - raise ValueError(msg) + logger.error(f"[Choice] 条件处理失败: {msg}") + return None return result @staticmethod @@ -141,11 +143,10 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, str): - logger.error("左值不是字符串类型: %s", left_value) - msg = "左值必须是字符串类型" - raise TypeError(msg) + msg = f"左值必须是字符串类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value - result = False if operate == StringOperate.EQUAL: return left_value == right_value elif operate == StringOperate.NOT_EQUAL: @@ -189,9 +190,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, (int, float)): - logger.error("左值不是数字类型: %s", left_value) - msg = "左值必须是数字类型" - raise TypeError(msg) + msg = f"左值必须是数字类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value if operate == NumberOperate.EQUAL: return left_value == right_value @@ -223,9 +224,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, bool): - logger.error("左值不是布尔类型: %s", left_value) msg = "左值必须是布尔类型" - raise TypeError(msg) + logger.warning(msg) + return None right_value = right.value if operate == BoolOperate.EQUAL: return left_value == right_value @@ -253,9 +254,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, list): - logger.error("左值不是列表类型: %s", left_value) - msg = "左值必须是列表类型" - raise TypeError(msg) + msg = f"左值必须是列表类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value if operate == ListOperate.EQUAL: return left_value == right_value @@ -293,9 +294,9 @@ class ConditionHandler(BaseModel): """ left_value = left.value if not isinstance(left_value, dict): - logger.error("左值不是字典类型: %s", left_value) - msg = "左值必须是字典类型" - raise TypeError(msg) + msg = f"左值必须是字典类型 ({left_value})" + logger.warning(msg) + return None right_value = right.value if operate == DictOperate.EQUAL: return left_value == right_value diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index 9e886c577..1be23761b 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -28,6 +28,7 @@ class Value(DataBase): step_id: str | None = Field(description="步骤id", default=None) type: Type | None = Field(description="值的类型", default=None) + name: str | None = Field(description="值的名称", default=None) value: str | float | int | bool | list | dict | None = Field(description="值", default=None) diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 91f22409f..ec51c45f3 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -131,7 +131,7 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") - if len(split_path) < 2: + if len(split_path) < 1: err = f"[CoreCall] 路径格式错误: {path}" logger.error(err) return None @@ -141,7 +141,7 @@ class CoreCall(BaseModel): return None data = history[split_path[0]].output_data - for key in split_path[2:]: + for key in split_path[1:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 89f8326b8..0f4978a1e 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -47,6 +47,10 @@ class FlowExecutor(BaseExecutor): flow_id: str = Field(description="Flow ID") question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") + current_step: StepQueueItem | None = Field( + description="当前执行的步骤", + default=None, + ) async def load_state(self) -> None: @@ -73,13 +77,13 @@ class FlowExecutor(BaseExecutor): self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: + async def _invoke_runner(self) -> None: """单一Step执行""" # 创建步骤Runner step_runner = StepExecutor( msg_queue=self.msg_queue, task=self.task, - step=queue_item, + step=self.current_step, background=self.background, question=self.question, ) @@ -97,12 +101,12 @@ class FlowExecutor(BaseExecutor): """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: try: - queue_item = self.step_queue.pop() + self.current_step = self.step_queue.pop() except IndexError: break # 执行Step - await self._invoke_runner(queue_item) + await self._invoke_runner() async def _find_next_id(self, step_id: str) -> list[str]: @@ -119,17 +123,17 @@ class FlowExecutor(BaseExecutor): # 如果当前步骤为结束,则直接返回 if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - if self.task.state.step_name == "Choice": + if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID branch_id = self.task.context[-1]["output_data"]["branch_id"] if branch_id: - self.task.state.step_id = self.task.state.step_id + "." + branch_id + next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) else: logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") return [] - - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + else: + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -180,6 +184,7 @@ class FlowExecutor(BaseExecutor): self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] # 运行Flow(未达终点) + is_error = False while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] @@ -202,7 +207,7 @@ class FlowExecutor(BaseExecutor): enable_filling=False, to_user=False, )) - self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + is_error = True # 错误处理后结束 self._reached_end = True @@ -217,6 +222,12 @@ class FlowExecutor(BaseExecutor): for step in next_step: self.step_queue.append(step) + # 更新Task状态 + if is_error: + self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + else: + self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: self.step_queue.append(StepQueueItem( @@ -228,6 +239,7 @@ class FlowExecutor(BaseExecutor): # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time # 推送Flow停止消息 - await self.push_message(EventType.FLOW_STOP.value) - # 更新Flow状态 - self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + if is_error: + await self.push_message(EventType.FLOW_FAILED.value) + else: + await self.push_message(EventType.FLOW_SUCCESS.value) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 941e8822f..6b78e65db 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -101,7 +101,7 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, - flow_status=StepStatus.SUCCESS, + flow_status=StepStatus.RUNNING, step_id=tool.id, step_name=tool.toolName, # description是规划的实际内容 diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index acdd4871e..a8ebec7b4 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -123,7 +123,7 @@ class MCPHost: return output_data - async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]: + async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" @@ -139,7 +139,7 @@ class MCPHost: {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": await self.assemble_memory()}, ], - tool.input_schema, + schema, ) return await json_generator.generate() diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index cd4f5975e..ec90f57b8 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -1,19 +1,29 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 用户目标拆解与规划""" +from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from apps.llm.function import JsonGenerator from apps.llm.reasoning import ReasoningLLM -from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER -from apps.schemas.mcp import MCPPlan, MCPTool +from apps.scheduler.mcp_agent.prompt import ( + CREATE_PLAN, + EVALUATE_GOAL, + FINAL_ANSWER, + GENERATE_FLOW_NAME, + GET_MISSING_PARAMS, + RECREATE_PLAN, + RISK_EVALUATE, +) +from apps.scheduler.slot.slot import Slot +from apps.schemas.mcp import GoalEvaluationResult, MCPPlan, MCPTool, ToolRisk class MCPPlanner: """MCP 用户目标拆解与规划""" - def __init__(self, user_goal: str) -> None: + def __init__(self, user_goal: str, resoning_llm: ReasoningLLM = None) -> None: """初始化MCP规划器""" self.user_goal = user_goal self._env = SandboxedEnvironment( @@ -22,37 +32,19 @@ class MCPPlanner: trim_blocks=True, lstrip_blocks=True, ) + self.resoning_llm = resoning_llm or ReasoningLLM() self.input_tokens = 0 self.output_tokens = 0 - - async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: - """规划下一步的执行流程,并输出""" - # 获取推理结果 - result = await self._get_reasoning_plan(tool_list, max_steps) - - # 解析为结构化数据 - return await self._parse_plan_result(result, max_steps) - - - async def _get_reasoning_plan(self, tool_list: list[MCPTool], max_steps: int) -> str: - """获取推理大模型的结果""" - # 格式化Prompt - template = self._env.from_string(CREATE_PLAN) - prompt = template.render( - goal=self.user_goal, - tools=tool_list, - max_num=max_steps, - ) - + async def get_resoning_result(self, prompt: str) -> str: + """获取推理结果""" # 调用推理大模型 message = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] - reasoning_llm = ReasoningLLM() result = "" - async for chunk in reasoning_llm.call( + async for chunk in self.resoning_llm.call( message, streaming=False, temperature=0.07, @@ -61,19 +53,12 @@ class MCPPlanner: result += chunk # 保存token用量 - self.input_tokens = reasoning_llm.input_tokens - self.output_tokens = reasoning_llm.output_tokens - + self.input_tokens = self.resoning_llm.input_tokens + self.output_tokens = self.resoning_llm.output_tokens return result - - async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: - """将推理结果解析为结构化数据""" - # 格式化Prompt - schema = MCPPlan.model_json_schema() - schema["properties"]["plans"]["maxItems"] = max_steps - - # 使用Function模型解析结果 + async def _parse_result(self, result: str, schema: dict[str, Any]) -> str: + """解析推理结果""" json_generator = JsonGenerator( result, [ @@ -82,9 +67,138 @@ class MCPPlanner: ], schema, ) - plan = await json_generator.generate() + json_result = await json_generator.generate() + return json_result + + async def evaluate_goal(self, tool_list: list[MCPTool]) -> GoalEvaluationResult: + """评估用户目标的可行性""" + # 获取推理结果 + result = await self._get_reasoning_evaluation(tool_list) + + # 解析为结构化数据 + evaluation = await self._parse_evaluation_result(result) + + # 返回评估结果 + return evaluation + + async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str: + """获取推理大模型的评估结果""" + template = self._env.from_string(EVALUATE_GOAL) + prompt = template.render( + goal=self.user_goal, + tools=tool_list, + ) + result = await self.get_resoning_result(prompt) + return result + + async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult: + """将推理结果解析为结构化数据""" + schema = GoalEvaluationResult.model_json_schema() + evaluation = await self._parse_result(result, schema) + # 使用GoalEvaluationResult模型解析结果 + return GoalEvaluationResult.model_validate(evaluation) + + async def get_flow_name(self) -> str: + """获取当前流程的名称""" + result = await self._get_reasoning_flow_name() + return result + + async def _get_reasoning_flow_name(self) -> str: + """获取推理大模型的流程名称""" + template = self._env.from_string(GENERATE_FLOW_NAME) + prompt = template.render(goal=self.user_goal) + result = await self.get_resoning_result(prompt) + return result + + async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan: + """规划下一步的执行流程,并输出""" + # 获取推理结果 + result = await self._get_reasoning_plan(tool_list, max_steps) + + # 解析为结构化数据 + return await self._parse_plan_result(result, max_steps) + + async def _get_reasoning_plan( + self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(), + tool_list: list[MCPTool] = [], + max_steps: int = 10) -> str: + """获取推理大模型的结果""" + # 格式化Prompt + if is_replan: + template = self._env.from_string(RECREATE_PLAN) + prompt = template.render( + current_plan=current_plan, + error_message=error_message, + goal=self.user_goal, + tools=tool_list, + max_num=max_steps, + ) + else: + template = self._env.from_string(CREATE_PLAN) + prompt = template.render( + goal=self.user_goal, + tools=tool_list, + max_num=max_steps, + ) + result = await self.get_resoning_result(prompt) + return result + + async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan: + """将推理结果解析为结构化数据""" + # 格式化Prompt + schema = MCPPlan.model_json_schema() + schema["properties"]["plans"]["maxItems"] = max_steps + plan = await self._parse_result(result, schema) + # 使用Function模型解析结果 return MCPPlan.model_validate(plan) + async def get_tool_risk(self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "") -> ToolRisk: + """获取MCP工具的风险评估结果""" + # 获取推理结果 + result = await self._get_reasoning_risk(tool, input_parm, additional_info) + + # 解析为结构化数据 + risk = await self._parse_risk_result(result) + + # 返回风险评估结果 + return risk + + async def _get_reasoning_risk(self, tool: MCPTool, input_param: dict[str, Any], additional_info: str) -> str: + """获取推理大模型的风险评估结果""" + template = self._env.from_string(RISK_EVALUATE) + prompt = template.render( + tool=tool, + input_param=input_param, + additional_info=additional_info, + ) + result = await self.get_resoning_result(prompt) + return result + + async def _parse_risk_result(self, result: str) -> ToolRisk: + """将推理结果解析为结构化数据""" + schema = ToolRisk.model_json_schema() + risk = await self._parse_result(result, schema) + # 使用ToolRisk模型解析结果 + return ToolRisk.model_validate(risk) + + async def get_missing_param( + self, tool: MCPTool, schema: dict[str, Any], + input_param: dict[str, Any], + error_message: str) -> list[str]: + """获取缺失的参数""" + slot = Slot(schema=schema) + schema_with_null = slot.add_null_to_basic_types() + template = self._env.from_string(GET_MISSING_PARAMS) + prompt = template.render( + tool=tool, + input_param=input_param, + schema=schema_with_null, + error_message=error_message, + ) + result = await self.get_resoning_result(prompt) + # 解析为结构化数据 + input_param_with_null = await self._parse_result(result, schema_with_null) + return input_param_with_null async def generate_answer(self, plan: MCPPlan, memory: str) -> str: """生成最终回答""" diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index b322fb088..cf05afc87 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -62,6 +62,76 @@ MCP_SELECT = dedent(r""" ### 请一步一步思考: """) +EVALUATE_GOAL = dedent(r""" + 你是一个计划评估器。 + 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。 + 如果能够完成,请返回`true`,否则返回`false`。 + 推理过程必须清晰明了,能够让人理解你的判断依据。 + 必须按照以下格式回答: + ```json + { + "can_complete": true/false, + "resoning": "你的推理过程" + } + ``` + + # 样例 + ## 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + + ## 工具集合 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - mysql_analyzer分析MySQL数据库性能 + - performance_tuner调优数据库性能 + - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf + + ## + ```json + { + "can_complete": true, + "resoning": "当前的工具集合中包含mysql_analyzer和performance_tuner,能够完成对MySQL数据库的性能分析和调优,因此可以完成用户的目标。" + } + ``` + + # 目标 + {{ goal }} + + # 工具集合 + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + + + # 附加信息 + {{ additional_info }} + +""") +GENERATE_FLOW_NAME = dedent(r""" + 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。 + + # 生成流程名称时的注意事项: + 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。 + 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。 + 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。 + 4. 流程名称应该尽量简短,小于20个字或者单词。 + 5. 只输出流程名称,不要输出其他内容。 + # 样例 + ## 目标 + 我需要扫描当前mysql数据库,分析性能瓶颈,并调优 + ## 输出 + 扫描MySQL数据库并分析性能瓶颈,进行调优 + # 现在开始生成流程名称: + # 目标 + {{ goal }} + # 输出 + """) CREATE_PLAN = dedent(r""" 你是一个计划生成器。 请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。 @@ -106,8 +176,6 @@ CREATE_PLAN = dedent(r""" {% for tool in tools %} - {{ tool.id }}{{tool.name}};{{ tool.description }} {% endfor %} - - Final结束步骤,当执行到这一步时,\ -表示计划执行结束,所得到的结果将作为最终结果。 # 样例 @@ -163,78 +231,384 @@ CREATE_PLAN = dedent(r""" # 计划 """) -EVALUATE_PLAN = dedent(r""" - 你是一个计划评估器。 - 请根据给定的计划,和当前计划执行的实际情况,分析当前计划是否合理和完整,并生成改进后的计划。 +RECREATE_PLAN = dedent(r""" + 你是一个计划重建器。 + 请根据用户的目标、当前计划和运行报错,重新生成一个计划。 # 一个好的计划应该: 1. 能够成功完成用户的目标 2. 计划中的每一个步骤必须且只能使用一个工具。 3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。 - 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。 + 4. 你的计划必须避免之前的错误,并且能够成功执行。 + 5. 计划中的最后一步必须是Final工具,以确保计划执行结束。 - # 你此前的计划是: + # 生成计划时的注意事项: - {{ plan }} + - 每一条计划包含3个部分: + - 计划内容:描述单个计划步骤的大致内容 + - 工具ID:必须从下文的工具列表中选择 + - 工具指令:改写用户的目标,使其更符合工具的输入要求 + - 必须按照如下格式生成计划,不要输出任何额外数据: - # 这个计划的执行情况是: + ```json + { + "plans": [ + { + "content": "计划内容", + "tool": "工具ID", + "instruction": "工具指令" + } + ] + } + ``` - 计划的执行情况将放置在 XML标签中。 + - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\ +思考过程应放置在 XML标签中。 + - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。 + - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。 - - {{ memory }} - + # 样例 - # 进行评估时的注意事项: + ## 目标 - - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。 - - 评估结果分为两个部分: - - 计划评估的结论 - - 改进后的计划 - - 请按照以下JSON格式输出评估结果: + 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。 + ## 工具 + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + - command_generator生成命令行指令 + - tool_selector选择合适的工具 + - command_executor执行命令行指令 + - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。 + + ## 当前计划 ```json { - "evaluation": "评估结果", "plans": [ { - "content": "改进后的计划内容", - "tool": "工具ID", - "instruction": "工具指令" + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[2]", + "tool": "Final", + "instruction": "" } ] } ``` + ## 运行报错 + 执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 + ## 重新生成的计划 - # 现在开始评估计划: + + 1. 这个目标需要使用网络扫描工具来完成,首先需要选择合适的网络扫描工具 + 2. 目标可以拆解为以下几个部分: + - 生成端口扫描命令 + - 执行端口扫描命令 + 3.但是在执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。 + 4.我将计划调整为: + - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具 + - 执行这个命令,查看当前机器支持哪些网络扫描工具 + - 然后从中选择一个网络扫描工具 + - 基于选择的网络扫描工具,生成端口扫描命令 + - 执行端口扫描命令 + + ```json + { + "plans": [ + { + "content": "需要生成一条命令查看当前机器支持哪些网络扫描工具", + "tool": "command_generator", + "instruction": "选择一个前机器支持哪些网络扫描工具" + }, + { + "content": "执行Result[0]中生成的命令,查看当前机器支持哪些网络扫描工具", + "tool": "command_executor", + "instruction": "执行Result[0]中生成的命令" + }, + { + "content": "从Result[1]中选择一个网络扫描工具,生成端口扫描命令", + "tool": "tool_selector", + "instruction": "选择一个网络扫描工具,生成端口扫描命令" + }, + { + "content": "基于result[2]中选择的网络扫描工具,生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在Result[0]的MCP Server上执行Result[3]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[4]", + "tool": "Final", + "instruction": "" + } + ] + } + ``` + + # 现在开始重新生成计划: + + # 目标 + + {{goal}} + + # 工具 + + 你可以访问并使用一些工具,这些工具将在 XML标签中给出。 + + + {% for tool in tools %} + - {{ tool.id }}{{tool.name}};{{ tool.description }} + {% endfor %} + + + # 当前计划 + {{ current_plan }} + + # 运行报错 + {{ error_message }} + + # 重新生成的计划 """) +RISK_EVALUATE = dedent(r""" + 你是一个工具执行计划评估器。 + 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。 + ```json + { + "risk": "low/medium/high", + "message": "提示信息" + } + ``` + # 样例 + ## 工具名称 + mysql_analyzer + ## 工具描述 + 分析MySQL数据库性能 + ## 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + ## 附加信息 + 1. 当前MySQL数据库的版本是8.0.26 + 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项 + ```ini + [mysqld] + innodb_buffer_pool_size=1G + innodb_log_file_size=256M + ``` + ## 输出 + ```json + { + "risk": "中", + "message": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。" + } + ``` + # 工具 + + {{ tool.name }} + {{ tool.description }} + + # 工具入参 + {{ input_param }} + # 附加信息 + {{ additional_info }} + # 输出 + """ + ) +# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划 +JUDGE_NEXT_STEP = dedent(r""" + 你是一个计划决策器。 + 你的任务是根据当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。 + 请根据以下规则进行判断: + 1. 仅通过补充工具入参来解决问题的,返回 fill_params; + 2. 需要重计划当前步骤的,返回 replan_current_step; + 3. 需要重计划接下来的所有计划的,返回 replan_all_steps; + 你的输出要以json格式返回,格式如下: + ```json + { + "next_step": "fill_params/replan_current_step/replan_all_steps", + "reason": "你的判断依据" + } + ``` + 注意: + reason字段必须清晰明了,能够让人理解你的判断依据,并且不超过50个中文字或者100个英文单词。 + # 样例 + ## 当前计划 + {"plans": [ + { + "content": "生成端口扫描命令", + "tool": "command_generator", + "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口" + }, + { + "content": "在执行Result[0]生成的命令", + "tool": "command_executor", + "instruction": "执行端口扫描命令" + }, + { + "content": "任务执行完成,端口扫描结果为Result[2]", + "tool": "Final", + "instruction": "" + } + ]} + ## 当前使用的工具 + + command_executor + 执行命令行指令 + + ## 工具入参 + { + "command": "nmap -sS -p--open 192.168.1.1" + } + ## 工具运行报错 + 执行端口扫描命令时,出现了错误:`-bash: nmap: command not found`。 + ## 输出 + ```json + { + "next_step": "replan_all_steps", + "reason": "当前工具执行报错,提示nmap命令未找到,需要增加command_generator和command_executor的步骤,生成nmap安装命令并执行,之后再生成端口扫描命令并执行。" + } + ``` + # 当前计划 + {{ current_plan }} + # 当前使用的工具 + + {{ tool.name }} + {{ tool.description }} + + # 工具入参 + {{ input_param }} + # 工具运行报错 + {{ error_message }} + # 输出 + """ + ) +# 获取缺失的参数的json结构体 +GET_MISSING_PARAMS = dedent(r""" + 你是一个工具参数获取器。 + 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。 + ```json + { + "host": "请补充主机地址", + "port": "请补充端口号", + "username": "请补充用户名", + "password": "请补充密码" + } + ``` + # 样例 + # 工具名称 + mysql_analyzer + # 工具描述 + 分析MySQL数据库性能 + # 工具入参 + { + "host": "192.0.0.1", + "port": 3306, + "username": "root", + "password": "password" + } + # 工具入参schema + { + "type": "object", + "properties": { + "host": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的主机地址(可以为字符串或null)" + }, + "port": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的端口号(可以是数字、字符串或null)" + }, + "username": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的用户名(可以为字符串或null)" + }, + "password": { + "anyOf": [ + {"type": "string"}, + {"type": "null"} + ], + "description": "MySQL数据库的密码(可以为字符串或null)" + } + }, + "required": ["host", "port", "username", "password"] + } + # 运行报错 + 执行端口扫描命令时,出现了错误:`password is not correct`。 + # 输出 + ```json + { + "host": "192.0.0.1", + "port": 3306, + "username": null, + "password": null + } + ``` + # 工具 + < tool > + < name > {{tool.name}} < /name > + < description > {{tool.description}} < /description > + < / tool > + # 工具入参 + {{input_param}} + # 工具入参schema(部分字段允许为null) + {{input_schema}} + # 运行报错 + {{error_message}} + # 输出 + """ + ) FINAL_ANSWER = dedent(r""" 综合理解计划执行结果和背景信息,向用户报告目标的完成情况。 # 用户目标 - {{ goal }} + {{goal}} # 计划执行情况 为了完成上述目标,你实施了以下计划: - {{ memory }} + {{memory}} # 其他背景信息: - {{ status }} + {{status}} # 现在,请根据以上信息,向用户报告目标的完成情况: """) + MEMORY_TEMPLATE = dedent(r""" - {% for ctx in context_list %} - - 第{{ loop.index }}步:{{ ctx.step_description }} - - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}` - - 执行状态:{{ ctx.status }} - - 得到数据:`{{ ctx.output_data }}` - {% endfor %} + { % for ctx in context_list % } + - 第{{loop.index}}步:{{ctx.step_description}} + - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}` + - 执行状态:{{ctx.status}} + - 得到数据:`{{ctx.output_data}}` + { % endfor % } """) diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py index 2ff503447..7198940bc 100644 --- a/apps/scheduler/mcp_agent/select.py +++ b/apps/scheduler/mcp_agent/select.py @@ -2,6 +2,8 @@ """选择MCP Server及其工具""" import logging +import uuid +from collections.abc import AsyncGenerator from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -26,8 +28,9 @@ logger = logging.getLogger(__name__) class MCPSelector: """MCP选择器""" - def __init__(self) -> None: + def __init__(self, resoning_llm: ReasoningLLM = None) -> None: """初始化助手类""" + self.resoning_llm = resoning_llm or ReasoningLLM() self.input_tokens = 0 self.output_tokens = 0 @@ -101,20 +104,15 @@ class MCPSelector: return await self._call_function_mcp(result, mcp_ids) - async def _call_reasoning(self, prompt: str) -> str: + async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]: """调用大模型进行推理""" logger.info("[MCPHelper] 调用推理大模型") - llm = ReasoningLLM() message = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, ] - result = "" - async for chunk in llm.call(message): - result += chunk - self.input_tokens += llm.input_tokens - self.output_tokens += llm.output_tokens - return result + async for chunk in self.resoning_llm.call(message): + yield chunk async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult: @@ -182,4 +180,17 @@ class MCPSelector: tool_obj = MCPTool.model_validate(tool) llm_tool_list.append(tool_obj) + llm_tool_list.append( + MCPTool( + id="00000000-0000-0000-0000-000000000000", + name="Final", + description="It is the final step, indicating the end of the plan execution."), + ) + llm_tool_list.append( + MCPTool( + id="00000000-0000-0000-0000-000000000001", + name="Chat", + description="It is a chat tool to communicate with the user."), + ) + return llm_tool_list diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 00f064f33..16b1f6669 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -4,6 +4,7 @@ import json import logging import shutil +from typing import Any import asyncer from anyio import Path @@ -329,7 +330,7 @@ class MCPLoader(metaclass=SingletonMeta): @staticmethod - async def user_active_template(user_sub: str, mcp_id: str) -> None: + async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any]) -> None: """ 用户激活MCP模板 @@ -348,6 +349,8 @@ class MCPLoader(metaclass=SingletonMeta): err = f"MCP模板“{mcp_id}”已存在或有同名文件,无法激活" raise FileExistsError(err) + mcp_config = await MCPLoader.get_config(mcp_id) + mcp_config.mcpServers[mcp_id].env.update(mcp_env) # 拷贝文件 await asyncer.asyncify(shutil.copytree)( template_path.as_posix(), @@ -356,6 +359,17 @@ class MCPLoader(metaclass=SingletonMeta): symlinks=True, ) + user_config_path = user_path / "config.json" + # 更新用户配置 + f = await user_config_path.open("w", encoding="utf-8", errors="ignore") + await f.write( + json.dumps( + mcp_config.model_dump(by_alias=True, exclude_none=True), + indent=4, + ensure_ascii=False, + ), + ) + await f.aclose() # 更新数据库 async with postgres.session() as session: await session.merge(MCPActivated( diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 61be756b4..272a8c60e 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -11,6 +11,7 @@ from apps.models.document import Document from apps.models.record import Record, RecordFootNote, RecordMetadata from apps.schemas.enum_var import StepStatus from apps.schemas.record import ( + FlowHistory, RecordContent, RecordDocument, RecordGroupDocument, @@ -174,7 +175,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: id=task.ids.record_id, conversationId=task.ids.conversation_id, taskId=task.id, - user_sub=user_sub, + userSub=user_sub, content=encrypt_data, key=encrypt_config, metadata=RecordMetadata( @@ -185,7 +186,12 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: footNoteMetadataList=footnote_list, ), createdAt=current_time, - flow=[i["_id"] for i in task.context], + flow=FlowHistory( + flow_id=task.state.flow_id, + flow_name=task.state.flow_name, + flow_status=task.state.flow_status, + history_ids=[context["_id"] for context in task.context], + ), ) # 修改文件状态 diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index bbb02d3ef..5383c419f 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -8,7 +8,7 @@ from textwrap import dedent from apps.common.config import config from apps.common.queue import MessageQueue from apps.models.document import Document -from apps.schemas.enum_var import EventType +from apps.schemas.enum_var import EventType, FlowStatus from apps.schemas.message import ( DocumentAddContent, InitContent, @@ -61,22 +61,26 @@ async def push_init_message( async def push_rag_message( task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], - rag_data: RAGQueryReq,) -> Task: + rag_data: RAGQueryReq,) -> None: """推送RAG消息""" full_answer = "" - async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): - task, content_obj = await _push_rag_chunk(task, queue, chunk) - if content_obj.event_type == EventType.TEXT_ADD.value: - # 如果是文本消息,直接拼接到答案中 - full_answer += content_obj.content - elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append(content_obj.content) - + try: + async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): + task, content_obj = await _push_rag_chunk(task, queue, chunk) + if content_obj.event_type == EventType.TEXT_ADD.value: + # 如果是文本消息,直接拼接到答案中 + full_answer += content_obj.content + elif content_obj.event_type == EventType.DOCUMENT_ADD.value: + task.runtime.documents.append(content_obj.content) + task.state.flow_status = FlowStatus.SUCCESS + except Exception as e: + logger.error(f"[Scheduler] RAG服务发生错误: {e}") + task.state.flow_status = FlowStatus.ERROR # 保存答案 task.runtime.answer = full_answer + task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time await TaskManager.save_task(task.id, task) - return task async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]: diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ef0a747d7..d35bbb218 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Scheduler模块""" +import asyncio import logging from datetime import UTC, datetime @@ -16,11 +17,12 @@ from apps.scheduler.scheduler.message import ( push_init_message, push_rag_message, ) -from apps.schemas.enum_var import AppType, EventType +from apps.schemas.enum_var import AppType, EventType, FlowStatus from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground from apps.schemas.task import Task +from apps.services.activity import Activity from apps.services.appcenter import AppCenterManager from apps.services.knowledge import KnowledgeBaseManager from apps.services.llm import LLMManager @@ -43,6 +45,29 @@ class Scheduler: self.queue = queue self.post_body = post_body + + async def _monitor_activity(self, kill_event, user_sub): + """监控用户活动状态,不活跃时终止工作流""" + try: + check_interval = 0.5 # 每0.5秒检查一次 + + while not kill_event.is_set(): + # 检查用户活动状态 + is_active = await Activity.is_active(user_sub) + + if not is_active: + logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_sub) + kill_event.set() + break + + # 控制检查频率 + await asyncio.sleep(check_interval) + except asyncio.CancelledError: + logger.info("[Scheduler] 活动监控任务已取消") + except Exception as e: + logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") + + async def run(self) -> None: # noqa: PLR0911 """运行调度器""" try: @@ -92,6 +117,9 @@ class Scheduler: # 如果是智能问答,直接执行 logger.info("[Scheduler] 开始执行") + # 创建用于通信的事件 + kill_event = asyncio.Event() + monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub)) if not self.post_body.app or self.post_body.app.app_id == "": self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( @@ -99,8 +127,9 @@ class Scheduler: query=self.post_body.question, tokensLimit=llm.max_tokens, ) - self.task = await push_rag_message(self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data) - self.task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time + # 启动监控任务和主任务 + main_task = asyncio.create_task(push_rag_message( + self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data)) else: # 查找对应的App元数据 app_data = await AppCenterManager.fetch_app_data_by_id(self.post_body.app.app_id) @@ -124,7 +153,26 @@ class Scheduler: conversation=context, facts=facts, ) - await self.run_executor(self.queue, self.post_body, executor_background) + # 启动监控任务和主任务 + main_task = asyncio.create_task(self.run_executor(self.queue, self.post_body, executor_background)) + # 等待任一任务完成 + done, pending = await asyncio.wait( + [main_task, monitor], + return_when=asyncio.FIRST_COMPLETED + ) + + # 如果是监控任务触发,终止主任务 + if kill_event.is_set(): + logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...") + main_task.cancel() + need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING] + if self.task.state.flow_status in need_change_cancel_flow_state: + self.task.state.flow_status = FlowStatus.CANCELLED + try: + await main_task + logger.info("[Scheduler] 工作流执行已被终止") + except Exception as e: + logger.error(f"[Scheduler] 终止工作流时发生错误: {e}") # 更新Task,发送结束消息 logger.info("[Scheduler] 发送结束消息") diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index bb2fde49a..bda44a3cf 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """参数槽位管理""" +import copy import json import logging import traceback @@ -13,15 +14,14 @@ from jsonschema.protocols import Validator from jsonschema.validators import extend from apps.scheduler.call.choice.schema import Type -from apps.schemas.response_data import ParamsNode - -from .parser import ( +from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, SlotDefaultParser, SlotTimestampParser, ) -from .util import escape_path, patch_json +from apps.scheduler.slot.util import escape_path, patch_json +from apps.schemas.response_data import ParamsNode logger = logging.getLogger(__name__) @@ -129,7 +129,7 @@ class Slot: # Schema标准 return [_process_json_value(item, spec_data["items"]) for item in json_value] if spec_data["type"] == "object" and isinstance(json_value, dict): - # 若Schema不标准,则不进行处理 + # 若Schema不标准,则不进行处理F if "properties" not in spec_data: return json_value # Schema标准 @@ -157,35 +157,60 @@ class Slot: @staticmethod def _generate_example(schema_node: dict) -> Any: # noqa: PLR0911 """根据schema生成示例值""" + if "anyOf" in schema_node or "oneOf" in schema_node: + # 如果有anyOf,随机返回一个示例 + for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]: + example = Slot._generate_example(item) + if example is not None: + return example + + if "allOf" in schema_node: + # 如果有allOf,返回所有示例的合并 + example = None + for item in schema_node["allOf"]: + if example is None: + example = Slot._generate_example(item) + else: + other_example = Slot._generate_example(item) + if isinstance(example, dict) and isinstance(other_example, dict): + example.update(other_example) + else: + example = None + break + return example + if "default" in schema_node: return schema_node["default"] if "type" not in schema_node: return None - + type_value = schema_node["type"] + if isinstance(type_value, list): + # 如果是多类型,随机返回一个示例 + if len(type_value) > 1: + type_value = type_value[0] # 处理类型为 object 的节点 - if schema_node["type"] == "object": + if type_value == "object": data = {} properties = schema_node.get("properties", {}) for name, schema in properties.items(): data[name] = Slot._generate_example(schema) return data - # 处理类型为 array 的节点 - if schema_node["type"] == "array": + elif type_value == "array": items_schema = schema_node.get("items", {}) return [Slot._generate_example(items_schema)] # 处理类型为 string 的节点 - if schema_node["type"] == "string": + elif type_value == "string": return "" # 处理类型为 number 或 integer 的节点 - if schema_node["type"] in ["number", "integer"]: + elif type_value in ["number", "integer"]: return 0 # 处理类型为 boolean 的节点 - if schema_node["type"] == "boolean": + elif type_value == "boolean": return False # 处理其他类型或未定义类型 @@ -199,29 +224,69 @@ class Slot: """从JSON Schema中提取类型描述""" def _extract_type_desc(schema_node: dict[str, Any]) -> dict[str, Any]: - if "type" not in schema_node and "anyOf" not in schema_node: - return {} - data = {"type": schema_node.get("type", ""), "description": schema_node.get("description", "")} - if "anyOf" in schema_node: - data["type"] = "anyOf" - # 处理类型为 object 的节点 - if "anyOf" in schema_node: - data["items"] = {} - type_index = 0 - for type_index, sub_schema in enumerate(schema_node["anyOf"]): - sub_result = _extract_type_desc(sub_schema) - if sub_result: - data["items"]["type_"+str(type_index)] = sub_result - if schema_node.get("type", "") == "object": - data["items"] = {} + # 处理组合关键字 + special_keys = ["anyOf", "allOf", "oneOf"] + for key in special_keys: + if key in schema_node: + data = { + "type": key, + "description": schema_node.get("description", ""), + "items": {}, + } + type_index = 0 + for item in schema_node[key]: + if isinstance(item, dict): + data["items"][f"item_{type_index}"] = _extract_type_desc(item) + else: + data["items"][f"item_{type_index}"] = {"type": item, "description": ""} + type_index += 1 + return data + # 处理基本类型 + type_val = schema_node.get("type", "") + description = schema_node.get("description", "") + + # 处理多类型数组 + if isinstance(type_val, list): + if len(type_val) > 1: + data = {"type": "union", "description": description, "items": {}} + type_index = 0 + for t in type_val: + if t == "object": + tmp_dict = {} + for key, val in schema_node.get("properties", {}).items(): + tmp_dict[key] = _extract_type_desc(val) + data["items"][f"item_{type_index}"] = tmp_dict + elif t == "array": + items_schema = schema_node.get("items", {}) + data["items"][f"item_{type_index}"] = _extract_type_desc(items_schema) + else: + data["items"][f"item_{type_index}"] = {"type": t, "description": description} + type_index += 1 + return data + elif len(type_val) == 1: + type_val = type_val[0] + else: + type_val = "" + + data = {"type": type_val, "description": description, "items": {}} + + # 递归处理对象和数组 + if type_val == "object": for key, val in schema_node.get("properties", {}).items(): data["items"][key] = _extract_type_desc(val) - - # 处理类型为 array 的节点 - if schema_node.get("type", "") == "array": + elif type_val == "array": items_schema = schema_node.get("items", {}) - data["items"] = _extract_type_desc(items_schema) + if isinstance(items_schema, list): + item_index = 0 + for item in items_schema: + data["items"][f"item_{item_index}"] = _extract_type_desc(item) + item_index += 1 + else: + data["items"]["item"] = _extract_type_desc(items_schema) + if data["items"] == {}: + del data["items"] return data + return _extract_type_desc(self._schema) def get_params_node_from_schema(self, root: str = "") -> ParamsNode: @@ -232,13 +297,15 @@ class Slot: return None param_type = schema_node["type"] + if isinstance(param_type, list): + return None # 不支持多类型 if param_type == "object": param_type = Type.DICT elif param_type == "array": param_type = Type.LIST elif param_type == "string": param_type = Type.STRING - elif param_type == "number": + elif param_type in ["number", "integer"]: param_type = Type.NUMBER elif param_type == "boolean": param_type = Type.BOOL @@ -247,9 +314,11 @@ class Slot: return None sub_params = [] - if param_type == "object" and "properties" in schema_node: + if param_type == Type.DICT and "properties" in schema_node: for key, value in schema_node["properties"].items(): - sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}") + if sub_param: + sub_params.append(sub_param) else: # 对于非对象类型,直接返回空子参数 sub_params = None @@ -318,7 +387,6 @@ class Slot: logger.exception("[Slot] 错误schema不合法: %s", error.schema) return {}, [] - def _assemble_patch( self, key: str, @@ -371,7 +439,6 @@ class Slot: logger.info("[Slot] 组装patch: %s", patch_list) return patch_list - def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]: """将用户手动填充的参数专为真实JSON""" json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data @@ -415,3 +482,54 @@ class Slot: return schema_template return {} + + def add_null_to_basic_types(self) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + """ + def add_null_to_basic_types(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项 + + 参数: + schema (dict): 原始 JSON Schema + + 返回: + dict: 修改后的 JSON Schema + """ + # 如果不是字典类型(schema),直接返回 + if not isinstance(schema, dict): + return schema + + # 处理当前节点的 type 字段 + if 'type' in schema: + # 处理单一类型字符串 + if isinstance(schema['type'], str): + if schema['type'] in ['boolean', 'number', 'string', 'integer']: + schema['type'] = [schema['type'], 'null'] + + # 处理类型数组 + elif isinstance(schema['type'], list): + for i, t in enumerate(schema['type']): + if isinstance(t, str) and t in ['boolean', 'number', 'string', 'integer']: + if 'null' not in schema['type']: + schema['type'].append('null') + break + + # 递归处理 properties 字段(对象类型) + if 'properties' in schema: + for prop, prop_schema in schema['properties'].items(): + schema['properties'][prop] = add_null_to_basic_types(prop_schema) + + # 递归处理 items 字段(数组类型) + if 'items' in schema: + schema['items'] = add_null_to_basic_types(schema['items']) + + # 递归处理 anyOf, oneOf, allOf 字段 + for keyword in ['anyOf', 'oneOf', 'allOf']: + if keyword in schema: + schema[keyword] = [add_null_to_basic_types(sub_schema) for sub_schema in schema[keyword]] + + return schema + schema_copy = copy.deepcopy(self._schema) + return add_null_to_basic_types(schema_copy) \ No newline at end of file diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index a6a2d8124..d46f52890 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + UNKNOWN = "unknown" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" @@ -26,6 +27,7 @@ class StepStatus(str, Enum): class FlowStatus(str, Enum): """Flow状态""" + UNKNOWN = "unknown" WAITING = "waiting" RUNNING = "running" SUCCESS = "success" @@ -57,6 +59,7 @@ class EventType(str, Enum): STEP_OUTPUT = "step.output" FLOW_STOP = "flow.stop" FLOW_FAILED = "flow.failed" + FLOW_SUCCESS = "flow.success" FLOW_CANCELLED = "flow.cancelled" DONE = "done" @@ -90,7 +93,7 @@ class NodeType(str, Enum): START = "start" END = "end" NORMAL = "normal" - CHOICE = "choice" + CHOICE = "Choice" class PermissionType(str, Enum): @@ -145,7 +148,7 @@ class SpecialCallType(str, Enum): LLM = "LLM" START = "start" END = "end" - CHOICE = "choice" + CHOICE = "Choice" class CommentType(str, Enum): diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index 098393739..768a1c4d1 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -6,6 +6,8 @@ from typing import Any from pydantic import BaseModel, Field +from apps.schemas.enum_var import SpecialCallType + from .enum_var import EdgeType @@ -48,7 +50,7 @@ class NodeItem(BaseModel): service_id: str = Field(alias="serviceId", default="") node_id: str = Field(alias="nodeId", default="") name: str = Field(default="") - call_id: str = Field(alias="callId", default="Empty") + call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value) description: str = Field(default="") parameters: dict[str, Any] = Field(default={}) position: PositionItem = Field(default=PositionItem()) @@ -74,6 +76,6 @@ class FlowItem(BaseModel): nodes: list[NodeItem] = Field(default=[]) edges: list[EdgeItem] = Field(default=[]) created_at: float | None = Field(alias="createdAt", default=0) - connectivity: bool = Field(default=False,description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") + connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边") focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem()) debug: bool = Field(default=False) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 41f5b3034..c31542bc6 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,7 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" -import uuid from enum import Enum from pydantic import BaseModel, Field @@ -58,6 +57,28 @@ class MCPServerConfig(MCPServerItem): author: str = Field(description="MCP 服务器上传者", default="") +class GoalEvaluationResult(BaseModel): + """MCP 目标评估结果""" + + can_complete: bool = Field(description="是否可以完成目标") + reason: str = Field(description="评估原因") + + +class Risk(str, Enum): + """MCP工具风险类型""" + + LOW = "low" + MEDIUM = "medium" + HIGH = "high" + + +class ToolRisk(BaseModel): + """MCP工具风险评估结果""" + + risk: Risk = Field(description="风险类型", default=Risk.LOW) + reason: str = Field(description="风险原因", default="") + + class MCPSelectResult(BaseModel): """MCP选择结果""" @@ -67,7 +88,7 @@ class MCPSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - id: str = Field(default_factory=lambda: str(uuid.uuid4())) + step_id: str = Field(description="步骤的ID", default="") content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") @@ -76,4 +97,4 @@ class MCPPlanItem(BaseModel): class MCPPlan(BaseModel): """MCP 计划""" - plans: list[MCPPlanItem] = Field(description="计划列表") + plans: list[MCPPlanItem] = Field(description="计划列表", default=[]) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 1ea12c39b..578f567ce 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -7,7 +7,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field -from .enum_var import CommentType, StepStatus +from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus class RecordDocument(BaseModel): @@ -104,6 +104,15 @@ class RecordGroupDocument(BaseModel): created_at: float = Field(default=0.0, description="文档创建时间") +class FlowHistory(BaseModel): + """Flow执行历史""" + + flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + flow_name: str = Field(default="", description="Flow名称") + flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态") + history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表") + + class Record(RecordData): """问答,用于保存在MongoDB中""" @@ -111,4 +120,5 @@ class Record(RecordData): key: dict[str, Any] = {} content: str comment: RecordComment = Field(default=RecordComment()) - flow: list[str] = Field(default=[]) + flow: FlowHistory = Field( + default=FlowHistory(), description="Flow执行历史信息") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index ceca9e749..b369ddf1d 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -92,6 +92,7 @@ class ActiveMCPServiceRequest(BaseModel): """POST /api/mcp/{serviceId} 请求数据结构""" active: bool = Field(description="是否激活mcp服务") + mcp_env: dict[str, Any] = Field(default={}, description="MCP服务环境变量", alias="mcpEnv") class UpdateServiceRequest(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 4d49e1c10..e0a2b39ef 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -38,16 +38,18 @@ class ExecutorState(BaseModel): """FlowExecutor状态""" # 执行器级数据 - flow_id: str = Field(description="Flow ID") - flow_name: str = Field(description="Flow名称") - description: str = Field(description="Flow描述") - flow_status: FlowStatus = Field(description="Flow状态") + flow_id: str = Field(description="Flow ID", default="") + flow_name: str = Field(description="Flow名称", default="") + description: str = Field(description="Flow描述", default="") + flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN) # 任务级数据 - step_id: str = Field(description="当前步骤ID") - step_name: str = Field(description="当前步骤名称") - step_status: StepStatus = Field(description="当前步骤状态") + step_id: str = Field(description="当前步骤ID", default="") + step_name: str = Field(description="当前步骤名称", default="") + step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN) step_description: str = Field(description="当前步骤描述", default="") - app_id: str = Field(description="应用ID") + retry_times: int = Field(description="当前步骤重试次数", default=0) + error_message: str = Field(description="当前步骤错误信息", default="") + app_id: str = Field(description="应用ID", default="") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -92,7 +94,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) - state: ExecutorState | None = Field(description="Flow的状态", default=None) + state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index b0de757ae..619a121fb 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError from apps.scheduler.pool.pool import Pool -from apps.schemas.enum_var import NodeType +from apps.schemas.enum_var import NodeType, SpecialCallType from apps.schemas.flow_topology import EdgeItem, FlowItem, NodeItem if TYPE_CHECKING: @@ -43,32 +43,40 @@ class FlowService: """移除流程图中的多余结构""" node_branch_map = {} for node in flow_item.nodes: - if node.node_id not in {"start", "end", "Empty"}: + if node.node_id not in {"start", "end", SpecialCallType.EMPTY.value}: try: call_class: type[BaseModel] = await Pool().get_call(node.call_id) if not call_class: - node.node_id = "Empty" + node.node_id = SpecialCallType.EMPTY.value node.description = "【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n"+node.description except Exception: - node.node_id = "Empty" + node.node_id = SpecialCallType.EMPTY.value node.description = "【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n"+node.description logger.exception("[FlowService] 获取步骤的call_id失败%s", node.call_id) node_branch_map[node.step_id] = set() if node.call_id == NodeType.CHOICE.value: - node.parameters = node.parameters["input_parameters"] - if "choices" not in node.parameters: - node.parameters["choices"] = [] - for choice in node.parameters["choices"]: - if choice["branchId"] in node_branch_map[node.step_id]: - err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}重复" + input_parameters = node.parameters["input_parameters"] + if "choices" not in input_parameters: + logger.error(f"[FlowService] 节点{node.name}的分支字段缺失") + raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段缺失") + if not input_parameters["choices"]: + logger.error(f"[FlowService] 节点{node.name}的分支字段为空") + raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段为空") + for choice in input_parameters["choices"]: + if "branch_id" not in choice: + err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段" + logger.error(err) + raise FlowBranchValidationError(err) + if choice["branch_id"] in node_branch_map[node.step_id]: + err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}重复" logger.error(err) raise ValueError(err) for illegal_char in BRANCH_ILLEGAL_CHARS: - if illegal_char in choice["branchId"]: - err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}名称中含有非法字符" + if illegal_char in choice["branch_id"]: + err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}名称中含有非法字符" logger.error(err) raise ValueError(err) - node_branch_map[node.step_id].add(choice["branchId"]) + node_branch_map[node.step_id].add(choice["branch_id"]) else: node_branch_map[node.step_id].add("") valid_edges = [] diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index dc561eba8..a08d54551 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -4,6 +4,7 @@ import logging import re import uuid +from typing import Any import magic from fastapi import UploadFile @@ -313,6 +314,7 @@ class MCPServiceManager: async def active_mcpservice( user_sub: str, mcp_id: str, + mcp_env: dict[str, Any] = {}, ) -> None: """ 激活MCP服务 @@ -334,7 +336,7 @@ class MCPServiceManager: userSub=user_sub, )) await session.commit() - await MCPLoader.user_active_template(user_sub, mcp_id) + await MCPLoader.user_active_template(user_sub, mcp_id, mcp_env) @staticmethod diff --git a/apps/services/node.py b/apps/services/node.py index e1f3ff5e5..12438cc4e 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -9,6 +9,7 @@ from sqlalchemy import select from apps.common.postgres import postgres from apps.models.node import NodeInfo from apps.scheduler.pool.pool import Pool +from apps.schemas.enum_var import SpecialCallType if TYPE_CHECKING: from pydantic import BaseModel @@ -59,6 +60,9 @@ class NodeManager: async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" # 查找Node信息 + if node_id == SpecialCallType.EMPTY.value: + # 如果是空节点,返回空Schema + return {}, {} logger.info("[NodeManager] 获取节点 %s", node_id) node_data = await NodeManager.get_node(node_id) call_id = node_data.callId diff --git a/apps/services/parameter.py b/apps/services/parameter.py index 3105af851..322224dc7 100644 --- a/apps/services/parameter.py +++ b/apps/services/parameter.py @@ -41,7 +41,7 @@ class ParameterManager: for item in operate: result.append(OperateAndBindType( operate=item, - bind_type=ConditionHandler.get_value_type_from_operate(item))) + bind_type=(await ConditionHandler.get_value_type_from_operate(item)))) return result @staticmethod @@ -51,8 +51,10 @@ class ParameterManager: q = [step_id] in_edges = {} step_id_to_node_id = {} + step_id_to_node_name = {} for step in flow.nodes: step_id_to_node_id[step.step_id] = step.node_id + step_id_to_node_name[step.step_id] = step.name for edge in flow.edges: if edge.target_node not in in_edges: in_edges[edge.target_node] = [] @@ -65,16 +67,17 @@ class ParameterManager: if pre_node_id not in q: q.append(pre_node_id) pre_step_params = [] - for step_id in q: + for i in range(1, len(q)): + step_id = q[i] node_id = step_id_to_node_id.get(step_id) params_schema, output_schema = await NodeManager.get_node_params(node_id) slot = Slot(output_schema) - params_node = slot.get_params_node_from_schema(root='/output') + params_node = slot.get_params_node_from_schema() pre_step_params.append( StepParams( - stepId=node_id, - name=params_schema.get("name", ""), + stepId=step_id, + name=step_id_to_node_name.get(step_id), paramsNode=params_node, - ) + ), ) return pre_step_params diff --git a/apps/services/rag.py b/apps/services/rag.py index 148e39cf8..d1e91ac30 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -17,7 +17,6 @@ from apps.llm.token import TokenCalculator from apps.schemas.config import LLMConfig from apps.schemas.enum_var import EventType from apps.schemas.rag_data import RAGQueryReq -from apps.services.activity import Activity from apps.services.session import SessionManager logger = logging.getLogger(__name__) @@ -259,8 +258,6 @@ class RAG: result_only=False, model=llm.model_name, ): - if not await Activity.is_active(user_sub): - return current_chunk = buffer + chunk # 防止脚注被截断 if len(chunk) >= 2 and chunk[-2:] != "]]": diff --git a/apps/services/task.py b/apps/services/task.py index 28ab72d54..ec8436b18 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -85,7 +85,7 @@ class TaskManager: return [] flow_context_list = [] - for flow_context_id in records[0]["records"]["flow"]: + for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: flow_context_list.append(flow_context) @@ -207,7 +207,6 @@ class TaskManager: conversation_id=post_body.conversation_id, group_id=post_body.group_id if post_body.group_id else "", ), - state=None, tokens=TaskTokens(), runtime=TaskRuntime(), ) -- Gitee