From 0da06339e748e419ecf8697ea01eb67aae97b151 Mon Sep 17 00:00:00 2001 From: Ethan-Zhang Date: Wed, 23 Jul 2025 10:03:52 +0800 Subject: [PATCH] =?UTF-8?q?Feat:=20=E8=8A=82=E7=82=B9=E5=88=86=E7=B1=BB&&?= =?UTF-8?q?=E5=8F=98=E9=87=8F=E6=A8=A1=E5=9D=97=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CONVERSATION_VARIABLE_FIX.md | 179 ++++ apps/common/redis_cache.py | 271 ++++++ apps/main.py | 102 ++- apps/routers/flow.py | 17 + apps/routers/variable.py | 624 +++++++++++-- apps/scheduler/call/api/api.py | 8 +- apps/scheduler/call/code/code.py | 184 +++- apps/scheduler/call/code/schema.py | 2 +- apps/scheduler/call/convert/convert.py | 8 +- apps/scheduler/call/core.py | 115 +++ apps/scheduler/call/empty.py | 8 +- apps/scheduler/call/facts/facts.py | 8 +- apps/scheduler/call/graph/graph.py | 8 +- apps/scheduler/call/llm/llm.py | 8 +- apps/scheduler/call/mcp/mcp.py | 8 +- apps/scheduler/call/rag/rag.py | 8 +- apps/scheduler/call/reply/direct_reply.py | 24 +- apps/scheduler/call/slot/slot.py | 8 +- apps/scheduler/call/sql/sql.py | 8 +- apps/scheduler/call/suggest/suggest.py | 8 +- apps/scheduler/call/summary/summary.py | 8 +- apps/scheduler/executor/step.py | 107 +++ apps/scheduler/executor/step_config.py | 79 ++ apps/scheduler/pool/loader/call.py | 4 +- apps/scheduler/pool/loader/flow.py | 97 +- apps/scheduler/variable/README.md | 350 +++++--- apps/scheduler/variable/__init__.py | 8 +- apps/scheduler/variable/base.py | 7 + apps/scheduler/variable/integration.py | 281 +++--- apps/scheduler/variable/parser.py | 212 ++++- apps/scheduler/variable/pool.py | 802 ----------------- apps/scheduler/variable/pool_base.py | 828 ++++++++++++++++++ apps/scheduler/variable/pool_manager.py | 353 ++++++++ apps/scheduler/variable/security.py | 14 +- .../variable/system_variables_example.py | 220 +++++ apps/schemas/config.py | 15 + apps/schemas/enum_var.py | 6 +- apps/schemas/flow_topology.py | 3 +- apps/schemas/pool.py | 4 +- apps/schemas/scheduler.py | 6 +- apps/services/flow.py | 3 +- apps/services/predecessor_cache_service.py | 488 +++++++++++ assets/.config.example.toml | 11 + docs/variable_configuration.md | 133 +++ pyproject.toml | 1 + 45 files changed, 4397 insertions(+), 1249 deletions(-) create mode 100644 CONVERSATION_VARIABLE_FIX.md create mode 100644 apps/common/redis_cache.py create mode 100644 apps/scheduler/executor/step_config.py delete mode 100644 apps/scheduler/variable/pool.py create mode 100644 apps/scheduler/variable/pool_base.py create mode 100644 apps/scheduler/variable/pool_manager.py create mode 100644 apps/scheduler/variable/system_variables_example.py create mode 100644 apps/services/predecessor_cache_service.py create mode 100644 docs/variable_configuration.md diff --git a/CONVERSATION_VARIABLE_FIX.md b/CONVERSATION_VARIABLE_FIX.md new file mode 100644 index 000000000..2a84ca946 --- /dev/null +++ b/CONVERSATION_VARIABLE_FIX.md @@ -0,0 +1,179 @@ +# 对话变量模板问题修复总结 + +## 🔍 **问题描述** + +用户报告创建对话级变量`test`成功后,无法通过API接口查询到: +``` +GET /api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 +``` + +## 🔧 **根本原因分析** + +经过分析发现,问题出现在我之前重构变量架构时的遗漏: + +### 1. **创建API工作正常** +- ✅ 对话变量正确存储到FlowVariablePool的`_conversation_templates`字典中 +- ✅ 数据库持久化成功 + +### 2. **查询API有缺陷** +- ❌ `pool_manager.py`中`list_variables_from_any_pool`方法只处理了`conversation_id`参数 +- ❌ 没有处理`scope=conversation&flow_id=xxx`的查询情况 +- ❌ `get_variable_from_any_pool`方法也有同样问题 + +### 3. **更新删除API有问题** +- ❌ FlowVariablePool的`update_variable`和`delete_variable`方法只在`_variables`字典中查找 +- ❌ 找不到存储在`_conversation_templates`字典中的对话变量模板 + +## 🛠️ **修复方案** + +### 1. **修复查询逻辑** + +#### `list_variables_from_any_pool`方法 +**修复前**: +```python +elif scope == VariableScope.CONVERSATION and conversation_id: + pool = await self.get_conversation_pool(conversation_id) + if pool: + return await pool.list_variables(include_system=False) + return [] +``` + +**修复后**: +```python +elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 使用conversation_id查询对话变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + return await pool.list_variables(include_system=False) + elif flow_id: + # 使用flow_id查询对话变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.list_conversation_templates() + return [] +``` + +#### `get_variable_from_any_pool`方法 +类似的修复,支持通过`flow_id`查询对话变量模板。 + +### 2. **修复创建逻辑** + +#### 修改`create_variable`路由 +**修复前**: +```python +# 创建变量 +variable = await pool.add_variable(...) +``` + +**修复后**: +```python +# 根据作用域创建不同类型的变量 +if request.scope == VariableScope.CONVERSATION: + # 创建对话变量模板 + variable = await pool.add_conversation_template(...) +else: + # 创建其他类型的变量 + variable = await pool.add_variable(...) +``` + +### 3. **增强FlowVariablePool功能** + +为FlowVariablePool添加了重写的方法,支持多字典操作: + +#### `update_variable`方法 +- 在环境变量、系统变量模板、对话变量模板中按顺序查找 +- 找到变量后执行更新操作 +- 正确持久化到数据库 + +#### `delete_variable`方法 +- 支持删除存储在不同字典中的变量 +- 保留权限检查(系统变量模板不允许删除) + +#### `get_variable`方法 +- 统一的变量查找接口 +- 支持跨字典查找 + +## ✅ **修复验证** + +### 现在支持的完整工作流程: + +#### 1. **Flow级别操作**(变量模板管理) +```bash +# 创建对话变量模板 +POST /api/variable/create +{ + "name": "test", + "var_type": "string", + "scope": "conversation", + "value": "123", + "description": "321", + "flow_id": "52e069c7-5556-42af-bdfc-63f4dc2dcd28" +} + +# 查询对话变量模板 +GET /api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 + +# 更新对话变量模板 +PUT /api/variable/update?name=test&scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 + +# 删除对话变量模板 +DELETE /api/variable/delete?name=test&scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28 +``` + +#### 2. **Conversation级别操作**(变量实例管理) +```bash +# 查询对话变量实例 +GET /api/variable/list?scope=conversation&conversation_id=conv123 + +# 更新对话变量实例值 +PUT /api/variable/update?name=test&scope=conversation&conversation_id=conv123 +``` + +## 🎯 **测试建议** + +### 立即测试 +现在可以重新测试原来失败的API调用: +```bash +curl "http://10.211.55.10:8002/api/variable/list?scope=conversation&flow_id=52e069c7-5556-42af-bdfc-63f4dc2dcd28" +``` + +### 完整测试流程 +1. **创建对话变量模板**(前端已测试成功) +2. **查询对话变量模板**(现在应该能查到) +3. **更新对话变量模板** +4. **删除对话变量模板** + +### 自动化测试 +运行测试脚本验证: +```bash +cd euler-copilot-framework +python test_conversation_variables.py +``` + +## 📊 **架构完整性验证** + +现在所有变量类型的查询都应该正常工作: + +### Flow级别查询 +- ✅ 系统变量模板:`GET /api/variable/list?scope=system&flow_id=xxx` +- ✅ 对话变量模板:`GET /api/variable/list?scope=conversation&flow_id=xxx` +- ✅ 环境变量:`GET /api/variable/list?scope=environment&flow_id=xxx` + +### Conversation级别查询 +- ✅ 系统变量实例:`GET /api/variable/list?scope=system&conversation_id=xxx` +- ✅ 对话变量实例:`GET /api/variable/list?scope=conversation&conversation_id=xxx` + +### User级别查询 +- ✅ 用户变量:`GET /api/variable/list?scope=user` + +## 🎉 **预期结果** + +修复后,你的前端应该能够: + +1. **成功创建对话变量模板**(已验证) +2. **成功查询对话变量模板**(修复的核心问题) +3. **成功更新对话变量模板** +4. **成功删除对话变量模板** + +所有操作都在Flow级别进行,符合你的设计需求:**Flow级别管理模板定义,Conversation级别操作实际数据**。 \ No newline at end of file diff --git a/apps/common/redis_cache.py b/apps/common/redis_cache.py new file mode 100644 index 000000000..019d0be91 --- /dev/null +++ b/apps/common/redis_cache.py @@ -0,0 +1,271 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Redis缓存模块 - 用于前置节点变量预解析缓存""" + +import json +import logging +import asyncio +from typing import List, Dict, Any, Optional +from datetime import datetime, UTC + +import redis.asyncio as redis +from apps.common.singleton import SingletonMeta + +logger = logging.getLogger(__name__) + + +class RedisCache(metaclass=SingletonMeta): + """Redis缓存管理器""" + + def __init__(self): + self._redis: Optional[redis.Redis] = None + self._connected = False + + async def init(self, redis_config=None, redis_url: str = None): + """初始化Redis连接 + + Args: + redis_config: Redis配置对象(优先级更高) + redis_url: Redis连接URL(降级选项) + """ + try: + if redis_config: + # 使用配置对象构建连接,添加连接池和超时参数 + self._redis = redis.Redis( + host=redis_config.host, + port=redis_config.port, + password=redis_config.password if redis_config.password else None, + db=redis_config.database, + decode_responses=redis_config.decode_responses, + # 连接池配置 + max_connections=redis_config.max_connections, + # 超时配置 + socket_timeout=redis_config.socket_timeout, + socket_connect_timeout=redis_config.socket_connect_timeout, + socket_keepalive=True, + socket_keepalive_options={}, + # 连接重试 + retry_on_timeout=True, + retry_on_error=[ConnectionError, TimeoutError], + health_check_interval=redis_config.health_check_interval + ) + logger.info(f"使用配置连接Redis: {redis_config.host}:{redis_config.port}, 数据库: {redis_config.database}") + elif redis_url: + # 降级使用URL连接 + self._redis = redis.from_url( + redis_url, + decode_responses=True, + socket_timeout=5.0, + socket_connect_timeout=5.0, + max_connections=10 + ) + logger.info(f"使用URL连接Redis: {redis_url}") + else: + raise ValueError("必须提供redis_config或redis_url参数") + + # 测试连接 + logger.info("正在测试Redis连接...") + ping_result = await self._redis.ping() + logger.info(f"Redis ping结果: {ping_result}") + + # 测试基本操作 + test_key = "__redis_test__" + await self._redis.set(test_key, "test", ex=10) + test_value = await self._redis.get(test_key) + await self._redis.delete(test_key) + logger.info(f"Redis读写测试成功: {test_value}") + + self._connected = True + logger.info("Redis连接初始化成功") + except ConnectionError as e: + logger.error(f"Redis连接错误: {e}") + self._connected = False + except TimeoutError as e: + logger.error(f"Redis连接超时: {e}") + self._connected = False + except Exception as e: + logger.error(f"Redis连接初始化失败: {e}") + logger.error(f"错误类型: {type(e).__name__}") + self._connected = False + + def is_connected(self) -> bool: + """检查Redis连接状态""" + return self._connected and self._redis is not None + + async def close(self): + """关闭Redis连接""" + if self._redis: + await self._redis.close() + self._connected = False + + +class PredecessorVariableCache: + """前置节点变量预解析缓存管理器""" + + def __init__(self, redis_cache: RedisCache): + self.redis = redis_cache + self.CACHE_PREFIX = "predecessor_vars" + self.PARSING_STATUS_PREFIX = "parsing_status" + self.CACHE_TTL = 3600 * 24 # 缓存24小时 + + def _get_cache_key(self, flow_id: str, step_id: str) -> str: + """生成缓存key""" + return f"{self.CACHE_PREFIX}:{flow_id}:{step_id}" + + def _get_status_key(self, flow_id: str, step_id: str) -> str: + """生成解析状态key""" + return f"{self.PARSING_STATUS_PREFIX}:{flow_id}:{step_id}" + + def _get_flow_hash_key(self, flow_id: str) -> str: + """生成Flow哈希key,用于存储Flow的拓扑结构哈希值""" + return f"flow_hash:{flow_id}" + + async def get_cached_variables(self, flow_id: str, step_id: str) -> Optional[List[Dict[str, Any]]]: + """获取缓存的前置节点变量""" + if not self.redis.is_connected(): + return None + + try: + cache_key = self._get_cache_key(flow_id, step_id) + cached_data = await self.redis._redis.get(cache_key) + + if cached_data: + data = json.loads(cached_data) + logger.info(f"从缓存获取前置节点变量: {flow_id}:{step_id}, 数量: {len(data.get('variables', []))}") + return data.get('variables', []) + + except Exception as e: + logger.error(f"获取缓存的前置节点变量失败: {e}") + + return None + + async def set_cached_variables(self, flow_id: str, step_id: str, variables: List[Dict[str, Any]], flow_hash: str): + """设置缓存的前置节点变量""" + if not self.redis.is_connected(): + return False + + try: + cache_key = self._get_cache_key(flow_id, step_id) + cache_data = { + 'variables': variables, + 'flow_hash': flow_hash, + 'cached_at': datetime.now(UTC).isoformat(), + 'step_count': len(variables) + } + + await self.redis._redis.setex( + cache_key, + self.CACHE_TTL, + json.dumps(cache_data, default=str) + ) + + logger.info(f"缓存前置节点变量成功: {flow_id}:{step_id}, 数量: {len(variables)}") + return True + + except Exception as e: + logger.error(f"缓存前置节点变量失败: {e}") + return False + + async def is_parsing_in_progress(self, flow_id: str, step_id: str) -> bool: + """检查是否正在解析中""" + if not self.redis.is_connected(): + return False + + try: + status_key = self._get_status_key(flow_id, step_id) + status = await self.redis._redis.get(status_key) + return status == "parsing" + except Exception as e: + logger.error(f"检查解析状态失败: {e}") + return False + + async def set_parsing_status(self, flow_id: str, step_id: str, status: str, ttl: int = 300): + """设置解析状态 (parsing, completed, failed)""" + if not self.redis.is_connected(): + return False + + try: + status_key = self._get_status_key(flow_id, step_id) + await self.redis._redis.setex(status_key, ttl, status) + return True + except Exception as e: + logger.error(f"设置解析状态失败: {e}") + return False + + async def wait_for_parsing_completion(self, flow_id: str, step_id: str, max_wait_time: int = 30) -> bool: + """等待解析完成""" + if not self.redis.is_connected(): + return False + + start_time = datetime.now(UTC) + + while (datetime.now(UTC) - start_time).total_seconds() < max_wait_time: + if not await self.is_parsing_in_progress(flow_id, step_id): + # 检查是否有缓存结果 + cached_vars = await self.get_cached_variables(flow_id, step_id) + return cached_vars is not None + + await asyncio.sleep(0.5) # 等待500ms后重试 + + logger.warning(f"等待解析完成超时: {flow_id}:{step_id}") + return False + + async def invalidate_flow_cache(self, flow_id: str): + """使某个Flow的所有缓存失效""" + if not self.redis.is_connected(): + return + + try: + # 查找所有相关的缓存key + pattern = f"{self.CACHE_PREFIX}:{flow_id}:*" + keys = await self.redis._redis.keys(pattern) + + # 同时删除解析状态key + status_pattern = f"{self.PARSING_STATUS_PREFIX}:{flow_id}:*" + status_keys = await self.redis._redis.keys(status_pattern) + + all_keys = keys + status_keys + + if all_keys: + await self.redis._redis.delete(*all_keys) + logger.info(f"清除Flow缓存: {flow_id}, 删除key数量: {len(all_keys)}") + + except Exception as e: + logger.error(f"清除Flow缓存失败: {e}") + + async def get_flow_hash(self, flow_id: str) -> Optional[str]: + """获取Flow的拓扑结构哈希值""" + if not self.redis.is_connected(): + return None + + try: + hash_key = self._get_flow_hash_key(flow_id) + return await self.redis._redis.get(hash_key) + except Exception as e: + logger.error(f"获取Flow哈希失败: {e}") + return None + + async def set_flow_hash(self, flow_id: str, flow_hash: str): + """设置Flow的拓扑结构哈希值""" + if not self.redis.is_connected(): + return False + + try: + # 检查事件循环是否仍然活跃 + import asyncio + try: + asyncio.get_running_loop() + except RuntimeError: + logger.warning(f"事件循环已关闭,跳过设置Flow哈希: {flow_id}") + return False + + hash_key = self._get_flow_hash_key(flow_id) + await self.redis._redis.setex(hash_key, self.CACHE_TTL, flow_hash) + return True + except Exception as e: + logger.error(f"设置Flow哈希失败: {e}") + return False + + +# 全局实例 +redis_cache = RedisCache() +predecessor_cache = PredecessorVariableCache(redis_cache) \ No newline at end of file diff --git a/apps/main.py b/apps/main.py index ee3fa741a..717311037 100644 --- a/apps/main.py +++ b/apps/main.py @@ -8,6 +8,10 @@ from __future__ import annotations import asyncio import logging +import logging.config +import signal +import sys +from contextlib import asynccontextmanager import uvicorn from fastapi import FastAPI @@ -39,9 +43,59 @@ from apps.routers import ( parameter ) from apps.scheduler.pool.pool import Pool +from apps.services.predecessor_cache_service import cleanup_background_tasks + +# 全局变量用于跟踪后台任务 +_cleanup_task = None + +async def cleanup_on_shutdown(): + """应用关闭时的清理函数""" + logger = logging.getLogger(__name__) + logger.info("开始清理应用资源...") + + try: + # 取消定期清理任务 + global _cleanup_task + if _cleanup_task and not _cleanup_task.done(): + _cleanup_task.cancel() + try: + await _cleanup_task + except asyncio.CancelledError: + logger.info("定期清理任务已取消") + + # 清理后台任务 + await cleanup_background_tasks() + + # 关闭Redis连接 + from apps.common.redis_cache import RedisCache + redis_cache = RedisCache() + if redis_cache.is_connected(): + await redis_cache.close() + logger.info("Redis连接已关闭") + + except Exception as e: + logger.error(f"清理应用资源时出错: {e}") + + logger.info("应用资源清理完成") + +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理""" + # 启动时的初始化 + await init_resources() + + yield + + # 关闭时的清理 + await cleanup_on_shutdown() # 定义FastAPI app -app = FastAPI(redoc_url=None) +app = FastAPI( + title="Euler Copilot Framework", + description="AI-powered automation framework", + version="1.0.0", + lifespan=lifespan, +) # 定义FastAPI全局中间件 app.add_middleware( CORSMiddleware, @@ -90,14 +144,48 @@ async def init_resources() -> None: await Pool.init() TokenCalculator() - # 初始化变量系统 - from apps.scheduler.variable.pool import get_variable_pool - await get_variable_pool() + # 初始化变量池管理器 + from apps.scheduler.variable.pool_manager import initialize_pool_manager + await initialize_pool_manager() + + # 初始化前置节点变量缓存服务 + try: + from apps.services.predecessor_cache_service import PredecessorCacheService, periodic_cleanup_background_tasks + await PredecessorCacheService.initialize_redis() + + # 启动定期清理任务 + global _cleanup_task + _cleanup_task = asyncio.create_task(start_periodic_cleanup()) + + logging.info("前置节点变量缓存服务初始化成功") + except Exception as e: + logging.warning(f"前置节点变量缓存服务初始化失败(将降级使用实时解析): {e}") + +async def start_periodic_cleanup(): + """启动定期清理任务""" + try: + from apps.services.predecessor_cache_service import periodic_cleanup_background_tasks + while True: + # 每60秒清理一次已完成的后台任务 + await asyncio.sleep(60) + await periodic_cleanup_background_tasks() + except asyncio.CancelledError: + logging.info("定期清理任务已取消") + raise # 重新抛出CancelledError + except Exception as e: + logging.error(f"定期清理任务异常: {e}") # 运行 if __name__ == "__main__": - # 初始化必要资源 - asyncio.run(init_resources()) - + def signal_handler(signum, frame): + """信号处理器""" + logger = logging.getLogger(__name__) + logger.info(f"收到信号 {signum},准备关闭应用...") + sys.exit(0) + + # 注册信号处理器 + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + # 启动FastAPI uvicorn.run(app, host="0.0.0.0", port=8002, log_level="info", log_config=None) diff --git a/apps/routers/flow.py b/apps/routers/flow.py index fc38c1bfb..646213a88 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """FastAPI Flow拓扑结构展示API""" +import logging from typing import Annotated from fastapi import APIRouter, Body, Depends, Query, status @@ -25,6 +26,8 @@ from apps.services.application import AppManager from apps.services.flow import FlowManager from apps.services.flow_validate import FlowService +logger = logging.getLogger(__name__) + router = APIRouter( prefix="/api/flow", tags=["flow"], @@ -153,6 +156,20 @@ async def put_flow( result=FlowStructurePutMsg(), ).model_dump(exclude_none=True, by_alias=True), ) + + # 触发前置节点变量预解析(异步执行,不阻塞响应) + try: + from apps.services.predecessor_cache_service import PredecessorCacheService + import asyncio + + # 在后台异步触发预解析 + asyncio.create_task( + PredecessorCacheService.trigger_flow_parsing(flow_id, force_refresh=True) + ) + logger.info(f"已触发Flow前置节点变量预解析: {flow_id}") + except Exception as trigger_error: + logger.warning(f"触发Flow前置节点变量预解析失败: {flow_id}, 错误: {trigger_error}") + return JSONResponse( status_code=status.HTTP_200_OK, content=FlowStructurePutRsp( diff --git a/apps/routers/variable.py b/apps/routers/variable.py index bf930584f..7a68e69d8 100644 --- a/apps/routers/variable.py +++ b/apps/routers/variable.py @@ -1,7 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """FastAPI 变量管理 API""" -from typing import Annotated, List, Optional +import logging +from typing import Annotated, List, Optional, Dict from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi.responses import JSONResponse @@ -9,10 +10,13 @@ from pydantic import BaseModel, Field from apps.dependency import get_user from apps.dependency.user import verify_user -from apps.scheduler.variable.pool import get_variable_pool +from apps.scheduler.variable.pool_manager import get_pool_manager from apps.scheduler.variable.type import VariableType, VariableScope from apps.scheduler.variable.parser import VariableParser from apps.schemas.response_data import ResponseData +from apps.services.flow import FlowManager + +logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/variable", @@ -23,6 +27,124 @@ router = APIRouter( ) +async def _get_predecessor_node_variables( + user_sub: str, + flow_id: str, + conversation_id: Optional[str], + current_step_id: str +) -> List: + """获取前置节点的输出变量(优化版本,使用缓存) + + Args: + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID(可选,配置阶段可能为None) + current_step_id: 当前步骤ID + + Returns: + List: 前置节点的输出变量列表 + """ + try: + variables = [] + pool_manager = await get_pool_manager() + + if conversation_id: + # 运行阶段:从对话池获取实际的前置节点变量 + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + if conversation_pool: + # 获取所有对话变量 + all_conversation_vars = await conversation_pool.list_variables() + + # 筛选出前置节点的输出变量(格式为 node_id.key) + for var in all_conversation_vars: + var_name = var.name + # 检查是否为节点输出变量格式(包含.且不是系统变量) + if "." in var_name and not var_name.startswith("system."): + # 提取节点ID + node_id = var_name.split(".")[0] + + # 检查是否为前置节点(这里可以根据需要添加更精确的前置判断逻辑) + if node_id != current_step_id: # 不是当前节点的变量 + variables.append(var) + else: + # 配置阶段:优先使用缓存,降级到实时解析 + try: + # 尝试使用优化的缓存服务 + from apps.services.predecessor_cache_service import PredecessorCacheService + + # 1. 先从flow池中查找已存在的前置节点变量 + flow_pool = await pool_manager.get_flow_pool(flow_id) + if flow_pool: + flow_conversation_vars = await flow_pool.list_variables() + + # 筛选出前置节点的输出变量(格式为 node_id.key) + for var in flow_conversation_vars: + var_name = var.name + if "." in var_name and not var_name.startswith("system."): + node_id = var_name.split(".")[0] + if node_id != current_step_id: + variables.append(var) + + # 2. 使用优化的缓存服务获取前置节点变量 + cached_var_data = await PredecessorCacheService.get_predecessor_variables_optimized( + flow_id, current_step_id, user_sub, max_wait_time=5 + ) + + # 将缓存的变量数据转换为Variable对象 + for var_data in cached_var_data: + try: + from apps.scheduler.variable.variables import create_variable + from apps.scheduler.variable.base import VariableMetadata + from apps.scheduler.variable.type import VariableType, VariableScope + from datetime import datetime + + # 创建变量元数据 + metadata = VariableMetadata( + name=var_data['name'], + var_type=VariableType(var_data['var_type']), + scope=VariableScope(var_data['scope']), + description=var_data.get('description', ''), + created_by=user_sub, + created_at=datetime.fromisoformat(var_data['created_at'].replace('Z', '+00:00')), + updated_at=datetime.fromisoformat(var_data['updated_at'].replace('Z', '+00:00')) + ) + + # 创建变量对象,并附加缓存的节点信息 + variable = create_variable(metadata, var_data.get('value', '')) + + # 将节点信息附加到变量对象上(用于后续响应格式化) + if hasattr(variable, '_cache_data'): + variable._cache_data = var_data + else: + # 如果对象不支持动态属性,我们可以创建一个包装类或者在响应时处理 + setattr(variable, '_cache_data', var_data) + + variables.append(variable) + + except Exception as var_create_error: + logger.warning(f"创建缓存变量对象失败: {var_create_error}") + continue + + logger.info(f"配置阶段:为节点 {current_step_id} 找到前置节点变量总数: {len([v for v in variables if hasattr(v, 'name') and '.' in v.name and not v.name.startswith('system.')])}") + + except Exception as flow_error: + logger.warning(f"配置阶段获取前置节点变量失败,降级到实时解析: {flow_error}") + # 降级到原有的实时解析逻辑 + predecessor_vars = await _get_predecessor_variables_from_topology( + flow_id, current_step_id, user_sub + ) + variables.extend(predecessor_vars) + + return variables + + except Exception as e: + logger.error(f"获取前置节点变量失败: {e}") + return [] + + + + + # 请求和响应模型 class CreateVariableRequest(BaseModel): """创建变量请求""" @@ -50,6 +172,8 @@ class VariableResponse(BaseModel): description: Optional[str] = Field(description="变量描述") created_at: str = Field(description="创建时间") updated_at: str = Field(description="更新时间") + step: Optional[str] = Field(default=None, description="节点名称(前置节点变量专用)") + step_id: Optional[str] = Field(default=None, description="节点ID(前置节点变量专用)") class VariableListResponse(BaseModel): @@ -97,22 +221,66 @@ async def create_variable( detail="不允许创建系统级变量" ) - pool = await get_variable_pool() + pool_manager = await get_pool_manager() + + # 根据作用域获取合适的变量池 + if request.scope == VariableScope.USER: + # 用户级变量需要user_sub参数 + if not user_sub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户级变量需要用户身份" + ) + pool = await pool_manager.get_user_pool(user_sub) + elif request.scope == VariableScope.ENVIRONMENT: + # 环境级变量需要flow_id参数 + if not request.flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="环境级变量需要flow_id参数" + ) + pool = await pool_manager.get_flow_pool(request.flow_id) + elif request.scope == VariableScope.CONVERSATION: + # 对话级变量需要flow_id参数,用于创建模板 + if not request.flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量需要flow_id参数" + ) + # 对话级变量模板在流程池中定义 + pool = await pool_manager.get_flow_pool(request.flow_id) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的变量作用域: {request.scope.value}" + ) - # 创建变量 - variable = await pool.add_variable( - name=request.name, - var_type=request.var_type, - scope=request.scope, - value=request.value, - description=request.description, - user_sub=None, # 不再支持用户级变量 - flow_id=request.flow_id if request.scope in [VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, - ) + if not pool: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="无法获取变量池" + ) - # 如果是对话级变量,刷新缓存确保数据一致性 - if request.scope == VariableScope.CONVERSATION and request.flow_id: - await pool.refresh_conversation_cache(request.flow_id) + # 根据作用域创建不同类型的变量 + if request.scope == VariableScope.CONVERSATION: + # 创建对话变量模板 + variable = await pool.add_conversation_template( + name=request.name, + var_type=request.var_type, + default_value=request.value, + description=request.description, + created_by=user_sub + ) + else: + # 创建其他类型的变量 + variable = await pool.add_variable( + name=request.name, + var_type=request.var_type, + value=request.value, + description=request.description, + created_by=user_sub + ) + return ResponseData( code=200, @@ -146,27 +314,68 @@ async def update_variable( name: str = Query(..., description="变量名称"), scope: VariableScope = Query(..., description="变量作用域"), flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"), request: UpdateVariableRequest = Body(...), ) -> ResponseData: """更新变量值""" try: - pool = await get_variable_pool() + pool_manager = await get_pool_manager() + + # 根据作用域获取合适的变量池 + if scope == VariableScope.USER: + if not user_sub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户级变量需要用户身份" + ) + pool = await pool_manager.get_user_pool(user_sub) + elif scope == VariableScope.ENVIRONMENT: + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="环境级变量需要flow_id参数" + ) + pool = await pool_manager.get_flow_pool(flow_id) + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 运行时:使用对话池,如果不存在则创建 + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量运行时需要conversation_id和flow_id参数" + ) + pool = await pool_manager.get_conversation_pool(conversation_id) + if not pool: + # 对话池不存在,自动创建 + pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) + elif flow_id: + # 配置时:使用流程池 + pool = await pool_manager.get_flow_pool(flow_id) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量需要conversation_id(运行时)或flow_id(配置时)参数" + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的变量作用域: {scope.value}" + ) + + if not pool: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="无法获取变量池" + ) # 更新变量 variable = await pool.update_variable( name=name, - scope=scope, value=request.value, var_type=request.var_type, - description=request.description, - user_sub=None, # 不再支持用户级变量 - flow_id=flow_id, + description=request.description ) - # 如果是对话级变量,刷新缓存确保数据一致性 - if scope == VariableScope.CONVERSATION and flow_id: - await pool.refresh_conversation_cache(flow_id) - return ResponseData( code=200, message="变量更新成功", @@ -203,18 +412,61 @@ async def delete_variable( name: str = Query(..., description="变量名称"), scope: VariableScope = Query(..., description="变量作用域"), flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(对话级变量运行时必需)"), ) -> ResponseData: """删除变量""" try: - pool = await get_variable_pool() + pool_manager = await get_pool_manager() + + # 根据作用域获取合适的变量池 + if scope == VariableScope.USER: + if not user_sub: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="用户级变量需要用户身份" + ) + pool = await pool_manager.get_user_pool(user_sub) + elif scope == VariableScope.ENVIRONMENT: + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="环境级变量需要flow_id参数" + ) + pool = await pool_manager.get_flow_pool(flow_id) + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 运行时:使用对话池,如果不存在则创建 + if not flow_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量运行时需要conversation_id和flow_id参数" + ) + pool = await pool_manager.get_conversation_pool(conversation_id) + if not pool: + # 对话池不存在,自动创建 + pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) + elif flow_id: + # 配置时:使用流程池 + pool = await pool_manager.get_flow_pool(flow_id) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="对话级变量需要conversation_id(运行时)或flow_id(配置时)参数" + ) + else: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"不支持的变量作用域: {scope.value}" + ) + + if not pool: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="无法获取变量池" + ) # 删除变量 - success = await pool.delete_variable( - name=name, - scope=scope, - user_sub=None, # 不再支持用户级变量 - flow_id=flow_id, - ) + success = await pool.delete_variable(name) if not success: raise HTTPException( @@ -222,10 +474,6 @@ async def delete_variable( detail="变量不存在" ) - # 如果是对话级变量,刷新缓存确保数据一致性 - if scope == VariableScope.CONVERSATION and flow_id: - await pool.refresh_conversation_cache(flow_id) - return ResponseData( code=200, message="变量删除成功", @@ -261,17 +509,19 @@ async def get_variable( name: str = Query(..., description="变量名称"), scope: VariableScope = Query(..., description="变量作用域"), flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"), ) -> VariableResponse: """获取单个变量""" try: - pool = await get_variable_pool() + pool_manager = await get_pool_manager() - # 获取变量 - variable = await pool.get_variable( + # 根据作用域获取变量 + variable = await pool_manager.get_variable_from_any_pool( name=name, scope=scope, - user_sub=None, # 不再支持用户级变量 - flow_id=flow_id if scope in [VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + user_id=user_sub if scope == VariableScope.USER else None, + flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None ) if not variable: @@ -318,32 +568,92 @@ async def list_variables( user_sub: Annotated[str, Depends(get_user)], scope: VariableScope = Query(..., description="变量作用域"), flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + conversation_id: Optional[str] = Query(default=None, description="对话ID(系统级和对话级变量必需)"), + current_step_id: Optional[str] = Query(default=None, description="当前步骤ID(用于获取前置节点变量)"), ) -> VariableListResponse: """列出指定作用域的变量""" try: - pool = await get_variable_pool() + pool_manager = await get_pool_manager() # 获取变量列表 - variables = await pool.list_variables( + variables = await pool_manager.list_variables_from_any_pool( scope=scope, - user_sub=None, # 不再支持用户级变量 - flow_id=flow_id if scope in [VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + user_id=user_sub if scope == VariableScope.USER else None, + flow_id=flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None ) + # 如果是对话级变量且提供了current_step_id,则额外获取前置节点的输出变量 + if scope == VariableScope.CONVERSATION and current_step_id and flow_id: + predecessor_variables = await _get_predecessor_node_variables( + user_sub, flow_id, conversation_id, current_step_id + ) + variables.extend(predecessor_variables) + # 过滤权限并构建响应 filtered_variables = [] for variable in variables: if variable.can_access(user_sub): var_dict = variable.to_dict() - filtered_variables.append(VariableResponse( - name=variable.name, - var_type=variable.var_type.value, - scope=variable.scope.value, - value=str(var_dict["value"]) if var_dict["value"] is not None else "", - description=variable.metadata.description, - created_at=variable.metadata.created_at.isoformat(), - updated_at=variable.metadata.updated_at.isoformat(), - )) + + # 检查是否为前置节点变量 + is_predecessor_var = ( + "." in variable.name and + not variable.name.startswith("system.") and + scope == VariableScope.CONVERSATION and + flow_id + ) + + if is_predecessor_var: + # 前置节点变量特殊处理 + parts = variable.name.split(".", 1) + if len(parts) == 2: + step_id, var_name = parts + + # 优先使用缓存数据中的节点信息 + if hasattr(variable, '_cache_data') and variable._cache_data: + cache_data = variable._cache_data + step_name = cache_data.get('step_name', step_id) + step_id_from_cache = cache_data.get('step_id', step_id) + else: + # 降级到实时获取节点信息 + node_info = await _get_node_info_by_step_id(flow_id, step_id) + step_name = node_info["name"] + step_id_from_cache = node_info["step_id"] + + filtered_variables.append(VariableResponse( + name=var_name, # 只保留变量名部分 + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + step=step_name, # 节点名称 + step_id=step_id_from_cache # 节点ID + )) + else: + # 降级处理,如果格式不符合预期 + filtered_variables.append(VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + )) + else: + # 普通变量 + filtered_variables.append(VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + )) return VariableListResponse( variables=filtered_variables, @@ -372,7 +682,7 @@ async def parse_template( try: # 创建变量解析器 parser = VariableParser( - user_sub=user_sub, + user_id=user_sub, flow_id=request.flow_id, conversation_id=None, # 不再使用conversation_id ) @@ -410,7 +720,7 @@ async def validate_template( try: # 创建变量解析器 parser = VariableParser( - user_sub=user_sub, + user_id=user_sub, flow_id=request.flow_id, conversation_id=None, # 不再使用conversation_id ) @@ -460,9 +770,9 @@ async def clear_conversation_variables( ) -> ResponseData: """清空指定工作流的对话级变量""" try: - pool = await get_variable_pool() + pool_manager = await get_pool_manager() # 清空工作流的对话级变量 - await pool.clear_conversation_variables(flow_id) + await pool_manager.clear_conversation_variables(flow_id) return ResponseData( code=200, @@ -474,4 +784,200 @@ async def clear_conversation_variables( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"清空对话变量失败: {str(e)}" - ) \ No newline at end of file + ) + + +async def _get_node_info_by_step_id(flow_id: str, step_id: str) -> Dict[str, str]: + """根据step_id获取节点信息""" + try: + flow_item = await _get_flow_by_flow_id(flow_id) + if not flow_item: + return {"name": step_id, "step_id": step_id} # 降级返回step_id作为名称 + + # 查找对应的节点 + for node in flow_item.nodes: + if node.step_id == step_id: + return { + "name": node.name or step_id, # 如果没有名称则使用step_id + "step_id": step_id + } + + # 如果没有找到节点,返回默认值 + return {"name": step_id, "step_id": step_id} + + except Exception as e: + logger.error(f"获取节点信息失败: {e}") + return {"name": step_id, "step_id": step_id} + + +async def _get_predecessor_variables_from_topology( + flow_id: str, + current_step_id: str, + user_sub: str +) -> List: + """通过工作流拓扑分析获取前置节点变量""" + try: + variables = [] + + # 直接通过flow_id获取工作流拓扑信息 + flow_item = await _get_flow_by_flow_id(flow_id) + if not flow_item: + logger.warning(f"无法获取工作流信息: flow_id={flow_id}") + return variables + + # 分析前置节点 + predecessor_nodes = _find_predecessor_nodes(flow_item, current_step_id) + + # 为每个前置节点创建潜在的输出变量 + for node in predecessor_nodes: + node_vars = await _create_node_output_variables(node, user_sub) + variables.extend(node_vars) + + logger.info(f"通过拓扑分析为节点 {current_step_id} 创建了 {len(variables)} 个前置节点变量") + return variables + + except Exception as e: + logger.error(f"通过拓扑分析获取前置节点变量失败: {e}") + return [] + + +async def _get_flow_by_flow_id(flow_id: str): + """直接通过flow_id获取工作流信息""" + try: + from apps.common.mongo import MongoDB + + app_collection = MongoDB().get_collection("app") + + # 查询包含此flow_id的app,同时获取app_id + app_record = await app_collection.find_one( + {"flows.id": flow_id}, + {"_id": 1} + ) + + if not app_record: + logger.warning(f"未找到包含flow_id {flow_id} 的应用") + return None + + app_id = app_record["_id"] + + # 使用现有的FlowManager方法获取flow + flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + return flow_item + + except Exception as e: + logger.error(f"通过flow_id获取工作流失败: {e}") + return None + + +def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: + """在工作流中查找前置节点""" + try: + predecessor_nodes = [] + + # 遍历边,找到指向当前节点的边 + for edge in flow_item.edges: + if edge.target_node == current_step_id: + # 找到前置节点 + source_node = next( + (node for node in flow_item.nodes if node.step_id == edge.source_node), + None + ) + if source_node: + predecessor_nodes.append(source_node) + + logger.info(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点") + return predecessor_nodes + + except Exception as e: + logger.error(f"查找前置节点失败: {e}") + return [] + + +async def _create_node_output_variables(node, user_sub: str) -> List: + """根据节点的output_parameters配置创建输出变量""" + try: + from apps.scheduler.variable.variables import create_variable + from apps.scheduler.variable.base import VariableMetadata + from datetime import datetime, UTC + + variables = [] + node_id = node.step_id + + # 调试:输出节点的完整参数信息 + logger.info(f"节点 {node_id} 的参数结构: {node.parameters}") + + # 统一从节点的output_parameters创建变量 + output_params = {} + if hasattr(node, 'parameters') and node.parameters: + # 尝试不同的访问方式 + if isinstance(node.parameters, dict): + output_params = node.parameters.get('output_parameters', {}) + logger.info(f"从字典中获取output_parameters: {output_params}") + else: + output_params = getattr(node.parameters, 'output_parameters', {}) + logger.info(f"从对象属性中获取output_parameters: {output_params}") + + # 如果没有配置output_parameters,跳过此节点 + if not output_params: + logger.info(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量") + return variables + + # 遍历output_parameters中的每个key-value对,创建对应的变量 + for param_name, param_config in output_params.items(): + # 解析参数配置 + if isinstance(param_config, dict): + param_type = param_config.get('type', 'string') + description = param_config.get('description', '') + else: + # 如果param_config不是字典,可能是简单的类型字符串 + param_type = str(param_config) if param_config else 'string' + description = '' + + # 确定变量类型 + var_type = VariableType.STRING # 默认类型 + if param_type == 'number': + var_type = VariableType.NUMBER + elif param_type == 'boolean': + var_type = VariableType.BOOLEAN + elif param_type == 'object': + var_type = VariableType.OBJECT + elif param_type == 'array' or param_type == 'array[any]': + var_type = VariableType.ARRAY_ANY + elif param_type == 'array[string]': + var_type = VariableType.ARRAY_STRING + elif param_type == 'array[number]': + var_type = VariableType.ARRAY_NUMBER + elif param_type == 'array[object]': + var_type = VariableType.ARRAY_OBJECT + elif param_type == 'array[boolean]': + var_type = VariableType.ARRAY_BOOLEAN + elif param_type == 'array[file]': + var_type = VariableType.ARRAY_FILE + elif param_type == 'array[secret]': + var_type = VariableType.ARRAY_SECRET + elif param_type == 'file': + var_type = VariableType.FILE + elif param_type == 'secret': + var_type = VariableType.SECRET + + # 创建变量元数据 + metadata = VariableMetadata( + name=f"{node_id}.{param_name}", + var_type=var_type, + scope=VariableScope.CONVERSATION, + description=description or f"来自节点 {node_id} 的输出参数 {param_name}", + created_by=user_sub, + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC) + ) + + # 创建变量对象 + variable = create_variable(metadata, "") # 配置阶段的潜在变量,值为空 + variables.append(variable) + + logger.info(f"为节点 {node_id} 创建了 {len(variables)} 个输出变量: {[v.name for v in variables]}") + return variables + + except Exception as e: + logger.error(f"创建节点输出变量失败: {e}") + return [] \ No newline at end of file diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index e1891f725..1aec22c82 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -15,7 +15,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.common.oidc import oidc_provider from apps.scheduler.call.api.schema import APIInput, APIOutput from apps.scheduler.call.core import CoreCall -from apps.schemas.enum_var import CallOutputType, ContentType, HTTPMethod +from apps.schemas.enum_var import CallOutputType, CallType, ContentType, HTTPMethod from apps.schemas.scheduler import ( CallError, CallInfo, @@ -62,7 +62,11 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="API调用", description="向某一个API接口发送HTTP请求,获取数据。") + return CallInfo( + name="API调用", + type=CallType.TOOL, + description="向某一个API接口发送HTTP请求,获取数据。" + ) async def _init(self, call_vars: CallVars) -> APIInput: """初始化API调用工具""" diff --git a/apps/scheduler/call/code/code.py b/apps/scheduler/call/code/code.py index ff2bbdc24..18b6577b4 100644 --- a/apps/scheduler/call/code/code.py +++ b/apps/scheduler/call/code/code.py @@ -1,7 +1,6 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """代码执行工具""" -import json import logging from collections.abc import AsyncGenerator from typing import Any @@ -12,7 +11,7 @@ from pydantic import Field from apps.common.config import Config from apps.scheduler.call.code.schema import CodeInput, CodeOutput from apps.scheduler.call.core import CoreCall -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -35,12 +34,18 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300) memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024) cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) + input_parameters: dict[str, Any] = Field(description="输入参数配置", default={}) + output_parameters: dict[str, Any] = Field(description="输出参数配置", default={}) @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="代码执行", description="在安全的沙箱环境中执行Python、JavaScript、Bash代码。") + return CallInfo( + name="代码执行", + type=CallType.TRANSFORM, + description="在安全的沙箱环境中执行Python、JavaScript、Bash代码。" + ) async def _init(self, call_vars: CallVars) -> CodeInput: @@ -52,6 +57,14 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): "permissions": ["execute"] } + # 处理输入参数 - 使用基类的变量解析功能 + input_arg = {} + if self.input_parameters: + # 解析每个输入参数 + for param_name, param_config in self.input_parameters.items(): + resolved_value = await self._resolve_variables_in_config(param_config, call_vars) + input_arg[param_name] = resolved_value + return CodeInput( code=self.code, code_type=self.code_type, @@ -60,6 +73,7 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): timeout_seconds=self.timeout_seconds, memory_limit_mb=self.memory_limit_mb, cpu_limit=self.cpu_limit, + input_arg=input_arg, ) @@ -81,6 +95,7 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): "timeout_seconds": data.timeout_seconds, "memory_limit_mb": data.memory_limit_mb, "cpu_limit": data.cpu_limit, + "input_arg": data.input_arg, } # 发送执行请求 @@ -98,22 +113,63 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): ) result = response.json() - task_id = result.get("task_id", "") + logger.info(f"Sandbox service response: {result}") + + # 检查请求是否成功 + success = result.get("success", False) + message = result.get("message", "") + timestamp = result.get("timestamp", "") + + if not success: + raise CallError( + message=f"代码执行服务返回错误: {message}", + data={"response": result} + ) + + # 提取任务信息 + data = result.get("data", {}) + task_id = data.get("task_id", "") + estimated_wait_time = data.get("estimated_wait_time", 0) + queue_position = data.get("queue_position", 0) + + logger.info(f"Task submitted successfully - task_id: {task_id}, estimated_wait_time: {estimated_wait_time}s, queue_position: {queue_position}, timestamp: {timestamp}") # 轮询获取结果 if task_id: + # 有task_id,需要轮询获取最终结果 result = await self._wait_for_result(sandbox_url, task_id) + else: + # 没有task_id,可能是同步执行,直接使用初始响应 + # 但需要确保结果格式正确 + if "output" not in result and "error" not in result: + # 如果初始响应没有包含执行结果,可能是异步但没有返回task_id的错误情况 + result = { + "status": "error", + "error": "服务器没有返回task_id且没有执行结果", + "output": "", + "result": {} + } - # 返回结果 + # 处理sandbox返回的结果,提取output_parameters指定的数据 + extracted_data = await self._process_sandbox_result(result) + + # 构建最终输出内容 + final_content = CodeOutput( + task_id=task_id, + status=result.get("status", "unknown"), + output=result.get("output") or "", + error=result.get("error") or "", + ).model_dump(by_alias=True, exclude_none=True) + + # 如果成功提取到数据,将其合并到输出中 + if extracted_data and result.get("status") == "completed": + final_content.update(extracted_data) + logger.info(f"[Code] 已将提取的数据合并到输出: {list(extracted_data.keys())}") + + # 返回最终结果 yield CallOutputChunk( type=CallOutputType.DATA, - content=CodeOutput( - task_id=task_id, - status=result.get("status", "unknown"), - result=result.get("result", {}), - output=result.get("output", ""), - error=result.get("error", ""), - ).model_dump(by_alias=True, exclude_none=True), + content=final_content, ) except httpx.TimeoutException: @@ -124,6 +180,85 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): raise CallError(message=f"代码执行失败: {e!s}", data={}) + async def _process_sandbox_result(self, result: dict[str, Any]) -> dict[str, Any] | None: + """处理sandbox返回的结果,根据output_parameters提取数据""" + try: + # 检查是否有output_parameters配置 + if not hasattr(self, 'output_parameters') or not self.output_parameters: + logger.debug("[Code] 无output_parameters配置,跳过数据提取") + return None + + # 获取sandbox返回的output + sandbox_output = result.get("output") + if not sandbox_output: + logger.warning("[Code] sandbox返回的结果中没有output字段") + return None + + # 确保output是字典类型 + if isinstance(sandbox_output, str): + # 尝试解析JSON字符串 + try: + import json + sandbox_output = json.loads(sandbox_output) + except json.JSONDecodeError: + logger.warning(f"[Code] sandbox返回的output不是有效的JSON格式: {sandbox_output}") + return None + + if not isinstance(sandbox_output, dict): + logger.warning(f"[Code] sandbox返回的output不是字典类型: {type(sandbox_output)}") + return None + + # 根据output_parameters提取对应的kv对 + extracted_data = {} + for param_name, param_config in self.output_parameters.items(): + try: + # 支持多种提取方式 + if param_name in sandbox_output: + # 直接键匹配 + extracted_data[param_name] = sandbox_output[param_name] + elif isinstance(param_config, dict) and "path" in param_config: + # 路径提取 + path = param_config["path"] + value = self._extract_value_by_path(sandbox_output, path) + if value is not None: + extracted_data[param_name] = value + elif isinstance(param_config, dict) and param_config.get("source") == "full_output": + # 使用完整输出 + extracted_data[param_name] = sandbox_output + elif isinstance(param_config, dict) and "default" in param_config: + # 使用默认值 + extracted_data[param_name] = param_config["default"] + else: + logger.debug(f"[Code] 无法提取参数 {param_name},在output中未找到对应值") + + except Exception as e: + logger.warning(f"[Code] 提取参数 {param_name} 失败: {e}") + + if extracted_data: + logger.info(f"[Code] 成功提取 {len(extracted_data)} 个输出参数: {list(extracted_data.keys())}") + return extracted_data + else: + logger.debug("[Code] 未能提取到任何输出参数") + return None + + except Exception as e: + logger.error(f"[Code] 处理sandbox结果失败: {e}") + return None + + def _extract_value_by_path(self, data: dict, path: str) -> Any: + """根据路径提取值 (例如: 'result.data.value')""" + try: + current = data + for key in path.split('.'): + if isinstance(current, dict) and key in current: + current = current[key] + else: + return None + return current + except Exception: + return None + + async def _wait_for_result(self, sandbox_url: str, task_id: str, max_attempts: int = 30) -> dict[str, Any]: """等待任务执行完成""" import asyncio @@ -135,18 +270,39 @@ class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/status") if response.status_code == 200: status_result = response.json() - status = status_result.get("status", "") + logger.info(f"Task status response: {status_result}") + + # 检查响应是否成功 + success = status_result.get("success", False) + if not success: + message = status_result.get("message", "获取任务状态失败") + logger.warning(f"Failed to get task status: {message}") + await asyncio.sleep(1) + continue + + # 提取任务状态 + data = status_result.get("data", {}) + status = data.get("status", "") # 如果任务完成,获取结果 if status in ["completed", "failed", "cancelled"]: result_response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/result") if result_response.status_code == 200: - return result_response.json() + result_data = result_response.json() + logger.info(f"Task result response: {result_data}") + + # 检查获取结果是否成功 + if result_data.get("success", False): + # 返回实际的执行结果 + return result_data.get("data", {}) + else: + return {"status": status, "error": result_data.get("message", "获取结果失败")} else: return {"status": status, "error": "无法获取结果"} # 如果任务仍在运行,继续等待 if status in ["pending", "running"]: + logger.debug(f"Task {task_id} still {status}, waiting...") await asyncio.sleep(1) continue diff --git a/apps/scheduler/call/code/schema.py b/apps/scheduler/call/code/schema.py index 71722d84a..93648a811 100644 --- a/apps/scheduler/call/code/schema.py +++ b/apps/scheduler/call/code/schema.py @@ -18,6 +18,7 @@ class CodeInput(DataBase): timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300) memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024) cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) + input_arg: dict[str, Any] = Field(description="传递给main函数的输入参数", default={}) class CodeOutput(DataBase): @@ -25,6 +26,5 @@ class CodeOutput(DataBase): task_id: str = Field(description="任务ID") status: str = Field(description="任务状态") - result: dict[str, Any] = Field(description="执行结果", default={}) output: str = Field(description="执行输出", default="") error: str = Field(description="错误信息", default="") \ No newline at end of file diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py index 27980bd8a..bbe0dbe80 100644 --- a/apps/scheduler/call/convert/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -12,7 +12,7 @@ from pydantic import Field from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput from apps.scheduler.call.core import CallOutputChunk, CoreCall -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallInfo, CallOutputChunk, @@ -30,7 +30,11 @@ class Convert(CoreCall, input_model=ConvertInput, output_model=ConvertOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="模板转换", description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。") + return CallInfo( + name="模板转换", + type=CallType.TRANSFORM, + description="使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。" + ) async def _init(self, call_vars: CallVars) -> ConvertInput: """初始化工具""" diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index af28c6a34..c5322b048 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -6,6 +6,7 @@ Core Call类是定义了所有Call都应具有的方法和参数的PyDantic类 """ import logging +import re from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, ClassVar, Self @@ -14,6 +15,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM +from apps.scheduler.variable.integration import VariableIntegration from apps.schemas.enum_var import CallOutputType from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( @@ -70,6 +72,7 @@ class CoreCall(BaseModel): ) to_user: bool = Field(description="是否需要将输出返回给用户", default=False) + enable_variable_resolution: bool = Field(description="是否启用自动变量解析", default=True) model_config = ConfigDict( arbitrary_types_allowed=True, @@ -108,6 +111,7 @@ class CoreCall(BaseModel): task_id=executor.task.id, flow_id=executor.task.state.flow_id, session_id=executor.task.ids.session_id, + conversation_id=executor.task.ids.conversation_id, user_sub=executor.task.ids.user_sub, app_id=executor.task.state.app_id, ), @@ -163,9 +167,113 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj + async def _initialize_variable_context(self, call_vars: CallVars) -> dict[str, Any]: + """初始化变量解析上下文并初始化系统变量""" + context = { + "question": call_vars.question, + "user_sub": call_vars.ids.user_sub, + "flow_id": call_vars.ids.flow_id, + "session_id": call_vars.ids.session_id, + "app_id": call_vars.ids.app_id, + "conversation_id": call_vars.ids.conversation_id, + } + + await VariableIntegration.initialize_system_variables(context) + return context + + async def _resolve_variables_in_config(self, config: Any, call_vars: CallVars) -> Any: + """解析配置中的变量引用 + + Args: + config: 配置值,可能包含变量引用 + call_vars: Call变量 + + Returns: + 解析后的配置值 + """ + if isinstance(config, dict): + if "reference" in config: + # 解析变量引用 + resolved_value = await VariableIntegration.resolve_variable_reference( + config["reference"], + user_sub=call_vars.ids.user_sub, + flow_id=call_vars.ids.flow_id, + conversation_id=call_vars.ids.conversation_id + ) + return resolved_value + elif "value" in config: + # 使用默认值 + return config["value"] + else: + # 递归解析字典中的所有值 + resolved_dict = {} + for key, value in config.items(): + resolved_dict[key] = await self._resolve_variables_in_config(value, call_vars) + return resolved_dict + elif isinstance(config, list): + # 递归解析列表中的所有值 + resolved_list = [] + for item in config: + resolved_item = await self._resolve_variables_in_config(item, call_vars) + resolved_list.append(resolved_item) + return resolved_list + elif isinstance(config, str): + # 解析字符串中的变量引用 + return await self._resolve_variables_in_text(config, call_vars) + else: + # 直接返回配置值 + return config + + async def _resolve_variables_in_text(self, text: str, call_vars: CallVars) -> str: + """解析文本中的变量引用({{...}} 语法) + + Args: + text: 包含变量引用的文本 + call_vars: Call变量 + + Returns: + 解析后的文本 + """ + if not isinstance(text, str): + return text + + # 检查是否包含变量引用语法 + if not re.search(r'\{\{.*?\}\}', text): + return text + + # 提取所有变量引用并逐一解析替换 + variable_pattern = r'\{\{(.*?)\}\}' + matches = re.findall(variable_pattern, text) + + resolved_text = text + for match in matches: + try: + # 解析变量引用 + resolved_value = await VariableIntegration.resolve_variable_reference( + match.strip(), + user_sub=call_vars.ids.user_sub, + flow_id=call_vars.ids.flow_id, + conversation_id=call_vars.ids.conversation_id + ) + # 替换原始文本中的变量引用 + resolved_text = resolved_text.replace(f'{{{{{match}}}}}', str(resolved_value)) + except Exception as e: + logger.warning(f"[CoreCall] 解析变量引用 '{match}' 失败: {e}") + # 如果解析失败,保留原始的变量引用 + continue + + return resolved_text + + async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) + self._step_id = executor.step.step_id # 存储 step_id 用于变量名构造 + + # 如果启用了变量解析,初始化变量上下文 + if self.enable_variable_resolution: + await self._initialize_variable_context(self._sys_vars) + input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) @@ -181,10 +289,17 @@ class CoreCall(BaseModel): async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" + async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" + self._last_output_data = {} # 初始化输出数据存储 + async for chunk in self._exec(input_data): + # 捕获最后的输出数据 + if chunk.type == CallOutputType.DATA and isinstance(chunk.content, dict): + self._last_output_data = chunk.content yield chunk + await self._after_exec(input_data) async def _llm(self, messages: list[dict[str, Any]]) -> str: diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py index 5865bc7e8..a66aac525 100644 --- a/apps/scheduler/call/empty.py +++ b/apps/scheduler/call/empty.py @@ -5,7 +5,7 @@ from collections.abc import AsyncGenerator from typing import Any from apps.scheduler.call.core import CoreCall, DataBase -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars @@ -20,7 +20,11 @@ class Empty(CoreCall, input_model=DataBase, output_model=DataBase): :return: Call的名称和描述 :rtype: CallInfo """ - return CallInfo(name="空白", description="空白节点,用于占位") + return CallInfo( + name="空白", + type=CallType.DEFAULT, + description="空白节点,用于占位" + ) async def _init(self, call_vars: CallVars) -> DataBase: diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index f8aebcd74..10241d85a 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -16,7 +16,7 @@ from apps.scheduler.call.facts.schema import ( FactsInput, FactsOutput, ) -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars from apps.services.user_domain import UserDomainManager @@ -34,7 +34,11 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="提取事实", description="从对话上下文和文档片段中提取事实。") + return CallInfo( + name="提取事实", + type=CallType.DEFAULT, + description="从对话上下文和文档片段中提取事实。" + ) @classmethod diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py index c2728f179..7383b2f2c 100644 --- a/apps/scheduler/call/graph/graph.py +++ b/apps/scheduler/call/graph/graph.py @@ -11,7 +11,7 @@ from pydantic import Field from apps.scheduler.call.core import CoreCall from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput from apps.scheduler.call.graph.style import RenderStyle -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -29,7 +29,11 @@ class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="图表", description="将SQL查询出的数据转换为图表") + return CallInfo( + name="图表", + type=CallType.TRANSFORM, + description="将SQL查询出的数据转换为图表" + ) async def _init(self, call_vars: CallVars) -> RenderInput: diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 6a679dce9..7f35310e5 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -15,7 +15,7 @@ from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall from apps.scheduler.call.llm.prompt import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT from apps.scheduler.call.llm.schema import LLMInput, LLMOutput -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -42,7 +42,11 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="大模型", description="以指定的提示词和上下文信息调用大模型,并获得输出。") + return CallInfo( + name="大模型", + type=CallType.DEFAULT, + description="以指定的提示词和上下文信息调用大模型,并获得输出。" + ) async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 4e6a1bb73..9c78a1836 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -16,7 +16,7 @@ from apps.scheduler.call.mcp.schema import ( MCPOutput, ) from apps.scheduler.mcp import MCPHost, MCPPlanner, MCPSelector -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.mcp import MCPPlanItem from apps.schemas.scheduler import ( CallInfo, @@ -43,7 +43,11 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): :return: Call的名称和描述 :rtype: CallInfo """ - return CallInfo(name="MCP", description="调用MCP Server,执行工具") + return CallInfo( + name="MCP", + type=CallType.DEFAULT, + description="调用MCP Server,执行工具" + ) async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index e27327d8a..7c1c48940 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -13,7 +13,7 @@ from apps.common.config import Config from apps.llm.patterns.rewrite import QuestionRewrite from apps.scheduler.call.core import CoreCall from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, SearchMethod -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -40,7 +40,11 @@ class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="知识库", description="查询知识库,从文档中获取必要信息") + return CallInfo( + name="知识库", + type=CallType.DEFAULT, + description="查询知识库,从文档中获取必要信息" + ) async def _init(self, call_vars: CallVars) -> RAGInput: """初始化RAG工具""" diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py index d1f67f285..b5de998a1 100644 --- a/apps/scheduler/call/reply/direct_reply.py +++ b/apps/scheduler/call/reply/direct_reply.py @@ -9,8 +9,7 @@ from pydantic import Field from apps.scheduler.call.core import CoreCall from apps.scheduler.call.reply.schema import DirectReplyInput, DirectReplyOutput -from apps.scheduler.variable.integration import VariableIntegration -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -29,7 +28,11 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="直接回复", description="直接回复用户输入的内容,支持变量插入") + return CallInfo( + name="直接回复", + type=CallType.DEFAULT, + description="直接回复用户输入的内容,支持变量插入" + ) async def _init(self, call_vars: CallVars) -> DirectReplyInput: """初始化DirectReply工具""" @@ -41,19 +44,8 @@ class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectRep data = DirectReplyInput(**input_data) try: - # 解析answer中的变量引用语法 - parsed_answer = await VariableIntegration.parse_call_input( - data.answer, - user_sub=self._sys_vars.ids.user_sub, - flow_id=self._sys_vars.ids.flow_id, - conversation_id=self._sys_vars.ids.session_id - ) - - # 如果解析结果是字符串,直接使用;否则转换为字符串 - if isinstance(parsed_answer, str): - final_answer = parsed_answer - else: - final_answer = str(parsed_answer) + # 使用基类的变量解析功能处理文本中的变量引用 + final_answer = await self._resolve_variables_in_text(data.answer, self._sys_vars) logger.info(f"[DirectReply] 原始答案: {data.answer}") logger.info(f"[DirectReply] 解析后答案: {final_answer}") diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 4f8e1010c..69c2bfbcd 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -15,7 +15,7 @@ from apps.scheduler.call.core import CoreCall from apps.scheduler.call.slot.prompt import SLOT_GEN_PROMPT from apps.scheduler.call.slot.schema import SlotInput, SlotOutput from apps.scheduler.slot.slot import Slot as SlotProcessor -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars @@ -36,7 +36,11 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="参数自动填充", description="根据步骤历史,自动填充参数") + return CallInfo( + name="参数自动填充", + type=CallType.TRANSFORM, + description="根据步骤历史,自动填充参数" + ) async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> tuple[str, dict[str, Any]]: diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py index 3e24301de..2f4ef97fb 100644 --- a/apps/scheduler/call/sql/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -12,7 +12,7 @@ from pydantic import Field from apps.common.config import Config from apps.scheduler.call.core import CoreCall from apps.scheduler.call.sql.schema import SQLInput, SQLOutput -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.scheduler import ( CallError, CallInfo, @@ -35,7 +35,11 @@ class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="SQL查询", description="使用大模型生成SQL语句,用于查询数据库中的结构化数据") + return CallInfo( + name="SQL查询", + type=CallType.TOOL, + description="使用大模型生成SQL语句,用于查询数据库中的结构化数据" + ) async def _init(self, call_vars: CallVars) -> SQLInput: diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 1788fa0f4..663e1d9c9 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -20,7 +20,7 @@ from apps.scheduler.call.suggest.schema import ( SuggestionInput, SuggestionOutput, ) -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.record import RecordContent from apps.schemas.scheduler import ( @@ -50,7 +50,11 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="问题推荐", description="在答案下方显示推荐的下一个问题") + return CallInfo( + name="问题推荐", + type=CallType.DEFAULT, + description="在答案下方显示推荐的下一个问题" + ) @classmethod diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index b605204e1..7f6ff062d 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -9,7 +9,7 @@ from pydantic import Field from apps.llm.patterns.executor import ExecutorSummary from apps.scheduler.call.core import CoreCall, DataBase from apps.scheduler.call.summary.schema import SummaryOutput -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( CallInfo, @@ -31,7 +31,11 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="理解上下文", description="使用大模型,理解对话上下文") + return CallInfo( + name="理解上下文", + type=CallType.DEFAULT, + description="使用大模型,理解对话上下文" + ) @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index c9f215fe2..106203d1c 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -18,7 +18,9 @@ from apps.scheduler.call.slot.schema import SlotOutput from apps.scheduler.call.slot.slot import Slot from apps.scheduler.call.summary.summary import Summary from apps.scheduler.executor.base import BaseExecutor +from apps.scheduler.executor.step_config import should_use_direct_conversation_format from apps.scheduler.pool.pool import Pool +from apps.scheduler.variable.integration import VariableIntegration from apps.schemas.enum_var import ( EventType, SpecialCallType, @@ -203,6 +205,108 @@ class StepExecutor(BaseExecutor): return content + + async def _save_output_parameters_to_variables(self, output_data: str | dict[str, Any]) -> None: + """保存节点输出参数到变量池""" + try: + # 检查是否有output_parameters配置 + output_parameters = None + if self.step.step.params and isinstance(self.step.step.params, dict): + output_parameters = self.step.step.params.get("output_parameters", {}) + + if not output_parameters or not isinstance(output_parameters, dict): + return + + # 确保output_data是字典格式 + if isinstance(output_data, str): + # 如果是字符串,包装成字典 + data_dict = {"text": output_data} + else: + data_dict = output_data if isinstance(output_data, dict) else {} + + # 确定变量名前缀(根据配置决定是否使用直接格式) + use_direct_format = should_use_direct_conversation_format( + call_id=self._call_id, + step_name=self.step.step.name, + step_id=self.step.step_id + ) + + if use_direct_format: + # 配置允许的节点类型保持原有格式:conversation.key + var_prefix = "" + logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用直接变量格式") + else: + # 其他节点使用格式:conversation.node_id.key + var_prefix = f"{self.step.step_id}." + logger.debug(f"[StepExecutor] 节点 {self.step.step.name}({self._call_id}) 使用带前缀变量格式") + + # 保存每个output_parameter到变量池 + saved_count = 0 + for param_name, param_config in output_parameters.items(): + try: + # 获取参数值 + param_value = self._extract_value_from_output_data(param_name, data_dict, param_config) + + if param_value is not None: + # 构造变量名 + var_name = f"{var_prefix}{param_name}" + + # 保存到对话变量池 + success = await VariableIntegration.save_conversation_variable( + var_name=var_name, + value=param_value, + var_type=param_config.get("type", "string"), + description=param_config.get("description", ""), + user_sub=self.task.ids.user_sub, + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + conversation_id=self.task.ids.conversation_id + ) + + if success: + saved_count += 1 + logger.debug(f"[StepExecutor] 已保存输出参数变量: conversation.{var_name} = {param_value}") + else: + logger.warning(f"[StepExecutor] 保存输出参数变量失败: {var_name}") + + except Exception as e: + logger.warning(f"[StepExecutor] 保存输出参数 {param_name} 失败: {e}") + + if saved_count > 0: + logger.info(f"[StepExecutor] 已保存 {saved_count} 个输出参数到变量池") + + except Exception as e: + logger.error(f"[StepExecutor] 保存输出参数到变量池失败: {e}") + + def _extract_value_from_output_data(self, param_name: str, output_data: dict[str, Any], param_config: dict) -> Any: + """从输出数据中提取参数值""" + # 支持多种提取方式 + + # 1. 直接从输出数据中获取同名key + if param_name in output_data: + return output_data[param_name] + + # 2. 支持路径提取(例如:result.data.value) + if "path" in param_config: + path = param_config["path"] + current_data = output_data + for key in path.split("."): + if isinstance(current_data, dict) and key in current_data: + current_data = current_data[key] + else: + return None + return current_data + + # 3. 支持默认值 + if "default" in param_config: + return param_config["default"] + + # 4. 如果参数配置为"full_output",返回完整输出 + if param_config.get("source") == "full_output": + return output_data + + return None + + async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) @@ -250,6 +354,9 @@ class StepExecutor(BaseExecutor): else: output_data = content + # 保存output_parameters到变量池 + await self._save_output_parameters_to_variables(output_data) + # 更新context history = FlowStepHistory( task_id=self.task.id, diff --git a/apps/scheduler/executor/step_config.py b/apps/scheduler/executor/step_config.py new file mode 100644 index 000000000..1abcf4496 --- /dev/null +++ b/apps/scheduler/executor/step_config.py @@ -0,0 +1,79 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""步骤执行器配置""" + +from typing import Set + +# 可以直接写入 conversation.key 格式的节点类型配置 +# 这些节点类型的输出变量不会添加 step_id 前缀 +# +# 说明: +# - 添加新的节点类型时,直接在这个集合中添加对应的 call_id 或节点名称即可 +# - 支持大小写敏感匹配,建议同时添加大小写版本以确保兼容性 +# - 这些节点的输出变量将保存为 conversation.key 格式 +# - 其他节点的输出变量将保存为 conversation.step_id.key 格式 +DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: Set[str] = { + # 开始节点相关 + "Start", + "start", + + # 输入节点相关 + "Input", + "UserInput", + "input", + + # 未来可能的节点类型示例(取消注释即可启用) + # "GlobalConfig", # 全局配置节点 + # "SessionInit", # 会话初始化节点 + # "SystemConfig", # 系统配置节点 +} + +# 可以通过节点名称模式匹配的规则 +# 如果节点名称或step_id(转换为小写后)以这些字符串开头,则使用直接格式 +# +# 说明: +# - 这些模式用于匹配节点名称或step_id的前缀(不区分大小写) +# - 比如:"start"会匹配 "StartProcess"、"start_workflow" 等 +# - 适用于无法提前知道具体节点名称,但可以通过命名规范识别的场景 +DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: Set[str] = { + "start", # 匹配所有以start开头的节点 + "init", # 匹配所有以init开头的节点 + "input", # 匹配所有以input开头的节点 + + # 可以根据需要添加更多模式 + # "config", # 匹配配置相关节点 + # "setup", # 匹配设置相关节点 +} + +def should_use_direct_conversation_format(call_id: str, step_name: str, step_id: str) -> bool: + """ + 判断是否应该使用直接的 conversation.key 格式 + + Args: + call_id: 节点的call_id + step_name: 节点名称 + step_id: 节点ID + + Returns: + bool: True表示使用 conversation.key,False表示使用 conversation.step_id.key + """ + # 1. 检查call_id是否在直接写入列表中 + if call_id in DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: + return True + + # 2. 检查节点名称是否在直接写入列表中 + if step_name in DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: + return True + + # 3. 检查节点名称是否匹配模式 + step_name_lower = step_name.lower() + for pattern in DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: + if step_name_lower.startswith(pattern): + return True + + # 4. 检查step_id是否匹配模式 + step_id_lower = step_id.lower() + for pattern in DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: + if step_id_lower.startswith(pattern): + return True + + return False \ No newline at end of file diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 283448786..157342875 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -42,13 +42,12 @@ class CallLoader(metaclass=SingletonMeta): call_metadata.append( CallPool( _id=call_id, - type=CallType.SYSTEM, + type=call_info.type, name=call_info.name, description=call_info.description, path=f"python::apps.scheduler.call::{call_id}", ), ) - return call_metadata async def _load_single_call_dir(self, call_dir_name: str) -> list[CallPool]: @@ -189,6 +188,7 @@ class CallLoader(metaclass=SingletonMeta): NodePool( _id=call.id, name=call.name, + type=call.type, description=call.description, service_id="", call_id=call.id, diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 57344d40a..2b4fcdc7a 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -27,6 +27,10 @@ BASE_PATH = Path(Config().get_config().deploy.data_dir) / "semantics" / "app" class FlowLoader: """工作流加载器""" + # 添加并发控制 + _loading_flows = {} # 改为字典,存储加载任务 + _loading_lock = asyncio.Lock() + async def _load_yaml_file(self, flow_path: Path) -> dict[str, Any]: """从YAML文件加载工作流配置""" try: @@ -100,7 +104,61 @@ class FlowLoader: async def load(self, app_id: str, flow_id: str) -> Flow | None: """从文件系统中加载【单个】工作流""" - logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", flow_id, app_id) + flow_key = f"{app_id}:{flow_id}" + + # 第一次检查:是否已在加载中 + existing_task = None + async with self._loading_lock: + if flow_key in self._loading_flows: + existing_task = self._loading_flows[flow_key] + + # 如果找到现有任务,等待其完成 + if existing_task is not None: + logger.info(f"[FlowLoader] 工作流正在加载中,等待完成: {flow_key}") + try: + return await existing_task + except Exception as e: + logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}") + # 如果等待失败,清理失败的任务并重试 + async with self._loading_lock: + if self._loading_flows.get(flow_key) == existing_task: + self._loading_flows.pop(flow_key, None) + return None + + # 创建新的加载任务 + task = None + async with self._loading_lock: + # 再次检查,防止竞态条件 + if flow_key in self._loading_flows: + existing_task = self._loading_flows[flow_key] + # 如果有新任务出现,等待它完成 + if existing_task is not None: + try: + return await existing_task + except Exception as e: + logger.error(f"[FlowLoader] 等待工作流加载失败: {flow_key}, 错误: {e}") + return None + + # 创建新的加载任务 + task = asyncio.create_task(self._do_load(app_id, flow_id)) + self._loading_flows[flow_key] = task + + # 执行加载任务 + try: + result = await task + return result + except Exception as e: + logger.error(f"[FlowLoader] 工作流加载失败: {flow_key}, 错误: {e}") + return None + finally: + # 确保从加载集合中移除 + async with self._loading_lock: + if self._loading_flows.get(flow_key) == task: + self._loading_flows.pop(flow_key, None) + + async def _do_load(self, app_id: str, flow_id: str) -> Flow | None: + """实际执行加载工作流的方法""" + logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", app_id, flow_id) # 构建工作流文件路径 flow_path = BASE_PATH / app_id / "flow" / f"{flow_id}.yaml" @@ -235,18 +293,29 @@ class FlowLoader: except Exception: logger.exception("[FlowLoader] 更新 MongoDB 失败") - # 删除重复的ID - while True: + # 删除重复的ID,增加重试次数限制 + max_retries = 10 + retry_count = 0 + while retry_count < max_retries: try: table = await LanceDB().get_table("flow") await table.delete(f"id = '{metadata.id}'") break except RuntimeError as e: if "Commit conflict" in str(e): - logger.error("[FlowLoader] LanceDB删除flow冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) + retry_count += 1 + logger.error(f"[FlowLoader] LanceDB删除flow冲突,重试中... ({retry_count}/{max_retries})") # noqa: TRY400 + # 指数退避,减少冲突概率 + await asyncio.sleep(0.01 * (2 ** min(retry_count, 5))) else: raise + except Exception as e: + logger.error(f"[FlowLoader] LanceDB删除操作异常: {e}") + break + + if retry_count >= max_retries: + logger.warning(f"[FlowLoader] LanceDB删除flow达到最大重试次数,跳过删除: {metadata.id}") + # 不抛出异常,继续执行后续操作 # 进行向量化 service_embedding = await Embedding.get_embedding([metadata.description]) vector_data = [ @@ -256,7 +325,10 @@ class FlowLoader: embedding=service_embedding[0], ), ] - while True: + # 插入向量数据,增加重试次数限制 + max_retries_insert = 10 + retry_count_insert = 0 + while retry_count_insert < max_retries_insert: try: table = await LanceDB().get_table("flow") await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( @@ -265,7 +337,16 @@ class FlowLoader: break except RuntimeError as e: if "Commit conflict" in str(e): - logger.error("[FlowLoader] LanceDB插入flow冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) + retry_count_insert += 1 + logger.error(f"[FlowLoader] LanceDB插入flow冲突,重试中... ({retry_count_insert}/{max_retries_insert})") # noqa: TRY400 + # 指数退避,减少冲突概率 + await asyncio.sleep(0.01 * (2 ** min(retry_count_insert, 5))) else: raise + except Exception as e: + logger.error(f"[FlowLoader] LanceDB插入操作异常: {e}") + break + + if retry_count_insert >= max_retries_insert: + logger.error(f"[FlowLoader] LanceDB插入flow达到最大重试次数,操作失败: {metadata.id}") + raise RuntimeError(f"LanceDB插入flow失败,达到最大重试次数: {metadata.id}") diff --git a/apps/scheduler/variable/README.md b/apps/scheduler/variable/README.md index 2b6313960..73718ce01 100644 --- a/apps/scheduler/variable/README.md +++ b/apps/scheduler/variable/README.md @@ -1,152 +1,264 @@ -# 工作流变量管理系统 +# 变量池架构文档 -## 概述 +## 架构设计 -工作流变量管理系统为Euler Copilot Framework提供了全面的变量管理功能,支持在工作流执行过程中进行变量的定义、存储、解析和使用。 +基于用户需求,变量系统采用"模板-实例"的两级架构: -## 功能特性 +### 设计理念 -### 1. 多种变量类型支持 -- **基础类型**: String、Number、Boolean、Object、File -- **安全类型**: Secret(加密存储的密钥变量) -- **数组类型**: Array[Any]、Array[String]、Array[Number]、Array[Object]、Array[File]、Array[Boolean]、Array[Secret] +- **Flow级别(父pool)**:管理变量模板定义,用户可以查看和配置变量结构 +- **Conversation级别(子pool)**:管理变量实例,存储实际的运行时数据 -### 2. 四种作用域 -- **系统级变量** (`system`): 只读,包含query、files、app_id等系统信息 -- **用户级变量** (`user`): 跟随用户,如个人API密钥、配置信息 -- **环境级变量** (`env`): 跟随工作流,如流程配置参数 -- **对话级变量** (`conversation`): 单次对话内有效,支持局部作用域 +### 变量分类 -### 3. 变量解析语法 -支持在模板中使用以下语法引用变量: +不同类型的变量有不同的存储和管理方式: +- **系统变量**:模板在Flow级别定义,实例在Conversation级别运行时更新 +- **对话变量**:模板在Flow级别定义,实例在Conversation级别用户可设置 +- **环境变量**:直接在Flow级别存储和使用 +- **用户变量**:在User级别长期存储 + +## 架构实现 + +### 变量池类型 + +#### 1. UserVariablePool(用户变量池) +- **关联ID**: `user_id` +- **权限**: 用户可读写 +- **生命周期**: 随用户创建而创建,长期存在 +- **典型变量**: API密钥、用户偏好、个人配置等 + +#### 2. FlowVariablePool(流程变量池) +- **关联ID**: `flow_id` +- **权限**: 流程可读写 +- **生命周期**: 随 flow 创建而创建 +- **继承**: 支持从父流程继承 +- **存储内容**: + - 环境变量(直接使用) + - 系统变量模板(供对话继承) + - 对话变量模板(供对话继承) + +#### 3. ConversationVariablePool(对话变量池) +- **关联ID**: `conversation_id` +- **权限**: + - 系统变量实例:只读,由系统自动更新 + - 对话变量实例:可读写,用户可设置值 +- **生命周期**: 随对话创建而创建,对话结束后可选择性清理 +- **初始化方式**: 从FlowVariablePool的模板自动继承 +- **包含内容**: + - **系统变量实例**:`query`, `files`, `dialogue_count`等运行时值 + - **对话变量实例**:用户定义的对话上下文数据 + +## 核心设计原则 + +### 1. 统一的对话上下文 +所有对话相关的变量(无论是系统变量还是对话变量)都在同一个对话变量池中管理,确保上下文的一致性。 + +### 2. 权限区分 +通过 `is_system` 标记区分系统变量和对话变量: +- `is_system=True`: 系统变量,只读,由系统自动更新 +- `is_system=False`: 对话变量,可读写,支持人为修改 + +### 3. 自动初始化和持久化 +创建对话变量池时,自动初始化所有必需的系统变量,设置合理的默认值,并立即持久化到数据库,确保系统变量在任何时候都可用。 + +## 使用方式 + +### 1. 创建对话变量池 + +```python +pool_manager = await get_pool_manager() + +# 创建对话变量池(自动包含系统变量) +conv_pool = await pool_manager.create_conversation_pool("conv123", "flow456") ``` -{{sys.query}} # 系统变量 -{{user.api_key}} # 用户变量 -{{env.config.timeout}} # 环境变量(支持嵌套访问) -{{conversation.result}} # 对话变量 + +### 2. 更新系统变量 + +```python +# 系统变量由解析器自动更新 +parser = VariableParser( + user_id="user123", + flow_id="flow456", + conversation_id="conv123" +) + +# 更新系统变量 +await parser.update_system_variables({ + "question": "你好,请帮我分析数据", + "files": [{"name": "data.csv", "size": 1024}], + "dialogue_count": 1, + "user_sub": "user123" +}) ``` -### 4. 安全保障 -- Secret变量加密存储 -- 访问权限控制 -- 审计日志记录 -- 访问频率限制 -- 密钥强度检查 - -## API 接口 - -### 变量管理 -- `POST /api/variable/create` - 创建变量 -- `PUT /api/variable/update` - 更新变量值 -- `DELETE /api/variable/delete` - 删除变量 -- `GET /api/variable/get` - 获取单个变量 -- `GET /api/variable/list` - 列出变量 - -### 模板解析 -- `POST /api/variable/parse` - 解析模板中的变量引用 -- `POST /api/variable/validate` - 验证模板中的变量引用 - -### 系统信息 -- `GET /api/variable/types` - 获取支持的变量类型 -- `POST /api/variable/clear-conversation` - 清空对话变量 - -## 使用示例 - -### 1. 创建用户级变量 -```json -{ - "name": "openai_api_key", - "var_type": "secret", - "scope": "user", - "value": "sk-xxxxxxxxxxxxxxxx", - "description": "OpenAI API密钥" -} +### 3. 更新对话变量 + +```python +# 添加对话变量 +await conv_pool.add_variable( + name="context_history", + var_type=VariableType.ARRAY_STRING, + value=["用户问候", "系统回应"], + description="对话历史" +) + +# 更新对话变量 +await conv_pool.update_variable("context_history", value=["问候", "回应", "新消息"]) ``` -### 2. 在工作流中使用变量 -```yaml -steps: - llm_call: - node: "llm" - params: - system_prompt: "你是一个助手,用户查询:{{sys.query}}" - api_key: "{{user.openai_api_key}}" - temperature: "{{env.llm_config.temperature}}" +### 4. 变量解析 + +```python +# 系统变量和对话变量使用相同的引用语法 +template = """ +系统变量 - 用户查询: {{sys.query}} +系统变量 - 对话轮数: {{sys.dialogue_count}} +对话变量 - 历史: {{conversation.context_history}} +用户变量 - 偏好: {{user.preferences}} +环境变量 - 数据库: {{env.database_url}} +""" + +parsed = await parser.parse_template(template) ``` -### 3. 在对话中设置临时变量 -工作流执行过程中可以动态设置对话级变量: +## 变量引用语法 + +变量引用保持不变: +- `{{sys.variable_name}}` - 系统变量(对话级别,只读) +- `{{conversation.variable_name}}` - 对话变量(对话级别,可读写) +- `{{user.variable_name}}` - 用户变量 +- `{{env.variable_name}}` - 环境变量 + +## 权限控制详细说明 + +### 系统变量权限 ```python -await VariableIntegration.add_conversation_variable( - name="processing_result", - value={"status": "completed", "data": result}, - conversation_id=conversation_id, - var_type_str="object" -) +# 普通更新会被拒绝 +await conv_pool.update_variable("query", value="new query") # ❌ 抛出 PermissionError + +# 系统内部更新 +await conv_pool.update_system_variable("query", "new query") # ✅ 成功 +# 或者 +await conv_pool.update_variable("query", value="new query", force_system_update=True) # ✅ 成功 +``` + +### 对话变量权限 +```python +# 普通对话变量可以自由更新 +await conv_pool.update_variable("context_history", value=new_history) # ✅ 成功 ``` -## 集成到工作流 +## 数据存储 + +### 元数据增强 +```python +class VariableMetadata(BaseModel): + # ... 其他字段 + is_system: bool = Field(default=False, description="是否为系统变量(只读)") +``` + +### 数据库查询 +系统变量和对话变量存储在同一个集合中,通过 `metadata.is_system` 字段区分。 + +## 迁移影响 + +### 对用户的影响 +- ✅ **变量引用语法完全不变** +- ✅ **API接口完全兼容** +- ✅ **现有功能正常工作** + +### 内部实现变化 +- 去掉了独立的 `SystemVariablePool` +- 系统变量现在在 `ConversationVariablePool` 中管理 +- 通过权限控制区分系统变量和对话变量 -系统自动集成到现有的工作流调度器中: +## 架构优势 -1. **系统变量自动初始化** - 每次工作流启动时自动设置系统变量 -2. **输入参数解析** - Call的输入参数自动解析变量引用 -3. **模板渲染** - LLM提示词等模板自动替换变量 -4. **输出变量提取** - 步骤输出可自动提取为对话级变量 +### 1. 逻辑一致性 +系统变量和对话变量都属于对话上下文,在同一个池中管理更合理。 -## 安全机制 +### 2. 简化管理 +不需要在系统池和对话池之间同步数据,避免了数据一致性问题。 -### Secret变量保护 -- 使用AES加密存储 -- 基于用户ID和变量名生成唯一加密密钥 -- 显示时自动打码 -- 支持密钥轮换 +### 3. 更好的性能 +减少了池之间的数据传递和同步开销。 -### 访问控制 -- 权限验证(用户只能访问自己的变量) -- 访问频率限制(防止暴力破解) -- IP地址验证(可选) -- 失败尝试监控和临时封禁 +### 4. 扩展性 +为未来可能的对话级系统变量扩展提供了更好的基础。 -### 审计日志 -- 完整的访问日志记录 -- 密钥访问审计(记录哈希值而非原始值) -- 自动清理过期日志 +## 总结 -## 开发指南 +修正后的架构更准确地反映了变量的实际使用场景: +- **用户变量**: 用户级别,长期存在 +- **环境变量**: 流程级别,配置相关 +- **系统变量 + 对话变量**: 对话级别,上下文相关 -### 添加新的变量类型 -1. 在`VariableType`枚举中添加新类型 -2. 创建继承自`BaseVariable`的新变量类 -3. 在`VARIABLE_CLASS_MAP`中注册映射关系 +这样的设计更符合实际业务逻辑,也更容易理解和维护。 -### 扩展变量解析 -1. 修改`VariableParser.VARIABLE_PATTERN`正则表达式 -2. 在`resolve_variable_reference`方法中添加新的解析逻辑 +## 系统变量详细说明 -### 自定义安全策略 -1. 继承`SecretVariableSecurity`类 -2. 重写相关的安全检查方法 -3. 在应用启动时注册自定义安全管理器 +### 预定义系统变量 -## 故障排除 +每个对话变量池创建时,会自动初始化以下系统变量: -### 常见问题 -1. **变量解析失败** - 检查变量名是否存在,作用域是否正确 -2. **Secret变量解密失败** - 可能是加密密钥损坏,需要重新设置 -3. **访问被拒绝** - 检查用户权限和访问频率限制 +| 变量名 | 类型 | 描述 | 初始值 | +|-------|------|------|--------| +| `query` | STRING | 用户查询内容 | "" | +| `files` | ARRAY_FILE | 用户上传的文件列表 | [] | +| `dialogue_count` | NUMBER | 对话轮数 | 0 | +| `app_id` | STRING | 应用ID | "" | +| `flow_id` | STRING | 工作流ID | {flow_id} | +| `user_id` | STRING | 用户ID | "" | +| `session_id` | STRING | 会话ID | "" | +| `conversation_id` | STRING | 对话ID | {conversation_id} | +| `timestamp` | NUMBER | 当前时间戳 | {当前时间} | + +### 系统变量生命周期 + +1. **创建阶段**:对话变量池创建时,所有系统变量被初始化并持久化到数据库 +2. **更新阶段**:通过`VariableParser.update_system_variables()`方法更新系统变量值 +3. **访问阶段**:通过模板解析或直接访问获取系统变量值 +4. **清理阶段**:对话结束时,整个对话变量池被清理 + +### 系统变量更新机制 -### 日志调试 -启用DEBUG日志级别查看详细的变量解析过程: ```python -import logging -logging.getLogger('apps.scheduler.variable').setLevel(logging.DEBUG) +# 创建解析器并确保对话池存在 +parser = VariableParser(user_id=user_id, flow_id=flow_id, conversation_id=conversation_id) +await parser.create_conversation_pool_if_needed() + +# 更新系统变量 +context = { + "question": "用户的问题", + "files": [{"name": "file.txt", "size": 1024}], + "dialogue_count": 1, + "app_id": "app123", + "user_sub": user_id, + "session_id": "session456" +} + +await parser.update_system_variables(context) ``` -## 最佳实践 +### 系统变量的只读保护 + +```python +# ❌ 直接修改系统变量会失败 +await conversation_pool.update_variable("query", value="修改内容") # 抛出PermissionError + +# ✅ 只能通过系统内部接口更新 +await conversation_pool.update_system_variable("query", "新内容") # 成功 +``` + +### 使用系统变量 + +```python +# 在模板中引用系统变量 +template = """ +用户问题:{{sys.query}} +对话轮数:{{sys.dialogue_count}} +工作流ID:{{sys.flow_id}} +""" -1. **变量命名** - 使用描述性的名称,避免特殊字符 -2. **作用域选择** - 根据变量的生命周期选择合适的作用域 -3. **Secret管理** - 定期轮换密钥,使用强密码 -4. **性能优化** - 避免在循环中频繁解析变量 -5. **安全审计** - 定期检查访问日志,监控异常行为 \ No newline at end of file +parsed = await parser.parse_template(template) +``` \ No newline at end of file diff --git a/apps/scheduler/variable/__init__.py b/apps/scheduler/variable/__init__.py index 9ee14662d..2c487379c 100644 --- a/apps/scheduler/variable/__init__.py +++ b/apps/scheduler/variable/__init__.py @@ -22,7 +22,7 @@ from .variables import ( create_variable, VARIABLE_CLASS_MAP, ) -from .pool import VariablePool, get_variable_pool +from .pool_manager import VariablePoolManager, get_pool_manager from .parser import VariableParser, VariableReferenceBuilder, VariableContext from .integration import VariableIntegration @@ -46,9 +46,9 @@ __all__ = [ "create_variable", "VARIABLE_CLASS_MAP", - # 变量池 - "VariablePool", - "get_variable_pool", + # 变量池管理器 + "VariablePoolManager", + "get_pool_manager", # 解析器 "VariableParser", diff --git a/apps/scheduler/variable/base.py b/apps/scheduler/variable/base.py index 11e3c246e..94beaaebc 100644 --- a/apps/scheduler/variable/base.py +++ b/apps/scheduler/variable/base.py @@ -19,6 +19,13 @@ class VariableMetadata(BaseModel): # 作用域相关属性 user_sub: Optional[str] = Field(default=None, description="用户级变量的用户ID") flow_id: Optional[str] = Field(default=None, description="环境级/对话级变量的流程ID") + conversation_id: Optional[str] = Field(default=None, description="对话级变量的对话ID") + + # 系统变量标识 + is_system: bool = Field(default=False, description="是否为系统变量(只读)") + + # 模板变量标识 + is_template: bool = Field(default=False, description="是否为模板变量(存储在flow级别)") # 安全相关属性 is_encrypted: bool = Field(default=False, description="是否加密存储") diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py index b1c54c9c5..a33385a17 100644 --- a/apps/scheduler/variable/integration.py +++ b/apps/scheduler/variable/integration.py @@ -4,7 +4,7 @@ import logging from typing import Any, Dict, List, Optional, Union from apps.scheduler.variable.parser import VariableParser -from apps.scheduler.variable.pool import get_variable_pool +from apps.scheduler.variable.pool_manager import get_pool_manager from apps.scheduler.variable.type import VariableScope logger = logging.getLogger(__name__) @@ -67,6 +67,112 @@ class VariableIntegration: logger.warning(f"解析Call输入变量失败: {e}") # 如果解析失败,返回原始输入 return input_data + + @staticmethod + async def resolve_variable_reference( + reference: str, + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None + ) -> Any: + """解析单个变量引用 + + Args: + reference: 变量引用字符串(如 "{{user.name}}" 或 "user.name") + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + Any: 解析后的变量值 + """ + try: + parser = VariableParser( + user_id=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + # 清理引用字符串(移除花括号) + clean_reference = reference.strip("{}") + + # 使用解析器解析变量引用 + resolved_value = await parser._resolve_variable_reference(clean_reference) + + return resolved_value + + except Exception as e: + logger.error(f"解析变量引用失败: {reference}, 错误: {e}") + raise + + @staticmethod + async def save_conversation_variable( + var_name: str, + value: Any, + var_type: str = "string", + description: str = "", + user_sub: str = "", + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None + ) -> bool: + """保存对话变量 + + Args: + var_name: 变量名(不包含scope前缀) + value: 变量值 + var_type: 变量类型 + description: 变量描述 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + bool: 是否保存成功 + """ + try: + if not conversation_id: + logger.warning("无法保存对话变量:缺少conversation_id") + return False + + # 直接使用pool_manager,避免解析器的复杂逻辑 + pool_manager = await get_pool_manager() + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + + if not conversation_pool: + logger.warning(f"无法获取对话变量池: {conversation_id}") + return False + + # 转换变量类型 + from apps.scheduler.variable.type import VariableType + try: + var_type_enum = VariableType(var_type) + except ValueError: + var_type_enum = VariableType.STRING + logger.warning(f"未知的变量类型 {var_type},使用默认类型 string") + + # 尝试更新变量,如果不存在则创建 + try: + await conversation_pool.update_variable(var_name, value=value) + logger.debug(f"对话变量已更新: {var_name} = {value}") + return True + except ValueError as e: + if "不存在" in str(e): + # 变量不存在,创建新变量 + await conversation_pool.add_variable( + name=var_name, + var_type=var_type_enum, + value=value, + description=description, + created_by=user_sub or "system" + ) + logger.debug(f"对话变量已创建: {var_name} = {value}") + return True + else: + raise # 其他错误重新抛出 + + except Exception as e: + logger.error(f"保存对话变量失败: {var_name} - {e}") + return False @staticmethod async def parse_template_string(template: str, @@ -119,14 +225,18 @@ class VariableIntegration: # 转换变量类型 var_type = VariableType(var_type_str) - pool = await get_variable_pool() - # 在内部将conversation_id作为flow_id传递 - await pool.add_variable( + pool_manager = await get_pool_manager() + # 获取对话变量池(如果不存在会抛出异常) + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + if not conversation_pool: + logger.error(f"对话变量池不存在: {conversation_id}") + return False + + await conversation_pool.add_variable( name=name, var_type=var_type, - scope=VariableScope.CONVERSATION, value=value, - flow_id=conversation_id # 统一使用flow_id + description=f"对话变量: {name}" ) logger.debug(f"已添加对话变量: {name} = {value}") @@ -150,13 +260,16 @@ class VariableIntegration: bool: 是否更新成功 """ try: - pool = await get_variable_pool() - # 在内部将conversation_id作为flow_id传递 - await pool.update_variable( + pool_manager = await get_pool_manager() + # 获取对话变量池 + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + if not conversation_pool: + logger.error(f"对话变量池不存在: {conversation_id}") + return False + + await conversation_pool.update_variable( name=name, - scope=VariableScope.CONVERSATION, - value=value, - flow_id=conversation_id # 统一使用flow_id + value=value ) logger.debug(f"已更新对话变量: {name} = {value}") @@ -209,9 +322,13 @@ class VariableIntegration: conversation_id: 对话ID """ try: - pool = await get_variable_pool() - await pool.clear_conversation_variables(conversation_id) - logger.info(f"已清理对话 {conversation_id} 的变量") + pool_manager = await get_pool_manager() + # 移除对话变量池 + success = await pool_manager.remove_conversation_pool(conversation_id) + if success: + logger.info(f"已清理对话 {conversation_id} 的变量") + else: + logger.warning(f"对话变量池不存在: {conversation_id}") except Exception as e: logger.error(f"清理对话变量失败: {e}") @@ -244,134 +361,6 @@ class VariableIntegration: return False, [str(e)] -# 为现有调度器提供的扩展方法 -class CoreCallExtension: - """CoreCall的变量支持扩展""" - - @staticmethod - async def enhanced_set_input(core_call_instance, executor) -> None: - """增强版的_set_input方法,支持变量解析 - - Args: - core_call_instance: CoreCall实例 - executor: StepExecutor实例 - """ - # 调用原始的_set_input方法 - original_set_input = getattr(core_call_instance.__class__, '_original_set_input', None) - if original_set_input: - await original_set_input(core_call_instance, executor) - else: - # 如果没有原始方法,执行基本的输入设置 - from apps.scheduler.call.core import CoreCall - core_call_instance._sys_vars = CoreCall._assemble_call_vars(executor) - input_data = await core_call_instance._init(core_call_instance._sys_vars) - core_call_instance.input = input_data.model_dump(by_alias=True, exclude_none=True) - - # 解析输入中的变量引用 - if hasattr(core_call_instance, 'input') and core_call_instance.input: - # 首先初始化系统变量 - context = { - "question": executor.question, - "user_sub": executor.task.ids.user_sub, - "flow_id": executor.task.state.flow_id if executor.task.state else None, - "session_id": executor.task.ids.session_id, - "app_id": executor.task.state.app_id if executor.task.state else None, - "conversation_id": executor.task.ids.session_id, # 使用session_id作为conversation_id - } - - await VariableIntegration.initialize_system_variables(context) - - # 解析输入变量 - core_call_instance.input = await VariableIntegration.parse_call_input( - core_call_instance.input, - user_sub=executor.task.ids.user_sub, - flow_id=executor.task.state.flow_id if executor.task.state else None, - conversation_id=executor.task.ids.session_id - ) - - -class LLMCallExtension: - """LLM Call的变量支持扩展""" - - @staticmethod - async def enhanced_prepare_message(llm_instance, call_vars) -> list[dict[str, Any]]: - """增强版的_prepare_message方法,支持变量解析 - - Args: - llm_instance: LLM实例 - call_vars: CallVars实例 - - Returns: - list[dict[str, Any]]: 解析后的消息列表 - """ - # 调用原始的_prepare_message方法 - original_prepare_message = getattr(llm_instance.__class__, '_original_prepare_message', None) - if original_prepare_message: - messages = await original_prepare_message(llm_instance, call_vars) - else: - # 如果没有原始方法,返回基本消息 - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": call_vars.question} - ] - - # 解析消息中的变量引用 - parsed_messages = [] - for message in messages: - if "content" in message and isinstance(message["content"], str): - parsed_content = await VariableIntegration.parse_template_string( - message["content"], - user_sub=call_vars.ids.user_sub, - flow_id=call_vars.ids.flow_id, - conversation_id=call_vars.ids.session_id - ) - parsed_messages.append({ - **message, - "content": parsed_content - }) - else: - parsed_messages.append(message) - - return parsed_messages - - -def monkey_patch_scheduler(): - """为现有调度器添加变量支持的猴子补丁""" - try: - # 为CoreCall添加变量支持 - from apps.scheduler.call.core import CoreCall - - # 保存原始方法 - original_set_input = getattr(CoreCall, '_set_input', None) - if original_set_input: - setattr(CoreCall, '_original_set_input', original_set_input) - - # 替换为增强版本 - setattr(CoreCall, '_set_input', CoreCallExtension.enhanced_set_input) - # 为LLM Call添加变量支持 - try: - from apps.scheduler.call.llm.llm import LLM - - # 保存原始方法 - original_prepare_message = getattr(LLM, '_prepare_message', None) - if original_prepare_message: - setattr(LLM, '_original_prepare_message', original_prepare_message) - - # 替换为增强版本 - setattr(LLM, '_prepare_message', LLMCallExtension.enhanced_prepare_message) - - except ImportError: - logger.warning("LLM Call 不存在,跳过LLM变量支持") - - logger.info("变量解析功能已集成到调度器中") - - except Exception as e: - logger.error(f"集成变量解析功能失败: {e}") - raise - - -# 在模块导入时自动应用补丁 -try: - monkey_patch_scheduler() -except Exception as e: - logger.error(f"自动应用变量解析补丁失败: {e}") \ No newline at end of file +# 注意:原本的 monkey_patch_scheduler 和相关扩展类已被移除 +# 因为 CoreCall 类现在已经内置了完整的变量解析功能 +# 这些代码是旧版本的遗留,会导致循环导入问题 \ No newline at end of file diff --git a/apps/scheduler/variable/parser.py b/apps/scheduler/variable/parser.py index 79cfad5fb..9eb0dcfaf 100644 --- a/apps/scheduler/variable/parser.py +++ b/apps/scheduler/variable/parser.py @@ -4,39 +4,42 @@ from typing import Any, Dict, List, Optional, Tuple, Union import json from datetime import datetime, UTC -from .pool import get_variable_pool +from .pool_manager import get_pool_manager from .type import VariableScope logger = logging.getLogger(__name__) class VariableParser: - """变量解析器 - 解析和替换模板中的变量引用""" + """变量解析器 - 支持新架构的变量解析器(系统变量和对话变量都在对话池中)""" # 变量引用的正则表达式:{{scope.variable_name.nested_path}} VARIABLE_PATTERN = re.compile(r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*\}\}') def __init__(self, - user_sub: Optional[str] = None, + user_id: Optional[str] = None, flow_id: Optional[str] = None, - conversation_id: Optional[str] = None): + conversation_id: Optional[str] = None, + user_sub: Optional[str] = None): """初始化变量解析器 Args: - user_sub: 用户ID + user_id: 用户ID (向后兼容) flow_id: 流程ID conversation_id: 对话ID + user_sub: 用户订阅ID (优先使用,用于未来鉴权等需求) """ - self.user_sub = user_sub + # 优先使用 user_sub,如果没有则使用 user_id + self.user_id = user_sub if user_sub is not None else user_id self.flow_id = flow_id self.conversation_id = conversation_id - self._variable_pool = None + self._pool_manager = None - async def _get_pool(self): - """获取变量池实例""" - if self._variable_pool is None: - self._variable_pool = await get_variable_pool() - return self._variable_pool + async def _get_pool_manager(self): + """获取变量池管理器实例""" + if self._pool_manager is None: + self._pool_manager = await get_pool_manager() + return self._pool_manager async def parse_template(self, template: str) -> str: """解析模板字符串,替换其中的变量引用 @@ -50,8 +53,6 @@ class VariableParser: if not template: return template - pool = await self._get_pool() - # 查找所有变量引用 matches = self.VARIABLE_PATTERN.findall(template) @@ -60,12 +61,7 @@ class VariableParser: for match in matches: try: # 解析变量引用 - value = await pool.resolve_variable_reference( - f"{{{{{match}}}}}", - user_sub=self.user_sub, - flow_id=self.flow_id, - conversation_id=self.conversation_id - ) + value = await self._resolve_variable_reference(match) # 转换为字符串 str_value = self._convert_to_string(value) @@ -82,6 +78,90 @@ class VariableParser: return result + async def _resolve_variable_reference(self, reference: str) -> Any: + """解析变量引用 + + Args: + reference: 变量引用字符串(不含花括号) + + Returns: + Any: 变量值 + """ + pool_manager = await self._get_pool_manager() + + # 解析作用域和变量名 + parts = reference.split(".", 1) + if len(parts) != 2: + raise ValueError(f"无效的变量引用格式: {reference}") + + scope_str, var_path = parts + + # 确定作用域 + scope_map = { + "sys": VariableScope.SYSTEM, + "system": VariableScope.SYSTEM, + "user": VariableScope.USER, + "env": VariableScope.ENVIRONMENT, + "environment": VariableScope.ENVIRONMENT, + "conversation": VariableScope.CONVERSATION, + "conv": VariableScope.CONVERSATION, + } + + scope = scope_map.get(scope_str) + if not scope: + raise ValueError(f"无效的变量作用域: {scope_str}") + + # 解析变量路径 + # 对于conversation作用域,支持节点输出变量格式:conversation.node_id.key + if scope == VariableScope.CONVERSATION and "." in var_path: + # 检查是否为节点输出变量(格式:node_id.key) + # 先尝试获取完整路径作为变量名 + try: + variable = await pool_manager.get_variable_from_any_pool( + name=var_path, # 使用完整路径作为变量名 + scope=scope, + user_id=self.user_id if scope == VariableScope.USER else None, + flow_id=self.flow_id if scope in [VariableScope.SYSTEM, VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + conversation_id=self.conversation_id if scope in [VariableScope.SYSTEM, VariableScope.CONVERSATION] else None + ) + if variable: + return variable.value + except: + pass # 如果找不到,继续使用原有逻辑 + + # 原有逻辑:支持嵌套访问如 user.config.api_key + path_parts = var_path.split(".") + var_name = path_parts[0] + + # 根据作用域获取变量 + variable = await pool_manager.get_variable_from_any_pool( + name=var_name, + scope=scope, + user_id=self.user_id, + flow_id=self.flow_id, + conversation_id=self.conversation_id + ) + + if not variable: + raise ValueError(f"变量不存在: {scope_str}.{var_name}") + + # 获取变量值 + value = variable.value + + # 如果有嵌套路径,继续解析 + for path_part in path_parts[1:]: + if isinstance(value, dict): + value = value.get(path_part) + elif isinstance(value, list) and path_part.isdigit(): + try: + value = value[int(path_part)] + except IndexError: + value = None + else: + raise ValueError(f"无法访问路径: {var_path}") + + return value + async def extract_variables(self, template: str) -> List[str]: """提取模板中的所有变量引用 @@ -109,18 +189,12 @@ class VariableParser: if not template: return True, [] - pool = await self._get_pool() matches = self.VARIABLE_PATTERN.findall(template) invalid_refs = [] for match in matches: try: - await pool.resolve_variable_reference( - f"{{{{{match}}}}}", - user_sub=self.user_sub, - flow_id=self.flow_id, - conversation_id=self.conversation_id - ) + await self._resolve_variable_reference(match) except Exception: invalid_refs.append(f"{{{{{match}}}}}") @@ -161,7 +235,14 @@ class VariableParser: Args: context: 系统上下文信息 """ - pool = await self._get_pool() + if not self.conversation_id: + logger.warning("无法更新系统变量:缺少conversation_id") + return + + # 确保对话变量池存在 + await self.create_conversation_pool_if_needed() + + pool_manager = await self._get_pool_manager() # 预定义的系统变量映射 system_var_mappings = { @@ -169,22 +250,85 @@ class VariableParser: "files": context.get("files", []), "dialogue_count": context.get("dialogue_count", 0), "app_id": context.get("app_id", ""), - "flow_id": context.get("flow_id", ""), - "user_id": context.get("user_sub", ""), + "flow_id": context.get("flow_id", self.flow_id or ""), + "user_id": context.get("user_sub", self.user_id or ""), "session_id": context.get("session_id", ""), + "conversation_id": self.conversation_id, "timestamp": datetime.now(UTC).timestamp(), } + # 获取对话变量池 + conversation_pool = await pool_manager.get_conversation_pool(self.conversation_id) + if not conversation_pool: + logger.error(f"对话变量池不存在,无法更新系统变量: {self.conversation_id}") + return + # 更新系统变量 + updated_count = 0 for var_name, var_value in system_var_mappings.items(): try: - variable = await pool.get_variable(var_name, VariableScope.SYSTEM) - if variable: - # 系统变量需要特殊处理,绕过只读限制 - variable._value = var_value + success = await conversation_pool.update_system_variable(var_name, var_value) + if success: + updated_count += 1 logger.debug(f"已更新系统变量: {var_name} = {var_value}") + else: + logger.warning(f"系统变量更新失败: {var_name}") except Exception as e: logger.warning(f"更新系统变量失败: {var_name} - {e}") + + logger.info(f"系统变量更新完成: {updated_count}/{len(system_var_mappings)} 个变量更新成功") + + async def update_conversation_variable(self, var_name: str, value: Any) -> bool: + """更新对话变量的值 + + Args: + var_name: 变量名 + value: 新值 + + Returns: + bool: 是否更新成功 + """ + if not self.conversation_id: + logger.warning("无法更新对话变量:缺少conversation_id") + return False + + pool_manager = await self._get_pool_manager() + conversation_pool = await pool_manager.get_conversation_pool(self.conversation_id) + + if not conversation_pool: + logger.warning(f"无法获取对话变量池: {self.conversation_id}") + return False + + try: + await conversation_pool.update_variable(var_name, value=value) + logger.info(f"已更新对话变量: {var_name} = {value}") + return True + except Exception as e: + logger.error(f"更新对话变量失败: {var_name} - {e}") + return False + + async def create_conversation_pool_if_needed(self) -> bool: + """如果需要,创建对话变量池 + + Returns: + bool: 是否创建成功 + """ + if not self.conversation_id or not self.flow_id: + return False + + pool_manager = await self._get_pool_manager() + existing_pool = await pool_manager.get_conversation_pool(self.conversation_id) + + if existing_pool: + return True + + try: + await pool_manager.create_conversation_pool(self.conversation_id, self.flow_id) + logger.info(f"已创建对话变量池: {self.conversation_id}") + return True + except Exception as e: + logger.error(f"创建对话变量池失败: {self.conversation_id} - {e}") + return False def _convert_to_string(self, value: Any) -> str: """将值转换为字符串 diff --git a/apps/scheduler/variable/pool.py b/apps/scheduler/variable/pool.py deleted file mode 100644 index 1d7ac4a26..000000000 --- a/apps/scheduler/variable/pool.py +++ /dev/null @@ -1,802 +0,0 @@ -import logging -from typing import Any, Dict, List, Optional, Set -from collections import defaultdict -import asyncio -from contextlib import asynccontextmanager -from datetime import datetime, UTC - -from apps.common.mongo import MongoDB -from .base import BaseVariable, VariableMetadata -from .type import VariableType, VariableScope -from .variables import create_variable, VARIABLE_CLASS_MAP - -logger = logging.getLogger(__name__) - - -class VariablePool: - """变量池 - 管理所有作用域的变量""" - - def __init__(self): - """初始化变量池""" - # 内存缓存:scope -> {variable_name: variable} - self._variables: Dict[VariableScope, Dict[str, BaseVariable]] = { - VariableScope.SYSTEM: {}, - VariableScope.USER: {}, - VariableScope.ENVIRONMENT: {}, - VariableScope.CONVERSATION: {}, - } - - # 上下文相关的变量缓存 - self._user_variables: Dict[str, Dict[str, BaseVariable]] = defaultdict(dict) # user_sub -> variables - self._env_variables: Dict[str, Dict[str, BaseVariable]] = defaultdict(dict) # flow_id -> variables - self._conv_variables: Dict[str, Dict[str, BaseVariable]] = defaultdict(dict) # flow_id -> variables (对话级变量) - - # 系统级变量定义 - self._system_variables_initialized = False - self._lock = asyncio.Lock() - - async def initialize(self): - """初始化变量池,加载系统变量""" - async with self._lock: - if not self._system_variables_initialized: - await self._initialize_system_variables() - self._system_variables_initialized = True - - async def _initialize_system_variables(self): - """初始化系统级变量""" - system_vars = [ - ("query", VariableType.STRING, "用户查询内容", ""), - ("files", VariableType.ARRAY_FILE, "用户上传的文件列表", []), - ("dialogue_count", VariableType.NUMBER, "对话轮数", 0), - ("app_id", VariableType.STRING, "应用ID", ""), - ("flow_id", VariableType.STRING, "工作流ID", ""), - ("user_id", VariableType.STRING, "用户ID", ""), - ("session_id", VariableType.STRING, "会话ID", ""), - ("timestamp", VariableType.NUMBER, "当前时间戳", 0), - ] - - for var_name, var_type, description, default_value in system_vars: - metadata = VariableMetadata( - name=var_name, - var_type=var_type, - scope=VariableScope.SYSTEM, - description=description, - created_by="system" - ) - variable = create_variable(metadata, default_value) - self._variables[VariableScope.SYSTEM][var_name] = variable - - logger.info(f"已初始化 {len(system_vars)} 个系统级变量") - - async def add_variable(self, - name: str, - var_type: VariableType, - scope: VariableScope, - value: Any = None, - description: Optional[str] = None, - user_sub: Optional[str] = None, - flow_id: Optional[str] = None) -> BaseVariable: - """添加变量 - - Args: - name: 变量名 - var_type: 变量类型 - scope: 作用域 - value: 初始值 - description: 描述 - user_sub: 用户ID(用户级变量必需) - flow_id: 流程ID(环境级变量必需) - - Returns: - BaseVariable: 创建的变量 - """ - await self.initialize() - - # 验证作用域相关参数 - if scope == VariableScope.SYSTEM: - raise ValueError("不能直接添加系统级变量") - elif scope == VariableScope.USER and not user_sub: - raise ValueError("用户级变量必须指定 user_sub") - elif scope == VariableScope.ENVIRONMENT and not flow_id: - raise ValueError("环境级变量必须指定 flow_id") - elif scope == VariableScope.CONVERSATION and not flow_id: - raise ValueError("对话级变量必须指定 flow_id") - - # 检查变量是否已存在 - existing_var = await self.get_variable(name, scope, user_sub, flow_id) - if existing_var: - raise ValueError(f"变量 {name} 在作用域 {scope.value} 中已存在") - - # 创建变量元数据 - metadata = VariableMetadata( - name=name, - var_type=var_type, - scope=scope, - description=description, - user_sub=user_sub, - flow_id=flow_id, - created_by=user_sub or "system" - ) - - # 创建变量 - variable = create_variable(metadata, value) - - # 存储到对应的缓存中 - await self._store_variable(variable) - - # 持久化存储 - 为了更好的用户体验,对话级变量也进行持久化 - await self._persist_variable(variable) - - logger.info(f"已添加变量: {name} ({var_type.value}) 到作用域 {scope.value}") - return variable - - async def update_variable(self, - name: str, - scope: VariableScope, - value: Optional[Any] = None, - var_type: Optional[VariableType] = None, - description: Optional[str] = None, - user_sub: Optional[str] = None, - flow_id: Optional[str] = None) -> BaseVariable: - """更新变量值、类型或描述 - - Args: - name: 变量名 - scope: 作用域 - value: 新值(可选) - var_type: 新变量类型(可选) - description: 新描述(可选) - user_sub: 用户ID - flow_id: 流程ID - - Returns: - BaseVariable: 更新后的变量 - """ - variable = await self.get_variable(name, scope, user_sub, flow_id) - if not variable: - raise ValueError(f"变量 {name} 在作用域 {scope.value} 中不存在") - - # 检查权限 - if user_sub and not variable.can_access(user_sub): - raise PermissionError(f"用户 {user_sub} 没有权限修改变量 {name}") - - # 至少需要更新一个字段 - if value is None and var_type is None and description is None: - raise ValueError("至少需要指定一个要更新的字段(value、var_type 或 description)") - - # 更新字段 - if value is not None: - variable.value = value - if var_type is not None: - variable.metadata.var_type = var_type - if description is not None: - variable.metadata.description = description - - # 更新时间戳 - variable.metadata.updated_at = datetime.now(UTC) - - # 更新缓存 - await self._store_variable(variable) - - # 持久化更新 - 为了更好的用户体验,对话级变量也进行持久化 - await self._persist_variable(variable) - - logger.info(f"已更新变量: {name} 在作用域 {scope.value}") - return variable - - async def delete_variable(self, - name: str, - scope: VariableScope, - user_sub: Optional[str] = None, - flow_id: Optional[str] = None) -> bool: - """删除变量 - - Args: - name: 变量名 - scope: 作用域 - user_sub: 用户ID - flow_id: 流程ID - - Returns: - bool: 是否删除成功 - """ - if scope == VariableScope.SYSTEM: - raise ValueError("不能删除系统级变量") - - variable = await self.get_variable(name, scope, user_sub, flow_id) - if not variable: - return False - - # 检查权限 - if user_sub and not variable.can_access(user_sub): - raise PermissionError(f"用户 {user_sub} 没有权限删除变量 {name}") - - # 从缓存中删除 - await self._remove_variable_from_cache(variable) - # 从数据库删除 - await self._delete_variable_from_db(variable) - - logger.info(f"已删除变量: {name} 从作用域 {scope.value}") - return True - - async def get_variable(self, - name: str, - scope: VariableScope, - user_sub: Optional[str] = None, - flow_id: Optional[str] = None) -> Optional[BaseVariable]: - """获取变量 - - Args: - name: 变量名 - scope: 作用域 - user_sub: 用户ID - flow_id: 流程ID - - Returns: - Optional[BaseVariable]: 变量或None - """ - await self.initialize() - - # 根据作用域从对应缓存中查找 - if scope == VariableScope.SYSTEM: - return self._variables[VariableScope.SYSTEM].get(name) - elif scope == VariableScope.USER and user_sub: - # 先从缓存查找 - if name in self._user_variables[user_sub]: - return self._user_variables[user_sub][name] - # 从数据库加载 - return await self._load_user_variable(name, user_sub) - elif scope == VariableScope.ENVIRONMENT and flow_id: - # 先从缓存查找 - if name in self._env_variables[flow_id]: - return self._env_variables[flow_id][name] - # 从数据库加载 - return await self._load_env_variable(name, flow_id) - elif scope == VariableScope.CONVERSATION and flow_id: - # 先从缓存查找 - if name in self._conv_variables[flow_id]: - return self._conv_variables[flow_id][name] - # 从数据库加载 - return await self._load_conv_variable(name, flow_id) - - return None - - async def list_variables(self, - scope: VariableScope, - user_sub: Optional[str] = None, - flow_id: Optional[str] = None) -> List[BaseVariable]: - """列出指定作用域的所有变量 - - Args: - scope: 作用域 - user_sub: 用户ID - flow_id: 流程ID - - Returns: - List[BaseVariable]: 变量列表 - """ - await self.initialize() - - if scope == VariableScope.SYSTEM: - return list(self._variables[VariableScope.SYSTEM].values()) - elif scope == VariableScope.USER and user_sub: - # 加载用户的所有变量 - await self._load_all_user_variables(user_sub) - return list(self._user_variables[user_sub].values()) - elif scope == VariableScope.ENVIRONMENT and flow_id: - # 加载环境的所有变量 - await self._load_all_env_variables(flow_id) - return list(self._env_variables[flow_id].values()) - elif scope == VariableScope.CONVERSATION and flow_id: - # 加载对话级的所有变量 - await self._load_all_conv_variables(flow_id) - return list(self._conv_variables[flow_id].values()) - - return [] - - async def clear_conversation_variables(self, flow_id: str): - """清空工作流的对话级变量""" - if flow_id in self._conv_variables: - del self._conv_variables[flow_id] - logger.info(f"已清空工作流 {flow_id} 的对话级变量") - - async def refresh_conversation_cache(self, flow_id: str): - """刷新对话级变量缓存""" - if flow_id in self._conv_variables: - # 清空当前缓存 - del self._conv_variables[flow_id] - - # 重新加载 - await self._load_all_conv_variables(flow_id) - logger.info(f"已刷新工作流 {flow_id} 的对话级变量缓存") - - async def resolve_variable_reference(self, - reference: str, - user_sub: Optional[str] = None, - flow_id: Optional[str] = None) -> Any: - """解析变量引用,如 {{sys.query}} 或 {{user.token}} - - Args: - reference: 变量引用字符串 - user_sub: 用户ID - flow_id: 流程ID - - Returns: - Any: 变量值 - """ - # 移除 {{ 和 }} - clean_ref = reference.strip("{}").strip() - - # 解析作用域和变量名 - parts = clean_ref.split(".", 1) - if len(parts) != 2: - raise ValueError(f"无效的变量引用格式: {reference}") - - scope_str, var_path = parts - - # 确定作用域 - scope_map = { - "sys": VariableScope.SYSTEM, - "system": VariableScope.SYSTEM, - "user": VariableScope.USER, - "env": VariableScope.ENVIRONMENT, - "environment": VariableScope.ENVIRONMENT, - "conversation": VariableScope.CONVERSATION, - "conv": VariableScope.CONVERSATION, - } - - scope = scope_map.get(scope_str) - if not scope: - raise ValueError(f"无效的变量作用域: {scope_str}") - - # 解析变量路径(支持嵌套访问如 user.config.api_key) - path_parts = var_path.split(".") - var_name = path_parts[0] - - # 获取变量 - variable = await self.get_variable(var_name, scope, user_sub, flow_id) - if not variable: - raise ValueError(f"变量不存在: {clean_ref}") - - # 获取变量值 - value = variable.value - - # 如果有嵌套路径,继续解析 - for path_part in path_parts[1:]: - if isinstance(value, dict): - value = value.get(path_part) - elif isinstance(value, list) and path_part.isdigit(): - try: - value = value[int(path_part)] - except IndexError: - value = None - else: - raise ValueError(f"无法访问路径: {var_path}") - - return value - - async def _store_variable(self, variable: BaseVariable): - """存储变量到缓存""" - scope = variable.scope - name = variable.name - - if scope == VariableScope.SYSTEM: - self._variables[scope][name] = variable - elif scope == VariableScope.USER: - user_sub = variable.metadata.user_sub - if user_sub: - self._user_variables[user_sub][name] = variable - elif scope == VariableScope.ENVIRONMENT: - flow_id = variable.metadata.flow_id - if flow_id: - self._env_variables[flow_id][name] = variable - elif scope == VariableScope.CONVERSATION: - flow_id = variable.metadata.flow_id - if flow_id: - self._conv_variables[flow_id][name] = variable - - async def _remove_variable_from_cache(self, variable: BaseVariable): - """从缓存中移除变量""" - scope = variable.scope - name = variable.name - - if scope == VariableScope.USER: - user_sub = variable.metadata.user_sub - if user_sub and name in self._user_variables[user_sub]: - del self._user_variables[user_sub][name] - elif scope == VariableScope.ENVIRONMENT: - flow_id = variable.metadata.flow_id - if flow_id and name in self._env_variables[flow_id]: - del self._env_variables[flow_id][name] - elif scope == VariableScope.CONVERSATION: - flow_id = variable.metadata.flow_id - if flow_id and name in self._conv_variables[flow_id]: - del self._conv_variables[flow_id][name] - - async def _persist_variable(self, variable: BaseVariable): - """持久化变量到数据库""" - try: - collection = MongoDB().get_collection("variables") - data = variable.serialize() - - # 构建查询条件 - query = { - "metadata.name": variable.name, - "metadata.scope": variable.scope.value - } - - if variable.scope == VariableScope.USER and variable.metadata.user_sub: - query["metadata.user_sub"] = variable.metadata.user_sub - elif variable.scope == VariableScope.ENVIRONMENT and variable.metadata.flow_id: - query["metadata.flow_id"] = variable.metadata.flow_id - elif variable.scope == VariableScope.CONVERSATION and variable.metadata.flow_id: - query["metadata.flow_id"] = variable.metadata.flow_id - - # 更新或插入 - 强制等待写入确认 - from pymongo import WriteConcern - result = await collection.with_options( - write_concern=WriteConcern(w="majority", j=True) - ).replace_one(query, data, upsert=True) - - # 确保写入成功 - if not (result.acknowledged and (result.matched_count > 0 or result.upserted_id)): - raise RuntimeError(f"变量持久化失败: {variable.name}") - - except Exception as e: - logger.error(f"持久化变量失败: {e}") - raise - - async def _delete_variable_from_db(self, variable: BaseVariable): - """从数据库删除变量""" - try: - collection = MongoDB().get_collection("variables") - - query = { - "metadata.name": variable.name, - "metadata.scope": variable.scope.value - } - - if variable.scope == VariableScope.USER and variable.metadata.user_sub: - query["metadata.user_sub"] = variable.metadata.user_sub - elif variable.scope == VariableScope.ENVIRONMENT and variable.metadata.flow_id: - query["metadata.flow_id"] = variable.metadata.flow_id - elif variable.scope == VariableScope.CONVERSATION and variable.metadata.flow_id: - query["metadata.flow_id"] = variable.metadata.flow_id - - # 强制等待写入确认的删除操作 - from pymongo import WriteConcern - result = await collection.with_options( - write_concern=WriteConcern(w="majority", j=True) - ).delete_one(query) - - # 确保删除成功 - if not result.acknowledged: - raise RuntimeError(f"变量删除失败: {variable.name}") - - except Exception as e: - logger.error(f"删除变量失败: {e}") - raise - - async def _load_user_variable(self, name: str, user_sub: str) -> Optional[BaseVariable]: - """从数据库加载用户变量""" - try: - collection = MongoDB().get_collection("variables") - doc = await collection.find_one({ - "metadata.name": name, - "metadata.scope": VariableScope.USER.value, - "metadata.user_sub": user_sub - }) - - if doc: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - try: - variable = var_class.deserialize(doc) - self._user_variables[user_sub][name] = variable - return variable - except Exception as e: - logger.warning(f"用户变量 {name} 数据损坏,将从数据库删除: {e}") - await self._delete_corrupted_variable(doc) - return None - - return None - - except Exception as e: - logger.error(f"加载用户变量失败: {e}") - return None - - async def _load_env_variable(self, name: str, flow_id: str) -> Optional[BaseVariable]: - """从数据库加载环境变量""" - try: - collection = MongoDB().get_collection("variables") - doc = await collection.find_one({ - "metadata.name": name, - "metadata.scope": VariableScope.ENVIRONMENT.value, - "metadata.flow_id": flow_id - }) - - if doc: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - try: - variable = var_class.deserialize(doc) - self._env_variables[flow_id][name] = variable - return variable - except Exception as e: - logger.warning(f"环境变量 {name} 数据损坏,将从数据库删除: {e}") - await self._delete_corrupted_variable(doc) - return None - - return None - - except Exception as e: - logger.error(f"加载环境变量失败: {e}") - return None - - async def _load_conv_variable(self, name: str, flow_id: str) -> Optional[BaseVariable]: - """从数据库加载对话级变量""" - try: - collection = MongoDB().get_collection("variables") - doc = await collection.find_one({ - "metadata.name": name, - "metadata.scope": VariableScope.CONVERSATION.value, - "metadata.flow_id": flow_id - }) - - if doc: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - try: - variable = var_class.deserialize(doc) - self._conv_variables[flow_id][name] = variable - return variable - except Exception as e: - logger.warning(f"对话级变量 {name} 数据损坏,将从数据库删除: {e}") - await self._delete_corrupted_variable(doc) - return None - - return None - - except Exception as e: - logger.error(f"加载对话级变量失败: {e}") - return None - - async def _load_all_user_variables(self, user_sub: str): - """加载用户的所有变量""" - try: - collection = MongoDB().get_collection("variables") - cursor = collection.find({ - "metadata.scope": VariableScope.USER.value, - "metadata.user_sub": user_sub - }) - - corrupted_count = 0 - loaded_count = 0 - - async for doc in cursor: - try: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - variable = var_class.deserialize(doc) - self._user_variables[user_sub][variable.name] = variable - loaded_count += 1 - break - except Exception as e: - var_name = doc.get("metadata", {}).get("name", "unknown") - logger.warning(f"用户变量 {var_name} 数据损坏,将从数据库删除: {e}") - await self._delete_corrupted_variable(doc) - corrupted_count += 1 - - if corrupted_count > 0: - logger.info(f"用户 {user_sub} 加载变量完成: 成功 {loaded_count} 个,清理损坏数据 {corrupted_count} 个") - else: - logger.debug(f"用户 {user_sub} 加载变量完成: {loaded_count} 个") - - except Exception as e: - logger.error(f"加载用户所有变量失败: {e}") - - async def _load_all_env_variables(self, flow_id: str): - """加载环境的所有变量""" - try: - collection = MongoDB().get_collection("variables") - cursor = collection.find({ - "metadata.scope": VariableScope.ENVIRONMENT.value, - "metadata.flow_id": flow_id - }) - - corrupted_count = 0 - loaded_count = 0 - - async for doc in cursor: - try: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - variable = var_class.deserialize(doc) - self._env_variables[flow_id][variable.name] = variable - loaded_count += 1 - break - except Exception as e: - var_name = doc.get("metadata", {}).get("name", "unknown") - logger.warning(f"环境变量 {var_name} 数据损坏,将从数据库删除: {e}") - await self._delete_corrupted_variable(doc) - corrupted_count += 1 - - if corrupted_count > 0: - logger.info(f"环境 {flow_id} 加载变量完成: 成功 {loaded_count} 个,清理损坏数据 {corrupted_count} 个") - else: - logger.debug(f"环境 {flow_id} 加载变量完成: {loaded_count} 个") - - except Exception as e: - logger.error(f"加载环境所有变量失败: {e}") - - async def _load_all_conv_variables(self, flow_id: str): - """加载对话级的所有变量""" - try: - collection = MongoDB().get_collection("variables") - cursor = collection.find({ - "metadata.scope": VariableScope.CONVERSATION.value, - "metadata.flow_id": flow_id - }) - - corrupted_count = 0 - loaded_count = 0 - - async for doc in cursor: - try: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - variable = var_class.deserialize(doc) - self._conv_variables[flow_id][variable.name] = variable - loaded_count += 1 - break - except Exception as e: - var_name = doc.get("metadata", {}).get("name", "unknown") - logger.warning(f"对话级变量 {var_name} 数据损坏,将从数据库删除: {e}") - await self._delete_corrupted_variable(doc) - corrupted_count += 1 - - if corrupted_count > 0: - logger.info(f"对话 {flow_id} 加载变量完成: 成功 {loaded_count} 个,清理损坏数据 {corrupted_count} 个") - else: - logger.debug(f"对话 {flow_id} 加载变量完成: {loaded_count} 个") - - except Exception as e: - logger.error(f"加载对话级所有变量失败: {e}") - - async def _delete_corrupted_variable(self, doc: Dict[str, Any]): - """删除损坏的变量数据 - - Args: - doc: 损坏的变量文档 - """ - try: - collection = MongoDB().get_collection("variables") - - # 记录损坏的数据用于分析 - corrupted_collection = MongoDB().get_collection("corrupted_variables") - backup_doc = doc.copy() - backup_doc["corrupted_at"] = datetime.now(UTC) - backup_doc["reason"] = "deserialization_failed" - - # 备份损坏数据 - try: - await corrupted_collection.insert_one(backup_doc) - except Exception as backup_e: - logger.warning(f"备份损坏变量数据失败: {backup_e}") - - # 删除原始损坏数据 - if "_id" in doc: - result = await collection.delete_one({"_id": doc["_id"]}) - if result.deleted_count > 0: - var_name = doc.get("metadata", {}).get("name", "unknown") - logger.debug(f"已删除损坏的变量数据: {var_name}") - else: - logger.warning(f"未能删除损坏的变量数据,可能已被删除") - - except Exception as e: - logger.error(f"删除损坏变量数据失败: {e}") - # 即使删除失败也不抛出异常,避免影响其他变量的加载 - - async def cleanup_corrupted_variables(self) -> Dict[str, int]: - """手动清理所有损坏的变量数据 - - Returns: - Dict[str, int]: 清理统计信息 {"checked": 检查数量, "cleaned": 清理数量} - """ - logger.info("开始检查和清理损坏的变量数据...") - - collection = MongoDB().get_collection("variables") - checked_count = 0 - cleaned_count = 0 - - try: - # 遍历所有变量进行检查 - cursor = collection.find({}) - - async for doc in cursor: - checked_count += 1 - try: - variable_class_name = doc.get("class") - if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: - # 找到对应的变量类并尝试反序列化 - for var_class in VARIABLE_CLASS_MAP.values(): - if var_class.__name__ == variable_class_name: - var_class.deserialize(doc) # 只是测试反序列化,不保存 - break - else: - # 未知的变量类型也视为损坏 - raise ValueError(f"未知的变量类型: {variable_class_name}") - - except Exception as e: - var_name = doc.get("metadata", {}).get("name", "unknown") - logger.warning(f"发现损坏的变量 {var_name},将删除: {e}") - await self._delete_corrupted_variable(doc) - cleaned_count += 1 - - result = {"checked": checked_count, "cleaned": cleaned_count} - logger.info(f"变量数据清理完成: 检查了 {checked_count} 个变量,清理了 {cleaned_count} 个损坏的变量") - return result - - except Exception as e: - logger.error(f"清理损坏变量失败: {e}") - return {"checked": checked_count, "cleaned": cleaned_count} - - async def get_corrupted_variables_info(self) -> List[Dict[str, Any]]: - """获取已备份的损坏变量信息 - - Returns: - List[Dict[str, Any]]: 损坏变量信息列表 - """ - try: - corrupted_collection = MongoDB().get_collection("corrupted_variables") - corrupted_vars = [] - - cursor = corrupted_collection.find({}).sort("corrupted_at", -1).limit(100) - async for doc in cursor: - var_info = { - "name": doc.get("metadata", {}).get("name", "unknown"), - "scope": doc.get("metadata", {}).get("scope", "unknown"), - "var_type": doc.get("metadata", {}).get("var_type", "unknown"), - "corrupted_at": doc.get("corrupted_at"), - "reason": doc.get("reason", "unknown"), - "user_sub": doc.get("metadata", {}).get("user_sub"), - "flow_id": doc.get("metadata", {}).get("flow_id"), - } - corrupted_vars.append(var_info) - - return corrupted_vars - - except Exception as e: - logger.error(f"获取损坏变量信息失败: {e}") - return [] - - -# 全局变量池实例 -_variable_pool = None - - -async def get_variable_pool() -> VariablePool: - """获取全局变量池实例""" - global _variable_pool - if _variable_pool is None: - _variable_pool = VariablePool() - await _variable_pool.initialize() - return _variable_pool \ No newline at end of file diff --git a/apps/scheduler/variable/pool_base.py b/apps/scheduler/variable/pool_base.py new file mode 100644 index 000000000..83acdffdc --- /dev/null +++ b/apps/scheduler/variable/pool_base.py @@ -0,0 +1,828 @@ +import logging +import asyncio +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Set +from datetime import datetime, UTC + +from apps.common.mongo import MongoDB +from .base import BaseVariable, VariableMetadata +from .type import VariableType, VariableScope +from .variables import create_variable, VARIABLE_CLASS_MAP + +logger = logging.getLogger(__name__) + + +class BaseVariablePool(ABC): + """变量池基类""" + + def __init__(self, pool_id: str, scope: VariableScope): + """初始化变量池 + + Args: + pool_id: 池标识符(如user_id、flow_id、conversation_id等) + scope: 池作用域 + """ + self.pool_id = pool_id + self.scope = scope + self._variables: Dict[str, BaseVariable] = {} + self._initialized = False + self._lock = asyncio.Lock() + + @property + def is_initialized(self) -> bool: + """检查是否已初始化""" + return self._initialized + + async def initialize(self): + """初始化变量池""" + async with self._lock: + if not self._initialized: + await self._load_variables() + await self._setup_default_variables() + self._initialized = True + logger.info(f"已初始化变量池: {self.__class__.__name__}({self.pool_id})") + + @abstractmethod + async def _load_variables(self): + """从存储加载变量""" + pass + + @abstractmethod + async def _setup_default_variables(self): + """设置默认变量""" + pass + + @abstractmethod + def can_modify(self) -> bool: + """检查是否允许修改变量""" + pass + + async def add_variable(self, + name: str, + var_type: VariableType, + value: Any = None, + description: Optional[str] = None, + created_by: Optional[str] = None, + is_system: bool = False) -> BaseVariable: + """添加变量""" + if not self.can_modify(): + raise PermissionError(f"不允许修改{self.scope.value}级变量") + + if name in self._variables: + raise ValueError(f"变量 {name} 已存在") + + # 创建变量元数据 + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=self.scope, + description=description, + user_sub=getattr(self, 'user_id', None), + flow_id=getattr(self, 'flow_id', None), + conversation_id=getattr(self, 'conversation_id', None), + created_by=created_by or "system", + is_system=is_system # 标记是否为系统变量 + ) + + # 创建变量 + variable = create_variable(metadata, value) + self._variables[name] = variable + + # 持久化 + await self._persist_variable(variable) + + logger.info(f"已添加{'系统' if is_system else ''}变量: {name} 到池 {self.pool_id}") + return variable + + async def update_variable(self, + name: str, + value: Optional[Any] = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + force_system_update: bool = False) -> BaseVariable: + """更新变量""" + if name not in self._variables: + raise ValueError(f"变量 {name} 不存在") + + variable = self._variables[name] + + # 检查系统变量的修改权限 + if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system and not force_system_update: + raise PermissionError(f"系统变量 {name} 不允许修改") + + if not self.can_modify() and not force_system_update: + raise PermissionError(f"不允许修改{self.scope.value}级变量") + + # 更新字段 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + variable.metadata.updated_at = datetime.now(UTC) + + # 持久化 + await self._persist_variable(variable) + + logger.info(f"已更新变量: {name} 在池 {self.pool_id}, 值为{value}") + return variable + + async def delete_variable(self, name: str) -> bool: + """删除变量""" + if not self.can_modify(): + raise PermissionError(f"不允许修改{self.scope.value}级变量") + + if name not in self._variables: + return False + + variable = self._variables[name] + + # 检查是否为系统变量 + if hasattr(variable.metadata, 'is_system') and variable.metadata.is_system: + raise PermissionError(f"系统变量 {name} 不允许删除") + + del self._variables[name] + + # 从数据库删除 + await self._delete_variable_from_db(variable) + + logger.info(f"已删除变量: {name} 从池 {self.pool_id}") + return True + + async def get_variable(self, name: str) -> Optional[BaseVariable]: + """获取变量""" + return self._variables.get(name) + + async def list_variables(self, include_system: bool = True) -> List[BaseVariable]: + """列出所有变量""" + if include_system: + return list(self._variables.values()) + else: + # 只返回非系统变量 + return [var for var in self._variables.values() + if not (hasattr(var.metadata, 'is_system') and var.metadata.is_system)] + + async def list_system_variables(self) -> List[BaseVariable]: + """列出系统变量""" + return [var for var in self._variables.values() + if hasattr(var.metadata, 'is_system') and var.metadata.is_system] + + async def has_variable(self, name: str) -> bool: + """检查变量是否存在""" + return name in self._variables + + async def copy_variables(self) -> Dict[str, BaseVariable]: + """拷贝所有变量""" + copied = {} + for name, variable in self._variables.items(): + # 创建新的元数据 + new_metadata = VariableMetadata( + name=variable.metadata.name, + var_type=variable.metadata.var_type, + scope=variable.metadata.scope, + description=variable.metadata.description, + user_sub=variable.metadata.user_sub, + flow_id=variable.metadata.flow_id, + conversation_id=variable.metadata.conversation_id, + created_by=variable.metadata.created_by, + is_system=getattr(variable.metadata, 'is_system', False) + ) + # 创建新的变量实例 + copied[name] = create_variable(new_metadata, variable.value) + return copied + + async def _persist_variable(self, variable: BaseVariable): + """持久化变量""" + try: + collection = MongoDB().get_collection("variables") + data = variable.serialize() + + # 构建查询条件 + query = { + "metadata.name": variable.name, + "metadata.scope": variable.scope.value + } + + # 添加池特定的查询条件 + self._add_pool_query_conditions(query, variable) + + # 更新或插入 + from pymongo import WriteConcern + result = await collection.with_options( + write_concern=WriteConcern(w="majority", j=True) + ).replace_one(query, data, upsert=True) + + if not (result.acknowledged and (result.matched_count > 0 or result.upserted_id)): + raise RuntimeError(f"变量持久化失败: {variable.name}") + + except Exception as e: + logger.error(f"持久化变量失败: {e}") + raise + + async def _delete_variable_from_db(self, variable: BaseVariable): + """从数据库删除变量""" + try: + collection = MongoDB().get_collection("variables") + + query = { + "metadata.name": variable.name, + "metadata.scope": variable.scope.value + } + + # 添加池特定的查询条件 + self._add_pool_query_conditions(query, variable) + + from pymongo import WriteConcern + result = await collection.with_options( + write_concern=WriteConcern(w="majority", j=True) + ).delete_one(query) + + if not result.acknowledged: + raise RuntimeError(f"变量删除失败: {variable.name}") + + except Exception as e: + logger.error(f"删除变量失败: {e}") + raise + + @abstractmethod + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加池特定的查询条件""" + pass + + async def resolve_variable_reference(self, reference: str) -> Any: + """解析变量引用""" + # 移除 {{ 和 }} + clean_ref = reference.strip("{}").strip() + + # 解析变量路径 + path_parts = clean_ref.split(".") + var_name = path_parts[0] + + # 获取变量 + variable = await self.get_variable(var_name) + if not variable: + raise ValueError(f"变量不存在: {var_name}") + + # 获取变量值 + value = variable.value + + # 处理嵌套路径 + for path_part in path_parts[1:]: + if isinstance(value, dict): + value = value.get(path_part) + elif isinstance(value, list) and path_part.isdigit(): + try: + value = value[int(path_part)] + except IndexError: + value = None + else: + raise ValueError(f"无法访问路径: {clean_ref}") + + return value + + +class UserVariablePool(BaseVariablePool): + """用户变量池""" + + def __init__(self, user_id: str): + super().__init__(user_id, VariableScope.USER) + self.user_id = user_id + + async def _load_variables(self): + """从数据库加载用户变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.USER.value, + "metadata.user_sub": self.user_id + }) + + loaded_count = 0 + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._variables[variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"用户变量 {var_name} 数据损坏: {e}") + + logger.debug(f"用户 {self.user_id} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载用户变量失败: {e}") + + async def _setup_default_variables(self): + """用户变量池不需要默认变量""" + pass + + def can_modify(self) -> bool: + """用户变量允许修改""" + return True + + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加用户变量池的查询条件""" + query["metadata.user_sub"] = self.user_id + + +class FlowVariablePool(BaseVariablePool): + """流程变量池(环境变量 + 系统变量模板 + 对话变量模板)""" + + def __init__(self, flow_id: str, parent_flow_id: Optional[str] = None): + super().__init__(flow_id, VariableScope.ENVIRONMENT) # 保持主要scope为ENVIRONMENT + self.flow_id = flow_id + self.parent_flow_id = parent_flow_id + + # 分别存储不同类型的变量 + # _variables 继续存储环境变量(保持向后兼容) + self._system_templates: Dict[str, BaseVariable] = {} # 系统变量模板 + self._conversation_templates: Dict[str, BaseVariable] = {} # 对话变量模板 + + async def _load_variables(self): + """从数据库加载所有类型的变量(环境变量 + 模板变量)""" + try: + collection = MongoDB().get_collection("variables") + loaded_counts = {"environment": 0, "system_templates": 0, "conversation_templates": 0} + + # 1. 加载环境变量 + env_cursor = collection.find({ + "metadata.scope": VariableScope.ENVIRONMENT.value, + "metadata.flow_id": self.flow_id + }) + + async for doc in env_cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._variables[variable.name] = variable + loaded_counts["environment"] += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"环境变量 {var_name} 数据损坏: {e}") + + # 2. 加载系统变量模板 + system_template_cursor = collection.find({ + "metadata.scope": VariableScope.SYSTEM.value, + "metadata.flow_id": self.flow_id, + "metadata.is_template": True + }) + + async for doc in system_template_cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._system_templates[variable.name] = variable + loaded_counts["system_templates"] += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"系统变量模板 {var_name} 数据损坏: {e}") + + # 3. 加载对话变量模板 + conv_template_cursor = collection.find({ + "metadata.scope": VariableScope.CONVERSATION.value, + "metadata.flow_id": self.flow_id, + "metadata.is_template": True + }) + + async for doc in conv_template_cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._conversation_templates[variable.name] = variable + loaded_counts["conversation_templates"] += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"对话变量模板 {var_name} 数据损坏: {e}") + + total_loaded = sum(loaded_counts.values()) + logger.debug(f"流程 {self.flow_id} 加载变量完成: 环境变量{loaded_counts['environment']}个, " + f"系统模板{loaded_counts['system_templates']}个, " + f"对话模板{loaded_counts['conversation_templates']}个, 总计{total_loaded}个") + + except Exception as e: + logger.error(f"加载流程变量失败: {e}") + + async def _setup_default_variables(self): + """设置默认的系统变量模板""" + from datetime import datetime, UTC + + # 定义系统变量模板(这些是模板,不是实例) + system_var_templates = [ + ("query", VariableType.STRING, "用户查询内容", ""), + ("files", VariableType.ARRAY_FILE, "用户上传的文件列表", []), + ("dialogue_count", VariableType.NUMBER, "对话轮数", 0), + ("app_id", VariableType.STRING, "应用ID", ""), + ("flow_id", VariableType.STRING, "工作流ID", self.flow_id), + ("user_id", VariableType.STRING, "用户ID", ""), + ("session_id", VariableType.STRING, "会话ID", ""), + ("conversation_id", VariableType.STRING, "对话ID", ""), + ("timestamp", VariableType.NUMBER, "当前时间戳", 0), + ] + + created_count = 0 + for var_name, var_type, description, default_value in system_var_templates: + # 如果系统变量模板不存在,才创建 + if var_name not in self._system_templates: + metadata = VariableMetadata( + name=var_name, + var_type=var_type, + scope=VariableScope.SYSTEM, + description=description, + flow_id=self.flow_id, + created_by="system", + is_system=True, + is_template=True # 标记为模板 + ) + variable = create_variable(metadata, default_value) + self._system_templates[var_name] = variable + + # 持久化模板到数据库 + try: + await self._persist_variable(variable) + created_count += 1 + logger.debug(f"已持久化系统变量模板: {var_name}") + except Exception as e: + logger.error(f"持久化系统变量模板失败: {var_name} - {e}") + + if created_count > 0: + logger.info(f"已为流程 {self.flow_id} 初始化 {created_count} 个系统变量模板") + + def can_modify(self) -> bool: + """环境变量允许修改""" + return True + + # === 系统变量模板相关方法 === + + async def get_system_template(self, name: str) -> Optional[BaseVariable]: + """获取系统变量模板""" + return self._system_templates.get(name) + + async def list_system_templates(self) -> List[BaseVariable]: + """列出所有系统变量模板""" + return list(self._system_templates.values()) + + async def add_system_template(self, name: str, var_type: VariableType, + default_value: Any = None, description: str = None) -> BaseVariable: + """添加系统变量模板""" + if name in self._system_templates: + raise ValueError(f"系统变量模板 {name} 已存在") + + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=VariableScope.SYSTEM, + description=description, + flow_id=self.flow_id, + created_by="system", + is_system=True, + is_template=True + ) + + variable = create_variable(metadata, default_value) + self._system_templates[name] = variable + + # 持久化到数据库 + await self._persist_variable(variable) + + logger.info(f"已添加系统变量模板: {name} 到流程 {self.flow_id}") + return variable + + # === 对话变量模板相关方法 === + + async def get_conversation_template(self, name: str) -> Optional[BaseVariable]: + """获取对话变量模板""" + return self._conversation_templates.get(name) + + async def list_conversation_templates(self) -> List[BaseVariable]: + """列出所有对话变量模板""" + return list(self._conversation_templates.values()) + + async def add_conversation_template(self, name: str, var_type: VariableType, + default_value: Any = None, description: str = None, + created_by: str = None) -> BaseVariable: + """添加对话变量模板""" + if name in self._conversation_templates: + raise ValueError(f"对话变量模板 {name} 已存在") + + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=VariableScope.CONVERSATION, + description=description, + flow_id=self.flow_id, + created_by=created_by or "user", + is_system=False, + is_template=True + ) + + variable = create_variable(metadata, default_value) + self._conversation_templates[name] = variable + + # 持久化到数据库 + await self._persist_variable(variable) + + logger.info(f"已添加对话变量模板: {name} 到流程 {self.flow_id}") + return variable + + # === 重写基类方法支持多scope查询 === + + async def get_variable_by_scope(self, name: str, scope: VariableScope) -> Optional[BaseVariable]: + """根据作用域获取变量""" + if scope == VariableScope.ENVIRONMENT: + return self._variables.get(name) + elif scope == VariableScope.SYSTEM: + return self._system_templates.get(name) + elif scope == VariableScope.CONVERSATION: + return self._conversation_templates.get(name) + else: + return None + + async def list_variables_by_scope(self, scope: VariableScope) -> List[BaseVariable]: + """根据作用域列出变量""" + if scope == VariableScope.ENVIRONMENT: + return list(self._variables.values()) + elif scope == VariableScope.SYSTEM: + return list(self._system_templates.values()) + elif scope == VariableScope.CONVERSATION: + return list(self._conversation_templates.values()) + else: + return [] + + # === 重写基类方法支持多字典操作 === + + async def update_variable(self, name: str, value: Any = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + force_system_update: bool = False) -> BaseVariable: + """更新变量(支持多字典查找)""" + + # 先在环境变量中查找 + if name in self._variables: + return await super().update_variable(name, value, var_type, description, force_system_update) + + # 在系统变量模板中查找 + elif name in self._system_templates: + variable = self._system_templates[name] + + # 检查权限 + if not force_system_update and getattr(variable.metadata, 'is_system', False): + raise PermissionError(f"系统变量 {name} 不允许直接修改") + + # 更新变量 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + # 持久化 + await self._persist_variable(variable) + return variable + + # 在对话变量模板中查找 + elif name in self._conversation_templates: + variable = self._conversation_templates[name] + + # 更新变量 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + # 持久化 + await self._persist_variable(variable) + return variable + + else: + raise ValueError(f"变量 {name} 不存在") + + async def delete_variable(self, name: str) -> bool: + """删除变量(支持多字典查找)""" + + # 先在环境变量中查找 + if name in self._variables: + return await super().delete_variable(name) + + # 在系统变量模板中查找 + elif name in self._system_templates: + variable = self._system_templates[name] + + # 检查权限 + if getattr(variable.metadata, 'is_system', False): + raise PermissionError(f"系统变量模板 {name} 不允许删除") + + del self._system_templates[name] + await self._delete_variable_from_db(variable) + return True + + # 在对话变量模板中查找 + elif name in self._conversation_templates: + variable = self._conversation_templates[name] + del self._conversation_templates[name] + await self._delete_variable_from_db(variable) + return True + + else: + return False + + async def get_variable(self, name: str) -> Optional[BaseVariable]: + """获取变量(支持多字典查找)""" + + # 先在环境变量中查找 + if name in self._variables: + return self._variables[name] + + # 在系统变量模板中查找 + elif name in self._system_templates: + return self._system_templates[name] + + # 在对话变量模板中查找 + elif name in self._conversation_templates: + return self._conversation_templates[name] + + else: + return None + + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加环境变量池的查询条件""" + query["metadata.flow_id"] = self.flow_id + + async def inherit_from_parent(self, parent_pool: "FlowVariablePool"): + """从父流程继承环境变量""" + parent_variables = await parent_pool.copy_variables() + for name, variable in parent_variables.items(): + # 更新元数据中的flow_id + variable.metadata.flow_id = self.flow_id + self._variables[name] = variable + # 持久化继承的变量 + await self._persist_variable(variable) + + logger.info(f"流程 {self.flow_id} 从父流程 {parent_pool.flow_id} 继承了 {len(parent_variables)} 个环境变量") + + +class ConversationVariablePool(BaseVariablePool): + """对话变量池 - 包含系统变量和对话变量""" + + def __init__(self, conversation_id: str, flow_id: str): + super().__init__(conversation_id, VariableScope.CONVERSATION) + self.conversation_id = conversation_id + self.flow_id = flow_id + + async def _load_variables(self): + """从数据库加载对话变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.CONVERSATION.value, + "metadata.conversation_id": self.conversation_id + }) + + loaded_count = 0 + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._variables[variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"对话变量 {var_name} 数据损坏: {e}") + + logger.debug(f"对话 {self.conversation_id} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载对话变量失败: {e}") + + async def _setup_default_variables(self): + """从flow模板继承系统变量和对话变量""" + from .pool_manager import get_pool_manager + + try: + pool_manager = await get_pool_manager() + flow_pool = await pool_manager.get_flow_pool(self.flow_id) + + if not flow_pool: + logger.warning(f"未找到流程池 {self.flow_id},无法继承变量模板") + return + + created_count = 0 + + # 1. 从系统变量模板创建系统变量实例 + system_templates = await flow_pool.list_system_templates() + for template in system_templates: + if template.name not in self._variables: + # 创建系统变量实例(不是模板) + metadata = VariableMetadata( + name=template.name, + var_type=template.var_type, + scope=VariableScope.CONVERSATION, # 存储在对话作用域 + description=template.metadata.description, + flow_id=self.flow_id, + conversation_id=self.conversation_id, + created_by="system", + is_system=True, # 标记为系统变量 + is_template=False # 这是实例,不是模板 + ) + + # 使用模板的默认值创建实例 + variable = create_variable(metadata, template.value) + self._variables[template.name] = variable + + # 持久化系统变量实例 + try: + await self._persist_variable(variable) + created_count += 1 + logger.debug(f"已从模板创建系统变量实例: {template.name}") + except Exception as e: + logger.error(f"持久化系统变量实例失败: {template.name} - {e}") + + # 2. 从对话变量模板创建对话变量实例 + conversation_templates = await flow_pool.list_conversation_templates() + for template in conversation_templates: + if template.name not in self._variables: + # 创建对话变量实例 + metadata = VariableMetadata( + name=template.name, + var_type=template.var_type, + scope=VariableScope.CONVERSATION, + description=template.metadata.description, + flow_id=self.flow_id, + conversation_id=self.conversation_id, + created_by=template.metadata.created_by, + is_system=False, # 对话变量 + is_template=False # 这是实例,不是模板 + ) + + # 使用模板的默认值创建实例 + variable = create_variable(metadata, template.value) + self._variables[template.name] = variable + + # 持久化对话变量实例 + try: + await self._persist_variable(variable) + created_count += 1 + logger.debug(f"已从模板创建对话变量实例: {template.name}") + except Exception as e: + logger.error(f"持久化对话变量实例失败: {template.name} - {e}") + + if created_count > 0: + logger.info(f"已为对话 {self.conversation_id} 从流程模板继承 {created_count} 个变量") + + except Exception as e: + logger.error(f"从流程模板继承变量失败: {e}") + + def can_modify(self) -> bool: + """对话变量允许修改""" + return True + + def _add_pool_query_conditions(self, query: Dict[str, Any], variable: BaseVariable): + """添加对话变量池的查询条件""" + query["metadata.conversation_id"] = self.conversation_id + query["metadata.flow_id"] = self.flow_id + + async def update_system_variable(self, name: str, value: Any) -> bool: + """更新系统变量的值(系统内部调用)""" + try: + await self.update_variable(name, value=value, force_system_update=True) + return True + except Exception as e: + logger.error(f"更新系统变量失败: {name} - {e}") + return False + + async def inherit_from_conversation_template(self, template_pool: Optional["ConversationVariablePool"] = None): + """从对话模板池继承变量(如果存在)""" + if template_pool: + template_variables = await template_pool.copy_variables() + for name, variable in template_variables.items(): + # 只继承非系统变量 + if not (hasattr(variable.metadata, 'is_system') and variable.metadata.is_system): + variable.metadata.conversation_id = self.conversation_id + self._variables[name] = variable + + logger.info(f"对话 {self.conversation_id} 从模板继承了 {len(template_variables)} 个变量") \ No newline at end of file diff --git a/apps/scheduler/variable/pool_manager.py b/apps/scheduler/variable/pool_manager.py new file mode 100644 index 000000000..479015add --- /dev/null +++ b/apps/scheduler/variable/pool_manager.py @@ -0,0 +1,353 @@ +import logging +import asyncio +from typing import Dict, List, Optional, Set, Tuple, Any +from contextlib import asynccontextmanager + +from apps.common.mongo import MongoDB +from .pool_base import ( + BaseVariablePool, + UserVariablePool, + FlowVariablePool, + ConversationVariablePool +) +from .type import VariableScope +from .base import BaseVariable + +logger = logging.getLogger(__name__) + + +class VariablePoolManager: + """变量池管理器 - 管理所有类型变量池的生命周期""" + + def __init__(self): + """初始化变量池管理器""" + # 用户变量池缓存: user_id -> UserVariablePool + self._user_pools: Dict[str, UserVariablePool] = {} + + # 流程变量池缓存: flow_id -> FlowVariablePool + self._flow_pools: Dict[str, FlowVariablePool] = {} + + # 对话变量池缓存: conversation_id -> ConversationVariablePool + self._conversation_pools: Dict[str, ConversationVariablePool] = {} + + # 流程继承关系缓存: child_flow_id -> parent_flow_id + self._flow_inheritance: Dict[str, str] = {} + + self._initialized = False + self._lock = asyncio.Lock() + + async def initialize(self): + """初始化变量池管理器""" + async with self._lock: + if not self._initialized: + await self._load_existing_entities() + await self._patrol_and_create_missing_pools() + self._initialized = True + logger.info("变量池管理器初始化完成") + + async def _load_existing_entities(self): + """加载现有的用户和流程实体""" + try: + # 这里应该从相应的用户和流程数据库表中加载 + # 目前先从变量表中推断存在的实体 + collection = MongoDB().get_collection("variables") + + # 获取所有唯一的用户ID + user_ids = await collection.distinct("metadata.user_sub", { + "metadata.user_sub": {"$ne": None} + }) + logger.info(f"发现 {len(user_ids)} 个用户需要变量池") + + # 获取所有唯一的流程ID + flow_ids = await collection.distinct("metadata.flow_id", { + "metadata.flow_id": {"$ne": None} + }) + logger.info(f"发现 {len(flow_ids)} 个流程需要变量池") + + # 缓存实体信息用于后续创建池 + self._discovered_users = set(user_ids) + self._discovered_flows = set(flow_ids) + + except Exception as e: + logger.error(f"加载现有实体失败: {e}") + self._discovered_users = set() + self._discovered_flows = set() + + async def _patrol_and_create_missing_pools(self): + """巡检并创建缺失的变量池""" + logger.info("开始巡检并创建缺失的变量池...") + + # 为所有发现的用户创建用户变量池 + created_user_pools = 0 + for user_id in self._discovered_users: + if user_id not in self._user_pools: + await self._create_user_pool(user_id) + created_user_pools += 1 + + # 为所有发现的流程创建流程变量池 + created_flow_pools = 0 + for flow_id in self._discovered_flows: + if flow_id not in self._flow_pools: + await self._create_flow_pool(flow_id) + created_flow_pools += 1 + + logger.info(f"巡检完成: 创建了 {created_user_pools} 个用户池, " + f"{created_flow_pools} 个流程池") + + async def get_user_pool(self, user_id: str, auto_create: bool = True) -> Optional[UserVariablePool]: + """获取用户变量池""" + if user_id in self._user_pools: + return self._user_pools[user_id] + + if auto_create: + return await self._create_user_pool(user_id) + + return None + + async def get_flow_pool(self, flow_id: str, parent_flow_id: Optional[str] = None, + auto_create: bool = True) -> Optional[FlowVariablePool]: + """获取流程变量池""" + if flow_id in self._flow_pools: + return self._flow_pools[flow_id] + + if auto_create: + return await self._create_flow_pool(flow_id, parent_flow_id) + + return None + + async def create_conversation_pool(self, conversation_id: str, flow_id: str) -> ConversationVariablePool: + """创建对话变量池(包含系统变量和对话变量)""" + if conversation_id in self._conversation_pools: + logger.warning(f"对话池 {conversation_id} 已存在,将覆盖") + + # 创建对话变量池 + conversation_pool = ConversationVariablePool(conversation_id, flow_id) + await conversation_pool.initialize() + + # 从对话模板池继承变量(如果存在) + conversation_template_pool = await self._get_conversation_template_pool(flow_id) + await conversation_pool.inherit_from_conversation_template(conversation_template_pool) + + # 缓存池 + self._conversation_pools[conversation_id] = conversation_pool + + logger.info(f"已创建对话变量池: {conversation_id}") + return conversation_pool + + async def get_conversation_pool(self, conversation_id: str) -> Optional[ConversationVariablePool]: + """获取对话变量池""" + return self._conversation_pools.get(conversation_id) + + async def remove_conversation_pool(self, conversation_id: str) -> bool: + """移除对话变量池""" + if conversation_id in self._conversation_pools: + del self._conversation_pools[conversation_id] + logger.info(f"已移除对话变量池: {conversation_id}") + return True + return False + + async def get_variable_from_any_pool(self, + name: str, + scope: VariableScope, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> Optional[BaseVariable]: + """从任意池中获取变量""" + if scope == VariableScope.USER and user_id: + pool = await self.get_user_pool(user_id) + return await pool.get_variable(name) if pool else None + + elif scope == VariableScope.ENVIRONMENT and flow_id: + pool = await self.get_flow_pool(flow_id) + return await pool.get_variable(name) if pool else None + + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 使用conversation_id查询对话变量实例 + pool = await self.get_conversation_pool(conversation_id) + return await pool.get_variable(name) if pool else None + elif flow_id: + # 使用flow_id查询对话变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.get_conversation_template(name) + return None + + # 系统变量处理 + elif scope == VariableScope.SYSTEM: + if conversation_id: + # 优先使用conversation_id查询实际的系统变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + variable = await pool.get_variable(name) + # 检查是否为系统变量 + if variable and hasattr(variable.metadata, 'is_system') and variable.metadata.is_system: + return variable + elif flow_id: + # 使用flow_id查询系统变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.get_system_template(name) + + return None + + async def list_variables_from_any_pool(self, + scope: VariableScope, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> List[BaseVariable]: + """从任意池中列出变量""" + if scope == VariableScope.USER and user_id: + pool = await self.get_user_pool(user_id) + return await pool.list_variables() if pool else [] + + elif scope == VariableScope.ENVIRONMENT and flow_id: + pool = await self.get_flow_pool(flow_id) + return await pool.list_variables() if pool else [] + + elif scope == VariableScope.CONVERSATION: + if conversation_id: + # 使用conversation_id查询对话变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + # 只返回非系统变量 + return await pool.list_variables(include_system=False) + elif flow_id: + # 使用flow_id查询对话变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.list_conversation_templates() + return [] + + # 系统变量处理 + elif scope == VariableScope.SYSTEM: + if conversation_id: + # 优先使用conversation_id查询实际的系统变量实例 + pool = await self.get_conversation_pool(conversation_id) + if pool: + # 只返回系统变量 + return await pool.list_system_variables() + elif flow_id: + # 使用flow_id查询系统变量模板 + flow_pool = await self.get_flow_pool(flow_id) + if flow_pool: + return await flow_pool.list_system_templates() + return [] + + return [] + + async def update_system_variable(self, conversation_id: str, name: str, value: Any) -> bool: + """更新对话中的系统变量""" + conversation_pool = await self.get_conversation_pool(conversation_id) + if conversation_pool: + return await conversation_pool.update_system_variable(name, value) + return False + + async def _create_user_pool(self, user_id: str) -> UserVariablePool: + """创建用户变量池""" + pool = UserVariablePool(user_id) + await pool.initialize() + self._user_pools[user_id] = pool + logger.info(f"已创建用户变量池: {user_id}") + return pool + + async def _create_flow_pool(self, flow_id: str, parent_flow_id: Optional[str] = None) -> FlowVariablePool: + """创建流程变量池""" + pool = FlowVariablePool(flow_id, parent_flow_id) + await pool.initialize() + + # 如果有父流程,从父流程继承变量 + if parent_flow_id and parent_flow_id in self._flow_pools: + parent_pool = self._flow_pools[parent_flow_id] + await pool.inherit_from_parent(parent_pool) + self._flow_inheritance[flow_id] = parent_flow_id + + self._flow_pools[flow_id] = pool + logger.info(f"已创建流程变量池: {flow_id}") + return pool + + async def _get_conversation_template_pool(self, flow_id: str) -> Optional[ConversationVariablePool]: + """获取对话模板池(目前简化处理,返回None)""" + # 这里可以实现从数据库加载对话模板的逻辑 + # 目前简化处理,返回None + return None + + async def clear_conversation_variables(self, flow_id: str): + """清空工作流的所有对话变量池""" + to_remove = [] + for conversation_id, pool in self._conversation_pools.items(): + if pool.flow_id == flow_id: + to_remove.append(conversation_id) + + for conversation_id in to_remove: + del self._conversation_pools[conversation_id] + + logger.info(f"已清空工作流 {flow_id} 的 {len(to_remove)} 个对话变量池") + + async def get_pool_stats(self) -> Dict[str, int]: + """获取变量池统计信息""" + return { + "user_pools": len(self._user_pools), + "flow_pools": len(self._flow_pools), + "conversation_pools": len(self._conversation_pools), + } + + async def cleanup_unused_pools(self, active_conversations: Set[str]): + """清理未使用的对话变量池""" + to_remove = [] + for conversation_id in self._conversation_pools: + if conversation_id not in active_conversations: + to_remove.append(conversation_id) + + for conversation_id in to_remove: + del self._conversation_pools[conversation_id] + + if to_remove: + logger.info(f"清理了 {len(to_remove)} 个未使用的对话变量池") + + @asynccontextmanager + async def get_pool_for_scope(self, + scope: VariableScope, + user_id: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None): + """上下文管理器,获取指定作用域的变量池""" + pool = None + + try: + if scope == VariableScope.USER and user_id: + pool = await self.get_user_pool(user_id) + elif scope == VariableScope.ENVIRONMENT and flow_id: + pool = await self.get_flow_pool(flow_id) + elif scope in [VariableScope.CONVERSATION, VariableScope.SYSTEM] and conversation_id: + pool = await self.get_conversation_pool(conversation_id) + + if not pool: + raise ValueError(f"无法获取 {scope.value} 级变量池") + + yield pool + + except Exception: + raise + finally: + # 这里可以添加清理逻辑,比如对话池的自动清理等 + pass + + +# 全局变量池管理器实例 +_pool_manager = None + + +async def get_pool_manager() -> VariablePoolManager: + """获取全局变量池管理器实例""" + global _pool_manager + if _pool_manager is None: + _pool_manager = VariablePoolManager() + await _pool_manager.initialize() + return _pool_manager + + +async def initialize_pool_manager(): + """初始化变量池管理器(在应用启动时调用)""" + await get_pool_manager() + logger.info("变量池管理器已启动") \ No newline at end of file diff --git a/apps/scheduler/variable/security.py b/apps/scheduler/variable/security.py index 0289a29da..07fde158c 100644 --- a/apps/scheduler/variable/security.py +++ b/apps/scheduler/variable/security.py @@ -202,13 +202,19 @@ class SecretVariableSecurity: bool: 是否轮换成功 """ try: - from .pool import get_variable_pool + from .pool_manager import get_pool_manager from .type import VariableScope - pool = await get_variable_pool() + pool_manager = await get_pool_manager() + + # 获取用户变量池 + user_pool = await pool_manager.get_user_pool(user_sub) + if not user_pool: + logger.error(f"用户变量池不存在: {user_sub}") + return False # 获取密钥变量 - variable = await pool.get_variable(variable_name, VariableScope.USER, user_sub=user_sub) + variable = await user_pool.get_variable(variable_name) if not variable or not variable.var_type.is_secret_type(): return False @@ -220,7 +226,7 @@ class SecretVariableSecurity: variable.value = original_value # 这会触发重新加密 # 更新存储 - await pool._persist_variable(variable) + await user_pool._persist_variable(variable) # 记录轮换操作 await self._log_access( diff --git a/apps/scheduler/variable/system_variables_example.py b/apps/scheduler/variable/system_variables_example.py new file mode 100644 index 000000000..0e48de30a --- /dev/null +++ b/apps/scheduler/variable/system_variables_example.py @@ -0,0 +1,220 @@ +""" +系统变量使用示例 + +演示系统变量的正确初始化、更新和访问流程 +""" + +import asyncio +from datetime import datetime, UTC +from typing import Dict, Any + +from .pool_manager import get_pool_manager +from .parser import VariableParser, VariableReferenceBuilder +from .type import VariableScope + + +async def demonstrate_system_variables(): + """演示系统变量的完整工作流程""" + + # 模拟对话参数 + user_id = "user123" + flow_id = "flow456" + conversation_id = "conv789" + + print("=== 系统变量演示 ===\n") + + # 1. 创建变量解析器(会自动创建对话池并初始化系统变量) + print("1. 创建变量解析器并初始化对话变量池...") + parser = VariableParser( + user_id=user_id, + flow_id=flow_id, + conversation_id=conversation_id + ) + + # 确保对话池存在 + success = await parser.create_conversation_pool_if_needed() + print(f" 对话池创建结果: {'成功' if success else '失败'}") + + # 2. 检查系统变量是否已正确初始化 + print("\n2. 检查初始化的系统变量...") + pool_manager = await get_pool_manager() + conversation_pool = await pool_manager.get_conversation_pool(conversation_id) + + if conversation_pool: + system_vars = await conversation_pool.list_system_variables() + print(f" 已初始化 {len(system_vars)} 个系统变量:") + for var in system_vars: + print(f" - {var.name}: {var.value} ({var.var_type.value})") + + # 3. 更新系统变量(模拟对话开始) + print("\n3. 更新系统变量...") + context = { + "question": "请帮我分析这个数据文件", + "files": [{"name": "data.csv", "size": 1024, "type": "text/csv"}], + "dialogue_count": 1, + "app_id": "app001", + "user_sub": user_id, + "session_id": "session123" + } + + await parser.update_system_variables(context) + print(" 系统变量更新完成") + + # 4. 验证系统变量已正确更新 + print("\n4. 验证系统变量更新结果...") + updated_vars = await conversation_pool.list_system_variables() + for var in updated_vars: + if var.name in ["query", "files", "dialogue_count", "app_id", "user_id", "session_id"]: + print(f" - {var.name}: {var.value}") + + # 5. 使用变量引用解析模板 + print("\n5. 解析包含系统变量的模板...") + template = """ +用户查询: {{sys.query}} +对话轮数: {{sys.dialogue_count}} +流程ID: {{sys.flow_id}} +用户ID: {{sys.user_id}} +文件数量: {{sys.files.length}} +""" + + try: + parsed_result = await parser.parse_template(template) + print(" 模板解析结果:") + print(parsed_result) + except Exception as e: + print(f" 模板解析失败: {e}") + + # 6. 验证系统变量的只读性 + print("\n6. 验证系统变量的只读保护...") + try: + # 尝试直接修改系统变量(应该失败) + await conversation_pool.update_variable("query", value="恶意修改") + print(" ❌ 错误:系统变量被意外修改") + except PermissionError: + print(" ✅ 正确:系统变量只读保护生效") + except Exception as e: + print(f" 🤔 意外错误: {e}") + + # 7. 展示系统变量的强制更新(内部使用) + print("\n7. 演示系统变量的内部更新...") + success = await conversation_pool.update_system_variable("dialogue_count", 2) + if success: + updated_var = await conversation_pool.get_variable("dialogue_count") + print(f" ✅ 系统变量内部更新成功: dialogue_count = {updated_var.value}") + else: + print(" ❌ 系统变量内部更新失败") + + # 8. 清理 + print("\n8. 清理对话变量池...") + removed = await pool_manager.remove_conversation_pool(conversation_id) + print(f" 清理结果: {'成功' if removed else '失败'}") + + print("\n=== 演示完成 ===") + + +async def demonstrate_variable_references(): + """演示系统变量引用的构建和使用""" + + print("\n=== 变量引用演示 ===\n") + + # 构建各种变量引用 + print("1. 变量引用构建示例:") + + # 系统变量引用 + query_ref = VariableReferenceBuilder.system("query") + files_ref = VariableReferenceBuilder.system("files", "0.name") # 嵌套访问 + + # 用户变量引用 + api_key_ref = VariableReferenceBuilder.user("api_key") + + # 环境变量引用 + db_url_ref = VariableReferenceBuilder.environment("database_url") + + # 对话变量引用 + history_ref = VariableReferenceBuilder.conversation("chat_history") + + print(f" 系统变量 - 用户查询: {query_ref}") + print(f" 系统变量 - 首个文件名: {files_ref}") + print(f" 用户变量 - API密钥: {api_key_ref}") + print(f" 环境变量 - 数据库: {db_url_ref}") + print(f" 对话变量 - 聊天历史: {history_ref}") + + # 构建复杂模板 + print("\n2. 复杂模板示例:") + complex_template = f""" +# 对话上下文 +- 用户: {query_ref} +- 轮次: {VariableReferenceBuilder.system("dialogue_count")} +- 时间: {VariableReferenceBuilder.system("timestamp")} + +# 文件信息 +- 文件列表: {VariableReferenceBuilder.system("files")} +- 文件数量: {VariableReferenceBuilder.system("files", "length")} + +# 会话信息 +- 对话ID: {VariableReferenceBuilder.system("conversation_id")} +- 流程ID: {VariableReferenceBuilder.system("flow_id")} +- 用户ID: {VariableReferenceBuilder.system("user_id")} +""" + + print(complex_template) + + print("=== 引用演示完成 ===") + + +async def validate_system_variable_persistence(): + """验证系统变量的持久化""" + + print("\n=== 持久化验证 ===\n") + + conversation_id = "test_persistence_conv" + flow_id = "test_persistence_flow" + + # 创建对话池 + pool_manager = await get_pool_manager() + conversation_pool = await pool_manager.create_conversation_pool(conversation_id, flow_id) + + print("1. 检查新创建池的系统变量...") + system_vars_before = await conversation_pool.list_system_variables() + print(f" 创建后的系统变量数量: {len(system_vars_before)}") + + # 模拟应用重启 - 重新获取池 + print("\n2. 模拟重新加载...") + await pool_manager.remove_conversation_pool(conversation_id) + + # 重新创建同一个对话池 + conversation_pool_reloaded = await pool_manager.create_conversation_pool(conversation_id, flow_id) + system_vars_after = await conversation_pool_reloaded.list_system_variables() + + print(f" 重新加载后的系统变量数量: {len(system_vars_after)}") + + # 验证变量是否一致 + vars_before_names = {var.name for var in system_vars_before} + vars_after_names = {var.name for var in system_vars_after} + + if vars_before_names == vars_after_names: + print(" ✅ 系统变量持久化验证成功") + else: + print(" ❌ 系统变量持久化验证失败") + print(f" 之前: {vars_before_names}") + print(f" 之后: {vars_after_names}") + + # 清理 + await pool_manager.remove_conversation_pool(conversation_id) + + print("=== 持久化验证完成 ===") + + +if __name__ == "__main__": + async def main(): + """运行所有演示""" + try: + await demonstrate_system_variables() + await demonstrate_variable_references() + await validate_system_variable_persistence() + except Exception as e: + print(f"演示过程中发生错误: {e}") + import traceback + traceback.print_exc() + + asyncio.run(main()) \ No newline at end of file diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 263215f3f..3205f062b 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -84,6 +84,20 @@ class MongoDBConfig(BaseModel): database: str = Field(description="MongoDB数据库名") +class RedisConfig(BaseModel): + """Redis配置""" + + host: str = Field(description="Redis主机名", default="redis-db") + port: int = Field(description="Redis端口号", default=6379) + password: str | None = Field(description="Redis密码", default=None) + database: int = Field(description="Redis数据库编号", default=0) + decode_responses: bool = Field(description="是否解码响应", default=True) + socket_timeout: float = Field(description="套接字超时时间(秒)", default=5.0) + socket_connect_timeout: float = Field(description="连接超时时间(秒)", default=5.0) + max_connections: int = Field(description="最大连接数", default=10) + health_check_interval: int = Field(description="健康检查间隔(秒)", default=30) + + class LLMConfig(BaseModel): """LLM配置""" @@ -143,6 +157,7 @@ class ConfigModel(BaseModel): fastapi: FastAPIConfig minio: MinioConfig mongodb: MongoDBConfig + redis: RedisConfig llm: LLMConfig function_call: FunctionCallConfig security: SecurityConfig diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a6bed204..bcf3c8375 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -51,8 +51,10 @@ class EventType(str, Enum): class CallType(str, Enum): """Call类型""" - SYSTEM = "system" - PYTHON = "python" + DEFAULT = "default" + LOGIC = "logic" + TRANSFORM = "transform" + TOOL = "tool" class MetadataType(str, Enum): diff --git a/apps/schemas/flow_topology.py b/apps/schemas/flow_topology.py index d0ab666a9..aa5da0d58 100644 --- a/apps/schemas/flow_topology.py +++ b/apps/schemas/flow_topology.py @@ -5,7 +5,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import EdgeType +from apps.schemas.enum_var import CallType, EdgeType class NodeMetaDataItem(BaseModel): @@ -14,6 +14,7 @@ class NodeMetaDataItem(BaseModel): node_id: str = Field(alias="nodeId") call_id: str = Field(alias="callId") name: str + type: CallType description: str parameters: dict[str, Any] | None editable: bool = Field(default=True) diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py index 27e16b370..009c8206d 100644 --- a/apps/schemas/pool.py +++ b/apps/schemas/pool.py @@ -45,9 +45,6 @@ class CallPool(BaseData): Call信息 collection: call - - “path”的格式如下: - 1. Python代码会被导入成包,路径格式为`python::::`,用于查找Call的包路径和类路径 """ type: CallType = Field(description="Call的类型") @@ -77,6 +74,7 @@ class NodePool(BaseData): service_id: str | None = Field(description="Node所属的Service ID", default=None) call_id: str = Field(description="所使用的Call的ID") + type: CallType = Field(description="所使用的Call的类型") known_params: dict[str, Any] | None = Field( description="已知的用于Call部分的参数,独立于输入和输出之外", default=None, diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 38fd94ad4..dd2dbb604 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -1,18 +1,19 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """插件、工作流、步骤相关数据结构定义""" +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import CallOutputType +from apps.schemas.enum_var import CallOutputType, CallType from apps.schemas.task import FlowStepHistory class CallInfo(BaseModel): """Call的名称和描述""" - name: str = Field(description="Call的名称") + type: CallType = Field(description="Call的类别") description: str = Field(description="Call的描述") @@ -22,6 +23,7 @@ class CallIds(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="Flow ID") session_id: str = Field(description="当前用户的Session ID") + conversation_id: str = Field(description="当前对话ID") app_id: str = Field(description="当前应用的ID") user_sub: str = Field(description="当前用户的用户ID") diff --git a/apps/services/flow.py b/apps/services/flow.py index ebc5764bc..f84f5b832 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -95,6 +95,7 @@ class FlowManager: nodeId=node_pool_record["_id"], callId=node_pool_record["call_id"], name=node_pool_record["name"], + type=node_pool_record["type"], description=node_pool_record["description"], editable=True, createdAt=node_pool_record["created_at"], @@ -153,7 +154,7 @@ class FlowManager: NodeServiceItem( serviceId=record["_id"], name=record["name"], - type="default", + type="default", # TODO record["type"]? nodeMetaDatas=[], createdAt=str(record["created_at"]), ) diff --git a/apps/services/predecessor_cache_service.py b/apps/services/predecessor_cache_service.py new file mode 100644 index 000000000..967077ea9 --- /dev/null +++ b/apps/services/predecessor_cache_service.py @@ -0,0 +1,488 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""前置节点变量预解析缓存服务""" + +import asyncio +import hashlib +import json +import logging +from typing import List, Dict, Any, Optional +from datetime import datetime, UTC + +from apps.common.redis_cache import RedisCache +from apps.common.process_handler import ProcessHandler +from apps.services.flow import FlowManager +from apps.scheduler.variable.variables import create_variable +from apps.scheduler.variable.base import VariableMetadata +from apps.scheduler.variable.type import VariableType, VariableScope + +logger = logging.getLogger(__name__) + +# 全局Redis缓存实例 +redis_cache = RedisCache() +predecessor_cache = None + +# 添加任务管理 +_background_tasks: Dict[str, asyncio.Task] = {} +_task_lock = asyncio.Lock() + +def _get_predecessor_cache(): + """获取predecessor_cache实例,确保初始化""" + global predecessor_cache + if predecessor_cache is None: + from apps.common.redis_cache import PredecessorVariableCache + predecessor_cache = PredecessorVariableCache(redis_cache) + return predecessor_cache + +async def init_predecessor_cache(): + """初始化前置节点变量缓存""" + global predecessor_cache + if predecessor_cache is None: + from apps.common.redis_cache import PredecessorVariableCache + predecessor_cache = PredecessorVariableCache(redis_cache) + +async def cleanup_background_tasks(): + """清理后台任务""" + async with _task_lock: + if not _background_tasks: + logger.info("没有后台任务需要清理") + return + + logger.info(f"开始清理 {len(_background_tasks)} 个后台任务") + + # 取消所有未完成的任务 + cancelled_count = 0 + for task_id, task in list(_background_tasks.items()): + if not task.done(): + task.cancel() + cancelled_count += 1 + logger.debug(f"取消后台任务: {task_id}") + + # 等待任务取消完成,设置超时避免永久等待 + if cancelled_count > 0: + logger.info(f"等待 {cancelled_count} 个任务取消完成...") + timeout = 5.0 # 5秒超时 + try: + await asyncio.wait_for( + asyncio.gather(*[task for task in _background_tasks.values()], return_exceptions=True), + timeout=timeout + ) + except asyncio.TimeoutError: + logger.warning(f"等待任务取消超时 ({timeout}s),强制清理") + except Exception as e: + logger.error(f"等待任务取消时出错: {e}") + + # 清理任务字典 + completed_count = len(_background_tasks) + _background_tasks.clear() + logger.info(f"后台任务清理完成,共清理 {completed_count} 个任务") + +async def periodic_cleanup_background_tasks(): + """定期清理已完成的后台任务""" + try: + async with _task_lock: + if not _background_tasks: + return + + completed_tasks = [] + for task_id, task in list(_background_tasks.items()): + if task.done(): + completed_tasks.append(task_id) + try: + # 获取任务结果,记录异常 + await task + logger.debug(f"后台任务已完成: {task_id}") + except Exception as e: + logger.error(f"后台任务执行异常: {task_id}, 错误: {e}") + + # 移除已完成的任务 + for task_id in completed_tasks: + _background_tasks.pop(task_id, None) + + if completed_tasks: + logger.info(f"定期清理了 {len(completed_tasks)} 个已完成的后台任务") + + except Exception as e: + logger.error(f"定期清理后台任务失败: {e}") + + +class PredecessorCacheService: + """前置节点变量预解析缓存服务""" + + @staticmethod + async def initialize_redis(): + """初始化Redis连接""" + try: + # 从配置文件读取Redis配置 + from apps.common.config import Config + + config = Config().get_config() + redis_config = config.redis + + logger.info(f"准备连接Redis: {redis_config.host}:{redis_config.port}") + await redis_cache.init(redis_config=redis_config) + + # 验证连接是否正常 + if redis_cache.is_connected(): + logger.info("前置节点缓存服务Redis初始化成功") + return + else: + raise Exception("Redis连接验证失败") + + except Exception as e: + logger.error(f"使用配置文件连接Redis失败: {e}") + + # 尝试降级连接方案 + try: + logger.info("尝试降级连接方案...") + from apps.common.config import Config + config = Config().get_config() + redis_config = config.redis + + # 构建简单的Redis URL + password_part = f":{redis_config.password}@" if redis_config.password else "" + redis_url = f"redis://{password_part}{redis_config.host}:{redis_config.port}/{redis_config.database}" + + await redis_cache.init(redis_url=redis_url) + + if redis_cache.is_connected(): + logger.info("降级连接方案成功") + return + else: + raise Exception("降级连接方案也失败") + + except Exception as fallback_error: + logger.error(f"降级连接方案也失败: {fallback_error}") + + # 即使Redis初始化失败,也不要抛出异常,而是继续运行(降级模式) + logger.info("将使用实时解析模式作为降级方案") + + @staticmethod + def calculate_flow_hash(flow_item) -> str: + """计算Flow拓扑结构的哈希值""" + try: + # 提取关键的拓扑信息 + topology_data = { + 'nodes': [ + { + 'step_id': node.step_id, + 'call_id': getattr(node, 'call_id', ''), + 'parameters': getattr(node, 'parameters', {}) + } + for node in flow_item.nodes + ], + 'edges': [ + { + 'source_node': edge.source_node, + 'target_node': edge.target_node + } + for edge in flow_item.edges + ] + } + + # 生成哈希 + topology_json = json.dumps(topology_data, sort_keys=True) + return hashlib.md5(topology_json.encode()).hexdigest() + except Exception as e: + logger.error(f"计算Flow哈希失败: {e}") + return str(datetime.now(UTC).timestamp()) # 降级方案 + + @staticmethod + async def trigger_flow_parsing(flow_id: str, force_refresh: bool = False): + """触发整个Flow的前置节点变量解析""" + try: + # 获取Flow信息 + flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id) + if not flow_item: + logger.warning(f"Flow不存在,跳过解析: {flow_id}") + return + + # 计算当前Flow的哈希 + current_hash = PredecessorCacheService.calculate_flow_hash(flow_item) + + # 检查是否需要重新解析 + if not force_refresh: + cached_hash = await _get_predecessor_cache().get_flow_hash(flow_id) + if cached_hash == current_hash: + logger.info(f"Flow拓扑未变化,跳过解析: {flow_id}") + return + + # 更新Flow哈希 + await _get_predecessor_cache().set_flow_hash(flow_id, current_hash) + + # 清除旧缓存 + await _get_predecessor_cache().invalidate_flow_cache(flow_id) + + # 为每个节点启动异步解析任务 + tasks = [] + for node in flow_item.nodes: + step_id = node.step_id + task_id = f"parse_predecessor_{flow_id}_{step_id}" + + # 避免重复任务 + async with _task_lock: + if task_id in _background_tasks and not _background_tasks[task_id].done(): + continue + + # 异步启动解析任务 + task = asyncio.create_task( + PredecessorCacheService._parse_single_node_predecessor( + flow_id, step_id, current_hash + ) + ) + _background_tasks[task_id] = task + tasks.append((task_id, task)) + + if tasks: + logger.info(f"启动Flow前置节点解析任务: {flow_id}, 节点数量: {len(tasks)}") + # 简化处理:直接启动任务,依赖cleanup_background_tasks进行清理 + for task_id, task in tasks: + # 不添加回调,让任务自然完成 + logger.debug(f"启动后台任务: {task_id}") + + except Exception as e: + logger.error(f"触发Flow解析失败: {flow_id}, 错误: {e}") + + @staticmethod + async def _cleanup_task(task_id: str): + """清理完成的任务""" + try: + async with _task_lock: + task = _background_tasks.pop(task_id, None) + if task and task.done(): + # 检查任务是否有异常 + try: + result = await task + logger.debug(f"后台任务完成: {task_id}") + except Exception as e: + logger.error(f"后台任务执行异常: {task_id}, 错误: {e}") + except Exception as e: + logger.error(f"清理任务失败: {task_id}, 错误: {e}") + + @staticmethod + async def _parse_single_node_predecessor(flow_id: str, step_id: str, flow_hash: str): + """解析单个节点的前置节点变量""" + try: + # 检查事件循环是否仍然活跃 + try: + asyncio.get_running_loop() + except RuntimeError: + logger.warning(f"事件循环已关闭,跳过解析: {flow_id}:{step_id}") + return + + # 设置解析状态 + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "parsing") + + # 获取Flow信息 + flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id) + if not flow_item: + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "failed") + return + + # 查找前置节点 + predecessor_nodes = PredecessorCacheService._find_predecessor_nodes(flow_item, step_id) + + # 为每个前置节点创建输出变量 + variables_data = [] + for node in predecessor_nodes: + node_vars = await PredecessorCacheService._create_node_output_variables(node) + variables_data.extend(node_vars) + + # 缓存结果 + await _get_predecessor_cache().set_cached_variables(flow_id, step_id, variables_data, flow_hash) + + # 设置完成状态 + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "completed") + + logger.info(f"节点前置变量解析完成: {flow_id}:{step_id}, 变量数量: {len(variables_data)}") + + except asyncio.CancelledError: + logger.info(f"节点前置变量解析任务被取消: {flow_id}:{step_id}") + try: + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "cancelled") + except Exception: + pass # 忽略清理时的错误 + except Exception as e: + logger.error(f"解析节点前置变量失败: {flow_id}:{step_id}, 错误: {e}") + try: + await _get_predecessor_cache().set_parsing_status(flow_id, step_id, "failed") + except Exception: + # 如果连设置状态都失败了,说明可能是事件循环关闭导致的 + logger.warning(f"无法设置解析状态为失败: {flow_id}:{step_id}") + + @staticmethod + async def get_predecessor_variables_optimized( + flow_id: str, + step_id: str, + user_sub: str, + max_wait_time: int = 10 + ) -> List[Dict[str, Any]]: + """优化的前置节点变量获取(优先使用缓存)""" + try: + # 1. 先尝试从缓存获取 + cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id) + if cached_vars is not None: + logger.info(f"使用缓存的前置节点变量: {flow_id}:{step_id}") + return cached_vars + + # 2. 检查是否正在解析中 + if await _get_predecessor_cache().is_parsing_in_progress(flow_id, step_id): + logger.info(f"等待前置节点变量解析完成: {flow_id}:{step_id}") + # 等待解析完成 + if await _get_predecessor_cache().wait_for_parsing_completion(flow_id, step_id, max_wait_time): + cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id) + if cached_vars is not None: + return cached_vars + + # 3. 缓存未命中,启动实时解析 + logger.info(f"缓存未命中,启动实时解析: {flow_id}:{step_id}") + + # 获取Flow信息 + flow_item = await PredecessorCacheService._get_flow_by_flow_id(flow_id) + if not flow_item: + return [] + + # 计算Flow哈希 + flow_hash = PredecessorCacheService.calculate_flow_hash(flow_item) + + # 立即解析并缓存 + await PredecessorCacheService._parse_single_node_predecessor(flow_id, step_id, flow_hash) + + # 再次尝试从缓存获取 + cached_vars = await _get_predecessor_cache().get_cached_variables(flow_id, step_id) + return cached_vars or [] + + except Exception as e: + logger.error(f"获取优化前置节点变量失败: {flow_id}:{step_id}, 错误: {e}") + return [] + + @staticmethod + async def _get_flow_by_flow_id(flow_id: str): + """通过flow_id获取工作流信息""" + try: + from apps.common.mongo import MongoDB + + app_collection = MongoDB().get_collection("app") + + # 查询包含此flow_id的app,同时获取app_id + app_record = await app_collection.find_one( + {"flows.id": flow_id}, + {"_id": 1} + ) + + if not app_record: + logger.warning(f"未找到包含flow_id {flow_id} 的应用") + return None + + app_id = app_record["_id"] + + # 使用现有的FlowManager方法获取flow + flow_item = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + return flow_item + + except Exception as e: + logger.error(f"通过flow_id获取工作流失败: {e}") + return None + + @staticmethod + def _find_predecessor_nodes(flow_item, current_step_id: str) -> List: + """在工作流中查找前置节点""" + try: + predecessor_nodes = [] + + # 遍历边,找到指向当前节点的边 + for edge in flow_item.edges: + if edge.target_node == current_step_id: + # 找到前置节点 + source_node = next( + (node for node in flow_item.nodes if node.step_id == edge.source_node), + None + ) + if source_node: + predecessor_nodes.append(source_node) + + logger.debug(f"为节点 {current_step_id} 找到 {len(predecessor_nodes)} 个前置节点") + return predecessor_nodes + + except Exception as e: + logger.error(f"查找前置节点失败: {e}") + return [] + + @staticmethod + async def _create_node_output_variables(node) -> List[Dict[str, Any]]: + """根据节点的output_parameters配置创建输出变量数据""" + try: + variables_data = [] + node_id = node.step_id + + # 统一从节点的output_parameters创建变量 + output_params = {} + if hasattr(node, 'parameters') and node.parameters: + if isinstance(node.parameters, dict): + output_params = node.parameters.get('output_parameters', {}) + else: + output_params = getattr(node.parameters, 'output_parameters', {}) + + # 如果没有配置output_parameters,跳过此节点 + if not output_params: + logger.debug(f"节点 {node_id} 没有配置output_parameters,跳过创建输出变量") + return variables_data + + # 遍历output_parameters中的每个key-value对,创建对应的变量数据 + for param_name, param_config in output_params.items(): + # 解析参数配置 + if isinstance(param_config, dict): + param_type = param_config.get('type', 'string') + description = param_config.get('description', '') + else: + # 如果param_config不是字典,可能是简单的类型字符串 + param_type = str(param_config) if param_config else 'string' + description = '' + + # 确定变量类型 + var_type = VariableType.STRING # 默认类型 + if param_type == 'number': + var_type = VariableType.NUMBER + elif param_type == 'boolean': + var_type = VariableType.BOOLEAN + elif param_type == 'object': + var_type = VariableType.OBJECT + elif param_type == 'array' or param_type == 'array[any]': + var_type = VariableType.ARRAY_ANY + elif param_type == 'array[string]': + var_type = VariableType.ARRAY_STRING + elif param_type == 'array[number]': + var_type = VariableType.ARRAY_NUMBER + elif param_type == 'array[object]': + var_type = VariableType.ARRAY_OBJECT + elif param_type == 'array[boolean]': + var_type = VariableType.ARRAY_BOOLEAN + elif param_type == 'array[file]': + var_type = VariableType.ARRAY_FILE + elif param_type == 'array[secret]': + var_type = VariableType.ARRAY_SECRET + elif param_type == 'file': + var_type = VariableType.FILE + elif param_type == 'secret': + var_type = VariableType.SECRET + + # 创建变量数据(用于缓存的字典格式) + variable_data = { + 'name': f"{node_id}.{param_name}", + 'var_type': var_type.value, + 'scope': VariableScope.CONVERSATION.value, + 'value': "", # 配置阶段的潜在变量,值为空 + 'description': description or f"来自节点 {node_id} 的输出参数 {param_name}", + 'created_at': datetime.now(UTC).isoformat(), + 'updated_at': datetime.now(UTC).isoformat(), + 'step_name': getattr(node, 'name', node_id), # 节点名称 + 'step_id': node_id # 节点ID + } + + variables_data.append(variable_data) + + logger.debug(f"为节点 {node_id} 创建了 {len(variables_data)} 个输出变量: {[v['name'] for v in variables_data]}") + return variables_data + + except Exception as e: + logger.error(f"创建节点输出变量失败: {e}") + return [] \ No newline at end of file diff --git a/assets/.config.example.toml b/assets/.config.example.toml index 9068706ee..28fb5ef8e 100644 --- a/assets/.config.example.toml +++ b/assets/.config.example.toml @@ -37,6 +37,17 @@ user = 'euler_copilot' password = '' database = 'euler_copilot' +[redis] +host = 'redis-db' +port = 6379 +password = '' +database = 0 +decode_responses = true +socket_timeout = 5.0 +socket_connect_timeout = 5.0 +max_connections = 10 +health_check_interval = 30 + [minio] endpoint = '127.0.0.1:9000' access_key = 'minioadmin' diff --git a/docs/variable_configuration.md b/docs/variable_configuration.md new file mode 100644 index 000000000..b105c4af1 --- /dev/null +++ b/docs/variable_configuration.md @@ -0,0 +1,133 @@ +# 变量存储格式配置说明 + +## 概述 + +节点执行完成后,系统会根据节点的`output_parameters`配置自动将输出数据保存到对话变量池中。变量的存储格式有两种: + +1. **直接格式**: `conversation.key` +2. **带前缀格式**: `conversation.step_id.key` + +## 配置方式 + +### 1. 通过节点类型配置 + +在 `apps/scheduler/executor/step_config.py` 中的 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 集合中添加节点类型: + +```python +DIRECT_CONVERSATION_VARIABLE_NODE_TYPES: Set[str] = { + "Start", # Start节点 + "Input", # 输入节点 + "YourNewNode", # 新增的节点类型 +} +``` + +### 2. 通过名称模式配置 + +在 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 集合中添加匹配模式: + +```python +DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS: Set[str] = { + "start", # 匹配以"start"开头的节点名称 + "init", # 匹配以"init"开头的节点名称 + "config", # 新增:匹配以"config"开头的节点名称 +} +``` + +## 判断逻辑 + +系统会按以下顺序判断是否使用直接格式: + +1. 检查节点的 `call_id` 是否在 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 中 +2. 检查节点的 `step_name` 是否在 `DIRECT_CONVERSATION_VARIABLE_NODE_TYPES` 中 +3. 检查节点名称(小写)是否以 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 中的模式开头 +4. 检查 `step_id`(小写)是否以 `DIRECT_CONVERSATION_VARIABLE_NAME_PATTERNS` 中的模式开头 + +如果任一条件满足,则使用直接格式 `conversation.key`,否则使用带前缀格式 `conversation.step_id.key`。 + +## 使用示例 + +### 示例1:Start节点 + +```json +// 节点配置 +{ + "call_id": "Start", + "step_name": "start", + "step_id": "start_001", + "output_parameters": { + "user_name": {"type": "string"}, + "session_id": {"type": "string"} + } +} + +// 保存的变量格式 +conversation.user_name = "张三" +conversation.session_id = "sess_123" +``` + +### 示例2:普通处理节点 + +```json +// 节点配置 +{ + "call_id": "Code", + "step_name": "数据处理", + "step_id": "process_001", + "output_parameters": { + "result": {"type": "object"}, + "status": {"type": "string"} + } +} + +// 保存的变量格式 +conversation.process_001.result = {...} +conversation.process_001.status = "success" +``` + +### 示例3:配置节点(新增类型) + +```python +# 在step_config.py中添加 +DIRECT_CONVERSATION_VARIABLE_NODE_TYPES.add("GlobalConfig") +``` + +```json +// 节点配置 +{ + "call_id": "GlobalConfig", + "step_name": "全局配置", + "step_id": "config_001", + "output_parameters": { + "api_key": {"type": "secret"}, + "timeout": {"type": "number"} + } +} + +// 保存的变量格式(使用直接格式) +conversation.api_key = "xxx" +conversation.timeout = 30 +``` + +## 变量引用 + +在其他节点中可以通过以下方式引用这些变量: + +```json +{ + "input_parameters": { + "user": { + "reference": "{{conversation.user_name}}" // 直接格式变量 + }, + "data": { + "reference": "{{conversation.process_001.result}}" // 带前缀格式变量 + } + } +} +``` + +## 注意事项 + +1. **一致性**: 建议同时添加大小写版本以确保兼容性 +2. **命名冲突**: 使用直接格式时需要注意变量名冲突问题 +3. **可追溯性**: 带前缀格式便于追踪变量来源,直接格式便于全局访问 +4. **配置变更**: 修改配置后需要重启服务才能生效 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 2f8222898..8a34b17dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ description = "EulerCopilot 后端服务" requires-python = "==3.11.6" dependencies = [ "aiofiles==24.1.0", + "redis==5.0.8", "asyncer==0.0.8", "asyncpg==0.30.0", "cryptography==44.0.2", -- Gitee