From 2459e9e51293c1e17a9fa0ac7b1da518dfcd8541 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 28 Feb 2025 16:54:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=90=8C=E6=AD=A5Executor=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E6=9B=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 +- apps/common/queue.py | 9 ++-- apps/common/task.py | 2 +- apps/entities/collection.py | 6 +-- apps/entities/flow_topology.py | 4 +- apps/entities/message.py | 1 + apps/entities/task.py | 3 +- apps/llm/patterns/select.py | 7 +-- apps/llm/reasoning.py | 2 +- apps/manager/flow.py | 4 +- apps/manager/node.py | 22 ++------ apps/routers/record.py | 2 +- apps/routers/service.py | 29 ++++++---- apps/scheduler/call/__init__.py | 2 + apps/scheduler/call/api.py | 68 +++++++++++------------ apps/scheduler/call/empty.py | 23 ++++++++ apps/scheduler/call/llm.py | 1 + apps/scheduler/call/rag.py | 8 ++- apps/scheduler/executor/flow.py | 75 +++++++++++++------------- apps/scheduler/pool/check.py | 2 + apps/scheduler/pool/loader/app.py | 12 ++--- apps/scheduler/pool/loader/call.py | 18 +++---- apps/scheduler/pool/loader/flow.py | 13 +++-- apps/scheduler/pool/loader/metadata.py | 10 +++- apps/scheduler/pool/loader/service.py | 9 +++- apps/scheduler/pool/pool.py | 26 ++++++--- apps/scheduler/scheduler/context.py | 2 +- apps/scheduler/scheduler/flow.py | 12 +++-- apps/scheduler/scheduler/message.py | 2 +- 29 files changed, 218 insertions(+), 159 deletions(-) create mode 100644 apps/scheduler/call/empty.py diff --git a/.gitignore b/.gitignore index 11b4659f0..d9827ef91 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ apps/utils/init *.bak apps/embedding -logs \ No newline at end of file +logs +.git-credentials \ No newline at end of file diff --git a/apps/common/queue.py b/apps/common/queue.py index 14c76a5d6..0d77ae15b 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -52,11 +52,11 @@ class MessageQueue: return # 计算当前Step时间 - step_time = round((datetime.now(timezone.utc).timestamp() - task.record.metadata.time), 3) + step_time = round((datetime.now(timezone.utc).timestamp() - task.record.metadata.time_cost), 3) metadata = MessageMetadata( - time=step_time, - input_tokens=task.record.metadata.input_tokens, - output_tokens=task.record.metadata.output_tokens, + timeCost=step_time, + inputTokens=task.record.metadata.input_tokens, + outputTokens=task.record.metadata.output_tokens, ) if task.flow_state: @@ -65,6 +65,7 @@ class MessageQueue: appId=task.flow_state.app_id, flowId=task.flow_state.name, stepId=task.flow_state.step_id, + stepName=task.flow_state.step_name, stepStatus=task.flow_state.status, ) else: diff --git a/apps/common/task.py b/apps/common/task.py index e46ed338a..2ae27d0a8 100644 --- a/apps/common/task.py +++ b/apps/common/task.py @@ -65,7 +65,7 @@ class Task: metadata=RecordMetadata( input_tokens=0, output_tokens=0, - time=0, + time_cost=0, feature={}, ), createdAt=round(datetime.now(timezone.utc).timestamp(), 3), diff --git a/apps/entities/collection.py b/apps/entities/collection.py index 953fae85d..01042e481 100644 --- a/apps/entities/collection.py +++ b/apps/entities/collection.py @@ -112,9 +112,9 @@ class Audit(BaseModel): class RecordMetadata(BaseModel): """Record表子项:Record的元信息""" - input_tokens: int - output_tokens: int - time: float + input_tokens: int = Field(default=0, alias="inputTokens") + output_tokens: int = Field(default=0, alias="outputTokens") + time_cost: float = Field(default=0, alias="timeCost") feature: dict[str, Any] = {} diff --git a/apps/entities/flow_topology.py b/apps/entities/flow_topology.py index 469ee3f87..f223c2c47 100644 --- a/apps/entities/flow_topology.py +++ b/apps/entities/flow_topology.py @@ -6,7 +6,7 @@ from typing import Any, Optional from pydantic import BaseModel, Field -from apps.entities.enum_var import EdgeType, NodeType +from apps.entities.enum_var import EdgeType class NodeMetaDataItem(BaseModel): @@ -52,7 +52,7 @@ class NodeItem(BaseModel): service_id: str = Field(alias="serviceId", default="") node_meta_data_id: str = Field(alias="nodeMetaDataId", default="") name: str = Field(default="") - type: str = Field(default=NodeType.NORMAL.value) + type: str = Field(default="Empty") description: str = Field(default="") enable: bool = Field(default=True) parameters: dict[str, Any] = Field(default={}) diff --git a/apps/entities/message.py b/apps/entities/message.py index 60f79865d..75bc490f1 100644 --- a/apps/entities/message.py +++ b/apps/entities/message.py @@ -24,6 +24,7 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") step_id: str = Field(description="当前步骤ID", alias="stepId") + step_name: str = Field(description="当前步骤名称", alias="stepName") step_status: StepStatus = Field(description="当前步骤状态", alias="stepStatus") diff --git a/apps/entities/task.py b/apps/entities/task.py index 23fd1db40..319a001d7 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -36,7 +36,8 @@ class ExecutorState(BaseModel): description: str = Field(description="执行器描述") status: StepStatus = Field(description="执行器状态") # 附加信息 - step_id: str = Field(description="当前步骤名称") + step_id: str = Field(description="当前步骤ID") + step_name: str = Field(description="当前步骤名称") app_id: str = Field(description="应用ID") # 运行时数据 ai_summary: str = Field(description="大模型的思考内容", default="") diff --git a/apps/llm/patterns/select.py b/apps/llm/patterns/select.py index 9b4551ea4..c90db86b2 100644 --- a/apps/llm/patterns/select.py +++ b/apps/llm/patterns/select.py @@ -23,7 +23,7 @@ class Select(CorePattern): 根据历史对话(包括工具调用结果)和用户问题,从给出的选项列表中,选出最符合要求的那一项。 在输出之前,请先思考,并使用“”标签给出思考过程。 - 结果需要使用JSON格式输出,输出格式为:{ "choice": "选项名称" } + 结果需要使用JSON格式输出,输出格式为:{{ "choice": "选项名称" }} @@ -31,7 +31,7 @@ class Select(CorePattern): 使用天气API,查询明天杭州的天气信息 - [API] 请求特定API,获得返回的JSON数据 + [API] HTTP请求,获得返回的JSON数据 [SQL] 查询数据库,获得数据库表中的数据 @@ -43,7 +43,7 @@ class Select(CorePattern): - { "choice": "API" } + {{ "choice": "API" }} @@ -123,6 +123,7 @@ class Select(CorePattern): data_str = json.dumps(kwargs.get("data", {}), ensure_ascii=False) choice_prompt, choices_list = self._choices_to_prompt(kwargs["choices"]) + logger.info("[Select] 选项列表: %s", choice_prompt) user_input = self.user_prompt.format( question=kwargs["question"], background=background, diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index e5a33eb87..fb74818ff 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -161,4 +161,4 @@ class ReasoningLLM: output_tokens = self._calculate_token_length([{"role": "assistant", "content": result}], pure_text=True) task = ray.get_actor("task") - await task.update_token_summary.remote(task_id, input_tokens, output_tokens) \ No newline at end of file + await task.update_token_summary.remote(task_id, input_tokens, output_tokens) diff --git a/apps/manager/flow.py b/apps/manager/flow.py index f3d2ba559..b7dfe49e3 100644 --- a/apps/manager/flow.py +++ b/apps/manager/flow.py @@ -71,7 +71,7 @@ class FlowManager: return False @staticmethod - async def get_node_meta_datas_by_service_id(service_id: str) -> Optional[list[NodeMetaDataItem]]: + async def get_node_id_by_service_id(service_id: str) -> Optional[list[NodeMetaDataItem]]: """serviceId获取service的接口数据,并将接口转换为节点元数据 :param service_id: 服务id @@ -157,7 +157,7 @@ class FlowManager: for record in service_records ] for service_item in service_items: - node_meta_datas = await FlowManager.get_node_meta_datas_by_service_id(service_item.service_id) + node_meta_datas = await FlowManager.get_node_id_by_service_id(service_item.service_id) if node_meta_datas is None: node_meta_datas = [] service_item.node_meta_datas = node_meta_datas diff --git a/apps/manager/node.py b/apps/manager/node.py index cfcc9fcab..83faa4eaf 100644 --- a/apps/manager/node.py +++ b/apps/manager/node.py @@ -6,7 +6,7 @@ import ray from pydantic import BaseModel from apps.entities.node import APINode -from apps.entities.pool import CallPool, NodePool +from apps.entities.pool import NodePool from apps.models.mongo import MongoDB logger = logging.getLogger("ray") @@ -83,27 +83,13 @@ class NodeManager: raise ValueError(err) from e call_id = node_data.call_id - # 查找Node对应的Call信息 - call_collection = MongoDB().get_collection("call") - call = await call_collection.find_one({"_id": call_id}) - if not call: - err = f"[NodeManager] Call {call_id} not found." - logger.error(err) - raise ValueError(err) - - try: - call_data = CallPool.model_validate(call) - except Exception as e: - err = "[NodeManager] 获取Call数据失败" - logger.exception(err) - raise ValueError(err) from e # 查找Call信息 - logger.info("[NodeManager] 获取Call %s", call_data.path) + logger.info("[NodeManager] 获取Call %s", call_id) pool = ray.get_actor("pool") - call_class: type[BaseModel] = await pool.get_call.remote(call_data.path) + call_class: type[BaseModel] = await pool.get_call.remote(call_id) if not call_class: - err = f"[NodeManager] Call {call_data.path} not found" + err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) raise ValueError(err) diff --git a/apps/routers/record.py b/apps/routers/record.py index 361adc749..b7bcabd54 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -63,7 +63,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ metadata=record.metadata if record.metadata else RecordMetadata( input_tokens=0, output_tokens=0, - time=0, + time_cost=0, ), createdAt=record.created_at, ) diff --git a/apps/routers/service.py b/apps/routers/service.py index 128af89f0..b2c352380 100644 --- a/apps/routers/service.py +++ b/apps/routers/service.py @@ -166,12 +166,12 @@ async def update_service( # noqa: PLR0911 result={}, ).model_dump(exclude_none=True, by_alias=True), ) - except PermissionError as e: + except PermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=ResponseData( code=status.HTTP_403_FORBIDDEN, - message=str(e), + message="UNAUTHORIZED", result={}, ).model_dump(exclude_none=True, by_alias=True), ) @@ -187,6 +187,15 @@ async def update_service( # noqa: PLR0911 ) try: name, apis = await ServiceCenterManager.get_service_apis(service_id) + except ValueError: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="INVALID_SERVICE_ID", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ) except Exception: logger.exception("[ServiceCenter] 获取服务API失败") return JSONResponse( @@ -214,21 +223,21 @@ async def get_service_detail( if edit: try: name, data = await ServiceCenterManager.get_service_data(user_sub, service_id) - except ValueError as e: + except ValueError: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( code=status.HTTP_400_BAD_REQUEST, - message=str(e), + message="INVALID_SERVICE_ID", result={}, ).model_dump(exclude_none=True, by_alias=True), ) - except PermissionError as e: + except PermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=ResponseData( code=status.HTTP_403_FORBIDDEN, - message=str(e), + message="UNAUTHORIZED", result={}, ).model_dump(exclude_none=True, by_alias=True), ) @@ -278,21 +287,21 @@ async def delete_service( """删除服务""" try: await ServiceCenterManager.delete_service(user_sub, service_id) - except ValueError as e: + except ValueError: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( code=status.HTTP_400_BAD_REQUEST, - message=str(e), + message="INVALID_SERVICE_ID", result={}, ).model_dump(exclude_none=True, by_alias=True), ) - except PermissionError as e: + except PermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=ResponseData( code=status.HTTP_403_FORBIDDEN, - message=str(e), + message="UNAUTHORIZED", result={}, ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index 96281e1a3..c85be3758 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -3,6 +3,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from apps.scheduler.call.api import API +from apps.scheduler.call.empty import Empty from apps.scheduler.call.llm import LLM from apps.scheduler.call.rag import RAG @@ -10,4 +11,5 @@ __all__ = [ "API", "LLM", "RAG", + "Empty", ] diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index 41e9f34c8..c267531ea 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -10,27 +10,13 @@ from fastapi import status from pydantic import BaseModel, Field from apps.constants import LOGGER +from apps.entities.enum_var import ContentType, HTTPMethod from apps.entities.scheduler import CallError, CallVars from apps.manager.token import TokenManager from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot -class _APIParams(BaseModel): - """API调用工具的参数""" - - url: str = Field(description="API接口的完整URL") - method: Literal[ - "get", "post", "put", "delete", "patch", - ] = Field(description="API接口的HTTP Method") - content_type: Literal[ - "application/json", "application/x-www-form-urlencoded", "multipart/form-data", - ] = Field(description="API接口的Content-Type") - timeout: int = Field(description="工具超时时间", default=300) - body: dict[str, Any] = Field(description="已知的部分请求体", default={}) - auth: dict[str, Any] = Field(description="API鉴权信息", default={}) - - class APIOutput(BaseModel): """API调用工具的输出""" @@ -45,6 +31,13 @@ class API(CoreCall, ret_type=APIOutput): name: ClassVar[str] = "HTTP请求" description: ClassVar[str] = "向某一个API接口发送HTTP请求,获取数据。" + url: str = Field(description="API接口的完整URL") + method: HTTPMethod = Field(description="API接口的HTTP Method") + content_type: ContentType = Field(description="API接口的Content-Type") + timeout: int = Field(description="工具超时时间", default=300, gt=30) + body: dict[str, Any] = Field(description="已知的部分请求体", default={}) + auth: dict[str, Any] = Field(description="API鉴权信息", default={}) + async def __call__(self, syscall_vars: CallVars, **_kwargs: Any) -> APIOutput: """调用API,然后返回LLM解析后的数据""" @@ -60,12 +53,11 @@ class API(CoreCall, ret_type=APIOutput): async def _make_api_call(self, data: Optional[dict], files: aiohttp.FormData): # noqa: ANN202, C901 # 获取必要参数 - params: _APIParams = getattr(self, "_params") syscall_vars: CallVars = getattr(self, "_syscall_vars") - if params.content_type != "form": + if self.content_type != "form": req_header = { - "Content-Type": "application/json", + "Content-Type": ContentType.JSON.value, } else: req_header = {} @@ -75,36 +67,36 @@ class API(CoreCall, ret_type=APIOutput): if data is None: data = {} - if params.auth is not None and "type" in params.auth: - if params.auth["type"] == "header": - req_header.update(params.auth["args"]) - elif params.auth["type"] == "cookie": - req_cookie.update(params.auth["args"]) - elif params.auth["type"] == "params": - req_params.update(params.auth["args"]) - elif params.auth["type"] == "oidc": + if self.auth is not None and "type" in self.auth: + if self.auth["type"] == "header": + req_header.update(self.auth["args"]) + elif self.auth["type"] == "cookie": + req_cookie.update(self.auth["args"]) + elif self.auth["type"] == "params": + req_params.update(self.auth["args"]) + elif self.auth["type"] == "oidc": token = await TokenManager.get_plugin_token( - params.auth["domain"], + self.auth["domain"], syscall_vars.session_id, - params.auth["access_token_url"], - int(params.auth["token_expire_time"]), + self.auth["access_token_url"], + int(self.auth["token_expire_time"]), ) req_header.update({"access-token": token}) - if params.method in ["get", "delete"]: + if self.method in ["get", "delete"]: req_params.update(data) - return self._session.request(params.method, params.url, params=req_params, headers=req_header, cookies=req_cookie, - timeout=params.timeout) + return self._session.request(self.method, self.url, params=req_params, headers=req_header, cookies=req_cookie, + timeout=self.timeout) - if params.method in ["post", "put", "patch"]: - if self._data_type == "form": + if self.method in ["post", "put", "patch"]: + if self.content_type == "form": form_data = files for key, val in data.items(): form_data.add_field(key, val) - return self._session.request(params.method, params.url, data=form_data, headers=req_header, cookies=req_cookie, - timeout=params.timeout) - return self._session.request(params.method, params.url, json=data, headers=req_header, cookies=req_cookie, - timeout=params.timeout) + return self._session.request(self.method, self.url, data=form_data, headers=req_header, cookies=req_cookie, + timeout=self.timeout) + return self._session.request(self.method, self.url, json=data, headers=req_header, cookies=req_cookie, + timeout=self.timeout) err = "Method not implemented." raise NotImplementedError(err) diff --git a/apps/scheduler/call/empty.py b/apps/scheduler/call/empty.py new file mode 100644 index 000000000..158cd0f9b --- /dev/null +++ b/apps/scheduler/call/empty.py @@ -0,0 +1,23 @@ +"""空Call""" +from typing import Any, ClassVar + +from pydantic import BaseModel + +from apps.entities.scheduler import CallVars +from apps.scheduler.call.core import CoreCall + + +class EmptyData(BaseModel): + """空数据""" + + +class Empty(CoreCall, ret_type=EmptyData): + """空Call""" + + name: ClassVar[str] = "空白" + description: ClassVar[str] = "空白节点,用于占位" + + + async def __call__(self, _syscall_vars: CallVars, **_kwargs: Any) -> EmptyData: + """空Call""" + return EmptyData() diff --git a/apps/scheduler/call/llm.py b/apps/scheduler/call/llm.py index c3e8ea666..bc67c6f9c 100644 --- a/apps/scheduler/call/llm.py +++ b/apps/scheduler/call/llm.py @@ -95,4 +95,5 @@ class LLM(CoreCall, ret_type=LLMNodeOutput): except Exception as e: raise CallError(message=f"大模型调用失败:{e!s}", data={}) from e + result = result.strip().strip("\n") return LLMNodeOutput(message=result) diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py index 776d7c6b1..5ac642928 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag.py @@ -2,6 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import logging from typing import Any, ClassVar, Literal import aiohttp @@ -12,6 +13,8 @@ from apps.common.config import config from apps.entities.scheduler import CallError, CallVars from apps.scheduler.call.core import CoreCall +logger = logging.getLogger("ray") + class RAGOutput(BaseModel): """RAG工具的输出""" @@ -36,7 +39,7 @@ class RAG(CoreCall, ret_type=RAGOutput): "kb_sn": self.knowledge_base, "top_k": self.top_k, "retrieval_mode": self.retrieval_mode, - "question": syscall_vars.question, + "content": syscall_vars.question, } url = config["RAG_HOST"].rstrip("/") + "/chunk/get" @@ -60,7 +63,10 @@ class RAG(CoreCall, ret_type=RAGOutput): return RAGOutput( corpus=corpus, ) + text = await response.text() + logger.error("[RAG] 调用失败:%s", text) + raise CallError( message=f"rag调用失败:{text}", data={ diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index f43eebce6..4bc57791c 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -2,6 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import asyncio import logging from typing import Any, Optional @@ -65,11 +66,12 @@ class Executor(BaseModel): status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), step_id="start", + step_name="开始", ai_summary="", filled_data=self.post_body_app.params, ) # 是否结束运行 - self._stop = False + self._can_continue = True async def _check_cls(self, call_cls: Any) -> bool: @@ -79,14 +81,15 @@ class Executor(BaseModel): flag = False if not hasattr(call_cls, "description") or not isinstance(call_cls.description, str): flag = False - if not hasattr(call_cls, "exec") or not callable(call_cls.exec): + if not callable(call_cls) or not asyncio.iscoroutinefunction(call_cls.__call__): flag = False return flag + # TODO async def _run_error(self, step: FlowError) -> dict[str, Any]: """运行错误处理步骤""" - pass + return {} async def _get_call_cls(self, node_id: str) -> Optional[type[CoreCall]]: @@ -107,15 +110,15 @@ class Executor(BaseModel): # 从Pool中获取对应的Call pool = ray.get_actor("pool") try: - call_cls: type[CoreCall] = await pool.get_call.remote(call_id, self.flow_state.app_id) + call_cls: type[CoreCall] = await pool.get_call.remote(call_id) except Exception: logger.exception("[FlowExecutor] 载入工具%s时发生错误", node_id) self.flow_state.status = StepStatus.ERROR return None # 检查Call合法性 - if not self._check_cls(call_cls): - logger.error("[FlowExecutor] 工具%s不符合Call标准要求", node_id) + if not await self._check_cls(call_cls): + logger.error("[FlowExecutor] 工具 %s 不符合Call标准要求", node_id) self.flow_state.status = StepStatus.ERROR return None @@ -145,7 +148,7 @@ class Executor(BaseModel): # 如果还有未填充的部分,则返回False if remaining_schema: - self._stop = True + self._can_continue = False self.flow_state.status = StepStatus.RUNNING # 推送空输入输出 await push_step_input(self.task.record.task_id, self.queue, self.flow_state, self.flow) @@ -156,6 +159,7 @@ class Executor(BaseModel): return True, slot_data + # TODO async def _get_last_output(self, task: TaskBlock) -> Optional[dict[str, Any]]: """获取上一步的输出""" return None @@ -185,22 +189,19 @@ class Executor(BaseModel): return result_data - async def _run_step(self, step_data: Step) -> None: + async def _run_step(self, step_id: str, step_data: Step) -> None: """运行单个步骤""" logger.info("[FlowExecutor] 运行步骤 %s", step_data.name) # 更新State - self.flow_state.step_id = step_data.name + self.flow_state.step_id = step_id self.flow_state.status = StepStatus.RUNNING - # Call类型为none,跳过执行 + # 特殊类型的Node,跳过执行 node_id = step_data.node - if node_id in ("none", "start", "end"): - return - # 获取并验证Call类 call_cls = await self._get_call_cls(node_id) if call_cls is None: - logger.error("[FlowExecutor] Node%s对应的工具不存在", node_id) + logger.error("[FlowExecutor] Node %s 对应的工具不存在", node_id) return # 准备系统变量 @@ -224,7 +225,7 @@ class Executor(BaseModel): # TODO: 处理slots # can_continue, slot_data = await self._process_slots(call_obj) - # if not can_continue: + # if not self._can_continue: # return # 推送步骤输入 @@ -238,8 +239,19 @@ class Executor(BaseModel): return + async def _handle_last_step(self) -> None: + """处理最后一步""" + # 如果当前步骤为结束,则直接返回 + if self.flow_state.step_id == "end": + return + + async def _handle_next_step(self) -> None: """处理下一步""" + # 如果当前步骤为结束,则直接返回 + if self.flow_state.step_id == "end": + return + next_nodes = [] # 遍历Edges,查找下一个节点 for edge in self.flow.edges: @@ -255,10 +267,16 @@ class Executor(BaseModel): # 处理下一步 if not next_nodes: self.flow_state.step_id = "end" + self.flow_state.step_name = "结束" else: self.flow_state.step_id = next_nodes[0] + self.flow_state.step_name = next_nodes[0] + # self.flow_state.step_name = self.flow.steps[next_nodes[0]].name logger.info("[FlowExecutor] 下一步 %s", self.flow_state.step_id) + # 如果是最后一步,设置停止标志 + if self.flow_state.step_id == "end": + self._can_continue = False async def run(self) -> None: @@ -270,36 +288,21 @@ class Executor(BaseModel): # 推送Flow开始 await push_flow_start(self.task.record.task_id, self.queue, self.flow_state, self.question) - while not self._stop: + while self._can_continue: # 当前步骤不存在 if self.flow_state.step_id not in self.flow.steps: - logger.error("[FlowExecutor] 当前步骤%s不存在", self.flow_state.step_id) - break + logger.error("[FlowExecutor] 当前步骤 %s 不存在", self.flow_state.step_id) + self.flow_state.status = StepStatus.ERROR if self.flow_state.status == StepStatus.ERROR: # 当前步骤为错误处理步骤 logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") step = self.flow.on_error + await self._run_error(step) else: + # 当前步骤为正常步骤 step = self.flow.steps[self.flow_state.step_id] - - # 当前步骤空白 - if not step: - break - - # 判断当前是否为最后一步 - if self.flow_state.step_id == "end": - break - - # 运行步骤 - if isinstance(step, FlowError): - result = await self._run_error(step) - else: - result = await self._run_step(step) - - # 如果停止,则结束执行 - if self._stop: - break + await self._run_step(self.flow_state.step_id, step) # 处理下一步 await self._handle_next_step() diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py index 2a474bd74..29183deaa 100644 --- a/apps/scheduler/pool/check.py +++ b/apps/scheduler/pool/check.py @@ -41,6 +41,7 @@ class FileChecker: return hashes + async def diff_one(self, path: Path, previous_hashes: Optional[dict[str, str]] = None) -> bool: """检查文件是否发生变化""" self._resource_path = path @@ -48,6 +49,7 @@ class FileChecker: self.hashes[path_diff.as_posix()] = await self.check_one(path) return self.hashes[path_diff.as_posix()] != previous_hashes + async def diff(self, check_type: MetadataType) -> tuple[list[str], list[str]]: """生成更新列表和删除列表""" if check_type == MetadataType.APP: diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index b9a47d260..53fdd963f 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -38,13 +38,11 @@ class AppLoader: metadata = await MetadataLoader().load_one(metadata_path) if not metadata: err = f"[AppLoader] 元数据不存在: {metadata_path}" - logger.error(err) raise ValueError(err) metadata.hashes = hashes if not isinstance(metadata, AppMetadata): err = f"[AppLoader] 元数据类型错误: {metadata_path}" - logger.error(err) raise TypeError(err) # 加载工作流 @@ -59,7 +57,6 @@ class AppLoader: flow = await flow_loader.load(app_id, flow_file.stem) if not flow: err = f"[AppLoader] 工作流 {flow_file} 加载失败" - logger.error(err) raise ValueError(err) if not flow.debug: metadata.published = False @@ -100,7 +97,7 @@ class AppLoader: await self.load(app_id, file_checker.hashes[f"{APP_DIR}/{app_id}"]) - async def delete(self, app_id: str) -> None: + async def delete(self, app_id: str, *, is_reload: bool = False) -> None: """删除App,并更新数据库 :param app_id: 应用 ID @@ -131,9 +128,10 @@ class AppLoader: await session.aclose() - app_path = Path(config["SEMANTICS_DIR"]) / APP_DIR / app_id - if await app_path.exists(): - shutil.rmtree(str(app_path), ignore_errors=True) + if not is_reload: + app_path = Path(config["SEMANTICS_DIR"]) / APP_DIR / app_id + if await app_path.exists(): + shutil.rmtree(str(app_path), ignore_errors=True) async def _update_db(self, metadata: AppMetadata) -> None: diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index f3d6d079e..36ee09248 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -47,11 +47,11 @@ class CallLoader: return call_metadata - async def _load_single_call_dir(self, call_name: str) -> list[CallPool]: + async def _load_single_call_dir(self, call_dir_name: str) -> list[CallPool]: """加载单个Call package""" call_metadata = [] - call_dir = Path(config["SEMANTICS_DIR"]) / CALL_DIR / call_name + call_dir = Path(config["SEMANTICS_DIR"]) / CALL_DIR / call_dir_name if not (call_dir / "__init__.py").exists(): LOGGER.info(msg=f"模块{call_dir}不存在__init__.py文件,尝试自动创建。") try: @@ -62,16 +62,16 @@ class CallLoader: # 载入子包 try: - call_package = importlib.import_module("call." + call_name) + call_package = importlib.import_module("call." + call_dir_name) except Exception as e: - err = f"载入模块call.{call_name}失败:{e}。" + err = f"载入模块call.{call_dir_name}失败:{e}。" raise RuntimeError(err) from e - sys.modules["call." + call_name] = call_package + sys.modules["call." + call_dir_name] = call_package # 已载入包,处理包中每个工具 if not hasattr(call_package, "__all__"): - err = f"模块call.{call_name}不符合模块要求,无法处理。" + err = f"模块call.{call_dir_name}不符合模块要求,无法处理。" LOGGER.info(msg=err) raise ValueError(err) @@ -79,11 +79,11 @@ class CallLoader: try: call_cls = getattr(call_package, call_id) except Exception as e: - err = f"载入工具call.{call_name}.{call_id}失败:{e};跳过载入。" + err = f"载入工具call.{call_dir_name}.{call_id}失败:{e};跳过载入。" LOGGER.info(msg=err) continue - cls_path = f"{call_package.service}::call.{call_name}.{call_id}" + cls_path = f"{call_package.service}::call.{call_dir_name}.{call_id}" cls_hash = shake_128(cls_path.encode()).hexdigest(8) call_metadata.append( CallPool( @@ -91,7 +91,7 @@ class CallLoader: type=CallType.PYTHON, name=call_cls.name, description=call_cls.description, - path=f"python::call.{call_name}::{call_id}", + path=f"python::call.{call_dir_name}::{call_id}", ), ) diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index fa7e2bdc8..479c01311 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -25,6 +25,9 @@ class FlowLoader: """从文件系统中加载【单个】工作流""" logger.info("[FlowLoader] 加载工作流 %s 应用 %s...", flow_id, app_id) flow_path = Path(config["SEMANTICS_DIR"]) / "app" / app_id / "flow" / f"{flow_id}.yaml" + if not await flow_path.exists(): + logger.error(f"[FlowLoader] Flow file {flow_path} does not exist.") + return None async with aiofiles.open(flow_path, encoding="utf-8") as f: flow_yaml = yaml.safe_load(await f.read()) @@ -41,7 +44,7 @@ class FlowLoader: if "start" not in flow_yaml["steps"] or "end" not in flow_yaml["steps"]: err = f"工作流必须包含开始和结束节点:{flow_path!s}" logger.error(err) - raise ValueError(err) + return None logger.info("[FlowLoader] 解析工作流 %s 应用 %s 的边...", flow_id, app_id) for edge in flow_yaml["edges"]: @@ -57,7 +60,7 @@ class FlowLoader: except KeyError as e: err = f"Invalid edge type {edge['type']}" logger.exception(err) - raise ValueError(err) from e + return None logger.info("[FlowLoader] 解析工作流 %s 应用 %s 的步骤...", flow_id, app_id) for key, step in flow_yaml["steps"].items(): @@ -121,7 +124,11 @@ class FlowLoader: } async with aiofiles.open(flow_path, mode="w", encoding="utf-8") as f: - await f.write(yaml.dump(flow_dict, allow_unicode=True, sort_keys=False)) + await f.write(yaml.dump( + flow_dict, + allow_unicode=True, + sort_keys=False, + )) async def delete(self, app_id: str, flow_id: str) -> bool: """删除指定工作流文件""" diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py index ffe786908..6f1665bae 100644 --- a/apps/scheduler/pool/loader/metadata.py +++ b/apps/scheduler/pool/loader/metadata.py @@ -61,6 +61,7 @@ class MetadataLoader: return metadata + async def save_one( self, metadata_type: MetadataType, @@ -96,5 +97,10 @@ class MetadataLoader: else: data = metadata - yaml_data = yaml.safe_dump(jsonable_encoder(data)) - await resource_path.write_text(yaml_data) + # 使用UTF-8保存YAML,忽略部分乱码 + yaml_dict = yaml.dump( + jsonable_encoder(data, exclude={"hashes"}), + allow_unicode=True, + sort_keys=False, + ) + await resource_path.write_text(yaml_dict) diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index cf8714c24..ef28db770 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -4,6 +4,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import asyncio +import shutil import ray from anyio import Path @@ -77,7 +78,8 @@ class ServiceLoader: await file_checker.diff_one(service_path) await self.load(service_id, file_checker.hashes[f"{SERVICE_DIR}/{service_id}"]) - async def delete(self, service_id: str) -> None: + + async def delete(self, service_id: str, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" service_collection = MongoDB.get_collection("service") node_collection = MongoDB.get_collection("node") @@ -97,6 +99,11 @@ class ServiceLoader: err = f"[ServiceLoader] 删除数据库失败:{e}" LOGGER.error(err) + if not is_reload: + path = Path(config["SEMANTICS_DIR"]) / SERVICE_DIR / service_id + if await path.exists(): + shutil.rmtree(path) + await session.aclose() async def _update_db(self, nodes: list[NodePool], metadata: ServiceMetadata) -> None: diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index b9b76e2fa..45d91c7bf 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -12,7 +12,7 @@ from anyio import Path from apps.constants import APP_DIR, SERVICE_DIR from apps.entities.enum_var import MetadataType from apps.entities.flow import Flow -from apps.entities.pool import AppFlow +from apps.entities.pool import AppFlow, CallPool from apps.models.mongo import MongoDB from apps.scheduler.pool.check import FileChecker from apps.scheduler.pool.loader import ( @@ -46,7 +46,7 @@ class Pool: # 批量删除 for service in changed_service: - await service_loader.delete(service) + await service_loader.delete(service, is_reload=True) for service in deleted_service: await service_loader.delete(service) @@ -63,7 +63,7 @@ class Pool: # 批量删除App for app in changed_app: - await app_loader.delete(app) + await app_loader.delete(app, is_reload=True) for app in deleted_app: await app_loader.delete(app) @@ -74,6 +74,7 @@ class Pool: await app_loader.load(app, checker.hashes[hash_key]) + # TODO async def save(self, *, is_deletion: bool = False) -> None: """保存【单个】资源""" pass @@ -102,11 +103,20 @@ class Pool: return await flow_loader.load(app_id, flow_id) - async def get_call(self, call_path: str) -> Any: - """拿到Call的信息""" - call_path_split = call_path.split("::") + async def get_call(self, call_id: str) -> Any: + """[Exception] 拿到Call的信息""" + # 从MongoDB里拿到数据 + call_collection = MongoDB.get_collection("call") + call_db_data = await call_collection.find_one({"_id": call_id}) + if not call_db_data: + err = f"[Pool] Call{call_id}不存在" + logger.error(err) + raise ValueError(err) + + call_metadata = CallPool.model_validate(call_db_data) + call_path_split = call_metadata.path.split("::") if not call_path_split: - err = f"[Pool] Call路径{call_path}不合法" + err = f"[Pool] Call路径{call_metadata.path}不合法" logger.error(err) raise ValueError(err) @@ -116,7 +126,7 @@ class Pool: call_module = importlib.import_module(call_path_split[1]) return getattr(call_module, call_path_split[2]) except Exception as e: - err = f"[Pool] 获取Call{call_path}类失败" + err = f"[Pool] 获取Call{call_metadata.path}类失败" logger.exception(err) raise RuntimeError(err) from e return None diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index f57e107ed..94aa8c3c1 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -107,7 +107,7 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do await TaskManager.create_flows(history_data) # 修改metadata里面时间为实际运行时间 - task.record.metadata.time = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time, 2) + task.record.metadata.time_cost = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time_cost, 2) # 提取facts # 记忆提取 diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 357525e8c..c76ce728d 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -2,7 +2,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import asyncio import logging from typing import Optional @@ -34,18 +33,21 @@ class FlowChooser: async def get_top_flow(self) -> str: """获取Top1 Flow""" - logger.info("[FlowChooser] 判断任务 %s 的Top1 Flow...", self._task_id) pool = ray.get_actor("pool") # 获取所选应用的所有Flow if not self._user_selected or not self._user_selected.app_id: return "KnowledgeBase" - flow_list = await asyncio.gather(pool.get_flow_list.remote(self._user_selected.app_id))[0] + flow_list = await pool.get_flow_metadata.remote(self._user_selected.app_id) if not flow_list: return "KnowledgeBase" - logger.info("[FlowChooser] 选择任务 %s 的Top1 Flow...", self._task_id) - return await Select().generate(self._task_id, question=self._question, choices=flow_list) + logger.info("[FlowChooser] 选择任务 %s 最合适的Flow", self._task_id) + choices = [{ + "name": flow.id, + "description": f"{flow.name}, {flow.description}", + } for flow in flow_list] + return await Select().generate(self._task_id, question=self._question, choices=choices) async def choose_flow(self) -> Optional[RequestDataApp]: diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 38184ec29..d8fbc4c61 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -47,7 +47,7 @@ async def push_init_message(task_id: str, queue: actor.ActorHandle, context_num: # 保存必要信息到Task created_at = round(datetime.now(timezone.utc).timestamp(), 2) - task.record.metadata.time = created_at + task.record.metadata.time_cost = created_at task.record.metadata.feature = feature.model_dump(exclude_none=True, by_alias=True) # 推送初始化消息 -- Gitee