From daed8e1fe874c034aa8e9e1b335963b3ff206abc Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 14:58:14 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0reason=5Ftype=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/request_data.py | 2 +- apps/manager/blacklist.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index de9ab62a9..fdc8b710f 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -70,7 +70,7 @@ class AbuseRequest(BaseModel): record_id: str reason: str - reason_type: list[str] + reason_type: str class AbuseProcessRequest(BaseModel): diff --git a/apps/manager/blacklist.py b/apps/manager/blacklist.py index b895496cb..57c53efc9 100644 --- a/apps/manager/blacklist.py +++ b/apps/manager/blacklist.py @@ -168,7 +168,7 @@ class AbuseManager: """用户举报相关操作""" @staticmethod - async def change_abuse_report(user_sub: str, record_id: str, reason_type: list[str], reason: str) -> bool: + async def change_abuse_report(user_sub: str, record_id: str, reason_type: str, reason: str) -> bool: """存储用户举报详情""" try: # 判断record_id是否合法 -- Gitee From cc19973e8c3b4b6dc21116b528a967f3c86e7f72 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 14:59:13 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0Executor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/flow.py | 54 +++++++++------------------------ apps/scheduler/executor/step.py | 47 +++++++++++++++++++--------- 2 files changed, 46 insertions(+), 55 deletions(-) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index a588431ab..c36b1eb5c 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -84,10 +84,7 @@ class FlowExecutor(BaseExecutor): async def _invoke_runner(self, queue_item: StepQueueItem) -> None: """单一Step执行""" - if not self.task.state: - err = "[FlowExecutor] 当前ExecutorState为空" - logger.error(err) - raise ValueError(err) + self.validate_flow_state(self.task) # 创建步骤Runner step_runner = StepExecutor( @@ -101,12 +98,8 @@ class FlowExecutor(BaseExecutor): # 初始化步骤 await step_runner.init() - - # 运行Step,并判断是否需要输出 - if await self._is_to_user(queue_item.step_id, queue_item.step): - await step_runner.run_step(to_user=True) - else: - await step_runner.run_step() + # 运行Step + await step_runner.run_step() # 更新Task self.task = step_runner.task @@ -115,10 +108,7 @@ class FlowExecutor(BaseExecutor): async def _step_process(self) -> None: """单一Step执行""" - if not self.task.state: - err = "[FlowExecutor] 当前ExecutorState为空" - logger.error(err) - raise ValueError(err) + self.validate_flow_state(self.task) while True: try: @@ -139,30 +129,15 @@ class FlowExecutor(BaseExecutor): return next_ids - async def _is_to_user(self, step_id: str, step: Step) -> bool: - """判断是否需要输出""" - # 最后一步或者Choice类型,任何情况下都不输出 - if step.type in [ - SpecialCallType.CHOICE.value, - ]: - return False - - next_steps = await self._find_next_id(step_id) - return bool(len(next_steps) == 1 and self.flow.steps[next_steps[0]].type == SpecialCallType.OUTPUT.value) - - async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" - if not self.task.state: - err = "[StepExecutor] 当前ExecutorState为空" - logger.error(err) - raise ValueError(err) + self.validate_flow_state(self.task) # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: + if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] return [] - next_steps = await self._find_next_id(self.task.state.step_id) + next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -188,10 +163,7 @@ class FlowExecutor(BaseExecutor): 数据通过向Queue发送消息的方式传输 """ - if not self.task.state: - err = "[FlowExecutor] 当前ExecutorState为空" - logger.error(err) - raise ValueError(err) + self.validate_flow_state(self.task) logger.info("[FlowExecutor] 运行工作流") # 推送Flow开始消息 @@ -199,8 +171,8 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, - step=self.flow.steps[self.task.state.step_id], + step_id=self.task.state.step_id, # type: ignore[arg-type] + step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] ) # 头插开始前的系统步骤,并执行 @@ -209,6 +181,7 @@ class FlowExecutor(BaseExecutor): step_id=str(uuid.uuid4()), step=step, enable_filling=False, + to_user=False, )) await self._step_process() @@ -218,13 +191,14 @@ class FlowExecutor(BaseExecutor): # 运行Flow(未达终点) while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.status == StepStatus.ERROR: + if self.task.state.status == StepStatus.ERROR: # type: ignore[arg-type] logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() - self.step_queue.append(StepQueueItem( + self.step_queue.appendleft(StepQueueItem( step_id=str(uuid.uuid4()), step=ERROR_STEP, enable_filling=False, + to_user=True, )) # 错误处理后结束 self._reached_end = True diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 0ca0a6f47..1614e224b 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -2,6 +2,7 @@ import logging import uuid +from collections.abc import AsyncGenerator from typing import Any from pydantic import ConfigDict @@ -12,7 +13,6 @@ from apps.entities.enum_var import ( StepStatus, ) from apps.entities.message import TextAddContent -from apps.entities.pool import NodePool from apps.entities.scheduler import CallOutputChunk from apps.entities.task import FlowStepHistory, StepQueueItem from apps.manager.node import NodeManager @@ -103,7 +103,6 @@ class StepExecutor(BaseExecutor): SpecialCallType.SUMMARY.value, SpecialCallType.FACTS.value, SpecialCallType.SLOT.value, - SpecialCallType.OUTPUT.value, SpecialCallType.EMPTY.value, SpecialCallType.START.value, SpecialCallType.END.value, @@ -144,24 +143,20 @@ class StepExecutor(BaseExecutor): await TaskManager.save_task(self.task.id, self.task) - async def run_step(self, *, to_user: bool = False) -> None: - """运行单个步骤""" - self.validate_flow_state(self.task) - logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name) - - # 推送输入 - await self.push_message(EventType.STEP_INPUT.value, self.input) - - # 执行步骤 - iterator = self.obj.exec(self, self.input) - + async def _process_chunk( + self, + iterator: AsyncGenerator[CallOutputChunk, None], + *, + to_user: bool = False, + ) -> str | dict[str, Any]: + """处理Chunk""" content: str | dict[str, Any] = "" + async for chunk in iterator: if not isinstance(chunk, CallOutputChunk): err = "[StepExecutor] 返回结果类型错误" logger.error(err) raise TypeError(err) - if isinstance(chunk.content, str): if not isinstance(content, str): content = "" @@ -170,7 +165,6 @@ class StepExecutor(BaseExecutor): if not isinstance(content, dict): content = {} content = chunk.content - if to_user: if isinstance(chunk.content, str): await self.push_message(EventType.TEXT_ADD.value, chunk.content) @@ -178,6 +172,29 @@ class StepExecutor(BaseExecutor): else: await self.push_message(self.step.step.type, chunk.content) + return content + + + async def run_step(self) -> None: + """运行单个步骤""" + self.validate_flow_state(self.task) + logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name) + + # 推送输入 + await self.push_message(EventType.STEP_INPUT.value, self.input) + + # 执行步骤 + iterator = self.obj.exec(self, self.input) + + try: + content = await self._process_chunk(iterator, to_user=self.obj.to_user) + except Exception: + logger.exception("[StepExecutor] 运行步骤失败") + self.task.state.status = StepStatus.ERROR # type: ignore[arg-type] + await self.push_message(EventType.STEP_OUTPUT.value, {}) + await TaskManager.save_task(self.task.id, self.task) + return + # 更新执行状态 self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] -- Gitee From e06aad998acf0f5d82c4834ce5224628a4177258 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 14:59:34 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E5=88=A0=E9=99=A4Output=E8=8A=82=E7=82=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/output.py | 34 ---------------------------------- 1 file changed, 34 deletions(-) delete mode 100644 apps/scheduler/call/output.py diff --git a/apps/scheduler/call/output.py b/apps/scheduler/call/output.py deleted file mode 100644 index 807c023fe..000000000 --- a/apps/scheduler/call/output.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -输出工具:将文字和结构化数据输出至前端 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -""" - -from collections.abc import AsyncGenerator -from typing import Any - -from apps.entities.enum_var import CallOutputType -from apps.entities.scheduler import CallInfo, CallOutputChunk, CallVars -from apps.scheduler.call.core import CoreCall, DataBase - - -class Output(CoreCall, input_type=DataBase, output_type=DataBase): - """输出工具""" - - @classmethod - def cls_info(cls) -> CallInfo: - """返回Call的名称和描述""" - return CallInfo(name="输出", description="将前一步骤的数据输出给用户") - - - async def _init(self, call_vars: CallVars) -> dict[str, Any]: - """初始化工具""" - return {} - - - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: - """执行工具""" - yield CallOutputChunk( - type=CallOutputType.DATA, - content={}, - ) -- Gitee From bee02f04e69456bb67ec2f81eb920c383d684649 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 15:05:23 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0Call?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/__init__.py | 4 +-- apps/scheduler/call/core.py | 43 ++++++++++++++++++++++++++------- apps/scheduler/call/empty.py | 2 +- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 2d5f919bc..924c42514 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -6,8 +6,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. from apps.scheduler.call.api.api import API from apps.scheduler.call.convert.convert import Convert +from apps.scheduler.call.graph.graph import Graph from apps.scheduler.call.llm.llm import LLM -from apps.scheduler.call.output import Output from apps.scheduler.call.rag.rag import RAG from apps.scheduler.call.sql.sql import SQL from apps.scheduler.call.suggest.suggest import Suggestion @@ -19,6 +19,6 @@ __all__ = [ "RAG", "SQL", "Convert", - "Output", + "Graph", "Suggestion", ] diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 7aab09e1e..7a90b08d7 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -10,6 +10,7 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, Field +from pydantic.json_schema import SkipJsonSchema from apps.entities.enum_var import CallOutputType from apps.entities.scheduler import ( @@ -40,24 +41,32 @@ class DataBase(BaseModel): class CoreCall(BaseModel): """所有Call的父类,所有Call必须继承此类。""" - name: str - description: str + name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True) + description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True) + input_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field( + description="Call的输入Pydantic类型;不包含override的模板", + exclude=True, + frozen=True, + ) + output_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field( + description="Call的输出Pydantic类型;不包含override的模板", + exclude=True, + frozen=True, + ) + + to_user: bool = Field(description="是否需要将输出返回给用户", default=False) model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", ) - input_type: ClassVar[type[DataBase]] = Field(description="Call的输入Pydantic类型", exclude=True, frozen=True) - """需要以消息形式输出的、需要大模型填充参数的输入信息填到这里""" - output_type: ClassVar[type[DataBase]] = Field(description="Call的输出Pydantic类型", exclude=True, frozen=True) - """需要以消息形式输出的输出信息填到这里""" - def __init_subclass__(cls, input_type: type[DataBase], output_type: type[DataBase], **kwargs: Any) -> None: + def __init_subclass__(cls, input_model: type[DataBase], output_model: type[DataBase], **kwargs: Any) -> None: """初始化子类""" super().__init_subclass__(**kwargs) - cls.input_type = input_type - cls.output_type = output_type + cls.input_model = input_model + cls.output_model = output_model @classmethod @@ -67,6 +76,22 @@ class CoreCall(BaseModel): raise NotImplementedError(err) + @property + def input_type(self) -> type[DataBase]: + """返回输入类型""" + class InputType(self.input_model): + pass + return InputType + + + @property + def output_type(self) -> type[DataBase]: + """返回输出类型""" + class OutputType(self._output_type_template): + pass + return OutputType + + @staticmethod def _assemble_call_vars(executor: "StepExecutor") -> CallVars: """组装CallVars""" diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py index 970a93950..84dc9ae79 100644 --- a/apps/scheduler/call/empty.py +++ b/apps/scheduler/call/empty.py @@ -12,7 +12,7 @@ from apps.entities.scheduler import CallInfo, CallOutputChunk, CallVars from apps.scheduler.call.core import CoreCall, DataBase -class Empty(CoreCall, input_type=DataBase, output_type=DataBase): +class Empty(CoreCall, input_model=DataBase, output_model=DataBase): """空Call""" @classmethod -- Gitee From 89faf50cfb32fc367860b4fdf5a37bf52fff2daa Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 15:06:47 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/convert/convert.py | 2 +- apps/scheduler/call/facts/facts.py | 2 +- apps/scheduler/call/graph/graph.py | 20 ++++++++++++++------ apps/scheduler/call/llm/llm.py | 2 +- apps/scheduler/call/rag/rag.py | 4 +--- 5 files changed, 18 insertions(+), 12 deletions(-) diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py index 95086cb90..f15fd2f88 100644 --- a/apps/scheduler/call/convert/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -22,7 +22,7 @@ from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput from apps.scheduler.call.core import CallOutputChunk, CoreCall -class Convert(CoreCall, input_type=ConvertInput, output_type=ConvertOutput): +class Convert(CoreCall, input_model=ConvertInput, output_model=ConvertOutput): """Convert 工具,用于对生成的文字信息和原始数据进行格式化""" text_template: str | None = Field(description="自然语言信息的格式化模板,jinja2语法", default=None) diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index e982c59d9..0c099729a 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor -class FactsCall(CoreCall, input_type=FactsInput, output_type=FactsOutput): +class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): """提取事实工具""" answer: str = Field(description="用户输入") diff --git a/apps/scheduler/call/graph/graph.py b/apps/scheduler/call/graph/graph.py index 38983216c..da0783b4b 100644 --- a/apps/scheduler/call/graph/graph.py +++ b/apps/scheduler/call/graph/graph.py @@ -6,27 +6,35 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. import json from collections.abc import AsyncGenerator -from typing import Any, ClassVar +from typing import Any from anyio import Path from pydantic import Field from apps.entities.enum_var import CallOutputType -from apps.entities.scheduler import CallError, CallOutputChunk, CallVars +from apps.entities.scheduler import ( + CallError, + CallInfo, + CallOutputChunk, + CallVars, +) from apps.scheduler.call.core import CoreCall from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput from apps.scheduler.call.graph.style import RenderStyle -class Graph(CoreCall, output_type=RenderOutput): +class Graph(CoreCall, input_model=RenderInput, output_model=RenderOutput): """Render Call,用于将SQL Tool查询出的数据转换为图表""" - name: ClassVar[str] = Field("图表", exclude=True, frozen=True) - description: ClassVar[str] = Field("将SQL查询出的数据转换为图表", exclude=True, frozen=True) - data: list[dict[str, Any]] = Field(description="用于绘制图表的数据", exclude=True, frozen=True) + @classmethod + def cls_info(cls) -> CallInfo: + """返回Call的名称和描述""" + return CallInfo(name="图表", description="将SQL查询出的数据转换为图表") + + async def _init(self, call_vars: CallVars) -> dict[str, Any]: """初始化Render Call,校验参数,读取option模板""" await super()._init(call_vars) diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 1ce2a6b47..a8dfd7909 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -28,7 +28,7 @@ from apps.scheduler.call.llm.schema import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMP logger = logging.getLogger(__name__) -class LLM(CoreCall, input_type=LLMInput, output_type=LLMOutput): +class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): """大模型调用工具""" temperature: float = Field(description="大模型温度(随机化程度)", default=0.7) diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index 92fff7d30..b73b21f8b 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -27,7 +27,7 @@ from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, RetrievalMode logger = logging.getLogger(__name__) -class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): +class RAG(CoreCall, input_model=RAGInput, output_model=RAGOutput): """RAG工具:查询知识库""" knowledge_base: str | None = Field(description="知识库的id", alias="kb_sn", default=None) @@ -41,8 +41,6 @@ class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): async def _init(self, call_vars: CallVars) -> dict[str, Any]: """初始化RAG工具""" - await super()._init(call_vars) - self._task_id = call_vars.ids.task_id return RAGInput( content=call_vars.question, -- Gitee From eff6899feeb67280d19dac4b464f0d7131e6502e Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 15:07:53 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0Call?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/sql/sql.py | 97 ++++++++++++++------------ apps/scheduler/call/suggest/suggest.py | 4 +- apps/scheduler/call/summary/summary.py | 2 +- 3 files changed, 57 insertions(+), 46 deletions(-) diff --git a/apps/scheduler/call/sql/sql.py b/apps/scheduler/call/sql/sql.py index e31b605b1..e3ef25c71 100644 --- a/apps/scheduler/call/sql/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -27,7 +27,7 @@ from apps.scheduler.call.sql.schema import SQLInput, SQLOutput logger = logging.getLogger(__name__) -class SQL(CoreCall, input_type=SQLInput, output_type=SQLOutput): +class SQL(CoreCall, input_model=SQLInput, output_model=SQLOutput): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" database_url: str = Field(description="数据库连接地址") @@ -49,10 +49,8 @@ class SQL(CoreCall, input_type=SQLInput, output_type=SQLOutput): ).model_dump(by_alias=True, exclude_none=True) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: - """运行SQL工具""" - # 获取必要参数 - data = SQLInput(**input_data) + async def _generate_sql(self, data: SQLInput) -> list[dict[str, Any]]: + """生成SQL语句列表""" post_data = { "database_url": self.database_url, "table_name_list": self.table_name_list, @@ -60,76 +58,89 @@ class SQL(CoreCall, input_type=SQLInput, output_type=SQLOutput): "topk": self.top_k, "use_llm_enhancements": self.use_llm_enhancements, } - headers = { - "Content-Type": "application/json", - } + headers = {"Content-Type": "application/json"} - # 生成sql,重试3次直到sql生成数量大于等于top_k sql_list = [] retry = 0 max_retry = 3 - while retry < max_retry: + + while retry < max_retry and len(sql_list) < self.top_k: try: async with aiohttp.ClientSession() as session, session.post( - Config().get_config().extra.sql_url+"/database/sql", + Config().get_config().extra.sql_url + "/database/sql", headers=headers, json=post_data, timeout=aiohttp.ClientTimeout(total=60), ) as response: if response.status == status.HTTP_200_OK: result = await response.json() - sub_sql_list = result["result"]["sql_list"] - sql_list += sub_sql_list - retry += 1 + if result["code"] == status.HTTP_200_OK: + sql_list.extend(result["result"]["sql_list"]) else: text = await response.text() logger.error("[SQL] 生成失败:%s", text) retry += 1 - continue except Exception: logger.exception("[SQL] 生成失败") retry += 1 - continue - if len(sql_list) >= self.top_k: - break - #执行sql,并将执行结果保存在sql_exec_results中 - sql_exec_results: Any = None + return sql_list + + + async def _execute_sql(self, sql_list: list[dict[str, Any]]) -> list[dict[str, Any]] | None: + """执行SQL语句并返回结果""" + headers = {"Content-Type": "application/json"} + for sql_dict in sql_list: - database_id = sql_dict["database_id"] - sql = sql_dict["sql"] - post_data = { - "database_id": database_id, - "sql": sql, - } try: async with aiohttp.ClientSession() as session, session.post( - Config().get_config().extra.sql_url+"/sql/execute", + Config().get_config().extra.sql_url + "/sql/execute", headers=headers, - json=post_data, + json={ + "database_id": sql_dict["database_id"], + "sql": sql_dict["sql"], + }, timeout=aiohttp.ClientTimeout(total=60), ) as response: if response.status == status.HTTP_200_OK: result = await response.json() - sql_exec_results = result["result"] - break + if result["code"] == status.HTTP_200_OK: + return result["result"] else: text = await response.text() logger.error("[SQL] 调用失败:%s", text) - continue except Exception: logger.exception("[SQL] 调用失败") - continue - if sql_exec_results is not None: - data = SQLOutput( - dataset=sql_exec_results, - ).model_dump(exclude_none=True, by_alias=True) - yield CallOutputChunk( - content=data, - type=CallOutputType.DATA, + + return None + + + async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """运行SQL工具""" + data = SQLInput(**input_data) + + # 生成SQL语句 + sql_list = await self._generate_sql(data) + if not sql_list: + raise CallError( + message="SQL查询错误:无法生成有效的SQL语句!", + data={}, + ) + + # 执行SQL语句 + sql_exec_results = await self._execute_sql(sql_list) + if sql_exec_results is None: + raise CallError( + message="SQL查询错误:SQL语句执行失败!", + data={}, ) - return - raise CallError( - message="SQL查询错误:SQL语句错误,数据库查询失败!", - data={}, + + # 返回结果 + data = SQLOutput( + dataset=sql_exec_results, + ).model_dump(exclude_none=True, by_alias=True) + + yield CallOutputChunk( + content=data, + type=CallOutputType.DATA, ) diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 903208187..28ea2b438 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -5,7 +5,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Self +from typing import TYPE_CHECKING, Any, Self from pydantic import Field @@ -24,7 +24,7 @@ if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor -class Suggestion(CoreCall, input_type=SuggestionInput, output_type=SuggestionOutput): +class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionOutput): """问题推荐""" configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1) diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index 8b5ba0a76..a4514f7c1 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: -class Summary(CoreCall, input_type=SummaryInput, output_type=SummaryOutput): +class Summary(CoreCall, input_model=SummaryInput, output_model=SummaryOutput): """总结工具""" context: ExecutorBackground = Field(description="对话上下文") -- Gitee From 06cc33ca6469140fa9644bc810342fe7ff1a5baa Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 18 Apr 2025 15:09:10 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=E6=9B=B4=E6=96=B0Search?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/api/api.py | 2 +- apps/scheduler/call/search/schema.py | 0 apps/scheduler/call/search/search.py | 31 ++++++++++++++++++++++++++++ apps/scheduler/call/slot/schema.py | 6 +++++- apps/scheduler/call/slot/slot.py | 2 +- 5 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 apps/scheduler/call/search/schema.py create mode 100644 apps/scheduler/call/search/search.py diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index d0573d29d..38d732e91 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -52,7 +52,7 @@ SUCCESS_HTTP_CODES = [ ] -class API(CoreCall, input_type=APIInput, output_type=APIOutput): +class API(CoreCall, input_model=APIInput, output_model=APIOutput): """API调用工具""" url: str = Field(description="API接口的完整URL") diff --git a/apps/scheduler/call/search/schema.py b/apps/scheduler/call/search/schema.py new file mode 100644 index 000000000..e69de29bb diff --git a/apps/scheduler/call/search/search.py b/apps/scheduler/call/search/search.py new file mode 100644 index 000000000..e12d38580 --- /dev/null +++ b/apps/scheduler/call/search/search.py @@ -0,0 +1,31 @@ +"""搜索工具""" +from typing import Any, ClassVar + +from pydantic import BaseModel, Field + +from apps.entities.scheduler import CallVars +from apps.scheduler.call.core import CoreCall + + +class SearchRet(BaseModel): + """搜索工具返回值""" + + data: list[dict[str, Any]] = Field(description="搜索结果") + + +class Search(CoreCall, ret_type=SearchRet): + """搜索工具""" + + name: ClassVar[str] = "搜索" + description: ClassVar[str] = "获取搜索引擎的结果" + + async def _init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]: + """初始化工具""" + self._query: str = kwargs["query"] + return {} + + + async def _exec(self) -> dict[str, Any]: + """执行工具""" + return {} + diff --git a/apps/scheduler/call/slot/schema.py b/apps/scheduler/call/slot/schema.py index 7d053aa58..08373cd96 100644 --- a/apps/scheduler/call/slot/schema.py +++ b/apps/scheduler/call/slot/schema.py @@ -7,7 +7,11 @@ from apps.scheduler.call.core import DataBase SLOT_GEN_PROMPT = [ {"role": "user", "content": ""}, - {"role": "assistant", "content": ""}, + {"role": "assistant", "content": r""" + + + """, + }, ] diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 5d0a29ee6..07640490a 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor -class Slot(CoreCall, input_type=SlotInput, output_type=SlotOutput): +class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): """参数填充工具""" data: dict[str, Any] = Field(description="当前输入", default={}) -- Gitee