From d95f677fdfd555405fba8aafc49861e39d9e24d9 Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 16:22:09 +0800 Subject: [PATCH 1/2] =?UTF-8?q?context=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/executor/flow.py | 2 +- apps/scheduler/scheduler/context.py | 2 +- apps/schemas/task.py | 2 +- apps/services/task.py | 5 +++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 157c9159..d8400fdf 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -119,7 +119,7 @@ class FlowExecutor(BaseExecutor): return [] if self.current_step.step.type == SpecialCallType.CHOICE.value: # 如果是choice节点,获取分支ID - branch_id = self.task.context[-1]["output_data"]["branch_id"] + branch_id = self.task.context[-1].output_data["branch_id"] if branch_id: next_steps = await self._find_next_id(self.task.state.step_id + "." + branch_id) logger.info("[FlowExecutor] 分支ID:%s", branch_id) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 109a5456..72f64a71 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -193,7 +193,7 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None: flow_id=task.state.flow_id, flow_name=task.state.flow_name, flow_status=task.state.flow_status, - history_ids=[context["_id"] for context in task.context], + history_ids=[context.id for context in task.context], ) ) diff --git a/apps/schemas/task.py b/apps/schemas/task.py index e1901416..9e9f1531 100644 --- a/apps/schemas/task.py +++ b/apps/schemas/task.py @@ -93,7 +93,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) + context: list[FlowStepHistory] = Field(description="Flow的步骤执行信息", default=[]) state: ExecutorState = Field(description="Flow的状态", default=ExecutorState()) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") diff --git a/apps/services/task.py b/apps/services/task.py index dceee324..0f237a3f 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -13,6 +13,7 @@ from apps.schemas.task import ( TaskIds, TaskRuntime, TaskTokens, + FlowStepHistory ) from apps.services.record import RecordManager @@ -67,7 +68,7 @@ class TaskManager: return Task.model_validate(task) @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> FlowStepHistory: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") @@ -85,7 +86,7 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]["history_ids"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context_list.append(flow_context) + flow_context_list.append(FlowStepHistory.model_validate(flow_context)) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") return [] -- Gitee From ff332ab39f2e264f792bf5b08e959e3b67a2f89f Mon Sep 17 00:00:00 2001 From: zxstty Date: Thu, 31 Jul 2025 16:24:13 +0800 Subject: [PATCH 2/2] =?UTF-8?q?context=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/record.py | 1 - apps/services/task.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/apps/routers/record.py b/apps/routers/record.py index 29ff68e0..f73e6c98 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -93,7 +93,6 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ steps=[], ) for flow_step in flow_step_list: - flow_step = FlowStepHistory.model_validate(flow_step) tmp_record.flow.steps.append( RecordFlowStep( stepId=flow_step.step_name, # TODO: 此处前端应该用name diff --git a/apps/services/task.py b/apps/services/task.py index 0f237a3f..93604a93 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -68,7 +68,7 @@ class TaskManager: return Task.model_validate(task) @staticmethod - async def get_context_by_record_id(record_group_id: str, record_id: str) -> FlowStepHistory: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB().get_collection("record_group") flow_context_collection = MongoDB().get_collection("flow_context") -- Gitee