diff --git a/apps/main.py b/apps/main.py index c4ca2bfb116db11624f4e4fbbe4e876a86e5f1eb..c26e5e471bd45a066c0af2d13b25a4028de02b12 100644 --- a/apps/main.py +++ b/apps/main.py @@ -36,6 +36,7 @@ from apps.routers import ( record, service, user, + parameter ) from apps.scheduler.pool.pool import Pool @@ -66,6 +67,7 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(parameter.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 538c25568d828fb64076fa40de6d1774bbce2af5..6edbe2e142cb6589be8947c28ae2eb4a7287baa1 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -5,13 +5,10 @@ from fastapi.responses import JSONResponse from apps.dependency import get_user from apps.dependency.user import verify_user -from apps.scheduler.call.choice.choice import Choice +from apps.services.parameter import ParameterManager from apps.schemas.response_data import ( - FlowStructureGetMsg, - FlowStructureGetRsp, - GetParamsMsg, - GetParamsRsp, - ResponseData, + GetOperaRsp, + GetParamsRsp ) from apps.services.application import AppManager from apps.services.flow import FlowManager @@ -25,10 +22,7 @@ router = APIRouter( ) -@router.get("", response_model={ - status.HTTP_403_FORBIDDEN: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - },) +@router.get("", response_model=GetParamsRsp) async def get_parameters( user_sub: Annotated[str, Depends(get_user)], app_id: Annotated[str, Query(alias="appId")], @@ -39,105 +33,45 @@ async def get_parameters( if not await AppManager.validate_user_app_access(user_sub, app_id): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=FlowStructureGetRsp( + content=GetParamsRsp( code=status.HTTP_403_FORBIDDEN, message="用户没有权限访问该流", - result=FlowStructureGetMsg(), + result=[], ).model_dump(exclude_none=True, by_alias=True), ) flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) if not flow: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( + content=GetParamsRsp( code=status.HTTP_404_NOT_FOUND, message="未找到该流", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) - result = await FlowManager.get_params_by_flow_and_step_id(flow, step_id) - if not result: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( - code=status.HTTP_404_NOT_FOUND, - message="未找到该节点", - result={}, + result=[], ).model_dump(exclude_none=True, by_alias=True), ) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) return JSONResponse( status_code=status.HTTP_200_OK, content=GetParamsRsp( code=status.HTTP_200_OK, message="获取参数成功", - result=GetParamsMsg(result=result), + result=result ).model_dump(exclude_none=True, by_alias=True), ) -async def operate_parameters(operate: str) -> list[str] | None: - """ - 根据操作类型获取对应的操作符参数列表 - - Args: - operate: 操作类型,支持 'int', 'str', 'bool' - - Returns: - 对应的操作符参数列表,若类型不支持则返回None - """ - string = [ - "equal", - "not_equal", - "great",#长度大于 - "great_equals",#长度大于等于 - "less",#长度小于 - "less_equals",#长度小于等于 - "greater", - "greater_equals", - "smaller", - "smaller_equals", - ] - integer = [ - "equal", - "not_equal", - "great", - "great_equals", - "less", - "less_equals", - ] - boolen = ["equal", "not_equal", "is_empty", "not_empty"] - if operate in string: - return string - if operate in integer: - return integer - if operate in boolen: - return boolen - return None - -@router.get("/operate", response_model={ - status.HTTP_200_OK: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - },) +@router.get("/operate", response_model=GetOperaRsp) async def get_operate_parameters( user_sub: Annotated[str, Depends(get_user)], - operate: Annotated[str, Query(alias="operate")], + param_type: Annotated[str, Query(alias="ParamType")], ) -> JSONResponse: """Get parameters for node choice.""" - result = await operate_parameters(operate) - if not result: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content=ResponseData( - code=status.HTTP_404_NOT_FOUND, - message="未找到该符号", - result=[], - ).model_dump(exclude_none=True, by_alias=True), - ) + result = await ParameterManager.get_operate_and_bind_type(param_type) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( + content=GetOperaRsp( code=status.HTTP_200_OK, - message="获取参数成功", - result=result, + message="获取操作成功", + result=result ).model_dump(exclude_none=True, by_alias=True), - ) \ No newline at end of file + ) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2ee8b862885b0d88e519bf00a1678a49863df2bd..c5a6f0549c6891bcdf74052ddb8152fb550653b2 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -8,7 +8,7 @@ from apps.scheduler.call.mcp.mcp import MCP from apps.scheduler.call.rag.rag import RAG from apps.scheduler.call.sql.sql import SQL from apps.scheduler.call.suggest.suggest import Suggestion - +from apps.scheduler.call.choice.choice import Choice # 只包含需要在编排界面展示的工具 __all__ = [ "API", @@ -18,4 +18,5 @@ __all__ = [ "SQL", "Graph", "Suggestion", + "Choice" ] diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 24df0daea03a47bad397d17d2986dc97d85390f9..8cab828889db4b0d97ff87318760adb507a9f3b1 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,6 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" +import ast +import copy import logging from collections.abc import AsyncGenerator from typing import Any @@ -9,11 +11,13 @@ from pydantic import Field from apps.scheduler.call.choice.condition_handler import ConditionHandler from apps.scheduler.call.choice.schema import ( + Condition, ChoiceBranch, ChoiceInput, ChoiceOutput, Logic, ) +from apps.schemas.parameters import Type from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( @@ -30,17 +34,13 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" to_user: bool = Field(default=False) - choices: list[ChoiceBranch] = Field(description="分支", default=[]) + choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(), + ChoiceBranch(conditions=[Condition()], is_default=False)]) @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") - - def _raise_value_error(self, msg: str) -> None: - """统一处理 ValueError 异常抛出""" - logger.warning(msg) - raise ValueError(msg) + return CallInfo(name="选择器", description="使用大模型或使用程序做出判断") async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: """替换choices中的系统变量""" @@ -51,31 +51,76 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): # 验证逻辑运算符 if choice.logic not in [Logic.AND, Logic.OR]: msg = f"无效的逻辑运算符: {choice.logic}" - self._raise_value_error(msg) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue valid_conditions = [] - for condition in choice.conditions: + for i in range(len(choice.conditions)): + condition = copy.deepcopy(choice.conditions[i]) # 处理左值 - if condition.left.step_id: - condition.left.value = self._extract_history_variables(condition.left.step_id, call_vars.history) + if condition.left.step_id is not None: + condition.left.value = self._extract_history_variables( + condition.left.step_id+'/'+condition.left.value, call_vars.history) # 检查历史变量是否成功提取 - if condition.left.value is None: - msg = f"步骤 {condition.left.step_id} 的历史变量不存在" - self._raise_value_error(msg) + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.left.value, condition.left.type): + msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue else: msg = "左侧变量缺少step_id" - self._raise_value_error(msg) - - valid_conditions.append(condition) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + # 处理右值 + if condition.right.step_id is not None: + condition.right.value = self._extract_history_variables( + condition.right.step_id+'/'+condition.right.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.right.value is None: + msg = f"步骤 {condition.right.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + # 如果右值没有step_id,尝试从call_vars中获取 + right_value_type = await ConditionHandler.get_value_type_from_operate( + condition.operate) + if right_value_type is None: + msg = f"不支持的运算符: {condition.operate}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if condition.right.type != right_value_type: + msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if right_value_type == Type.STRING: + condition.right.value = str(condition.right.value) + else: + condition.right.value = ast.literal_eval(condition.right.value) + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + valid_conditions.append(condition) # 如果所有条件都无效,抛出异常 if not valid_conditions: msg = "分支没有有效条件" - self._raise_value_error(msg) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue # 更新有效条件 choice.conditions = valid_conditions - valid_choices.append(choice.dict()) + valid_choices.append(choice) except ValueError as e: logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) @@ -95,13 +140,11 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """执行Choice工具""" # 解析输入数据 data = ChoiceInput(**input_data) - ret: CallOutputChunk = CallOutputChunk( - type=CallOutputType.DATA, - content=None, - ) - condition_handler = ConditionHandler() try: - ret.content = condition_handler.handler(data.choices) - yield ret + branch_id = ConditionHandler.handler(data.choices) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True), + ) except Exception as e: raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index feab5c6470ec2d0868c7a72984a2d684665507c3..7542f2944eedff72b820a79725d64b5ddeba5a11 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -1,19 +1,79 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """处理条件分支的工具""" + import logging from pydantic import BaseModel -from apps.scheduler.call.choice.schema import ChoiceBranch, ChoiceOutput, Condition, Logic, Operator, Type, Value +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) + +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + Condition, + Logic, + Value +) logger = logging.getLogger(__name__) class ConditionHandler(BaseModel): """条件分支处理器""" + @staticmethod + async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | + BoolOperate | DictOperate) -> Type: + """获取右值的类型""" + if isinstance(operate, NumberOperate): + return Type.NUMBER + if operate in [ + StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS, + StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]: + return Type.STRING + if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN, + StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN, + StringOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]: + return Type.LIST + if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]: + return Type.STRING + if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN, + ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN, + ListOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]: + return Type.BOOL + if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]: + return Type.DICT + if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]: + return Type.STRING + return None + + @staticmethod + def check_value_type(value: Value, expected_type: Type) -> bool: + """检查值的类型是否符合预期""" + if expected_type == Type.STRING and isinstance(value.value, str): + return True + if expected_type == Type.NUMBER and isinstance(value.value, (int, float)): + return True + if expected_type == Type.LIST and isinstance(value.value, list): + return True + if expected_type == Type.DICT and isinstance(value.value, dict): + return True + if expected_type == Type.BOOL and isinstance(value.value, bool): + return True + return False - def handler(self, choices: list[ChoiceBranch]) -> ChoiceOutput: + @staticmethod + def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" default_branch = [c for c in choices if c.is_default] @@ -22,7 +82,7 @@ class ConditionHandler(BaseModel): if block_judgement.is_default: continue for condition in block_judgement.conditions: - result = self._judge_condition(condition) + result = ConditionHandler._judge_condition(condition) results.append(result) if block_judgement.logic == Logic.AND: final_result = all(results) @@ -30,58 +90,55 @@ class ConditionHandler(BaseModel): final_result = any(results) if final_result: - return { - "branch_id": block_judgement.branch_id, - "message": f"选择分支:{block_judgement.branch_id}", - } + return block_judgement.branch_id # 如果没有匹配的分支,选择默认分支 if default_branch: - return { - "branch_id": default_branch[0].branch_id, - "message": f"选择默认分支:{default_branch[0].branch_id}", - } - return { - "branch_id": "", - "message": "没有匹配的分支,且没有默认分支", - } - - def _judge_condition(self, condition: Condition) -> bool: + return default_branch[0].branch_id + return "" + + @staticmethod + def _judge_condition(condition: Condition) -> bool: """ 判断条件是否成立。 Args: - condition (Condition): 'left', 'operator', 'right', 'type' + condition (Condition): 'left', 'operate', 'right', 'type' Returns: bool """ left = condition.left - operator = condition.operator + operate = condition.operate right = condition.right value_type = condition.type result = None if value_type == Type.STRING: - result = self._judge_string_condition(left, operator, right) - elif value_type == Type.INT: - result = self._judge_int_condition(left, operator, right) + result = ConditionHandler._judge_string_condition(left, operate, right) + elif value_type == Type.NUMBER: + result = ConditionHandler._judge_int_condition(left, operate, right) elif value_type == Type.BOOL: - result = self._judge_bool_condition(left, operator, right) + result = ConditionHandler._judge_bool_condition(left, operate, right) + elif value_type == Type.LIST: + result = ConditionHandler._judge_list_condition(left, operate, right) + elif value_type == Type.DICT: + result = ConditionHandler._judge_dict_condition(left, operate, right) else: logger.error("不支持的数据类型: %s", value_type) msg = f"不支持的数据类型: {value_type}" raise ValueError(msg) return result - def _judge_string_condition(self, left: Value, operator: Operator, right: Value) -> bool: + @staticmethod + def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool: """ 判断字符串类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -95,39 +152,41 @@ class ConditionHandler(BaseModel): raise TypeError(msg) right_value = right.value result = False - if operator == Operator.EQUAL: - result = left_value == right_value - elif operator == Operator.NEQUAL: - result = left_value != right_value - elif operator == Operator.GREAT: - result = len(left_value) > len(right_value) - elif operator == Operator.GREAT_EQUALS: - result = len(left_value) >= len(right_value) - elif operator == Operator.LESS: - result = len(left_value) < len(right_value) - elif operator == Operator.LESS_EQUALS: - result = len(left_value) <= len(right_value) - elif operator == Operator.GREATER: - result = left_value > right_value - elif operator == Operator.GREATER_EQUALS: - result = left_value >= right_value - elif operator == Operator.SMALLER: - result = left_value < right_value - elif operator == Operator.SMALLER_EQUALS: - result = left_value <= right_value - elif operator == Operator.CONTAINS: - result = right_value in left_value - elif operator == Operator.NOT_CONTAINS: - result = right_value not in left_value - return result + if operate == StringOperate.EQUAL: + return left_value == right_value + elif operate == StringOperate.NOT_EQUAL: + return left_value != right_value + elif operate == StringOperate.CONTAINS: + return right_value in left_value + elif operate == StringOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == StringOperate.STARTS_WITH: + return left_value.startswith(right_value) + elif operate == StringOperate.ENDS_WITH: + return left_value.endswith(right_value) + elif operate == StringOperate.REGEX_MATCH: + import re + return bool(re.match(right_value, left_value)) + elif operate == StringOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == StringOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == StringOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False - def _judge_int_condition(self, left: Value, operator: Operator, right: Value) -> bool: # noqa: PLR0911 + @staticmethod + def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911 """ - 判断整数类型的条件。 + 判断数字类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -135,32 +194,33 @@ class ConditionHandler(BaseModel): """ left_value = left.value - if not isinstance(left_value, int): - logger.error("左值不是整数类型: %s", left_value) - msg = "左值必须是整数类型" + if not isinstance(left_value, (int, float)): + logger.error("左值不是数字类型: %s", left_value) + msg = "左值必须是数字类型" raise TypeError(msg) right_value = right.value - if operator == Operator.EQUAL: + if operate == NumberOperate.EQUAL: return left_value == right_value - if operator == Operator.NEQUAL: + elif operate == NumberOperate.NOT_EQUAL: return left_value != right_value - if operator == Operator.GREAT: + elif operate == NumberOperate.GREATER_THAN: return left_value > right_value - if operator == Operator.GREAT_EQUALS: - return left_value >= right_value - if operator == Operator.LESS: + elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004 return left_value < right_value - if operator == Operator.LESS_EQUALS: + elif operate == NumberOperate.GREATER_THAN_OR_EQUAL: + return left_value >= right_value + elif operate == NumberOperate.LESS_THAN_OR_EQUAL: return left_value <= right_value return False - def _judge_bool_condition(self, left: Value, operator: Operator, right: Value) -> bool: + @staticmethod + def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool: """ 判断布尔类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -173,12 +233,82 @@ class ConditionHandler(BaseModel): msg = "左值必须是布尔类型" raise TypeError(msg) right_value = right.value - if operator == Operator.EQUAL: + if operate == BoolOperate.EQUAL: + return left_value == right_value + elif operate == BoolOperate.NOT_EQUAL: + return left_value != right_value + elif operate == BoolOperate.IS_EMPTY: + return not left_value + elif operate == BoolOperate.NOT_EMPTY: + return left_value + return False + + @staticmethod + def _judge_list_condition(left: Value, operate: ListOperate, right: Value): + """ + 判断列表类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, list): + logger.error("左值不是列表类型: %s", left_value) + msg = "左值必须是列表类型" + raise TypeError(msg) + right_value = right.value + if operate == ListOperate.EQUAL: + return left_value == right_value + elif operate == ListOperate.NOT_EQUAL: + return left_value != right_value + elif operate == ListOperate.CONTAINS: + return right_value in left_value + elif operate == ListOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == ListOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == ListOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == ListOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_dict_condition(left: Value, operate: DictOperate, right: Value): + """ + 判断字典类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, dict): + logger.error("左值不是字典类型: %s", left_value) + msg = "左值必须是字典类型" + raise TypeError(msg) + right_value = right.value + if operate == DictOperate.EQUAL: return left_value == right_value - if operator == Operator.NEQUAL: + elif operate == DictOperate.NOT_EQUAL: return left_value != right_value - if operator == Operator.IS_EMPTY: - return left_value == "" - if operator == Operator.NOT_EMPTY: - return left_value != "" + elif operate == DictOperate.CONTAINS_KEY: + return right_value in left_value + elif operate == DictOperate.NOT_CONTAINS_KEY: + return right_value not in left_value return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index ed1c628cf002112750165dc3c23d4f493eeeafbc..b95b166879608bc431525958e7e120ad93c5e5a3 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,30 +1,20 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +import uuid from enum import Enum from pydantic import BaseModel, Field +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.scheduler.call.core import DataBase -class Operator(str, Enum): - """Choice Call支持的运算符""" - - EQUAL = "equal" - NEQUAL = "not_equal" - GREAT = "great" - GREAT_EQUALS = "great_equals" - LESS = "less" - LESS_EQUALS = "less_equals" - # string - CONTAINS = "contains" - NOT_CONTAINS = "not_contains" - GREATER = "greater" - GREATER_EQUALS = "greater_equals" - SMALLER = "smaller" - SMALLER_EQUALS = "smaller_equals" - # bool - IS_EMPTY = "is_empty" - NOT_EMPTY = "not_empty" class Logic(str, Enum): @@ -34,38 +24,31 @@ class Logic(str, Enum): OR = "or" -class Type(str, Enum): - """Choice 工具支持的类型""" - - STRING = "string" - INT = "int" - BOOL = "bool" - - -class Value(BaseModel): +class Value(DataBase): """值的结构""" - step_id: str = Field(description="步骤id", default="") - value: str | int | bool = Field(description="值", default=None) + step_id: str | None = Field(description="步骤id", default=None) + type: Type | None = Field(description="值的类型", default=None) + value: str | float | int | bool | list | dict | None = Field(description="值", default=None) -class Condition(BaseModel): +class Condition(DataBase): """单个条件""" - type: Type = Field(description="值的类型", default=Type.STRING) - left: Value = Field(description="左值") - right: Value = Field(description="右值") - operator: Operator = Field(description="运算符", default="equal") - id: int = Field(description="条件ID") + left: Value = Field(description="左值", default=Value()) + right: Value = Field(description="右值", default=Value()) + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field( + description="运算符", default=None) + id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4())) -class ChoiceBranch(BaseModel): +class ChoiceBranch(DataBase): """子分支""" - branch_id: str = Field(description="分支ID", default="") + branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4())) logic: Logic = Field(description="逻辑运算符", default=Logic.AND) conditions: list[Condition] = Field(description="条件列表", default=[]) - is_default: bool = Field(description="是否为默认分支", default=False) + is_default: bool = Field(description="是否为默认分支", default=True) class ChoiceInput(DataBase): @@ -76,3 +59,5 @@ class ChoiceInput(DataBase): class ChoiceOutput(DataBase): """Choice Call的输出""" + + branch_id: str = Field(description="分支ID", default="") diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 2b1cbba83b9345e8df84e8f23f893137cb46532b..af28c6a34150d26daa12e0ae3fc200fa26a82cda 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -76,21 +76,18 @@ class CoreCall(BaseModel): extra="allow", ) - def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: """初始化子类""" super().__init_subclass__(**kwargs) cls.input_model = input_model cls.output_model = output_model - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" err = "[CoreCall] 必须手动实现info方法" raise NotImplementedError(err) - @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" @@ -120,7 +117,6 @@ class CoreCall(BaseModel): summary=executor.task.runtime.summary, ) - @staticmethod def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any: """ @@ -131,18 +127,16 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") + if len(split_path) < 2: + err = f"[CoreCall] 路径格式错误: {path}" + logger.error(err) + return None if split_path[0] not in history: err = f"[CoreCall] 步骤{split_path[0]}不存在" logger.error(err) - raise CallError( - message=err, - data={ - "step_id": split_path[0], - }, - ) - + return None data = history[split_path[0]].output_data - for key in split_path[1:]: + for key in split_path[2:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) @@ -156,7 +150,6 @@ class CoreCall(BaseModel): data = data[key] return data - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """实例化Call类""" @@ -170,36 +163,30 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj - async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) - async def _init(self, call_vars: CallVars) -> DataBase: """初始化Call类,并返回Call的输入""" err = "[CoreCall] 初始化方法必须手动实现" raise NotImplementedError(err) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的流式输出方法""" yield CallOutputChunk(type=CallOutputType.TEXT, content="") - async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" async for chunk in self._exec(input_data): yield chunk await self._after_exec(input_data) - async def _llm(self, messages: list[dict[str, Any]]) -> str: """Call可直接使用的LLM非流式调用""" result = "" @@ -210,7 +197,6 @@ class CoreCall(BaseModel): self.output_tokens = llm.output_tokens return result - async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel: """Call可直接使用的JSON生成""" json = FunctionLLM() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index d8d22c46f5ef66dd707f0750ce56ea33bdb20351..a86ec4ac063c3cda5c6444ae17d7af5a6f6e7f46 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -70,7 +69,6 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" # 创建步骤Runner @@ -90,7 +88,6 @@ class FlowExecutor(BaseExecutor): # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: @@ -102,7 +99,6 @@ class FlowExecutor(BaseExecutor): # 执行Step await self._invoke_runner(queue_item) - async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" next_ids = [] @@ -111,15 +107,14 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] if self.task.state.step_name == "Choice": # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"].get("branch_id", None) + branch_id = self.task.context[-1]["output_data"]["branch_id"] if branch_id: self.task.state.step_id = self.task.state.step_id + "." + branch_id logger.info("[FlowExecutor] 分支ID:%s", branch_id) @@ -127,7 +122,7 @@ class FlowExecutor(BaseExecutor): logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") return [] - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -146,7 +141,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -159,8 +153,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -179,7 +173,7 @@ class FlowExecutor(BaseExecutor): # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -192,7 +186,7 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 506f3bb1021bdadeb4c894b56594d3a5922f7add..f3aeb82c37b8430785e13daad4fb97ed020603e2 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -170,7 +169,6 @@ class StepExecutor(BaseExecutor): self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor): return content - async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 4caab4d2dc945ec5ac14b7769770f94c39980a9f..89433cade929ef6917664450bb3645500ed2df5a 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -12,6 +12,8 @@ from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema.validators import extend +from apps.schemas.response_data import ParamsNode +from apps.scheduler.call.choice.schema import Type from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -221,6 +223,45 @@ class Slot: return data return _extract_type_desc(self._schema) + def get_params_node_from_schema(self, root: str = "") -> ParamsNode: + """从JSON Schema中提取ParamsNode""" + def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode: + """递归提取ParamsNode""" + if "type" not in schema_node: + return None + + param_type = schema_node["type"] + if param_type == "object": + param_type = Type.DICT + elif param_type == "array": + param_type = Type.LIST + elif param_type == "string": + param_type = Type.STRING + elif param_type == "number": + param_type = Type.NUMBER + elif param_type == "boolean": + param_type = Type.BOOL + else: + logger.warning(f"[Slot] 不支持的参数类型: {param_type}") + return None + sub_params = [] + + if param_type == "object" and "properties" in schema_node: + for key, value in schema_node["properties"].items(): + sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + else: + # 对于非对象类型,直接返回空子参数 + sub_params = None + return ParamsNode(paramName=name, + paramPath=path, + paramType=param_type, + subParams=sub_params) + try: + return _extract_params_node(self._schema, name=root, path=root) + except Exception as e: + logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") + return None + def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """将JSON Schema扁平化""" result = {} @@ -276,7 +317,6 @@ class Slot: logger.exception("[Slot] 错误schema不合法: %s", error.schema) return {}, [] - def _assemble_patch( self, key: str, @@ -329,7 +369,6 @@ class Slot: logger.info("[Slot] 组装patch: %s", patch_list) return patch_list - def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]: """将用户手动填充的参数专为真实JSON""" json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 99bcccde868cbb4d1e61d792fee02a4d8aee08a8..e91f5f75b836e61bcf45680c592f27b520683ff4 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -129,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - no_auth: NoauthConfig + no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig()) deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..bd908d2375415c798f4804c4eabde797f6a4f7a0 --- /dev/null +++ b/apps/schemas/parameters.py @@ -0,0 +1,69 @@ +from enum import Enum + + +class NumberOperate(str, Enum): + """Choice 工具支持的数字运算符""" + + EQUAL = "number_equal" + NOT_EQUAL = "number_not_equal" + GREATER_THAN = "number_greater_than" + LESS_THAN = "number_less_than" + GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal" + LESS_THAN_OR_EQUAL = "number_less_than_or_equal" + + +class StringOperate(str, Enum): + """Choice 工具支持的字符串运算符""" + + EQUAL = "string_equal" + NOT_EQUAL = "string_not_equal" + CONTAINS = "string_contains" + NOT_CONTAINS = "string_not_contains" + STARTS_WITH = "string_starts_with" + ENDS_WITH = "string_ends_with" + LENGTH_EQUAL = "string_length_equal" + LENGTH_GREATER_THAN = "string_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal" + LENGTH_LESS_THAN = "string_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal" + REGEX_MATCH = "string_regex_match" + + +class ListOperate(str, Enum): + """Choice 工具支持的列表运算符""" + + EQUAL = "list_equal" + NOT_EQUAL = "list_not_equal" + CONTAINS = "list_contains" + NOT_CONTAINS = "list_not_contains" + LENGTH_EQUAL = "list_length_equal" + LENGTH_GREATER_THAN = "list_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal" + LENGTH_LESS_THAN = "list_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal" + + +class BoolOperate(str, Enum): + """Choice 工具支持的布尔运算符""" + + EQUAL = "bool_equal" + NOT_EQUAL = "bool_not_equal" + + +class DictOperate(str, Enum): + """Choice 工具支持的字典运算符""" + + EQUAL = "dict_equal" + NOT_EQUAL = "dict_not_equal" + CONTAINS_KEY = "dict_contains_key" + NOT_CONTAINS_KEY = "dict_not_contains_key" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + NUMBER = "number" + LIST = "list" + DICT = "dict" + BOOL = "bool" diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 20d7ad9b03c3dc58cf083240a88304cafa8ca9b8..b1dc77b75ba11943e149dc8c471bc3141d684442 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -14,6 +14,14 @@ from apps.schemas.flow_topology import ( NodeServiceItem, PositionItem, ) +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType from apps.schemas.record import RecordData from apps.schemas.user import UserInfo @@ -629,20 +637,41 @@ class ListLLMRsp(ResponseData): result: list[LLMProviderInfo] = Field(default=[], title="Result") -class Params(BaseModel): + +class ParamsNode(BaseModel): """参数数据结构""" + param_name: str = Field(..., description="参数名称", alias="paramName") + param_path: str = Field(..., description="参数路径", alias="paramPath") + param_type: Type = Field(..., description="参数类型", alias="paramType") + sub_params: list["ParamsNode"] | None = Field( + default=None, description="子参数列表", alias="subParams" + ) - id: str = Field(..., description="StepID") - name: str = Field(..., description="Step名称") - parameters: dict[str, Any] = Field(..., description="参数") - operate: str = Field(..., description="比较符") -class GetParamsMsg(BaseModel): - """GET /api/params 返回数据结构""" +class StepParams(BaseModel): + """参数数据结构""" + step_id: str = Field(..., description="步骤ID", alias="stepId") + name: str = Field(..., description="Step名称") + params_node: ParamsNode | None = Field( + default=None, description="参数节点", alias="paramsNode") - result: list[Params] = Field(..., title="Result") class GetParamsRsp(ResponseData): """GET /api/params 返回数据结构""" - result: GetParamsMsg \ No newline at end of file + result: list[StepParams] = Field( + default=[], description="参数列表", alias="result" + ) + + +class OperateAndBindType(BaseModel): + """操作和绑定类型数据结构""" + + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型") + bind_type: Type = Field(description="绑定类型") + + +class GetOperaRsp(ResponseData): + """GET /api/operate 返回数据结构""" + + result: list[OperateAndBindType] = Field(..., title="Result") diff --git a/apps/services/flow.py b/apps/services/flow.py index cb8ad57fd9b3e2e8535d346f82f9cb0e71d46cdc..c9fd86fd8f789546ba459354fd0d724113a7479b 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,6 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager -from apps.schemas.response_data import Params logger = logging.getLogger(__name__) @@ -470,48 +469,3 @@ class FlowManager: return False else: return True - - @staticmethod - async def get_params_by_flow_and_step_id( - flow: FlowItem, step_id: str - ) -> list[Params] | None: - """递归收集指定节点之前所有路径上的节点参数""" - params = [] - collected = set() # 记录已收集参数的节点 - - async def backtrack(current_id: str, visited: set) -> None: - # 避免循环递归 - if current_id in visited: - return - visited.add(current_id) - - # 获取所有指向当前节点的边 - incoming_edges = [ - edge for edge in flow.edges if edge.target_node == current_id - ] - - for edge in incoming_edges: - source_id = edge.source_node - - # 跳过起始节点 - if source_id == "start": - continue - - # 收集当前节点的参数(如果未被收集过) - if source_id not in collected: - node = flow.nodes.get(source_id) - if node: - collected.add(source_id) - params.append( - Params( - id=source_id, - name=node.name, - parameters=node.parameters.get("parameters", {}), - ), - ) - - # 继续回溯,传递当前路径的visited集合副本 - await backtrack(source_id, visited.copy()) - - await backtrack(step_id, set()) - return params diff --git a/apps/services/node.py b/apps/services/node.py index 6f0d492edacdd7611324a38953417e0fa769249b..bf48e71feec4df41dddc3fecfa1228913940571c 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -16,6 +16,7 @@ NODE_TYPE_MAP = { "API": APINode, } + class NodeManager: """Node管理器""" @@ -29,7 +30,6 @@ class NodeManager: raise ValueError(err) return node["call_id"] - @staticmethod async def get_node(node_id: str) -> NodePool: """获取Node的类型""" @@ -40,7 +40,6 @@ class NodeManager: raise ValueError(err) return NodePool.model_validate(node) - @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" @@ -52,7 +51,6 @@ class NodeManager: return "" return node_doc["name"] - @staticmethod def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: """递归合并参数Schema,将known_params中的值填充到params_schema的对应位置""" @@ -75,7 +73,6 @@ class NodeManager: return params_schema - @staticmethod async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" @@ -100,7 +97,6 @@ class NodeManager: err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) raise ValueError(err) - # 返回参数Schema return ( NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), diff --git a/apps/services/parameter.py b/apps/services/parameter.py new file mode 100644 index 0000000000000000000000000000000000000000..ae375e97fcf92bebfc8dca2b384ffcead4da58bf --- /dev/null +++ b/apps/services/parameter.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""flow Manager""" + +import logging + +from pymongo import ASCENDING + +from apps.services.node import NodeManager +from apps.schemas.flow_topology import FlowItem +from apps.scheduler.slot.slot import Slot +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, + Type +) +from apps.schemas.response_data import ( + OperateAndBindType, + ParamsNode, + StepParams, +) +from apps.services.node import NodeManager +logger = logging.getLogger(__name__) + + +class ParameterManager: + """Parameter Manager""" + @staticmethod + async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]: + """Get operate and bind type""" + result = [] + operate = None + if param_type == Type.NUMBER: + operate = NumberOperate + elif param_type == Type.STRING: + operate = StringOperate + elif param_type == Type.LIST: + operate = ListOperate + elif param_type == Type.BOOL: + operate = BoolOperate + elif param_type == Type.DICT: + operate = DictOperate + if operate: + for item in operate: + result.append(OperateAndBindType( + operate=item, + bind_type=ConditionHandler.get_value_type_from_operate(item))) + return result + + @staticmethod + async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]: + """Get pre params by flow and step id""" + index = 0 + q = [step_id] + in_edges = {} + step_id_to_node_id = {} + for step in flow.nodes: + step_id_to_node_id[step.step_id] = step.node_id + for edge in flow.edges: + if edge.target_node not in in_edges: + in_edges[edge.target_node] = [] + in_edges[edge.target_node].append(edge.source_node) + while index < len(q): + tmp_step_id = q[index] + index += 1 + for i in range(len(in_edges.get(tmp_step_id, []))): + pre_node_id = in_edges[tmp_step_id][i] + if pre_node_id not in q: + q.append(pre_node_id) + pre_step_params = [] + for step_id in q: + node_id = step_id_to_node_id.get(step_id) + params_schema, output_schema = await NodeManager.get_node_params(node_id) + slot = Slot(output_schema) + params_node = slot.get_params_node_from_schema(root='/output') + pre_step_params.append( + StepParams( + stepId=node_id, + name=params_schema.get("name", ""), + paramsNode=params_node + ) + ) + return pre_step_params