From d68e85f243eee587159dce8394fb9ab01582e801 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:03:33 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E5=90=8C=E6=AD=A5=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E6=94=B9=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/collection.py | 2 +- apps/entities/enum_var.py | 18 ++++++++---------- apps/entities/message.py | 9 +-------- apps/entities/scheduler.py | 18 ++++++++++++------ apps/entities/task.py | 18 ++++++++++++------ apps/routers/chat.py | 2 +- 6 files changed, 35 insertions(+), 32 deletions(-) diff --git a/apps/entities/collection.py b/apps/entities/collection.py index 01042e481..8410e1e35 100644 --- a/apps/entities/collection.py +++ b/apps/entities/collection.py @@ -124,6 +124,7 @@ class RecordContent(BaseModel): question: str answer: str data: dict[str, Any] = {} + facts: list[str] = Field(description="[运行后修改]与Record关联的事实信息", default=[]) class RecordComment(BaseModel): @@ -144,7 +145,6 @@ class Record(BaseModel): data: str key: dict[str, Any] = {} flow: list[str] = Field(description="[运行后修改]与Record关联的FlowHistory的ID", default=[]) - facts: list[str] = Field(description="[运行后修改]与Record关联的事实信息", default=[]) comment: Optional[RecordComment] = None metadata: Optional[RecordMetadata] = None created_at: float = Field(default_factory=lambda: round(datetime.now(tz=timezone.utc).timestamp(), 3)) diff --git a/apps/entities/enum_var.py b/apps/entities/enum_var.py index 328dd626d..8c30ca234 100644 --- a/apps/entities/enum_var.py +++ b/apps/entities/enum_var.py @@ -32,22 +32,13 @@ class DocumentStatus(str, Enum): FAILED = "failed" -class FlowOutputType(str, Enum): - """Flow输出类型""" - - CODE = "code" - CHART = "chart" - URL = "url" - SCHEMA = "schema" - NONE = "none" - - class EventType(str, Enum): """事件类型""" HEARTBEAT = "heartbeat" INIT = "init" TEXT_ADD = "text.add" + GRAPH = "graph" DOCUMENT_ADD = "document.add" SUGGEST = "suggest" FLOW_START = "flow.start" @@ -134,3 +125,10 @@ class ContentType(str, Enum): JSON = "application/json" FORM_URLENCODED = "application/x-www-form-urlencoded" MULTIPART_FORM_DATA = "multipart/form-data" + + +class CallOutputType(str, Enum): + """Call输出类型""" + + TEXT = "text" + DATA = "data" diff --git a/apps/entities/message.py b/apps/entities/message.py index 75bc490f1..8464c2d2c 100644 --- a/apps/entities/message.py +++ b/apps/entities/message.py @@ -7,7 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field from apps.entities.collection import RecordMetadata -from apps.entities.enum_var import EventType, FlowOutputType, StepStatus +from apps.entities.enum_var import EventType, StepStatus class HeartbeatData(BaseModel): @@ -81,13 +81,6 @@ class FlowStartContent(BaseModel): params: dict[str, Any] = Field(description="预先提供的参数") -class FlowStopContent(BaseModel): - """flow.stop消息的content""" - - type: Optional[FlowOutputType] = Field(description="Flow输出的类型", default=None) - data: Optional[dict[str, Any]] = Field(description="Flow输出的数据", default=None) - - class MessageBase(HeartbeatData): """基础消息事件结构""" diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py index 8f0cf69a5..0b0f4de19 100644 --- a/apps/entities/scheduler.py +++ b/apps/entities/scheduler.py @@ -2,18 +2,16 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any +from typing import Any, Union from pydantic import BaseModel, Field +from apps.entities.enum_var import CallOutputType from apps.entities.task import FlowStepHistory class CallVars(BaseModel): - """所有Call都需要接受的参数。包含用户输入、上下文信息、Step的输出记录等 - - 这一部分的参数由Executor填充,用户无法修改 - """ + """由Executor填充的变量,即“系统变量”""" summary: str = Field(description="上下文信息") user_sub: str = Field(description="用户ID") @@ -22,12 +20,13 @@ class CallVars(BaseModel): task_id: str = Field(description="任务ID") flow_id: str = Field(description="Flow ID") session_id: str = Field(description="当前用户的Session ID") + service_id: str = Field(description="语义接口ID") class ExecutorBackground(BaseModel): """Executor的背景信息""" - conversation: str = Field(description="当前Executor的背景信息") + conversation: list[dict[str, str]] = Field(description="对话记录") facts: list[str] = Field(description="当前Executor的背景信息") @@ -38,3 +37,10 @@ class CallError(Exception): """获取Call错误中的数据""" self.message = message self.data = data + + +class CallOutputChunk(BaseModel): + """Call的输出""" + + type: CallOutputType = Field(description="输出类型") + content: Union[str, dict[str, Any]] = Field(description="输出内容") diff --git a/apps/entities/task.py b/apps/entities/task.py index 4d61e32cb..d704c4829 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -39,21 +39,27 @@ class ExecutorState(BaseModel): step_id: str = Field(description="当前步骤ID") step_name: str = Field(description="当前步骤名称") app_id: str = Field(description="应用ID") - # 运行时数据 - ai_summary: str = Field(description="大模型的思考内容", default="") - filled_data: dict[str, Any] = Field(description="待使用的参数", default={}) remaining_schema: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={}) +class TaskRuntimeData(BaseModel): + """Task的运行时数据""" + + session_id: str = Field(description="浏览器会话ID", default="") + user_sub: str = Field(description="用户ID", default="") + summary: str = Field(description="对对话记录的总结", default="") + filled_data: dict[str, Any] = Field(description="已经填充的参数", default={}) + new_context: list[str] = Field(description="本次运行中新产生的步骤历史", default=[]) + reached_end: bool = Field(description="是否到达Flow的终点", default=False) + + class TaskBlock(BaseModel): """内存中的Task块,不存储在数据库中""" - session_id: str = Field(description="浏览器会话ID") - user_sub: str = Field(description="用户ID", default="") record: RecordData = Field(description="当前任务执行过程关联的Record") flow_state: Optional[ExecutorState] = Field(description="Flow的状态", default=None) flow_context: dict[str, FlowStepHistory] = Field(description="Flow的执行信息", default={}) - new_context: list[str] = Field(description="Flow的执行信息(增量ID)", default=[]) + runtime_data: TaskRuntimeData = Field(description="Task的运行时数据", default_factory=lambda: TaskRuntimeData()) class RequestDataApp(BaseModel): diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 16391241e..2451bcd91 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -55,7 +55,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) task = await task_actor.get_task.remote(session_id=session_id, post_body=post_body) task_id = task.record.task_id - task.user_sub = user_sub + task.runtime_data.user_sub = user_sub task.record.group_id = group_id post_body.group_id = group_id await task_actor.set_task.remote(task_id, task) -- Gitee From 20490eb7a2228f59d52be3c9453fa10612a94375 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:04:52 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=90=8C=E6=AD=A5utils=E5=92=8Cslot?= =?UTF-8?q?=E6=94=B9=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/constants.py | 65 ------------- apps/entities/node.py | 6 +- apps/scheduler/slot/parser/const.py | 7 ++ apps/scheduler/slot/parser/default.py | 7 ++ apps/scheduler/slot/slot.py | 57 +----------- apps/utils/flow.py | 129 +++++++++++++++++++------- 6 files changed, 116 insertions(+), 155 deletions(-) diff --git a/apps/constants.py b/apps/constants.py index 6b05e30bb..85ad1725b 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -4,9 +4,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from __future__ import annotations -import logging -from textwrap import dedent - # 新对话默认标题 NEW_CHAT = "New Chat" # 滑动窗口限流 默认窗口期 @@ -27,8 +24,6 @@ APP_DIR = "app" FLOW_DIR = "flow" # Scheduler进程数 SCHEDULER_REPLICAS = 2 -# 日志记录器 -LOGGER = logging.getLogger("ray") REASONING_BEGIN_TOKEN = [ "", @@ -36,63 +31,3 @@ REASONING_BEGIN_TOKEN = [ REASONING_END_TOKEN = [ "", ] -LLM_CONTEXT_PROMPT = dedent( - r""" - 以下是对用户和AI间对话的简短总结: - {{ summary }} - - 你作为AI,在回答用户的问题前,需要获取必要信息。为此,你调用了以下工具,并获得了输出: - - {% for tool in history_data %} - {{ tool.step_id }} - {{ tool.output_data }} - {% endfor %} - - """, - ).strip("\n") -LLM_DEFAULT_PROMPT = dedent( - r""" - - 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。 - 当前时间:{{ time }},可以作为时间参照。 - 用户的问题将在中给出,上下文背景信息将在中给出。 - 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 - - - - {{ question }} - - - - {{ context }} - - """, - ).strip("\n") -RAG_ANSWER_PROMPT = dedent( - r""" - - 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。 - 用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。 - - 注意事项: - 1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。 - 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。 - 3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。 - 4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。 - - - - {{ question }} - - - - {{ context }} - - - - {{ document }} - - - 现在,请根据上述信息,回答用户的问题: - """, - ).strip("\n") diff --git a/apps/entities/node.py b/apps/entities/node.py index 03aab2e7e..c51b2b53d 100644 --- a/apps/entities/node.py +++ b/apps/entities/node.py @@ -9,14 +9,14 @@ from apps.entities.pool import NodePool class APINodeInput(BaseModel): """API节点覆盖输入""" - param_schema: Optional[dict[str, Any]] = Field(description="API节点输入参数Schema", default=None) - body_schema: Optional[dict[str, Any]] = Field(description="API节点输入请求体Schema", default=None) + query: Optional[dict[str, Any]] = Field(description="API节点输入参数Schema", default=None) + body: Optional[dict[str, Any]] = Field(description="API节点输入请求体Schema", default=None) class APINodeOutput(BaseModel): """API节点覆盖输出""" - resp_schema: Optional[dict[str, Any]] = Field(description="API节点输出Schema", default=None) + result: Optional[dict[str, Any]] = Field(description="API节点输出Schema", default=None) class APINode(NodePool): diff --git a/apps/scheduler/slot/parser/const.py b/apps/scheduler/slot/parser/const.py index 9c0329eb4..a008fc1d9 100644 --- a/apps/scheduler/slot/parser/const.py +++ b/apps/scheduler/slot/parser/const.py @@ -4,6 +4,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Any +from jsonschema import Validator + from apps.entities.enum_var import SlotType @@ -21,3 +23,8 @@ class SlotConstParser: """ raise NotImplementedError + @classmethod + def keyword_validate(cls, validator: Validator, keyword: str, instance: Any, schema: dict[str, Any]) -> bool: + """生成对应类型的验证器""" + ... + diff --git a/apps/scheduler/slot/parser/default.py b/apps/scheduler/slot/parser/default.py index 998ebba92..b8edb293e 100644 --- a/apps/scheduler/slot/parser/default.py +++ b/apps/scheduler/slot/parser/default.py @@ -4,6 +4,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Any +from jsonschema import Validator + from apps.entities.enum_var import SlotType @@ -21,3 +23,8 @@ class SlotDefaultParser: """ raise NotImplementedError + @classmethod + def keyword_validate(cls, validator: Validator, keyword: str, instance: Any, schema: dict[str, Any]) -> bool: + """给字段设置默认值""" + ... + diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py index fd620d696..0d2bd91ca 100644 --- a/apps/scheduler/slot/slot.py +++ b/apps/scheduler/slot/slot.py @@ -6,15 +6,13 @@ import json import logging import traceback from collections.abc import Mapping -from copy import deepcopy -from typing import Any, Optional, Union +from typing import Any, Union from jsonschema import Draft7Validator from jsonschema.exceptions import ValidationError from jsonschema.protocols import Validator from jsonschema.validators import extend -from apps.llm.patterns.json_gen import Json from apps.scheduler.slot.parser import ( SlotConstParser, SlotDateParser, @@ -32,8 +30,8 @@ _TYPE_CHECKER = [ ] _FORMAT_CHECKER = [] _KEYWORD_CHECKER = { - "const": SlotConstParser, - "default": SlotDefaultParser, + "const": SlotConstParser.keyword_validate, + "default": SlotDefaultParser.keyword_validate, } # 各类转换器 @@ -258,52 +256,3 @@ class Slot: return schema_template return {} - - - @staticmethod - async def llm_param_gen(task_id: str, question: str, thought: str, previous_output: Optional[dict[str, Any]], remaining_schema: dict[str, Any]) -> dict[str, Any]: - """使用LLM生成JSON参数""" - # 组装工具消息 - conversation = [ - {"role": "user", "content": question}, - {"role": "assistant", "content": thought}, - ] - - if previous_output is not None: - tool_str = f"""我使用了一个工具从其他来源获取额外信息。\ - 工具的输出数据是 `{previous_output}`。 - 输出的schema是 `{json.dumps(previous_output["output_schema"], ensure_ascii=False)}`,其中包含了输出的描述信息。 - """ - - conversation.append({"role": "tool", "content": tool_str}) - - return await Json().generate(task_id, conversation=conversation, spec=remaining_schema, strict=False) - - - async def process(self, previous_json: dict[str, Any], new_json: dict[str, Any], llm_params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: - """对参数槽进行综合处理,返回剩余的JSON Schema和填充后的JSON""" - # 将用户手动填充的参数专为真实JSON - slot_data = self.convert_json(new_json) - # 合并 - result_json = deepcopy(previous_json) - result_json.update(slot_data) - # 检测槽位是否合法、是否填充完成 - remaining_slot = self.check_json(result_json) - # 如果还有未填充的部分,则尝试使用LLM生成 - if remaining_slot: - generated_slot = await Slot.llm_param_gen( - llm_params["task_id"], - llm_params["question"], - llm_params["thought"], - llm_params["previous_output"], - remaining_slot, - ) - # 合并 - generated_slot = self.convert_json(generated_slot) - result_json.update(generated_slot) - # 再次检查槽位 - remaining_slot = self.check_json(result_json) - return remaining_slot, result_json - - return {}, result_json - diff --git a/apps/utils/flow.py b/apps/utils/flow.py index bac16f3f9..107879384 100644 --- a/apps/utils/flow.py +++ b/apps/utils/flow.py @@ -2,6 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import collections import logging from apps.entities.enum_var import NodeType @@ -13,43 +14,59 @@ logger = logging.getLogger("ray") class FlowService: """flow拓扑相关函数""" + @staticmethod + def _validate_branch_id(node_name: str, branch_id: str, node_branches: set, branch_illegal_chars: str = ".") -> None: + """验证分支ID的合法性;当分支ID重复或包含非法字符时抛出异常""" + if branch_id in node_branches: + err = f"[FlowService] 节点{node_name}的分支{branch_id}重复" + logger.error(err) + raise Exception(err) + + for illegal_char in branch_illegal_chars: + if illegal_char in branch_id: + err = f"[FlowService] 节点{node_name}的分支{branch_id}名称中含有非法字符" + logger.error(err) + raise Exception(err) + + @staticmethod + def _process_choice_node(node: NodeItem, node_to_branches: dict) -> None: + """处理选择节点的分支信息""" + node.parameters = node.parameters["input_parameters"] + if "choices" not in node.parameters: + node.parameters["choices"] = [] + + for branch in node.parameters["choices"]: + FlowService._validate_branch_id(node.name, branch["branchId"], node_to_branches[node.step_id]) + node_to_branches[node.step_id].add(branch["branchId"]) + + @staticmethod + def _filter_valid_edges(flow_item: FlowItem, node_to_branches: dict) -> list[EdgeItem]: + """过滤有效的边""" + return [ + edge for edge in flow_item.edges + if (edge.source_node in node_to_branches and + edge.target_node in node_to_branches and + edge.branch_id in node_to_branches[edge.source_node]) + ] + @staticmethod async def remove_excess_structure_from_flow(flow_item: FlowItem) -> FlowItem: """移除流程图中的多余结构""" node_to_branches = {} branch_illegal_chars = "." + # 处理所有节点的分支信息 for node in flow_item.nodes: - node_to_branches[node.step_id] = set() if node.call_id == NodeType.CHOICE.value: - node.parameters = node.parameters["input_parameters"] - if "choices" not in node.parameters: - node.parameters["choices"] = [] - for branch in node.parameters["choices"]: - if branch["branchId"] in node_to_branches[node.step_id]: - err = f"[FlowService] 节点{node.name}的分支{branch['branchId']}重复" - logger.error(err) - raise Exception(err) - for illegal_char in branch_illegal_chars: - if illegal_char in branch["branchId"]: - err = f"[FlowService] 节点{node.name}的分支{branch['branchId']}名称中含有非法字符" - logger.error(err) - raise Exception(err) - node_to_branches[node.step_id].add(branch["branchId"]) + FlowService._process_choice_node(node, node_to_branches) else: node_to_branches[node.step_id].add("") - new_edges_items = [] - for edge in flow_item.edges: - if edge.source_node not in node_to_branches: - continue - if edge.target_node not in node_to_branches: - continue - if edge.branch_id not in node_to_branches[edge.source_node]: - continue - new_edges_items.append(edge) - flow_item.edges = new_edges_items + + # 过滤有效的边 + flow_item.edges = FlowService._filter_valid_edges(flow_item, node_to_branches) return flow_item + @staticmethod async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[str, str]: """验证节点ID的唯一性并获取起始和终止节点ID,当节点ID重复或起始/终止节点数量不为1时抛出异常""" @@ -84,6 +101,21 @@ class FlowService: return start_id, end_id + @staticmethod + async def validate_flow_illegal(flow_item: FlowItem) -> tuple[str, str]: + """验证流程图是否合法;当流程图不合法时抛出异常""" + # 验证节点ID并获取起始和终止节点 + start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes) + + # 验证边的合法性并获取节点的入度和出度 + in_deg, out_deg = await FlowService._validate_edges(flow_item.edges) + + # 验证起始和终止节点的入度和出度 + await FlowService._validate_node_degrees(start_id, end_id, in_deg, out_deg) + + return start_id, end_id + + @staticmethod async def _validate_edges(edges: list[EdgeItem]) -> tuple[dict[str, int], dict[str, int]]: """验证边的合法性并计算节点的入度和出度;当边的ID重复、起始终止节点相同或分支重复时抛出异常""" @@ -117,6 +149,7 @@ class FlowService: return in_deg, out_deg + @staticmethod async def _validate_node_degrees(start_id: str, end_id: str, in_deg: dict[str, int], out_deg: dict[str, int]) -> None: @@ -130,17 +163,46 @@ class FlowService: logger.error(err) raise Exception(err) + @staticmethod - async def validate_flow_illegal(flow_item: FlowItem) -> None: - """验证流程图是否合法;当流程图不合法时抛出异常""" - # 验证节点ID并获取起始和终止节点 - start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes) + def _find_start_node_id(nodes: list[NodeItem]) -> str: + """查找起始节点ID""" + for node in nodes: + if node.call_id == NodeType.START.value: + return node.step_id + return "" - # 验证边的合法性并获取节点的入度和出度 - in_deg, out_deg = await FlowService._validate_edges(flow_item.edges) - # 验证起始和终止节点的入度和出度 - await FlowService._validate_node_degrees(start_id, end_id, in_deg, out_deg) + @staticmethod + def _build_adjacency_list(edges: list[EdgeItem]) -> dict[str, list[str]]: + """构建邻接表""" + adj_list = {} + for edge in edges: + if edge.source_node not in adj_list: + adj_list[edge.source_node] = [] + adj_list[edge.source_node].append(edge.target_node) + return adj_list + + + @staticmethod + def _bfs_traverse(start_id: str, adj_list: dict[str, list[str]]) -> set[str]: + """使用BFS遍历图并返回可达节点集合""" + visited = set() + if not start_id: + return visited + + queue = collections.deque([start_id]) + visited.add(start_id) + + while queue: + current = queue.popleft() + if current in adj_list: + for neighbor in adj_list[current]: + if neighbor not in visited: + visited.add(neighbor) + queue.append(neighbor) + return visited + @staticmethod async def validate_flow_connectivity(flow_item: FlowItem) -> bool: @@ -188,6 +250,7 @@ class FlowService: return end_id in vis + def generate_from_schema(schema: dict) -> dict: """根据 JSON Schema 生成示例 JSON 数据。 -- Gitee From 8c1d1516c56b93e589390cce6259dc3678510241 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:05:19 +0800 Subject: [PATCH 3/8] =?UTF-8?q?=E5=90=8C=E6=AD=A5Scheduler=E6=94=B9?= =?UTF-8?q?=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/context.py | 52 ++++++++------------------- apps/scheduler/scheduler/flow.py | 2 +- apps/scheduler/scheduler/message.py | 5 +-- apps/scheduler/scheduler/scheduler.py | 26 +++++++------- 4 files changed, 31 insertions(+), 54 deletions(-) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 06fb8b373..241e6ecf0 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -24,6 +24,11 @@ logger = logging.getLogger("ray") async def get_docs(user_sub: str, post_body: RequestData) -> tuple[Union[list[RecordDocument], list[Document]], list[str]]: """获取当前问答可供关联的文档""" + if not post_body.group_id: + err = "[Scheduler] 问答组ID不能为空!" + logger.error(err) + raise ValueError(err) + doc_ids = [] docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id) @@ -41,21 +46,7 @@ async def get_docs(user_sub: str, post_body: RequestData) -> tuple[Union[list[Re return docs, doc_ids -async def get_rag_context(user_sub: str, post_body: RequestData, n: int) -> list[dict[str, str]]: - """获取当前问答的上下文信息""" - # 组装RAG请求数据,备用 - records = await RecordManager.query_record_by_conversation_id(user_sub, post_body.conversation_id, n + 5) - - context = [] - for record in records: - record_data = RecordContent.model_validate_json(Security.decrypt(record.data, record.key)) - context.append({"role": "user", "content": record_data.question}) - context.append({"role": "assistant", "content": record_data.answer}) - - return context - - -async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[str, list[str]]: +async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[list[dict[str, str]], list[str]]: """获取当前问答的上下文信息 注意:这里的n要比用户选择的多,因为要考虑事实信息和历史问题 @@ -65,27 +56,17 @@ async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[st # 获取最后n+5条Record records = await RecordManager.query_record_by_conversation_id(user_sub, post_body.conversation_id, n + 5) - # 获取事实信息 - facts = [] - for record in records: - facts.extend(record.facts) # 组装问答 - messages = "" + context = [] + facts = [] for record in records: record_data = RecordContent.model_validate_json(Security.decrypt(record.data, record.key)) + context.append({"role": "user", "content": record_data.question}) + context.append({"role": "assistant", "content": record_data.answer}) + facts.extend(record_data.facts) - messages += f""" - - {record_data.question} - - - {record_data.answer} - - """ - messages += "" - - return messages, facts + return context, facts async def generate_facts(task: TaskBlock, question: str) -> list[str]: @@ -115,7 +96,7 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do # 循环创建FlowHistory history_data = [] # 遍历查找数据,并添加 - for history_id in task.new_context: + for history_id in task.runtime_data.new_context: for history in task.flow_context.values(): if history.id == history_id: history_data.append(history) @@ -125,20 +106,15 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do # 修改metadata里面时间为实际运行时间 task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time_cost, 2) - # 提取facts - # 记忆提取 - facts = await generate_facts(task, post_body.question) - # 整理Record数据 record = Record( record_id=task.record.id, user_sub=user_sub, data=encrypt_data, key=encrypt_config, - facts=facts, metadata=task.record.metadata, created_at=task.record.created_at, - flow=task.new_context, + flow=task.runtime_data.new_context, ) record_group = task.record.group_id diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 0abf16875..f87c133a7 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -7,10 +7,10 @@ from typing import Optional import ray -from apps.constants import RAG_ANSWER_PROMPT from apps.entities.flow import Flow, FlowError, Step from apps.entities.task import RequestDataApp from apps.llm.patterns import Select +from apps.scheduler.call.llm.schema import RAG_ANSWER_PROMPT logger = logging.getLogger("ray") diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index b246bf604..2e747118c 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -67,7 +67,7 @@ async def push_rag_message(task: TaskBlock, queue: actor.ActorHandle, user_sub: async for chunk in RAG.get_rag_result(user_sub, rag_data): task, chunk_content = await _push_rag_chunk(task, queue, chunk) full_answer += chunk_content - # FIXME: 这里由于后面没有其他消息,所以只能每个trunk更新task + # TODO: 这里由于后面没有其他消息,所以只能每个trunk更新task await task_actor.set_task.remote(task.record.task_id, task) # 保存答案 @@ -97,10 +97,11 @@ async def _push_rag_chunk(task: TaskBlock, queue: actor.ActorHandle, content: st event_type=EventType.TEXT_ADD, data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True), ) - return task, content_obj.content except Exception: logger.exception("[Scheduler] RAG服务返回错误数据") return task, "" + else: + return task, content_obj.content async def push_document_message(task: TaskBlock, queue: actor.ActorHandle, doc: Union[RecordDocument, Document]) -> TaskBlock: diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index a088618ef..70c4ca0ad 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -16,8 +16,8 @@ from apps.entities.task import SchedulerResult, TaskBlock from apps.manager.appcenter import AppCenterManager from apps.manager.flow import FlowManager from apps.manager.user import UserManager -from apps.scheduler.executor.flow import Executor -from apps.scheduler.scheduler.context import get_context, get_docs, get_rag_context +from apps.scheduler.executor.flow import FlowExecutor +from apps.scheduler.scheduler.context import get_context, get_docs from apps.scheduler.scheduler.flow import FlowChooser from apps.scheduler.scheduler.message import ( push_document_message, @@ -60,15 +60,6 @@ class Scheduler: await queue.close.remote() # type: ignore[attr-defined] return SchedulerResult(used_docs=[]) - # 组装RAG请求数据,备用 - rag_data = RAGQueryReq( - question=post_body.question, - language=post_body.language, - document_ids=doc_ids, - kb_sn=None if not user_info.kb_id else user_info.kb_id, - top_k=5, - history=await get_rag_context(user_sub, post_body, 3), - ) # 已使用文档 used_docs = [] @@ -82,7 +73,16 @@ class Scheduler: task = await push_document_message(task, queue, doc) await asyncio.sleep(0.1) - # 保存有数据的最后一条消息 + # 调用RAG + history, _ = await get_context(user_sub, post_body, 3) + rag_data = RAGQueryReq( + question=post_body.question, + language=post_body.language, + document_ids=doc_ids, + kb_sn=None if not user_info.kb_id else user_info.kb_id, + top_k=5, + history=history, + ) task = await push_rag_message(task, queue, user_sub, rag_data) else: # 查找对应的App元数据 @@ -149,7 +149,7 @@ class Scheduler: # 初始化Executor logger.info("[Scheduler] 初始化Executor") - flow_exec = Executor( + flow_exec = FlowExecutor( flow_id=flow_id, flow=flow_data, task=task, -- Gitee From 22fdde8d06e9c7f7b71b3628cc90b36531d7dd18 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:06:17 +0800 Subject: [PATCH 4/8] =?UTF-8?q?=E5=90=8C=E6=AD=A5Executor=E6=94=B9?= =?UTF-8?q?=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/task.py | 10 +- apps/manager/appcenter.py | 6 +- apps/manager/node.py | 20 +- apps/manager/token.py | 1 - apps/scheduler/executor/base.py | 81 ++++++++ apps/scheduler/executor/flow.py | 323 +++++++++++------------------ apps/scheduler/executor/hook.py | 75 +++++++ apps/scheduler/executor/message.py | 131 ------------ apps/scheduler/executor/node.py | 52 +++++ apps/scheduler/executor/step.py | 193 +++++++++++++++++ 10 files changed, 551 insertions(+), 341 deletions(-) create mode 100644 apps/scheduler/executor/base.py create mode 100644 apps/scheduler/executor/hook.py delete mode 100644 apps/scheduler/executor/message.py create mode 100644 apps/scheduler/executor/node.py create mode 100644 apps/scheduler/executor/step.py diff --git a/apps/common/task.py b/apps/common/task.py index 8f4696241..f937fefad 100644 --- a/apps/common/task.py +++ b/apps/common/task.py @@ -16,7 +16,7 @@ from apps.entities.record import ( RecordMetadata, ) from apps.entities.request_data import RequestData -from apps.entities.task import TaskBlock, TaskData +from apps.entities.task import TaskBlock, TaskData, TaskRuntimeData from apps.manager.task import TaskManager from apps.models.mongo import MongoDB @@ -77,8 +77,10 @@ class Task: new_record.task_id = task_id self._task_map[task_id] = TaskBlock( - session_id=session_id, record=new_record, + runtime_data=TaskRuntimeData( + session_id=session_id, + ), ) return deepcopy(self._task_map[task_id]) @@ -86,9 +88,11 @@ class Task: task_id = task.id new_record.task_id = task_id self._task_map[task_id] = TaskBlock( - session_id=session_id, record=new_record, flow_state=task.state, + runtime_data=TaskRuntimeData( + session_id=session_id, + ), ) return deepcopy(self._task_map[task_id]) diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py index 29c45f5bf..8ee6e0ef8 100644 --- a/apps/manager/appcenter.py +++ b/apps/manager/appcenter.py @@ -383,6 +383,8 @@ class AppCenterManager: :param app_id: 应用唯一标识 :return: 更新是否成功 """ + if not app_id: + return True try: user_collection = MongoDB.get_collection("user") current_time = round(datetime.now(tz=timezone.utc).timestamp(), 3) @@ -402,10 +404,10 @@ class AppCenterManager: logger.info("[AppCenterManager] 更新用户最近使用应用: %s", user_sub) return True logger.warning("[AppCenterManager] 用户 %s 没有更新", user_sub) - return False except Exception: logger.exception("[AppCenterManager] 更新用户最近使用应用失败") - return False + + return False @staticmethod async def get_default_flow_id(app_id: str) -> Optional[str]: diff --git a/apps/manager/node.py b/apps/manager/node.py index 83faa4eaf..9ba8692fe 100644 --- a/apps/manager/node.py +++ b/apps/manager/node.py @@ -28,6 +28,24 @@ class NodeManager: return node["call_id"] + @staticmethod + async def get_node(node_id: str) -> NodePool: + """获取Node的类型""" + node_collection = MongoDB().get_collection("node") + node = await node_collection.find_one({"_id": node_id}) + if not node: + err = f"[NodeManager] Node type {node_id} not found." + raise ValueError(err) + + try: + node = NodePool.model_validate(node) + except Exception as e: + err = f"[NodeManager] Node {node_id} 验证失败" + logger.exception(err) + raise ValueError(err) from e + return node + + @staticmethod async def get_node_name(node_id: str) -> str: """获取node的名称""" @@ -96,5 +114,5 @@ class NodeManager: # 返回参数Schema return ( NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}), - call_class.ret_type.model_json_schema(), # type: ignore[attr-defined] + call_class.output_type.model_json_schema(), # type: ignore[attr-defined] ) diff --git a/apps/manager/token.py b/apps/manager/token.py index b0bf98853..a54137288 100644 --- a/apps/manager/token.py +++ b/apps/manager/token.py @@ -107,4 +107,3 @@ class TokenManager: pipe.delete(f"{user_sub}_oidc_refresh_token") pipe.delete(f"aops_{user_sub}_token") await pipe.execute() - diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py new file mode 100644 index 000000000..30ad51f63 --- /dev/null +++ b/apps/scheduler/executor/base.py @@ -0,0 +1,81 @@ +"""Executor基类""" +import logging +from typing import Optional, Any +from datetime import datetime, timezone + +from pydantic import BaseModel, ConfigDict +from ray import actor + +from apps.entities.enum_var import EventType, StepStatus +from apps.entities.flow import Flow +from apps.entities.task import TaskBlock +from apps.entities.scheduler import ExecutorBackground +from apps.entities.message import FlowStartContent +from apps.entities.task import FlowStepHistory + +logger = logging.getLogger("ray") + + +class BaseExecutor(BaseModel): + """Executor基类""" + task: TaskBlock + queue: actor.ActorHandle + flow: Flow + executor_background: ExecutorBackground + question: str + step_history: Optional[FlowStepHistory] = None + + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="allow", + ) + + + def validate_flow_state(self) -> None: + """验证flow_state是否存在""" + if not self.task.flow_state: + err = "[Executor] 当前ExecutorState为空" + logger.error(err) + raise ValueError(err) + + + async def push_message(self, event_type: EventType, data: Optional[dict] = None) -> None: + """统一的消息推送接口 + + Args: + event_type: 事件类型 + data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent + """ + self.validate_flow_state() + + if event_type == EventType.FLOW_START and data is None: + content = FlowStartContent( + question=self.question, + params=self.task.runtime_data.filled_data, + ) + data = content.model_dump(exclude_none=True, by_alias=True) + elif event_type == EventType.FLOW_STOP and data is None: + data = {} + elif event_type == EventType.STEP_INPUT: + # 更新step_history的输入数据 + if self.step_history is not None and data is not None: + self.step_history.input_data = data + if self.task.flow_state is not None: + self.step_history.status = self.task.flow_state.status + # 步骤开始,重置时间 + self.task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp(), 2) + elif event_type == EventType.STEP_OUTPUT: + # 更新step_history的输出数据 + if self.step_history is not None and data is not None: + self.step_history.output_data = data + if self.task.flow_state is not None: + self.step_history.status = self.task.flow_state.status + self.task.flow_context[self.task.flow_state.step_id] = self.step_history + self.task.runtime_data.new_context.append(self.step_history.id) + + await self.queue.push_output.remote( + self.task, + event_type=event_type, + data=data + ) # type: ignore[attr-defined] diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 8e3f4c5ff..409bedc12 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -2,49 +2,67 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import asyncio -import inspect import logging -from typing import Any, Optional import ray from pydantic import BaseModel, ConfigDict, Field from ray import actor -from apps.entities.enum_var import StepStatus -from apps.entities.flow import Flow, FlowError, Step +from apps.entities.enum_var import EventType, StepStatus +from apps.entities.flow import Flow, Step +from apps.entities.message import FlowStartContent from apps.entities.request_data import RequestDataApp from apps.entities.scheduler import CallVars, ExecutorBackground from apps.entities.task import ExecutorState, TaskBlock -from apps.llm.patterns.executor import ExecutorSummary -from apps.manager.node import NodeManager from apps.manager.task import TaskManager -from apps.scheduler.call.core import CoreCall -from apps.scheduler.call.empty import Empty -from apps.scheduler.call.llm import LLM -from apps.scheduler.executor.message import ( - push_flow_start, - push_flow_stop, - push_step_input, - push_step_output, - push_text_output, -) -from apps.scheduler.slot.slot import Slot +from apps.scheduler.call.llm.schema import LLM_ERROR_PROMPT +from apps.scheduler.executor.base import BaseExecutor +from apps.scheduler.executor.step import StepExecutor logger = logging.getLogger("ray") +FIXED_STEPS_BEFORE_START = { + "_summary": Step( + name="理解上下文", + description="使用大模型,理解对话上下文", + node="Summary", + type="Summary", + params={}, + ), +} +FIXED_STEPS_AFTER_END = { + "_facts": Step( + name="记忆存储", + description="理解对话答案,并存储到记忆中", + node="Facts", + type="Facts", + params={}, + ), +} +SLOT_FILLING_STEP = Step( + name="填充参数", + description="根据步骤历史,填充参数", + node="Slot", + type="Slot", + params={}, +) +ERROR_STEP = Step( + name="错误处理", + description="错误处理", + node="LLM", + type="LLM", + params={ + "user_prompt": LLM_ERROR_PROMPT, + }, +) # 单个流的执行工具 -class Executor(BaseModel): +class FlowExecutor(BaseExecutor): """用于执行工作流的Executor""" flow_id: str = Field(description="Flow ID") - flow: Flow = Field(description="工作流数据") - task: TaskBlock = Field(description="任务信息") question: str = Field(description="用户输入") - queue: actor.ActorHandle = Field(description="消息队列") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - executor_background: ExecutorBackground = Field(description="Executor的背景信息") model_config = ConfigDict( arbitrary_types_allowed=True, @@ -58,216 +76,89 @@ class Executor(BaseModel): logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State if self.task.flow_state: - self.flow_state = self.task.flow_state self.task.flow_context = await TaskManager.get_flow_history_by_task_id(self.task.record.task_id) else: # 创建ExecutorState - self.flow_state = ExecutorState( + self.task.flow_state = ExecutorState( flow_id=str(self.flow_id), description=str(self.flow.description), status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", step_name="开始", - ai_summary="", - filled_data=self.post_body_app.params, ) - # 是否结束运行 - self._can_continue = True - # 向用户输出的内容 - self._final_answer = "" - - - async def _check_cls(self, call_cls: Any) -> bool: - """检查Call是否符合标准要求""" - flag = True - if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str): - flag = False - if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str): - flag = False - if not hasattr(call_cls, "exec") or not asyncio.iscoroutinefunction(call_cls.exec): - flag = False - if not hasattr(call_cls, "stream") or not inspect.isasyncgenfunction(call_cls.stream): - flag = False - return flag - - - # TODO: 默认错误处理步骤 - async def _run_error(self, step: FlowError) -> dict[str, Any]: - """运行错误处理步骤""" - return {} - - - async def _get_call_cls(self, node_id: str) -> type[CoreCall]: - """获取并验证Call类""" - # 检查flow_state是否为空 - if not self.flow_state: - err = "[FlowExecutor] flow_state为空" - raise ValueError(err) + # 是否到达Flow结束终点(变量) + self._reached_end: bool = False - # 特判 - if node_id == "Empty": - return Empty - # 获取对应Node的call_id - call_id = await NodeManager.get_node_call_id(node_id) - # 从Pool中获取对应的Call - pool = ray.get_actor("pool") - call_cls: type[CoreCall] = await pool.get_call.remote(call_id) - - # 检查Call合法性 - if not await self._check_cls(call_cls): - err = f"[FlowExecutor] 工具 {node_id} 不符合Call标准要求" + async def _invoke_runner(self, step_id: str) -> None: + """调用Runner""" + if not self.task.flow_state: + err = "[FlowExecutor] 当前ExecutorState为空" + logger.error(err) raise ValueError(err) - return call_cls - - - async def _process_slots(self, call_obj: Any) -> tuple[bool, Optional[dict[str, Any]]]: - """处理slot参数""" - if not (hasattr(call_obj, "slot_schema") and call_obj.slot_schema): - return True, None + step = self.flow.steps[step_id] - slot_processor = Slot(call_obj.slot_schema) - remaining_schema, slot_data = await slot_processor.process( - self.flow_state.filled_data, - self.post_body_app.params, - { - "task_id": self.task.record.task_id, - "question": self.question, - "thought": self.flow_state.ai_summary, - "previous_output": await self._get_last_output(self.task), - }, - ) - - # 保存Schema至State - self.flow_state.remaining_schema = remaining_schema - self.flow_state.filled_data.update(slot_data) - - # 如果还有未填充的部分,则返回False - if remaining_schema: - self._can_continue = False - self.flow_state.status = StepStatus.RUNNING - # 推送空输入输出 - self.task = await push_step_input(self.task, self.queue, self.flow_state, {}) - self.flow_state.status = StepStatus.PARAM - self.task = await push_step_output(self.task, self.queue, self.flow_state, {}) - return False, None - - return True, slot_data - - - # TODO - async def _get_last_output(self, task: TaskBlock) -> Optional[dict[str, Any]]: - """获取上一步的输出""" - return None - - - async def _execute_call(self, call_obj: Any, *, is_final_answer: bool) -> dict[str, Any]: - """执行Call并处理结果""" - # call_obj一定合法;开始判断是否为最终结果 - if is_final_answer and isinstance(call_obj, LLM): - # 最后一步 & 是大模型步骤,直接流式输出 - async for chunk in call_obj.stream(): - await push_text_output(self.task, self.queue, chunk) - self._final_answer += chunk - self.flow_state.status = StepStatus.SUCCESS - return { - "message": self._final_answer, - } - - # 其他情况:先运行步骤,得到结果 - result: dict[str, Any] = await call_obj.exec() - if is_final_answer: - if call_obj.name == "Convert": - # 如果是Convert,直接输出转换之后的结果 - self._final_answer += result["message"] - await push_text_output(self.task, self.queue, self._final_answer) - self.flow_state.status = StepStatus.SUCCESS - return { - "message": self._final_answer, - } - # 其他工具,加一步大模型过程 - # FIXME: 使用单独的Prompt - self.flow_state.status = StepStatus.SUCCESS - return { - "message": self._final_answer, - } - - # 其他情况:返回结果 - self.flow_state.status = StepStatus.SUCCESS - return result - - - async def _run_step(self, step_id: str, step_data: Step) -> None: - """运行单个步骤""" - logger.info("[FlowExecutor] 运行步骤 %s", step_data.name) - - # State写入ID和运行状态 - self.flow_state.step_id = step_id - self.flow_state.step_name = step_data.name - self.flow_state.status = StepStatus.RUNNING - - # 查找下一个步骤;判断是否是end前最后一步 - next_step = await self._get_next_step() - is_final_answer = next_step == "end" - - # 获取并验证Call类 - node_id = step_data.node - call_cls = await self._get_call_cls(node_id) + # 获取task actor + task_actor = ray.get_actor("task") # 准备系统变量 sys_vars = CallVars( question=self.question, task_id=self.task.record.task_id, flow_id=self.post_body_app.flow_id, - session_id=self.task.session_id, + session_id=self.task.runtime_data.session_id, history=self.task.flow_context, - summary=self.flow_state.ai_summary, - user_sub=self.task.user_sub, + summary=self.task.runtime_data.summary, + user_sub=self.task.runtime_data.user_sub, + service_id=step.params.get("service_id", ""), + ) + step_runner = StepExecutor( + queue=self.queue, + task=self.task, + flow=self.flow, + sys_vars=sys_vars, + executor_background=self.executor_background, + question=self.question, ) - # 初始化Call - call_obj = call_cls.model_validate(step_data.params) - input_data = await call_obj.init(sys_vars) - - # TODO: 处理slots - # can_continue, slot_data = await self._process_slots(call_obj) - # if not self._can_continue: - # return - - # 执行Call并获取结果 - self.task = await push_step_input(self.task, self.queue, self.flow_state, input_data) - result_data = await self._execute_call(call_obj, is_final_answer=is_final_answer) - self.task = await push_step_output(self.task, self.queue, self.flow_state, result_data) + # 运行Step + call_id, call_obj = await step_runner.init_step(step_id) + # 尝试填参 + input_data = await step_runner.fill_slots(call_obj) + # 运行 + await step_runner.run_step(call_id, call_obj, input_data) - # 更新下一步 - self.flow_state.step_id = next_step + # 更新Task + self.task = step_runner.task + await task_actor.set_task.remote(self.task.record.task_id, self.task) - async def _get_next_step(self) -> str: + async def _find_flow_next(self) -> str: """在当前步骤执行前,尝试获取下一步""" + if not self.task.flow_state: + err = "[StepExecutor] 当前ExecutorState为空" + logger.error(err) + raise ValueError(err) + # 如果当前步骤为结束,则直接返回 - if self.flow_state.step_id == "end" or not self.flow_state.step_id: + if self.task.flow_state.step_id == "end" or not self.task.flow_state.step_id: # 如果是最后一步,设置停止标志 - self._can_continue = False - return "" - - if self.flow.steps[self.flow_state.step_id].node == "Choice": - # 如果是选择步骤,那么现在还不知道下一步是谁,直接返回 + self._reached_end = True return "" next_steps = [] # 遍历Edges,查找下一个节点 for edge in self.flow.edges: - if edge.edge_from == self.flow_state.step_id: + if edge.edge_from == self.task.flow_state.step_id: next_steps += [edge.edge_to] # 如果step没有任何出边,直接跳到end if not next_steps: return "end" - # FIXME: 目前只使用第一个出边 + # TODO: 目前只使用第一个出边 logger.info("[FlowExecutor] 下一步 %s", next_steps[0]) return next_steps[0] @@ -277,33 +168,59 @@ class Executor(BaseModel): 数据通过向Queue发送消息的方式传输 """ + if not self.task.flow_state: + err = "[FlowExecutor] 当前ExecutorState为空" + logger.error(err) + raise ValueError(err) + task_actor = ray.get_actor("task") logger.info("[FlowExecutor] 运行工作流") # 推送Flow开始消息 - self.task = await push_flow_start(self.task, self.queue, self.flow_state, self.question) + await self.push_message(EventType.FLOW_START) + + # 进行开始前的系统步骤 + for step_id, step in FIXED_STEPS_BEFORE_START.items(): + self.flow.steps[step_id] = step + await self._invoke_runner(step_id) + self.task.flow_state.step_id = "start" + self.task.flow_state.step_name = "开始" # 如果允许继续运行Flow - while self._can_continue: + while not self._reached_end: # Flow定义中找不到step - if not self.flow_state.step_id or (self.flow_state.step_id not in self.flow.steps): - logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.flow_state.step_id) - self.flow_state.status = StepStatus.ERROR + if not self.task.flow_state.step_id or (self.task.flow_state.step_id not in self.flow.steps): + logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.task.flow_state.step_id) + self.task.flow_state.status = StepStatus.ERROR - if self.flow_state.status == StepStatus.ERROR: + if self.task.flow_state.status == StepStatus.ERROR: # 执行错误处理步骤 logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") - step = self.flow.on_error - await self._run_error(step) + self.flow.steps["_error"] = ERROR_STEP + await self._invoke_runner("_error") else: # 执行正常步骤 - step = self.flow.steps[self.flow_state.step_id] - await self._run_step(self.flow_state.step_id, step) + step = self.flow.steps[self.task.flow_state.step_id] + await self._invoke_runner(self.task.flow_state.step_id) # 步骤结束,更新全局的Task await task_actor.set_task.remote(self.task.record.task_id, self.task) + # 查找下一个节点 + next_step_id = await self._find_flow_next() + if next_step_id: + self.task.flow_state.step_id = next_step_id + self.task.flow_state.step_name = self.flow.steps[next_step_id].name + else: + # 如果没有下一个节点,设置结束标志 + self._reached_end = True + + # 运行结束后的系统步骤 + for step_id, step in FIXED_STEPS_AFTER_END.items(): + self.flow.steps[step_id] = step + await self._invoke_runner(step_id) + # 推送Flow停止消息 - self.task = await push_flow_stop(self.task, self.queue, self.flow_state, self.flow, self._final_answer) + await self.push_message(EventType.FLOW_STOP) # 更新全局的Task await task_actor.set_task.remote(self.task.record.task_id, self.task) diff --git a/apps/scheduler/executor/hook.py b/apps/scheduler/executor/hook.py new file mode 100644 index 000000000..a12fbb85b --- /dev/null +++ b/apps/scheduler/executor/hook.py @@ -0,0 +1,75 @@ +"""特殊Call的Runner;obj为当前Executor对象""" +from typing import Any, ClassVar, Protocol, TypeVar + +from pydantic import Field +from ray import actor + +from apps.entities.enum_var import EventType +from apps.entities.scheduler import CallOutputChunk, ExecutorBackground +from apps.entities.task import TaskBlock +from apps.scheduler.call.facts.facts import FactsCall +from apps.scheduler.call.summary.summary import Summary + + +class StepHookProtocol(Protocol): + """Step Hook的Protocol""" + + task: TaskBlock + executor_background: ExecutorBackground = Field(description="Executor的背景信息") + + async def _push_call_output(self, task: TaskBlock, queue: actor.ActorHandle, msg_type: EventType, data: dict[str, Any]) -> None: + """推送输出""" + ... + + +T = TypeVar("T", bound=StepHookProtocol) + + +class StepHookMixins: + """Executor的Step Hook Mixins""" + + _HOOK: ClassVar[dict[str, Any]] = { + "Facts": { + "before_init": "_facts_before_init", + "after_run": "_facts_after_run", + }, + "Summary": { + "before_init": "_summary_before_init", + "after_run": "_summary_after_run", + }, + "Output": { + "after_run": "_output_after_run", + }, + } + + async def _facts_before_init(self: T) -> FactsCall: # type: ignore[override] + """Facts Call的特殊run_step""" + return FactsCall( + answer=self.task.record.content.answer, + ) + + async def _facts_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override] + """Facts Call的输出hook""" + self.task.record.content.facts = output.content["facts"] + + + async def _summary_before_init(self: T) -> Summary: # type: ignore[override] + """Summary Call的输入hook""" + return Summary( + context=self.executor_background, + ) + + + async def _summary_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override] + """Summary Call的输出hook""" + self.task.runtime_data.summary = output.content + + + async def _output_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override] + """Output Call的输出hook""" + ... + + + async def _suggest_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override] + """Suggest Call的输出hook""" + ... diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py deleted file mode 100644 index f387dc5ff..000000000 --- a/apps/scheduler/executor/message.py +++ /dev/null @@ -1,131 +0,0 @@ -"""FlowExecutor的消息推送 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -""" -from datetime import datetime, timezone -from typing import Any - -from ray import actor - -from apps.entities.enum_var import EventType, FlowOutputType -from apps.entities.flow import Flow -from apps.entities.message import ( - FlowStartContent, - FlowStopContent, - TextAddContent, -) -from apps.entities.task import ( - ExecutorState, - FlowStepHistory, - TaskBlock, -) - - -# FIXME: 临时使用截断规避前端问题 -def truncate_data(data: Any) -> Any: - """截断数据""" - if isinstance(data, dict): - return {k: truncate_data(v) for k, v in data.items()} - if isinstance(data, list): - return [truncate_data(v) for v in data] - if isinstance(data, str): - if len(data) > 300: - return data[:300] + "..." - return data - return data - - -async def push_step_input(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, input_data: dict[str, Any]) -> TaskBlock: - """推送步骤输入""" - # 更新State - task.flow_state = state - # 更新FlowContext - task.flow_context[state.step_id] = FlowStepHistory( - task_id=task.record.task_id, - flow_id=state.flow_id, - step_id=state.step_id, - status=state.status, - input_data=state.filled_data, - output_data={}, - ) - # 步骤开始,重置时间 - task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp(), 2) - - # 推送消息 - await queue.push_output.remote(task, event_type=EventType.STEP_INPUT, data=truncate_data(input_data)) # type: ignore[attr-defined] - return task - - -async def push_step_output(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, output: dict[str, Any]) -> TaskBlock: - """推送步骤输出""" - # 更新State - task.flow_state = state - - # 更新FlowContext - task.flow_context[state.step_id].output_data = output - task.flow_context[state.step_id].status = state.status - - # FlowContext加入Record - task.new_context.append(task.flow_context[state.step_id].id) - - # 推送消息 - await queue.push_output.remote(task, event_type=EventType.STEP_OUTPUT, data=truncate_data(output)) # type: ignore[attr-defined] - return task - - -async def push_flow_start(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, question: str) -> TaskBlock: - """推送Flow开始""" - # 设置state - task.flow_state = state - - # 组装消息 - content = FlowStartContent( - question=question, - params=state.filled_data, - ) - # 推送消息 - await queue.push_output.remote(task, event_type=EventType.FLOW_START, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined] - return task - - -async def assemble_flow_stop_content(state: ExecutorState, flow: Flow) -> FlowStopContent: - """组装Flow结束消息""" - if state.remaining_schema: - # 如果当前Flow是填充步骤,则推送Schema - content = FlowStopContent( - type=FlowOutputType.SCHEMA, - data=state.remaining_schema, - ) - else: - content = FlowStopContent() - - return content - - # elif call_type == "render": - # # 如果当前Flow是图表,则推送Chart - # chart_option = task.flow_context[state.step_id].output_data["output"] - # content = FlowStopContent( - # type=FlowOutputType.CHART, - # data=chart_option, - # ) - -async def push_flow_stop(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, flow: Flow, final_answer: str) -> TaskBlock: - """推送Flow结束""" - # 设置state - task.flow_state = state - # 保存最终输出 - task.record.content.answer = final_answer - content = await assemble_flow_stop_content(state, flow) - - # 推送Stop消息 - await queue.push_output.remote(task, event_type=EventType.FLOW_STOP, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined] - return task - - -async def push_text_output(task: TaskBlock, queue: actor.ActorHandle, text: str) -> None: - """推送文本输出""" - content = TextAddContent( - text=text, - ) - # 推送消息 - await queue.push_output.remote(task, event_type=EventType.TEXT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined] diff --git a/apps/scheduler/executor/node.py b/apps/scheduler/executor/node.py new file mode 100644 index 000000000..64eed95b1 --- /dev/null +++ b/apps/scheduler/executor/node.py @@ -0,0 +1,52 @@ +"""Node相关的函数,含Node转换为Call""" +import inspect +from typing import Any + +import ray + +from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.empty import Empty +from apps.scheduler.call.facts.facts import FactsCall +from apps.scheduler.call.slot.slot import Slot +from apps.scheduler.call.summary.summary import Summary + + +class StepNode: + """Executor的Node相关逻辑""" + + @staticmethod + async def check_cls(call_cls: Any) -> bool: + """检查Call是否符合标准要求""" + flag = True + if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str): + flag = False + if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str): + flag = False + if not hasattr(call_cls, "exec") or not inspect.isasyncgenfunction(call_cls.exec): + flag = False + return flag + + + @staticmethod + async def get_call_cls(call_id: str) -> type[CoreCall]: + """获取并验证Call类""" + # 特判,用于处理隐藏节点 + if call_id == "Empty": + return Empty + if call_id == "Summary": + return Summary + if call_id == "Facts": + return FactsCall + if call_id == "Slot": + return Slot + + # 从Pool中获取对应的Call + pool = ray.get_actor("pool") + call_cls: type[CoreCall] = await pool.get_call.remote(call_id) + + # 检查Call合法性 + if not await StepNode.check_cls(call_cls): + err = f"[FlowExecutor] 工具 {call_id} 不符合Call标准要求" + raise ValueError(err) + + return call_cls diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py new file mode 100644 index 000000000..22ca1e780 --- /dev/null +++ b/apps/scheduler/executor/step.py @@ -0,0 +1,193 @@ +"""工作流中步骤相关函数""" + +import logging +from datetime import datetime, timezone +from typing import Any, Optional + +from pydantic import BaseModel, ConfigDict +from ray import actor + +from apps.entities.enum_var import EventType, StepStatus +from apps.entities.flow import Flow, Step +from apps.entities.pool import NodePool +from apps.entities.scheduler import CallOutputChunk, CallVars, ExecutorBackground +from apps.entities.task import FlowStepHistory, TaskBlock +from apps.manager.node import NodeManager +from apps.scheduler.call.slot.schema import SlotInput, SlotOutput +from apps.scheduler.call.slot.slot import Slot +from apps.scheduler.executor.base import BaseExecutor +from apps.scheduler.executor.hook import StepHookMixins +from apps.scheduler.executor.node import StepNode + +logger = logging.getLogger("ray") +SPECIAL_EVENT_TYPES = [ + EventType.GRAPH, + EventType.SUGGEST, + EventType.TEXT_ADD, +] + + +class StepExecutor(BaseExecutor, StepHookMixins): + """工作流中步骤相关函数""" + + sys_vars: CallVars + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="allow", + ) + + def __init__(self, **kwargs: Any) -> None: + """初始化""" + super().__init__(**kwargs) + if not self.task.flow_state: + err = "[StepExecutor] 当前ExecutorState为空" + logger.error(err) + raise ValueError(err) + + self.step_history = FlowStepHistory( + task_id=self.sys_vars.task_id, + flow_id=self.sys_vars.flow_id, + step_id=self.task.flow_state.step_id, + status=self.task.flow_state.status, + input_data={}, + output_data={}, + ) + + async def _push_call_output(self, task: TaskBlock, queue: actor.ActorHandle, msg_type: EventType, data: dict[str, Any]) -> None: + """推送输出""" + # 如果不是特殊类型 + if msg_type not in SPECIAL_EVENT_TYPES: + err = f"[StepExecutor] 不支持的事件类型: {msg_type}" + logger.error(err) + raise ValueError(err) + + # 推送消息 + await queue.push_output.remote( + task, + event_type=msg_type, + data=data, + ) # type: ignore[attr-defined] + + async def init_step(self, step_id: str) -> tuple[str, Any]: + """初始化步骤""" + if not self.task.flow_state: + err = "[StepExecutor] 当前ExecutorState为空" + logger.error(err) + raise ValueError(err) + + logger.info("[FlowExecutor] 运行步骤 %s", self.flow.steps[step_id].name) + + # State写入ID和运行状态 + self.task.flow_state.step_id = step_id + self.task.flow_state.step_name = self.flow.steps[step_id].name + self.task.flow_state.status = StepStatus.RUNNING + + # 获取并验证Call类 + node_id = self.flow.steps[step_id].node + # 获取node详情并存储 + try: + self._node_data = await NodeManager.get_node(node_id) + except Exception: + logger.info("[StepExecutor] 获取Node失败,为内置Node或ID不存在") + self._node_data = None + + if self._node_data: + call_cls = await StepNode.get_call_cls(self._node_data.call_id) + call_id = self._node_data.call_id + else: + # 可能是特殊的内置Node + call_cls = await StepNode.get_call_cls(node_id) + call_id = node_id + + # 初始化Call Class,用户参数会覆盖node的参数 + params: dict[str, Any] = self._node_data.known_params if self._node_data and self._node_data.known_params else {} # type: ignore[union-attr] + if self.flow.steps[step_id].params: + params.update(self.flow.steps[step_id].params) + + # 检查初始化Call Class是否需要特殊处理 + try: + if call_id in self._HOOK and "before_init" in self._HOOK[call_id]: + func = getattr(self, self._HOOK[call_id]["before_init"]) + call_obj = await func() + else: + call_obj = call_cls.model_validate(params) + except Exception: + logger.exception("[StepExecutor] 初始化Call失败") + raise + else: + return call_id, call_obj + + + async def fill_slots(self, call_obj: Any) -> dict[str, Any]: + """执行Call并处理结果""" + if not self.task.flow_state: + err = "[StepExecutor] 当前ExecutorState为空" + logger.error(err) + raise ValueError(err) + + # 尝试初始化call_obj + try: + input_data = await call_obj.init(self.sys_vars) + except Exception: + logger.exception("[StepExecutor] 初始化Call失败") + raise + + # 合并已经填充的参数 + input_data.update(self.task.runtime_data.filled_data) + + # 检查输入参数 + input_schema = call_obj.input_type.model_json_schema(override=self._node_data.override_input) if self._node_data else {} + + # 检查输入参数是否需要填充 + slot_obj = Slot( + data=input_data, + current_schema=input_schema, + summary=self.task.runtime_data.summary, + facts=self.task.record.content.facts, + ) + slot_input = SlotInput(**await slot_obj.init(self.sys_vars)) + + # 若需要填参,额外执行一步 + if slot_input.remaining_schema: + output = SlotOutput(**await self.run_step("Slot", slot_obj, slot_input.remaining_schema)) + self.task.runtime_data.filled_data = output.slot_data + + # 如果还有剩余参数,则中断 + if output.remaining_schema: + # 更新状态 + self.task.flow_state.status = StepStatus.PARAM + self.task.flow_state.remaining_schema = output.remaining_schema + err = "[StepExecutor] 需要填参" + logger.error(err) + raise ValueError(err) + + return output.slot_data + return input_data + + + # TODO: 需要评估如何进行流式输出 + async def run_step(self, call_id: str, call_obj: Any, input_data: dict[str, Any]) -> dict[str, Any]: + """运行单个步骤""" + iterator = call_obj.exec(input_data) + await self.push_message(EventType.STEP_INPUT, input_data) + + result: Optional[CallOutputChunk] = None + async for chunk in iterator: + result = chunk + + if result is None: + content = {} + else: + content = result.content if isinstance(result.content, dict) else {"message": result.content} + + # 检查Call Class的返回值是否需要特殊处理 + if call_id in self._HOOK and "after_run" in self._HOOK[call_id]: + func = getattr(self, self._HOOK[call_id]["after_run"]) + await func(result) + + # 更新执行状态 + self.task.flow_state.status = StepStatus.SUCCESS + await self.push_message(EventType.STEP_OUTPUT, content) + + return content -- Gitee From 9db341aee6b3e6983948e464a060d28b37dabc52 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:07:03 +0800 Subject: [PATCH 5/8] =?UTF-8?q?=E5=90=8C=E6=AD=A5Call=E6=94=B9=E5=8A=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/__init__.py | 12 +- apps/scheduler/call/{ => api}/api.py | 176 +++++++++++------- apps/scheduler/call/api/schema.py | 25 +++ apps/scheduler/call/{ => choice}/choice.py | 0 apps/scheduler/call/{ => convert}/convert.py | 54 +++--- apps/scheduler/call/convert/schema.py | 17 ++ apps/scheduler/call/core.py | 45 +++-- apps/scheduler/call/empty.py | 29 +-- apps/scheduler/call/facts.py | 52 ------ apps/scheduler/call/facts/facts.py | 58 ++++++ apps/scheduler/call/facts/schema.py | 27 +++ .../call/{render => graph}/__init__.py | 0 .../call/{render/render.py => graph/graph.py} | 110 ++++------- .../call/{render => graph}/option.json | 0 apps/scheduler/call/graph/schema.py | 39 ++++ .../scheduler/call/{render => graph}/style.py | 0 apps/scheduler/call/{ => llm}/llm.py | 57 +++--- apps/scheduler/call/llm/schema.py | 107 +++++++++++ apps/scheduler/call/output/output.py | 41 ++++ apps/scheduler/call/output/schema.py | 15 ++ apps/scheduler/call/{ => rag}/rag.py | 72 +++---- apps/scheduler/call/rag/schema.py | 31 +++ apps/scheduler/call/search/search.py | 31 +++ apps/scheduler/call/slot/schema.py | 24 +++ apps/scheduler/call/slot/slot.py | 68 +++++++ apps/scheduler/call/sql/schema.py | 19 ++ apps/scheduler/call/{ => sql}/sql.py | 74 ++++---- apps/scheduler/call/suggest/schema.py | 37 ++++ apps/scheduler/call/{ => suggest}/suggest.py | 70 +++---- apps/scheduler/call/summary.py | 43 ----- apps/scheduler/call/summary/schema.py | 17 ++ apps/scheduler/call/summary/summary.py | 39 ++++ 32 files changed, 951 insertions(+), 438 deletions(-) rename apps/scheduler/call/{ => api}/api.py (39%) create mode 100644 apps/scheduler/call/api/schema.py rename apps/scheduler/call/{ => choice}/choice.py (100%) rename apps/scheduler/call/{ => convert}/convert.py (56%) create mode 100644 apps/scheduler/call/convert/schema.py delete mode 100644 apps/scheduler/call/facts.py create mode 100644 apps/scheduler/call/facts/facts.py create mode 100644 apps/scheduler/call/facts/schema.py rename apps/scheduler/call/{render => graph}/__init__.py (100%) rename apps/scheduler/call/{render/render.py => graph/graph.py} (48%) rename apps/scheduler/call/{render => graph}/option.json (100%) create mode 100644 apps/scheduler/call/graph/schema.py rename apps/scheduler/call/{render => graph}/style.py (100%) rename apps/scheduler/call/{ => llm}/llm.py (63%) create mode 100644 apps/scheduler/call/llm/schema.py create mode 100644 apps/scheduler/call/output/output.py create mode 100644 apps/scheduler/call/output/schema.py rename apps/scheduler/call/{ => rag}/rag.py (43%) create mode 100644 apps/scheduler/call/rag/schema.py create mode 100644 apps/scheduler/call/search/search.py create mode 100644 apps/scheduler/call/slot/schema.py create mode 100644 apps/scheduler/call/slot/slot.py create mode 100644 apps/scheduler/call/sql/schema.py rename apps/scheduler/call/{ => sql}/sql.py (52%) create mode 100644 apps/scheduler/call/suggest/schema.py rename apps/scheduler/call/{ => suggest}/suggest.py (38%) delete mode 100644 apps/scheduler/call/summary.py create mode 100644 apps/scheduler/call/summary/schema.py create mode 100644 apps/scheduler/call/summary/summary.py diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 5d5605b0c..913fc8f78 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -2,15 +2,19 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from apps.scheduler.call.api import API -from apps.scheduler.call.llm import LLM -from apps.scheduler.call.rag import RAG -from apps.scheduler.call.suggest import Suggestion +from apps.scheduler.call.api.api import API +from apps.scheduler.call.convert.convert import Convert +from apps.scheduler.call.llm.llm import LLM +from apps.scheduler.call.output.output import Output +from apps.scheduler.call.rag.rag import RAG +from apps.scheduler.call.suggest.suggest import Suggestion # 只包含需要在编排界面展示的工具 __all__ = [ "API", "LLM", "RAG", + "Convert", + "Output", "Suggestion", ] diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api/api.py similarity index 39% rename from apps/scheduler/call/api.py rename to apps/scheduler/call/api/api.py index d6753a59c..4a17070cc 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api/api.py @@ -3,83 +3,126 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from typing import Any, ClassVar, Literal, Optional +import logging +from collections.abc import AsyncGenerator +from typing import Annotated, Any, ClassVar, Optional import aiohttp from fastapi import status -from pydantic import BaseModel, Field +from pydantic import Field -from apps.constants import LOGGER -from apps.entities.enum_var import ContentType, HTTPMethod -from apps.entities.scheduler import CallError, CallVars +from apps.common.config import config +from apps.entities.enum_var import CallOutputType, ContentType, HTTPMethod +from apps.entities.flow import ServiceMetadata +from apps.entities.scheduler import CallError, CallOutputChunk, CallVars +from apps.manager.service import ServiceCenterManager from apps.manager.token import TokenManager +from apps.scheduler.call.api.schema import APIInput, APIOutput from apps.scheduler.call.core import CoreCall -from apps.scheduler.slot.slot import Slot - -class APIOutput(BaseModel): - """API调用工具的输出""" - - http_code: int = Field(description="API调用工具的HTTP返回码") - message: str = Field(description="API调用工具的执行结果") - output: dict[str, Any] = Field(description="API调用工具的输出") - - -class API(CoreCall, ret_type=APIOutput): +logger = logging.getLogger("ray") +SUCCESS_HTTP_CODES = [ + status.HTTP_200_OK, + status.HTTP_201_CREATED, + status.HTTP_202_ACCEPTED, + status.HTTP_203_NON_AUTHORITATIVE_INFORMATION, + status.HTTP_204_NO_CONTENT, + status.HTTP_205_RESET_CONTENT, + status.HTTP_206_PARTIAL_CONTENT, + status.HTTP_207_MULTI_STATUS, + status.HTTP_208_ALREADY_REPORTED, + status.HTTP_226_IM_USED, + status.HTTP_301_MOVED_PERMANENTLY, + status.HTTP_302_FOUND, + status.HTTP_303_SEE_OTHER, + status.HTTP_304_NOT_MODIFIED, + status.HTTP_307_TEMPORARY_REDIRECT, + status.HTTP_308_PERMANENT_REDIRECT, +] + + +class API(CoreCall, input_type=APIInput, output_type=APIOutput): """API调用工具""" - name: ClassVar[str] = "HTTP请求" - description: ClassVar[str] = "向某一个API接口发送HTTP请求,获取数据。" + name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "HTTP请求" + description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "向某一个API接口发送HTTP请求,获取数据。" url: str = Field(description="API接口的完整URL") method: HTTPMethod = Field(description="API接口的HTTP Method") content_type: ContentType = Field(description="API接口的Content-Type") timeout: int = Field(description="工具超时时间", default=300, gt=30) + body: dict[str, Any] = Field(description="已知的部分请求体", default={}) - auth: dict[str, Any] = Field(description="API鉴权信息", default={}) + query: dict[str, Any] = Field(description="已知的部分请求参数", default={}) - async def exec(self, syscall_vars: CallVars, **_kwargs: Any) -> APIOutput: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: + """初始化API调用工具""" + await super().init(syscall_vars) + # 获取对应API的Service Metadata + try: + service_metadata = await ServiceCenterManager.get_service_data(syscall_vars.user_sub, syscall_vars.service_id) + service_metadata = ServiceMetadata.model_validate(service_metadata) + except Exception as e: + raise CallError( + message="API接口的Service Metadata获取失败", + data={}, + ) from e + + # 获取Service对应的Auth + self._auth = service_metadata.api.auth + self._session_id = syscall_vars.session_id + self._service_id = syscall_vars.service_id + + return APIInput( + service_id=self._service_id, + url=self.url, + method=self.method, + query=self.query, + body=self.body, + ).model_dump(exclude_none=True, by_alias=True) + + + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """调用API,然后返回LLM解析后的数据""" self._session = aiohttp.ClientSession() try: - result = await self._call_api(slot_data) - await self._session.close() - return result - except Exception as e: + result = await self._call_api(input_data) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=result.model_dump(exclude_none=True, by_alias=True), + ) + finally: await self._session.close() - raise RuntimeError from e - async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202, C901 + async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202 + """组装API请求Session""" # 获取必要参数 - syscall_vars: CallVars = getattr(self, "_syscall_vars") - - if self.content_type != "form": - req_header = { - "Content-Type": ContentType.JSON.value, - } - else: - req_header = {} + req_header = { + "Content-Type": self.content_type.value, + } req_cookie = {} req_params = {} - if data is None: - data = {} - - if self.auth is not None and "type" in self.auth: - if self.auth["type"] == "header": - req_header.update(self.auth["args"]) - elif self.auth["type"] == "cookie": - req_cookie.update(self.auth["args"]) - elif self.auth["type"] == "params": - req_params.update(self.auth["args"]) - elif self.auth["type"] == "oidc": + data = data or {} + + if self._auth: + if self._auth.header: + for item in self._auth.header: + req_header[item.name] = item.value + elif self._auth.cookie: + for item in self._auth.cookie: + req_cookie[item.name] = item.value + elif self._auth.query: + for item in self._auth.query: + req_params[item.name] = item.value + elif self._auth.oidc: token = await TokenManager.get_plugin_token( - self.auth["domain"], - syscall_vars.session_id, - self.auth["access_token_url"], - int(self.auth["token_expire_time"]), + self._service_id, + self._session_id, + config["OIDC_ACCESS_TOKEN_URL"], + int(config["OIDC_ACCESS_TOKEN_EXPIRE_TIME"]), ) req_header.update({"access-token": token}) @@ -102,49 +145,40 @@ class API(CoreCall, ret_type=APIOutput): raise NotImplementedError(err) - async def _call_api(self, slot_data: Optional[dict[str, Any]] = None) -> APIOutput: + async def _call_api(self, final_data: Optional[dict[str, Any]] = None) -> APIOutput: + """实际调用API,并处理返回值""" # 获取必要参数 - params: _APIParams = getattr(self, "_params") - LOGGER.info(f"调用接口{params.url},请求数据为{slot_data}") + logger.info("[API] 调用接口 %s,请求数据为 %s", self.url, final_data) - session_context = await self._make_api_call(slot_data, aiohttp.FormData()) + session_context = await self._make_api_call(final_data, aiohttp.FormData()) async with session_context as response: - if response.status >= status.HTTP_400_BAD_REQUEST: + if response.status in SUCCESS_HTTP_CODES: text = f"API发生错误:API返回状态码{response.status}, 原因为{response.reason}。" - LOGGER.error(text) + logger.error(text) raise CallError( message=text, data={"api_response_data": await response.text()}, ) + response_status = response.status response_data = await response.text() - LOGGER.info(f"调用接口{params.url}, 结果为 {response_data}") + logger.info("[API] 调用接口 %s,结果为 %s", self.url, response_data) - # 组装message - message = f"""You called the HTTP API "{params.url}", which is used to "{self._spec[2]['summary']}".""" # 如果没有返回结果 - if response_data is None: + if not response_data: return APIOutput( http_code=response_status, - output={}, - message=message + "But the API returned an empty response.", + result={}, ) # 如果返回值是JSON try: response_dict = json.loads(response_data) - except Exception as e: - err = f"返回值不是JSON:{e!s}" - LOGGER.error(err) - - # openapi里存在有HTTP 200对应的Schema,则进行处理 - if response_schema: - slot = Slot(response_schema) - response_data = json.dumps(slot.process_json(response_dict), ensure_ascii=False) + except Exception: + logger.exception("[API] 返回值不是JSON格式!") return APIOutput( http_code=response_status, - output=json.loads(response_data), - message=message + "The API returned some data, and is shown in the 'output' field below.", + result=response_dict, ) diff --git a/apps/scheduler/call/api/schema.py b/apps/scheduler/call/api/schema.py new file mode 100644 index 000000000..12afcb7e5 --- /dev/null +++ b/apps/scheduler/call/api/schema.py @@ -0,0 +1,25 @@ +"""API调用工具的输入和输出""" + +from typing import Any, Union + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class APIInput(DataBase): + """API调用工具的输入""" + + service_id: str = Field(description="API调用工具的Service ID") + url: str = Field(description="API调用工具的URL") + method: str = Field(description="API调用工具的HTTP方法") + + query: dict[str, Any] = Field(description="API调用工具的请求参数") + body: dict[str, Any] = Field(description="API调用工具的请求体") + + +class APIOutput(DataBase): + """API调用工具的输出""" + + http_code: int = Field(description="API调用工具的HTTP返回码") + result: Union[dict[str, Any], str] = Field(description="API调用工具的输出") diff --git a/apps/scheduler/call/choice.py b/apps/scheduler/call/choice/choice.py similarity index 100% rename from apps/scheduler/call/choice.py rename to apps/scheduler/call/choice/choice.py diff --git a/apps/scheduler/call/convert.py b/apps/scheduler/call/convert/convert.py similarity index 56% rename from apps/scheduler/call/convert.py rename to apps/scheduler/call/convert/convert.py index 0b6ff004f..cf85bf73a 100644 --- a/apps/scheduler/call/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -3,43 +3,43 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json +from collections.abc import AsyncGenerator from datetime import datetime from textwrap import dedent -from typing import Any, Optional +from typing import Any, ClassVar, Optional import _jsonnet import pytz from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment -from pydantic import BaseModel, Field +from pydantic import Field +from apps.entities.enum_var import CallOutputType from apps.entities.scheduler import CallVars -from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput +from apps.scheduler.call.core import CallOutputChunk, CoreCall -class ConvertOutput(BaseModel): - """定义Convert工具的输出""" - - text: str = Field(description="格式化后的文字信息") - data: dict = Field(description="格式化后的结果") - - -class Convert(CoreCall, ret_type=ConvertOutput): +class Convert(CoreCall, input_type=ConvertInput, output_type=ConvertOutput): """Convert 工具,用于对生成的文字信息和原始数据进行格式化""" - name: str = "convert" - description: str = "从上一步的工具的原始JSON返回结果中,提取特定字段的信息。" + name: ClassVar[str] = "模板转换" + description: ClassVar[str] = "使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。" text_template: Optional[str] = Field(description="自然语言信息的格式化模板,jinja2语法", default=None) data_template: Optional[str] = Field(description="原始数据的格式化模板,jsonnet语法", default=None) - async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化工具""" - return {} + await super().init(syscall_vars) + + self._history = syscall_vars.history + self._question = syscall_vars.question + return ConvertInput().model_dump(exclude_none=True, by_alias=True) - async def exec(self) -> ConvertOutput: + async def exec(self) -> AsyncGenerator[CallOutputChunk, None]: """调用Convert工具 :param _slot_data: 经用户确认后的参数(目前未使用) @@ -47,7 +47,7 @@ class Convert(CoreCall, ret_type=ConvertOutput): """ # 判断用户是否给了值 time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S") - if params.text is None: + if self.text_template is None: result_message = last_output.get("message", "") else: text_template = SandboxedEnvironment( @@ -55,25 +55,25 @@ class Convert(CoreCall, ret_type=ConvertOutput): autoescape=select_autoescape(), trim_blocks=True, lstrip_blocks=True, - ).from_string(params.text) - result_message = text_template.render(time=time, history=syscall_vars.history, question=syscall_vars.question) + ).from_string(self.text_template) + result_message = text_template.render(time=time, history=self._history, question=self._question) - if params.data is None: + if self.data_template is None: result_data = last_output.get("output", {}) else: extra_str = json.dumps({ "time": time, - "question": syscall_vars.question, + "question": self._question, }, ensure_ascii=False) - history_str = json.dumps([item.output_data["output"] for item in syscall_vars.history if "output" in item.output_data], ensure_ascii=False) + history_str = json.dumps([item.output_data["output"] for item in self._history if "output" in item.output_data], ensure_ascii=False) data_template = dedent(f""" local extra = {extra_str}; local history = {history_str}; - {params.data} + {self.data_template} """) - result_data = json.loads(_jsonnet.evaluate_snippet(data_template, params.data), ensure_ascii=False) + result_data = json.loads(_jsonnet.evaluate_snippet(data_template, self.data_template), ensure_ascii=False) - return ConvertOutput( - text=result_message, - data=result_data, + yield CallOutputChunk( + type=CallOutputType.DATA, + content=result_message, ) diff --git a/apps/scheduler/call/convert/schema.py b/apps/scheduler/call/convert/schema.py new file mode 100644 index 000000000..1af354c42 --- /dev/null +++ b/apps/scheduler/call/convert/schema.py @@ -0,0 +1,17 @@ +"""Convert工具的Schema""" + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class ConvertInput(DataBase): + """定义Convert工具的输入""" + + + +class ConvertOutput(DataBase): + """定义Convert工具的输出""" + + text: str = Field(description="格式化后的文字信息") + data: dict = Field(description="格式化后的结果") diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 675440431..83b6681ee 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -4,42 +4,55 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from collections.abc import AsyncGenerator -from typing import Any, ClassVar +from typing import Annotated, Any, ClassVar, Optional from pydantic import BaseModel, ConfigDict, Field -from apps.entities.scheduler import CallVars +from apps.entities.scheduler import CallOutputChunk, CallVars + + +class DataBase(BaseModel): + """所有Call的输入基类""" + + @classmethod + def model_json_schema(cls, override: Optional[dict[str, Any]] = None, **kwargs: Any) -> dict[str, Any]: + """通过override参数,动态填充Schema内容""" + schema = super().model_json_schema(**kwargs) + if override: + for key, value in override.items(): + schema["properties"][key] = value + return schema class CoreCall(BaseModel): """所有Call的父类,所有Call必须继承此类。""" - name: ClassVar[str] = Field(description="Call的名称") - description: ClassVar[str] = Field(description="Call的描述") + name: ClassVar[Annotated[str, Field(description="Call的名称", exclude=True)]] = "" + description: ClassVar[Annotated[str, Field(description="Call的描述", exclude=True)]] = "" model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", ) - ret_type: ClassVar[type[BaseModel]] = Field(description="Call的返回类型") + input_type: ClassVar[type[BaseModel]] = Field(description="Call的输入Pydantic类型", exclude=True, frozen=True) + """需要以消息形式输出的、需要大模型填充参数的输入信息填到这里""" + output_type: ClassVar[type[BaseModel]] = Field(description="Call的输出Pydantic类型", exclude=True, frozen=True) + """需要以消息形式输出的输出信息填到这里""" - def __init_subclass__(cls, ret_type: type[BaseModel], **kwargs: Any) -> None: + def __init_subclass__(cls, input_type: type[BaseModel], output_type: type[BaseModel], **kwargs: Any) -> None: """初始化子类""" super().__init_subclass__(**kwargs) - cls.ret_type = ret_type + cls.input_type = input_type + cls.output_type = output_type - async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化Call类,并返回Call的输入""" - raise NotImplementedError - - - async def exec(self) -> dict[str, Any]: - """Call类实例的非流式输出方法""" - raise NotImplementedError + ... - async def stream(self) -> AsyncGenerator[str, None]: + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """Call类实例的流式输出方法""" - yield "" + err = "[CoreCall] 流式输出方法必须手动实现" + raise NotImplementedError(err) diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py index 73741f166..2a712cf4e 100644 --- a/apps/scheduler/call/empty.py +++ b/apps/scheduler/call/empty.py @@ -2,30 +2,37 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +from collections.abc import AsyncGenerator from typing import Any, ClassVar -from pydantic import BaseModel +from pydantic import BaseModel, Field -from apps.entities.scheduler import CallVars -from apps.scheduler.call.core import CoreCall +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallOutputChunk, CallVars +from apps.scheduler.call.core import CoreCall, DataBase -class EmptyData(BaseModel): - """空数据""" +class EmptyInput(BaseModel): + """空输入""" -class Empty(CoreCall, ret_type=EmptyData): +class EmptyOutput(DataBase): + """空输出""" + + +class Empty(CoreCall, input_type=EmptyInput, output_type=EmptyOutput): """空Call""" - name: ClassVar[str] = "空白" - description: ClassVar[str] = "空白节点,用于占位" + name: ClassVar[str] = Field("空白", exclude=True, frozen=True) + description: ClassVar[str] = Field("空白节点,用于占位", exclude=True, frozen=True) - async def init(self, _syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化""" return {} - async def exec(self) -> dict[str, Any]: + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """执行""" - return EmptyData().model_dump() + output = CallOutputChunk(type=CallOutputType.TEXT, content="") + yield output diff --git a/apps/scheduler/call/facts.py b/apps/scheduler/call/facts.py deleted file mode 100644 index bec6f253e..000000000 --- a/apps/scheduler/call/facts.py +++ /dev/null @@ -1,52 +0,0 @@ -"""提取事实工具""" -from typing import Any, ClassVar - -from pydantic import BaseModel, Field - -from apps.entities.scheduler import CallVars -from apps.llm.patterns.domain import Domain -from apps.llm.patterns.facts import Facts -from apps.manager.user_domain import UserDomainManager -from apps.scheduler.call.core import CoreCall - - -class FactsOutput(BaseModel): - """提取事实工具的输出""" - - output: list[str] = Field(description="提取的事实") - - -class FactsCall(CoreCall, ret_type=FactsOutput): - """提取事实工具""" - - name: ClassVar[str] = "提取事实" - description: ClassVar[str] = "从对话上下文和文档片段中提取事实。" - - - async def init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]: - """初始化工具""" - # 组装必要变量 - self._message = { - "question": syscall_vars.question, - "answer": kwargs["answer"], - } - self._task_id = syscall_vars.task_id - self._user_sub = syscall_vars.user_sub - - return { - "message": self._message, - "task_id": self._task_id, - "user_sub": self._user_sub, - } - - - async def exec(self) -> dict[str, Any]: - """执行工具""" - # 提取事实信息 - facts = await Facts().generate(self._task_id, message=self._message) - # 更新用户画像 - domain_list = await Domain().generate(self._task_id, message=self._message) - for domain in domain_list: - await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(self._user_sub, domain) - - return FactsOutput(output=facts).model_dump(exclude_none=True, by_alias=True) diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py new file mode 100644 index 000000000..a45216c8c --- /dev/null +++ b/apps/scheduler/call/facts/facts.py @@ -0,0 +1,58 @@ +"""提取事实工具""" +from collections.abc import AsyncGenerator +from typing import Annotated, Any, ClassVar + +from pydantic import Field + +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallOutputChunk, CallVars +from apps.llm.patterns.domain import Domain +from apps.llm.patterns.facts import Facts +from apps.manager.user_domain import UserDomainManager +from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.facts.schema import FactsInput, FactsMessage, FactsOutput + + +class FactsCall(CoreCall, input_type=FactsInput, output_type=FactsOutput): + """提取事实工具""" + + name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "提取事实" + description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "从对话上下文和文档片段中提取事实。" + + answer: str = Field(description="用户输入") + + + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: + """初始化工具""" + await super().init(syscall_vars) + + # 组装必要变量 + message = FactsMessage( + question=syscall_vars.question, + answer=self.answer, + ) + + return FactsInput( + task_id=syscall_vars.task_id, + user_sub=syscall_vars.user_sub, + message=message, + ).model_dump(exclude_none=True, by_alias=True) + + + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行工具""" + data = FactsInput(**input_data) + # 提取事实信息 + facts = await Facts().generate(data.task_id, message=data.message) + # 更新用户画像 + domain_list = await Domain().generate(data.task_id, message=data.message) + for domain in domain_list: + await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) + + yield CallOutputChunk( + type=CallOutputType.DATA, + content=FactsOutput( + output=facts, + domain=domain_list, + ).model_dump(by_alias=True, exclude_none=True), + ) diff --git a/apps/scheduler/call/facts/schema.py b/apps/scheduler/call/facts/schema.py new file mode 100644 index 000000000..abbfe6655 --- /dev/null +++ b/apps/scheduler/call/facts/schema.py @@ -0,0 +1,27 @@ +"""Facts工具的输入和输出""" + +from pydantic import BaseModel, Field + +from apps.scheduler.call.core import DataBase + + +class FactsMessage(BaseModel): + """提取事实工具的消息""" + + question: str = Field(description="用户输入") + answer: str = Field(description="用户输入") + + +class FactsInput(DataBase): + """提取事实工具的输入""" + + task_id: str = Field(description="任务ID") + user_sub: str = Field(description="用户ID") + message: FactsMessage = Field(description="消息") + + +class FactsOutput(DataBase): + """提取事实工具的输出""" + + output: list[str] = Field(description="提取的事实") + domain: list[str] = Field(description="提取的领域") diff --git a/apps/scheduler/call/render/__init__.py b/apps/scheduler/call/graph/__init__.py similarity index 100% rename from apps/scheduler/call/render/__init__.py rename to apps/scheduler/call/graph/__init__.py diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/graph/graph.py similarity index 48% rename from apps/scheduler/call/render/render.py rename to apps/scheduler/call/graph/graph.py index 09bc3b8a0..e480afa39 100644 --- a/apps/scheduler/call/render/render.py +++ b/apps/scheduler/call/graph/graph.py @@ -3,115 +3,87 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json -from pathlib import Path -from typing import Any +from collections.abc import AsyncGenerator +from typing import Any, ClassVar -from pydantic import BaseModel, Field +from anyio import Path +from pydantic import Field -from apps.entities.scheduler import CallError, CallVars +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallError, CallOutputChunk, CallVars from apps.scheduler.call.core import CoreCall -from apps.scheduler.call.render.style import RenderStyle +from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput +from apps.scheduler.call.graph.style import RenderStyle -class _RenderAxis(BaseModel): - """ECharts图表的轴配置""" - - type: str = Field(description="轴的类型") - axisTick: dict = Field(description="轴刻度配置") # noqa: N815 - - -class _RenderFormat(BaseModel): - """ECharts图表配置""" - - tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置") - legend: dict[str, Any] = Field(description="ECharts图表的图例配置") - dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置") - xAxis: _RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815 - yAxis: _RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815 - series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置") - - -class _RenderOutput(BaseModel): - """Render工具的输出""" - - output: _RenderFormat = Field(description="ECharts图表配置") - message: str = Field(description="Render工具的输出") - - -class _RenderParam(BaseModel): - """Render工具的参数""" - - - -class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutput): +class Graph(CoreCall, output_type=RenderOutput): """Render Call,用于将SQL Tool查询出的数据转换为图表""" - name: str = "render" - description: str = "渲染图表工具,可将给定的数据绘制为图表。" + name: ClassVar[str] = Field("图表", exclude=True, frozen=True) + description: ClassVar[str] = Field("将SQL查询出的数据转换为图表", exclude=True, frozen=True) + + data: list[dict[str, Any]] = Field("用于绘制图表的数据", exclude=True, frozen=True) - def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003 + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化Render Call,校验参数,读取option模板""" + await super().init(syscall_vars) + try: option_location = Path(__file__).parent / "option.json" - with Path(option_location).open(encoding="utf-8") as f: - self._option_template = json.load(f) + f = await Path(option_location).open(encoding="utf-8") + data = await f.read() + self._option_template = json.loads(data) + await f.aclose() except Exception as e: raise CallError(message=f"图表模板读取失败:{e!s}", data={}) from e + return RenderInput( + question=syscall_vars.question, + task_id=syscall_vars.task_id, + data=self.data, + ).model_dump(exclude_none=True, by_alias=True) - async def __call__(self, _slot_data: dict[str, Any]) -> _RenderOutput: - """运行Render Call""" - # 获取必要参数 - syscall_vars: CallVars = getattr(self, "_syscall_vars") - - # 检测前一个工具是否为SQL - if "dataset" not in syscall_vars.history[-1].output_data: - raise CallError( - message="图表生成失败!Render必须在SQL后调用!", - data={}, - ) - data = syscall_vars.history[-1].output_data["output"] - if data["type"] != "sql" or "dataset" not in data: - raise CallError( - message="图表生成失败!Render必须在SQL后调用!", - data={}, - ) - data = json.loads(data["dataset"]) + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """运行Render Call""" + data = RenderInput(**input_data) # 判断数据格式是否满足要求 # 样例:[{'openeuler_version': 'openEuler-22.03-LTS-SP2', '软件数量': 10}] malformed = True - if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict): + if isinstance(data.data, list) and len(data.data) > 0 and isinstance(data.data[0], dict): malformed = False # 将执行SQL工具查询到的数据转换为特定格式 if malformed: raise CallError( - message="SQL未查询到数据,或数据格式错误,无法生成图表!", - data={"data": data}, + message="数据格式错误,无法生成图表!", + data={"data": data.data}, ) # 对少量数据进行处理 - column_num = len(data[0]) - 1 + processed_data = data.data + column_num = len(processed_data[0]) - 1 if column_num == 0: - data = Render._separate_key_value(data) + processed_data = Graph._separate_key_value(processed_data) column_num = 1 # 该格式满足ECharts DataSource要求,与option模板进行拼接 - self._option_template["dataset"]["source"] = data + self._option_template["dataset"]["source"] = processed_data try: - llm_output = await RenderStyle().generate(syscall_vars.task_id, question=syscall_vars.question) + llm_output = await RenderStyle().generate(data.task_id, question=data.question) add_style = llm_output.get("additional_style", "") self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"]) except Exception as e: raise CallError(message=f"图表生成失败:{e!s}", data={"data": data}) from e - return _RenderOutput( - output=_RenderFormat.model_validate(self._option_template), - message="图表生成成功!图表将使用外置工具进行展示。", + yield CallOutputChunk( + type=CallOutputType.DATA, + content=RenderOutput( + output=RenderFormat.model_validate(self._option_template), + ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/scheduler/call/render/option.json b/apps/scheduler/call/graph/option.json similarity index 100% rename from apps/scheduler/call/render/option.json rename to apps/scheduler/call/graph/option.json diff --git a/apps/scheduler/call/graph/schema.py b/apps/scheduler/call/graph/schema.py new file mode 100644 index 000000000..6f26c5224 --- /dev/null +++ b/apps/scheduler/call/graph/schema.py @@ -0,0 +1,39 @@ +"""图表工具的输入输出""" + +from typing import Any + +from pydantic import BaseModel, Field + +from apps.scheduler.call.core import DataBase + + +class RenderAxis(BaseModel): + """ECharts图表的轴配置""" + + type: str = Field(description="轴的类型") + axisTick: dict = Field(description="轴刻度配置") # noqa: N815 + + +class RenderFormat(BaseModel): + """ECharts图表配置""" + + tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置") + legend: dict[str, Any] = Field(description="ECharts图表的图例配置") + dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置") + xAxis: RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815 + yAxis: RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815 + series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置") + + +class RenderInput(DataBase): + """图表工具的输入""" + + question: str = Field(description="用户输入") + task_id: str = Field(description="任务ID") + data: list[dict[str, Any]] = Field(description="图表数据") + + +class RenderOutput(DataBase): + """Render工具的输出""" + + output: RenderFormat = Field(description="ECharts图表配置") diff --git a/apps/scheduler/call/render/style.py b/apps/scheduler/call/graph/style.py similarity index 100% rename from apps/scheduler/call/render/style.py rename to apps/scheduler/call/graph/style.py diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm/llm.py similarity index 63% rename from apps/scheduler/call/llm.py rename to apps/scheduler/call/llm/llm.py index 290bf6990..7af2ab324 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -5,34 +5,31 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import logging from collections.abc import AsyncGenerator from datetime import datetime -from typing import Any, ClassVar +from typing import Annotated, Any, ClassVar import pytz from jinja2 import BaseLoader, select_autoescape from jinja2.sandbox import SandboxedEnvironment -from pydantic import BaseModel, Field +from pydantic import Field -from apps.constants import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT -from apps.entities.scheduler import CallError, CallVars +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallError, CallOutputChunk, CallVars from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.llm.schema import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT, LLMInput, LLMOutput logger = logging.getLogger("ray") -class LLMNodeOutput(BaseModel): - """定义LLM工具调用的输出""" - message: str = Field(description="大模型输出的文字信息") - -class LLM(CoreCall, ret_type=LLMNodeOutput): +class LLM(CoreCall, input_type=LLMInput, output_type=LLMOutput): """大模型调用工具""" - name: ClassVar[str] = "大模型" - description: ClassVar[str] = "以指定的提示词和上下文信息调用大模型,并获得输出。" + name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "大模型" + description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "以指定的提示词和上下文信息调用大模型,并获得输出。" temperature: float = Field(description="大模型温度(随机化程度)", default=0.7) enable_context: bool = Field(description="是否启用上下文", default=True) - step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3) + step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3, ge=1, le=10) system_prompt: str = Field(description="大模型系统提示词", default="") user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT) @@ -83,33 +80,23 @@ class LLM(CoreCall, ret_type=LLMNodeOutput): ] - async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化LLM工具""" - self._message = await self._prepare_message(syscall_vars) - self._task_id = syscall_vars.task_id - return { - "task_id": self._task_id, - "message": self._message, - } + await super().init(syscall_vars) + return LLMInput( + task_id=syscall_vars.task_id, + message=await self._prepare_message(syscall_vars), + ).model_dump(exclude_none=True, by_alias=True) - async def exec(self) -> dict[str, Any]: + + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """运行LLM Call""" + data = LLMInput(**input_data) try: - result = "" - async for chunk in ReasoningLLM().call(task_id=self._task_id, messages=self._message): - result += chunk + async for chunk in ReasoningLLM().call(task_id=data.task_id, messages=data.message): + if not chunk: + continue + yield CallOutputChunk(type=CallOutputType.TEXT, content=chunk) except Exception as e: raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e - - result = result.strip().strip("\n") - return LLMNodeOutput(message=result).model_dump(exclude_none=True, by_alias=True) - - - async def stream(self) -> AsyncGenerator[str, None]: - """流式输出""" - try: - async for chunk in ReasoningLLM().call(task_id=self._task_id, messages=self._message): - yield chunk - except Exception as e: - raise CallError(message=f"大模型流式输出失败:{e!s}", data={}) from e diff --git a/apps/scheduler/call/llm/schema.py b/apps/scheduler/call/llm/schema.py new file mode 100644 index 000000000..08bd4fabe --- /dev/null +++ b/apps/scheduler/call/llm/schema.py @@ -0,0 +1,107 @@ +"""LLM工具的输入输出定义""" +from textwrap import dedent + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + +LLM_CONTEXT_PROMPT = dedent( + r""" + 以下是对用户和AI间对话的简短总结: + {{ summary }} + + 你作为AI,在回答用户的问题前,需要获取必要信息。为此,你调用了以下工具,并获得了输出: + + {% for tool in history_data %} + {{ tool.step_id }} + {{ tool.output_data }} + {% endfor %} + + """, + ).strip("\n") +LLM_DEFAULT_PROMPT = dedent( + r""" + + 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。 + 当前时间:{{ time }},可以作为时间参照。 + 用户的问题将在中给出,上下文背景信息将在中给出。 + 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。 + + + + {{ question }} + + + + {{ context }} + + + 现在,输出你的回答: + """, + ).strip("\n") +LLM_ERROR_PROMPT = dedent( + r""" + + 你是一位智能助手,能够根据用户的问题,使用Python工具获取信息,并作出回答。你在使用工具解决回答用户的问题时,发生了错误。 + 你的任务是:分析工具(Python程序)的异常信息,分析造成该异常可能的原因,并以通俗易懂的方式,将原因告知用户。 + + 当前时间:{{ time }},可以作为时间参照。 + 发生错误的程序异常信息将在中给出,用户的问题将在中给出,上下文背景信息将在中给出。 + 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息。 + + + + {{ error_info }} + + + + {{ question }} + + + + {{ context }} + + + 现在,输出你的回答: + """, +).strip("\n") +RAG_ANSWER_PROMPT = dedent( + r""" + + 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。 + 用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。 + + 注意事项: + 1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。 + 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。 + 3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。 + 4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。 + + + + {{ question }} + + + + {{ context }} + + + + {{ document }} + + + 现在,请根据上述信息,回答用户的问题: + """, + ).strip("\n") + + +class LLMInput(DataBase): + """定义LLM工具调用的输入""" + + task_id: str = Field(description="任务ID") + message: list[dict[str, str]] = Field(description="输入给大模型的消息列表") + + +class LLMOutput(DataBase): + """定义LLM工具调用的输出""" + diff --git a/apps/scheduler/call/output/output.py b/apps/scheduler/call/output/output.py new file mode 100644 index 000000000..23e21fed0 --- /dev/null +++ b/apps/scheduler/call/output/output.py @@ -0,0 +1,41 @@ +"""输出工具:将文字和结构化数据输出至前端""" +from collections.abc import AsyncGenerator +from typing import Annotated, Any, ClassVar + +from pydantic import Field +from ray.actor import ActorHandle + +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallOutputChunk, CallVars +from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.output.schema import OutputInput, OutputOutput + + +class Output(CoreCall, input_type=OutputInput, output_type=OutputOutput): + """输出工具""" + + name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "输出" + description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "将文字和结构化数据输出至前端" + + template: str = Field(description="输出模板(只能使用直接变量引用)") + + + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: + """初始化工具""" + self._text: AsyncGenerator[str, None] = kwargs["text"] + self._data: dict[str, Any] = kwargs["data"] + + return { + "text": self._text, + "data": self._data, + } + + + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行工具""" + yield CallOutputChunk( + type=CallOutputType.TEXT, + content=OutputOutput( + output=self._text, + ).model_dump(by_alias=True, exclude_none=True), + ) diff --git a/apps/scheduler/call/output/schema.py b/apps/scheduler/call/output/schema.py new file mode 100644 index 000000000..468c07ed4 --- /dev/null +++ b/apps/scheduler/call/output/schema.py @@ -0,0 +1,15 @@ +"""输出工具的数据结构""" + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class OutputInput(DataBase): + """输出工具的输入""" + + +class OutputOutput(DataBase): + """输出工具的输出""" + + output: str = Field(description="输出工具的输出") diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag/rag.py similarity index 43% rename from apps/scheduler/call/rag.py rename to apps/scheduler/call/rag/rag.py index 4c0e22787..6a18df78f 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -3,70 +3,74 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import logging -from typing import Any, ClassVar, Literal, Optional +from collections.abc import AsyncGenerator +from typing import Annotated, Any, ClassVar, Literal, Optional import aiohttp from fastapi import status -from pydantic import BaseModel, Field +from pydantic import Field from apps.common.config import config -from apps.entities.scheduler import CallError, CallVars +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallError, CallOutputChunk, CallVars +from apps.llm.patterns.rewrite import QuestionRewrite from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, RetrievalMode logger = logging.getLogger("ray") -class RAGOutput(BaseModel): - """RAG工具的输出""" - - corpus: list[str] = Field(description="知识库的语料列表") - - -class RAG(CoreCall, ret_type=RAGOutput): +class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): """RAG工具:查询知识库""" - name: ClassVar[str] = "知识库" - description: ClassVar[str] = "查询知识库,从文档中获取必要信息" + name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "知识库" + description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "查询知识库,从文档中获取必要信息" knowledge_base: Optional[str] = Field(description="知识库的id", alias="kb_sn", default=None) top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) retrieval_mode: Literal["chunk", "full_text"] = Field(description="检索模式", default="chunk") - async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化RAG工具""" - self._params_dict = { - "kb_sn": self.knowledge_base, - "top_k": self.top_k, - "retrieval_mode": self.retrieval_mode, - "content": syscall_vars.question, - } - - self._url = config["RAG_HOST"].rstrip("/") + "/chunk/get" - self._headers = { - "Content-Type": "application/json", - } + await super().init(syscall_vars) - return self._params_dict + self._task_id = syscall_vars.task_id + return RAGInput( + question=syscall_vars.question, + kb_sn=self.knowledge_base, + top_k=self.top_k, + retrieval_mode=RetrievalMode(self.retrieval_mode), + ).model_dump(by_alias=True, exclude_none=True) - async def exec(self) -> dict[str, Any]: + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """调用RAG工具""" + data = RAGInput(**input_data) + question = await QuestionRewrite().generate(self._task_id, question=data.question) + data.question = question + + url = config["RAG_HOST"].rstrip("/") + "/chunk/get" + headers = { + "Content-Type": "application/json", + } + # 发送 GET 请求 - async with aiohttp.ClientSession() as session, session.post(self._url, headers=self._headers, json=self._params_dict) as response: + async with aiohttp.ClientSession() as session, session.post(url, headers=headers, json=input_data) as response: # 检查响应状态码 if response.status == status.HTTP_200_OK: result = await response.json() chunk_list = result["data"] - corpus = [] for chunk in chunk_list: clean_chunk = chunk.replace("\n", " ") - corpus.append(clean_chunk) - - return RAGOutput( - corpus=corpus, - ).model_dump(exclude_none=True, by_alias=True) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=RAGOutput( + question=data.question, + corpus=clean_chunk, + ).model_dump(exclude_none=True, by_alias=True), + ) text = await response.text() logger.error("[RAG] 调用失败:%s", text) @@ -74,7 +78,7 @@ class RAG(CoreCall, ret_type=RAGOutput): raise CallError( message=f"rag调用失败:{text}", data={ - "question": self._params_dict["content"], + "question": data.question, "status": response.status, "text": text, }, diff --git a/apps/scheduler/call/rag/schema.py b/apps/scheduler/call/rag/schema.py new file mode 100644 index 000000000..81cc84c82 --- /dev/null +++ b/apps/scheduler/call/rag/schema.py @@ -0,0 +1,31 @@ +"""RAG工具的输入和输出""" + +from enum import Enum +from typing import Optional + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class RetrievalMode(str, Enum): + """检索模式""" + + CHUNK = "chunk" + FULL_TEXT = "full_text" + + +class RAGOutput(DataBase): + """RAG工具的输出""" + + question: str = Field(description="用户输入") + corpus: str = Field(description="知识库的语料列表") + + +class RAGInput(DataBase): + """RAG工具的输入""" + + question: str = Field(description="用户输入") + knowledge_base: Optional[str] = Field(description="知识库的id", alias="kb_sn", default=None) + top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) + retrieval_mode: RetrievalMode = Field(description="检索模式", default=RetrievalMode.CHUNK) diff --git a/apps/scheduler/call/search/search.py b/apps/scheduler/call/search/search.py new file mode 100644 index 000000000..e2be72baa --- /dev/null +++ b/apps/scheduler/call/search/search.py @@ -0,0 +1,31 @@ +"""搜索工具""" +from typing import Any, ClassVar + +from pydantic import BaseModel, Field + +from apps.entities.scheduler import CallVars +from apps.scheduler.call.core import CoreCall + + +class SearchRet(BaseModel): + """搜索工具返回值""" + + data: list[dict[str, Any]] = Field(description="搜索结果") + + +class Search(CoreCall, ret_type=SearchRet): + """搜索工具""" + + name: ClassVar[str] = "搜索" + description: ClassVar[str] = "获取搜索引擎的结果" + + async def init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]: + """初始化工具""" + self._query: str = kwargs["query"] + return {} + + + async def exec(self) -> dict[str, Any]: + """执行工具""" + return {} + diff --git a/apps/scheduler/call/slot/schema.py b/apps/scheduler/call/slot/schema.py new file mode 100644 index 000000000..7d053aa58 --- /dev/null +++ b/apps/scheduler/call/slot/schema.py @@ -0,0 +1,24 @@ +"""参数填充工具的Schema""" +from typing import Any + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + +SLOT_GEN_PROMPT = [ + {"role": "user", "content": ""}, + {"role": "assistant", "content": ""}, +] + + +class SlotInput(DataBase): + """参数填充工具的输入""" + + remaining_schema: dict[str, Any] = Field(description="剩余的Schema", default={}) + + +class SlotOutput(DataBase): + """参数填充工具的输出""" + + slot_data: dict[str, Any] = Field(description="填充后的数据", default={}) + remaining_schema: dict[str, Any] = Field(description="剩余的Schema", default={}) diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py new file mode 100644 index 000000000..7d4367baf --- /dev/null +++ b/apps/scheduler/call/slot/slot.py @@ -0,0 +1,68 @@ +"""参数填充工具""" +from collections.abc import AsyncGenerator +from typing import Annotated, Any, ClassVar + +import ray +from pydantic import Field + +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallOutputChunk +from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.slot.schema import SlotInput, SlotOutput +from apps.scheduler.executor.step import CallVars +from apps.scheduler.slot.slot import Slot as SlotProcessor + + +class Slot(CoreCall, input_type=SlotInput, output_type=SlotOutput): + """参数填充工具""" + + name: ClassVar[Annotated[str, Field(description="Call的名称", exclude=True)]] = "参数自动填充" + description: ClassVar[Annotated[str, Field(description="Call的描述", exclude=True)]] = "根据步骤历史,自动填充参数" + + data: dict[str, Any] = Field(description="当前输入", default={}) + current_schema: dict[str, Any] = Field(description="当前Schema", default={}) + summary: str = Field(description="背景信息总结", default="") + facts: list[str] = Field(description="事实信息", default=[]) + step_num: int = Field(description="历史步骤数", default=1) + + + async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> dict[str, Any]: + """使用LLM填充参数""" + ... + + + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: + """初始化""" + await super().init(syscall_vars) + + self._task_id = syscall_vars.task_id + + if not self.current_schema: + return SlotInput( + remaining_schema={}, + ).model_dump(by_alias=True, exclude_none=True) + + self._processor = SlotProcessor(self.current_schema) + remaining_schema = self._processor.check_json(self.data) + + return SlotInput( + remaining_schema=remaining_schema, + ).model_dump(by_alias=True, exclude_none=True) + + + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行参数填充""" + data = SlotInput(**input_data) + + # 使用LLM填充参数 + slot_data = await self._llm_slot_fill(data.remaining_schema) + + # 再次检查 + remaining_schema = self._processor.check_json(slot_data) + yield CallOutputChunk( + type=CallOutputType.DATA, + content=SlotOutput( + slot_data=slot_data, + remaining_schema=remaining_schema, + ).model_dump(by_alias=True, exclude_none=True), + ) diff --git a/apps/scheduler/call/sql/schema.py b/apps/scheduler/call/sql/schema.py new file mode 100644 index 000000000..b1dc669db --- /dev/null +++ b/apps/scheduler/call/sql/schema.py @@ -0,0 +1,19 @@ +"""SQL工具的输入输出""" + +from typing import Any + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class SQLInput(DataBase): + """SQL工具的输入""" + + question: str = Field(description="用户输入") + + +class SQLOutput(DataBase): + """SQL工具的输出""" + + dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql/sql.py similarity index 52% rename from apps/scheduler/call/sql.py rename to apps/scheduler/call/sql/sql.py index 312c5faaf..28c763c0b 100644 --- a/apps/scheduler/call/sql.py +++ b/apps/scheduler/call/sql/sql.py @@ -5,65 +5,68 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import json import logging -from typing import Any, Optional +from collections.abc import AsyncGenerator +from typing import Any, ClassVar, Optional import aiohttp from fastapi import status -from pydantic import BaseModel, Field +from pydantic import Field from sqlalchemy import text from apps.common.config import config -from apps.entities.scheduler import CallError, CallVars +from apps.entities.scheduler import CallError, CallOutputChunk, CallOutputType, CallVars from apps.models.postgres import PostgreSQL from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.sql.schema import SQLInput, SQLOutput logger = logging.getLogger("ray") -class SQLOutput(BaseModel): - """SQL工具的输出""" - - message: str = Field(description="SQL工具的执行结果") - dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果") - - class SQL(CoreCall): """SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。""" - name: str = "数据库" - description: str = "使用大模型生成SQL语句,用于查询数据库中的结构化数据" + name: ClassVar[str] = Field("数据库", exclude=True) + description: ClassVar[str] = Field("使用大模型生成SQL语句,用于查询数据库中的结构化数据", exclude=True) + sql: Optional[str] = Field(description="用户输入") - def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003 + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化SQL工具。""" - # 初始化aiohttp的ClientSession - self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) + await super().init(syscall_vars) + + return SQLInput( + question=syscall_vars.question, + ).model_dump(by_alias=True, exclude_none=True) - async def exec(self, _slot_data: dict[str, Any]) -> SQLOutput: + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """运行SQL工具""" # 获取必要参数 - syscall_vars: CallVars = getattr(self, "_syscall_vars") + data = SQLInput(**input_data) # 若手动设置了SQL,则直接使用 - session = await PostgreSQL.get_session() - if params.sql: + pg_session = await PostgreSQL.get_session() + if self.sql: try: - result = (await session.execute(text(params.sql))).all() - await session.close() + result = (await pg_session.execute(text(self.sql))).all() + await pg_session.close() dataset_list = [db_item._asdict() for db_item in result] - return SQLOutput( - message="SQL查询成功!", - dataset=dataset_list, + yield CallOutputChunk( + type=CallOutputType.DATA, + content=SQLOutput( + dataset=dataset_list, + ).model_dump(by_alias=True, exclude_none=True), ) except Exception as e: raise CallError(message=f"SQL查询错误:{e!s}", data={}) from e + else: + return # 若未设置SQL,则调用Chat2DB工具API,获取SQL post_data = { - "question": syscall_vars.question, + "question": data.question, "topk_sql": 5, "use_llm_enhancements": True, } @@ -71,7 +74,9 @@ class SQL(CoreCall): "Content-Type": "application/json", } - async with self._session.post(config["SQL_URL"], ssl=False, json=post_data, headers=headers) as response: + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) as session, session.post( + config["SQL_URL"], ssl=False, json=post_data, headers=headers, + ) as response: if response.status != status.HTTP_200_OK: raise CallError( message=f"SQL查询错误:API返回状态码{response.status}, 详细原因为{response.reason}。", @@ -79,20 +84,25 @@ class SQL(CoreCall): ) result = json.loads(await response.text()) logger.info("SQL工具返回的信息为:%s", result) - await self._session.close() + for item in result["sql_list"]: try: - db_result = (await session.execute(text(item["sql"]))).all() - await session.close() + db_result = (await pg_session.execute(text(item["sql"]))).all() + await pg_session.close() dataset_list = [db_item._asdict() for db_item in db_result] - return SQLOutput( - message="数据库查询成功!", - dataset=dataset_list, + yield CallOutputChunk( + type=CallOutputType.DATA, + content=SQLOutput( + dataset=dataset_list, + ).model_dump(by_alias=True, exclude_none=True), ) except Exception: # noqa: PERF203 logger.exception("SQL查询错误,正在换用下一条SQL语句。") + continue + else: + return raise CallError( message="SQL查询错误:SQL语句错误,数据库查询失败!", diff --git a/apps/scheduler/call/suggest/schema.py b/apps/scheduler/call/suggest/schema.py new file mode 100644 index 000000000..91c34bfd8 --- /dev/null +++ b/apps/scheduler/call/suggest/schema.py @@ -0,0 +1,37 @@ +"""问题推荐工具的输入输出""" + +from typing import Optional + +from pydantic import BaseModel, Field + +from apps.scheduler.call.core import DataBase + + +class SingleFlowSuggestionConfig(BaseModel): + """涉及单个Flow的问题推荐配置""" + + flow_id: str + question: Optional[str] = Field(default=None, description="固定的推荐问题") + + +class SuggestionOutputItem(BaseModel): + """问题推荐结果的单个条目""" + + question: str + app_id: str + flow_id: str + flow_description: str + + +class SuggestionInput(DataBase): + """问题推荐输入""" + + question: str + task_id: str + user_sub: str + + +class SuggestionOutput(DataBase): + """问题推荐结果""" + + output: list[SuggestionOutputItem] diff --git a/apps/scheduler/call/suggest.py b/apps/scheduler/call/suggest/suggest.py similarity index 38% rename from apps/scheduler/call/suggest.py rename to apps/scheduler/call/suggest/suggest.py index 96181c758..ba031c3c8 100644 --- a/apps/scheduler/call/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -2,72 +2,54 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from typing import Any, ClassVar, Optional +from collections.abc import AsyncGenerator +from typing import Annotated, Any, ClassVar import ray -from pydantic import BaseModel, Field +from pydantic import Field -from apps.entities.scheduler import CallError, CallVars +from apps.entities.scheduler import CallError, CallOutputChunk, CallVars from apps.manager.user_domain import UserDomainManager from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.suggest.schema import ( + SingleFlowSuggestionConfig, + SuggestionInput, + SuggestionOutput, +) -class _SingleFlowSuggestionConfig(BaseModel): - """涉及单个Flow的问题推荐配置""" - - flow_id: str - question: Optional[str] = Field(default=None, description="固定的推荐问题") - - -class _SuggestionOutputItem(BaseModel): - """问题推荐结果的单个条目""" - - question: str - app_id: str - flow_id: str - flow_description: str - - -class SuggestionOutput(BaseModel): - """问题推荐结果""" - - output: list[_SuggestionOutputItem] - - -class Suggestion(CoreCall, ret_type=SuggestionOutput): +class Suggestion(CoreCall, input_type=SuggestionInput, output_type=SuggestionOutput): """问题推荐""" - name: ClassVar[str] = "问题推荐" - description: ClassVar[str] = "在答案下方显示推荐的下一个问题" + name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "问题推荐" + description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "在答案下方显示推荐的下一个问题" - configs: list[_SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1) + configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1) num: int = Field(default=3, ge=1, le=6, description="推荐问题的总数量(必须大于等于configs中涉及的Flow的数量)") + context: list[dict[str, str]] + - async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]: + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: """初始化""" - # 普通变量 - self._question = syscall_vars.question - self._task_id = syscall_vars.task_id - self._background = syscall_vars.summary - self._user_sub = syscall_vars.user_sub + await super().init(syscall_vars) - return { - "question": self._question, - "task_id": self._task_id, - "background": self._background, - "user_sub": self._user_sub, - } + return SuggestionInput( + question=syscall_vars.question, + task_id=syscall_vars.task_id, + user_sub=syscall_vars.user_sub, + ).model_dump(by_alias=True, exclude_none=True) - async def exec(self) -> dict[str, Any]: + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """运行问题推荐""" + data = SuggestionInput(**input_data) # 获取当前任务 task_actor = await ray.get_actor("task") - task = await task_actor.get_task.remote(self._task_id) + task = await task_actor.get_task.remote(data.task_id) # 获取当前用户的画像 - user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(self._user_sub, 5) + user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5) current_record = [ { diff --git a/apps/scheduler/call/summary.py b/apps/scheduler/call/summary.py deleted file mode 100644 index 27b18280e..000000000 --- a/apps/scheduler/call/summary.py +++ /dev/null @@ -1,43 +0,0 @@ -"""总结上下文工具 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -""" -from typing import Any, ClassVar - -from pydantic import BaseModel, Field - -from apps.entities.scheduler import CallVars, ExecutorBackground -from apps.llm.patterns.executor import ExecutorSummary -from apps.scheduler.call.core import CoreCall - - -class SummaryOutput(BaseModel): - """总结工具的输出""" - - summary: str = Field(description="对问答上下文的总结内容") - - -class Summary(CoreCall, ret_type=SummaryOutput): - """总结工具""" - - name: ClassVar[str] = "理解上下文" - description: ClassVar[str] = "使用大模型,理解对话上下文" - - - async def init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]: - """初始化工具""" - self._context: ExecutorBackground = kwargs["context"] - self._task_id: str = syscall_vars.task_id - if not self._context or not isinstance(self._context, list): - err = "context参数必须是一个列表" - raise ValueError(err) - - return { - "task_id": self._task_id, - } - - - async def exec(self) -> dict[str, Any]: - """执行工具""" - summary = await ExecutorSummary().generate(self._task_id, background=self._context) - return SummaryOutput(summary=summary).model_dump(by_alias=True, exclude_none=True) diff --git a/apps/scheduler/call/summary/schema.py b/apps/scheduler/call/summary/schema.py new file mode 100644 index 000000000..7247fd335 --- /dev/null +++ b/apps/scheduler/call/summary/schema.py @@ -0,0 +1,17 @@ +"""总结工具的输入和输出""" + +from pydantic import Field + +from apps.scheduler.call.core import DataBase + + +class SummaryInput(DataBase): + """总结工具的输入""" + + task_id: str = Field(description="任务ID") + + +class SummaryOutput(DataBase): + """总结工具的输出""" + + summary: str = Field(description="对问答上下文的总结内容") diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py new file mode 100644 index 000000000..52736cbe9 --- /dev/null +++ b/apps/scheduler/call/summary/summary.py @@ -0,0 +1,39 @@ +"""总结上下文工具 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +from collections.abc import AsyncGenerator +from typing import Any, ClassVar + +from pydantic import Field + +from apps.entities.enum_var import CallOutputType +from apps.entities.scheduler import CallOutputChunk, CallVars, ExecutorBackground +from apps.llm.patterns.executor import ExecutorSummary +from apps.scheduler.call.core import CoreCall +from apps.scheduler.call.summary.schema import SummaryInput, SummaryOutput + + +class Summary(CoreCall, input_type=SummaryInput, output_type=SummaryOutput): + """总结工具""" + + name: ClassVar[str] = Field("理解上下文", exclude=True) + description: ClassVar[str] = Field("使用大模型,理解对话上下文", exclude=True) + + context: ExecutorBackground = Field(description="对话上下文") + + + async def init(self, syscall_vars: CallVars) -> dict[str, Any]: + """初始化工具""" + await super().init(syscall_vars) + + return SummaryInput( + task_id=syscall_vars.task_id, + ).model_dump(by_alias=True, exclude_none=True) + + + async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: + """执行工具""" + data = SummaryInput(**input_data) + summary = await ExecutorSummary().generate(data.task_id, background=self.context) + yield CallOutputChunk(type=CallOutputType.TEXT, content=summary) -- Gitee From 2426b35917ea56638f31e0ce0c55f271209b879d Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:27:37 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E4=BF=AE=E6=AD=A3RAG=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/rag/rag.py | 26 +++++++++++++++----------- apps/scheduler/call/rag/schema.py | 4 ++-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/apps/scheduler/call/rag/rag.py b/apps/scheduler/call/rag/rag.py index 6a18df78f..c3eadaab9 100644 --- a/apps/scheduler/call/rag/rag.py +++ b/apps/scheduler/call/rag/rag.py @@ -37,7 +37,7 @@ class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): self._task_id = syscall_vars.task_id return RAGInput( - question=syscall_vars.question, + content=syscall_vars.question, kb_sn=self.knowledge_base, top_k=self.top_k, retrieval_mode=RetrievalMode(self.retrieval_mode), @@ -47,8 +47,8 @@ class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """调用RAG工具""" data = RAGInput(**input_data) - question = await QuestionRewrite().generate(self._task_id, question=data.question) - data.question = question + question = await QuestionRewrite().generate(self._task_id, question=data.content) + data.content = question url = config["RAG_HOST"].rstrip("/") + "/chunk/get" headers = { @@ -62,15 +62,19 @@ class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): result = await response.json() chunk_list = result["data"] + corpus = [] for chunk in chunk_list: clean_chunk = chunk.replace("\n", " ") - yield CallOutputChunk( - type=CallOutputType.DATA, - content=RAGOutput( - question=data.question, - corpus=clean_chunk, - ).model_dump(exclude_none=True, by_alias=True), - ) + corpus.append(clean_chunk) + + yield CallOutputChunk( + type=CallOutputType.DATA, + content=RAGOutput( + question=data.content, + corpus=corpus, + ).model_dump(exclude_none=True, by_alias=True), + ) + return text = await response.text() logger.error("[RAG] 调用失败:%s", text) @@ -78,7 +82,7 @@ class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput): raise CallError( message=f"rag调用失败:{text}", data={ - "question": data.question, + "question": data.content, "status": response.status, "text": text, }, diff --git a/apps/scheduler/call/rag/schema.py b/apps/scheduler/call/rag/schema.py index 81cc84c82..092879221 100644 --- a/apps/scheduler/call/rag/schema.py +++ b/apps/scheduler/call/rag/schema.py @@ -19,13 +19,13 @@ class RAGOutput(DataBase): """RAG工具的输出""" question: str = Field(description="用户输入") - corpus: str = Field(description="知识库的语料列表") + corpus: list[str] = Field(description="知识库的语料列表") class RAGInput(DataBase): """RAG工具的输入""" - question: str = Field(description="用户输入") + content: str = Field(description="用户输入") knowledge_base: Optional[str] = Field(description="知识库的id", alias="kb_sn", default=None) top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) retrieval_mode: RetrievalMode = Field(description="检索模式", default=RetrievalMode.CHUNK) -- Gitee From c4fda6b3938631ecdf859f2f2b298b0e3a804657 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:36:07 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E5=8A=A0=E5=85=A5=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E6=94=B9=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/patterns/rewrite.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index 7322b06b2..90e267585 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -1,5 +1,9 @@ """问题改写""" + +from typing import Any, ClassVar + from apps.llm.patterns.core import CorePattern +from apps.llm.patterns.json_gen import Json from apps.llm.reasoning import ReasoningLLM @@ -11,14 +15,23 @@ class QuestionRewrite(CorePattern): 根据上面的对话,推断用户的实际意图并补全用户的提问内容。 要求: - 1. 请使用JSON格式输出,参考下面给出的样例;输出不要包含XML标签,不要包含任何解释说明; + 1. 请使用JSON格式输出,参考下面给出的样例;不要包含任何XML标签,不要包含任何解释说明; 2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。 3. 补全内容必须精准、恰当,不要编造任何内容。 + + 输出格式样例: + {{ + "question": "补全后的问题" + }} openEuler的特点? - openEuler相较于其他操作系统,其特点是什么? + + {{ + "question": "openEuler相较于其他操作系统,其特点是什么?" + }} + @@ -27,6 +40,18 @@ class QuestionRewrite(CorePattern): """ """用户提示词""" + slot_schema: ClassVar[dict[str, Any]] = { + "type": "object", + "properties": { + "question": { + "type": "string", + "description": "补全后的问题", + }, + }, + "required": ["question"], + } + """最终输出的JSON Schema""" + async def generate(self, task_id: str, **kwargs) -> str: # noqa: ANN003 """问题补全与重写""" question = kwargs["question"] @@ -40,4 +65,9 @@ class QuestionRewrite(CorePattern): async for chunk in ReasoningLLM().call(task_id, messages, streaming=False): result += chunk - return result.strip().strip("\n").removesuffix("") + messages += [{"role": "assistant", "content": result}] + question_dict = await Json().generate("", conversation=messages, spec=self.slot_schema) + + if not question_dict or "question" not in question_dict or not question_dict["question"]: + return question + return question_dict["question"] -- Gitee From ef053c48c3798cb111acd78c9138ce71281bd684 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 18 Mar 2025 08:52:51 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E9=94=99=E8=AF=AF=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/convert/convert.py | 4 ++-- apps/scheduler/call/facts/facts.py | 4 ++-- apps/scheduler/call/llm/llm.py | 7 ++++--- apps/scheduler/scheduler/graph.py | 2 ++ 4 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 apps/scheduler/scheduler/graph.py diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py index cf85bf73a..8100755cc 100644 --- a/apps/scheduler/call/convert/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -10,7 +10,7 @@ from typing import Any, ClassVar, Optional import _jsonnet import pytz -from jinja2 import BaseLoader, select_autoescape +from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field @@ -52,7 +52,7 @@ class Convert(CoreCall, input_type=ConvertInput, output_type=ConvertOutput): else: text_template = SandboxedEnvironment( loader=BaseLoader(), - autoescape=select_autoescape(), + autoescape=False, trim_blocks=True, lstrip_blocks=True, ).from_string(self.text_template) diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index a45216c8c..36e49d4e1 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -43,9 +43,9 @@ class FactsCall(CoreCall, input_type=FactsInput, output_type=FactsOutput): """执行工具""" data = FactsInput(**input_data) # 提取事实信息 - facts = await Facts().generate(data.task_id, message=data.message) + facts = await Facts().generate(data.task_id, conversation=data.message) # 更新用户画像 - domain_list = await Domain().generate(data.task_id, message=data.message) + domain_list = await Domain().generate(data.task_id, conversation=data.message) for domain in domain_list: await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index 7af2ab324..10c591301 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -8,7 +8,7 @@ from datetime import datetime from typing import Annotated, Any, ClassVar import pytz -from jinja2 import BaseLoader, select_autoescape +from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field @@ -39,7 +39,7 @@ class LLM(CoreCall, input_type=LLMInput, output_type=LLMOutput): # 创建共享的 Environment 实例 env = SandboxedEnvironment( loader=BaseLoader(), - autoescape=select_autoescape(), + autoescape=False, trim_blocks=True, lstrip_blocks=True, ) @@ -94,7 +94,8 @@ class LLM(CoreCall, input_type=LLMInput, output_type=LLMOutput): """运行LLM Call""" data = LLMInput(**input_data) try: - async for chunk in ReasoningLLM().call(task_id=data.task_id, messages=data.message): + # TODO: 临时改成了非流式 + async for chunk in ReasoningLLM().call(task_id=data.task_id, messages=data.message, streaming=False): if not chunk: continue yield CallOutputChunk(type=CallOutputType.TEXT, content=chunk) diff --git a/apps/scheduler/scheduler/graph.py b/apps/scheduler/scheduler/graph.py new file mode 100644 index 000000000..9dd562621 --- /dev/null +++ b/apps/scheduler/scheduler/graph.py @@ -0,0 +1,2 @@ +"""工作流DAG相关操作""" + -- Gitee