diff --git a/apps/models/user.py b/apps/models/user.py index 7d35555508f543a3d1dfa4091517e89ae6ce5021..ccd4a20d506c2d40d6e504117e9222c6c9d5a5a3 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -17,7 +17,7 @@ class User(Base): """用户表""" __tablename__ = "framework_user" - id: Mapped[str] = mapped_column(String(50), primary_key=True, index=True, init=False) + id: Mapped[str] = mapped_column(String(50), primary_key=True, index=True) """用户ID""" userName: Mapped[str] = mapped_column(String(50), index=True, unique=True, nullable=False) # noqa: N815 """用户名""" diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 85c10e1a71433efaa904a99a6095c25ef00f0f6e..b74bf83e5c44ae7d8d6db3e10aea99395feaf9b1 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -2,7 +2,6 @@ """FastAPI 用户认证相关路由""" import logging -from datetime import UTC, datetime from pathlib import Path from typing import Annotated @@ -76,10 +75,8 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: {"request": request, "reason": "未能获取用户信息,请关闭本窗口并重试。"}, status_code=status.HTTP_403_FORBIDDEN, ) - # 更新用户信息 - await UserManager.update_user(user_sub, { - "lastLogin": datetime.now(UTC), - }) + # 创建或更新用户登录信息 + await UserManager.create_or_update_on_login(user_sub) # 创建会话 current_session = await SessionManager.create_session(user_sub, user_host) return _templates.TemplateResponse( diff --git a/apps/routers/blacklist.py b/apps/routers/blacklist.py index 351d8e3939141550594834aaa29a7c8cdd278a0b..02a6e93fabaaf4d6d5603ff398b21b635228c823 100644 --- a/apps/routers/blacklist.py +++ b/apps/routers/blacklist.py @@ -54,7 +54,7 @@ async def get_blacklist_user(page: int = 0) -> JSONResponse: return JSONResponse(status_code=status.HTTP_200_OK, content=GetBlacklistUserRsp( code=status.HTTP_200_OK, message="ok", - result=GetBlacklistUserMsg(user_subs=user_list), + result=GetBlacklistUserMsg(users=user_list), ).model_dump(exclude_none=True, by_alias=True)) @@ -64,13 +64,13 @@ async def change_blacklist_user(request: UserBlacklistRequest) -> JSONResponse: # 拉黑用户 if request.is_ban: result = await UserBlacklistManager.change_blacklisted_users( - request.user_sub, + request.user_id, -MAX_CREDIT, ) # 解除拉黑 else: result = await UserBlacklistManager.change_blacklisted_users( - request.user_sub, + request.user_id, MAX_CREDIT, ) @@ -144,7 +144,7 @@ async def change_blacklist_question(request: QuestionBlacklistRequest) -> JSONRe async def abuse_report(raw_request: Request, request: AbuseRequest) -> JSONResponse: """POST /blacklist/complaint: 用户实施举报""" result = await AbuseManager.change_abuse_report( - raw_request.state.user_sub, + raw_request.state.user_id, request.record_id, request.reason_type, request.reason, diff --git a/apps/routers/document.py b/apps/routers/document.py index f916dafca57817050d65604233513406f2e94e4a..aaae234b38eec168ca1d01de63ebe0fb66582f1c 100644 --- a/apps/routers/document.py +++ b/apps/routers/document.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """FastAPI文件上传路由""" +import logging import uuid from typing import Annotated @@ -21,6 +22,9 @@ from apps.schemas.response_data import ResponseData from apps.services.conversation import ConversationManager from apps.services.document import DocumentManager from apps.services.knowledge_service import KnowledgeBaseService +from apps.services.user import UserManager + +_logger = logging.getLogger(__name__) router = APIRouter( prefix="/api/document", @@ -63,6 +67,88 @@ async def document_upload( ) +async def _get_user_name_or_skip(user_id: str, document_id: uuid.UUID) -> str | None: + """获取用户名,如果用户不存在则记录警告并返回None""" + user = await UserManager.get_user(user_id) + if user is None: + _logger.warning( + "User not found for document %s (userId: %s), skipping document", + document_id, + user_id, + ) + return None + return user.userName + + +async def _process_used_documents(conversation_id: uuid.UUID) -> list[ConversationDocumentItem]: + """处理已使用的文档列表""" + result = [] + docs = await DocumentManager.get_used_docs(conversation_id) + + for item in docs: + user_name = await _get_user_name_or_skip(item.userId, item.id) + if user_name is None: + continue + + result.append( + ConversationDocumentItem( + id=item.id, + name=item.name, + type=item.extension, + size=round(item.size, 2), + status=DocumentStatus.USED, + created_at=item.createdAt, + user_name=user_name, + ), + ) + + return result + + +async def _process_unused_documents( + conversation_id: uuid.UUID, + session_id: str, +) -> list[ConversationDocumentItem]: + """处理未使用的文档列表""" + result = [] + unused_docs = await DocumentManager.get_unused_docs(conversation_id) + doc_status = await KnowledgeBaseService.get_doc_status_from_rag( + session_id, [item.id for item in unused_docs], + ) + + for current_doc in unused_docs: + for status_item in doc_status: + if current_doc.id != status_item.id: + continue + + # 确定文档状态 + if status_item.status == "success": + new_status = DocumentStatus.UNUSED + elif status_item.status == "failed": + new_status = DocumentStatus.FAILED + else: + new_status = DocumentStatus.PROCESSING + + # 获取用户名 + user_name = await _get_user_name_or_skip(current_doc.userId, current_doc.id) + if user_name is None: + continue + + result.append( + ConversationDocumentItem( + id=current_doc.id, + name=current_doc.name, + type=current_doc.extension, + size=round(current_doc.size, 2), + status=new_status, + created_at=current_doc.createdAt, + user_name=user_name, + ), + ) + + return result + + @router.get("/{conversation_id}", response_model=ConversationDocumentRsp) async def get_document_list( request: Request, conversation_id: Annotated[uuid.UUID, Path()], @@ -83,48 +169,10 @@ async def get_document_list( result = [] if used: - # 拿到所有已使用的文档 - docs = await DocumentManager.get_used_docs(conversation_id) - result += [ - ConversationDocumentItem( - id=item.id, - name=item.name, - type=item.extension, - size=round(item.size, 2), - status=DocumentStatus.USED, - created_at=item.createdAt, - ) - for item in docs - ] + result.extend(await _process_used_documents(conversation_id)) if unused: - # 拿到所有未使用的文档 - unused_docs = await DocumentManager.get_unused_docs(conversation_id) - doc_status = await KnowledgeBaseService.get_doc_status_from_rag( - request.state.session_id, [item.id for item in unused_docs], - ) - for current_doc in unused_docs: - for status_item in doc_status: - if current_doc.id != status_item.id: - continue - - if status_item.status == "success": - new_status = DocumentStatus.UNUSED - elif status_item.status == "failed": - new_status = DocumentStatus.FAILED - else: - new_status = DocumentStatus.PROCESSING - - result += [ - ConversationDocumentItem( - id=current_doc.id, - name=current_doc.name, - type=current_doc.extension, - size=round(current_doc.size, 2), - status=new_status, - created_at=current_doc.createdAt, - ), - ] + result.extend(await _process_unused_documents(conversation_id, request.state.session_id)) # 对外展示的时候用id,不用alias return JSONResponse( diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index b59d6aa006fb1b6604bd213e80158e6f448db2d7..6ecb1e940bf25acb7fd72ecc98c32d154a8eb60c 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -58,11 +58,11 @@ async def get_mcpservice_list( # noqa: PLR0913 isActive: bool | None = None, # noqa: N803 ) -> JSONResponse: """GET /mcp?searchType=xxx&keyword=xxx&page=xxx&isInstall=xxx&isActive=xxx: 获取服务列表""" - user_sub = request.state.user_sub + user_id = request.state.user_id try: service_cards = await MCPServiceManager.fetch_mcp_services( searchType, - user_sub, + user_id, keyword, page, is_install=isInstall, @@ -101,7 +101,7 @@ async def create_or_update_mcpservice( """PUT /mcp: 新建或更新MCP服务""" if not data.mcp_id: try: - service_id = await MCPServiceManager.create_mcpservice(data, request.state.user_sub) + service_id = await MCPServiceManager.create_mcpservice(data, request.state.user_id) except Exception as e: _logger.exception("[MCPServiceCenter] MCP服务创建失败") return JSONResponse( @@ -114,7 +114,7 @@ async def create_or_update_mcpservice( ) else: try: - service_id = await MCPServiceManager.update_mcpservice(data, request.state.user_sub) + service_id = await MCPServiceManager.update_mcpservice(data, request.state.user_id) except Exception as e: _logger.exception("[MCPService] 更新MCP服务失败") return JSONResponse( @@ -144,7 +144,7 @@ async def install_mcp_service( ) -> JSONResponse: """安装MCP服务""" try: - await MCPServiceManager.install_mcpservice(request.state.user_sub, service_id, install=install) + await MCPServiceManager.install_mcpservice(request.state.user_id, service_id, install=install) except Exception as e: err = f"[MCPService] 安装mcp服务失败: {e!s}" if install else f"[MCPService] 卸载mcp服务失败: {e!s}" _logger.exception(err) @@ -314,9 +314,9 @@ async def active_or_deactivate_mcp_service( """激活/取消激活mcp""" try: if data.active: - await MCPServiceManager.active_mcpservice(request.state.user_sub, mcpId, data.mcp_env) + await MCPServiceManager.active_mcpservice(request.state.user_id, mcpId, data.mcp_env) else: - await MCPServiceManager.deactive_mcpservice(request.state.user_sub, mcpId) + await MCPServiceManager.deactive_mcpservice(request.state.user_id, mcpId) except Exception as e: err = f"[MCPService] 激活mcp服务失败: {e!s}" if data.active else f"[MCPService] 取消激活mcp服务失败: {e!s}" _logger.exception(err) diff --git a/apps/routers/service.py b/apps/routers/service.py index 46e43b154590cc5aeffbce264482ee36a900de46..d39dcd30fbeb9c2bf3f14d65d48b8560f4d9f0b7 100644 --- a/apps/routers/service.py +++ b/apps/routers/service.py @@ -75,7 +75,7 @@ async def get_service_list( # noqa: PLR0913 try: if createdByMe: # 筛选我创建的 service_cards, total_count = await ServiceCenterManager.fetch_user_services( - request.state.user_sub, + request.state.user_id, searchType, keyword, page, @@ -83,7 +83,7 @@ async def get_service_list( # noqa: PLR0913 ) elif favorited: # 筛选我收藏的 service_cards, total_count = await ServiceCenterManager.fetch_favorite_services( - request.state.user_sub, + request.state.user_id, searchType, keyword, page, @@ -91,7 +91,7 @@ async def get_service_list( # noqa: PLR0913 ) else: # 获取所有服务 service_cards, total_count = await ServiceCenterManager.fetch_all_services( - request.state.user_sub, + request.state.user_id, searchType, keyword, page, @@ -135,7 +135,7 @@ async def update_service(request: Request, data: UpdateServiceRequest) -> JSONRe """POST /service: 上传并解析服务""" if not data.service_id: try: - service_id = await ServiceCenterManager.create_service(request.state.user_sub, data.data) + service_id = await ServiceCenterManager.create_service(request.state.user_id, data.data) except Exception as e: _logger.exception("[ServiceCenter] 创建服务失败") return JSONResponse( @@ -148,7 +148,7 @@ async def update_service(request: Request, data: UpdateServiceRequest) -> JSONRe ) else: try: - service_id = await ServiceCenterManager.update_service(request.state.user_sub, data.service_id, data.data) + service_id = await ServiceCenterManager.update_service(request.state.user_id, data.service_id, data.data) except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, @@ -194,7 +194,7 @@ async def get_service_detail( # 示例:返回指定服务的详情 if edit: try: - name, data = await ServiceCenterManager.get_service_data(request.state.user_sub, serviceId) + name, data = await ServiceCenterManager.get_service_data(request.state.user_id, serviceId) except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, @@ -237,7 +237,7 @@ async def get_service_detail( async def delete_service(request: Request, serviceId: Annotated[uuid.UUID, Path()]) -> JSONResponse: # noqa: N803 """DELETE /service/{serviceId}: 删除服务""" try: - await ServiceCenterManager.delete_service(request.state.user_sub, serviceId) + await ServiceCenterManager.delete_service(request.state.user_id, serviceId) except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, @@ -271,7 +271,7 @@ async def modify_favorite_service( """PUT /service/{serviceId}: 修改服务收藏状态""" try: success = await ServiceCenterManager.modify_favorite_service( - request.state.user_sub, + request.state.user_id, serviceId, favorited=data.favorited, ) diff --git a/apps/routers/user.py b/apps/routers/user.py index df7b9f5509a2aa7143bc92a6f4353b0324c1d5b1..4460cc5d1002b34ccfc7a2f16d21e8ba68942493 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -4,11 +4,12 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import JSONResponse +from apps.common.config import config from apps.dependency import verify_personal_token, verify_session from apps.schemas.request_data import UserUpdateRequest from apps.schemas.response_data import ResponseData from apps.schemas.tag import UserTagListResponse -from apps.schemas.user import UserListItem, UserListMsg, UserListRsp +from apps.schemas.user import UserInfoMsg, UserInfoRsp, UserListItem, UserListMsg, UserListRsp from apps.services.user import UserManager from apps.services.user_tag import UserTagManager @@ -23,7 +24,13 @@ router = APIRouter( async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONResponse: """POST /api/user: 更新当前用户信息""" # 更新用户信息 - await UserManager.update_user(request.state.user_sub, data.model_dump(exclude_unset=True, exclude_none=True)) + try: + await UserManager.update_user_info(request.state.user_id, data) + except ValueError as e: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"code": status.HTTP_404_NOT_FOUND, "message": str(e)}, + ) return JSONResponse( status_code=status.HTTP_200_OK, @@ -31,14 +38,36 @@ async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONRes ) -# TODO -@router.get("") +@router.get("", response_model=UserInfoRsp) async def get_user_info(request: Request) -> JSONResponse: - """GET /api/user/info: 获取当前用户信息""" - user = await UserManager.get_user(request.state.user_sub) + """GET /api/user: 获取当前用户信息""" + user = await UserManager.get_user(request.state.user_id) + + if not user: + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content=ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="用户不存在", + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ) + + user_info = UserInfoMsg( + userId=user.id, + userName=user.userName, + isAdmin=user.id in config.login.admin_user, + personalToken=user.personalToken, + autoExecute=user.autoExecute or False, + ) + return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData(code=status.HTTP_200_OK, message="success", result=user), + content=UserInfoRsp( + code=status.HTTP_200_OK, + message="用户信息获取成功", + result=user_info, + ).model_dump(exclude_none=True, by_alias=True), ) @@ -50,11 +79,11 @@ async def list_user( user_list, total = await UserManager.list_user(page_size, page_num) user_info_list = [] for user in user_list: - if user.userSub == request.state.user_sub: + if user.id == request.state.user_id: continue info = UserListItem( - userName=user.userSub, - userSub=user.userSub, + userName=user.userName, + userId=user.id, ) user_info_list.append(info) @@ -77,7 +106,7 @@ async def get_user_tag( ) -> JSONResponse: """GET /user/tag?topk=5: 获取用户标签""" try: - tags = await UserTagManager.get_user_domain_by_user_and_topk(request.state.user_sub, topk) + tags = await UserTagManager.get_user_domain_by_user_and_topk(request.state.user_id, topk) except ValueError as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/scheduler/scheduler/executor.py b/apps/scheduler/scheduler/executor.py index 6aa616fbd77d15d1d0ab136f60e65eefd00e574b..bfae212bbb20d4312495334a27f97c566b8ca8ab 100644 --- a/apps/scheduler/scheduler/executor.py +++ b/apps/scheduler/scheduler/executor.py @@ -76,7 +76,6 @@ class ExecutorMixin: """ if final_app_id is None: _logger.info("[Scheduler] 运行智能问答模式") - await self._push_init_message(3, is_flow=False) return asyncio.create_task(self._run_qa()) try: @@ -91,11 +90,9 @@ class ExecutorMixin: if app_data.app_type == AppType.FLOW: _logger.info("[Scheduler] 运行Flow应用") - await self._push_init_message(context_num, is_flow=True) return asyncio.create_task(self._run_flow()) _logger.info("[Scheduler] 运行MCP Agent应用") - await self._push_init_message(context_num, is_flow=False) return asyncio.create_task(self._run_agent()) async def _handle_task_cancellation(self, main_task: asyncio.Task) -> None: diff --git a/apps/schemas/blacklist.py b/apps/schemas/blacklist.py index 5e53393c45a412593e7af3b64ccccbbed44e7d97..093d066ae9d56aaf72d7427747170a9e6889206f 100644 --- a/apps/schemas/blacklist.py +++ b/apps/schemas/blacklist.py @@ -20,7 +20,7 @@ class QuestionBlacklistRequest(BaseModel): class UserBlacklistRequest(BaseModel): """POST /api/blacklist/user 请求数据结构""" - user_sub: str + user_id: str is_ban: int @@ -65,10 +65,17 @@ class BlacklistSchema(BaseModel): from_attributes = True +class BlacklistedUser(BaseModel): + """被封禁用户的数据结构""" + + user_id: str + user_name: str + + class GetBlacklistUserMsg(BaseModel): """GET /api/blacklist/user Result数据结构""" - user_subs: list[str] + users: list[BlacklistedUser] class GetBlacklistUserRsp(ResponseData): diff --git a/apps/schemas/document.py b/apps/schemas/document.py index b3a643d686650ed818d2c7483f324600d9e3ad0e..b92dd8289ac6106252e674c692ec15ad1ae1a721 100644 --- a/apps/schemas/document.py +++ b/apps/schemas/document.py @@ -17,7 +17,7 @@ class BaseDocumentItem(BaseModel): type: str = Field(default="", description="文档类型") size: float = Field(default=0.0, description="文档大小") created_at: datetime | None = Field(default=None, description="创建时间") - user_sub: str | None = None + user_name: str | None = Field(default=None, description="上传者用户名") conversation_id: str | None = None class Config: diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 6d0657fb57e09d40d155a6caddbf2ed570523140..834afec083204f4b8b9e50f81783a07e853bccea 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -76,5 +76,4 @@ class UpdateUserKnowledgebaseReq(BaseModel): class UserUpdateRequest(BaseModel): """更新用户信息请求体""" - user_name: str | None = Field(default=None, description="用户名", alias="userName") auto_execute: bool = Field(default=False, description="是否自动执行", alias="autoExecute") diff --git a/apps/schemas/user.py b/apps/schemas/user.py index df88992cb6c9689117c8cf913109379aec484a35..4401ccf469201dad274374222c8e751c4bbb5a8c 100644 --- a/apps/schemas/user.py +++ b/apps/schemas/user.py @@ -9,14 +9,14 @@ from .response_data import ResponseData class UserListItem(BaseModel): """用户信息数据结构""" - user_sub: str = Field(alias="userSub", default="") + user_id: str = Field(alias="userId", default="") user_name: str = Field(alias="userName", default="") class UserInfoMsg(BaseModel): """GET /api/user Result数据结构""" - id: int + id: str = Field(alias="userId", default="") user_name: str = Field(alias="userName", default="") is_admin: bool = Field(alias="isAdmin", default=False) personal_token: str = Field(alias="personalToken", default="") diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index d80ded572029aaad9b8c782f1e13c768043fcd35..a8024b01a32061189200082d1c4a50835bfda494 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -9,6 +9,7 @@ from sqlalchemy import and_, select from apps.common.postgres import postgres from apps.common.security import Security from apps.models import Blacklist, Record, User +from apps.schemas.blacklist import BlacklistedUser from apps.schemas.record import RecordContent logger = logging.getLogger(__name__) @@ -77,23 +78,26 @@ class UserBlacklistManager: """用户黑名单相关操作""" @staticmethod - async def get_blacklisted_users(limit: int, offset: int) -> list[str]: + async def get_blacklisted_users(limit: int, offset: int) -> list[BlacklistedUser]: """获取当前所有黑名单用户""" async with postgres.session() as session: - return list((await session.scalars(select(User.userSub).where( + users = (await session.scalars(select(User).where( User.credit <= 0, - ).order_by(User.userSub).offset(offset).limit(limit))).all()) + ).order_by(User.userName).offset(offset).limit(limit))).all() + return [BlacklistedUser(user_id=user.id, user_name=user.userName) for user in users] @staticmethod - async def check_blacklisted_users(user_sub: str) -> bool: + async def check_blacklisted_users(user_id: str) -> bool: """检测某用户是否已被拉黑""" async with postgres.session() as session: + conditions = [ + User.credit <= 0, + User.isWhitelisted == False, # noqa: E712 + User.id == user_id, + ] + user = (await session.scalars(select(User).where( - and_( - User.userSub == user_sub, - User.credit <= 0, - User.isWhitelisted == False, # noqa: E712 - ), + and_(*conditions), ).limit(1))).one_or_none() if user: @@ -102,11 +106,11 @@ class UserBlacklistManager: return False @staticmethod - async def change_blacklisted_users(user_sub: str, credit_diff: int, credit_limit: int = 100) -> bool: + async def change_blacklisted_users(user_id: str, credit_diff: int, credit_limit: int = 100) -> bool: """修改用户的信用分""" async with postgres.session() as session: # 获取用户当前信用分 - user = (await session.scalars(select(User).where(User.userSub == user_sub))).one_or_none() + user = (await session.scalars(select(User).where(User.id == user_id))).one_or_none() # 用户不存在 if user is None: @@ -143,12 +147,18 @@ class AbuseManager: """用户举报相关操作""" @staticmethod - async def change_abuse_report(user_sub: str, record_id: uuid.UUID, reason_type: str, reason: str) -> bool: + async def change_abuse_report(user_id: str, record_id: uuid.UUID, reason_type: str, reason: str) -> bool: """存储用户举报详情""" async with postgres.session() as session: + # Verify user exists + user = (await session.scalars(select(User).where(User.id == user_id))).one_or_none() + if not user: + logger.info("[AbuseManager] 用户不存在") + return False + record = (await session.scalars(select(Record).where( and_( - Record.userSub == user_sub, + Record.userId == user_id, Record.id == record_id, ), ))).one_or_none() diff --git a/apps/services/document.py b/apps/services/document.py index 70ed9c409e09b6e7d661764b8b0dd99ceb605c7a..94b8818a28197d435329bbed67ede1af2b99ae02 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -56,7 +56,7 @@ class DocumentManager: @staticmethod async def storage_docs( - user_sub: str, conversation_id: uuid.UUID, documents: list[UploadFile], + user_id: str, conversation_id: uuid.UUID, documents: list[UploadFile], ) -> list[Document]: """存储多个文件""" uploaded_files = [] @@ -75,7 +75,7 @@ class DocumentManager: # 保存到数据库 doc_info = Document( - userSub=user_sub, + userId=user_id, name=document.filename, extension=mime, size=document.size / 1024.0, @@ -176,7 +176,7 @@ class DocumentManager: @staticmethod - async def delete_document(user_sub: str, document_list: list[str]) -> None: + async def delete_document(user_id: str, document_list: list[str]) -> None: """从未使用文件列表中删除一个文件""" async with postgres.session() as session: for doc in document_list: @@ -184,7 +184,7 @@ class DocumentManager: select(Document).where( and_( Document.id == doc, - Document.userSub == user_sub, + Document.userId == user_id, ), ), ) @@ -213,7 +213,7 @@ class DocumentManager: @staticmethod - async def delete_document_by_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> list[str]: + async def delete_document_by_conversation_id(user_id: str, conversation_id: uuid.UUID) -> list[str]: """通过ConversationID删除文件""" doc_ids = [] @@ -227,9 +227,9 @@ class DocumentManager: await session.delete(doc) await session.commit() - session_id = await SessionManager.get_session_by_user_sub(user_sub) + session_id = await SessionManager.get_session_by_user(user_id) if not session_id: - logger.error("[DocumentManager] Session不存在: %s", user_sub) + logger.error("[DocumentManager] Session不存在: %s", user_id) return [] await KnowledgeBaseService.delete_doc_from_rag(session_id, doc_ids) return doc_ids @@ -247,14 +247,14 @@ class DocumentManager: @staticmethod - async def change_doc_status(user_sub: str, conversation_id: uuid.UUID) -> None: + async def change_doc_status(user_id: str, conversation_id: uuid.UUID) -> None: """文件状态由unused改为used""" async with postgres.session() as session: conversation = (await session.scalars( select(Conversation).where( and_( Conversation.id == conversation_id, - Conversation.userSub == user_sub, + Conversation.userId == user_id, ), ), )).one_or_none() @@ -280,7 +280,7 @@ class DocumentManager: @staticmethod - async def save_answer_doc(user_sub: str, record_id: uuid.UUID, doc_infos: list[ConversationDocument]) -> None: + async def save_answer_doc(user_id: str, record_id: uuid.UUID, doc_infos: list[ConversationDocument]) -> None: """保存与答案关联的文件(使用PostgreSQL)""" async with postgres.session() as session: # 查询对应的Record @@ -288,7 +288,7 @@ class DocumentManager: select(Record).where( and_( Record.id == record_id, - Record.userSub == user_sub, + Record.userId == user_id, ), ), )).one_or_none() diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 43f3653ab4c5a20f065a77579def37be7a82f2b3..8dfc50878f660ba777f81dbc74e42ec9f0d7fd5a 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -36,6 +36,7 @@ from apps.schemas.mcp import ( UpdateMCPServiceRequest, ) from apps.schemas.mcp_service import MCPServiceCardItem +from apps.services.user import UserManager logger = logging.getLogger(__name__) MCP_ICON_PATH = ICON_PATH / "mcp" @@ -46,13 +47,7 @@ class MCPServiceManager: @staticmethod async def is_active(user_id: str, mcp_id: str) -> bool: - """ - 判断用户是否激活MCP - - :param user_id: str: 用户ID - :param mcp_id: str: MCP服务ID - :return: 是否激活 - """ + """判断用户是否激活MCP""" async with postgres.session() as session: mcp_info = (await session.scalars(select(MCPActivated).where( and_( @@ -65,12 +60,7 @@ class MCPServiceManager: @staticmethod async def get_icon_path(mcp_id: str) -> str: - """ - 获取MCP服务图标路径 - - :param mcp_id: str: MCP服务ID - :return: 图标路径 - """ + """获取MCP服务图标路径""" if (MCP_ICON_PATH / f"{mcp_id}.png").exists(): return f"/static/mcp/{mcp_id}.png" return "" @@ -78,12 +68,7 @@ class MCPServiceManager: @staticmethod async def get_service_status(mcp_id: str) -> MCPInstallStatus: - """ - 获取MCP服务状态 - - :param str mcp_id: MCP服务ID - :return: MCP服务状态 - """ + """获取MCP服务状态""" async with postgres.session() as session: mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() if mcp_info: @@ -101,27 +86,21 @@ class MCPServiceManager: is_install: bool | None = None, is_active: bool | None = None, ) -> list[MCPServiceCardItem]: - """ - 获取所有MCP服务列表 - - :param search_type: SearchType: str: MCP搜索类型 - :param user_id: str: 用户ID - :param keyword: str: MCP搜索关键字 - :param page: int: 页码 - :param is_install: bool | None: 是否已安装 - :param is_active: bool | None: 是否已激活 - :return: MCP服务列表 - """ + """获取所有MCP服务列表""" mcpservice_pools = await MCPServiceManager._search_mcpservice( search_type, keyword, page, user_id, is_active=is_active, is_installed=is_install, ) + + author_ids = {item.authorId for item in mcpservice_pools} + author_names = await UserManager.get_usernames_by_ids(author_ids) + return [ MCPServiceCardItem( mcpserviceId=item.id, icon=await MCPServiceManager.get_icon_path(item.id), name=item.name, description=item.description, - author=item.authorId, + author=author_names.get(item.authorId, item.authorId), isActive=await MCPServiceManager.is_active(user_id, item.id), status=await MCPServiceManager.get_service_status(item.id), ) @@ -131,24 +110,14 @@ class MCPServiceManager: @staticmethod async def get_mcp_service(mcp_id: str) -> MCPInfo | None: - """ - 获取MCP服务详细信息 - - :param mcp_id: str: MCP服务ID - :return: MCP服务详细信息 - """ + """获取MCP服务详细信息""" async with postgres.session() as session: return (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() @staticmethod async def get_mcp_config(mcp_id: str) -> tuple[MCPServerConfig, str]: - """ - 获取MCP服务配置 - - :param mcp_id: str: MCP服务ID - :return: MCP服务配置 - """ + """获取MCP服务配置""" icon_path = "" config = await MCPLoader.get_config(mcp_id) return config, icon_path @@ -156,16 +125,58 @@ class MCPServiceManager: @staticmethod async def get_mcp_tools(mcp_id: str) -> list[MCPTools]: - """ - 获取MCP可用工具 - - :param mcp_id: str: MCP服务ID - :return: MCP工具详细信息列表 - """ + """获取MCP可用工具""" async with postgres.session() as session: return list((await session.scalars(select(MCPTools).where(MCPTools.mcpId == mcp_id))).all()) + @staticmethod + async def _build_search_conditions( + search_type: SearchType, + keyword: str | None, + author_user_ids: list[str], + ) -> Any | None: + """构建搜索条件""" + if search_type == SearchType.ALL: + conditions = [ + MCPInfo.name.like(f"%{keyword}%"), + MCPInfo.description.like(f"%{keyword}%"), + ] + if author_user_ids: + conditions.append(MCPInfo.authorId.in_(author_user_ids)) + return or_(*conditions) + + if search_type == SearchType.NAME: + return MCPInfo.name.like(f"%{keyword}%") + + if search_type == SearchType.DESCRIPTION: + return MCPInfo.description.like(f"%{keyword}%") + + if author_user_ids: + return MCPInfo.authorId.in_(author_user_ids) + return None + + + @staticmethod + def _apply_user_filters( + sql: Any, + user_id: str, + *, + is_active: bool | None = None, + is_installed: bool | None = None, + ) -> Any: + """应用用户过滤条件""" + if is_installed is not None: + sql = sql.where(MCPInfo.id.in_( + select(MCPActivated.mcpId).where(MCPActivated.userId == user_id), + )) + if is_active is not None: + sql = sql.where(MCPInfo.id.in_( + select(MCPActivated.mcpId).where(MCPActivated.userId == user_id), + )) + return sql + + @staticmethod async def _search_mcpservice( # noqa: PLR0913 search_type: SearchType, @@ -176,69 +187,42 @@ class MCPServiceManager: is_active: bool | None = None, is_installed: bool | None = None, ) -> list[MCPInfo]: - """ - 基于输入条件搜索MCP服务 - - :param search_conditions: dict[str, Any]: 搜索条件 - :param page: int: 页码 - :return: MCP列表 - """ - # 分页查询 + """基于输入条件搜索MCP服务""" skip = (page - 1) * SERVICE_PAGE_SIZE - async with postgres.session() as session: - sql = select(MCPInfo) - - if search_type == SearchType.ALL: - sql = sql.where( - or_( - MCPInfo.name.like(f"%{keyword}%"), - MCPInfo.description.like(f"%{keyword}%"), - MCPInfo.authorId.like(f"%{keyword}%"), - ), - ) - elif search_type == SearchType.NAME: - sql = sql.where(MCPInfo.name.like(f"%{keyword}%")) - elif search_type == SearchType.DESCRIPTION: - sql = sql.where(MCPInfo.description.like(f"%{keyword}%")) - elif search_type == SearchType.AUTHOR: - sql = sql.where(MCPInfo.authorId.like(f"%{keyword}%")) - sql = sql.offset(skip).limit(SERVICE_PAGE_SIZE) + author_user_ids = [] + if keyword and search_type in (SearchType.ALL, SearchType.AUTHOR): + author_user_ids = await UserManager.get_user_ids_by_username_keyword(keyword) - if is_installed is not None: - sql = sql.where(MCPInfo.id.in_( - select(MCPActivated.mcpId).where(MCPActivated.userId == user_id), - )) - if is_active is not None: - sql = sql.where(MCPInfo.id.in_( - select(MCPActivated.mcpId).where(MCPActivated.userId == user_id), - )) + search_conditions = await MCPServiceManager._build_search_conditions( + search_type, keyword, author_user_ids, + ) + if search_conditions is None: + logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_type) + return [] + async with postgres.session() as session: + sql = select(MCPInfo).where(search_conditions) + sql = MCPServiceManager._apply_user_filters( + sql, user_id, is_active=is_active, is_installed=is_installed, + ) + sql = sql.offset(skip).limit(SERVICE_PAGE_SIZE) result = list((await session.scalars(sql)).all()) - # 如果未找到,返回空列表 if not result: logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_type) - return [] - # 将数据库中的MCP服务转换为对应的对象列表 + return result @staticmethod async def create_mcpservice(data: UpdateMCPServiceRequest, user_id: str) -> str: - """ - 创建MCP服务 - - :param UpdateMCPServiceRequest data: MCP服务配置 - :return: MCP服务ID - """ - # 检查config + """创建MCP服务""" if data.mcp_type == MCPType.SSE: config = MCPServerSSEConfig.model_validate(data.config) else: config = MCPServerStdioConfig.model_validate(data.config) - # 构造Server mcp_server = MCPServerConfig( name=await MCPServiceManager.clean_name(data.name), overview=data.overview, @@ -250,14 +234,13 @@ class MCPServiceManager: author=user_id, ) - # 检查是否存在相同服务 async with postgres.session() as session: mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.name == mcp_server.name))).one_or_none() if mcp_info: mcp_server.name = f"{mcp_server.name}-{uuid.uuid4().hex[:6]}" logger.warning("[MCPServiceManager] 已存在相同ID或名称的MCP服务") - # 保存并载入配�? logger.info("[MCPServiceManager] 创建mcp�?s", mcp_server.name) + logger.info("[MCPServiceManager] 创建mcp: %s", mcp_server.name) mcp_path = MCP_PATH / "template" / data.mcp_id / "project" index = None if isinstance(config, MCPServerStdioConfig): @@ -291,12 +274,7 @@ class MCPServiceManager: @staticmethod async def update_mcpservice(data: UpdateMCPServiceRequest, user_id: str) -> str: - """ - 更新MCP服务 - - :param UpdateMCPServiceRequest data: MCP服务配置 - :return: MCP服务ID - """ + """更新MCP服务""" if not data.mcp_id: msg = "[MCPServiceManager] MCP服务ID为空" raise ValueError(msg) @@ -312,13 +290,11 @@ class MCPServiceManager: msg = "[MCPServiceManager] MCP服务未找到或无权限" raise ValueError(msg) - # 查询所有激活该MCP服务的用户 async with postgres.session() as session: activated_users = (await session.scalars(select(MCPActivated).where( MCPActivated.mcpId == data.mcp_id, ))).all() - # 为每个激活的用户取消激活 for activated_user in activated_users: await MCPServiceManager.deactive_mcpservice(user_id=activated_user.userId, mcp_id=data.mcp_id) @@ -349,16 +325,9 @@ class MCPServiceManager: @staticmethod async def delete_mcpservice(mcp_id: str) -> None: - """ - 删除MCP服务 - - :param mcp_id: str: MCP服务ID - :return: 是否删除成功 - """ - # 删除对应的mcp + """删除MCP服务""" await MCPLoader.delete_mcp(mcp_id) - # 从PostgreSQL中删除所有应用对应该MCP服务的关联记录 async with postgres.session() as session: stmt = delete(AppMCP).where(AppMCP.mcpId == mcp_id) await session.execute(stmt) @@ -371,14 +340,7 @@ class MCPServiceManager: mcp_id: str, mcp_env: dict[str, Any] | None = None, ) -> None: - """ - 激活MCP服务 - - :param user_id: str: 用户ID - :param mcp_id: str: MCP服务ID - :param mcp_env: dict[str, Any] | None: MCP服务环境变量 - :return: 无 - """ + """激活MCP服务""" if mcp_env is None: mcp_env = {} @@ -403,13 +365,7 @@ class MCPServiceManager: user_id: str, mcp_id: str, ) -> None: - """ - 取消激活MCP服务 - - :param user_id: str: 用户ID - :param mcp_id: str: MCP服务ID - :return: 无 - """ + """取消激活MCP服务""" try: await mcp_pool.stop(mcp_id=mcp_id, user_id=user_id) except KeyError: @@ -419,12 +375,7 @@ class MCPServiceManager: @staticmethod async def clean_name(name: str) -> str: - """ - 移除MCP服务名称中的特殊字符 - - :param name: str: MCP服务名称 - :return: 清理后的MCP服务名称 - """ + """移除MCP服务名称中的特殊字符""" invalid_chars = r'[\\\/:*?"<>|]' return re.sub(invalid_chars, "_", name) @@ -435,7 +386,6 @@ class MCPServiceManager: icon: UploadFile, ) -> str: """保存MCP服务图标""" - # 检查MIME mime = magic.from_buffer(icon.file.read(), mime=True) icon.file.seek(0) @@ -443,14 +393,11 @@ class MCPServiceManager: err = "[MCPServiceManager] 不支持的图标格式" raise ValueError(err) - # 保存图标 image = Image.open(icon.file) image = image.convert("RGB") image = image.resize((64, 64), resample=Image.Resampling.LANCZOS) - # 检查文件夹 if not await MCP_ICON_PATH.exists(): await MCP_ICON_PATH.mkdir(parents=True, exist_ok=True) - # 保存 image.save(MCP_ICON_PATH / f"{mcp_id}.png", format="PNG", optimize=True, compress_level=9) return f"/static/mcp/{mcp_id}.png" @@ -458,13 +405,7 @@ class MCPServiceManager: @staticmethod async def is_user_actived(user_id: str, mcp_id: str) -> bool: - """ - 判断用户是否激活MCP - - :param user_id: str: 用户ID - :param mcp_id: str: MCP服务ID - :return: 是否激活 - """ + """判断用户是否激活MCP""" async with postgres.session() as session: mcp_info = (await session.scalars(select(MCPActivated).where( and_( @@ -477,26 +418,14 @@ class MCPServiceManager: @staticmethod async def query_mcp_tools(mcp_id: str) -> list[MCPTools]: - """ - 查询MCP工具 - - :param mcp_id: str: MCP服务ID - :return: MCP工具列表 - """ + """查询MCP工具""" async with postgres.session() as session: return list((await session.scalars(select(MCPTools).where(MCPTools.mcpId == mcp_id))).all()) @staticmethod async def install_mcpservice(user_id: str, service_id: str, *, install: bool) -> None: - """ - 安装或卸载MCP服务 - - :param user_id: str: 用户ID - :param service_id: str: MCP服务ID - :param install: bool: 是否安装 - :return: 无 - """ + """安装或卸载MCP服务""" async with postgres.session() as session: db_service = (await session.scalars(select(MCPInfo).where( and_( diff --git a/apps/services/service.py b/apps/services/service.py index 3db7ccff3a1dd0bc97f0cdf3599cdd6f27d33bb4..6deb8cdf452e39a80a48ab792bb86768b385a190 100644 --- a/apps/services/service.py +++ b/apps/services/service.py @@ -31,6 +31,7 @@ from apps.schemas.flow import ( ServiceMetadata, ) from apps.schemas.service import ServiceApiData, ServiceCardItem +from apps.services.user import UserManager logger = logging.getLogger(__name__) @@ -59,17 +60,26 @@ class ServiceCenterManager: # 构建搜索条件 search_conditions = [] if keyword: - search_condition_map = { - SearchType.ALL: or_( + # 处理按作者搜索的特殊情况:根据用户名关键字查询用户ID + if search_type == SearchType.AUTHOR: + user_ids = await UserManager.get_user_ids_by_username_keyword(keyword) + # 如果没有匹配的用户,添加一个永远为 False 的条件 + search_conditions = [Service.authorId.in_(user_ids)] if user_ids else [Service.authorId.is_(None)] + elif search_type == SearchType.ALL: + # 全局搜索时,也需要根据用户名搜索作者 + user_ids = await UserManager.get_user_ids_by_username_keyword(keyword) + author_condition = Service.authorId.in_(user_ids) if user_ids else Service.authorId.is_(None) + search_conditions = [or_( Service.name.like(f"%{keyword}%"), Service.description.like(f"%{keyword}%"), - Service.authorId.like(f"%{keyword}%"), - ), - SearchType.NAME: Service.name.like(f"%{keyword}%"), - SearchType.DESCRIPTION: Service.description.like(f"%{keyword}%"), - SearchType.AUTHOR: Service.authorId.like(f"%{keyword}%"), - } - search_conditions = [search_condition_map.get(search_type)] + author_condition, + )] + else: + search_condition_map = { + SearchType.NAME: Service.name.like(f"%{keyword}%"), + SearchType.DESCRIPTION: Service.description.like(f"%{keyword}%"), + } + search_conditions = [search_condition_map.get(search_type)] all_conditions = (base_conditions or []) + search_conditions @@ -99,6 +109,10 @@ class ServiceCenterManager: service_pools = list((await session.scalars(data_query)).all()) total_count = (await session.scalar(count_query)) or 0 + # 获取作者的用户名 + author_ids = {service_pool.authorId for service_pool in service_pools} + author_map = await UserManager.get_usernames_by_ids(author_ids) + fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_id) services = [ ServiceCardItem( @@ -106,7 +120,7 @@ class ServiceCenterManager: icon="", name=service_pool.name, description=service_pool.description, - author=service_pool.author, + author=author_map[service_pool.authorId], favorited=(service_pool.id in fav_service_ids), ) for service_pool in service_pools @@ -135,6 +149,10 @@ class ServiceCenterManager: service_pools = list((await session.scalars(data_query)).all()) total_count = (await session.scalar(count_query)) or 0 + # 获取作者的用户名 + author_ids = {service_pool.authorId for service_pool in service_pools} + author_map = await UserManager.get_usernames_by_ids(author_ids) + fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_id) services = [ ServiceCardItem( @@ -142,7 +160,7 @@ class ServiceCenterManager: icon="", name=service_pool.name, description=service_pool.description, - author=service_pool.author, + author=author_map[service_pool.authorId], favorited=(service_pool.id in fav_service_ids), ) for service_pool in service_pools @@ -181,18 +199,22 @@ class ServiceCenterManager: service_pools = list((await session.scalars(data_query)).all()) total_count = (await session.scalar(count_query)) or 0 - services = [ - ServiceCardItem( - serviceId=service_pool.id, - icon="", - name=service_pool.name, - description=service_pool.description, - author=service_pool.author, - favorited=True, - ) - for service_pool in service_pools - ] - return services, total_count + # 获取作者的用户名 + author_ids = {service_pool.authorId for service_pool in service_pools} + author_map = await UserManager.get_usernames_by_ids(author_ids) + + services = [ + ServiceCardItem( + serviceId=service_pool.id, + icon="", + name=service_pool.name, + description=service_pool.description, + author=author_map[service_pool.authorId], + favorited=True, + ) + for service_pool in service_pools + ] + return services, total_count @staticmethod diff --git a/apps/services/session.py b/apps/services/session.py index 1bd99452c18a6610c0f799d4a81f77f9095d5e6d..fb3947ab8a75aa896895efa9d236532947a00f48 100644 --- a/apps/services/session.py +++ b/apps/services/session.py @@ -66,7 +66,7 @@ class SessionManager: return None # 查询黑名单 - if user_id and await UserBlacklistManager.check_blacklisted_users(user_id=user_id): + if user_id and await UserBlacklistManager.check_blacklisted_users(user_id): logger.error("[SessionManager] 用户在Session黑名单中") await SessionManager.delete_session(session_id) return None diff --git a/apps/services/user.py b/apps/services/user.py index 950425c2716f9fd657c1d1602f9b5a49aaa18a0e..4946109c4e130bce0180615ce0caa5502612d135 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -9,6 +9,7 @@ from sqlalchemy import func, select from apps.common.postgres import postgres from apps.models import User +from apps.schemas.request_data import UserUpdateRequest from .conversation import ConversationManager @@ -48,34 +49,93 @@ class UserManager: @staticmethod - async def update_user(user_id: str, data: dict) -> None: + async def get_usernames_by_ids(user_ids: set[str]) -> dict[str, str]: """ - 根据用户sub更新用户信息 + 根据用户ID集合批量获取用户名 + + :param user_ids: 用户ID集合 + :return: 用户ID到用户名的映射字典 + """ + if not user_ids: + return {} + + async with postgres.session() as session: + users = (await session.scalars( + select(User).where(User.id.in_(user_ids)), + )).all() + return {user.id: user.userName for user in users} + + @staticmethod + async def get_user_ids_by_username_keyword(keyword: str) -> list[str]: + """ + 根据用户名关键字查询所有匹配的用户ID + + :param keyword: 用户名关键字 + :return: 匹配的用户ID列表 + """ + if not keyword: + return [] + + async with postgres.session() as session: + users = (await session.scalars( + select(User).where(User.userName.like(f"%{keyword}%")), + )).all() + return [user.id for user in users] + + + @staticmethod + async def update_user_info(user_id: str, data: UserUpdateRequest) -> None: + """ + 更新已有用户的属性信息 :param user_id: 用户ID :param data: 更新数据 + :raises ValueError: 如果用户不存在 + """ + async with postgres.session() as session: + user = ( + await session.scalars(select(User).where(User.id == user_id)) + ).one_or_none() + if not user: + error_msg = f"[UserManager] User {user_id} not found" + logger.error(error_msg) + raise ValueError(error_msg) + + # 更新指定字段 + update_dict = data.model_dump(exclude_unset=True, exclude_none=True, by_alias=True) + for key, value in update_dict.items(): + if hasattr(user, key): + setattr(user, key, value) + + await session.commit() + + @staticmethod + async def create_or_update_on_login(user_id: str, user_name: str | None = None) -> None: + """ + 在登录时创建新用户或更新已有用户的 lastLogin 时间 + + :param user_id: 用户ID + :param user_name: 用户名(仅在创建新用户时使用) """ async with postgres.session() as session: user = ( await session.scalars(select(User).where(User.id == user_id)) ).one_or_none() + if not user: + # 创建新用户 user = User( - userName=data.get("userName", f"用户-{secrets.token_hex(8)}"), + id=user_id, + userName=user_name or f"用户-{secrets.token_hex(8)}", isActive=True, isWhitelisted=False, credit=0, ) - await session.merge(user) - await session.commit() - return - - # 更新指定字段 - for key, value in data.items(): - if hasattr(user, key) and value is not None: - setattr(user, key, value) + session.add(user) + else: + # 更新已有用户的登录时间 + user.lastLogin = datetime.now(tz=UTC) - user.lastLogin = datetime.now(tz=UTC) await session.commit() @staticmethod diff --git a/tests/services/test_user.py b/tests/services/test_user.py index f6704cf2ce15ffdc17b7c2a979b302bc78ba6420..84cbf663d2909f06024eaaf1c5c515772ed8f9b6 100644 --- a/tests/services/test_user.py +++ b/tests/services/test_user.py @@ -88,48 +88,70 @@ class TestUserManager: mock_session.scalars.assert_called_once() @patch("apps.services.user.postgres.session") - async def test_update_user_create_new(self, mock_session_maker, mock_user): - """测试 update_user 方法 - 创建新用户""" + async def test_update_user_info_success(self, mock_session_maker, mock_user): + """测试 update_user_info 方法 - 更新现有用户""" # 配置模拟会话 mock_session = AsyncMock() mock_session_maker.return_value.__aenter__.return_value = mock_session - mock_session.scalars.return_value.one_or_none.return_value = None + mock_session.scalars.return_value.one_or_none.return_value = mock_user # 准备测试数据 - update_data = UserUpdateRequest( - isActive=True, - isWhitelisted=False, - credit=150, - ) + update_data = UserUpdateRequest(autoExecute=True) # 调用被测试方法 - await UserManager.update_user("new_user_sub", update_data) + await UserManager.update_user_info("test_user_sub", update_data) # 验证调用 - mock_session.merge.assert_called_once() + assert mock_user.autoExecute is True mock_session.commit.assert_called_once() @patch("apps.services.user.postgres.session") - async def test_update_user_update_existing(self, mock_session_maker, mock_user): - """测试 update_user 方法 - 更新现有用户""" + async def test_update_user_info_not_found(self, mock_session_maker): + """测试 update_user_info 方法 - 用户不存在""" # 配置模拟会话 mock_session = AsyncMock() mock_session_maker.return_value.__aenter__.return_value = mock_session - mock_session.scalars.return_value.one_or_none.return_value = mock_user + mock_session.scalars.return_value.one_or_none.return_value = None # 准备测试数据 - update_data = UserUpdateRequest( - isActive=False, - credit=200 - ) + update_data = UserUpdateRequest(autoExecute=True) + + # 调用被测试方法并验证抛出异常 + with pytest.raises(ValueError, match="User .* not found"): + await UserManager.update_user_info("nonexistent_user", update_data) + + # 验证未提交 + mock_session.commit.assert_not_called() + + @patch("apps.services.user.postgres.session") + async def test_create_or_update_on_login_create_new(self, mock_session_maker): + """测试 create_or_update_on_login 方法 - 创建新用户""" + # 配置模拟会话 + mock_session = AsyncMock() + mock_session_maker.return_value.__aenter__.return_value = mock_session + mock_session.scalars.return_value.one_or_none.return_value = None + + # 调用被测试方法 + await UserManager.create_or_update_on_login("new_user_id", "TestUser") + + # 验证调用 + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + @patch("apps.services.user.postgres.session") + async def test_create_or_update_on_login_update_existing(self, mock_session_maker, mock_user): + """测试 create_or_update_on_login 方法 - 更新已有用户""" + # 配置模拟会话 + mock_session = AsyncMock() + mock_session_maker.return_value.__aenter__.return_value = mock_session + mock_session.scalars.return_value.one_or_none.return_value = mock_user # 调用被测试方法 - await UserManager.update_user("test_user_sub", update_data) + await UserManager.create_or_update_on_login("test_user_sub") # 验证调用 - assert mock_user.isActive == False - assert mock_user.credit == 200 assert mock_user.lastLogin is not None + mock_session.add.assert_not_called() mock_session.commit.assert_called_once() @patch("apps.services.user.postgres.session")