From 7ef62d60b9ddbad8d8561e12a3272b978b717591 Mon Sep 17 00:00:00 2001 From: zxstty Date: Fri, 25 Jul 2025 11:53:24 +0800 Subject: [PATCH 01/12] =?UTF-8?q?=E5=8E=BB=E9=99=A4mcp=E5=AE=89=E8=A3=85?= =?UTF-8?q?=E6=97=B6=E5=80=99=E7=9A=84=E6=97=A0=E7=94=A8=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/pool/loader/mcp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 66a516e77..1463d0a1a 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -153,7 +153,6 @@ class MCPLoader(metaclass=SingletonMeta): # 检查目录 template_path = MCP_PATH / "template" / mcp_id await Path.mkdir(template_path, parents=True, exist_ok=True) - ProcessHandler.clear_finished_tasks() # 安装MCP模板 if not ProcessHandler.add_task(mcp_id, MCPLoader._install_template_task, mcp_id, config): err = f"安装任务无法执行,请稍后重试: {mcp_id}" -- Gitee From b929ed4ead156f6d7008ec575830375b7beb0218 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 12:16:10 +0800 Subject: [PATCH 02/12] =?UTF-8?q?=E9=80=82=E9=85=8Drag=E7=9A=84=E4=BD=9C?= =?UTF-8?q?=E8=80=85=E5=92=8C=E5=88=9B=E5=BB=BA=E6=97=B6=E9=97=B4=E7=9A=84?= =?UTF-8?q?=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/context.py | 3 +++ apps/scheduler/scheduler/message.py | 9 +-------- apps/schemas/record.py | 6 +++++- apps/services/document.py | 4 ++++ apps/services/rag.py | 3 +++ 5 files changed, 16 insertions(+), 9 deletions(-) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index b7088d8da..4c2c4cf0b 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -114,11 +114,14 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: used_docs.append( RecordGroupDocument( _id=docs["id"], + author=docs.get("author", ""), + order=docs.get("order", 0), name=docs["name"], abstract=docs.get("abstract", ""), extension=docs.get("extension", ""), size=docs.get("size", 0), associated="answer", + created_at=docs.get("created_at", round(datetime.now(UTC).timestamp(), 3)), ) ) if docs.get("order") is not None: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index c89fdd101..5997b48fb 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -71,14 +71,7 @@ async def push_rag_message( # 如果是文本消息,直接拼接到答案中 full_answer += content_obj.content elif content_obj.event_type == EventType.DOCUMENT_ADD.value: - task.runtime.documents.append({ - "id": content_obj.content.get("id", ""), - "order": content_obj.content.get("order", 0), - "name": content_obj.content.get("name", ""), - "abstract": content_obj.content.get("abstract", ""), - "extension": content_obj.content.get("extension", ""), - "size": content_obj.content.get("size", 0), - }) + task.runtime.documents.append(content_obj.content) # 保存答案 task.runtime.answer = full_answer await TaskManager.save_task(task.id, task) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d7acd3682..e1a995f8e 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -17,8 +17,9 @@ class RecordDocument(Document): """GET /api/record/{conversation_id} Result中的document数据结构""" id: str = Field(alias="_id", default="") + order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None = None + author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] class Config: @@ -103,11 +104,14 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") + order: int = Field(default=0, description="文档顺序") + author: str = Field(default="", description="文档作者") name: str = Field(default="", description="文档名称") abstract: str = Field(default="", description="文档摘要") extension: str = Field(default="", description="文档扩展名") size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] + created_at: float = Field(default=0.0, description="文档创建时间") class Record(RecordData): diff --git a/apps/services/document.py b/apps/services/document.py index 203162da1..451423a9d 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -2,6 +2,7 @@ """文件Manager""" import base64 +from datetime import UTC, datetime import logging import uuid @@ -131,12 +132,15 @@ class DocumentManager: return [ RecordDocument( _id=doc.id, + order=doc.order, + author=doc.author, abstract=doc.abstract, name=doc.name, type=doc.extension, size=doc.size, conversation_id=record_group.get("conversation_id", ""), associated=doc.associated, + created_at=doc.created_at or round(datetime.now(tz=UTC).timestamp(), 3) ) for doc in docs if type is None or doc.associated == type ] diff --git a/apps/services/rag.py b/apps/services/rag.py index 6b6c843dc..efbdfe940 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """对接Euler Copilot RAG""" +from datetime import UTC, datetime import json import logging from collections.abc import AsyncGenerator @@ -156,9 +157,11 @@ class RAG: "id": doc_chunk["docId"], "order": doc_cnt, "name": doc_chunk.get("docName", ""), + "author": doc_chunk.get("docAuthor", ""), "extension": doc_chunk.get("docExtension", ""), "abstract": doc_chunk.get("docAbstract", ""), "size": doc_chunk.get("docSize", 0), + "created_at": doc_chunk.get("docCreatedAt", round(datetime.now(UTC).timestamp(), 3)), }) doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] -- Gitee From 755d5c315b6d9c0d3833383a790b188b0f53b28d Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:40:30 +0800 Subject: [PATCH 03/12] =?UTF-8?q?=E5=AF=B9rag=E7=9A=84team=E5=92=8C?= =?UTF-8?q?=E5=9B=A2=E9=98=9F=E8=8E=B7=E5=8F=96=E5=A2=9E=E5=8A=A0=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E6=8D=95=E8=8E=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/services/conversation.py | 6 +++++- apps/services/knowledge.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 4bcade45c..bac964db1 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -59,7 +59,11 @@ class ConversationManager: model_name=llm.model_name, ) kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except: + logger.error("[ConversationManager] 获取团队知识库列表失败") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 9b4077f92..bd8dfc9e8 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -138,7 +138,11 @@ class KnowledgeBaseManager: return [] kb_ids_update_success = [] kb_item_dict_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + try: + team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) + except Exception as e: + logger.error(f"[KnowledgeBaseManager] 获取团队知识库列表失败: {e}") + team_kb_list = [] for team_kb in team_kb_list: for kb in team_kb["kbList"]: if str(kb["kbId"]) in kb_ids: -- Gitee From d4ec596d65a0f0dfe4b4b442454da0a39b8b1325 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:53:07 +0800 Subject: [PATCH 04/12] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index e1a995f8e..6e9617a33 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,6 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") + user_sub: None | None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From 5fa678be223c4ed7eead9744074fb5402584247d Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 15:56:50 +0800 Subject: [PATCH 05/12] =?UTF-8?q?=E5=AE=8C=E5=96=84RecordDocument=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 6e9617a33..b5e1b0c55 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -19,7 +19,7 @@ class RecordDocument(Document): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None | None + user_sub: None = None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] -- Gitee From b407dc0ce2f530220448ac87544f861f3a8d4513 Mon Sep 17 00:00:00 2001 From: zxstty Date: Wed, 23 Jul 2025 16:21:07 +0800 Subject: [PATCH 06/12] =?UTF-8?q?=E5=AE=8C=E5=96=84DocumentAddContent?= =?UTF-8?q?=E7=9A=84=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/message.py | 2 ++ apps/schemas/message.py | 6 +++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 5997b48fb..a2a45e414 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -108,10 +108,12 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl data=DocumentAddContent( documentId=content_obj.content.get("id", ""), documentOrder=content_obj.content.get("order", 0), + documentAuthor=content_obj.content.get("author", ""), documentName=content_obj.content.get("name", ""), documentAbstract=content_obj.content.get("abstract", ""), documentType=content_obj.content.get("extension", ""), documentSize=content_obj.content.get("size", 0), + createdAt=round(datetime.now(tz=UTC).timestamp(), 3), ).model_dump(exclude_none=True, by_alias=True), ) except Exception: diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d0661224e..5f0aee8ab 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -2,7 +2,7 @@ """队列中的消息结构""" from typing import Any - +from datetime import UTC, datetime from pydantic import BaseModel, Field from apps.schemas.enum_var import EventType, StepStatus @@ -60,10 +60,14 @@ class DocumentAddContent(BaseModel): document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") + document_author: str = Field(description="文档作者", alias="documentAuthor", default="") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) + created_at: float = Field( + description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + ) class FlowStartContent(BaseModel): -- Gitee From af1b364f950f18d5c6165438ea9a04fe6697e070 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 19:22:36 +0800 Subject: [PATCH 07/12] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=97=A0=E9=89=B4?= =?UTF-8?q?=E6=9D=83=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/dependency/user.py | 8 +++++++ apps/routers/api_key.py | 1 + apps/routers/auth.py | 3 +-- apps/scheduler/executor/step.py | 38 ++++++++++++++++----------------- apps/schemas/config.py | 9 +++++++- apps/schemas/message.py | 2 ++ apps/schemas/request_data.py | 4 ++-- 7 files changed, 41 insertions(+), 24 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51e..87cbd2902 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -5,10 +5,12 @@ import logging from fastapi import Depends from fastapi.security import OAuth2PasswordBearer +import secrets from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection +from apps.common.config import Config from apps.services.api_key import ApiKeyManager from apps.services.session import SessionManager @@ -48,6 +50,9 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return secrets.token_hex(16) session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( @@ -69,6 +74,9 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ + if Config().get_config().no_auth.enable: + # 如果启用了无认证访问,直接返回调试用户 + return Config().get_config().no_auth.user_sub session_id = await _get_session_id_from_request(request) if not session_id: raise HTTPException( diff --git a/apps/routers/api_key.py b/apps/routers/api_key.py index 158cfc13a..51366a219 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/api_key.py @@ -6,6 +6,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse + from apps.dependency.user import get_user, verify_user from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp from apps.schemas.response_data import ResponseData diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed62..4a3f82939 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -9,6 +9,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates +from apps.common.config import Config from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user from apps.schemas.collection import Audit @@ -47,8 +48,6 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: user_info = await oidc_provider.get_oidc_user(token["access_token"]) user_sub: str | None = user_info.get("user_sub", None) - if user_sub: - await oidc_provider.set_token(user_sub, token["access_token"], token["refresh_token"]) except Exception as e: logger.exception("User login failed") status_code = status.HTTP_400_BAD_REQUEST if "auth error" in str(e) else status.HTTP_403_FORBIDDEN diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 6b3451fa9..506f3bb10 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -86,8 +86,8 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.task.state.step_id = self.step.step_id # type: ignore[arg-type] + self.task.state.step_name = self.step.step.name # type: ignore[arg-type] # 获取并验证Call类 node_id = self.step.step.node @@ -127,13 +127,13 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.task.state.step_id # type: ignore[arg-type] + current_step_name = self.task.state.step_name # type: ignore[arg-type] # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] + self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 @@ -156,17 +156,17 @@ class StepExecutor(BaseExecutor): # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] else: - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] + self.task.state.step_id = current_step_id # type: ignore[arg-type] + self.task.state.step_name = current_step_name # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens @@ -212,7 +212,7 @@ class StepExecutor(BaseExecutor): await self._run_slot_filling() # 更新状态 - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -224,22 +224,22 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] + self.task.state.error_info = { # type: ignore[arg-type] "err_msg": str(e), "data": {}, } return # 更新执行状态 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time @@ -253,12 +253,12 @@ class StepExecutor(BaseExecutor): # 更新context history = FlowStepHistory( task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] + flow_id=self.task.state.flow_id, # type: ignore[arg-type] + flow_name=self.task.state.flow_name, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, step_description=self.step.step.description, - status=self.task.state.status, # type: ignore[arg-type] + status=self.task.state.status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data, ) diff --git a/apps/schemas/config.py b/apps/schemas/config.py index b88a81f1a..99bcccde8 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -6,6 +6,13 @@ from typing import Literal from pydantic import BaseModel, Field +class NoauthConfig(BaseModel): + """无认证配置""" + + enable: bool = Field(description="是否启用无认证访问", default=False) + user_sub: str = Field(description="调试用户的sub", default="admin") + + class DeployConfig(BaseModel): """部署配置""" @@ -122,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - + no_auth: NoauthConfig deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 5f0aee8ab..cf70a82b8 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -24,6 +24,8 @@ class MessageFlow(BaseModel): flow_id: str = Field(description="Flow ID", alias="flowId") step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") + sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) + sub_step_name: str | None = Field(description="当前子步骤名称", alias="subStepName", default=None) step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 2305dd93d..5d6dc7fcc 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -16,8 +16,8 @@ class RequestDataApp(BaseModel): """模型对话中包含的app信息""" app_id: str = Field(description="应用ID", alias="appId") - flow_id: str = Field(description="Flow ID", alias="flowId") - params: dict[str, Any] = Field(description="插件参数") + flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") + params: dict[str, Any] | None = Field(default=None, description="插件参数") class MockRequestData(BaseModel): -- Gitee From 9653ef8e6c3d6ba15aba377b9721a1f605f7cff7 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 24 Jul 2025 21:13:04 +0800 Subject: [PATCH 08/12] =?UTF-8?q?=E5=AE=8C=E5=96=84Agent=E7=9A=84=E5=BC=80?= =?UTF-8?q?=E5=8F=91&=E4=BF=AE=E5=A4=8Dmcp=E6=B3=A8=E5=86=8C=E6=97=B6?= =?UTF-8?q?=E5=80=99=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/chat.py | 11 ++++++++--- apps/scheduler/call/mcp/mcp.py | 7 ------- apps/scheduler/executor/agent.py | 28 +++++++++------------------ apps/scheduler/scheduler/scheduler.py | 2 +- apps/schemas/enum_var.py | 3 +++ apps/schemas/mcp.py | 3 ++- apps/schemas/request_data.py | 1 + apps/schemas/task.py | 3 +++ apps/services/task.py | 9 --------- 9 files changed, 27 insertions(+), 40 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 7fe5162c0..589000be3 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -36,11 +36,16 @@ async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> T # 生成group_id if not post_body.group_id: post_body.group_id = str(uuid.uuid4()) - # 创建或还原Task + if post_body.new_task: + # 创建或还原Task + task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) + if task: + await TaskManager.delete_task_by_task_id(task.id) task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) # 更改信息并刷新数据库 - task.runtime.question = post_body.question - task.ids.group_id = post_body.group_id + if post_body.new_task: + task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id return task diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 661e9ada7..4e6a1bb73 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -35,7 +35,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): text_output: bool = Field(description="是否将结果以文本形式返回", default=True) to_user: bool = Field(description="是否将结果返回给用户", default=True) - @classmethod def info(cls) -> CallInfo: """ @@ -46,7 +45,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """ return CallInfo(name="MCP", description="调用MCP Server,执行工具") - async def _init(self, call_vars: CallVars) -> MCPInput: """初始化MCP""" # 获取MCP交互类 @@ -63,7 +61,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): return MCPInput(avaliable_tools=avaliable_tools, max_steps=self.max_steps) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行MCP""" # 生成计划 @@ -80,7 +77,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): async for chunk in self._generate_answer(): yield chunk - async def _generate_plan(self) -> AsyncGenerator[CallOutputChunk, None]: """生成执行计划""" # 开始提示 @@ -103,7 +99,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): data=self._plan.model_dump(), ) - async def _execute_plan_item(self, plan_item: MCPPlanItem) -> AsyncGenerator[CallOutputChunk, None]: """执行单个计划项""" # 判断是否为Final @@ -141,7 +136,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): }, ) - async def _generate_answer(self) -> AsyncGenerator[CallOutputChunk, None]: """生成总结""" # 提示开始总结 @@ -163,7 +157,6 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): ).model_dump(), ) - def _create_output( self, text: str, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index f6814dd36..2ff4f3d39 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -7,6 +7,8 @@ from pydantic import Field from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.agent.mcp import MCPAgent +from apps.schemas.task import ExecutorState, StepQueueItem +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -15,26 +17,14 @@ class MCPAgentExecutor(BaseExecutor): """MCP Agent执行器""" question: str = Field(description="用户输入") - max_steps: int = Field(default=10, description="最大步数") + max_steps: int = Field(default=20, description="最大步数") servers_id: list[str] = Field(description="MCP server id") agent_id: str = Field(default="", description="Agent ID") agent_description: str = Field(default="", description="Agent描述") - async def run(self) -> None: - """运行MCP Agent""" - agent = await MCPAgent.create( - servers_id=self.servers_id, - max_steps=self.max_steps, - task=self.task, - msg_queue=self.msg_queue, - question=self.question, - agent_id=self.agent_id, - description=self.agent_description, - ) - - try: - answer = await agent.run(self.question) - self.task = agent.task - self.task.runtime.answer = answer - except Exception as e: - logger.error(f"Error: {str(e)}") + async def load_state(self) -> None: + """从数据库中加载FlowExecutor的状态""" + logger.info("[FlowExecutor] 加载Executor状态") + # 尝试恢复State + if self.task.state: + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index ed73638ce..417f93d28 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -206,7 +206,7 @@ class Scheduler: task=self.task, msg_queue=queue, question=post_body.question, - max_steps=app_metadata.history_len, + history_len=app_metadata.history_len, servers_id=servers_id, background=background, agent_id=app_info.app_id, diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9a20ba84d..a84dc3a35 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -15,6 +15,7 @@ class SlotType(str, Enum): class StepStatus(str, Enum): """步骤状态""" + WAITING = "waiting" RUNNING = "running" SUCCESS = "success" ERROR = "error" @@ -38,6 +39,8 @@ class EventType(str, Enum): TEXT_ADD = "text.add" GRAPH = "graph" DOCUMENT_ADD = "document.add" + STEP_WAITING_FOR_START = "step.waiting_for_start" + STEP_WAITING_FOR_PARAM = "step.waiting_for_param" FLOW_START = "flow.start" STEP_INPUT = "step.input" STEP_OUTPUT = "step.output" diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 44021b0ed..60c8f17b4 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """MCP 相关数据结构""" +import uuid from enum import Enum from typing import Any @@ -117,7 +118,7 @@ class MCPToolSelectResult(BaseModel): class MCPPlanItem(BaseModel): """MCP 计划""" - + id: str = Field(default_factory=lambda: str(uuid.uuid4())) content: str = Field(description="计划内容") tool: str = Field(description="工具名称") instruction: str = Field(description="工具指令") diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 5d6dc7fcc..a3a8848c3 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -46,6 +46,7 @@ class RequestData(BaseModel): files: list[str] = Field(default=[], description="文件列表") app: RequestDataApp | None = Field(default=None, description="应用") debug: bool = Field(default=False, description="是否调试") + new_task: bool = Field(default=True, description="是否新建任务") class QuestionBlacklistRequest(BaseModel): diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 8efcb5991..37fdebbfb 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, Field from apps.schemas.enum_var import StepStatus from apps.schemas.flow import Step +from apps.schemas.mcp import MCPPlan class FlowStepHistory(BaseModel): @@ -42,6 +43,7 @@ class ExecutorState(BaseModel): # 附加信息 step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") + step_description: str = Field(description="当前步骤描述", default="") app_id: str = Field(description="应用ID") slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) error_info: dict[str, Any] = Field(description="错误信息", default={}) @@ -75,6 +77,7 @@ class TaskRuntime(BaseModel): summary: str = Field(description="摘要", default="") filled: dict[str, Any] = Field(description="填充的槽位", default={}) documents: list[dict[str, Any]] = Field(description="文档列表", default=[]) + temporary_plans: MCPPlan | None = Field(description="临时计划列表", default=None) class Task(BaseModel): diff --git a/apps/services/task.py b/apps/services/task.py index 1e672be69..2456d96bf 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -45,7 +45,6 @@ class TaskManager: return Task.model_validate(task) - @staticmethod async def get_task_by_group_id(group_id: str, conversation_id: str) -> Task | None: """获取组ID的最后一条问答组关联的任务""" @@ -58,7 +57,6 @@ class TaskManager: task = await task_collection.find_one({"_id": record_group_obj.task_id}) return Task.model_validate(task) - @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" @@ -68,7 +66,6 @@ class TaskManager: return None return Task.model_validate(task) - @staticmethod async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: """根据record_group_id获取flow信息""" @@ -95,7 +92,6 @@ class TaskManager: else: return flow_context_list - @staticmethod async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: """根据task_id获取flow信息""" @@ -115,7 +111,6 @@ class TaskManager: else: return flow_context - @staticmethod async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: """保存flow信息到flow_context""" @@ -137,7 +132,6 @@ class TaskManager: except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") - @staticmethod async def delete_task_by_task_id(task_id: str) -> None: """通过task_id删除Task信息""" @@ -148,7 +142,6 @@ class TaskManager: if task: await task_collection.delete_one({"_id": task_id}) - @staticmethod async def delete_tasks_by_conversation_id(conversation_id: str) -> None: """通过ConversationID删除Task信息""" @@ -167,7 +160,6 @@ class TaskManager: await task_collection.delete_many({"conversation_id": conversation_id}, session=session) await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session) - @classmethod async def get_task( cls, @@ -212,7 +204,6 @@ class TaskManager: runtime=TaskRuntime(), ) - @classmethod async def save_task(cls, task_id: str, task: Task) -> None: """保存任务块""" -- Gitee From 23fd74f6c25e2dd98bb487b485e42663da1a2369 Mon Sep 17 00:00:00 2001 From: zengxianghuai Date: Fri, 25 Jul 2025 17:34:51 +0800 Subject: [PATCH 09/12] =?UTF-8?q?=E5=A2=9E=E5=8A=A0choice=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E4=BB=A5=E5=8F=8A=E7=9B=B8=E5=85=B3=E8=8E=B7=E5=8F=96?= =?UTF-8?q?step=5Fid=E5=92=8C=E6=93=8D=E4=BD=9C=E7=AC=A6=E7=9A=84=E8=B7=AF?= =?UTF-8?q?=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/parameter.py | 143 ++++++++++++++ apps/scheduler/call/choice/choice.py | 79 +++++++- .../call/choice/condition_handler.py | 184 ++++++++++++++++++ apps/scheduler/call/choice/schema.py | 66 +++++++ apps/scheduler/executor/flow.py | 9 + apps/schemas/response_data.py | 18 ++ apps/services/flow.py | 26 ++- 7 files changed, 516 insertions(+), 9 deletions(-) create mode 100644 apps/routers/parameter.py create mode 100644 apps/scheduler/call/choice/condition_handler.py diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py new file mode 100644 index 000000000..33789dfdc --- /dev/null +++ b/apps/routers/parameter.py @@ -0,0 +1,143 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency import get_user +from apps.dependency.user import verify_user +from apps.scheduler.call.choice.choice import Choice +from apps.schemas.response_data import ( + FlowStructureGetMsg, + FlowStructureGetRsp, + GetParamsMsg, + GetParamsRsp, + ResponseData, +) +from apps.services.application import AppManager +from apps.services.flow import FlowManager + +router = APIRouter( + prefix="/api/parameter", + tags=["parameter"], + dependencies=[ + Depends(verify_user), + ], +) + + +@router.get("", response_model={ + status.HTTP_403_FORBIDDEN: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + },) +async def get_parameters( + user_sub: Annotated[str, Depends(get_user)], + app_id: Annotated[str, Query(alias="appId")], + flow_id: Annotated[str, Query(alias="flowId")], + step_id: Annotated[str, Query(alias="stepId")], +) -> JSONResponse: + """Get parameters for node choice.""" + if not await AppManager.validate_user_app_access(user_sub, app_id): + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content=FlowStructureGetRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=FlowStructureGetMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ) + flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + if not flow: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=FlowStructureGetRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该流", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + result = await FlowManager.get_step_by_flow_and_step_id(flow, step_id) + if not result: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=FlowStructureGetRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该节点", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=GetParamsRsp( + code=status.HTTP_200_OK, + message="获取参数成功", + result=GetParamsMsg(result=result), + ).model_dump(exclude_none=True, by_alias=True), + ) + +async def operate_parameters(operate: str) -> list[str] | None: + """ + 根据操作类型获取对应的操作符参数列表 + + Args: + operate: 操作类型,支持 'int', 'str', 'bool' + + Returns: + 对应的操作符参数列表,若类型不支持则返回None + + """ + string = [ + "equal", + "not_equal", + "great",#长度大于 + "great_equals",#长度大于等于 + "less",#长度小于 + "less_equals",#长度小于等于 + "greater", + "greater_equals", + "smaller", + "smaller_equals", + ] + integer = [ + "equal", + "not_equal", + "great", + "great_equals", + "less", + "less_equals", + ] + boolen = ["equal", "not_equal", "is_empty", "not_empty"] + if operate in string: + return string + if operate in integer: + return integer + if operate in boolen: + return boolen + return None + +@router.get("/operate", response_model={ + status.HTTP_200_OK: {"model": ResponseData}, + status.HTTP_404_NOT_FOUND: {"model": ResponseData}, + },) +async def get_operate_parameters( + user_sub: Annotated[str, Depends(get_user)], + operate: Annotated[str, Query(alias="operate")] +) -> JSONResponse: + """Get parameters for node choice.""" + result = await operate_parameters(operate) + if not result: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="未找到该符号", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData( + code=status.HTTP_200_OK, + message="获取参数成功", + result=result, + ).model_dump(exclude_none=True, by_alias=True), + ) \ No newline at end of file diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index a5edf21af..e9cd7d41a 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,19 +1,82 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" -from enum import Enum - -from apps.scheduler.call.choice.schema import ChoiceInput, ChoiceOutput -from apps.scheduler.call.core import CoreCall +import logging +from collections.abc import AsyncGenerator +from typing import Any +from pydantic import Field -class Operator(str, Enum): - """Choice工具支持的运算符""" +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + ChoiceInput, + ChoiceOutput, + Logic, +) +from apps.scheduler.call.core import CoreCall +from apps.schemas.enum_var import CallOutputType +from apps.schemas.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) - pass +logger = logging.getLogger(__name__) class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" - pass + to_user: bool = Field(default=False) + choices: list[ChoiceBranch] = Field(description="分支", default=[]) + + @classmethod + def info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") + + async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: + """替换choices中的系统变量""" + valid_choices = [] + for choice in self.choices: + if choice.logic not in [Logic.AND, Logic.OR]: + logger.warning("分支 %s 的逻辑运算符 %s 无效,已跳过",choice.branch_id, choice.logic) + continue + valid_conditions = [] + for condition in choice.conditions: + if condition.left.step_id: + condition.left.value = self._extract_history_variables( + condition.left.step_id, call_vars.history, + ) + valid_conditions.append(condition) + else: + logger.warning("条件 %s 的左侧变量 %s 无效,已跳过", condition.condition_id, condition.left) + continue + choice.conditions = valid_conditions + valid_choices.append(choice.dict()) + return valid_choices + + async def _init(self, call_vars: CallVars) -> ChoiceInput: + """初始化Choice工具""" + return ChoiceInput( + choices=await self._prepare_message(call_vars), + ) + + async def _exec( + self, input_data: dict[str, Any] + ) -> AsyncGenerator[CallOutputChunk, None]: + """执行Choice工具""" + # 解析输入数据 + data = ChoiceInput(**input_data) + ret: CallOutputChunk = CallOutputChunk( + type=CallOutputType.DATA, + content=None, + ) + condition_handler = ConditionHandler() + try: + ret.content = condition_handler.handler(data.choices) + yield ret + except Exception as e: + raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py new file mode 100644 index 000000000..feab5c647 --- /dev/null +++ b/apps/scheduler/call/choice/condition_handler.py @@ -0,0 +1,184 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""处理条件分支的工具""" + +import logging + +from pydantic import BaseModel + +from apps.scheduler.call.choice.schema import ChoiceBranch, ChoiceOutput, Condition, Logic, Operator, Type, Value + +logger = logging.getLogger(__name__) + + +class ConditionHandler(BaseModel): + """条件分支处理器""" + + def handler(self, choices: list[ChoiceBranch]) -> ChoiceOutput: + """处理条件""" + default_branch = [c for c in choices if c.is_default] + + for block_judgement in choices: + results = [] + if block_judgement.is_default: + continue + for condition in block_judgement.conditions: + result = self._judge_condition(condition) + results.append(result) + if block_judgement.logic == Logic.AND: + final_result = all(results) + elif block_judgement.logic == Logic.OR: + final_result = any(results) + + if final_result: + return { + "branch_id": block_judgement.branch_id, + "message": f"选择分支:{block_judgement.branch_id}", + } + + # 如果没有匹配的分支,选择默认分支 + if default_branch: + return { + "branch_id": default_branch[0].branch_id, + "message": f"选择默认分支:{default_branch[0].branch_id}", + } + return { + "branch_id": "", + "message": "没有匹配的分支,且没有默认分支", + } + + def _judge_condition(self, condition: Condition) -> bool: + """ + 判断条件是否成立。 + + Args: + condition (Condition): 'left', 'operator', 'right', 'type' + + Returns: + bool + + """ + left = condition.left + operator = condition.operator + right = condition.right + value_type = condition.type + + result = None + if value_type == Type.STRING: + result = self._judge_string_condition(left, operator, right) + elif value_type == Type.INT: + result = self._judge_int_condition(left, operator, right) + elif value_type == Type.BOOL: + result = self._judge_bool_condition(left, operator, right) + else: + logger.error("不支持的数据类型: %s", value_type) + msg = f"不支持的数据类型: {value_type}" + raise ValueError(msg) + return result + + def _judge_string_condition(self, left: Value, operator: Operator, right: Value) -> bool: + """ + 判断字符串类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operator (Operator): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, str): + logger.error("左值不是字符串类型: %s", left_value) + msg = "左值必须是字符串类型" + raise TypeError(msg) + right_value = right.value + result = False + if operator == Operator.EQUAL: + result = left_value == right_value + elif operator == Operator.NEQUAL: + result = left_value != right_value + elif operator == Operator.GREAT: + result = len(left_value) > len(right_value) + elif operator == Operator.GREAT_EQUALS: + result = len(left_value) >= len(right_value) + elif operator == Operator.LESS: + result = len(left_value) < len(right_value) + elif operator == Operator.LESS_EQUALS: + result = len(left_value) <= len(right_value) + elif operator == Operator.GREATER: + result = left_value > right_value + elif operator == Operator.GREATER_EQUALS: + result = left_value >= right_value + elif operator == Operator.SMALLER: + result = left_value < right_value + elif operator == Operator.SMALLER_EQUALS: + result = left_value <= right_value + elif operator == Operator.CONTAINS: + result = right_value in left_value + elif operator == Operator.NOT_CONTAINS: + result = right_value not in left_value + return result + + def _judge_int_condition(self, left: Value, operator: Operator, right: Value) -> bool: # noqa: PLR0911 + """ + 判断整数类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operator (Operator): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, int): + logger.error("左值不是整数类型: %s", left_value) + msg = "左值必须是整数类型" + raise TypeError(msg) + right_value = right.value + if operator == Operator.EQUAL: + return left_value == right_value + if operator == Operator.NEQUAL: + return left_value != right_value + if operator == Operator.GREAT: + return left_value > right_value + if operator == Operator.GREAT_EQUALS: + return left_value >= right_value + if operator == Operator.LESS: + return left_value < right_value + if operator == Operator.LESS_EQUALS: + return left_value <= right_value + return False + + def _judge_bool_condition(self, left: Value, operator: Operator, right: Value) -> bool: + """ + 判断布尔类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operator (Operator): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, bool): + logger.error("左值不是布尔类型: %s", left_value) + msg = "左值必须是布尔类型" + raise TypeError(msg) + right_value = right.value + if operator == Operator.EQUAL: + return left_value == right_value + if operator == Operator.NEQUAL: + return left_value != right_value + if operator == Operator.IS_EMPTY: + return left_value == "" + if operator == Operator.NOT_EMPTY: + return left_value != "" + return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index 60b62d09f..ed1c628cf 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,12 +1,78 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +from enum import Enum + +from pydantic import BaseModel, Field + from apps.scheduler.call.core import DataBase +class Operator(str, Enum): + """Choice Call支持的运算符""" + + EQUAL = "equal" + NEQUAL = "not_equal" + GREAT = "great" + GREAT_EQUALS = "great_equals" + LESS = "less" + LESS_EQUALS = "less_equals" + # string + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + GREATER = "greater" + GREATER_EQUALS = "greater_equals" + SMALLER = "smaller" + SMALLER_EQUALS = "smaller_equals" + # bool + IS_EMPTY = "is_empty" + NOT_EMPTY = "not_empty" + + +class Logic(str, Enum): + """Choice 工具支持的逻辑运算符""" + + AND = "and" + OR = "or" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + INT = "int" + BOOL = "bool" + + +class Value(BaseModel): + """值的结构""" + + step_id: str = Field(description="步骤id", default="") + value: str | int | bool = Field(description="值", default=None) + + +class Condition(BaseModel): + """单个条件""" + + type: Type = Field(description="值的类型", default=Type.STRING) + left: Value = Field(description="左值") + right: Value = Field(description="右值") + operator: Operator = Field(description="运算符", default="equal") + id: int = Field(description="条件ID") + + +class ChoiceBranch(BaseModel): + """子分支""" + + branch_id: str = Field(description="分支ID", default="") + logic: Logic = Field(description="逻辑运算符", default=Logic.AND) + conditions: list[Condition] = Field(description="条件列表", default=[]) + is_default: bool = Field(description="是否为默认分支", default=False) class ChoiceInput(DataBase): """Choice Call的输入""" + choices: list[ChoiceBranch] = Field(description="分支", default=[]) + class ChoiceOutput(DataBase): """Choice Call的输出""" diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a70d0d707..d8d22c46f 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -117,6 +117,15 @@ class FlowExecutor(BaseExecutor): # 如果当前步骤为结束,则直接返回 if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] + if self.task.state.step_name == "Choice": + # 如果是choice节点,获取分支ID + branch_id = self.task.context[-1]["output_data"].get("branch_id", None) + if branch_id: + self.task.state.step_id = self.task.state.step_id + "." + branch_id + logger.info("[FlowExecutor] 分支ID:%s", branch_id) + else: + logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") + return [] next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index b2a187291..20d7ad9b0 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -628,3 +628,21 @@ class ListLLMRsp(ResponseData): """GET /api/llm 返回数据结构""" result: list[LLMProviderInfo] = Field(default=[], title="Result") + +class Params(BaseModel): + """参数数据结构""" + + id: str = Field(..., description="StepID") + name: str = Field(..., description="Step名称") + parameters: dict[str, Any] = Field(..., description="参数") + operate: str = Field(..., description="比较符") + +class GetParamsMsg(BaseModel): + """GET /api/params 返回数据结构""" + + result: list[Params] = Field(..., title="Result") + +class GetParamsRsp(ResponseData): + """GET /api/params 返回数据结构""" + + result: GetParamsMsg \ No newline at end of file diff --git a/apps/services/flow.py b/apps/services/flow.py index 9275fc60c..0a253c317 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,7 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager - +from apps.schemas.response_data import Params logger = logging.getLogger(__name__) @@ -470,3 +470,27 @@ class FlowManager: return False else: return True + + @staticmethod + async def get_step_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[Params] | None: + """ + 寻找stepID对应的节点之前的所有节点的参数 + """ + params = [] + try: + for edge in flow.edges: + if edge.target_node == step_id: + id = edge.source_node + if id == "start": + break + params.append(Params( + id = id, + name = flow.nodes[id].name, + parameters = flow.nodes[id].parameters.get("parameters", {}) + )) + step_id = edge.source_node + return params + except Exception: + logger.exception("[FlowManager] 获取节点失败") + return None + \ No newline at end of file -- Gitee From 884000ea5b1287f8d86b9a137d116ed503e4e3c7 Mon Sep 17 00:00:00 2001 From: zengxianghuai Date: Fri, 25 Jul 2025 17:55:41 +0800 Subject: [PATCH 10/12] =?UTF-8?q?=E5=A2=9E=E5=8A=A0choice=E5=BC=82?= =?UTF-8?q?=E5=B8=B8=E6=8D=95=E8=8E=B7=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/choice/choice.py | 53 ++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index e9cd7d41a..24df0daea 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -37,25 +37,50 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """返回Call的名称和描述""" return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") + def _raise_value_error(self, msg: str) -> None: + """统一处理 ValueError 异常抛出""" + logger.warning(msg) + raise ValueError(msg) + async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: """替换choices中的系统变量""" valid_choices = [] + for choice in self.choices: - if choice.logic not in [Logic.AND, Logic.OR]: - logger.warning("分支 %s 的逻辑运算符 %s 无效,已跳过",choice.branch_id, choice.logic) + try: + # 验证逻辑运算符 + if choice.logic not in [Logic.AND, Logic.OR]: + msg = f"无效的逻辑运算符: {choice.logic}" + self._raise_value_error(msg) + + valid_conditions = [] + for condition in choice.conditions: + # 处理左值 + if condition.left.step_id: + condition.left.value = self._extract_history_variables(condition.left.step_id, call_vars.history) + # 检查历史变量是否成功提取 + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + self._raise_value_error(msg) + else: + msg = "左侧变量缺少step_id" + self._raise_value_error(msg) + + valid_conditions.append(condition) + + # 如果所有条件都无效,抛出异常 + if not valid_conditions: + msg = "分支没有有效条件" + self._raise_value_error(msg) + + # 更新有效条件 + choice.conditions = valid_conditions + valid_choices.append(choice.dict()) + + except ValueError as e: + logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) continue - valid_conditions = [] - for condition in choice.conditions: - if condition.left.step_id: - condition.left.value = self._extract_history_variables( - condition.left.step_id, call_vars.history, - ) - valid_conditions.append(condition) - else: - logger.warning("条件 %s 的左侧变量 %s 无效,已跳过", condition.condition_id, condition.left) - continue - choice.conditions = valid_conditions - valid_choices.append(choice.dict()) + return valid_choices async def _init(self, call_vars: CallVars) -> ChoiceInput: -- Gitee From b51ed5f71018454a29c6dc0bee93593f4c92e177 Mon Sep 17 00:00:00 2001 From: zengxianghuai Date: Fri, 25 Jul 2025 18:45:20 +0800 Subject: [PATCH 11/12] =?UTF-8?q?=E4=BF=AE=E6=94=B9get=5Fparams=5Fby=5Fflo?= =?UTF-8?q?w=5Fand=5Fstep=5Fid=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/parameter.py | 4 +-- apps/services/flow.py | 63 ++++++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 33789dfdc..538c25568 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -55,7 +55,7 @@ async def get_parameters( result={}, ).model_dump(exclude_none=True, by_alias=True), ) - result = await FlowManager.get_step_by_flow_and_step_id(flow, step_id) + result = await FlowManager.get_params_by_flow_and_step_id(flow, step_id) if not result: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -120,7 +120,7 @@ async def operate_parameters(operate: str) -> list[str] | None: },) async def get_operate_parameters( user_sub: Annotated[str, Depends(get_user)], - operate: Annotated[str, Query(alias="operate")] + operate: Annotated[str, Query(alias="operate")], ) -> JSONResponse: """Get parameters for node choice.""" result = await operate_parameters(operate) diff --git a/apps/services/flow.py b/apps/services/flow.py index 0a253c317..cb8ad57fd 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -472,25 +472,46 @@ class FlowManager: return True @staticmethod - async def get_step_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[Params] | None: - """ - 寻找stepID对应的节点之前的所有节点的参数 - """ + async def get_params_by_flow_and_step_id( + flow: FlowItem, step_id: str + ) -> list[Params] | None: + """递归收集指定节点之前所有路径上的节点参数""" params = [] - try: - for edge in flow.edges: - if edge.target_node == step_id: - id = edge.source_node - if id == "start": - break - params.append(Params( - id = id, - name = flow.nodes[id].name, - parameters = flow.nodes[id].parameters.get("parameters", {}) - )) - step_id = edge.source_node - return params - except Exception: - logger.exception("[FlowManager] 获取节点失败") - return None - \ No newline at end of file + collected = set() # 记录已收集参数的节点 + + async def backtrack(current_id: str, visited: set) -> None: + # 避免循环递归 + if current_id in visited: + return + visited.add(current_id) + + # 获取所有指向当前节点的边 + incoming_edges = [ + edge for edge in flow.edges if edge.target_node == current_id + ] + + for edge in incoming_edges: + source_id = edge.source_node + + # 跳过起始节点 + if source_id == "start": + continue + + # 收集当前节点的参数(如果未被收集过) + if source_id not in collected: + node = flow.nodes.get(source_id) + if node: + collected.add(source_id) + params.append( + Params( + id=source_id, + name=node.name, + parameters=node.parameters.get("parameters", {}), + ), + ) + + # 继续回溯,传递当前路径的visited集合副本 + await backtrack(source_id, visited.copy()) + + await backtrack(step_id, set()) + return params -- Gitee From 4a58ac8a819e5cacc6401aeb3caf868b86be36ce Mon Sep 17 00:00:00 2001 From: zxstty Date: Sun, 27 Jul 2025 15:33:20 +0800 Subject: [PATCH 12/12] =?UTF-8?q?=E5=AE=8C=E5=96=84=E9=80=89=E6=8B=A9?= =?UTF-8?q?=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/main.py | 2 + apps/routers/parameter.py | 100 ++----- apps/scheduler/call/__init__.py | 3 +- apps/scheduler/call/choice/choice.py | 95 ++++-- .../call/choice/condition_handler.py | 272 +++++++++++++----- apps/scheduler/call/choice/schema.py | 63 ++-- apps/scheduler/call/core.py | 26 +- apps/scheduler/executor/flow.py | 20 +- apps/scheduler/executor/step.py | 3 - apps/scheduler/slot/slot.py | 43 ++- apps/schemas/config.py | 2 +- apps/schemas/parameters.py | 69 +++++ apps/schemas/response_data.py | 47 ++- apps/services/flow.py | 46 --- apps/services/node.py | 6 +- apps/services/parameter.py | 86 ++++++ 16 files changed, 564 insertions(+), 319 deletions(-) create mode 100644 apps/schemas/parameters.py create mode 100644 apps/services/parameter.py diff --git a/apps/main.py b/apps/main.py index c4ca2bfb1..c26e5e471 100644 --- a/apps/main.py +++ b/apps/main.py @@ -36,6 +36,7 @@ from apps.routers import ( record, service, user, + parameter ) from apps.scheduler.pool.pool import Pool @@ -66,6 +67,7 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(parameter.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 538c25568..6edbe2e14 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -5,13 +5,10 @@ from fastapi.responses import JSONResponse from apps.dependency import get_user from apps.dependency.user import verify_user -from apps.scheduler.call.choice.choice import Choice +from apps.services.parameter import ParameterManager from apps.schemas.response_data import ( - FlowStructureGetMsg, - FlowStructureGetRsp, - GetParamsMsg, - GetParamsRsp, - ResponseData, + GetOperaRsp, + GetParamsRsp ) from apps.services.application import AppManager from apps.services.flow import FlowManager @@ -25,10 +22,7 @@ router = APIRouter( ) -@router.get("", response_model={ - status.HTTP_403_FORBIDDEN: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - },) +@router.get("", response_model=GetParamsRsp) async def get_parameters( user_sub: Annotated[str, Depends(get_user)], app_id: Annotated[str, Query(alias="appId")], @@ -39,105 +33,45 @@ async def get_parameters( if not await AppManager.validate_user_app_access(user_sub, app_id): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=FlowStructureGetRsp( + content=GetParamsRsp( code=status.HTTP_403_FORBIDDEN, message="用户没有权限访问该流", - result=FlowStructureGetMsg(), + result=[], ).model_dump(exclude_none=True, by_alias=True), ) flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) if not flow: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( + content=GetParamsRsp( code=status.HTTP_404_NOT_FOUND, message="未找到该流", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) - result = await FlowManager.get_params_by_flow_and_step_id(flow, step_id) - if not result: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( - code=status.HTTP_404_NOT_FOUND, - message="未找到该节点", - result={}, + result=[], ).model_dump(exclude_none=True, by_alias=True), ) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) return JSONResponse( status_code=status.HTTP_200_OK, content=GetParamsRsp( code=status.HTTP_200_OK, message="获取参数成功", - result=GetParamsMsg(result=result), + result=result ).model_dump(exclude_none=True, by_alias=True), ) -async def operate_parameters(operate: str) -> list[str] | None: - """ - 根据操作类型获取对应的操作符参数列表 - - Args: - operate: 操作类型,支持 'int', 'str', 'bool' - - Returns: - 对应的操作符参数列表,若类型不支持则返回None - """ - string = [ - "equal", - "not_equal", - "great",#长度大于 - "great_equals",#长度大于等于 - "less",#长度小于 - "less_equals",#长度小于等于 - "greater", - "greater_equals", - "smaller", - "smaller_equals", - ] - integer = [ - "equal", - "not_equal", - "great", - "great_equals", - "less", - "less_equals", - ] - boolen = ["equal", "not_equal", "is_empty", "not_empty"] - if operate in string: - return string - if operate in integer: - return integer - if operate in boolen: - return boolen - return None - -@router.get("/operate", response_model={ - status.HTTP_200_OK: {"model": ResponseData}, - status.HTTP_404_NOT_FOUND: {"model": ResponseData}, - },) +@router.get("/operate", response_model=GetOperaRsp) async def get_operate_parameters( user_sub: Annotated[str, Depends(get_user)], - operate: Annotated[str, Query(alias="operate")], + param_type: Annotated[str, Query(alias="ParamType")], ) -> JSONResponse: """Get parameters for node choice.""" - result = await operate_parameters(operate) - if not result: - return JSONResponse( - status_code=status.HTTP_404_NOT_FOUND, - content=ResponseData( - code=status.HTTP_404_NOT_FOUND, - message="未找到该符号", - result=[], - ).model_dump(exclude_none=True, by_alias=True), - ) + result = await ParameterManager.get_operate_and_bind_type(param_type) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( + content=GetOperaRsp( code=status.HTTP_200_OK, - message="获取参数成功", - result=result, + message="获取操作成功", + result=result ).model_dump(exclude_none=True, by_alias=True), - ) \ No newline at end of file + ) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2ee8b8628..c5a6f0549 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -8,7 +8,7 @@ from apps.scheduler.call.mcp.mcp import MCP from apps.scheduler.call.rag.rag import RAG from apps.scheduler.call.sql.sql import SQL from apps.scheduler.call.suggest.suggest import Suggestion - +from apps.scheduler.call.choice.choice import Choice # 只包含需要在编排界面展示的工具 __all__ = [ "API", @@ -18,4 +18,5 @@ __all__ = [ "SQL", "Graph", "Suggestion", + "Choice" ] diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index 24df0daea..8cab82888 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -1,6 +1,8 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """使用大模型或使用程序做出判断""" +import ast +import copy import logging from collections.abc import AsyncGenerator from typing import Any @@ -9,11 +11,13 @@ from pydantic import Field from apps.scheduler.call.choice.condition_handler import ConditionHandler from apps.scheduler.call.choice.schema import ( + Condition, ChoiceBranch, ChoiceInput, ChoiceOutput, Logic, ) +from apps.schemas.parameters import Type from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( @@ -30,17 +34,13 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """Choice工具""" to_user: bool = Field(default=False) - choices: list[ChoiceBranch] = Field(description="分支", default=[]) + choices: list[ChoiceBranch] = Field(description="分支", default=[ChoiceBranch(), + ChoiceBranch(conditions=[Condition()], is_default=False)]) @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" - return CallInfo(name="Choice", description="使用大模型或使用程序做出判断") - - def _raise_value_error(self, msg: str) -> None: - """统一处理 ValueError 异常抛出""" - logger.warning(msg) - raise ValueError(msg) + return CallInfo(name="选择器", description="使用大模型或使用程序做出判断") async def _prepare_message(self, call_vars: CallVars) -> list[dict[str, Any]]: """替换choices中的系统变量""" @@ -51,31 +51,76 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): # 验证逻辑运算符 if choice.logic not in [Logic.AND, Logic.OR]: msg = f"无效的逻辑运算符: {choice.logic}" - self._raise_value_error(msg) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue valid_conditions = [] - for condition in choice.conditions: + for i in range(len(choice.conditions)): + condition = copy.deepcopy(choice.conditions[i]) # 处理左值 - if condition.left.step_id: - condition.left.value = self._extract_history_variables(condition.left.step_id, call_vars.history) + if condition.left.step_id is not None: + condition.left.value = self._extract_history_variables( + condition.left.step_id+'/'+condition.left.value, call_vars.history) # 检查历史变量是否成功提取 - if condition.left.value is None: - msg = f"步骤 {condition.left.step_id} 的历史变量不存在" - self._raise_value_error(msg) + if condition.left.value is None: + msg = f"步骤 {condition.left.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.left.value, condition.left.type): + msg = f"左值类型不匹配: {condition.left.value} 应为 {condition.left.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue else: msg = "左侧变量缺少step_id" - self._raise_value_error(msg) - - valid_conditions.append(condition) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + # 处理右值 + if condition.right.step_id is not None: + condition.right.value = self._extract_history_variables( + condition.right.step_id+'/'+condition.right.value, call_vars.history) + # 检查历史变量是否成功提取 + if condition.right.value is None: + msg = f"步骤 {condition.right.step_id} 的历史变量不存在" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + else: + # 如果右值没有step_id,尝试从call_vars中获取 + right_value_type = await ConditionHandler.get_value_type_from_operate( + condition.operate) + if right_value_type is None: + msg = f"不支持的运算符: {condition.operate}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if condition.right.type != right_value_type: + msg = f"右值类型不匹配: {condition.right.value} 应为 {right_value_type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + if right_value_type == Type.STRING: + condition.right.value = str(condition.right.value) + else: + condition.right.value = ast.literal_eval(condition.right.value) + if not ConditionHandler.check_value_type( + condition.right.value, condition.right.type): + msg = f"右值类型不匹配: {condition.right.value} 应为 {condition.right.type.value}" + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue + valid_conditions.append(condition) # 如果所有条件都无效,抛出异常 if not valid_conditions: msg = "分支没有有效条件" - self._raise_value_error(msg) + logger.warning(f"[Choice] 分支 {choice.branch_id} 条件处理失败: {msg}") + continue # 更新有效条件 choice.conditions = valid_conditions - valid_choices.append(choice.dict()) + valid_choices.append(choice) except ValueError as e: logger.warning("分支 %s 处理失败: %s,已跳过", choice.branch_id, str(e)) @@ -95,13 +140,11 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): """执行Choice工具""" # 解析输入数据 data = ChoiceInput(**input_data) - ret: CallOutputChunk = CallOutputChunk( - type=CallOutputType.DATA, - content=None, - ) - condition_handler = ConditionHandler() try: - ret.content = condition_handler.handler(data.choices) - yield ret + branch_id = ConditionHandler.handler(data.choices) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=ChoiceOutput(branch_id=branch_id).model_dump(exclude_none=True, by_alias=True), + ) except Exception as e: raise CallError(message=f"选择工具调用失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/choice/condition_handler.py b/apps/scheduler/call/choice/condition_handler.py index feab5c647..7542f2944 100644 --- a/apps/scheduler/call/choice/condition_handler.py +++ b/apps/scheduler/call/choice/condition_handler.py @@ -1,19 +1,79 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """处理条件分支的工具""" + import logging from pydantic import BaseModel -from apps.scheduler.call.choice.schema import ChoiceBranch, ChoiceOutput, Condition, Logic, Operator, Type, Value +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) + +from apps.scheduler.call.choice.schema import ( + ChoiceBranch, + Condition, + Logic, + Value +) logger = logging.getLogger(__name__) class ConditionHandler(BaseModel): """条件分支处理器""" + @staticmethod + async def get_value_type_from_operate(operate: NumberOperate | StringOperate | ListOperate | + BoolOperate | DictOperate) -> Type: + """获取右值的类型""" + if isinstance(operate, NumberOperate): + return Type.NUMBER + if operate in [ + StringOperate.EQUAL, StringOperate.NOT_EQUAL, StringOperate.CONTAINS, StringOperate.NOT_CONTAINS, + StringOperate.STARTS_WITH, StringOperate.ENDS_WITH, StringOperate.REGEX_MATCH]: + return Type.STRING + if operate in [StringOperate.LENGTH_EQUAL, StringOperate.LENGTH_GREATER_THAN, + StringOperate.LENGTH_GREATER_THAN_OR_EQUAL, StringOperate.LENGTH_LESS_THAN, + StringOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [ListOperate.EQUAL, ListOperate.NOT_EQUAL]: + return Type.LIST + if operate in [ListOperate.CONTAINS, ListOperate.NOT_CONTAINS]: + return Type.STRING + if operate in [ListOperate.LENGTH_EQUAL, ListOperate.LENGTH_GREATER_THAN, + ListOperate.LENGTH_GREATER_THAN_OR_EQUAL, ListOperate.LENGTH_LESS_THAN, + ListOperate.LENGTH_LESS_THAN_OR_EQUAL]: + return Type.NUMBER + if operate in [BoolOperate.EQUAL, BoolOperate.NOT_EQUAL]: + return Type.BOOL + if operate in [DictOperate.EQUAL, DictOperate.NOT_EQUAL]: + return Type.DICT + if operate in [DictOperate.CONTAINS_KEY, DictOperate.NOT_CONTAINS_KEY]: + return Type.STRING + return None + + @staticmethod + def check_value_type(value: Value, expected_type: Type) -> bool: + """检查值的类型是否符合预期""" + if expected_type == Type.STRING and isinstance(value.value, str): + return True + if expected_type == Type.NUMBER and isinstance(value.value, (int, float)): + return True + if expected_type == Type.LIST and isinstance(value.value, list): + return True + if expected_type == Type.DICT and isinstance(value.value, dict): + return True + if expected_type == Type.BOOL and isinstance(value.value, bool): + return True + return False - def handler(self, choices: list[ChoiceBranch]) -> ChoiceOutput: + @staticmethod + def handler(choices: list[ChoiceBranch]) -> str: """处理条件""" default_branch = [c for c in choices if c.is_default] @@ -22,7 +82,7 @@ class ConditionHandler(BaseModel): if block_judgement.is_default: continue for condition in block_judgement.conditions: - result = self._judge_condition(condition) + result = ConditionHandler._judge_condition(condition) results.append(result) if block_judgement.logic == Logic.AND: final_result = all(results) @@ -30,58 +90,55 @@ class ConditionHandler(BaseModel): final_result = any(results) if final_result: - return { - "branch_id": block_judgement.branch_id, - "message": f"选择分支:{block_judgement.branch_id}", - } + return block_judgement.branch_id # 如果没有匹配的分支,选择默认分支 if default_branch: - return { - "branch_id": default_branch[0].branch_id, - "message": f"选择默认分支:{default_branch[0].branch_id}", - } - return { - "branch_id": "", - "message": "没有匹配的分支,且没有默认分支", - } - - def _judge_condition(self, condition: Condition) -> bool: + return default_branch[0].branch_id + return "" + + @staticmethod + def _judge_condition(condition: Condition) -> bool: """ 判断条件是否成立。 Args: - condition (Condition): 'left', 'operator', 'right', 'type' + condition (Condition): 'left', 'operate', 'right', 'type' Returns: bool """ left = condition.left - operator = condition.operator + operate = condition.operate right = condition.right value_type = condition.type result = None if value_type == Type.STRING: - result = self._judge_string_condition(left, operator, right) - elif value_type == Type.INT: - result = self._judge_int_condition(left, operator, right) + result = ConditionHandler._judge_string_condition(left, operate, right) + elif value_type == Type.NUMBER: + result = ConditionHandler._judge_int_condition(left, operate, right) elif value_type == Type.BOOL: - result = self._judge_bool_condition(left, operator, right) + result = ConditionHandler._judge_bool_condition(left, operate, right) + elif value_type == Type.LIST: + result = ConditionHandler._judge_list_condition(left, operate, right) + elif value_type == Type.DICT: + result = ConditionHandler._judge_dict_condition(left, operate, right) else: logger.error("不支持的数据类型: %s", value_type) msg = f"不支持的数据类型: {value_type}" raise ValueError(msg) return result - def _judge_string_condition(self, left: Value, operator: Operator, right: Value) -> bool: + @staticmethod + def _judge_string_condition(left: Value, operate: StringOperate, right: Value) -> bool: """ 判断字符串类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -95,39 +152,41 @@ class ConditionHandler(BaseModel): raise TypeError(msg) right_value = right.value result = False - if operator == Operator.EQUAL: - result = left_value == right_value - elif operator == Operator.NEQUAL: - result = left_value != right_value - elif operator == Operator.GREAT: - result = len(left_value) > len(right_value) - elif operator == Operator.GREAT_EQUALS: - result = len(left_value) >= len(right_value) - elif operator == Operator.LESS: - result = len(left_value) < len(right_value) - elif operator == Operator.LESS_EQUALS: - result = len(left_value) <= len(right_value) - elif operator == Operator.GREATER: - result = left_value > right_value - elif operator == Operator.GREATER_EQUALS: - result = left_value >= right_value - elif operator == Operator.SMALLER: - result = left_value < right_value - elif operator == Operator.SMALLER_EQUALS: - result = left_value <= right_value - elif operator == Operator.CONTAINS: - result = right_value in left_value - elif operator == Operator.NOT_CONTAINS: - result = right_value not in left_value - return result + if operate == StringOperate.EQUAL: + return left_value == right_value + elif operate == StringOperate.NOT_EQUAL: + return left_value != right_value + elif operate == StringOperate.CONTAINS: + return right_value in left_value + elif operate == StringOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == StringOperate.STARTS_WITH: + return left_value.startswith(right_value) + elif operate == StringOperate.ENDS_WITH: + return left_value.endswith(right_value) + elif operate == StringOperate.REGEX_MATCH: + import re + return bool(re.match(right_value, left_value)) + elif operate == StringOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == StringOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == StringOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == StringOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == StringOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False - def _judge_int_condition(self, left: Value, operator: Operator, right: Value) -> bool: # noqa: PLR0911 + @staticmethod + def _judge_number_condition(left: Value, operate: NumberOperate, right: Value) -> bool: # noqa: PLR0911 """ - 判断整数类型的条件。 + 判断数字类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -135,32 +194,33 @@ class ConditionHandler(BaseModel): """ left_value = left.value - if not isinstance(left_value, int): - logger.error("左值不是整数类型: %s", left_value) - msg = "左值必须是整数类型" + if not isinstance(left_value, (int, float)): + logger.error("左值不是数字类型: %s", left_value) + msg = "左值必须是数字类型" raise TypeError(msg) right_value = right.value - if operator == Operator.EQUAL: + if operate == NumberOperate.EQUAL: return left_value == right_value - if operator == Operator.NEQUAL: + elif operate == NumberOperate.NOT_EQUAL: return left_value != right_value - if operator == Operator.GREAT: + elif operate == NumberOperate.GREATER_THAN: return left_value > right_value - if operator == Operator.GREAT_EQUALS: - return left_value >= right_value - if operator == Operator.LESS: + elif operate == NumberOperate.LESS_THAN: # noqa: PLR2004 return left_value < right_value - if operator == Operator.LESS_EQUALS: + elif operate == NumberOperate.GREATER_THAN_OR_EQUAL: + return left_value >= right_value + elif operate == NumberOperate.LESS_THAN_OR_EQUAL: return left_value <= right_value return False - def _judge_bool_condition(self, left: Value, operator: Operator, right: Value) -> bool: + @staticmethod + def _judge_bool_condition(left: Value, operate: BoolOperate, right: Value) -> bool: """ 判断布尔类型的条件。 Args: left (Value): 左值,包含 'value' 键。 - operator (Operator): 操作符 + operate (Operate): 操作符 right (Value): 右值,包含 'value' 键。 Returns: @@ -173,12 +233,82 @@ class ConditionHandler(BaseModel): msg = "左值必须是布尔类型" raise TypeError(msg) right_value = right.value - if operator == Operator.EQUAL: + if operate == BoolOperate.EQUAL: + return left_value == right_value + elif operate == BoolOperate.NOT_EQUAL: + return left_value != right_value + elif operate == BoolOperate.IS_EMPTY: + return not left_value + elif operate == BoolOperate.NOT_EMPTY: + return left_value + return False + + @staticmethod + def _judge_list_condition(left: Value, operate: ListOperate, right: Value): + """ + 判断列表类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, list): + logger.error("左值不是列表类型: %s", left_value) + msg = "左值必须是列表类型" + raise TypeError(msg) + right_value = right.value + if operate == ListOperate.EQUAL: + return left_value == right_value + elif operate == ListOperate.NOT_EQUAL: + return left_value != right_value + elif operate == ListOperate.CONTAINS: + return right_value in left_value + elif operate == ListOperate.NOT_CONTAINS: + return right_value not in left_value + elif operate == ListOperate.LENGTH_EQUAL: + return len(left_value) == right_value + elif operate == ListOperate.LENGTH_GREATER_THAN: + return len(left_value) > right_value + elif operate == ListOperate.LENGTH_GREATER_THAN_OR_EQUAL: + return len(left_value) >= right_value + elif operate == ListOperate.LENGTH_LESS_THAN: + return len(left_value) < right_value + elif operate == ListOperate.LENGTH_LESS_THAN_OR_EQUAL: + return len(left_value) <= right_value + return False + + @staticmethod + def _judge_dict_condition(left: Value, operate: DictOperate, right: Value): + """ + 判断字典类型的条件。 + + Args: + left (Value): 左值,包含 'value' 键。 + operate (Operate): 操作符 + right (Value): 右值,包含 'value' 键。 + + Returns: + bool + + """ + left_value = left.value + if not isinstance(left_value, dict): + logger.error("左值不是字典类型: %s", left_value) + msg = "左值必须是字典类型" + raise TypeError(msg) + right_value = right.value + if operate == DictOperate.EQUAL: return left_value == right_value - if operator == Operator.NEQUAL: + elif operate == DictOperate.NOT_EQUAL: return left_value != right_value - if operator == Operator.IS_EMPTY: - return left_value == "" - if operator == Operator.NOT_EMPTY: - return left_value != "" + elif operate == DictOperate.CONTAINS_KEY: + return right_value in left_value + elif operate == DictOperate.NOT_CONTAINS_KEY: + return right_value not in left_value return False diff --git a/apps/scheduler/call/choice/schema.py b/apps/scheduler/call/choice/schema.py index ed1c628cf..b95b16687 100644 --- a/apps/scheduler/call/choice/schema.py +++ b/apps/scheduler/call/choice/schema.py @@ -1,30 +1,20 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Choice Call的输入和输出""" +import uuid from enum import Enum from pydantic import BaseModel, Field +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.scheduler.call.core import DataBase -class Operator(str, Enum): - """Choice Call支持的运算符""" - - EQUAL = "equal" - NEQUAL = "not_equal" - GREAT = "great" - GREAT_EQUALS = "great_equals" - LESS = "less" - LESS_EQUALS = "less_equals" - # string - CONTAINS = "contains" - NOT_CONTAINS = "not_contains" - GREATER = "greater" - GREATER_EQUALS = "greater_equals" - SMALLER = "smaller" - SMALLER_EQUALS = "smaller_equals" - # bool - IS_EMPTY = "is_empty" - NOT_EMPTY = "not_empty" class Logic(str, Enum): @@ -34,38 +24,31 @@ class Logic(str, Enum): OR = "or" -class Type(str, Enum): - """Choice 工具支持的类型""" - - STRING = "string" - INT = "int" - BOOL = "bool" - - -class Value(BaseModel): +class Value(DataBase): """值的结构""" - step_id: str = Field(description="步骤id", default="") - value: str | int | bool = Field(description="值", default=None) + step_id: str | None = Field(description="步骤id", default=None) + type: Type | None = Field(description="值的类型", default=None) + value: str | float | int | bool | list | dict | None = Field(description="值", default=None) -class Condition(BaseModel): +class Condition(DataBase): """单个条件""" - type: Type = Field(description="值的类型", default=Type.STRING) - left: Value = Field(description="左值") - right: Value = Field(description="右值") - operator: Operator = Field(description="运算符", default="equal") - id: int = Field(description="条件ID") + left: Value = Field(description="左值", default=Value()) + right: Value = Field(description="右值", default=Value()) + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate | None = Field( + description="运算符", default=None) + id: str = Field(description="条件ID", default_factory=lambda: str(uuid.uuid4())) -class ChoiceBranch(BaseModel): +class ChoiceBranch(DataBase): """子分支""" - branch_id: str = Field(description="分支ID", default="") + branch_id: str = Field(description="分支ID", default_factory=lambda: str(uuid.uuid4())) logic: Logic = Field(description="逻辑运算符", default=Logic.AND) conditions: list[Condition] = Field(description="条件列表", default=[]) - is_default: bool = Field(description="是否为默认分支", default=False) + is_default: bool = Field(description="是否为默认分支", default=True) class ChoiceInput(DataBase): @@ -76,3 +59,5 @@ class ChoiceInput(DataBase): class ChoiceOutput(DataBase): """Choice Call的输出""" + + branch_id: str = Field(description="分支ID", default="") diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 2b1cbba83..af28c6a34 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -76,21 +76,18 @@ class CoreCall(BaseModel): extra="allow", ) - def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: """初始化子类""" super().__init_subclass__(**kwargs) cls.input_model = input_model cls.output_model = output_model - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" err = "[CoreCall] 必须手动实现info方法" raise NotImplementedError(err) - @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" @@ -120,7 +117,6 @@ class CoreCall(BaseModel): summary=executor.task.runtime.summary, ) - @staticmethod def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any: """ @@ -131,18 +127,16 @@ class CoreCall(BaseModel): :return: 变量 """ split_path = path.split("/") + if len(split_path) < 2: + err = f"[CoreCall] 路径格式错误: {path}" + logger.error(err) + return None if split_path[0] not in history: err = f"[CoreCall] 步骤{split_path[0]}不存在" logger.error(err) - raise CallError( - message=err, - data={ - "step_id": split_path[0], - }, - ) - + return None data = history[split_path[0]].output_data - for key in split_path[1:]: + for key in split_path[2:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" logger.error(err) @@ -156,7 +150,6 @@ class CoreCall(BaseModel): data = data[key] return data - @classmethod async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: """实例化Call类""" @@ -170,36 +163,30 @@ class CoreCall(BaseModel): await obj._set_input(executor) return obj - async def _set_input(self, executor: "StepExecutor") -> None: """获取Call的输入""" self._sys_vars = self._assemble_call_vars(executor) input_data = await self._init(self._sys_vars) self.input = input_data.model_dump(by_alias=True, exclude_none=True) - async def _init(self, call_vars: CallVars) -> DataBase: """初始化Call类,并返回Call的输入""" err = "[CoreCall] 初始化方法必须手动实现" raise NotImplementedError(err) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的流式输出方法""" yield CallOutputChunk(type=CallOutputType.TEXT, content="") - async def _after_exec(self, input_data: dict[str, Any]) -> None: """Call类实例的执行后方法""" - async def exec(self, executor: "StepExecutor", input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的执行方法""" async for chunk in self._exec(input_data): yield chunk await self._after_exec(input_data) - async def _llm(self, messages: list[dict[str, Any]]) -> str: """Call可直接使用的LLM非流式调用""" result = "" @@ -210,7 +197,6 @@ class CoreCall(BaseModel): self.output_tokens = llm.output_tokens return result - async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel: """Call可直接使用的JSON生成""" json = FunctionLLM() diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index d8d22c46f..a86ec4ac0 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -47,7 +47,6 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - async def load_state(self) -> None: """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") @@ -70,7 +69,6 @@ class FlowExecutor(BaseExecutor): self._reached_end: bool = False self.step_queue: deque[StepQueueItem] = deque() - async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" # 创建步骤Runner @@ -90,7 +88,6 @@ class FlowExecutor(BaseExecutor): # 更新Task(已存过库) self.task = step_runner.task - async def _step_process(self) -> None: """执行当前queue里面的所有步骤(在用户看来是单一Step)""" while True: @@ -102,7 +99,6 @@ class FlowExecutor(BaseExecutor): # 执行Step await self._invoke_runner(queue_item) - async def _find_next_id(self, step_id: str) -> list[str]: """查找下一个节点""" next_ids = [] @@ -111,15 +107,14 @@ class FlowExecutor(BaseExecutor): next_ids += [edge.edge_to] return next_ids - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] if self.task.state.step_name == "Choice": # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"].get("branch_id", None) + branch_id = self.task.context[-1]["output_data"]["branch_id"] if branch_id: self.task.state.step_id = self.task.state.step_id + "." + branch_id logger.info("[FlowExecutor] 分支ID:%s", branch_id) @@ -127,7 +122,7 @@ class FlowExecutor(BaseExecutor): logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") return [] - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -146,7 +141,6 @@ class FlowExecutor(BaseExecutor): for next_step in next_steps ] - async def run(self) -> None: """ 运行流,返回各步骤结果,直到无法继续执行 @@ -159,8 +153,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -179,7 +173,7 @@ class FlowExecutor(BaseExecutor): # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] + if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -192,7 +186,7 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], # type: ignore[arg-type] ), }, ), diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 506f3bb10..f3aeb82c3 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -119,7 +119,6 @@ class StepExecutor(BaseExecutor): logger.exception("[StepExecutor] 初始化Call失败") raise - async def _run_slot_filling(self) -> None: """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断是否需要进行自动参数填充 @@ -170,7 +169,6 @@ class StepExecutor(BaseExecutor): self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens - async def _process_chunk( self, iterator: AsyncGenerator[CallOutputChunk, None], @@ -202,7 +200,6 @@ class StepExecutor(BaseExecutor): return content - async def run(self) -> None: """运行单个步骤""" self.validate_flow_state(self.task) diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index 4caab4d2d..89433cade 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -12,6 +12,8 @@ from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema.validators import extend +from apps.schemas.response_data import ParamsNode +from apps.scheduler.call.choice.schema import Type from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -221,6 +223,45 @@ class Slot: return data return _extract_type_desc(self._schema) + def get_params_node_from_schema(self, root: str = "") -> ParamsNode: + """从JSON Schema中提取ParamsNode""" + def _extract_params_node(schema_node: dict[str, Any], name: str = "", path: str = "") -> ParamsNode: + """递归提取ParamsNode""" + if "type" not in schema_node: + return None + + param_type = schema_node["type"] + if param_type == "object": + param_type = Type.DICT + elif param_type == "array": + param_type = Type.LIST + elif param_type == "string": + param_type = Type.STRING + elif param_type == "number": + param_type = Type.NUMBER + elif param_type == "boolean": + param_type = Type.BOOL + else: + logger.warning(f"[Slot] 不支持的参数类型: {param_type}") + return None + sub_params = [] + + if param_type == "object" and "properties" in schema_node: + for key, value in schema_node["properties"].items(): + sub_params.append(_extract_params_node(value, name=key, path=f"{path}/{key}")) + else: + # 对于非对象类型,直接返回空子参数 + sub_params = None + return ParamsNode(paramName=name, + paramPath=path, + paramType=param_type, + subParams=sub_params) + try: + return _extract_params_node(self._schema, name=root, path=root) + except Exception as e: + logger.error(f"[Slot] 提取ParamsNode失败: {e!s}\n{traceback.format_exc()}") + return None + def _flatten_schema(self, schema: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: """将JSON Schema扁平化""" result = {} @@ -276,7 +317,6 @@ class Slot: logger.exception("[Slot] 错误schema不合法: %s", error.schema) return {}, [] - def _assemble_patch( self, key: str, @@ -329,7 +369,6 @@ class Slot: logger.info("[Slot] 组装patch: %s", patch_list) return patch_list - def convert_json(self, json_data: str | dict[str, Any]) -> dict[str, Any]: """将用户手动填充的参数专为真实JSON""" json_dict = json.loads(json_data) if isinstance(json_data, str) else json_data diff --git a/apps/schemas/config.py b/apps/schemas/config.py index 99bcccde8..e91f5f75b 100644 --- a/apps/schemas/config.py +++ b/apps/schemas/config.py @@ -129,7 +129,7 @@ class ExtraConfig(BaseModel): class ConfigModel(BaseModel): """配置文件的校验Class""" - no_auth: NoauthConfig + no_auth: NoauthConfig = Field(description="无认证配置", default=NoauthConfig()) deploy: DeployConfig login: LoginConfig embedding: EmbeddingConfig diff --git a/apps/schemas/parameters.py b/apps/schemas/parameters.py new file mode 100644 index 000000000..bd908d237 --- /dev/null +++ b/apps/schemas/parameters.py @@ -0,0 +1,69 @@ +from enum import Enum + + +class NumberOperate(str, Enum): + """Choice 工具支持的数字运算符""" + + EQUAL = "number_equal" + NOT_EQUAL = "number_not_equal" + GREATER_THAN = "number_greater_than" + LESS_THAN = "number_less_than" + GREATER_THAN_OR_EQUAL = "number_greater_than_or_equal" + LESS_THAN_OR_EQUAL = "number_less_than_or_equal" + + +class StringOperate(str, Enum): + """Choice 工具支持的字符串运算符""" + + EQUAL = "string_equal" + NOT_EQUAL = "string_not_equal" + CONTAINS = "string_contains" + NOT_CONTAINS = "string_not_contains" + STARTS_WITH = "string_starts_with" + ENDS_WITH = "string_ends_with" + LENGTH_EQUAL = "string_length_equal" + LENGTH_GREATER_THAN = "string_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "string_length_greater_than_or_equal" + LENGTH_LESS_THAN = "string_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "string_length_less_than_or_equal" + REGEX_MATCH = "string_regex_match" + + +class ListOperate(str, Enum): + """Choice 工具支持的列表运算符""" + + EQUAL = "list_equal" + NOT_EQUAL = "list_not_equal" + CONTAINS = "list_contains" + NOT_CONTAINS = "list_not_contains" + LENGTH_EQUAL = "list_length_equal" + LENGTH_GREATER_THAN = "list_length_greater_than" + LENGTH_GREATER_THAN_OR_EQUAL = "list_length_greater_than_or_equal" + LENGTH_LESS_THAN = "list_length_less_than" + LENGTH_LESS_THAN_OR_EQUAL = "list_length_less_than_or_equal" + + +class BoolOperate(str, Enum): + """Choice 工具支持的布尔运算符""" + + EQUAL = "bool_equal" + NOT_EQUAL = "bool_not_equal" + + +class DictOperate(str, Enum): + """Choice 工具支持的字典运算符""" + + EQUAL = "dict_equal" + NOT_EQUAL = "dict_not_equal" + CONTAINS_KEY = "dict_contains_key" + NOT_CONTAINS_KEY = "dict_not_contains_key" + + +class Type(str, Enum): + """Choice 工具支持的类型""" + + STRING = "string" + NUMBER = "number" + LIST = "list" + DICT = "dict" + BOOL = "bool" diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 20d7ad9b0..b1dc77b75 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -14,6 +14,14 @@ from apps.schemas.flow_topology import ( NodeServiceItem, PositionItem, ) +from apps.schemas.parameters import ( + Type, + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, +) from apps.schemas.mcp import MCPInstallStatus, MCPTool, MCPType from apps.schemas.record import RecordData from apps.schemas.user import UserInfo @@ -629,20 +637,41 @@ class ListLLMRsp(ResponseData): result: list[LLMProviderInfo] = Field(default=[], title="Result") -class Params(BaseModel): + +class ParamsNode(BaseModel): """参数数据结构""" + param_name: str = Field(..., description="参数名称", alias="paramName") + param_path: str = Field(..., description="参数路径", alias="paramPath") + param_type: Type = Field(..., description="参数类型", alias="paramType") + sub_params: list["ParamsNode"] | None = Field( + default=None, description="子参数列表", alias="subParams" + ) - id: str = Field(..., description="StepID") - name: str = Field(..., description="Step名称") - parameters: dict[str, Any] = Field(..., description="参数") - operate: str = Field(..., description="比较符") -class GetParamsMsg(BaseModel): - """GET /api/params 返回数据结构""" +class StepParams(BaseModel): + """参数数据结构""" + step_id: str = Field(..., description="步骤ID", alias="stepId") + name: str = Field(..., description="Step名称") + params_node: ParamsNode | None = Field( + default=None, description="参数节点", alias="paramsNode") - result: list[Params] = Field(..., title="Result") class GetParamsRsp(ResponseData): """GET /api/params 返回数据结构""" - result: GetParamsMsg \ No newline at end of file + result: list[StepParams] = Field( + default=[], description="参数列表", alias="result" + ) + + +class OperateAndBindType(BaseModel): + """操作和绑定类型数据结构""" + + operate: NumberOperate | StringOperate | ListOperate | BoolOperate | DictOperate = Field(description="操作类型") + bind_type: Type = Field(description="绑定类型") + + +class GetOperaRsp(ResponseData): + """GET /api/operate 返回数据结构""" + + result: list[OperateAndBindType] = Field(..., title="Result") diff --git a/apps/services/flow.py b/apps/services/flow.py index cb8ad57fd..c9fd86fd8 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -20,7 +20,6 @@ from apps.schemas.flow_topology import ( PositionItem, ) from apps.services.node import NodeManager -from apps.schemas.response_data import Params logger = logging.getLogger(__name__) @@ -470,48 +469,3 @@ class FlowManager: return False else: return True - - @staticmethod - async def get_params_by_flow_and_step_id( - flow: FlowItem, step_id: str - ) -> list[Params] | None: - """递归收集指定节点之前所有路径上的节点参数""" - params = [] - collected = set() # 记录已收集参数的节点 - - async def backtrack(current_id: str, visited: set) -> None: - # 避免循环递归 - if current_id in visited: - return - visited.add(current_id) - - # 获取所有指向当前节点的边 - incoming_edges = [ - edge for edge in flow.edges if edge.target_node == current_id - ] - - for edge in incoming_edges: - source_id = edge.source_node - - # 跳过起始节点 - if source_id == "start": - continue - - # 收集当前节点的参数(如果未被收集过) - if source_id not in collected: - node = flow.nodes.get(source_id) - if node: - collected.add(source_id) - params.append( - Params( - id=source_id, - name=node.name, - parameters=node.parameters.get("parameters", {}), - ), - ) - - # 继续回溯,传递当前路径的visited集合副本 - await backtrack(source_id, visited.copy()) - - await backtrack(step_id, set()) - return params diff --git a/apps/services/node.py b/apps/services/node.py index 6f0d492ed..bf48e71fe 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -16,6 +16,7 @@ NODE_TYPE_MAP = { "API": APINode, } + class NodeManager: """Node管理器""" @@ -29,7 +30,6 @@ class NodeManager: raise ValueError(err) return node["call_id"] - @staticmethod async def get_node(node_id: str) -> NodePool: """获取Node的类型""" @@ -40,7 +40,6 @@ class NodeManager: raise ValueError(err) return NodePool.model_validate(node) - @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" @@ -52,7 +51,6 @@ class NodeManager: return "" return node_doc["name"] - @staticmethod def merge_params_schema(params_schema: dict[str, Any], known_params: dict[str, Any]) -> dict[str, Any]: """递归合并参数Schema,将known_params中的值填充到params_schema的对应位置""" @@ -75,7 +73,6 @@ class NodeManager: return params_schema - @staticmethod async def get_node_params(node_id: str) -> tuple[dict[str, Any], dict[str, Any]]: """获取Node数据""" @@ -100,7 +97,6 @@ class NodeManager: err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) raise ValueError(err) - # 返回参数Schema return ( NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), diff --git a/apps/services/parameter.py b/apps/services/parameter.py new file mode 100644 index 000000000..ae375e97f --- /dev/null +++ b/apps/services/parameter.py @@ -0,0 +1,86 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""flow Manager""" + +import logging + +from pymongo import ASCENDING + +from apps.services.node import NodeManager +from apps.schemas.flow_topology import FlowItem +from apps.scheduler.slot.slot import Slot +from apps.scheduler.call.choice.condition_handler import ConditionHandler +from apps.scheduler.call.choice.schema import ( + NumberOperate, + StringOperate, + ListOperate, + BoolOperate, + DictOperate, + Type +) +from apps.schemas.response_data import ( + OperateAndBindType, + ParamsNode, + StepParams, +) +from apps.services.node import NodeManager +logger = logging.getLogger(__name__) + + +class ParameterManager: + """Parameter Manager""" + @staticmethod + async def get_operate_and_bind_type(param_type: Type) -> list[OperateAndBindType]: + """Get operate and bind type""" + result = [] + operate = None + if param_type == Type.NUMBER: + operate = NumberOperate + elif param_type == Type.STRING: + operate = StringOperate + elif param_type == Type.LIST: + operate = ListOperate + elif param_type == Type.BOOL: + operate = BoolOperate + elif param_type == Type.DICT: + operate = DictOperate + if operate: + for item in operate: + result.append(OperateAndBindType( + operate=item, + bind_type=ConditionHandler.get_value_type_from_operate(item))) + return result + + @staticmethod + async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]: + """Get pre params by flow and step id""" + index = 0 + q = [step_id] + in_edges = {} + step_id_to_node_id = {} + for step in flow.nodes: + step_id_to_node_id[step.step_id] = step.node_id + for edge in flow.edges: + if edge.target_node not in in_edges: + in_edges[edge.target_node] = [] + in_edges[edge.target_node].append(edge.source_node) + while index < len(q): + tmp_step_id = q[index] + index += 1 + for i in range(len(in_edges.get(tmp_step_id, []))): + pre_node_id = in_edges[tmp_step_id][i] + if pre_node_id not in q: + q.append(pre_node_id) + pre_step_params = [] + for step_id in q: + node_id = step_id_to_node_id.get(step_id) + params_schema, output_schema = await NodeManager.get_node_params(node_id) + slot = Slot(output_schema) + params_node = slot.get_params_node_from_schema(root='/output') + pre_step_params.append( + StepParams( + stepId=node_id, + name=params_schema.get("name", ""), + paramsNode=params_node + ) + ) + return pre_step_params -- Gitee