diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 44425bc6c9716385a92f4e4ea0cce6a730b067f1..d87932bd4afb866d218717f5dd51f6b98fba9ee0 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -115,9 +115,9 @@ class FlowExecutor(BaseExecutor): return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"]["branch_id"] + branch_id = self.task.context[-1].output_data.get("branch_id", "") if branch_id: - next_steps = self.flow.steps.get(self.task.state.step_id+"."+branch_id) + next_steps = await self._find_next_id(self.task.state.step_id+"."+branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) else: logger.warning("[FlowExecutor] 没有找到分支ID,返回空列表") diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 4c2c4cf0b046cb7fbe1b4ddca6dc93f55d8aabb0..1e936ab1357272c446c29da5b5e1823f53dce1ea 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -188,7 +188,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: feature={}, ), createdAt=current_time, - flow=[i["_id"] for i in task.context], + flow=[context.id for context in task.context], ) # 检查是否存在group_id diff --git a/apps/schemas/task.py b/apps/schemas/task.py index 37fdebbfb04d7eec41766eb8618026514935b321..3a37126053e693226da63a07283832c9cf3c0fae 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -89,7 +89,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) + context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) state: ExecutorState | None = Field(description="Flow的状态", default=None) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") diff --git a/apps/services/task.py b/apps/services/task.py index 2456d96bffb4291dd4b77ea559251450e42bed50..81924c58df5de55997c7f236d8f3f15dc93a5607 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -13,6 +13,7 @@ from apps.schemas.task import ( TaskIds, TaskRuntime, TaskTokens, + FlowStepHistory ) from apps.services.record import RecordManager @@ -67,7 +68,7 @@ class TaskManager: return Task.model_validate(task) @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") @@ -85,7 +86,7 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context_list.append(flow_context) + flow_context_list.append(FlowStepHistory.model_validate(flow_context)) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") return [] @@ -93,7 +94,7 @@ class TaskManager: return flow_context_list @staticmethod - async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> list[F]: """根据task_id获取flow信息""" flow_context_collection = MongoDB().get_collection("flow_context")