From 2b795d4d6581fcef8f7c349da0268b7c4455f3df Mon Sep 17 00:00:00 2001 From: z30057876 Date: Tue, 12 Aug 2025 20:07:46 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=81=E7=A7=BB=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E3=80=81=E4=BF=AE=E6=AD=A3Task=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 2 +- apps/constants.py | 2 +- apps/llm/prompt.py | 2 +- apps/models/task.py | 5 +- apps/routers/appcenter.py | 1 - apps/routers/auth.py | 2 +- apps/routers/chat.py | 4 -- apps/routers/parameter.py | 30 ++++---- apps/routers/record.py | 1 + apps/scheduler/call/choice/choice.py | 6 +- apps/scheduler/call/core.py | 10 +-- apps/scheduler/call/llm/llm.py | 1 - apps/scheduler/call/slot/slot.py | 2 - apps/scheduler/call/suggest/suggest.py | 5 +- apps/scheduler/executor/agent.py | 82 +++++++++++++--------- apps/scheduler/executor/base.py | 11 +-- apps/scheduler/mcp_agent/host.py | 2 +- apps/scheduler/pool/loader/app.py | 27 +++++--- apps/scheduler/pool/loader/flow.py | 12 +++- apps/scheduler/scheduler/context.py | 22 ++---- apps/scheduler/scheduler/flow.py | 15 ++-- apps/scheduler/scheduler/message.py | 2 +- apps/scheduler/scheduler/scheduler.py | 80 ++++++++-------------- apps/schemas/message.py | 5 +- apps/schemas/request_data.py | 4 +- apps/services/appcenter.py | 59 +++++++++------- apps/services/conversation.py | 4 +- apps/services/flow.py | 38 +++++++---- apps/services/mcp_service.py | 11 ++- apps/services/parameter.py | 3 +- apps/services/task.py | 95 +++++++++++++------------- 31 files changed, 275 insertions(+), 270 deletions(-) diff --git a/apps/common/queue.py b/apps/common/queue.py index cba88fe5..bcda8aab 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -8,6 +8,7 @@ from collections.abc import AsyncGenerator from datetime import UTC, datetime from typing import Any +from apps.models.task import Task from apps.schemas.enum_var import EventType from apps.schemas.message import ( HeartbeatData, @@ -15,7 +16,6 @@ from apps.schemas.message import ( MessageFlow, MessageMetadata, ) -from apps.schemas.task import Task logger = logging.getLogger(__name__) diff --git a/apps/constants.py b/apps/constants.py index 1407f561..1303e3d6 100644 --- a/apps/constants.py +++ b/apps/constants.py @@ -11,7 +11,7 @@ from .common.config import config # 新对话默认标题 NEW_CHAT = "新对话" # 滑动窗口限流 默认窗口期 -SLIDE_WINDOW_TIME = 600 +SLIDE_WINDOW_TIME = 15 # OIDC 访问Token 过期时间(分钟) OIDC_ACCESS_TOKEN_EXPIRE_TIME = 30 # OIDC 刷新Token 过期时间(分钟) diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index 7cf555be..7702f6d0 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -19,7 +19,7 @@ JSON_GEN_BASIC = dedent(r""" Background information is given in XML tags. - Here are the conversations between you and the user: + Here are the background information between you and the user: {% if conversation|length > 0 %} {% for message in conversation %} diff --git a/apps/models/task.py b/apps/models/task.py index 930a141b..e089fc9b 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -23,7 +23,10 @@ class Task(Base): UUID(as_uuid=True), ForeignKey("framework_conversation.id"), nullable=False, ) """对话ID""" - checkpointId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_executor_checkpoint.id")) # noqa: N815 + checkpointId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 + UUID(as_uuid=True), ForeignKey("framework_executor_checkpoint.id"), + nullable=True, default=None, + ) """检查点ID""" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) """任务ID""" diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index e60981c9..90629179 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -14,7 +14,6 @@ from apps.schemas.appcenter import AppFlowInfo, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType from apps.schemas.request_data import ChangeFavouriteAppRequest, CreateAppRequest from apps.schemas.response_data import ( - AppMcpServiceInfo, BaseAppOperationMsg, BaseAppOperationRsp, ChangeFavouriteAppMsg, diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 10b895b9..ef2f5aa5 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -164,7 +164,7 @@ async def userinfo(request: Request) -> JSONResponse: user_sub=request.state.user_sub, revision=user.isActive, is_admin=is_admin, - auto_execute=user.autoExecute, + auto_execute=user.autoExecute or False, ), ).model_dump(exclude_none=True, by_alias=True), ) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 31b381dd..033156a9 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -144,10 +144,6 @@ async def chat(request: Request, post_body: RequestData) -> StreamingResponse: await UserBlacklistManager.change_blacklisted_users(user_sub, -10) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") - # 限流检查 - if await Activity.is_active(user_sub): - raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") - res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 833e0cab..a409e12e 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -1,10 +1,11 @@ -from typing import Annotated +"""步骤参数相关路由""" -from fastapi import APIRouter, Depends, Query, status +import uuid + +from fastapi import APIRouter, Depends, Request, status from fastapi.responses import JSONResponse -from apps.dependency import get_user -from apps.dependency.user import verify_user +from apps.dependency.user import verify_personal_token, verify_session from apps.schemas.response_data import GetOperaRsp, GetParamsRsp from apps.services.appcenter import AppCenterManager from apps.services.flow import FlowManager @@ -14,20 +15,18 @@ router = APIRouter( prefix="/api/parameter", tags=["parameter"], dependencies=[ - Depends(verify_user), + Depends(verify_personal_token), + Depends(verify_session), ], ) @router.get("", response_model=GetParamsRsp) async def get_parameters( - user_sub: Annotated[str, Depends(get_user)], - app_id: Annotated[str, Query(alias="appId")], - flow_id: Annotated[str, Query(alias="flowId")], - step_id: Annotated[str, Query(alias="stepId")], + request: Request, appId: uuid.UUID, flowId: uuid.UUID, stepId: uuid.UUID, # noqa: N803 ) -> JSONResponse: """Get parameters for node choice.""" - if not await AppManager.validate_user_app_access(user_sub, app_id): + if not await AppCenterManager.validate_user_app_access(request.state.user_sub, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=GetParamsRsp( @@ -36,7 +35,7 @@ async def get_parameters( result=[], ).model_dump(exclude_none=True, by_alias=True), ) - flow = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + flow = await FlowManager.get_flow_by_app_and_flow_id(appId, flowId) if not flow: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, @@ -46,7 +45,7 @@ async def get_parameters( result=[], ).model_dump(exclude_none=True, by_alias=True), ) - result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, step_id) + result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, stepId) return JSONResponse( status_code=status.HTTP_200_OK, content=GetParamsRsp( @@ -58,12 +57,9 @@ async def get_parameters( @router.get("/operate", response_model=GetOperaRsp) -async def get_operate_parameters( - user_sub: Annotated[str, Depends(get_user)], - param_type: Annotated[str, Query(alias="ParamType")], -) -> JSONResponse: +async def get_operate_parameters(paramType: str) -> JSONResponse: # noqa: N803 """Get parameters for node choice.""" - result = await ParameterManager.get_operate_and_bind_type(param_type) + result = await ParameterManager.get_operate_and_bind_type(paramType) return JSONResponse( status_code=status.HTTP_200_OK, content=GetOperaRsp( diff --git a/apps/routers/record.py b/apps/routers/record.py index 81fa3003..9277b541 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -103,6 +103,7 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path stepStatus=flow_step.step_status, input=flow_step.input_data, output=flow_step.output_data, + exData=flow_step.ex_data, ), ) diff --git a/apps/scheduler/call/choice/choice.py b/apps/scheduler/call/choice/choice.py index f294c9cb..928f27ee 100644 --- a/apps/scheduler/call/choice/choice.py +++ b/apps/scheduler/call/choice/choice.py @@ -19,7 +19,6 @@ from apps.scheduler.call.choice.schema import ( ) from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType -from apps.schemas.parameters import Type from apps.schemas.scheduler import ( CallError, CallInfo, @@ -110,10 +109,7 @@ class Choice(CoreCall, input_model=ChoiceInput, output_model=ChoiceOutput): f"右值类型不匹配:{condition.right.value} 应为 {right_value_type.value}") logger.warning(msg) continue - if right_value_type == Type.STRING: - condition.right.value = str(condition.right.value) - else: - condition.right.value = ast.literal_eval(condition.right.value) + condition.right.value = ast.literal_eval(str(condition.right.value)) if not ConditionHandler.check_value_type( condition.right, condition.right.type, ): diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index ec51c45f..ed58a88f 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -14,7 +14,9 @@ from pydantic.json_schema import SkipJsonSchema from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM +from apps.models.llm import LLMData from apps.models.node import NodeInfo +from apps.models.task import ExecutorHistory from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( CallError, @@ -24,7 +26,6 @@ from apps.schemas.scheduler import ( CallTokens, CallVars, ) -from apps.schemas.task import FlowStepHistory if TYPE_CHECKING: from apps.scheduler.executor.step import StepExecutor @@ -122,7 +123,7 @@ class CoreCall(BaseModel): @staticmethod - def _extract_history_variables(path: str, history: dict[str, FlowStepHistory]) -> Any: + def _extract_history_variables(path: str, history: dict[str, ExecutorHistory]) -> Any: """ 提取History中的变量 @@ -210,8 +211,7 @@ class CoreCall(BaseModel): return result - async def _json(self, messages: list[dict[str, Any]], schema: type[BaseModel]) -> BaseModel: + async def _json(self, messages: list[dict[str, Any]], schema: dict[str, Any]) -> dict[str, Any]: """Call可直接使用的JSON生成""" json = FunctionLLM() - result = await json.call(messages=messages, schema=schema.model_json_schema()) - return schema.model_validate(result) + return await json.call(messages=messages, schema=schema) diff --git a/apps/scheduler/call/llm/llm.py b/apps/scheduler/call/llm/llm.py index f16cffc6..4a458cb2 100644 --- a/apps/scheduler/call/llm/llm.py +++ b/apps/scheduler/call/llm/llm.py @@ -11,7 +11,6 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field -from apps.llm.reasoning import ReasoningLLM from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index deed87b9..06ba1d7b 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -9,8 +9,6 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field -from apps.llm.function import FunctionLLM, JsonGenerator -from apps.llm.reasoning import ReasoningLLM from apps.models.node import NodeInfo from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot as SlotProcessor diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 7b0b2d99..9f0dfa83 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -12,7 +12,6 @@ from pydantic import Field from pydantic.json_schema import SkipJsonSchema from apps.common.security import Security -from apps.llm.function import FunctionLLM from apps.models.node import NodeInfo from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType @@ -165,7 +164,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO }, preference=user_domain, ) - result = await FunctionLLM().call( + result = await self._json( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, @@ -197,7 +196,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO }, preference=user_domain, ) - result = await FunctionLLM().call( + result = await self._json( messages=[ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}, diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index e70b229a..6f933b8b 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -2,6 +2,7 @@ """MCP Agent执行器""" import logging +import traceback import uuid import anyio @@ -26,7 +27,7 @@ from apps.schemas.mcp import ( ToolExcutionErrorType, ToolRisk, ) -from apps.schemas.message import param +from apps.schemas.message import FlowParams from apps.schemas.task import ExecutorState, FlowStepHistory, StepQueueItem from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager @@ -48,7 +49,7 @@ class MCPAgentExecutor(BaseExecutor): tools: dict[str, MCPTool] = Field( description="MCP工具列表,key为tool_id", default={}, ) - params: param | bool | None = Field( + params: FlowParams | bool | None = Field( default=None, description="流执行过程中的参数补充", alias="params", ) resoning_llm: ReasoningLLM = Field( @@ -108,14 +109,15 @@ class MCPAgentExecutor(BaseExecutor): start_index = start_index.start_index current_plan = MCPPlan(plans=self.task.runtime.temporary_plans.plans[start_index:]) error_message = self.task.state.error_message - temporary_plans = await MCPPlanner.create_plan(self.task.runtime.question, - is_replan=is_replan, - error_message=error_message, - current_plan=current_plan, - tool_list=tools, - max_steps=self.max_steps-start_index-1, - reasoning_llm=self.resoning_llm - ) + temporary_plans = await MCPPlanner.create_plan( + self.task.runtime.question, + is_replan=is_replan, + error_message=error_message, + current_plan=current_plan, + tool_list=tools, + max_steps=self.max_steps-start_index-1, + reasoning_llm=self.resoning_llm, + ) await self.update_tokens() await self.push_message( EventType.STEP_CANCEL, @@ -140,10 +142,12 @@ class MCPAgentExecutor(BaseExecutor): tool_id = self.task.runtime.temporary_plans.plans[self.task.state.step_index].tool step = self.task.runtime.temporary_plans.plans[self.task.state.step_index] mcp_tool = self.tools[tool_id] - self.task.state.current_input = await MCPHost._get_first_input_params(mcp_tool, self.task.runtime.question, self.task) + self.task.state.current_input = await MCPHost.get_first_input_params( + mcp_tool, self.task.runtime.question, step.instruction, self.task + ) else: # 获取后续输入参数 - if isinstance(self.params, param): + if isinstance(self.params, FlowParams): params = self.params.content params_description = self.params.description else: @@ -151,7 +155,8 @@ class MCPAgentExecutor(BaseExecutor): params_description = "" tool_id = self.task.runtime.temporary_plans.plans[self.task.state.step_index].tool mcp_tool = self.tools[tool_id] - self.task.state.current_input = await MCPHost._get_first_input_params(mcp_tool, step.instruction, self.task) + step = self.task.runtime.temporary_plans.plans[self.task.state.step_index] + self.task.state.current_input = await MCPHost._fill_params(mcp_tool, self.task.runtime.question, step.instruction, self.task.state.current_input, self.task.state.error_message, params, params_description) async def reset_step_to_index(self, start_index: int) -> None: """重置步骤到开始""" @@ -273,7 +278,7 @@ class MCPAgentExecutor(BaseExecutor): self.resoning_llm, ) await self.update_tokens() - error_message = MCPPlanner.change_err_message_to_description( + error_message = await MCPPlanner.change_err_message_to_description( error_message=self.task.state.error_message, tool=mcp_tool, input_params=self.task.state.current_input, @@ -334,7 +339,7 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.flow_status = FlowStatus.ERROR await self.push_message( EventType.FLOW_FAILED, - data={} + data={}, ) if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: del self.task.context[-1] @@ -350,7 +355,7 @@ class MCPAgentExecutor(BaseExecutor): flow_status=self.task.state.flow_status, input_data={}, output_data={}, - ) + ), ) async def work(self) -> None: @@ -358,7 +363,7 @@ class MCPAgentExecutor(BaseExecutor): if self.task.state.step_status == StepStatus.INIT: await self.push_message( EventType.STEP_INIT, - data={} + data={}, ) await self.get_tool_input_param(is_first=True) user_info = await UserManager.get_userinfo_by_user_sub(self.task.ids.user_sub) @@ -380,11 +385,11 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.step_status = StepStatus.CANCELLED await self.push_message( EventType.STEP_CANCEL, - data={} + data={}, ) await self.push_message( EventType.FLOW_CANCEL, - data={} + data={}, ) if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: self.task.context[-1].step_status = StepStatus.CANCELLED @@ -415,7 +420,7 @@ class MCPAgentExecutor(BaseExecutor): ) if error_type.type == ErrorType.DECORRECT_PLAN or user_info.auto_execute: await self.plan(is_replan=True) - self.reset_step_to_index(self.task.state.step_index) + await self.reset_step_to_index(self.task.state.step_index) elif error_type.type == ErrorType.MISSING_PARAM: await self.generate_params_with_null() elif self.task.state.step_status == StepStatus.SUCCESS: @@ -455,20 +460,34 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.error_message = str(e) await self.push_message( EventType.FLOW_FAILED, - data={} + data={}, ) return self.task.state.flow_status = FlowStatus.RUNNING + plan = { + "plans": [], + } + for p in self.task.runtime.temporary_plans.plans: + if p.tool == FINAL_TOOL_ID: + continue + mcp_tool = self.tools.get(p.tool, None) + plan["plans"].append( + { + "stepId": p.step_id, + "stepName": mcp_tool.name, + "stepDescription": p.content, + }, + ) await self.push_message( EventType.FLOW_START, - data={} + data=plan, ) if self.task.state.step_id == FINAL_TOOL_ID: # 如果已经是最后一步,直接结束 self.task.state.flow_status = FlowStatus.SUCCESS await self.push_message( EventType.FLOW_SUCCESS, - data={} + data={}, ) await self.summarize() return @@ -488,23 +507,21 @@ class MCPAgentExecutor(BaseExecutor): self.task.state.step_status = StepStatus.SUCCESS await self.push_message( EventType.FLOW_SUCCESS, - data={} + data={}, ) await self.summarize() except Exception as e: - import traceback - logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", traceback.format_exc()) - logger.error("[MCPAgentExecutor] 执行过程中发生错误: %s", str(e)) + logger.exception("[MCPAgentExecutor] 执行过程中发生错误") self.task.state.flow_status = FlowStatus.ERROR self.task.state.error_message = str(e) self.task.state.step_status = StepStatus.ERROR await self.push_message( EventType.STEP_ERROR, - data={} + data={}, ) await self.push_message( EventType.FLOW_FAILED, - data={} + data={}, ) if len(self.task.context) and self.task.context[-1].step_id == self.task.state.step_id: del self.task.context[-1] @@ -520,12 +537,11 @@ class MCPAgentExecutor(BaseExecutor): flow_status=self.task.state.flow_status, input_data={}, output_data={}, - ) + ), ) finally: for mcp_service in self.mcp_list: try: await self.mcp_pool.stop(mcp_service.id, self.task.ids.user_sub) - except Exception as e: - import traceback - logger.error("[MCPAgentExecutor] 停止MCP客户端时发生错误: %s", traceback.format_exc()) + except Exception: + logger.exception("[MCPAgentExecutor] 停止MCP客户端时发生错误") diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 86e72483..12fd5951 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -8,10 +8,10 @@ from typing import Any from pydantic import BaseModel, ConfigDict from apps.common.queue import MessageQueue +from apps.models.task import Task from apps.schemas.enum_var import EventType -from apps.schemas.message import FlowStartContent, TextAddContent +from apps.schemas.message import TextAddContent from apps.schemas.scheduler import ExecutorBackground -from apps.schemas.task import Task logger = logging.getLogger(__name__) @@ -44,12 +44,7 @@ class BaseExecutor(BaseModel, ABC): :param event_type: 事件类型 :param data: 消息数据,如果是FLOW_START事件且data为None,则自动构建FlowStartContent """ - if event_type == EventType.FLOW_START.value and isinstance(data, dict): - data = FlowStartContent( - question=self.question, - params=self.task.runtime.filled, - ).model_dump(exclude_none=True, by_alias=True) - elif event_type == EventType.TEXT_ADD.value and isinstance(data, str): + if event_type == EventType.TEXT_ADD.value and isinstance(data, str): data = TextAddContent(text=data).model_dump(exclude_none=True, by_alias=True) if data is None: diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index c0f52c7f..f068e9d2 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -45,7 +45,7 @@ class MCPHost: context_list=task.context, ) - async def _get_first_input_params(mcp_tool: MCPTool, query: str, task: Task) -> dict[str, Any]: + async def get_first_input_params(mcp_tool: MCPTool, query: str, task: Task) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入·指令,这样可以调用generate llm_query = rf""" diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 5d2da583..a0eebe94 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -29,24 +29,15 @@ class AppLoader: """应用加载器""" @staticmethod - async def load(app_id: uuid.UUID, hashes: dict[str, str]) -> None: # noqa: C901 + async def load(app_id: uuid.UUID, hashes: dict[str, str]) -> None: """ 从文件系统中加载应用 :param app_id: 应用 ID """ app_path = BASE_PATH / str(app_id) - metadata_path = app_path / "metadata.yaml" - metadata = await MetadataLoader().load_one(metadata_path) - if not metadata: - err = f"[AppLoader] 元数据不存在: {metadata_path}" - raise ValueError(err) + metadata = await AppLoader.read_metadata(app_id) metadata.hashes = hashes - - if not isinstance(metadata, (AppMetadata, AgentAppMetadata)): - err = f"[AppLoader] 元数据类型错误: {metadata_path}" - raise TypeError(err) - if metadata.app_type == AppType.FLOW and isinstance(metadata, AppMetadata): # 加载工作流 flow_path = app_path / "flow" @@ -90,6 +81,20 @@ class AppLoader: await AppLoader._update_db(metadata) + @staticmethod + async def read_metadata(app_id: uuid.UUID) -> AppMetadata | AgentAppMetadata: + """读取应用元数据""" + metadata_path = BASE_PATH / str(app_id) / "metadata.yaml" + metadata = await MetadataLoader().load_one(metadata_path) + if not metadata: + err = f"[AppLoader] 元数据不存在: {metadata_path}" + raise ValueError(err) + if not isinstance(metadata, (AppMetadata, AgentAppMetadata)): + err = f"[AppLoader] 元数据类型错误: {metadata_path}" + raise TypeError(err) + return metadata + + @staticmethod async def save(metadata: AppMetadata | AgentAppMetadata, app_id: uuid.UUID) -> None: """ diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 4c9d9671..fbcafa79 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -189,7 +189,7 @@ class FlowLoader: @staticmethod - async def delete(app_id: uuid.UUID, flow_id: uuid.UUID) -> bool: + async def delete(app_id: uuid.UUID, flow_id: uuid.UUID) -> None: """删除指定工作流文件""" flow_path = BASE_PATH / str(app_id) / "flow" / f"{flow_id}.yaml" # 确保目标为文件且存在 @@ -198,10 +198,16 @@ class FlowLoader: await flow_path.unlink() async with postgres.session() as session: + await session.execute(delete(FlowInfo).where( + and_( + FlowInfo.appId == app_id, + FlowInfo.id == flow_id, + ), + )) await session.execute(delete(FlowPoolVector).where(FlowPoolVector.id == flow_id)) - return True + await session.commit() + return logger.warning("[FlowLoader] 工作流文件不存在或不是文件:%s", flow_path) - return True @staticmethod diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 3371cbdc..05c8b3f8 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -17,7 +17,6 @@ from apps.schemas.record import ( RecordGroupDocument, ) from apps.schemas.request_data import RequestData -from apps.schemas.task import Task from apps.services.appcenter import AppCenterManager from apps.services.document import DocumentManager from apps.services.record import RecordManager @@ -28,24 +27,13 @@ logger = logging.getLogger(__name__) async def get_docs(post_body: RequestData) -> tuple[list[RecordDocument] | list[Document], list[str]]: """获取当前问答可供关联的文档""" - if not post_body.group_id: - err = "[Scheduler] 问答组ID不能为空!" - logger.error(err) - raise ValueError(err) - doc_ids = [] - docs = await DocumentManager.get_used_docs_by_record(post_body.group_id, "question") - if not docs: - # 是新提问 - # 从Conversation中获取刚上传的文档 - docs = await DocumentManager.get_unused_docs(post_body.conversation_id) - # 从最近10条Record中获取文档 - docs += await DocumentManager.get_used_docs(post_body.conversation_id, 10, "question") - doc_ids += [doc.id for doc in docs] - else: - # 是重新生成 - doc_ids += [doc.id for doc in docs] + # 从Conversation中获取刚上传的文档 + docs = await DocumentManager.get_unused_docs(post_body.conversation_id) + # 从最近10条Record中获取文档 + docs += await DocumentManager.get_used_docs(post_body.conversation_id, 10, "question") + doc_ids += [doc.id for doc in docs] return docs, doc_ids diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 8fc5754b..9d46a3fc 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -2,11 +2,12 @@ """Scheduler中,关于Flow的逻辑""" import logging +import uuid from apps.llm.patterns import Select from apps.scheduler.pool.pool import Pool from apps.schemas.request_data import RequestDataApp -from apps.schemas.task import Task +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -14,9 +15,9 @@ logger = logging.getLogger(__name__) class FlowChooser: """Flow选择器""" - def __init__(self, task: Task, question: str, user_selected: RequestDataApp | None = None) -> None: + def __init__(self, task_id: uuid.UUID, question: str, user_selected: RequestDataApp | None = None) -> None: """初始化Flow选择器""" - self.task = task + self.task_id = task_id self._question = question self._user_selected = user_selected @@ -31,15 +32,15 @@ class FlowChooser: if not flow_list: return "KnowledgeBase" - logger.info("[FlowChooser] 选择任务 %s 最合适的Flow", self.task.id) + logger.info("[FlowChooser] 选择任务 %s 最合适的Flow", self.task_id) choices = [{ "name": flow.id, "description": f"{flow.name}, {flow.description}", } for flow in flow_list] select_obj = Select() top_flow = await select_obj.generate(question=self._question, choices=choices) - self.task.tokens.input_tokens += select_obj.input_tokens - self.task.tokens.output_tokens += select_obj.output_tokens + + await TaskManager.update_task_token(self.task_id, select_obj.input_tokens, select_obj.output_tokens) return top_flow @@ -64,5 +65,5 @@ class FlowChooser: return RequestDataApp( appId=self._user_selected.app_id, flowId=top_flow, - params={}, + params=None, ) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 5383c419..a5981906 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -8,6 +8,7 @@ from textwrap import dedent from apps.common.config import config from apps.common.queue import MessageQueue from apps.models.document import Document +from apps.models.task import Task from apps.schemas.enum_var import EventType, FlowStatus from apps.schemas.message import ( DocumentAddContent, @@ -17,7 +18,6 @@ from apps.schemas.message import ( ) from apps.schemas.rag_data import RAGEventData, RAGQueryReq from apps.schemas.record import RecordDocument -from apps.schemas.task import Task from apps.services.rag import RAG from apps.services.task import TaskManager diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index d5e2e313..7a828bb4 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -3,13 +3,14 @@ import asyncio import logging -from datetime import UTC, datetime +import uuid from apps.common.config import config from apps.common.queue import MessageQueue from apps.llm.patterns.rewrite import QuestionRewrite from apps.llm.reasoning import ReasoningLLM -from apps.models.app import App +from apps.models.llm import LLMData +from apps.models.task import Task from apps.scheduler.executor.agent import MCPAgentExecutor from apps.scheduler.executor.flow import FlowExecutor from apps.scheduler.pool.pool import Pool @@ -24,11 +25,11 @@ from apps.schemas.enum_var import AppType, EventType, FlowStatus from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground -from apps.schemas.task import Task from apps.services.activity import Activity from apps.services.appcenter import AppCenterManager from apps.services.knowledge import KnowledgeBaseManager from apps.services.llm import LLMManager +from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -40,16 +41,15 @@ class Scheduler: Scheduler包含一个“SchedulerContext”,作用为多个Executor的“聊天会话” """ - def __init__(self, task: Task, queue: MessageQueue, post_body: RequestData) -> None: + async def init(self, task_id: uuid.UUID, queue: MessageQueue, post_body: RequestData) -> None: """初始化""" self.used_docs = [] - self.task = task - + self.task_id = task_id self.queue = queue self.post_body = post_body - async def _monitor_activity(self, kill_event, user_sub): + async def _monitor_activity(self, kill_event: asyncio.Event, user_sub: str) -> None: """监控用户活动状态,不活跃时终止工作流""" try: check_interval = 0.5 # 每0.5秒检查一次 @@ -67,56 +67,36 @@ class Scheduler: await asyncio.sleep(check_interval) except asyncio.CancelledError: logger.info("[Scheduler] 活动监控任务已取消") - except Exception as e: - logger.error(f"[Scheduler] 活动监控过程中发生错误: {e}") - - - async def get_llm_use_in_chat_with_rag(self) -> LLM: - """获取RAG大模型""" - try: - # 获取当前会话使用的大模型 - llm_id = await LLMManager.get_llm_id_by_conversation_id( - self.task.ids.user_sub, self.task.ids.conversation_id, - ) - if not llm_id: - logger.error("[Scheduler] 获取大模型ID失败") - return None - if llm_id == "empty": - return LLM( - _id="empty", - user_sub=self.task.ids.user_sub, - openai_base_url=Config().get_config().llm.endpoint, - openai_api_key=Config().get_config().llm.key, - model_name=Config().get_config().llm.model, - max_tokens=Config().get_config().llm.max_tokens, - ) - llm = await LLMManager.get_llm_by_id(self.task.ids.user_sub, llm_id) - if not llm: - logger.error("[Scheduler] 获取大模型失败") - return None - return llm except Exception: - logger.exception("[Scheduler] 获取大模型失败") - return None + logger.exception("[Scheduler] 活动监控过程中发生错误") + kill_event.set() - async def get_kb_ids_use_in_chat_with_rag(self) -> list[str]: - """获取知识库ID列表""" - try: - return await KnowledgeBaseManager.get_kb_ids_by_conversation_id( - self.task.ids.user_sub, self.task.ids.conversation_id, - ) - except Exception: - logger.exception("[Scheduler] 获取知识库ID失败") - await self.queue.close() - return [] + async def get_chat_llm(self, llm_id: uuid.UUID) -> LLMData: + """获取RAG大模型""" + # 获取当前会话使用的大模型 + llm = await LLMManager.get_llm(llm_id) + if not llm: + err = "[Scheduler] 获取大模型ID失败" + logger.error(err) + raise ValueError(err) + return llm async def run(self) -> None: """运行调度器""" + task = await TaskManager.get_task_by_conversation_id(self.task_id) + if not task: + task = Task( + id=self.task_id, + ids=TaskIds( + user_sub=self.post_body.user_sub, + conversation_id=self.task_id, + ), + ) try: # 获取当前问答可供关联的文档 - docs, doc_ids = await get_docs(self.task.ids.user_sub, self.post_body) + docs, doc_ids = await get_docs(self.post_body) except Exception: logger.exception("[Scheduler] 获取文档失败") await self.queue.close() @@ -134,8 +114,8 @@ class Scheduler: if self.task.state.app_id: rag_method = False if rag_method: - llm = await self.get_llm_use_in_chat_with_rag() - kb_ids = await self.get_kb_ids_use_in_chat_with_rag() + llm = await self.get_chat_llm() + kb_ids = await KnowledgeBaseManager.get_selected_kb(self.task.ids.user_sub) self.task = await push_init_message(self.task, self.queue, 3, is_flow=False) rag_data = RAGQueryReq( kbIds=kb_ids, diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 8a9b35eb..77b17399 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -12,7 +12,7 @@ from apps.schemas.enum_var import EventType, FlowStatus, StepStatus from .record import RecordMetadata -class param(BaseModel): +class FlowParams(BaseModel): """流执行过程中的参数补充""" content: dict[str, Any] = Field(default={}, description="流执行过程中的参数补充内容") @@ -80,7 +80,8 @@ class DocumentAddContent(BaseModel): document_type: str = Field(description="文档MIME类型", alias="documentType", default="") document_size: float = Field(ge=0, description="文档大小,单位是KB,保留两位小数", alias="documentSize", default=0) created_at: float = Field( - description="文档创建时间,单位是秒", alias="createdAt", default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3) + description="文档创建时间,单位是秒", alias="createdAt", + default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3), ) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 6e79def6..d30c57da 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -10,7 +10,7 @@ from .appcenter import AppData from .enum_var import CommentType from .flow_topology import FlowItem from .mcp import MCPType -from .message import param +from .message import FlowParams class RequestDataApp(BaseModel): @@ -18,7 +18,7 @@ class RequestDataApp(BaseModel): app_id: uuid.UUID = Field(description="应用ID", alias="appId") flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId") - params: param | None = Field(default=None, description="流执行过程中的参数补充", alias="params") + params: FlowParams | None = Field(default=None, description="流执行过程中的参数补充", alias="params") class RequestData(BaseModel): diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index 7f4a6f86..01106f75 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -15,7 +15,7 @@ from apps.models.app import App, AppACL from apps.models.user import UserAppUsage, UserFavorite, UserFavoriteType from apps.scheduler.pool.loader.app import AppLoader from apps.schemas.agent import AgentAppMetadata -from apps.schemas.appcenter import AppCenterCardItem, AppData, AppPermissionData +from apps.schemas.appcenter import AppCenterCardItem, AppData from apps.schemas.enum_var import AppFilterType, AppType, PermissionType from apps.schemas.flow import AppMetadata, MetadataType, Permission from apps.schemas.response_data import RecentAppList, RecentAppListItem @@ -37,24 +37,30 @@ class AppCenterManager: :param app_id: 应用id :return: 如果用户具有所需权限则返回True,否则返回False """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") - query = { - "_id": app_id, - "$or": [ - {"author": user_sub}, - {"permission.type": PermissionType.PUBLIC.value}, - { - "$and": [ - {"permission.type": PermissionType.PROTECTED.value}, - {"permission.users": user_sub}, - ], - }, - ], - } + async with postgres.session() as session: + sql = select(App.author, App.permission).where(App.id == app_id) + app_info = (await session.execute(sql)).one_or_none() + if not app_info: + msg = f"[AppCenterManager] 应用不存在: {app_id}" + raise ValueError(msg) - result = await app_collection.find_one(query) - return result is not None + # 作者一定可以访问 + if app_info.author == user_sub: + return True + + if app_info.permission == PermissionType.PUBLIC: + return True + if app_info.permission == PermissionType.PRIVATE: + return False + if app_info.permission == PermissionType.PROTECTED: + sql = select(AppACL.appId).where( + AppACL.appId == app_id, + AppACL.userSub == user_sub, + ) + acl_info = (await session.execute(sql)).one_or_none() + if acl_info: + return True + return False @staticmethod @@ -118,12 +124,13 @@ class AppCenterManager: ).cte() # 获取用户所有收藏的应用 - user_favourite_apps = select(UserFavorite.itemId).where( + favapps_sql = select(UserFavorite.itemId).where( and_( UserFavorite.userSub == user_sub, UserFavorite.favouriteType == UserFavoriteType.APP, ), ) + favapps = list((await session.scalars(favapps_sql)).all()) # 根据搜索类型加入搜索条件 if filter_type == AppFilterType.ALL: @@ -133,7 +140,7 @@ class AppCenterManager: elif filter_type == AppFilterType.FAVORITE: filtered_apps = select(app_data).where( and_( - App.id.in_(user_favourite_apps), + App.id.in_(favapps_sql), App.isPublished == True, # noqa: E712 ), ).cte() @@ -174,7 +181,7 @@ class AppCenterManager: name=app.name, description=app.description, author=app.author, - favorited=(app.id in user_favourite_apps), + favorited=(app.id in favapps), published=app.isPublished, ) for app in result @@ -480,7 +487,7 @@ class AppCenterManager: @staticmethod - def _create_agent_metadata( + async def _create_agent_metadata( common_params: dict, user_sub: str, data: AppData | None = None, @@ -494,7 +501,9 @@ class AppCenterManager: # mcp_service 逻辑 if data is not None and hasattr(data, "mcp_service") and data.mcp_service: # 创建应用场景,验证传入的 mcp_service 状态,确保只使用已经激活的 (create_app) - metadata.mcp_service = [svc.id for svc in data.mcp_service if MCPServiceManager.is_active(user_sub, svc.id)] + metadata.mcp_service = [ + mcp_id for mcp_id in data.mcp_service if await MCPServiceManager.is_active(user_sub, mcp_id) + ] elif data is not None and hasattr(data, "mcp_service"): # 更新应用场景,使用 data 中的 mcp_service (update_app) metadata.mcp_service = data.mcp_service if data.mcp_service is not None else [] @@ -559,7 +568,9 @@ class AppCenterManager: # 根据应用类型创建不同的元数据 if app_type == AppType.AGENT: - return AppCenterManager._create_agent_metadata(common_params, user_sub, data, app_data, published=published) + return await AppCenterManager._create_agent_metadata( + common_params, user_sub, data, app_data, published=published, + ) return AppCenterManager._create_flow_metadata(common_params, data, app_data, published=published) diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 17c0dad2..6b60e143 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -67,7 +67,9 @@ class ConversationManager: @staticmethod - async def add_conversation_by_user_sub(user_sub: str, app_id: uuid.UUID, *, debug: bool) -> Conversation | None: + async def add_conversation_by_user_sub( + user_sub: str, app_id: uuid.UUID | None = None, *, debug: bool = False, + ) -> Conversation | None: """通过用户ID新建对话""" conv = Conversation( userSub=user_sub, diff --git a/apps/services/flow.py b/apps/services/flow.py index 8d371c99..92052a92 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -4,18 +4,19 @@ import logging import uuid -from sqlalchemy import and_, select +from sqlalchemy import and_, delete, select from apps.common.postgres import postgres -from apps.models.app import App +from apps.models.app import App, AppHashes from apps.models.flow import Flow as FlowInfo from apps.models.node import NodeInfo from apps.models.service import Service from apps.models.user import UserFavorite, UserFavoriteType +from apps.scheduler.pool.loader.app import AppLoader from apps.scheduler.pool.loader.flow import FlowLoader from apps.scheduler.slot.slot import Slot from apps.schemas.enum_var import EdgeType -from apps.schemas.flow import Edge, Flow, Step +from apps.schemas.flow import AppMetadata, Edge, Flow, Step from apps.schemas.flow_topology import ( EdgeItem, FlowItem, @@ -38,7 +39,7 @@ class FlowManager: @staticmethod async def get_node_id_by_service_id(service_id: uuid.UUID) -> list[NodeMetaDataBase] | None: """ - serviceId获取service的基础信息 + 根据serviceId获取service内节点的基础信息,用于左侧节点列表展示 :param service_id: 服务id :return: 节点基础信息的列表,按创建时间排序 @@ -65,7 +66,7 @@ class FlowManager: @staticmethod async def get_service_by_user_id(user_sub: str) -> list[NodeServiceItem] | None: """ - 通过user_id获取用户自己上传的或收藏的Service + 通过user_id获取用户自己上传的或收藏的Service,用于插件中心展示 :user_sub: 用户的唯一标识符 :return: service的列表 @@ -351,17 +352,26 @@ class FlowManager: :param app_id: 应用的id :param flow_id: 流的id """ - app_collection = MongoDB().get_collection("app") - key = f"flow/{flow_id}.yaml" - await app_collection.update_one({"_id": app_id}, {"$unset": {f"hashes.{key}": ""}}) - await app_collection.update_one({"_id": app_id}, {"$pull": {"flows": {"id": flow_id}}}) + await FlowLoader.delete(app_id, flow_id) - result = await FlowLoader.delete(app_id, flow_id) + async with postgres.session() as session: + key = f"flow/{flow_id}.yaml" + await session.execute( + delete(AppHashes).where( + and_( + AppHashes.appId == app_id, + AppHashes.filePath == key, + ), + ), + ) - if result is None: - error_msg = f"[FlowManager] 删除流 {flow_id} 失败" - logger.error(error_msg) - raise ValueError(error_msg) + metadata = await AppLoader.read_metadata(app_id) + if not isinstance(metadata, AppMetadata): + err = f"[FlowManager] 应用 {app_id} 不是Flow应用" + logger.error(err) + raise TypeError(err) + metadata.flows = [flow for flow in metadata.flows if flow.id != flow_id] + await AppLoader.save(metadata, app_id) @staticmethod diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 8c9cc8b9..b1fbae0a 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -236,12 +236,17 @@ class MCPServiceManager: else: config = MCPServerStdioConfig.model_validate(data.config) + if not data.service_id: + data.service_id = str(uuid.uuid4()) + # 构造Server mcp_server = MCPServerConfig( name=await MCPServiceManager.clean_name(data.name), overview=data.overview, description=data.description, - mcpServers=config, + mcpServers={ + data.service_id: config, + }, mcpType=data.mcp_type, author=user_sub, ) @@ -255,7 +260,7 @@ class MCPServiceManager: # 保存并载入配置 logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name) - mcp_path = MCP_PATH / "template" / mcp_id / "project" + mcp_path = MCP_PATH / "template" / data.service_id / "project" index = None if isinstance(config, MCPServerStdioConfig): index = None @@ -268,7 +273,7 @@ class MCPServiceManager: if index >= len(config.args): config.args.append(str(mcp_path)) else: - config.args[index+1] = str(mcp_path) + config.args[index] = str(mcp_path) else: config.args += ["--directory", str(mcp_path)] await MCPLoader._insert_template_db(mcp_id=mcp_id, config=mcp_server) diff --git a/apps/services/parameter.py b/apps/services/parameter.py index 7d22b0c1..ffb39976 100644 --- a/apps/services/parameter.py +++ b/apps/services/parameter.py @@ -2,6 +2,7 @@ """Parameter Manager""" import logging +import uuid from apps.scheduler.call.choice.condition_handler import ConditionHandler from apps.scheduler.call.choice.schema import BoolOperate, DictOperate, ListOperate, NumberOperate, StringOperate, Type @@ -45,7 +46,7 @@ class ParameterManager: return result @staticmethod - async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: str) -> list[StepParams]: + async def get_pre_params_by_flow_and_step_id(flow: FlowItem, step_id: uuid.UUID) -> list[StepParams]: """Get pre params by flow and step id""" index = 0 q = [step_id] diff --git a/apps/services/task.py b/apps/services/task.py index 95afa094..647d1796 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -4,14 +4,12 @@ import logging import uuid -from sqlalchemy import delete, select +from sqlalchemy import delete, select, update from apps.common.postgres import postgres -from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task +from apps.models.task import ExecutorCheckpoint, ExecutorHistory, Task, TaskRuntime from apps.schemas.request_data import RequestData -from .record import RecordManager - logger = logging.getLogger(__name__) @@ -19,59 +17,40 @@ class TaskManager: """从数据库中获取任务信息""" @staticmethod - async def get_task_by_conversation_id(conversation_id: str) -> Task | None: - """获取对话ID的最后一条问答组关联的任务""" - # 查询对话ID的最后一条问答组 - last_group = await RecordManager.query_record_group_by_conversation_id(conversation_id, 1) - if not last_group or len(last_group) == 0: - logger.error("[TaskManager] 没有找到对话 %s 的问答组", conversation_id) - # 空对话或无效对话,新建Task - return None - - last_group = last_group[0] - task_id = last_group.task_id - - # 查询最后一条问答组关联的任务 - task_collection = MongoDB().get_collection("task") - task = await task_collection.find_one({"_id": task_id}) - if not task: - # 任务不存在,新建Task - logger.error("[TaskManager] 任务 %s 不存在", task_id) - return None + async def get_task_by_conversation_id(conversation_id: str) -> Task: + """获取对话ID的最后一个任务""" + async with postgres.session() as session: + task = (await session.scalars( + select(Task).where(Task.conversationId == conversation_id).order_by(Task.updatedAt.desc()).limit(1), + )).one_or_none() + if not task: + # 任务不存在,新建Task + logger.info("[TaskManager] 新建任务", task_id) + return Task( + conversationId=conversation_id, - return Task.model_validate(task) + ) + return task @staticmethod async def get_task_by_task_id(task_id: uuid.UUID) -> Task | None: """根据task_id获取任务""" - task_collection = MongoDB().get_collection("task") - task = await task_collection.find_one({"_id": task_id}) - if not task: - return None - return Task.model_validate(task) + async with postgres.session() as session: + return (await session.scalars( + select(Task).where(Task.id == task_id), + )).one_or_none() @staticmethod - async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[FlowStepHistory]: + async def get_context_by_task_id(task_id: str, length: int | None = None) -> list[ExecutorHistory]: """根据task_id获取flow信息""" async with postgres.session() as session: - executor_history_collection = session.query(ExecutorHistory) - - flow_context = [] - try: - async for history in flow_context_collection.find( - {"task_id": task_id}, - ).sort( - "created_at", -1, - ).limit(length): - for i in range(len(flow_context)): - flow_context.append(FlowStepHistory.model_validate(history)) - except Exception: - logger.exception("[TaskManager] 获取task_id的flow信息失败") - return [] - else: - return flow_context + return list((await session.scalars( + select(ExecutorHistory).where( + ExecutorHistory.taskId == task_id, + ).order_by(ExecutorHistory.updatedAt.desc()).limit(length), + )).all()) @staticmethod @@ -94,7 +73,7 @@ class TaskManager: ) @staticmethod - async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None: + async def save_flow_context(task_id: str, flow_context: list[ExecutorHistory]) -> None: """保存flow信息到flow_context""" flow_context_collection = MongoDB().get_collection("flow_context") try: @@ -116,7 +95,7 @@ class TaskManager: @staticmethod - async def delete_task_by_task_id(task_id: str) -> None: + async def delete_task_by_task_id(task_id: uuid.UUID) -> None: """通过task_id删除Task信息""" async with postgres.session() as session: task = (await session.scalars( @@ -127,7 +106,7 @@ class TaskManager: @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> list[str]: + async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> list[uuid.UUID]: """通过ConversationID删除Task信息""" async with postgres.session() as session: task_ids = [] @@ -175,3 +154,21 @@ class TaskManager: {"$set": task.model_dump(by_alias=True, exclude_none=True)}, upsert=True, ) + + + @classmethod + async def update_task_token(cls, task_id: uuid.UUID, input_token: int, output_token: int) -> tuple[int, int]: + """更新任务的Token""" + async with postgres.session() as session: + sql = select(TaskRuntime.inputToken, TaskRuntime.outputToken).where(TaskRuntime.taskId == task_id) + row = (await session.execute(sql)).one() + + new_input_token = row.inputToken + input_token + new_output_token = row.outputToken + output_token + + sql = update(TaskRuntime).where(TaskRuntime.taskId == task_id).values( + inputToken=new_input_token, outputToken=new_output_token, + ) + await session.execute(sql) + await session.commit() + return new_input_token, new_output_token -- Gitee