diff --git a/apps/models/flow.py b/apps/models/flow.py index 60be2e1efbd36757bcfb061ea7bdc6981b8fd801..43e37992367b45589130986ea575819c55d84b18 100644 --- a/apps/models/flow.py +++ b/apps/models/flow.py @@ -14,6 +14,8 @@ class Flow(Base): """Flow""" __tablename__ = "framework_flow" + id: Mapped[str] = mapped_column(String(255), primary_key=True) + """Flow的ID""" appId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_app.id"), @@ -31,10 +33,6 @@ class Flow(Base): """是否经过调试""" enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) """是否启用""" - id: Mapped[uuid.UUID] = mapped_column( - UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4, - ) - """Flow的ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), diff --git a/apps/models/llm.py b/apps/models/llm.py index e67a6b04699f1c516ad72719cd7b972e9bb4abde..267f12bbdd324541b65fbf530c319cacf8aa9163 100644 --- a/apps/models/llm.py +++ b/apps/models/llm.py @@ -1,10 +1,8 @@ """大模型信息 数据库表""" -import uuid from datetime import UTC, datetime from sqlalchemy import DateTime, Integer, String -from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column from apps.common.config import config @@ -17,7 +15,7 @@ class LLMData(Base): """大模型信息""" __tablename__ = "framework_llm" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) + id: Mapped[str] = mapped_column(String(255), primary_key=True) """大模型ID""" icon: Mapped[str] = mapped_column(String(1000), default=llm_provider_dict["ollama"]["icon"], nullable=False) """LLM图标路径""" diff --git a/apps/models/mcp.py b/apps/models/mcp.py index c573585d50d2faa3f3c1b1e7ad72b17dc3071319..84db4bc0f940123425581483906b6ac4a1b8ee3f 100644 --- a/apps/models/mcp.py +++ b/apps/models/mcp.py @@ -43,7 +43,7 @@ class MCPInfo(Base): """MCP 描述""" author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) """MCP 创建者""" - id: Mapped[str] = mapped_column(String(100), primary_key=True) + id: Mapped[str] = mapped_column(String(255), primary_key=True) """MCP ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), @@ -68,7 +68,7 @@ class MCPActivated(Base): __tablename__ = "framework_mcp_activated" mcpId: Mapped[str] = mapped_column( # noqa: N815 - String(100), ForeignKey("framework_mcp.id"), index=True, nullable=False, + String(255), ForeignKey("framework_mcp.id"), index=True, nullable=False, ) """MCP ID""" userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 @@ -82,7 +82,7 @@ class MCPTools(Base): __tablename__ = "framework_mcp_tools" mcpId: Mapped[str] = mapped_column( # noqa: N815 - String(100), ForeignKey("framework_mcp.id"), index=True, nullable=False, + String(255), ForeignKey("framework_mcp.id"), index=True, nullable=False, ) """MCP ID""" toolName: Mapped[str] = mapped_column(String(255), nullable=False) # noqa: N815 diff --git a/apps/models/node.py b/apps/models/node.py index f89f6b1acca859b299c2cb53910792758b31cdac..f477e4e6f4864476a67df02fcfadeffbb1c16559 100644 --- a/apps/models/node.py +++ b/apps/models/node.py @@ -1,11 +1,10 @@ """节点 数据库表""" -import uuid from datetime import UTC, datetime from typing import Any from sqlalchemy import DateTime, ForeignKey, String -from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -19,8 +18,8 @@ class NodeInfo(Base): """节点名称""" description: Mapped[str] = mapped_column(String(2000), nullable=False) """节点描述""" - serviceId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 - UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=True, + serviceId: Mapped[str | None] = mapped_column( # noqa: N815 + String(255), ForeignKey("framework_service.id"), nullable=True, ) """所属服务ID""" callId: Mapped[str] = mapped_column(String(50), nullable=False) # noqa: N815 @@ -31,7 +30,7 @@ class NodeInfo(Base): """Node的输入Schema;用于描述Call的参数中特定的字段""" overrideOutput: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) # noqa: N815 """Node的输出Schema;用于描述Call的输出中特定的字段""" - id: Mapped[str] = mapped_column(String(50), primary_key=True, default_factory=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column(String(255), primary_key=True) """节点ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), diff --git a/apps/models/record.py b/apps/models/record.py index 9a53c44b8b40fb333a21a8e94434d298c4a10669..7b3bec8ddd010b0d3c94ac3cdd0d4d4a9b3c0e22 100644 --- a/apps/models/record.py +++ b/apps/models/record.py @@ -44,7 +44,7 @@ class RecordMetadata(Base): """问答对元数据""" __tablename__ = "framework_record_metadata" - record_id: Mapped[uuid.UUID] = mapped_column( + recordId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_record.id"), primary_key=True, ) """问答对ID""" @@ -72,7 +72,7 @@ class RecordFootNote(Base): __tablename__ = "framework_record_foot_note" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" - record_id: Mapped[uuid.UUID] = mapped_column( + recordId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_record.id"), nullable=False, ) """问答对ID""" diff --git a/apps/models/service.py b/apps/models/service.py index 390f93553fda02b0a4c977a4cd0b3629886109f9..8fe23f8461dcd43f5a1f7ac9d8a75ada149c2acb 100644 --- a/apps/models/service.py +++ b/apps/models/service.py @@ -16,6 +16,8 @@ class Service(Base): """插件""" __tablename__ = "framework_service" + id: Mapped[str] = mapped_column(String(255), primary_key=True) + """插件ID""" name: Mapped[str] = mapped_column(String(255), nullable=False) """插件名称""" description: Mapped[str] = mapped_column(String(2000), nullable=False) @@ -29,8 +31,6 @@ class Service(Base): nullable=False, ) """插件更新时间""" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) - """插件ID""" iconPath: Mapped[str] = mapped_column(String(255), default="", nullable=False) # noqa: N815 """插件图标路径""" permission: Mapped[PermissionType] = mapped_column( @@ -43,7 +43,7 @@ class ServiceACL(Base): """插件权限""" __tablename__ = "framework_service_acl" - serviceId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 + serviceId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 """关联的插件ID""" userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 """用户名""" @@ -59,7 +59,7 @@ class ServiceHashes(Base): __tablename__ = "framework_service_hashes" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" - serviceId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 + serviceId: Mapped[str] = mapped_column(String(255), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 """关联的插件ID""" filePath: Mapped[str] = mapped_column(String(255), nullable=False) # noqa: N815 """文件路径""" diff --git a/apps/models/task.py b/apps/models/task.py index e089fc9bec035502cc7f2f9e8226473b533e3325..ad5c96f6ad75e4774f986e2fe22e9c46053ff455 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -8,7 +8,7 @@ from sqlalchemy import DateTime, Enum, Float, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column -from apps.schemas.enum_var import FlowStatus, Language, StepStatus +from apps.schemas.enum_var import ExecutorStatus, Language, StepStatus from .base import Base @@ -64,7 +64,7 @@ class TaskRuntime(Base): """中间推理""" filledSlot: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False, default={}) # noqa: N815 """计划""" - document: Mapped[list[dict[str, Any]]] = mapped_column(JSONB, nullable=False, default={}) + document: Mapped[list[dict[uuid.UUID, Any]]] = mapped_column(JSONB, nullable=False, default={}) """关联文档""" language: Mapped[Language] = mapped_column(Enum(Language), nullable=False, default=Language.ZH) """语言""" @@ -80,11 +80,11 @@ class ExecutorCheckpoint(Base): """任务ID""" appId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) # noqa: N815 """应用ID""" - executorId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) # noqa: N815 + executorId: Mapped[str] = mapped_column(String(255), nullable=False) # noqa: N815 """执行器ID(例如工作流ID)""" executorName: Mapped[str] = mapped_column(String(255), nullable=False) # noqa: N815 """执行器名称(例如工作流名称)""" - executorStatus: Mapped[FlowStatus] = mapped_column(Enum(FlowStatus), nullable=False) # noqa: N815 + executorStatus: Mapped[ExecutorStatus] = mapped_column(Enum(ExecutorStatus), nullable=False) # noqa: N815 """执行器状态""" # 步骤级数据 stepId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) # noqa: N815 @@ -101,6 +101,8 @@ class ExecutorCheckpoint(Base): """步骤描述""" data: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False, default={}) """步骤额外数据""" + errorMessage: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False, default={}) # noqa: N815 + """错误信息""" class ExecutorHistory(Base): @@ -110,11 +112,11 @@ class ExecutorHistory(Base): taskId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_task.id"), nullable=False) # noqa: N815 """任务ID""" - executorId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) # noqa: N815 + executorId: Mapped[str] = mapped_column(String(255), nullable=False) # noqa: N815 """执行器ID(例如工作流ID)""" executorName: Mapped[str] = mapped_column(String(255)) # noqa: N815 """执行器名称(例如工作流名称)""" - executorStatus: Mapped[FlowStatus] = mapped_column(Enum(FlowStatus), nullable=False) # noqa: N815 + executorStatus: Mapped[ExecutorStatus] = mapped_column(Enum(ExecutorStatus), nullable=False) # noqa: N815 """执行器状态""" stepId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), nullable=False) # noqa: N815 """步骤ID""" diff --git a/apps/models/vectors.py b/apps/models/vectors.py index be2f6c53a3cb4a3705374d38925cd646913ffef0..266901dbeeb7e56e207ec72cdd3e9906fb860e58 100644 --- a/apps/models/vectors.py +++ b/apps/models/vectors.py @@ -14,11 +14,11 @@ class FlowPoolVector(Base): """Flow向量数据""" __tablename__ = "framework_flow_vector" - appId: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_app.id"), nullable=False) # noqa: N815 + appId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False) # noqa: N815 """所属App的ID""" embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) """向量数据""" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_flow.id"), primary_key=True) + id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_flow.id"), primary_key=True) """Flow的ID""" __table_args__ = ( Index( @@ -37,7 +37,7 @@ class ServicePoolVector(Base): __tablename__ = "framework_service_vector" embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) """向量数据""" - id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), primary_key=True) + id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_service.id"), primary_key=True) """Service的ID""" __table_args__ = ( Index( @@ -56,9 +56,9 @@ class NodePoolVector(Base): __tablename__ = "framework_node_vector" embedding: Mapped[list[float]] = mapped_column(Vector(1024), nullable=False) """向量数据""" - serviceId: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("framework_service.id"), nullable=True) # noqa: N815 + serviceId: Mapped[str | None] = mapped_column(String(255), ForeignKey("framework_service.id"), nullable=True) # noqa: N815 """Service的ID""" - id: Mapped[str] = mapped_column(String(36), ForeignKey("framework_node.id"), primary_key=True) + id: Mapped[str] = mapped_column(String(255), ForeignKey("framework_node.id"), primary_key=True) """Node的ID""" __table_args__ = ( Index( diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 2498da814950ae840b493d03a858c96333661377..70491d7281c4339d00699620f400b42ed9078325 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -15,7 +15,7 @@ from apps.common.wordscheck import WordsCheck from apps.dependency import verify_personal_token, verify_session from apps.scheduler.scheduler import Scheduler from apps.scheduler.scheduler.context import save_data -from apps.schemas.enum_var import FlowStatus +from apps.schemas.enum_var import ExecutorStatus from apps.schemas.request_data import RequestData, RequestDataApp from apps.schemas.response_data import ResponseData from apps.schemas.task import Task @@ -100,7 +100,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 获取最终答案 task = scheduler.task - if task.state.flow_status == FlowStatus.ERROR: + if task.state.flow_status == ExecutorStatus.ERROR: logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" await Activity.remove_active(user_sub) diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index ed58a88f625db21afe87e11a4aab8c24cfba362f..127498dc7626e631df3e14c47a9f2fd4a8541b2f 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -103,7 +103,7 @@ class CoreCall(BaseModel): history = {} history_order = [] for item in executor.task.context: - item_obj = FlowStepHistory.model_validate(item) + item_obj = ExecutorHistory.model_validate(item) history[item_obj.step_id] = item_obj history_order.append(item_obj.step_id) @@ -141,7 +141,7 @@ class CoreCall(BaseModel): logger.error(err) return None - data = history[split_path[0]].output_data + data = history[split_path[0]].outputData for key in split_path[1:]: if key not in data: err = f"[CoreCall] 输出Key {key} 不存在" diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 663f6827471308954affb28c450e33040ee105cb..e85091c9ce5c7f6030013d9e470b9ae785d68ee7 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -13,7 +13,7 @@ from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.host import MCPHost from apps.scheduler.mcp_agent.plan import MCPPlanner from apps.scheduler.pool.mcp.pool import MCPPool -from apps.schemas.enum_var import EventType, FlowStatus, StepStatus +from apps.schemas.enum_var import EventType, ExecutorStatus, StepStatus from apps.schemas.mcp import ( MCPCollection, MCPTool, @@ -70,7 +70,7 @@ class MCPAgentExecutor(BaseExecutor): """从数据库中加载FlowExecutor的状态""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state and self.task.state.flow_status != FlowStatus.INIT: + if self.task.state and self.task.state.flow_status != ExecutorStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) async def load_mcp(self) -> None: @@ -137,7 +137,7 @@ class MCPAgentExecutor(BaseExecutor): EventType.STEP_WAITING_FOR_START, confirm_message.model_dump(exclude_none=True, by_alias=True), ) await self.push_message(EventType.FLOW_STOP, {}) - self.task.state.flow_status = FlowStatus.WAITING + self.task.state.flow_status = ExecutorStatus.WAITING self.task.state.step_status = StepStatus.WAITING self.task.context.append( FlowStepHistory( @@ -157,7 +157,7 @@ class MCPAgentExecutor(BaseExecutor): async def run_step(self) -> None: """执行步骤""" - self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.flow_status = ExecutorStatus.RUNNING self.task.state.step_status = StepStatus.RUNNING mcp_tool = self.tools[self.task.state.tool_id] mcp_client = await self.mcp_pool.get(mcp_tool.mcp_id, self.task.ids.user_sub) @@ -231,7 +231,7 @@ class MCPAgentExecutor(BaseExecutor): EventType.STEP_WAITING_FOR_PARAM, data={"message": error_message, "params": params_with_null}, ) await self.push_message(EventType.FLOW_STOP, data={}) - self.task.state.flow_status = FlowStatus.WAITING + self.task.state.flow_status = ExecutorStatus.WAITING self.task.state.step_status = StepStatus.PARAM self.task.context.append( FlowStepHistory( @@ -284,7 +284,7 @@ class MCPAgentExecutor(BaseExecutor): async def error_handle_after_step(self) -> None: """步骤执行失败后的错误处理""" self.task.state.step_status = StepStatus.ERROR - self.task.state.flow_status = FlowStatus.ERROR + self.task.state.flow_status = ExecutorStatus.ERROR await self.push_message(EventType.FLOW_FAILED, data={}) if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: del self.task.context[-1] @@ -326,7 +326,7 @@ class MCPAgentExecutor(BaseExecutor): if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: del self.task.context[-1] else: - self.task.state.flow_status = FlowStatus.CANCELLED + self.task.state.flow_status = ExecutorStatus.CANCELLED self.task.state.step_status = StepStatus.CANCELLED await self.push_message(EventType.STEP_CANCEL, data={}) await self.push_message(EventType.FLOW_CANCEL, data={}) @@ -402,7 +402,7 @@ class MCPAgentExecutor(BaseExecutor): # 初始化MCP服务 await self.load_state() await self.load_mcp() - if self.task.state.flow_status == FlowStatus.INIT: + if self.task.state.flow_status == ExecutorStatus.INIT: # 初始化状态 try: self.task.state.flow_id = str(uuid.uuid4()) @@ -411,20 +411,20 @@ class MCPAgentExecutor(BaseExecutor): await self.get_next_step() except Exception as e: logger.exception("[MCPAgentExecutor] 初始化失败") - self.task.state.flow_status = FlowStatus.ERROR + self.task.state.flow_status = ExecutorStatus.ERROR self.task.state.error_message = str(e) await self.push_message(EventType.FLOW_FAILED, data={}) return - self.task.state.flow_status = FlowStatus.RUNNING + self.task.state.flow_status = ExecutorStatus.RUNNING await self.push_message(EventType.FLOW_START, data={}) if self.task.state.tool_id == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 - self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.flow_status = ExecutorStatus.SUCCESS await self.push_message(EventType.FLOW_SUCCESS, data={}) await self.summarize() return try: - while self.task.state.flow_status == FlowStatus.RUNNING: + while self.task.state.flow_status == ExecutorStatus.RUNNING: if self.task.state.tool_id == FINAL_TOOL_ID: break await self.work() @@ -432,13 +432,13 @@ class MCPAgentExecutor(BaseExecutor): tool_id = self.task.state.tool_id if tool_id == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 - self.task.state.flow_status = FlowStatus.SUCCESS + self.task.state.flow_status = ExecutorStatus.SUCCESS self.task.state.step_status = StepStatus.SUCCESS await self.push_message(EventType.FLOW_SUCCESS, data={}) await self.summarize() except Exception as e: logger.exception("[MCPAgentExecutor] 执行过程中发生错误") - self.task.state.flow_status = FlowStatus.ERROR + self.task.state.flow_status = ExecutorStatus.ERROR self.task.state.error_message = str(e) self.task.state.step_status = StepStatus.ERROR await self.push_message(EventType.STEP_ERROR, data={}) diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 12fd5951f3cc22bdcfaf2822afb76f6bcd3245a7..44751760ad2f3116c9a0ec80b41fd4401fda3b55 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -3,15 +3,17 @@ import logging from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, ConfigDict -from apps.common.queue import MessageQueue -from apps.models.task import Task from apps.schemas.enum_var import EventType from apps.schemas.message import TextAddContent -from apps.schemas.scheduler import ExecutorBackground + +if TYPE_CHECKING: + from apps.common.queue import MessageQueue + from apps.models.task import ExecutorCheckpoint, Task, TaskRuntime + from apps.schemas.scheduler import ExecutorBackground logger = logging.getLogger(__name__) @@ -19,9 +21,11 @@ logger = logging.getLogger(__name__) class BaseExecutor(BaseModel, ABC): """Executor基类""" - task: Task - msg_queue: MessageQueue - background: ExecutorBackground + task: "Task" + runtime: "TaskRuntime" + state: "ExecutorCheckpoint" + msg_queue: "MessageQueue" + background: "ExecutorBackground" question: str model_config = ConfigDict( @@ -29,14 +33,6 @@ class BaseExecutor(BaseModel, ABC): extra="allow", ) - @staticmethod - def validate_flow_state(task: Task) -> None: - """验证flow_state是否存在""" - if not task.state: - err = "[Executor] 当前ExecutorState为空" - logger.error(err) - raise ValueError(err) - async def push_message(self, event_type: str, data: dict[str, Any] | str | None = None) -> None: """ 统一的消息推送接口 @@ -55,7 +51,7 @@ class BaseExecutor(BaseModel, ABC): await self.msg_queue.push_output( self.task, event_type=event_type, - data=data, # type: ignore[arg-type] + data=data, ) @abstractmethod diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 286c4df0167ee31f4a9210b8117b2f3b1e183f40..aadb0f26c5e6629134342c941c62c41641efcc63 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -8,11 +8,12 @@ from datetime import UTC, datetime from pydantic import Field +from apps.models.task import ExecutorHistory, Task from apps.scheduler.call.llm.prompt import LLM_ERROR_PROMPT -from apps.schemas.enum_var import EventType, FlowStatus, SpecialCallType, StepStatus +from apps.schemas.enum_var import EventType, ExecutorStatus, SpecialCallType, StepStatus from apps.schemas.flow import Flow, Step from apps.schemas.request_data import RequestDataApp -from apps.schemas.task import ExecutorState, StepQueueItem +from apps.schemas.task import StepQueueItem from apps.services.task import TaskManager from .base import BaseExecutor @@ -47,21 +48,25 @@ class FlowExecutor(BaseExecutor): flow_id: str = Field(description="Flow ID") question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") - current_step: StepQueueItem = Field(description="当前执行的步骤") + current_step: StepQueueItem | None = Field( + description="当前执行的步骤", + exclude=True, + default=None, + ) - async def load_state(self) -> None: - """从数据库中加载FlowExecutor的状态""" + async def init(self) -> None: + """初始化FlowExecutor""" logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State - if self.task.state and self.task.state.flow_status != FlowStatus.INIT: + if self.task.state and self.task.state.flow_status != ExecutorStatus.INIT: self.task.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState - self.task.state = ExecutorState( - flow_id=str(self.flow_id), - flow_name=self.flow.name, - flow_status=FlowStatus.RUNNING, + self.state = ExecutorHistory( + executorId=str(self.flow_id), + executorName=self.flow.name, + executorStatus=ExecutorStatus.RUNNING, description=str(self.flow.description), step_status=StepStatus.RUNNING, app_id=str(self.post_body_app.app_id), @@ -83,6 +88,8 @@ class FlowExecutor(BaseExecutor): step=self.current_step, background=self.background, question=self.question, + runtime=self.runtime, + state=self.state, ) # 初始化步骤 @@ -106,7 +113,7 @@ class FlowExecutor(BaseExecutor): await self._invoke_runner() - async def _find_next_id(self, step_id: str) -> list[str]: + async def _find_next_id(self, step_id: uuid.UUID) -> list[uuid.UUID]: """查找下一个节点""" next_ids = [] for edge in self.flow.edges: @@ -118,19 +125,19 @@ class FlowExecutor(BaseExecutor): async def _find_flow_next(self) -> list[StepQueueItem]: """在当前步骤执行前,尝试获取下一步""" # 如果当前步骤为结束,则直接返回 - if self.task.state.step_id == "end" or not self.task.state.step_id: # type: ignore[arg-type] + if self.state.stepId == "end" or not self.state.stepId: return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID branch_id = self.task.context[-1].output_data["branch_id"] if branch_id: - next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) + next_steps = await self._find_next_id(self.state.stepId + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) else: logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") return [] else: - next_steps = await self._find_next_id(self.task.state.step_id) # type: ignore[arg-type] + next_steps = await self._find_next_id(self.state.stepId) # 如果step没有任何出边,直接跳到end if not next_steps: return [ @@ -162,14 +169,14 @@ class FlowExecutor(BaseExecutor): # 获取首个步骤 first_step = StepQueueItem( - step_id=self.task.state.step_id, # type: ignore[arg-type] - step=self.flow.steps[self.task.state.step_id], # type: ignore[arg-type] + step_id=self.state.stepId, + step=self.flow.steps[self.state.stepId], ) # 头插开始前的系统步骤,并执行 for step in FIXED_STEPS_BEFORE_START: self.step_queue.append(StepQueueItem( - step_id=str(uuid.uuid4()), + step_id=uuid.uuid4(), step=step, enable_filling=False, to_user=False, @@ -178,13 +185,13 @@ class FlowExecutor(BaseExecutor): # 插入首个步骤 self.step_queue.append(first_step) - self.task.state.flow_status = FlowStatus.RUNNING # type: ignore[arg-type] + self.state.executorStatus = ExecutorStatus.RUNNING # 运行Flow(未达终点) is_error = False while not self._reached_end: # 如果当前步骤出错,执行错误处理步骤 - if self.task.state.step_status == StepStatus.ERROR: # type: ignore[arg-type] + if self.state.stepStatus == StepStatus.ERROR: logger.warning("[FlowExecutor] Executor出错,执行错误处理步骤") self.step_queue.clear() self.step_queue.appendleft(StepQueueItem( @@ -197,7 +204,7 @@ class FlowExecutor(BaseExecutor): params={ "user_prompt": LLM_ERROR_PROMPT.replace( "{{ error_info }}", - self.task.state.error_info["err_msg"], # type: ignore[arg-type] + self.task.state.error_info["err_msg"], ), }, ), @@ -221,20 +228,20 @@ class FlowExecutor(BaseExecutor): # 更新Task状态 if is_error: - self.task.state.flow_status = FlowStatus.ERROR # type: ignore[arg-type] + self.state.executorStatus = ExecutorStatus.ERROR else: - self.task.state.flow_status = FlowStatus.SUCCESS # type: ignore[arg-type] + self.state.executorStatus = ExecutorStatus.SUCCESS # 尾插运行结束后的系统步骤 for step in FIXED_STEPS_AFTER_END: self.step_queue.append(StepQueueItem( - step_id=str(uuid.uuid4()), + step_id=uuid.uuid4(), step=step, )) await self._step_process() # FlowStop需要返回总时间,需要倒推最初的开始时间(当前时间减去当前已用总时间) - self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.full_time + self.runtime.time = round(datetime.now(UTC).timestamp(), 2) - self.runtime.fullTime # 推送Flow停止消息 if is_error: await self.push_message(EventType.FLOW_FAILED.value) diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 233c8999af5aa882b23b3f962a20b6a57a473f11..cda0b7082723979cde273cd01791f569d9d7fff0 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -6,10 +6,11 @@ import logging import uuid from collections.abc import AsyncGenerator from datetime import UTC, datetime -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import ConfigDict +from apps.models.task import ExecutorHistory from apps.scheduler.call.core import CoreCall from apps.scheduler.call.empty import Empty from apps.scheduler.call.facts.facts import FactsCall @@ -24,29 +25,27 @@ from apps.schemas.enum_var import ( ) from apps.schemas.message import TextAddContent from apps.schemas.scheduler import CallError, CallOutputChunk -from apps.schemas.task import FlowStepHistory, StepQueueItem from apps.services.node import NodeManager +from apps.services.task import TaskManager from .base import BaseExecutor +if TYPE_CHECKING: + from apps.schemas.task import StepQueueItem + logger = logging.getLogger(__name__) class StepExecutor(BaseExecutor): """工作流中步骤相关函数""" - step: StepQueueItem + step: "StepQueueItem" model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", ) - def __init__(self, **kwargs: Any) -> None: - """初始化""" - super().__init__(**kwargs) - self.validate_flow_state(self.task) - @staticmethod async def check_cls(call_cls: Any) -> bool: """检查Call是否符合标准要求""" @@ -87,8 +86,9 @@ class StepExecutor(BaseExecutor): logger.info("[StepExecutor] 初始化步骤 %s", self.step.step.name) # State写入ID和运行状态 - self.task.state.step_id = self.step.step_id # type: ignore[arg-type] - self.task.state.step_name = self.step.step.name # type: ignore[arg-type] + self.state.stepId = self.step.step_id + self.state.stepDescription = self.step.step.description + self.state.stepName = self.step.step.name # 获取并验证Call类 node_id = self.step.step.node @@ -128,14 +128,14 @@ class StepExecutor(BaseExecutor): return # 暂存旧数据 - current_step_id = self.task.state.step_id # type: ignore[arg-type] - current_step_name = self.task.state.step_name # type: ignore[arg-type] + current_step_id = self.state.stepId + current_step_name = self.state.stepName # 更新State - self.task.state.step_id = str(uuid.uuid4()) # type: ignore[arg-type] - self.task.state.step_name = "自动参数填充" # type: ignore[arg-type] - self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] - self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) + self.state.stepId = uuid.uuid4() + self.state.stepName = "自动参数填充" + self.state.stepStatus = StepStatus.RUNNING + self.runtime.time = round(datetime.now(UTC).timestamp(), 2) # 初始化填参 slot_obj = await Slot.instance( @@ -152,24 +152,30 @@ class StepExecutor(BaseExecutor): iterator = slot_obj.exec(self, slot_obj.input) async for chunk in iterator: result: SlotOutput = SlotOutput.model_validate(chunk.content) - self.task.tokens.input_tokens += slot_obj.tokens.input_tokens - self.task.tokens.output_tokens += slot_obj.tokens.output_tokens + await TaskManager.update_task_token( + self.task.id, + input_token=slot_obj.tokens.input_tokens, + output_token=slot_obj.tokens.output_tokens, + ) # 如果没有填全,则状态设置为待填参 if result.remaining_schema: - self.task.state.step_status = StepStatus.PARAM # type: ignore[arg-type] + self.state.stepStatus = StepStatus.PARAM else: - self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] + self.state.stepStatus = StepStatus.SUCCESS await self.push_message(EventType.STEP_OUTPUT.value, result.model_dump(by_alias=True, exclude_none=True)) # 更新输入 self.obj.input.update(result.slot_data) # 恢复State - self.task.state.step_id = current_step_id # type: ignore[arg-type] - self.task.state.step_name = current_step_name # type: ignore[arg-type] - self.task.tokens.input_tokens += self.obj.tokens.input_tokens - self.task.tokens.output_tokens += self.obj.tokens.output_tokens + self.state.stepId = current_step_id + self.state.stepName = current_step_name + await TaskManager.update_task_token( + self.task.id, + input_token=self.obj.tokens.input_tokens, + output_token=self.obj.tokens.output_tokens, + ) async def _process_chunk( @@ -197,7 +203,7 @@ class StepExecutor(BaseExecutor): if to_user: if isinstance(chunk.content, str): await self.push_message(EventType.TEXT_ADD.value, chunk.content) - self.task.runtime.answer += chunk.content + self.runtime.fullAnswer += chunk.content else: await self.push_message(self.step.step.type, chunk.content) @@ -206,15 +212,14 @@ class StepExecutor(BaseExecutor): async def run(self) -> None: """运行单个步骤""" - self.validate_flow_state(self.task) logger.info("[StepExecutor] 运行步骤 %s", self.step.step.name) # 进行自动参数填充 await self._run_slot_filling() # 更新状态 - self.task.state.step_status = StepStatus.RUNNING # type: ignore[arg-type] - self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) + self.state.stepStatus = StepStatus.RUNNING + self.runtime.time = round(datetime.now(UTC).timestamp(), 2) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -225,25 +230,27 @@ class StepExecutor(BaseExecutor): content = await self._process_chunk(iterator, to_user=self.obj.to_user) except Exception as e: logger.exception("[StepExecutor] 运行步骤失败,进行异常处理步骤") - self.task.state.step_status = StepStatus.ERROR # type: ignore[arg-type] + self.state.stepStatus = StepStatus.ERROR await self.push_message(EventType.STEP_OUTPUT.value, {}) if isinstance(e, CallError): - self.task.state.error_info = { # type: ignore[arg-type] + self.state.errorMessage = { "err_msg": e.message, "data": e.data, } else: - self.task.state.error_info = { # type: ignore[arg-type] - "err_msg": str(e), + self.state.errorMessage = { "data": {}, } return # 更新执行状态 - self.task.state.step_status = StepStatus.SUCCESS # type: ignore[arg-type] - self.task.tokens.input_tokens += self.obj.tokens.input_tokens - self.task.tokens.output_tokens += self.obj.tokens.output_tokens - self.task.tokens.full_time += round(datetime.now(UTC).timestamp(), 2) - self.task.tokens.time + self.state.stepStatus = StepStatus.SUCCESS + await TaskManager.update_task_token( + self.task.id, + input_token=self.obj.tokens.input_tokens, + output_token=self.obj.tokens.output_tokens, + ) + self.runtime.fullTime = round(datetime.now(UTC).timestamp(), 2) - self.runtime.time # 更新history if isinstance(content, str): @@ -252,16 +259,17 @@ class StepExecutor(BaseExecutor): output_data = content # 更新context - history = FlowStepHistory( - task_id=self.task.id, - flow_id=self.task.state.flow_id, # type: ignore[arg-type] - flow_name=self.task.state.flow_name, # type: ignore[arg-type] - step_id=self.step.step_id, - step_name=self.step.step.name, - step_description=self.step.step.description, - step_status=self.task.state.step_status, - input_data=self.obj.input, - output_data=output_data, + history = ExecutorHistory( + taskId=self.task.id, + executorId=self.state.executorId, + executorName=self.state.executorName, + executorStatus=self.state.executorStatus, + stepId=self.step.step_id, + stepName=self.step.step.name, + stepDescription=self.step.step.description, + stepStatus=self.state.stepStatus, + inputData=self.obj.input, + outputData=output_data, ) self.task.context.append(history) diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 3ccc088019bfbba85630ac4de612b3a1f35ee66e..bceeaa9d937461d1e2070d27ad17a588619c9326 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -16,7 +16,7 @@ from apps.models.task import ExecutorHistory from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.pool import MCPPool -from apps.schemas.enum_var import FlowStatus, StepStatus +from apps.schemas.enum_var import ExecutorStatus, StepStatus from apps.schemas.mcp import MCPPlanItem from apps.services.mcp_service import MCPServiceManager from apps.services.task import TaskManager @@ -102,7 +102,7 @@ class MCPHost: task_id=self._task_id, flow_id=self._runtime_id, flow_name=self._runtime_name, - flow_status=FlowStatus.RUNNING, + flow_status=ExecutorStatus.RUNNING, step_id=tool.id, step_name=tool.toolName, # description是规划的实际内容 diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py index 2e1d5f0e571076d7732b1888ccbbc4e638dc606e..55f8d8c76783d5b4e9fc87b797071e2f142e3368 100644 --- a/apps/scheduler/pool/check.py +++ b/apps/scheduler/pool/check.py @@ -55,7 +55,7 @@ class FileChecker: return self.hashes[path_diff.as_posix()] != previous_hashes - async def diff(self, check_type: MetadataType) -> tuple[list[uuid.UUID], list[uuid.UUID]]: + async def diff(self, check_type: MetadataType) -> tuple[list[str], list[str]]: """生成更新列表和删除列表""" async with postgres.session() as session: # 判断类型 diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index fbcafa7960303461d7c8d91a4ff8e95b21a1011a..dc83faf1cbb6a645c5af733c40c47e8e5020de3d 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -59,7 +59,7 @@ class FlowLoader: @staticmethod - async def _process_edges(flow_yaml: dict[str, Any], flow_id: uuid.UUID, app_id: uuid.UUID) -> dict[str, Any]: + async def _process_edges(flow_yaml: dict[str, Any], flow_id: str, app_id: uuid.UUID) -> dict[str, Any]: """处理工作流边的转换""" logger.info("[FlowLoader] 应用 %s:解析工作流 %s 的边", flow_id, app_id) try: @@ -78,7 +78,7 @@ class FlowLoader: @staticmethod - async def _process_steps(flow_yaml: dict[str, Any], flow_id: uuid.UUID, app_id: uuid.UUID) -> dict[str, Any]: + async def _process_steps(flow_yaml: dict[str, Any], flow_id: str, app_id: uuid.UUID) -> dict[str, Any]: """处理工作流步骤的转换""" logger.info("[FlowLoader] 应用 %s:解析工作流 %s 的步骤", flow_id, app_id) for key, step in flow_yaml["steps"].items(): @@ -110,7 +110,7 @@ class FlowLoader: @staticmethod - async def load(app_id: uuid.UUID, flow_id: uuid.UUID) -> Flow: + async def load(app_id: uuid.UUID, flow_id: str) -> Flow: """从文件系统中加载【单个】工作流""" logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", flow_id, app_id) @@ -157,7 +157,7 @@ class FlowLoader: @staticmethod - async def save(app_id: uuid.UUID, flow_id: uuid.UUID, flow: Flow) -> None: + async def save(app_id: uuid.UUID, flow_id: str, flow: Flow) -> None: """保存工作流""" flow_path = BASE_PATH / str(app_id) / "flow" / f"{flow_id}.yaml" if not await flow_path.parent.exists(): @@ -189,7 +189,7 @@ class FlowLoader: @staticmethod - async def delete(app_id: uuid.UUID, flow_id: uuid.UUID) -> None: + async def delete(app_id: uuid.UUID, flow_id: str) -> None: """删除指定工作流文件""" flow_path = BASE_PATH / str(app_id) / "flow" / f"{flow_id}.yaml" # 确保目标为文件且存在 diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py index 6bf044dde4024dfde97818f646fe9c118ae8d9b3..809bac6dbfcca5b4a51955f13d671a593253f44e 100644 --- a/apps/scheduler/pool/loader/metadata.py +++ b/apps/scheduler/pool/loader/metadata.py @@ -80,7 +80,7 @@ class MetadataLoader: self, metadata_type: MetadataType, metadata: dict[str, Any] | AppMetadata | ServiceMetadata | AgentAppMetadata, - resource_id: uuid.UUID, + resource_id: uuid.UUID | str, ) -> None: """保存单个元数据""" class_dict = { diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 09c1b08f3fd265d45c02ef9575c856c58aa23a8c..979775b26893cb52b8d5b64a022e5d6580dee819 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -2,7 +2,6 @@ """OpenAPI文档载入器""" import logging -import uuid from hashlib import shake_128 from typing import Any @@ -122,7 +121,7 @@ class OpenAPILoader: async def _process_spec( self, - service_id: uuid.UUID, + service_id: str, yaml_filename: str, spec: ReducedOpenAPISpec, server: str, @@ -170,7 +169,7 @@ class OpenAPILoader: """加载字典形式的OpenAPI文档""" spec = reduce_openapi_spec(yaml_dict) try: - await self._process_spec(uuid.UUID("00000000-0000-0000-0000-000000000000"), "temp.yaml", spec, spec.servers) + await self._process_spec("temp", "temp.yaml", spec, spec.servers) except Exception: err = "[OpenAPILoader] 处理OpenAPI文档失败" logger.exception(err) @@ -179,7 +178,7 @@ class OpenAPILoader: return spec - async def load_one(self, service_id: uuid.UUID, yaml_path: Path, server: str) -> list[NodeInfo]: + async def load_one(self, service_id: str, yaml_path: Path, server: str) -> list[NodeInfo]: """加载单个OpenAPI文档,可以直接指定路径""" try: spec = await self._read_yaml(yaml_path) diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 802ec35050ae860727dedb2e354c277853330559..01cebd78d650621cc579ef90ce806dbadec549a1 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -3,7 +3,6 @@ import logging import shutil -import uuid from anyio import Path from sqlalchemy import delete @@ -28,7 +27,7 @@ class ServiceLoader: """Service 加载器""" @staticmethod - async def load(service_id: uuid.UUID, hashes: dict[str, str]) -> None: + async def load(service_id: str, hashes: dict[str, str]) -> None: """加载单个Service""" service_path = BASE_PATH / str(service_id) # 载入元数据 @@ -55,7 +54,7 @@ class ServiceLoader: @staticmethod - async def save(service_id: uuid.UUID, metadata: ServiceMetadata, data: dict) -> None: + async def save(service_id: str, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" service_path = BASE_PATH / str(service_id) # 创建文件夹 @@ -75,7 +74,7 @@ class ServiceLoader: @staticmethod - async def delete(service_id: uuid.UUID, *, is_reload: bool = False) -> None: + async def delete(service_id: str, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" async with postgres.session() as session: await session.execute(delete(Service).where(Service.id == service_id)) diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 496b18daf6c1b526f7bc9c30adb3cd3920664740..fce4344e159f3d2d6a909df26d834b8c051521f1 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -111,19 +111,19 @@ class Pool: # 批量删除App for app in changed_app: - await AppLoader.delete(app, is_reload=True) + await AppLoader.delete(uuid.UUID(app), is_reload=True) for app in deleted_app: - await AppLoader.delete(app) + await AppLoader.delete(uuid.UUID(app)) # 批量加载App for app in changed_app: hash_key = Path("app/" + str(app)).as_posix() if hash_key in checker.hashes: try: - await AppLoader.load(app, checker.hashes[hash_key]) - except Exception as e: - await AppLoader.delete(app, is_reload=True) - logger.warning("[Pool] 加载App %s 失败: %s", app, e) + await AppLoader.load(uuid.UUID(app), checker.hashes[hash_key]) + except Exception as e: # noqa: BLE001 + await AppLoader.delete(uuid.UUID(app), is_reload=True) + logger.warning("[Pool] 加载App %s 失败: %s", app, str(e)) # 载入MCP logger.info("[Pool] 载入MCP") @@ -138,7 +138,7 @@ class Pool: )).all()) - async def get_flow(self, app_id: uuid.UUID, flow_id: uuid.UUID) -> Flow | None: + async def get_flow(self, app_id: uuid.UUID, flow_id: str) -> Flow | None: """从文件系统中获取单个Flow的全部数据""" logger.info("[Pool] 获取工作流 %s", flow_id) return await FlowLoader.load(app_id, flow_id) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index a59819066cbc5041e2347e165f4029f91a0b1592..9d21433bad15700aec196cffa8bba14aa214ad39 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -9,7 +9,7 @@ from apps.common.config import config from apps.common.queue import MessageQueue from apps.models.document import Document from apps.models.task import Task -from apps.schemas.enum_var import EventType, FlowStatus +from apps.schemas.enum_var import EventType, ExecutorStatus from apps.schemas.message import ( DocumentAddContent, InitContent, @@ -73,10 +73,10 @@ async def push_rag_message( full_answer += content_obj.content elif content_obj.event_type == EventType.DOCUMENT_ADD.value: task.runtime.documents.append(content_obj.content) - task.state.flow_status = FlowStatus.SUCCESS + task.state.flow_status = ExecutorStatus.SUCCESS except Exception as e: logger.error(f"[Scheduler] RAG服务发生错误: {e}") - task.state.flow_status = FlowStatus.ERROR + task.state.flow_status = ExecutorStatus.ERROR # 保存答案 task.runtime.answer = full_answer task.tokens.full_time = round(datetime.now(UTC).timestamp(), 2) - task.tokens.time diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 7a828bb415163945751a133efb361500832352b4..a291f3a76fef6184c3a8e51d316185b6c0542bf2 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -21,7 +21,7 @@ from apps.scheduler.scheduler.message import ( push_rag_message, ) from apps.schemas.config import LLMConfig -from apps.schemas.enum_var import AppType, EventType, FlowStatus +from apps.schemas.enum_var import AppType, EventType, ExecutorStatus from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground @@ -153,21 +153,21 @@ class Scheduler: # 等待任一任务完成 done, pending = await asyncio.wait( [main_task, monitor], - return_when=asyncio.FIRST_COMPLETED + return_when=asyncio.FIRST_COMPLETED, ) # 如果是监控任务触发,终止主任务 if kill_event.is_set(): logger.warning("[Scheduler] 用户活动状态检测不活跃,正在终止工作流执行...") main_task.cancel() - need_change_cancel_flow_state = [FlowStatus.RUNNING, FlowStatus.WAITING] + need_change_cancel_flow_state = [ExecutorStatus.RUNNING, ExecutorStatus.WAITING] if self.task.state.flow_status in need_change_cancel_flow_state: - self.task.state.flow_status = FlowStatus.CANCELLED + self.task.state.flow_status = ExecutorStatus.CANCELLED try: await main_task logger.info("[Scheduler] 工作流执行已被终止") - except Exception as e: - logger.error(f"[Scheduler] 终止工作流时发生错误: {e}") + except Exception: + logger.exception("[Scheduler] 终止工作流时发生错误") # 更新Task,发送结束消息 logger.info("[Scheduler] 发送结束消息") @@ -192,19 +192,7 @@ class Scheduler: if not app_metadata: logger.error("[Scheduler] 未找到Agent应用") return - if app_metadata.llm_id == "empty": - llm = LLM( - _id="empty", - user_sub=self.task.ids.user_sub, - openai_base_url=Config().get_config().llm.endpoint, - openai_api_key=Config().get_config().llm.key, - model_name=Config().get_config().llm.model, - max_tokens=Config().get_config().llm.max_tokens, - ) - else: - llm = await LLMManager.get_llm_by_id( - self.task.ids.user_sub, app_metadata.llm_id, - ) + llm = await LLMManager.get_llm(app_metadata.llm_id) if not llm: logger.error("[Scheduler] 获取大模型失败") await self.queue.close() @@ -217,10 +205,12 @@ class Scheduler: max_tokens=llm.max_tokens, ), ) - if background.conversation and self.task.state.flow_status == FlowStatus.INIT: + if background.conversation and self.task.state.flow_status == ExecutorStatus.INIT: try: question_obj = QuestionRewrite() - post_body.question = await question_obj.generate(history=background.conversation, question=post_body.question, llm=reasion_llm) + post_body.question = await question_obj.generate( + history=background.conversation, question=post_body.question, llm=reasion_llm, + ) except Exception: logger.exception("[Scheduler] 问题重写失败") if app_metadata.app_type == AppType.FLOW.value: @@ -266,19 +256,16 @@ class Scheduler: # 开始运行 logger.info("[Scheduler] 运行Executor") - await flow_exec.load_state() + await flow_exec.init() await flow_exec.run() self.task = flow_exec.task elif app_metadata.app_type == AppType.AGENT.value: - # 获取agent中对应的MCP server信息 - servers_id = app_metadata.mcp_service # 初始化Executor agent_exec = MCPAgentExecutor( task=self.task, msg_queue=queue, question=post_body.question, history_len=app_metadata.history_len, - servers_id=servers_id, background=background, agent_id=app_info.app_id, params=post_body.params, diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 8ebc46e3c04240cddb9127ffcc7c3d77e8521183..4477bd59050c831fb1b685e40331bd6040938cfd 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -25,8 +25,8 @@ class StepStatus(str, Enum): CANCELLED = "cancelled" -class FlowStatus(str, Enum): - """Flow状态""" +class ExecutorStatus(str, Enum): + """执行器状态""" UNKNOWN = "unknown" INIT = "init" diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index 5cacb2ab3f01217270e0b6bb169e4c403aa290b4..0a1566de3d2e132e350ea92500b4238c3fea68f6 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -19,9 +19,9 @@ from .flow_topology import PositionItem class Edge(BaseModel): """Flow中Edge的数据""" - id: str = Field(description="边的ID") - edge_from: str = Field(description="边的来源节点ID") - edge_to: str = Field(description="边的目标节点ID") + id: uuid.UUID = Field(description="边的ID") + edge_from: uuid.UUID = Field(description="边的来源节点ID") + edge_to: uuid.UUID = Field(description="边的目标节点ID") edge_type: EdgeType | None = Field(description="边的类型", default=EdgeType.NORMAL) @@ -52,7 +52,7 @@ class Flow(BaseModel): focus_point: PositionItem | None = Field(description="当前焦点节点", default=PositionItem(x=0, y=0)) debug: bool = Field(description="是否经过调试", default=False) on_error: FlowError = FlowError(use_llm=True) - steps: dict[str, Step] = Field(description="节点列表", default={}) + steps: dict[uuid.UUID, Step] = Field(description="节点列表", default={}) edges: list[Edge] = Field(description="边列表", default=[]) @@ -113,6 +113,7 @@ class ServiceApiConfig(BaseModel): class ServiceMetadata(MetadataBase): """Service的元数据""" + id: str = Field(description="Service的ID") type: MetadataType = MetadataType.SERVICE api: ServiceApiConfig = Field(description="API配置") permission: Permission | None = Field(description="服务权限配置", default=None) diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 1ece99ba27784c482d9be7109ffbef63b416a780..629915b21670ec98ad572f3d4cd2ffcf9b83e9bc 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel, Field -from apps.schemas.enum_var import EventType, FlowStatus, StepStatus +from apps.schemas.enum_var import EventType, ExecutorStatus, StepStatus from .record import RecordMetadata @@ -33,7 +33,7 @@ class MessageFlow(BaseModel): app_id: str = Field(description="插件ID", alias="appId") flow_id: str = Field(description="Flow ID", alias="flowId") flow_name: str = Field(description="Flow名称", alias="flowName") - flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN) + flow_status: ExecutorStatus = Field(description="Flow状态", alias="flowStatus", default=ExecutorStatus.UNKNOWN) step_id: str = Field(description="当前步骤ID", alias="stepId") step_name: str = Field(description="当前步骤名称", alias="stepName") sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 94b83c09c5d014eaffe2ba3ec603a2a69311d996..d8e237e763591df7ce5c26a01a92570e2ea791c6 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -7,7 +7,7 @@ from typing import Any, Literal from pydantic import BaseModel, Field -from apps.schemas.enum_var import CommentType, FlowStatus, StepStatus +from apps.schemas.enum_var import CommentType, ExecutorStatus, StepStatus class RecordDocument(BaseModel): @@ -111,7 +111,7 @@ class FlowHistory(BaseModel): flow_id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") flow_name: str = Field(default="", description="Flow名称") - flow_staus: FlowStatus = Field(default=FlowStatus.SUCCESS, description="Flow执行状态") + flow_staus: ExecutorStatus = Field(default=ExecutorStatus.SUCCESS, description="Flow执行状态") history_ids: list[str] = Field(default=[], description="Flow执行历史ID列表") diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 0a337cfc8a5fa5130abf5604800d8ace6a8ed7ae..ad8244af9746460e7e37995c1bea0a2ca2bd799e 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """Task相关数据结构定义""" +import uuid from typing import Any from pydantic import BaseModel, Field @@ -23,7 +24,7 @@ class TaskExtra(BaseModel): class StepQueueItem(BaseModel): """步骤栈中的元素""" - step_id: str = Field(description="步骤ID") + step_id: uuid.UUID = Field(description="步骤ID") step: Step = Field(description="步骤") enable_filling: bool | None = Field(description="是否启用填充", default=None) to_user: bool | None = Field(description="是否输出给用户", default=None) diff --git a/apps/services/record.py b/apps/services/record.py index 9f81d926b53b9299bff7c58754517f46fe26f390..ae7885183a79d3db8fbd671db35cee122aa8e4a3 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -11,7 +11,7 @@ from sqlalchemy import and_, select from apps.common.postgres import postgres from apps.models.conversation import Conversation from apps.models.record import Record as PgRecord -from apps.schemas.enum_var import FlowStatus +from apps.schemas.enum_var import ExecutorStatus from apps.schemas.record import Record logger = logging.getLogger(__name__) @@ -117,8 +117,8 @@ class RecordManager: async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None: """更新Record关联的Flow状态""" try: - {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}}, - {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}}, + {"records.task_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [ExecutorStatus.ERROR.value, ExecutorStatus.SUCCESS.value]}}, + {"$set": {"records.$[elem].flow.flow_status": ExecutorStatus.CANCELLED}}, array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}], ) except Exception: diff --git a/apps/services/task.py b/apps/services/task.py index 832b5743805605e14b0a8a165b9b9d3d37d79a57..bc009c335a30d1dfd565e0854590f752e04e9e47 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -4,9 +4,10 @@ import logging import uuid -from sqlalchemy import delete, select, update +from sqlalchemy import and_, delete, select, update from apps.common.postgres import postgres +from apps.models.conversation import Conversation from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime from apps.schemas.request_data import RequestData @@ -17,19 +18,32 @@ class TaskManager: """从数据库中获取任务信息""" @staticmethod - async def get_task_by_conversation_id(conversation_id: str) -> Task: + async def get_task_by_conversation_id(conversation_id: uuid.UUID, user_sub: str) -> Task: """获取对话ID的最后一个任务""" async with postgres.session() as session: + # 检查user_sub是否匹配Conversation + conversation = (await session.scalars( + select(Conversation.id).where( + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), + ), + )).one_or_none() + if not conversation: + err = f"对话不存在或无权访问: {conversation_id}" + raise RuntimeError(err) + task = (await session.scalars( select(Task).where(Task.conversationId == conversation_id).order_by(Task.updatedAt.desc()).limit(1), )).one_or_none() if not task: # 任务不存在,新建Task - logger.info("[TaskManager] 新建任务", task_id) return Task( conversationId=conversation_id, - + userSub=user_sub, ) + logger.info("[TaskManager] 新建任务 %s", task.id) return task @@ -42,6 +56,24 @@ class TaskManager: )).one_or_none() + @staticmethod + async def get_task_runtime_by_task_id(task_id: uuid.UUID) -> TaskRuntime | None: + """根据task_id获取任务运行时""" + async with postgres.session() as session: + return (await session.scalars( + select(TaskRuntime).where(TaskRuntime.taskId == task_id), + )).one_or_none() + + + @staticmethod + async def get_task_state_by_task_id(task_id: uuid.UUID) -> ExecutorCheckpoint | None: + """根据task_id获取任务状态""" + async with postgres.session() as session: + return (await session.scalars( + select(ExecutorCheckpoint).where(ExecutorCheckpoint.taskId == task_id), + )).one_or_none() + + @staticmethod async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[ExecutorHistory]: """根据task_id获取flow信息""" @@ -160,14 +192,25 @@ class TaskManager: @classmethod - async def update_task_token(cls, task_id: uuid.UUID, input_token: int, output_token: int) -> tuple[int, int]: + async def update_task_token( + cls, + task_id: uuid.UUID, + input_token: int, + output_token: int, + *, + override: bool = False, + ) -> tuple[int, int]: """更新任务的Token""" async with postgres.session() as session: sql = select(TaskRuntime.inputToken, TaskRuntime.outputToken).where(TaskRuntime.taskId == task_id) row = (await session.execute(sql)).one() - new_input_token = row.inputToken + input_token - new_output_token = row.outputToken + output_token + if override: + new_input_token = input_token + new_output_token = output_token + else: + new_input_token = row.inputToken + input_token + new_output_token = row.outputToken + output_token sql = update(TaskRuntime).where(TaskRuntime.taskId == task_id).values( inputToken=new_input_token, outputToken=new_output_token,