From 78f633cec687fcccf2476cad46f8c0d8d38cfd70 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 24 Oct 2025 17:26:02 +0800 Subject: [PATCH 1/5] =?UTF-8?q?userSub=E6=94=B9userId=EF=BC=881=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 4 +- apps/dependency/user.py | 32 +++++------ apps/models/app.py | 10 ++-- apps/models/comment.py | 4 +- apps/models/conversation.py | 4 +- apps/models/document.py | 4 +- apps/models/mcp.py | 6 +- apps/models/record.py | 4 +- apps/models/service.py | 8 +-- apps/models/session.py | 8 +-- apps/models/settings.py | 2 +- apps/models/task.py | 2 +- apps/models/user.py | 16 +++--- apps/routers/chat.py | 24 ++++---- apps/routers/comment.py | 2 +- apps/routers/conversation.py | 14 ++--- apps/routers/document.py | 6 +- apps/routers/flow.py | 10 ++-- apps/routers/parameter.py | 2 +- apps/routers/record.py | 4 +- apps/schemas/record.py | 4 +- apps/services/activity.py | 17 +++--- apps/services/comment.py | 4 +- apps/services/conversation.py | 38 ++++++------- apps/services/flow.py | 8 +-- apps/services/llm.py | 1 - apps/services/mcp_service.py | 86 +++++++++++++++-------------- apps/services/personal_token.py | 10 ++-- apps/services/record.py | 18 +++--- apps/services/service.py | 66 +++++++++++----------- apps/services/session.py | 26 ++++----- apps/services/settings.py | 6 +- apps/services/tag.py | 8 +-- apps/services/task.py | 8 +-- apps/services/token.py | 24 ++++---- apps/services/user_tag.py | 10 ++-- tests/common/test_queue.py | 2 +- tests/schemas/test_user.py | 14 ++--- tests/services/test_conversation.py | 12 ++-- 39 files changed, 264 insertions(+), 264 deletions(-) diff --git a/apps/common/queue.py b/apps/common/queue.py index 685511477..87089a92c 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -13,7 +13,7 @@ from apps.schemas.enum_var import EventType from apps.schemas.message import ( HeartbeatData, MessageBase, - MessageFlow, + MessageExecutor, MessageMetadata, ) from apps.schemas.task import TaskData @@ -54,7 +54,7 @@ class MessageQueue: if task.state: # 如果使用了Flow - flow = MessageFlow( + flow = MessageExecutor( appId=task.state.appId, executorId=task.state.executorId, executorName=task.state.executorName, diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 545b5a959..4867df3a9 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -22,7 +22,7 @@ async def verify_session(request: HTTPConnection) -> None: - 如果Authorization头不存在或不以Bearer开头,抛出401 - 如果Bearer token以sk-开头,跳过(由verify_personal_token处理) - 如果Bearer token不以sk-开头,则作为Session ID校验 - - 如果是合法session则设置user_sub + - 如果是合法session则设置user_id :param request: HTTP请求 :return: @@ -48,29 +48,29 @@ async def verify_session(request: HTTPConnection) -> None: # 作为Session ID校验 session_id = token request.state.session_id = session_id - user = await SessionManager.get_user(session_id) - if not user: + user_id = await SessionManager.get_user(session_id) + if not user_id: logger.warning("Session ID鉴权失败:无效的session_id=%s", session_id) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Session ID 鉴权失败", ) - request.state.user_sub = user + request.state.user_id = user_id async def verify_personal_token(request: HTTPConnection) -> None: """ 验证Personal Token是否有效;作为第二层鉴权检查 - - 如果已经通过verify_session设置了user_sub,则跳过 + - 如果已经通过verify_session设置了user_id,则跳过 - 如果Bearer token以sk-开头,则作为Personal Token校验 - - 合法则设置user_sub,不合法则抛出401 + - 合法则设置user_id,不合法则抛出401 :param request: HTTP请求 :return: """ # 如果已经通过Session验证,则跳过 - if hasattr(request.state, "user_sub"): + if hasattr(request.state, "user_id"): return auth_header = request.headers.get("Authorization") @@ -87,9 +87,9 @@ async def verify_personal_token(request: HTTPConnection) -> None: # 检查是否为Personal Token(以sk-开头) if token.startswith("sk-"): # 验证是否为合法的Personal Token - user_sub = await PersonalTokenManager.get_user_by_personal_token(token) - if user_sub is not None: - request.state.user_sub = user_sub + user_id = await PersonalTokenManager.get_user_by_personal_token(token) + if user_id is not None: + request.state.user_id = user_id else: # Personal Token无效,抛出401 logger.warning("Personal Token鉴权失败:无效的token") @@ -100,15 +100,15 @@ async def verify_personal_token(request: HTTPConnection) -> None: async def verify_admin(request: HTTPConnection) -> None: """验证用户是否为管理员""" - if not hasattr(request.state, "user_sub"): + if not hasattr(request.state, "user_id"): logger.warning("管理员鉴权失败:用户未登录") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未登录") - user_sub = request.state.user_sub - user = await UserManager.get_user(user_sub) + user_id = request.state.user_id + user = await UserManager.get_user(user_id) request.state.user = user if not user: - logger.warning("管理员鉴权失败:用户不存在,user_sub=%s", user_sub) + logger.warning("管理员鉴权失败:用户不存在,user_id=%s", user_id) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户不存在") - if user.userSub not in config.login.admin_user: - logger.warning("管理员鉴权失败:用户无管理员权限,user_sub=%s", user_sub) + if user.userName not in config.login.admin_user: + logger.warning("管理员鉴权失败:用户无管理员权限,user_id=%s", user_id) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户无权限") diff --git a/apps/models/app.py b/apps/models/app.py index 0f41493d1..a2312c6b7 100644 --- a/apps/models/app.py +++ b/apps/models/app.py @@ -36,8 +36,8 @@ class App(Base): """应用名称""" description: Mapped[str] = mapped_column(String(2000), nullable=False) """应用描述""" - author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) - """应用作者""" + authorId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """应用作者ID""" appType: Mapped[AppType] = mapped_column(Enum(AppType), nullable=False) # noqa: N815 """应用类型""" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) @@ -60,7 +60,7 @@ class App(Base): """权限类型""" # 索引 idx_published_updated_at: ClassVar[Index] = Index("idx_published_updated_at", "isPublished", "updatedAt") - idx_author_id_name: ClassVar[Index] = Index("idx_author_id_name", "author", "id", "name") + idx_author_id_name: ClassVar[Index] = Index("idx_author_id_name", "authorId", "id", "name") __table_args__: ClassVar[tuple[Any, ...]] = ( idx_published_updated_at, @@ -75,8 +75,8 @@ class AppACL(Base): __tablename__ = "framework_app_acl" appId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_app.id"), primary_key=True) # noqa: N815 """关联的应用ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False, index=True) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False, index=True) # noqa: N815 + """用户ID""" action: Mapped[str] = mapped_column(String(255), default="", index=True, nullable=False) """操作类型(读/写)""" diff --git a/apps/models/comment.py b/apps/models/comment.py index 0c438c205..fdc627937 100644 --- a/apps/models/comment.py +++ b/apps/models/comment.py @@ -30,8 +30,8 @@ class Comment(Base): UUID(as_uuid=True), ForeignKey("framework_record.id"), nullable=False, index=True, ) """问答对ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """用户ID""" commentType: Mapped[CommentType] = mapped_column(Enum(CommentType), nullable=False) # noqa: N815 """点赞点踩""" feedbackType: Mapped[list[str]] = mapped_column(ARRAY(String(100)), nullable=False) # noqa: N815 diff --git a/apps/models/conversation.py b/apps/models/conversation.py index 82645b918..8ba910396 100644 --- a/apps/models/conversation.py +++ b/apps/models/conversation.py @@ -18,8 +18,8 @@ class Conversation(Base): """对话""" __tablename__ = "framework_conversation" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """用户ID""" appId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=True, ) diff --git a/apps/models/document.py b/apps/models/document.py index 6e9e02405..6fdea76b1 100644 --- a/apps/models/document.py +++ b/apps/models/document.py @@ -16,7 +16,7 @@ class Document(Base): """文件""" __tablename__ = "framework_document" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 """用户ID""" name: Mapped[str] = mapped_column(String(255), nullable=False) """文件名称""" @@ -39,7 +39,7 @@ class Document(Base): ) """文件的创建时间""" # 索引 - idx_user_conversation: ClassVar[Index] = Index("idx_user_sub_conversation_id", "userSub", "conversationId") + idx_user_conversation: ClassVar[Index] = Index("idx_user_id_conversation_id", "userId", "conversationId") __table_args__: ClassVar[tuple[Any, ...]] = ( idx_user_conversation, diff --git a/apps/models/mcp.py b/apps/models/mcp.py index d48c97b4d..0b3a43f16 100644 --- a/apps/models/mcp.py +++ b/apps/models/mcp.py @@ -40,8 +40,8 @@ class MCPInfo(Base): """MCP 概述""" description: Mapped[str] = mapped_column(Text, nullable=False) """MCP 描述""" - author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) - """MCP 创建者""" + authorId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """MCP 创建者ID""" id: Mapped[str] = mapped_column(String(255), primary_key=True) """MCP ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 @@ -70,7 +70,7 @@ class MCPActivated(Base): String(255), ForeignKey("framework_mcp.id"), index=True, nullable=False, ) """MCP ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 """用户ID""" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" diff --git a/apps/models/record.py b/apps/models/record.py index c99d8ba64..2e97404cf 100644 --- a/apps/models/record.py +++ b/apps/models/record.py @@ -16,8 +16,8 @@ class Record(Base): """问答对""" __tablename__ = "framework_record" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """用户ID""" conversationId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_conversation.id"), nullable=False, ) diff --git a/apps/models/service.py b/apps/models/service.py index 714fb4638..60d1d711a 100644 --- a/apps/models/service.py +++ b/apps/models/service.py @@ -20,8 +20,8 @@ class Service(Base): """插件名称""" description: Mapped[str] = mapped_column(String(2000), nullable=False) """插件描述""" - author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) - """插件作者""" + authorId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """插件作者ID""" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) """插件ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 @@ -45,8 +45,8 @@ class ServiceACL(Base): __tablename__ = "framework_service_acl" serviceId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_service.id"), nullable=False) # noqa: N815 """关联的插件ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """用户ID""" action: Mapped[str] = mapped_column(String(255), default="", nullable=False) """操作类型(读/写)""" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) diff --git a/apps/models/session.py b/apps/models/session.py index d25a3e33c..90bed160c 100644 --- a/apps/models/session.py +++ b/apps/models/session.py @@ -24,8 +24,8 @@ class Session(Base): """会话""" __tablename__ = "framework_session" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """用户ID""" ip: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) """IP地址""" pluginId: Mapped[str | None] = mapped_column(String(255), nullable=True, default=None) # noqa: N815 @@ -51,7 +51,7 @@ class SessionActivity(Base): __tablename__ = "framework_session_activity" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 + """用户ID""" timestamp: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) """时间戳""" diff --git a/apps/models/settings.py b/apps/models/settings.py index fa9bb0ad1..8cf38ac51 100644 --- a/apps/models/settings.py +++ b/apps/models/settings.py @@ -36,6 +36,6 @@ class GlobalSettings(Base): """设置更新时间""" lastEditedBy: Mapped[str | None] = mapped_column( # noqa: N815 - String(50), ForeignKey("framework_user.userSub"), nullable=True, default=None, + String(50), ForeignKey("framework_user.id"), nullable=True, default=None, ) """最后一次修改的用户sub""" diff --git a/apps/models/task.py b/apps/models/task.py index 9c7a8c7ba..c799b243f 100644 --- a/apps/models/task.py +++ b/apps/models/task.py @@ -67,7 +67,7 @@ class Task(Base): """任务""" __tablename__ = "framework_task" - userSub: Mapped[str] = mapped_column(String(255), ForeignKey("framework_user.userSub")) # noqa: N815 + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id")) # noqa: N815 """用户ID""" conversationId: Mapped[uuid.UUID | None] = mapped_column( # noqa: N815 UUID(as_uuid=True), ForeignKey("framework_conversation.id"), nullable=True, diff --git a/apps/models/user.py b/apps/models/user.py index bd720c4f8..7d3555550 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -17,9 +17,9 @@ class User(Base): """用户表""" __tablename__ = "framework_user" - id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + id: Mapped[str] = mapped_column(String(50), primary_key=True, index=True, init=False) """用户ID""" - userSub: Mapped[str] = mapped_column(String(50), index=True, unique=True, nullable=False) # noqa: N815 + userName: Mapped[str] = mapped_column(String(50), index=True, unique=True, nullable=False) # noqa: N815 """用户名""" lastLogin: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), nullable=False, @@ -52,8 +52,8 @@ class UserFavorite(Base): __tablename__ = "framework_user_favorite" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """用户收藏ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), index=True, nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), index=True, nullable=False) # noqa: N815 + """用户ID""" favouriteType: Mapped[UserFavoriteType] = mapped_column(Enum(UserFavoriteType), nullable=False) # noqa: N815 """收藏类型""" itemId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), index=True, nullable=False) # noqa: N815 @@ -66,8 +66,8 @@ class UserAppUsage(Base): __tablename__ = "framework_user_app_usage" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """用户应用使用情况ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), unique=True, nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), unique=True, nullable=False) # noqa: N815 + """用户ID""" appId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False) # noqa: N815 """应用ID""" usageCount: Mapped[int] = mapped_column(Integer, default=0, nullable=False) # noqa: N815 @@ -84,8 +84,8 @@ class UserTag(Base): __tablename__ = "framework_user_tag" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """用户标签ID""" - userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub"), unique=True, nullable=False) # noqa: N815 - """用户名""" + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), unique=True, nullable=False) # noqa: N815 + """用户ID""" tag: Mapped[int] = mapped_column(BigInteger, ForeignKey("framework_tag.id"), nullable=False) """标签ID""" count: Mapped[int] = mapped_column(Integer, default=0, nullable=False) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 33d813cd0..213afc55c 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -29,16 +29,16 @@ router = APIRouter( ], ) -async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: +async def chat_generator(post_body: RequestData, user_id: str, session_id: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" try: - await Activity.set_active(user_sub) + await Activity.set_active(user_id) # 敏感词检查 if await words_check.check(post_body.question) != 1: yield "data: [SENSITIVE]\n\n" _logger.info("[Chat] 问题包含敏感词!") - await Activity.remove_active(user_sub) + await Activity.remove_active(user_id) return # 创建queue;由Scheduler进行关闭 @@ -47,8 +47,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 在单独Task中运行Scheduler,拉齐queue.get的时机 scheduler = Scheduler() - await scheduler.init(queue, post_body, user_sub) - _logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_sub)}") + await scheduler.init(queue, post_body, user_id) + _logger.info(f"[Chat] 用户是否活跃: {await Activity.is_active(user_id)}") scheduler_task = asyncio.create_task(scheduler.run()) # 处理每一条消息 @@ -65,14 +65,14 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) if task.state.executorStatus == ExecutorStatus.ERROR: _logger.error("[Chat] 生成答案失败") yield "data: [ERROR]\n\n" - await Activity.remove_active(user_sub) + await Activity.remove_active(user_id) return # 对结果进行敏感词检查 if await words_check.check(task.runtime.fullAnswer) != 1: yield "data: [SENSITIVE]\n\n" _logger.info("[Chat] 答案包含敏感词!") - await Activity.remove_active(user_sub) + await Activity.remove_active(user_id) return if post_body.app and post_body.app.flow_id: @@ -89,23 +89,23 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) yield "data: [ERROR]\n\n" finally: - await Activity.remove_active(user_sub) + await Activity.remove_active(user_id) @router.post("/chat") async def chat(request: Request, post_body: RequestData) -> StreamingResponse: """LLM流式对话接口""" session_id = request.state.session_id - user_sub = request.state.user_sub + user_id = request.state.user_id # 问题黑名单检测 if (post_body.question is not None) and ( not await QuestionBlacklistManager.check_blacklisted_questions(input_question=post_body.question) ): # 用户扣分 - await UserBlacklistManager.change_blacklisted_users(user_sub, -10) + await UserBlacklistManager.change_blacklisted_users(user_id, -10) raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") - res = chat_generator(post_body, user_sub, session_id) + res = chat_generator(post_body, user_id, session_id) return StreamingResponse( content=res, media_type="text/event-stream", @@ -118,7 +118,7 @@ async def chat(request: Request, post_body: RequestData) -> StreamingResponse: @router.post("/stop", response_model=ResponseData) async def stop_generation(request: Request) -> JSONResponse: """停止生成""" - await Activity.remove_active(request.state.user_sub) + await Activity.remove_active(request.state.user_id) return JSONResponse( status_code=status.HTTP_200_OK, diff --git a/apps/routers/comment.py b/apps/routers/comment.py index 276035ff7..18caec43b 100644 --- a/apps/routers/comment.py +++ b/apps/routers/comment.py @@ -34,7 +34,7 @@ async def add_comment(request: Request, post_body: AddCommentData) -> JSONRespon reason_description=post_body.reason_description, feedback_time=round(datetime.now(tz=UTC).timestamp(), 3), ) - result = await CommentManager.update_comment(post_body.record_id, comment_data, request.state.user_sub) + result = await CommentManager.update_comment(post_body.record_id, comment_data, request.state.user_id) if not result: return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( code=status.HTTP_400_BAD_REQUEST, diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 0707192b9..1076a0569 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -40,8 +40,8 @@ _logger = logging.getLogger(__name__) ) async def get_conversation_list(request: Request) -> JSONResponse: """GET /conversation: 获取对话列表""" - conversations = await ConversationManager.get_conversation_by_user_sub( - request.state.user_sub, + conversations = await ConversationManager.get_conversation_by_user( + request.state.user_id, ) # 把已有对话转换为列表 result_conversations = [] @@ -74,8 +74,8 @@ async def update_conversation( ) -> JSONResponse: """PUT /conversation: 更新特定Conversation的信息""" # 判断Conversation是否合法 - conv = await ConversationManager.get_conversation_by_conversation_id(request.state.user_sub, conversation_id) - if not conv or conv.userSub != request.state.user_sub: + conv = await ConversationManager.get_conversation_by_conversation_id(request.state.user_id, conversation_id) + if not conv or conv.userId != request.state.user_id: _logger.error("[Conversation] conversation_id 不存在") return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, @@ -89,7 +89,7 @@ async def update_conversation( # 更新Conversation数据 try: await ConversationManager.update_conversation_by_conversation_id( - request.state.user_sub, + request.state.user_id, conversation_id, { "title": post_body.title, @@ -133,12 +133,12 @@ async def delete_conversation( for conversation_id in post_body.conversation_list: # 删除对话 await ConversationManager.delete_conversation_by_conversation_id( - request.state.user_sub, + request.state.user_id, conversation_id, ) # 删除对话对应的文件 await DocumentManager.delete_document_by_conversation_id( - request.state.user_sub, + request.state.user_id, conversation_id, ) deleted_conversation.append(conversation_id) diff --git a/apps/routers/document.py b/apps/routers/document.py index ddf076953..f916dafca 100644 --- a/apps/routers/document.py +++ b/apps/routers/document.py @@ -39,7 +39,7 @@ async def document_upload( documents: list[UploadFile], ) -> JSONResponse: """POST /document/{conversation_id}: 上传文档到指定对话""" - result = await DocumentManager.storage_docs(request.state.user_sub, conversation_id, documents) + result = await DocumentManager.storage_docs(request.state.user_id, conversation_id, documents) await KnowledgeBaseService.send_file_to_rag(request.state.session_id, result) # 返回所有Framework已知的文档 @@ -71,7 +71,7 @@ async def get_document_list( ) -> JSONResponse: """GET /document/{conversation_id}: 获取特定对话的文档列表""" # 判断Conversation有权访问 - if not await ConversationManager.verify_conversation_access(request.state.user_sub, conversation_id): + if not await ConversationManager.verify_conversation_access(request.state.user_id, conversation_id): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=ResponseData( @@ -143,7 +143,7 @@ async def delete_single_document( ) -> JSONResponse: """DELETE /document/{document_id}: 删除单个文件""" # 在Framework侧删除 - result = await DocumentManager.delete_document(request.state.user_sub, [document_id]) + result = await DocumentManager.delete_document(request.state.user_id, [document_id]) if not result: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/flow.py b/apps/routers/flow.py index 799cd4e53..3e20c2925 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -40,7 +40,7 @@ router = APIRouter( ) async def get_services(request: Request) -> NodeServiceListRsp: """获取用户可访问的节点元数据所在服务的信息""" - services = await FlowManager.get_service_by_user_id(request.state.user_sub) + services = await FlowManager.get_service_by_user(request.state.user_id) if services is None: return NodeServiceListRsp( code=status.HTTP_404_NOT_FOUND, @@ -62,7 +62,7 @@ async def get_services(request: Request) -> NodeServiceListRsp: ) async def get_flow(request: Request, appId: uuid.UUID, flowId: str) -> JSONResponse: # noqa: N803 """获取流拓扑结构""" - if not await AppCenterManager.validate_user_app_access(request.state.user_sub, appId): + if not await AppCenterManager.validate_user_app_access(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=FlowStructureGetRsp( @@ -105,7 +105,7 @@ async def put_flow( put_body: Annotated[PutFlowReq, Body()], ) -> JSONResponse: """修改流拓扑结构""" - if not await AppCenterManager.validate_app_belong_to_user(request.state.user_sub, appId): + if not await AppCenterManager.validate_app_belong_to_user(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=FlowStructurePutRsp( @@ -131,7 +131,7 @@ async def put_flow( ) flow = await FlowManager.get_flow_by_app_and_flow_id(appId, flowId) - await AppCenterManager.update_app_publish_status(appId, request.state.user_sub) + await AppCenterManager.update_app_publish_status(appId, request.state.user_id) if flow is None: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -157,7 +157,7 @@ async def put_flow( ) async def delete_flow(request: Request, appId: uuid.UUID, flowId: str) -> JSONResponse: # noqa: N803 """删除流拓扑结构""" - if not await AppCenterManager.validate_app_belong_to_user(request.state.user_sub, appId): + if not await AppCenterManager.validate_app_belong_to_user(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=FlowStructureDeleteRsp( diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 3c46d2cd0..781faf546 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -28,7 +28,7 @@ async def get_parameters( request: Request, appId: uuid.UUID, flowId: str, stepId: uuid.UUID, # noqa: N803 ) -> JSONResponse: """Get parameters for node choice.""" - if not await AppCenterManager.validate_user_app_access(request.state.user_sub, appId): + if not await AppCenterManager.validate_user_app_access(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content=GetParamsRsp( diff --git a/apps/routers/record.py b/apps/routers/record.py index a4cacae15..677c17631 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -45,7 +45,7 @@ router = APIRouter( async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path()]) -> JSONResponse: # noqa: N803 """GET /record/{conversationId}: 获取某个对话的所有问答对""" cur_conv = await ConversationManager.get_conversation_by_conversation_id( - request.state.user_sub, conversationId, + request.state.user_id, conversationId, ) # 判断conversation是否合法 if not cur_conv: @@ -82,7 +82,7 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path ) # 获得Record关联的文档 - tmp_record.document = await DocumentManager.get_used_docs_by_record(user_sub, record_group.id) + tmp_record.document = await DocumentManager.get_used_docs_by_record(user_id, record_group.id) # 获得Record关联的flow数据 flow_step_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) if flow_step_list: diff --git a/apps/schemas/record.py b/apps/schemas/record.py index cc8f53957..b1f558034 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -16,7 +16,7 @@ class RecordDocument(BaseModel): id: str = Field(alias="_id", default="") order: int = Field(default=0, description="文档顺序") abstract: str = Field(default="", description="文档摘要") - user_sub: None = None + user_id: None = None author: str = Field(default="", description="文档作者") associated: Literal["question", "answer"] @@ -102,7 +102,7 @@ class FlowHistory(BaseModel): class Record(RecordData): """问答,用于保存在MongoDB中""" - user_sub: str + user_id: str key: dict[str, Any] = {} task_id: uuid.UUID | None = Field(default=None, description="任务ID") content: str = Field(default="", description="Record内容,已加密") diff --git a/apps/services/activity.py b/apps/services/activity.py index ca9166386..c83db3266 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -15,20 +15,19 @@ class Activity: """活动控制:全局并发限制,同时最多有 n 个任务在执行(与用户无关)""" @staticmethod - async def is_active(user_sub: str) -> bool: + async def is_active(user_id: str) -> bool: """ 判断系统是否达到全局并发上限 - :param user_sub: 用户实体ID(兼容现有接口签名) + :param user_id: 用户实体ID(兼容现有接口签名) :return: 达到并发上限返回 True,否则 False """ - _ = user_sub time = datetime.now(tz=UTC) async with postgres.session() as session: # 单用户滑动窗口限流:统计该用户在窗口内的请求数 count = (await session.scalars(select(func.count(SessionActivity.id)).where( - SessionActivity.userSub == user_sub, + SessionActivity.userId == user_id, SessionActivity.timestamp >= time - timedelta(seconds=SLIDE_WINDOW_TIME), SessionActivity.timestamp <= time, ))).one() @@ -40,7 +39,7 @@ class Activity: return current_active >= MAX_CONCURRENT_TASKS @staticmethod - async def set_active(user_sub: str) -> None: + async def set_active(user_id: str) -> None: """设置活跃标识:当未超过全局并发上限时登记一个活动任务""" time = datetime.now(UTC) async with postgres.session() as session: @@ -49,19 +48,19 @@ class Activity: if current_active >= MAX_CONCURRENT_TASKS: err = "系统并发已达上限" raise ActivityError(err) - await session.merge(SessionActivity(userSub=user_sub, timestamp=time)) + await session.merge(SessionActivity(userId=user_id, timestamp=time)) await session.commit() @staticmethod - async def remove_active(user_sub: str) -> None: + async def remove_active(user_id: str) -> None: """ 释放一个活动任务名额(按发起者标识清除对应记录) - :param user_sub: 用户实体ID + :param user_id: 用户实体ID """ async with postgres.session() as session: await session.execute( - delete(SessionActivity).where(SessionActivity.userSub == user_sub), + delete(SessionActivity).where(SessionActivity.userId == user_id), ) await session.commit() diff --git a/apps/services/comment.py b/apps/services/comment.py index 99c6f86cb..04b312ed5 100644 --- a/apps/services/comment.py +++ b/apps/services/comment.py @@ -41,7 +41,7 @@ class CommentManager: return None @staticmethod - async def update_comment(record_id: str, data: RecordComment, user_sub: str) -> None: + async def update_comment(record_id: str, data: RecordComment, user_id: str) -> None: """ 更新评论 @@ -62,7 +62,7 @@ class CommentManager: else: comment_info = Comment( recordId=uuid.UUID(record_id), - userSub=user_sub, + userId=user_id, commentType=data.comment, feedbackType=data.feedback_type, feedbackLink=data.feedback_link, diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 403e9e92b..4f0f03577 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -20,13 +20,13 @@ class ConversationManager: """对话管理器""" @staticmethod - async def get_conversation_by_user_sub(user_sub: str) -> list[Conversation]: + async def get_conversation_by_user(user_id: str) -> list[Conversation]: """根据用户ID获取对话列表,按时间由近到远排序""" async with postgres.session() as session: result = (await session.scalars( select(Conversation).where( and_( - Conversation.userSub == user_sub, + Conversation.userId == user_id, Conversation.isTemporary == False, # noqa: E712 ), ).order_by( @@ -37,28 +37,28 @@ class ConversationManager: @staticmethod - async def get_conversation_by_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> Conversation | None: + async def get_conversation_by_conversation_id(user_id: str, conversation_id: uuid.UUID) -> Conversation | None: """通过ConversationID查询对话信息""" async with postgres.session() as session: return (await session.scalars( select(Conversation).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one_or_none() @staticmethod - async def verify_conversation_access(user_sub: str, conversation_id: uuid.UUID) -> bool: + async def verify_conversation_access(user_id: str, conversation_id: uuid.UUID) -> bool: """验证对话是否属于用户""" async with postgres.session() as session: result = (await session.scalars( func.count(Conversation.id).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one() @@ -66,14 +66,14 @@ class ConversationManager: @staticmethod - async def add_conversation_by_user_sub( - title: str, user_sub: str, app_id: uuid.UUID | None = None, + async def add_conversation_by_user( + title: str, user_id: str, app_id: uuid.UUID | None = None, *, debug: bool = False, ) -> Conversation | None: """通过用户ID新建对话""" conv = Conversation( - userSub=user_sub, + userId=user_id, appId=app_id, isTemporary=debug, title=title, @@ -90,7 +90,7 @@ class ConversationManager: app_obj = (await session.scalars( select(UserAppUsage).where( and_( - UserAppUsage.userSub == user_sub, + UserAppUsage.userId == user_id, UserAppUsage.appId == app_id, ), ), @@ -104,7 +104,7 @@ class ConversationManager: await session.commit() else: await session.merge(UserAppUsage( - userSub=user_sub, + userId=user_id, appId=app_id, usageCount=1, lastUsed=datetime.now(tz=UTC), @@ -118,7 +118,7 @@ class ConversationManager: @staticmethod async def update_conversation_by_conversation_id( - user_sub: str, conversation_id: uuid.UUID, data: dict[str, Any], + user_id: str, conversation_id: uuid.UUID, data: dict[str, Any], ) -> bool: """通过ConversationID更新对话信息""" async with postgres.session() as session: @@ -126,7 +126,7 @@ class ConversationManager: select(Conversation).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one_or_none() @@ -140,14 +140,14 @@ class ConversationManager: @staticmethod - async def delete_conversation_by_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> None: + async def delete_conversation_by_conversation_id(user_id: str, conversation_id: uuid.UUID) -> None: """通过ConversationID删除对话""" async with postgres.session() as session: conv = (await session.scalars( select(Conversation).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one_or_none() @@ -161,11 +161,11 @@ class ConversationManager: @staticmethod - async def delete_conversation_by_user_sub(user_sub: str) -> None: + async def delete_conversation_by_user(user_id: str) -> None: """通过用户ID删除对话""" async with postgres.session() as session: convs = list((await session.scalars( - select(Conversation).where(Conversation.userSub == user_sub), + select(Conversation).where(Conversation.userId == user_id), )).all()) for conv in convs: await session.delete(conv) @@ -174,14 +174,14 @@ class ConversationManager: @staticmethod - async def verify_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> bool: + async def verify_conversation_id(user_id: str, conversation_id: uuid.UUID) -> bool: """验证对话ID是否属于用户""" async with postgres.session() as session: result = (await session.scalars( func.count(Conversation.id).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one() diff --git a/apps/services/flow.py b/apps/services/flow.py index 9a9435d72..a358f86bf 100644 --- a/apps/services/flow.py +++ b/apps/services/flow.py @@ -79,11 +79,11 @@ class FlowManager: @staticmethod - async def get_service_by_user_id(user_sub: str) -> list[NodeServiceItem] | None: + async def get_service_by_user(user_id: str) -> list[NodeServiceItem] | None: """ 通过user_id获取用户自己上传的或收藏的Service,用于插件中心展示 - :user_sub: 用户的唯一标识符 + :param user_id: str: 用户的唯一标识符 :return: service的列表 """ async with postgres.session() as session: @@ -92,7 +92,7 @@ class FlowManager: user_favs = list((await session.scalars( select(UserFavorite.itemId).where( and_( - UserFavorite.userSub == user_sub, + UserFavorite.userId == user_id, UserFavorite.favouriteType == UserFavoriteType.SERVICE, ), ), @@ -101,7 +101,7 @@ class FlowManager: # 用户所上传的Service user_services = list((await session.scalars( select(Service.id).where( - Service.author == user_sub, + Service.authorId == user_id, ), )).all()) service_ids = service_ids + user_services diff --git a/apps/services/llm.py b/apps/services/llm.py index 416544190..6eb3b6532 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -24,7 +24,6 @@ class LLMManager: """ 通过ID获取大模型 - :param user_sub: 用户ID :param llm_id: 大模型ID :return: 大模型对象 """ diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 235492f9f..43f3653ab 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -1,5 +1,5 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""MCP服务(插件中心)管理器""" +"""MCP服务(插件中心)管理""" import logging import re @@ -42,22 +42,22 @@ MCP_ICON_PATH = ICON_PATH / "mcp" class MCPServiceManager: - """MCP服务管理器""" + """MCP服务管理""" @staticmethod - async def is_active(user_sub: str, mcp_id: str) -> bool: + async def is_active(user_id: str, mcp_id: str) -> bool: """ 判断用户是否激活MCP - :param str user_sub: 用户ID - :param str mcp_id: MCP服务ID + :param user_id: str: 用户ID + :param mcp_id: str: MCP服务ID :return: 是否激活 """ async with postgres.session() as session: mcp_info = (await session.scalars(select(MCPActivated).where( and_( MCPActivated.mcpId == mcp_id, - MCPActivated.userSub == user_sub, + MCPActivated.userId == user_id, ), ))).one_or_none() return bool(mcp_info) @@ -94,7 +94,7 @@ class MCPServiceManager: @staticmethod async def fetch_mcp_services( # noqa: PLR0913 search_type: SearchType, - user_sub: str, + user_id: str, keyword: str | None, page: int, *, @@ -105,13 +105,15 @@ class MCPServiceManager: 获取所有MCP服务列表 :param search_type: SearchType: str: MCP搜索类型 - :param user_sub: str: 用户ID + :param user_id: str: 用户ID :param keyword: str: MCP搜索关键字 :param page: int: 页码 + :param is_install: bool | None: 是否已安装 + :param is_active: bool | None: 是否已激活 :return: MCP服务列表 """ mcpservice_pools = await MCPServiceManager._search_mcpservice( - search_type, keyword, page, user_sub, is_active=is_active, is_installed=is_install, + search_type, keyword, page, user_id, is_active=is_active, is_installed=is_install, ) return [ MCPServiceCardItem( @@ -119,8 +121,8 @@ class MCPServiceManager: icon=await MCPServiceManager.get_icon_path(item.id), name=item.name, description=item.description, - author=item.author, - isActive=await MCPServiceManager.is_active(user_sub, item.id), + author=item.authorId, + isActive=await MCPServiceManager.is_active(user_id, item.id), status=await MCPServiceManager.get_service_status(item.id), ) for item in mcpservice_pools @@ -155,7 +157,7 @@ class MCPServiceManager: @staticmethod async def get_mcp_tools(mcp_id: str) -> list[MCPTools]: """ - 获取MCP可以用工具 + 获取MCP可用工具 :param mcp_id: str: MCP服务ID :return: MCP工具详细信息列表 @@ -169,7 +171,7 @@ class MCPServiceManager: search_type: SearchType, keyword: str | None, page: int, - user_sub: str, + user_id: str, *, is_active: bool | None = None, is_installed: bool | None = None, @@ -191,7 +193,7 @@ class MCPServiceManager: or_( MCPInfo.name.like(f"%{keyword}%"), MCPInfo.description.like(f"%{keyword}%"), - MCPInfo.author.like(f"%{keyword}%"), + MCPInfo.authorId.like(f"%{keyword}%"), ), ) elif search_type == SearchType.NAME: @@ -199,17 +201,17 @@ class MCPServiceManager: elif search_type == SearchType.DESCRIPTION: sql = sql.where(MCPInfo.description.like(f"%{keyword}%")) elif search_type == SearchType.AUTHOR: - sql = sql.where(MCPInfo.author.like(f"%{keyword}%")) + sql = sql.where(MCPInfo.authorId.like(f"%{keyword}%")) sql = sql.offset(skip).limit(SERVICE_PAGE_SIZE) if is_installed is not None: sql = sql.where(MCPInfo.id.in_( - select(MCPActivated.mcpId).where(MCPActivated.userSub == user_sub), + select(MCPActivated.mcpId).where(MCPActivated.userId == user_id), )) if is_active is not None: sql = sql.where(MCPInfo.id.in_( - select(MCPActivated.mcpId).where(MCPActivated.userSub == user_sub), + select(MCPActivated.mcpId).where(MCPActivated.userId == user_id), )) result = list((await session.scalars(sql)).all()) @@ -218,12 +220,12 @@ class MCPServiceManager: if not result: logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_type) return [] - # 将数据库中的MCP服务转换为对象 + # 将数据库中的MCP服务转换为对应的对象列表 return result @staticmethod - async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: + async def create_mcpservice(data: UpdateMCPServiceRequest, user_id: str) -> str: """ 创建MCP服务 @@ -245,7 +247,7 @@ class MCPServiceManager: data.mcp_id: config, }, mcpType=data.mcp_type, - author=user_sub, + author=user_id, ) # 检查是否存在相同服务 @@ -255,8 +257,7 @@ class MCPServiceManager: mcp_server.name = f"{mcp_server.name}-{uuid.uuid4().hex[:6]}" logger.warning("[MCPServiceManager] 已存在相同ID或名称的MCP服务") - # 保存并载入配置 - logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name) + # 保存并载入配�? logger.info("[MCPServiceManager] 创建mcp�?s", mcp_server.name) mcp_path = MCP_PATH / "template" / data.mcp_id / "project" index = None if isinstance(config, MCPServerStdioConfig): @@ -280,7 +281,7 @@ class MCPServiceManager: overview=mcp_server.overview, description=mcp_server.description, mcpType=mcp_server.mcpType, - author=mcp_server.author or "", + authorId=mcp_server.author or "", )) await session.commit() await MCPLoader.save_one(data.mcp_id, mcp_server) @@ -289,7 +290,7 @@ class MCPServiceManager: @staticmethod - async def update_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: + async def update_mcpservice(data: UpdateMCPServiceRequest, user_id: str) -> str: """ 更新MCP服务 @@ -304,7 +305,7 @@ class MCPServiceManager: db_service = (await session.scalars(select(MCPInfo).where( and_( MCPInfo.id == data.mcp_id, - MCPInfo.author == user_sub, + MCPInfo.authorId == user_id, ), ))).one_or_none() if not db_service: @@ -319,7 +320,7 @@ class MCPServiceManager: # 为每个激活的用户取消激活 for activated_user in activated_users: - await MCPServiceManager.deactive_mcpservice(user_sub=activated_user.userSub, mcp_id=data.mcp_id) + await MCPServiceManager.deactive_mcpservice(user_id=activated_user.userId, mcp_id=data.mcp_id) mcp_config = MCPServerConfig( name=data.name, @@ -329,7 +330,7 @@ class MCPServiceManager: data.mcp_id: data.config[data.mcp_id], }, mcpType=data.mcp_type, - author=user_sub, + author=user_id, ) async with postgres.session() as session: await session.merge(MCPInfo( @@ -338,7 +339,7 @@ class MCPServiceManager: overview=mcp_config.overview, description=mcp_config.description, mcpType=mcp_config.mcpType, - author=mcp_config.author or "", + authorId=mcp_config.author or "", )) await session.commit() await MCPLoader.save_one(mcp_id=data.mcp_id, config=mcp_config) @@ -366,15 +367,16 @@ class MCPServiceManager: @staticmethod async def active_mcpservice( - user_sub: str, + user_id: str, mcp_id: str, mcp_env: dict[str, Any] | None = None, ) -> None: """ 激活MCP服务 - :param user_sub: str: 用户ID + :param user_id: str: 用户ID :param mcp_id: str: MCP服务ID + :param mcp_env: dict[str, Any] | None: MCP服务环境变量 :return: 无 """ if mcp_env is None: @@ -390,29 +392,29 @@ class MCPServiceManager: raise RuntimeError(err) await session.merge(MCPActivated( mcpId=mcp_info.id, - userSub=user_sub, + userId=user_id, )) await session.commit() - await MCPLoader.user_active_template(user_sub, mcp_id, mcp_env) + await MCPLoader.user_active_template(user_id, mcp_id, mcp_env) @staticmethod async def deactive_mcpservice( - user_sub: str, + user_id: str, mcp_id: str, ) -> None: """ 取消激活MCP服务 - :param user_sub: str: 用户ID + :param user_id: str: 用户ID :param mcp_id: str: MCP服务ID :return: 无 """ try: - await mcp_pool.stop(mcp_id=mcp_id, user_sub=user_sub) + await mcp_pool.stop(mcp_id=mcp_id, user_id=user_id) except KeyError: logger.warning("[MCPServiceManager] MCP服务无进程") - await MCPLoader.user_deactive_template(user_sub, mcp_id) + await MCPLoader.user_deactive_template(user_id, mcp_id) @staticmethod @@ -455,11 +457,11 @@ class MCPServiceManager: @staticmethod - async def is_user_actived(user_sub: str, mcp_id: str) -> bool: + async def is_user_actived(user_id: str, mcp_id: str) -> bool: """ 判断用户是否激活MCP - :param user_sub: str: 用户ID + :param user_id: str: 用户ID :param mcp_id: str: MCP服务ID :return: 是否激活 """ @@ -467,7 +469,7 @@ class MCPServiceManager: mcp_info = (await session.scalars(select(MCPActivated).where( and_( MCPActivated.mcpId == mcp_id, - MCPActivated.userSub == user_sub, + MCPActivated.userId == user_id, ), ))).one_or_none() return bool(mcp_info) @@ -486,11 +488,11 @@ class MCPServiceManager: @staticmethod - async def install_mcpservice(user_sub: str, service_id: str, *, install: bool) -> None: + async def install_mcpservice(user_id: str, service_id: str, *, install: bool) -> None: """ 安装或卸载MCP服务 - :param user_sub: str: 用户ID + :param user_id: str: 用户ID :param service_id: str: MCP服务ID :param install: bool: 是否安装 :return: 无 @@ -499,7 +501,7 @@ class MCPServiceManager: db_service = (await session.scalars(select(MCPInfo).where( and_( MCPInfo.id == service_id, - MCPInfo.author == user_sub, + MCPInfo.authorId == user_id, ), ))).one_or_none() diff --git a/apps/services/personal_token.py b/apps/services/personal_token.py index d72be9854..c56c66576 100644 --- a/apps/services/personal_token.py +++ b/apps/services/personal_token.py @@ -18,7 +18,7 @@ class PersonalTokenManager: @staticmethod async def get_user_by_personal_token(personal_token: str) -> str | None: """ - 根据Personal Token获取用户信息 + 根据Personal Token获取用户ID :param personal_token: Personal Token :return: 用户ID @@ -27,7 +27,7 @@ class PersonalTokenManager: try: result = ( await session.scalars( - select(User.userSub).where(User.personalToken == personal_token), + select(User.id).where(User.personalToken == personal_token), ) ).one_or_none() except Exception: @@ -38,11 +38,11 @@ class PersonalTokenManager: @staticmethod - async def update_personal_token(user_sub: str) -> str | None: + async def update_personal_token(user_id: str) -> str | None: """ 更新Personal Token - :param user_sub: 用户ID + :param user_id: 用户ID :return: 更新后的Personal Token """ personal_token = uuid.uuid4().hex @@ -50,7 +50,7 @@ class PersonalTokenManager: try: async with postgres.session() as session: await session.execute( - update(User).where(User.id == user_sub).values(personal_token=personal_token_with_prefix), + update(User).where(User.id == user_id).values(personal_token=personal_token_with_prefix), ) await session.commit() except Exception: diff --git a/apps/services/record.py b/apps/services/record.py index dc0f6ddd7..d90160471 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -20,12 +20,12 @@ class RecordManager: """问答对相关操作""" @staticmethod - async def verify_record_in_conversation(record_id: uuid.UUID, user_sub: str, conversation_id: uuid.UUID) -> bool: + async def verify_record_in_conversation(record_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID) -> bool: """ 校验指定record_id是否属于指定用户和会话(PostgreSQL实现) :param record_id: 记录ID - :param user_sub: 用户sub + :param user_id: 用户ID :param conversation_id: 会话ID :return: 是否存在 """ @@ -34,7 +34,7 @@ class RecordManager: select(PgRecord).where( and_( PgRecord.id == record_id, - PgRecord.userSub == user_sub, + PgRecord.userId == user_id, PgRecord.conversationId == conversation_id, ), ))).one_or_none() @@ -42,14 +42,14 @@ class RecordManager: @staticmethod - async def insert_record_data(user_sub: str, conversation_id: uuid.UUID, record: Record) -> uuid.UUID | None: + async def insert_record_data(user_id: str, conversation_id: uuid.UUID, record: Record) -> uuid.UUID | None: """Record插入PostgreSQL""" async with postgres.session() as session: conv = (await session.scalars( select(Conversation).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one_or_none() @@ -61,7 +61,7 @@ class RecordManager: id=record.id, conversationId=conversation_id, taskId=record.task_id, - userSub=user_sub, + userId=user_id, content=record.content, key=record.key, createdAt=datetime.fromtimestamp(record.created_at, tz=UTC), @@ -73,7 +73,7 @@ class RecordManager: @staticmethod async def query_record_by_conversation_id( - user_sub: str, + user_id: str, conversation_id: uuid.UUID, total_pairs: int | None = None, order: Literal["desc", "asc"] = "desc", @@ -83,7 +83,7 @@ class RecordManager: sql = select(PgRecord).where( and_( PgRecord.conversationId == conversation_id, - PgRecord.userSub == user_sub, + PgRecord.userId == user_id, ), ).order_by(PgRecord.createdAt.desc() if order == "desc" else PgRecord.createdAt.asc()) if total_pairs is not None: @@ -97,7 +97,7 @@ class RecordManager: id=pg_record.id, conversationId=pg_record.conversationId, task_id=pg_record.taskId, - user_sub=pg_record.userSub, + user_id=pg_record.userId, content=pg_record.content, key=pg_record.key, createdAt=pg_record.createdAt.timestamp(), diff --git a/apps/services/service.py b/apps/services/service.py index 22751d501..3db7ccff3 100644 --- a/apps/services/service.py +++ b/apps/services/service.py @@ -63,11 +63,11 @@ class ServiceCenterManager: SearchType.ALL: or_( Service.name.like(f"%{keyword}%"), Service.description.like(f"%{keyword}%"), - Service.author.like(f"%{keyword}%"), + Service.authorId.like(f"%{keyword}%"), ), SearchType.NAME: Service.name.like(f"%{keyword}%"), SearchType.DESCRIPTION: Service.description.like(f"%{keyword}%"), - SearchType.AUTHOR: Service.author.like(f"%{keyword}%"), + SearchType.AUTHOR: Service.authorId.like(f"%{keyword}%"), } search_conditions = [search_condition_map.get(search_type)] @@ -85,7 +85,7 @@ class ServiceCenterManager: @staticmethod async def fetch_all_services( - user_sub: str, + user_id: str, search_type: SearchType, keyword: str | None, page: int, @@ -99,7 +99,7 @@ class ServiceCenterManager: service_pools = list((await session.scalars(data_query)).all()) total_count = (await session.scalar(count_query)) or 0 - fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_sub) + fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_id) services = [ ServiceCardItem( serviceId=service_pool.id, @@ -116,7 +116,7 @@ class ServiceCenterManager: @staticmethod async def fetch_user_services( - user_sub: str, + user_id: str, search_type: SearchType, keyword: str | None, page: int, @@ -124,18 +124,18 @@ class ServiceCenterManager: ) -> tuple[list[ServiceCardItem], int]: """获取用户创建的服务""" if search_type == SearchType.AUTHOR: - if keyword is not None and keyword not in user_sub: + if keyword is not None and user_id not in keyword: return [], 0 - keyword = user_sub + keyword = user_id async with postgres.session() as session: data_query, count_query = await ServiceCenterManager._build_service_query( - search_type, keyword, page, page_size, base_conditions=[Service.author == user_sub], + search_type, keyword, page, page_size, base_conditions=[Service.authorId == user_id], ) service_pools = list((await session.scalars(data_query)).all()) total_count = (await session.scalar(count_query)) or 0 - fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_sub) + fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_id) services = [ ServiceCardItem( serviceId=service_pool.id, @@ -152,7 +152,7 @@ class ServiceCenterManager: @staticmethod async def fetch_favorite_services( - user_sub: str, + user_id: str, search_type: SearchType, keyword: str | None, page: int, @@ -165,7 +165,7 @@ class ServiceCenterManager: select(UserFavorite.itemId) .where( and_( - UserFavorite.userSub == user_sub, + UserFavorite.userId == user_id, UserFavorite.favouriteType == UserFavoriteType.SERVICE, ), ) @@ -197,7 +197,7 @@ class ServiceCenterManager: @staticmethod async def create_service( - user_sub: str, + user_id: str, data: dict[str, Any], ) -> uuid.UUID: """创建服务""" @@ -225,7 +225,7 @@ class ServiceCenterManager: id=service_id, name=validated_data.id, description=validated_data.description, - author=user_sub, + author=user_id, api=ServiceApiConfig(server=validated_data.servers), permission=Permission(type=PermissionType.PUBLIC), # 默认公开 ) @@ -237,7 +237,7 @@ class ServiceCenterManager: @staticmethod async def update_service( - user_sub: str, + user_id: str, service_id: uuid.UUID, data: dict[str, Any], ) -> uuid.UUID: @@ -252,7 +252,7 @@ class ServiceCenterManager: if not db_service: msg = "[ServiceCenterManager] Service not found" raise RuntimeError(msg) - if db_service.author != user_sub: + if db_service.authorId != user_id: msg = "[ServiceCenterManager] Permission denied" raise RuntimeError(msg) db_service.updatedAt = datetime.now(tz=UTC) @@ -263,7 +263,7 @@ class ServiceCenterManager: id=service_id, name=validated_data.id, description=validated_data.description, - author=user_sub, + author=user_id, api=ServiceApiConfig(server=validated_data.servers), ) service_loader = ServiceLoader() @@ -309,7 +309,7 @@ class ServiceCenterManager: @staticmethod async def get_service_data( - user_sub: str, + user_id: str, service_id: uuid.UUID, ) -> tuple[str, dict[str, Any]]: """获取服务数据""" @@ -319,7 +319,7 @@ class ServiceCenterManager: select(Service).where( and_( Service.id == service_id, - Service.author == user_sub, + Service.authorId == user_id, ), ), )).one_or_none() @@ -338,17 +338,17 @@ class ServiceCenterManager: @staticmethod async def get_service_metadata( - user_sub: str, + user_id: str, service_id: uuid.UUID, ) -> ServiceMetadata: """获取服务元数据""" async with postgres.session() as session: allowed_user = list((await session.scalars( - select(ServiceACL.userSub).where( + select(ServiceACL.userId).where( ServiceACL.serviceId == service_id, ), )).all()) - if user_sub in allowed_user: + if user_id in allowed_user: db_service = (await session.scalars( select(Service).where( and_( @@ -364,11 +364,11 @@ class ServiceCenterManager: Service.id == service_id, or_( and_( - Service.author == user_sub, + Service.authorId == user_id, Service.permission == PermissionType.PRIVATE, ), Service.permission == PermissionType.PUBLIC, - Service.author == user_sub, + Service.authorId == user_id, ), ), ), @@ -386,14 +386,14 @@ class ServiceCenterManager: @staticmethod - async def delete_service(user_sub: str, service_id: uuid.UUID) -> None: + async def delete_service(user_id: str, service_id: uuid.UUID) -> None: """删除服务""" async with postgres.session() as session: db_service = (await session.scalars( select(Service).where( and_( Service.id == service_id, - Service.author == user_sub, + Service.authorId == user_id, ), ), )).one_or_none() @@ -422,7 +422,7 @@ class ServiceCenterManager: @staticmethod async def modify_favorite_service( - user_sub: str, + user_id: str, service_id: uuid.UUID, *, favorited: bool, @@ -440,12 +440,12 @@ class ServiceCenterManager: raise RuntimeError(msg) user = (await session.scalars( - select(func.count(User.userSub)).where( - User.userSub == user_sub, + select(func.count(User.id)).where( + User.id == user_id, ), )).one() if not user: - msg = f"[ServiceCenterManager] 用户未找到: {user_sub}" + msg = f"[ServiceCenterManager] 用户未找到: {user_id}" logger.warning(msg) raise RuntimeError(msg) @@ -454,7 +454,7 @@ class ServiceCenterManager: select(UserFavorite).where( and_( UserFavorite.itemId == service_id, - UserFavorite.userSub == user_sub, + UserFavorite.userId == user_id, UserFavorite.favouriteType == UserFavoriteType.SERVICE, ), ), @@ -463,7 +463,7 @@ class ServiceCenterManager: # 创建收藏条目 user_favourite = UserFavorite( itemId=service_id, - userSub=user_sub, + userId=user_id, favouriteType=UserFavoriteType.SERVICE, ) session.add(user_favourite) @@ -475,12 +475,12 @@ class ServiceCenterManager: @staticmethod - async def _get_favorite_service_ids_by_user(user_sub: str) -> list[uuid.UUID]: + async def _get_favorite_service_ids_by_user(user_id: str) -> list[uuid.UUID]: """获取用户收藏的服务ID""" async with postgres.session() as session: user_favourite = (await session.scalars( select(UserFavorite).where( - UserFavorite.userSub == user_sub, + UserFavorite.userId == user_id, UserFavorite.favouriteType == UserFavoriteType.SERVICE, ), )).all() diff --git a/apps/services/session.py b/apps/services/session.py index 94d377f94..1bd99452c 100644 --- a/apps/services/session.py +++ b/apps/services/session.py @@ -19,18 +19,18 @@ class SessionManager: """浏览器Session管理""" @staticmethod - async def create_session(user_sub: str, ip: str) -> str: + async def create_session(user_id: str, ip: str) -> str: """创建浏览器Session""" if not ip: err = "用户IP错误!" raise ValueError(err) - if not user_sub: - err = "用户名错误!" + if not user_id: + err = "用户ID错误!" raise ValueError(err) data = Session( - userSub=user_sub, + userId=user_id, ip=ip, validUntil=datetime.now(UTC) + timedelta(minutes=SESSION_TTL), sessionType=SessionType.CODE, @@ -57,26 +57,26 @@ class SessionManager: @staticmethod async def get_user(session_id: str) -> str | None: - """从Session中获取用户""" + """从Session中获取用户ID""" async with postgres.session() as session: - user_sub = ( - await session.scalars(select(Session.userSub).where(Session.id == session_id)) + user_id = ( + await session.scalars(select(Session.userId).where(Session.id == session_id)) ).one_or_none() - if not user_sub: + if not user_id: return None # 查询黑名单 - if user_sub and await UserBlacklistManager.check_blacklisted_users(user_sub): + if user_id and await UserBlacklistManager.check_blacklisted_users(user_id=user_id): logger.error("[SessionManager] 用户在Session黑名单中") await SessionManager.delete_session(session_id) return None - return user_sub + return user_id @staticmethod - async def get_session_by_user_sub(user_sub: str) -> str | None: - """根据用户sub获取Session""" + async def get_session_by_user(user_id: str) -> str | None: + """根据用户ID获取Session""" async with postgres.session() as session: - data = await session.scalars(select(Session.id).where(Session.userSub == user_sub)) + data = await session.scalars(select(Session.id).where(Session.userId == user_id)) return data.one_or_none() diff --git a/apps/services/settings.py b/apps/services/settings.py index 87ddad229..ae856b0d5 100644 --- a/apps/services/settings.py +++ b/apps/services/settings.py @@ -69,13 +69,13 @@ class SettingsManager: @staticmethod async def update_global_llm_settings( - user_sub: str, + user_id: str, req: UpdateSpecialLlmReq, ) -> None: """ 更新全局默认LLM(仅管理员) - :param user_sub: 操作的管理员user_sub + :param user_id: 操作的管理员user_id :param req: 更新请求体 """ # 验证functionLLM是否支持Function Call @@ -116,7 +116,7 @@ class SettingsManager: # 更新全局设置 settings.functionLlmId = req.functionLLM settings.embeddingLlmId = req.embeddingLLM - settings.lastEditedBy = user_sub + settings.lastEditedBy = user_id await session.commit() # 如果embedding模型发生变化,在新协程中触发向量化过程 diff --git a/apps/services/tag.py b/apps/services/tag.py index 57db184bc..77f47f72c 100644 --- a/apps/services/tag.py +++ b/apps/services/tag.py @@ -41,16 +41,16 @@ class TagManager: @staticmethod - async def get_tag_by_user_sub(user_sub: str) -> list[Tag]: + async def get_tag_by_user(user_id: str) -> list[Tag]: """ - 根据用户sub获取用户标签信息 + 根据用户ID获取用户标签信息 - :param user_sub: 用户sub + :param user_id: 用户ID :return: 用户标签信息 """ tags = [] async with postgres.session() as session: - user_tags = (await session.scalars(select(UserTag).where(UserTag.userSub == user_sub))).all() + user_tags = (await session.scalars(select(UserTag).where(UserTag.userId == user_id))).all() for user_tag in user_tags: tag = (await session.scalars(select(Tag).where(Tag.id == user_tag.tag))).one_or_none() if tag: diff --git a/apps/services/task.py b/apps/services/task.py index 6f9560511..8034ea66e 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -23,15 +23,15 @@ class TaskManager: """从数据库中获取任务信息""" @staticmethod - async def get_task_by_conversation_id(conversation_id: uuid.UUID, user_sub: str) -> Task: + async def get_task_by_conversation_id(conversation_id: uuid.UUID, user_id: str) -> Task: """获取对话ID的最后一个任务""" async with postgres.session() as session: - # 检查user_sub是否匹配Conversation + # 检查user_id是否匹配Conversation conversation = (await session.scalars( select(Conversation.id).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one_or_none() @@ -46,7 +46,7 @@ class TaskManager: # 任务不存在,新建Task return Task( conversationId=conversation_id, - userSub=user_sub, + userId=user_id, ) logger.info("[TaskManager] 新建任务 %s", task.id) return task diff --git a/apps/services/token.py b/apps/services/token.py index 221d4acbf..a5d0eadbc 100644 --- a/apps/services/token.py +++ b/apps/services/token.py @@ -32,8 +32,8 @@ class TokenManager: expire_time: int, ) -> str: """获取插件Token""" - user_sub = await SessionManager.get_user(session_id=session_id) - if not user_sub: + user_id = await SessionManager.get_user(session_id=session_id) + if not user_id: err = "用户不存在!" raise ValueError(err) @@ -43,7 +43,7 @@ class TokenManager: select(Session) .where( and_( - Session.userSub == user_sub, + Session.userId == user_id, Session.sessionType == SessionType.ACCESS_TOKEN, Session.pluginId == str(plugin_id), ), @@ -57,7 +57,7 @@ class TokenManager: token = await TokenManager.generate_plugin_token( plugin_id, session_id, - user_sub, + user_id, access_token_url, expire_time, ) @@ -66,7 +66,7 @@ class TokenManager: raise RuntimeError(err) await session.merge(Session( - userSub=user_sub, + userId=user_id, sessionType=SessionType.PLUGIN_TOKEN, pluginId=str(plugin_id), token=token, @@ -89,7 +89,7 @@ class TokenManager: async def generate_plugin_token( plugin_name: uuid.UUID, session_id: str, - user_sub: str, + user_id: str, access_token_url: str, expire_time: int, ) -> str | None: @@ -100,7 +100,7 @@ class TokenManager: select(Session) .where( and_( - Session.userSub == user_sub, + Session.userId == user_id, Session.sessionType == SessionType.ACCESS_TOKEN, ), ), @@ -117,7 +117,7 @@ class TokenManager: select(Session) .where( and_( - Session.userSub == user_sub, + Session.userId == user_id, Session.sessionType == SessionType.REFRESH_TOKEN, ), ), @@ -138,7 +138,7 @@ class TokenManager: # 更新OIDC token async with postgres.session() as session: await session.merge(Session( - userSub=user_sub, + userId=user_id, sessionType=SessionType.ACCESS_TOKEN, token=oidc_access_token, validUntil=datetime.now(UTC) + timedelta(minutes=OIDC_ACCESS_TOKEN_EXPIRE_TIME), @@ -165,7 +165,7 @@ class TokenManager: # 保存插件token async with postgres.session() as session: await session.merge(Session( - userSub=user_sub, + userId=user_id, sessionType=SessionType.ACCESS_TOKEN, pluginId=str(plugin_name), token=ret["access_token"], @@ -177,14 +177,14 @@ class TokenManager: @staticmethod - async def delete_plugin_token(user_sub: str) -> None: + async def delete_plugin_token(user_id: str) -> None: """删除插件token(使用PostgreSQL)""" async with postgres.session() as session: token_data = (await session.scalars( select(Session) .where( and_( - Session.userSub == user_sub, + Session.userId == user_id, Session.sessionType.in_([ SessionType.ACCESS_TOKEN, SessionType.REFRESH_TOKEN, diff --git a/apps/services/user_tag.py b/apps/services/user_tag.py index 4e211bc61..15d514558 100644 --- a/apps/services/user_tag.py +++ b/apps/services/user_tag.py @@ -16,10 +16,10 @@ class UserTagManager: """用户画像管理""" @staticmethod - async def get_user_domain_by_user_sub_and_topk(user_sub: str, topk: int | None = None) -> list[UserTagInfo]: + async def get_user_domain_by_user_and_topk(user_id: str, topk: int | None = None) -> list[UserTagInfo]: """根据用户ID,查询用户最常涉及的n个领域""" async with postgres.session() as session: - query = select(UserTag).where(UserTag.userSub == user_sub).order_by(UserTag.count.desc()) + query = select(UserTag).where(UserTag.userId == user_id).order_by(UserTag.count.desc()) if topk is not None: query = query.limit(topk) @@ -35,7 +35,7 @@ class UserTagManager: @staticmethod - async def update_user_domain_by_user_sub_and_domain_name(user_sub: str, domain_name: str) -> None: + async def update_user_domain_by_user_and_domain_name(user_id: str, domain_name: str) -> None: """增加特定用户特定领域的频次""" async with postgres.session() as session: tag = ( @@ -52,7 +52,7 @@ class UserTagManager: await session.scalars( select(UserTag).where( and_( - UserTag.userSub == user_sub, + UserTag.userId == user_id, UserTag.tag == tag.id, ), ), @@ -60,7 +60,7 @@ class UserTagManager: ).one_or_none() if not user_domain: - user_domain = UserTag(userSub=user_sub, tag=tag.id, count=1) + user_domain = UserTag(userId=user_id, tag=tag.id, count=1) await session.merge(user_domain) else: user_domain.count += 1 diff --git a/tests/common/test_queue.py b/tests/common/test_queue.py index e1517277d..ff0b4780c 100644 --- a/tests/common/test_queue.py +++ b/tests/common/test_queue.py @@ -7,7 +7,7 @@ import pytest from apps.common.queue import MessageQueue from apps.schemas.enum_var import EventType -from apps.schemas.message import HeartbeatData, MessageBase, MessageFlow, MessageMetadata +from apps.schemas.message import HeartbeatData, MessageBase, MessageExecutor, MessageMetadata from apps.schemas.task import Task diff --git a/tests/schemas/test_user.py b/tests/schemas/test_user.py index 160fb2c10..379cf07e8 100644 --- a/tests/schemas/test_user.py +++ b/tests/schemas/test_user.py @@ -1,17 +1,17 @@ """UserInfo模型的单元测试""" -from apps.schemas.user import UserInfo +from apps.schemas.user import UserListItem def test_user_info_creation() -> None: """测试UserInfo对象的创建""" # 测试默认值 - user = UserInfo() + user = UserListItem() assert user.user_sub == "" assert user.user_name == "" # 测试指定值 - user = UserInfo(user_sub="sub123", user_name="test_user") # pyright: ignore[reportCallIssue] + user = UserListItem(user_sub="sub123", user_name="test_user") # pyright: ignore[reportCallIssue] assert user.user_sub == "sub123" assert user.user_name == "test_user" @@ -19,7 +19,7 @@ def test_user_info_creation() -> None: def test_user_info_alias() -> None: """测试UserInfo的别名功能""" # 测试使用别名创建对象 - user = UserInfo(userSub="sub123", userName="test_user") + user = UserListItem(userSub="sub123", userName="test_user") assert user.user_sub == "sub123" assert user.user_name == "test_user" @@ -27,19 +27,19 @@ def test_user_info_alias() -> None: def test_user_info_validation() -> None: """测试UserInfo的数据验证""" # 测试正常情况 - user = UserInfo(userSub="sub123", userName="test_user") + user = UserListItem(userSub="sub123", userName="test_user") assert user.user_sub == "sub123" assert user.user_name == "test_user" # 测试空字符串 - user = UserInfo(userSub="", userName="") + user = UserListItem(userSub="", userName="") assert user.user_sub == "" assert user.user_name == "" def test_user_info_str_representation() -> None: """测试UserInfo的字符串表示""" - user = UserInfo(userSub="sub123", userName="test_user") + user = UserListItem(userSub="sub123", userName="test_user") # 确保对象可以正确转换为字符串 str_repr = str(user) assert "UserInfo" in str_repr diff --git a/tests/services/test_conversation.py b/tests/services/test_conversation.py index 3c9933d8d..f0ab79561 100644 --- a/tests/services/test_conversation.py +++ b/tests/services/test_conversation.py @@ -10,7 +10,7 @@ from apps.services.conversation import ConversationManager class TestUserQaRecordManager(unittest.TestCase): - def test_get_user_qa_record_by_user_sub(self): + def test_get_user_qa_record_by_user(self): user_sub = "test_user_sub" with patch.object(MysqlDB, 'get_session') as mock_get_session: mock_session = MagicMock() @@ -19,7 +19,7 @@ class TestUserQaRecordManager(unittest.TestCase): mock_query.filter().all = mock_all mock_session.query.return_value = mock_query mock_get_session.return_value.__enter__.return_value = mock_session - result = ConversationManager.get_conversation_by_user_sub(user_sub) + result = ConversationManager.get_conversation_by_user(user_sub) self.assertEqual(result, ["result"]) def test_get_user_qa_record_by_session_id(self): @@ -34,7 +34,7 @@ class TestUserQaRecordManager(unittest.TestCase): result = ConversationManager.get_conversation_by_conversation_id(session_id) self.assertEqual(result, "result") - def test_add_user_qa_record_by_user_sub(self): + def test_add_user_qa_record_by_user(self): user_sub = "test_user_sub" with patch.object(MysqlDB, 'get_session') as mock_get_session: mock_session = MagicMock() @@ -43,7 +43,7 @@ class TestUserQaRecordManager(unittest.TestCase): mock_session.add = mock_add mock_session.commit = mock_commit mock_get_session.return_value.__enter__.return_value = mock_session - result = ConversationManager.add_conversation_by_user_sub(user_sub) + result = ConversationManager.add_conversation_by_user(user_sub) self.assertIsInstance(result, str) def test_update_user_qa_record_by_session_id(self): @@ -83,7 +83,7 @@ class TestUserQaRecordManager(unittest.TestCase): ConversationManager.delete_conversation_by_conversation_id(session_id) mock_delete.assert_called_once() - def test_delete_user_qa_record_by_user_sub(self): + def test_delete_user_qa_record_by_user(self): user_sub = "test_user_sub" with patch.object(MysqlDB, 'get_session') as mock_get_session: mock_session = MagicMock() @@ -94,7 +94,7 @@ class TestUserQaRecordManager(unittest.TestCase): mock_query.filter().delete = mock_delete mock_session.commit = mock_commit mock_get_session.return_value.__enter__.return_value = mock_session - ConversationManager.delete_conversation_by_user_sub(user_sub) + ConversationManager.delete_conversation_by_user(user_sub) mock_delete.assert_called_once() -- Gitee From f39acc53c46fc0a194bf1aaf53e6f364fdcd3cea Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 24 Oct 2025 17:28:15 +0800 Subject: [PATCH 2/5] =?UTF-8?q?userSub=E6=94=B9userId=EF=BC=882=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/oidc_provider/authhub.py | 1 + apps/routers/auth.py | 37 ++------------------ apps/routers/user.py | 27 ++++++++++----- apps/services/user.py | 31 ++++++++--------- design/call/core.md | 2 +- design/executor/agent.md | 2 +- design/services/appcenter.md | 20 +++++------ design/services/blacklist.md | 30 ++++++++-------- design/services/comment.md | 22 ++++++------ design/services/conversation.md | 52 ++++++++++++++-------------- design/services/document.md | 20 +++++------ design/services/flow.md | 6 ++-- design/services/mcp_service.md | 28 +++++++-------- design/services/parameter.md | 2 +- design/services/session.md | 14 ++++---- design/services/tag.md | 12 +++---- design/services/user.md | 12 +++---- 17 files changed, 148 insertions(+), 170 deletions(-) diff --git a/apps/common/oidc_provider/authhub.py b/apps/common/oidc_provider/authhub.py index 43f75530c..7b3fbe051 100644 --- a/apps/common/oidc_provider/authhub.py +++ b/apps/common/oidc_provider/authhub.py @@ -92,6 +92,7 @@ class AuthhubOIDCProvider(OIDCProviderBase): return { "user_sub": result["data"], + "user_name": result["data"], } @classmethod diff --git a/apps/routers/auth.py b/apps/routers/auth.py index c6ecc40a1..85c10e1a7 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -10,14 +10,10 @@ from fastapi import APIRouter, Body, Depends, Request, status from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates -from apps.common.config import config from apps.common.oidc import oidc_provider from apps.dependency import verify_personal_token, verify_session from apps.schemas.personal_token import PostPersonalTokenMsg, PostPersonalTokenRsp -from apps.schemas.request_data import UserUpdateRequest from apps.schemas.response_data import ( - AuthUserMsg, - AuthUserRsp, OidcRedirectMsg, OidcRedirectRsp, ResponseData, @@ -81,7 +77,9 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: status_code=status.HTTP_403_FORBIDDEN, ) # 更新用户信息 - await UserManager.update_user(user_sub, UserUpdateRequest(lastLogin=datetime.now(UTC))) + await UserManager.update_user(user_sub, { + "lastLogin": datetime.now(UTC), + }) # 创建会话 current_session = await SessionManager.create_session(user_sub, user_host) return _templates.TemplateResponse( @@ -145,35 +143,6 @@ async def oidc_logout(token: Annotated[str, Body()]) -> JSONResponse: ) -@router.get("/user", response_model=AuthUserRsp) -async def userinfo(request: Request) -> JSONResponse: - """GET /auth/user: 获取当前用户信息""" - user = await UserManager.get_user(user_sub=request.state.user_sub) - if not user: - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="Get UserInfo failed.", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) - is_admin = user.userSub in config.login.admin_user - return JSONResponse( - status_code=status.HTTP_200_OK, - content=AuthUserRsp( - code=status.HTTP_200_OK, - message="success", - result=AuthUserMsg( - user_sub=request.state.user_sub, - revision=user.isActive, - is_admin=is_admin, - auto_execute=user.autoExecute or False, - ), - ).model_dump(exclude_none=True, by_alias=True), - ) - - @router.post("/key", responses={ 400: {"model": ResponseData}, }, response_model=PostPersonalTokenRsp) diff --git a/apps/routers/user.py b/apps/routers/user.py index 2554f9139..df7b9f550 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -6,9 +6,9 @@ from fastapi.responses import JSONResponse from apps.dependency import verify_personal_token, verify_session from apps.schemas.request_data import UserUpdateRequest -from apps.schemas.response_data import ResponseData, UserGetMsp, UserGetRsp +from apps.schemas.response_data import ResponseData from apps.schemas.tag import UserTagListResponse -from apps.schemas.user import UserInfo +from apps.schemas.user import UserListItem, UserListMsg, UserListRsp from apps.services.user import UserManager from apps.services.user_tag import UserTagManager @@ -21,9 +21,9 @@ router = APIRouter( @router.post("", response_model=ResponseData) async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONResponse: - """POST /auth/user: 更新当前用户信息""" + """POST /api/user: 更新当前用户信息""" # 更新用户信息 - await UserManager.update_user(request.state.user_sub, data) + await UserManager.update_user(request.state.user_sub, data.model_dump(exclude_unset=True, exclude_none=True)) return JSONResponse( status_code=status.HTTP_200_OK, @@ -31,7 +31,18 @@ async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONRes ) +# TODO @router.get("") +async def get_user_info(request: Request) -> JSONResponse: + """GET /api/user/info: 获取当前用户信息""" + user = await UserManager.get_user(request.state.user_sub) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ResponseData(code=status.HTTP_200_OK, message="success", result=user), + ) + + +@router.get("/list") async def list_user( request: Request, page_size: int = 10, page_num: int = 1, ) -> JSONResponse: @@ -41,7 +52,7 @@ async def list_user( for user in user_list: if user.userSub == request.state.user_sub: continue - info = UserInfo( + info = UserListItem( userName=user.userSub, userSub=user.userSub, ) @@ -49,10 +60,10 @@ async def list_user( return JSONResponse( status_code=status.HTTP_200_OK, - content=UserGetRsp( + content=UserListRsp( code=status.HTTP_200_OK, message="用户数据详细信息获取成功", - result=UserGetMsp(userInfoList=user_info_list, total=total), + result=UserListMsg(userInfoList=user_info_list, total=total), ).model_dump(exclude_none=True, by_alias=True), ) @@ -66,7 +77,7 @@ async def get_user_tag( ) -> JSONResponse: """GET /user/tag?topk=5: 获取用户标签""" try: - tags = await UserTagManager.get_user_domain_by_user_sub_and_topk(request.state.user_sub, topk) + tags = await UserTagManager.get_user_domain_by_user_and_topk(request.state.user_sub, topk) except ValueError as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/services/user.py b/apps/services/user.py index 5ee6c4113..950425c27 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -2,13 +2,13 @@ """用户 Manager""" import logging +import secrets from datetime import UTC, datetime from sqlalchemy import func, select from apps.common.postgres import postgres from apps.models import User -from apps.schemas.request_data import UserUpdateRequest from .conversation import ConversationManager @@ -34,37 +34,34 @@ class UserManager: @staticmethod - async def get_user(user_sub: str) -> User | None: + async def get_user(user_id: str) -> User | None: """ - 根据用户sub获取用户信息 + 根据用户ID获取用户信息 - :param user_sub: 用户sub + :param user_id: 用户ID :return: 用户信息 """ async with postgres.session() as session: return ( - await session.scalars(select(User).where(User.userSub == user_sub)) + await session.scalars(select(User).where(User.id == user_id)) ).one_or_none() @staticmethod - async def update_user(user_sub: str, data: UserUpdateRequest) -> None: + async def update_user(user_id: str, data: dict) -> None: """ 根据用户sub更新用户信息 - :param user_sub: 用户sub + :param user_id: 用户ID :param data: 更新数据 """ - # 将 Pydantic 模型转换为字典 - update_data = data.model_dump(exclude_unset=True, exclude_none=True) - async with postgres.session() as session: user = ( - await session.scalars(select(User).where(User.userSub == user_sub)) + await session.scalars(select(User).where(User.id == user_id)) ).one_or_none() if not user: user = User( - userSub=user_sub, + userName=data.get("userName", f"用户-{secrets.token_hex(8)}"), isActive=True, isWhitelisted=False, credit=0, @@ -74,7 +71,7 @@ class UserManager: return # 更新指定字段 - for key, value in update_data.items(): + for key, value in data.items(): if hasattr(user, key) and value is not None: setattr(user, key, value) @@ -82,15 +79,15 @@ class UserManager: await session.commit() @staticmethod - async def delete_user(user_sub: str) -> None: + async def delete_user(user_id: str) -> None: """ 根据用户sub删除用户信息 - :param user_sub: 用户sub + :param user_id: 用户ID """ async with postgres.session() as session: user = ( - await session.scalars(select(User).where(User.userSub == user_sub)) + await session.scalars(select(User).where(User.id == user_id)) ).one_or_none() if not user: return @@ -98,4 +95,4 @@ class UserManager: await session.delete(user) await session.commit() - await ConversationManager.delete_conversation_by_user_sub(user_sub) + await ConversationManager.delete_conversation_by_user(user_id) diff --git a/design/call/core.md b/design/call/core.md index e0216dabf..1910f0724 100644 --- a/design/call/core.md +++ b/design/call/core.md @@ -108,7 +108,7 @@ DataBase提供了动态Schema填充能力,允许子类在运行时通过overri - `task_id`: executor.task.metadata.id - `executor_id`: executor.task.state.executorId - `session_id`: executor.task.runtime.sessionId -- `user_sub`: executor.task.metadata.userSub +- `user_id`: executor.task.metadata.userId - `app_id`: executor.task.state.appId - `conversation_id`: executor.task.metadata.conversationId - `question`: executor.question diff --git a/design/executor/agent.md b/design/executor/agent.md index 87b4cd772..f2681535d 100644 --- a/design/executor/agent.md +++ b/design/executor/agent.md @@ -710,7 +710,7 @@ erDiagram Task { UUID id PK "任务ID" - string userSub FK "用户ID" + int userId FK "用户ID" UUID conversationId FK "对话ID" UUID checkpointId FK "检查点ID" datetime updatedAt "更新时间" diff --git a/design/services/appcenter.md b/design/services/appcenter.md index 5b49eab97..8fb04ec3f 100644 --- a/design/services/appcenter.md +++ b/design/services/appcenter.md @@ -335,7 +335,7 @@ sequenceDiagram Client->>Router: POST /api/app (无appId) Router->>Router: 验证用户会话和令牌 - Router->>Manager: create_app(user_sub, request) + Router->>Manager: create_app(user_id, request) Manager->>Manager: 生成新的app_id Manager->>Manager: _process_app_and_save() Manager->>Pool: 从资源池获取app_loader @@ -369,8 +369,8 @@ sequenceDiagram Client->>Router: POST /api/app (含appId) Router->>Router: 验证用户会话和令牌 - Router->>Manager: update_app(user_sub, app_id, request) - Manager->>Manager: _get_app_data(app_id, user_sub) + Router->>Manager: update_app(user_id, app_id, request) + Manager->>Manager: _get_app_data(app_id, user_id) Manager->>DB: 查询应用信息 alt 应用不存在 DB-->>Manager: 返回空 @@ -378,7 +378,7 @@ sequenceDiagram Router-->>Client: 400 BAD_REQUEST else 权限不足 DB-->>Manager: 返回应用信息 - Manager->>Manager: 检查author != user_sub + Manager->>Manager: 检查author != user_id Manager-->>Router: 抛出InstancePermissionError Router-->>Client: 403 FORBIDDEN else 更改应用类型 @@ -456,8 +456,8 @@ sequenceDiagram Client->>Router: POST /api/app/{appId} Router->>Router: 验证用户会话和令牌 - Router->>Manager: update_app_publish_status(app_id, user_sub) - Manager->>Manager: _get_app_data(app_id, user_sub) + Router->>Manager: update_app_publish_status(app_id, user_id) + Manager->>Manager: _get_app_data(app_id, user_id) Manager->>DB: 查询应用并验证权限 alt 权限验证失败 @@ -783,7 +783,7 @@ erDiagram AppACL { uuid appId PK "FK" - string userSub "FK" + int userId "FK" string action } @@ -794,7 +794,7 @@ erDiagram } User { - string userSub PK + int userId PK datetime lastLogin boolean isActive string personalToken @@ -802,14 +802,14 @@ erDiagram UserFavorite { int id PK - string userSub "FK" + int userId "FK" enum favouriteType uuid itemId } UserAppUsage { int id PK - string userSub "FK" + int userId "FK" uuid appId "FK" int usageCount datetime lastUsed diff --git a/design/services/blacklist.md b/design/services/blacklist.md index 804f39c0e..1199925cb 100644 --- a/design/services/blacklist.md +++ b/design/services/blacklist.md @@ -26,7 +26,7 @@ erDiagram Record ||--|| User : "属于用户" Blacklist { - bigint id PK "主键ID" + int id PK "主键ID" uuid recordId FK "问答对ID" text question "黑名单问题" text answer "固定回答" @@ -38,14 +38,14 @@ erDiagram Record { uuid id PK "问答对ID" - string userSub FK "用户标识" + int userId FK "用户标识" text content "加密内容" text key "加密密钥" } User { - bigint id PK "用户ID" - string userSub UK "用户标识" + int id PK "用户ID" + int userId UK "用户标识" int credit "信用分" boolean isWhitelisted "白名单标识" } @@ -112,7 +112,7 @@ sequenceDiagram participant DB as 数据库 Admin->>Router: POST /api/blacklist/user - Router->>Manager: change_blacklisted_users(user_sub, credit_diff) + Router->>Manager: change_blacklisted_users(user_id, credit_diff) Manager->>DB: 查询用户信息 DB-->>Manager: 返回用户数据 @@ -150,10 +150,10 @@ sequenceDiagram 1. **get_blacklisted_users(limit, offset)**: 分页获取所有信用分小于等于0的 黑名单用户标识列表。 -2. **check_blacklisted_users(user_sub)**: 检测指定用户是否被拉黑。 +2. **check_blacklisted_users(user_id)**: 检测指定用户是否被拉黑。 同时检查信用分和白名单状态,白名单用户即使信用分为0也不会被拦截。 -3. **change_blacklisted_users(user_sub, credit_diff, credit_limit)**: +3. **change_blacklisted_users(user_id, credit_diff, credit_limit)**: 修改用户信用分。传入负值用于封禁用户,传入正值用于解禁用户。 系统会自动计算新信用分并确保其在有效范围内。 @@ -260,7 +260,7 @@ sequenceDiagram #### 举报管理主要方法 -1. **change_abuse_report(user_sub, record_id, reason_type, reason)**: +1. **change_abuse_report(user_id, record_id, reason_type, reason)**: 用户提交举报。验证问答对归属关系,解密内容后创建待审核黑名单记录, 记录举报类型和原因。 @@ -300,10 +300,10 @@ X-Personal-Token: "code": 200, "message": "ok", "result": { - "user_subs": [ - "user-sub-123", - "user-sub-456", - "user-sub-789" + "userName": [ + "user-123", + "user-456", + "user-789" ] } } @@ -320,7 +320,7 @@ Authorization: Bearer X-Personal-Token: { - "user_sub": "user-sub-123", + "userName": "user-123", "is_ban": 1 } ``` @@ -329,7 +329,7 @@ X-Personal-Token: ```json { - "user_sub": "user-sub-123", + "userName": "user-123", "is_ban": 0 } ``` @@ -683,7 +683,7 @@ sequenceDiagram | 错误场景 | HTTP状态码 | 错误信息 | 处理建议 | |---------|-----------|---------|---------| -| 用户不存在 | 500 | Change user blacklist error. | 检查user_sub是否正确 | +| 用户不存在 | 500 | Change user blacklist error. | 检查userId是否正确 | | 黑名单记录不存在 | 500 | Modify question blacklist error. | 检查黑名单ID是否有效 | | 举报记录不合法 | 500 | Report abuse complaint error. | 确认问答对ID正确且属于当前用户 | | 待审核问题不存在 | 500 | Audit abuse question error. | 检查record_id是否有效且为待审核状态 | diff --git a/design/services/comment.md b/design/services/comment.md index e85d6c68d..5c553f401 100644 --- a/design/services/comment.md +++ b/design/services/comment.md @@ -30,7 +30,7 @@ erDiagram COMMENT { bigint id "主键ID (Primary Key, Auto Increment)" uuid recordId "问答对ID (Foreign Key → framework_record.id, Indexed)" - string userSub "用户标识 (Foreign Key → framework_user.userSub, max 50)" + int userId "用户标识 (Foreign Key → framework_user.id)" CommentType commentType "评论类型 (Not Null)" string[] feedbackType "投诉类别列表 (Not Null)" string feedbackLink "投诉相关链接 (Not Null, max 1000)" @@ -57,7 +57,7 @@ erDiagram 提供评论的核心业务逻辑操作: - **query_comment(record_id: str)**: 根据问答ID查询评论 -- **update_comment(record_id: str, data: RecordComment, user_sub: str)**: 创建或更新评论 +- **update_comment(record_id: str, data: RecordComment, user_id: str)**: 创建或更新评论 ### 3. 路由层 (routers/comment.py) @@ -151,7 +151,7 @@ flowchart TD CreateDTO --> QueryDB{查询数据库
记录是否存在?} QueryDB -->|存在| Update[更新现有记录
commentType
feedbackType
feedbackLink
feedbackContent] - QueryDB -->|不存在| Create[创建新记录
包含recordId
userSub等字段] + QueryDB -->|不存在| Create[创建新记录
包含recordId
userId等字段] Update --> Commit[提交事务] Create --> Merge[Merge操作] @@ -198,7 +198,7 @@ sequenceDiagram Router->>Router: 解析dislike_reason
分号分隔 → list Router->>Router: 创建RecordComment对象
设置feedback_time - Router->>Manager: update_comment(record_id, data, user_sub) + Router->>Manager: update_comment(record_id, data, user_id) activate Manager Manager->>Session: 创建异步会话 @@ -299,7 +299,7 @@ classDiagram class Comment { +int id +UUID recordId - +str userSub + +int userId +CommentType commentType +list~str~ feedbackType +str feedbackLink @@ -333,7 +333,7 @@ classDiagram class CommentManager { +query_comment(record_id)$ RecordComment|None - +update_comment(record_id, data, user_sub)$ None + +update_comment(record_id, data, user_id)$ None } class Router { @@ -357,7 +357,7 @@ erDiagram FRAMEWORK_RECORD ||--o| FRAMEWORK_COMMENT : has FRAMEWORK_USER { - string userSub PK + int userId PK string userName datetime createdAt } @@ -365,7 +365,7 @@ erDiagram FRAMEWORK_RECORD { uuid id PK uuid conversationId - string userSub FK + int userId FK string content datetime createdAt } @@ -373,7 +373,7 @@ erDiagram FRAMEWORK_COMMENT { bigint id PK uuid recordId FK - string userSub FK + int userId FK enum commentType array feedbackType string feedbackLink @@ -531,7 +531,7 @@ curl -X POST "http://localhost:8000/api/comment" \ ### 2. 数据完整性 - recordId 外键约束 → framework_record.id -- userSub 外键约束 → framework_user.userSub +- userId 外键约束 → framework_user.userId - 索引优化:recordId 字段建立索引提高查询性能 ### 3. 字段别名映射 @@ -574,7 +574,7 @@ curl -X POST "http://localhost:8000/api/comment" \ ### 1. 认证与授权 - **双重认证**:Session + Personal Token -- **用户隔离**:userSub 关联确保数据隔离 +- **用户隔离**:userId 关联确保数据隔离 - **依赖注入**:在路由层统一进行身份验证 ### 2. 输入验证 diff --git a/design/services/conversation.md b/design/services/conversation.md index 9aca8c31a..965319f95 100644 --- a/design/services/conversation.md +++ b/design/services/conversation.md @@ -49,7 +49,7 @@ erDiagram Conversation { uuid id PK - string userSub FK + int userId FK uuid appId FK string title boolean isTemporary @@ -57,7 +57,7 @@ erDiagram } UserAppUsage { - string userSub FK + int userId FK uuid appId FK int usageCount datetime lastUsed @@ -90,8 +90,8 @@ sequenceDiagram C->>R: GET /api/conversation R->>Auth: 验证会话和令牌 Auth-->>R: 验证通过 - R->>CM: get_conversation_by_user_sub(user_sub) - CM->>DB: SELECT * FROM Conversation WHERE userSub=? + R->>CM: get_conversation_by_user_id(user_id) + CM->>DB: SELECT * FROM Conversation WHERE userId=? DB-->>CM: 返回对话列表 CM-->>R: 返回Conversation对象列表 @@ -116,7 +116,7 @@ sequenceDiagram ```mermaid flowchart TD - Start([开始创建对话]) --> Input[接收参数: title, user_sub, app_id, debug] + Start([开始创建对话]) --> Input[接收参数: title, user_id, app_id, debug] Input --> CreateObj[创建Conversation对象] CreateObj --> SaveDB{保存到数据库} SaveDB -->|失败| Error[记录错误日志] @@ -162,15 +162,15 @@ sequenceDiagram C->>R: POST /api/conversation?conversationId=xxx Note over C,R: Body: {title: "新标题"} - R->>CM: get_conversation_by_conversation_id(user_sub, conversation_id) - CM->>DB: SELECT * FROM Conversation WHERE id=? AND userSub=? + R->>CM: get_conversation_by_conversation_id(user_id, conversation_id) + CM->>DB: SELECT * FROM Conversation WHERE id=? AND userId=? DB-->>CM: 返回对话对象 CM-->>R: 返回Conversation alt 对话不存在或权限不足 R-->>C: 400 Bad Request else 对话存在 - R->>CM: update_conversation_by_conversation_id(user_sub, id, data) + R->>CM: update_conversation_by_conversation_id(user_id, id, data) CM->>DB: UPDATE Conversation SET title=? WHERE id=? DB-->>CM: 更新成功 CM-->>R: 返回True @@ -205,8 +205,8 @@ sequenceDiagram Note over C,R: Body: {conversation_list: [id1, id2, ...]} loop 遍历每个conversation_id - R->>CM: delete_conversation_by_conversation_id(user_sub, id) - CM->>DB: SELECT * FROM Conversation WHERE id=? AND userSub=? + R->>CM: delete_conversation_by_conversation_id(user_id, id) + CM->>DB: SELECT * FROM Conversation WHERE id=? AND userId=? DB-->>CM: 返回对话对象 alt 对话存在 @@ -218,7 +218,7 @@ sequenceDiagram CM-->>R: 删除完成 - R->>DM: delete_document_by_conversation_id(user_sub, id) + R->>DM: delete_document_by_conversation_id(user_id, id) DM-->>R: 文档删除完成 end @@ -238,9 +238,9 @@ sequenceDiagram ```mermaid flowchart TD - Start([开始验证]) --> Input[输入: user_sub, conversation_id] + Start([开始验证]) --> Input[输入: user_id, conversation_id] Input --> Query[查询数据库] - Query --> SQL["SELECT COUNT(id) FROM Conversation
WHERE id=? AND userSub=?"] + Query --> SQL["SELECT COUNT(id) FROM Conversation
WHERE id=? AND userId=?"] SQL --> GetResult[获取查询结果] GetResult --> CheckCount{count > 0?} CheckCount -->|是| ReturnTrue[返回 True] @@ -342,25 +342,25 @@ flowchart TD | 方法名 | 功能描述 | 参数 | 返回值 | |-------|---------|------|--------| -| `get_conversation_by_user_sub` | 获取用户的对话列表 | `user_sub: str` | `list[Conversation]` | -| `get_conversation_by_conversation_id` | 通过ID获取对话 | `user_sub: str, conversation_id: UUID` | `Conversation \| None` | -| `verify_conversation_access` | 验证对话访问权限 | `user_sub: str, conversation_id: UUID` | `bool` | -| `add_conversation_by_user_sub` | 创建新对话 | `title: str, user_sub: str, app_id: UUID, debug: bool` | `Conversation \| None` | -| `update_conversation_by_conversation_id` | 更新对话信息 | `user_sub: str, conversation_id: UUID, data: dict` | `bool` | -| `delete_conversation_by_conversation_id` | 删除指定对话 | `user_sub: str, conversation_id: UUID` | `None` | -| `delete_conversation_by_user_sub` | 删除用户所有对话 | `user_sub: str` | `None` | -| `verify_conversation_id` | 验证对话ID有效性 | `user_sub: str, conversation_id: UUID` | `bool` | +| `get_conversation_by_user_id` | 获取用户的对话列表 | `user_id: str` | `list[Conversation]` | +| `get_conversation_by_conversation_id` | 通过ID获取对话 | `user_id: str, conversation_id: UUID` | `Conversation \| None` | +| `verify_conversation_access` | 验证对话访问权限 | `user_id: str, conversation_id: UUID` | `bool` | +| `add_conversation_by_user_id` | 创建新对话 | `title: str, user_id: str, app_id: UUID, debug: bool` | `Conversation \| None` | +| `update_conversation_by_conversation_id` | 更新对话信息 | `user_id: str, conversation_id: UUID, data: dict` | `bool` | +| `delete_conversation_by_conversation_id` | 删除指定对话 | `user_id: str, conversation_id: UUID` | `None` | +| `delete_conversation_by_user_id` | 删除用户所有对话 | `user_id: str` | `None` | +| `verify_conversation_id` | 验证对话ID有效性 | `user_id: str, conversation_id: UUID` | `bool` | ### 5.2 方法调用关系 ```mermaid graph LR - API[API Router] --> Get[get_conversation_by_user_sub] + API[API Router] --> Get[get_conversation_by_user_id] API --> GetById[get_conversation_by_conversation_id] - API --> Add[add_conversation_by_user_sub] + API --> Add[add_conversation_by_user_id] API --> Update[update_conversation_by_conversation_id] API --> Delete[delete_conversation_by_conversation_id] - API --> DeleteUser[delete_conversation_by_user_sub] + API --> DeleteUser[delete_conversation_by_user_id] API --> Verify[verify_conversation_access] API --> VerifyId[verify_conversation_id] @@ -378,12 +378,12 @@ graph LR ### 6.1 权限控制 -所有操作都基于 `user_sub` 进行权限验证,确保用户只能访问自己的对话: +所有操作都基于 `user_id` 进行权限验证,确保用户只能访问自己的对话: ```python and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ) ``` diff --git a/design/services/document.md b/design/services/document.md index f559a7bd8..de54088ea 100644 --- a/design/services/document.md +++ b/design/services/document.md @@ -29,7 +29,7 @@ Document 模块负责处理文档上传、存储、检索和删除等核心功 ```python { "id": "uuid", - "userSub": "string", # 用户标识 + "userId": "string", # 用户标识 "name": "string", # 文件名 "extension": "string", # MIME 类型 "size": "float", # 文件大小 (KB) @@ -508,11 +508,11 @@ sequenceDiagram User->>API: POST /api/document/:conv_id API->>Auth: verify_session() - Auth-->>API: ✓ user_sub + Auth-->>API: ✓ user_id API->>Auth: verify_personal_token() Auth-->>API: ✓ session_id - API->>DM: storage_docs(user_sub, conv_id, files) + API->>DM: storage_docs(user_id, conv_id, files) loop 每个文件 DM->>DM: 检测 MIME 类型 @@ -545,7 +545,7 @@ sequenceDiagram User->>API: GET /api/document/:conv_id?used=true&unused=true - API->>CM: verify_conversation_access(user_sub, conv_id) + API->>CM: verify_conversation_access(user_id, conv_id) CM->>DB: SELECT Conversation DB-->>CM: Conversation 数据 CM-->>API: ✓ 有权限 @@ -600,7 +600,7 @@ sequenceDiagram User->>API: DELETE /api/document/:doc_id - API->>DM: delete_document(user_sub, [doc_id]) + API->>DM: delete_document(user_id, [doc_id]) DM->>DB: SELECT Document (验证所有权) DB-->>DM: Document 数据 @@ -642,7 +642,7 @@ sequenceDiagram User->>Chat: 发送消息引用文档 - Chat->>DM: change_doc_status(user_sub, conv_id) + Chat->>DM: change_doc_status(user_id, conv_id) DM->>DB: SELECT Conversation (验证权限) DB-->>DM: Conversation 数据 @@ -659,7 +659,7 @@ sequenceDiagram Note over Chat: 继续处理对话
生成回答 - Chat->>DM: save_answer_doc(user_sub, record_id, doc_infos) + Chat->>DM: save_answer_doc(user_id, record_id, doc_infos) DM->>DB: SELECT Record (验证) DB-->>DM: Record 数据 @@ -731,7 +731,7 @@ erDiagram Document { uuid id PK - string userSub + int userId string name string extension float size @@ -750,7 +750,7 @@ erDiagram Conversation { uuid id PK - string userSub + int userId string title datetime createdAt } @@ -758,7 +758,7 @@ erDiagram Record { uuid id PK uuid conversationId FK - string userSub + int userId string content datetime createdAt } diff --git a/design/services/flow.md b/design/services/flow.md index e235172f2..a7df9957d 100644 --- a/design/services/flow.md +++ b/design/services/flow.md @@ -112,7 +112,7 @@ erDiagram } UserFavorite { - string userSub + string userId uuid itemId string favouriteType } @@ -194,7 +194,7 @@ sequenceDiagram Note over Client,NM: 获取用户服务列表 Client->>API: GET /api/flow/service - API->>FM: get_service_by_user_id(user_sub) + API->>FM: get_service_by_user_id(user_id) FM->>DB: 查询用户收藏的服务 FM->>DB: 查询用户上传的服务 FM->>DB: 去重并查询服务详情 @@ -443,7 +443,7 @@ Authorization: Bearer |------|------|------|--------| | `get_flows_by_app_id` | 获取应用的所有工作流 | `app_id: UUID` | `list[FlowInfo]` | | `get_node_id_by_service_id` | 获取服务下的节点列表 | `service_id: UUID` | `list[NodeMetaDataBase]` | -| `get_service_by_user_id` | 获取用户的服务列表 | `user_sub: str` | `list[NodeServiceItem]` | +| `get_service_by_user_id` | 获取用户的服务列表 | `user_id: str` | `list[NodeServiceItem]` | | `get_node_by_node_id` | 获取节点详细信息 | `node_id: str` | `NodeMetaDataItem` | | `get_flow_by_app_and_flow_id` | 获取工作流详情 | `app_id: UUID, flow_id: str` | `FlowItem` | | `put_flow_by_app_and_flow_id` | 保存/更新工作流 | `app_id: UUID, flow_id: str, flow_item: FlowItem` | `None` | diff --git a/design/services/mcp_service.md b/design/services/mcp_service.md index 265b6c74e..491269171 100644 --- a/design/services/mcp_service.md +++ b/design/services/mcp_service.md @@ -546,7 +546,7 @@ sequenceDiagram Client->>Router: PUT /api/admin/mcp (无mcp_id) Router->>Router: 验证管理员权限 - Router->>Manager: create_mcpservice(data, user_sub) + Router->>Manager: create_mcpservice(data, user_id) Manager->>Manager: 验证配置类型 alt mcp_type为SSE Manager->>Manager: MCPServerSSEConfig.model_validate() @@ -588,7 +588,7 @@ sequenceDiagram Client->>Router: PUT /api/admin/mcp (含mcp_id) Router->>Router: 验证管理员权限 - Router->>Manager: update_mcpservice(data, user_sub) + Router->>Manager: update_mcpservice(data, user_id) alt mcp_id为空 Manager-->>Router: 抛出ValueError @@ -606,10 +606,10 @@ sequenceDiagram DB-->>Manager: 返回activated_users列表 loop 对每个激活用户 - Manager->>Manager: deactive_mcpservice(user_sub, mcp_id) - Manager->>Pool: stop(mcp_id, user_sub) + Manager->>Manager: deactive_mcpservice(user_id, mcp_id) + Manager->>Pool: stop(mcp_id, user_id) Note over Pool: 停止用户的MCP进程 - Manager->>Loader: user_deactive_template(user_sub, mcp_id) + Manager->>Loader: user_deactive_template(user_id, mcp_id) end Manager->>Manager: 构建新的MCPServerConfig @@ -634,7 +634,7 @@ sequenceDiagram Client->>Router: POST /api/admin/mcp/{serviceId}/install?install=true Router->>Router: 验证管理员权限 - Router->>Manager: install_mcpservice(user_sub, service_id, install=true) + Router->>Manager: install_mcpservice(user_id, service_id, install=true) Manager->>DB: 查询服务并验证作者权限 alt 服务不存在或无权限 @@ -681,7 +681,7 @@ sequenceDiagram Client->>Router: POST /api/mcp/{mcpId} (active=true) Router->>Router: 验证用户身份 - Router->>Manager: active_mcpservice(user_sub, mcp_id, mcp_env) + Router->>Manager: active_mcpservice(user_id, mcp_id, mcp_env) Manager->>DB: 查询MCPInfo记录 alt 服务不存在 @@ -695,8 +695,8 @@ sequenceDiagram else 可以激活 DB-->>Manager: 返回服务信息(status=READY) Manager->>DB: 插入或更新MCPActivated记录 - Note over DB: mcpId + userSub为复合键 - Manager->>Loader: user_active_template(user_sub, mcp_id, mcp_env) + Note over DB: mcpId + userId为复合键 + Manager->>Loader: user_active_template(user_id, mcp_id, mcp_env) Loader->>FileSystem: 创建用户专属配置文件 Note over FileSystem: 合并用户环境变量
生成用户配置 Loader->>Pool: 启动用户的MCP进程 @@ -790,7 +790,7 @@ stateDiagram-v2 判断指定用户是否已激活指定MCP服务。 - **功能描述**: 检查用户对MCP服务的激活状态 -- **查询逻辑**: 在MCPActivated表中查找同时匹配mcpId和userSub的记录 +- **查询逻辑**: 在MCPActivated表中查找同时匹配mcpId和userId的记录 - **返回值**: 存在记录返回True,否则返回False - **使用场景**: 服务列表展示时标注激活状态,权限验证 @@ -875,7 +875,7 @@ stateDiagram-v2 - DESCRIPTION: 仅匹配description字段 - AUTHOR: 仅匹配author字段 - **安装状态过滤**: 通过子查询检查mcpId是否在MCPActivated表中 -- **激活状态过滤**: 同样使用子查询,并额外过滤userSub +- **激活状态过滤**: 同样使用子查询,并额外过滤userId - **返回结果**: MCPInfo对象列表或空列表 ### create_mcpservice @@ -1047,7 +1047,7 @@ erDiagram MCPActivated { string mcpId PK "FK" - string userSub PK "FK" + int userId PK "FK" json env datetime activatedAt } @@ -1068,7 +1068,7 @@ erDiagram } User { - string userSub PK + int userId PK string email boolean isAdmin } @@ -1298,7 +1298,7 @@ Server-Sent Events类型的MCP服务通过HTTP连接实现: │ │ └── src/ # 源代码 │ └── ... └── user/ # 用户配置目录 - ├── {user_sub}/ # 单个用户目录 + ├── {user_name}/ # 单个用户目录 │ ├── {mcp_id}.json # 用户专属配置 │ └── ... └── ... diff --git a/design/services/parameter.md b/design/services/parameter.md index 0f2e75154..a3d789753 100644 --- a/design/services/parameter.md +++ b/design/services/parameter.md @@ -86,7 +86,7 @@ sequenceDiagram R->>Auth: 验证会话和令牌 Auth-->>R: 验证通过 - R->>AM: validate_user_app_access(user_sub, appId) + R->>AM: validate_user_app_access(user_id, appId) AM->>DB: 查询用户应用权限 DB-->>AM: 返回权限结果 AM-->>R: 返回权限验证结果 diff --git a/design/services/session.md b/design/services/session.md index 163768cd6..30bdd9d8d 100644 --- a/design/services/session.md +++ b/design/services/session.md @@ -19,7 +19,7 @@ Session 模块是 openEuler Intelligence 框架中的会话管理系统,负责 - **表名**: `framework_session` - **主键**: `id` (String(255), 默认为随机生成的16字节十六进制字符串) - **字段**: - - `userSub`: 用户标识 (String(50), 外键关联framework_user.userSub) + - `userId`: 用户标识 (String(50), 外键关联framework_user.id) - `ip`: IP地址 (String(255), 可为空) - `pluginId`: 插件ID (String(255), 可为空) - `token`: Token信息 (String(2000), 可为空) @@ -73,7 +73,7 @@ sequenceDiagram Note over Client, DB: 获取用户流程 Client->>SessionMgr: get_user(session_id) - SessionMgr->>DB: 查询Session.userSub + SessionMgr->>DB: 查询Session.userId DB-->>SessionMgr: 返回user_sub或None alt 用户不存在 @@ -126,7 +126,7 @@ erDiagram User ||--o{ SessionActivity : "用户产生活动" User { - string userSub PK "用户标识" + int userId PK "用户标识" boolean isWhitelisted "是否白名单" integer credit "信用分" string personalToken "个人令牌" @@ -134,7 +134,7 @@ erDiagram Session { string id PK "会话ID" - string userSub FK "用户标识" + int userId FK "用户标识" string ip "IP地址" string pluginId "插件ID" string token "Token信息" @@ -144,7 +144,7 @@ erDiagram SessionActivity { BigInteger id PK - string userSub FK "用户标识" + int userId FK "用户标识" datetime timestamp "活动时间戳" } ``` @@ -212,10 +212,10 @@ flowchart TD ## 性能优化 -1. **数据库索引**: 对userSub和id字段建立索引,提高查询效率 +1. **数据库索引**: 对userId和id字段建立索引,提高查询效率 2. **异步操作**: 所有数据库操作使用异步方式,提高并发性能 3. **连接池管理**: 使用数据库连接池管理连接,减少连接开销 -4. **最小化查询**: get_user方法只查询必要的userSub字段,而非整个Session对象 +4. **最小化查询**: get_user方法只查询必要的userId字段,而非整个Session对象 ## 与其他模块的交互 diff --git a/design/services/tag.md b/design/services/tag.md index a6df9c7a6..ff54e2a48 100644 --- a/design/services/tag.md +++ b/design/services/tag.md @@ -20,7 +20,7 @@ erDiagram framework_user { BigInteger id PK "主键" - string userSub UK "用户标识" + int userId UK "用户标识" datetime lastLogin "最后登录时间" boolean isActive "是否活跃" boolean isWhitelisted "是否白名单" @@ -40,7 +40,7 @@ erDiagram framework_user_tag { BigInteger id PK "主键" - string userSub FK "用户标识(外键)" + int userId FK "用户标识(外键)" BigInteger tag FK "标签ID(外键)" integer count "标签使用频次(默认0)" } @@ -50,7 +50,7 @@ erDiagram - **framework_tag**: 标签基础信息表,存储标签的定义和元数据 - **framework_user_tag**: 用户标签关联表,记录用户与标签的多对多关系及使用频次 -- **framework_user**: 用户基础信息表,通过userSub字段与标签系统关联 +- **framework_user**: 用户基础信息表,通过userId字段与标签系统关联 ## API接口 @@ -86,7 +86,7 @@ erDiagram - `get_all_tag()`: 获取所有标签 - `get_tag_by_name(name)`: 根据名称获取标签 -- `get_tag_by_user_sub(user_sub)`: 获取用户的所有标签 +- `get_tag_by_user_id(user_id)`: 获取用户的所有标签 - `add_tag(data)`: 添加新标签(当前路由未调用) - `update_tag_by_name(data)`: 更新标签定义,标签不存在时抛出 `ValueError` - `delete_tag(data)`: 删除标签 @@ -139,8 +139,8 @@ sequenceDiagram Note over Client, DB: 用户标签查询流程 Client->>Router: GET /api/user/tag - Router->>Service: get_tag_by_user_sub(user_sub) - Service->>DB: SELECT * FROM framework_user_tag WHERE userSub=? + Router->>Service: get_tag_by_user_id(user_id) + Service->>DB: SELECT * FROM framework_user_tag WHERE userId=? DB-->>Service: 返回用户标签关联 loop 遍历用户标签 Service->>DB: SELECT * FROM framework_tag WHERE id=? diff --git a/design/services/user.md b/design/services/user.md index 86255d62c..526409fe5 100644 --- a/design/services/user.md +++ b/design/services/user.md @@ -34,7 +34,7 @@ erDiagram User { bigint id PK "用户ID" - string userSub UK "用户标识" + string userName UK "用户标识" datetime lastLogin "最后登录时间" boolean isActive "是否活跃" boolean isWhitelisted "是否白名单" @@ -47,14 +47,14 @@ erDiagram UserFavorite { bigint id PK "收藏ID" - string userSub FK "用户标识" + int userId FK "用户标识" enum favouriteType "收藏类型" uuid itemId "项目ID" } UserAppUsage { bigint id PK "使用记录ID" - string userSub FK "用户标识" + int userId FK "用户标识" uuid appId FK "应用ID" int usageCount "使用次数" datetime lastUsed "最后使用时间" @@ -62,7 +62,7 @@ erDiagram UserTag { bigint id PK "标签ID" - string userSub FK "用户标识" + int userId FK "用户标识" bigint tag FK "标签ID" int count "标签归类次数" } @@ -103,7 +103,7 @@ flowchart TD C1 --> C2[分页查询用户列表] C2 --> C3[返回用户列表和总数] - D --> D1[根据userSub查询用户] + D --> D1[根据userId查询用户] D1 --> D2[返回用户信息或None] E --> E1{用户是否存在} @@ -268,7 +268,7 @@ flowchart TD "total": 100, "userInfoList": [ { - "userSub": "user-123", + "userId": "123", "userName": "张三" } ] -- Gitee From 4ea9d6108e97e62c40ad1284f28eee83a275807b Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 24 Oct 2025 17:29:10 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=AF=BC=E5=85=A5?= =?UTF-8?q?=EF=BC=9BuserSub=E6=94=B9userId=EF=BC=883=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/call/api/api.py | 2 +- apps/scheduler/call/convert/convert.py | 3 +- apps/scheduler/call/core.py | 31 +++++++++++----- apps/scheduler/call/facts/facts.py | 4 +-- apps/scheduler/call/facts/schema.py | 2 +- apps/scheduler/call/mcp/mcp.py | 5 +-- apps/scheduler/call/suggest/schema.py | 2 +- apps/scheduler/call/suggest/suggest.py | 10 +++--- apps/scheduler/executor/agent.py | 14 ++++---- apps/scheduler/executor/base.py | 2 +- apps/scheduler/mcp/host.py | 20 +++++------ apps/scheduler/mcp_agent/base.py | 4 +-- apps/scheduler/pool/loader/app.py | 6 ++-- apps/scheduler/pool/loader/call.py | 2 +- apps/scheduler/pool/loader/mcp.py | 30 ++++++++-------- apps/scheduler/pool/loader/service.py | 18 +++++++++- apps/scheduler/pool/mcp/client.py | 14 ++++---- apps/scheduler/pool/mcp/pool.py | 50 +++++++++++++------------- 18 files changed, 126 insertions(+), 93 deletions(-) diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index 774e10f09..1eed0288f 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -88,7 +88,7 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): if self.node.serviceId: try: service_metadata = await ServiceCenterManager.get_service_metadata( - call_vars.ids.user_sub, + call_vars.ids.user_id, self.node.serviceId, ) # 获取Service对应的Auth diff --git a/apps/scheduler/call/convert/convert.py b/apps/scheduler/call/convert/convert.py index cd105df9e..f1e6f616f 100644 --- a/apps/scheduler/call/convert/convert.py +++ b/apps/scheduler/call/convert/convert.py @@ -11,9 +11,10 @@ from jinja2.sandbox import SandboxedEnvironment from pydantic import Field from apps.models import LanguageType -from apps.scheduler.call.core import CallError, CoreCall +from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( + CallError, CallInfo, CallOutputChunk, CallVars, diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index dffe3962e..06939f9b7 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -12,10 +12,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, Self from pydantic import BaseModel, ConfigDict, Field from pydantic.json_schema import SkipJsonSchema +from apps.llm import json_generator from apps.models import ExecutorHistory, LanguageType, NodeInfo from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( - CallError, CallIds, CallInfo, CallOutputChunk, @@ -102,7 +102,7 @@ class CoreCall(BaseModel): task_id=executor.task.metadata.id, executor_id=executor.task.state.executorId, session_id=executor.task.runtime.sessionId, - user_sub=executor.task.metadata.userSub, + user_id=executor.task.metadata.userId, app_id=executor.task.state.appId, conversation_id=executor.task.metadata.conversationId, ), @@ -203,10 +203,25 @@ class CoreCall(BaseModel): yield chunk.content - async def _json(self, messages: list[dict[str, Any]], schema: dict[str, Any]) -> dict[str, Any]: + async def _json(self, messages: list[dict[str, Any]], function: dict[str, Any]) -> dict[str, Any]: """Call可直接使用的JSON生成""" - if not self._llm_obj.function: - err = "[CoreCall] 未设置函数调用模型!" - logger.error(err) - raise CallError(message=err, data={}) - return await self._llm_obj.function.call(messages=messages, schema=schema) + # 从messages中提取最后一条用户消息作为query,其他作为conversation + query = "" + conversation = [] + + for i, msg in enumerate(messages): + role = msg.get("role") + # 跳过system消息 + if role == "system": + continue + # 找到最后一条user消息作为query + if role == "user" and i == len(messages) - 1: + query = msg.get("content", "") + else: + conversation.append(msg) + + return await json_generator.generate( + query=query, + function=function, + conversation=conversation if conversation else None, + ) diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index 6cbb3889c..378ba7634 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -69,7 +69,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ] return FactsInput( - user_sub=call_vars.ids.user_sub, + user_id=call_vars.ids.user_id, message=message, ) @@ -105,7 +105,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): domain_list = DomainGen.model_validate(domain_list) for domain in domain_list.keywords: - await UserTagManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) + await UserTagManager.update_user_domain_by_user_and_domain_name(data.user_id, domain) yield CallOutputChunk( type=CallOutputType.DATA, diff --git a/apps/scheduler/call/facts/schema.py b/apps/scheduler/call/facts/schema.py index c16b94ebc..f4afe68da 100644 --- a/apps/scheduler/call/facts/schema.py +++ b/apps/scheduler/call/facts/schema.py @@ -21,7 +21,7 @@ class FactsGen(BaseModel): class FactsInput(DataBase): """提取事实工具的输入""" - user_sub: str = Field(description="用户ID") + user_id: str = Field(description="用户ID") message: list[dict[str, str]] = Field(description="消息") diff --git a/apps/scheduler/call/mcp/mcp.py b/apps/scheduler/call/mcp/mcp.py index 42cfaff04..4149aa0b0 100644 --- a/apps/scheduler/call/mcp/mcp.py +++ b/apps/scheduler/call/mcp/mcp.py @@ -9,11 +9,12 @@ from typing import Any from pydantic import Field from apps.models import LanguageType -from apps.scheduler.call.core import CallError, CoreCall +from apps.scheduler.call.core import CoreCall from apps.scheduler.mcp import MCPHost, MCPPlanner, MCPSelector from apps.schemas.enum_var import CallOutputType from apps.schemas.mcp import MCPPlanItem from apps.schemas.scheduler import ( + CallError, CallInfo, CallOutputChunk, CallVars, @@ -87,7 +88,7 @@ class MCP(CoreCall, input_model=MCPInput, output_model=MCPOutput): """初始化MCP""" # 获取MCP交互类 self._host = MCPHost( - call_vars.ids.user_sub, + call_vars.ids.user_id, call_vars.ids.task_id, self._llm_obj, self._sys_vars.language, diff --git a/apps/scheduler/call/suggest/schema.py b/apps/scheduler/call/suggest/schema.py index a57f57a64..71fbb52b8 100644 --- a/apps/scheduler/call/suggest/schema.py +++ b/apps/scheduler/call/suggest/schema.py @@ -23,7 +23,7 @@ class SuggestionInput(DataBase): """问题推荐输入""" question: str - user_sub: str + user_id: str history_questions: list[str] diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 95569f318..9b1b5624a 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -81,7 +81,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO self._history_questions = [] else: self._history_questions = await self._get_history_questions( - call_vars.ids.user_sub, + call_vars.ids.user_id, self.conversation_id, ) self._app_id = call_vars.ids.app_id @@ -107,15 +107,15 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO return SuggestionInput( question=call_vars.question, - user_sub=call_vars.ids.user_sub, + user_id=call_vars.ids.user_id, history_questions=self._history_questions, ) - async def _get_history_questions(self, user_sub: str, conversation_id: uuid.UUID) -> list[str]: + async def _get_history_questions(self, user_id: str, conversation_id: uuid.UUID) -> list[str]: """获取当前对话的历史问题""" records = await RecordManager.query_record_by_conversation_id( - user_sub, + user_id, conversation_id, 15, ) @@ -132,7 +132,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO data = SuggestionInput(**input_data) # 获取当前用户的画像 - user_domain_info = await UserTagManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5) + user_domain_info = await UserTagManager.get_user_domain_by_user_and_topk(data.user_id, 5) user_domain = [tag.name for tag in user_domain_info] # 初始化Prompt prompt_tpl = self._env.from_string(SUGGEST_PROMPT[self._sys_vars.language]) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 7f5bf7243..b82f0f965 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -54,7 +54,7 @@ class MCPAgentExecutor(BaseExecutor): # 初始化MCP Host相关对象 self._planner = MCPPlanner(self.task, self.llm) self._host = MCPHost(self.task, self.llm) - user = await UserManager.get_user(self.task.metadata.userSub) + user = await UserManager.get_user(self.task.metadata.userId) if not user: err = "[MCPAgentExecutor] 用户不存在: %s" _logger.error(err) @@ -75,10 +75,10 @@ class MCPAgentExecutor(BaseExecutor): mcp_ids = app_metadata.mcp_service for mcp_id in mcp_ids: - if not await MCPServiceManager.is_user_actived(self.task.metadata.userSub, mcp_id): + if not await MCPServiceManager.is_user_actived(self.task.metadata.userId, mcp_id): _logger.warning( "[MCPAgentExecutor] 用户 %s 未启用MCP %s", - self.task.metadata.userSub, + self.task.metadata.userId, mcp_id, ) continue @@ -231,7 +231,7 @@ class MCPAgentExecutor(BaseExecutor): self._validate_current_tool() current_tool = cast("MCPTools", self._current_tool) - mcp_client = await mcp_pool.get(current_tool.mcpId, self.task.metadata.userSub) + mcp_client = await mcp_pool.get(current_tool.mcpId, self.task.metadata.userId) if not mcp_client: _logger.exception("[MCPAgentExecutor] MCP客户端不存在: %s", current_tool.mcpId) state.stepStatus = StepStatus.ERROR @@ -247,7 +247,7 @@ class MCPAgentExecutor(BaseExecutor): except anyio.ClosedResourceError as e: _logger.exception("[MCPAgentExecutor] MCP客户端连接已关闭: %s", current_tool.mcpId) # 停止当前用户MCP进程 - await mcp_pool.stop(current_tool.mcpId, self.task.metadata.userSub) + await mcp_pool.stop(current_tool.mcpId, self.task.metadata.userId) state.stepStatus = StepStatus.ERROR state.errorMessage = { "err_msg": str(e), @@ -435,7 +435,7 @@ class MCPAgentExecutor(BaseExecutor): if self._retry_times >= AGENT_MAX_RETRY_TIMES: await self.error_handle_after_step() else: - user_info = await UserManager.get_user(self.task.metadata.userSub) + user_info = await UserManager.get_user(self.task.metadata.userId) if not user_info: err = "[MCPAgentExecutor] 用户不存在" _logger.error(err) @@ -539,6 +539,6 @@ class MCPAgentExecutor(BaseExecutor): finally: for mcp_service in self._mcp_list: try: - await mcp_pool.stop(mcp_service.id, self.task.metadata.userSub) + await mcp_pool.stop(mcp_service.id, self.task.metadata.userId) except Exception: _logger.exception("[MCPAgentExecutor] 停止MCP客户端时发生错误") diff --git a/apps/scheduler/executor/base.py b/apps/scheduler/executor/base.py index 1ff24a31d..43876db36 100644 --- a/apps/scheduler/executor/base.py +++ b/apps/scheduler/executor/base.py @@ -51,7 +51,7 @@ class BaseExecutor(BaseModel, ABC): return # 获取最后n+5条Record records = await RecordManager.query_record_by_conversation_id( - self.task.metadata.userSub, self.task.metadata.conversationId, n + 5, + self.task.metadata.userId, self.task.metadata.conversationId, n + 5, ) # 组装问答 context = [] diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 4caf52e69..5d608e054 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -24,10 +24,10 @@ logger = logging.getLogger(__name__) class MCPHost: """MCP宿主服务""" - def __init__(self, user_sub: str, task_id: uuid.UUID, llm: LLMConfig, language: LanguageType) -> None: + def __init__(self, user_id: str, task_id: uuid.UUID, llm: LLMConfig, language: LanguageType) -> None: """初始化MCP宿主""" self._task_id = task_id - self._user_sub = user_sub + self._user_id = user_id self._context_list = [] self._language = language self._llm = llm @@ -44,15 +44,15 @@ class MCPHost: async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" - if not await MCPServiceManager.is_user_actived(self._user_sub, mcp_id): - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + if not await MCPServiceManager.is_user_actived(self._user_id, mcp_id): + logger.warning("用户 %s 未启用MCP %s", self._user_id, mcp_id) return None # 获取MCP配置 try: - return await mcp_pool.get(mcp_id, self._user_sub) + return await mcp_pool.get(mcp_id, self._user_id) except KeyError: - logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_sub, mcp_id) + logger.warning("用户 %s 的MCP %s 没有运行中的实例,请检查环境", self._user_id, mcp_id) return None @@ -126,7 +126,7 @@ class MCPHost: async def call_tool(self, tool: MCPTools, plan_item: MCPPlanItem) -> list[dict[str, Any]]: """调用工具""" # 拿到Client - client = await mcp_pool.get(tool.mcpId, self._user_sub) + client = await mcp_pool.get(tool.mcpId, self._user_id) if client is None: err = f"[MCPHost] MCP Server不合法: {tool.mcpId}" logger.error(err) @@ -153,14 +153,14 @@ class MCPHost: tool_list = [] for mcp_id in mcp_id_list: # 检查用户是否启用了这个mcp - if not await MCPServiceManager.is_user_actived(self._user_sub, mcp_id): - logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) + if not await MCPServiceManager.is_user_actived(self._user_id, mcp_id): + logger.warning("用户 %s 未启用MCP %s", self._user_id, mcp_id) continue # 获取MCP工具配置 try: tool_list.extend(await MCPServiceManager.get_mcp_tools(mcp_id)) except KeyError: - logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) + logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_id, mcp_id) continue return tool_list diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 42524db95..ea3b01a83 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -14,14 +14,14 @@ _logger = logging.getLogger(__name__) class MCPBase: """MCP基类""" - _user_sub: str + _user_id: str _llm: LLMConfig _goal: str _language: LanguageType def __init__(self, task: TaskData, llm: LLMConfig) -> None: """初始化MCP基类""" - self._user_sub = task.metadata.userSub + self._user_id = task.metadata.userId self._llm = llm self._goal = task.runtime.userInput self._language = task.runtime.language diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 619808514..4142989c8 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -159,7 +159,7 @@ class AppLoader: id=metadata.id, name=metadata.name, description=metadata.description, - author=metadata.author, + authorId=metadata.author, appType=metadata.app_type, isPublished=metadata.published, permission=metadata.permission.type if metadata.permission else PermissionType.PRIVATE, @@ -171,10 +171,10 @@ class AppLoader: and metadata.permission.type == PermissionType.PROTECTED and metadata.permission.users ): - for user_sub in metadata.permission.users: + for user_id in metadata.permission.users: await session.merge(AppACL( appId=metadata.id, - userSub=user_sub, + userId=user_id, action="", )) # 保存AppHashes表 diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 8220960a6..c854f079d 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -71,7 +71,7 @@ class CallLoader: # 检查表是否存在 if not await _table_exists(embedding.NodePoolVector.__tablename__): _logger.warning( - "表 %s 不存在,跳过向量数据插入", + "[CallLoader] 表 %s 不存在,跳过向量数据插入", embedding.NodePoolVector.__tablename__, ) return diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index e3e219c8d..5f51166d7 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -249,14 +249,14 @@ class MCPLoader: async def _get_template_tool( mcp_id: str, config: MCPServerConfig, - user_sub: str | None = None, + user_id: str | None = None, ) -> list[MCPTools]: """ 获取MCP模板的工具列表 :param str mcp_id: MCP模板ID :param MCPServerConfig config: MCP配置 - :param str | None user_sub: 用户ID,默认为None + :param str | None user_id: 用户ID,默认为None :return: 工具列表 :rtype: list[str] """ @@ -271,7 +271,7 @@ class MCPLoader: logger.error(err) raise ValueError(err) - await client.init(user_sub, mcp_id, config.mcpServers[mcp_id]) + await client.init(user_id, mcp_id, config.mcpServers[mcp_id]) # 获取工具列表 tool_list = [] @@ -303,7 +303,7 @@ class MCPLoader: overview=config.overview, description=config.description, mcpType=config.mcpType, - author=config.author or "", + authorId=config.author or "", )) await session.commit() @@ -396,19 +396,19 @@ class MCPLoader: @staticmethod - async def user_active_template(user_sub: str, mcp_id: str, mcp_env: dict[str, Any] | None = None) -> None: + async def user_active_template(user_id: str, mcp_id: str, mcp_env: dict[str, Any] | None = None) -> None: """ 用户激活MCP模板 激活MCP模板时,将已安装的环境拷贝一份到用户目录,并更新数据库 - :param str user_sub: 用户ID + :param str user_id: 用户ID :param str mcp_id: MCP模板ID :return: 无 :raises FileExistsError: MCP模板已存在或有同名文件,无法激活 """ template_path = MCP_PATH / "template" / str(mcp_id) - user_path = MCP_PATH / "users" / user_sub / str(mcp_id) + user_path = MCP_PATH / "users" / user_id / str(mcp_id) # 判断是否存在 if await user_path.exists(): @@ -457,23 +457,23 @@ class MCPLoader: async with postgres.session() as session: await session.merge(MCPActivated( mcpId=mcp_id, - userSub=user_sub, + userId=user_id, )) await session.commit() @staticmethod - async def user_deactive_template(user_sub: str, mcp_id: str) -> None: + async def user_deactive_template(user_id: str, mcp_id: str) -> None: """ 取消激活MCP模板 取消激活MCP模板时,删除用户目录下对应的MCP环境文件夹,并更新数据库 - :param str user_sub: 用户ID + :param str user_id: 用户ID :param str mcp_id: MCP模板ID :return: 无 """ # 删除用户目录 - user_path = MCP_PATH / "users" / user_sub / str(mcp_id) + user_path = MCP_PATH / "users" / user_id / str(mcp_id) await asyncer.asyncify(shutil.rmtree)(user_path.as_posix(), ignore_errors=True) # 更新数据库 @@ -481,7 +481,7 @@ class MCPLoader: await session.execute(delete(MCPActivated).where( and_( MCPActivated.mcpId == mcp_id, - MCPActivated.userSub == user_sub, + MCPActivated.userId == user_id, ), )) await session.commit() @@ -551,7 +551,7 @@ class MCPLoader: mcp_activated = (await session.scalars(select(MCPActivated).where(MCPActivated.mcpId == mcp_id))).all() for activated in mcp_activated: - await MCPLoader.user_deactive_template(activated.userSub, mcp_id) + await MCPLoader.user_deactive_template(activated.userId, mcp_id) await session.delete(activated) await session.delete(mcp_info) await session.commit() @@ -643,10 +643,10 @@ class MCPLoader: # 删除所有的激活情况 await session.execute(delete(MCPActivated).where(MCPActivated.mcpId == mcp_id)) # 插入新的激活情况 - for user_sub in user_list: + for user_id in user_list: session.add(MCPActivated( mcpId=mcp_id, - userSub=user_sub, + userId=user_id, )) await session.commit() diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 0caa01d68..c86d0b51b 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -6,7 +6,7 @@ import shutil import uuid from anyio import Path -from sqlalchemy import delete, select +from sqlalchemy import delete, inspect, select from apps.common.config import config from apps.common.postgres import postgres @@ -28,6 +28,14 @@ logger = logging.getLogger(__name__) BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "service" +async def _table_exists(table_name: str) -> bool: + """检查表是否存在""" + async with postgres.engine.connect() as conn: + return await conn.run_sync( + lambda sync_conn: inspect(sync_conn).has_table(table_name), + ) + + class ServiceLoader: """Service 加载器""" @@ -179,6 +187,14 @@ class ServiceLoader: service_description: str, ) -> None: """更新向量数据""" + # 检查表是否存在 + if not await _table_exists(embedding.ServicePoolVector.__tablename__): + logger.warning( + "[ServiceLoader] 表 %s 不存在,跳过向量数据更新", + embedding.ServicePoolVector.__tablename__, + ) + return + service_vecs = await embedding.get_embedding([service_description]) node_descriptions = [] for node in nodes: diff --git a/apps/scheduler/pool/mcp/client.py b/apps/scheduler/pool/mcp/client.py index fc26249f2..b40c5212c 100644 --- a/apps/scheduler/pool/mcp/client.py +++ b/apps/scheduler/pool/mcp/client.py @@ -39,7 +39,7 @@ class MCPClient: self.status = MCPStatus.UNINITIALIZED async def _main_loop( - self, user_sub: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig, + self, user_id: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig, ) -> None: """ 创建MCP Client @@ -47,7 +47,7 @@ class MCPClient: 抽象函数;作用为在初始化的时候使用MCP SDK创建Client 由于目前MCP的实现中Client和Session是1:1的关系,所以直接创建了 :class:`~mcp.ClientSession` - :param str user_sub: 用户ID + :param str user_id: 用户ID :param str mcp_id: MCP ID :param MCPServerSSEConfig | MCPServerStdioConfig config: MCP配置 :return: MCP ClientSession @@ -61,8 +61,8 @@ class MCPClient: headers=env, ) elif isinstance(config, MCPServerStdioConfig): - if user_sub: - cwd = MCP_PATH / "users" / user_sub / mcp_id / "project" + if user_id: + cwd = MCP_PATH / "users" / user_id / mcp_id / "project" else: cwd = MCP_PATH / "template" / mcp_id / "project" await cwd.mkdir(parents=True, exist_ok=True) @@ -108,13 +108,13 @@ class MCPClient: logger.exception("[MCPClient] MCP %s:关闭失败", mcp_id) - async def init(self, user_sub: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig) -> None: + async def init(self, user_id: str | None, mcp_id: str, config: MCPServerSSEConfig | MCPServerStdioConfig) -> None: """ 初始化 MCP Client类 初始化MCP Client,并创建MCP Server进程和ClientSession - :param str user_sub: 用户ID + :param str user_id: 用户ID :param str mcp_id: MCP ID :param MCPServerSSEConfig | MCPServerStdioConfig config: MCP配置 :return: None @@ -126,7 +126,7 @@ class MCPClient: self.stop_sign = asyncio.Event() # 创建协程 - self.task = asyncio.create_task(self._main_loop(user_sub, mcp_id, config)) + self.task = asyncio.create_task(self._main_loop(user_id, mcp_id, config)) # 等待初始化完成 done, pending = await asyncio.wait( diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 8b798aa84..de64646a8 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -24,13 +24,13 @@ class _MCPPool: self.pool: dict[str, dict[str, MCPClient]] = {} - async def init_mcp(self, mcp_id: str, user_sub: str) -> MCPClient | None: + async def init_mcp(self, mcp_id: str, user_id: str) -> MCPClient | None: """初始化MCP池""" - config_path = MCP_USER_PATH / user_sub / mcp_id / "config.json" + config_path = MCP_USER_PATH / user_id / mcp_id / "config.json" flag = (await config_path.exists()) if not flag: - logger.warning("[MCPPool] 用户 %s 的MCP %s 配置文件不存在", user_sub, mcp_id) + logger.warning("[MCPPool] 用户 %s 的MCP %s 配置文件不存在", user_id, mcp_id) return None config = MCPServerConfig.model_validate_json(await config_path.read_text()) @@ -38,67 +38,67 @@ class _MCPPool: if config.mcpType in (MCPType.SSE, MCPType.STDIO): client = MCPClient() else: - logger.warning("[MCPPool] 用户 %s 的MCP %s 类型错误", user_sub, mcp_id) + logger.warning("[MCPPool] 用户 %s 的MCP %s 类型错误", user_id, mcp_id) return None - await client.init(user_sub, mcp_id, config.mcpServers[mcp_id]) - if user_sub not in self.pool: - self.pool[user_sub] = {} - self.pool[user_sub][mcp_id] = client + await client.init(user_id, mcp_id, config.mcpServers[mcp_id]) + if user_id not in self.pool: + self.pool[user_id] = {} + self.pool[user_id][mcp_id] = client return client - async def _get_from_dict(self, mcp_id: str, user_sub: str) -> MCPClient | None: + async def _get_from_dict(self, mcp_id: str, user_id: str) -> MCPClient | None: """从字典中获取MCP客户端""" - if user_sub not in self.pool: + if user_id not in self.pool: return None - if mcp_id not in self.pool[user_sub]: + if mcp_id not in self.pool[user_id]: return None - return self.pool[user_sub][mcp_id] + return self.pool[user_id][mcp_id] - async def _validate_user(self, mcp_id: str, user_sub: str) -> bool: + async def _validate_user(self, mcp_id: str, user_id: str) -> bool: """验证用户是否已激活""" async with postgres.session() as session: result = (await session.scalars( select(MCPActivated).where( and_( MCPActivated.mcpId == mcp_id, - MCPActivated.userSub == user_sub, + MCPActivated.userId == user_id, ), ).limit(1), )).one_or_none() return result is not None - async def get(self, mcp_id: str, user_sub: str) -> MCPClient | None: + async def get(self, mcp_id: str, user_id: str) -> MCPClient | None: """获取MCP客户端""" - item = await self._get_from_dict(mcp_id, user_sub) + item = await self._get_from_dict(mcp_id, user_id) if item is None: # 检查用户是否已激活 - if not await self._validate_user(mcp_id, user_sub): - logger.warning("用户 %s 未激活MCP %s", user_sub, mcp_id) + if not await self._validate_user(mcp_id, user_id): + logger.warning("用户 %s 未激活MCP %s", user_id, mcp_id) return None # 初始化进程 - item = await self.init_mcp(mcp_id, user_sub) + item = await self.init_mcp(mcp_id, user_id) if item is None: return None - if user_sub not in self.pool: - self.pool[user_sub] = {} + if user_id not in self.pool: + self.pool[user_id] = {} - self.pool[user_sub][mcp_id] = item + self.pool[user_id][mcp_id] = item return item - async def stop(self, mcp_id: str, user_sub: str) -> None: + async def stop(self, mcp_id: str, user_id: str) -> None: """停止MCP客户端""" - await self.pool[user_sub][mcp_id].stop() - del self.pool[user_sub][mcp_id] + await self.pool[user_id][mcp_id].stop() + del self.pool[user_id][mcp_id] # 创建单例对象 -- Gitee From 839d40ec89b8b2e67c2fd26642e080321e1f7049 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 24 Oct 2025 17:30:05 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E6=B8=85=E9=99=A4=E4=B8=8D=E7=94=A8?= =?UTF-8?q?=E7=9A=84=E6=95=B0=E6=8D=AE=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3=E6=95=B0=E6=8D=AE=E7=BB=93=E6=9E=84=E7=B1=BB=E5=BD=92?= =?UTF-8?q?=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/message.py | 20 ++------------------ apps/schemas/response_data.py | 29 ----------------------------- apps/schemas/scheduler.py | 2 +- apps/schemas/user.py | 33 ++++++++++++++++++++++++++++++++- 4 files changed, 35 insertions(+), 49 deletions(-) diff --git a/apps/schemas/message.py b/apps/schemas/message.py index 3b2b1706d..9f6dac514 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -27,7 +27,7 @@ class HeartbeatData(BaseModel): ) -class MessageFlow(BaseModel): +class MessageExecutor(BaseModel): """消息中有关Flow信息的部分""" app_id: uuid.UUID | None = Field(description="插件ID", alias="appId", default=None) @@ -48,22 +48,6 @@ class MessageMetadata(RecordMetadata): feature: None = None -class InitContentFeature(BaseModel): - """init消息的feature""" - - max_tokens: int = Field(description="最大生成token数", ge=0, alias="maxTokens") - context_num: int = Field(description="上下文消息数量", le=10, ge=0, alias="contextNum") - enable_feedback: bool = Field(description="是否启用反馈", alias="enableFeedback") - enable_regenerate: bool = Field(description="是否启用重新生成", alias="enableRegenerate") - - -class InitContent(BaseModel): - """init消息的content""" - - feature: InitContentFeature = Field(description="问答功能开关") - created_at: float = Field(description="创建时间", alias="createdAt") - - class TextAddContent(BaseModel): """text.add消息的content""" @@ -75,6 +59,6 @@ class MessageBase(HeartbeatData): id: uuid.UUID = Field(min_length=36, max_length=36) conversation_id: uuid.UUID | None = Field(min_length=36, max_length=36, alias="conversationId", default=None) - flow: MessageFlow | None = None + flow: MessageExecutor | None = None content: Any | None = Field(default=None, description="消息内容") metadata: MessageMetadata diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index f0580aa87..f39fdbe6a 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -17,7 +17,6 @@ from .parameters import ( Type, ) from .record import RecordData -from .user import UserInfo class ResponseData(BaseModel): @@ -28,21 +27,6 @@ class ResponseData(BaseModel): result: Any -class AuthUserMsg(BaseModel): - """GET /api/auth/user Result数据结构""" - - user_sub: str - revision: bool - is_admin: bool - auto_execute: bool - - -class AuthUserRsp(ResponseData): - """GET /api/auth/user 返回数据结构""" - - result: AuthUserMsg - - class HealthCheckRsp(BaseModel): """GET /health_check 返回数据结构""" @@ -87,19 +71,6 @@ class ListTeamKnowledgeRsp(ResponseData): result: ListTeamKnowledgeMsg -class UserGetMsp(BaseModel): - """GET /api/user result""" - - total: int = Field(default=0) - user_info_list: list[UserInfo] = Field(alias="userInfoList", default=[]) - - -class UserGetRsp(ResponseData): - """GET /api/user 返回数据结构""" - - result: UserGetMsp - - class LLMProviderInfo(BaseModel): """LLM数据结构""" diff --git a/apps/schemas/scheduler.py b/apps/schemas/scheduler.py index 98c34ef68..fa7cd8dd0 100644 --- a/apps/schemas/scheduler.py +++ b/apps/schemas/scheduler.py @@ -25,7 +25,7 @@ class CallIds(BaseModel): executor_id: str = Field(description="Flow ID") session_id: str | None = Field(description="当前用户的Session ID") app_id: uuid.UUID | None = Field(description="当前应用的ID") - user_sub: str = Field(description="当前用户的用户ID") + user_id: str = Field(description="当前用户的用户ID") conversation_id: uuid.UUID | None = Field(description="当前对话的ID") diff --git a/apps/schemas/user.py b/apps/schemas/user.py index 24de0aecb..df88992cb 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -3,9 +3,40 @@ from pydantic import BaseModel, Field +from .response_data import ResponseData -class UserInfo(BaseModel): + +class UserListItem(BaseModel): """用户信息数据结构""" user_sub: str = Field(alias="userSub", default="") user_name: str = Field(alias="userName", default="") + + +class UserInfoMsg(BaseModel): + """GET /api/user Result数据结构""" + + id: int + user_name: str = Field(alias="userName", default="") + is_admin: bool = Field(alias="isAdmin", default=False) + personal_token: str = Field(alias="personalToken", default="") + auto_execute: bool = Field(alias="autoExecute", default=False) + + +class UserInfoRsp(ResponseData): + """GET /api/user Result数据结构""" + + result: UserInfoMsg + + +class UserListMsg(BaseModel): + """GET /api/user/list Result数据结构""" + + total: int = Field(default=0) + user_info_list: list[UserListItem] = Field(alias="userInfoList", default=[]) + + +class UserListRsp(ResponseData): + """GET /api/user/list 返回数据结构""" + + result: UserListMsg -- Gitee From abd417837f4930daca9a12c4b1da47081e3b3c57 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Fri, 24 Oct 2025 17:30:37 +0800 Subject: [PATCH 5/5] =?UTF-8?q?=E6=8C=89=E5=8A=9F=E8=83=BD=E6=8B=86?= =?UTF-8?q?=E5=88=86Scheduler=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/conversation.py | 35 ++ apps/scheduler/scheduler/data.py | 97 ++++ apps/scheduler/scheduler/executor.py | 220 +++++++++ apps/scheduler/scheduler/flow.py | 67 +++ apps/scheduler/scheduler/init.py | 150 ++++++ apps/scheduler/scheduler/message.py | 24 + apps/scheduler/scheduler/scheduler.py | 597 ++--------------------- 7 files changed, 622 insertions(+), 568 deletions(-) create mode 100644 apps/scheduler/scheduler/conversation.py create mode 100644 apps/scheduler/scheduler/data.py create mode 100644 apps/scheduler/scheduler/executor.py create mode 100644 apps/scheduler/scheduler/flow.py create mode 100644 apps/scheduler/scheduler/init.py create mode 100644 apps/scheduler/scheduler/message.py diff --git a/apps/scheduler/scheduler/conversation.py b/apps/scheduler/scheduler/conversation.py new file mode 100644 index 000000000..8885704f7 --- /dev/null +++ b/apps/scheduler/scheduler/conversation.py @@ -0,0 +1,35 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""对话管理相关的Mixin类""" + +import logging +import uuid + +from apps.models import Conversation +from apps.services.appcenter import AppCenterManager +from apps.services.conversation import ConversationManager + +_logger = logging.getLogger(__name__) + + +class ConversationMixin: + """处理对话管理相关的逻辑""" + + async def create_new_conversation( + self, title: str, user_id: str, app_id: uuid.UUID | None = None, + *, + debug: bool = False, + ) -> Conversation: + """判断并创建新对话""" + if app_id and not await AppCenterManager.validate_user_app_access(user_id, app_id): + err = "Invalid app_id." + raise RuntimeError(err) + new_conv = await ConversationManager.add_conversation_by_user( + title=title, + user_id=user_id, + app_id=app_id, + debug=debug, + ) + if not new_conv: + err = "Create new conversation failed." + raise RuntimeError(err) + return new_conv diff --git a/apps/scheduler/scheduler/data.py b/apps/scheduler/scheduler/data.py new file mode 100644 index 000000000..008f5a23d --- /dev/null +++ b/apps/scheduler/scheduler/data.py @@ -0,0 +1,97 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""数据管理相关的Mixin类""" + +import logging +from datetime import UTC, datetime + +from apps.common.security import Security +from apps.models import Record, RecordMetadata, StepStatus +from apps.schemas.record import FlowHistory, RecordContent +from apps.schemas.request_data import RequestData +from apps.schemas.task import TaskData +from apps.services.appcenter import AppCenterManager +from apps.services.document import DocumentManager +from apps.services.record import RecordManager +from apps.services.task import TaskManager + +_logger = logging.getLogger(__name__) + + +class DataMixin: + """处理数据保存和管理相关的逻辑""" + + task: TaskData + post_body: RequestData + + async def _save_data(self) -> None: + """保存当前Executor、Task、Record等的数据""" + task = self.task + user_id = self.task.metadata.userId + post_body = self.post_body + + used_docs = [] + record_group = None # TODO: 需要从适当的地方获取record_group + + if hasattr(task.runtime, "documents") and task.runtime.documents: + for docs in task.runtime.documents: + doc_dict = docs if isinstance(docs, dict) else (docs.model_dump() if hasattr(docs, "model_dump") else docs) + used_docs.append(doc_dict) + + record_content = RecordContent( + question=task.runtime.question if hasattr(task.runtime, "question") else "", + answer=task.runtime.answer if hasattr(task.runtime, "answer") else "", + facts=task.runtime.facts if hasattr(task.runtime, "facts") else [], + data={}, + ) + + try: + encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) + except Exception: + _logger.exception("[Scheduler] 问答对加密错误") + return + + if task.state: + await TaskManager.save_flow_context(task.context) + + current_time = round(datetime.now(UTC).timestamp(), 2) + record = Record( + id=task.metadata.id, + conversationId=task.metadata.conversationId, + taskId=task.metadata.id, + userId=user_id, + content=encrypt_data, + key=encrypt_config, + metadata=RecordMetadata( + timeCost=0, # TODO: 需要从task中获取时间成本 + inputTokens=0, # TODO: 需要从task中获取token信息 + outputTokens=0, # TODO: 需要从task中获取token信息 + feature={}, + ), + createdAt=current_time, + flow=FlowHistory( + flow_id=task.state.flow_id if task.state else "", + flow_name=task.state.flow_name if task.state else "", + flow_status=task.state.flow_status if task.state else StepStatus.SUCCESS, + history_ids=[context.id for context in task.context], + ) if task.state else None, + ) + + if record_group and post_body.conversation_id: + await DocumentManager.change_doc_status(user_id, post_body.conversation_id, record_group) + + if post_body.conversation_id: + await RecordManager.insert_record_data(user_id, post_body.conversation_id, record) + + if record_group and used_docs: + await DocumentManager.save_answer_doc(user_id, record_group, used_docs) + + if post_body.app and post_body.app.app_id: + await AppCenterManager.update_recent_app(user_id, post_body.app.app_id) + + if not task.state or task.state.flow_status in [StepStatus.SUCCESS, StepStatus.ERROR, StepStatus.CANCELLED]: + await TaskManager.delete_task_by_task_id(task.metadata.id) + else: + await TaskManager.save_task(task.metadata.id, task.metadata) + await TaskManager.save_task_runtime(task.runtime) + if task.state: + await TaskManager.save_executor_checkpoint(task.state) diff --git a/apps/scheduler/scheduler/executor.py b/apps/scheduler/scheduler/executor.py new file mode 100644 index 000000000..6aa616fbd --- /dev/null +++ b/apps/scheduler/scheduler/executor.py @@ -0,0 +1,220 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""执行器相关的Mixin类""" + +import asyncio +import logging +import uuid + +from apps.common.queue import MessageQueue +from apps.llm import LLMConfig +from apps.models import AppType, ExecutorStatus +from apps.scheduler.executor.agent import MCPAgentExecutor +from apps.scheduler.executor.flow import FlowExecutor +from apps.scheduler.executor.qa import QAExecutor +from apps.scheduler.pool.pool import pool +from apps.schemas.request_data import RequestData +from apps.schemas.task import TaskData +from apps.services.activity import Activity +from apps.services.appcenter import AppCenterManager +from apps.services.conversation import ConversationManager + +_logger = logging.getLogger(__name__) + + +class ExecutorMixin: + """处理执行器相关的逻辑""" + + task: TaskData + post_body: RequestData + queue: MessageQueue + llm: LLMConfig + + async def get_top_flow(self) -> str: + """获取Top1 Flow (由FlowMixin实现)""" + ... + + async def _determine_app_id(self) -> uuid.UUID | None: + """确定最终使用的 app_id""" + conversation = None + + if self.task.metadata.conversationId: + conversation = await ConversationManager.get_conversation_by_conversation_id( + self.task.metadata.userId, + self.task.metadata.conversationId, + ) + + if conversation and conversation.appId: + final_app_id = conversation.appId + _logger.info("[Scheduler] 使用Conversation中的appId: %s", final_app_id) + + if self.post_body.app and self.post_body.app.app_id and self.post_body.app.app_id != conversation.appId: + _logger.warning( + "[Scheduler] post_body中的app_id(%s)与Conversation中的appId(%s)不符,忽略post_body中的app信息", + self.post_body.app.app_id, + conversation.appId, + ) + self.post_body.app.app_id = conversation.appId + elif self.post_body.app and self.post_body.app.app_id: + final_app_id = self.post_body.app.app_id + _logger.info("[Scheduler] Conversation中无appId,使用post_body中的app_id: %s", final_app_id) + else: + final_app_id = None + _logger.info("[Scheduler] Conversation和post_body中均无appId,fallback到智能问答") + + return final_app_id + + async def _create_executor_task(self, final_app_id: uuid.UUID | None) -> asyncio.Task | None: + """ + 根据 app_id 创建对应的执行器任务 + + Args: + final_app_id: 要使用的 app_id,None 表示使用 QA 模式 + + Returns: + 创建的异步任务,如果发生错误则返回 None + + """ + if final_app_id is None: + _logger.info("[Scheduler] 运行智能问答模式") + await self._push_init_message(3, is_flow=False) + return asyncio.create_task(self._run_qa()) + + try: + app_data = await AppCenterManager.fetch_app_metadata_by_id(final_app_id) + except ValueError: + _logger.exception("[Scheduler] App %s 不存在或元数据文件缺失", final_app_id) + await self.queue.close() + return None + + context_num = app_data.history_len + _logger.info("[Scheduler] App上下文窗口: %d", context_num) + + if app_data.app_type == AppType.FLOW: + _logger.info("[Scheduler] 运行Flow应用") + await self._push_init_message(context_num, is_flow=True) + return asyncio.create_task(self._run_flow()) + + _logger.info("[Scheduler] 运行MCP Agent应用") + await self._push_init_message(context_num, is_flow=False) + return asyncio.create_task(self._run_agent()) + + async def _handle_task_cancellation(self, main_task: asyncio.Task) -> None: + """ + 处理任务取消的逻辑 + + Args: + main_task: 需要取消的主任务 + + """ + _logger.warning("[Scheduler] 用户取消执行,正在终止...") + main_task.cancel() + try: + await main_task + except asyncio.CancelledError: + _logger.info("[Scheduler] 主任务已取消") + except Exception: + _logger.exception("[Scheduler] 终止工作流时发生错误") + + if self.task.state and self.task.state.executorStatus in [ + ExecutorStatus.INIT, + ExecutorStatus.RUNNING, + ExecutorStatus.WAITING, + ]: + self.task.state.executorStatus = ExecutorStatus.CANCELLED + _logger.info("[Scheduler] ExecutorStatus已设置为CANCELLED") + elif self.task.state: + _logger.info("[Scheduler] ExecutorStatus为 %s,保持不变", self.task.state.executorStatus) + else: + _logger.warning("[Scheduler] task.state为None,无法更新ExecutorStatus") + + async def _monitor_activity(self, kill_event: asyncio.Event, user_id: str) -> None: + """监控用户活动状态,不活跃时终止工作流""" + try: + check_interval = 0.5 + + while not kill_event.is_set(): + is_active = await Activity.is_active(user_id) + + if not is_active: + _logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_id) + kill_event.set() + break + + await asyncio.sleep(check_interval) + except asyncio.CancelledError: + _logger.info("[Scheduler] 活动监控任务已取消") + except Exception: + _logger.exception("[Scheduler] 活动监控过程中发生错误") + kill_event.set() + + async def _run_qa(self) -> None: + """运行QA执行器""" + qa_executor = QAExecutor( + task=self.task, + msg_queue=self.queue, + question=self.post_body.question, + llm=self.llm, + ) + _logger.info("[Scheduler] 开始智能问答") + await qa_executor.init() + await qa_executor.run() + self.task = qa_executor.task + + async def _run_flow(self) -> None: + """运行Flow执行器""" + if not self.post_body.app or not self.post_body.app.app_id: + _logger.error("[Scheduler] 未选择应用") + return + + _logger.info("[Scheduler] 获取工作流元数据") + flow_info = await pool.get_flow_metadata(self.post_body.app.app_id) + + if not flow_info: + _logger.error("[Scheduler] 未找到工作流元数据") + return + + if not self.post_body.app.flow_id: + _logger.info("[Scheduler] 选择最合适的流") + flow_id = await self.get_top_flow() + else: + flow_id = self.post_body.app.flow_id + _logger.info("[Scheduler] 获取工作流定义") + flow_data = await pool.get_flow(self.post_body.app.app_id, flow_id) + + if not flow_data: + _logger.error("[Scheduler] 未找到工作流定义") + return + + flow_exec = FlowExecutor( + flow_id=flow_id, + flow=flow_data, + task=self.task, + msg_queue=self.queue, + question=self.post_body.question, + post_body_app=self.post_body.app, + llm=self.llm, + ) + + _logger.info("[Scheduler] 运行工作流执行器") + await flow_exec.init() + await flow_exec.run() + self.task = flow_exec.task + + async def _run_agent(self) -> None: + """构造Executor并执行""" + if not self.post_body.app or not self.post_body.app.app_id: + _logger.error("[Scheduler] 未选择MCP应用") + return + + agent_exec = MCPAgentExecutor( + task=self.task, + msg_queue=self.queue, + question=self.post_body.question, + agent_id=self.post_body.app.app_id, + params=self.post_body.app.params, + llm=self.llm, + ) + _logger.info("[Scheduler] 运行MCP执行器") + await agent_exec.init() + await agent_exec.run() + self.task = agent_exec.task diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py new file mode 100644 index 000000000..c42b0ee96 --- /dev/null +++ b/apps/scheduler/scheduler/flow.py @@ -0,0 +1,67 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""Flow相关的Mixin类""" + +import logging + +from jinja2.sandbox import SandboxedEnvironment + +from apps.llm import json_generator +from apps.scheduler.pool.pool import pool +from apps.schemas.request_data import RequestData +from apps.schemas.scheduler import TopFlow +from apps.schemas.task import TaskData + +from .prompt import FLOW_SELECT + +_logger = logging.getLogger(__name__) + + +class FlowMixin: + """处理Flow相关的逻辑""" + + post_body: RequestData + task: TaskData + _env: SandboxedEnvironment + + async def get_top_flow(self) -> str: + """获取Top1 Flow""" + if not self.post_body.app or not self.post_body.app.app_id: + err = "[Scheduler] 未选择应用" + _logger.error(err) + raise RuntimeError(err) + + flow_list = await pool.get_flow_metadata(self.post_body.app.app_id) + if not flow_list: + err = "[Scheduler] 未找到应用中合法的Flow" + _logger.error(err) + raise RuntimeError(err) + + _logger.info("[Scheduler] 选择应用 %s 最合适的Flow", self.post_body.app.app_id) + choices = [{ + "name": flow.id, + "description": f"{flow.name}, {flow.description}", + } for flow in flow_list] + + template = self._env.from_string(FLOW_SELECT[self.task.runtime.language]) + prompt = template.render( + template, + question=self.post_body.question, + choice_list=choices, + ) + schema = TopFlow.model_json_schema() + schema["properties"]["choice"]["enum"] = [choice["name"] for choice in choices] + function = { + "name": "select_flow", + "description": "Select the appropriate flow", + "parameters": schema, + } + result_str = await json_generator.generate( + query=self.post_body.question, + function=function, + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + ) + result = TopFlow.model_validate(result_str) + return result.choice diff --git a/apps/scheduler/scheduler/init.py b/apps/scheduler/scheduler/init.py new file mode 100644 index 000000000..bdd53a3e9 --- /dev/null +++ b/apps/scheduler/scheduler/init.py @@ -0,0 +1,150 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""初始化相关的Mixin类""" + +import logging +import uuid + +from jinja2 import BaseLoader +from jinja2.sandbox import SandboxedEnvironment + +from apps.common.queue import MessageQueue +from apps.llm import LLM, LLMConfig, embedding +from apps.models import Task, TaskRuntime, User +from apps.schemas.request_data import RequestData +from apps.schemas.task import TaskData +from apps.services.conversation import ConversationManager +from apps.services.llm import LLMManager +from apps.services.task import TaskManager +from apps.services.user import UserManager + +_logger = logging.getLogger(__name__) + + +class InitializationMixin: + """处理Scheduler初始化相关的逻辑""" + + task: TaskData + user: User + post_body: RequestData + queue: MessageQueue + llm: LLMConfig + _env: SandboxedEnvironment + + async def _get_scheduler_llm(self, reasoning_llm_id: str) -> LLMConfig: + """获取RAG大模型""" + reasoning_llm = await LLMManager.get_llm(reasoning_llm_id) + if not reasoning_llm: + err = "[Scheduler] 获取问答用大模型ID失败" + _logger.error(err) + raise ValueError(err) + reasoning_llm = LLM(reasoning_llm) + + function_llm = None + if not self.user.functionLLM: + _logger.error("[Scheduler] 用户 %s 没有设置函数调用大模型,相关功能将被禁用", self.user.id) + else: + function_llm = await LLMManager.get_llm(self.user.functionLLM) + if not function_llm: + _logger.error( + "[Scheduler] 用户 %s 设置的函数调用大模型ID %s 不存在,相关功能将被禁用", + self.user.id, self.user.functionLLM, + ) + else: + function_llm = LLM(function_llm) + + embedding_obj = None + if not self.user.embeddingLLM: + _logger.error("[Scheduler] 用户 %s 没有设置向量模型,相关功能将被禁用", self.user.id) + else: + embedding_llm_config = await LLMManager.get_llm(self.user.embeddingLLM) + if not embedding_llm_config: + _logger.error( + "[Scheduler] 用户 %s 设置的向量模型ID %s 不存在,相关功能将被禁用", + self.user.id, self.user.embeddingLLM, + ) + else: + await embedding.init(embedding_llm_config) + embedding_obj = embedding + + return LLMConfig( + reasoning=reasoning_llm, + function=function_llm, + embedding=embedding_obj, + ) + + def _create_new_task(self, task_id: uuid.UUID, user_id: str, conversation_id: uuid.UUID | None) -> TaskData: + """创建新的TaskData""" + return TaskData( + metadata=Task( + id=task_id, + userId=user_id, + conversationId=conversation_id, + ), + runtime=TaskRuntime( + taskId=task_id, + ), + state=None, + context=[], + ) + + async def _init_task(self, task_id: uuid.UUID, user_id: str) -> None: + """初始化Task""" + conversation_id = self.post_body.conversation_id + + if not conversation_id: + _logger.info("[Scheduler] 无Conversation ID,直接创建新任务") + self.task = self._create_new_task(task_id, user_id, None) + return + + _logger.info("[Scheduler] 尝试从Conversation ID %s 恢复任务", conversation_id) + + restored = False + try: + conversation = await ConversationManager.get_conversation_by_conversation_id( + user_id, + conversation_id, + ) + + if conversation: + last_task = await TaskManager.get_task_by_conversation_id(conversation_id, user_id) + + if last_task and last_task.id: + _logger.info("[Scheduler] 从Conversation恢复任务 %s", last_task.id) + task_data = await TaskManager.get_task_data_by_task_id(last_task.id) + if task_data: + self.task = task_data + self.task.metadata.id = task_id + self.task.runtime.taskId = task_id + if self.task.state: + self.task.state.taskId = task_id + restored = True + else: + _logger.warning( + "[Scheduler] Conversation %s 不存在或无权访问,创建新任务", + conversation_id, + ) + except Exception: + _logger.exception("[Scheduler] 从Conversation恢复任务失败,创建新任务") + + if not restored: + _logger.info("[Scheduler] 无法恢复任务,创建新任务") + self.task = self._create_new_task(task_id, user_id, conversation_id) + + async def _init_user(self, user_id: str) -> None: + """初始化用户""" + user = await UserManager.get_user(user_id) + if not user: + err = f"[Scheduler] 用户 {user_id} 不存在" + _logger.error(err) + raise RuntimeError(err) + self.user = user + + def _init_jinja2_env(self) -> None: + """初始化Jinja2环境""" + self._env = SandboxedEnvironment( + loader=BaseLoader(), + autoescape=False, + trim_blocks=True, + lstrip_blocks=True, + extensions=["jinja2.ext.loopcontrols"], + ) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py new file mode 100644 index 000000000..8a0f0bd96 --- /dev/null +++ b/apps/scheduler/scheduler/message.py @@ -0,0 +1,24 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""消息推送相关的Mixin类""" + +import logging + +from apps.common.queue import MessageQueue +from apps.llm import LLMConfig +from apps.schemas.enum_var import EventType +from apps.schemas.task import TaskData + +_logger = logging.getLogger(__name__) + + +class MessagingMixin: + """处理消息推送相关的逻辑""" + + queue: MessageQueue + task: TaskData + llm: LLMConfig + + async def _push_done_message(self) -> None: + """推送任务完成消息""" + _logger.info("[Scheduler] 发送结束消息") + await self.queue.push_output(self.task, self.llm, event_type=EventType.DONE.value, data={}) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 3a9f4a176..436a17e25 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -4,617 +4,78 @@ import asyncio import logging import uuid -from datetime import UTC, datetime - -from jinja2 import BaseLoader -from jinja2.sandbox import SandboxedEnvironment from apps.common.queue import MessageQueue -from apps.common.security import Security -from apps.llm import LLM, LLMConfig, embedding, json_generator -from apps.models import ( - AppType, - Conversation, - ExecutorStatus, - Record, - RecordMetadata, - StepStatus, - Task, - TaskRuntime, - User, -) -from apps.scheduler.executor.agent import MCPAgentExecutor -from apps.scheduler.executor.flow import FlowExecutor -from apps.scheduler.executor.qa import QAExecutor -from apps.scheduler.pool.pool import pool -from apps.schemas.enum_var import EventType -from apps.schemas.message import ( - InitContent, - InitContentFeature, -) -from apps.schemas.record import FlowHistory, RecordContent from apps.schemas.request_data import RequestData -from apps.schemas.scheduler import TopFlow -from apps.schemas.task import TaskData -from apps.services.activity import Activity -from apps.services.appcenter import AppCenterManager -from apps.services.conversation import ConversationManager -from apps.services.document import DocumentManager -from apps.services.llm import LLMManager -from apps.services.record import RecordManager -from apps.services.task import TaskManager -from apps.services.user import UserManager -from .prompt import FLOW_SELECT +from .conversation import ConversationMixin +from .data import DataMixin +from .executor import ExecutorMixin +from .flow import FlowMixin +from .init import InitializationMixin +from .message import MessagingMixin _logger = logging.getLogger(__name__) -class Scheduler: - """ - “调度器”,是最顶层的、控制Executor执行顺序和状态的逻辑。 - - Scheduler包含一个“SchedulerContext”,作用为多个Executor的“聊天会话” +class Scheduler( + InitializationMixin, + ExecutorMixin, + FlowMixin, + ConversationMixin, + DataMixin, + MessagingMixin, +): """ + "调度器",是最顶层的、控制Executor执行顺序和状态的逻辑。 - task: TaskData - llm: LLMConfig - queue: MessageQueue - post_body: RequestData - user: User + Scheduler包含一个"SchedulerContext",作用为多个Executor的"聊天会话" + 所有属性都继承自各个Mixin类,主要包括: + - task: TaskData (来自InitializationMixin) + - llm: LLMConfig (来自InitializationMixin) + - queue: MessageQueue (来自InitializationMixin) + - post_body: RequestData (来自InitializationMixin) + - user: User (来自InitializationMixin) + - _env: SandboxedEnvironment (来自InitializationMixin) + """ async def init( self, task_id: uuid.UUID, queue: MessageQueue, post_body: RequestData, - user_sub: str, + user_id: str, ) -> None: """初始化""" self.queue = queue self.post_body = post_body - # 获取用户 - user = await UserManager.get_user(user_sub) - if not user: - err = f"[Scheduler] 用户 {user_sub} 不存在" - _logger.error(err) - raise RuntimeError(err) - self.user = user - - # 初始化Task - await self._init_task(task_id, user_sub) - - # Jinja2 - self._env = SandboxedEnvironment( - loader=BaseLoader(), - autoescape=False, - trim_blocks=True, - lstrip_blocks=True, - extensions=["jinja2.ext.loopcontrols"], - ) - # LLM + await self._init_user(user_id) + await self._init_task(task_id, user_id) + self._init_jinja2_env() self.llm = await self._get_scheduler_llm(post_body.llm_id) - async def _push_done_message(self) -> None: - """推送任务完成消息""" - _logger.info("[Scheduler] 发送结束消息") - await self.queue.push_output(self.task, self.llm, event_type=EventType.DONE.value, data={}) - - async def _determine_app_id(self) -> uuid.UUID | None: - """ - 确定最终使用的 app_id - - Returns: - final_app_id: 最终使用的 app_id,如果为 None 则使用 QA 模式 - - """ - conversation = None - - if self.task.metadata.conversationId: - conversation = await ConversationManager.get_conversation_by_conversation_id( - self.task.metadata.userSub, - self.task.metadata.conversationId, - ) - - if conversation and conversation.appId: - # Conversation中有appId,使用它 - final_app_id = conversation.appId - _logger.info("[Scheduler] 使用Conversation中的appId: %s", final_app_id) - - # 如果post_body中也有app_id且与Conversation不符,忽略post_body中的app信息 - if self.post_body.app and self.post_body.app.app_id and self.post_body.app.app_id != conversation.appId: - _logger.warning( - "[Scheduler] post_body中的app_id(%s)与Conversation中的appId(%s)不符,忽略post_body中的app信息", - self.post_body.app.app_id, - conversation.appId, - ) - # 清空post_body中的app信息,以Conversation为准 - self.post_body.app.app_id = conversation.appId - elif self.post_body.app and self.post_body.app.app_id: - # Conversation中appId为None,使用post_body中的app信息 - final_app_id = self.post_body.app.app_id - _logger.info("[Scheduler] Conversation中无appId,使用post_body中的app_id: %s", final_app_id) - else: - # 两者都为None,fallback到QAExecutor - final_app_id = None - _logger.info("[Scheduler] Conversation和post_body中均无appId,fallback到智能问答") - - return final_app_id - - async def _create_executor_task(self, final_app_id: uuid.UUID | None) -> asyncio.Task | None: - """ - 根据 app_id 创建对应的执行器任务 - - Args: - final_app_id: 要使用的 app_id,None 表示使用 QA 模式 - - Returns: - 创建的异步任务,如果发生错误则返回 None - - """ - if final_app_id is None: - # 没有app相关信息,运行QAExecutor - _logger.info("[Scheduler] 运行智能问答模式") - await self._push_init_message(3, is_flow=False) - return asyncio.create_task(self._run_qa()) - - # 有app信息,获取app详情和元数据 - try: - app_data = await AppCenterManager.fetch_app_metadata_by_id(final_app_id) - except ValueError: - _logger.exception("[Scheduler] App %s 不存在或元数据文件缺失", final_app_id) - await self.queue.close() - return None - - # 获取上下文窗口并根据app类型决定执行器 - context_num = app_data.history_len - _logger.info("[Scheduler] App上下文窗口: %d", context_num) - - if app_data.app_type == AppType.FLOW: - _logger.info("[Scheduler] 运行Flow应用") - await self._push_init_message(context_num, is_flow=True) - return asyncio.create_task(self._run_flow()) - - _logger.info("[Scheduler] 运行MCP Agent应用") - await self._push_init_message(context_num, is_flow=False) - return asyncio.create_task(self._run_agent()) - - async def _handle_task_cancellation(self, main_task: asyncio.Task) -> None: - """ - 处理任务取消的逻辑 - - Args: - main_task: 需要取消的主任务 - - """ - _logger.warning("[Scheduler] 用户取消执行,正在终止...") - main_task.cancel() - try: - await main_task - except asyncio.CancelledError: - _logger.info("[Scheduler] 主任务已取消") - except Exception: - _logger.exception("[Scheduler] 终止工作流时发生错误") - - # 检查ExecutorState,若为init、running或waiting,将状态改为cancelled - if self.task.state and self.task.state.executorStatus in [ - ExecutorStatus.INIT, - ExecutorStatus.RUNNING, - ExecutorStatus.WAITING, - ]: - self.task.state.executorStatus = ExecutorStatus.CANCELLED - _logger.info("[Scheduler] ExecutorStatus已设置为CANCELLED") - elif self.task.state: - _logger.info("[Scheduler] ExecutorStatus为 %s,保持不变", self.task.state.executorStatus) - else: - _logger.warning("[Scheduler] task.state为None,无法更新ExecutorStatus") - - async def _monitor_activity(self, kill_event: asyncio.Event, user_sub: str) -> None: - """监控用户活动状态,不活跃时终止工作流""" - try: - check_interval = 0.5 # 每0.5秒检查一次 - - while not kill_event.is_set(): - # 检查用户活动状态 - is_active = await Activity.is_active(user_sub) - - if not is_active: - _logger.warning("[Scheduler] 用户 %s 不活跃,终止工作流", user_sub) - kill_event.set() - break - - # 控制检查频率 - await asyncio.sleep(check_interval) - except asyncio.CancelledError: - _logger.info("[Scheduler] 活动监控任务已取消") - except Exception: - _logger.exception("[Scheduler] 活动监控过程中发生错误") - kill_event.set() - - - async def _get_scheduler_llm(self, reasoning_llm_id: str) -> LLMConfig: - """获取RAG大模型""" - # 获取当前会话使用的大模型 - reasoning_llm = await LLMManager.get_llm(reasoning_llm_id) - if not reasoning_llm: - err = "[Scheduler] 获取问答用大模型ID失败" - _logger.error(err) - raise ValueError(err) - reasoning_llm = LLM(reasoning_llm) - - # 获取功能性的大模型信息 - function_llm = None - if not self.user.functionLLM: - _logger.error("[Scheduler] 用户 %s 没有设置函数调用大模型,相关功能将被禁用", self.user.userSub) - else: - function_llm = await LLMManager.get_llm(self.user.functionLLM) - if not function_llm: - _logger.error( - "[Scheduler] 用户 %s 设置的函数调用大模型ID %s 不存在,相关功能将被禁用", - self.user.userSub, self.user.functionLLM, - ) - else: - function_llm = LLM(function_llm) - - # 获取并设置全局embedding模型 - embedding_obj = None - if not self.user.embeddingLLM: - _logger.error("[Scheduler] 用户 %s 没有设置向量模型,相关功能将被禁用", self.user.userSub) - else: - embedding_llm_config = await LLMManager.get_llm(self.user.embeddingLLM) - if not embedding_llm_config: - _logger.error( - "[Scheduler] 用户 %s 设置的向量模型ID %s 不存在,相关功能将被禁用", - self.user.userSub, self.user.embeddingLLM, - ) - else: - # 设置全局embedding配置 - await embedding.init(embedding_llm_config) - embedding_obj = embedding - - return LLMConfig( - reasoning=reasoning_llm, - function=function_llm, - embedding=embedding_obj, - ) - - - async def get_top_flow(self) -> str: - """获取Top1 Flow""" - # 获取所选应用的所有Flow - if not self.post_body.app or not self.post_body.app.app_id: - err = "[Scheduler] 未选择应用" - _logger.error(err) - raise RuntimeError(err) - - flow_list = await pool.get_flow_metadata(self.post_body.app.app_id) - if not flow_list: - err = "[Scheduler] 未找到应用中合法的Flow" - _logger.error(err) - raise RuntimeError(err) - - _logger.info("[Scheduler] 选择应用 %s 最合适的Flow", self.post_body.app.app_id) - choices = [{ - "name": flow.id, - "description": f"{flow.name}, {flow.description}", - } for flow in flow_list] - - # 根据用户语言选择模板 - template = self._env.from_string(FLOW_SELECT[self.task.runtime.language]) - # 渲染模板 - prompt = template.render( - template, - question=self.post_body.question, - choice_list=choices, - ) - schema = TopFlow.model_json_schema() - schema["properties"]["choice"]["enum"] = [choice["name"] for choice in choices] - function = { - "name": "select_flow", - "description": "Select the appropriate flow", - "parameters": schema, - } - result_str = await json_generator.generate( - query=self.post_body.question, - function=function, - conversation=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt}, - ], - ) - result = TopFlow.model_validate(result_str) - return result.choice - - - async def create_new_conversation( - self, title: str, user_sub: str, app_id: uuid.UUID | None = None, - *, - debug: bool = False, - ) -> Conversation: - """判断并创建新对话""" - # 新建对话 - if app_id and not await AppCenterManager.validate_user_app_access(user_sub, app_id): - err = "Invalid app_id." - raise RuntimeError(err) - new_conv = await ConversationManager.add_conversation_by_user_sub( - title=title, - user_sub=user_sub, - app_id=app_id, - debug=debug, - ) - if not new_conv: - err = "Create new conversation failed." - raise RuntimeError(err) - return new_conv - - - def _create_new_task(self, task_id: uuid.UUID, user_sub: str, conversation_id: uuid.UUID | None) -> TaskData: - """创建新的TaskData""" - return TaskData( - metadata=Task( - id=task_id, - userSub=user_sub, - conversationId=conversation_id, - ), - runtime=TaskRuntime( - taskId=task_id, - ), - state=None, - context=[], - ) - - async def _init_task(self, task_id: uuid.UUID, user_sub: str) -> None: - """初始化Task""" - conversation_id = self.post_body.conversation_id - - # 若没有Conversation ID则直接创建task - if not conversation_id: - _logger.info("[Scheduler] 无Conversation ID,直接创建新任务") - self.task = self._create_new_task(task_id, user_sub, None) - return - - # 有ConversationID则尝试从ConversationID中恢复task - _logger.info("[Scheduler] 尝试从Conversation ID %s 恢复任务", conversation_id) - - # 尝试恢复任务 - restored = False - try: - # 先验证conversation是否存在且属于该用户 - conversation = await ConversationManager.get_conversation_by_conversation_id( - user_sub, - conversation_id, - ) - - if conversation: - # 尝试从Conversation中获取最后一个Task - last_task = await TaskManager.get_task_by_conversation_id(conversation_id, user_sub) - - # 如果能获取到task,则加载完整的TaskData - if last_task and last_task.id: - _logger.info("[Scheduler] 从Conversation恢复任务 %s", last_task.id) - task_data = await TaskManager.get_task_data_by_task_id(last_task.id) - if task_data: - self.task = task_data - # 更新task_id为新的task_id - self.task.metadata.id = task_id - self.task.runtime.taskId = task_id - if self.task.state: - self.task.state.taskId = task_id - restored = True - else: - _logger.warning( - "[Scheduler] Conversation %s 不存在或无权访问,创建新任务", - conversation_id, - ) - except Exception: - _logger.exception("[Scheduler] 从Conversation恢复任务失败,创建新任务") - - # 恢复不成功则新建task - if not restored: - _logger.info("[Scheduler] 无法恢复任务,创建新任务") - self.task = self._create_new_task(task_id, user_sub, conversation_id) - - async def run(self) -> None: """运行调度器""" _logger.info("[Scheduler] 开始执行") - # 创建用于通信的事件和监控任务 kill_event = asyncio.Event() - monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.metadata.userSub)) + monitor = asyncio.create_task(self._monitor_activity(kill_event, self.task.metadata.userId)) - # 确定最终使用的 app_id final_app_id = await self._determine_app_id() - # 根据 app_id 创建对应的执行器任务 main_task = await self._create_executor_task(final_app_id) if main_task is None: - # 创建任务失败(通常是因为 app 不存在),直接返回 return - # 等待任一任务完成 done, pending = await asyncio.wait( [main_task, monitor], return_when=asyncio.FIRST_COMPLETED, ) - # 如果用户手动终止,则取消主任务 if kill_event.is_set(): await self._handle_task_cancellation(main_task) - # 更新Task,发送结束消息 await self._push_done_message() - # 关闭Queue await self.queue.close() - - - async def _run_qa(self) -> None: - qa_executor = QAExecutor( - task=self.task, - msg_queue=self.queue, - question=self.post_body.question, - llm=self.llm, - ) - _logger.info("[Scheduler] 开始智能问答") - await qa_executor.init() - await qa_executor.run() - self.task = qa_executor.task - - - async def _run_flow(self) -> None: - # 获取应用信息 - if not self.post_body.app or not self.post_body.app.app_id: - _logger.error("[Scheduler] 未选择应用") - return - - _logger.info("[Scheduler] 获取工作流元数据") - flow_info = await pool.get_flow_metadata(self.post_body.app.app_id) - - # 如果flow_info为空,则直接返回 - if not flow_info: - _logger.error("[Scheduler] 未找到工作流元数据") - return - - # 如果用户选了特定的Flow - if not self.post_body.app.flow_id: - _logger.info("[Scheduler] 选择最合适的流") - flow_id = await self.get_top_flow() - else: - # 如果用户没有选特定的Flow,则根据语义选择一个Flow - flow_id = self.post_body.app.flow_id - _logger.info("[Scheduler] 获取工作流定义") - flow_data = await pool.get_flow(self.post_body.app.app_id, flow_id) - - # 如果flow_data为空,则直接返回 - if not flow_data: - _logger.error("[Scheduler] 未找到工作流定义") - return - - # 初始化Executor - flow_exec = FlowExecutor( - flow_id=flow_id, - flow=flow_data, - task=self.task, - msg_queue=self.queue, - question=self.post_body.question, - post_body_app=self.post_body.app, - llm=self.llm, - ) - - # 开始运行 - _logger.info("[Scheduler] 运行工作流执行器") - await flow_exec.init() - await flow_exec.run() - self.task = flow_exec.task - - - async def _run_agent(self) -> None: - """构造Executor并执行""" - # 获取应用信息 - if not self.post_body.app or not self.post_body.app.app_id: - _logger.error("[Scheduler] 未选择MCP应用") - return - - # 初始化Executor - agent_exec = MCPAgentExecutor( - task=self.task, - msg_queue=self.queue, - question=self.post_body.question, - agent_id=self.post_body.app.app_id, - params=self.post_body.app.params, - llm=self.llm, - ) - # 开始运行 - _logger.info("[Scheduler] 运行MCP执行器") - await agent_exec.init() - await agent_exec.run() - self.task = agent_exec.task - - - async def _save_data(self) -> None: - """保存当前Executor、Task、Record等的数据""" - task = self.task - user_sub = self.task.metadata.userSub - post_body = self.post_body - - # 构造文档列表 - used_docs = [] - record_group = None # TODO: 需要从适当的地方获取record_group - - # 处理文档 - if hasattr(task.runtime, "documents") and task.runtime.documents: - for docs in task.runtime.documents: - doc_dict = docs if isinstance(docs, dict) else (docs.model_dump() if hasattr(docs, "model_dump") else docs) - used_docs.append(doc_dict) - - # 组装RecordContent - record_content = RecordContent( - question=task.runtime.question if hasattr(task.runtime, "question") else "", - answer=task.runtime.answer if hasattr(task.runtime, "answer") else "", - facts=task.runtime.facts if hasattr(task.runtime, "facts") else [], - data={}, - ) - - try: - # 加密Record数据 - encrypt_data, encrypt_config = Security.encrypt(record_content.model_dump_json(by_alias=True)) - except Exception: - _logger.exception("[Scheduler] 问答对加密错误") - return - - # 保存Flow信息 - if task.state: - # 遍历查找数据,并添加 - await TaskManager.save_flow_context(task.context) - - # 整理Record数据 - current_time = round(datetime.now(UTC).timestamp(), 2) - record = Record( - id=task.metadata.id, # record_id - conversationId=task.metadata.conversationId, - taskId=task.metadata.id, - userSub=user_sub, - content=encrypt_data, - key=encrypt_config, - metadata=RecordMetadata( - timeCost=0, # TODO: 需要从task中获取时间成本 - inputTokens=0, # TODO: 需要从task中获取token信息 - outputTokens=0, # TODO: 需要从task中获取token信息 - feature={}, - ), - createdAt=current_time, - flow=FlowHistory( - flow_id=task.state.flow_id if task.state else "", - flow_name=task.state.flow_name if task.state else "", - flow_status=task.state.flow_status if task.state else StepStatus.SUCCESS, - history_ids=[context.id for context in task.context], - ) if task.state else None, - ) - - # 修改文件状态 - if record_group and post_body.conversation_id: - await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) - - # 保存Record - if post_body.conversation_id: - await RecordManager.insert_record_data(user_sub, post_body.conversation_id, record) - - # 保存与答案关联的文件 - if record_group and used_docs: - await DocumentManager.save_answer_doc(user_sub, record_group, used_docs) - - if post_body.app and post_body.app.app_id: - # 更新最近使用的应用 - await AppCenterManager.update_recent_app(user_sub, post_body.app.app_id) - - # 若状态为成功,删除Task - if not task.state or task.state.flow_status in [StepStatus.SUCCESS, StepStatus.ERROR, StepStatus.CANCELLED]: - await TaskManager.delete_task_by_task_id(task.metadata.id) - else: - # 更新Task - await TaskManager.save_task(task.metadata.id, task.metadata) - await TaskManager.save_task_runtime(task.runtime) - if task.state: - await TaskManager.save_executor_checkpoint(task.state) -- Gitee