diff --git a/apps/scheduler/call/render/__init__.py b/apps/scheduler/call/graph/__init__.py
similarity index 100%
rename from apps/scheduler/call/render/__init__.py
rename to apps/scheduler/call/graph/__init__.py
diff --git a/apps/scheduler/call/render/render.py b/apps/scheduler/call/graph/graph.py
similarity index 44%
rename from apps/scheduler/call/render/render.py
rename to apps/scheduler/call/graph/graph.py
index 09bc3b8a05771a1317dc03137db38f290df83505..fa07e12b1004e1f8accb3a23c12856edb205f7e0 100644
--- a/apps/scheduler/call/render/render.py
+++ b/apps/scheduler/call/graph/graph.py
@@ -1,123 +1,98 @@
-"""Call: Render,用于将SQL Tool查询出的数据转换为图表
+"""
+Call: Render,用于将SQL Tool查询出的数据转换为图表
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""
+
import json
-from pathlib import Path
-from typing import Any
+from collections.abc import AsyncGenerator
+from typing import Any, ClassVar
-from pydantic import BaseModel, Field
+from anyio import Path
+from pydantic import Field
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
from apps.scheduler.call.core import CoreCall
-from apps.scheduler.call.render.style import RenderStyle
-
-
-class _RenderAxis(BaseModel):
- """ECharts图表的轴配置"""
-
- type: str = Field(description="轴的类型")
- axisTick: dict = Field(description="轴刻度配置") # noqa: N815
-
-
-class _RenderFormat(BaseModel):
- """ECharts图表配置"""
-
- tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置")
- legend: dict[str, Any] = Field(description="ECharts图表的图例配置")
- dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置")
- xAxis: _RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815
- yAxis: _RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815
- series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置")
-
-
-class _RenderOutput(BaseModel):
- """Render工具的输出"""
-
- output: _RenderFormat = Field(description="ECharts图表配置")
- message: str = Field(description="Render工具的输出")
-
+from apps.scheduler.call.graph.schema import RenderFormat, RenderInput, RenderOutput
+from apps.scheduler.call.graph.style import RenderStyle
-class _RenderParam(BaseModel):
- """Render工具的参数"""
-
-
-class Render(metaclass=CoreCall, param_cls=_RenderParam, output_cls=_RenderOutput):
+class Graph(CoreCall, output_type=RenderOutput):
"""Render Call,用于将SQL Tool查询出的数据转换为图表"""
- name: str = "render"
- description: str = "渲染图表工具,可将给定的数据绘制为图表。"
+ name: ClassVar[str] = Field("图表", exclude=True, frozen=True)
+ description: ClassVar[str] = Field("将SQL查询出的数据转换为图表", exclude=True, frozen=True)
+
+ data: list[dict[str, Any]] = Field(description="用于绘制图表的数据", exclude=True, frozen=True)
- def init(self, _syscall_vars: CallVars, **_kwargs) -> None: # noqa: ANN003
+ async def _init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化Render Call,校验参数,读取option模板"""
+ await super()._init(syscall_vars)
+
try:
option_location = Path(__file__).parent / "option.json"
- with Path(option_location).open(encoding="utf-8") as f:
- self._option_template = json.load(f)
+ f = await Path(option_location).open(encoding="utf-8")
+ data = await f.read()
+ self._option_template = json.loads(data)
+ await f.aclose()
except Exception as e:
raise CallError(message=f"图表模板读取失败:{e!s}", data={}) from e
+ return RenderInput(
+ question=syscall_vars.question,
+ task_id=syscall_vars.task_id,
+ data=self.data,
+ ).model_dump(exclude_none=True, by_alias=True)
- async def __call__(self, _slot_data: dict[str, Any]) -> _RenderOutput:
- """运行Render Call"""
- # 获取必要参数
- syscall_vars: CallVars = getattr(self, "_syscall_vars")
-
- # 检测前一个工具是否为SQL
- if "dataset" not in syscall_vars.history[-1].output_data:
- raise CallError(
- message="图表生成失败!Render必须在SQL后调用!",
- data={},
- )
- data = syscall_vars.history[-1].output_data["output"]
- if data["type"] != "sql" or "dataset" not in data:
- raise CallError(
- message="图表生成失败!Render必须在SQL后调用!",
- data={},
- )
- data = json.loads(data["dataset"])
+ async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """运行Render Call"""
+ data = RenderInput(**input_data)
# 判断数据格式是否满足要求
# 样例:[{'openeuler_version': 'openEuler-22.03-LTS-SP2', '软件数量': 10}]
malformed = True
- if isinstance(data, list) and len(data) > 0 and isinstance(data[0], dict):
+ if isinstance(data.data, list) and len(data.data) > 0 and isinstance(data.data[0], dict):
malformed = False
# 将执行SQL工具查询到的数据转换为特定格式
if malformed:
raise CallError(
- message="SQL未查询到数据,或数据格式错误,无法生成图表!",
- data={"data": data},
+ message="数据格式错误,无法生成图表!",
+ data={"data": data.data},
)
# 对少量数据进行处理
- column_num = len(data[0]) - 1
+ processed_data = data.data
+ column_num = len(processed_data[0]) - 1
if column_num == 0:
- data = Render._separate_key_value(data)
+ processed_data = Graph._separate_key_value(processed_data)
column_num = 1
# 该格式满足ECharts DataSource要求,与option模板进行拼接
- self._option_template["dataset"]["source"] = data
+ self._option_template["dataset"]["source"] = processed_data
try:
- llm_output = await RenderStyle().generate(syscall_vars.task_id, question=syscall_vars.question)
+ llm_output = await RenderStyle().generate(data.task_id, question=data.question)
add_style = llm_output.get("additional_style", "")
self._parse_options(column_num, llm_output["chart_type"], add_style, llm_output["scale_type"])
except Exception as e:
raise CallError(message=f"图表生成失败:{e!s}", data={"data": data}) from e
- return _RenderOutput(
- output=_RenderFormat.model_validate(self._option_template),
- message="图表生成成功!图表将使用外置工具进行展示。",
+ yield CallOutputChunk(
+ type=CallOutputType.DATA,
+ content=RenderOutput(
+ output=RenderFormat.model_validate(self._option_template),
+ ).model_dump(exclude_none=True, by_alias=True),
)
@staticmethod
def _separate_key_value(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
- """若数据只有一组(例如:{"aaa": "bbb"}),则分离键值对。
+ """
+ 若数据只有一组(例如:{"aaa": "bbb"}),则分离键值对。
样例:{"type": "aaa", "value": "bbb"}
diff --git a/apps/scheduler/call/render/option.json b/apps/scheduler/call/graph/option.json
similarity index 100%
rename from apps/scheduler/call/render/option.json
rename to apps/scheduler/call/graph/option.json
diff --git a/apps/scheduler/call/graph/schema.py b/apps/scheduler/call/graph/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f26c5224997c5b0da4771a202298cab391f1187
--- /dev/null
+++ b/apps/scheduler/call/graph/schema.py
@@ -0,0 +1,39 @@
+"""图表工具的输入输出"""
+
+from typing import Any
+
+from pydantic import BaseModel, Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class RenderAxis(BaseModel):
+ """ECharts图表的轴配置"""
+
+ type: str = Field(description="轴的类型")
+ axisTick: dict = Field(description="轴刻度配置") # noqa: N815
+
+
+class RenderFormat(BaseModel):
+ """ECharts图表配置"""
+
+ tooltip: dict[str, Any] = Field(description="ECharts图表的提示框配置")
+ legend: dict[str, Any] = Field(description="ECharts图表的图例配置")
+ dataset: dict[str, Any] = Field(description="ECharts图表的数据集配置")
+ xAxis: RenderAxis = Field(description="ECharts图表的X轴配置") # noqa: N815
+ yAxis: RenderAxis = Field(description="ECharts图表的Y轴配置") # noqa: N815
+ series: list[dict[str, Any]] = Field(description="ECharts图表的数据列配置")
+
+
+class RenderInput(DataBase):
+ """图表工具的输入"""
+
+ question: str = Field(description="用户输入")
+ task_id: str = Field(description="任务ID")
+ data: list[dict[str, Any]] = Field(description="图表数据")
+
+
+class RenderOutput(DataBase):
+ """Render工具的输出"""
+
+ output: RenderFormat = Field(description="ECharts图表配置")
diff --git a/apps/scheduler/call/render/style.py b/apps/scheduler/call/graph/style.py
similarity index 92%
rename from apps/scheduler/call/render/style.py
rename to apps/scheduler/call/graph/style.py
index 2fe809a948ee17e8ffb50a7cbaaa96f06273d6cc..5040be19e171215bcd86534dec41293fb23df267 100644
--- a/apps/scheduler/call/render/style.py
+++ b/apps/scheduler/call/graph/style.py
@@ -1,8 +1,10 @@
-"""Render Call: 选择图表样式
+"""
+Render Call: 选择图表样式
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""
-from typing import Any, ClassVar, Optional
+
+from typing import Any, ClassVar
from apps.llm.patterns.core import CorePattern
from apps.llm.patterns.json_gen import Json
@@ -82,7 +84,7 @@ class RenderStyle(CorePattern):
"required": ["chart_type", "scale_type"],
}
- def __init__(self, system_prompt: Optional[str] = None, user_prompt: Optional[str] = None) -> None:
+ def __init__(self, system_prompt: str | None = None, user_prompt: str | None = None) -> None:
"""初始化RenderStyle Prompt"""
super().__init__(system_prompt, user_prompt)
diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm/llm.py
similarity index 57%
rename from apps/scheduler/call/llm.py
rename to apps/scheduler/call/llm/llm.py
index 290bf699056150fd04e7675479fa3df79a0cd986..ceb163eb6b5e51cfa6b7e757d2a1ab90a628e71d 100644
--- a/apps/scheduler/call/llm.py
+++ b/apps/scheduler/call/llm/llm.py
@@ -1,54 +1,54 @@
-"""工具:调用大模型
+"""
+工具:调用大模型
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""
+
import logging
from collections.abc import AsyncGenerator
from datetime import datetime
-from typing import Any, ClassVar
+from typing import Annotated, Any, ClassVar
import pytz
-from jinja2 import BaseLoader, select_autoescape
+from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
-from pydantic import BaseModel, Field
+from pydantic import Field
-from apps.constants import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT
-from apps.entities.scheduler import CallError, CallVars
+from apps.entities.enum_var import CallOutputType
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.llm.schema import LLM_CONTEXT_PROMPT, LLM_DEFAULT_PROMPT, LLMInput, LLMOutput
-logger = logging.getLogger("ray")
-class LLMNodeOutput(BaseModel):
- """定义LLM工具调用的输出"""
-
- message: str = Field(description="大模型输出的文字信息")
+logger = logging.getLogger(__name__)
-class LLM(CoreCall, ret_type=LLMNodeOutput):
+class LLM(CoreCall, input_type=LLMInput, output_type=LLMOutput):
"""大模型调用工具"""
- name: ClassVar[str] = "大模型"
- description: ClassVar[str] = "以指定的提示词和上下文信息调用大模型,并获得输出。"
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "大模型"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = (
+ "以指定的提示词和上下文信息调用大模型,并获得输出。"
+ )
temperature: float = Field(description="大模型温度(随机化程度)", default=0.7)
enable_context: bool = Field(description="是否启用上下文", default=True)
- step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3)
+ step_history_size: int = Field(description="上下文信息中包含的步骤历史数量", default=3, ge=1, le=10)
system_prompt: str = Field(description="大模型系统提示词", default="")
user_prompt: str = Field(description="大模型用户提示词", default=LLM_DEFAULT_PROMPT)
-
async def _prepare_message(self, syscall_vars: CallVars) -> list[dict[str, Any]]:
"""准备消息"""
# 创建共享的 Environment 实例
env = SandboxedEnvironment(
loader=BaseLoader(),
- autoescape=select_autoescape(),
+ autoescape=False,
trim_blocks=True,
lstrip_blocks=True,
)
# 上下文信息
- step_history = list(syscall_vars.history.values())[-self.step_history_size:]
+ step_history = list(syscall_vars.history.values())[-self.step_history_size :]
if self.enable_context:
context_tmpl = env.from_string(LLM_CONTEXT_PROMPT)
context_prompt = context_tmpl.render(
@@ -82,34 +82,23 @@ class LLM(CoreCall, ret_type=LLMNodeOutput):
{"role": "user", "content": user_input},
]
-
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
+ async def _init(self, syscall_vars: CallVars) -> dict[str, Any]:
"""初始化LLM工具"""
- self._message = await self._prepare_message(syscall_vars)
- self._task_id = syscall_vars.task_id
- return {
- "task_id": self._task_id,
- "message": self._message,
- }
+ await super()._init(syscall_vars)
+ return LLMInput(
+ task_id=syscall_vars.task_id,
+ message=await self._prepare_message(syscall_vars),
+ ).model_dump(exclude_none=True, by_alias=True)
- async def exec(self) -> dict[str, Any]:
+ async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
"""运行LLM Call"""
+ data = LLMInput(**input_data)
try:
- result = ""
- async for chunk in ReasoningLLM().call(task_id=self._task_id, messages=self._message):
- result += chunk
+ # TODO: 临时改成了非流式
+ async for chunk in ReasoningLLM().call(task_id=data.task_id, messages=data.message, streaming=False):
+ if not chunk:
+ continue
+ yield CallOutputChunk(type=CallOutputType.TEXT, content=chunk)
except Exception as e:
raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e
-
- result = result.strip().strip("\n")
- return LLMNodeOutput(message=result).model_dump(exclude_none=True, by_alias=True)
-
-
- async def stream(self) -> AsyncGenerator[str, None]:
- """流式输出"""
- try:
- async for chunk in ReasoningLLM().call(task_id=self._task_id, messages=self._message):
- yield chunk
- except Exception as e:
- raise CallError(message=f"大模型流式输出失败:{e!s}", data={}) from e
diff --git a/apps/scheduler/call/llm/schema.py b/apps/scheduler/call/llm/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b1fe30cd0c0b2761ca8dccf624e99bc9fbf2d5
--- /dev/null
+++ b/apps/scheduler/call/llm/schema.py
@@ -0,0 +1,107 @@
+"""LLM工具的输入输出定义"""
+
+from textwrap import dedent
+
+from pydantic import Field
+
+from apps.scheduler.call.core import DataBase
+
+LLM_CONTEXT_PROMPT = dedent(
+ r"""
+ 以下是对用户和AI间对话的简短总结:
+ {{ summary }}
+
+ 你作为AI,在回答用户的问题前,需要获取必要信息。为此,你调用了以下工具,并获得了输出:
+
+ {% for tool in history_data %}
+ {{ tool.step_id }}
+
+ {% endfor %}
+
+ """,
+).strip("\n")
+LLM_DEFAULT_PROMPT = dedent(
+ r"""
+
+ 你是一个乐于助人的智能助手。请结合给出的背景信息, 回答用户的提问。
+ 当前时间:{{ time }},可以作为时间参照。
+ 用户的问题将在中给出,上下文背景信息将在中给出。
+ 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息直接作答。
+
+
+
+ {{ question }}
+
+
+
+ {{ context }}
+
+
+ 现在,输出你的回答:
+ """,
+).strip("\n")
+LLM_ERROR_PROMPT = dedent(
+ r"""
+
+ 你是一位智能助手,能够根据用户的问题,使用Python工具获取信息,并作出回答。你在使用工具解决回答用户的问题时,发生了错误。
+ 你的任务是:分析工具(Python程序)的异常信息,分析造成该异常可能的原因,并以通俗易懂的方式,将原因告知用户。
+
+ 当前时间:{{ time }},可以作为时间参照。
+ 发生错误的程序异常信息将在中给出,用户的问题将在中给出,上下文背景信息将在中给出。
+ 注意:输出不要包含任何XML标签,不要编造任何信息。若你认为用户提问与背景信息无关,请忽略背景信息。
+
+
+
+ {{ error_info }}
+
+
+
+ {{ question }}
+
+
+
+ {{ context }}
+
+
+ 现在,输出你的回答:
+ """,
+).strip("\n")
+RAG_ANSWER_PROMPT = dedent(
+ r"""
+
+ 你是由openEuler社区构建的大型语言AI助手。请根据背景信息(包含对话上下文和文档片段),回答用户问题。
+ 用户的问题将在中给出,上下文背景信息将在中给出,文档片段将在中给出。
+
+ 注意事项:
+ 1. 输出不要包含任何XML标签。请确保输出内容的正确性,不要编造任何信息。
+ 2. 如果用户询问你关于你自己的问题,请统一回答:“我叫EulerCopilot,是openEuler社区的智能助手”。
+ 3. 背景信息仅供参考,若背景信息与用户问题无关,请忽略背景信息直接作答。
+ 4. 请在回答中使用Markdown格式,并**不要**将内容放在"```"中。
+
+
+
+ {{ question }}
+
+
+
+ {{ context }}
+
+
+
+ {{ document }}
+
+
+ 现在,请根据上述信息,回答用户的问题:
+ """,
+).strip("\n")
+
+
+class LLMInput(DataBase):
+ """定义LLM工具调用的输入"""
+
+ task_id: str = Field(description="任务ID")
+ message: list[dict[str, str]] = Field(description="输入给大模型的消息列表")
+
+
+class LLMOutput(DataBase):
+ """定义LLM工具调用的输出"""
diff --git a/apps/scheduler/call/suggest.py b/apps/scheduler/call/suggest.py
deleted file mode 100644
index 96181c758984594297ffcc5c953d12c2ce089881..0000000000000000000000000000000000000000
--- a/apps/scheduler/call/suggest.py
+++ /dev/null
@@ -1,83 +0,0 @@
-"""用于问题推荐的工具
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-"""
-from typing import Any, ClassVar, Optional
-
-import ray
-from pydantic import BaseModel, Field
-
-from apps.entities.scheduler import CallError, CallVars
-from apps.manager.user_domain import UserDomainManager
-from apps.scheduler.call.core import CoreCall
-
-
-class _SingleFlowSuggestionConfig(BaseModel):
- """涉及单个Flow的问题推荐配置"""
-
- flow_id: str
- question: Optional[str] = Field(default=None, description="固定的推荐问题")
-
-
-class _SuggestionOutputItem(BaseModel):
- """问题推荐结果的单个条目"""
-
- question: str
- app_id: str
- flow_id: str
- flow_description: str
-
-
-class SuggestionOutput(BaseModel):
- """问题推荐结果"""
-
- output: list[_SuggestionOutputItem]
-
-
-class Suggestion(CoreCall, ret_type=SuggestionOutput):
- """问题推荐"""
-
- name: ClassVar[str] = "问题推荐"
- description: ClassVar[str] = "在答案下方显示推荐的下一个问题"
-
- configs: list[_SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1)
- num: int = Field(default=3, ge=1, le=6, description="推荐问题的总数量(必须大于等于configs中涉及的Flow的数量)")
-
-
- async def init(self, syscall_vars: CallVars, **_kwargs: Any) -> dict[str, Any]:
- """初始化"""
- # 普通变量
- self._question = syscall_vars.question
- self._task_id = syscall_vars.task_id
- self._background = syscall_vars.summary
- self._user_sub = syscall_vars.user_sub
-
- return {
- "question": self._question,
- "task_id": self._task_id,
- "background": self._background,
- "user_sub": self._user_sub,
- }
-
-
- async def exec(self) -> dict[str, Any]:
- """运行问题推荐"""
- # 获取当前任务
- task_actor = await ray.get_actor("task")
- task = await task_actor.get_task.remote(self._task_id)
-
- # 获取当前用户的画像
- user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(self._user_sub, 5)
-
- current_record = [
- {
- "role": "user",
- "content": task.record.content.question,
- },
- {
- "role": "assistant",
- "content": task.record.content.answer,
- },
- ]
-
- return {}
diff --git a/apps/scheduler/call/suggest/schema.py b/apps/scheduler/call/suggest/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..bae027d066815e86c5ad91c8999e66539ad9e8be
--- /dev/null
+++ b/apps/scheduler/call/suggest/schema.py
@@ -0,0 +1,35 @@
+"""问题推荐工具的输入输出"""
+
+from pydantic import BaseModel, Field
+
+from apps.scheduler.call.core import DataBase
+
+
+class SingleFlowSuggestionConfig(BaseModel):
+ """涉及单个Flow的问题推荐配置"""
+
+ flow_id: str
+ question: str | None = Field(default=None, description="固定的推荐问题")
+
+
+class SuggestionOutputItem(BaseModel):
+ """问题推荐结果的单个条目"""
+
+ question: str
+ app_id: str
+ flow_id: str
+ flow_description: str
+
+
+class SuggestionInput(DataBase):
+ """问题推荐输入"""
+
+ question: str
+ task_id: str
+ user_sub: str
+
+
+class SuggestionOutput(DataBase):
+ """问题推荐结果"""
+
+ output: list[SuggestionOutputItem]
diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf03b10dbd34428329b9936f70fb2c485ca979e7
--- /dev/null
+++ b/apps/scheduler/call/suggest/suggest.py
@@ -0,0 +1,66 @@
+"""
+用于问题推荐的工具
+
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
+"""
+
+from collections.abc import AsyncGenerator
+from typing import Annotated, Any, ClassVar
+
+from pydantic import Field
+
+from apps.entities.scheduler import CallError, CallOutputChunk, CallVars
+from apps.manager.task import TaskManager
+from apps.manager.user_domain import UserDomainManager
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.suggest.schema import (
+ SingleFlowSuggestionConfig,
+ SuggestionInput,
+ SuggestionOutput,
+)
+
+
+class Suggestion(CoreCall, input_type=SuggestionInput, output_type=SuggestionOutput):
+ """问题推荐"""
+
+ name: ClassVar[Annotated[str, Field(description="工具名称", exclude=True, frozen=True)]] = "问题推荐"
+ description: ClassVar[Annotated[str, Field(description="工具描述", exclude=True, frozen=True)]] = (
+ "在答案下方显示推荐的下一个问题"
+ )
+
+ configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置", min_length=1)
+ num: int = Field(default=3, ge=1, le=6, description="推荐问题的总数量(必须大于等于configs中涉及的Flow的数量)")
+
+ context: list[dict[str, str]]
+
+ async def _init(self, syscall_vars: CallVars) -> dict[str, Any]:
+ """初始化"""
+ await super()._init(syscall_vars)
+
+ return SuggestionInput(
+ question=syscall_vars.question,
+ task_id=syscall_vars.task_id,
+ user_sub=syscall_vars.user_sub,
+ ).model_dump(by_alias=True, exclude_none=True)
+
+ async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]:
+ """运行问题推荐"""
+ data = SuggestionInput(**input_data)
+ # 获取当前任务
+ task_block = await TaskManager.get_task(task_id=data.task_id)
+
+ # 获取当前用户的画像
+ user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5)
+
+ current_record = [
+ {
+ "role": "user",
+ "content": task_block.record.content.question,
+ },
+ {
+ "role": "assistant",
+ "content": task_block.record.content.answer,
+ },
+ ]
+
+ yield CallOutputChunk(content="")
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index 8e3f4c5ffe0a6c416f02d89aaa7b09982d793b64..941d71369a0b68af2bd42da8fca8c12d76de11aa 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -1,50 +1,67 @@
-"""Flow执行Executor
+"""
+Flow执行Executor
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""
-import asyncio
-import inspect
+
import logging
-from typing import Any, Optional
-import ray
-from pydantic import BaseModel, ConfigDict, Field
-from ray import actor
+from pydantic import ConfigDict, Field
-from apps.entities.enum_var import StepStatus
-from apps.entities.flow import Flow, FlowError, Step
+from apps.entities.enum_var import EventType, StepStatus
+from apps.entities.flow import Step
from apps.entities.request_data import RequestDataApp
-from apps.entities.scheduler import CallVars, ExecutorBackground
-from apps.entities.task import ExecutorState, TaskBlock
-from apps.llm.patterns.executor import ExecutorSummary
-from apps.manager.node import NodeManager
+from apps.entities.scheduler import CallVars
+from apps.entities.task import ExecutorState
from apps.manager.task import TaskManager
-from apps.scheduler.call.core import CoreCall
-from apps.scheduler.call.empty import Empty
-from apps.scheduler.call.llm import LLM
-from apps.scheduler.executor.message import (
- push_flow_start,
- push_flow_stop,
- push_step_input,
- push_step_output,
- push_text_output,
+from apps.scheduler.call.llm.schema import LLM_ERROR_PROMPT
+from apps.scheduler.executor.base import BaseExecutor
+from apps.scheduler.executor.step import StepExecutor
+
+logger = logging.getLogger(__name__)
+FIXED_STEPS_BEFORE_START = {
+ "_summary": Step(
+ name="理解上下文",
+ description="使用大模型,理解对话上下文",
+ node="Summary",
+ type="Summary",
+ params={},
+ ),
+}
+FIXED_STEPS_AFTER_END = {
+ "_facts": Step(
+ name="记忆存储",
+ description="理解对话答案,并存储到记忆中",
+ node="Facts",
+ type="Facts",
+ params={},
+ ),
+}
+SLOT_FILLING_STEP = Step(
+ name="填充参数",
+ description="根据步骤历史,填充参数",
+ node="Slot",
+ type="Slot",
+ params={},
+)
+ERROR_STEP = Step(
+ name="错误处理",
+ description="错误处理",
+ node="LLM",
+ type="LLM",
+ params={
+ "user_prompt": LLM_ERROR_PROMPT,
+ },
)
-from apps.scheduler.slot.slot import Slot
-
-logger = logging.getLogger("ray")
# 单个流的执行工具
-class Executor(BaseModel):
+class FlowExecutor(BaseExecutor):
"""用于执行工作流的Executor"""
flow_id: str = Field(description="Flow ID")
- flow: Flow = Field(description="工作流数据")
- task: TaskBlock = Field(description="任务信息")
question: str = Field(description="用户输入")
- queue: actor.ActorHandle = Field(description="消息队列")
post_body_app: RequestDataApp = Field(description="请求体中的app信息")
- executor_background: ExecutorBackground = Field(description="Executor的背景信息")
model_config = ConfigDict(
arbitrary_types_allowed=True,
@@ -52,258 +69,150 @@ class Executor(BaseModel):
)
"""Pydantic配置"""
-
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
- if self.task.flow_state:
- self.flow_state = self.task.flow_state
- self.task.flow_context = await TaskManager.get_flow_history_by_task_id(self.task.record.task_id)
+ if self.task.state:
+ self.task.context = await TaskManager.get_flow_history_by_task_id(self.task.id)
else:
# 创建ExecutorState
- self.flow_state = ExecutorState(
+ self.task.state = ExecutorState(
flow_id=str(self.flow_id),
description=str(self.flow.description),
status=StepStatus.RUNNING,
app_id=str(self.post_body_app.app_id),
step_id="start",
step_name="开始",
- ai_summary="",
- filled_data=self.post_body_app.params,
)
- # 是否结束运行
- self._can_continue = True
- # 向用户输出的内容
- self._final_answer = ""
-
-
- async def _check_cls(self, call_cls: Any) -> bool:
- """检查Call是否符合标准要求"""
- flag = True
- if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str):
- flag = False
- if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str):
- flag = False
- if not hasattr(call_cls, "exec") or not asyncio.iscoroutinefunction(call_cls.exec):
- flag = False
- if not hasattr(call_cls, "stream") or not inspect.isasyncgenfunction(call_cls.stream):
- flag = False
- return flag
-
-
- # TODO: 默认错误处理步骤
- async def _run_error(self, step: FlowError) -> dict[str, Any]:
- """运行错误处理步骤"""
- return {}
-
-
- async def _get_call_cls(self, node_id: str) -> type[CoreCall]:
- """获取并验证Call类"""
- # 检查flow_state是否为空
- if not self.flow_state:
- err = "[FlowExecutor] flow_state为空"
+ # 是否到达Flow结束终点(变量)
+ self._reached_end: bool = False
+
+ async def _invoke_runner(self, step_id: str) -> None:
+ """调用Runner"""
+ if not self.task.state:
+ err = "[FlowExecutor] 当前ExecutorState为空"
+ logger.error(err)
raise ValueError(err)
- # 特判
- if node_id == "Empty":
- return Empty
-
- # 获取对应Node的call_id
- call_id = await NodeManager.get_node_call_id(node_id)
- # 从Pool中获取对应的Call
- pool = ray.get_actor("pool")
- call_cls: type[CoreCall] = await pool.get_call.remote(call_id)
-
- # 检查Call合法性
- if not await self._check_cls(call_cls):
- err = f"[FlowExecutor] 工具 {node_id} 不符合Call标准要求"
- raise ValueError(err)
-
- return call_cls
-
-
- async def _process_slots(self, call_obj: Any) -> tuple[bool, Optional[dict[str, Any]]]:
- """处理slot参数"""
- if not (hasattr(call_obj, "slot_schema") and call_obj.slot_schema):
- return True, None
-
- slot_processor = Slot(call_obj.slot_schema)
- remaining_schema, slot_data = await slot_processor.process(
- self.flow_state.filled_data,
- self.post_body_app.params,
- {
- "task_id": self.task.record.task_id,
- "question": self.question,
- "thought": self.flow_state.ai_summary,
- "previous_output": await self._get_last_output(self.task),
- },
- )
-
- # 保存Schema至State
- self.flow_state.remaining_schema = remaining_schema
- self.flow_state.filled_data.update(slot_data)
-
- # 如果还有未填充的部分,则返回False
- if remaining_schema:
- self._can_continue = False
- self.flow_state.status = StepStatus.RUNNING
- # 推送空输入输出
- self.task = await push_step_input(self.task, self.queue, self.flow_state, {})
- self.flow_state.status = StepStatus.PARAM
- self.task = await push_step_output(self.task, self.queue, self.flow_state, {})
- return False, None
-
- return True, slot_data
-
-
- # TODO
- async def _get_last_output(self, task: TaskBlock) -> Optional[dict[str, Any]]:
- """获取上一步的输出"""
- return None
-
-
- async def _execute_call(self, call_obj: Any, *, is_final_answer: bool) -> dict[str, Any]:
- """执行Call并处理结果"""
- # call_obj一定合法;开始判断是否为最终结果
- if is_final_answer and isinstance(call_obj, LLM):
- # 最后一步 & 是大模型步骤,直接流式输出
- async for chunk in call_obj.stream():
- await push_text_output(self.task, self.queue, chunk)
- self._final_answer += chunk
- self.flow_state.status = StepStatus.SUCCESS
- return {
- "message": self._final_answer,
- }
-
- # 其他情况:先运行步骤,得到结果
- result: dict[str, Any] = await call_obj.exec()
- if is_final_answer:
- if call_obj.name == "Convert":
- # 如果是Convert,直接输出转换之后的结果
- self._final_answer += result["message"]
- await push_text_output(self.task, self.queue, self._final_answer)
- self.flow_state.status = StepStatus.SUCCESS
- return {
- "message": self._final_answer,
- }
- # 其他工具,加一步大模型过程
- # FIXME: 使用单独的Prompt
- self.flow_state.status = StepStatus.SUCCESS
- return {
- "message": self._final_answer,
- }
-
- # 其他情况:返回结果
- self.flow_state.status = StepStatus.SUCCESS
- return result
-
-
- async def _run_step(self, step_id: str, step_data: Step) -> None:
- """运行单个步骤"""
- logger.info("[FlowExecutor] 运行步骤 %s", step_data.name)
-
- # State写入ID和运行状态
- self.flow_state.step_id = step_id
- self.flow_state.step_name = step_data.name
- self.flow_state.status = StepStatus.RUNNING
-
- # 查找下一个步骤;判断是否是end前最后一步
- next_step = await self._get_next_step()
- is_final_answer = next_step == "end"
-
- # 获取并验证Call类
- node_id = step_data.node
- call_cls = await self._get_call_cls(node_id)
+ step = self.flow.steps[step_id]
# 准备系统变量
sys_vars = CallVars(
question=self.question,
- task_id=self.task.record.task_id,
+ task_id=self.task.id,
flow_id=self.post_body_app.flow_id,
- session_id=self.task.session_id,
- history=self.task.flow_context,
- summary=self.flow_state.ai_summary,
- user_sub=self.task.user_sub,
+ session_id=self.task.ids.session_id,
+ history=self.task.context,
+ summary=self.task.runtime.summary,
+ user_sub=self.task.ids.user_sub,
+ service_id=step.params.get("service_id", ""),
+ )
+ step_runner = StepExecutor(
+ queue=self.queue,
+ task=self.task,
+ flow=self.flow,
+ sys_vars=sys_vars,
+ executor_background=self.executor_background,
+ question=self.question,
)
- # 初始化Call
- call_obj = call_cls.model_validate(step_data.params)
- input_data = await call_obj.init(sys_vars)
-
- # TODO: 处理slots
- # can_continue, slot_data = await self._process_slots(call_obj)
- # if not self._can_continue:
- # return
-
- # 执行Call并获取结果
- self.task = await push_step_input(self.task, self.queue, self.flow_state, input_data)
- result_data = await self._execute_call(call_obj, is_final_answer=is_final_answer)
- self.task = await push_step_output(self.task, self.queue, self.flow_state, result_data)
-
- # 更新下一步
- self.flow_state.step_id = next_step
+ # 运行Step
+ call_id, call_obj = await step_runner.init_step(step_id)
+ # 尝试填参
+ input_data = await step_runner.fill_slots(call_obj)
+ # 运行
+ await step_runner.run_step(call_id, call_obj, input_data)
+ # 更新Task
+ self.task = step_runner.task
+ await TaskManager.save_task(self.task.id, self.task)
- async def _get_next_step(self) -> str:
+ async def _find_flow_next(self) -> str:
"""在当前步骤执行前,尝试获取下一步"""
+ if not self.task.state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
# 如果当前步骤为结束,则直接返回
- if self.flow_state.step_id == "end" or not self.flow_state.step_id:
+ if self.task.state.step_id == "end" or not self.task.state.step_id:
# 如果是最后一步,设置停止标志
- self._can_continue = False
- return ""
-
- if self.flow.steps[self.flow_state.step_id].node == "Choice":
- # 如果是选择步骤,那么现在还不知道下一步是谁,直接返回
+ self._reached_end = True
return ""
next_steps = []
# 遍历Edges,查找下一个节点
for edge in self.flow.edges:
- if edge.edge_from == self.flow_state.step_id:
+ if edge.edge_from == self.task.state.step_id:
next_steps += [edge.edge_to]
# 如果step没有任何出边,直接跳到end
if not next_steps:
return "end"
- # FIXME: 目前只使用第一个出边
+ # TODO: 目前只使用第一个出边
logger.info("[FlowExecutor] 下一步 %s", next_steps[0])
return next_steps[0]
-
async def run(self) -> None:
- """运行流,返回各步骤结果,直到无法继续执行
+ """
+ 运行流,返回各步骤结果,直到无法继续执行
数据通过向Queue发送消息的方式传输
"""
- task_actor = ray.get_actor("task")
+ if not self.task.state:
+ err = "[FlowExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
logger.info("[FlowExecutor] 运行工作流")
# 推送Flow开始消息
- self.task = await push_flow_start(self.task, self.queue, self.flow_state, self.question)
+ await self.push_message(EventType.FLOW_START)
+
+ # 进行开始前的系统步骤
+ for step_id, step in FIXED_STEPS_BEFORE_START.items():
+ self.flow.steps[step_id] = step
+ await self._invoke_runner(step_id)
+ self.task.state.step_id = "start"
+ self.task.state.step_name = "开始"
# 如果允许继续运行Flow
- while self._can_continue:
+ while not self._reached_end:
# Flow定义中找不到step
- if not self.flow_state.step_id or (self.flow_state.step_id not in self.flow.steps):
- logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.flow_state.step_id)
- self.flow_state.status = StepStatus.ERROR
+ if not self.task.state.step_id or (self.task.state.step_id not in self.flow.steps):
+ logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.task.state.step_id)
+ self.task.state.status = StepStatus.ERROR
- if self.flow_state.status == StepStatus.ERROR:
+ if self.task.state.status == StepStatus.ERROR:
# 执行错误处理步骤
logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤")
- step = self.flow.on_error
- await self._run_error(step)
+ self.flow.steps["_error"] = ERROR_STEP
+ await self._invoke_runner("_error")
else:
# 执行正常步骤
- step = self.flow.steps[self.flow_state.step_id]
- await self._run_step(self.flow_state.step_id, step)
+ step = self.flow.steps[self.task.state.step_id]
+ await self._invoke_runner(self.task.state.step_id)
# 步骤结束,更新全局的Task
- await task_actor.set_task.remote(self.task.record.task_id, self.task)
+ await TaskManager.save_task(self.task.id, self.task)
+
+ # 查找下一个节点
+ next_step_id = await self._find_flow_next()
+ if next_step_id:
+ self.task.state.step_id = next_step_id
+ self.task.state.step_name = self.flow.steps[next_step_id].name
+ else:
+ # 如果没有下一个节点,设置结束标志
+ self._reached_end = True
+
+ # 运行结束后的系统步骤
+ for step_id, step in FIXED_STEPS_AFTER_END.items():
+ self.flow.steps[step_id] = step
+ await self._invoke_runner(step_id)
# 推送Flow停止消息
- self.task = await push_flow_stop(self.task, self.queue, self.flow_state, self.flow, self._final_answer)
+ await self.push_message(EventType.FLOW_STOP)
# 更新全局的Task
- await task_actor.set_task.remote(self.task.record.task_id, self.task)
+ await TaskManager.save_task(self.task.id, self.task)
diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py
deleted file mode 100644
index f387dc5ffd352d69ebaa997941c55534e8de1cd4..0000000000000000000000000000000000000000
--- a/apps/scheduler/executor/message.py
+++ /dev/null
@@ -1,131 +0,0 @@
-"""FlowExecutor的消息推送
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-"""
-from datetime import datetime, timezone
-from typing import Any
-
-from ray import actor
-
-from apps.entities.enum_var import EventType, FlowOutputType
-from apps.entities.flow import Flow
-from apps.entities.message import (
- FlowStartContent,
- FlowStopContent,
- TextAddContent,
-)
-from apps.entities.task import (
- ExecutorState,
- FlowStepHistory,
- TaskBlock,
-)
-
-
-# FIXME: 临时使用截断规避前端问题
-def truncate_data(data: Any) -> Any:
- """截断数据"""
- if isinstance(data, dict):
- return {k: truncate_data(v) for k, v in data.items()}
- if isinstance(data, list):
- return [truncate_data(v) for v in data]
- if isinstance(data, str):
- if len(data) > 300:
- return data[:300] + "..."
- return data
- return data
-
-
-async def push_step_input(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, input_data: dict[str, Any]) -> TaskBlock:
- """推送步骤输入"""
- # 更新State
- task.flow_state = state
- # 更新FlowContext
- task.flow_context[state.step_id] = FlowStepHistory(
- task_id=task.record.task_id,
- flow_id=state.flow_id,
- step_id=state.step_id,
- status=state.status,
- input_data=state.filled_data,
- output_data={},
- )
- # 步骤开始,重置时间
- task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp(), 2)
-
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.STEP_INPUT, data=truncate_data(input_data)) # type: ignore[attr-defined]
- return task
-
-
-async def push_step_output(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, output: dict[str, Any]) -> TaskBlock:
- """推送步骤输出"""
- # 更新State
- task.flow_state = state
-
- # 更新FlowContext
- task.flow_context[state.step_id].output_data = output
- task.flow_context[state.step_id].status = state.status
-
- # FlowContext加入Record
- task.new_context.append(task.flow_context[state.step_id].id)
-
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.STEP_OUTPUT, data=truncate_data(output)) # type: ignore[attr-defined]
- return task
-
-
-async def push_flow_start(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, question: str) -> TaskBlock:
- """推送Flow开始"""
- # 设置state
- task.flow_state = state
-
- # 组装消息
- content = FlowStartContent(
- question=question,
- params=state.filled_data,
- )
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.FLOW_START, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined]
- return task
-
-
-async def assemble_flow_stop_content(state: ExecutorState, flow: Flow) -> FlowStopContent:
- """组装Flow结束消息"""
- if state.remaining_schema:
- # 如果当前Flow是填充步骤,则推送Schema
- content = FlowStopContent(
- type=FlowOutputType.SCHEMA,
- data=state.remaining_schema,
- )
- else:
- content = FlowStopContent()
-
- return content
-
- # elif call_type == "render":
- # # 如果当前Flow是图表,则推送Chart
- # chart_option = task.flow_context[state.step_id].output_data["output"]
- # content = FlowStopContent(
- # type=FlowOutputType.CHART,
- # data=chart_option,
- # )
-
-async def push_flow_stop(task: TaskBlock, queue: actor.ActorHandle, state: ExecutorState, flow: Flow, final_answer: str) -> TaskBlock:
- """推送Flow结束"""
- # 设置state
- task.flow_state = state
- # 保存最终输出
- task.record.content.answer = final_answer
- content = await assemble_flow_stop_content(state, flow)
-
- # 推送Stop消息
- await queue.push_output.remote(task, event_type=EventType.FLOW_STOP, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined]
- return task
-
-
-async def push_text_output(task: TaskBlock, queue: actor.ActorHandle, text: str) -> None:
- """推送文本输出"""
- content = TextAddContent(
- text=text,
- )
- # 推送消息
- await queue.push_output.remote(task, event_type=EventType.TEXT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) # type: ignore[attr-defined]
diff --git a/apps/scheduler/executor/node.py b/apps/scheduler/executor/node.py
new file mode 100644
index 0000000000000000000000000000000000000000..be5afa27e269389ac3a6ca168c4a4dad8cf37c7d
--- /dev/null
+++ b/apps/scheduler/executor/node.py
@@ -0,0 +1,50 @@
+"""Node相关的函数,含Node转换为Call"""
+
+import inspect
+from typing import Any
+
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.empty import Empty
+from apps.scheduler.call.facts.facts import FactsCall
+from apps.scheduler.call.slot.slot import Slot
+from apps.scheduler.call.summary.summary import Summary
+from apps.scheduler.pool.pool import Pool
+
+
+class StepNode:
+ """Executor的Node相关逻辑"""
+
+ @staticmethod
+ async def check_cls(call_cls: Any) -> bool:
+ """检查Call是否符合标准要求"""
+ flag = True
+ if not hasattr(call_cls, "name") or not isinstance(call_cls.name, str):
+ flag = False
+ if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str):
+ flag = False
+ if not hasattr(call_cls, "exec") or not inspect.isasyncgenfunction(call_cls.exec):
+ flag = False
+ return flag
+
+ @staticmethod
+ async def get_call_cls(call_id: str) -> type[CoreCall]:
+ """获取并验证Call类"""
+ # 特判,用于处理隐藏节点
+ if call_id == "Empty":
+ return Empty
+ if call_id == "Summary":
+ return Summary
+ if call_id == "Facts":
+ return FactsCall
+ if call_id == "Slot":
+ return Slot
+
+ # 从Pool中获取对应的Call
+ call_cls: type[CoreCall] = await Pool().get_call(call_id)
+
+ # 检查Call合法性
+ if not await StepNode.check_cls(call_cls):
+ err = f"[FlowExecutor] 工具 {call_id} 不符合Call标准要求"
+ raise ValueError(err)
+
+ return call_cls
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5674b4e01e376167d937ca41489e9a807278245
--- /dev/null
+++ b/apps/scheduler/executor/step.py
@@ -0,0 +1,192 @@
+"""工作流中步骤相关函数"""
+
+import logging
+from typing import Any
+
+from pydantic import ConfigDict
+
+from apps.common.queue import MessageQueue
+from apps.entities.enum_var import EventType, StepStatus
+from apps.entities.scheduler import CallOutputChunk, CallVars
+from apps.entities.task import FlowStepHistory, Task
+from apps.manager.node import NodeManager
+from apps.manager.task import TaskManager
+from apps.scheduler.call.core import CoreCall
+from apps.scheduler.call.slot.schema import SlotInput, SlotOutput
+from apps.scheduler.call.slot.slot import Slot
+from apps.scheduler.executor.base import BaseExecutor
+from apps.scheduler.executor.node import StepNode
+
+logger = logging.getLogger(__name__)
+SPECIAL_EVENT_TYPES = [
+ EventType.GRAPH,
+ EventType.SUGGEST,
+ EventType.TEXT_ADD,
+]
+
+
+class StepExecutor(BaseExecutor):
+ """工作流中步骤相关函数"""
+
+ sys_vars: CallVars
+
+ model_config = ConfigDict(
+ arbitrary_types_allowed=True,
+ extra="allow",
+ )
+
+ def __init__(self, **kwargs: Any) -> None:
+ """初始化"""
+ super().__init__(**kwargs)
+ if not self.task.state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+ self.step_history = FlowStepHistory(
+ task_id=self.sys_vars.task_id,
+ flow_id=self.sys_vars.flow_id,
+ step_id=self.task.state.step_id,
+ status=self.task.state.status,
+ input_data={},
+ output_data={},
+ )
+
+ async def _push_call_output(
+ self,
+ task: Task,
+ queue: MessageQueue,
+ msg_type: EventType,
+ data: dict[str, Any],
+ ) -> None:
+ """推送输出"""
+ # 如果不是特殊类型
+ if msg_type not in SPECIAL_EVENT_TYPES:
+ err = f"[StepExecutor] 不支持的事件类型: {msg_type}"
+ logger.error(err)
+ raise ValueError(err)
+
+ # 推送消息
+ await queue.push_output(
+ task,
+ event_type=msg_type,
+ data=data,
+ )
+
+ async def init_step(self, step_id: str) -> tuple[str, Any]:
+ """初始化步骤"""
+ if not self.task.state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+ logger.info("[FlowExecutor] 运行步骤 %s", self.flow.steps[step_id].name)
+
+ # State写入ID和运行状态
+ self.task.state.step_id = step_id
+ self.task.state.step_name = self.flow.steps[step_id].name
+ self.task.state.status = StepStatus.RUNNING
+ await TaskManager.save_task(self.task.id, self.task)
+
+ # 获取并验证Call类
+ node_id = self.flow.steps[step_id].node
+ # 获取node详情并存储
+ try:
+ self._node_data = await NodeManager.get_node(node_id)
+ except Exception:
+ logger.exception("[StepExecutor] 获取Node失败,为内置Node或ID不存在")
+ self._node_data = None
+
+ if self._node_data:
+ call_cls = await StepNode.get_call_cls(self._node_data.call_id)
+ call_id = self._node_data.call_id
+ else:
+ # 可能是特殊的内置Node
+ call_cls = await StepNode.get_call_cls(node_id)
+ call_id = node_id
+
+ # 初始化Call Class,用户参数会覆盖node的参数
+ params: dict[str, Any] = (
+ self._node_data.known_params if self._node_data and self._node_data.known_params else {}
+ ) # type: ignore[union-attr]
+ if self.flow.steps[step_id].params:
+ params.update(self.flow.steps[step_id].params)
+
+ # 检查初始化Call Class是否需要特殊处理
+ try:
+ # TODO
+ call_obj = call_cls.init(params)
+ except Exception:
+ logger.exception("[StepExecutor] 初始化Call失败")
+ raise
+ else:
+ return call_id, call_obj
+
+ async def fill_slots(self, call_obj: Any) -> dict[str, Any]:
+ """执行Call并处理结果"""
+ if not self.task.state:
+ err = "[StepExecutor] 当前ExecutorState为空"
+ logger.error(err)
+ raise ValueError(err)
+
+ # 尝试初始化call_obj
+ try:
+ input_data = await call_obj.init(self.sys_vars)
+ except Exception:
+ logger.exception("[StepExecutor] 初始化Call失败")
+ raise
+
+ # 合并已经填充的参数
+ input_data.update(self.task.runtime.filled)
+
+ # 检查输入参数
+ input_schema = (
+ call_obj.input_type.model_json_schema(override=self._node_data.override_input) if self._node_data else {}
+ )
+
+ # 检查输入参数是否需要填充
+ slot_obj = Slot(
+ data=input_data,
+ current_schema=input_schema,
+ summary=self.task.runtime.summary,
+ facts=self.task.runtime.facts,
+ )
+ slot_input = SlotInput(**await slot_obj.init(self.sys_vars))
+
+ # 若需要填参,额外执行一步
+ if slot_input.remaining_schema:
+ output = SlotOutput(**await self.run_step("Slot", slot_obj, slot_input.remaining_schema))
+ self.task.runtime.filled = output.slot_data
+
+ # 如果还有剩余参数,则中断
+ if output.remaining_schema:
+ # 更新状态
+ self.task.state.status = StepStatus.PARAM
+ self.task.state.slot = output.remaining_schema
+ err = "[StepExecutor] 需要填参"
+ logger.error(err)
+ raise ValueError(err)
+
+ return output.slot_data
+ return input_data
+
+ # TODO: 需要评估如何进行流式输出
+ async def run_step(self, call_id: str, call_obj: CoreCall, input_data: dict[str, Any]) -> dict[str, Any]:
+ """运行单个步骤"""
+ iterator = call_obj.exec(self, input_data)
+ await self.push_message(EventType.STEP_INPUT, input_data)
+
+ result: CallOutputChunk | None = None
+ async for chunk in iterator:
+ result = chunk
+
+ if result is None:
+ content = {}
+ else:
+ content = result.content if isinstance(result.content, dict) else {"message": result.content}
+
+ # 更新执行状态
+ self.task.state.status = StepStatus.SUCCESS
+ await self.push_message(EventType.STEP_OUTPUT, content)
+
+ return content