diff --git a/apps/entities/record.py b/apps/entities/record.py index a0747ba22141a21a629f176ce65108c8de46f5ed..68ee72885038ae9f33dcc391794905b2b264ac61 100644 --- a/apps/entities/record.py +++ b/apps/entities/record.py @@ -13,7 +13,7 @@ from pydantic import BaseModel, Field from apps.entities.collection import ( Document, ) -from apps.entities.enum_var import StepStatus, CommentType +from apps.entities.enum_var import CommentType, StepStatus class RecordDocument(Document): @@ -86,7 +86,7 @@ class RecordData(BaseModel): flow: RecordFlow | None = None content: RecordContent metadata: RecordMetadata - comment: RecordComment= Field(default=RecordComment()) + comment: CommentType = Field(default=CommentType.NONE) created_at: float = Field(alias="createdAt") @@ -103,6 +103,7 @@ class Record(RecordData): user_sub: str key: dict[str, Any] = {} content: str + comment: RecordComment= Field(default=RecordComment()) flow: list[str] = Field(default=[]) diff --git a/apps/entities/scheduler.py b/apps/entities/scheduler.py index 215efe5f1ed89071a478d35b95ba210cec570906..b69cf160245ece2397d0c9961f71d997d2dc582a 100644 --- a/apps/entities/scheduler.py +++ b/apps/entities/scheduler.py @@ -35,6 +35,7 @@ class CallVars(BaseModel): summary: str = Field(description="上下文信息") question: str = Field(description="改写后的用户输入") history: dict[str, FlowStepHistory] = Field(description="Executor中历史工具的结构化数据", default={}) + history_order: list[str] = Field(description="Executor中历史工具的顺序", default=[]) ids: CallIds = Field(description="Call的ID") diff --git a/apps/entities/task.py b/apps/entities/task.py index 24a4f3d56e4136c7e0f688baff9d8b602e1cdc04..c68e5216336e3310c6396d062b793d5940707a1e 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -85,7 +85,7 @@ class Task(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") ids: TaskIds = Field(description="任务涉及的各种ID") - context: dict[str, FlowStepHistory] = Field(description="Flow的步骤执行信息", default={}) + context: list[dict[str, Any]] = Field(description="Flow的步骤执行信息", default=[]) state: ExecutorState | None = Field(description="Flow的状态", default=None) tokens: TaskTokens = Field(description="Token信息") runtime: TaskRuntime = Field(description="任务运行时数据") diff --git a/apps/manager/blacklist.py b/apps/manager/blacklist.py index 57c53efc9cc0e6ba9c4f34e7f32dd16cd992d8ff..a1481caf55014403b3fe6f287a8321adbae9528c 100644 --- a/apps/manager/blacklist.py +++ b/apps/manager/blacklist.py @@ -177,7 +177,7 @@ class AbuseManager: [ {"$match": {"user_sub": user_sub}}, {"$unwind": "$records"}, - {"$match": {"records._id": record_id}}, + {"$match": {"records.id": record_id}}, {"$limit": 1}, ], ) diff --git a/apps/manager/comment.py b/apps/manager/comment.py index 63d55458855062c4bb7448420057a61285aec7e5..bd1d79de2acf9e6395f213b24baa1aada6050233 100644 --- a/apps/manager/comment.py +++ b/apps/manager/comment.py @@ -27,9 +27,9 @@ class CommentManager: record_group_collection = MongoDB.get_collection("record_group") result = await record_group_collection.aggregate( [ - {"$match": {"_id": group_id, "records._id": record_id}}, + {"$match": {"_id": group_id, "records.id": record_id}}, {"$unwind": "$records"}, - {"$match": {"records._id": record_id}}, + {"$match": {"records.id": record_id}}, {"$limit": 1}, ], ) @@ -52,7 +52,7 @@ class CommentManager: try: record_group_collection = MongoDB.get_collection("record_group") await record_group_collection.update_one( - {"_id": group_id, "records._id": record_id}, + {"_id": group_id, "records.id": record_id}, {"$set": {"records.$.comment": data.model_dump(by_alias=True)}}, ) except Exception: diff --git a/apps/manager/record.py b/apps/manager/record.py index ae4f7703bb30ced266de8710768375fe4eed9fe7..72d0bd045b383b5bad2fa5a66fd650350de87e06 100644 --- a/apps/manager/record.py +++ b/apps/manager/record.py @@ -5,7 +5,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ import logging -import uuid from typing import Literal from apps.entities.record import ( @@ -21,9 +20,8 @@ class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(user_sub: str, conversation_id: str, task_id: str) -> str | None: + async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None: """创建问答组""" - group_id = str(uuid.uuid4()) record_group_collection = MongoDB.get_collection("record_group") conversation_collection = MongoDB.get_collection("conversation") record_group = RecordGroup( @@ -146,13 +144,14 @@ class RecordManager: try: record_group_collection = MongoDB.get_collection("record_group") record_data = await record_group_collection.find_one( - {"_id": group_id, "user_sub": user_sub, "records._id": record_id}, + {"_id": group_id, "user_sub": user_sub, "records.id": record_id}, ) return bool(record_data) except Exception: logger.exception("[RecordManager] 验证记录是否在组中失败") return False + @staticmethod async def check_group_id(group_id: str, user_sub: str) -> bool: """检查group_id是否存在""" diff --git a/apps/manager/task.py b/apps/manager/task.py index ff3257295d1e23a3fdc1d4c9d0f1762b168451e1..6d6e4bb4594ee61bf3346f9eedbf46503e591c74 100644 --- a/apps/manager/task.py +++ b/apps/manager/task.py @@ -6,10 +6,16 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. import logging import uuid +from typing import Any from apps.entities.record import RecordGroup from apps.entities.request_data import RequestData -from apps.entities.task import FlowStepHistory, Task, TaskIds, TaskRuntime, TaskTokens +from apps.entities.task import ( + Task, + TaskIds, + TaskRuntime, + TaskTokens, +) from apps.manager.record import RecordManager from apps.models.mongo import MongoDB @@ -33,6 +39,7 @@ class TaskManager: }, ) + @staticmethod async def get_task_by_conversation_id(conversation_id: str) -> Task | None: """获取对话ID的最后一条问答组关联的任务""" @@ -73,15 +80,19 @@ class TaskManager: logger.exception("[TaskManager] 获取组ID的最后一条问答组关联的任务失败") return None + @staticmethod async def get_task_by_task_id(task_id: str) -> Task | None: """根据task_id获取任务""" task_collection = MongoDB.get_collection("task") task = await task_collection.find_one({"_id": task_id}) + if not task: + return None return Task.model_validate(task) + @staticmethod - async def get_flow_history_by_record_id(record_group_id: str, record_id: str) -> list[FlowStepHistory]: + async def get_context_by_record_id(record_group_id: str, record_id: str) -> list[dict[str, Any]]: """根据record_group_id获取flow信息""" record_group_collection = MongoDB.get_collection("record_group") flow_context_collection = MongoDB.get_collection("flow_context") @@ -99,7 +110,6 @@ class TaskManager: for flow_context_id in records[0]["records"]["flow"]: flow_context = await flow_context_collection.find_one({"_id": flow_context_id}) if flow_context: - flow_context = FlowStepHistory.model_validate(flow_context) flow_context_list.append(flow_context) except Exception: logger.exception("[TaskManager] 获取record_id的flow信息失败") @@ -109,33 +119,43 @@ class TaskManager: @staticmethod - async def get_flow_history_by_task_id(task_id: str, length: int = 0) -> dict[str, FlowStepHistory]: + async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]: """根据task_id获取flow信息""" flow_context_collection = MongoDB.get_collection("flow_context") - flow_context = {} + flow_context = [] try: async for history in flow_context_collection.find( {"task_id": task_id}, ).sort( "created_at", -1, ).limit(length): - history_obj = FlowStepHistory.model_validate(history) - flow_context[history_obj.step_id] = history_obj + flow_context.extend(history) except Exception: logger.exception("[TaskManager] 获取task_id的flow信息失败") - return {} + return [] else: return flow_context @staticmethod - async def save_flow_context(flow_context: list[FlowStepHistory]) -> None: + async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB.get_collection("flow_context") try: - flow_context_str = [flow.model_dump(by_alias=True) for flow in flow_context] - await flow_context_collection.insert_many(flow_context_str) + for history in flow_context: + # 查找是否存在 + current_context = await flow_context_collection.find_one({ + "task_id": task_id, + "_id": history["_id"], + }) + if current_context: + await flow_context_collection.update_one( + {"_id": current_context["_id"]}, + {"$set": history}, + ) + else: + await flow_context_collection.insert_one(history) except Exception: logger.exception("[TaskManager] 保存flow执行记录失败") @@ -170,6 +190,7 @@ class TaskManager: except Exception: logger.exception("[TaskManager] 通过ConversationID删除Task信息失败") + @classmethod async def get_task( cls, @@ -199,7 +220,21 @@ class TaskManager: else: task = await cls.get_task_by_conversation_id(post_body.conversation_id) - return await cls.create_or_update_task(post_body, user_sub, session_id, task) + if task: + return task + return Task( + _id=str(uuid.uuid4()), + ids=TaskIds( + user_sub=user_sub if user_sub else "", + session_id=session_id if session_id else "", + conversation_id=post_body.conversation_id, + group_id=post_body.group_id if post_body.group_id else "", + ), + state=None, + tokens=TaskTokens(), + runtime=TaskRuntime(), + ) + @classmethod async def save_task(cls, task_id: str, task: Task) -> None: @@ -210,48 +245,5 @@ class TaskManager: await task_collection.update_one( {"_id": task_id}, {"$set": task.model_dump(by_alias=True, exclude_none=True)}, + upsert=True, ) - - @classmethod - async def create_or_update_task( - cls, - post_body: RequestData, - user_sub: str | None = None, - session_id: str | None = None, - task: Task | None = None, - ) -> Task: - """创建或更新任务块""" - # 创建新的Record - if not task: - # 任务不存在,新建Task - task_id = str(uuid.uuid4()) - - task = Task( - _id=task_id, - ids=TaskIds( - user_sub=user_sub if user_sub else "", - session_id=session_id if session_id else "", - conversation_id=post_body.conversation_id, - group_id=post_body.group_id if post_body.group_id else "", - ), - state=None, - tokens=TaskTokens(), - runtime=TaskRuntime(), - ) - - # 保存到MongoDB - task_collection = MongoDB.get_collection("task") - await task_collection.insert_one(task.model_dump(by_alias=True, exclude_none=True)) - return task - - # 任务存在,更新Task - task_id = task.id - - # 更新MongoDB - task_collection = MongoDB.get_collection("task") - await task_collection.update_one( - {"_id": task_id}, - {"$set": task.model_dump(by_alias=True, exclude_none=True)}, - ) - - return task diff --git a/apps/routers/chat.py b/apps/routers/chat.py index f4af2ba507330504d6d78df549bec883e7e96adf..727c9730d1d60de12cae166b4611c670e84109e3 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -41,13 +41,14 @@ router = APIRouter( async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> str: """初始化Task""" # 生成group_id - group_id = str(uuid.uuid4()) if not post_body.group_id else post_body.group_id - post_body.group_id = group_id + if not post_body.group_id: + post_body.group_id = str(uuid.uuid4()) # 创建或还原Task task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub) task_id = task.id # 更改信息并刷新数据库 task.runtime.question = post_body.question + task.ids.group_id = post_body.group_id await TaskManager.save_task(task_id, task) return task_id diff --git a/apps/routers/comment.py b/apps/routers/comment.py index ccaada8b1409b2b768df8c466218c3ba007e4b4c..d7d368058bfb19118aa5c12aff17ed973eabd102 100644 --- a/apps/routers/comment.py +++ b/apps/routers/comment.py @@ -33,17 +33,17 @@ async def add_comment(post_body: AddCommentData, user_sub: Annotated[str, Depend """给Record添加评论""" if not await RecordManager.verify_record_in_group(post_body.group_id, post_body.record_id, user_sub): logger.error("[Comment] record_id 不存在") - return JSONResponse(status_code=status.HTTP_204_NO_CONTENT, content=ResponseData( - code=status.HTTP_204_NO_CONTENT, + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( + code=status.HTTP_400_BAD_REQUEST, message="record_id not found", result={}, ).model_dump(exclude_none=True, by_alias=True)) comment_data = RecordComment( comment=post_body.comment, - feedback_type=post_body.dislike_reason, - feedback_link=post_body.reason_link, - feedback_content=post_body.reason_description, + dislike_reason=post_body.dislike_reason, + reason_link=post_body.reason_link, + reason_description=post_body.reason_description, feedback_time=round(datetime.now(tz=UTC).timestamp(), 3), ) await CommentManager.update_comment(post_body.group_id, post_body.record_id, comment_data) diff --git a/apps/routers/record.py b/apps/routers/record.py index e3554b77fbf13b3aebee5742bca7e91bb8240a54..fe8c767b9fdee7a83ae5ed97e8a6ff75499bc075 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -77,7 +77,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ outputTokens=0, timeCost=0, ), - comment=record.comment, + comment=record.comment.comment, createdAt=record.created_at, ) @@ -85,7 +85,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) # 获得Record关联的flow数据 - flow_list = await TaskManager.get_flow_history_by_record_id(record_group.id, record.id) + flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) if flow_list: tmp_record.flow = RecordFlow( id=flow_list[0].id, diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index f7e09cea06595c7d74813dbf446423c53259e345..5276728e1ce6963bd4729f12fd00f8148ded5c6a 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -20,9 +20,12 @@ from apps.entities.scheduler import ( CallOutputChunk, CallVars, ) +from apps.entities.task import FlowStepHistory if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor + + logger = logging.getLogger(__name__) @@ -87,6 +90,13 @@ class CoreCall(BaseModel): logger.error(err) raise ValueError(err) + history = {} + history_order = [] + for item in executor.task.context: + item_obj = FlowStepHistory.model_validate(item) + history[item_obj.step_id] = item_obj + history_order.append(item_obj.step_id) + return CallVars( ids=CallIds( task_id=executor.task.id, @@ -96,7 +106,8 @@ class CoreCall(BaseModel): app_id=executor.task.state.app_id, ), question=executor.question, - history=executor.task.context, + history=history, + history_order=history_order, summary=executor.task.runtime.summary, ) diff --git a/apps/scheduler/call/graph/__init__.py b/apps/scheduler/call/graph/__init__.py index e16980413123da6f6e11e165e9273e41a6ccb92d..2a4869835e16717864af645b08f65b807d0ac17d 100644 --- a/apps/scheduler/call/graph/__init__.py +++ b/apps/scheduler/call/graph/__init__.py @@ -1,4 +1,5 @@ -"""Render工具 +""" +Render工具 用于生成ECharts图表。 """ diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 8c455d338742717181382f9b33a730408ce5f2b2..aae19569c20f29462f799a9f4aa8429a8428f4c9 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -10,7 +10,6 @@ from apps.entities.enum_var import CallOutputType from apps.entities.pool import NodePool from apps.entities.scheduler import CallInfo, CallOutputChunk, CallVars from apps.llm.patterns.json_gen import Json -from apps.manager.task import TaskManager from apps.scheduler.call.core import CoreCall from apps.scheduler.call.slot.schema import SlotInput, SlotOutput from apps.scheduler.slot.slot import Slot as SlotProcessor @@ -71,7 +70,9 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): async def _init(self, call_vars: CallVars) -> SlotInput: """初始化""" self._task_id = call_vars.ids.task_id - self._flow_history = await TaskManager.get_flow_history_by_task_id(self._task_id, self.step_num) + self._flow_history = {} + for key in call_vars.history_order[:-self.step_num]: + self._flow_history[key] = call_vars.history[key] if not self.current_schema: return SlotInput( diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 7feb09b0f0ae1c03a01dbb305793c6f6280cb34c..3d9c8e56ad620c450aec78e099e7f804b7085cd2 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -10,7 +10,7 @@ from apps.common.queue import MessageQueue from apps.entities.enum_var import EventType from apps.entities.message import FlowStartContent, TextAddContent from apps.entities.scheduler import ExecutorBackground -from apps.entities.task import FlowStepHistory, Task +from apps.entities.task import Task logger = logging.getLogger(__name__) @@ -22,7 +22,6 @@ class BaseExecutor(BaseModel): msg_queue: MessageQueue background: ExecutorBackground question: str - history: FlowStepHistory | None = None model_config = ConfigDict( arbitrary_types_allowed=True, @@ -37,7 +36,7 @@ class BaseExecutor(BaseModel): logger.error(err) raise ValueError(err) - async def push_message(self, event_type: str, data: dict[str, Any] | str | None = None) -> None: # noqa: C901 + async def push_message(self, event_type: str, data: dict[str, Any] | str | None = None) -> None: """ 统一的消息推送接口 @@ -54,24 +53,11 @@ class BaseExecutor(BaseModel): elif event_type == EventType.FLOW_STOP.value: data = {} elif event_type == EventType.STEP_INPUT.value and isinstance(data, dict): - # 更新step_history的输入数据 - if self.history is not None: - self.history.input_data = data - if self.task.state is not None: - self.history.status = self.task.state.status - # 步骤开始,重置时间 - self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) - elif event_type == EventType.STEP_OUTPUT.value and isinstance(data, dict): - # 更新step_history的输出数据 - if self.history is not None: - self.history.output_data = data - if self.task.state is not None: - self.history.status = self.task.state.status - self.task.context[self.task.state.step_id] = self.history + # 步骤开始,重置时间 + self.task.tokens.time = round(datetime.now(UTC).timestamp(), 2) elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): data=TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) - if data is None: data = {} elif isinstance(data, str): diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index c36b1eb5cc421186140ea723d76122f3cd59de0a..0ba3a93efe48fe3eb7e3526895c7aa537a8a4a20 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -59,6 +59,7 @@ class FlowExecutor(BaseExecutor): question: str = Field(description="用户输入") post_body_app: RequestDataApp = Field(description="请求体中的app信息") + """Pydantic配置""" async def load_state(self) -> None: @@ -66,7 +67,7 @@ class FlowExecutor(BaseExecutor): logger.info("[FlowExecutor] 加载Executor状态") # 尝试恢复State if self.task.state: - self.task.context = await TaskManager.get_flow_history_by_task_id(self.task.id) + self.task.context = await TaskManager.get_context_by_task_id(self.task.id) else: # 创建ExecutorState self.task.state = ExecutorState( @@ -93,7 +94,6 @@ class FlowExecutor(BaseExecutor): step=queue_item, background=self.background, question=self.question, - history=self.history, ) # 初始化步骤 @@ -101,9 +101,8 @@ class FlowExecutor(BaseExecutor): # 运行Step await step_runner.run_step() - # 更新Task + # 更新Task(已存过库) self.task = step_runner.task - await TaskManager.save_task(self.task.id, self.task) async def _step_process(self) -> None: diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py index 63ec8d029654a2453e0d9d641be1e6e1511aa2e5..6b584ddfbc0d4461aa6d90ef73c259ec5706ce58 100644 --- a/apps/scheduler/executor/step.py +++ b/apps/scheduler/executor/step.py @@ -29,7 +29,6 @@ class StepExecutor(BaseExecutor): step: StepQueueItem - model_config = ConfigDict( arbitrary_types_allowed=True, extra="allow", @@ -40,7 +39,7 @@ class StepExecutor(BaseExecutor): super().__init__(**kwargs) self.validate_flow_state(self.task) - self.history = FlowStepHistory( + self._history = FlowStepHistory( task_id=self.task.id, flow_id=self.task.state.flow_id, # type: ignore[arg-type] step_id=self.task.state.step_id, # type: ignore[arg-type] @@ -49,6 +48,7 @@ class StepExecutor(BaseExecutor): output_data={}, ) + async def init(self) -> None: """初始化步骤""" self.validate_flow_state(self.task) @@ -58,7 +58,6 @@ class StepExecutor(BaseExecutor): # State写入ID和运行状态 self.task.state.step_id = self.step.step_id # type: ignore[arg-type] self.task.state.step_name = self.step.step.name # type: ignore[arg-type] - self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] await TaskManager.save_task(self.task.id, self.task) # 获取并验证Call类 @@ -93,7 +92,7 @@ class StepExecutor(BaseExecutor): async def _run_slot_filling(self) -> None: - """运行自动参数填充""" + """运行自动参数填充;相当于特殊Step,但是不存库""" # 判断State是否为空 self.validate_flow_state(self.task) @@ -128,11 +127,20 @@ class StepExecutor(BaseExecutor): async for chunk in iterator: result: SlotOutput = SlotOutput.model_validate(chunk.content) - # 推送填参结果 - self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + # 如果没有填全 + if result.remaining_schema: + # 状态设置为待填参 + self.task.state.status = StepStatus.PARAM # type: ignore[arg-type] + else: + # 推送填参结果 + self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] await TaskManager.save_task(self.task.id, self.task) await self.push_message(EventType.STEP_OUTPUT.value, result.slot_data) + # 没完整填参,则返回 + if self.task.state.status == StepStatus.PARAM: # type: ignore[arg-type] + return + # 更新输入 self.obj.input.update(result.slot_data) @@ -183,6 +191,13 @@ class StepExecutor(BaseExecutor): # 进行自动参数填充 await self._run_slot_filling() + # 更新history + self._history.input_data = self.obj.input + self._history.output_data = {} + # 更新状态 + self.task.state.status = StepStatus.RUNNING # type: ignore[arg-type] + self._history.status = self.task.state.status # type: ignore[arg-type] + await TaskManager.save_task(self.task.id, self.task) # 推送输入 await self.push_message(EventType.STEP_INPUT.value, self.obj.input) @@ -200,9 +215,17 @@ class StepExecutor(BaseExecutor): # 更新执行状态 self.task.state.status = StepStatus.SUCCESS # type: ignore[arg-type] + self._history.status = self.task.state.status # type: ignore[arg-type] + # 更新history if isinstance(content, str): - self.content = TextAddContent(text=content).model_dump(exclude_none=True, by_alias=True) + self._history.output_data = TextAddContent(text=content).model_dump(exclude_none=True, by_alias=True) else: - self.content = content - await self.push_message(EventType.STEP_OUTPUT.value, self.content) + self._history.output_data = content + + # 更新context + self.task.context.append(self._history.model_dump(exclude_none=True, by_alias=True)) + await TaskManager.save_task(self.task.id, self.task) + + # 推送输出 + await self.push_message(EventType.STEP_OUTPUT.value, self._history.output_data) diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 1a4de8aff7a5716f9eace196317869a340e7328e..4d52672de11c44125649e38218cfc95126acf387 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -8,6 +8,7 @@ from datetime import UTC, datetime from apps.common.security import Security from apps.entities.collection import Document +from apps.entities.enum_var import StepStatus from apps.entities.record import ( Record, RecordContent, @@ -105,12 +106,7 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do # 保存Flow信息 if task.state: # 遍历查找数据,并添加 - for history_id in task.context: - pass - # await TaskManager.save_flow_context(history_data) - - # 修改metadata里面时间为实际运行时间 - task.tokens.time = round(datetime.now(UTC).timestamp() - task.tokens.time, 2) + await TaskManager.save_flow_context(task.id, task.context) # 整理Record数据 current_time = round(datetime.now(UTC).timestamp(), 2) @@ -129,16 +125,19 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do feature={}, ), createdAt=current_time, - flow=[i.id for i in task.context.values()], + flow=[i["_id"] for i in task.context], ) - record_group = task.ids.group_id # 检查是否存在group_id - if not await RecordManager.check_group_id(record_group, user_sub): - record_group = await RecordManager.create_record_group(user_sub, post_body.conversation_id, task.id) + if not await RecordManager.check_group_id(task.ids.group_id, user_sub): + record_group = await RecordManager.create_record_group( + task.ids.group_id, user_sub, post_body.conversation_id, task.id, + ) if not record_group: logger.error("[Scheduler] 创建问答组失败") return + else: + record_group = task.ids.group_id # 修改文件状态 await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) @@ -151,5 +150,9 @@ async def save_data(task_id: str, user_sub: str, post_body: RequestData, used_do # 更新最近使用的应用 await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) - # 保存Task - await TaskManager.save_task(task_id, task) + # 若状态为成功,删除Task + if task.state and task.state.status == StepStatus.SUCCESS: + await TaskManager.delete_task_by_task_id(task_id) + else: + # 更新Task + await TaskManager.save_task(task_id, task)