diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 1d987dec48559563addd9435da1d0dbfdf565a00..ee759c8e64ace67b9c806e04cb43cfbd554cade0 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -113,13 +113,17 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): if not success: return False, error_msg + # 更新条件对象中的解析后的值 + condition.left = _left + condition.right = _right + # 检查运算符是否有效 if condition.operate is None: return False, "条件缺少运算符" # 根据运算符确定期望的值类型 try: - expected_type = await ConditionHandler.get_value_type_from_operate(condition.operate) + expected_type = ConditionHandler.get_value_type_from_operate(condition.operate) except Exception as e: return False, f"不支持的运算符: {condition.operate}" @@ -176,7 +180,6 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): list[ChoiceBranch]: 处理后的有效分支列表 """ valid_choices = [] - for choice in self.choices: try: success, error_messages = await self._process_branch(choice, call_vars) diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index 6151b55d407dc99884888523bb9201f13b063af5..7048c2e98d0d1a984729ea09b8c59e5f827c931c 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -28,8 +28,8 @@ logger = logging.getLogger(__name__) class ConditionHandler(BaseModel): """条件分支处理器""" @staticmethod - async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | - BoolOperate | DictOperate) -> ValueType: + def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | + BoolOperate | DictOperate) -> ValueType: """根据逻辑运算符获取值的类型""" if isinstance(operate, NumberOperate): return ValueType.NUMBER @@ -75,23 +75,35 @@ class ConditionHandler(BaseModel): @staticmethod def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" + logger.error(choices) default_branch = [c for c in choices if c.is_default] - for block_judgement in choices[::-1]: + # 先处理所有非默认分支 + for block_judgement in choices: results = [] + # 跳过默认分支,先处理有条件的分支 if block_judgement.is_default: - return default_branch[0].branch_id + continue + for condition in block_judgement.conditions: result = ConditionHandler._judge_condition(condition) results.append(result) + if block_judgement.logic == Logic.AND: final_result = all(results) elif block_judgement.logic == Logic.OR: final_result = any(results) + else: + # 如果没有逻辑运算符但有条件,默认使用AND逻辑 + final_result = all(results) if results else False if final_result: return block_judgement.branch_id + # 如果所有非默认分支都不满足条件,返回默认分支 + if default_branch: + return default_branch[0].branch_id + return "" @staticmethod @@ -100,7 +112,7 @@ class ConditionHandler(BaseModel): 判断条件是否成立。 Args: - condition (Condition): 'left', 'operate', 'right', 'type' + condition (Condition): 'left', 'operate', 'right' Returns: bool @@ -109,7 +121,14 @@ class ConditionHandler(BaseModel): left = condition.left operate = condition.operate right = condition.right - value_type = condition.type + + # 根据操作符动态推断值类型 + try: + value_type = ConditionHandler.get_value_type_from_operate(operate) + except Exception as e: + logger.error("无法推断操作符 %s 的值类型: %s", operate, e) + msg = f"无法推断操作符 {operate} 的值类型: {e}" + raise ValueError(msg) result = None if value_type == ValueType.STRING: diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 65338a14c6f7a2b0c44b467bc54902c7cdc1a600..3c53d91e7b38816644178544f569e9ac76831a19 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -281,7 +281,7 @@ class CoreCall(BaseModel): # 检查是否包含变量引用语法 match = re.match(r'^\{\{.*?\}\}$', value.value) - if not value.type == ValueType.REFERENCE and match: + if not value.type == ValueType.REFERENCE or not match: return value try: diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index d87932bd4afb866d218717f5dd51f6b98fba9ee0..b73e4be328817937d788aea4336988ac285227ae 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -11,6 +11,7 @@ from pydantic import Field from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.executor.step import StepExecutor +from apps.scheduler.variable.integration import VariableIntegration from apps.schemas.enum_var import EventType, SpecialCallType, StepStatus from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp @@ -114,13 +115,36 @@ class FlowExecutor(BaseExecutor): if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: - # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1].output_data.get("branch_id", "") - if branch_id: - next_steps = await self._find_next_id(self.task.state.step_id+"."+branch_id) - logger.info("[FlowExecutor] 分支ID:%s", branch_id) - else: - logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") + # 如果是choice节点,从变量池中获取分支ID + try: + # Choice 节点保存的变量格式:conversation.{step_id}.branch_id + branch_id_var_reference = f"conversation.{self.task.state.step_id}.branch_id" + branch_id, _ = await VariableIntegration.resolve_variable_reference( + reference=branch_id_var_reference, + user_sub=self.task.ids.user_sub, + flow_id=self.task.state.flow_id, + conversation_id=self.task.ids.conversation_id + ) + + if branch_id: + # 构建带分支ID的edge_from + edge_from = f"{self.task.state.step_id}.{branch_id}" + logger.info("[FlowExecutor] 从变量池获取分支ID:%s,查找边:%s", branch_id, edge_from) + + # 在edges中查找对应的下一个节点 + next_steps = [] + for edge in self.flow.edges: + if edge.edge_from == edge_from: + next_steps.append(edge.edge_to) + logger.info("[FlowExecutor] 找到下一个节点:%s", edge.edge_to) + + if not next_steps: + logger.warning("[FlowExecutor] 没有找到分支 %s 对应的边", edge_from) + else: + logger.warning("[FlowExecutor] 没有找到分支ID变量") + return [] + except Exception as e: + logger.error("[FlowExecutor] 从变量池获取分支ID失败: %s", e) return [] else: next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 106203d1c35fc82676a81a5b331ec813410e2d7f..466abc575d7f0bc109b5fd252a493f2c4687f2b4 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -118,6 +118,18 @@ class StepExecutor(BaseExecutor): if self.step.step.params: params.update(self.step.step.params) + # 对于需要扁平化处理的Call类型,将input_parameters中的内容提取到顶级 + if self._call_id in ["Choice", "Code", "DirectReply"] and "input_parameters" in params: + # 提取input_parameters中的所有字段到顶级 + input_params = params.get("input_parameters", {}) + if isinstance(input_params, dict): + # 将input_parameters中的字段提取到顶级 + for key, value in input_params.items(): + params[key] = value + # 移除input_parameters,避免重复 + params.pop("input_parameters", None) + logger.info(f"[StepExecutor] 对 {self._call_id} 节点进行参数扁平化处理") + try: self.obj = await call_cls.instance(self, self.node, **params) except Exception: @@ -234,11 +246,9 @@ class StepExecutor(BaseExecutor): if use_direct_format: # 配置允许的节点类型保持原有格式:conversation.key var_prefix = "" - logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用直接变量格式") else: # 其他节点使用格式:conversation.node_id.key var_prefix = f"{self.step.step_id}." - logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用带前缀变量格式") # 保存每个output_parameter到变量池 saved_count = 0