diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index b1e92b8aa7c6a833ee310e2899115c15b26480a8..840c2fb9aa386b47a19e609f5aa17592695a40d4 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -13,6 +13,7 @@ from apps.common.wordscheck import WordsCheck
from apps.dependency import verify_personal_token, verify_session
from apps.scheduler.scheduler import Scheduler
from apps.scheduler.scheduler.context import save_data
+from apps.schemas.enum_var import FlowStatus
from apps.schemas.request_data import RequestData
from apps.schemas.response_data import ResponseData
from apps.schemas.task import Task
@@ -82,8 +83,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
# 获取最终答案
task = scheduler.task
- if not task.runtime.answer:
- logger.error("[Chat] 答案为空")
+ if task.state.flow_status == FlowStatus.ERROR:
+ logger.error("[Chat] 生成答案失败")
yield "data: [ERROR]\n\n"
await Activity.remove_active(user_sub)
return
diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py
index ab01d52c59523b59849c1fb6fbc0873f9744d6c7..dd0c0765c1a122c7b0d7b49cd218fc9002765d2b 100644
--- a/apps/routers/mcp_service.py
+++ b/apps/routers/mcp_service.py
@@ -281,7 +281,7 @@ async def active_or_deactivate_mcp_service(
"""激活/取消激活mcp"""
try:
if data.active:
- await MCPServiceManager.active_mcpservice(request.state.user_sub, mcpId)
+ await MCPServiceManager.active_mcpservice(request.state.user_sub, service_id, data.mcp_env)
else:
await MCPServiceManager.deactive_mcpservice(request.state.user_sub, mcpId)
except Exception as e:
diff --git a/apps/routers/record.py b/apps/routers/record.py
index 4f424e8145c9aa8b333830ce5c6dae7739a8e31d..02c629d02623681cdeedcfaa33ed9562f5c5300b 100644
--- a/apps/routers/record.py
+++ b/apps/routers/record.py
@@ -84,16 +84,14 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path
# 获得Record关联的文档
tmp_record.document = await DocumentManager.get_used_docs_by_record(user_sub, record_group.id)
-
# 获得Record关联的flow数据
flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id)
if flow_step_list:
- first_step_history = FlowStepHistory.model_validate(flow_step_list[0])
tmp_record.flow = RecordFlow(
- id=first_step_history.flow_name, # TODO: 此处前端应该用name
+ id=record.flow.flow_id, # TODO: 此处前端应该用name
recordId=record.id,
- flowStatus=first_step_history.flow_status,
- flowId=first_step_history.id,
+ flowStatus=record.flow.flow_staus,
+ flowId=record.flow.flow_id,
stepNum=len(flow_step_list),
steps=[],
)
diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py
index 0d9ace600e4109de2d36325fd08863b88a596e7c..c78cd4ab5bda2bf71ad968480444bcad1c8be544 100644
--- a/apps/scheduler/call/choice/choice.py
+++ b/apps/scheduler/call/choice/choice.py
@@ -113,7 +113,7 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput):
valid_conditions.append(condition)
# 如果所有条件都无效,抛出异常
- if not valid_conditions:
+ if not valid_conditions and not choice.is_default:
msg = "分支没有有效条件"
logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}")
continue
diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py
index f8a9e40e363a88a07a67bf198c24cc9cd849157f..3ee048252b2134f69310476e9806bce8d3d6df9f 100644
--- a/apps/scheduler/call/choice/condition_handler.py
+++ b/apps/scheduler/call/choice/condition_handler.py
@@ -1,12 +1,16 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""处理条件分支的工具"""
-
import logging
from pydantic import BaseModel
-from apps.scheduler.call.choice.schema import ChoiceBranch, Condition, Logic, Value
+from apps.scheduler.call.choice.schema import (
+ ChoiceBranch,
+ Condition,
+ Logic,
+ Value,
+)
from apps.schemas.parameters import (
BoolOperate,
DictOperate,
@@ -69,15 +73,17 @@ class ConditionHandler(BaseModel):
@staticmethod
def handler(choices: list[ChoiceBranch]) -> str:
"""处理条件"""
- default_branch = [c for c in choices if c.is_default]
-
- for block_judgement in choices:
+ for block_judgement in choices[::-1]:
results = []
if block_judgement.is_default:
- continue
+ return block_judgement.branch_id
for condition in block_judgement.conditions:
result = ConditionHandler._judge_condition(condition)
- results.append(result)
+ if result is not None:
+ results.append(result)
+ if not results:
+ logger.warning(f"[Choice] 分支 {block_judgement.branch_id} 条件处理失败: 没有有效的条件")
+ continue
if block_judgement.logic == Logic.AND:
final_result = all(results)
elif block_judgement.logic == Logic.OR:
@@ -85,10 +91,6 @@ class ConditionHandler(BaseModel):
if final_result:
return block_judgement.branch_id
-
- # 如果没有匹配的分支,选择默认分支
- if default_branch:
- return default_branch[0].branch_id
return ""
@staticmethod
@@ -112,7 +114,7 @@ class ConditionHandler(BaseModel):
if value_type == Type.STRING:
result = ConditionHandler._judge_string_condition(left, operate, right)
elif value_type == Type.NUMBER:
- result = ConditionHandler._judge_int_condition(left, operate, right)
+ result = ConditionHandler._judge_number_condition(left, operate, right)
elif value_type == Type.BOOL:
result = ConditionHandler._judge_bool_condition(left, operate, right)
elif value_type == Type.LIST:
@@ -120,9 +122,9 @@ class ConditionHandler(BaseModel):
elif value_type == Type.DICT:
result = ConditionHandler._judge_dict_condition(left, operate, right)
else:
- logger.error("不支持的数据类型: %s", value_type)
msg = f"不支持的数据类型: {value_type}"
- raise ValueError(msg)
+ logger.error(f"[Choice] 条件处理失败: {msg}")
+ return None
return result
@staticmethod
@@ -141,11 +143,10 @@ class ConditionHandler(BaseModel):
"""
left_value = left.value
if not isinstance(left_value, str):
- logger.error("左值不是字符串类型: %s", left_value)
- msg = "左值必须是字符串类型"
- raise TypeError(msg)
+ msg = f"左值必须是字符串类型 ({left_value})"
+ logger.warning(msg)
+ return None
right_value = right.value
- result = False
if operate == StringOperate.EQUAL:
return left_value == right_value
elif operate == StringOperate.NOT_EQUAL:
@@ -189,9 +190,9 @@ class ConditionHandler(BaseModel):
"""
left_value = left.value
if not isinstance(left_value, (int, float)):
- logger.error("左值不是数字类型: %s", left_value)
- msg = "左值必须是数字类型"
- raise TypeError(msg)
+ msg = f"左值必须是数字类型 ({left_value})"
+ logger.warning(msg)
+ return None
right_value = right.value
if operate == NumberOperate.EQUAL:
return left_value == right_value
@@ -223,9 +224,9 @@ class ConditionHandler(BaseModel):
"""
left_value = left.value
if not isinstance(left_value, bool):
- logger.error("左值不是布尔类型: %s", left_value)
msg = "左值必须是布尔类型"
- raise TypeError(msg)
+ logger.warning(msg)
+ return None
right_value = right.value
if operate == BoolOperate.EQUAL:
return left_value == right_value
@@ -253,9 +254,9 @@ class ConditionHandler(BaseModel):
"""
left_value = left.value
if not isinstance(left_value, list):
- logger.error("左值不是列表类型: %s", left_value)
- msg = "左值必须是列表类型"
- raise TypeError(msg)
+ msg = f"左值必须是列表类型 ({left_value})"
+ logger.warning(msg)
+ return None
right_value = right.value
if operate == ListOperate.EQUAL:
return left_value == right_value
@@ -293,9 +294,9 @@ class ConditionHandler(BaseModel):
"""
left_value = left.value
if not isinstance(left_value, dict):
- logger.error("左值不是字典类型: %s", left_value)
- msg = "左值必须是字典类型"
- raise TypeError(msg)
+ msg = f"左值必须是字典类型 ({left_value})"
+ logger.warning(msg)
+ return None
right_value = right.value
if operate == DictOperate.EQUAL:
return left_value == right_value
diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py
index 9e886c57729315e4b8c54c8a8af52098590e450b..1be23761b32f71e2e170c0f388bc46a37fc798f1 100644
--- a/apps/scheduler/call/choice/schema.py
+++ b/apps/scheduler/call/choice/schema.py
@@ -28,6 +28,7 @@ class Value(DataBase):
step_id: str | None = Field(description="步骤id", default=None)
type: Type | None = Field(description="值的类型", default=None)
+ name: str | None = Field(description="值的名称", default=None)
value: str | float | int | bool | list | dict | None = Field(description="值", default=None)
diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py
index 91f22409f111448878f7773c8fdf3bf48f975e2a..ec51c45f3e8e1b2672dc3f552b5c0e272b228d91 100644
--- a/apps/scheduler/call/core.py
+++ b/apps/scheduler/call/core.py
@@ -131,7 +131,7 @@ class CoreCall(BaseModel):
:return: 变量
"""
split_path = path.split("/")
- if len(split_path) < 2:
+ if len(split_path) < 1:
err = f"[CoreCall] 路径格式错误: {path}"
logger.error(err)
return None
@@ -141,7 +141,7 @@ class CoreCall(BaseModel):
return None
data = history[split_path[0]].output_data
- for key in split_path[2:]:
+ for key in split_path[1:]:
if key not in data:
err = f"[CoreCall] 输出Key {key} 不存在"
logger.error(err)
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index 89f8326b831b63b9071c04cde118bb37e851dbc1..0f4978a1eec3f1a05d0e829defb455d4088e82ab 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -47,6 +47,10 @@ class FlowExecutor(BaseExecutor):
flow_id: str = Field(description="Flow ID")
question: str = Field(description="用户输入")
post_body_app: RequestDataApp = Field(description="请求体中的app信息")
+ current_step: StepQueueItem | None = Field(
+ description="当前执行的步骤",
+ default=None,
+ )
async def load_state(self) -> None:
@@ -73,13 +77,13 @@ class FlowExecutor(BaseExecutor):
self.step_queue: deque[StepQueueItem] = deque()
- async def _invoke_runner(self, queue_item: StepQueueItem) -> None:
+ async def _invoke_runner(self) -> None:
"""单一Step执行"""
# 创建步骤Runner
step_runner = StepExecutor(
msg_queue=self.msg_queue,
task=self.task,
- step=queue_item,
+ step=self.current_step,
background=self.background,
question=self.question,
)
@@ -97,12 +101,12 @@ class FlowExecutor(BaseExecutor):
"""执行当前queue里面的所有步骤(在用户看来是单一Step)"""
while True:
try:
- queue_item = self.step_queue.pop()
+ self.current_step = self.step_queue.pop()
except IndexError:
break
# 执行Step
- await self._invoke_runner(queue_item)
+ await self._invoke_runner()
async def _find_next_id(self, step_id: str) -> list[str]:
@@ -119,17 +123,17 @@ class FlowExecutor(BaseExecutor):
# 如果当前步骤为结束,则直接返回
if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type]
return []
- if self.task.state.step_name == "Choice":
+ if self.current_step.step.type == SpecialCallType.CHOICE.value:
# 如果是choice节点,获取分支ID
branch_id = self.task.context[-1]["output_data"]["branch_id"]
if branch_id:
- self.task.state.step_id = self.task.state.step_id + "." + branch_id
+ next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id)
logger.info("[FlowExecutor] 分支ID:%s", branch_id)
else:
logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表")
return []
-
- next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type]
+ else:
+ next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type]
# 如果step没有任何出边,直接跳到end
if not next_steps:
return [
@@ -180,6 +184,7 @@ class FlowExecutor(BaseExecutor):
self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type]
# 运行Flow(未达终点)
+ is_error = False
while not self._reached_end:
# 如果当前步骤出错,执行错误处理步骤
if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type]
@@ -202,7 +207,7 @@ class FlowExecutor(BaseExecutor):
enable_filling=False,
to_user=False,
))
- self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type]
+ is_error = True
# 错误处理后结束
self._reached_end = True
@@ -217,6 +222,12 @@ class FlowExecutor(BaseExecutor):
for step in next_step:
self.step_queue.append(step)
+ # 更新Task状态
+ if is_error:
+ self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type]
+ else:
+ self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type]
+
# 尾插运行结束后的系统步骤
for step in FIXED_STEPS_AFTER_END:
self.step_queue.append(StepQueueItem(
@@ -228,6 +239,7 @@ class FlowExecutor(BaseExecutor):
# FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间)
self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time
# 推送Flow停止消息
- await self.push_message(EventType.FLOW_STOP.value)
- # 更新Flow状态
- self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type]
+ if is_error:
+ await self.push_message(EventType.FLOW_FAILED.value)
+ else:
+ await self.push_message(EventType.FLOW_SUCCESS.value)
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index 941e8822fe26ae663e31981c1c14867370d8db77..6b78e65dbb66d2e2ce3782543f2a0698327df373 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -101,7 +101,7 @@ class MCPHost:
task_id=self._task_id,
flow_id=self._runtime_id,
flow_name=self._runtime_name,
- flow_status=StepStatus.SUCCESS,
+ flow_status=StepStatus.RUNNING,
step_id=tool.id,
step_name=tool.toolName,
# description是规划的实际内容
diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py
index acdd4871e58989022ab0d2fd6aae68c9aa671929..a8ebec7b4d1182684a02e5b9f9ce68189af6c2e4 100644
--- a/apps/scheduler/mcp_agent/host.py
+++ b/apps/scheduler/mcp_agent/host.py
@@ -123,7 +123,7 @@ class MCPHost:
return output_data
- async def _fill_params(self, tool: MCPTool, query: str) -> dict[str, Any]:
+ async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]:
"""填充工具参数"""
# 更清晰的输入·指令,这样可以调用generate
llm_query = rf"""
@@ -139,7 +139,7 @@ class MCPHost:
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": await self.assemble_memory()},
],
- tool.input_schema,
+ schema,
)
return await json_generator.generate()
diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py
index cd4f5975eea3f023a92626966081c2d1eb33bdb7..ec90f57b8bb490db64e3a4652933ad3b369ad4d9 100644
--- a/apps/scheduler/mcp_agent/plan.py
+++ b/apps/scheduler/mcp_agent/plan.py
@@ -1,19 +1,29 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 用户目标拆解与规划"""
+from typing import Any
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
from apps.llm.function import JsonGenerator
from apps.llm.reasoning import ReasoningLLM
-from apps.scheduler.mcp.prompt import CREATE_PLAN, FINAL_ANSWER
-from apps.schemas.mcp import MCPPlan, MCPTool
+from apps.scheduler.mcp_agent.prompt import (
+ CREATE_PLAN,
+ EVALUATE_GOAL,
+ FINAL_ANSWER,
+ GENERATE_FLOW_NAME,
+ GET_MISSING_PARAMS,
+ RECREATE_PLAN,
+ RISK_EVALUATE,
+)
+from apps.scheduler.slot.slot import Slot
+from apps.schemas.mcp import GoalEvaluationResult, MCPPlan, MCPTool, ToolRisk
class MCPPlanner:
"""MCP 用户目标拆解与规划"""
- def __init__(self, user_goal: str) -> None:
+ def __init__(self, user_goal: str, resoning_llm: ReasoningLLM = None) -> None:
"""初始化MCP规划器"""
self.user_goal = user_goal
self._env = SandboxedEnvironment(
@@ -22,37 +32,19 @@ class MCPPlanner:
trim_blocks=True,
lstrip_blocks=True,
)
+ self.resoning_llm = resoning_llm or ReasoningLLM()
self.input_tokens = 0
self.output_tokens = 0
-
- async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan:
- """规划下一步的执行流程,并输出"""
- # 获取推理结果
- result = await self._get_reasoning_plan(tool_list, max_steps)
-
- # 解析为结构化数据
- return await self._parse_plan_result(result, max_steps)
-
-
- async def _get_reasoning_plan(self, tool_list: list[MCPTool], max_steps: int) -> str:
- """获取推理大模型的结果"""
- # 格式化Prompt
- template = self._env.from_string(CREATE_PLAN)
- prompt = template.render(
- goal=self.user_goal,
- tools=tool_list,
- max_num=max_steps,
- )
-
+ async def get_resoning_result(self, prompt: str) -> str:
+ """获取推理结果"""
# 调用推理大模型
message = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
- reasoning_llm = ReasoningLLM()
result = ""
- async for chunk in reasoning_llm.call(
+ async for chunk in self.resoning_llm.call(
message,
streaming=False,
temperature=0.07,
@@ -61,19 +53,12 @@ class MCPPlanner:
result += chunk
# 保存token用量
- self.input_tokens = reasoning_llm.input_tokens
- self.output_tokens = reasoning_llm.output_tokens
-
+ self.input_tokens = self.resoning_llm.input_tokens
+ self.output_tokens = self.resoning_llm.output_tokens
return result
-
- async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan:
- """将推理结果解析为结构化数据"""
- # 格式化Prompt
- schema = MCPPlan.model_json_schema()
- schema["properties"]["plans"]["maxItems"] = max_steps
-
- # 使用Function模型解析结果
+ async def _parse_result(self, result: str, schema: dict[str, Any]) -> str:
+ """解析推理结果"""
json_generator = JsonGenerator(
result,
[
@@ -82,9 +67,138 @@ class MCPPlanner:
],
schema,
)
- plan = await json_generator.generate()
+ json_result = await json_generator.generate()
+ return json_result
+
+ async def evaluate_goal(self, tool_list: list[MCPTool]) -> GoalEvaluationResult:
+ """评估用户目标的可行性"""
+ # 获取推理结果
+ result = await self._get_reasoning_evaluation(tool_list)
+
+ # 解析为结构化数据
+ evaluation = await self._parse_evaluation_result(result)
+
+ # 返回评估结果
+ return evaluation
+
+ async def _get_reasoning_evaluation(self, tool_list: list[MCPTool]) -> str:
+ """获取推理大模型的评估结果"""
+ template = self._env.from_string(EVALUATE_GOAL)
+ prompt = template.render(
+ goal=self.user_goal,
+ tools=tool_list,
+ )
+ result = await self.get_resoning_result(prompt)
+ return result
+
+ async def _parse_evaluation_result(self, result: str) -> GoalEvaluationResult:
+ """将推理结果解析为结构化数据"""
+ schema = GoalEvaluationResult.model_json_schema()
+ evaluation = await self._parse_result(result, schema)
+ # 使用GoalEvaluationResult模型解析结果
+ return GoalEvaluationResult.model_validate(evaluation)
+
+ async def get_flow_name(self) -> str:
+ """获取当前流程的名称"""
+ result = await self._get_reasoning_flow_name()
+ return result
+
+ async def _get_reasoning_flow_name(self) -> str:
+ """获取推理大模型的流程名称"""
+ template = self._env.from_string(GENERATE_FLOW_NAME)
+ prompt = template.render(goal=self.user_goal)
+ result = await self.get_resoning_result(prompt)
+ return result
+
+ async def create_plan(self, tool_list: list[MCPTool], max_steps: int = 6) -> MCPPlan:
+ """规划下一步的执行流程,并输出"""
+ # 获取推理结果
+ result = await self._get_reasoning_plan(tool_list, max_steps)
+
+ # 解析为结构化数据
+ return await self._parse_plan_result(result, max_steps)
+
+ async def _get_reasoning_plan(
+ self, is_replan: bool = False, error_message: str = "", current_plan: MCPPlan = MCPPlan(),
+ tool_list: list[MCPTool] = [],
+ max_steps: int = 10) -> str:
+ """获取推理大模型的结果"""
+ # 格式化Prompt
+ if is_replan:
+ template = self._env.from_string(RECREATE_PLAN)
+ prompt = template.render(
+ current_plan=current_plan,
+ error_message=error_message,
+ goal=self.user_goal,
+ tools=tool_list,
+ max_num=max_steps,
+ )
+ else:
+ template = self._env.from_string(CREATE_PLAN)
+ prompt = template.render(
+ goal=self.user_goal,
+ tools=tool_list,
+ max_num=max_steps,
+ )
+ result = await self.get_resoning_result(prompt)
+ return result
+
+ async def _parse_plan_result(self, result: str, max_steps: int) -> MCPPlan:
+ """将推理结果解析为结构化数据"""
+ # 格式化Prompt
+ schema = MCPPlan.model_json_schema()
+ schema["properties"]["plans"]["maxItems"] = max_steps
+ plan = await self._parse_result(result, schema)
+ # 使用Function模型解析结果
return MCPPlan.model_validate(plan)
+ async def get_tool_risk(self, tool: MCPTool, input_parm: dict[str, Any], additional_info: str = "") -> ToolRisk:
+ """获取MCP工具的风险评估结果"""
+ # 获取推理结果
+ result = await self._get_reasoning_risk(tool, input_parm, additional_info)
+
+ # 解析为结构化数据
+ risk = await self._parse_risk_result(result)
+
+ # 返回风险评估结果
+ return risk
+
+ async def _get_reasoning_risk(self, tool: MCPTool, input_param: dict[str, Any], additional_info: str) -> str:
+ """获取推理大模型的风险评估结果"""
+ template = self._env.from_string(RISK_EVALUATE)
+ prompt = template.render(
+ tool=tool,
+ input_param=input_param,
+ additional_info=additional_info,
+ )
+ result = await self.get_resoning_result(prompt)
+ return result
+
+ async def _parse_risk_result(self, result: str) -> ToolRisk:
+ """将推理结果解析为结构化数据"""
+ schema = ToolRisk.model_json_schema()
+ risk = await self._parse_result(result, schema)
+ # 使用ToolRisk模型解析结果
+ return ToolRisk.model_validate(risk)
+
+ async def get_missing_param(
+ self, tool: MCPTool, schema: dict[str, Any],
+ input_param: dict[str, Any],
+ error_message: str) -> list[str]:
+ """获取缺失的参数"""
+ slot = Slot(schema=schema)
+ schema_with_null = slot.add_null_to_basic_types()
+ template = self._env.from_string(GET_MISSING_PARAMS)
+ prompt = template.render(
+ tool=tool,
+ input_param=input_param,
+ schema=schema_with_null,
+ error_message=error_message,
+ )
+ result = await self.get_resoning_result(prompt)
+ # 解析为结构化数据
+ input_param_with_null = await self._parse_result(result, schema_with_null)
+ return input_param_with_null
async def generate_answer(self, plan: MCPPlan, memory: str) -> str:
"""生成最终回答"""
diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py
index b322fb0883e8ed935243389cb86066845a549631..cf05afc877e9ce08d790799d08e45c1f321d2d21 100644
--- a/apps/scheduler/mcp_agent/prompt.py
+++ b/apps/scheduler/mcp_agent/prompt.py
@@ -62,6 +62,76 @@ MCP_SELECT = dedent(r"""
### 请一步一步思考:
""")
+EVALUATE_GOAL = dedent(r"""
+ 你是一个计划评估器。
+ 请根据用户的目标和当前的工具集合以及一些附加信息,判断基于当前的工具集合,是否能够完成用户的目标。
+ 如果能够完成,请返回`true`,否则返回`false`。
+ 推理过程必须清晰明了,能够让人理解你的判断依据。
+ 必须按照以下格式回答:
+ ```json
+ {
+ "can_complete": true/false,
+ "resoning": "你的推理过程"
+ }
+ ```
+
+ # 样例
+ ## 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈,并调优
+
+ ## 工具集合
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+ - mysql_analyzer分析MySQL数据库性能
+ - performance_tuner调优数据库性能
+ - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+
+
+ ## 附加信息
+ 1. 当前MySQL数据库的版本是8.0.26
+ 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf
+
+ ##
+ ```json
+ {
+ "can_complete": true,
+ "resoning": "当前的工具集合中包含mysql_analyzer和performance_tuner,能够完成对MySQL数据库的性能分析和调优,因此可以完成用户的目标。"
+ }
+ ```
+
+ # 目标
+ {{ goal }}
+
+ # 工具集合
+
+ {% for tool in tools %}
+ - {{ tool.id }}{{tool.name}};{{ tool.description }}
+ {% endfor %}
+
+
+ # 附加信息
+ {{ additional_info }}
+
+""")
+GENERATE_FLOW_NAME = dedent(r"""
+ 你是一个智能助手,你的任务是根据用户的目标,生成一个合适的流程名称。
+
+ # 生成流程名称时的注意事项:
+ 1. 流程名称应该简洁明了,能够准确表达达成用户目标的过程。
+ 2. 流程名称应该包含关键的操作或步骤,例如“扫描”、“分析”、“调优”等。
+ 3. 流程名称应该避免使用过于复杂或专业的术语,以便用户能够理解。
+ 4. 流程名称应该尽量简短,小于20个字或者单词。
+ 5. 只输出流程名称,不要输出其他内容。
+ # 样例
+ ## 目标
+ 我需要扫描当前mysql数据库,分析性能瓶颈,并调优
+ ## 输出
+ 扫描MySQL数据库并分析性能瓶颈,进行调优
+ # 现在开始生成流程名称:
+ # 目标
+ {{ goal }}
+ # 输出
+ """)
CREATE_PLAN = dedent(r"""
你是一个计划生成器。
请分析用户的目标,并生成一个计划。你后续将根据这个计划,一步一步地完成用户的目标。
@@ -106,8 +176,6 @@ CREATE_PLAN = dedent(r"""
{% for tool in tools %}
- {{ tool.id }}{{tool.name}};{{ tool.description }}
{% endfor %}
- - Final结束步骤,当执行到这一步时,\
-表示计划执行结束,所得到的结果将作为最终结果。
# 样例
@@ -163,78 +231,384 @@ CREATE_PLAN = dedent(r"""
# 计划
""")
-EVALUATE_PLAN = dedent(r"""
- 你是一个计划评估器。
- 请根据给定的计划,和当前计划执行的实际情况,分析当前计划是否合理和完整,并生成改进后的计划。
+RECREATE_PLAN = dedent(r"""
+ 你是一个计划重建器。
+ 请根据用户的目标、当前计划和运行报错,重新生成一个计划。
# 一个好的计划应该:
1. 能够成功完成用户的目标
2. 计划中的每一个步骤必须且只能使用一个工具。
3. 计划中的步骤必须具有清晰和逻辑的步骤,没有冗余或不必要的步骤。
- 4. 计划中的最后一步必须是Final工具,以确保计划执行结束。
+ 4. 你的计划必须避免之前的错误,并且能够成功执行。
+ 5. 计划中的最后一步必须是Final工具,以确保计划执行结束。
- # 你此前的计划是:
+ # 生成计划时的注意事项:
- {{ plan }}
+ - 每一条计划包含3个部分:
+ - 计划内容:描述单个计划步骤的大致内容
+ - 工具ID:必须从下文的工具列表中选择
+ - 工具指令:改写用户的目标,使其更符合工具的输入要求
+ - 必须按照如下格式生成计划,不要输出任何额外数据:
- # 这个计划的执行情况是:
+ ```json
+ {
+ "plans": [
+ {
+ "content": "计划内容",
+ "tool": "工具ID",
+ "instruction": "工具指令"
+ }
+ ]
+ }
+ ```
- 计划的执行情况将放置在 XML标签中。
+ - 在生成计划之前,请一步一步思考,解析用户的目标,并指导你接下来的生成。\
+思考过程应放置在 XML标签中。
+ - 计划内容中,可以使用"Result[]"来引用之前计划步骤的结果。例如:"Result[3]"表示引用第三条计划执行后的结果。
+ - 计划不得多于{{ max_num }}条,且每条计划内容应少于150字。
-
- {{ memory }}
-
+ # 样例
- # 进行评估时的注意事项:
+ ## 目标
- - 请一步一步思考,解析用户的目标,并指导你接下来的生成。思考过程应放置在 XML标签中。
- - 评估结果分为两个部分:
- - 计划评估的结论
- - 改进后的计划
- - 请按照以下JSON格式输出评估结果:
+ 请帮我扫描一下192.168.1.1的这台机器的端口,看看有哪些端口开放。
+ ## 工具
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+ - command_generator生成命令行指令
+ - tool_selector选择合适的工具
+ - command_executor执行命令行指令
+ - Final结束步骤,当执行到这一步时,表示计划执行结束,所得到的结果将作为最终结果。
+
+ ## 当前计划
```json
{
- "evaluation": "评估结果",
"plans": [
{
- "content": "改进后的计划内容",
- "tool": "工具ID",
- "instruction": "工具指令"
+ "content": "生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口"
+ },
+ {
+ "content": "在执行Result[0]生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ },
+ {
+ "content": "任务执行完成,端口扫描结果为Result[2]",
+ "tool": "Final",
+ "instruction": ""
}
]
}
```
+ ## 运行报错
+ 执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。
+ ## 重新生成的计划
- # 现在开始评估计划:
+
+ 1. 这个目标需要使用网络扫描工具来完成,首先需要选择合适的网络扫描工具
+ 2. 目标可以拆解为以下几个部分:
+ - 生成端口扫描命令
+ - 执行端口扫描命令
+ 3.但是在执行端口扫描命令时,出现了错误:`-bash: curl: command not found`。
+ 4.我将计划调整为:
+ - 需要先生成一个命令,查看当前机器支持哪些网络扫描工具
+ - 执行这个命令,查看当前机器支持哪些网络扫描工具
+ - 然后从中选择一个网络扫描工具
+ - 基于选择的网络扫描工具,生成端口扫描命令
+ - 执行端口扫描命令
+
+ ```json
+ {
+ "plans": [
+ {
+ "content": "需要生成一条命令查看当前机器支持哪些网络扫描工具",
+ "tool": "command_generator",
+ "instruction": "选择一个前机器支持哪些网络扫描工具"
+ },
+ {
+ "content": "执行Result[0]中生成的命令,查看当前机器支持哪些网络扫描工具",
+ "tool": "command_executor",
+ "instruction": "执行Result[0]中生成的命令"
+ },
+ {
+ "content": "从Result[1]中选择一个网络扫描工具,生成端口扫描命令",
+ "tool": "tool_selector",
+ "instruction": "选择一个网络扫描工具,生成端口扫描命令"
+ },
+ {
+ "content": "基于result[2]中选择的网络扫描工具,生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口"
+ },
+ {
+ "content": "在Result[0]的MCP Server上执行Result[3]生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ },
+ {
+ "content": "任务执行完成,端口扫描结果为Result[4]",
+ "tool": "Final",
+ "instruction": ""
+ }
+ ]
+ }
+ ```
+
+ # 现在开始重新生成计划:
+
+ # 目标
+
+ {{goal}}
+
+ # 工具
+
+ 你可以访问并使用一些工具,这些工具将在 XML标签中给出。
+
+
+ {% for tool in tools %}
+ - {{ tool.id }}{{tool.name}};{{ tool.description }}
+ {% endfor %}
+
+
+ # 当前计划
+ {{ current_plan }}
+
+ # 运行报错
+ {{ error_message }}
+
+ # 重新生成的计划
""")
+RISK_EVALUATE = dedent(r"""
+ 你是一个工具执行计划评估器。
+ 你的任务是根据当前工具的名称、描述和入参以及附加信息,判断当前工具执行的风险并输出提示。
+ ```json
+ {
+ "risk": "low/medium/high",
+ "message": "提示信息"
+ }
+ ```
+ # 样例
+ ## 工具名称
+ mysql_analyzer
+ ## 工具描述
+ 分析MySQL数据库性能
+ ## 工具入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ ## 附加信息
+ 1. 当前MySQL数据库的版本是8.0.26
+ 2. 当前MySQL数据库的配置文件路径是/etc/my.cnf,并含有以下配置项
+ ```ini
+ [mysqld]
+ innodb_buffer_pool_size=1G
+ innodb_log_file_size=256M
+ ```
+ ## 输出
+ ```json
+ {
+ "risk": "中",
+ "message": "当前工具将连接到MySQL数据库并分析性能,可能会对数据库性能产生一定影响。请确保在非生产环境中执行此操作。"
+ }
+ ```
+ # 工具
+
+ {{ tool.name }}
+ {{ tool.description }}
+
+ # 工具入参
+ {{ input_param }}
+ # 附加信息
+ {{ additional_info }}
+ # 输出
+ """
+ )
+# 根据当前计划和报错信息决定下一步执行,具体计划有需要用户补充工具入参、重计划当前步骤、重计划接下来的所有计划
+JUDGE_NEXT_STEP = dedent(r"""
+ 你是一个计划决策器。
+ 你的任务是根据当前计划、当前使用的工具、工具入参和工具运行报错,决定下一步执行的操作。
+ 请根据以下规则进行判断:
+ 1. 仅通过补充工具入参来解决问题的,返回 fill_params;
+ 2. 需要重计划当前步骤的,返回 replan_current_step;
+ 3. 需要重计划接下来的所有计划的,返回 replan_all_steps;
+ 你的输出要以json格式返回,格式如下:
+ ```json
+ {
+ "next_step": "fill_params/replan_current_step/replan_all_steps",
+ "reason": "你的判断依据"
+ }
+ ```
+ 注意:
+ reason字段必须清晰明了,能够让人理解你的判断依据,并且不超过50个中文字或者100个英文单词。
+ # 样例
+ ## 当前计划
+ {"plans": [
+ {
+ "content": "生成端口扫描命令",
+ "tool": "command_generator",
+ "instruction": "生成端口扫描命令:扫描192.168.1.1的开放端口"
+ },
+ {
+ "content": "在执行Result[0]生成的命令",
+ "tool": "command_executor",
+ "instruction": "执行端口扫描命令"
+ },
+ {
+ "content": "任务执行完成,端口扫描结果为Result[2]",
+ "tool": "Final",
+ "instruction": ""
+ }
+ ]}
+ ## 当前使用的工具
+
+ command_executor
+ 执行命令行指令
+
+ ## 工具入参
+ {
+ "command": "nmap -sS -p--open 192.168.1.1"
+ }
+ ## 工具运行报错
+ 执行端口扫描命令时,出现了错误:`-bash: nmap: command not found`。
+ ## 输出
+ ```json
+ {
+ "next_step": "replan_all_steps",
+ "reason": "当前工具执行报错,提示nmap命令未找到,需要增加command_generator和command_executor的步骤,生成nmap安装命令并执行,之后再生成端口扫描命令并执行。"
+ }
+ ```
+ # 当前计划
+ {{ current_plan }}
+ # 当前使用的工具
+
+ {{ tool.name }}
+ {{ tool.description }}
+
+ # 工具入参
+ {{ input_param }}
+ # 工具运行报错
+ {{ error_message }}
+ # 输出
+ """
+ )
+# 获取缺失的参数的json结构体
+GET_MISSING_PARAMS = dedent(r"""
+ 你是一个工具参数获取器。
+ 你的任务是根据当前工具的名称、描述和入参和入参的schema以及运行报错,将当前缺失的参数设置为null,并输出一个JSON格式的字符串。
+ ```json
+ {
+ "host": "请补充主机地址",
+ "port": "请补充端口号",
+ "username": "请补充用户名",
+ "password": "请补充密码"
+ }
+ ```
+ # 样例
+ # 工具名称
+ mysql_analyzer
+ # 工具描述
+ 分析MySQL数据库性能
+ # 工具入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ # 工具入参schema
+ {
+ "type": "object",
+ "properties": {
+ "host": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的主机地址(可以为字符串或null)"
+ },
+ "port": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的端口号(可以是数字、字符串或null)"
+ },
+ "username": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的用户名(可以为字符串或null)"
+ },
+ "password": {
+ "anyOf": [
+ {"type": "string"},
+ {"type": "null"}
+ ],
+ "description": "MySQL数据库的密码(可以为字符串或null)"
+ }
+ },
+ "required": ["host", "port", "username", "password"]
+ }
+ # 运行报错
+ 执行端口扫描命令时,出现了错误:`password is not correct`。
+ # 输出
+ ```json
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": null,
+ "password": null
+ }
+ ```
+ # 工具
+ < tool >
+ < name > {{tool.name}} < /name >
+ < description > {{tool.description}} < /description >
+ < / tool >
+ # 工具入参
+ {{input_param}}
+ # 工具入参schema(部分字段允许为null)
+ {{input_schema}}
+ # 运行报错
+ {{error_message}}
+ # 输出
+ """
+ )
FINAL_ANSWER = dedent(r"""
综合理解计划执行结果和背景信息,向用户报告目标的完成情况。
# 用户目标
- {{ goal }}
+ {{goal}}
# 计划执行情况
为了完成上述目标,你实施了以下计划:
- {{ memory }}
+ {{memory}}
# 其他背景信息:
- {{ status }}
+ {{status}}
# 现在,请根据以上信息,向用户报告目标的完成情况:
""")
+
MEMORY_TEMPLATE = dedent(r"""
- {% for ctx in context_list %}
- - 第{{ loop.index }}步:{{ ctx.step_description }}
- - 调用工具 `{{ ctx.step_id }}`,并提供参数 `{{ ctx.input_data }}`
- - 执行状态:{{ ctx.status }}
- - 得到数据:`{{ ctx.output_data }}`
- {% endfor %}
+ { % for ctx in context_list % }
+ - 第{{loop.index}}步:{{ctx.step_description}}
+ - 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}`
+ - 执行状态:{{ctx.status}}
+ - 得到数据:`{{ctx.output_data}}`
+ { % endfor % }
""")
diff --git a/apps/scheduler/mcp_agent/select.py b/apps/scheduler/mcp_agent/select.py
index 2ff5034471c5e9c38f166c6187b76dfb4596f734..7198940bc52b31e9dc5d760084d5e80f3af435dd 100644
--- a/apps/scheduler/mcp_agent/select.py
+++ b/apps/scheduler/mcp_agent/select.py
@@ -2,6 +2,8 @@
"""选择MCP Server及其工具"""
import logging
+import uuid
+from collections.abc import AsyncGenerator
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
@@ -26,8 +28,9 @@ logger = logging.getLogger(__name__)
class MCPSelector:
"""MCP选择器"""
- def __init__(self) -> None:
+ def __init__(self, resoning_llm: ReasoningLLM = None) -> None:
"""初始化助手类"""
+ self.resoning_llm = resoning_llm or ReasoningLLM()
self.input_tokens = 0
self.output_tokens = 0
@@ -101,20 +104,15 @@ class MCPSelector:
return await self._call_function_mcp(result, mcp_ids)
- async def _call_reasoning(self, prompt: str) -> str:
+ async def _call_reasoning(self, prompt: str) -> AsyncGenerator[str, None]:
"""调用大模型进行推理"""
logger.info("[MCPHelper] 调用推理大模型")
- llm = ReasoningLLM()
message = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
- result = ""
- async for chunk in llm.call(message):
- result += chunk
- self.input_tokens += llm.input_tokens
- self.output_tokens += llm.output_tokens
- return result
+ async for chunk in self.resoning_llm.call(message):
+ yield chunk
async def _call_function_mcp(self, reasoning_result: str, mcp_ids: list[str]) -> MCPSelectResult:
@@ -182,4 +180,17 @@ class MCPSelector:
tool_obj = MCPTool.model_validate(tool)
llm_tool_list.append(tool_obj)
+ llm_tool_list.append(
+ MCPTool(
+ id="00000000-0000-0000-0000-000000000000",
+ name="Final",
+ description="It is the final step, indicating the end of the plan execution."),
+ )
+ llm_tool_list.append(
+ MCPTool(
+ id="00000000-0000-0000-0000-000000000001",
+ name="Chat",
+ description="It is a chat tool to communicate with the user."),
+ )
+
return llm_tool_list
diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py
index 00f064f33fc6e49d2a252259ce35e13904f3a760..16b1f6669a5a9781e07007c570324803e62d2740 100644
--- a/apps/scheduler/pool/loader/mcp.py
+++ b/apps/scheduler/pool/loader/mcp.py
@@ -4,6 +4,7 @@
import json
import logging
import shutil
+from typing import Any
import asyncer
from anyio import Path
@@ -329,7 +330,7 @@ class MCPLoader(metaclass=SingletonMeta):
@staticmethod
- async def user_active_template(user_sub: str, mcp_id: str) -> None:
+ async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any]) -> None:
"""
用户激活MCP模板
@@ -348,6 +349,8 @@ class MCPLoader(metaclass=SingletonMeta):
err = f"MCP模板“{mcp_id}”已存在或有同名文件,无法激活"
raise FileExistsError(err)
+ mcp_config = await MCPLoader.get_config(mcp_id)
+ mcp_config.mcpServers[mcp_id].env.update(mcp_env)
# 拷贝文件
await asyncer.asyncify(shutil.copytree)(
template_path.as_posix(),
@@ -356,6 +359,17 @@ class MCPLoader(metaclass=SingletonMeta):
symlinks=True,
)
+ user_config_path = user_path / "config.json"
+ # 更新用户配置
+ f = await user_config_path.open("w", encoding="utf-8", errors="ignore")
+ await f.write(
+ json.dumps(
+ mcp_config.model_dump(by_alias=True, exclude_none=True),
+ indent=4,
+ ensure_ascii=False,
+ ),
+ )
+ await f.aclose()
# 更新数据库
async with postgres.session() as session:
await session.merge(MCPActivated(
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index 61be756b4eb4175ddb2f74e18c899394677c9659..272a8c60e451437494ada813433ad4a58578030a 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -11,6 +11,7 @@ from apps.models.document import Document
from apps.models.record import Record, RecordFootNote, RecordMetadata
from apps.schemas.enum_var import StepStatus
from apps.schemas.record import (
+ FlowHistory,
RecordContent,
RecordDocument,
RecordGroupDocument,
@@ -174,7 +175,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
id=task.ids.record_id,
conversationId=task.ids.conversation_id,
taskId=task.id,
- user_sub=user_sub,
+ userSub=user_sub,
content=encrypt_data,
key=encrypt_config,
metadata=RecordMetadata(
@@ -185,7 +186,12 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
footNoteMetadataList=footnote_list,
),
createdAt=current_time,
- flow=[i["_id"] for i in task.context],
+ flow=FlowHistory(
+ flow_id=task.state.flow_id,
+ flow_name=task.state.flow_name,
+ flow_status=task.state.flow_status,
+ history_ids=[context["_id"] for context in task.context],
+ ),
)
# 修改文件状态
diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py
index bbb02d3ef56880fcf0a7ba132d5c037a9b3992df..5383c419f6de7071d307515af804a3967633d920 100644
--- a/apps/scheduler/scheduler/message.py
+++ b/apps/scheduler/scheduler/message.py
@@ -8,7 +8,7 @@ from textwrap import dedent
from apps.common.config import config
from apps.common.queue import MessageQueue
from apps.models.document import Document
-from apps.schemas.enum_var import EventType
+from apps.schemas.enum_var import EventType, FlowStatus
from apps.schemas.message import (
DocumentAddContent,
InitContent,
@@ -61,22 +61,26 @@ async def push_init_message(
async def push_rag_message(
task: Task, queue: MessageQueue, user_sub: str, llm: LLM, history: list[dict[str, str]],
doc_ids: list[str],
- rag_data: RAGQueryReq,) -> Task:
+ rag_data: RAGQueryReq,) -> None:
"""推送RAG消息"""
full_answer = ""
- async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data):
- task, content_obj = await _push_rag_chunk(task, queue, chunk)
- if content_obj.event_type == EventType.TEXT_ADD.value:
- # 如果是文本消息,直接拼接到答案中
- full_answer += content_obj.content
- elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
- task.runtime.documents.append(content_obj.content)
-
+ try:
+ async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data):
+ task, content_obj = await _push_rag_chunk(task, queue, chunk)
+ if content_obj.event_type == EventType.TEXT_ADD.value:
+ # 如果是文本消息,直接拼接到答案中
+ full_answer += content_obj.content
+ elif content_obj.event_type == EventType.DOCUMENT_ADD.value:
+ task.runtime.documents.append(content_obj.content)
+ task.state.flow_status = FlowStatus.SUCCESS
+ except Exception as e:
+ logger.error(f"[Scheduler] RAG服务发生错误: {e}")
+ task.state.flow_status = FlowStatus.ERROR
# 保存答案
task.runtime.answer = full_answer
+ task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time
await TaskManager.save_task(task.id, task)
- return task
async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tuple[Task, RAGEventData]:
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index ef0a747d77c6c6d46ec5e1b66bbc05f5c1b4cacf..d35bbb21828cf849746e4c3d1bf935fe544e3da5 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""Scheduler模块"""
+import asyncio
import logging
from datetime import UTC, datetime
@@ -16,11 +17,12 @@ from apps.scheduler.scheduler.message import (
push_init_message,
push_rag_message,
)
-from apps.schemas.enum_var import AppType, EventType
+from apps.schemas.enum_var import AppType, EventType, FlowStatus
from apps.schemas.rag_data import RAGQueryReq
from apps.schemas.request_data import RequestData
from apps.schemas.scheduler import ExecutorBackground
from apps.schemas.task import Task
+from apps.services.activity import Activity
from apps.services.appcenter import AppCenterManager
from apps.services.knowledge import KnowledgeBaseManager
from apps.services.llm import LLMManager
@@ -43,6 +45,29 @@ class Scheduler:
self.queue = queue
self.post_body = post_body
+
+ async def _monitor_activity(self, kill_event, user_sub):
+ """监控用户活动状态,不活跃时终止工作流"""
+ try:
+ check_interval = 0.5 # 每0.5秒检查一次
+
+ while not kill_event.is_set():
+ # 检查用户活动状态
+ is_active = await Activity.is_active(user_sub)
+
+ if not is_active:
+ logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_sub)
+ kill_event.set()
+ break
+
+ # 控制检查频率
+ await asyncio.sleep(check_interval)
+ except asyncio.CancelledError:
+ logger.info("[Scheduler] 活动监控任务已取消")
+ except Exception as e:
+ logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}")
+
+
async def run(self) -> None: # noqa: PLR0911
"""运行调度器"""
try:
@@ -92,6 +117,9 @@ class Scheduler:
# 如果是智能问答,直接执行
logger.info("[Scheduler] 开始执行")
+ # 创建用于通信的事件
+ kill_event = asyncio.Event()
+ monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.ids.user_sub))
if not self.post_body.app or self.post_body.app.app_id == "":
self.task = await push_init_message(self.task, self.queue, 3, is_flow=False)
rag_data = RAGQueryReq(
@@ -99,8 +127,9 @@ class Scheduler:
query=self.post_body.question,
tokensLimit=llm.max_tokens,
)
- self.task = await push_rag_message(self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data)
- self.task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time
+ # 启动监控任务和主任务
+ main_task = asyncio.create_task(push_rag_message(
+ self.task, self.queue, self.task.ids.user_sub, llm, history, doc_ids, rag_data))
else:
# 查找对应的App元数据
app_data = await AppCenterManager.fetch_app_data_by_id(self.post_body.app.app_id)
@@ -124,7 +153,26 @@ class Scheduler:
conversation=context,
facts=facts,
)
- await self.run_executor(self.queue, self.post_body, executor_background)
+ # 启动监控任务和主任务
+ main_task = asyncio.create_task(self.run_executor(self.queue, self.post_body, executor_background))
+ # 等待任一任务完成
+ done, pending = await asyncio.wait(
+ [main_task, monitor],
+ return_when=asyncio.FIRST_COMPLETED
+ )
+
+ # 如果是监控任务触发,终止主任务
+ if kill_event.is_set():
+ logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...")
+ main_task.cancel()
+ need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING]
+ if self.task.state.flow_status in need_change_cancel_flow_state:
+ self.task.state.flow_status = FlowStatus.CANCELLED
+ try:
+ await main_task
+ logger.info("[Scheduler] 工作流执行已被终止")
+ except Exception as e:
+ logger.error(f"[Scheduler] 终止工作流时发生错误: {e}")
# 更新Task,发送结束消息
logger.info("[Scheduler] 发送结束消息")
diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py
index bb2fde49a95b6f738429893137b7c7d7d2a826a8..bda44a3cf26111facd70d3beef71c8244d0d7970 100644
--- a/apps/scheduler/slot/slot.py
+++ b/apps/scheduler/slot/slot.py
@@ -1,6 +1,7 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""参数槽位管理"""
+import copy
import json
import logging
import traceback
@@ -13,15 +14,14 @@ from jsonschema.protocols import Validator
from jsonschema.validators import extend
from apps.scheduler.call.choice.schema import Type
-from apps.schemas.response_data import ParamsNode
-
-from .parser import (
+from apps.scheduler.slot.parser import (
SlotConstParser,
SlotDateParser,
SlotDefaultParser,
SlotTimestampParser,
)
-from .util import escape_path, patch_json
+from apps.scheduler.slot.util import escape_path, patch_json
+from apps.schemas.response_data import ParamsNode
logger = logging.getLogger(__name__)
@@ -129,7 +129,7 @@ class Slot:
# Schema标准
return [_process_json_value(item, spec_data["items"]) for item in json_value]
if spec_data["type"] == "object" and isinstance(json_value, dict):
- # 若Schema不标准,则不进行处理
+ # 若Schema不标准,则不进行处理F
if "properties" not in spec_data:
return json_value
# Schema标准
@@ -157,35 +157,60 @@ class Slot:
@staticmethod
def _generate_example(schema_node: dict) -> Any: # noqa: PLR0911
"""根据schema生成示例值"""
+ if "anyOf" in schema_node or "oneOf" in schema_node:
+ # 如果有anyOf,随机返回一个示例
+ for item in schema_node["anyOf"] if "anyOf" in schema_node else schema_node["oneOf"]:
+ example = Slot._generate_example(item)
+ if example is not None:
+ return example
+
+ if "allOf" in schema_node:
+ # 如果有allOf,返回所有示例的合并
+ example = None
+ for item in schema_node["allOf"]:
+ if example is None:
+ example = Slot._generate_example(item)
+ else:
+ other_example = Slot._generate_example(item)
+ if isinstance(example, dict) and isinstance(other_example, dict):
+ example.update(other_example)
+ else:
+ example = None
+ break
+ return example
+
if "default" in schema_node:
return schema_node["default"]
if "type" not in schema_node:
return None
-
+ type_value = schema_node["type"]
+ if isinstance(type_value, list):
+ # 如果是多类型,随机返回一个示例
+ if len(type_value) > 1:
+ type_value = type_value[0]
# 处理类型为 object 的节点
- if schema_node["type"] == "object":
+ if type_value == "object":
data = {}
properties = schema_node.get("properties", {})
for name, schema in properties.items():
data[name] = Slot._generate_example(schema)
return data
-
# 处理类型为 array 的节点
- if schema_node["type"] == "array":
+ elif type_value == "array":
items_schema = schema_node.get("items", {})
return [Slot._generate_example(items_schema)]
# 处理类型为 string 的节点
- if schema_node["type"] == "string":
+ elif type_value == "string":
return ""
# 处理类型为 number 或 integer 的节点
- if schema_node["type"] in ["number", "integer"]:
+ elif type_value in ["number", "integer"]:
return 0
# 处理类型为 boolean 的节点
- if schema_node["type"] == "boolean":
+ elif type_value == "boolean":
return False
# 处理其他类型或未定义类型
@@ -199,29 +224,69 @@ class Slot:
"""从JSON Schema中提取类型描述"""
def _extract_type_desc(schema_node: dict[str, Any]) -> dict[str, Any]:
- if "type" not in schema_node and "anyOf" not in schema_node:
- return {}
- data = {"type": schema_node.get("type", ""), "description": schema_node.get("description", "")}
- if "anyOf" in schema_node:
- data["type"] = "anyOf"
- # 处理类型为 object 的节点
- if "anyOf" in schema_node:
- data["items"] = {}
- type_index = 0
- for type_index, sub_schema in enumerate(schema_node["anyOf"]):
- sub_result = _extract_type_desc(sub_schema)
- if sub_result:
- data["items"]["type_"+str(type_index)] = sub_result
- if schema_node.get("type", "") == "object":
- data["items"] = {}
+ # 处理组合关键字
+ special_keys = ["anyOf", "allOf", "oneOf"]
+ for key in special_keys:
+ if key in schema_node:
+ data = {
+ "type": key,
+ "description": schema_node.get("description", ""),
+ "items": {},
+ }
+ type_index = 0
+ for item in schema_node[key]:
+ if isinstance(item, dict):
+ data["items"][f"item_{type_index}"] = _extract_type_desc(item)
+ else:
+ data["items"][f"item_{type_index}"] = {"type": item, "description": ""}
+ type_index += 1
+ return data
+ # 处理基本类型
+ type_val = schema_node.get("type", "")
+ description = schema_node.get("description", "")
+
+ # 处理多类型数组
+ if isinstance(type_val, list):
+ if len(type_val) > 1:
+ data = {"type": "union", "description": description, "items": {}}
+ type_index = 0
+ for t in type_val:
+ if t == "object":
+ tmp_dict = {}
+ for key, val in schema_node.get("properties", {}).items():
+ tmp_dict[key] = _extract_type_desc(val)
+ data["items"][f"item_{type_index}"] = tmp_dict
+ elif t == "array":
+ items_schema = schema_node.get("items", {})
+ data["items"][f"item_{type_index}"] = _extract_type_desc(items_schema)
+ else:
+ data["items"][f"item_{type_index}"] = {"type": t, "description": description}
+ type_index += 1
+ return data
+ elif len(type_val) == 1:
+ type_val = type_val[0]
+ else:
+ type_val = ""
+
+ data = {"type": type_val, "description": description, "items": {}}
+
+ # 递归处理对象和数组
+ if type_val == "object":
for key, val in schema_node.get("properties", {}).items():
data["items"][key] = _extract_type_desc(val)
-
- # 处理类型为 array 的节点
- if schema_node.get("type", "") == "array":
+ elif type_val == "array":
items_schema = schema_node.get("items", {})
- data["items"] = _extract_type_desc(items_schema)
+ if isinstance(items_schema, list):
+ item_index = 0
+ for item in items_schema:
+ data["items"][f"item_{item_index}"] = _extract_type_desc(item)
+ item_index += 1
+ else:
+ data["items"]["item"] = _extract_type_desc(items_schema)
+ if data["items"] == {}:
+ del data["items"]
return data
+
return _extract_type_desc(self._schema)
def get_params_node_from_schema(self, root: str = "") -> ParamsNode:
@@ -232,13 +297,15 @@ class Slot:
return None
param_type = schema_node["type"]
+ if isinstance(param_type, list):
+ return None # 不支持多类型
if param_type == "object":
param_type = Type.DICT
elif param_type == "array":
param_type = Type.LIST
elif param_type == "string":
param_type = Type.STRING
- elif param_type == "number":
+ elif param_type in ["number", "integer"]:
param_type = Type.NUMBER
elif param_type == "boolean":
param_type = Type.BOOL
@@ -247,9 +314,11 @@ class Slot:
return None
sub_params = []
- if param_type == "object" and "properties" in schema_node:
+ if param_type == Type.DICT and "properties" in schema_node:
for key, value in schema_node["properties"].items():
- sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}"))
+ sub_param = _extract_params_node(value, name=key, path=f"{path}/{key}")
+ if sub_param:
+ sub_params.append(sub_param)
else:
# 对于非对象类型,直接返回空子参数
sub_params = None
@@ -318,7 +387,6 @@ class Slot:
logger.exception("[Slot] 错误schema不合法: %s", error.schema)
return {}, []
-
def _assemble_patch(
self,
key: str,
@@ -371,7 +439,6 @@ class Slot:
logger.info("[Slot] 组装patch: %s", patch_list)
return patch_list
-
def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]:
"""将用户手动填充的参数专为真实JSON"""
json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data
@@ -415,3 +482,54 @@ class Slot:
return schema_template
return {}
+
+ def add_null_to_basic_types(self) -> dict[str, Any]:
+ """
+ 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项
+ """
+ def add_null_to_basic_types(schema: dict[str, Any]) -> dict[str, Any]:
+ """
+ 递归地为 JSON Schema 中的基础类型(bool、number等)添加 null 选项
+
+ 参数:
+ schema (dict): 原始 JSON Schema
+
+ 返回:
+ dict: 修改后的 JSON Schema
+ """
+ # 如果不是字典类型(schema),直接返回
+ if not isinstance(schema, dict):
+ return schema
+
+ # 处理当前节点的 type 字段
+ if 'type' in schema:
+ # 处理单一类型字符串
+ if isinstance(schema['type'], str):
+ if schema['type'] in ['boolean', 'number', 'string', 'integer']:
+ schema['type'] = [schema['type'], 'null']
+
+ # 处理类型数组
+ elif isinstance(schema['type'], list):
+ for i, t in enumerate(schema['type']):
+ if isinstance(t, str) and t in ['boolean', 'number', 'string', 'integer']:
+ if 'null' not in schema['type']:
+ schema['type'].append('null')
+ break
+
+ # 递归处理 properties 字段(对象类型)
+ if 'properties' in schema:
+ for prop, prop_schema in schema['properties'].items():
+ schema['properties'][prop] = add_null_to_basic_types(prop_schema)
+
+ # 递归处理 items 字段(数组类型)
+ if 'items' in schema:
+ schema['items'] = add_null_to_basic_types(schema['items'])
+
+ # 递归处理 anyOf, oneOf, allOf 字段
+ for keyword in ['anyOf', 'oneOf', 'allOf']:
+ if keyword in schema:
+ schema[keyword] = [add_null_to_basic_types(sub_schema) for sub_schema in schema[keyword]]
+
+ return schema
+ schema_copy = copy.deepcopy(self._schema)
+ return add_null_to_basic_types(schema_copy)
\ No newline at end of file
diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py
index a6a2d812434d6cf2bcd41a4211e3f50710f9ceb6..d46f5289061b002f64d2df37d4059112c4b3d3fc 100644
--- a/apps/schemas/enum_var.py
+++ b/apps/schemas/enum_var.py
@@ -15,6 +15,7 @@ class SlotType(str, Enum):
class StepStatus(str, Enum):
"""步骤状态"""
+ UNKNOWN = "unknown"
WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
@@ -26,6 +27,7 @@ class StepStatus(str, Enum):
class FlowStatus(str, Enum):
"""Flow状态"""
+ UNKNOWN = "unknown"
WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
@@ -57,6 +59,7 @@ class EventType(str, Enum):
STEP_OUTPUT = "step.output"
FLOW_STOP = "flow.stop"
FLOW_FAILED = "flow.failed"
+ FLOW_SUCCESS = "flow.success"
FLOW_CANCELLED = "flow.cancelled"
DONE = "done"
@@ -90,7 +93,7 @@ class NodeType(str, Enum):
START = "start"
END = "end"
NORMAL = "normal"
- CHOICE = "choice"
+ CHOICE = "Choice"
class PermissionType(str, Enum):
@@ -145,7 +148,7 @@ class SpecialCallType(str, Enum):
LLM = "LLM"
START = "start"
END = "end"
- CHOICE = "choice"
+ CHOICE = "Choice"
class CommentType(str, Enum):
diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py
index 098393739b859fe298210e3f6da02adda8ce93d8..768a1c4d1b8c7c836e0059b1ac1589d1b858777e 100644
--- a/apps/schemas/flow_topology.py
+++ b/apps/schemas/flow_topology.py
@@ -6,6 +6,8 @@ from typing import Any
from pydantic import BaseModel, Field
+from apps.schemas.enum_var import SpecialCallType
+
from .enum_var import EdgeType
@@ -48,7 +50,7 @@ class NodeItem(BaseModel):
service_id: str = Field(alias="serviceId", default="")
node_id: str = Field(alias="nodeId", default="")
name: str = Field(default="")
- call_id: str = Field(alias="callId", default="Empty")
+ call_id: str = Field(alias="callId", default=SpecialCallType.EMPTY.value)
description: str = Field(default="")
parameters: dict[str, Any] = Field(default={})
position: PositionItem = Field(default=PositionItem())
@@ -74,6 +76,6 @@ class FlowItem(BaseModel):
nodes: list[NodeItem] = Field(default=[])
edges: list[EdgeItem] = Field(default=[])
created_at: float | None = Field(alias="createdAt", default=0)
- connectivity: bool = Field(default=False,description="图的开始节点和结束节点是否联通,并且除结束节点都有出边")
+ connectivity: bool = Field(default=False, description="图的开始节点和结束节点是否联通,并且除结束节点都有出边")
focus_point: PositionItem = Field(alias="focusPoint", default=PositionItem())
debug: bool = Field(default=False)
diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py
index 41f5b303457c4bf9bf61782016ebfd0210ad40ee..c31542bc614d99148a1fff8cf4d4e0c1e0cc6735 100644
--- a/apps/schemas/mcp.py
+++ b/apps/schemas/mcp.py
@@ -1,7 +1,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 相关数据结构"""
-import uuid
from enum import Enum
from pydantic import BaseModel, Field
@@ -58,6 +57,28 @@ class MCPServerConfig(MCPServerItem):
author: str = Field(description="MCP 服务器上传者", default="")
+class GoalEvaluationResult(BaseModel):
+ """MCP 目标评估结果"""
+
+ can_complete: bool = Field(description="是否可以完成目标")
+ reason: str = Field(description="评估原因")
+
+
+class Risk(str, Enum):
+ """MCP工具风险类型"""
+
+ LOW = "low"
+ MEDIUM = "medium"
+ HIGH = "high"
+
+
+class ToolRisk(BaseModel):
+ """MCP工具风险评估结果"""
+
+ risk: Risk = Field(description="风险类型", default=Risk.LOW)
+ reason: str = Field(description="风险原因", default="")
+
+
class MCPSelectResult(BaseModel):
"""MCP选择结果"""
@@ -67,7 +88,7 @@ class MCPSelectResult(BaseModel):
class MCPPlanItem(BaseModel):
"""MCP 计划"""
- id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ step_id: str = Field(description="步骤的ID", default="")
content: str = Field(description="计划内容")
tool: str = Field(description="工具名称")
instruction: str = Field(description="工具指令")
@@ -76,4 +97,4 @@ class MCPPlanItem(BaseModel):
class MCPPlan(BaseModel):
"""MCP 计划"""
- plans: list[MCPPlanItem] = Field(description="计划列表")
+ plans: list[MCPPlanItem] = Field(description="计划列表", default=[])
diff --git a/apps/schemas/record.py b/apps/schemas/record.py
index 1ea12c39b70843484bee8d98e0a78a695e07c089..578f567ce78657fe1bb86af44629d405d8bb8a5f 100644
--- a/apps/schemas/record.py
+++ b/apps/schemas/record.py
@@ -7,7 +7,7 @@ from typing import Any, Literal
from pydantic import BaseModel, Field
-from .enum_var import CommentType, StepStatus
+from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus
class RecordDocument(BaseModel):
@@ -104,6 +104,15 @@ class RecordGroupDocument(BaseModel):
created_at: float = Field(default=0.0, description="文档创建时间")
+class FlowHistory(BaseModel):
+ """Flow执行历史"""
+
+ flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
+ flow_name: str = Field(default="", description="Flow名称")
+ flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态")
+ history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表")
+
+
class Record(RecordData):
"""问答,用于保存在MongoDB中"""
@@ -111,4 +120,5 @@ class Record(RecordData):
key: dict[str, Any] = {}
content: str
comment: RecordComment = Field(default=RecordComment())
- flow: list[str] = Field(default=[])
+ flow: FlowHistory = Field(
+ default=FlowHistory(), description="Flow执行历史信息")
diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py
index ceca9e7496bf7577a66cb08369eafcb132f6e900..b369ddf1d3258534c941d47a4d9634da0f2ba7a8 100644
--- a/apps/schemas/request_data.py
+++ b/apps/schemas/request_data.py
@@ -92,6 +92,7 @@ class ActiveMCPServiceRequest(BaseModel):
"""POST /api/mcp/{serviceId} 请求数据结构"""
active: bool = Field(description="是否激活mcp服务")
+ mcp_env: dict[str, Any] = Field(default={}, description="MCP服务环境变量", alias="mcpEnv")
class UpdateServiceRequest(BaseModel):
diff --git a/apps/schemas/task.py b/apps/schemas/task.py
index 4d49e1c10112822ac3a21a03fcf8b4c661940a45..e0a2b39ef1342665cf8beda40ed9682113c86f9b 100644
--- a/apps/schemas/task.py
+++ b/apps/schemas/task.py
@@ -38,16 +38,18 @@ class ExecutorState(BaseModel):
"""FlowExecutor状态"""
# 执行器级数据
- flow_id: str = Field(description="Flow ID")
- flow_name: str = Field(description="Flow名称")
- description: str = Field(description="Flow描述")
- flow_status: FlowStatus = Field(description="Flow状态")
+ flow_id: str = Field(description="Flow ID", default="")
+ flow_name: str = Field(description="Flow名称", default="")
+ description: str = Field(description="Flow描述", default="")
+ flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN)
# 任务级数据
- step_id: str = Field(description="当前步骤ID")
- step_name: str = Field(description="当前步骤名称")
- step_status: StepStatus = Field(description="当前步骤状态")
+ step_id: str = Field(description="当前步骤ID", default="")
+ step_name: str = Field(description="当前步骤名称", default="")
+ step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN)
step_description: str = Field(description="当前步骤描述", default="")
- app_id: str = Field(description="应用ID")
+ retry_times: int = Field(description="当前步骤重试次数", default=0)
+ error_message: str = Field(description="当前步骤错误信息", default="")
+ app_id: str = Field(description="应用ID", default="")
slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
error_info: dict[str, Any] = Field(description="错误信息", default={})
@@ -92,7 +94,7 @@ class Task(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
ids: TaskIds = Field(description="任务涉及的各种ID")
context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[])
- state: ExecutorState | None = Field(description="Flow的状态", default=None)
+ state: ExecutorState = Field(description="Flow的状态", default=ExecutorState())
tokens: TaskTokens = Field(description="Token信息")
runtime: TaskRuntime = Field(description="任务运行时数据")
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py
index b0de757aed4c845ff328b38c9e4f621502744b25..619a121fb38cf66ecb2c510065495ea48689c625 100644
--- a/apps/services/flow_validate.py
+++ b/apps/services/flow_validate.py
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING
from apps.exceptions import FlowBranchValidationError, FlowEdgeValidationError, FlowNodeValidationError
from apps.scheduler.pool.pool import Pool
-from apps.schemas.enum_var import NodeType
+from apps.schemas.enum_var import NodeType, SpecialCallType
from apps.schemas.flow_topology import EdgeItem, FlowItem, NodeItem
if TYPE_CHECKING:
@@ -43,32 +43,40 @@ class FlowService:
"""移除流程图中的多余结构"""
node_branch_map = {}
for node in flow_item.nodes:
- if node.node_id not in {"start", "end", "Empty"}:
+ if node.node_id not in {"start", "end", SpecialCallType.EMPTY.value}:
try:
call_class: type[BaseModel] = await Pool().get_call(node.call_id)
if not call_class:
- node.node_id = "Empty"
+ node.node_id = SpecialCallType.EMPTY.value
node.description = "【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n"+node.description
except Exception:
- node.node_id = "Empty"
+ node.node_id = SpecialCallType.EMPTY.value
node.description = "【对应的api工具被删除!节点不可用!请联系相关人员!】\n\n"+node.description
logger.exception("[FlowService] 获取步骤的call_id失败%s", node.call_id)
node_branch_map[node.step_id] = set()
if node.call_id == NodeType.CHOICE.value:
- node.parameters = node.parameters["input_parameters"]
- if "choices" not in node.parameters:
- node.parameters["choices"] = []
- for choice in node.parameters["choices"]:
- if choice["branchId"] in node_branch_map[node.step_id]:
- err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}重复"
+ input_parameters = node.parameters["input_parameters"]
+ if "choices" not in input_parameters:
+ logger.error(f"[FlowService] 节点{node.name}的分支字段缺失")
+ raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段缺失")
+ if not input_parameters["choices"]:
+ logger.error(f"[FlowService] 节点{node.name}的分支字段为空")
+ raise FlowBranchValidationError(f"[FlowService] 节点{node.name}的分支字段为空")
+ for choice in input_parameters["choices"]:
+ if "branch_id" not in choice:
+ err = f"[FlowService] 节点{node.name}的分支choice缺少branch_id字段"
+ logger.error(err)
+ raise FlowBranchValidationError(err)
+ if choice["branch_id"] in node_branch_map[node.step_id]:
+ err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}重复"
logger.error(err)
raise ValueError(err)
for illegal_char in BRANCH_ILLEGAL_CHARS:
- if illegal_char in choice["branchId"]:
- err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}名称中含有非法字符"
+ if illegal_char in choice["branch_id"]:
+ err = f"[FlowService] 节点{node.name}的分支{choice['branch_id']}名称中含有非法字符"
logger.error(err)
raise ValueError(err)
- node_branch_map[node.step_id].add(choice["branchId"])
+ node_branch_map[node.step_id].add(choice["branch_id"])
else:
node_branch_map[node.step_id].add("")
valid_edges = []
diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py
index dc561eba80eb6ac113dca77727a29b1fd889f873..a08d54551ee9806d92cb3e451c207f25c3a97d9e 100644
--- a/apps/services/mcp_service.py
+++ b/apps/services/mcp_service.py
@@ -4,6 +4,7 @@
import logging
import re
import uuid
+from typing import Any
import magic
from fastapi import UploadFile
@@ -313,6 +314,7 @@ class MCPServiceManager:
async def active_mcpservice(
user_sub: str,
mcp_id: str,
+ mcp_env: dict[str, Any] = {},
) -> None:
"""
激活MCP服务
@@ -334,7 +336,7 @@ class MCPServiceManager:
userSub=user_sub,
))
await session.commit()
- await MCPLoader.user_active_template(user_sub, mcp_id)
+ await MCPLoader.user_active_template(user_sub, mcp_id, mcp_env)
@staticmethod
diff --git a/apps/services/node.py b/apps/services/node.py
index e1f3ff5e5fe36f9bee33138911933d7c6cac8f56..12438cc4e95c83feaa9a43aba45e98aefa05697d 100644
--- a/apps/services/node.py
+++ b/apps/services/node.py
@@ -9,6 +9,7 @@ from sqlalchemy import select
from apps.common.postgres import postgres
from apps.models.node import NodeInfo
from apps.scheduler.pool.pool import Pool
+from apps.schemas.enum_var import SpecialCallType
if TYPE_CHECKING:
from pydantic import BaseModel
@@ -59,6 +60,9 @@ class NodeManager:
async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]:
"""获取Node数据"""
# 查找Node信息
+ if node_id == SpecialCallType.EMPTY.value:
+ # 如果是空节点,返回空Schema
+ return {}, {}
logger.info("[NodeManager] 获取节点 %s", node_id)
node_data = await NodeManager.get_node(node_id)
call_id = node_data.callId
diff --git a/apps/services/parameter.py b/apps/services/parameter.py
index 3105af85191219eb5d431aceca0b3d5485945c41..322224dc757ce9eedfb859d85f7a4dd8bf57f1a1 100644
--- a/apps/services/parameter.py
+++ b/apps/services/parameter.py
@@ -41,7 +41,7 @@ class ParameterManager:
for item in operate:
result.append(OperateAndBindType(
operate=item,
- bind_type=ConditionHandler.get_value_type_from_operate(item)))
+ bind_type=(await ConditionHandler.get_value_type_from_operate(item))))
return result
@staticmethod
@@ -51,8 +51,10 @@ class ParameterManager:
q = [step_id]
in_edges = {}
step_id_to_node_id = {}
+ step_id_to_node_name = {}
for step in flow.nodes:
step_id_to_node_id[step.step_id] = step.node_id
+ step_id_to_node_name[step.step_id] = step.name
for edge in flow.edges:
if edge.target_node not in in_edges:
in_edges[edge.target_node] = []
@@ -65,16 +67,17 @@ class ParameterManager:
if pre_node_id not in q:
q.append(pre_node_id)
pre_step_params = []
- for step_id in q:
+ for i in range(1, len(q)):
+ step_id = q[i]
node_id = step_id_to_node_id.get(step_id)
params_schema, output_schema = await NodeManager.get_node_params(node_id)
slot = Slot(output_schema)
- params_node = slot.get_params_node_from_schema(root='/output')
+ params_node = slot.get_params_node_from_schema()
pre_step_params.append(
StepParams(
- stepId=node_id,
- name=params_schema.get("name", ""),
+ stepId=step_id,
+ name=step_id_to_node_name.get(step_id),
paramsNode=params_node,
- )
+ ),
)
return pre_step_params
diff --git a/apps/services/rag.py b/apps/services/rag.py
index 148e39cf8c8092ef9bf344035eb5073d6efcc055..d1e91ac307f2f2436b1ccf1da0d33c3b88a015c3 100644
--- a/apps/services/rag.py
+++ b/apps/services/rag.py
@@ -17,7 +17,6 @@ from apps.llm.token import TokenCalculator
from apps.schemas.config import LLMConfig
from apps.schemas.enum_var import EventType
from apps.schemas.rag_data import RAGQueryReq
-from apps.services.activity import Activity
from apps.services.session import SessionManager
logger = logging.getLogger(__name__)
@@ -259,8 +258,6 @@ class RAG:
result_only=False,
model=llm.model_name,
):
- if not await Activity.is_active(user_sub):
- return
current_chunk = buffer + chunk
# 防止脚注被截断
if len(chunk) >= 2 and chunk[-2:] != "]]":
diff --git a/apps/services/task.py b/apps/services/task.py
index 28ab72d54c121e786f7b782115acc83a27f9f6a7..ec8436b18cd2a8ec6b9162387dcc8c3635a3cedd 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -85,7 +85,7 @@ class TaskManager:
return []
flow_context_list = []
- for flow_context_id in records[0]["records"]["flow"]:
+ for flow_context_id in records[0]["records"]["flow"]["history_ids"]:
flow_context = await flow_context_collection.find_one({"_id": flow_context_id})
if flow_context:
flow_context_list.append(flow_context)
@@ -207,7 +207,6 @@ class TaskManager:
conversation_id=post_body.conversation_id,
group_id=post_body.group_id if post_body.group_id else "",
),
- state=None,
tokens=TaskTokens(),
runtime=TaskRuntime(),
)