diff --git a/apps/main.py b/apps/main.py index c4ca2bfb116db11624f4e4fbbe4e876a86e5f1eb..0a18fdf5a6f0d1fdeb6eec267b066610faea52bf 100644 --- a/apps/main.py +++ b/apps/main.py @@ -36,6 +36,7 @@ from apps.routers import ( record, service, user, + variable, ) from apps.scheduler.pool.pool import Pool @@ -66,6 +67,7 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(variable.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" @@ -87,6 +89,10 @@ async def init_resources() -> None: await LanceDB().init() await Pool.init() TokenCalculator() + + # 初始化变量系统 + from apps.scheduler.variable.pool import get_variable_pool + await get_variable_pool() # 运行 if __name__ == "__main__": diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed629d90776b59ae3bf652381318dcabf0e..f77c81ee4d3a8813ec36d4fda71fdd1a4aa60e71 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -149,7 +149,7 @@ async def oidc_redirect() -> JSONResponse: # TODO(zwt): OIDC主动触发logout @router.post("/logout", dependencies=[Depends(verify_user)], response_model=ResponseData) -async def oidc_logout(token: str) -> JSONResponse: +async def oidc_logout(token: str) -> JSONResponse | None: """OIDC主动触发登出""" diff --git a/apps/routers/variable.py b/apps/routers/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..bf930584f8c20ba828b018cebebe98f693d53c9a --- /dev/null +++ b/apps/routers/variable.py @@ -0,0 +1,477 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""FastAPI 变量管理 API""" + +from typing import Annotated, List, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query, status +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.scheduler.variable.pool import get_variable_pool +from apps.scheduler.variable.type import VariableType, VariableScope +from apps.scheduler.variable.parser import VariableParser +from apps.schemas.response_data import ResponseData + +router = APIRouter( + prefix="/api/variable", + tags=["variable"], + dependencies=[ + Depends(verify_user), + ], +) + + +# 请求和响应模型 +class CreateVariableRequest(BaseModel): + """创建变量请求""" + name: str = Field(description="变量名称") + var_type: VariableType = Field(description="变量类型") + scope: VariableScope = Field(description="变量作用域") + value: Optional[str] = Field(default=None, description="变量值") + description: Optional[str] = Field(default=None, description="变量描述") + flow_id: Optional[str] = Field(default=None, description="流程ID(环境级和对话级变量必需)") + + +class UpdateVariableRequest(BaseModel): + """更新变量请求""" + value: Optional[str] = Field(default=None, description="新的变量值") + var_type: Optional[VariableType] = Field(default=None, description="新的变量类型") + description: Optional[str] = Field(default=None, description="新的变量描述") + + +class VariableResponse(BaseModel): + """变量响应""" + name: str = Field(description="变量名称") + var_type: str = Field(description="变量类型") + scope: str = Field(description="变量作用域") + value: str = Field(description="变量值") + description: Optional[str] = Field(description="变量描述") + created_at: str = Field(description="创建时间") + updated_at: str = Field(description="更新时间") + + +class VariableListResponse(BaseModel): + """变量列表响应""" + variables: List[VariableResponse] = Field(description="变量列表") + total: int = Field(description="总数量") + + +class ParseTemplateRequest(BaseModel): + """解析模板请求""" + template: str = Field(description="包含变量引用的模板") + flow_id: Optional[str] = Field(default=None, description="流程ID") + + +class ParseTemplateResponse(BaseModel): + """解析模板响应""" + parsed_template: str = Field(description="解析后的模板") + variables_used: List[str] = Field(description="使用的变量引用") + + +class ValidateTemplateResponse(BaseModel): + """验证模板响应""" + is_valid: bool = Field(description="是否有效") + invalid_references: List[str] = Field(description="无效的变量引用") + + +@router.post( + "/create", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + }, +) +async def create_variable( + user_sub: Annotated[str, Depends(get_user)], + request: CreateVariableRequest = Body(...), +) -> ResponseData: + """创建变量""" + try: + # 验证作用域权限 + if request.scope == VariableScope.SYSTEM: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="不允许创建系统级变量" + ) + + pool = await get_variable_pool() + + # 创建变量 + variable = await pool.add_variable( + name=request.name, + var_type=request.var_type, + scope=request.scope, + value=request.value, + description=request.description, + user_sub=None, # 不再支持用户级变量 + flow_id=request.flow_id if request.scope in [VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + ) + + # 如果是对话级变量,刷新缓存确保数据一致性 + if request.scope == VariableScope.CONVERSATION and request.flow_id: + await pool.refresh_conversation_cache(request.flow_id) + + return ResponseData( + code=200, + message="变量创建成功", + result={"variable_name": variable.name}, + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=str(e) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"创建变量失败: {str(e)}" + ) + + +@router.put( + "/update", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + }, +) +async def update_variable( + user_sub: Annotated[str, Depends(get_user)], + name: str = Query(..., description="变量名称"), + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), + request: UpdateVariableRequest = Body(...), +) -> ResponseData: + """更新变量值""" + try: + pool = await get_variable_pool() + + # 更新变量 + variable = await pool.update_variable( + name=name, + scope=scope, + value=request.value, + var_type=request.var_type, + description=request.description, + user_sub=None, # 不再支持用户级变量 + flow_id=flow_id, + ) + + # 如果是对话级变量,刷新缓存确保数据一致性 + if scope == VariableScope.CONVERSATION and flow_id: + await pool.refresh_conversation_cache(flow_id) + + return ResponseData( + code=200, + message="变量更新成功", + result={"variable_name": variable.name} + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e) + ) + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"更新变量失败: {str(e)}" + ) + + +@router.delete( + "/delete", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + }, +) +async def delete_variable( + user_sub: Annotated[str, Depends(get_user)], + name: str = Query(..., description="变量名称"), + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), +) -> ResponseData: + """删除变量""" + try: + pool = await get_variable_pool() + + # 删除变量 + success = await pool.delete_variable( + name=name, + scope=scope, + user_sub=None, # 不再支持用户级变量 + flow_id=flow_id, + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="变量不存在" + ) + + # 如果是对话级变量,刷新缓存确保数据一致性 + if scope == VariableScope.CONVERSATION and flow_id: + await pool.refresh_conversation_cache(flow_id) + + return ResponseData( + code=200, + message="变量删除成功", + result={"variable_name": name} + ) + + except ValueError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + except PermissionError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e) + ) + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"删除变量失败: {str(e)}" + ) + + +@router.get( + "/get", + responses={ + status.HTTP_200_OK: {"model": VariableResponse}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + }, +) +async def get_variable( + user_sub: Annotated[str, Depends(get_user)], + name: str = Query(..., description="变量名称"), + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), +) -> VariableResponse: + """获取单个变量""" + try: + pool = await get_variable_pool() + + # 获取变量 + variable = await pool.get_variable( + name=name, + scope=scope, + user_sub=None, # 不再支持用户级变量 + flow_id=flow_id if scope in [VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + ) + + if not variable: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="变量不存在" + ) + + # 检查权限 + if not variable.can_access(user_sub): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="没有权限访问此变量" + ) + + # 构建响应 + var_dict = variable.to_dict() + return VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + ) + + except HTTPException: + raise + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取变量失败: {str(e)}" + ) + + +@router.get( + "/list", + responses={ + status.HTTP_200_OK: {"model": VariableListResponse}, + }, +) +async def list_variables( + user_sub: Annotated[str, Depends(get_user)], + scope: VariableScope = Query(..., description="变量作用域"), + flow_id: Optional[str] = Query(default=None, description="流程ID(环境级和对话级变量必需)"), +) -> VariableListResponse: + """列出指定作用域的变量""" + try: + pool = await get_variable_pool() + + # 获取变量列表 + variables = await pool.list_variables( + scope=scope, + user_sub=None, # 不再支持用户级变量 + flow_id=flow_id if scope in [VariableScope.ENVIRONMENT, VariableScope.CONVERSATION] else None, + ) + + # 过滤权限并构建响应 + filtered_variables = [] + for variable in variables: + if variable.can_access(user_sub): + var_dict = variable.to_dict() + filtered_variables.append(VariableResponse( + name=variable.name, + var_type=variable.var_type.value, + scope=variable.scope.value, + value=str(var_dict["value"]) if var_dict["value"] is not None else "", + description=variable.metadata.description, + created_at=variable.metadata.created_at.isoformat(), + updated_at=variable.metadata.updated_at.isoformat(), + )) + + return VariableListResponse( + variables=filtered_variables, + total=len(filtered_variables) + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取变量列表失败: {str(e)}" + ) + + +@router.post( + "/parse", + responses={ + status.HTTP_200_OK: {"model": ParseTemplateResponse}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + }, +) +async def parse_template( + user_sub: Annotated[str, Depends(get_user)], + request: ParseTemplateRequest = Body(...), +) -> ParseTemplateResponse: + """解析模板中的变量引用""" + try: + # 创建变量解析器 + parser = VariableParser( + user_sub=user_sub, + flow_id=request.flow_id, + conversation_id=None, # 不再使用conversation_id + ) + + # 解析模板 + parsed_template = await parser.parse_template(request.template) + + # 提取使用的变量 + variables_used = await parser.extract_variables(request.template) + + return ParseTemplateResponse( + parsed_template=parsed_template, + variables_used=variables_used + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"解析模板失败: {str(e)}" + ) + + +@router.post( + "/validate", + responses={ + status.HTTP_200_OK: {"model": ValidateTemplateResponse}, + status.HTTP_400_BAD_REQUEST: {"model": ResponseData}, + }, +) +async def validate_template( + user_sub: Annotated[str, Depends(get_user)], + request: ParseTemplateRequest = Body(...), +) -> ValidateTemplateResponse: + """验证模板中的变量引用是否有效""" + try: + # 创建变量解析器 + parser = VariableParser( + user_sub=user_sub, + flow_id=request.flow_id, + conversation_id=None, # 不再使用conversation_id + ) + + # 验证模板 + is_valid, invalid_refs = await parser.validate_template(request.template) + + return ValidateTemplateResponse( + is_valid=is_valid, + invalid_references=invalid_refs + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"验证模板失败: {str(e)}" + ) + + +@router.get( + "/types", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + }, +) +async def get_variable_types() -> ResponseData: + """获取支持的变量类型列表""" + return ResponseData( + code=200, + message="获取变量类型成功", + result={ + "types": [vtype.value for vtype in VariableType], + "scopes": [scope.value for scope in VariableScope], + } + ) + + +@router.post( + "/clear-conversation", + responses={ + status.HTTP_200_OK: {"model": ResponseData}, + }, +) +async def clear_conversation_variables( + user_sub: Annotated[str, Depends(get_user)], + flow_id: str = Query(..., description="流程ID"), +) -> ResponseData: + """清空指定工作流的对话级变量""" + try: + pool = await get_variable_pool() + # 清空工作流的对话级变量 + await pool.clear_conversation_variables(flow_id) + + return ResponseData( + code=200, + message="工作流对话变量已清空", + result={"flow_id": flow_id} + ) + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"清空对话变量失败: {str(e)}" + ) \ No newline at end of file diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2ee8b862885b0d88e519bf00a1678a49863df2bd..2a75d042dae1e2ae38a627f3afdf2e67640c7f30 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -2,16 +2,20 @@ """Agent工具部分""" from apps.scheduler.call.api.api import API +from apps.scheduler.call.code.code import Code from apps.scheduler.call.graph.graph import Graph from apps.scheduler.call.llm.llm import LLM from apps.scheduler.call.mcp.mcp import MCP from apps.scheduler.call.rag.rag import RAG +from apps.scheduler.call.reply.direct_reply import DirectReply from apps.scheduler.call.sql.sql import SQL from apps.scheduler.call.suggest.suggest import Suggestion # 只包含需要在编排界面展示的工具 __all__ = [ "API", + "Code", + "DirectReply", "LLM", "MCP", "RAG", diff --git a/apps/scheduler/call/code/__init__.py b/apps/scheduler/call/code/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44da1c5078f9f6b32fec8cf3e23e6984aa1e8948 --- /dev/null +++ b/apps/scheduler/call/code/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""代码执行工具""" + +from apps.scheduler.call.code.code import Code + +__all__ = ["Code"] diff --git a/apps/scheduler/call/code/code.py b/apps/scheduler/call/code/code.py new file mode 100644 index 0000000000000000000000000000000000000000..ff2bbdc24fcc21322e995f79d5a6a9482e9094eb --- /dev/null +++ b/apps/scheduler/call/code/code.py @@ -0,0 +1,161 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""代码执行工具""" + +import json +import logging +from collections.abc import AsyncGenerator +from typing import Any + +import httpx +from pydantic import Field + +from apps.common.config import Config +from apps.scheduler.call.code.schema import CodeInput, CodeOutput +from apps.scheduler.call.core import CoreCall +from apps.schemas.enum_var import CallOutputType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) + +logger = logging.getLogger(__name__) + + +class Code(CoreCall, input_model=CodeInput, output_model=CodeOutput): + """代码执行工具""" + + to_user: bool = Field(default=True) + + # 代码执行参数 + code: str = Field(description="要执行的代码", default="") + code_type: str = Field(description="代码类型,支持python、javascript、bash", default="python") + security_level: str = Field(description="安全等级,low或high", default="low") + timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300) + memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024) + cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) + + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="代码执行", description="在安全的沙箱环境中执行Python、JavaScript、Bash代码。") + + + async def _init(self, call_vars: CallVars) -> CodeInput: + """初始化代码执行工具""" + # 构造用户信息 + user_info = { + "user_id": call_vars.ids.user_sub, + "username": call_vars.ids.user_sub, # 可以从其他地方获取真实用户名 + "permissions": ["execute"] + } + + return CodeInput( + code=self.code, + code_type=self.code_type, + user_info=user_info, + security_level=self.security_level, + timeout_seconds=self.timeout_seconds, + memory_limit_mb=self.memory_limit_mb, + cpu_limit=self.cpu_limit, + ) + + + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行代码""" + data = CodeInput(**input_data) + + try: + # 获取sandbox服务地址 + config = Config().get_config() + sandbox_url = config.sandbox.sandbox_service + + # 构造请求数据 + request_data = { + "code": data.code, + "code_type": data.code_type, + "user_info": data.user_info, + "security_level": data.security_level, + "timeout_seconds": data.timeout_seconds, + "memory_limit_mb": data.memory_limit_mb, + "cpu_limit": data.cpu_limit, + } + + # 发送执行请求 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + f"{sandbox_url.rstrip('/')}/execute", + json=request_data, + headers={"Content-Type": "application/json"} + ) + + if response.status_code != 200: + raise CallError( + message=f"代码执行服务请求失败: {response.status_code}", + data={"status_code": response.status_code, "response": response.text} + ) + + result = response.json() + task_id = result.get("task_id", "") + + # 轮询获取结果 + if task_id: + result = await self._wait_for_result(sandbox_url, task_id) + + # 返回结果 + yield CallOutputChunk( + type=CallOutputType.DATA, + content=CodeOutput( + task_id=task_id, + status=result.get("status", "unknown"), + result=result.get("result", {}), + output=result.get("output", ""), + error=result.get("error", ""), + ).model_dump(by_alias=True, exclude_none=True), + ) + + except httpx.TimeoutException: + raise CallError(message="代码执行超时", data={}) + except httpx.RequestError as e: + raise CallError(message=f"请求sandbox服务失败: {e!s}", data={}) + except Exception as e: + raise CallError(message=f"代码执行失败: {e!s}", data={}) + + + async def _wait_for_result(self, sandbox_url: str, task_id: str, max_attempts: int = 30) -> dict[str, Any]: + """等待任务执行完成""" + import asyncio + + async with httpx.AsyncClient(timeout=10.0) as client: + for _ in range(max_attempts): + try: + # 获取任务状态 + response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/status") + if response.status_code == 200: + status_result = response.json() + status = status_result.get("status", "") + + # 如果任务完成,获取结果 + if status in ["completed", "failed", "cancelled"]: + result_response = await client.get(f"{sandbox_url.rstrip('/')}/task/{task_id}/result") + if result_response.status_code == 200: + return result_response.json() + else: + return {"status": status, "error": "无法获取结果"} + + # 如果任务仍在运行,继续等待 + if status in ["pending", "running"]: + await asyncio.sleep(1) + continue + + # 其他状态或请求失败,等待后重试 + await asyncio.sleep(1) + + except Exception: + # 请求异常,等待后重试 + await asyncio.sleep(1) + + # 超时返回 + return {"status": "timeout", "error": "等待任务完成超时"} \ No newline at end of file diff --git a/apps/scheduler/call/code/schema.py b/apps/scheduler/call/code/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..71722d84a178be0f44138b59a865c29d96a693c7 --- /dev/null +++ b/apps/scheduler/call/code/schema.py @@ -0,0 +1,30 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""代码执行工具的数据结构""" + +from typing import Any + +from pydantic import BaseModel, Field + +from apps.scheduler.call.core import DataBase + + +class CodeInput(DataBase): + """代码执行工具的输入""" + + code: str = Field(description="要执行的代码") + code_type: str = Field(description="代码类型,支持python、javascript、bash") + user_info: dict[str, Any] = Field(description="用户信息", default={}) + security_level: str = Field(description="安全等级,low或high", default="low") + timeout_seconds: int = Field(description="超时时间(秒)", default=30, ge=1, le=300) + memory_limit_mb: int = Field(description="内存限制(MB)", default=128, ge=1, le=1024) + cpu_limit: float = Field(description="CPU限制", default=0.5, ge=0.1, le=2.0) + + +class CodeOutput(DataBase): + """代码执行工具的输出""" + + task_id: str = Field(description="任务ID") + status: str = Field(description="任务状态") + result: dict[str, Any] = Field(description="执行结果", default={}) + output: str = Field(description="执行输出", default="") + error: str = Field(description="错误信息", default="") \ No newline at end of file diff --git a/apps/scheduler/call/reply/__init__.py b/apps/scheduler/call/reply/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5588bcc298b338a542312a22843494eb8fe1d890 --- /dev/null +++ b/apps/scheduler/call/reply/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Reply工具模块""" + +from apps.scheduler.call.reply.direct_reply import DirectReply + +__all__ = [ + "DirectReply", +] diff --git a/apps/scheduler/call/reply/direct_reply.py b/apps/scheduler/call/reply/direct_reply.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f67f285a34f542e998be651e2d175a84556eb3 --- /dev/null +++ b/apps/scheduler/call/reply/direct_reply.py @@ -0,0 +1,72 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""直接回复工具""" + +import logging +from collections.abc import AsyncGenerator +from typing import Any + +from pydantic import Field + +from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.reply.schema import DirectReplyInput, DirectReplyOutput +from apps.scheduler.variable.integration import VariableIntegration +from apps.schemas.enum_var import CallOutputType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) + +logger = logging.getLogger(__name__) + + +class DirectReply(CoreCall, input_model=DirectReplyInput, output_model=DirectReplyOutput): + """直接回复工具,支持变量引用语法""" + + to_user: bool = Field(default=True) + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="直接回复", description="直接回复用户输入的内容,支持变量插入") + + async def _init(self, call_vars: CallVars) -> DirectReplyInput: + """初始化DirectReply工具""" + answer = getattr(self, 'answer', '') + return DirectReplyInput(answer=answer) + + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行直接回复""" + data = DirectReplyInput(**input_data) + + try: + # 解析answer中的变量引用语法 + parsed_answer = await VariableIntegration.parse_call_input( + data.answer, + user_sub=self._sys_vars.ids.user_sub, + flow_id=self._sys_vars.ids.flow_id, + conversation_id=self._sys_vars.ids.session_id + ) + + # 如果解析结果是字符串,直接使用;否则转换为字符串 + if isinstance(parsed_answer, str): + final_answer = parsed_answer + else: + final_answer = str(parsed_answer) + + logger.info(f"[DirectReply] 原始答案: {data.answer}") + logger.info(f"[DirectReply] 解析后答案: {final_answer}") + + # 直接返回处理后的内容 + yield CallOutputChunk( + type=CallOutputType.TEXT, + content=final_answer + ) + + except Exception as e: + logger.error(f"[DirectReply] 处理回复内容失败: {e}") + raise CallError( + message=f"直接回复处理失败:{e!s}", + data={"original_answer": data.answer} + ) from e diff --git a/apps/scheduler/call/reply/schema.py b/apps/scheduler/call/reply/schema.py new file mode 100644 index 0000000000000000000000000000000000000000..f13dc636284dc6da764aa73fc62e99405674febe --- /dev/null +++ b/apps/scheduler/call/reply/schema.py @@ -0,0 +1,18 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""DirectReply工具的输入输出定义""" + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class DirectReplyInput(DataBase): + """定义DirectReply工具调用的输入""" + + answer: str = Field(description="直接回复的内容,支持变量引用语法") + + +class DirectReplyOutput(DataBase): + """定义DirectReply工具调用的输出""" + + message: str = Field(description="处理后的回复消息") \ No newline at end of file diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa9ccd92b8f9b2d497f3983a64ff3a6981..e04b58e85b6a827ed6cd38f008909ce06539a1d7 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -13,6 +13,7 @@ from pydantic import ConfigDict from apps.scheduler.call.core import CoreCall from apps.scheduler.call.empty import Empty from apps.scheduler.call.facts.facts import FactsCall +from apps.scheduler.call.reply.direct_reply import DirectReply from apps.scheduler.call.slot.schema import SlotOutput from apps.scheduler.call.slot.slot import Slot from apps.scheduler.call.summary.summary import Summary @@ -70,6 +71,8 @@ class StepExecutor(BaseExecutor): return FactsCall if call_id == SpecialCallType.SLOT.value: return Slot + if call_id == SpecialCallType.DIRECT_REPLY.value: + return DirectReply # 从Pool中获取对应的Call call_cls: type[CoreCall] = await Pool().get_call(call_id) diff --git a/apps/scheduler/variable/README.md b/apps/scheduler/variable/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2b6313960e801ad521dd1e8870dde4577a033b3e --- /dev/null +++ b/apps/scheduler/variable/README.md @@ -0,0 +1,152 @@ +# 工作流变量管理系统 + +## 概述 + +工作流变量管理系统为Euler Copilot Framework提供了全面的变量管理功能,支持在工作流执行过程中进行变量的定义、存储、解析和使用。 + +## 功能特性 + +### 1. 多种变量类型支持 +- **基础类型**: String、Number、Boolean、Object、File +- **安全类型**: Secret(加密存储的密钥变量) +- **数组类型**: Array[Any]、Array[String]、Array[Number]、Array[Object]、Array[File]、Array[Boolean]、Array[Secret] + +### 2. 四种作用域 +- **系统级变量** (`system`): 只读,包含query、files、app_id等系统信息 +- **用户级变量** (`user`): 跟随用户,如个人API密钥、配置信息 +- **环境级变量** (`env`): 跟随工作流,如流程配置参数 +- **对话级变量** (`conversation`): 单次对话内有效,支持局部作用域 + +### 3. 变量解析语法 +支持在模板中使用以下语法引用变量: +``` +{{sys.query}} # 系统变量 +{{user.api_key}} # 用户变量 +{{env.config.timeout}} # 环境变量(支持嵌套访问) +{{conversation.result}} # 对话变量 +``` + +### 4. 安全保障 +- Secret变量加密存储 +- 访问权限控制 +- 审计日志记录 +- 访问频率限制 +- 密钥强度检查 + +## API 接口 + +### 变量管理 +- `POST /api/variable/create` - 创建变量 +- `PUT /api/variable/update` - 更新变量值 +- `DELETE /api/variable/delete` - 删除变量 +- `GET /api/variable/get` - 获取单个变量 +- `GET /api/variable/list` - 列出变量 + +### 模板解析 +- `POST /api/variable/parse` - 解析模板中的变量引用 +- `POST /api/variable/validate` - 验证模板中的变量引用 + +### 系统信息 +- `GET /api/variable/types` - 获取支持的变量类型 +- `POST /api/variable/clear-conversation` - 清空对话变量 + +## 使用示例 + +### 1. 创建用户级变量 +```json +{ + "name": "openai_api_key", + "var_type": "secret", + "scope": "user", + "value": "sk-xxxxxxxxxxxxxxxx", + "description": "OpenAI API密钥" +} +``` + +### 2. 在工作流中使用变量 +```yaml +steps: + llm_call: + node: "llm" + params: + system_prompt: "你是一个助手,用户查询:{{sys.query}}" + api_key: "{{user.openai_api_key}}" + temperature: "{{env.llm_config.temperature}}" +``` + +### 3. 在对话中设置临时变量 +工作流执行过程中可以动态设置对话级变量: +```python +await VariableIntegration.add_conversation_variable( + name="processing_result", + value={"status": "completed", "data": result}, + conversation_id=conversation_id, + var_type_str="object" +) +``` + +## 集成到工作流 + +系统自动集成到现有的工作流调度器中: + +1. **系统变量自动初始化** - 每次工作流启动时自动设置系统变量 +2. **输入参数解析** - Call的输入参数自动解析变量引用 +3. **模板渲染** - LLM提示词等模板自动替换变量 +4. **输出变量提取** - 步骤输出可自动提取为对话级变量 + +## 安全机制 + +### Secret变量保护 +- 使用AES加密存储 +- 基于用户ID和变量名生成唯一加密密钥 +- 显示时自动打码 +- 支持密钥轮换 + +### 访问控制 +- 权限验证(用户只能访问自己的变量) +- 访问频率限制(防止暴力破解) +- IP地址验证(可选) +- 失败尝试监控和临时封禁 + +### 审计日志 +- 完整的访问日志记录 +- 密钥访问审计(记录哈希值而非原始值) +- 自动清理过期日志 + +## 开发指南 + +### 添加新的变量类型 +1. 在`VariableType`枚举中添加新类型 +2. 创建继承自`BaseVariable`的新变量类 +3. 在`VARIABLE_CLASS_MAP`中注册映射关系 + +### 扩展变量解析 +1. 修改`VariableParser.VARIABLE_PATTERN`正则表达式 +2. 在`resolve_variable_reference`方法中添加新的解析逻辑 + +### 自定义安全策略 +1. 继承`SecretVariableSecurity`类 +2. 重写相关的安全检查方法 +3. 在应用启动时注册自定义安全管理器 + +## 故障排除 + +### 常见问题 +1. **变量解析失败** - 检查变量名是否存在,作用域是否正确 +2. **Secret变量解密失败** - 可能是加密密钥损坏,需要重新设置 +3. **访问被拒绝** - 检查用户权限和访问频率限制 + +### 日志调试 +启用DEBUG日志级别查看详细的变量解析过程: +```python +import logging +logging.getLogger('apps.scheduler.variable').setLevel(logging.DEBUG) +``` + +## 最佳实践 + +1. **变量命名** - 使用描述性的名称,避免特殊字符 +2. **作用域选择** - 根据变量的生命周期选择合适的作用域 +3. **Secret管理** - 定期轮换密钥,使用强密码 +4. **性能优化** - 避免在循环中频繁解析变量 +5. **安全审计** - 定期检查访问日志,监控异常行为 \ No newline at end of file diff --git a/apps/scheduler/variable/__init__.py b/apps/scheduler/variable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee14662d93577868a521054c932c61542731396 --- /dev/null +++ b/apps/scheduler/variable/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""工作流变量管理模块 + +这个模块提供了完整的变量管理功能,包括: +- 多种变量类型支持(String、Number、Boolean、Object、Secret、File、Array等) +- 四种作用域支持(系统级、用户级、环境级、对话级) +- 变量解析和模板替换 +- 安全的密钥变量处理 +- 与工作流调度器的集成 +""" + +from .type import VariableType, VariableScope +from .base import BaseVariable, VariableMetadata +from .variables import ( + StringVariable, + NumberVariable, + BooleanVariable, + ObjectVariable, + SecretVariable, + FileVariable, + ArrayVariable, + create_variable, + VARIABLE_CLASS_MAP, +) +from .pool import VariablePool, get_variable_pool +from .parser import VariableParser, VariableReferenceBuilder, VariableContext +from .integration import VariableIntegration + +__all__ = [ + # 基础类型和枚举 + "VariableType", + "VariableScope", + "VariableMetadata", + "BaseVariable", + + # 具体变量类型 + "StringVariable", + "NumberVariable", + "BooleanVariable", + "ObjectVariable", + "SecretVariable", + "FileVariable", + "ArrayVariable", + + # 工厂函数和映射 + "create_variable", + "VARIABLE_CLASS_MAP", + + # 变量池 + "VariablePool", + "get_variable_pool", + + # 解析器 + "VariableParser", + "VariableReferenceBuilder", + "VariableContext", + + # 集成功能 + "VariableIntegration", +] diff --git a/apps/scheduler/variable/base.py b/apps/scheduler/variable/base.py new file mode 100644 index 0000000000000000000000000000000000000000..11e3c246e02ee84406b4250d40473a8ba3097173 --- /dev/null +++ b/apps/scheduler/variable/base.py @@ -0,0 +1,183 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional +from pydantic import BaseModel, Field +from datetime import datetime, UTC + +from .type import VariableType, VariableScope + + +class VariableMetadata(BaseModel): + """变量元数据""" + name: str = Field(description="变量名称") + var_type: VariableType = Field(description="变量类型") + scope: VariableScope = Field(description="变量作用域") + description: Optional[str] = Field(default=None, description="变量描述") + created_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="创建时间") + updated_at: datetime = Field(default_factory=lambda: datetime.now(UTC), description="更新时间") + created_by: Optional[str] = Field(default=None, description="创建者用户ID") + + # 作用域相关属性 + user_sub: Optional[str] = Field(default=None, description="用户级变量的用户ID") + flow_id: Optional[str] = Field(default=None, description="环境级/对话级变量的流程ID") + + # 安全相关属性 + is_encrypted: bool = Field(default=False, description="是否加密存储") + access_permissions: Optional[Dict[str, Any]] = Field(default=None, description="访问权限") + + +class BaseVariable(ABC): + """变量处理基类""" + + def __init__(self, metadata: VariableMetadata, value: Any = None): + """初始化变量 + + Args: + metadata: 变量元数据 + value: 变量值 + """ + self.metadata = metadata + self._value = None # 先设置为None + self._original_value = value + self._initializing = True # 标记正在初始化 + + # 通过setter设置值,触发类型验证 + if value is not None: + self.value = value # 这会触发setter和类型验证 + else: + self._value = value # 如果是None则直接设置 + + self._initializing = False # 初始化完成 + + @property + def name(self) -> str: + """获取变量名称""" + return self.metadata.name + + @property + def var_type(self) -> VariableType: + """获取变量类型""" + return self.metadata.var_type + + @property + def scope(self) -> VariableScope: + """获取变量作用域""" + return self.metadata.scope + + @property + def value(self) -> Any: + """获取变量值""" + return self._value + + @value.setter + def value(self, new_value: Any) -> None: + """设置变量值""" + # 只有在非初始化阶段才检查系统级变量的修改限制 + if self.scope == VariableScope.SYSTEM and not getattr(self, '_initializing', False): + raise ValueError("系统级变量不能修改") + + # 验证类型 + if not self._validate_type(new_value): + raise TypeError(f"变量 {self.name} 的值类型不匹配,期望: {self.var_type}") + + self._value = new_value + self.metadata.updated_at = datetime.now(UTC) + + @abstractmethod + def _validate_type(self, value: Any) -> bool: + """验证值的类型是否正确 + + Args: + value: 要验证的值 + + Returns: + bool: 类型是否正确 + """ + pass + + @abstractmethod + def to_string(self) -> str: + """将变量转换为字符串表示 + + Returns: + str: 字符串表示 + """ + pass + + @abstractmethod + def to_dict(self) -> Dict[str, Any]: + """将变量转换为字典表示 + + Returns: + Dict[str, Any]: 字典表示 + """ + pass + + @abstractmethod + def serialize(self) -> Dict[str, Any]: + """序列化变量用于存储 + + Returns: + Dict[str, Any]: 序列化后的数据 + """ + pass + + @classmethod + @abstractmethod + def deserialize(cls, data: Dict[str, Any]) -> "BaseVariable": + """从序列化数据恢复变量 + + Args: + data: 序列化数据 + + Returns: + BaseVariable: 恢复的变量实例 + """ + pass + + def copy(self) -> "BaseVariable": + """创建变量的副本 + + Returns: + BaseVariable: 变量副本 + """ + # 深拷贝元数据和值 + import copy + new_metadata = copy.deepcopy(self.metadata) + new_value = copy.deepcopy(self._value) + return self.__class__(new_metadata, new_value) + + def reset(self) -> None: + """重置变量到初始值""" + if self.scope == VariableScope.SYSTEM: + raise ValueError("系统级变量不能重置") + + self._value = self._original_value + self.metadata.updated_at = datetime.now(UTC) + + def can_access(self, user_sub: str) -> bool: + """检查用户是否有权限访问此变量 + + Args: + user_sub: 用户ID + + Returns: + bool: 是否有权限访问 + """ + # 系统级变量所有人都可以访问 + if self.scope == VariableScope.SYSTEM: + return True + + # 用户级变量只有创建者可以访问 + if self.scope == VariableScope.USER: + return self.metadata.user_sub == user_sub + + # 环境级和对话级变量根据上下文判断(这里简化处理) + return True + + def __str__(self) -> str: + """字符串表示""" + return f"{self.name}({self.var_type.value})={self.to_string()}" + + def __repr__(self) -> str: + """调试表示""" + return f"<{self.__class__.__name__}(name='{self.name}', type='{self.var_type.value}', scope='{self.scope.value}')>" \ No newline at end of file diff --git a/apps/scheduler/variable/integration.py b/apps/scheduler/variable/integration.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c54c9c596c3acc5389df254651a4414088daed --- /dev/null +++ b/apps/scheduler/variable/integration.py @@ -0,0 +1,377 @@ +"""变量解析与工作流调度器集成""" + +import logging +from typing import Any, Dict, List, Optional, Union + +from apps.scheduler.variable.parser import VariableParser +from apps.scheduler.variable.pool import get_variable_pool +from apps.scheduler.variable.type import VariableScope + +logger = logging.getLogger(__name__) + + +class VariableIntegration: + """变量解析集成类 - 为现有调度器提供变量功能""" + + @staticmethod + async def initialize_system_variables(context: Dict[str, Any]) -> None: + """初始化系统变量 + + Args: + context: 系统上下文信息,包括用户查询、文件等 + """ + try: + parser = VariableParser( + user_sub=context.get("user_sub"), + flow_id=context.get("flow_id"), + conversation_id=context.get("conversation_id") + ) + + # 更新系统变量 + await parser.update_system_variables(context) + + logger.info("系统变量已初始化") + except Exception as e: + logger.error(f"初始化系统变量失败: {e}") + raise + + @staticmethod + async def parse_call_input(input_data: Dict[str, Any], + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> Union[str, Dict, List]: + """解析Call输入中的变量引用 + + Args: + input_data: 输入数据 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + Dict[str, Any]: 解析后的输入数据 + """ + try: + parser = VariableParser( + user_sub=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + # 递归解析JSON模板中的变量引用 + parsed_input = await parser.parse_json_template(input_data) + + return parsed_input + + except Exception as e: + logger.warning(f"解析Call输入变量失败: {e}") + # 如果解析失败,返回原始输入 + return input_data + + @staticmethod + async def parse_template_string(template: str, + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> str: + """解析模板字符串中的变量引用 + + Args: + template: 模板字符串 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + str: 解析后的字符串 + """ + try: + parser = VariableParser( + user_sub=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + return await parser.parse_template(template) + except Exception as e: + logger.warning(f"解析模板字符串失败: {e}") + # 如果解析失败,返回原始模板 + return template + + @staticmethod + async def add_conversation_variable(name: str, + value: Any, + conversation_id: str, + var_type_str: str = "string") -> bool: + """添加对话级变量 + + Args: + name: 变量名 + value: 变量值 + conversation_id: 对话ID(在内部作为flow_id使用) + var_type_str: 变量类型字符串 + + Returns: + bool: 是否添加成功 + """ + try: + from apps.scheduler.variable.type import VariableType + + # 转换变量类型 + var_type = VariableType(var_type_str) + + pool = await get_variable_pool() + # 在内部将conversation_id作为flow_id传递 + await pool.add_variable( + name=name, + var_type=var_type, + scope=VariableScope.CONVERSATION, + value=value, + flow_id=conversation_id # 统一使用flow_id + ) + + logger.debug(f"已添加对话变量: {name} = {value}") + return True + except Exception as e: + logger.error(f"添加对话变量失败: {e}") + return False + + @staticmethod + async def update_conversation_variable(name: str, + value: Any, + conversation_id: str) -> bool: + """更新对话级变量 + + Args: + name: 变量名 + value: 新值 + conversation_id: 对话ID + + Returns: + bool: 是否更新成功 + """ + try: + pool = await get_variable_pool() + # 在内部将conversation_id作为flow_id传递 + await pool.update_variable( + name=name, + scope=VariableScope.CONVERSATION, + value=value, + flow_id=conversation_id # 统一使用flow_id + ) + + logger.debug(f"已更新对话变量: {name} = {value}") + return True + except Exception as e: + logger.error(f"更新对话变量失败: {e}") + return False + + @staticmethod + async def extract_output_variables(output_data: Dict[str, Any], + conversation_id: str, + step_name: str) -> None: + """从步骤输出中提取变量并设置为对话级变量 + + Args: + output_data: 步骤输出数据 + conversation_id: 对话ID + step_name: 步骤名称 + """ + try: + # 将整个输出作为对象变量存储 + await VariableIntegration.add_conversation_variable( + name=f"step_{step_name}_output", + value=output_data, + conversation_id=conversation_id, + var_type_str="object" + ) + + # 如果输出中有特定的变量定义,也可以单独提取 + if isinstance(output_data, dict): + for key, value in output_data.items(): + if key.startswith("var_"): + # 以 var_ 开头的字段被视为变量定义 + var_name = key[4:] # 移除 var_ 前缀 + await VariableIntegration.add_conversation_variable( + name=var_name, + value=value, + conversation_id=conversation_id, + var_type_str="string" + ) + + except Exception as e: + logger.error(f"提取输出变量失败: {e}") + + @staticmethod + async def clear_conversation_context(conversation_id: str) -> None: + """清理对话上下文中的变量 + + Args: + conversation_id: 对话ID + """ + try: + pool = await get_variable_pool() + await pool.clear_conversation_variables(conversation_id) + logger.info(f"已清理对话 {conversation_id} 的变量") + except Exception as e: + logger.error(f"清理对话变量失败: {e}") + + @staticmethod + async def validate_variable_references(template: str, + user_sub: str, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None) -> tuple[bool, list[str]]: + """验证模板中的变量引用是否有效 + + Args: + template: 模板字符串 + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + + Returns: + tuple[bool, list[str]]: (是否全部有效, 无效的变量引用列表) + """ + try: + parser = VariableParser( + user_sub=user_sub, + flow_id=flow_id, + conversation_id=conversation_id + ) + + return await parser.validate_template(template) + except Exception as e: + logger.error(f"验证变量引用失败: {e}") + return False, [str(e)] + + +# 为现有调度器提供的扩展方法 +class CoreCallExtension: + """CoreCall的变量支持扩展""" + + @staticmethod + async def enhanced_set_input(core_call_instance, executor) -> None: + """增强版的_set_input方法,支持变量解析 + + Args: + core_call_instance: CoreCall实例 + executor: StepExecutor实例 + """ + # 调用原始的_set_input方法 + original_set_input = getattr(core_call_instance.__class__, '_original_set_input', None) + if original_set_input: + await original_set_input(core_call_instance, executor) + else: + # 如果没有原始方法,执行基本的输入设置 + from apps.scheduler.call.core import CoreCall + core_call_instance._sys_vars = CoreCall._assemble_call_vars(executor) + input_data = await core_call_instance._init(core_call_instance._sys_vars) + core_call_instance.input = input_data.model_dump(by_alias=True, exclude_none=True) + + # 解析输入中的变量引用 + if hasattr(core_call_instance, 'input') and core_call_instance.input: + # 首先初始化系统变量 + context = { + "question": executor.question, + "user_sub": executor.task.ids.user_sub, + "flow_id": executor.task.state.flow_id if executor.task.state else None, + "session_id": executor.task.ids.session_id, + "app_id": executor.task.state.app_id if executor.task.state else None, + "conversation_id": executor.task.ids.session_id, # 使用session_id作为conversation_id + } + + await VariableIntegration.initialize_system_variables(context) + + # 解析输入变量 + core_call_instance.input = await VariableIntegration.parse_call_input( + core_call_instance.input, + user_sub=executor.task.ids.user_sub, + flow_id=executor.task.state.flow_id if executor.task.state else None, + conversation_id=executor.task.ids.session_id + ) + + +class LLMCallExtension: + """LLM Call的变量支持扩展""" + + @staticmethod + async def enhanced_prepare_message(llm_instance, call_vars) -> list[dict[str, Any]]: + """增强版的_prepare_message方法,支持变量解析 + + Args: + llm_instance: LLM实例 + call_vars: CallVars实例 + + Returns: + list[dict[str, Any]]: 解析后的消息列表 + """ + # 调用原始的_prepare_message方法 + original_prepare_message = getattr(llm_instance.__class__, '_original_prepare_message', None) + if original_prepare_message: + messages = await original_prepare_message(llm_instance, call_vars) + else: + # 如果没有原始方法,返回基本消息 + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": call_vars.question} + ] + + # 解析消息中的变量引用 + parsed_messages = [] + for message in messages: + if "content" in message and isinstance(message["content"], str): + parsed_content = await VariableIntegration.parse_template_string( + message["content"], + user_sub=call_vars.ids.user_sub, + flow_id=call_vars.ids.flow_id, + conversation_id=call_vars.ids.session_id + ) + parsed_messages.append({ + **message, + "content": parsed_content + }) + else: + parsed_messages.append(message) + + return parsed_messages + + +def monkey_patch_scheduler(): + """为现有调度器添加变量支持的猴子补丁""" + try: + # 为CoreCall添加变量支持 + from apps.scheduler.call.core import CoreCall + + # 保存原始方法 + original_set_input = getattr(CoreCall, '_set_input', None) + if original_set_input: + setattr(CoreCall, '_original_set_input', original_set_input) + + # 替换为增强版本 + setattr(CoreCall, '_set_input', CoreCallExtension.enhanced_set_input) + # 为LLM Call添加变量支持 + try: + from apps.scheduler.call.llm.llm import LLM + + # 保存原始方法 + original_prepare_message = getattr(LLM, '_prepare_message', None) + if original_prepare_message: + setattr(LLM, '_original_prepare_message', original_prepare_message) + + # 替换为增强版本 + setattr(LLM, '_prepare_message', LLMCallExtension.enhanced_prepare_message) + + except ImportError: + logger.warning("LLM Call 不存在,跳过LLM变量支持") + + logger.info("变量解析功能已集成到调度器中") + + except Exception as e: + logger.error(f"集成变量解析功能失败: {e}") + raise + + +# 在模块导入时自动应用补丁 +try: + monkey_patch_scheduler() +except Exception as e: + logger.error(f"自动应用变量解析补丁失败: {e}") \ No newline at end of file diff --git a/apps/scheduler/variable/parser.py b/apps/scheduler/variable/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..79cfad5fb6277e6a96c63719b414959624232eb2 --- /dev/null +++ b/apps/scheduler/variable/parser.py @@ -0,0 +1,370 @@ +import re +import logging +from typing import Any, Dict, List, Optional, Tuple, Union +import json +from datetime import datetime, UTC + +from .pool import get_variable_pool +from .type import VariableScope + +logger = logging.getLogger(__name__) + + +class VariableParser: + """变量解析器 - 解析和替换模板中的变量引用""" + + # 变量引用的正则表达式:{{scope.variable_name.nested_path}} + VARIABLE_PATTERN = re.compile(r'\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)\s*\}\}') + + def __init__(self, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None, + conversation_id: Optional[str] = None): + """初始化变量解析器 + + Args: + user_sub: 用户ID + flow_id: 流程ID + conversation_id: 对话ID + """ + self.user_sub = user_sub + self.flow_id = flow_id + self.conversation_id = conversation_id + self._variable_pool = None + + async def _get_pool(self): + """获取变量池实例""" + if self._variable_pool is None: + self._variable_pool = await get_variable_pool() + return self._variable_pool + + async def parse_template(self, template: str) -> str: + """解析模板字符串,替换其中的变量引用 + + Args: + template: 包含变量引用的模板字符串 + + Returns: + str: 替换后的字符串 + """ + if not template: + return template + + pool = await self._get_pool() + + # 查找所有变量引用 + matches = self.VARIABLE_PATTERN.findall(template) + + # 替换每个变量引用 + result = template + for match in matches: + try: + # 解析变量引用 + value = await pool.resolve_variable_reference( + f"{{{{{match}}}}}", + user_sub=self.user_sub, + flow_id=self.flow_id, + conversation_id=self.conversation_id + ) + + # 转换为字符串 + str_value = self._convert_to_string(value) + + # 替换模板中的变量引用 + result = result.replace(f"{{{{{match}}}}}", str_value) + + logger.debug(f"已替换变量: {{{{{match}}}}} -> {str_value}") + + except Exception as e: + logger.warning(f"解析变量引用失败: {{{{{match}}}}} - {e}") + # 保持原始引用不变 + continue + + return result + + async def extract_variables(self, template: str) -> List[str]: + """提取模板中的所有变量引用 + + Args: + template: 模板字符串 + + Returns: + List[str]: 变量引用列表 + """ + if not template: + return [] + + matches = self.VARIABLE_PATTERN.findall(template) + return [f"{{{{{match}}}}}" for match in matches] + + async def validate_template(self, template: str) -> Tuple[bool, List[str]]: + """验证模板中的变量引用是否都存在 + + Args: + template: 模板字符串 + + Returns: + Tuple[bool, List[str]]: (是否全部有效, 无效的变量引用列表) + """ + if not template: + return True, [] + + pool = await self._get_pool() + matches = self.VARIABLE_PATTERN.findall(template) + invalid_refs = [] + + for match in matches: + try: + await pool.resolve_variable_reference( + f"{{{{{match}}}}}", + user_sub=self.user_sub, + flow_id=self.flow_id, + conversation_id=self.conversation_id + ) + except Exception: + invalid_refs.append(f"{{{{{match}}}}}") + + return len(invalid_refs) == 0, invalid_refs + + async def parse_json_template(self, json_template: Union[str, Dict, List]) -> Union[str, Dict, List]: + """解析JSON格式的模板,递归处理所有字符串值中的变量引用 + + Args: + json_template: JSON模板(字符串、字典或列表) + + Returns: + Union[str, Dict, List]: 解析后的JSON + """ + if isinstance(json_template, str): + return await self.parse_template(json_template) + elif isinstance(json_template, dict): + result = {} + for key, value in json_template.items(): + # 键也可能包含变量引用 + parsed_key = await self.parse_template(str(key)) + parsed_value = await self.parse_json_template(value) + result[parsed_key] = parsed_value + return result + elif isinstance(json_template, list): + result = [] + for item in json_template: + parsed_item = await self.parse_json_template(item) + result.append(parsed_item) + return result + else: + # 其他类型直接返回 + return json_template + + async def update_system_variables(self, context: Dict[str, Any]): + """更新系统变量的值 + + Args: + context: 系统上下文信息 + """ + pool = await self._get_pool() + + # 预定义的系统变量映射 + system_var_mappings = { + "query": context.get("question", ""), + "files": context.get("files", []), + "dialogue_count": context.get("dialogue_count", 0), + "app_id": context.get("app_id", ""), + "flow_id": context.get("flow_id", ""), + "user_id": context.get("user_sub", ""), + "session_id": context.get("session_id", ""), + "timestamp": datetime.now(UTC).timestamp(), + } + + # 更新系统变量 + for var_name, var_value in system_var_mappings.items(): + try: + variable = await pool.get_variable(var_name, VariableScope.SYSTEM) + if variable: + # 系统变量需要特殊处理,绕过只读限制 + variable._value = var_value + logger.debug(f"已更新系统变量: {var_name} = {var_value}") + except Exception as e: + logger.warning(f"更新系统变量失败: {var_name} - {e}") + + def _convert_to_string(self, value: Any) -> str: + """将值转换为字符串 + + Args: + value: 要转换的值 + + Returns: + str: 字符串表示 + """ + if value is None: + return "" + elif isinstance(value, str): + return value + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, (dict, list)): + try: + return json.dumps(value, ensure_ascii=False, separators=(',', ':')) + except (TypeError, ValueError): + return str(value) + else: + return str(value) + + @classmethod + def escape_variable_reference(cls, text: str) -> str: + """转义变量引用,防止被解析 + + Args: + text: 包含变量引用的文本 + + Returns: + str: 转义后的文本 + """ + return text.replace("{{", "\\{\\{").replace("}}", "\\}\\}") + + @classmethod + def unescape_variable_reference(cls, text: str) -> str: + """取消转义变量引用 + + Args: + text: 转义的文本 + + Returns: + str: 取消转义后的文本 + """ + return text.replace("\\{\\{", "{{").replace("\\}\\}", "}}") + + +class VariableReferenceBuilder: + """变量引用构建器 - 帮助构建标准的变量引用字符串""" + + @staticmethod + def system(var_name: str, nested_path: Optional[str] = None) -> str: + """构建系统变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径(如 config.api_key) + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{sys.{var_name}.{nested_path}}}}}" + return f"{{{{sys.{var_name}}}}}" + + @staticmethod + def user(var_name: str, nested_path: Optional[str] = None) -> str: + """构建用户变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径 + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{user.{var_name}.{nested_path}}}}}" + return f"{{{{user.{var_name}}}}}" + + @staticmethod + def environment(var_name: str, nested_path: Optional[str] = None) -> str: + """构建环境变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径 + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{env.{var_name}.{nested_path}}}}}" + return f"{{{{env.{var_name}}}}}" + + @staticmethod + def conversation(var_name: str, nested_path: Optional[str] = None) -> str: + """构建对话变量引用 + + Args: + var_name: 变量名 + nested_path: 嵌套路径 + + Returns: + str: 变量引用字符串 + """ + if nested_path: + return f"{{{{conversation.{var_name}.{nested_path}}}}}" + return f"{{{{conversation.{var_name}}}}}" + + +class VariableContext: + """变量上下文管理器 - 管理局部变量作用域""" + + def __init__(self, + parser: VariableParser, + parent_context: Optional["VariableContext"] = None): + """初始化变量上下文 + + Args: + parser: 变量解析器 + parent_context: 父级上下文(用于嵌套作用域) + """ + self.parser = parser + self.parent_context = parent_context + self._local_variables: Dict[str, Any] = {} + + def set_local_variable(self, name: str, value: Any): + """设置局部变量 + + Args: + name: 变量名 + value: 变量值 + """ + self._local_variables[name] = value + + def get_local_variable(self, name: str) -> Any: + """获取局部变量 + + Args: + name: 变量名 + + Returns: + Any: 变量值 + """ + if name in self._local_variables: + return self._local_variables[name] + elif self.parent_context: + return self.parent_context.get_local_variable(name) + return None + + async def parse_with_locals(self, template: str) -> str: + """使用局部变量解析模板 + + Args: + template: 模板字符串 + + Returns: + str: 解析后的字符串 + """ + # 首先用局部变量替换 + result = template + + # 替换局部变量(使用简单的 ${var_name} 语法) + for var_name, var_value in self._local_variables.items(): + pattern = f"${{{var_name}}}" + str_value = self.parser._convert_to_string(var_value) + result = result.replace(pattern, str_value) + + # 然后用全局变量解析器处理剩余的变量引用 + return await self.parser.parse_template(result) + + def create_child_context(self) -> "VariableContext": + """创建子级上下文 + + Returns: + VariableContext: 子级上下文 + """ + return VariableContext(self.parser, self) \ No newline at end of file diff --git a/apps/scheduler/variable/pool.py b/apps/scheduler/variable/pool.py new file mode 100644 index 0000000000000000000000000000000000000000..1d7ac4a260f70e62195d07680c90124dea267618 --- /dev/null +++ b/apps/scheduler/variable/pool.py @@ -0,0 +1,802 @@ +import logging +from typing import Any, Dict, List, Optional, Set +from collections import defaultdict +import asyncio +from contextlib import asynccontextmanager +from datetime import datetime, UTC + +from apps.common.mongo import MongoDB +from .base import BaseVariable, VariableMetadata +from .type import VariableType, VariableScope +from .variables import create_variable, VARIABLE_CLASS_MAP + +logger = logging.getLogger(__name__) + + +class VariablePool: + """变量池 - 管理所有作用域的变量""" + + def __init__(self): + """初始化变量池""" + # 内存缓存:scope -> {variable_name: variable} + self._variables: Dict[VariableScope, Dict[str, BaseVariable]] = { + VariableScope.SYSTEM: {}, + VariableScope.USER: {}, + VariableScope.ENVIRONMENT: {}, + VariableScope.CONVERSATION: {}, + } + + # 上下文相关的变量缓存 + self._user_variables: Dict[str, Dict[str, BaseVariable]] = defaultdict(dict) # user_sub -> variables + self._env_variables: Dict[str, Dict[str, BaseVariable]] = defaultdict(dict) # flow_id -> variables + self._conv_variables: Dict[str, Dict[str, BaseVariable]] = defaultdict(dict) # flow_id -> variables (对话级变量) + + # 系统级变量定义 + self._system_variables_initialized = False + self._lock = asyncio.Lock() + + async def initialize(self): + """初始化变量池,加载系统变量""" + async with self._lock: + if not self._system_variables_initialized: + await self._initialize_system_variables() + self._system_variables_initialized = True + + async def _initialize_system_variables(self): + """初始化系统级变量""" + system_vars = [ + ("query", VariableType.STRING, "用户查询内容", ""), + ("files", VariableType.ARRAY_FILE, "用户上传的文件列表", []), + ("dialogue_count", VariableType.NUMBER, "对话轮数", 0), + ("app_id", VariableType.STRING, "应用ID", ""), + ("flow_id", VariableType.STRING, "工作流ID", ""), + ("user_id", VariableType.STRING, "用户ID", ""), + ("session_id", VariableType.STRING, "会话ID", ""), + ("timestamp", VariableType.NUMBER, "当前时间戳", 0), + ] + + for var_name, var_type, description, default_value in system_vars: + metadata = VariableMetadata( + name=var_name, + var_type=var_type, + scope=VariableScope.SYSTEM, + description=description, + created_by="system" + ) + variable = create_variable(metadata, default_value) + self._variables[VariableScope.SYSTEM][var_name] = variable + + logger.info(f"已初始化 {len(system_vars)} 个系统级变量") + + async def add_variable(self, + name: str, + var_type: VariableType, + scope: VariableScope, + value: Any = None, + description: Optional[str] = None, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None) -> BaseVariable: + """添加变量 + + Args: + name: 变量名 + var_type: 变量类型 + scope: 作用域 + value: 初始值 + description: 描述 + user_sub: 用户ID(用户级变量必需) + flow_id: 流程ID(环境级变量必需) + + Returns: + BaseVariable: 创建的变量 + """ + await self.initialize() + + # 验证作用域相关参数 + if scope == VariableScope.SYSTEM: + raise ValueError("不能直接添加系统级变量") + elif scope == VariableScope.USER and not user_sub: + raise ValueError("用户级变量必须指定 user_sub") + elif scope == VariableScope.ENVIRONMENT and not flow_id: + raise ValueError("环境级变量必须指定 flow_id") + elif scope == VariableScope.CONVERSATION and not flow_id: + raise ValueError("对话级变量必须指定 flow_id") + + # 检查变量是否已存在 + existing_var = await self.get_variable(name, scope, user_sub, flow_id) + if existing_var: + raise ValueError(f"变量 {name} 在作用域 {scope.value} 中已存在") + + # 创建变量元数据 + metadata = VariableMetadata( + name=name, + var_type=var_type, + scope=scope, + description=description, + user_sub=user_sub, + flow_id=flow_id, + created_by=user_sub or "system" + ) + + # 创建变量 + variable = create_variable(metadata, value) + + # 存储到对应的缓存中 + await self._store_variable(variable) + + # 持久化存储 - 为了更好的用户体验,对话级变量也进行持久化 + await self._persist_variable(variable) + + logger.info(f"已添加变量: {name} ({var_type.value}) 到作用域 {scope.value}") + return variable + + async def update_variable(self, + name: str, + scope: VariableScope, + value: Optional[Any] = None, + var_type: Optional[VariableType] = None, + description: Optional[str] = None, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None) -> BaseVariable: + """更新变量值、类型或描述 + + Args: + name: 变量名 + scope: 作用域 + value: 新值(可选) + var_type: 新变量类型(可选) + description: 新描述(可选) + user_sub: 用户ID + flow_id: 流程ID + + Returns: + BaseVariable: 更新后的变量 + """ + variable = await self.get_variable(name, scope, user_sub, flow_id) + if not variable: + raise ValueError(f"变量 {name} 在作用域 {scope.value} 中不存在") + + # 检查权限 + if user_sub and not variable.can_access(user_sub): + raise PermissionError(f"用户 {user_sub} 没有权限修改变量 {name}") + + # 至少需要更新一个字段 + if value is None and var_type is None and description is None: + raise ValueError("至少需要指定一个要更新的字段(value、var_type 或 description)") + + # 更新字段 + if value is not None: + variable.value = value + if var_type is not None: + variable.metadata.var_type = var_type + if description is not None: + variable.metadata.description = description + + # 更新时间戳 + variable.metadata.updated_at = datetime.now(UTC) + + # 更新缓存 + await self._store_variable(variable) + + # 持久化更新 - 为了更好的用户体验,对话级变量也进行持久化 + await self._persist_variable(variable) + + logger.info(f"已更新变量: {name} 在作用域 {scope.value}") + return variable + + async def delete_variable(self, + name: str, + scope: VariableScope, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None) -> bool: + """删除变量 + + Args: + name: 变量名 + scope: 作用域 + user_sub: 用户ID + flow_id: 流程ID + + Returns: + bool: 是否删除成功 + """ + if scope == VariableScope.SYSTEM: + raise ValueError("不能删除系统级变量") + + variable = await self.get_variable(name, scope, user_sub, flow_id) + if not variable: + return False + + # 检查权限 + if user_sub and not variable.can_access(user_sub): + raise PermissionError(f"用户 {user_sub} 没有权限删除变量 {name}") + + # 从缓存中删除 + await self._remove_variable_from_cache(variable) + # 从数据库删除 + await self._delete_variable_from_db(variable) + + logger.info(f"已删除变量: {name} 从作用域 {scope.value}") + return True + + async def get_variable(self, + name: str, + scope: VariableScope, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None) -> Optional[BaseVariable]: + """获取变量 + + Args: + name: 变量名 + scope: 作用域 + user_sub: 用户ID + flow_id: 流程ID + + Returns: + Optional[BaseVariable]: 变量或None + """ + await self.initialize() + + # 根据作用域从对应缓存中查找 + if scope == VariableScope.SYSTEM: + return self._variables[VariableScope.SYSTEM].get(name) + elif scope == VariableScope.USER and user_sub: + # 先从缓存查找 + if name in self._user_variables[user_sub]: + return self._user_variables[user_sub][name] + # 从数据库加载 + return await self._load_user_variable(name, user_sub) + elif scope == VariableScope.ENVIRONMENT and flow_id: + # 先从缓存查找 + if name in self._env_variables[flow_id]: + return self._env_variables[flow_id][name] + # 从数据库加载 + return await self._load_env_variable(name, flow_id) + elif scope == VariableScope.CONVERSATION and flow_id: + # 先从缓存查找 + if name in self._conv_variables[flow_id]: + return self._conv_variables[flow_id][name] + # 从数据库加载 + return await self._load_conv_variable(name, flow_id) + + return None + + async def list_variables(self, + scope: VariableScope, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None) -> List[BaseVariable]: + """列出指定作用域的所有变量 + + Args: + scope: 作用域 + user_sub: 用户ID + flow_id: 流程ID + + Returns: + List[BaseVariable]: 变量列表 + """ + await self.initialize() + + if scope == VariableScope.SYSTEM: + return list(self._variables[VariableScope.SYSTEM].values()) + elif scope == VariableScope.USER and user_sub: + # 加载用户的所有变量 + await self._load_all_user_variables(user_sub) + return list(self._user_variables[user_sub].values()) + elif scope == VariableScope.ENVIRONMENT and flow_id: + # 加载环境的所有变量 + await self._load_all_env_variables(flow_id) + return list(self._env_variables[flow_id].values()) + elif scope == VariableScope.CONVERSATION and flow_id: + # 加载对话级的所有变量 + await self._load_all_conv_variables(flow_id) + return list(self._conv_variables[flow_id].values()) + + return [] + + async def clear_conversation_variables(self, flow_id: str): + """清空工作流的对话级变量""" + if flow_id in self._conv_variables: + del self._conv_variables[flow_id] + logger.info(f"已清空工作流 {flow_id} 的对话级变量") + + async def refresh_conversation_cache(self, flow_id: str): + """刷新对话级变量缓存""" + if flow_id in self._conv_variables: + # 清空当前缓存 + del self._conv_variables[flow_id] + + # 重新加载 + await self._load_all_conv_variables(flow_id) + logger.info(f"已刷新工作流 {flow_id} 的对话级变量缓存") + + async def resolve_variable_reference(self, + reference: str, + user_sub: Optional[str] = None, + flow_id: Optional[str] = None) -> Any: + """解析变量引用,如 {{sys.query}} 或 {{user.token}} + + Args: + reference: 变量引用字符串 + user_sub: 用户ID + flow_id: 流程ID + + Returns: + Any: 变量值 + """ + # 移除 {{ 和 }} + clean_ref = reference.strip("{}").strip() + + # 解析作用域和变量名 + parts = clean_ref.split(".", 1) + if len(parts) != 2: + raise ValueError(f"无效的变量引用格式: {reference}") + + scope_str, var_path = parts + + # 确定作用域 + scope_map = { + "sys": VariableScope.SYSTEM, + "system": VariableScope.SYSTEM, + "user": VariableScope.USER, + "env": VariableScope.ENVIRONMENT, + "environment": VariableScope.ENVIRONMENT, + "conversation": VariableScope.CONVERSATION, + "conv": VariableScope.CONVERSATION, + } + + scope = scope_map.get(scope_str) + if not scope: + raise ValueError(f"无效的变量作用域: {scope_str}") + + # 解析变量路径(支持嵌套访问如 user.config.api_key) + path_parts = var_path.split(".") + var_name = path_parts[0] + + # 获取变量 + variable = await self.get_variable(var_name, scope, user_sub, flow_id) + if not variable: + raise ValueError(f"变量不存在: {clean_ref}") + + # 获取变量值 + value = variable.value + + # 如果有嵌套路径,继续解析 + for path_part in path_parts[1:]: + if isinstance(value, dict): + value = value.get(path_part) + elif isinstance(value, list) and path_part.isdigit(): + try: + value = value[int(path_part)] + except IndexError: + value = None + else: + raise ValueError(f"无法访问路径: {var_path}") + + return value + + async def _store_variable(self, variable: BaseVariable): + """存储变量到缓存""" + scope = variable.scope + name = variable.name + + if scope == VariableScope.SYSTEM: + self._variables[scope][name] = variable + elif scope == VariableScope.USER: + user_sub = variable.metadata.user_sub + if user_sub: + self._user_variables[user_sub][name] = variable + elif scope == VariableScope.ENVIRONMENT: + flow_id = variable.metadata.flow_id + if flow_id: + self._env_variables[flow_id][name] = variable + elif scope == VariableScope.CONVERSATION: + flow_id = variable.metadata.flow_id + if flow_id: + self._conv_variables[flow_id][name] = variable + + async def _remove_variable_from_cache(self, variable: BaseVariable): + """从缓存中移除变量""" + scope = variable.scope + name = variable.name + + if scope == VariableScope.USER: + user_sub = variable.metadata.user_sub + if user_sub and name in self._user_variables[user_sub]: + del self._user_variables[user_sub][name] + elif scope == VariableScope.ENVIRONMENT: + flow_id = variable.metadata.flow_id + if flow_id and name in self._env_variables[flow_id]: + del self._env_variables[flow_id][name] + elif scope == VariableScope.CONVERSATION: + flow_id = variable.metadata.flow_id + if flow_id and name in self._conv_variables[flow_id]: + del self._conv_variables[flow_id][name] + + async def _persist_variable(self, variable: BaseVariable): + """持久化变量到数据库""" + try: + collection = MongoDB().get_collection("variables") + data = variable.serialize() + + # 构建查询条件 + query = { + "metadata.name": variable.name, + "metadata.scope": variable.scope.value + } + + if variable.scope == VariableScope.USER and variable.metadata.user_sub: + query["metadata.user_sub"] = variable.metadata.user_sub + elif variable.scope == VariableScope.ENVIRONMENT and variable.metadata.flow_id: + query["metadata.flow_id"] = variable.metadata.flow_id + elif variable.scope == VariableScope.CONVERSATION and variable.metadata.flow_id: + query["metadata.flow_id"] = variable.metadata.flow_id + + # 更新或插入 - 强制等待写入确认 + from pymongo import WriteConcern + result = await collection.with_options( + write_concern=WriteConcern(w="majority", j=True) + ).replace_one(query, data, upsert=True) + + # 确保写入成功 + if not (result.acknowledged and (result.matched_count > 0 or result.upserted_id)): + raise RuntimeError(f"变量持久化失败: {variable.name}") + + except Exception as e: + logger.error(f"持久化变量失败: {e}") + raise + + async def _delete_variable_from_db(self, variable: BaseVariable): + """从数据库删除变量""" + try: + collection = MongoDB().get_collection("variables") + + query = { + "metadata.name": variable.name, + "metadata.scope": variable.scope.value + } + + if variable.scope == VariableScope.USER and variable.metadata.user_sub: + query["metadata.user_sub"] = variable.metadata.user_sub + elif variable.scope == VariableScope.ENVIRONMENT and variable.metadata.flow_id: + query["metadata.flow_id"] = variable.metadata.flow_id + elif variable.scope == VariableScope.CONVERSATION and variable.metadata.flow_id: + query["metadata.flow_id"] = variable.metadata.flow_id + + # 强制等待写入确认的删除操作 + from pymongo import WriteConcern + result = await collection.with_options( + write_concern=WriteConcern(w="majority", j=True) + ).delete_one(query) + + # 确保删除成功 + if not result.acknowledged: + raise RuntimeError(f"变量删除失败: {variable.name}") + + except Exception as e: + logger.error(f"删除变量失败: {e}") + raise + + async def _load_user_variable(self, name: str, user_sub: str) -> Optional[BaseVariable]: + """从数据库加载用户变量""" + try: + collection = MongoDB().get_collection("variables") + doc = await collection.find_one({ + "metadata.name": name, + "metadata.scope": VariableScope.USER.value, + "metadata.user_sub": user_sub + }) + + if doc: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + try: + variable = var_class.deserialize(doc) + self._user_variables[user_sub][name] = variable + return variable + except Exception as e: + logger.warning(f"用户变量 {name} 数据损坏,将从数据库删除: {e}") + await self._delete_corrupted_variable(doc) + return None + + return None + + except Exception as e: + logger.error(f"加载用户变量失败: {e}") + return None + + async def _load_env_variable(self, name: str, flow_id: str) -> Optional[BaseVariable]: + """从数据库加载环境变量""" + try: + collection = MongoDB().get_collection("variables") + doc = await collection.find_one({ + "metadata.name": name, + "metadata.scope": VariableScope.ENVIRONMENT.value, + "metadata.flow_id": flow_id + }) + + if doc: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + try: + variable = var_class.deserialize(doc) + self._env_variables[flow_id][name] = variable + return variable + except Exception as e: + logger.warning(f"环境变量 {name} 数据损坏,将从数据库删除: {e}") + await self._delete_corrupted_variable(doc) + return None + + return None + + except Exception as e: + logger.error(f"加载环境变量失败: {e}") + return None + + async def _load_conv_variable(self, name: str, flow_id: str) -> Optional[BaseVariable]: + """从数据库加载对话级变量""" + try: + collection = MongoDB().get_collection("variables") + doc = await collection.find_one({ + "metadata.name": name, + "metadata.scope": VariableScope.CONVERSATION.value, + "metadata.flow_id": flow_id + }) + + if doc: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + try: + variable = var_class.deserialize(doc) + self._conv_variables[flow_id][name] = variable + return variable + except Exception as e: + logger.warning(f"对话级变量 {name} 数据损坏,将从数据库删除: {e}") + await self._delete_corrupted_variable(doc) + return None + + return None + + except Exception as e: + logger.error(f"加载对话级变量失败: {e}") + return None + + async def _load_all_user_variables(self, user_sub: str): + """加载用户的所有变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.USER.value, + "metadata.user_sub": user_sub + }) + + corrupted_count = 0 + loaded_count = 0 + + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._user_variables[user_sub][variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"用户变量 {var_name} 数据损坏,将从数据库删除: {e}") + await self._delete_corrupted_variable(doc) + corrupted_count += 1 + + if corrupted_count > 0: + logger.info(f"用户 {user_sub} 加载变量完成: 成功 {loaded_count} 个,清理损坏数据 {corrupted_count} 个") + else: + logger.debug(f"用户 {user_sub} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载用户所有变量失败: {e}") + + async def _load_all_env_variables(self, flow_id: str): + """加载环境的所有变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.ENVIRONMENT.value, + "metadata.flow_id": flow_id + }) + + corrupted_count = 0 + loaded_count = 0 + + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._env_variables[flow_id][variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"环境变量 {var_name} 数据损坏,将从数据库删除: {e}") + await self._delete_corrupted_variable(doc) + corrupted_count += 1 + + if corrupted_count > 0: + logger.info(f"环境 {flow_id} 加载变量完成: 成功 {loaded_count} 个,清理损坏数据 {corrupted_count} 个") + else: + logger.debug(f"环境 {flow_id} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载环境所有变量失败: {e}") + + async def _load_all_conv_variables(self, flow_id: str): + """加载对话级的所有变量""" + try: + collection = MongoDB().get_collection("variables") + cursor = collection.find({ + "metadata.scope": VariableScope.CONVERSATION.value, + "metadata.flow_id": flow_id + }) + + corrupted_count = 0 + loaded_count = 0 + + async for doc in cursor: + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + variable = var_class.deserialize(doc) + self._conv_variables[flow_id][variable.name] = variable + loaded_count += 1 + break + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"对话级变量 {var_name} 数据损坏,将从数据库删除: {e}") + await self._delete_corrupted_variable(doc) + corrupted_count += 1 + + if corrupted_count > 0: + logger.info(f"对话 {flow_id} 加载变量完成: 成功 {loaded_count} 个,清理损坏数据 {corrupted_count} 个") + else: + logger.debug(f"对话 {flow_id} 加载变量完成: {loaded_count} 个") + + except Exception as e: + logger.error(f"加载对话级所有变量失败: {e}") + + async def _delete_corrupted_variable(self, doc: Dict[str, Any]): + """删除损坏的变量数据 + + Args: + doc: 损坏的变量文档 + """ + try: + collection = MongoDB().get_collection("variables") + + # 记录损坏的数据用于分析 + corrupted_collection = MongoDB().get_collection("corrupted_variables") + backup_doc = doc.copy() + backup_doc["corrupted_at"] = datetime.now(UTC) + backup_doc["reason"] = "deserialization_failed" + + # 备份损坏数据 + try: + await corrupted_collection.insert_one(backup_doc) + except Exception as backup_e: + logger.warning(f"备份损坏变量数据失败: {backup_e}") + + # 删除原始损坏数据 + if "_id" in doc: + result = await collection.delete_one({"_id": doc["_id"]}) + if result.deleted_count > 0: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.debug(f"已删除损坏的变量数据: {var_name}") + else: + logger.warning(f"未能删除损坏的变量数据,可能已被删除") + + except Exception as e: + logger.error(f"删除损坏变量数据失败: {e}") + # 即使删除失败也不抛出异常,避免影响其他变量的加载 + + async def cleanup_corrupted_variables(self) -> Dict[str, int]: + """手动清理所有损坏的变量数据 + + Returns: + Dict[str, int]: 清理统计信息 {"checked": 检查数量, "cleaned": 清理数量} + """ + logger.info("开始检查和清理损坏的变量数据...") + + collection = MongoDB().get_collection("variables") + checked_count = 0 + cleaned_count = 0 + + try: + # 遍历所有变量进行检查 + cursor = collection.find({}) + + async for doc in cursor: + checked_count += 1 + try: + variable_class_name = doc.get("class") + if variable_class_name in [cls.__name__ for cls in VARIABLE_CLASS_MAP.values()]: + # 找到对应的变量类并尝试反序列化 + for var_class in VARIABLE_CLASS_MAP.values(): + if var_class.__name__ == variable_class_name: + var_class.deserialize(doc) # 只是测试反序列化,不保存 + break + else: + # 未知的变量类型也视为损坏 + raise ValueError(f"未知的变量类型: {variable_class_name}") + + except Exception as e: + var_name = doc.get("metadata", {}).get("name", "unknown") + logger.warning(f"发现损坏的变量 {var_name},将删除: {e}") + await self._delete_corrupted_variable(doc) + cleaned_count += 1 + + result = {"checked": checked_count, "cleaned": cleaned_count} + logger.info(f"变量数据清理完成: 检查了 {checked_count} 个变量,清理了 {cleaned_count} 个损坏的变量") + return result + + except Exception as e: + logger.error(f"清理损坏变量失败: {e}") + return {"checked": checked_count, "cleaned": cleaned_count} + + async def get_corrupted_variables_info(self) -> List[Dict[str, Any]]: + """获取已备份的损坏变量信息 + + Returns: + List[Dict[str, Any]]: 损坏变量信息列表 + """ + try: + corrupted_collection = MongoDB().get_collection("corrupted_variables") + corrupted_vars = [] + + cursor = corrupted_collection.find({}).sort("corrupted_at", -1).limit(100) + async for doc in cursor: + var_info = { + "name": doc.get("metadata", {}).get("name", "unknown"), + "scope": doc.get("metadata", {}).get("scope", "unknown"), + "var_type": doc.get("metadata", {}).get("var_type", "unknown"), + "corrupted_at": doc.get("corrupted_at"), + "reason": doc.get("reason", "unknown"), + "user_sub": doc.get("metadata", {}).get("user_sub"), + "flow_id": doc.get("metadata", {}).get("flow_id"), + } + corrupted_vars.append(var_info) + + return corrupted_vars + + except Exception as e: + logger.error(f"获取损坏变量信息失败: {e}") + return [] + + +# 全局变量池实例 +_variable_pool = None + + +async def get_variable_pool() -> VariablePool: + """获取全局变量池实例""" + global _variable_pool + if _variable_pool is None: + _variable_pool = VariablePool() + await _variable_pool.initialize() + return _variable_pool \ No newline at end of file diff --git a/apps/scheduler/variable/security.py b/apps/scheduler/variable/security.py new file mode 100644 index 0000000000000000000000000000000000000000..0289a29dac62d24273f3fb7fb4b0984f8bd0bba6 --- /dev/null +++ b/apps/scheduler/variable/security.py @@ -0,0 +1,409 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""变量安全管理模块 + +提供密钥变量的额外安全保障,包括: +- 访问审计日志 +- 密钥轮换 +- 安全检查 +- 权限验证 +""" + +import logging +import hashlib +import time +from typing import Any, Dict, List, Optional, Set +from datetime import datetime, UTC, timedelta +from dataclasses import dataclass + +from apps.common.mongo import MongoDB + +logger = logging.getLogger(__name__) + + +@dataclass +class AccessLog: + """访问日志记录""" + variable_name: str + user_sub: str + access_time: datetime + access_type: str # read, write, delete + ip_address: Optional[str] = None + user_agent: Optional[str] = None + success: bool = True + error_message: Optional[str] = None + + +class SecretVariableSecurity: + """密钥变量安全管理器""" + + def __init__(self): + """初始化安全管理器""" + self._access_logs: List[AccessLog] = [] + self._failed_access_attempts: Dict[str, List[datetime]] = {} + self._blocked_users: Set[str] = set() + + # 安全配置 + self.max_failed_attempts = 5 # 最大失败尝试次数 + self.block_duration_minutes = 30 # 封禁时长(分钟) + self.audit_retention_days = 90 # 审计日志保留天数 + + async def verify_access_permission(self, + variable_name: str, + user_sub: str, + access_type: str = "read", + ip_address: Optional[str] = None) -> bool: + """验证访问权限 + + Args: + variable_name: 变量名 + user_sub: 用户ID + access_type: 访问类型(read, write, delete) + ip_address: IP地址 + + Returns: + bool: 是否允许访问 + """ + try: + # 检查用户是否被封禁 + if await self._is_user_blocked(user_sub): + await self._log_access( + variable_name, user_sub, access_type, ip_address, + success=False, error_message="用户被临时封禁" + ) + return False + + # 检查是否超出访问频率限制 + if await self._check_rate_limit(user_sub): + await self._log_access( + variable_name, user_sub, access_type, ip_address, + success=False, error_message="访问频率过高" + ) + await self._record_failed_attempt(user_sub) + return False + + # 验证IP地址(可选) + if ip_address and not await self._is_ip_allowed(ip_address): + await self._log_access( + variable_name, user_sub, access_type, ip_address, + success=False, error_message="IP地址不在允许列表中" + ) + return False + + # 记录成功访问 + await self._log_access(variable_name, user_sub, access_type, ip_address) + await self._clear_failed_attempts(user_sub) + + return True + + except Exception as e: + logger.error(f"验证访问权限失败: {e}") + return False + + async def audit_secret_access(self, + variable_name: str, + user_sub: str, + actual_value: str, + access_type: str = "read") -> None: + """审计密钥访问 + + Args: + variable_name: 变量名 + user_sub: 用户ID + actual_value: 实际访问的值 + access_type: 访问类型 + """ + try: + # 计算值的哈希(用于审计,不存储原始值) + value_hash = hashlib.sha256(actual_value.encode()).hexdigest()[:16] + + # 记录详细的审计信息 + audit_record = { + "variable_name": variable_name, + "user_sub": user_sub, + "access_time": datetime.now(UTC), + "access_type": access_type, + "value_hash": value_hash, + "value_length": len(actual_value), + } + + # 保存到审计日志 + await self._save_audit_record(audit_record) + + logger.info(f"已审计密钥访问: {variable_name} by {user_sub}") + + except Exception as e: + logger.error(f"审计密钥访问失败: {e}") + + async def check_secret_strength(self, secret_value: str) -> Dict[str, Any]: + """检查密钥强度 + + Args: + secret_value: 密钥值 + + Returns: + Dict[str, Any]: 检查结果 + """ + result = { + "is_strong": False, + "score": 0, + "warnings": [], + "recommendations": [] + } + + try: + # 长度检查 + if len(secret_value) < 8: + result["warnings"].append("密钥长度过短(建议至少8位)") + elif len(secret_value) >= 12: + result["score"] += 2 + else: + result["score"] += 1 + + # 复杂性检查 + has_upper = any(c.isupper() for c in secret_value) + has_lower = any(c.islower() for c in secret_value) + has_digit = any(c.isdigit() for c in secret_value) + has_special = any(c in "!@#$%^&*()_+-=[]{}|;:,.<>?" for c in secret_value) + + complexity_score = sum([has_upper, has_lower, has_digit, has_special]) + result["score"] += complexity_score + + if complexity_score < 3: + result["recommendations"].append("建议包含大小写字母、数字和特殊字符") + + # 常见密码检查 + common_passwords = ["password", "123456", "admin", "root", "qwerty"] + if secret_value.lower() in common_passwords: + result["warnings"].append("使用了常见的弱密码") + result["score"] = 0 + + # 重复字符检查 + if len(set(secret_value)) < len(secret_value) * 0.6: + result["warnings"].append("包含过多重复字符") + + # 设置强度等级 + if result["score"] >= 6 and len(result["warnings"]) == 0: + result["is_strong"] = True + + return result + + except Exception as e: + logger.error(f"检查密钥强度失败: {e}") + return result + + async def rotate_secret_key(self, variable_name: str, user_sub: str) -> bool: + """轮换密钥的加密密钥 + + Args: + variable_name: 变量名 + user_sub: 用户ID + + Returns: + bool: 是否轮换成功 + """ + try: + from .pool import get_variable_pool + from .type import VariableScope + + pool = await get_variable_pool() + + # 获取密钥变量 + variable = await pool.get_variable(variable_name, VariableScope.USER, user_sub=user_sub) + if not variable or not variable.var_type.is_secret_type(): + return False + + # 获取原始值 + original_value = variable.value + + # 重新生成加密密钥并重新加密 + variable._encryption_key = variable._generate_encryption_key() + variable.value = original_value # 这会触发重新加密 + + # 更新存储 + await pool._persist_variable(variable) + + # 记录轮换操作 + await self._log_access( + variable_name, user_sub, "key_rotation", + success=True, error_message="密钥轮换成功" + ) + + logger.info(f"已轮换密钥变量的加密密钥: {variable_name}") + return True + + except Exception as e: + logger.error(f"轮换密钥失败: {e}") + return False + + async def get_access_logs(self, + variable_name: Optional[str] = None, + user_sub: Optional[str] = None, + hours: int = 24) -> List[Dict[str, Any]]: + """获取访问日志 + + Args: + variable_name: 变量名(可选) + user_sub: 用户ID(可选) + hours: 查询最近几小时的日志 + + Returns: + List[Dict[str, Any]]: 访问日志列表 + """ + try: + collection = MongoDB().get_collection("variable_access_logs") + + # 构建查询条件 + query = { + "access_time": { + "$gte": datetime.now(UTC) - timedelta(hours=hours) + } + } + + if variable_name: + query["variable_name"] = variable_name + if user_sub: + query["user_sub"] = user_sub + + # 查询日志 + logs = [] + async for doc in collection.find(query).sort("access_time", -1): + logs.append({ + "variable_name": doc.get("variable_name"), + "user_sub": doc.get("user_sub"), + "access_time": doc.get("access_time"), + "access_type": doc.get("access_type"), + "ip_address": doc.get("ip_address"), + "success": doc.get("success", True), + "error_message": doc.get("error_message") + }) + + return logs + + except Exception as e: + logger.error(f"获取访问日志失败: {e}") + return [] + + async def _is_user_blocked(self, user_sub: str) -> bool: + """检查用户是否被封禁""" + if user_sub not in self._failed_access_attempts: + return False + + failed_attempts = self._failed_access_attempts[user_sub] + recent_failures = [ + attempt for attempt in failed_attempts + if attempt > datetime.now(UTC) - timedelta(minutes=self.block_duration_minutes) + ] + + return len(recent_failures) >= self.max_failed_attempts + + async def _check_rate_limit(self, user_sub: str) -> bool: + """检查访问频率限制""" + try: + # 获取最近1分钟的访问记录 + collection = MongoDB().get_collection("variable_access_logs") + count = await collection.count_documents({ + "user_sub": user_sub, + "access_time": { + "$gte": datetime.now(UTC) - timedelta(minutes=1) + } + }) + + # 限制每分钟最多30次访问 + return count >= 30 + + except Exception as e: + logger.error(f"检查访问频率失败: {e}") + return False + + async def _is_ip_allowed(self, ip_address: str) -> bool: + """检查IP地址是否被允许""" + # 这里可以实现IP白名单/黑名单逻辑 + # 暂时返回True,表示允许所有IP + return True + + async def _record_failed_attempt(self, user_sub: str): + """记录失败尝试""" + if user_sub not in self._failed_access_attempts: + self._failed_access_attempts[user_sub] = [] + + self._failed_access_attempts[user_sub].append(datetime.now(UTC)) + + # 清理过期的失败记录 + cutoff_time = datetime.now(UTC) - timedelta(minutes=self.block_duration_minutes) + self._failed_access_attempts[user_sub] = [ + attempt for attempt in self._failed_access_attempts[user_sub] + if attempt > cutoff_time + ] + + async def _clear_failed_attempts(self, user_sub: str): + """清除失败尝试记录""" + if user_sub in self._failed_access_attempts: + del self._failed_access_attempts[user_sub] + + async def _log_access(self, + variable_name: str, + user_sub: str, + access_type: str, + ip_address: Optional[str] = None, + success: bool = True, + error_message: Optional[str] = None): + """记录访问日志""" + try: + log_entry = { + "variable_name": variable_name, + "user_sub": user_sub, + "access_time": datetime.now(UTC), + "access_type": access_type, + "ip_address": ip_address, + "success": success, + "error_message": error_message + } + + # 保存到数据库 + collection = MongoDB().get_collection("variable_access_logs") + await collection.insert_one(log_entry) + + except Exception as e: + logger.error(f"记录访问日志失败: {e}") + + async def _save_audit_record(self, audit_record: Dict[str, Any]): + """保存审计记录""" + try: + collection = MongoDB().get_collection("variable_audit_logs") + await collection.insert_one(audit_record) + except Exception as e: + logger.error(f"保存审计记录失败: {e}") + + async def cleanup_old_logs(self): + """清理过期的日志记录""" + try: + cutoff_time = datetime.now(UTC) - timedelta(days=self.audit_retention_days) + + # 清理访问日志 + access_collection = MongoDB().get_collection("variable_access_logs") + result1 = await access_collection.delete_many({ + "access_time": {"$lt": cutoff_time} + }) + + # 清理审计日志 + audit_collection = MongoDB().get_collection("variable_audit_logs") + result2 = await audit_collection.delete_many({ + "access_time": {"$lt": cutoff_time} + }) + + logger.info(f"已清理过期日志: 访问日志 {result1.deleted_count} 条, 审计日志 {result2.deleted_count} 条") + + except Exception as e: + logger.error(f"清理过期日志失败: {e}") + + +# 全局安全管理器实例 +_security_manager = None + + +def get_security_manager() -> SecretVariableSecurity: + """获取全局安全管理器实例""" + global _security_manager + if _security_manager is None: + _security_manager = SecretVariableSecurity() + return _security_manager \ No newline at end of file diff --git a/apps/scheduler/variable/type.py b/apps/scheduler/variable/type.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb4627d8b5cf3762c24eb9bc0214623d56c5c8b --- /dev/null +++ b/apps/scheduler/variable/type.py @@ -0,0 +1,72 @@ +from enum import StrEnum + + +class VariableType(StrEnum): + """变量类型枚举""" + # 基础类型 + NUMBER = "number" + STRING = "string" + BOOLEAN = "boolean" + OBJECT = "object" + SECRET = "secret" + GROUP = "group" + FILE = "file" + + # 数组类型 + ARRAY_ANY = "array[any]" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + ARRAY_FILE = "array[file]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_SECRET = "array[secret]" + + def is_array_type(self) -> bool: + """检查是否为数组类型""" + return self in _ARRAY_TYPES + + def is_secret_type(self) -> bool: + """检查是否为密钥类型""" + return self in _SECRET_TYPES + + def get_array_element_type(self) -> "VariableType | None": + """获取数组元素类型""" + if not self.is_array_type(): + return None + + element_type_map = { + VariableType.ARRAY_ANY: None, # any类型无法确定具体元素类型 + VariableType.ARRAY_STRING: VariableType.STRING, + VariableType.ARRAY_NUMBER: VariableType.NUMBER, + VariableType.ARRAY_OBJECT: VariableType.OBJECT, + VariableType.ARRAY_FILE: VariableType.FILE, + VariableType.ARRAY_BOOLEAN: VariableType.BOOLEAN, + VariableType.ARRAY_SECRET: VariableType.SECRET, + } + return element_type_map.get(self) + + +class VariableScope(StrEnum): + """变量作用域枚举""" + SYSTEM = "system" # 系统级变量(只读) + USER = "user" # 用户级变量(跟随用户) + ENVIRONMENT = "env" # 环境级变量(跟随flow) + CONVERSATION = "conversation" # 对话级变量(每次运行重新初始化) + + +# 数组类型集合 +_ARRAY_TYPES = frozenset([ + VariableType.ARRAY_ANY, + VariableType.ARRAY_STRING, + VariableType.ARRAY_NUMBER, + VariableType.ARRAY_OBJECT, + VariableType.ARRAY_FILE, + VariableType.ARRAY_BOOLEAN, + VariableType.ARRAY_SECRET, +]) + +# 密钥类型集合 +_SECRET_TYPES = frozenset([ + VariableType.SECRET, + VariableType.ARRAY_SECRET, +]) \ No newline at end of file diff --git a/apps/scheduler/variable/variables.py b/apps/scheduler/variable/variables.py new file mode 100644 index 0000000000000000000000000000000000000000..2902dd3adf454ff96aaa98eb2507a533ffccae7c --- /dev/null +++ b/apps/scheduler/variable/variables.py @@ -0,0 +1,479 @@ +import json +import base64 +import hashlib +from typing import Any, Dict, List, Union, Optional +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC + +from .base import BaseVariable, VariableMetadata +from .type import VariableType + + +class StringVariable(BaseVariable): + """字符串变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为字符串类型""" + return isinstance(value, str) + + def to_string(self) -> str: + """转换为字符串""" + return str(self._value) if self._value is not None else "" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "StringVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class NumberVariable(BaseVariable): + """数字变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为数字类型""" + return isinstance(value, (int, float)) + + def to_string(self) -> str: + """转换为字符串""" + return str(self._value) if self._value is not None else "0" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "NumberVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class BooleanVariable(BaseVariable): + """布尔变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为布尔类型""" + return isinstance(value, bool) + + def to_string(self) -> str: + """转换为字符串""" + return str(self._value).lower() if self._value is not None else "false" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "BooleanVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class ObjectVariable(BaseVariable): + """对象变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为对象类型""" + return isinstance(value, dict) + + def to_string(self) -> str: + """转换为字符串""" + if self._value is None: + return "{}" + try: + return json.dumps(self._value, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + return str(self._value) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "ObjectVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class SecretVariable(BaseVariable): + """密钥变量 - 提供安全存储和访问机制""" + + def __init__(self, metadata: VariableMetadata, value: Any = None, encryption_key: Optional[str] = None): + """初始化密钥变量 + + Args: + metadata: 变量元数据 + value: 变量值 + encryption_key: 加密密钥(可选,如果不提供会自动生成) + """ + # 先设置加密密钥,因为父类初始化时可能会调用value setter + self._encryption_key = encryption_key or self._generate_encryption_key() + super().__init__(metadata, value) + self.metadata.is_encrypted = True + + # 如果提供了值,确保它已被加密(在value setter中已处理) + # 这里不需要再次加密,因为super().__init__已经通过setter处理了 + + def _generate_encryption_key(self) -> str: + """生成加密密钥""" + # 使用用户ID和变量名生成唯一的加密密钥 + user_sub = self.metadata.user_sub or "default" + salt = f"{user_sub}:{self.metadata.name}".encode() + + kdf = PBKDF2HMAC( + algorithm=hashes.SHA256(), + length=32, + salt=salt, + iterations=100000, + ) + key = base64.urlsafe_b64encode(kdf.derive(b"secret_variable_key")) + return key.decode() + + def _encrypt_value(self, value: str) -> str: + """加密值""" + if not isinstance(value, str): + value = str(value) + + f = Fernet(self._encryption_key.encode()) + encrypted_value = f.encrypt(value.encode()) + return base64.urlsafe_b64encode(encrypted_value).decode() + + def _decrypt_value(self, encrypted_value: str) -> str: + """解密值""" + try: + encrypted_bytes = base64.urlsafe_b64decode(encrypted_value.encode()) + f = Fernet(self._encryption_key.encode()) + decrypted_value = f.decrypt(encrypted_bytes) + return decrypted_value.decode() + except Exception: + return "[解密失败]" + + def _validate_type(self, value: Any) -> bool: + """验证值类型""" + return isinstance(value, str) + + @property + def value(self) -> str: + """获取解密后的值""" + if self._value is None: + return "" + return self._decrypt_value(self._value) + + @value.setter + def value(self, new_value: Any) -> None: + """设置新值(会自动加密)""" + if self.scope.value == "system": + raise ValueError("系统级变量不能修改") + + if not self._validate_type(new_value): + raise TypeError(f"变量 {self.name} 的值类型不匹配,期望: {self.var_type}") + + self._value = self._encrypt_value(new_value) + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def get_masked_value(self) -> str: + """获取掩码值用于显示""" + actual_value = self.value + if len(actual_value) <= 4: + return "*" * len(actual_value) + return actual_value[:2] + "*" * (len(actual_value) - 4) + actual_value[-2:] + + def to_string(self) -> str: + """转换为字符串(掩码形式)""" + return self.get_masked_value() + + def to_dict(self) -> Dict[str, Any]: + """转换为字典(掩码形式)""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self.get_masked_value(), + "scope": self.scope.value + } + + def to_dict_with_actual_value(self, user_sub: str) -> Dict[str, Any]: + """转换为包含实际值的字典(需要权限检查)""" + if not self.can_access(user_sub): + raise PermissionError(f"用户 {user_sub} 没有权限访问密钥变量 {self.name}") + + return { + "name": self.name, + "type": self.var_type.value, + "value": self.value, # 实际解密值 + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化(保持加密状态)""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, # 加密后的值 + "encryption_key": self._encryption_key, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "SecretVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + instance = cls(metadata, None, data.get("encryption_key")) + instance._value = data["value"] # 直接设置加密值 + return instance + + +class FileVariable(BaseVariable): + """文件变量""" + + def _validate_type(self, value: Any) -> bool: + """验证值是否为文件路径或文件对象""" + return isinstance(value, (str, dict)) and ( + isinstance(value, str) or + (isinstance(value, dict) and "filename" in value and "content" in value) + ) + + def to_string(self) -> str: + """转换为字符串""" + if isinstance(self._value, str): + return self._value + elif isinstance(self._value, dict): + return self._value.get("filename", "unnamed_file") + return "" + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "FileVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +class ArrayVariable(BaseVariable): + """数组变量""" + + def __init__(self, metadata: VariableMetadata, value: Any = None): + """初始化数组变量""" + # 先设置元素类型,因为父类初始化时会调用_validate_type + self._element_type = metadata.var_type.get_array_element_type() + super().__init__(metadata, value or []) + + def _validate_type(self, value: Any) -> bool: + """验证值是否为数组类型,并检查元素类型""" + if not isinstance(value, list): + return False + + # 如果是 array[any],不需要检查元素类型 + if self._element_type is None: + return True + + # 检查所有元素类型 + for item in value: + if not self._validate_element_type(item): + return False + + return True + + def _validate_element_type(self, element: Any) -> bool: + """验证单个元素的类型""" + if self._element_type is None: # array[any] + return True + + type_validators = { + VariableType.STRING: lambda x: isinstance(x, str), + VariableType.NUMBER: lambda x: isinstance(x, (int, float)), + VariableType.BOOLEAN: lambda x: isinstance(x, bool), + VariableType.OBJECT: lambda x: isinstance(x, dict), + VariableType.SECRET: lambda x: isinstance(x, str), + VariableType.FILE: lambda x: isinstance(x, (str, dict)), + } + + validator = type_validators.get(self._element_type) + return validator(element) if validator else False + + def append(self, item: Any) -> None: + """添加元素到数组""" + if not self._validate_element_type(item): + raise TypeError(f"元素类型不匹配,期望: {self._element_type}") + + if self._value is None: + self._value = [] + self._value.append(item) + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def remove(self, item: Any) -> None: + """从数组中移除元素""" + if self._value and item in self._value: + self._value.remove(item) + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def __len__(self) -> int: + """获取数组长度""" + return len(self._value) if self._value else 0 + + def __getitem__(self, index: int) -> Any: + """获取指定索引的元素""" + if self._value is None: + raise IndexError("数组为空") + return self._value[index] + + def __setitem__(self, index: int, value: Any) -> None: + """设置指定索引的元素""" + if not self._validate_element_type(value): + raise TypeError(f"元素类型不匹配,期望: {self._element_type}") + + if self._value is None: + self._value = [] + + # 扩展数组到指定索引 + while len(self._value) <= index: + self._value.append(None) + + self._value[index] = value + from datetime import datetime, UTC + self.metadata.updated_at = datetime.now(UTC) + + def to_string(self) -> str: + """转换为字符串""" + if self._value is None: + return "[]" + try: + return json.dumps(self._value, ensure_ascii=False, indent=2) + except (TypeError, ValueError): + return str(self._value) + + def to_dict(self) -> Dict[str, Any]: + """转换为字典""" + return { + "name": self.name, + "type": self.var_type.value, + "value": self._value, + "scope": self.scope.value, + "element_type": self._element_type.value if self._element_type else None + } + + def serialize(self) -> Dict[str, Any]: + """序列化""" + return { + "metadata": self.metadata.model_dump(), + "value": self._value, + "class": self.__class__.__name__ + } + + @classmethod + def deserialize(cls, data: Dict[str, Any]) -> "ArrayVariable": + """反序列化""" + metadata = VariableMetadata(**data["metadata"]) + return cls(metadata, data["value"]) + + +# 变量类型映射 +VARIABLE_CLASS_MAP = { + VariableType.STRING: StringVariable, + VariableType.NUMBER: NumberVariable, + VariableType.BOOLEAN: BooleanVariable, + VariableType.OBJECT: ObjectVariable, + VariableType.SECRET: SecretVariable, + VariableType.FILE: FileVariable, + VariableType.ARRAY_ANY: ArrayVariable, + VariableType.ARRAY_STRING: ArrayVariable, + VariableType.ARRAY_NUMBER: ArrayVariable, + VariableType.ARRAY_OBJECT: ArrayVariable, + VariableType.ARRAY_FILE: ArrayVariable, + VariableType.ARRAY_BOOLEAN: ArrayVariable, + VariableType.ARRAY_SECRET: ArrayVariable, +} + + +def create_variable(metadata: VariableMetadata, value: Any = None) -> BaseVariable: + """根据类型创建变量实例 + + Args: + metadata: 变量元数据 + value: 变量值 + + Returns: + BaseVariable: 创建的变量实例 + """ + variable_class = VARIABLE_CLASS_MAP.get(metadata.var_type) + if not variable_class: + raise ValueError(f"不支持的变量类型: {metadata.var_type}") + + return variable_class(metadata, value) \ No newline at end of file diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1afb018663713a3bf4c0b9b62436e59d4..7daa27969b90fcdc2767d3983a82dd1e6bcefb5a 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -114,6 +114,12 @@ class CheckConfig(BaseModel): words_list: str = Field(description="敏感词列表文件路径") +class SandboxConfig(BaseModel): + """代码沙箱配置""" + + sandbox_service: str = Field(description="代码沙箱服务地址") + + class ExtraConfig(BaseModel): """额外配置""" @@ -134,4 +140,5 @@ class ConfigModel(BaseModel): function_call: FunctionCallConfig security: SecurityConfig check: CheckConfig + sandbox: SandboxConfig extra: ExtraConfig diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84d4805bcd502d9d4ca0f9ba3c49a7bb4c..5629d546d04656381e89060830564f6ff2c0f1ac 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -142,6 +142,7 @@ class SpecialCallType(str, Enum): FACTS = "Facts" SLOT = "Slot" LLM = "LLM" + DIRECT_REPLY = "DirectReply" START = "start" END = "end" CHOICE = "choice" diff --git a/apps/services/flow.py b/apps/services/flow.py index 9275fc60c75199e8e17ebcba8f164083190b7211..aae9c9cd436473c019005d623d69fafcb3114091 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -257,15 +257,21 @@ class FlowManager: debug=flow_config.debug, ) for node_id, node_config in flow_config.steps.items(): - input_parameters = node_config.params - if node_config.node not in ("Empty"): - _, output_parameters = await NodeManager.get_node_params(node_config.node) + # 对于Code节点,直接使用保存的完整params作为parameters + if node_config.type == "Code": + parameters = node_config.params # 直接使用保存的完整params else: - output_parameters = {} - parameters = { - "input_parameters": input_parameters, - "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(), - } + # 其他节点:使用原有逻辑 + input_parameters = node_config.params + if node_config.node not in ("Empty"): + _, output_parameters = await NodeManager.get_node_params(node_config.node) + else: + output_parameters = {} + parameters = { + "input_parameters": input_parameters, + "output_parameters": Slot(output_parameters).extract_type_desc_from_schema(), + } + node_item = NodeItem( stepId=node_id, nodeId=node_config.node, @@ -275,8 +281,7 @@ class FlowManager: editable=True, callId=node_config.type, parameters=parameters, - position=PositionItem( - x=node_config.pos.x, y=node_config.pos.y), + position=PositionItem(x=node_config.pos.x, y=node_config.pos.y), ) flow_item.nodes.append(node_item) @@ -384,13 +389,19 @@ class FlowManager: debug=flow_item.debug, ) for node_item in flow_item.nodes: + # 对于Code节点,保存完整的parameters;其他节点只保存input_parameters + if node_item.call_id == "Code": + params = node_item.parameters # 保存完整的parameters(包含input_parameters、output_parameters以及code配置) + else: + params = node_item.parameters.get("input_parameters", {}) # 其他节点只保存input_parameters + flow_config.steps[node_item.step_id] = Step( type=node_item.call_id, node=node_item.node_id, name=node_item.name, description=node_item.description, pos=node_item.position, - params=node_item.parameters.get("input_parameters", {}), + params=params, ) for edge_item in flow_item.edges: edge_from = edge_item.source_node diff --git a/assets/.config.example.toml b/assets/.config.example.toml index 61f13e49a893e8a698df3113148226c8a6a7463a..9068706ee0d5743b3212258af283177b281c4cd4 100644 --- a/assets/.config.example.toml +++ b/assets/.config.example.toml @@ -62,5 +62,8 @@ temperature = 0.7 enable = false words_list = '' +[sandbox] +sandbox_service = 'http://127.0.0.1:8000' + [extra] sql_url = 'http://127.0.0.1:9015' diff --git a/deploy/chart/euler_copilot/configs/framework/config.toml b/deploy/chart/euler_copilot/configs/framework/config.toml index 64a1a6a03600b1cfc48c504f32c1e8c9ce4350cb..50b3b7ed7e04bca5d1963ed7e1c7ec2d81e62b6e 100644 --- a/deploy/chart/euler_copilot/configs/framework/config.toml +++ b/deploy/chart/euler_copilot/configs/framework/config.toml @@ -62,5 +62,8 @@ temperature = {{ default 0.7 .Values.models.functionCall.temperature }} enable = false words_list = "" +[sandbox] +sandbox_service = 'http://euler-copilot-sandbox-service.{{ .Release.Namespace }}.svc.cluster.local:8000' + [extra] sql_url = '' diff --git a/euler_copilot_framework.egg-info/PKG-INFO b/euler_copilot_framework.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..768d3f89fce727f40aaee643151eddd45dc3fd9d --- /dev/null +++ b/euler_copilot_framework.egg-info/PKG-INFO @@ -0,0 +1,38 @@ +Metadata-Version: 2.4 +Name: euler-copilot-framework +Version: 0.9.6 +Summary: EulerCopilot 后端服务 +Requires-Python: ==3.11.6 +License-File: LICENSE +Requires-Dist: aiofiles==24.1.0 +Requires-Dist: asyncer==0.0.8 +Requires-Dist: asyncpg==0.30.0 +Requires-Dist: cryptography==44.0.2 +Requires-Dist: fastapi==0.115.12 +Requires-Dist: httpx==0.28.1 +Requires-Dist: httpx-sse==0.4.0 +Requires-Dist: jinja2==3.1.6 +Requires-Dist: jionlp==1.5.20 +Requires-Dist: jsonschema==4.23.0 +Requires-Dist: lancedb==0.21.2 +Requires-Dist: mcp==1.9.4 +Requires-Dist: minio==7.2.15 +Requires-Dist: ollama==0.5.1 +Requires-Dist: openai==1.91.0 +Requires-Dist: pandas==2.2.3 +Requires-Dist: pgvector==0.4.1 +Requires-Dist: pillow==10.3.0 +Requires-Dist: pydantic==2.11.7 +Requires-Dist: pymongo==4.12.1 +Requires-Dist: python-jsonpath==1.3.0 +Requires-Dist: python-magic==0.4.27 +Requires-Dist: python-multipart==0.0.20 +Requires-Dist: pytz==2025.2 +Requires-Dist: pyyaml==6.0.2 +Requires-Dist: rich==13.9.4 +Requires-Dist: sqids==0.5.1 +Requires-Dist: sqlalchemy==2.0.41 +Requires-Dist: tiktoken==0.9.0 +Requires-Dist: toml==0.10.2 +Requires-Dist: uvicorn==0.34.0 +Dynamic: license-file diff --git a/euler_copilot_framework.egg-info/SOURCES.txt b/euler_copilot_framework.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..3beeaae71ae08e9482a4157601718053efd51c51 --- /dev/null +++ b/euler_copilot_framework.egg-info/SOURCES.txt @@ -0,0 +1,12 @@ +LICENSE +README.md +pyproject.toml +apps/__init__.py +apps/constants.py +apps/exceptions.py +apps/main.py +euler_copilot_framework.egg-info/PKG-INFO +euler_copilot_framework.egg-info/SOURCES.txt +euler_copilot_framework.egg-info/dependency_links.txt +euler_copilot_framework.egg-info/requires.txt +euler_copilot_framework.egg-info/top_level.txt \ No newline at end of file diff --git a/euler_copilot_framework.egg-info/dependency_links.txt b/euler_copilot_framework.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/euler_copilot_framework.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/euler_copilot_framework.egg-info/requires.txt b/euler_copilot_framework.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..de92c758f82c631f9fd0ef4277e15568d0e2d6bb --- /dev/null +++ b/euler_copilot_framework.egg-info/requires.txt @@ -0,0 +1,31 @@ +aiofiles==24.1.0 +asyncer==0.0.8 +asyncpg==0.30.0 +cryptography==44.0.2 +fastapi==0.115.12 +httpx==0.28.1 +httpx-sse==0.4.0 +jinja2==3.1.6 +jionlp==1.5.20 +jsonschema==4.23.0 +lancedb==0.21.2 +mcp==1.9.4 +minio==7.2.15 +ollama==0.5.1 +openai==1.91.0 +pandas==2.2.3 +pgvector==0.4.1 +pillow==10.3.0 +pydantic==2.11.7 +pymongo==4.12.1 +python-jsonpath==1.3.0 +python-magic==0.4.27 +python-multipart==0.0.20 +pytz==2025.2 +pyyaml==6.0.2 +rich==13.9.4 +sqids==0.5.1 +sqlalchemy==2.0.41 +tiktoken==0.9.0 +toml==0.10.2 +uvicorn==0.34.0 diff --git a/euler_copilot_framework.egg-info/top_level.txt b/euler_copilot_framework.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea5a6165b61e9169f28b337d9757f02c7af34ba1 --- /dev/null +++ b/euler_copilot_framework.egg-info/top_level.txt @@ -0,0 +1 @@ +apps diff --git a/pyproject.toml b/pyproject.toml index 681ea9fb4552650da9148a22dd55d2063dca2410..2f8222898360e76094f66a8961065394ee94902d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,3 +51,10 @@ dev = [ "sphinx==8.2.3", "sphinx-rtd-theme==3.0.2", ] + +[build-system] +requires = ["setuptools>=65"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["apps"] # 仅包含apps目录 \ No newline at end of file diff --git a/tests/common/test_config.py b/tests/common/test_config.py index 6c77141ca78e662ab12a8ec3c1ae017f3de6fced..9d90da2efd8b6f81a14136a6787ba38e157ebd16 100644 --- a/tests/common/test_config.py +++ b/tests/common/test_config.py @@ -68,6 +68,7 @@ MOCK_CONFIG_DATA: dict[str, Any] = { "jwt_key": "test_jwt_key", }, "check": {"enable": False, "words_list": ""}, + "sandbox": {"sandbox_service": "http://localhost:8000"}, "extra": {"sql_url": "http://localhost"}, }