diff --git a/apps/routers/variable.py b/apps/routers/variable.py index 7a68e69d8034d335bea4ccb29219ec4fc147e73b..0c3ee420687437e81acfbb9b1fd50f1a71081981 100644 --- a/apps/routers/variable.py +++ b/apps/routers/variable.py @@ -67,25 +67,10 @@ async def _get_predecessor_node_variables( if node_id != current_step_id: # 不是当前节点的变量 variables.append(var) else: - # 配置阶段:优先使用缓存,降级到实时解析 try: - # 尝试使用优化的缓存服务 + # 使用缓存获取变量列表 from apps.services.predecessor_cache_service import PredecessorCacheService - # 1. 先从flow池中查找已存在的前置节点变量 - flow_pool = await pool_manager.get_flow_pool(flow_id) - if flow_pool: - flow_conversation_vars = await flow_pool.list_variables() - - # 筛选出前置节点的输出变量(格式为 node_id.key) - for var in flow_conversation_vars: - var_name = var.name - if "." in var_name and not var_name.startswith("system."): - node_id = var_name.split(".")[0] - if node_id != current_step_id: - variables.append(var) - - # 2. 使用优化的缓存服务获取前置节点变量 cached_var_data = await PredecessorCacheService.get_predecessor_variables_optimized( flow_id, current_step_id, user_sub, max_wait_time=5 ) diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index c050482c6aac1b2b9ac94f8e305403002e95915f..1d987dec48559563addd9435da1d0dbfdf565a00 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -62,7 +62,7 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): return True async def _process_condition_value(self, value: Value, call_vars: CallVars, - branch_id: str, value_position: str) -> tuple[bool, str]: + branch_id: str, value_position: str) -> tuple[bool, Value, str]: """处理条件值(左值或右值) Args: @@ -72,24 +72,22 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): value_position: 值的位置描述("左值"或"右值") Returns: - tuple[bool, str]: (是否成功, 错误消息) + tuple[bool, Value, str]: (是否成功, 处理后的Value对象,错误消息) """ # 处理reference类型 if value.type == ValueType.REFERENCE: try: - value.value = await self._resolve_variables_in_config(value.value, call_vars) + resolved_value = await self._resolve_single_value(value, call_vars) except Exception as e: - return False, f"{value_position}引用解析失败: {e}" + return False, None, f"{value_position}引用解析失败: {e}" + else: + resolved_value = value - # 检查解析后的值类型 - if value.value is None: - return False, f"{value_position}引用解析后为空" - - # 验证值类型 - if not ConditionHandler.check_value_type(value.value, value.type): - return False, f"{value_position}类型不匹配: {value.value} 应为 {value.type.value}" + # 对于非引用类型,验证值类型 + if not ConditionHandler.check_value_type(resolved_value): + return False, None, f"{value_position}类型不匹配: {resolved_value.value} 应为 {resolved_value.type}" - return True, "" + return True, resolved_value, "" async def _process_condition(self, condition: Condition, call_vars: CallVars, branch_id: str) -> tuple[bool, str]: @@ -104,16 +102,30 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): tuple[bool, str]: (是否成功, 错误消息) """ # 处理左值 - success, error_msg = await self._process_condition_value( + success, _left, error_msg = await self._process_condition_value( condition.left, call_vars, branch_id, "左值") if not success: return False, error_msg # 处理右值 - success, error_msg = await self._process_condition_value( + success, _right, error_msg = await self._process_condition_value( condition.right, call_vars, branch_id, "右值") if not success: return False, error_msg + + # 检查运算符是否有效 + if condition.operate is None: + return False, "条件缺少运算符" + + # 根据运算符确定期望的值类型 + try: + expected_type = await ConditionHandler.get_value_type_from_operate(condition.operate) + except Exception as e: + return False, f"不支持的运算符: {condition.operate}" + + # 检查类型是否与运算符匹配 + if not expected_type == _left.type == _right.type: + return False, f"左值类型 {_left.type.value} 与运算符 {condition.operate} 不匹配" return True, "" diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index 6b8fa3fc93bc8ec5ff5b1fc5f7641b4c7e89370b..6151b55d407dc99884888523bb9201f13b063af5 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -30,7 +30,7 @@ class ConditionHandler(BaseModel): @staticmethod async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate) -> ValueType: - """获取右值的类型""" + """根据逻辑运算符获取值的类型""" if isinstance(operate, NumberOperate): return ValueType.NUMBER if operate in [ @@ -58,17 +58,17 @@ class ConditionHandler(BaseModel): return None @staticmethod - def check_value_type(value: Value, expected_type: ValueType) -> bool: + def check_value_type(value: Value) -> bool: """检查值的类型是否符合预期""" - if expected_type == ValueType.STRING and isinstance(value.value, str): + if value.type == ValueType.STRING and isinstance(value.value, str): return True - if expected_type == ValueType.NUMBER and isinstance(value.value, (int, float)): + if value.type == ValueType.NUMBER and isinstance(value.value, (int, float)): return True - if expected_type == ValueType.LIST and isinstance(value.value, list): + if value.type == ValueType.LIST and isinstance(value.value, list): return True - if expected_type == ValueType.DICT and isinstance(value.value, dict): + if value.type == ValueType.DICT and isinstance(value.value, dict): return True - if expected_type == ValueType.BOOL and isinstance(value.value, bool): + if value.type == ValueType.BOOL and isinstance(value.value, bool): return True return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index 332a33c704a99e1791c78b4678e405a754e93fdc..6435b6a34c12e2a5f01ab7334c0b03152ad75291 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -15,6 +15,7 @@ from apps.schemas.parameters import ( ValueType, ) from apps.scheduler.call.core import DataBase +from apps.scheduler.variable.type import VariableType class Logic(str, Enum): @@ -60,3 +61,30 @@ class ChoiceOutput(DataBase): """Choice Call的输出""" branch_id: str = Field(description="分支ID", default="") + + +def convert_variable_type_to_value_type(var_type: VariableType) -> ValueType | None: + """将VariableType转换为ValueType + + Args: + var_type: 变量类型 + + Returns: + ValueType: 对应的值类型,如果无法转换则返回None + """ + + type_mapping = { + VariableType.STRING: ValueType.STRING, + VariableType.NUMBER: ValueType.NUMBER, + VariableType.BOOLEAN: ValueType.BOOL, + VariableType.OBJECT: ValueType.DICT, + VariableType.ARRAY_ANY: ValueType.LIST, + VariableType.ARRAY_STRING: ValueType.LIST, + VariableType.ARRAY_NUMBER: ValueType.LIST, + VariableType.ARRAY_OBJECT: ValueType.LIST, + VariableType.ARRAY_FILE: ValueType.LIST, + VariableType.ARRAY_BOOLEAN: ValueType.LIST, + VariableType.ARRAY_SECRET: ValueType.LIST, + } + + return type_mapping.get(var_type) diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index c5322b0485a082b2ac1f5dfac3be7c6bbd366f96..65338a14c6f7a2b0c44b467bc54902c7cdc1a600 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -18,6 +18,7 @@ from apps.llm.reasoning import ReasoningLLM from apps.scheduler.variable.integration import VariableIntegration from apps.schemas.enum_var import CallOutputType from apps.schemas.pool import NodePool +from apps.schemas.parameters import ValueType from apps.schemas.scheduler import ( CallError, CallIds, @@ -30,6 +31,7 @@ from apps.schemas.task import FlowStepHistory if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor + from apps.scheduler.call.choice.schema import Value logger = logging.getLogger(__name__) @@ -194,7 +196,7 @@ class CoreCall(BaseModel): if isinstance(config, dict): if "reference" in config: # 解析变量引用 - resolved_value = await VariableIntegration.resolve_variable_reference( + resolved_value, _ = await VariableIntegration.resolve_variable_reference( config["reference"], user_sub=call_vars.ids.user_sub, flow_id=call_vars.ids.flow_id, @@ -249,7 +251,7 @@ class CoreCall(BaseModel): for match in matches: try: # 解析变量引用 - resolved_value = await VariableIntegration.resolve_variable_reference( + resolved_value, _ = await VariableIntegration.resolve_variable_reference( match.strip(), user_sub=call_vars.ids.user_sub, flow_id=call_vars.ids.flow_id, @@ -264,6 +266,36 @@ class CoreCall(BaseModel): return resolved_text + async def _resolve_single_value(self, value: "Value", call_vars: CallVars) -> "Value": + """解析Value对象中的变量引用({{...}} 语法) + + Args: + value: Value对象 + call_vars: Call变量 + + Returns: + 解析后的Value对象 + """ + # 运行时导入避免循环依赖 + from apps.scheduler.call.choice.schema import Value + + # 检查是否包含变量引用语法 + match = re.match(r'^\{\{.*?\}\}$', value.value) + if not value.type == ValueType.REFERENCE and match: + return value + + try: + # 解析变量引用 + resolved_value, resolved_type = await VariableIntegration.resolve_variable_reference( + match.group().strip(), + user_sub=call_vars.ids.user_sub, + flow_id=call_vars.ids.flow_id, + conversation_id=call_vars.ids.conversation_id + ) + except Exception as e: + logger.warning(f"[CoreCall] 解析变量引用 '{match.group()}' 失败: {e}") + + return Value(value=resolved_value, type=resolved_type) async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py index a33385a176513a8717d98a59bcaf5f028ce8a47f..56c0f612d1f9002ec46a64f3677296e280fc7959 100644 --- a/apps/scheduler/variable/integration.py +++ b/apps/scheduler/variable/integration.py @@ -1,7 +1,7 @@ """变量解析与工作流调度器集成""" import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from apps.scheduler.variable.parser import VariableParser from apps.scheduler.variable.pool_manager import get_pool_manager @@ -74,7 +74,7 @@ class VariableIntegration: user_sub: str, flow_id: Optional[str] = None, conversation_id: Optional[str] = None - ) -> Any: + ) -> Tuple[Any, Any]: """解析单个变量引用 Args: @@ -84,7 +84,7 @@ class VariableIntegration: conversation_id: 对话ID Returns: - Any: 解析后的变量值 + Tuple[Any, Any]: 解析后的变量值和变量类型 """ try: parser = VariableParser( @@ -97,9 +97,9 @@ class VariableIntegration: clean_reference = reference.strip("{}") # 使用解析器解析变量引用 - resolved_value = await parser._resolve_variable_reference(clean_reference) + resolved_value, resolved_type = await parser._resolve_variable_reference(clean_reference) - return resolved_value + return resolved_value, resolved_type except Exception as e: logger.error(f"解析变量引用失败: {reference}, 错误: {e}") diff --git a/apps/scheduler/variable/parser.py b/apps/scheduler/variable/parser.py index 9eb0dcfaff90d5a0254c4f13917197b12c13ef95..cd0c58358689faf859cafcc2b3c34940b25c8c49 100644 --- a/apps/scheduler/variable/parser.py +++ b/apps/scheduler/variable/parser.py @@ -78,14 +78,14 @@ class VariableParser: return result - async def _resolve_variable_reference(self, reference: str) -> Any: + async def _resolve_variable_reference(self, reference: str) -> Tuple[Any, Any]: """解析变量引用 Args: reference: 变量引用字符串(不含花括号) Returns: - Any: 变量值 + Tuple[Any, Any]: 变量值, 变量类型 """ pool_manager = await self._get_pool_manager() @@ -125,7 +125,7 @@ class VariableParser: conversation_id=self.conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None ) if variable: - return variable.value + return variable.value, variable.var_type except: pass # 如果找不到,继续使用原有逻辑 @@ -147,6 +147,7 @@ class VariableParser: # 获取变量值 value = variable.value + var_type = variable.var_type # 如果有嵌套路径,继续解析 for path_part in path_parts[1:]: @@ -160,7 +161,7 @@ class VariableParser: else: raise ValueError(f"无法访问路径: {var_path}") - return value + return value, var_type async def extract_variables(self, template: str) -> List[str]: """提取模板中的所有变量引用 diff --git a/apps/scheduler/variable/pool_base.py b/apps/scheduler/variable/pool_base.py index 83acdffdceaa4409142d56c9c96675ebedf93680..b73ba3f1b3c1b036c282df0af4d0ae28e7264eea 100644 --- a/apps/scheduler/variable/pool_base.py +++ b/apps/scheduler/variable/pool_base.py @@ -1,7 +1,7 @@ import logging import asyncio from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set, Tuple from datetime import datetime, UTC from apps.common.mongo import MongoDB @@ -250,37 +250,6 @@ class BaseVariablePool(ABC): def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): """添加池特定的查询条件""" pass - - async def resolve_variable_reference(self, reference: str) -> Any: - """解析变量引用""" - # 移除 {{ 和 }} - clean_ref = reference.strip("{}").strip() - - # 解析变量路径 - path_parts = clean_ref.split(".") - var_name = path_parts[0] - - # 获取变量 - variable = await self.get_variable(var_name) - if not variable: - raise ValueError(f"变量不存在: {var_name}") - - # 获取变量值 - value = variable.value - - # 处理嵌套路径 - for path_part in path_parts[1:]: - if isinstance(value, dict): - value = value.get(path_part) - elif isinstance(value, list) and path_part.isdigit(): - try: - value = value[int(path_part)] - except IndexError: - value = None - else: - raise ValueError(f"无法访问路径: {clean_ref}") - - return value class UserVariablePool(BaseVariablePool): diff --git a/apps/services/flow.py b/apps/services/flow.py index f84f5b832c02fd4bd6a704b4a25dec80db0dfada..74639d1ad8f0db154bd979c7b7704b87a2783121 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -257,8 +257,8 @@ class FlowManager: debug=flow_config.debug, ) for node_id, node_config in flow_config.steps.items(): - # 对于Code节点,直接使用保存的完整params作为parameters - if node_config.type == "Code": + # TODO 新增标识位区分Node是否允许用户定义output parameters,如果output固定由节点生成,则走else逻辑 + if node_config.type == "Code" or node_config.type == "DirectReply" or node_config.type == "Choice": parameters = node_config.params # 直接使用保存的完整params else: # 其他节点:使用原有逻辑 @@ -389,11 +389,7 @@ class FlowManager: debug=flow_item.debug, ) for node_item in flow_item.nodes: - # 对于Code节点,保存完整的parameters;其他节点只保存input_parameters - if node_item.call_id == "Code": - params = node_item.parameters # 保存完整的parameters(包含input_parameters、output_parameters以及code配置) - else: - params = node_item.parameters.get("input_parameters", {}) # 其他节点只保存input_parameters + params = node_item.parameters flow_config.steps[node_item.step_id] = Step( type=node_item.call_id,