diff --git a/apps/common/queue.py b/apps/common/queue.py
index 911485b31176011204bb7b92a4aafd6f5eeaa41a..089d475e69752dc7fac595b3556eead69bd64aac 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -56,6 +56,8 @@ class MessageQueue:
flow = MessageFlow(
appId=task.state.app_id,
flowId=task.state.flow_id,
+ flowName=task.state.flow_name,
+ flowStatus=task.state.flow_status,
stepId=task.state.step_id,
stepName=task.state.step_name,
stepStatus=task.state.step_status
diff --git a/apps/main.py b/apps/main.py
index 1646f671adbda1867ecd1077be5b3fcfadcd93e2..58e2bd40bad99db5d12cc21957525f7b60fd47c0 100644
--- a/apps/main.py
+++ b/apps/main.py
@@ -97,10 +97,13 @@ async def add_no_auth_user() -> None:
username = os.environ.get('USER') # 适用于 Linux 和 macOS 系统
if not username:
username = "admin"
- await user_collection.insert_one(User(
- _id=username,
- is_admin=True,
- ).model_dump(by_alias=True))
+ try:
+ await user_collection.insert_one(User(
+ _id=username,
+ is_admin=True,
+ ).model_dump(by_alias=True))
+ except Exception as e:
+ logging.warning(f"添加无认证用户失败: {e}")
async def init_resources() -> None:
diff --git a/apps/routers/auth.py b/apps/routers/auth.py
index 4a3f829392135dd9637f32dd9d419a211637f0e2..72416fa35b1c23d2975497427da35a4912695fa4 100644
--- a/apps/routers/auth.py
+++ b/apps/routers/auth.py
@@ -74,7 +74,7 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse:
status_code=status.HTTP_403_FORBIDDEN,
)
- await UserManager.update_userinfo_by_user_sub(user_sub)
+ await UserManager.update_refresh_revision_by_user_sub(user_sub)
current_session = await SessionManager.create_session(user_host, user_sub)
@@ -177,6 +177,7 @@ async def userinfo(
user_sub=user_sub,
revision=user.is_active,
is_admin=user.is_admin,
+ auto_execute=user.auto_execute,
),
).model_dump(exclude_none=True, by_alias=True),
)
@@ -192,7 +193,7 @@ async def userinfo(
)
async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001
"""更新用户协议信息"""
- ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True)
+ ret: bool = await UserManager.update_refresh_revision_by_user_sub(user_sub, refresh_revision=True)
if not ret:
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
diff --git a/apps/routers/chat.py b/apps/routers/chat.py
index 26a874816f436a44e3569752edbea0767971f0d2..06bc2dd785008f0c78c805e5118684c17339ca46 100644
--- a/apps/routers/chat.py
+++ b/apps/routers/chat.py
@@ -22,6 +22,8 @@ from apps.schemas.task import Task
from apps.services.activity import Activity
from apps.services.blacklist import QuestionBlacklistManager, UserBlacklistManager
from apps.services.flow import FlowManager
+from apps.services.conversation import ConversationManager
+from apps.services.record import RecordManager
from apps.services.task import TaskManager
RECOMMEND_TRES = 5
@@ -32,25 +34,33 @@ router = APIRouter(
)
-async def init_task(post_body: RequestData, user_sub: str, session_id: str) -> Task:
+async def init_task(post_body: RequestData, user_sub: str) -> Task:
"""初始化Task"""
# 生成group_id
if not post_body.group_id:
post_body.group_id = str(uuid.uuid4())
- if post_body.new_task:
- # 创建或还原Task
- task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
- if task:
- await TaskManager.delete_task_by_task_id(task.id)
- task = await TaskManager.get_task(session_id=session_id, post_body=post_body, user_sub=user_sub)
+
# 更改信息并刷新数据库
if post_body.new_task:
- task.runtime.question = post_body.question
- task.ids.group_id = post_body.group_id
+ conversation = await ConversationManager.get_conversation_by_conversation_id(
+ user_sub=user_sub,
+ conversation_id=post_body.conversation_id,
+ )
+ if not conversation:
+ err = "[Chat] 用户没有权限访问该对话!"
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=err)
+ task_ids = await TaskManager.delete_tasks_by_conversation_id(post_body.conversation_id)
+ await RecordManager.update_record_flow_status_to_cancelled_by_task_ids(task_ids)
+ task = await TaskManager.init_new_task(user_sub=user_sub, conversation_id=post_body.conversation_id, post_body=post_body)
+ else:
+ if not post_body.task_id:
+ err = "[Chat] task_id 不可为空!"
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="task_id cannot be empty")
+ task = await TaskManager.get_task_by_conversation_id(post_body.task_id)
return task
-async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]:
+async def chat_generator(post_body: RequestData, user_sub: str) -> AsyncGenerator[str, None]:
"""进行实际问答,并从MQ中获取消息"""
try:
await Activity.set_active(user_sub)
@@ -62,7 +72,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
await Activity.remove_active(user_sub)
return
- task = await init_task(post_body, user_sub, session_id)
+ task = await init_task(post_body, user_sub)
# 创建queue;由Scheduler进行关闭
queue = MessageQueue()
@@ -120,7 +130,6 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str)
async def chat(
post_body: RequestData,
user_sub: Annotated[str, Depends(get_user)],
- session_id: Annotated[str, Depends(get_session)],
) -> StreamingResponse:
"""LLM流式对话接口"""
# 问题黑名单检测
@@ -133,7 +142,7 @@ async def chat(
if await Activity.is_active(user_sub):
raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests")
- res = chat_generator(post_body, user_sub, session_id)
+ res = chat_generator(post_body, user_sub)
return StreamingResponse(
content=res,
media_type="text/event-stream",
diff --git a/apps/routers/record.py b/apps/routers/record.py
index f73e6c9834147178e3141ed72689e4b42a927c6f..8d0105ff577ae4c425ad47fe3785743628627b4e 100644
--- a/apps/routers/record.py
+++ b/apps/routers/record.py
@@ -65,7 +65,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_
tmp_record = RecordData(
id=record.id,
groupId=record_group.id,
- taskId=record_group.task_id,
+ taskId=record.task_id,
conversationId=conversation_id,
content=record_data,
metadata=record.metadata
@@ -87,8 +87,9 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_
tmp_record.flow = RecordFlow(
id=record.flow.flow_id, # TODO: 此处前端应该用name
recordId=record.id,
- flowStatus=record.flow.flow_staus,
flowId=record.flow.flow_id,
+ flowName=record.flow.flow_name,
+ flowStatus=record.flow.flow_staus,
stepNum=len(flow_step_list),
steps=[],
)
diff --git a/apps/routers/user.py b/apps/routers/user.py
index 54e12f444181b56e408f7face5c4ed37be76008b..8c197204ec62861eddf250d9e73f5ebb7e5c71ad 100644
--- a/apps/routers/user.py
+++ b/apps/routers/user.py
@@ -3,10 +3,11 @@
from typing import Annotated
-from fastapi import APIRouter, Depends, status
+from fastapi import APIRouter, Body, Depends, status
from fastapi.responses import JSONResponse
from apps.dependency import get_user
+from apps.schemas.request_data import UserUpdateRequest
from apps.schemas.response_data import UserGetMsp, UserGetRsp
from apps.schemas.user import UserInfo
from apps.services.user import UserManager
@@ -18,7 +19,7 @@ router = APIRouter(
@router.get("")
-async def chat(
+async def get_user_sub(
user_sub: Annotated[str, Depends(get_user)],
) -> JSONResponse:
"""查询所有用户接口"""
@@ -42,3 +43,24 @@ async def chat(
result=UserGetMsp(userInfoList=user_info_list),
).model_dump(exclude_none=True, by_alias=True),
)
+
+
+@router.post("")
+async def update_user_info(
+ user_sub: Annotated[str, Depends(get_user)],
+ *,
+ data: Annotated[UserUpdateRequest, Body(..., description="用户更新信息")],
+) -> JSONResponse:
+ """更新用户信息接口"""
+ # 更新用户信息
+
+ result = await UserManager.update_userinfo_by_user_sub(user_sub, data)
+ if not result:
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"},
+ )
+ return JSONResponse(
+ status_code=status.HTTP_404_NOT_FOUND,
+ content={"code": status.HTTP_404_NOT_FOUND, "message": "用户信息更新失败"},
+ )
diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py
index cb8e183e9ac6036ac1a170d99b0da52b6e2a2d53..603ea65b66238628c60679449925892b85d6b189 100644
--- a/apps/scheduler/executor/agent.py
+++ b/apps/scheduler/executor/agent.py
@@ -5,26 +5,70 @@ import logging
from pydantic import Field
+from apps.llm.reasoning import ReasoningLLM
from apps.scheduler.executor.base import BaseExecutor
-from apps.scheduler.mcp_agent import host, plan, select
+from apps.schemas.enum_var import EventType, SpecialCallType, FlowStatus, StepStatus
+from apps.scheduler.mcp_agent.host import MCPHost
+from apps.scheduler.mcp_agent.plan import MCPPlanner
+from apps.scheduler.pool.mcp.client import MCPClient
+from apps.schemas.mcp import MCPCollection, MCPTool
from apps.schemas.task import ExecutorState, StepQueueItem
+from apps.schemas.message import param
from apps.services.task import TaskManager
-
+from apps.services.appcenter import AppCenterManager
+from apps.services.mcp_service import MCPServiceManager
logger = logging.getLogger(__name__)
class MCPAgentExecutor(BaseExecutor):
"""MCP Agent执行器"""
- question: str = Field(description="用户输入")
max_steps: int = Field(default=20, description="最大步数")
servers_id: list[str] = Field(description="MCP server id")
agent_id: str = Field(default="", description="Agent ID")
agent_description: str = Field(default="", description="Agent描述")
+ mcp_list: list[MCPCollection] = Field(description="MCP服务器列表", default=[])
+ mcp_client: dict[str, MCPClient] = Field(
+ description="MCP客户端列表,key为mcp_id", default={}
+ )
+ tool_list: list[MCPTool] = Field(description="MCP工具列表", default=[])
+ params: param | None = Field(
+ default=None, description="流执行过程中的参数补充", alias="params"
+ )
+ resoning_llm: ReasoningLLM = Field(
+ default=ReasoningLLM(),
+ description="推理大模型",
+ )
async def load_state(self) -> None:
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
- if self.task.state:
+ if self.task.state and self.task.state.flow_status != FlowStatus.INIT:
self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
+
+ async def load_mcp(self) -> None:
+ """加载MCP服务器列表"""
+ logger.info("[MCPAgentExecutor] 加载MCP服务器列表")
+ # 获取MCP服务器列表
+ app = await AppCenterManager.fetch_app_data_by_id(self.agent_id)
+ mcp_ids = app.mcp_service
+ for mcp_id in mcp_ids:
+ mcp_service = await MCPServiceManager.get_mcp_service(mcp_id)
+ if self.task.ids.user_sub not in mcp_service.activated:
+ logger.warning(
+ "[MCPAgentExecutor] 用户 %s 未启用MCP %s",
+ self.task.ids.user_sub,
+ mcp_id,
+ )
+ continue
+
+ self.mcp_list.append(mcp_service)
+ self.mcp_client[mcp_id] = await MCPHost.get_client(self.task.ids.user_sub, mcp_id)
+ self.tool_list.extend(mcp_service.tools)
+
+ async def run(self) -> None:
+ """执行MCP Agent的主逻辑"""
+ # 初始化MCP服务
+ self.load_state()
+ self.load_mcp()
diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py
index d8400fdf32af2e6d073e74f66893c97a0e4c9a56..ebd4da8ecb6e55c5220a9cd4e0f63c72842c9828 100644
--- a/apps/scheduler/executor/flow.py
+++ b/apps/scheduler/executor/flow.py
@@ -55,7 +55,7 @@ class FlowExecutor(BaseExecutor):
"""从数据库中加载FlowExecutor的状态"""
logger.info("[FlowExecutor] 加载Executor状态")
# 尝试恢复State
- if self.task.state:
+ if self.task.state and self.task.state.flow_status != FlowStatus.INIT:
self.task.context = await TaskManager.get_context_by_task_id(self.task.id)
else:
# 创建ExecutorState
diff --git a/apps/scheduler/executor/step.py b/apps/scheduler/executor/step.py
index 377a4c6e5a0c9e344158ba9c398515dd22b18451..5a95e407989f4dfc9c26c2157ccbec5d3ecdad21 100644
--- a/apps/scheduler/executor/step.py
+++ b/apps/scheduler/executor/step.py
@@ -260,7 +260,7 @@ class StepExecutor(BaseExecutor):
input_data=self.obj.input,
output_data=output_data,
)
- self.task.context.append(history.model_dump(exclude_none=True, by_alias=True))
+ self.task.context.append(history)
# 推送输出
await self.push_message(EventType.STEP_OUTPUT.value, output_data)
diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py
index aa19611278f1ffc9ad723bea90165066e999ae44..8e11083900c20c29fcaf71b6535f3b442fc2313a 100644
--- a/apps/scheduler/mcp/host.py
+++ b/apps/scheduler/mcp/host.py
@@ -67,7 +67,7 @@ class MCPHost:
context_list = []
for ctx_id in self._context_list:
- context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None)
+ context = next((ctx for ctx in task.context if ctx.id == ctx_id), None)
if not context:
continue
context_list.append(context)
@@ -118,7 +118,7 @@ class MCPHost:
logger.error("任务 %s 不存在", self._task_id)
return {}
self._context_list.append(context.id)
- task.context.append(context.model_dump(by_alias=True, exclude_none=True))
+ task.context.append(context.model_dump(exclude_none=True, by_alias=True))
await TaskManager.save_task(self._task_id, task)
return output_data
diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py
index a8ebec7b4d1182684a02e5b9f9ce68189af6c2e4..3217f539388d428009fc38e787ab0577cb73f51d 100644
--- a/apps/scheduler/mcp_agent/host.py
+++ b/apps/scheduler/mcp_agent/host.py
@@ -14,116 +14,53 @@ from apps.llm.function import JsonGenerator
from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE
from apps.scheduler.pool.mcp.client import MCPClient
from apps.scheduler.pool.mcp.pool import MCPPool
+from apps.scheduler.mcp_agent.prompt import REPAIR_PARAMS
from apps.schemas.enum_var import StepStatus
from apps.schemas.mcp import MCPPlanItem, MCPTool
-from apps.schemas.task import FlowStepHistory
+from apps.schemas.task import Task, FlowStepHistory
from apps.services.task import TaskManager
logger = logging.getLogger(__name__)
+_env = SandboxedEnvironment(
+ loader=BaseLoader,
+ autoescape=True,
+ trim_blocks=True,
+ lstrip_blocks=True,
+)
+
class MCPHost:
"""MCP宿主服务"""
- def __init__(self, user_sub: str, task_id: str, runtime_id: str, runtime_name: str) -> None:
- """初始化MCP宿主"""
- self._user_sub = user_sub
- self._task_id = task_id
- # 注意:runtime在工作流中是flow_id和step_description,在Agent中可为标识Agent的id和description
- self._runtime_id = runtime_id
- self._runtime_name = runtime_name
- self._context_list = []
- self._env = SandboxedEnvironment(
- loader=BaseLoader(),
- autoescape=False,
- trim_blocks=True,
- lstrip_blocks=True,
- )
-
- async def get_client(self, mcp_id: str) -> MCPClient | None:
+ @staticmethod
+ async def get_client(user_sub, mcp_id: str) -> MCPClient | None:
"""获取MCP客户端"""
mongo = MongoDB()
mcp_collection = mongo.get_collection("mcp")
# 检查用户是否启用了这个mcp
- mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub})
+ mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": user_sub})
if not mcp_db_result:
- logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id)
+ logger.warning("用户 %s 未启用MCP %s", user_sub, mcp_id)
return None
# 获取MCP配置
try:
- return await MCPPool().get(mcp_id, self._user_sub)
+ return await MCPPool().get(mcp_id, user_sub)
except KeyError:
- logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id)
+ logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", user_sub, mcp_id)
return None
- async def assemble_memory(self) -> str:
+ @staticmethod
+ async def assemble_memory(task: Task) -> str:
"""组装记忆"""
- task = await TaskManager.get_task_by_task_id(self._task_id)
- if not task:
- logger.error("任务 %s 不存在", self._task_id)
- return ""
-
- context_list = []
- for ctx_id in self._context_list:
- context = next((ctx for ctx in task.context if ctx["_id"] == ctx_id), None)
- if not context:
- continue
- context_list.append(context)
- return self._env.from_string(MEMORY_TEMPLATE).render(
- context_list=context_list,
- )
-
- async def _save_memory(
- self,
- tool: MCPTool,
- plan_item: MCPPlanItem,
- input_data: dict[str, Any],
- result: str,
- ) -> dict[str, Any]:
- """保存记忆"""
- try:
- output_data = json.loads(result)
- except Exception: # noqa: BLE001
- logger.warning("[MCPHost] 得到的数据不是dict格式!尝试转换为str")
- output_data = {
- "message": result,
- }
-
- if not isinstance(output_data, dict):
- output_data = {
- "message": result,
- }
-
- # 创建context;注意用法
- context = FlowStepHistory(
- task_id=self._task_id,
- flow_id=self._runtime_id,
- flow_name=self._runtime_name,
- flow_status=StepStatus.SUCCESS,
- step_id=tool.name,
- step_name=tool.name,
- # description是规划的实际内容
- step_description=plan_item.content,
- step_status=StepStatus.SUCCESS,
- input_data=input_data,
- output_data=output_data,
+ return _env.from_string(MEMORY_TEMPLATE).render(
+ context_list=task.context,
)
- # 保存到task
- task = await TaskManager.get_task_by_task_id(self._task_id)
- if not task:
- logger.error("任务 %s 不存在", self._task_id)
- return {}
- self._context_list.append(context.id)
- task.context.append(context.model_dump(by_alias=True, exclude_none=True))
- await TaskManager.save_task(self._task_id, task)
-
- return output_data
-
- async def _fill_params(self, schema: dict[str, Any], query: str) -> dict[str, Any]:
+ async def _get_first_input_params(schema: dict[str, Any], query: str) -> dict[str, Any]:
"""填充工具参数"""
# 更清晰的输入·指令,这样可以调用generate
llm_query = rf"""
@@ -137,23 +74,48 @@ class MCPHost:
llm_query,
[
{"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": await self.assemble_memory()},
+ {"role": "user", "content": await MCPHost.assemble_memory()},
+ ],
+ schema,
+ )
+ return await json_generator.generate()
+
+ async def _fill_params(mcp_tool: MCPTool, schema: dict[str, Any],
+ current_input: dict[str, Any],
+ error_message: str = "", params: dict[str, Any] = {},
+ params_description: str = "") -> dict[str, Any]:
+ llm_query = "请生成修复之后的工具参数"
+ prompt = _env.from_string(REPAIR_PARAMS).render(
+ tool_name=mcp_tool.name,
+ tool_description=mcp_tool.description,
+ input_schema=schema,
+ current_input=current_input,
+ error_message=error_message,
+ params=params,
+ params_description=params_description,
+ )
+
+ json_generator = JsonGenerator(
+ llm_query,
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt},
],
schema,
)
return await json_generator.generate()
- async def call_tool(self, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]:
+ async def call_tool(user_sub: str, tool: MCPTool, plan_item: MCPPlanItem) -> list[dict[str, Any]]:
"""调用工具"""
# 拿到Client
- client = await MCPPool().get(tool.mcp_id, self._user_sub)
+ client = await MCPPool().get(tool.mcp_id, user_sub)
if client is None:
err = f"[MCPHost] MCP Server不合法: {tool.mcp_id}"
logger.error(err)
raise ValueError(err)
# 填充参数
- params = await self._fill_params(tool, plan_item.instruction)
+ params = await MCPHost._fill_params(tool, plan_item.instruction)
# 调用工具
result = await client.call_tool(tool.name, params)
# 保存记忆
@@ -162,29 +124,12 @@ class MCPHost:
if not isinstance(item, TextContent):
logger.error("MCP结果类型不支持: %s", item)
continue
- processed_result.append(await self._save_memory(tool, plan_item, params, item.text))
-
- return processed_result
-
- async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]:
- """获取工具列表"""
- mongo = MongoDB()
- mcp_collection = mongo.get_collection("mcp")
-
- # 获取工具列表
- tool_list = []
- for mcp_id in mcp_id_list:
- # 检查用户是否启用了这个mcp
- mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub})
- if not mcp_db_result:
- logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id)
- continue
- # 获取MCP工具配置
+ result = item.text
try:
- for tool in mcp_db_result["tools"]:
- tool_list.extend([MCPTool.model_validate(tool)])
- except KeyError:
- logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id)
+ json_result = json.loads(result)
+ except Exception as e:
+ logger.error("MCP结果解析失败: %s, 错误: %s", result, e)
continue
+ processed_result.append(json_result)
- return tool_list
+ return processed_result
diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py
index a7c1f132d892a979e221632dcdcfd8bd10c1aacc..13e7a98dc334dc55d74ebd174f644c66fe382c09 100644
--- a/apps/scheduler/mcp_agent/plan.py
+++ b/apps/scheduler/mcp_agent/plan.py
@@ -1,6 +1,6 @@
# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved.
"""MCP 用户目标拆解与规划"""
-from typing import Any
+from typing import Any, AsyncGenerator
from jinja2 import BaseLoader
from jinja2.sandbox import SandboxedEnvironment
@@ -57,8 +57,8 @@ class MCPPlanner:
result += chunk
# 保存token用量
- self.input_tokens = self.resoning_llm.input_tokens
- self.output_tokens = self.resoning_llm.output_tokens
+ self.input_tokens += self.resoning_llm.input_tokens
+ self.output_tokens += self.resoning_llm.output_tokens
return result
async def _parse_result(self, result: str, schema: dict[str, Any]) -> str:
@@ -171,7 +171,8 @@ class MCPPlanner:
"""获取推理大模型的风险评估结果"""
template = self._env.from_string(RISK_EVALUATE)
prompt = template.render(
- tool=tool,
+ tool_name=tool.name,
+ tool_description=tool.description,
input_param=input_param,
additional_info=additional_info,
)
@@ -194,7 +195,8 @@ class MCPPlanner:
schema_with_null = slot.add_null_to_basic_types()
template = self._env.from_string(GET_MISSING_PARAMS)
prompt = template.render(
- tool=tool,
+ tool_name=tool.name,
+ tool_description=tool.description,
input_param=input_param,
schema=schema_with_null,
error_message=error_message,
@@ -204,7 +206,7 @@ class MCPPlanner:
input_param_with_null = await self._parse_result(result, schema_with_null)
return input_param_with_null
- async def generate_answer(self, plan: MCPPlan, memory: str) -> str:
+ async def generate_answer(self, plan: MCPPlan, memory: str) -> AsyncGenerator[str, None]:
"""生成最终回答"""
template = self._env.from_string(FINAL_ANSWER)
prompt = template.render(
@@ -213,16 +215,12 @@ class MCPPlanner:
goal=self.user_goal,
)
- llm = ReasoningLLM()
- result = ""
- async for chunk in llm.call(
+ async for chunk in self.resoning_llm.call(
[{"role": "user", "content": prompt}],
streaming=False,
temperature=0.07,
):
- result += chunk
+ yield chunk
- self.input_tokens = llm.input_tokens
- self.output_tokens = llm.output_tokens
-
- return result
+ self.input_tokens = self.resoning_llm.input_tokens
+ self.output_tokens = self.resoning_llm.output_tokens
diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py
index cf05afc877e9ce08d790799d08e45c1f321d2d21..9cbc2f5bcb8e643bf9a844a527c740862fe9d180 100644
--- a/apps/scheduler/mcp_agent/prompt.py
+++ b/apps/scheduler/mcp_agent/prompt.py
@@ -421,8 +421,8 @@ RISK_EVALUATE = dedent(r"""
```
# 工具
- {{ tool.name }}
- {{ tool.description }}
+ {{ tool_name }}
+ {{ tool_description }}
# 工具入参
{{ input_param }}
@@ -464,7 +464,7 @@ JUDGE_NEXT_STEP = dedent(r"""
{
"content": "任务执行完成,端口扫描结果为Result[2]",
"tool": "Final",
- "instruction": ""
+ "instruction": ""
}
]}
## 当前使用的工具
@@ -489,8 +489,8 @@ JUDGE_NEXT_STEP = dedent(r"""
{{ current_plan }}
# 当前使用的工具
- {{ tool.name }}
- {{ tool.description }}
+ {{ tool_name }}
+ {{ tool_description }}
# 工具入参
{{ input_param }}
@@ -571,8 +571,8 @@ GET_MISSING_PARAMS = dedent(r"""
```
# 工具
< tool >
- < name > {{tool.name}} < /name >
- < description > {{tool.description}} < /description >
+ < name > {{tool_name}} < /name >
+ < description > {{tool_description}} < /description >
< / tool >
# 工具入参
{{input_param}}
@@ -583,32 +583,107 @@ GET_MISSING_PARAMS = dedent(r"""
# 输出
"""
)
+REPAIR_PARAMS = dedent(r"""
+ 你是一个工具参数修复器。
+ 你的任务是根据当前的工具信息、工具入参的schema、工具当前的入参、工具的报错、补充的参数和补充的参数描述,修复当前工具的入参。
+
+ # 样例
+ ## 工具信息
+
+ mysql_analyzer
+ 分析MySQL数据库性能
+
+ ## 工具入参的schema
+ {
+ "type": "object",
+ "properties": {
+ "host": {
+ "type": "string",
+ "description": "MySQL数据库的主机地址"
+ },
+ "port": {
+ "type": "integer",
+ "description": "MySQL数据库的端口号"
+ },
+ "username": {
+ "type": "string",
+ "description": "MySQL数据库的用户名"
+ },
+ "password": {
+ "type": "string",
+ "description": "MySQL数据库的密码"
+ }
+ },
+ "required": ["host", "port", "username", "password"]
+ }
+ ## 工具当前的入参
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "root",
+ "password": "password"
+ }
+ ## 工具的报错
+ 执行端口扫描命令时,出现了错误:`password is not correct`。
+ ## 补充的参数
+ {
+ "username": "admin",
+ "password": "admin123"
+ }
+ ## 补充的参数描述
+ 用户希望使用admin用户和admin123密码来连接MySQL数据库。
+ # 输出
+ ```json
+ {
+ "host": "192.0.0.1",
+ "port": 3306,
+ "username": "admin",
+ "password": "admin123"
+ }
+ ```
+ # 工具
+
+ {{tool_name}}
+ {{tool_description}}
+
+ # 工具入参scheme
+ {{input_schema}}
+ # 工具入参
+ {{input_param}}
+ # 运行报错
+ {{error_message}}
+ # 补充的参数
+ {{params}}
+ # 补充的参数描述
+ {{params_description}}
+ # 输出
+ """
+ )
FINAL_ANSWER = dedent(r"""
综合理解计划执行结果和背景信息,向用户报告目标的完成情况。
# 用户目标
- {{goal}}
+ {{ goal }}
# 计划执行情况
为了完成上述目标,你实施了以下计划:
- {{memory}}
+ {{ memory }}
# 其他背景信息:
- {{status}}
+ {{ status }}
# 现在,请根据以上信息,向用户报告目标的完成情况:
""")
-
MEMORY_TEMPLATE = dedent(r"""
- { % for ctx in context_list % }
+ {% for ctx in context_list % }
- 第{{loop.index}}步:{{ctx.step_description}}
- 调用工具 `{{ctx.step_id}}`,并提供参数 `{{ctx.input_data}}`
- 执行状态:{{ctx.status}}
- 得到数据:`{{ctx.output_data}}`
- { % endfor % }
+ {% endfor % }
""")
diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py
index 72f64a71006e09a053499268dac9d0245d692c2f..3b26f42f2995db9911a61d85e140f4ec2da2d5e1 100644
--- a/apps/scheduler/scheduler/context.py
+++ b/apps/scheduler/scheduler/context.py
@@ -199,21 +199,21 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
# 检查是否存在group_id
if not await RecordManager.check_group_id(task.ids.group_id, user_sub):
- record_group = await RecordManager.create_record_group(
- task.ids.group_id, user_sub, post_body.conversation_id, task.id,
+ record_group_id = await RecordManager.create_record_group(
+ task.ids.group_id, user_sub, post_body.conversation_id
)
- if not record_group:
+ if not record_group_id:
logger.error("[Scheduler] 创建问答组失败")
return
else:
- record_group = task.ids.group_id
+ record_group_id = task.ids.group_id
# 修改文件状态
- await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group)
+ await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group_id)
# 保存Record
- await RecordManager.insert_record_data_into_record_group(user_sub, record_group, record)
+ await RecordManager.insert_record_data_into_record_group(user_sub, record_group_id, record)
# 保存与答案关联的文件
- await DocumentManager.save_answer_doc(user_sub, record_group, used_docs)
+ await DocumentManager.save_answer_doc(user_sub, record_group_id, used_docs)
if post_body.app and post_body.app.app_id:
# 更新最近使用的应用
@@ -223,5 +223,4 @@ async def save_data(task: Task, user_sub: str, post_body: RequestData) -> None:
if not task.state or task.state.flow_status == StepStatus.SUCCESS or task.state.flow_status == StepStatus.ERROR or task.state.flow_status == StepStatus.CANCELLED:
await TaskManager.delete_task_by_task_id(task.id)
else:
- # 更新Task
await TaskManager.save_task(task.id, task)
diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py
index 91930f8cd6ae9aa2b872962b668cc382e8b56f2b..f6325369355489cff2577a15d24817b8d54e26dd 100644
--- a/apps/scheduler/scheduler/scheduler.py
+++ b/apps/scheduler/scheduler/scheduler.py
@@ -229,7 +229,8 @@ class Scheduler:
# 初始化Executor
logger.info("[Scheduler] 初始化Executor")
-
+ logger.error(f"{flow_data}")
+ logger.error(f"{self.task}")
flow_exec = FlowExecutor(
flow_id=flow_id,
flow=flow_data,
diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py
index 0ff66c72bbe30b7cbb55517952c8b14744d69cf2..a2991f851a85c8d60afb6d60e76cbb766c9da94e 100644
--- a/apps/schemas/collection.py
+++ b/apps/schemas/collection.py
@@ -61,6 +61,7 @@ class User(BaseModel):
fav_apps: list[str] = []
fav_services: list[str] = []
is_admin: bool = Field(default=False, description="是否为管理员")
+ auto_execute: bool = Field(default=True, description="是否自动执行任务")
class LLM(BaseModel):
diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py
index 5ff3381c74e65baa9665e6b37109cb842e6c88f8..3fb65028799b91652257f994df6887f38acb7631 100644
--- a/apps/schemas/enum_var.py
+++ b/apps/schemas/enum_var.py
@@ -27,6 +27,7 @@ class FlowStatus(str, Enum):
"""Flow状态"""
UNKNOWN = "unknown"
+ INIT = "init"
WAITING = "waiting"
RUNNING = "running"
SUCCESS = "success"
diff --git a/apps/schemas/message.py b/apps/schemas/message.py
index cf70a82b8573d2e3e34564b8f0ec5280195980d9..e7341324266c26447b9d36195141e09addc4520d 100644
--- a/apps/schemas/message.py
+++ b/apps/schemas/message.py
@@ -5,10 +5,16 @@ from typing import Any
from datetime import UTC, datetime
from pydantic import BaseModel, Field
-from apps.schemas.enum_var import EventType, StepStatus
+from apps.schemas.enum_var import EventType, FlowStatus, StepStatus
from apps.schemas.record import RecordMetadata
+class param(BaseModel):
+ """流执行过程中的参数补充"""
+ content: dict[str, Any] | bool = Field(default={}, description="流执行过程中的参数补充内容")
+ description: str = Field(default="", description="流执行过程中的参数补充描述")
+
+
class HeartbeatData(BaseModel):
"""心跳事件的数据结构"""
@@ -22,6 +28,8 @@ class MessageFlow(BaseModel):
app_id: str = Field(description="插件ID", alias="appId")
flow_id: str = Field(description="Flow ID", alias="flowId")
+ flow_name: str = Field(description="Flow名称", alias="flowName")
+ flow_status: FlowStatus = Field(description="Flow状态", alias="flowStatus", default=FlowStatus.UNKNOWN)
step_id: str = Field(description="当前步骤ID", alias="stepId")
step_name: str = Field(description="当前步骤名称", alias="stepName")
sub_step_id: str | None = Field(description="当前子步骤ID", alias="subStepId", default=None)
diff --git a/apps/schemas/record.py b/apps/schemas/record.py
index d322276254347caf251bf96b3b01c93f9801c4ee..6a394375b02fc340119af397739f16c932bc0cbd 100644
--- a/apps/schemas/record.py
+++ b/apps/schemas/record.py
@@ -44,6 +44,7 @@ class RecordFlow(BaseModel):
id: str
record_id: str = Field(alias="recordId")
flow_id: str = Field(alias="flowId")
+ flow_name: str = Field(alias="flowName", default="")
flow_status: StepStatus = Field(alias="flowStatus", default=StepStatus.SUCCESS)
step_num: int = Field(alias="stepNum")
steps: list[RecordFlowStep]
@@ -129,6 +130,7 @@ class Record(RecordData):
user_sub: str
key: dict[str, Any] = {}
+ task_id: str
content: str
comment: RecordComment = Field(default=RecordComment())
flow: FlowHistory = Field(
@@ -149,5 +151,4 @@ class RecordGroup(BaseModel):
records: list[Record] = []
docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变
conversation_id: str
- task_id: str
created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3))
diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py
index 793ff45638c9c4a974d637728f6aa1045ef0ecbd..8719c2e985990b4f7e0f122353fb6097a359ff8a 100644
--- a/apps/schemas/request_data.py
+++ b/apps/schemas/request_data.py
@@ -10,6 +10,7 @@ from apps.schemas.appcenter import AppData
from apps.schemas.enum_var import CommentType
from apps.schemas.flow_topology import FlowItem
from apps.schemas.mcp import MCPType
+from apps.schemas.message import param
class RequestDataApp(BaseModel):
@@ -17,7 +18,7 @@ class RequestDataApp(BaseModel):
app_id: str = Field(description="应用ID", alias="appId")
flow_id: str | None = Field(default=None, description="Flow ID", alias="flowId")
- params: dict[str, Any] | None = Field(default=None, description="插件参数")
+ params: param | None = Field(default=None, description="流执行过程中的参数补充", alias="params")
class MockRequestData(BaseModel):
@@ -46,6 +47,7 @@ class RequestData(BaseModel):
files: list[str] = Field(default=[], description="文件列表")
app: RequestDataApp | None = Field(default=None, description="应用")
debug: bool = Field(default=False, description="是否调试")
+ task_id: str | None = Field(default=None, alias="taskId", description="任务ID")
new_task: bool = Field(default=True, description="是否新建任务")
@@ -185,3 +187,9 @@ class UpdateKbReq(BaseModel):
"""更新知识库请求体"""
kb_ids: list[str] = Field(description="知识库ID列表", alias="kbIds", default=[])
+
+
+class UserUpdateRequest(BaseModel):
+ """更新用户信息请求体"""
+
+ auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute")
diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py
index b1dc77b75ba11943e149dc8c471bc3141d684442..7f162326d278b4972cd3d708739213354b5886b0 100644
--- a/apps/schemas/response_data.py
+++ b/apps/schemas/response_data.py
@@ -55,6 +55,7 @@ class AuthUserMsg(BaseModel):
user_sub: str
revision: bool
is_admin: bool
+ auto_execute: bool
class AuthUserRsp(ResponseData):
diff --git a/apps/schemas/task.py b/apps/schemas/task.py
index 9e9f1531707edaad5c306098662a9d4a6fbce64b..eccc95a567be95d1270ffb7797313d07b6467318 100644
--- a/apps/schemas/task.py
+++ b/apps/schemas/task.py
@@ -40,17 +40,18 @@ class ExecutorState(BaseModel):
flow_id: str = Field(description="Flow ID", default="")
flow_name: str = Field(description="Flow名称", default="")
description: str = Field(description="Flow描述", default="")
- flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.UNKNOWN)
+ flow_status: FlowStatus = Field(description="Flow状态", default=FlowStatus.INIT)
# 任务级数据
step_id: str = Field(description="当前步骤ID", default="")
step_name: str = Field(description="当前步骤名称", default="")
step_status: StepStatus = Field(description="当前步骤状态", default=StepStatus.UNKNOWN)
step_description: str = Field(description="当前步骤描述", default="")
- retry_times: int = Field(description="当前步骤重试次数", default=0)
- error_message: str = Field(description="当前步骤错误信息", default="")
app_id: str = Field(description="应用ID", default="")
- slot: dict[str, Any] = Field(description="待填充参数的JSON Schema", default={})
- error_info: dict[str, Any] = Field(description="错误信息", default={})
+ current_input: dict[str, Any] = Field(description="当前输入数据", default={})
+ params: dict[str, Any] = Field(description="补充的参数", default={})
+ params_description: str = Field(description="补充的参数描述", default="")
+ error_info: str = Field(description="错误信息", default="")
+ retry_times: int = Field(description="当前步骤重试次数", default=0)
class TaskIds(BaseModel):
diff --git a/apps/schemas/user.py b/apps/schemas/user.py
index 61aa2587b8ae6255dc3885b04a96f12310f70773..debb3446e3b4f6a092ea4b581b66a45ea2a8c970 100644
--- a/apps/schemas/user.py
+++ b/apps/schemas/user.py
@@ -9,3 +9,4 @@ class UserInfo(BaseModel):
user_sub: str = Field(alias="userSub", default="")
user_name: str = Field(alias="userName", default="")
+ auto_execute: bool = Field(alias="autoExecute", default=False)
diff --git a/apps/services/conversation.py b/apps/services/conversation.py
index bac964db132fecc9599bb80943e6fa119048f387..6bacb72720dbe43ff6835d4714c720f2778d85b1 100644
--- a/apps/services/conversation.py
+++ b/apps/services/conversation.py
@@ -140,4 +140,4 @@ class ConversationManager:
await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session)
await session.commit_transaction()
- await TaskManager.delete_tasks_by_conversation_id(conversation_id)
+ await TaskManager.delete_tasks_and_flow_context_by_conversation_id(conversation_id)
diff --git a/apps/services/record.py b/apps/services/record.py
index 5c8a89dfd224166ac822e2eeb47fb79f1fefe7e4..6b61f91ecd841a9c585ada2913807f7f3c7dfe18 100644
--- a/apps/services/record.py
+++ b/apps/services/record.py
@@ -9,7 +9,7 @@ from apps.schemas.record import (
Record,
RecordGroup,
)
-
+from apps.schemas.enum_var import FlowStatus
logger = logging.getLogger(__name__)
@@ -17,7 +17,7 @@ class RecordManager:
"""问答对相关操作"""
@staticmethod
- async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None:
+ async def create_record_group(group_id: str, user_sub: str, conversation_id: str) -> str | None:
"""创建问答组"""
mongo = MongoDB()
record_group_collection = mongo.get_collection("record_group")
@@ -26,7 +26,6 @@ class RecordManager:
_id=group_id,
user_sub=user_sub,
conversation_id=conversation_id,
- task_id=task_id,
)
try:
@@ -49,6 +48,10 @@ class RecordManager:
mongo = MongoDB()
group_collection = mongo.get_collection("record_group")
try:
+ await group_collection.update_one(
+ {"_id": group_id, "user_sub": user_sub},
+ {"$pull": {"records": {"id": record.id}}}
+ )
await group_collection.update_one(
{"_id": group_id, "user_sub": user_sub},
{"$push": {"records": record.model_dump(by_alias=True)}},
@@ -133,6 +136,19 @@ class RecordManager:
logger.exception("[RecordManager] 查询问答组失败")
return []
+ @staticmethod
+ async def update_record_flow_status_to_cancelled_by_task_ids(task_ids: list[str]) -> None:
+ """更新Record关联的Flow状态"""
+ record_group_collection = MongoDB().get_collection("record_group")
+ try:
+ await record_group_collection.update_many(
+ {"records.flow.flow_id": {"$in": task_ids}, "records.flow.flow_status": {"$nin": [FlowStatus.ERROR.value, FlowStatus.SUCCESS.value]}},
+ {"$set": {"records.$[elem].flow.flow_status": FlowStatus.CANCELLED}},
+ array_filters=[{"elem.flow.flow_id": {"$in": task_ids}}],
+ )
+ except Exception:
+ logger.exception("[RecordManager] 更新Record关联的Flow状态失败")
+
@staticmethod
async def verify_record_in_group(group_id: str, record_id: str, user_sub: str) -> bool:
"""
@@ -151,7 +167,6 @@ class RecordManager:
logger.exception("[RecordManager] 验证记录是否在组中失败")
return False
-
@staticmethod
async def check_group_id(group_id: str, user_sub: str) -> bool:
"""检查group_id是否存在"""
diff --git a/apps/services/task.py b/apps/services/task.py
index 93604a933d78c84b36367d4873ead02b08193f15..2f75a8c3a034cb94e051efd1ee00496f9ab0351d 100644
--- a/apps/services/task.py
+++ b/apps/services/task.py
@@ -94,7 +94,7 @@ class TaskManager:
return flow_context_list
@staticmethod
- async def get_context_by_task_id(task_id: str, length: int = 0) -> list[dict[str, Any]]:
+ async def get_context_by_task_id(task_id: str, length: int = 0) -> list[FlowStepHistory]:
"""根据task_id获取flow信息"""
flow_context_collection = MongoDB().get_collection("flow_context")
@@ -105,7 +105,8 @@ class TaskManager:
).sort(
"created_at", -1,
).limit(length):
- flow_context += [history]
+ for i in range(len(flow_context)):
+ flow_context.append(FlowStepHistory.model_validate(history))
except Exception:
logger.exception("[TaskManager] 获取task_id的flow信息失败")
return []
@@ -113,7 +114,29 @@ class TaskManager:
return flow_context
@staticmethod
- async def save_flow_context(task_id: str, flow_context: list[dict[str, Any]]) -> None:
+ async def init_new_task(
+ cls,
+ user_sub: str,
+ session_id: str | None = None,
+ post_body: RequestData | None = None,
+ ) -> Task:
+ """获取任务块"""
+ return Task(
+ _id=str(uuid.uuid4()),
+ ids=TaskIds(
+ user_sub=user_sub if user_sub else "",
+ session_id=session_id if session_id else "",
+ conversation_id=post_body.conversation_id,
+ group_id=post_body.group_id if post_body.group_id else "",
+ ),
+ question=post_body.question if post_body else "",
+ group_id=post_body.group_id if post_body else "",
+ tokens=TaskTokens(),
+ runtime=TaskRuntime(),
+ )
+
+ @staticmethod
+ async def save_flow_context(task_id: str, flow_context: list[FlowStepHistory]) -> None:
"""保存flow信息到flow_context"""
flow_context_collection = MongoDB().get_collection("flow_context")
try:
@@ -121,15 +144,15 @@ class TaskManager:
# 查找是否存在
current_context = await flow_context_collection.find_one({
"task_id": task_id,
- "_id": history["_id"],
+ "_id": history.id,
})
if current_context:
await flow_context_collection.update_one(
{"_id": current_context["_id"]},
- {"$set": history},
+ {"$set": history.model_dump(exclude_none=True, by_alias=True)},
)
else:
- await flow_context_collection.insert_one(history)
+ await flow_context_collection.insert_one(history.model_dump(exclude_none=True, by_alias=True))
except Exception:
logger.exception("[TaskManager] 保存flow执行记录失败")
@@ -144,7 +167,25 @@ class TaskManager:
await task_collection.delete_one({"_id": task_id})
@staticmethod
- async def delete_tasks_by_conversation_id(conversation_id: str) -> None:
+ async def delete_tasks_by_conversation_id(conversation_id: str) -> list[str]:
+ """通过ConversationID删除Task信息"""
+ mongo = MongoDB()
+ task_collection = mongo.get_collection("task")
+ task_ids = []
+ try:
+ async for task in task_collection.find(
+ {"conversation_id": conversation_id},
+ {"_id": 1},
+ ):
+ task_ids.append(task["_id"])
+ if task_ids:
+ await task_collection.delete_many({"conversation_id": conversation_id})
+ except Exception:
+ logger.exception("[TaskManager] 删除ConversationID的Task信息失败")
+ return []
+
+ @staticmethod
+ async def delete_tasks_and_flow_context_by_conversation_id(conversation_id: str) -> None:
"""通过ConversationID删除Task信息"""
mongo = MongoDB()
task_collection = mongo.get_collection("task")
@@ -161,49 +202,6 @@ class TaskManager:
await task_collection.delete_many({"conversation_id": conversation_id}, session=session)
await flow_context_collection.delete_many({"task_id": {"$in": task_ids}}, session=session)
- @classmethod
- async def get_task(
- cls,
- task_id: str | None = None,
- session_id: str | None = None,
- post_body: RequestData | None = None,
- user_sub: str | None = None,
- ) -> Task:
- """获取任务块"""
- if task_id:
- try:
- task = await cls.get_task_by_task_id(task_id)
- if task:
- return task
- except Exception:
- logger.exception("[TaskManager] 通过task_id获取任务失败")
-
- logger.info("[TaskManager] 未提供task_id,通过session_id获取任务")
- if not session_id or not post_body:
- err = (
- "session_id 和 conversation_id 或 group_id 和 conversation_id 是恢复/创建任务的必要条件。"
- )
- raise ValueError(err)
-
- if post_body.group_id:
- task = await cls.get_task_by_group_id(post_body.group_id, post_body.conversation_id)
- else:
- task = await cls.get_task_by_conversation_id(post_body.conversation_id)
-
- if task:
- return task
- return Task(
- _id=str(uuid.uuid4()),
- ids=TaskIds(
- user_sub=user_sub if user_sub else "",
- session_id=session_id if session_id else "",
- conversation_id=post_body.conversation_id,
- group_id=post_body.group_id if post_body.group_id else "",
- ),
- tokens=TaskTokens(),
- runtime=TaskRuntime(),
- )
-
@classmethod
async def save_task(cls, task_id: str, task: Task) -> None:
"""保存任务块"""
diff --git a/apps/services/user.py b/apps/services/user.py
index 2721d3773bd43e2a8fda5b371d6336fcf0b9b7f3..476f1ef43b80fa5d25c91820f471bf4ff035f475 100644
--- a/apps/services/user.py
+++ b/apps/services/user.py
@@ -4,6 +4,7 @@
import logging
from datetime import UTC, datetime
+from apps.schemas.request_data import UserUpdateRequest
from apps.common.mongo import MongoDB
from apps.schemas.collection import User
from apps.services.conversation import ConversationManager
@@ -52,7 +53,26 @@ class UserManager:
return User(**user_data) if user_data else None
@staticmethod
- async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool:
+ async def update_userinfo_by_user_sub(user_sub: str, data: UserUpdateRequest) -> bool:
+ """
+ 根据用户sub更新用户信息
+
+ :param user_sub: 用户sub
+ :param data: 用户更新信息
+ :return: 是否更新成功
+ """
+ mongo = MongoDB()
+ user_collection = mongo.get_collection("user")
+ update_dict = {
+ "$set": {
+ "auto_execute": data.auto_execute,
+ }
+ }
+ result = await user_collection.update_one({"_id": user_sub}, update_dict)
+ return result.modified_count > 0
+
+ @staticmethod
+ async def update_refresh_revision_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool:
"""
根据用户sub更新用户信息
diff --git a/tests/manager/test_user.py b/tests/manager/test_user.py
index d350ca307a9f0aa2d6532cae614b8633b1ae0f81..ab6eeab4c498a2619c4f08c5b548a6096611aeaa 100644
--- a/tests/manager/test_user.py
+++ b/tests/manager/test_user.py
@@ -73,7 +73,7 @@ class TestUserManager(unittest.TestCase):
mock_mysql_db_instance.get_session.return_value = mock_session
# 调用被测方法
- updated_userinfo = UserManager.update_userinfo_by_user_sub(userinfo, refresh_revision=True)
+ updated_userinfo = UserManager.update_refresh_revision_by_user_sub(userinfo, refresh_revision=True)
# 断言返回的用户信息的 revision_number 是否与原始用户信息一致
self.assertEqual(updated_userinfo.revision_number, userinfo.revision_number)