diff --git a/apps/common/task.py b/apps/common/task.py
index 8f4696241f7823eb510f986cab15475b553921bc..f937fefad833ab59e30a09e802df9caec91be6da 100644
--- a/apps/common/task.py
+++ b/apps/common/task.py
@@ -16,7 +16,7 @@ from apps.entities.record import (
RecordMetadata,
)
from apps.entities.request_data import RequestData
-from apps.entities.task import TaskBlock, TaskData
+from apps.entities.task import TaskBlock, TaskData, TaskRuntimeData
from apps.manager.task import TaskManager
from apps.models.mongo import MongoDB
@@ -77,8 +77,10 @@ class Task:
new_record.task_id = task_id
self._task_map[task_id] = TaskBlock(
- session_id=session_id,
record=new_record,
+ runtime_data=TaskRuntimeData(
+ session_id=session_id,
+ ),
)
return deepcopy(self._task_map[task_id])
@@ -86,9 +88,11 @@ class Task:
task_id = task.id
new_record.task_id = task_id
self._task_map[task_id] = TaskBlock(
- session_id=session_id,
record=new_record,
flow_state=task.state,
+ runtime_data=TaskRuntimeData(
+ session_id=session_id,
+ ),
)
return deepcopy(self._task_map[task_id])
diff --git a/apps/constants.py b/apps/constants.py
index 6b05e30bba9642410a318a0ff48d1ec3cf528d28..85ad1725bbbd939f15f156053133a194d5167e26 100644
--- a/apps/constants.py
+++ b/apps/constants.py
@@ -4,9 +4,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
from __future__ import annotations
-import logging
-from textwrap import dedent
-
# 新对话默认标题
NEW_CHAT = "New Chat"
# 滑动窗口限流 默认窗口期
@@ -27,8 +24,6 @@ APP_DIR = "app"
FLOW_DIR = "flow"
# Scheduler进程数
SCHEDULER_REPLICAS = 2
-# 日志记录器
-LOGGER = logging.getLogger("ray")
REASONING_BEGIN_TOKEN = [
"",
@@ -36,63 +31,3 @@ REASONING_BEGIN_TOKEN = [
REASONING_END_TOKEN = [
"",
]
-LLM_CONTEXT_PROMPT = dedent(
- r"""
- 以下是对用户和AI间对话的简短总结:
- {{ summary }}
-
- 你作为AI,在回答用户的问题前,需要获取必要信息。为此,你调用了以下工具,并获得了输出:
-
- {% for tool in history_data %}
- {{ tool.step_id }}
-
- {% endfor %}
-
- """,
- ).strip("\n")
-LLM_DEFAULT_PROMPT = dedent(
- r"""
-
- 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。
- 当前时间:{{ time }},可以作为时间参照。
- 用户的问题将在中给出,上下文背景信息将在中给出。
- 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。
-
-
-
- {{ question }}
-
-
-
- {{ context }}
-
- """,
- ).strip("\n")
-RAG_ANSWER_PROMPT = dedent(
- r"""
-
- 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。
- 用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。
-
- 注意事项:
- 1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。
- 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。
- 3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。
- 4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。
-
-
-
- {{ question }}
-
-
-
- {{ context }}
-
-
-
- {{ document }}
-
-
- 现在,请根据上述信息,回答用户的问题:
- """,
- ).strip("\n")
diff --git a/apps/entities/collection.py b/apps/entities/collection.py
index 01042e4811fe2b4eb013a22dcdc5a7f97ab6d6cb..8410e1e35735169d7fcf818259c233331ff0b9ad 100644
--- a/apps/entities/collection.py
+++ b/apps/entities/collection.py
@@ -124,6 +124,7 @@ class RecordContent(BaseModel):
question: str
answer: str
data: dict[str, Any] = {}
+ facts: list[str] = Field(description="[运行后修改]与Record关联的事实信息", default=[])
class RecordComment(BaseModel):
@@ -144,7 +145,6 @@ class Record(BaseModel):
data: str
key: dict[str, Any] = {}
flow: list[str] = Field(description="[运行后修改]与Record关联的FlowHistory的ID", default=[])
- facts: list[str] = Field(description="[运行后修改]与Record关联的事实信息", default=[])
comment: Optional[RecordComment] = None
metadata: Optional[RecordMetadata] = None
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=timezone.utc).timestamp(), 3))
diff --git a/apps/entities/enum_var.py b/apps/entities/enum_var.py
index 328dd626d6e9cf3512f66cc04a383101a9279f5c..8c30ca234eb5231308a56e7bc3b970de2ad9fee2 100644
--- a/apps/entities/enum_var.py
+++ b/apps/entities/enum_var.py
@@ -32,22 +32,13 @@ class DocumentStatus(str, Enum):
FAILED = "failed"
-class FlowOutputType(str, Enum):
- """Flow输出类型"""
-
- CODE = "code"
- CHART = "chart"
- URL = "url"
- SCHEMA = "schema"
- NONE = "none"
-
-
class EventType(str, Enum):
"""事件类型"""
HEARTBEAT = "heartbeat"
INIT = "init"
TEXT_ADD = "text.add"
+ GRAPH = "graph"
DOCUMENT_ADD = "document.add"
SUGGEST = "suggest"
FLOW_START = "flow.start"
@@ -134,3 +125,10 @@ class ContentType(str, Enum):
JSON = "application/json"
FORM_URLENCODED = "application/x-www-form-urlencoded"
MULTIPART_FORM_DATA = "multipart/form-data"
+
+
+class CallOutputType(str, Enum):
+ """Call输出类型"""
+
+ TEXT = "text"
+ DATA = "data"
diff --git a/apps/entities/message.py b/apps/entities/message.py
index 75bc490f17e943fbc4a8455b72b83b525102e5e9..8464c2d2cb0b11f7b9ae9c04768bda286494adaf 100644
--- a/apps/entities/message.py
+++ b/apps/entities/message.py
@@ -7,7 +7,7 @@ from typing import Any, Optional
from pydantic import BaseModel, Field
from apps.entities.collection import RecordMetadata
-from apps.entities.enum_var import EventType, FlowOutputType, StepStatus
+from apps.entities.enum_var import EventType, StepStatus
class HeartbeatData(BaseModel):
@@ -81,13 +81,6 @@ class FlowStartContent(BaseModel):
params: dict[str, Any] = Field(description="预先提供的参数")
-class FlowStopContent(BaseModel):
- """flow.stop消息的content"""
-
- type: Optional[FlowOutputType] = Field(description="Flow输出的类型", default=None)
- data: Optional[dict[str, Any]] = Field(description="Flow输出的数据", default=None)
-
-
class MessageBase(HeartbeatData):
"""基础消息事件结构"""
diff --git a/apps/entities/node.py b/apps/entities/node.py
index 03aab2e7eb58b4bc0d36a0bd00fbc02f26a6e3fa..c51b2b53d58c0917c6e736fe1bbe74f7b7dee229 100644
--- a/apps/entities/node.py
+++ b/apps/entities/node.py
@@ -9,14 +9,14 @@ from apps.entities.pool import NodePool
class APINodeInput(BaseModel):
"""API节点覆盖输入"""
- param_schema: Optional[dict[str, Any]] = Field(description="API节点输入参数Schema", default=None)
- body_schema: Optional[dict[str, Any]] = Field(description="API节点输入请求体Schema", default=None)
+ query: Optional[dict[str, Any]] = Field(description="API节点输入参数Schema", default=None)
+ body: Optional[dict[str, Any]] = Field(description="API节点输入请求体Schema", default=None)
class APINodeOutput(BaseModel):
"""API节点覆盖输出"""
- resp_schema: Optional[dict[str, Any]] = Field(description="API节点输出Schema", default=None)
+ result: Optional[dict[str, Any]] = Field(description="API节点输出Schema", default=None)
class APINode(NodePool):
diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py
index 8f0cf69a50d3b77d3d2bdfe31e26fab3f1d48b9e..0b0f4de197d596d9861b329058f04d32801774f4 100644
--- a/apps/entities/scheduler.py
+++ b/apps/entities/scheduler.py
@@ -2,18 +2,16 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
-from typing import Any
+from typing import Any, Union
from pydantic import BaseModel, Field
+from apps.entities.enum_var import CallOutputType
from apps.entities.task import FlowStepHistory
class CallVars(BaseModel):
- """所有Call都需要接受的参数。包含用户输入、上下文信息、Step的输出记录等
-
- 这一部分的参数由Executor填充,用户无法修改
- """
+ """由Executor填充的变量,即“系统变量”"""
summary: str = Field(description="上下文信息")
user_sub: str = Field(description="用户ID")
@@ -22,12 +20,13 @@ class CallVars(BaseModel):
task_id: str = Field(description="任务ID")
flow_id: str = Field(description="Flow ID")
session_id: str = Field(description="当前用户的Session ID")
+ service_id: str = Field(description="语义接口ID")
class ExecutorBackground(BaseModel):
"""Executor的背景信息"""
- conversation: str = Field(description="当前Executor的背景信息")
+ conversation: list[dict[str, str]] = Field(description="对话记录")
facts: list[str] = Field(description="当前Executor的背景信息")
@@ -38,3 +37,10 @@ class CallError(Exception):
"""获取Call错误中的数据"""
self.message = message
self.data = data
+
+
+class CallOutputChunk(BaseModel):
+ """Call的输出"""
+
+ type: CallOutputType = Field(description="输出类型")
+ content: Union[str, dict[str, Any]] = Field(description="输出内容")
diff --git a/apps/entities/task.py b/apps/entities/task.py
index 4d61e32cb242fcea3befbc0546f62fd2f9f0e983..d704c4829d5a3bee2b9f38b1d2a1597dcabc2023 100644
--- a/apps/entities/task.py
+++ b/apps/entities/task.py
@@ -39,21 +39,27 @@ class ExecutorState(BaseModel):
step_id: str = Field(description="当前步骤ID")
step_name: str = Field(description="当前步骤名称")
app_id: str = Field(description="应用ID")
- # 运行时数据
- ai_summary: str = Field(description="大模型的思考内容", default="")
- filled_data: dict[str, Any] = Field(description="待使用的参数", default={})
remaining_schema: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
+class TaskRuntimeData(BaseModel):
+ """Task的运行时数据"""
+
+ session_id: str = Field(description="浏览器会话ID", default="")
+ user_sub: str = Field(description="用户ID", default="")
+ summary: str = Field(description="对对话记录的总结", default="")
+ filled_data: dict[str, Any] = Field(description="已经填充的参数", default={})
+ new_context: list[str] = Field(description="本次运行中新产生的步骤历史", default=[])
+ reached_end: bool = Field(description="是否到达Flow的终点", default=False)
+
+
class TaskBlock(BaseModel):
"""内存中的Task块,不存储在数据库中"""
- session_id: str = Field(description="浏览器会话ID")
- user_sub: str = Field(description="用户ID", default="")
record: RecordData = Field(description="当前任务执行过程关联的Record")
flow_state: Optional[ExecutorState] = Field(description="Flow的状态", default=None)
flow_context: dict[str, FlowStepHistory] = Field(description="Flow的执行信息", default={})
- new_context: list[str] = Field(description="Flow的执行信息(增量ID)", default=[])
+ runtime_data: TaskRuntimeData = Field(description="Task的运行时数据", default_factory=lambda: TaskRuntimeData())
class RequestDataApp(BaseModel):
diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py
index 7322b06b2a28438fa59d3e5a3edf3edf97a43a5c..90e2675852eae64a517fbb6f331936407390377e 100644
--- a/apps/llm/patterns/rewrite.py
+++ b/apps/llm/patterns/rewrite.py
@@ -1,5 +1,9 @@
"""问题改写"""
+
+from typing import Any, ClassVar
+
from apps.llm.patterns.core import CorePattern
+from apps.llm.patterns.json_gen import Json
from apps.llm.reasoning import ReasoningLLM
@@ -11,14 +15,23 @@ class QuestionRewrite(CorePattern):
根据上面的对话,推断用户的实际意图并补全用户的提问内容。
要求:
- 1. 请使用JSON格式输出,参考下面给出的样例;输出不要包含XML标签,不要包含任何解释说明;
+ 1. 请使用JSON格式输出,参考下面给出的样例;不要包含任何XML标签,不要包含任何解释说明;
2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。
3. 补全内容必须精准、恰当,不要编造任何内容。
+
+ 输出格式样例:
+ {{
+ "question": "补全后的问题"
+ }}
openEuler的特点?
-
+
@@ -27,6 +40,18 @@ class QuestionRewrite(CorePattern):
"""
"""用户提示词"""
+ slot_schema: ClassVar[dict[str, Any]] = {
+ "type": "object",
+ "properties": {
+ "question": {
+ "type": "string",
+ "description": "补全后的问题",
+ },
+ },
+ "required": ["question"],
+ }
+ """最终输出的JSON Schema"""
+
async def generate(self, task_id: str, **kwargs) -> str: # noqa: ANN003
"""问题补全与重写"""
question = kwargs["question"]
@@ -40,4 +65,9 @@ class QuestionRewrite(CorePattern):
async for chunk in ReasoningLLM().call(task_id, messages, streaming=False):
result += chunk
- return result.strip().strip("\n").removesuffix("")
+ messages += [{"role": "assistant", "content": result}]
+ question_dict = await Json().generate("", conversation=messages, spec=self.slot_schema)
+
+ if not question_dict or "question" not in question_dict or not question_dict["question"]:
+ return question
+ return question_dict["question"]
diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py
index 29c45f5bfac6b7b9665f0f05fc86547956b06431..8ee6e0ef8a24aa0b9388eb0371feb186cff44cf9 100644
--- a/apps/manager/appcenter.py
+++ b/apps/manager/appcenter.py
@@ -383,6 +383,8 @@ class AppCenterManager:
:param app_id: 应用唯一标识
:return: 更新是否成功
"""
+ if not app_id:
+ return True
try:
user_collection = MongoDB.get_collection("user")
current_time = round(datetime.now(tz=timezone.utc).timestamp(), 3)
@@ -402,10 +404,10 @@ class AppCenterManager:
logger.info("[AppCenterManager] 更新用户最近使用应用: %s", user_sub)
return True
logger.warning("[AppCenterManager] 用户 %s 没有更新", user_sub)
- return False
except Exception:
logger.exception("[AppCenterManager] 更新用户最近使用应用失败")
- return False
+
+ return False
@staticmethod
async def get_default_flow_id(app_id: str) -> Optional[str]:
diff --git a/apps/manager/node.py b/apps/manager/node.py
index 83faa4eaf18b6ec0e0fc4aebb1cf4c4a733ff871..9ba8692fe04f13207a91a776b55c1938c0ebf741 100644
--- a/apps/manager/node.py
+++ b/apps/manager/node.py
@@ -28,6 +28,24 @@ class NodeManager:
return node["call_id"]
+ @staticmethod
+ async def get_node(node_id: str) -> NodePool:
+ """获取Node的类型"""
+ node_collection = MongoDB().get_collection("node")
+ node = await node_collection.find_one({"_id": node_id})
+ if not node:
+ err = f"[NodeManager] Node type {node_id} not found."
+ raise ValueError(err)
+
+ try:
+ node = NodePool.model_validate(node)
+ except Exception as e:
+ err = f"[NodeManager] Node {node_id} 验证失败"
+ logger.exception(err)
+ raise ValueError(err) from e
+ return node
+
+
@staticmethod
async def get_node_name(node_id: str) -> str:
"""获取node的名称"""
@@ -96,5 +114,5 @@ class NodeManager:
# 返回参数Schema
return (
NodeManager.merge_params_schema(call_class.model_json_schema(), node_data.known_params or {}),
- call_class.ret_type.model_json_schema(), # type: ignore[attr-defined]
+ call_class.output_type.model_json_schema(), # type: ignore[attr-defined]
)
diff --git a/apps/manager/token.py b/apps/manager/token.py
index b0bf988536f0f93fce4fdb3263496ccca7d0d9db..a5413728835a22dc202fce719b8d26999a48d640 100644
--- a/apps/manager/token.py
+++ b/apps/manager/token.py
@@ -107,4 +107,3 @@ class TokenManager:
pipe.delete(f"{user_sub}_oidc_refresh_token")
pipe.delete(f"aops_{user_sub}_token")
await pipe.execute()
-
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 16391241e4991ec52ae26384b165a64f161f2deb..2451bcd91d285ef4b98e540a40e64caaf6070218 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -55,7 +55,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
task = await task_actor.get_task.remote(session_id=session_id, post_body=post_body)
task_id = task.record.task_id
- task.user_sub = user_sub
+ task.runtime_data.user_sub = user_sub
task.record.group_id = group_id
post_body.group_id = group_id
await task_actor.set_task.remote(task_id, task)
diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py
index 5d5605b0ce513d7383efb16456c21ea4c86e4c27..913fc8f78d1020cb7d6d50418fce434790eba665 100644
--- a/apps/scheduler/call/__init__.py
+++ b/apps/scheduler/call/__init__.py
@@ -2,15 +2,19 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
-from apps.scheduler.call.api import API
-from apps.scheduler.call.llm import LLM
-from apps.scheduler.call.rag import RAG
-from apps.scheduler.call.suggest import Suggestion
+from apps.scheduler.call.api.api import API
+from apps.scheduler.call.convert.convert import Convert
+from apps.scheduler.call.llm.llm import LLM
+from apps.scheduler.call.output.output import Output
+from apps.scheduler.call.rag.rag import RAG
+from apps.scheduler.call.suggest.suggest import Suggestion
# 只包含需要在编排界面展示的工具
__all__ = [
"API",
"LLM",
"RAG",
+ "Convert",
+ "Output",
"Suggestion",
]
diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api/api.py
similarity index 39%
rename from apps/scheduler/call/api.py
rename to apps/scheduler/call/api/api.py
index d6753a59c5d88008b0920f6a71ea494aba386308..4a17070ccb2792574e990183f1deaff2a0933de6 100644
--- a/apps/scheduler/call/api.py
+++ b/apps/scheduler/call/api/api.py
@@ -3,83 +3,126 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
import json
-from typing import Any, ClassVar, Literal, Optional
+import logging
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar, Optional
import aiohttp
from fastapi import status
-from pydantic import BaseModel, Field
+from pydantic import Field
-from apps.constants import LOGGER
-from apps.entities.enum_var import ContentType, HTTPMethod
-from apps.entities.scheduler import CallError, CallVars
+from apps.common.config import config
+from apps.entities.enum_var import CallOutputType, ContentType, HTTPMethod
+from apps.entities.flow import ServiceMetadata
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
+from apps.manager.service import ServiceCenterManager
from apps.manager.token import TokenManager
+from apps.scheduler.call.api.schema import APIInput, APIOutput
from apps.scheduler.call.core import CoreCall
-from apps.scheduler.slot.slot import Slot
-
-class APIOutput(BaseModel):
- """API调用工具的输出"""
-
- http_code: int = Field(description="API调用工具的HTTP返回码")
- message: str = Field(description="API调用工具的执行结果")
- output: dict[str, Any] = Field(description="API调用工具的输出")
-
-
-class API(CoreCall, ret_type=APIOutput):
+logger = logging.getLogger("ray")
+SUCCESS_HTTP_CODES = [
+ status.HTTP_200_OK,
+ status.HTTP_201_CREATED,
+ status.HTTP_202_ACCEPTED,
+ status.HTTP_203_NON_AUTHORITATIVE_INFORMATION,
+ status.HTTP_204_NO_CONTENT,
+ status.HTTP_205_RESET_CONTENT,
+ status.HTTP_206_PARTIAL_CONTENT,
+ status.HTTP_207_MULTI_STATUS,
+ status.HTTP_208_ALREADY_REPORTED,
+ status.HTTP_226_IM_USED,
+ status.HTTP_301_MOVED_PERMANENTLY,
+ status.HTTP_302_FOUND,
+ status.HTTP_303_SEE_OTHER,
+ status.HTTP_304_NOT_MODIFIED,
+ status.HTTP_307_TEMPORARY_REDIRECT,
+ status.HTTP_308_PERMANENT_REDIRECT,
+]
+
+
+class API(CoreCall, input_type=APIInput, output_type=APIOutput):
"""API调用工具"""
- name: ClassVar[str] = "HTTP请求"
- description: ClassVar[str] = "向某一个API接口发送HTTP请求,获取数据。"
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "HTTP请求"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "向某一个API接口发送HTTP请求,获取数据。"
url: str = Field(description="API接口的完整URL")
method: HTTPMethod = Field(description="API接口的HTTP Method")
content_type: ContentType = Field(description="API接口的Content-Type")
timeout: int = Field(description="工具超时时间", default=300, gt=30)
+
body: dict[str, Any] = Field(description="已知的部分请求体", default={})
- auth: dict[str, Any] = Field(description="API鉴权信息", default={})
+ query: dict[str, Any] = Field(description="已知的部分请求参数", default={})
- async def exec(self, syscall_vars: CallVars, **_kwargs: Any) -> APIOutput:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
+ """初始化API调用工具"""
+ await super().init(syscall_vars)
+ # 获取对应API的Service Metadata
+ try:
+ service_metadata = await ServiceCenterManager.get_service_data(syscall_vars.user_sub, syscall_vars.service_id)
+ service_metadata = ServiceMetadata.model_validate(service_metadata)
+ except Exception as e:
+ raise CallError(
+ message="API接口的Service Metadata获取失败",
+ data={},
+ ) from e
+
+ # 获取Service对应的Auth
+ self._auth = service_metadata.api.auth
+ self._session_id = syscall_vars.session_id
+ self._service_id = syscall_vars.service_id
+
+ return APIInput(
+ service_id=self._service_id,
+ url=self.url,
+ method=self.method,
+ query=self.query,
+ body=self.body,
+ ).model_dump(exclude_none=True, by_alias=True)
+
+
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""调用API,然后返回LLM解析后的数据"""
self._session = aiohttp.ClientSession()
try:
- result = await self._call_api(slot_data)
- await self._session.close()
- return result
- except Exception as e:
+ result = await self._call_api(input_data)
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=result.model_dump(exclude_none=True, by_alias=True),
+ )
+ finally:
await self._session.close()
- raise RuntimeError from e
- async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202, C901
+ async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202
+ """组装API请求Session"""
# 获取必要参数
- syscall_vars: CallVars = getattr(self, "_syscall_vars")
-
- if self.content_type != "form":
- req_header = {
- "Content-Type": ContentType.JSON.value,
- }
- else:
- req_header = {}
+ req_header = {
+ "Content-Type": self.content_type.value,
+ }
req_cookie = {}
req_params = {}
- if data is None:
- data = {}
-
- if self.auth is not None and "type" in self.auth:
- if self.auth["type"] == "header":
- req_header.update(self.auth["args"])
- elif self.auth["type"] == "cookie":
- req_cookie.update(self.auth["args"])
- elif self.auth["type"] == "params":
- req_params.update(self.auth["args"])
- elif self.auth["type"] == "oidc":
+ data = data or {}
+
+ if self._auth:
+ if self._auth.header:
+ for item in self._auth.header:
+ req_header[item.name] = item.value
+ elif self._auth.cookie:
+ for item in self._auth.cookie:
+ req_cookie[item.name] = item.value
+ elif self._auth.query:
+ for item in self._auth.query:
+ req_params[item.name] = item.value
+ elif self._auth.oidc:
token = await TokenManager.get_plugin_token(
- self.auth["domain"],
- syscall_vars.session_id,
- self.auth["access_token_url"],
- int(self.auth["token_expire_time"]),
+ self._service_id,
+ self._session_id,
+ config["OIDC_ACCESS_TOKEN_URL"],
+ int(config["OIDC_ACCESS_TOKEN_EXPIRE_TIME"]),
)
req_header.update({"access-token": token})
@@ -102,49 +145,40 @@ class API(CoreCall, ret_type=APIOutput):
raise NotImplementedError(err)
- async def _call_api(self, slot_data: Optional[dict[str, Any]] = None) -> APIOutput:
+ async def _call_api(self, final_data: Optional[dict[str, Any]] = None) -> APIOutput:
+ """实际调用API,并处理返回值"""
# 获取必要参数
- params: _APIParams = getattr(self, "_params")
- LOGGER.info(f"调用接口{params.url},请求数据为{slot_data}")
+ logger.info("[API] 调用接口 %s,请求数据为 %s", self.url, final_data)
- session_context = await self._make_api_call(slot_data, aiohttp.FormData())
+ session_context = await self._make_api_call(final_data, aiohttp.FormData())
async with session_context as response:
- if response.status >= status.HTTP_400_BAD_REQUEST:
+ if response.status in SUCCESS_HTTP_CODES:
text = f"API发生错误:API返回状态码{response.status}, 原因为{response.reason}。"
- LOGGER.error(text)
+ logger.error(text)
raise CallError(
message=text,
data={"api_response_data": await response.text()},
)
+
response_status = response.status
response_data = await response.text()
- LOGGER.info(f"调用接口{params.url}, 结果为 {response_data}")
+ logger.info("[API] 调用接口 %s,结果为 %s", self.url, response_data)
- # 组装message
- message = f"""You called the HTTP API "{params.url}", which is used to "{self._spec[2]['summary']}"."""
# 如果没有返回结果
- if response_data is None:
+ if not response_data:
return APIOutput(
http_code=response_status,
- output={},
- message=message + "But the API returned an empty response.",
+ result={},
)
# 如果返回值是JSON
try:
response_dict = json.loads(response_data)
- except Exception as e:
- err = f"返回值不是JSON:{e!s}"
- LOGGER.error(err)
-
- # openapi里存在有HTTP 200对应的Schema,则进行处理
- if response_schema:
- slot = Slot(response_schema)
- response_data = json.dumps(slot.process_json(response_dict), ensure_ascii=False)
+ except Exception:
+ logger.exception("[API] 返回值不是JSON格式!")
return APIOutput(
http_code=response_status,
- output=json.loads(response_data),
- message=message + "The API returned some data, and is shown in the 'output' field below.",
+ result=response_dict,
)
diff --git a/apps/scheduler/call/api/schema.py b/apps/scheduler/call/api/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..12afcb7e55f66872c6ae4b6c71f398cabbcab031
--- /dev/null
+++ b/apps/scheduler/call/api/schema.py
@@ -0,0 +1,25 @@
+"""API调用工具的输入和输出"""
+
+from typing import Any, Union
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class APIInput(DataBase):
+ """API调用工具的输入"""
+
+ service_id: str = Field(description="API调用工具的Service ID")
+ url: str = Field(description="API调用工具的URL")
+ method: str = Field(description="API调用工具的HTTP方法")
+
+ query: dict[str, Any] = Field(description="API调用工具的请求参数")
+ body: dict[str, Any] = Field(description="API调用工具的请求体")
+
+
+class APIOutput(DataBase):
+ """API调用工具的输出"""
+
+ http_code: int = Field(description="API调用工具的HTTP返回码")
+ result: Union[dict[str, Any], str] = Field(description="API调用工具的输出")
diff --git a/apps/scheduler/call/choice.py b/apps/scheduler/call/choice/choice.py
similarity index 100%
rename from apps/scheduler/call/choice.py
rename to apps/scheduler/call/choice/choice.py
diff --git a/apps/scheduler/call/convert.py b/apps/scheduler/call/convert/convert.py
similarity index 54%
rename from apps/scheduler/call/convert.py
rename to apps/scheduler/call/convert/convert.py
index 0b6ff004fad19efdb0817c64e6f311c34adf79b4..8100755cc6f7abe6091ba76800340f732cf8575a 100644
--- a/apps/scheduler/call/convert.py
+++ b/apps/scheduler/call/convert/convert.py
@@ -3,43 +3,43 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
import json
+from collections.abc import AsyncGenerator
from datetime import datetime
from textwrap import dedent
-from typing import Any, Optional
+from typing import Any, ClassVar, Optional
import _jsonnet
import pytz
-from jinja2 import BaseLoader, select_autoescape
+from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
-from pydantic import BaseModel, Field
+from pydantic import Field
+from apps.entities.enum_var import CallOutputType
from apps.entities.scheduler import CallVars
-from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.convert.schema import ConvertInput, ConvertOutput
+from apps.scheduler.call.core import CallOutputChunk, CoreCall
-class ConvertOutput(BaseModel):
- """定义Convert工具的输出"""
-
- text: str = Field(description="格式化后的文字信息")
- data: dict = Field(description="格式化后的结果")
-
-
-class Convert(CoreCall, ret_type=ConvertOutput):
+class Convert(CoreCall, input_type=ConvertInput, output_type=ConvertOutput):
"""Convert 工具,用于对生成的文字信息和原始数据进行格式化"""
- name: str = "convert"
- description: str = "从上一步的工具的原始JSON返回结果中,提取特定字段的信息。"
+ name: ClassVar[str] = "模板转换"
+ description: ClassVar[str] = "使用jinja2语法和jsonnet语法,将自然语言信息和原始数据进行格式化。"
text_template: Optional[str] = Field(description="自然语言信息的格式化模板,jinja2语法", default=None)
data_template: Optional[str] = Field(description="原始数据的格式化模板,jsonnet语法", default=None)
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化工具"""
- return {}
+ await super().init(syscall_vars)
+
+ self._history = syscall_vars.history
+ self._question = syscall_vars.question
+ return ConvertInput().model_dump(exclude_none=True, by_alias=True)
- async def exec(self) -> ConvertOutput:
+ async def exec(self) -> AsyncGenerator[CallOutputChunk, None]:
"""调用Convert工具
:param _slot_data: 经用户确认后的参数(目前未使用)
@@ -47,33 +47,33 @@ class Convert(CoreCall, ret_type=ConvertOutput):
"""
# 判断用户是否给了值
time = datetime.now(tz=pytz.timezone("Asia/Shanghai")).strftime("%Y-%m-%d %H:%M:%S")
- if params.text is None:
+ if self.text_template is None:
result_message = last_output.get("message", "")
else:
text_template = SandboxedEnvironment(
loader=BaseLoader(),
- autoescape=select_autoescape(),
+ autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
- ).from_string(params.text)
- result_message = text_template.render(time=time, history=syscall_vars.history, question=syscall_vars.question)
+ ).from_string(self.text_template)
+ result_message = text_template.render(time=time, history=self._history, question=self._question)
- if params.data is None:
+ if self.data_template is None:
result_data = last_output.get("output", {})
else:
extra_str = json.dumps({
"time": time,
- "question": syscall_vars.question,
+ "question": self._question,
}, ensure_ascii=False)
- history_str = json.dumps([item.output_data["output"] for item in syscall_vars.history if "output" in item.output_data], ensure_ascii=False)
+ history_str = json.dumps([item.output_data["output"] for item in self._history if "output" in item.output_data], ensure_ascii=False)
data_template = dedent(f"""
local extra = {extra_str};
local history = {history_str};
- {params.data}
+ {self.data_template}
""")
- result_data = json.loads(_jsonnet.evaluate_snippet(data_template, params.data), ensure_ascii=False)
+ result_data = json.loads(_jsonnet.evaluate_snippet(data_template, self.data_template), ensure_ascii=False)
- return ConvertOutput(
- text=result_message,
- data=result_data,
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=result_message,
)
diff --git a/apps/scheduler/call/convert/schema.py b/apps/scheduler/call/convert/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..1af354c42717c63312b74e793f3173d6d06b13c0
--- /dev/null
+++ b/apps/scheduler/call/convert/schema.py
@@ -0,0 +1,17 @@
+"""Convert工具的Schema"""
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class ConvertInput(DataBase):
+ """定义Convert工具的输入"""
+
+
+
+class ConvertOutput(DataBase):
+ """定义Convert工具的输出"""
+
+ text: str = Field(description="格式化后的文字信息")
+ data: dict = Field(description="格式化后的结果")
diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py
index 675440431a44054bee4a6af518ee3449530ec6dd..83b6681ee2e10e67e455befa45ba256ea71c286e 100644
--- a/apps/scheduler/call/core.py
+++ b/apps/scheduler/call/core.py
@@ -4,42 +4,55 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
from collections.abc import AsyncGenerator
-from typing import Any, ClassVar
+from typing import Annotated, Any, ClassVar, Optional
from pydantic import BaseModel, ConfigDict, Field
-from apps.entities.scheduler import CallVars
+from apps.entities.scheduler import CallOutputChunk, CallVars
+
+
+class DataBase(BaseModel):
+ """所有Call的输入基类"""
+
+ @classmethod
+ def model_json_schema(cls, override: Optional[dict[str, Any]] = None, **kwargs: Any) -> dict[str, Any]:
+ """通过override参数,动态填充Schema内容"""
+ schema = super().model_json_schema(**kwargs)
+ if override:
+ for key, value in override.items():
+ schema["properties"][key] = value
+ return schema
class CoreCall(BaseModel):
"""所有Call的父类,所有Call必须继承此类。"""
- name: ClassVar[str] = Field(description="Call的名称")
- description: ClassVar[str] = Field(description="Call的描述")
+ name: ClassVar[Annotated[str, Field(description="Call的名称", exclude=True)]] = ""
+ description: ClassVar[Annotated[str, Field(description="Call的描述", exclude=True)]] = ""
model_config = ConfigDict(
arbitrary_types_allowed=True,
extra="allow",
)
- ret_type: ClassVar[type[BaseModel]] = Field(description="Call的返回类型")
+ input_type: ClassVar[type[BaseModel]] = Field(description="Call的输入Pydantic类型", exclude=True, frozen=True)
+ """需要以消息形式输出的、需要大模型填充参数的输入信息填到这里"""
+ output_type: ClassVar[type[BaseModel]] = Field(description="Call的输出Pydantic类型", exclude=True, frozen=True)
+ """需要以消息形式输出的输出信息填到这里"""
- def __init_subclass__(cls, ret_type: type[BaseModel], **kwargs: Any) -> None:
+ def __init_subclass__(cls, input_type: type[BaseModel], output_type: type[BaseModel], **kwargs: Any) -> None:
"""初始化子类"""
super().__init_subclass__(**kwargs)
- cls.ret_type = ret_type
+ cls.input_type = input_type
+ cls.output_type = output_type
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化Call类,并返回Call的输入"""
- raise NotImplementedError
-
-
- async def exec(self) -> dict[str, Any]:
- """Call类实例的非流式输出方法"""
- raise NotImplementedError
+ ...
- async def stream(self) -> AsyncGenerator[str, None]:
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""Call类实例的流式输出方法"""
- yield ""
+ err = "[CoreCall] 流式输出方法必须手动实现"
+ raise NotImplementedError(err)
diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py
index 73741f1662ecfeadc37bc04f198dffb5a3cc15c7..2a712cf4ee76e7134701c69ca53ea20fb3d1ba5a 100644
--- a/apps/scheduler/call/empty.py
+++ b/apps/scheduler/call/empty.py
@@ -2,30 +2,37 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
+from collections.abc import AsyncGenerator
from typing import Any, ClassVar
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
-from apps.entities.scheduler import CallVars
-from apps.scheduler.call.core import CoreCall
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallOutputChunk, CallVars
+from apps.scheduler.call.core import CoreCall, DataBase
-class EmptyData(BaseModel):
- """空数据"""
+class EmptyInput(BaseModel):
+ """空输入"""
-class Empty(CoreCall, ret_type=EmptyData):
+class EmptyOutput(DataBase):
+ """空输出"""
+
+
+class Empty(CoreCall, input_type=EmptyInput, output_type=EmptyOutput):
"""空Call"""
- name: ClassVar[str] = "空白"
- description: ClassVar[str] = "空白节点,用于占位"
+ name: ClassVar[str] = Field("空白", exclude=True, frozen=True)
+ description: ClassVar[str] = Field("空白节点,用于占位", exclude=True, frozen=True)
- async def init(self, _syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化"""
return {}
- async def exec(self) -> dict[str, Any]:
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""执行"""
- return EmptyData().model_dump()
+ output = CallOutputChunk(type=CallOutputType.TEXT, content="")
+ yield output
diff --git a/apps/scheduler/call/facts.py b/apps/scheduler/call/facts.py
deleted file mode 100644
index bec6f253e93487652eacad9c5ea5e3411cc1c477..0000000000000000000000000000000000000000
--- a/apps/scheduler/call/facts.py
+++ /dev/null
@@ -1,52 +0,0 @@
-"""提取事实工具"""
-from typing import Any, ClassVar
-
-from pydantic import BaseModel, Field
-
-from apps.entities.scheduler import CallVars
-from apps.llm.patterns.domain import Domain
-from apps.llm.patterns.facts import Facts
-from apps.manager.user_domain import UserDomainManager
-from apps.scheduler.call.core import CoreCall
-
-
-class FactsOutput(BaseModel):
- """提取事实工具的输出"""
-
- output: list[str] = Field(description="提取的事实")
-
-
-class FactsCall(CoreCall, ret_type=FactsOutput):
- """提取事实工具"""
-
- name: ClassVar[str] = "提取事实"
- description: ClassVar[str] = "从对话上下文和文档片段中提取事实。"
-
-
- async def init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]:
- """初始化工具"""
- # 组装必要变量
- self._message = {
- "question": syscall_vars.question,
- "answer": kwargs["answer"],
- }
- self._task_id = syscall_vars.task_id
- self._user_sub = syscall_vars.user_sub
-
- return {
- "message": self._message,
- "task_id": self._task_id,
- "user_sub": self._user_sub,
- }
-
-
- async def exec(self) -> dict[str, Any]:
- """执行工具"""
- # 提取事实信息
- facts = await Facts().generate(self._task_id, message=self._message)
- # 更新用户画像
- domain_list = await Domain().generate(self._task_id, message=self._message)
- for domain in domain_list:
- await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(self._user_sub, domain)
-
- return FactsOutput(output=facts).model_dump(exclude_none=True, by_alias=True)
diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py
new file mode 100644
index 0000000000000000000000000000000000000000..36e49d4e1921bdd3f418791d38388436e75680ee
--- /dev/null
+++ b/apps/scheduler/call/facts/facts.py
@@ -0,0 +1,58 @@
+"""提取事实工具"""
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar
+
+from pydantic import Field
+
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallOutputChunk, CallVars
+from apps.llm.patterns.domain import Domain
+from apps.llm.patterns.facts import Facts
+from apps.manager.user_domain import UserDomainManager
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.facts.schema import FactsInput, FactsMessage, FactsOutput
+
+
+class FactsCall(CoreCall, input_type=FactsInput, output_type=FactsOutput):
+ """提取事实工具"""
+
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "提取事实"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "从对话上下文和文档片段中提取事实。"
+
+ answer: str = Field(description="用户输入")
+
+
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
+ """初始化工具"""
+ await super().init(syscall_vars)
+
+ # 组装必要变量
+ message = FactsMessage(
+ question=syscall_vars.question,
+ answer=self.answer,
+ )
+
+ return FactsInput(
+ task_id=syscall_vars.task_id,
+ user_sub=syscall_vars.user_sub,
+ message=message,
+ ).model_dump(exclude_none=True, by_alias=True)
+
+
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行工具"""
+ data = FactsInput(**input_data)
+ # 提取事实信息
+ facts = await Facts().generate(data.task_id, conversation=data.message)
+ # 更新用户画像
+ domain_list = await Domain().generate(data.task_id, conversation=data.message)
+ for domain in domain_list:
+ await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain)
+
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=FactsOutput(
+ output=facts,
+ domain=domain_list,
+ ).model_dump(by_alias=True, exclude_none=True),
+ )
diff --git a/apps/scheduler/call/facts/schema.py b/apps/scheduler/call/facts/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..abbfe6655fe2d2a291b4d9228dae254c17b66566
--- /dev/null
+++ b/apps/scheduler/call/facts/schema.py
@@ -0,0 +1,27 @@
+"""Facts工具的输入和输出"""
+
+from pydantic import BaseModel, Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class FactsMessage(BaseModel):
+ """提取事实工具的消息"""
+
+ question: str = Field(description="用户输入")
+ answer: str = Field(description="用户输入")
+
+
+class FactsInput(DataBase):
+ """提取事实工具的输入"""
+
+ task_id: str = Field(description="任务ID")
+ user_sub: str = Field(description="用户ID")
+ message: FactsMessage = Field(description="消息")
+
+
+class FactsOutput(DataBase):
+ """提取事实工具的输出"""
+
+ output: list[str] = Field(description="提取的事实")
+ domain: list[str] = Field(description="提取的领域")
diff --git a/apps/scheduler/call/render/__init__.py b/apps/scheduler/call/graph/__init__.py
similarity index 100%
rename from apps/scheduler/call/render/__init__.py
rename to apps/scheduler/call/graph/__init__.py
diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/graph/graph.py
similarity index 48%
rename from apps/scheduler/call/render/render.py
rename to apps/scheduler/call/graph/graph.py
index 09bc3b8a05771a1317dc03137db38f290df83505..e480afa39ef18a8b6fc8b547fc863bfe88649916 100644
--- a/apps/scheduler/call/render/render.py
+++ b/apps/scheduler/call/graph/graph.py
@@ -3,115 +3,87 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
import json
-from pathlib import Path
-from typing import Any
+from collections.abc import AsyncGenerator
+from typing import Any, ClassVar
-from pydantic import BaseModel, Field
+from anyio import Path
+from pydantic import Field
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
from apps.scheduler.call.core import CoreCall
-from apps.scheduler.call.render.style import RenderStyle
+from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput
+from apps.scheduler.call.graph.style import RenderStyle
-class _RenderAxis(BaseModel):
- """ECharts图表的轴配置"""
-
- type: str = Field(description="轴的类型")
- axisTick: dict = Field(description="轴刻度配置") # noqa: N815
-
-
-class _RenderFormat(BaseModel):
- """ECharts图表配置"""
-
- tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置")
- legend: dict[str, Any] = Field(description="ECharts图表的图例配置")
- dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置")
- xAxis: _RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815
- yAxis: _RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815
- series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置")
-
-
-class _RenderOutput(BaseModel):
- """Render工具的输出"""
-
- output: _RenderFormat = Field(description="ECharts图表配置")
- message: str = Field(description="Render工具的输出")
-
-
-class _RenderParam(BaseModel):
- """Render工具的参数"""
-
-
-
-class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutput):
+class Graph(CoreCall, output_type=RenderOutput):
"""Render Call,用于将SQL Tool查询出的数据转换为图表"""
- name: str = "render"
- description: str = "渲染图表工具,可将给定的数据绘制为图表。"
+ name: ClassVar[str] = Field("图表", exclude=True, frozen=True)
+ description: ClassVar[str] = Field("将SQL查询出的数据转换为图表", exclude=True, frozen=True)
+
+ data: list[dict[str, Any]] = Field("用于绘制图表的数据", exclude=True, frozen=True)
- def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化Render Call,校验参数,读取option模板"""
+ await super().init(syscall_vars)
+
try:
option_location = Path(__file__).parent / "option.json"
- with Path(option_location).open(encoding="utf-8") as f:
- self._option_template = json.load(f)
+ f = await Path(option_location).open(encoding="utf-8")
+ data = await f.read()
+ self._option_template = json.loads(data)
+ await f.aclose()
except Exception as e:
raise CallError(message=f"图表模板读取失败:{e!s}", data={}) from e
+ return RenderInput(
+ question=syscall_vars.question,
+ task_id=syscall_vars.task_id,
+ data=self.data,
+ ).model_dump(exclude_none=True, by_alias=True)
- async def __call__(self, _slot_data: dict[str, Any]) -> _RenderOutput:
- """运行Render Call"""
- # 获取必要参数
- syscall_vars: CallVars = getattr(self, "_syscall_vars")
-
- # 检测前一个工具是否为SQL
- if "dataset" not in syscall_vars.history[-1].output_data:
- raise CallError(
- message="图表生成失败!Render必须在SQL后调用!",
- data={},
- )
- data = syscall_vars.history[-1].output_data["output"]
- if data["type"] != "sql" or "dataset" not in data:
- raise CallError(
- message="图表生成失败!Render必须在SQL后调用!",
- data={},
- )
- data = json.loads(data["dataset"])
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """运行Render Call"""
+ data = RenderInput(**input_data)
# 判断数据格式是否满足要求
# 样例:[{'openeuler_version': 'openEuler-22.03-LTS-SP2', '软件数量': 10}]
malformed = True
- if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
+ if isinstance(data.data, list) and len(data.data) > 0 and isinstance(data.data[0], dict):
malformed = False
# 将执行SQL工具查询到的数据转换为特定格式
if malformed:
raise CallError(
- message="SQL未查询到数据,或数据格式错误,无法生成图表!",
- data={"data": data},
+ message="数据格式错误,无法生成图表!",
+ data={"data": data.data},
)
# 对少量数据进行处理
- column_num = len(data[0]) - 1
+ processed_data = data.data
+ column_num = len(processed_data[0]) - 1
if column_num == 0:
- data = Render._separate_key_value(data)
+ processed_data = Graph._separate_key_value(processed_data)
column_num = 1
# 该格式满足ECharts DataSource要求,与option模板进行拼接
- self._option_template["dataset"]["source"] = data
+ self._option_template["dataset"]["source"] = processed_data
try:
- llm_output = await RenderStyle().generate(syscall_vars.task_id, question=syscall_vars.question)
+ llm_output = await RenderStyle().generate(data.task_id, question=data.question)
add_style = llm_output.get("additional_style", "")
self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"])
except Exception as e:
raise CallError(message=f"图表生成失败:{e!s}", data={"data": data}) from e
- return _RenderOutput(
- output=_RenderFormat.model_validate(self._option_template),
- message="图表生成成功!图表将使用外置工具进行展示。",
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=RenderOutput(
+ output=RenderFormat.model_validate(self._option_template),
+ ).model_dump(exclude_none=True, by_alias=True),
)
diff --git a/apps/scheduler/call/render/option.json b/apps/scheduler/call/graph/option.json
similarity index 100%
rename from apps/scheduler/call/render/option.json
rename to apps/scheduler/call/graph/option.json
diff --git a/apps/scheduler/call/graph/schema.py b/apps/scheduler/call/graph/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f26c5224997c5b0da4771a202298cab391f1187
--- /dev/null
+++ b/apps/scheduler/call/graph/schema.py
@@ -0,0 +1,39 @@
+"""图表工具的输入输出"""
+
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class RenderAxis(BaseModel):
+ """ECharts图表的轴配置"""
+
+ type: str = Field(description="轴的类型")
+ axisTick: dict = Field(description="轴刻度配置") # noqa: N815
+
+
+class RenderFormat(BaseModel):
+ """ECharts图表配置"""
+
+ tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置")
+ legend: dict[str, Any] = Field(description="ECharts图表的图例配置")
+ dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置")
+ xAxis: RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815
+ yAxis: RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815
+ series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置")
+
+
+class RenderInput(DataBase):
+ """图表工具的输入"""
+
+ question: str = Field(description="用户输入")
+ task_id: str = Field(description="任务ID")
+ data: list[dict[str, Any]] = Field(description="图表数据")
+
+
+class RenderOutput(DataBase):
+ """Render工具的输出"""
+
+ output: RenderFormat = Field(description="ECharts图表配置")
diff --git a/apps/scheduler/call/render/style.py b/apps/scheduler/call/graph/style.py
similarity index 100%
rename from apps/scheduler/call/render/style.py
rename to apps/scheduler/call/graph/style.py
diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm/llm.py
similarity index 61%
rename from apps/scheduler/call/llm.py
rename to apps/scheduler/call/llm/llm.py
index 290bf699056150fd04e7675479fa3df79a0cd986..10c59130174b43257240661b72c8e4e9137e1410 100644
--- a/apps/scheduler/call/llm.py
+++ b/apps/scheduler/call/llm/llm.py
@@ -5,34 +5,31 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
import logging
from collections.abc import AsyncGenerator
from datetime import datetime
-from typing import Any, ClassVar
+from typing import Annotated, Any, ClassVar
import pytz
-from jinja2 import BaseLoader, select_autoescape
+from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
-from pydantic import BaseModel, Field
+from pydantic import Field
-from apps.constants import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.llm.schema import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT, LLMInput, LLMOutput
logger = logging.getLogger("ray")
-class LLMNodeOutput(BaseModel):
- """定义LLM工具调用的输出"""
- message: str = Field(description="大模型输出的文字信息")
-
-class LLM(CoreCall, ret_type=LLMNodeOutput):
+class LLM(CoreCall, input_type=LLMInput, output_type=LLMOutput):
"""大模型调用工具"""
- name: ClassVar[str] = "大模型"
- description: ClassVar[str] = "以指定的提示词和上下文信息调用大模型,并获得输出。"
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "大模型"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "以指定的提示词和上下文信息调用大模型,并获得输出。"
temperature: float = Field(description="大模型温度(随机化程度)", default=0.7)
enable_context: bool = Field(description="是否启用上下文", default=True)
- step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3)
+ step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3, ge=1, le=10)
system_prompt: str = Field(description="大模型系统提示词", default="")
user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT)
@@ -42,7 +39,7 @@ class LLM(CoreCall, ret_type=LLMNodeOutput):
# 创建共享的 Environment 实例
env = SandboxedEnvironment(
loader=BaseLoader(),
- autoescape=select_autoescape(),
+ autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)
@@ -83,33 +80,24 @@ class LLM(CoreCall, ret_type=LLMNodeOutput):
]
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化LLM工具"""
- self._message = await self._prepare_message(syscall_vars)
- self._task_id = syscall_vars.task_id
- return {
- "task_id": self._task_id,
- "message": self._message,
- }
+ await super().init(syscall_vars)
+ return LLMInput(
+ task_id=syscall_vars.task_id,
+ message=await self._prepare_message(syscall_vars),
+ ).model_dump(exclude_none=True, by_alias=True)
- async def exec(self) -> dict[str, Any]:
+
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""运行LLM Call"""
+ data = LLMInput(**input_data)
try:
- result = ""
- async for chunk in ReasoningLLM().call(task_id=self._task_id, messages=self._message):
- result += chunk
+ # TODO: 临时改成了非流式
+ async for chunk in ReasoningLLM().call(task_id=data.task_id, messages=data.message, streaming=False):
+ if not chunk:
+ continue
+ yield CallOutputChunk(type=CallOutputType.TEXT, content=chunk)
except Exception as e:
raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e
-
- result = result.strip().strip("\n")
- return LLMNodeOutput(message=result).model_dump(exclude_none=True, by_alias=True)
-
-
- async def stream(self) -> AsyncGenerator[str, None]:
- """流式输出"""
- try:
- async for chunk in ReasoningLLM().call(task_id=self._task_id, messages=self._message):
- yield chunk
- except Exception as e:
- raise CallError(message=f"大模型流式输出失败:{e!s}", data={}) from e
diff --git a/apps/scheduler/call/llm/schema.py b/apps/scheduler/call/llm/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..08bd4fabe808e20e3d54d1f7add19cb6b8f62d45
--- /dev/null
+++ b/apps/scheduler/call/llm/schema.py
@@ -0,0 +1,107 @@
+"""LLM工具的输入输出定义"""
+from textwrap import dedent
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+LLM_CONTEXT_PROMPT = dedent(
+ r"""
+ 以下是对用户和AI间对话的简短总结:
+ {{ summary }}
+
+ 你作为AI,在回答用户的问题前,需要获取必要信息。为此,你调用了以下工具,并获得了输出:
+
+ {% for tool in history_data %}
+ {{ tool.step_id }}
+
+ {% endfor %}
+
+ """,
+ ).strip("\n")
+LLM_DEFAULT_PROMPT = dedent(
+ r"""
+
+ 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。
+ 当前时间:{{ time }},可以作为时间参照。
+ 用户的问题将在中给出,上下文背景信息将在中给出。
+ 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。
+
+
+
+ {{ question }}
+
+
+
+ {{ context }}
+
+
+ 现在,输出你的回答:
+ """,
+ ).strip("\n")
+LLM_ERROR_PROMPT = dedent(
+ r"""
+
+ 你是一位智能助手,能够根据用户的问题,使用Python工具获取信息,并作出回答。你在使用工具解决回答用户的问题时,发生了错误。
+ 你的任务是:分析工具(Python程序)的异常信息,分析造成该异常可能的原因,并以通俗易懂的方式,将原因告知用户。
+
+ 当前时间:{{ time }},可以作为时间参照。
+ 发生错误的程序异常信息将在中给出,用户的问题将在中给出,上下文背景信息将在中给出。
+ 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息。
+
+
+
+ {{ error_info }}
+
+
+
+ {{ question }}
+
+
+
+ {{ context }}
+
+
+ 现在,输出你的回答:
+ """,
+).strip("\n")
+RAG_ANSWER_PROMPT = dedent(
+ r"""
+
+ 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。
+ 用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。
+
+ 注意事项:
+ 1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。
+ 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。
+ 3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。
+ 4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。
+
+
+
+ {{ question }}
+
+
+
+ {{ context }}
+
+
+
+ {{ document }}
+
+
+ 现在,请根据上述信息,回答用户的问题:
+ """,
+ ).strip("\n")
+
+
+class LLMInput(DataBase):
+ """定义LLM工具调用的输入"""
+
+ task_id: str = Field(description="任务ID")
+ message: list[dict[str, str]] = Field(description="输入给大模型的消息列表")
+
+
+class LLMOutput(DataBase):
+ """定义LLM工具调用的输出"""
+
diff --git a/apps/scheduler/call/output/output.py b/apps/scheduler/call/output/output.py
new file mode 100644
index 0000000000000000000000000000000000000000..23e21fed0e29ebcab36b2d4fb164a768b1238c6e
--- /dev/null
+++ b/apps/scheduler/call/output/output.py
@@ -0,0 +1,41 @@
+"""输出工具:将文字和结构化数据输出至前端"""
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar
+
+from pydantic import Field
+from ray.actor import ActorHandle
+
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallOutputChunk, CallVars
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.output.schema import OutputInput, OutputOutput
+
+
+class Output(CoreCall, input_type=OutputInput, output_type=OutputOutput):
+ """输出工具"""
+
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "输出"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "将文字和结构化数据输出至前端"
+
+ template: str = Field(description="输出模板(只能使用直接变量引用)")
+
+
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
+ """初始化工具"""
+ self._text: AsyncGenerator[str, None] = kwargs["text"]
+ self._data: dict[str, Any] = kwargs["data"]
+
+ return {
+ "text": self._text,
+ "data": self._data,
+ }
+
+
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行工具"""
+ yield CallOutputChunk(
+ type=CallOutputType.TEXT,
+ content=OutputOutput(
+ output=self._text,
+ ).model_dump(by_alias=True, exclude_none=True),
+ )
diff --git a/apps/scheduler/call/output/schema.py b/apps/scheduler/call/output/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..468c07ed4c0f4e9e6ebade2b863898fa8d61f8ae
--- /dev/null
+++ b/apps/scheduler/call/output/schema.py
@@ -0,0 +1,15 @@
+"""输出工具的数据结构"""
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class OutputInput(DataBase):
+ """输出工具的输入"""
+
+
+class OutputOutput(DataBase):
+ """输出工具的输出"""
+
+ output: str = Field(description="输出工具的输出")
diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag/rag.py
similarity index 46%
rename from apps/scheduler/call/rag.py
rename to apps/scheduler/call/rag/rag.py
index 4c0e2278718ce53f5e95b62cf9e48a8aaaa2ceb2..c3eadaab92ea85e5085be7aa90a7b29dd381084f 100644
--- a/apps/scheduler/call/rag.py
+++ b/apps/scheduler/call/rag/rag.py
@@ -3,57 +3,60 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
import logging
-from typing import Any, ClassVar, Literal, Optional
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar, Literal, Optional
import aiohttp
from fastapi import status
-from pydantic import BaseModel, Field
+from pydantic import Field
from apps.common.config import config
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
+from apps.llm.patterns.rewrite import QuestionRewrite
from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.rag.schema import RAGInput, RAGOutput, RetrievalMode
logger = logging.getLogger("ray")
-class RAGOutput(BaseModel):
- """RAG工具的输出"""
-
- corpus: list[str] = Field(description="知识库的语料列表")
-
-
-class RAG(CoreCall, ret_type=RAGOutput):
+class RAG(CoreCall, input_type=RAGInput, output_type=RAGOutput):
"""RAG工具:查询知识库"""
- name: ClassVar[str] = "知识库"
- description: ClassVar[str] = "查询知识库,从文档中获取必要信息"
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "知识库"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "查询知识库,从文档中获取必要信息"
knowledge_base: Optional[str] = Field(description="知识库的id", alias="kb_sn", default=None)
top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5)
retrieval_mode: Literal["chunk", "full_text"] = Field(description="检索模式", default="chunk")
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化RAG工具"""
- self._params_dict = {
- "kb_sn": self.knowledge_base,
- "top_k": self.top_k,
- "retrieval_mode": self.retrieval_mode,
- "content": syscall_vars.question,
- }
+ await super().init(syscall_vars)
- self._url = config["RAG_HOST"].rstrip("/") + "/chunk/get"
- self._headers = {
- "Content-Type": "application/json",
- }
+ self._task_id = syscall_vars.task_id
+ return RAGInput(
+ content=syscall_vars.question,
+ kb_sn=self.knowledge_base,
+ top_k=self.top_k,
+ retrieval_mode=RetrievalMode(self.retrieval_mode),
+ ).model_dump(by_alias=True, exclude_none=True)
- return self._params_dict
-
- async def exec(self) -> dict[str, Any]:
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""调用RAG工具"""
+ data = RAGInput(**input_data)
+ question = await QuestionRewrite().generate(self._task_id, question=data.content)
+ data.content = question
+
+ url = config["RAG_HOST"].rstrip("/") + "/chunk/get"
+ headers = {
+ "Content-Type": "application/json",
+ }
+
# 发送 GET 请求
- async with aiohttp.ClientSession() as session, session.post(self._url, headers=self._headers, json=self._params_dict) as response:
+ async with aiohttp.ClientSession() as session, session.post(url, headers=headers, json=input_data) as response:
# 检查响应状态码
if response.status == status.HTTP_200_OK:
result = await response.json()
@@ -64,9 +67,14 @@ class RAG(CoreCall, ret_type=RAGOutput):
clean_chunk = chunk.replace("\n", " ")
corpus.append(clean_chunk)
- return RAGOutput(
- corpus=corpus,
- ).model_dump(exclude_none=True, by_alias=True)
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=RAGOutput(
+ question=data.content,
+ corpus=corpus,
+ ).model_dump(exclude_none=True, by_alias=True),
+ )
+ return
text = await response.text()
logger.error("[RAG] 调用失败:%s", text)
@@ -74,7 +82,7 @@ class RAG(CoreCall, ret_type=RAGOutput):
raise CallError(
message=f"rag调用失败:{text}",
data={
- "question": self._params_dict["content"],
+ "question": data.content,
"status": response.status,
"text": text,
},
diff --git a/apps/scheduler/call/rag/schema.py b/apps/scheduler/call/rag/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..0928792212c8e17e8a294858279fd40432b5ca1b
--- /dev/null
+++ b/apps/scheduler/call/rag/schema.py
@@ -0,0 +1,31 @@
+"""RAG工具的输入和输出"""
+
+from enum import Enum
+from typing import Optional
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class RetrievalMode(str, Enum):
+ """检索模式"""
+
+ CHUNK = "chunk"
+ FULL_TEXT = "full_text"
+
+
+class RAGOutput(DataBase):
+ """RAG工具的输出"""
+
+ question: str = Field(description="用户输入")
+ corpus: list[str] = Field(description="知识库的语料列表")
+
+
+class RAGInput(DataBase):
+ """RAG工具的输入"""
+
+ content: str = Field(description="用户输入")
+ knowledge_base: Optional[str] = Field(description="知识库的id", alias="kb_sn", default=None)
+ top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5)
+ retrieval_mode: RetrievalMode = Field(description="检索模式", default=RetrievalMode.CHUNK)
diff --git a/apps/scheduler/call/search/search.py b/apps/scheduler/call/search/search.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2be72baa3c0bdf0ae9bbea508df6eb0bbbe7d5d
--- /dev/null
+++ b/apps/scheduler/call/search/search.py
@@ -0,0 +1,31 @@
+"""搜索工具"""
+from typing import Any, ClassVar
+
+from pydantic import BaseModel, Field
+
+from apps.entities.scheduler import CallVars
+from apps.scheduler.call.core import CoreCall
+
+
+class SearchRet(BaseModel):
+ """搜索工具返回值"""
+
+ data: list[dict[str, Any]] = Field(description="搜索结果")
+
+
+class Search(CoreCall, ret_type=SearchRet):
+ """搜索工具"""
+
+ name: ClassVar[str] = "搜索"
+ description: ClassVar[str] = "获取搜索引擎的结果"
+
+ async def init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]:
+ """初始化工具"""
+ self._query: str = kwargs["query"]
+ return {}
+
+
+ async def exec(self) -> dict[str, Any]:
+ """执行工具"""
+ return {}
+
diff --git a/apps/scheduler/call/slot/schema.py b/apps/scheduler/call/slot/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d053aa58237975c0dff3b2104fe73d81f791820
--- /dev/null
+++ b/apps/scheduler/call/slot/schema.py
@@ -0,0 +1,24 @@
+"""参数填充工具的Schema"""
+from typing import Any
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+SLOT_GEN_PROMPT = [
+ {"role": "user", "content": ""},
+ {"role": "assistant", "content": ""},
+]
+
+
+class SlotInput(DataBase):
+ """参数填充工具的输入"""
+
+ remaining_schema: dict[str, Any] = Field(description="剩余的Schema", default={})
+
+
+class SlotOutput(DataBase):
+ """参数填充工具的输出"""
+
+ slot_data: dict[str, Any] = Field(description="填充后的数据", default={})
+ remaining_schema: dict[str, Any] = Field(description="剩余的Schema", default={})
diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d4367baf18d8956a88027cdcbbef2c8f24d5b12
--- /dev/null
+++ b/apps/scheduler/call/slot/slot.py
@@ -0,0 +1,68 @@
+"""参数填充工具"""
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar
+
+import ray
+from pydantic import Field
+
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallOutputChunk
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.slot.schema import SlotInput, SlotOutput
+from apps.scheduler.executor.step import CallVars
+from apps.scheduler.slot.slot import Slot as SlotProcessor
+
+
+class Slot(CoreCall, input_type=SlotInput, output_type=SlotOutput):
+ """参数填充工具"""
+
+ name: ClassVar[Annotated[str, Field(description="Call的名称", exclude=True)]] = "参数自动填充"
+ description: ClassVar[Annotated[str, Field(description="Call的描述", exclude=True)]] = "根据步骤历史,自动填充参数"
+
+ data: dict[str, Any] = Field(description="当前输入", default={})
+ current_schema: dict[str, Any] = Field(description="当前Schema", default={})
+ summary: str = Field(description="背景信息总结", default="")
+ facts: list[str] = Field(description="事实信息", default=[])
+ step_num: int = Field(description="历史步骤数", default=1)
+
+
+ async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> dict[str, Any]:
+ """使用LLM填充参数"""
+ ...
+
+
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
+ """初始化"""
+ await super().init(syscall_vars)
+
+ self._task_id = syscall_vars.task_id
+
+ if not self.current_schema:
+ return SlotInput(
+ remaining_schema={},
+ ).model_dump(by_alias=True, exclude_none=True)
+
+ self._processor = SlotProcessor(self.current_schema)
+ remaining_schema = self._processor.check_json(self.data)
+
+ return SlotInput(
+ remaining_schema=remaining_schema,
+ ).model_dump(by_alias=True, exclude_none=True)
+
+
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行参数填充"""
+ data = SlotInput(**input_data)
+
+ # 使用LLM填充参数
+ slot_data = await self._llm_slot_fill(data.remaining_schema)
+
+ # 再次检查
+ remaining_schema = self._processor.check_json(slot_data)
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=SlotOutput(
+ slot_data=slot_data,
+ remaining_schema=remaining_schema,
+ ).model_dump(by_alias=True, exclude_none=True),
+ )
diff --git a/apps/scheduler/call/sql/schema.py b/apps/scheduler/call/sql/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1dc669db95684eb676247f70961569b22788c72
--- /dev/null
+++ b/apps/scheduler/call/sql/schema.py
@@ -0,0 +1,19 @@
+"""SQL工具的输入输出"""
+
+from typing import Any
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class SQLInput(DataBase):
+ """SQL工具的输入"""
+
+ question: str = Field(description="用户输入")
+
+
+class SQLOutput(DataBase):
+ """SQL工具的输出"""
+
+ dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果")
diff --git a/apps/scheduler/call/sql.py b/apps/scheduler/call/sql/sql.py
similarity index 52%
rename from apps/scheduler/call/sql.py
rename to apps/scheduler/call/sql/sql.py
index 312c5faafed695c2eb015d7f1ab22ca2197f03b4..28c763c0b209e5cd0596a4c454d16003c0f1b9dd 100644
--- a/apps/scheduler/call/sql.py
+++ b/apps/scheduler/call/sql/sql.py
@@ -5,65 +5,68 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
import json
import logging
-from typing import Any, Optional
+from collections.abc import AsyncGenerator
+from typing import Any, ClassVar, Optional
import aiohttp
from fastapi import status
-from pydantic import BaseModel, Field
+from pydantic import Field
from sqlalchemy import text
from apps.common.config import config
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.scheduler import CallError, CallOutputChunk, CallOutputType, CallVars
from apps.models.postgres import PostgreSQL
from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.sql.schema import SQLInput, SQLOutput
logger = logging.getLogger("ray")
-class SQLOutput(BaseModel):
- """SQL工具的输出"""
-
- message: str = Field(description="SQL工具的执行结果")
- dataset: list[dict[str, Any]] = Field(description="SQL工具的执行结果")
-
-
class SQL(CoreCall):
"""SQL工具。用于调用外置的Chat2DB工具的API,获得SQL语句;再在PostgreSQL中执行SQL语句,获得数据。"""
- name: str = "数据库"
- description: str = "使用大模型生成SQL语句,用于查询数据库中的结构化数据"
+ name: ClassVar[str] = Field("数据库", exclude=True)
+ description: ClassVar[str] = Field("使用大模型生成SQL语句,用于查询数据库中的结构化数据", exclude=True)
+
sql: Optional[str] = Field(description="用户输入")
- def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化SQL工具。"""
- # 初始化aiohttp的ClientSession
- self._session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300))
+ await super().init(syscall_vars)
+
+ return SQLInput(
+ question=syscall_vars.question,
+ ).model_dump(by_alias=True, exclude_none=True)
- async def exec(self, _slot_data: dict[str, Any]) -> SQLOutput:
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""运行SQL工具"""
# 获取必要参数
- syscall_vars: CallVars = getattr(self, "_syscall_vars")
+ data = SQLInput(**input_data)
# 若手动设置了SQL,则直接使用
- session = await PostgreSQL.get_session()
- if params.sql:
+ pg_session = await PostgreSQL.get_session()
+ if self.sql:
try:
- result = (await session.execute(text(params.sql))).all()
- await session.close()
+ result = (await pg_session.execute(text(self.sql))).all()
+ await pg_session.close()
dataset_list = [db_item._asdict() for db_item in result]
- return SQLOutput(
- message="SQL查询成功!",
- dataset=dataset_list,
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=SQLOutput(
+ dataset=dataset_list,
+ ).model_dump(by_alias=True, exclude_none=True),
)
except Exception as e:
raise CallError(message=f"SQL查询错误:{e!s}", data={}) from e
+ else:
+ return
# 若未设置SQL,则调用Chat2DB工具API,获取SQL
post_data = {
- "question": syscall_vars.question,
+ "question": data.question,
"topk_sql": 5,
"use_llm_enhancements": True,
}
@@ -71,7 +74,9 @@ class SQL(CoreCall):
"Content-Type": "application/json",
}
- async with self._session.post(config["SQL_URL"], ssl=False, json=post_data, headers=headers) as response:
+ async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(300)) as session, session.post(
+ config["SQL_URL"], ssl=False, json=post_data, headers=headers,
+ ) as response:
if response.status != status.HTTP_200_OK:
raise CallError(
message=f"SQL查询错误:API返回状态码{response.status}, 详细原因为{response.reason}。",
@@ -79,20 +84,25 @@ class SQL(CoreCall):
)
result = json.loads(await response.text())
logger.info("SQL工具返回的信息为:%s", result)
- await self._session.close()
+
for item in result["sql_list"]:
try:
- db_result = (await session.execute(text(item["sql"]))).all()
- await session.close()
+ db_result = (await pg_session.execute(text(item["sql"]))).all()
+ await pg_session.close()
dataset_list = [db_item._asdict() for db_item in db_result]
- return SQLOutput(
- message="数据库查询成功!",
- dataset=dataset_list,
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=SQLOutput(
+ dataset=dataset_list,
+ ).model_dump(by_alias=True, exclude_none=True),
)
except Exception: # noqa: PERF203
logger.exception("SQL查询错误,正在换用下一条SQL语句。")
+ continue
+ else:
+ return
raise CallError(
message="SQL查询错误:SQL语句错误,数据库查询失败!",
diff --git a/apps/scheduler/call/suggest/schema.py b/apps/scheduler/call/suggest/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..91c34bfd8b4e0f2a8b517b600193eb8cb952c5e0
--- /dev/null
+++ b/apps/scheduler/call/suggest/schema.py
@@ -0,0 +1,37 @@
+"""问题推荐工具的输入输出"""
+
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class SingleFlowSuggestionConfig(BaseModel):
+ """涉及单个Flow的问题推荐配置"""
+
+ flow_id: str
+ question: Optional[str] = Field(default=None, description="固定的推荐问题")
+
+
+class SuggestionOutputItem(BaseModel):
+ """问题推荐结果的单个条目"""
+
+ question: str
+ app_id: str
+ flow_id: str
+ flow_description: str
+
+
+class SuggestionInput(DataBase):
+ """问题推荐输入"""
+
+ question: str
+ task_id: str
+ user_sub: str
+
+
+class SuggestionOutput(DataBase):
+ """问题推荐结果"""
+
+ output: list[SuggestionOutputItem]
diff --git a/apps/scheduler/call/suggest.py b/apps/scheduler/call/suggest/suggest.py
similarity index 38%
rename from apps/scheduler/call/suggest.py
rename to apps/scheduler/call/suggest/suggest.py
index 96181c758984594297ffcc5c953d12c2ce089881..ba031c3c81b80b60c95b128ec3521fd10776246a 100644
--- a/apps/scheduler/call/suggest.py
+++ b/apps/scheduler/call/suggest/suggest.py
@@ -2,72 +2,54 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
-from typing import Any, ClassVar, Optional
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar
import ray
-from pydantic import BaseModel, Field
+from pydantic import Field
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
from apps.manager.user_domain import UserDomainManager
from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.suggest.schema import (
+ SingleFlowSuggestionConfig,
+ SuggestionInput,
+ SuggestionOutput,
+)
-class _SingleFlowSuggestionConfig(BaseModel):
- """涉及单个Flow的问题推荐配置"""
-
- flow_id: str
- question: Optional[str] = Field(default=None, description="固定的推荐问题")
-
-
-class _SuggestionOutputItem(BaseModel):
- """问题推荐结果的单个条目"""
-
- question: str
- app_id: str
- flow_id: str
- flow_description: str
-
-
-class SuggestionOutput(BaseModel):
- """问题推荐结果"""
-
- output: list[_SuggestionOutputItem]
-
-
-class Suggestion(CoreCall, ret_type=SuggestionOutput):
+class Suggestion(CoreCall, input_type=SuggestionInput, output_type=SuggestionOutput):
"""问题推荐"""
- name: ClassVar[str] = "问题推荐"
- description: ClassVar[str] = "在答案下方显示推荐的下一个问题"
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "问题推荐"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = "在答案下方显示推荐的下一个问题"
- configs: list[_SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1)
+ configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1)
num: int = Field(default=3, ge=1, le=6, description="推荐问题的总数量(必须大于等于configs中涉及的Flow的数量)")
+ context: list[dict[str, str]]
+
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化"""
- # 普通变量
- self._question = syscall_vars.question
- self._task_id = syscall_vars.task_id
- self._background = syscall_vars.summary
- self._user_sub = syscall_vars.user_sub
+ await super().init(syscall_vars)
- return {
- "question": self._question,
- "task_id": self._task_id,
- "background": self._background,
- "user_sub": self._user_sub,
- }
+ return SuggestionInput(
+ question=syscall_vars.question,
+ task_id=syscall_vars.task_id,
+ user_sub=syscall_vars.user_sub,
+ ).model_dump(by_alias=True, exclude_none=True)
- async def exec(self) -> dict[str, Any]:
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""运行问题推荐"""
+ data = SuggestionInput(**input_data)
# 获取当前任务
task_actor = await ray.get_actor("task")
- task = await task_actor.get_task.remote(self._task_id)
+ task = await task_actor.get_task.remote(data.task_id)
# 获取当前用户的画像
- user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(self._user_sub, 5)
+ user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5)
current_record = [
{
diff --git a/apps/scheduler/call/summary.py b/apps/scheduler/call/summary.py
deleted file mode 100644
index 27b18280e839449126ab618e088a356b6e53ca90..0000000000000000000000000000000000000000
--- a/apps/scheduler/call/summary.py
+++ /dev/null
@@ -1,43 +0,0 @@
-"""总结上下文工具
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-"""
-from typing import Any, ClassVar
-
-from pydantic import BaseModel, Field
-
-from apps.entities.scheduler import CallVars, ExecutorBackground
-from apps.llm.patterns.executor import ExecutorSummary
-from apps.scheduler.call.core import CoreCall
-
-
-class SummaryOutput(BaseModel):
- """总结工具的输出"""
-
- summary: str = Field(description="对问答上下文的总结内容")
-
-
-class Summary(CoreCall, ret_type=SummaryOutput):
- """总结工具"""
-
- name: ClassVar[str] = "理解上下文"
- description: ClassVar[str] = "使用大模型,理解对话上下文"
-
-
- async def init(self, syscall_vars: CallVars, **kwargs: Any) -> dict[str, Any]:
- """初始化工具"""
- self._context: ExecutorBackground = kwargs["context"]
- self._task_id: str = syscall_vars.task_id
- if not self._context or not isinstance(self._context, list):
- err = "context参数必须是一个列表"
- raise ValueError(err)
-
- return {
- "task_id": self._task_id,
- }
-
-
- async def exec(self) -> dict[str, Any]:
- """执行工具"""
- summary = await ExecutorSummary().generate(self._task_id, background=self._context)
- return SummaryOutput(summary=summary).model_dump(by_alias=True, exclude_none=True)
diff --git a/apps/scheduler/call/summary/schema.py b/apps/scheduler/call/summary/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..7247fd3353dd6f775cf68a1d238dbef470eb854d
--- /dev/null
+++ b/apps/scheduler/call/summary/schema.py
@@ -0,0 +1,17 @@
+"""总结工具的输入和输出"""
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class SummaryInput(DataBase):
+ """总结工具的输入"""
+
+ task_id: str = Field(description="任务ID")
+
+
+class SummaryOutput(DataBase):
+ """总结工具的输出"""
+
+ summary: str = Field(description="对问答上下文的总结内容")
diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..52736cbe94bb4dabb0344c5b5d397e3c8eb4c21d
--- /dev/null
+++ b/apps/scheduler/call/summary/summary.py
@@ -0,0 +1,39 @@
+"""总结上下文工具
+
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+"""
+from collections.abc import AsyncGenerator
+from typing import Any, ClassVar
+
+from pydantic import Field
+
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallOutputChunk, CallVars, ExecutorBackground
+from apps.llm.patterns.executor import ExecutorSummary
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.summary.schema import SummaryInput, SummaryOutput
+
+
+class Summary(CoreCall, input_type=SummaryInput, output_type=SummaryOutput):
+ """总结工具"""
+
+ name: ClassVar[str] = Field("理解上下文", exclude=True)
+ description: ClassVar[str] = Field("使用大模型,理解对话上下文", exclude=True)
+
+ context: ExecutorBackground = Field(description="对话上下文")
+
+
+ async def init(self, syscall_vars: CallVars) -> dict[str, Any]:
+ """初始化工具"""
+ await super().init(syscall_vars)
+
+ return SummaryInput(
+ task_id=syscall_vars.task_id,
+ ).model_dump(by_alias=True, exclude_none=True)
+
+
+ async def exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """执行工具"""
+ data = SummaryInput(**input_data)
+ summary = await ExecutorSummary().generate(data.task_id, background=self.context)
+ yield CallOutputChunk(type=CallOutputType.TEXT, content=summary)
diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..30ad51f636d592e87d640f23f335d69ce939cb35
--- /dev/null
+++ b/apps/scheduler/executor/base.py
@@ -0,0 +1,81 @@
+"""Executor基类"""
+import logging
+from typing import Optional, Any
+from datetime import datetime, timezone
+
+from pydantic import BaseModel, ConfigDict
+from ray import actor
+
+from apps.entities.enum_var import EventType, StepStatus
+from apps.entities.flow import Flow
+from apps.entities.task import TaskBlock
+from apps.entities.scheduler import ExecutorBackground
+from apps.entities.message import FlowStartContent
+from apps.entities.task import FlowStepHistory
+
+logger = logging.getLogger("ray")
+
+
+class BaseExecutor(BaseModel):
+ """Executor基类"""
+ task: TaskBlock
+ queue: actor.ActorHandle
+ flow: Flow
+ executor_background: ExecutorBackground
+ question: str
+ step_history: Optional[FlowStepHistory] = None
+
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra="allow",
+ )
+
+
+ def validate_flow_state(self) -> None:
+ """验证flow_state是否存在"""
+ if not self.task.flow_state:
+ err = "[Executor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+
+ async def push_message(self, event_type: EventType, data: Optional[dict] = None) -> None:
+ """统一的消息推送接口
+
+ Args:
+ event_type: 事件类型
+ data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent
+ """
+ self.validate_flow_state()
+
+ if event_type == EventType.FLOW_START and data is None:
+ content = FlowStartContent(
+ question=self.question,
+ params=self.task.runtime_data.filled_data,
+ )
+ data = content.model_dump(exclude_none=True, by_alias=True)
+ elif event_type == EventType.FLOW_STOP and data is None:
+ data = {}
+ elif event_type == EventType.STEP_INPUT:
+ # 更新step_history的输入数据
+ if self.step_history is not None and data is not None:
+ self.step_history.input_data = data
+ if self.task.flow_state is not None:
+ self.step_history.status = self.task.flow_state.status
+ # 步骤开始,重置时间
+ self.task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp(), 2)
+ elif event_type == EventType.STEP_OUTPUT:
+ # 更新step_history的输出数据
+ if self.step_history is not None and data is not None:
+ self.step_history.output_data = data
+ if self.task.flow_state is not None:
+ self.step_history.status = self.task.flow_state.status
+ self.task.flow_context[self.task.flow_state.step_id] = self.step_history
+ self.task.runtime_data.new_context.append(self.step_history.id)
+
+ await self.queue.push_output.remote(
+ self.task,
+ event_type=event_type,
+ data=data
+ ) # type: ignore[attr-defined]
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index 8e3f4c5ffe0a6c416f02d89aaa7b09982d793b64..409bedc121ffe5fda43b7ad5d2e33aa22da58c06 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -2,49 +2,67 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
-import asyncio
-import inspect
import logging
-from typing import Any, Optional
import ray
from pydantic import BaseModel, ConfigDict, Field
from ray import actor
-from apps.entities.enum_var import StepStatus
-from apps.entities.flow import Flow, FlowError, Step
+from apps.entities.enum_var import EventType, StepStatus
+from apps.entities.flow import Flow, Step
+from apps.entities.message import FlowStartContent
from apps.entities.request_data import RequestDataApp
from apps.entities.scheduler import CallVars, ExecutorBackground
from apps.entities.task import ExecutorState, TaskBlock
-from apps.llm.patterns.executor import ExecutorSummary
-from apps.manager.node import NodeManager
from apps.manager.task import TaskManager
-from apps.scheduler.call.core import CoreCall
-from apps.scheduler.call.empty import Empty
-from apps.scheduler.call.llm import LLM
-from apps.scheduler.executor.message import (
- push_flow_start,
- push_flow_stop,
- push_step_input,
- push_step_output,
- push_text_output,
-)
-from apps.scheduler.slot.slot import Slot
+from apps.scheduler.call.llm.schema import LLM_ERROR_PROMPT
+from apps.scheduler.executor.base import BaseExecutor
+from apps.scheduler.executor.step import StepExecutor
logger = logging.getLogger("ray")
+FIXED_STEPS_BEFORE_START = {
+ "_summary": Step(
+ name="理解上下文",
+ description="使用大模型,理解对话上下文",
+ node="Summary",
+ type="Summary",
+ params={},
+ ),
+}
+FIXED_STEPS_AFTER_END = {
+ "_facts": Step(
+ name="记忆存储",
+ description="理解对话答案,并存储到记忆中",
+ node="Facts",
+ type="Facts",
+ params={},
+ ),
+}
+SLOT_FILLING_STEP = Step(
+ name="填充参数",
+ description="根据步骤历史,填充参数",
+ node="Slot",
+ type="Slot",
+ params={},
+)
+ERROR_STEP = Step(
+ name="错误处理",
+ description="错误处理",
+ node="LLM",
+ type="LLM",
+ params={
+ "user_prompt": LLM_ERROR_PROMPT,
+ },
+)
# 单个流的执行工具
-class Executor(BaseModel):
+class FlowExecutor(BaseExecutor):
"""用于执行工作流的Executor"""
flow_id: str = Field(description="Flow ID")
- flow: Flow = Field(description="工作流数据")
- task: TaskBlock = Field(description="任务信息")
question: str = Field(description="用户输入")
- queue: actor.ActorHandle = Field(description="消息队列")
post_body_app: RequestDataApp = Field(description="请求体中的app信息")
- executor_background: ExecutorBackground = Field(description="Executor的背景信息")
model_config = ConfigDict(
arbitrary_types_allowed=True,
@@ -58,216 +76,89 @@ class Executor(BaseModel):
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
if self.task.flow_state:
- self.flow_state = self.task.flow_state
self.task.flow_context = await TaskManager.get_flow_history_by_task_id(self.task.record.task_id)
else:
# 创建ExecutorState
- self.flow_state = ExecutorState(
+ self.task.flow_state = ExecutorState(
flow_id=str(self.flow_id),
description=str(self.flow.description),
status=StepStatus.RUNNING,
app_id=str(self.post_body_app.app_id),
step_id="start",
step_name="开始",
- ai_summary="",
- filled_data=self.post_body_app.params,
)
- # 是否结束运行
- self._can_continue = True
- # 向用户输出的内容
- self._final_answer = ""
-
-
- async def _check_cls(self, call_cls: Any) -> bool:
- """检查Call是否符合标准要求"""
- flag = True
- if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str):
- flag = False
- if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str):
- flag = False
- if not hasattr(call_cls, "exec") or not asyncio.iscoroutinefunction(call_cls.exec):
- flag = False
- if not hasattr(call_cls, "stream") or not inspect.isasyncgenfunction(call_cls.stream):
- flag = False
- return flag
-
-
- # TODO: 默认错误处理步骤
- async def _run_error(self, step: FlowError) -> dict[str, Any]:
- """运行错误处理步骤"""
- return {}
-
-
- async def _get_call_cls(self, node_id: str) -> type[CoreCall]:
- """获取并验证Call类"""
- # 检查flow_state是否为空
- if not self.flow_state:
- err = "[FlowExecutor] flow_state为空"
- raise ValueError(err)
+ # 是否到达Flow结束终点(变量)
+ self._reached_end: bool = False
- # 特判
- if node_id == "Empty":
- return Empty
- # 获取对应Node的call_id
- call_id = await NodeManager.get_node_call_id(node_id)
- # 从Pool中获取对应的Call
- pool = ray.get_actor("pool")
- call_cls: type[CoreCall] = await pool.get_call.remote(call_id)
-
- # 检查Call合法性
- if not await self._check_cls(call_cls):
- err = f"[FlowExecutor] 工具 {node_id} 不符合Call标准要求"
+ async def _invoke_runner(self, step_id: str) -> None:
+ """调用Runner"""
+ if not self.task.flow_state:
+ err = "[FlowExecutor] 当前ExecutorState为空"
+ logger.error(err)
raise ValueError(err)
- return call_cls
-
-
- async def _process_slots(self, call_obj: Any) -> tuple[bool, Optional[dict[str, Any]]]:
- """处理slot参数"""
- if not (hasattr(call_obj, "slot_schema") and call_obj.slot_schema):
- return True, None
+ step = self.flow.steps[step_id]
- slot_processor = Slot(call_obj.slot_schema)
- remaining_schema, slot_data = await slot_processor.process(
- self.flow_state.filled_data,
- self.post_body_app.params,
- {
- "task_id": self.task.record.task_id,
- "question": self.question,
- "thought": self.flow_state.ai_summary,
- "previous_output": await self._get_last_output(self.task),
- },
- )
-
- # 保存Schema至State
- self.flow_state.remaining_schema = remaining_schema
- self.flow_state.filled_data.update(slot_data)
-
- # 如果还有未填充的部分,则返回False
- if remaining_schema:
- self._can_continue = False
- self.flow_state.status = StepStatus.RUNNING
- # 推送空输入输出
- self.task = await push_step_input(self.task, self.queue, self.flow_state, {})
- self.flow_state.status = StepStatus.PARAM
- self.task = await push_step_output(self.task, self.queue, self.flow_state, {})
- return False, None
-
- return True, slot_data
-
-
- # TODO
- async def _get_last_output(self, task: TaskBlock) -> Optional[dict[str, Any]]:
- """获取上一步的输出"""
- return None
-
-
- async def _execute_call(self, call_obj: Any, *, is_final_answer: bool) -> dict[str, Any]:
- """执行Call并处理结果"""
- # call_obj一定合法;开始判断是否为最终结果
- if is_final_answer and isinstance(call_obj, LLM):
- # 最后一步 & 是大模型步骤,直接流式输出
- async for chunk in call_obj.stream():
- await push_text_output(self.task, self.queue, chunk)
- self._final_answer += chunk
- self.flow_state.status = StepStatus.SUCCESS
- return {
- "message": self._final_answer,
- }
-
- # 其他情况:先运行步骤,得到结果
- result: dict[str, Any] = await call_obj.exec()
- if is_final_answer:
- if call_obj.name == "Convert":
- # 如果是Convert,直接输出转换之后的结果
- self._final_answer += result["message"]
- await push_text_output(self.task, self.queue, self._final_answer)
- self.flow_state.status = StepStatus.SUCCESS
- return {
- "message": self._final_answer,
- }
- # 其他工具,加一步大模型过程
- # FIXME: 使用单独的Prompt
- self.flow_state.status = StepStatus.SUCCESS
- return {
- "message": self._final_answer,
- }
-
- # 其他情况:返回结果
- self.flow_state.status = StepStatus.SUCCESS
- return result
-
-
- async def _run_step(self, step_id: str, step_data: Step) -> None:
- """运行单个步骤"""
- logger.info("[FlowExecutor] 运行步骤 %s", step_data.name)
-
- # State写入ID和运行状态
- self.flow_state.step_id = step_id
- self.flow_state.step_name = step_data.name
- self.flow_state.status = StepStatus.RUNNING
-
- # 查找下一个步骤;判断是否是end前最后一步
- next_step = await self._get_next_step()
- is_final_answer = next_step == "end"
-
- # 获取并验证Call类
- node_id = step_data.node
- call_cls = await self._get_call_cls(node_id)
+ # 获取task actor
+ task_actor = ray.get_actor("task")
# 准备系统变量
sys_vars = CallVars(
question=self.question,
task_id=self.task.record.task_id,
flow_id=self.post_body_app.flow_id,
- session_id=self.task.session_id,
+ session_id=self.task.runtime_data.session_id,
history=self.task.flow_context,
- summary=self.flow_state.ai_summary,
- user_sub=self.task.user_sub,
+ summary=self.task.runtime_data.summary,
+ user_sub=self.task.runtime_data.user_sub,
+ service_id=step.params.get("service_id", ""),
+ )
+ step_runner = StepExecutor(
+ queue=self.queue,
+ task=self.task,
+ flow=self.flow,
+ sys_vars=sys_vars,
+ executor_background=self.executor_background,
+ question=self.question,
)
- # 初始化Call
- call_obj = call_cls.model_validate(step_data.params)
- input_data = await call_obj.init(sys_vars)
-
- # TODO: 处理slots
- # can_continue, slot_data = await self._process_slots(call_obj)
- # if not self._can_continue:
- # return
-
- # 执行Call并获取结果
- self.task = await push_step_input(self.task, self.queue, self.flow_state, input_data)
- result_data = await self._execute_call(call_obj, is_final_answer=is_final_answer)
- self.task = await push_step_output(self.task, self.queue, self.flow_state, result_data)
+ # 运行Step
+ call_id, call_obj = await step_runner.init_step(step_id)
+ # 尝试填参
+ input_data = await step_runner.fill_slots(call_obj)
+ # 运行
+ await step_runner.run_step(call_id, call_obj, input_data)
- # 更新下一步
- self.flow_state.step_id = next_step
+ # 更新Task
+ self.task = step_runner.task
+ await task_actor.set_task.remote(self.task.record.task_id, self.task)
- async def _get_next_step(self) -> str:
+ async def _find_flow_next(self) -> str:
"""在当前步骤执行前,尝试获取下一步"""
+ if not self.task.flow_state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
# 如果当前步骤为结束,则直接返回
- if self.flow_state.step_id == "end" or not self.flow_state.step_id:
+ if self.task.flow_state.step_id == "end" or not self.task.flow_state.step_id:
# 如果是最后一步,设置停止标志
- self._can_continue = False
- return ""
-
- if self.flow.steps[self.flow_state.step_id].node == "Choice":
- # 如果是选择步骤,那么现在还不知道下一步是谁,直接返回
+ self._reached_end = True
return ""
next_steps = []
# 遍历Edges,查找下一个节点
for edge in self.flow.edges:
- if edge.edge_from == self.flow_state.step_id:
+ if edge.edge_from == self.task.flow_state.step_id:
next_steps += [edge.edge_to]
# 如果step没有任何出边,直接跳到end
if not next_steps:
return "end"
- # FIXME: 目前只使用第一个出边
+ # TODO: 目前只使用第一个出边
logger.info("[FlowExecutor] 下一步 %s", next_steps[0])
return next_steps[0]
@@ -277,33 +168,59 @@ class Executor(BaseModel):
数据通过向Queue发送消息的方式传输
"""
+ if not self.task.flow_state:
+ err = "[FlowExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
task_actor = ray.get_actor("task")
logger.info("[FlowExecutor] 运行工作流")
# 推送Flow开始消息
- self.task = await push_flow_start(self.task, self.queue, self.flow_state, self.question)
+ await self.push_message(EventType.FLOW_START)
+
+ # 进行开始前的系统步骤
+ for step_id, step in FIXED_STEPS_BEFORE_START.items():
+ self.flow.steps[step_id] = step
+ await self._invoke_runner(step_id)
+ self.task.flow_state.step_id = "start"
+ self.task.flow_state.step_name = "开始"
# 如果允许继续运行Flow
- while self._can_continue:
+ while not self._reached_end:
# Flow定义中找不到step
- if not self.flow_state.step_id or (self.flow_state.step_id not in self.flow.steps):
- logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.flow_state.step_id)
- self.flow_state.status = StepStatus.ERROR
+ if not self.task.flow_state.step_id or (self.task.flow_state.step_id not in self.flow.steps):
+ logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.task.flow_state.step_id)
+ self.task.flow_state.status = StepStatus.ERROR
- if self.flow_state.status == StepStatus.ERROR:
+ if self.task.flow_state.status == StepStatus.ERROR:
# 执行错误处理步骤
logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤")
- step = self.flow.on_error
- await self._run_error(step)
+ self.flow.steps["_error"] = ERROR_STEP
+ await self._invoke_runner("_error")
else:
# 执行正常步骤
- step = self.flow.steps[self.flow_state.step_id]
- await self._run_step(self.flow_state.step_id, step)
+ step = self.flow.steps[self.task.flow_state.step_id]
+ await self._invoke_runner(self.task.flow_state.step_id)
# 步骤结束,更新全局的Task
await task_actor.set_task.remote(self.task.record.task_id, self.task)
+ # 查找下一个节点
+ next_step_id = await self._find_flow_next()
+ if next_step_id:
+ self.task.flow_state.step_id = next_step_id
+ self.task.flow_state.step_name = self.flow.steps[next_step_id].name
+ else:
+ # 如果没有下一个节点,设置结束标志
+ self._reached_end = True
+
+ # 运行结束后的系统步骤
+ for step_id, step in FIXED_STEPS_AFTER_END.items():
+ self.flow.steps[step_id] = step
+ await self._invoke_runner(step_id)
+
# 推送Flow停止消息
- self.task = await push_flow_stop(self.task, self.queue, self.flow_state, self.flow, self._final_answer)
+ await self.push_message(EventType.FLOW_STOP)
# 更新全局的Task
await task_actor.set_task.remote(self.task.record.task_id, self.task)
diff --git a/apps/scheduler/executor/hook.py b/apps/scheduler/executor/hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12fbb85bf740c9c4c3843197f3065fc5a4cb4ce
--- /dev/null
+++ b/apps/scheduler/executor/hook.py
@@ -0,0 +1,75 @@
+"""特殊Call的Runner;obj为当前Executor对象"""
+from typing import Any, ClassVar, Protocol, TypeVar
+
+from pydantic import Field
+from ray import actor
+
+from apps.entities.enum_var import EventType
+from apps.entities.scheduler import CallOutputChunk, ExecutorBackground
+from apps.entities.task import TaskBlock
+from apps.scheduler.call.facts.facts import FactsCall
+from apps.scheduler.call.summary.summary import Summary
+
+
+class StepHookProtocol(Protocol):
+ """Step Hook的Protocol"""
+
+ task: TaskBlock
+ executor_background: ExecutorBackground = Field(description="Executor的背景信息")
+
+ async def _push_call_output(self, task: TaskBlock, queue: actor.ActorHandle, msg_type: EventType, data: dict[str, Any]) -> None:
+ """推送输出"""
+ ...
+
+
+T = TypeVar("T", bound=StepHookProtocol)
+
+
+class StepHookMixins:
+ """Executor的Step Hook Mixins"""
+
+ _HOOK: ClassVar[dict[str, Any]] = {
+ "Facts": {
+ "before_init": "_facts_before_init",
+ "after_run": "_facts_after_run",
+ },
+ "Summary": {
+ "before_init": "_summary_before_init",
+ "after_run": "_summary_after_run",
+ },
+ "Output": {
+ "after_run": "_output_after_run",
+ },
+ }
+
+ async def _facts_before_init(self: T) -> FactsCall: # type: ignore[override]
+ """Facts Call的特殊run_step"""
+ return FactsCall(
+ answer=self.task.record.content.answer,
+ )
+
+ async def _facts_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override]
+ """Facts Call的输出hook"""
+ self.task.record.content.facts = output.content["facts"]
+
+
+ async def _summary_before_init(self: T) -> Summary: # type: ignore[override]
+ """Summary Call的输入hook"""
+ return Summary(
+ context=self.executor_background,
+ )
+
+
+ async def _summary_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override]
+ """Summary Call的输出hook"""
+ self.task.runtime_data.summary = output.content
+
+
+ async def _output_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override]
+ """Output Call的输出hook"""
+ ...
+
+
+ async def _suggest_after_run(self: T, output: CallOutputChunk) -> None: # type: ignore[override]
+ """Suggest Call的输出hook"""
+ ...
diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py
deleted file mode 100644
index f387dc5ffd352d69ebaa997941c55534e8de1cd4..0000000000000000000000000000000000000000
--- a/apps/scheduler/executor/message.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""FlowExecutor的消息推送
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-"""
-from datetime import datetime, timezone
-from typing import Any
-
-from ray import actor
-
-from apps.entities.enum_var import EventType, FlowOutputType
-from apps.entities.flow import Flow
-from apps.entities.message import (
- FlowStartContent,
- FlowStopContent,
- TextAddContent,
-)
-from apps.entities.task import (
- ExecutorState,
- FlowStepHistory,
- TaskBlock,
-)
-
-
-# FIXME: 临时使用截断规避前端问题
-def truncate_data(data: Any) -> Any:
- """截断数据"""
- if isinstance(data, dict):
- return {k: truncate_data(v) for k, v in data.items()}
- if isinstance(data, list):
- return [truncate_data(v) for v in data]
- if isinstance(data, str):
- if len(data) > 300:
- return data[:300] + "..."
- return data
- return data
-
-
-async def push_step_input(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, input_data: dict[str, Any]) -> TaskBlock:
- """推送步骤输入"""
- # 更新State
- task.flow_state = state
- # 更新FlowContext
- task.flow_context[state.step_id] = FlowStepHistory(
- task_id=task.record.task_id,
- flow_id=state.flow_id,
- step_id=state.step_id,
- status=state.status,
- input_data=state.filled_data,
- output_data={},
- )
- # 步骤开始,重置时间
- task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp(), 2)
-
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.STEP_INPUT, data=truncate_data(input_data)) # type: ignore[attr-defined]
- return task
-
-
-async def push_step_output(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, output: dict[str, Any]) -> TaskBlock:
- """推送步骤输出"""
- # 更新State
- task.flow_state = state
-
- # 更新FlowContext
- task.flow_context[state.step_id].output_data = output
- task.flow_context[state.step_id].status = state.status
-
- # FlowContext加入Record
- task.new_context.append(task.flow_context[state.step_id].id)
-
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.STEP_OUTPUT, data=truncate_data(output)) # type: ignore[attr-defined]
- return task
-
-
-async def push_flow_start(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, question: str) -> TaskBlock:
- """推送Flow开始"""
- # 设置state
- task.flow_state = state
-
- # 组装消息
- content = FlowStartContent(
- question=question,
- params=state.filled_data,
- )
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.FLOW_START, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined]
- return task
-
-
-async def assemble_flow_stop_content(state: ExecutorState, flow: Flow) -> FlowStopContent:
- """组装Flow结束消息"""
- if state.remaining_schema:
- # 如果当前Flow是填充步骤,则推送Schema
- content = FlowStopContent(
- type=FlowOutputType.SCHEMA,
- data=state.remaining_schema,
- )
- else:
- content = FlowStopContent()
-
- return content
-
- # elif call_type == "render":
- # # 如果当前Flow是图表,则推送Chart
- # chart_option = task.flow_context[state.step_id].output_data["output"]
- # content = FlowStopContent(
- # type=FlowOutputType.CHART,
- # data=chart_option,
- # )
-
-async def push_flow_stop(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, flow: Flow, final_answer: str) -> TaskBlock:
- """推送Flow结束"""
- # 设置state
- task.flow_state = state
- # 保存最终输出
- task.record.content.answer = final_answer
- content = await assemble_flow_stop_content(state, flow)
-
- # 推送Stop消息
- await queue.push_output.remote(task, event_type=EventType.FLOW_STOP, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined]
- return task
-
-
-async def push_text_output(task: TaskBlock, queue: actor.ActorHandle, text: str) -> None:
- """推送文本输出"""
- content = TextAddContent(
- text=text,
- )
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.TEXT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined]
diff --git a/apps/scheduler/executor/node.py b/apps/scheduler/executor/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..64eed95b1727a05e9bcb2aa6b48c0b13981624b3
--- /dev/null
+++ b/apps/scheduler/executor/node.py
@@ -0,0 +1,52 @@
+"""Node相关的函数,含Node转换为Call"""
+import inspect
+from typing import Any
+
+import ray
+
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.empty import Empty
+from apps.scheduler.call.facts.facts import FactsCall
+from apps.scheduler.call.slot.slot import Slot
+from apps.scheduler.call.summary.summary import Summary
+
+
+class StepNode:
+ """Executor的Node相关逻辑"""
+
+ @staticmethod
+ async def check_cls(call_cls: Any) -> bool:
+ """检查Call是否符合标准要求"""
+ flag = True
+ if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str):
+ flag = False
+ if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str):
+ flag = False
+ if not hasattr(call_cls, "exec") or not inspect.isasyncgenfunction(call_cls.exec):
+ flag = False
+ return flag
+
+
+ @staticmethod
+ async def get_call_cls(call_id: str) -> type[CoreCall]:
+ """获取并验证Call类"""
+ # 特判,用于处理隐藏节点
+ if call_id == "Empty":
+ return Empty
+ if call_id == "Summary":
+ return Summary
+ if call_id == "Facts":
+ return FactsCall
+ if call_id == "Slot":
+ return Slot
+
+ # 从Pool中获取对应的Call
+ pool = ray.get_actor("pool")
+ call_cls: type[CoreCall] = await pool.get_call.remote(call_id)
+
+ # 检查Call合法性
+ if not await StepNode.check_cls(call_cls):
+ err = f"[FlowExecutor] 工具 {call_id} 不符合Call标准要求"
+ raise ValueError(err)
+
+ return call_cls
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
new file mode 100644
index 0000000000000000000000000000000000000000..22ca1e780b69dbfb8393f7f91620ddec9bb6ec96
--- /dev/null
+++ b/apps/scheduler/executor/step.py
@@ -0,0 +1,193 @@
+"""工作流中步骤相关函数"""
+
+import logging
+from datetime import datetime, timezone
+from typing import Any, Optional
+
+from pydantic import BaseModel, ConfigDict
+from ray import actor
+
+from apps.entities.enum_var import EventType, StepStatus
+from apps.entities.flow import Flow, Step
+from apps.entities.pool import NodePool
+from apps.entities.scheduler import CallOutputChunk, CallVars, ExecutorBackground
+from apps.entities.task import FlowStepHistory, TaskBlock
+from apps.manager.node import NodeManager
+from apps.scheduler.call.slot.schema import SlotInput, SlotOutput
+from apps.scheduler.call.slot.slot import Slot
+from apps.scheduler.executor.base import BaseExecutor
+from apps.scheduler.executor.hook import StepHookMixins
+from apps.scheduler.executor.node import StepNode
+
+logger = logging.getLogger("ray")
+SPECIAL_EVENT_TYPES = [
+ EventType.GRAPH,
+ EventType.SUGGEST,
+ EventType.TEXT_ADD,
+]
+
+
+class StepExecutor(BaseExecutor, StepHookMixins):
+ """工作流中步骤相关函数"""
+
+ sys_vars: CallVars
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra="allow",
+ )
+
+ def __init__(self, **kwargs: Any) -> None:
+ """初始化"""
+ super().__init__(**kwargs)
+ if not self.task.flow_state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+ self.step_history = FlowStepHistory(
+ task_id=self.sys_vars.task_id,
+ flow_id=self.sys_vars.flow_id,
+ step_id=self.task.flow_state.step_id,
+ status=self.task.flow_state.status,
+ input_data={},
+ output_data={},
+ )
+
+ async def _push_call_output(self, task: TaskBlock, queue: actor.ActorHandle, msg_type: EventType, data: dict[str, Any]) -> None:
+ """推送输出"""
+ # 如果不是特殊类型
+ if msg_type not in SPECIAL_EVENT_TYPES:
+ err = f"[StepExecutor] 不支持的事件类型: {msg_type}"
+ logger.error(err)
+ raise ValueError(err)
+
+ # 推送消息
+ await queue.push_output.remote(
+ task,
+ event_type=msg_type,
+ data=data,
+ ) # type: ignore[attr-defined]
+
+ async def init_step(self, step_id: str) -> tuple[str, Any]:
+ """初始化步骤"""
+ if not self.task.flow_state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+ logger.info("[FlowExecutor] 运行步骤 %s", self.flow.steps[step_id].name)
+
+ # State写入ID和运行状态
+ self.task.flow_state.step_id = step_id
+ self.task.flow_state.step_name = self.flow.steps[step_id].name
+ self.task.flow_state.status = StepStatus.RUNNING
+
+ # 获取并验证Call类
+ node_id = self.flow.steps[step_id].node
+ # 获取node详情并存储
+ try:
+ self._node_data = await NodeManager.get_node(node_id)
+ except Exception:
+ logger.info("[StepExecutor] 获取Node失败,为内置Node或ID不存在")
+ self._node_data = None
+
+ if self._node_data:
+ call_cls = await StepNode.get_call_cls(self._node_data.call_id)
+ call_id = self._node_data.call_id
+ else:
+ # 可能是特殊的内置Node
+ call_cls = await StepNode.get_call_cls(node_id)
+ call_id = node_id
+
+ # 初始化Call Class,用户参数会覆盖node的参数
+ params: dict[str, Any] = self._node_data.known_params if self._node_data and self._node_data.known_params else {} # type: ignore[union-attr]
+ if self.flow.steps[step_id].params:
+ params.update(self.flow.steps[step_id].params)
+
+ # 检查初始化Call Class是否需要特殊处理
+ try:
+ if call_id in self._HOOK and "before_init" in self._HOOK[call_id]:
+ func = getattr(self, self._HOOK[call_id]["before_init"])
+ call_obj = await func()
+ else:
+ call_obj = call_cls.model_validate(params)
+ except Exception:
+ logger.exception("[StepExecutor] 初始化Call失败")
+ raise
+ else:
+ return call_id, call_obj
+
+
+ async def fill_slots(self, call_obj: Any) -> dict[str, Any]:
+ """执行Call并处理结果"""
+ if not self.task.flow_state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+ # 尝试初始化call_obj
+ try:
+ input_data = await call_obj.init(self.sys_vars)
+ except Exception:
+ logger.exception("[StepExecutor] 初始化Call失败")
+ raise
+
+ # 合并已经填充的参数
+ input_data.update(self.task.runtime_data.filled_data)
+
+ # 检查输入参数
+ input_schema = call_obj.input_type.model_json_schema(override=self._node_data.override_input) if self._node_data else {}
+
+ # 检查输入参数是否需要填充
+ slot_obj = Slot(
+ data=input_data,
+ current_schema=input_schema,
+ summary=self.task.runtime_data.summary,
+ facts=self.task.record.content.facts,
+ )
+ slot_input = SlotInput(**await slot_obj.init(self.sys_vars))
+
+ # 若需要填参,额外执行一步
+ if slot_input.remaining_schema:
+ output = SlotOutput(**await self.run_step("Slot", slot_obj, slot_input.remaining_schema))
+ self.task.runtime_data.filled_data = output.slot_data
+
+ # 如果还有剩余参数,则中断
+ if output.remaining_schema:
+ # 更新状态
+ self.task.flow_state.status = StepStatus.PARAM
+ self.task.flow_state.remaining_schema = output.remaining_schema
+ err = "[StepExecutor] 需要填参"
+ logger.error(err)
+ raise ValueError(err)
+
+ return output.slot_data
+ return input_data
+
+
+ # TODO: 需要评估如何进行流式输出
+ async def run_step(self, call_id: str, call_obj: Any, input_data: dict[str, Any]) -> dict[str, Any]:
+ """运行单个步骤"""
+ iterator = call_obj.exec(input_data)
+ await self.push_message(EventType.STEP_INPUT, input_data)
+
+ result: Optional[CallOutputChunk] = None
+ async for chunk in iterator:
+ result = chunk
+
+ if result is None:
+ content = {}
+ else:
+ content = result.content if isinstance(result.content, dict) else {"message": result.content}
+
+ # 检查Call Class的返回值是否需要特殊处理
+ if call_id in self._HOOK and "after_run" in self._HOOK[call_id]:
+ func = getattr(self, self._HOOK[call_id]["after_run"])
+ await func(result)
+
+ # 更新执行状态
+ self.task.flow_state.status = StepStatus.SUCCESS
+ await self.push_message(EventType.STEP_OUTPUT, content)
+
+ return content
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index 06fb8b37351df87cf087a99f2f433b617ce52260..241e6ecf06a73dea325fd1e67faeeac4cf184af1 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -24,6 +24,11 @@ logger = logging.getLogger("ray")
async def get_docs(user_sub: str, post_body: RequestData) -> tuple[Union[list[RecordDocument], list[Document]], list[str]]:
"""获取当前问答可供关联的文档"""
+ if not post_body.group_id:
+ err = "[Scheduler] 问答组ID不能为空!"
+ logger.error(err)
+ raise ValueError(err)
+
doc_ids = []
docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id)
@@ -41,21 +46,7 @@ async def get_docs(user_sub: str, post_body: RequestData) -> tuple[Union[list[Re
return docs, doc_ids
-async def get_rag_context(user_sub: str, post_body: RequestData, n: int) -> list[dict[str, str]]:
- """获取当前问答的上下文信息"""
- # 组装RAG请求数据,备用
- records = await RecordManager.query_record_by_conversation_id(user_sub, post_body.conversation_id, n + 5)
-
- context = []
- for record in records:
- record_data = RecordContent.model_validate_json(Security.decrypt(record.data, record.key))
- context.append({"role": "user", "content": record_data.question})
- context.append({"role": "assistant", "content": record_data.answer})
-
- return context
-
-
-async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[str, list[str]]:
+async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[list[dict[str, str]], list[str]]:
"""获取当前问答的上下文信息
注意:这里的n要比用户选择的多,因为要考虑事实信息和历史问题
@@ -65,27 +56,17 @@ async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[st
# 获取最后n+5条Record
records = await RecordManager.query_record_by_conversation_id(user_sub, post_body.conversation_id, n + 5)
- # 获取事实信息
- facts = []
- for record in records:
- facts.extend(record.facts)
# 组装问答
- messages = ""
+ context = []
+ facts = []
for record in records:
record_data = RecordContent.model_validate_json(Security.decrypt(record.data, record.key))
+ context.append({"role": "user", "content": record_data.question})
+ context.append({"role": "assistant", "content": record_data.answer})
+ facts.extend(record_data.facts)
- messages += f"""
-
- {record_data.question}
-
-
- {record_data.answer}
-
- """
- messages += ""
-
- return messages, facts
+ return context, facts
async def generate_facts(task: TaskBlock, question: str) -> list[str]:
@@ -115,7 +96,7 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do
# 循环创建FlowHistory
history_data = []
# 遍历查找数据,并添加
- for history_id in task.new_context:
+ for history_id in task.runtime_data.new_context:
for history in task.flow_context.values():
if history.id == history_id:
history_data.append(history)
@@ -125,20 +106,15 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do
# 修改metadata里面时间为实际运行时间
task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time_cost, 2)
- # 提取facts
- # 记忆提取
- facts = await generate_facts(task, post_body.question)
-
# 整理Record数据
record = Record(
record_id=task.record.id,
user_sub=user_sub,
data=encrypt_data,
key=encrypt_config,
- facts=facts,
metadata=task.record.metadata,
created_at=task.record.created_at,
- flow=task.new_context,
+ flow=task.runtime_data.new_context,
)
record_group = task.record.group_id
diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py
index 0abf16875c0795f4f1ed1d6f6a4077267eea155a..f87c133a70e0d9f2611972a2a249619c38eadb6f 100644
--- a/apps/scheduler/scheduler/flow.py
+++ b/apps/scheduler/scheduler/flow.py
@@ -7,10 +7,10 @@ from typing import Optional
import ray
-from apps.constants import RAG_ANSWER_PROMPT
from apps.entities.flow import Flow, FlowError, Step
from apps.entities.task import RequestDataApp
from apps.llm.patterns import Select
+from apps.scheduler.call.llm.schema import RAG_ANSWER_PROMPT
logger = logging.getLogger("ray")
diff --git a/apps/scheduler/scheduler/graph.py b/apps/scheduler/scheduler/graph.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dd562621a16b26f29fb2b4a28b7d0001d42d5a6
--- /dev/null
+++ b/apps/scheduler/scheduler/graph.py
@@ -0,0 +1,2 @@
+"""工作流DAG相关操作"""
+
diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py
index b246bf6046d58b8e3406a7917ccf636355b713ae..2e747118c34330fa81de96283cf9f359b949042e 100644
--- a/apps/scheduler/scheduler/message.py
+++ b/apps/scheduler/scheduler/message.py
@@ -67,7 +67,7 @@ async def push_rag_message(task: TaskBlock, queue: actor.ActorHandle, user_sub:
async for chunk in RAG.get_rag_result(user_sub, rag_data):
task, chunk_content = await _push_rag_chunk(task, queue, chunk)
full_answer += chunk_content
- # FIXME: 这里由于后面没有其他消息,所以只能每个trunk更新task
+ # TODO: 这里由于后面没有其他消息,所以只能每个trunk更新task
await task_actor.set_task.remote(task.record.task_id, task)
# 保存答案
@@ -97,10 +97,11 @@ async def _push_rag_chunk(task: TaskBlock, queue: actor.ActorHandle, content: st
event_type=EventType.TEXT_ADD,
data=TextAddContent(text=content_obj.content).model_dump(exclude_none=True, by_alias=True),
)
- return task, content_obj.content
except Exception:
logger.exception("[Scheduler] RAG服务返回错误数据")
return task, ""
+ else:
+ return task, content_obj.content
async def push_document_message(task: TaskBlock, queue: actor.ActorHandle, doc: Union[RecordDocument, Document]) -> TaskBlock:
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index a088618ef1e5c82da6a40f90e179073bea9171d2..70c4ca0ad3f3a7799ce36c3c910afc2a28fd6517 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -16,8 +16,8 @@ from apps.entities.task import SchedulerResult, TaskBlock
from apps.manager.appcenter import AppCenterManager
from apps.manager.flow import FlowManager
from apps.manager.user import UserManager
-from apps.scheduler.executor.flow import Executor
-from apps.scheduler.scheduler.context import get_context, get_docs, get_rag_context
+from apps.scheduler.executor.flow import FlowExecutor
+from apps.scheduler.scheduler.context import get_context, get_docs
from apps.scheduler.scheduler.flow import FlowChooser
from apps.scheduler.scheduler.message import (
push_document_message,
@@ -60,15 +60,6 @@ class Scheduler:
await queue.close.remote() # type: ignore[attr-defined]
return SchedulerResult(used_docs=[])
- # 组装RAG请求数据,备用
- rag_data = RAGQueryReq(
- question=post_body.question,
- language=post_body.language,
- document_ids=doc_ids,
- kb_sn=None if not user_info.kb_id else user_info.kb_id,
- top_k=5,
- history=await get_rag_context(user_sub, post_body, 3),
- )
# 已使用文档
used_docs = []
@@ -82,7 +73,16 @@ class Scheduler:
task = await push_document_message(task, queue, doc)
await asyncio.sleep(0.1)
- # 保存有数据的最后一条消息
+ # 调用RAG
+ history, _ = await get_context(user_sub, post_body, 3)
+ rag_data = RAGQueryReq(
+ question=post_body.question,
+ language=post_body.language,
+ document_ids=doc_ids,
+ kb_sn=None if not user_info.kb_id else user_info.kb_id,
+ top_k=5,
+ history=history,
+ )
task = await push_rag_message(task, queue, user_sub, rag_data)
else:
# 查找对应的App元数据
@@ -149,7 +149,7 @@ class Scheduler:
# 初始化Executor
logger.info("[Scheduler] 初始化Executor")
- flow_exec = Executor(
+ flow_exec = FlowExecutor(
flow_id=flow_id,
flow=flow_data,
task=task,
diff --git a/apps/scheduler/slot/parser/const.py b/apps/scheduler/slot/parser/const.py
index 9c0329eb47f40a3fb0d0f2fc17f875c30419a0f6..a008fc1d9de6c59e868b9c30bcbefd6f7db5dab6 100644
--- a/apps/scheduler/slot/parser/const.py
+++ b/apps/scheduler/slot/parser/const.py
@@ -4,6 +4,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
from typing import Any
+from jsonschema import Validator
+
from apps.entities.enum_var import SlotType
@@ -21,3 +23,8 @@ class SlotConstParser:
"""
raise NotImplementedError
+ @classmethod
+ def keyword_validate(cls, validator: Validator, keyword: str, instance: Any, schema: dict[str, Any]) -> bool:
+ """生成对应类型的验证器"""
+ ...
+
diff --git a/apps/scheduler/slot/parser/default.py b/apps/scheduler/slot/parser/default.py
index 998ebba92852f686e06eacc0d3324217484fb11f..b8edb293ef91a68fcac65a9abfd308c98f48966c 100644
--- a/apps/scheduler/slot/parser/default.py
+++ b/apps/scheduler/slot/parser/default.py
@@ -4,6 +4,8 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
from typing import Any
+from jsonschema import Validator
+
from apps.entities.enum_var import SlotType
@@ -21,3 +23,8 @@ class SlotDefaultParser:
"""
raise NotImplementedError
+ @classmethod
+ def keyword_validate(cls, validator: Validator, keyword: str, instance: Any, schema: dict[str, Any]) -> bool:
+ """给字段设置默认值"""
+ ...
+
diff --git a/apps/scheduler/slot/slot.py b/apps/scheduler/slot/slot.py
index fd620d6966f147044aee80aedc82554c12871a8d..0d2bd91ca5399ad71e4fafe7250d1f967e0533ed 100644
--- a/apps/scheduler/slot/slot.py
+++ b/apps/scheduler/slot/slot.py
@@ -6,15 +6,13 @@ import json
import logging
import traceback
from collections.abc import Mapping
-from copy import deepcopy
-from typing import Any, Optional, Union
+from typing import Any, Union
from jsonschema import Draft7Validator
from jsonschema.exceptions import ValidationError
from jsonschema.protocols import Validator
from jsonschema.validators import extend
-from apps.llm.patterns.json_gen import Json
from apps.scheduler.slot.parser import (
SlotConstParser,
SlotDateParser,
@@ -32,8 +30,8 @@ _TYPE_CHECKER = [
]
_FORMAT_CHECKER = []
_KEYWORD_CHECKER = {
- "const": SlotConstParser,
- "default": SlotDefaultParser,
+ "const": SlotConstParser.keyword_validate,
+ "default": SlotDefaultParser.keyword_validate,
}
# 各类转换器
@@ -258,52 +256,3 @@ class Slot:
return schema_template
return {}
-
-
- @staticmethod
- async def llm_param_gen(task_id: str, question: str, thought: str, previous_output: Optional[dict[str, Any]], remaining_schema: dict[str, Any]) -> dict[str, Any]:
- """使用LLM生成JSON参数"""
- # 组装工具消息
- conversation = [
- {"role": "user", "content": question},
- {"role": "assistant", "content": thought},
- ]
-
- if previous_output is not None:
- tool_str = f"""我使用了一个工具从其他来源获取额外信息。\
- 工具的输出数据是 `{previous_output}`。
- 输出的schema是 `{json.dumps(previous_output["output_schema"], ensure_ascii=False)}`,其中包含了输出的描述信息。
- """
-
- conversation.append({"role": "tool", "content": tool_str})
-
- return await Json().generate(task_id, conversation=conversation, spec=remaining_schema, strict=False)
-
-
- async def process(self, previous_json: dict[str, Any], new_json: dict[str, Any], llm_params: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
- """对参数槽进行综合处理,返回剩余的JSON Schema和填充后的JSON"""
- # 将用户手动填充的参数专为真实JSON
- slot_data = self.convert_json(new_json)
- # 合并
- result_json = deepcopy(previous_json)
- result_json.update(slot_data)
- # 检测槽位是否合法、是否填充完成
- remaining_slot = self.check_json(result_json)
- # 如果还有未填充的部分,则尝试使用LLM生成
- if remaining_slot:
- generated_slot = await Slot.llm_param_gen(
- llm_params["task_id"],
- llm_params["question"],
- llm_params["thought"],
- llm_params["previous_output"],
- remaining_slot,
- )
- # 合并
- generated_slot = self.convert_json(generated_slot)
- result_json.update(generated_slot)
- # 再次检查槽位
- remaining_slot = self.check_json(result_json)
- return remaining_slot, result_json
-
- return {}, result_json
-
diff --git a/apps/utils/flow.py b/apps/utils/flow.py
index bac16f3f969f3c1d1ebccdb30c3c084233960393..107879384d46261aebadf3be5da8a9e32a77d1a2 100644
--- a/apps/utils/flow.py
+++ b/apps/utils/flow.py
@@ -2,6 +2,7 @@
Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
+import collections
import logging
from apps.entities.enum_var import NodeType
@@ -13,43 +14,59 @@ logger = logging.getLogger("ray")
class FlowService:
"""flow拓扑相关函数"""
+ @staticmethod
+ def _validate_branch_id(node_name: str, branch_id: str, node_branches: set, branch_illegal_chars: str = ".") -> None:
+ """验证分支ID的合法性;当分支ID重复或包含非法字符时抛出异常"""
+ if branch_id in node_branches:
+ err = f"[FlowService] 节点{node_name}的分支{branch_id}重复"
+ logger.error(err)
+ raise Exception(err)
+
+ for illegal_char in branch_illegal_chars:
+ if illegal_char in branch_id:
+ err = f"[FlowService] 节点{node_name}的分支{branch_id}名称中含有非法字符"
+ logger.error(err)
+ raise Exception(err)
+
+ @staticmethod
+ def _process_choice_node(node: NodeItem, node_to_branches: dict) -> None:
+ """处理选择节点的分支信息"""
+ node.parameters = node.parameters["input_parameters"]
+ if "choices" not in node.parameters:
+ node.parameters["choices"] = []
+
+ for branch in node.parameters["choices"]:
+ FlowService._validate_branch_id(node.name, branch["branchId"], node_to_branches[node.step_id])
+ node_to_branches[node.step_id].add(branch["branchId"])
+
+ @staticmethod
+ def _filter_valid_edges(flow_item: FlowItem, node_to_branches: dict) -> list[EdgeItem]:
+ """过滤有效的边"""
+ return [
+ edge for edge in flow_item.edges
+ if (edge.source_node in node_to_branches and
+ edge.target_node in node_to_branches and
+ edge.branch_id in node_to_branches[edge.source_node])
+ ]
+
@staticmethod
async def remove_excess_structure_from_flow(flow_item: FlowItem) -> FlowItem:
"""移除流程图中的多余结构"""
node_to_branches = {}
branch_illegal_chars = "."
+ # 处理所有节点的分支信息
for node in flow_item.nodes:
- node_to_branches[node.step_id] = set()
if node.call_id == NodeType.CHOICE.value:
- node.parameters = node.parameters["input_parameters"]
- if "choices" not in node.parameters:
- node.parameters["choices"] = []
- for branch in node.parameters["choices"]:
- if branch["branchId"] in node_to_branches[node.step_id]:
- err = f"[FlowService] 节点{node.name}的分支{branch['branchId']}重复"
- logger.error(err)
- raise Exception(err)
- for illegal_char in branch_illegal_chars:
- if illegal_char in branch["branchId"]:
- err = f"[FlowService] 节点{node.name}的分支{branch['branchId']}名称中含有非法字符"
- logger.error(err)
- raise Exception(err)
- node_to_branches[node.step_id].add(branch["branchId"])
+ FlowService._process_choice_node(node, node_to_branches)
else:
node_to_branches[node.step_id].add("")
- new_edges_items = []
- for edge in flow_item.edges:
- if edge.source_node not in node_to_branches:
- continue
- if edge.target_node not in node_to_branches:
- continue
- if edge.branch_id not in node_to_branches[edge.source_node]:
- continue
- new_edges_items.append(edge)
- flow_item.edges = new_edges_items
+
+ # 过滤有效的边
+ flow_item.edges = FlowService._filter_valid_edges(flow_item, node_to_branches)
return flow_item
+
@staticmethod
async def _validate_node_ids(nodes: list[NodeItem]) -> tuple[str, str]:
"""验证节点ID的唯一性并获取起始和终止节点ID,当节点ID重复或起始/终止节点数量不为1时抛出异常"""
@@ -84,6 +101,21 @@ class FlowService:
return start_id, end_id
+ @staticmethod
+ async def validate_flow_illegal(flow_item: FlowItem) -> tuple[str, str]:
+ """验证流程图是否合法;当流程图不合法时抛出异常"""
+ # 验证节点ID并获取起始和终止节点
+ start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes)
+
+ # 验证边的合法性并获取节点的入度和出度
+ in_deg, out_deg = await FlowService._validate_edges(flow_item.edges)
+
+ # 验证起始和终止节点的入度和出度
+ await FlowService._validate_node_degrees(start_id, end_id, in_deg, out_deg)
+
+ return start_id, end_id
+
+
@staticmethod
async def _validate_edges(edges: list[EdgeItem]) -> tuple[dict[str, int], dict[str, int]]:
"""验证边的合法性并计算节点的入度和出度;当边的ID重复、起始终止节点相同或分支重复时抛出异常"""
@@ -117,6 +149,7 @@ class FlowService:
return in_deg, out_deg
+
@staticmethod
async def _validate_node_degrees(start_id: str, end_id: str,
in_deg: dict[str, int], out_deg: dict[str, int]) -> None:
@@ -130,17 +163,46 @@ class FlowService:
logger.error(err)
raise Exception(err)
+
@staticmethod
- async def validate_flow_illegal(flow_item: FlowItem) -> None:
- """验证流程图是否合法;当流程图不合法时抛出异常"""
- # 验证节点ID并获取起始和终止节点
- start_id, end_id = await FlowService._validate_node_ids(flow_item.nodes)
+ def _find_start_node_id(nodes: list[NodeItem]) -> str:
+ """查找起始节点ID"""
+ for node in nodes:
+ if node.call_id == NodeType.START.value:
+ return node.step_id
+ return ""
- # 验证边的合法性并获取节点的入度和出度
- in_deg, out_deg = await FlowService._validate_edges(flow_item.edges)
- # 验证起始和终止节点的入度和出度
- await FlowService._validate_node_degrees(start_id, end_id, in_deg, out_deg)
+ @staticmethod
+ def _build_adjacency_list(edges: list[EdgeItem]) -> dict[str, list[str]]:
+ """构建邻接表"""
+ adj_list = {}
+ for edge in edges:
+ if edge.source_node not in adj_list:
+ adj_list[edge.source_node] = []
+ adj_list[edge.source_node].append(edge.target_node)
+ return adj_list
+
+
+ @staticmethod
+ def _bfs_traverse(start_id: str, adj_list: dict[str, list[str]]) -> set[str]:
+ """使用BFS遍历图并返回可达节点集合"""
+ visited = set()
+ if not start_id:
+ return visited
+
+ queue = collections.deque([start_id])
+ visited.add(start_id)
+
+ while queue:
+ current = queue.popleft()
+ if current in adj_list:
+ for neighbor in adj_list[current]:
+ if neighbor not in visited:
+ visited.add(neighbor)
+ queue.append(neighbor)
+ return visited
+
@staticmethod
async def validate_flow_connectivity(flow_item: FlowItem) -> bool:
@@ -188,6 +250,7 @@ class FlowService:
return end_id in vis
+
def generate_from_schema(schema: dict) -> dict:
"""根据 JSON Schema 生成示例 JSON 数据。