diff --git a/.gitee/ISSUE_TEMPLATE/bug.yaml b/.gitee/ISSUE_TEMPLATE/bug.yaml deleted file mode 100644 index e15737b31be1a6aeb36af4c20464a2951c2d64bd..0000000000000000000000000000000000000000 --- a/.gitee/ISSUE_TEMPLATE/bug.yaml +++ /dev/null @@ -1,85 +0,0 @@ -name: Bug 反馈 -description: 当你在代码中发现了一个 Bug,导致应用崩溃或抛出异常,或者有一个组件存在问题,或者某些地方看起来不对劲。 -title: "[Bug]: " -labels: ["bug"] -body: - - type: markdown - attributes: - value: | - 感谢对项目的支持与关注。在提出问题之前,请确保你已查看相关开发或使用文档: - - https://gitee.com/openeuler/euler-copilot-framework/tree/master/docs - - type: checkboxes - attributes: - label: 这个问题是否已经存在? - options: - - label: 我已经搜索过现有的问题 (https://gitee.com/openeuler/euler-copilot-framework/issues) - required: true - - type: textarea - attributes: - label: 如何复现 - description: 请详细告诉我们如何复现你遇到的问题,如涉及代码,可提供一个最小代码示例,并使用反引号```附上它 - placeholder: | - 1. ... - 2. ... - 3. ... - validations: - required: true - - type: textarea - attributes: - label: 预期结果 - description: 请告诉我们你预期会发生什么。 - validations: - required: true - - type: textarea - attributes: - label: 实际结果 - description: 请告诉我们实际发生了什么。 - validations: - required: true - - type: textarea - attributes: - label: 截图或视频 - description: 如果可以的话,上传任何关于 bug 的截图。 - value: | - [在这里上传图片] - - type: dropdown - id: version - attributes: - label: 版本 - description: 你当前正在使用我们软件的哪个版本/分支? - options: - - "0.4.2" - - "0.9.1" - - "0.9.3" - - "0.9.4 (默认)" - - "0.9.5 (预发布)" - validations: - required: true - - type: dropdown - id: os - attributes: - label: 操作系统 - description: 你使用的操作系统是什么? - options: - - "openEuler 22.03 LTS" - - "openEuler 22.03 LTS SP1" - - "openEuler 22.03 LTS SP2" - - "openEuler 22.03 LTS SP3" - - "openEuler 22.03 LTS SP4 (默认)" - - "openEuler 24.03 LTS" - - "openEuler 24.03 LTS SP1" - - "openEuler 25.03" - - "其他" - validations: - required: true - - type: textarea - attributes: - label: 硬件环境 - description: 请提供有关你的硬件环境的信息,例如处理器类型、内存大小等。 - placeholder: | - 1. 处理器: - 2. 内存大小: - 3. 大模型类型: - 4. 推理框架: - 5. GPU/NPU 型号: - 6. 其他: \ No newline at end of file diff --git a/.gitee/ISSUE_TEMPLATE/config.yaml b/.gitee/ISSUE_TEMPLATE/config.yaml deleted file mode 100644 index ec4bb386bcf8a4946923eff961cc7cdf70c0aedf..0000000000000000000000000000000000000000 --- a/.gitee/ISSUE_TEMPLATE/config.yaml +++ /dev/null @@ -1 +0,0 @@ -blank_issues_enabled: false \ No newline at end of file diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index ec5a9b03cf7541346951fd9453106c56058dd5c8..89e94ae47e969e869e35f7dc596384efdceb9e10 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -9,8 +9,8 @@ from typing import Any from pydantic import BaseModel, Field from apps.common.config import Config -from apps.entities.enum_var import CommentType from apps.entities.appcenter import AppData +from apps.entities.enum_var import CommentType from apps.entities.flow_topology import FlowItem @@ -130,7 +130,7 @@ class AddCommentData(BaseModel): record_id: str group_id: str comment: CommentType - dislike_reason: list[str] = Field(default=[], max_length=10) + dislike_reason: str = Field(default="", max_length=200) reason_link: str = Field(default="", max_length=200) reason_description: str = Field(default="", max_length=500) diff --git a/apps/entities/task.py b/apps/entities/task.py index 7f2b1b6d290f9575e978c50db2b3fac2b83b951e..e4eedbbb8e6defec047cd81b2cd68fd77cf2b40f 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -27,6 +27,7 @@ class FlowStepHistory(BaseModel): flow_name: str = Field(description="Flow名称") step_id: str = Field(description="当前步骤名称") step_name: str = Field(description="当前步骤名称") + step_description: str = Field(description="当前步骤描述") status: StepStatus = Field(description="当前步骤状态") input_data: dict[str, Any] = Field(description="当前Step执行的输入", default={}) output_data: dict[str, Any] = Field(description="当前Step执行后的结果", default={}) diff --git a/apps/manager/task.py b/apps/manager/task.py index 5b11dc21606e67c055bc03e993077cb0bf6bf2d2..e784bf6962529130a6dee7973bbd7268a26ee19d 100644 --- a/apps/manager/task.py +++ b/apps/manager/task.py @@ -115,7 +115,7 @@ class TaskManager: ).sort( "created_at", -1, ).limit(length): - flow_context.extend(history) + flow_context += [history] except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") return [] diff --git a/apps/routers/comment.py b/apps/routers/comment.py index d7d368058bfb19118aa5c12aff17ed973eabd102..781674c286ee60ddc30029f583d5d4ec592a4de3 100644 --- a/apps/routers/comment.py +++ b/apps/routers/comment.py @@ -41,7 +41,7 @@ async def add_comment(post_body: AddCommentData, user_sub: Annotated[str, Depend comment_data = RecordComment( comment=post_body.comment, - dislike_reason=post_body.dislike_reason, + dislike_reason=post_body.dislike_reason.split(";")[:-1], reason_link=post_body.reason_link, reason_description=post_body.reason_description, feedback_time=round(datetime.now(tz=UTC).timestamp(), 3), diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 924c425149cab89c231c60c0d9c449433aa9c09a..5c88cd2d98f36de96cbcd030133cb13cf015ebf9 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -5,7 +5,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ from apps.scheduler.call.api.api import API -from apps.scheduler.call.convert.convert import Convert from apps.scheduler.call.graph.graph import Graph from apps.scheduler.call.llm.llm import LLM from apps.scheduler.call.rag.rag import RAG @@ -18,7 +17,6 @@ __all__ = [ "LLM", "RAG", "SQL", - "Convert", "Graph", "Suggestion", ] diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index d4db20faf6355bc3cf16ecff29c51dab2c58a672..550f1edf185beca6049cf9e6a5b213344266235f 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -10,8 +10,10 @@ from collections.abc import AsyncGenerator from typing import Any import aiohttp +from aiohttp.client import _RequestContextManager from fastapi import status from pydantic import Field +from pydantic.json_schema import SkipJsonSchema from apps.common.oidc import oidc_provider from apps.entities.enum_var import CallOutputType, ContentType, HTTPMethod @@ -50,23 +52,21 @@ SUCCESS_HTTP_CODES = [ class API(CoreCall, input_model=APIInput, output_model=APIOutput): """API调用工具""" - enable_filling: bool = Field(description="是否需要进行自动参数填充", default=True) + enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=True) url: str = Field(description="API接口的完整URL") method: HTTPMethod = Field(description="API接口的HTTP Method") - content_type: ContentType = Field(description="API接口的Content-Type") + content_type: ContentType | None = Field(description="API接口的Content-Type", default=None) timeout: int = Field(description="工具超时时间", default=300, gt=30) body: dict[str, Any] = Field(description="已知的部分请求体", default={}) query: dict[str, Any] = Field(description="已知的部分请求参数", default={}) - @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" return CallInfo(name="API调用", description="向某一个API接口发送HTTP请求,获取数据。") - async def _init(self, call_vars: CallVars) -> APIInput: """初始化API调用工具""" self._service_id = "" @@ -77,7 +77,8 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): # 获取对应API的Service Metadata try: service_metadata = await ServiceCenterManager.get_service_metadata( - call_vars.ids.user_sub, self.node.service_id or "", + call_vars.ids.user_sub, + self.node.service_id or "", ) except Exception as e: raise CallError( @@ -96,7 +97,6 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): body=self.body, ) - async def _exec(self, input_data: dict[str, Any]) -> AsyncGenerator[CallOutputChunk, None]: """调用API,然后返回LLM解析后的数据""" self._session = aiohttp.ClientSession() @@ -111,84 +111,125 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): finally: await self._session.close() - - async def _make_api_call(self, data: APIInput, files: aiohttp.FormData): # noqa: ANN202, C901 + async def _make_api_call(self, data: APIInput, files: aiohttp.FormData) -> _RequestContextManager: """组装API请求Session""" # 获取必要参数 - req_header = { - "Content-Type": self.content_type.value, - } + if self._auth: + req_header, req_cookie, req_params = await self._apply_auth() + else: + req_header = {} + req_cookie = {} + req_params = {} + + if self.method in [HTTPMethod.GET.value, HTTPMethod.DELETE.value]: + return await self._handle_query_request(data, req_header, req_cookie, req_params) + + if self.method in [HTTPMethod.POST.value, HTTPMethod.PUT.value, HTTPMethod.PATCH.value]: + return await self._handle_body_request(data, files, req_header, req_cookie) + + raise CallError( + message="API接口的HTTP Method不支持", + data={}, + ) + + async def _apply_auth(self) -> tuple[dict[str, str], dict[str, str], dict[str, str]]: + """应用认证信息到请求参数中""" + # self._auth可能是None或ServiceApiAuth类型 + # ServiceApiAuth类型包含header、cookie、query和oidc属性 + req_header = {} req_cookie = {} req_params = {} - if self._auth: - if self._auth.header: - for item in self._auth.header: - req_header[item.name] = item.value - elif self._auth.cookie: - for item in self._auth.cookie: - req_cookie[item.name] = item.value - elif self._auth.query: - for item in self._auth.query: - req_params[item.name] = item.value - elif self._auth.oidc: - token = await TokenManager.get_plugin_token( - self._service_id, - self._session_id, - await oidc_provider.get_access_token_url(), - 30, - ) - req_header.update({"access-token": token}) - - if self.method in [ - HTTPMethod.GET.value, - HTTPMethod.DELETE.value, - ]: - req_params.update(data.query) - return self._session.request( - self.method, - self.url, - params=req_params, - headers=req_header, - cookies=req_cookie, - timeout=self._timeout, + if self._auth.header: # type: ignore[attr-defined] # 如果header列表非空 + for item in self._auth.header: # type: ignore[attr-defined] + req_header[item.name] = item.value + elif self._auth.cookie: # type: ignore[attr-defined] # 如果cookie列表非空 + for item in self._auth.cookie: # type: ignore[attr-defined] + req_cookie[item.name] = item.value + elif self._auth.query: # type: ignore[attr-defined] # 如果query列表非空 + for item in self._auth.query: # type: ignore[attr-defined] + req_params[item.name] = item.value + elif self._auth.oidc: # type: ignore[attr-defined] # 如果oidc配置存在 + token = await TokenManager.get_plugin_token( + self._service_id, + self._session_id, + await oidc_provider.get_access_token_url(), + 30, ) + req_header.update({"access-token": token}) + + return req_header, req_cookie, req_params + + async def _handle_query_request( + self, data: APIInput, req_header: dict[str, str], req_cookie: dict[str, str], req_params: dict[str, str], + ) -> _RequestContextManager: + """处理GET和DELETE请求""" + req_params.update(data.query) + return self._session.request( + self.method, + self.url, + params=req_params, + headers=req_header, + cookies=req_cookie, + timeout=self._timeout, + ) - if self.method in [ - HTTPMethod.POST.value, - HTTPMethod.PUT.value, - HTTPMethod.PATCH.value, - ]: - req_body = data.body - if self.content_type in [ - ContentType.FORM_URLENCODED.value, - ContentType.MULTIPART_FORM_DATA.value, - ]: - form_data = files - for key, val in req_body.items(): - form_data.add_field(key, val) - return self._session.request( - self.method, - self.url, - data=form_data, - headers=req_header, - cookies=req_cookie, - timeout=self._timeout, - ) - return self._session.request( - self.method, - self.url, - json=req_body, - headers=req_header, - cookies=req_cookie, - timeout=self._timeout, + async def _handle_body_request( + self, data: APIInput, files: aiohttp.FormData, req_header: dict[str, str], req_cookie: dict[str, str], + ) -> _RequestContextManager: + """处理POST、PUT和PATCH请求""" + if not self.content_type: + raise CallError( + message="API接口的Content-Type未指定", + data={}, ) + req_body = data.body + + if self.content_type in [ContentType.FORM_URLENCODED.value, ContentType.MULTIPART_FORM_DATA.value]: + return await self._handle_form_request(req_body, files, req_header, req_cookie) + + if self.content_type == ContentType.JSON.value: + return await self._handle_json_request(req_body, req_header, req_cookie) + raise CallError( - message="API接口的HTTP Method不支持", + message="API接口的Content-Type不支持", data={}, ) + async def _handle_form_request( + self, + req_body: dict[str, Any], + form_data: aiohttp.FormData, + req_header: dict[str, str], + req_cookie: dict[str, str], + ) -> _RequestContextManager: + """处理表单类型的请求""" + for key, val in req_body.items(): + form_data.add_field(key, val) + + return self._session.request( + self.method, + self.url, + data=form_data, + headers=req_header, + cookies=req_cookie, + timeout=self._timeout, + ) + + async def _handle_json_request( + self, req_body: dict[str, Any], req_header: dict[str, str], req_cookie: dict[str, str], + ) -> _RequestContextManager: + """处理JSON类型的请求""" + return self._session.request( + self.method, + self.url, + json=req_body, + headers=req_header, + cookies=req_cookie, + timeout=self._timeout, + ) + async def _call_api(self, final_data: APIInput) -> APIOutput: """实际调用API,并处理返回值""" # 获取必要参数 diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 6c835f9c5bab4f7a6ad57c9d0b26979bc585e7e7..74c51a78f18fdfe9a2b47d9e968e0d9897f82386 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -50,6 +50,8 @@ class CoreCall(BaseModel): name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True) description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True) node: SkipJsonSchema[NodePool | None] = Field(description="节点信息", exclude=True) + enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False, exclude=True) + tokens: SkipJsonSchema[CallTokens] = Field(description="Call的Tokens", default=CallTokens(), exclude=True) input_model: ClassVar[SkipJsonSchema[type[DataBase]]] = Field( description="Call的输入Pydantic类型;不包含override的模板", exclude=True, @@ -62,9 +64,6 @@ class CoreCall(BaseModel): ) to_user: bool = Field(description="是否需要将输出返回给用户", default=False) - enable_filling: bool = Field(description="是否需要进行自动参数填充", default=False) - tokens: CallTokens = Field(description="Call的Tokens", default=CallTokens()) - model_config = ConfigDict( arbitrary_types_allowed=True, diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index a6141a6ce5b7201283d8dfc00c24a5407715376f..ef7f618a161d131cc3ce77ee14027fcae0f08868 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -58,7 +58,10 @@ class LLM(CoreCall, input_model=LLMInput, output_model=LLMOutput): ) # 上下文信息 - step_history = list(call_vars.history.values())[-self.step_history_size :] + step_history = [] + for ids in call_vars.history_order[-self.step_history_size:]: + step_history += [call_vars.history[ids]] + if self.enable_context: context_tmpl = env.from_string(LLM_CONTEXT_PROMPT) context_prompt = context_tmpl.render( diff --git a/apps/scheduler/call/llm/schema.py b/apps/scheduler/call/llm/schema.py index 82031c485e30e7b76a752568ed2384e9f64e3d8a..64fe93af6be31953b6e692afb5855f0f0581588d 100644 --- a/apps/scheduler/call/llm/schema.py +++ b/apps/scheduler/call/llm/schema.py @@ -8,14 +8,20 @@ from apps.scheduler.call.core import DataBase LLM_CONTEXT_PROMPT = dedent( r""" - 以下是对用户和AI间对话的简短总结: - {{ summary }} + 以下是对用户和AI间对话的简短总结,在中给出: + + {{ summary }} + - 你作为AI,在回答用户的问题前,需要获取必要信息。为此,你调用了以下工具,并获得了输出: + 你作为AI,在回答用户的问题前,需要获取必要的信息。为此,你调用了一些工具,并获得了它们的输出: + 工具的输出数据将在中给出, 其中为工具的名称,为工具的输出数据。 {% for tool in history_data %} - {{ tool.step_id }} - {{ tool.output_data }} + + {{ tool.step_name }} + {{ tool.step_description }} + {{ tool.output_data }} + {% endfor %} """, diff --git a/apps/scheduler/call/slot/schema.py b/apps/scheduler/call/slot/schema.py index 08373cd962676bdf82bb3ce9cd9a8d0f9dee3f86..31fc9abaa4d43c825984ee420a75c80aaa287c19 100644 --- a/apps/scheduler/call/slot/schema.py +++ b/apps/scheduler/call/slot/schema.py @@ -5,14 +5,91 @@ from pydantic import Field from apps.scheduler.call.core import DataBase -SLOT_GEN_PROMPT = [ - {"role": "user", "content": ""}, - {"role": "assistant", "content": r""" +SLOT_GEN_PROMPT = r""" + + 你是一个可以使用工具的AI助手,正尝试使用工具来完成任务。 + 目前,你正在生成一个JSON参数对象,以作为调用工具的输入。 + 请根据用户输入、背景信息、工具信息和JSON Schema内容,生成符合要求的JSON对象。 + 背景信息将在中给出,工具信息将在中给出,JSON Schema将在中给出,\ + 用户的问题将在中给出。 + 请在中输出生成的JSON对象。 - """, - }, -] + 要求: + 1. 严格按照JSON Schema描述的JSON格式输出,不要编造不存在的字段。 + 2. JSON字段的值优先使用用户输入中的内容。如果用户输入没有,则使用背景信息中的内容。 + 3. 只输出JSON对象,不要输出任何解释说明,不要输出任何其他内容。 + 4. 如果JSON Schema中描述的JSON字段是可选的,则可以不输出该字段。 + 5. example中仅为示例,不要照搬example中的内容,不要将example中的内容作为输出。 + + + + + 用户询问杭州今天的天气情况。AI回复杭州今天晴,温度20℃。用户询问杭州明天的天气情况。 + + + 杭州明天的天气情况如何? + + + 工具名称:check_weather + 工具描述:查询指定城市的天气信息 + + + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "城市名称" + }, + "date": { + "type": "string", + "description": "查询日期" + }, + "required": ["city", "date"] + } + } + + + { + "city": "杭州", + "date": "明天" + } + + + + + 以下是对用户给你的任务的历史总结,在中给出: + + {{summary}} + + 附加的条目化信息: + {{ facts }} + + + 在本次任务中,你已经调用过一些工具,并获得了它们的输出,在中给出: + + {% for tool in history_data %} + + {{ tool.step_name }} + {{ tool.step_description }} + {{ tool.output_data }} + + {% endfor %} + + + + {{question}} + + + 工具名称:{{current_tool["name"]}} + 工具描述:{{current_tool["description"]}} + + + {{schema}} + + + """ class SlotInput(DataBase): diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index f5764942ebc91f5b2092b825c6bc531c4825b092..b197205e681109fa12ffd8c7da9097d84ce2c3c2 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -3,15 +3,17 @@ from collections.abc import AsyncGenerator from typing import TYPE_CHECKING, Any, Self -import jinja2 +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment from pydantic import Field from apps.entities.enum_var import CallOutputType from apps.entities.pool import NodePool from apps.entities.scheduler import CallInfo, CallOutputChunk, CallVars from apps.llm.patterns.json_gen import Json +from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall -from apps.scheduler.call.slot.schema import SlotInput, SlotOutput +from apps.scheduler.call.slot.schema import SLOT_GEN_PROMPT, SlotInput, SlotOutput from apps.scheduler.slot.slot import Slot as SlotProcessor if TYPE_CHECKING: @@ -36,16 +38,46 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): async def _llm_slot_fill(self, remaining_schema: dict[str, Any]) -> dict[str, Any]: """使用LLM填充参数""" + env = SandboxedEnvironment( + loader=BaseLoader(), + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + ) + template = env.from_string(SLOT_GEN_PROMPT) + conversation = [ - {"role": "user", "content": rf""" - 背景信息摘要: {self.summary} + {"role": "system", "content": ""}, + {"role": "user", "content": template.render( + current_tool={ + "name": self.name, + "description": self.description, + }, + schema=remaining_schema, + history_data=self._flow_history, + summary=self.summary, + question=self._question, + facts=self.facts, + )}, + ] - 事实信息: {self.facts} + # 使用大模型进行尝试 + reasoning = ReasoningLLM() + answer = "" + async for chunk in reasoning.call(messages=conversation, streaming=False): + answer += chunk + self.tokens.input_tokens += reasoning.input_tokens + self.tokens.output_tokens += reasoning.output_tokens - 历史输出(JSON): {self._flow_history} - """}, - ] + conversation = [ + {"role": "user", "content": f""" + 用户问题:{self._question} + 参考数据: + {answer} + """, + }, + ] return await Json().generate( conversation=conversation, spec=remaining_schema, @@ -68,9 +100,10 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): async def _init(self, call_vars: CallVars) -> SlotInput: """初始化""" - self._flow_history = {} + self._flow_history = [] + self._question = call_vars.question for key in call_vars.history_order[:-self.step_num]: - self._flow_history[key] = call_vars.history[key] + self._flow_history += [call_vars.history[key]] if not self.current_schema: return SlotInput( diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 8d0511947d8a3bd745595c5bbe64a5189ff25184..158c33516f94b1a921822af12e4fbfe6e83f7758 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -38,7 +38,9 @@ if TYPE_CHECKING: class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionOutput): """问题推荐""" - configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置") + to_user: SkipJsonSchema[bool] = Field(default=True, description="是否将推荐的问题推送给用户") + + configs: list[SingleFlowSuggestionConfig] = Field(description="问题推荐配置", default=[]) num: int = Field(default=3, ge=1, le=6, description="推荐问题的总数量(必须大于等于configs中涉及的Flow的数量)") context: SkipJsonSchema[list[dict[str, str]]] = Field(description="Executor的上下文", exclude=True) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 84160cb3ff1699277ce9a82c6c28ce428b26e9e3..7deacdb10670795d38db70bd47bf5a70dfcbe8c9 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -87,38 +87,37 @@ class StepExecutor(BaseExecutor): # 暂存旧数据 current_step_id = self.task.state.step_id # type: ignore[arg-type] current_step_name = self.task.state.step_name # type: ignore[arg-type] - current_step_status = self.task.state.status # type: ignore[arg-type] # 更新State self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - # 准备参数 - params = { - "data": self.obj.input, - "current_schema": self.obj.input_model.model_json_schema( - override=self.node.override_input if self.node and self.node.override_input else {}, - ), - } # 初始化填参 - slot_obj = await Slot.instance(self, self.node, **params) + slot_obj = await Slot.instance( + self, + self.node, + data=self.obj.input, + current_schema=self.obj.input_model.model_json_schema( + override=self.node.override_input if self.node and self.node.override_input else {}, + ), + ) # 推送填参消息 await self.push_message(EventType.STEP_INPUT.value, slot_obj.input) # 运行填参 iterator = slot_obj.exec(self, slot_obj.input) async for chunk in iterator: result: SlotOutput = SlotOutput.model_validate(chunk.content) + self.task.tokens.input_tokens += slot_obj.tokens.input_tokens + self.task.tokens.output_tokens += slot_obj.tokens.output_tokens - # 如果没有填全 + # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - # 状态设置为待填参 self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] else: - # 推送填参结果 self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] - await self.push_message(EventType.STEP_OUTPUT.value, result.slot_data) + await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) @@ -126,7 +125,6 @@ class StepExecutor(BaseExecutor): # 恢复State self.task.state.step_id = current_step_id # type: ignore[arg-type] self.task.state.step_name = current_step_name # type: ignore[arg-type] - self.task.state.status = current_step_status # type: ignore[arg-type] self.task.tokens.input_tokens += self.obj.tokens.input_tokens self.task.tokens.output_tokens += self.obj.tokens.output_tokens @@ -207,6 +205,7 @@ class StepExecutor(BaseExecutor): flow_name=self.task.state.flow_name, # type: ignore[arg-type] step_id=self.step.step_id, step_name=self.step.step.name, + step_description=self.step.step.description, status=self.task.state.status, # type: ignore[arg-type] input_data=self.obj.input, output_data=output_data,