diff --git a/apps/main.py b/apps/main.py index d0aea344f570fc5fcb923a23a767362b92e8f347..caaa31fc640a9f0d62c198cb72c36371c510b478 100644 --- a/apps/main.py +++ b/apps/main.py @@ -153,6 +153,9 @@ async def init_resources() -> None: from apps.services.predecessor_cache_service import PredecessorCacheService, periodic_cleanup_background_tasks await PredecessorCacheService.initialize_redis() + # 项目启动时清空所有前置节点缓存,确保使用最新的算法逻辑 + await PredecessorCacheService.clear_all_predecessor_cache() + # 启动定期清理任务 global _cleanup_task _cleanup_task = asyncio.create_task(start_periodic_cleanup()) diff --git a/apps/routers/variable.py b/apps/routers/variable.py index 0c3ee420687437e81acfbb9b1fd50f1a71081981..ea1a22e0ff579aa8ead43f8bd7efa1f65ced7aa5 100644 --- a/apps/routers/variable.py +++ b/apps/routers/variable.py @@ -2,7 +2,8 @@ """FastAPI 变量管理 API""" import logging -from typing import Annotated, List, Optional, Dict +from typing import Annotated, List, Optional, Dict, Any +from datetime import datetime, UTC from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi.responses import JSONResponse @@ -16,6 +17,9 @@ from apps.scheduler.variable.parser import VariableParser from apps.schemas.response_data import ResponseData from apps.services.flow import FlowManager +import json +import re + logger = logging.getLogger(__name__) router = APIRouter( @@ -136,14 +140,14 @@ class CreateVariableRequest(BaseModel): name: str = Field(description="变量名称") var_type: VariableType = Field(description="变量类型") scope: VariableScope = Field(description="变量作用域") - value: Optional[str] = Field(default=None, description="变量值") + value: Optional[Any] = Field(default=None, description="变量值") description: Optional[str] = Field(default=None, description="变量描述") flow_id: Optional[str] = Field(default=None, description="流程ID(环境级和对话级变量必需)") class UpdateVariableRequest(BaseModel): """更新变量请求""" - value: Optional[str] = Field(default=None, description="新的变量值") + value: Optional[Any] = Field(default=None, description="新的变量值") var_type: Optional[VariableType] = Field(default=None, description="新的变量类型") description: Optional[str] = Field(default=None, description="新的变量描述") @@ -206,6 +210,17 @@ async def create_variable( detail="不允许创建系统级变量" ) + # 类型转换和验证 + converted_value = None + if request.value is not None: + try: + converted_value = convert_value_by_type(request.value, request.var_type) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"变量值类型转换失败: {str(e)}" + ) + pool_manager = await get_pool_manager() # 根据作用域获取合适的变量池 @@ -252,7 +267,7 @@ async def create_variable( variable = await pool.add_conversation_template( name=request.name, var_type=request.var_type, - default_value=request.value, + default_value=converted_value, description=request.description, created_by=user_sub ) @@ -261,7 +276,7 @@ async def create_variable( variable = await pool.add_variable( name=request.name, var_type=request.var_type, - value=request.value, + value=converted_value, description=request.description, created_by=user_sub ) @@ -353,10 +368,34 @@ async def update_variable( detail="无法获取变量池" ) + # 类型转换和验证(仅当提供了新值和新类型时) + converted_value = request.value + if request.value is not None and request.var_type is not None: + try: + converted_value = convert_value_by_type(request.value, request.var_type) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"变量值类型转换失败: {str(e)}" + ) + elif request.value is not None: + # 如果只提供了值但没有类型,需要获取现有变量的类型 + try: + existing_variable = await pool.get_variable(name) + converted_value = convert_value_by_type(request.value, existing_variable.metadata.var_type) + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"变量值类型转换失败: {str(e)}" + ) + except Exception: + # 如果获取现有变量失败,使用原值 + converted_value = request.value + # 更新变量 variable = await pool.update_variable( name=name, - value=request.value, + value=converted_value, var_type=request.var_type, description=request.description ) @@ -965,4 +1004,342 @@ async def _create_node_output_variables(node, user_sub: str) -> List: except Exception as e: logger.error(f"创建节点输出变量失败: {e}") - return [] \ No newline at end of file + return [] + + +def convert_string_value_by_type(value: str, var_type: VariableType) -> Any: + """ + 根据变量类型将字符串值转换为对应的Python类型 + + Args: + value: 字符串格式的值 + var_type: 变量类型 + + Returns: + 转换后的值 + + Raises: + ValueError: 当值无法转换为指定类型时 + """ + if value is None: + return None + + try: + match var_type: + case VariableType.STRING | VariableType.SECRET: + return str(value) + + case VariableType.NUMBER: + # 尝试转换为数字 + if isinstance(value, str) and value.strip() == "": + return 0 + # 如果包含小数点,转换为float,否则转换为int + if '.' in str(value): + return float(value) + else: + return int(value) + + case VariableType.BOOLEAN: + # 处理布尔值转换 + if isinstance(value, bool): + return value + if isinstance(value, str): + lower_value = value.lower().strip() + if lower_value in ('true', '1', 'yes', 'on'): + return True + elif lower_value in ('false', '0', 'no', 'off'): + return False + else: + raise ValueError(f"无法将 '{value}' 转换为布尔值") + return bool(value) + + case VariableType.OBJECT: + # 处理对象类型 + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"无法解析JSON对象: {e}") + return dict(value) + + case VariableType.ARRAY_STRING: + # 处理字符串数组 + if isinstance(value, list): + return [str(item) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + else: + # 尝试按逗号分割 + return [item.strip() for item in value.split(',') if item.strip()] + except json.JSONDecodeError: + # 按逗号分割 + return [item.strip() for item in value.split(',') if item.strip()] + return list(value) + + case VariableType.ARRAY_NUMBER: + # 处理数字数组 + if isinstance(value, list): + return [float(item) if '.' in str(item) else int(item) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + result = [] + for item in parsed: + if isinstance(item, (int, float)): + result.append(item) + else: + result.append(float(item) if '.' in str(item) else int(item)) + return result + else: + raise ValueError("期望数组格式") + except (json.JSONDecodeError, ValueError) as e: + raise ValueError(f"无法解析数字数组: {e}") + return list(value) + + case VariableType.ARRAY_BOOLEAN: + # 处理布尔数组 + if isinstance(value, list): + return [convert_string_value_by_type(str(item), VariableType.BOOLEAN) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [convert_string_value_by_type(str(item), VariableType.BOOLEAN) for item in parsed] + else: + raise ValueError("期望数组格式") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析布尔数组: {e}") + return list(value) + + case VariableType.ARRAY_OBJECT: + # 处理对象数组 + if isinstance(value, list): + return [convert_string_value_by_type(json.dumps(item) if not isinstance(item, str) else item, VariableType.OBJECT) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return parsed # 已经是解析后的对象数组 + else: + raise ValueError("期望数组格式") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析对象数组: {e}") + return list(value) + + case VariableType.FILE | VariableType.ARRAY_FILE: + # 文件类型暂时保持字符串格式 + return str(value) + + case VariableType.ARRAY_SECRET: + # 密钥数组 + if isinstance(value, list): + return [str(item) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + else: + raise ValueError("期望数组格式") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析密钥数组: {e}") + return list(value) + + case _: + # 默认返回字符串 + return str(value) + + except (ValueError, TypeError, json.JSONDecodeError) as e: + raise ValueError(f"无法将值 '{value}' 转换为类型 '{var_type.value}': {str(e)}") + + +def convert_value_by_type(value: Any, var_type: VariableType) -> Any: + """ + 根据变量类型将任意类型的值转换为对应的Python类型 + + Args: + value: 任意类型的值 + var_type: 变量类型 + + Returns: + 转换后的值 + + Raises: + ValueError: 当值无法转换为指定类型时 + """ + if value is None: + return None + + try: + match var_type: + case VariableType.STRING | VariableType.SECRET: + # 如果已经是字符串,直接返回 + if isinstance(value, str): + return value + # 如果是复杂类型,JSON序列化 + elif isinstance(value, (dict, list)): + return json.dumps(value) + else: + return str(value) + + case VariableType.NUMBER: + # 如果已经是数字,直接返回 + if isinstance(value, (int, float)): + return value + # 尝试转换为数字 + if isinstance(value, str): + if value.strip() == "": + return 0 + # 如果包含小数点,转换为float,否则转换为int + if '.' in value or 'e' in value.lower(): + return float(value) + else: + return int(value) + elif isinstance(value, bool): + return int(value) + else: + return float(value) + + case VariableType.BOOLEAN: + # 如果已经是布尔值,直接返回 + if isinstance(value, bool): + return value + # 处理布尔值转换 + if isinstance(value, str): + lower_value = value.lower().strip() + if lower_value in ('true', '1', 'yes', 'on'): + return True + elif lower_value in ('false', '0', 'no', 'off'): + return False + else: + return bool(value) + else: + return bool(value) + + case VariableType.OBJECT: + # 如果已经是字典,直接返回 + if isinstance(value, dict): + return value + # 处理对象类型 + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as e: + raise ValueError(f"无法解析JSON对象: {e}") + else: + # 尝试转换为字典 + if hasattr(value, '__dict__'): + return value.__dict__ + else: + return dict(value) + + case VariableType.ARRAY_STRING: + # 如果已经是列表,检查元素类型 + if isinstance(value, list): + return [str(item) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + else: + # 尝试按逗号分割 + return [item.strip() for item in value.split(',') if item.strip()] + except json.JSONDecodeError: + # 按逗号分割 + return [item.strip() for item in value.split(',') if item.strip()] + else: + return [str(value)] + + case VariableType.ARRAY_NUMBER: + # 如果已经是列表,检查元素类型 + if isinstance(value, list): + result = [] + for item in value: + if isinstance(item, (int, float)): + result.append(item) + else: + result.append(float(item) if '.' in str(item) else int(item)) + return result + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + result = [] + for item in parsed: + if isinstance(item, (int, float)): + result.append(item) + else: + result.append(float(item) if '.' in str(item) else int(item)) + return result + else: + raise ValueError("期望数组格式") + except (json.JSONDecodeError, ValueError) as e: + raise ValueError(f"无法解析数字数组: {e}") + else: + return [value] + + case VariableType.ARRAY_BOOLEAN: + # 如果已经是列表,检查元素类型 + if isinstance(value, list): + return [convert_value_by_type(item, VariableType.BOOLEAN) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [convert_value_by_type(item, VariableType.BOOLEAN) for item in parsed] + else: + raise ValueError("期望数组格式") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析布尔数组: {e}") + else: + return [value] + + case VariableType.ARRAY_OBJECT: + # 如果已经是列表,检查元素类型 + if isinstance(value, list): + return [convert_value_by_type(item, VariableType.OBJECT) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return parsed # 已经是解析后的对象数组 + else: + raise ValueError("期望数组格式") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析对象数组: {e}") + else: + return [value] + + case VariableType.FILE | VariableType.ARRAY_FILE: + # 文件类型暂时保持原样 + return value + + case VariableType.ARRAY_SECRET: + # 密钥数组 + if isinstance(value, list): + return [str(item) for item in value] + if isinstance(value, str): + try: + parsed = json.loads(value) + if isinstance(parsed, list): + return [str(item) for item in parsed] + else: + raise ValueError("期望数组格式") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析密钥数组: {e}") + else: + return [str(value)] + + case _: + # 默认返回字符串 + return str(value) + + except (ValueError, TypeError, json.JSONDecodeError) as e: + raise ValueError(f"无法将值 '{value}' (类型: {type(value).__name__}) 转换为类型 '{var_type.value}': {str(e)}") \ No newline at end of file diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index ee759c8e64ace67b9c806e04cb43cfbd554cade0..ab7d581da1cfc2d84d3d36805cd358fdcb09bf10 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -83,12 +83,118 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): else: resolved_value = value - # 对于非引用类型,验证值类型 + # 对于非引用类型,先进行类型转换,然后验证值类型 + try: + resolved_value = self._convert_value_to_expected_type(resolved_value) + except Exception as e: + return False, None, f"{value_position}类型转换失败: {e}" + if not ConditionHandler.check_value_type(resolved_value): return False, None, f"{value_position}类型不匹配: {resolved_value.value} 应为 {resolved_value.type}" return True, resolved_value, "" + def _convert_value_to_expected_type(self, value: Value) -> Value: + """根据期望的类型转换值 + + Args: + value: 需要转换的Value对象 + + Returns: + Value: 转换后的Value对象 + + Raises: + ValueError: 当值无法转换为指定类型时 + """ + if value.value is None: + return value + + # 如果已经是正确的类型,直接返回 + if ConditionHandler.check_value_type(value): + return value + + # 创建新的Value对象进行转换 + converted_value = Value(type=value.type, value=value.value) + + try: + if value.type == ValueType.NUMBER: + # 转换为数字 + if isinstance(value.value, str): + if value.value.strip() == "": + converted_value.value = 0 + elif '.' in value.value or 'e' in value.value.lower(): + converted_value.value = float(value.value) + else: + converted_value.value = int(value.value) + elif isinstance(value.value, bool): + converted_value.value = int(value.value) + elif isinstance(value.value, (int, float)): + converted_value.value = value.value + else: + converted_value.value = float(value.value) + + elif value.type == ValueType.STRING: + # 转换为字符串 + if isinstance(value.value, (dict, list)): + import json + converted_value.value = json.dumps(value.value) + else: + converted_value.value = str(value.value) + + elif value.type == ValueType.BOOL: + # 转换为布尔值 + if isinstance(value.value, str): + lower_value = value.value.lower().strip() + if lower_value in ('true', '1', 'yes', 'on'): + converted_value.value = True + elif lower_value in ('false', '0', 'no', 'off'): + converted_value.value = False + else: + converted_value.value = bool(value.value) + else: + converted_value.value = bool(value.value) + + elif value.type == ValueType.LIST: + # 转换为列表 + if isinstance(value.value, str): + import json + try: + converted_value.value = json.loads(value.value) + if not isinstance(converted_value.value, list): + # 如果不是列表,尝试按逗号分割 + converted_value.value = [item.strip() for item in value.value.split(',') if item.strip()] + except json.JSONDecodeError: + # 按逗号分割 + converted_value.value = [item.strip() for item in value.value.split(',') if item.strip()] + elif isinstance(value.value, list): + converted_value.value = value.value + else: + converted_value.value = [value.value] + + elif value.type == ValueType.DICT: + # 转换为字典 + if isinstance(value.value, str): + import json + try: + converted_value.value = json.loads(value.value) + if not isinstance(converted_value.value, dict): + raise ValueError(f"解析后的值不是字典类型: {type(converted_value.value)}") + except json.JSONDecodeError as e: + raise ValueError(f"无法解析JSON字典: {e}") + elif isinstance(value.value, dict): + converted_value.value = value.value + else: + raise ValueError(f"无法将 {type(value.value)} 转换为字典") + + else: + # 对于其他类型,保持原值 + converted_value.value = value.value + + except (ValueError, TypeError, ImportError) as e: + raise ValueError(f"无法将值 '{value.value}' 转换为类型 '{value.type.value}': {str(e)}") + + return converted_value + async def _process_condition(self, condition: Condition, call_vars: CallVars, branch_id: str) -> tuple[bool, str]: """处理单个条件 diff --git a/apps/scheduler/call/code/code.py b/apps/scheduler/call/code/code.py index 18b6577b4589d548bfdf4213b31eee80b1d371af..5a331941eb94eeab47e322d7865ad7f5ec21e9af 100644 --- a/apps/scheduler/call/code/code.py +++ b/apps/scheduler/call/code/code.py @@ -57,13 +57,30 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): "permissions": ["execute"] } - # 处理输入参数 - 使用基类的变量解析功能 + # 处理输入参数 - 直接从input_parameters读取 input_arg = {} + + logger.info(f"[Code] 开始处理输入参数") + logger.info(f"[Code] self.input_parameters: {self.input_parameters}") + logger.info(f"[Code] self.output_parameters: {self.output_parameters}") + if self.input_parameters: + logger.info(f"[Code] 开始解析 {len(self.input_parameters)} 个输入参数") # 解析每个输入参数 for param_name, param_config in self.input_parameters.items(): - resolved_value = await self._resolve_variables_in_config(param_config, call_vars) - input_arg[param_name] = resolved_value + logger.info(f"[Code] 正在解析参数: {param_name}, 配置: {param_config}") + try: + resolved_value = await self._resolve_variables_in_config(param_config, call_vars) + input_arg[param_name] = resolved_value + logger.info(f"[Code] 参数 {param_name} 解析成功: {resolved_value}") + except Exception as e: + logger.error(f"[Code] 参数 {param_name} 解析失败: {e}") + # 使用原始配置作为降级 + input_arg[param_name] = param_config + else: + logger.warning("[Code] input_parameters 为空或未配置") + + logger.info(f"[Code] 最终的 input_arg: {input_arg}") return CodeInput( code=self.code, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 2ff4f3d39633b26d792b0b2c2a257107b4874d30..1a3876ec35885c3570bd12245f639d90c5d3831c 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -27,4 +27,6 @@ class MCPAgentExecutor(BaseExecutor): logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State if self.task.state: - self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + context_objects = await TaskManager.get_context_by_task_id(self.task.id) + # 将对象转换为字典以保持与系统其他部分的一致性 + self.task.context = [context.model_dump(exclude_none=True, by_alias=True) for context in context_objects] diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index b73e4be328817937d788aea4336988ac285227ae..b6ec9d750752b8797861bafd05f953b12bbf259f 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -54,7 +54,9 @@ class FlowExecutor(BaseExecutor): logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State if self.task.state: - self.task.context = await TaskManager.get_context_by_task_id(self.task.id) + context_objects = await TaskManager.get_context_by_task_id(self.task.id) + # 将对象转换为字典以保持与系统其他部分的一致性 + self.task.context = [context.model_dump(exclude_none=True, by_alias=True) for context in context_objects] else: # 创建ExecutorState self.task.state = ExecutorState( diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 466abc575d7f0bc109b5fd252a493f2c4687f2b4..fe5e65d5b9f6b6e820df22cd51437798e9243b54 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -7,6 +7,7 @@ import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime from typing import Any +import json from pydantic import ConfigDict @@ -119,7 +120,7 @@ class StepExecutor(BaseExecutor): params.update(self.step.step.params) # 对于需要扁平化处理的Call类型,将input_parameters中的内容提取到顶级 - if self._call_id in ["Choice", "Code", "DirectReply"] and "input_parameters" in params: + if self._call_id in ["Choice", "DirectReply"] and "input_parameters" in params: # 提取input_parameters中的所有字段到顶级 input_params = params.get("input_parameters", {}) if isinstance(input_params, dict): @@ -219,45 +220,54 @@ class StepExecutor(BaseExecutor): async def _save_output_parameters_to_variables(self, output_data: str | dict[str, Any]) -> None: - """保存节点输出参数到变量池""" + """保存输出参数到变量池,并进行类型验证""" try: - # 检查是否有output_parameters配置 + # 获取当前步骤的output_parameters配置 output_parameters = None if self.step.step.params and isinstance(self.step.step.params, dict): output_parameters = self.step.step.params.get("output_parameters", {}) + elif hasattr(self.step, 'output_parameters'): + output_parameters = self.step.output_parameters if not output_parameters or not isinstance(output_parameters, dict): + logger.debug(f"[StepExecutor] 步骤 {self.step.step_id} 没有配置output_parameters") return - # 确保output_data是字典格式 + # 解析输出数据 if isinstance(output_data, str): - # 如果是字符串,包装成字典 - data_dict = {"text": output_data} - else: - data_dict = output_data if isinstance(output_data, dict) else {} - - # 确定变量名前缀(根据配置决定是否使用直接格式) - use_direct_format = should_use_direct_conversation_format( - call_id=self._call_id, - step_name=self.step.step.name, - step_id=self.step.step_id - ) - - if use_direct_format: - # 配置允许的节点类型保持原有格式:conversation.key - var_prefix = "" + try: + data_dict = json.loads(output_data) + except json.JSONDecodeError: + logger.warning(f"[StepExecutor] 无法解析输出数据为JSON: {output_data}") + data_dict = {"raw_output": output_data} else: - # 其他节点使用格式:conversation.node_id.key - var_prefix = f"{self.step.step_id}." + data_dict = output_data + + # 构造变量名前缀 + var_prefix = f"{self.step.step_id}." - # 保存每个output_parameter到变量池 + # 保存每个output_parameter到变量池,并进行类型验证 saved_count = 0 + failed_params = [] + for param_name, param_config in output_parameters.items(): try: # 获取参数值 param_value = self._extract_value_from_output_data(param_name, data_dict, param_config) if param_value is not None: + # 获取期望的类型 + expected_type = param_config.get("type", "string") + + # 进行类型验证 + if not self._validate_output_value_type(param_value, expected_type): + error_msg = (f"输出参数 '{param_name}' 类型不匹配。" + f"期望: {expected_type}, " + f"实际: {type(param_value).__name__}({param_value})") + logger.error(f"[StepExecutor] {error_msg}") + failed_params.append(f"{param_name}: {error_msg}") + continue + # 构造变量名 var_name = f"{var_prefix}{param_name}" @@ -265,7 +275,7 @@ class StepExecutor(BaseExecutor): success = await VariableIntegration.save_conversation_variable( var_name=var_name, value=param_value, - var_type=param_config.get("type", "string"), + var_type=expected_type, description=param_config.get("description", ""), user_sub=self.task.ids.user_sub, flow_id=self.task.state.flow_id, # type: ignore[arg-type] @@ -276,16 +286,43 @@ class StepExecutor(BaseExecutor): saved_count += 1 logger.debug(f"[StepExecutor] 已保存输出参数变量: conversation.{var_name} = {param_value}") else: - logger.warning(f"[StepExecutor] 保存输出参数变量失败: {var_name}") + error_msg = f"保存输出参数变量失败: {var_name}" + logger.warning(f"[StepExecutor] {error_msg}") + failed_params.append(f"{param_name}: {error_msg}") except Exception as e: + error_msg = f"处理输出参数失败: {str(e)}" logger.warning(f"[StepExecutor] 保存输出参数 {param_name} 失败: {e}") - + failed_params.append(f"{param_name}: {error_msg}") + + # 如果有失败的参数,将步骤状态设置为失败 + if failed_params: + from apps.schemas.enum_var import StepStatus + self.task.state.status = StepStatus.FAILED # type: ignore[assignment] + + failure_msg = f"输出参数类型验证失败:\n" + "\n".join(failed_params) + logger.error(f"[StepExecutor] 步骤 {self.step.step_id} 执行失败: {failure_msg}") + + # 保存错误信息到任务状态 + if not hasattr(self.task.state, 'error_info') or self.task.state.error_info is None: + self.task.state.error_info = {} + self.task.state.error_info['output_validation_errors'] = failed_params # type: ignore[assignment] + + # 抛出异常以停止工作流执行 + raise ValueError(f"步骤输出参数类型验证失败: {failure_msg}") + if saved_count > 0: logger.info(f"[StepExecutor] 已保存 {saved_count} 个输出参数到变量池") except Exception as e: + # 如果是我们主动抛出的验证错误,重新抛出 + if "类型验证失败" in str(e): + raise logger.error(f"[StepExecutor] 保存输出参数到变量池失败: {e}") + # 对于其他意外错误,也将步骤设置为失败 + from apps.schemas.enum_var import StepStatus + self.task.state.status = StepStatus.FAILED # type: ignore[assignment] + raise def _extract_value_from_output_data(self, param_name: str, output_data: dict[str, Any], param_config: dict) -> Any: """从输出数据中提取参数值""" @@ -316,6 +353,28 @@ class StepExecutor(BaseExecutor): return None + def _validate_output_value_type(self, value: Any, expected_type: str) -> bool: + """验证输出值的类型是否符合期望""" + try: + if expected_type == "string": + return isinstance(value, str) + elif expected_type in ["number", "integer"]: + return isinstance(value, (int, float)) + elif expected_type == "boolean": + return isinstance(value, bool) + elif expected_type == "array": + return isinstance(value, list) + elif expected_type == "object": + return isinstance(value, dict) + elif expected_type == "any": + return True # 任何类型都匹配 + else: + # 对于未知类型,默认接受字符串 + logger.warning(f"[StepExecutor] 未知的期望类型: {expected_type},默认验证为字符串") + return isinstance(value, str) + except Exception: + return False + async def run(self) -> None: """运行单个步骤""" diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 1e936ab1357272c446c29da5b5e1823f53dce1ea..5d83e01cdada76d4d040553ad0a4c47e37e0fe0d 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -188,7 +188,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow=[context.id for context in task.context], + flow=[context.id if hasattr(context, 'id') else context.get('_id', context.get('id', '')) for context in task.context], ) # 检查是否存在group_id diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py index 56c0f612d1f9002ec46a64f3677296e280fc7959..3bfb1b8da179ca1853289445c6983651a5211f6f 100644 --- a/apps/scheduler/variable/integration.py +++ b/apps/scheduler/variable/integration.py @@ -1,10 +1,10 @@ -"""变量解析与工作流调度器集成""" +"""变量系统与外部组件的集成接口""" import logging from typing import Any, Dict, List, Optional, Tuple, Union -from apps.scheduler.variable.parser import VariableParser from apps.scheduler.variable.pool_manager import get_pool_manager +from apps.scheduler.variable.parser import VariableParser from apps.scheduler.variable.type import VariableScope logger = logging.getLogger(__name__) @@ -202,163 +202,6 @@ class VariableIntegration: logger.warning(f"解析模板字符串失败: {e}") # 如果解析失败,返回原始模板 return template - - @staticmethod - async def add_conversation_variable(name: str, - value: Any, - conversation_id: str, - var_type_str: str = "string") -> bool: - """添加对话级变量 - - Args: - name: 变量名 - value: 变量值 - conversation_id: 对话ID(在内部作为flow_id使用) - var_type_str: 变量类型字符串 - - Returns: - bool: 是否添加成功 - """ - try: - from apps.scheduler.variable.type import VariableType - - # 转换变量类型 - var_type = VariableType(var_type_str) - - pool_manager = await get_pool_manager() - # 获取对话变量池(如果不存在会抛出异常) - conversation_pool = await pool_manager.get_conversation_pool(conversation_id) - if not conversation_pool: - logger.error(f"对话变量池不存在: {conversation_id}") - return False - - await conversation_pool.add_variable( - name=name, - var_type=var_type, - value=value, - description=f"对话变量: {name}" - ) - - logger.debug(f"已添加对话变量: {name} = {value}") - return True - except Exception as e: - logger.error(f"添加对话变量失败: {e}") - return False - - @staticmethod - async def update_conversation_variable(name: str, - value: Any, - conversation_id: str) -> bool: - """更新对话级变量 - - Args: - name: 变量名 - value: 新值 - conversation_id: 对话ID - - Returns: - bool: 是否更新成功 - """ - try: - pool_manager = await get_pool_manager() - # 获取对话变量池 - conversation_pool = await pool_manager.get_conversation_pool(conversation_id) - if not conversation_pool: - logger.error(f"对话变量池不存在: {conversation_id}") - return False - - await conversation_pool.update_variable( - name=name, - value=value - ) - - logger.debug(f"已更新对话变量: {name} = {value}") - return True - except Exception as e: - logger.error(f"更新对话变量失败: {e}") - return False - - @staticmethod - async def extract_output_variables(output_data: Dict[str, Any], - conversation_id: str, - step_name: str) -> None: - """从步骤输出中提取变量并设置为对话级变量 - - Args: - output_data: 步骤输出数据 - conversation_id: 对话ID - step_name: 步骤名称 - """ - try: - # 将整个输出作为对象变量存储 - await VariableIntegration.add_conversation_variable( - name=f"step_{step_name}_output", - value=output_data, - conversation_id=conversation_id, - var_type_str="object" - ) - - # 如果输出中有特定的变量定义,也可以单独提取 - if isinstance(output_data, dict): - for key, value in output_data.items(): - if key.startswith("var_"): - # 以 var_ 开头的字段被视为变量定义 - var_name = key[4:] # 移除 var_ 前缀 - await VariableIntegration.add_conversation_variable( - name=var_name, - value=value, - conversation_id=conversation_id, - var_type_str="string" - ) - - except Exception as e: - logger.error(f"提取输出变量失败: {e}") - - @staticmethod - async def clear_conversation_context(conversation_id: str) -> None: - """清理对话上下文中的变量 - - Args: - conversation_id: 对话ID - """ - try: - pool_manager = await get_pool_manager() - # 移除对话变量池 - success = await pool_manager.remove_conversation_pool(conversation_id) - if success: - logger.info(f"已清理对话 {conversation_id} 的变量") - else: - logger.warning(f"对话变量池不存在: {conversation_id}") - except Exception as e: - logger.error(f"清理对话变量失败: {e}") - - @staticmethod - async def validate_variable_references(template: str, - user_sub: str, - flow_id: Optional[str] = None, - conversation_id: Optional[str] = None) -> tuple[bool, list[str]]: - """验证模板中的变量引用是否有效 - - Args: - template: 模板字符串 - user_sub: 用户ID - flow_id: 流程ID - conversation_id: 对话ID - - Returns: - tuple[bool, list[str]]: (是否全部有效, 无效的变量引用列表) - """ - try: - parser = VariableParser( - user_sub=user_sub, - flow_id=flow_id, - conversation_id=conversation_id - ) - - return await parser.validate_template(template) - except Exception as e: - logger.error(f"验证变量引用失败: {e}") - return False, [str(e)] # 注意:原本的 monkey_patch_scheduler 和相关扩展类已被移除 diff --git a/apps/scheduler/variable/parser.py b/apps/scheduler/variable/parser.py index cd0c58358689faf859cafcc2b3c34940b25c8c49..6c5f11df69392cae6d1dabfc6d6d9a5eacc7c905 100644 --- a/apps/scheduler/variable/parser.py +++ b/apps/scheduler/variable/parser.py @@ -90,7 +90,7 @@ class VariableParser: pool_manager = await self._get_pool_manager() # 解析作用域和变量名 - parts = reference.split(".", 1) + parts = reference.strip().split(".", 1) if len(parts) != 2: raise ValueError(f"无效的变量引用格式: {reference}") diff --git a/apps/services/predecessor_cache_service.py b/apps/services/predecessor_cache_service.py index 967077ea96b06982c543cf9ad6306b300cea48d5..b300e07560a22d10ac10e3ef86125f40f9026378 100644 --- a/apps/services/predecessor_cache_service.py +++ b/apps/services/predecessor_cache_service.py @@ -385,22 +385,43 @@ class PredecessorCacheService: @staticmethod def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: - """在工作流中查找前置节点""" + """在工作流中查找所有前置节点(使用BFS算法直到start节点)""" try: + from collections import deque + predecessor_nodes = [] + visited = set() # 避免重复访问节点 + queue = deque([current_step_id]) # BFS队列 + visited.add(current_step_id) - # 遍历边,找到指向当前节点的边 + # 构建边的反向映射,便于快速查找前置节点 + reverse_edges = {} # target_node -> [source_node1, source_node2, ...] for edge in flow_item.edges: - if edge.target_node == current_step_id: - # 找到前置节点 - source_node = next( - (node for node in flow_item.nodes if node.step_id == edge.source_node), - None - ) - if source_node: - predecessor_nodes.append(source_node) + if edge.target_node not in reverse_edges: + reverse_edges[edge.target_node] = [] + reverse_edges[edge.target_node].append(edge.source_node) + + # 构建节点ID到节点对象的映射 + node_map = {node.step_id: node for node in flow_item.nodes} - logger.debug(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点") + # BFS遍历找到所有前置节点 + while queue: + current_node_id = queue.popleft() + + # 获取当前节点的直接前置节点 + if current_node_id in reverse_edges: + for source_node_id in reverse_edges[current_node_id]: + if source_node_id not in visited: + visited.add(source_node_id) + queue.append(source_node_id) + + # 将前置节点添加到结果中(排除当前节点本身) + if source_node_id != current_step_id: + source_node = node_map.get(source_node_id) + if source_node: + predecessor_nodes.append(source_node) + + logger.debug(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点: {[node.step_id for node in predecessor_nodes]}") return predecessor_nodes except Exception as e: @@ -485,4 +506,37 @@ class PredecessorCacheService: except Exception as e: logger.error(f"创建节点输出变量失败: {e}") - return [] \ No newline at end of file + return [] + + @staticmethod + async def clear_all_predecessor_cache(): + """清空所有前置节点缓存(项目启动时使用)""" + try: + if not redis_cache.is_connected(): + logger.warning("Redis未连接,跳过清空缓存") + return + + logger.info("开始清空所有前置节点缓存...") + + # 查找所有前置节点缓存相关的key + cache_patterns = [ + "predecessor_vars:*", # 缓存的变量数据 + "parsing_status:*", # 解析状态 + "flow_hash:*" # Flow哈希值 + ] + + total_deleted = 0 + for pattern in cache_patterns: + try: + keys = await redis_cache._redis.keys(pattern) + if keys: + await redis_cache._redis.delete(*keys) + deleted_count = len(keys) + total_deleted += deleted_count + except Exception as e: + logger.error(f"清空缓存模式 {pattern} 失败: {e}") + + logger.info(f"前置节点缓存清空完成,共删除 {total_deleted} 个缓存key") + + except Exception as e: + logger.error(f"清空所有前置节点缓存失败: {e}") \ No newline at end of file diff --git a/apps/services/task.py b/apps/services/task.py index 81924c58df5de55997c7f236d8f3f15dc93a5607..df37dee55bfbf9a809fb58616e5a61059d8a0135 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -3,7 +3,7 @@ import logging import uuid -from typing import Any +from typing import Any, Dict, List from apps.common.mongo import MongoDB from apps.schemas.record import RecordGroup @@ -94,7 +94,7 @@ class TaskManager: return flow_context_list @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[F]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> List[FlowStepHistory]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context") @@ -105,7 +105,8 @@ class TaskManager: ).sort( "created_at", -1, ).limit(length): - flow_context += [history] + # 将字典转换为 FlowStepHistory 对象 + flow_context.append(FlowStepHistory.model_validate(history)) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return []