diff --git a/apps/llm/prompt.py b/apps/llm/prompt.py index fa424e731924cb2712470e881fc32ea96d047460..62689ccce929ca9f517ebee3367ea0d2b516e445 100644 --- a/apps/llm/prompt.py +++ b/apps/llm/prompt.py @@ -76,13 +76,15 @@ JSON_GEN: dict[LanguageType, str] = { {{ func.name }} {{ func.description }} - {{ func.parameters | tojson(indent=2) }} + {{ func.parameters }} {% endfor %} {% else %} ## 可用工具 {% for func in functions %} - - **{{ func.name }}**: {{ func.description }} + - **{{ func.name }}**: + - 工具描述:{{ func.description }} + - 工具参数Schema:{{ func.parameters }} {% endfor %} 请根据用户查询选择合适的工具来回答问题。你必须使用上述工具之一来处理查询。 diff --git a/apps/models/user.py b/apps/models/user.py index ccd4a20d506c2d40d6e504117e9222c6c9d5a5a3..6904f791cb2525ffe741ca68f78338424027354b 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -66,7 +66,7 @@ class UserAppUsage(Base): __tablename__ = "framework_user_app_usage" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """用户应用使用情况ID""" - userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), unique=True, nullable=False) # noqa: N815 + userId: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.id"), nullable=False) # noqa: N815 """用户ID""" appId: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), ForeignKey("framework_app.id"), nullable=False) # noqa: N815 """应用ID""" diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 5216aae960f193cb9dd7e5cf71494f4c15bcebd7..718d49a98e20da0a4c0619192b51b3e8ae2b3726 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Annotated from fastapi import APIRouter, Body, Depends, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import HTMLResponse, JSONResponse from fastapi.templating import Jinja2Templates @@ -92,11 +93,13 @@ async def logout(request: Request) -> JSONResponse: if not request.client: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="IP error", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="IP error", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) await TokenManager.delete_plugin_token(request.state.user_id) @@ -105,11 +108,13 @@ async def logout(request: Request) -> JSONResponse: return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="success", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -119,11 +124,13 @@ async def oidc_redirect() -> JSONResponse: redirect_url = await oidc_provider.get_redirect_url() return JSONResponse( status_code=status.HTTP_200_OK, - content=OidcRedirectRsp( - code=status.HTTP_200_OK, - message="success", - result=OidcRedirectMsg(url=redirect_url), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + OidcRedirectRsp( + code=status.HTTP_200_OK, + message="success", + result=OidcRedirectMsg(url=redirect_url), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -133,11 +140,13 @@ async def oidc_logout(token: Annotated[str, Body()]) -> JSONResponse: """POST /auth/logout: OIDC主动告知后端用户已在其他SSO站点登出""" return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="success", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -148,16 +157,20 @@ async def change_personal_token(request: Request) -> JSONResponse: """POST /auth/key: 重置用户的API密钥""" new_api_key: str | None = await PersonalTokenManager.update_personal_token(request.state.user_id) if not new_api_key: - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="failed to update personal token", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - - return JSONResponse(status_code=status.HTTP_200_OK, content=PostPersonalTokenRsp( - code=status.HTTP_200_OK, - message="success", - result=PostPersonalTokenMsg( - api_key=new_api_key, - ), - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="failed to update personal token", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + )) + + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder( + PostPersonalTokenRsp( + code=status.HTTP_200_OK, + message="success", + result=PostPersonalTokenMsg( + api_key=new_api_key, + ), + ).model_dump(exclude_none=True, by_alias=True), + )) diff --git a/apps/routers/blacklist.py b/apps/routers/blacklist.py index 02a6e93fabaaf4d6d5603ff398b21b635228c823..e6e29e54be509a5958b09dd3f7a0e0fb33560645 100644 --- a/apps/routers/blacklist.py +++ b/apps/routers/blacklist.py @@ -2,6 +2,7 @@ """FastAPI 黑名单相关路由""" from fastapi import APIRouter, Depends, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency.user import verify_admin, verify_personal_token, verify_session @@ -51,11 +52,16 @@ async def get_blacklist_user(page: int = 0) -> JSONResponse: page * PAGE_SIZE, ) - return JSONResponse(status_code=status.HTTP_200_OK, content=GetBlacklistUserRsp( - code=status.HTTP_200_OK, - message="ok", - result=GetBlacklistUserMsg(users=user_list), - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + GetBlacklistUserRsp( + code=status.HTTP_200_OK, + message="ok", + result=GetBlacklistUserMsg(users=user_list), + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.post("/user", response_model=ResponseData) @@ -75,16 +81,26 @@ async def change_blacklist_user(request: UserBlacklistRequest) -> JSONResponse: ) if not result: - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="Change user blacklist error.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="ok", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="Change user blacklist error.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="ok", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.get("/question", response_model=GetBlacklistQuestionRsp) async def get_blacklist_question(page: int = 0) -> JSONResponse: @@ -101,11 +117,16 @@ async def get_blacklist_question(page: int = 0) -> JSONResponse: ) # 将SQLAlchemy模型转换为Pydantic模型 question_schemas = [BlacklistSchema.model_validate(q) for q in question_list] - return JSONResponse(status_code=status.HTTP_200_OK, content=GetBlacklistQuestionRsp( - code=status.HTTP_200_OK, - message="ok", - result=GetBlacklistQuestionMsg(question_list=question_schemas), - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + GetBlacklistQuestionRsp( + code=status.HTTP_200_OK, + message="ok", + result=GetBlacklistQuestionMsg(question_list=question_schemas), + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.post("/question", response_model=ResponseData) async def change_blacklist_question(request: QuestionBlacklistRequest) -> JSONResponse: @@ -128,16 +149,26 @@ async def change_blacklist_question(request: QuestionBlacklistRequest) -> JSONRe ) if not result: - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="Modify question blacklist error.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="ok", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="Modify question blacklist error.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="ok", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @router.post("/complaint", response_model=ResponseData) @@ -151,16 +182,26 @@ async def abuse_report(raw_request: Request, request: AbuseRequest) -> JSONRespo ) if not result: - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="Report abuse complaint error.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="ok", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="Report abuse complaint error.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="ok", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.get("/abuse", response_model=GetBlacklistQuestionRsp) @@ -174,11 +215,16 @@ async def get_abuse_report(page: int = 0) -> JSONResponse: ) # 将SQLAlchemy模型转换为Pydantic模型 result_schemas = [BlacklistSchema.model_validate(r) for r in result] - return JSONResponse(status_code=status.HTTP_200_OK, content=GetBlacklistQuestionRsp( - code=status.HTTP_200_OK, - message="ok", - result=GetBlacklistQuestionMsg(question_list=result_schemas), - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + GetBlacklistQuestionRsp( + code=status.HTTP_200_OK, + message="ok", + result=GetBlacklistQuestionMsg(question_list=result_schemas), + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.post("/abuse", response_model=ResponseData) async def change_abuse_report(request: AbuseProcessRequest) -> JSONResponse: @@ -195,13 +241,23 @@ async def change_abuse_report(request: AbuseProcessRequest) -> JSONResponse: ) if not result: - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="Audit abuse question error.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="ok", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="Audit abuse question error.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="ok", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 48f5ad3c84820cc570ebb464dd3ce6ed91ef9bf4..08fc0a5624615b08a89681b347012e16caa417f3 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -6,6 +6,7 @@ import logging from collections.abc import AsyncGenerator from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue @@ -124,9 +125,11 @@ async def stop_generation(request: Request) -> JSONResponse: return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="stop generation success", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="stop generation success", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/comment.py b/apps/routers/comment.py index 18caec43b0fa4131571210d4a85021ff49496f47..572915bd5d4a4febb4371b6f27331c14cb752c6e 100644 --- a/apps/routers/comment.py +++ b/apps/routers/comment.py @@ -5,6 +5,7 @@ import logging from datetime import UTC, datetime from fastapi import APIRouter, Depends, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency import verify_personal_token, verify_session @@ -36,13 +37,23 @@ async def add_comment(request: Request, post_body: AddCommentData) -> JSONRespon ) result = await CommentManager.update_comment(post_body.record_id, comment_data, request.state.user_id) if not result: - return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="record_id not found", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="record_id not found", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="success", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index e1340e01ac8b7715ef20539cc7740871f509addf..237a6b5b1e7bbaf842f9cdea8fe2fe9304ca42df 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -6,6 +6,7 @@ import uuid from typing import Annotated from fastapi import APIRouter, Depends, Query, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency import verify_personal_token, verify_session @@ -58,11 +59,13 @@ async def get_conversation_list(request: Request) -> JSONResponse: return JSONResponse( status_code=status.HTTP_200_OK, - content=ConversationListRsp( - code=status.HTTP_200_OK, - message="success", - result=ConversationListMsg(conversations=result_conversations), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ConversationListRsp( + code=status.HTTP_200_OK, + message="success", + result=ConversationListMsg(conversations=result_conversations), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -79,11 +82,13 @@ async def update_conversation( _logger.error("[Conversation] conversation_id 不存在") return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="conversation_id not found", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="conversation_id not found", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) # 更新Conversation数据 @@ -99,27 +104,31 @@ async def update_conversation( _logger.exception("[Conversation] 更新对话数据失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="update conversation failed", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="update conversation failed", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=UpdateConversationRsp( - code=status.HTTP_200_OK, - message="success", - result=ConversationListItem( - conversationId=conv.id, - title=conv.title, - docCount=await DocumentManager.get_doc_count(conv.id), - createdTime=conv.createdAt.strftime("%Y-%m-%d %H:%M:%S"), - appId=conv.appId, - debug=conv.isTemporary, - ), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + UpdateConversationRsp( + code=status.HTTP_200_OK, + message="success", + result=ConversationListItem( + conversationId=conv.id, + title=conv.title, + docCount=await DocumentManager.get_doc_count(conv.id), + createdTime=conv.createdAt.strftime("%Y-%m-%d %H:%M:%S"), + appId=conv.appId, + debug=conv.isTemporary, + ), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -146,9 +155,11 @@ async def delete_conversation( return JSONResponse( status_code=status.HTTP_200_OK, - content=DeleteConversationRsp( - code=status.HTTP_200_OK, - message="success", - result=DeleteConversationMsg(conversationIdList=deleted_conversation), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + DeleteConversationRsp( + code=status.HTTP_200_OK, + message="success", + result=DeleteConversationMsg(conversationIdList=deleted_conversation), + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/document.py b/apps/routers/document.py index 3ca496294bd82d3a673447475849a709bba7ef5a..691e3202300e1ff3d977aa584f53c599c2a09baa 100644 --- a/apps/routers/document.py +++ b/apps/routers/document.py @@ -6,6 +6,7 @@ import uuid from typing import Annotated from fastapi import APIRouter, Depends, Path, Request, UploadFile, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency import verify_personal_token, verify_session @@ -60,11 +61,13 @@ async def document_upload( return JSONResponse( status_code=200, - content=UploadDocumentRsp( - code=status.HTTP_200_OK, - message="上传成功", - result=UploadDocumentMsg(documents=succeed_document), - ).model_dump(exclude_none=True, by_alias=False), + content=jsonable_encoder( + UploadDocumentRsp( + code=status.HTTP_200_OK, + message="上传成功", + result=UploadDocumentMsg(documents=succeed_document), + ).model_dump(exclude_none=True, by_alias=False), + ), ) @@ -161,10 +164,12 @@ async def get_document_list( if not await ConversationManager.verify_conversation_access(request.state.user_id, conversation_id): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=ResponseData( - code=status.HTTP_403_FORBIDDEN, - message="无权限访问", - result={}, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="无权限访问", + result={}, + ).model_dump(exclude_none=True, by_alias=False), ), ) @@ -179,11 +184,13 @@ async def get_document_list( # 对外展示的时候用id,不用alias return JSONResponse( status_code=status.HTTP_200_OK, - content=ConversationDocumentRsp( - code=status.HTTP_200_OK, - message="获取成功", - result=ConversationDocumentMsg(documents=result), - ).model_dump(exclude_none=True, by_alias=False), + content=jsonable_encoder( + ConversationDocumentRsp( + code=status.HTTP_200_OK, + message="获取成功", + result=ConversationDocumentMsg(documents=result), + ).model_dump(exclude_none=True, by_alias=False), + ), ) @@ -197,11 +204,13 @@ async def delete_single_document( if not result: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="删除文件失败", - result={}, - ).model_dump(exclude_none=True, by_alias=False), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="删除文件失败", + result={}, + ).model_dump(exclude_none=True, by_alias=False), + ), ) # 在RAG侧删除 auth_header = getattr(request.session, "session_id", None) or request.state.personal_token @@ -209,18 +218,22 @@ async def delete_single_document( if not result: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="RAG端删除文件失败", - result={}, - ).model_dump(exclude_none=True, by_alias=False), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="RAG端删除文件失败", + result={}, + ).model_dump(exclude_none=True, by_alias=False), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="删除成功", - result={}, - ).model_dump(exclude_none=True, by_alias=False), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="删除成功", + result={}, + ).model_dump(exclude_none=True, by_alias=False), + ), ) diff --git a/apps/routers/flow.py b/apps/routers/flow.py index 3e20c2925106bdcdb55e12759c758f87426263a3..11bd2d359ab5774e69ac76dafc9d1e721ed7e3ff 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -5,6 +5,7 @@ import uuid from typing import Annotated from fastapi import APIRouter, Body, Depends, Query, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency import verify_personal_token, verify_session @@ -65,29 +66,35 @@ async def get_flow(request: Request, appId: uuid.UUID, flowId: str) -> JSONRespo if not await AppCenterManager.validate_user_app_access(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=FlowStructureGetRsp( - code=status.HTTP_403_FORBIDDEN, - message="用户没有权限访问该Workflow", - result=FlowStructureGetMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructureGetRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该Workflow", + result=FlowStructureGetMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) result = await FlowManager.get_flow_by_app_and_flow_id(appId, flowId) if result is None: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureGetRsp( - code=status.HTTP_404_NOT_FOUND, - message="应用的Workflow获取失败", - result=FlowStructureGetMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructureGetRsp( + code=status.HTTP_404_NOT_FOUND, + message="应用的Workflow获取失败", + result=FlowStructureGetMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=FlowStructureGetRsp( - code=status.HTTP_200_OK, - message="应用的Workflow获取成功", - result=FlowStructureGetMsg(flow=result), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructureGetRsp( + code=status.HTTP_200_OK, + message="应用的Workflow获取成功", + result=FlowStructureGetMsg(flow=result), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -108,11 +115,13 @@ async def put_flow( if not await AppCenterManager.validate_app_belong_to_user(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=FlowStructurePutRsp( - code=status.HTTP_403_FORBIDDEN, - message="用户没有权限访问该流", - result=FlowStructurePutMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructurePutRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=FlowStructurePutMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) put_body.flow = await FlowServiceManager.remove_excess_structure_from_flow(put_body.flow) await FlowServiceManager.validate_flow_illegal(put_body.flow) @@ -123,11 +132,13 @@ async def put_flow( except Exception as e: # noqa: BLE001 return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=FlowStructurePutRsp( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"应用下流更新失败: {e!s}", - result=FlowStructurePutMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructurePutRsp( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"应用下流更新失败: {e!s}", + result=FlowStructurePutMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) flow = await FlowManager.get_flow_by_app_and_flow_id(appId, flowId) @@ -135,19 +146,23 @@ async def put_flow( if flow is None: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=FlowStructurePutRsp( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="应用下流更新后获取失败", - result=FlowStructurePutMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructurePutRsp( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="应用下流更新后获取失败", + result=FlowStructurePutMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=FlowStructurePutRsp( - code=status.HTTP_200_OK, - message="应用下流更新成功", - result=FlowStructurePutMsg(flow=flow), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructurePutRsp( + code=status.HTTP_200_OK, + message="应用下流更新成功", + result=FlowStructurePutMsg(flow=flow), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -160,27 +175,33 @@ async def delete_flow(request: Request, appId: uuid.UUID, flowId: str) -> JSONRe if not await AppCenterManager.validate_app_belong_to_user(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=FlowStructureDeleteRsp( - code=status.HTTP_403_FORBIDDEN, - message="用户没有权限访问该流", - result=FlowStructureDeleteMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructureDeleteRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=FlowStructureDeleteMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) result = await FlowManager.delete_flow_by_app_and_flow_id(appId, flowId) if result is None: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=FlowStructureDeleteRsp( - code=status.HTTP_404_NOT_FOUND, - message="应用下流程删除失败", - result=FlowStructureDeleteMsg(), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructureDeleteRsp( + code=status.HTTP_404_NOT_FOUND, + message="应用下流程删除失败", + result=FlowStructureDeleteMsg(), + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=FlowStructureDeleteRsp( - code=status.HTTP_200_OK, - message="应用下流程删除成功", - result=FlowStructureDeleteMsg(flowId=result), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + FlowStructureDeleteRsp( + code=status.HTTP_200_OK, + message="应用下流程删除成功", + result=FlowStructureDeleteMsg(flowId=result), + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/health.py b/apps/routers/health.py index 9bfe37e0f76bc00341b431c5107009a771ab5f7a..fe6cfe86d615f38420d523eb4d4afc8e5ca66b15 100644 --- a/apps/routers/health.py +++ b/apps/routers/health.py @@ -2,6 +2,7 @@ """FastAPI 健康检查接口""" from fastapi import APIRouter, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.schemas.response_data import HealthCheckRsp @@ -15,6 +16,11 @@ router = APIRouter( @router.get("", response_model=HealthCheckRsp) def health_check() -> JSONResponse: """GET /health_check: 服务健康检查接口""" - return JSONResponse(status_code=status.HTTP_200_OK, content=HealthCheckRsp( - status="ok", - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + HealthCheckRsp( + status="ok", + ).model_dump(exclude_none=True, by_alias=True), + ), + ) diff --git a/apps/routers/llm.py b/apps/routers/llm.py index 85420fb8a19fb35ee72c2ab6d39cb5f405f8f9c4..a7567112bad8d3289576c461b2e12a840c3097d3 100644 --- a/apps/routers/llm.py +++ b/apps/routers/llm.py @@ -1,7 +1,10 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """FastAPI 大模型相关接口""" +from typing import cast + from fastapi import APIRouter, Depends, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency import verify_admin, verify_personal_token, verify_session @@ -11,6 +14,8 @@ from apps.schemas.request_data import ( from apps.schemas.response_data import ( ListLLMAdminRsp, ListLLMRsp, + LLMAdminInfo, + LLMProviderInfo, ResponseData, ) from apps.services.llm import LLMManager @@ -39,14 +44,32 @@ admin_router = APIRouter( ) async def list_llm(llmId: str | None = None) -> JSONResponse: # noqa: N803 """GET /llm: 获取大模型列表""" - llm_list = await LLMManager.list_llm(llmId, admin_view=False) + llm_list_raw = await LLMManager.list_llm(llmId, admin_view=False) + + # 检查返回类型是否符合预期 + if llm_list_raw and not all(isinstance(item, LLMProviderInfo) for item in llm_list_raw): + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="大模型列表数据类型不符合预期", + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + + llm_list = cast("list[LLMProviderInfo]", llm_list_raw) + return JSONResponse( status_code=status.HTTP_200_OK, - content=ListLLMRsp( - code=status.HTTP_200_OK, - message="success", - result=llm_list, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ListLLMRsp( + code=status.HTTP_200_OK, + message="success", + result=llm_list, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -55,14 +78,32 @@ async def list_llm(llmId: str | None = None) -> JSONResponse: # noqa: N803 ) async def admin_list_llm(llmId: str | None = None) -> JSONResponse: # noqa: N803 """GET /llm/config: 获取大模型配置列表(管理员视图)""" - llm_list = await LLMManager.list_llm(llmId, admin_view=True) + llm_list_raw = await LLMManager.list_llm(llmId, admin_view=True) + + # 检查返回类型是否符合预期 + if llm_list_raw and not all(isinstance(item, LLMAdminInfo) for item in llm_list_raw): + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="大模型配置列表数据类型不符合预期", + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + + llm_list = cast("list[LLMAdminInfo]", llm_list_raw) + return JSONResponse( status_code=status.HTTP_200_OK, - content=ListLLMAdminRsp( - code=status.HTTP_200_OK, - message="success", - result=llm_list, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ListLLMAdminRsp( + code=status.HTTP_200_OK, + message="success", + result=llm_list, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -79,19 +120,23 @@ async def create_llm( except ValueError as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=str(e), - result=None, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=str(e), + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result=llmId, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="success", + result=llmId, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -105,17 +150,21 @@ async def delete_llm(llmId: str) -> JSONResponse: # noqa: N803 except ValueError as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=str(e), - result=None, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=str(e), + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result=llmId, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="success", + result=llmId, + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index 6ecb1e940bf25acb7fd72ecc98c32d154a8eb60c..e77910ade5e89a606980ef3e6f65391aae39683e 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -5,6 +5,7 @@ import logging from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, Path, Query, Request, UploadFile, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency.user import verify_admin, verify_personal_token, verify_session @@ -73,23 +74,27 @@ async def get_mcpservice_list( # noqa: PLR0913 _logger.exception(err) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=GetMCPServiceListRsp( - code=status.HTTP_200_OK, - message="OK", - result=GetMCPServiceListMsg( - currentPage=page, - services=service_cards, - ), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetMCPServiceListRsp( + code=status.HTTP_200_OK, + message="OK", + result=GetMCPServiceListMsg( + currentPage=page, + services=service_cards, + ), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -106,11 +111,13 @@ async def create_or_update_mcpservice( _logger.exception("[MCPServiceCenter] MCP服务创建失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"MCP服务创建失败: {e!s}", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"MCP服务创建失败: {e!s}", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) else: try: @@ -119,20 +126,27 @@ async def create_or_update_mcpservice( _logger.exception("[MCPService] 更新MCP服务失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"更新MCP服务失败: {e!s}", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"更新MCP服务失败: {e!s}", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) - return JSONResponse(status_code=status.HTTP_200_OK, content=UpdateMCPServiceRsp( - code=status.HTTP_200_OK, - message="OK", - result=UpdateMCPServiceMsg( - serviceId=service_id, - name=data.name, + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + UpdateMCPServiceRsp( + code=status.HTTP_200_OK, + message="OK", + result=UpdateMCPServiceMsg( + serviceId=service_id, + name=data.name, + ), + ).model_dump(exclude_none=True, by_alias=True), ), - ).model_dump(exclude_none=True, by_alias=True)) + ) @admin_router.post("/{serviceId}/install") @@ -150,19 +164,23 @@ async def install_mcp_service( _logger.exception(err) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=err, - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=err, + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="OK", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="OK", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -182,21 +200,25 @@ async def get_service_detail( _logger.exception(err) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) if data is None or config is None or icon is None: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=ResponseData( - code=status.HTTP_404_NOT_FOUND, - message="MCP服务有关信息不存在", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="MCP服务有关信息不存在", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) if edit: @@ -226,11 +248,13 @@ async def get_service_detail( return JSONResponse( status_code=status.HTTP_200_OK, - content=GetMCPServiceDetailRsp( - code=status.HTTP_200_OK, - message="OK", - result=detail, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetMCPServiceDetailRsp( + code=status.HTTP_200_OK, + message="OK", + result=detail, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -244,19 +268,23 @@ async def delete_service(serviceId: Annotated[str, Path()]) -> JSONResponse: # _logger.exception(err) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=DeleteMCPServiceRsp( - code=status.HTTP_200_OK, - message="OK", - result=BaseMCPServiceOperationMsg(serviceId=serviceId), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + DeleteMCPServiceRsp( + code=status.HTTP_200_OK, + message="OK", + result=BaseMCPServiceOperationMsg(serviceId=serviceId), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -276,11 +304,13 @@ async def update_mcp_icon( if not icon.size or icon.size == 0 or icon.size > 1024 * 1024 * 1: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="图标文件为空或超过1MB", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="图标文件为空或超过1MB", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) try: url = await MCPServiceManager.save_mcp_icon(serviceId, icon) @@ -289,19 +319,23 @@ async def update_mcp_icon( _logger.exception(err) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=err, - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=err, + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=UploadMCPServiceIconRsp( - code=status.HTTP_200_OK, - message="OK", - result=UploadMCPServiceIconMsg(serviceId=serviceId, url=url), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + UploadMCPServiceIconRsp( + code=status.HTTP_200_OK, + message="OK", + result=UploadMCPServiceIconMsg(serviceId=serviceId, url=url), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -322,17 +356,21 @@ async def active_or_deactivate_mcp_service( _logger.exception(err) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=err, - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=err, + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=ActiveMCPServiceRsp( - code=status.HTTP_200_OK, - message="OK", - result=BaseMCPServiceOperationMsg(serviceId=mcpId), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ActiveMCPServiceRsp( + code=status.HTTP_200_OK, + message="OK", + result=BaseMCPServiceOperationMsg(serviceId=mcpId), + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/parameter.py b/apps/routers/parameter.py index 781faf546c92789875eff02747b6d3802de02ec1..ac9d5f2160075e1b1ac8ef382686a5ece4e24980 100644 --- a/apps/routers/parameter.py +++ b/apps/routers/parameter.py @@ -4,6 +4,7 @@ import uuid from fastapi import APIRouter, Depends, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency.user import verify_personal_token, verify_session @@ -31,30 +32,36 @@ async def get_parameters( if not await AppCenterManager.validate_user_app_access(request.state.user_id, appId): return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=GetParamsRsp( - code=status.HTTP_403_FORBIDDEN, - message="用户没有权限访问该流", - result=[], - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetParamsRsp( + code=status.HTTP_403_FORBIDDEN, + message="用户没有权限访问该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ), ) flow = await FlowManager.get_flow_by_app_and_flow_id(appId, flowId) if not flow: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=GetParamsRsp( - code=status.HTTP_404_NOT_FOUND, - message="未找到该流", - result=[], - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetParamsRsp( + code=status.HTTP_404_NOT_FOUND, + message="未找到该流", + result=[], + ).model_dump(exclude_none=True, by_alias=True), + ), ) result = await ParameterManager.get_pre_params_by_flow_and_step_id(flow, stepId) return JSONResponse( status_code=status.HTTP_200_OK, - content=GetParamsRsp( - code=status.HTTP_200_OK, - message="获取参数成功", - result=result, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetParamsRsp( + code=status.HTTP_200_OK, + message="获取参数成功", + result=result, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -64,9 +71,11 @@ async def get_operate_parameters(paramType: Type) -> JSONResponse: # noqa: N803 result = await ParameterManager.get_operate_and_bind_type(paramType) return JSONResponse( status_code=status.HTTP_200_OK, - content=GetOperaRsp( - code=status.HTTP_200_OK, - message="获取操作成功", - result=result, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetOperaRsp( + code=status.HTTP_200_OK, + message="获取操作成功", + result=result, + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/record.py b/apps/routers/record.py index 677c1763190d76dfc53db90c6531c600d57d578a..2700ca953bd4b4a9121c7d33735666146ff83604 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -6,6 +6,7 @@ import uuid from typing import Annotated from fastapi import APIRouter, Depends, Path, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.common.security import Security @@ -51,11 +52,13 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path if not cur_conv: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=ResponseData( - code=status.HTTP_403_FORBIDDEN, - message="Conversation invalid.", - result={}, - ).model_dump(exclude_none=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="Conversation invalid.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) record_group_list = await RecordManager.query_record_group_by_conversation_id(conversationId) @@ -109,9 +112,11 @@ async def get_record(request: Request, conversationId: Annotated[uuid.UUID, Path result.append(tmp_record) return JSONResponse( status_code=status.HTTP_200_OK, - content=RecordListRsp( - code=status.HTTP_200_OK, - message="success", - result=RecordListMsg(records=result), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + RecordListRsp( + code=status.HTTP_200_OK, + message="success", + result=RecordListMsg(records=result), + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/routers/service.py b/apps/routers/service.py index d39dcd30fbeb9c2bf3f14d65d48b8560f4d9f0b7..e2fa4ff1419408e54740c944b757c1e795a66a6c 100644 --- a/apps/routers/service.py +++ b/apps/routers/service.py @@ -6,6 +6,7 @@ import uuid from typing import Annotated from fastapi import APIRouter, Depends, Path, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency.user import ( @@ -64,11 +65,13 @@ async def get_service_list( # noqa: PLR0913 if createdByMe and favorited: # 只能同时选择一个筛选条件 return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="INVALID_PARAMETER", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="INVALID_PARAMETER", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) service_cards, total_count = [], -1 @@ -101,32 +104,38 @@ async def get_service_list( # noqa: PLR0913 _logger.exception("[ServiceCenter] 获取服务列表失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) if total_count == -1: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="INVALID_PARAMETER", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="INVALID_PARAMETER", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=GetServiceListRsp( - code=status.HTTP_200_OK, - message="OK", - result=GetServiceListMsg( - currentPage=page, - totalCount=total_count, - services=service_cards, - ), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + GetServiceListRsp( + code=status.HTTP_200_OK, + message="OK", + result=GetServiceListMsg( + currentPage=page, + totalCount=total_count, + services=service_cards, + ), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -140,11 +149,13 @@ async def update_service(request: Request, data: UpdateServiceRequest) -> JSONRe _logger.exception("[ServiceCenter] 创建服务失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"OpenAPI解析错误: {e!s}", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"OpenAPI解析错误: {e!s}", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) else: try: @@ -152,21 +163,25 @@ async def update_service(request: Request, data: UpdateServiceRequest) -> JSONRe except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=ResponseData( - code=status.HTTP_403_FORBIDDEN, - message="未授权访问", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="未授权访问", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) except Exception as e: _logger.exception("[ServiceCenter] 更新服务失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"更新服务失败: {e!s}", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"更新服务失败: {e!s}", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) try: name, apis = await ServiceCenterManager.get_service_apis(service_id) @@ -174,15 +189,19 @@ async def update_service(request: Request, data: UpdateServiceRequest) -> JSONRe _logger.exception("[ServiceCenter] 获取服务API失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"获取服务API失败: {e!s}", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"获取服务API失败: {e!s}", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) msg = UpdateServiceMsg(serviceId=service_id, name=name, apis=apis) rsp = UpdateServiceRsp(code=status.HTTP_200_OK, message="OK", result=msg) - return JSONResponse(status_code=status.HTTP_200_OK, content=rsp.model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder( + rsp.model_dump(exclude_none=True, by_alias=True), + )) @router.get("/{serviceId}", response_model=GetServiceDetailRsp) @@ -198,21 +217,25 @@ async def get_service_detail( except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=ResponseData( - code=status.HTTP_403_FORBIDDEN, - message="未授权访问", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="未授权访问", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) except Exception: _logger.exception("[ServiceCenter] 获取服务数据失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) detail = GetServiceDetailMsg(serviceId=serviceId, name=name, data=data) else: @@ -222,15 +245,19 @@ async def get_service_detail( _logger.exception("[ServiceCenter] 获取服务API失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) detail = GetServiceDetailMsg(serviceId=serviceId, name=name, apis=apis) rsp = GetServiceDetailRsp(code=status.HTTP_200_OK, message="OK", result=detail) - return JSONResponse(status_code=status.HTTP_200_OK, content=rsp.model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder( + rsp.model_dump(exclude_none=True, by_alias=True), + )) @admin_router.delete("/{serviceId}", response_model=DeleteServiceRsp) @@ -241,25 +268,31 @@ async def delete_service(request: Request, serviceId: Annotated[uuid.UUID, Path( except InstancePermissionError: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, - content=ResponseData( - code=status.HTTP_403_FORBIDDEN, - message="未授权访问", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="未授权访问", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) except Exception: _logger.exception("[ServiceCenter] 删除服务失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) msg = BaseServiceOperationMsg(serviceId=serviceId) rsp = DeleteServiceRsp(code=status.HTTP_200_OK, message="OK", result=msg) - return JSONResponse(status_code=status.HTTP_200_OK, content=rsp.model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder( + rsp.model_dump(exclude_none=True, by_alias=True), + )) @router.put("/{serviceId}", response_model=ChangeFavouriteServiceRsp) @@ -278,22 +311,28 @@ async def modify_favorite_service( if not success: return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, - content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="INVALID_PARAMETER", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="INVALID_PARAMETER", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) except Exception: _logger.exception("[ServiceCenter] 修改服务收藏状态失败") return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="ERROR", - result={}, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="ERROR", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), ) msg = ChangeFavouriteServiceMsg(serviceId=serviceId, favorited=data.favorited) rsp = ChangeFavouriteServiceRsp(code=status.HTTP_200_OK, message="OK", result=msg) - return JSONResponse(status_code=status.HTTP_200_OK, content=rsp.model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=jsonable_encoder( + rsp.model_dump(exclude_none=True, by_alias=True), + )) diff --git a/apps/routers/tag.py b/apps/routers/tag.py index 391a28e381a29d075e0a44e932c41d953fe3cce8..7affdee22cbe063705807f8639d0454da2f090c5 100644 --- a/apps/routers/tag.py +++ b/apps/routers/tag.py @@ -2,6 +2,7 @@ """FastAPI 用户标签相关API""" from fastapi import APIRouter, Depends, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.dependency.user import verify_admin, verify_personal_token, verify_session @@ -23,48 +24,78 @@ admin_router = APIRouter( @admin_router.get("") async def get_user_tag() -> JSONResponse: """GET /tag: 获取所有标签""" - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="[Tag] Get all tag success.", - result=await TagManager.get_all_tag(), - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="[Tag] Get all tag success.", + result=await TagManager.get_all_tag(), + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.post("", response_model=ResponseData) async def update_tag(post_body: PostTagData) -> JSONResponse: """添加或改动特定标签定义""" if not await TagManager.update_tag_by_name(post_body): - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="[Tag] Update tag failed", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="[Tag] Update tag success.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="[Tag] Update tag failed", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="[Tag] Update tag success.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) @admin_router.delete("", response_model=ResponseData) async def delete_tag(post_body: PostTagData) -> JSONResponse: """删除某个标签""" if not await TagManager.get_tag_by_name(post_body.tag): - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="[Tag] Tag does not exist.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="[Tag] Tag does not exist.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) try: await TagManager.delete_tag(post_body) except Exception as e: # noqa: BLE001 - return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"[Tag] Delete tag failed: {e!s}", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="[Tag] Delete tag success.", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"[Tag] Delete tag failed: {e!s}", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) + return JSONResponse( + status_code=status.HTTP_200_OK, + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="[Tag] Delete tag success.", + result={}, + ).model_dump(exclude_none=True, by_alias=True), + ), + ) diff --git a/apps/routers/user.py b/apps/routers/user.py index 4460cc5d1002b34ccfc7a2f16d21e8ba68942493..01fc885c41a3d9aa692b5c240983a3c00c8d7603 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -2,6 +2,7 @@ """用户相关接口""" from fastapi import APIRouter, Depends, Request, status +from fastapi.encoders import jsonable_encoder from fastapi.responses import JSONResponse from apps.common.config import config @@ -29,12 +30,12 @@ async def update_user_info(request: Request, data: UserUpdateRequest) -> JSONRes except ValueError as e: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content={"code": status.HTTP_404_NOT_FOUND, "message": str(e)}, + content=jsonable_encoder({"code": status.HTTP_404_NOT_FOUND, "message": str(e)}), ) return JSONResponse( status_code=status.HTTP_200_OK, - content={"code": status.HTTP_200_OK, "message": "用户信息更新成功"}, + content=jsonable_encoder({"code": status.HTTP_200_OK, "message": "用户信息更新成功"}), ) @@ -46,11 +47,13 @@ async def get_user_info(request: Request) -> JSONResponse: if not user: return JSONResponse( status_code=status.HTTP_404_NOT_FOUND, - content=ResponseData( - code=status.HTTP_404_NOT_FOUND, - message="用户不存在", - result=None, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="用户不存在", + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ), ) user_info = UserInfoMsg( @@ -63,11 +66,13 @@ async def get_user_info(request: Request) -> JSONResponse: return JSONResponse( status_code=status.HTTP_200_OK, - content=UserInfoRsp( - code=status.HTTP_200_OK, - message="用户信息获取成功", - result=user_info, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + UserInfoRsp( + code=status.HTTP_200_OK, + message="用户信息获取成功", + result=user_info, + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -89,11 +94,13 @@ async def list_user( return JSONResponse( status_code=status.HTTP_200_OK, - content=UserListRsp( - code=status.HTTP_200_OK, - message="用户数据详细信息获取成功", - result=UserListMsg(userInfoList=user_info_list, total=total), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + UserListRsp( + code=status.HTTP_200_OK, + message="用户数据详细信息获取成功", + result=UserListMsg(userInfoList=user_info_list, total=total), + ).model_dump(exclude_none=True, by_alias=True), + ), ) @@ -110,17 +117,21 @@ async def get_user_tag( except ValueError as e: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=str(e), - result=None, - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=str(e), + result=None, + ).model_dump(exclude_none=True, by_alias=True), + ), ) return JSONResponse( status_code=status.HTTP_200_OK, - content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result=UserTagListResponse(tags=tags), - ).model_dump(exclude_none=True, by_alias=True), + content=jsonable_encoder( + ResponseData( + code=status.HTTP_200_OK, + message="success", + result=UserTagListResponse(tags=tags), + ).model_dump(exclude_none=True, by_alias=True), + ), ) diff --git a/apps/scheduler/executor/agent.py b/apps/scheduler/executor/agent.py index 08b150cbbbb113466010cca6e0176dcf992a21ab..2a96f920f479f4dd0663c8f098e5efd269fffb1f 100644 --- a/apps/scheduler/executor/agent.py +++ b/apps/scheduler/executor/agent.py @@ -3,7 +3,7 @@ import logging import uuid -from typing import TYPE_CHECKING, cast +from typing import cast import anyio from mcp.types import TextContent @@ -11,6 +11,7 @@ from pydantic import Field from apps.constants import AGENT_FINAL_STEP_NAME, AGENT_MAX_RETRY_TIMES, AGENT_MAX_STEPS from apps.models import ExecutorHistory, ExecutorStatus, MCPTools, StepStatus +from apps.models.task import ExecutorCheckpoint from apps.scheduler.executor.base import BaseExecutor from apps.scheduler.mcp_agent.host import MCPHost from apps.scheduler.mcp_agent.plan import MCPPlanner @@ -23,9 +24,6 @@ from apps.services.appcenter import AppCenterManager from apps.services.mcp_service import MCPServiceManager from apps.services.user import UserManager -if TYPE_CHECKING: - from apps.models.task import ExecutorCheckpoint - _logger = logging.getLogger(__name__) class MCPAgentExecutor(BaseExecutor): @@ -49,8 +47,8 @@ class MCPAgentExecutor(BaseExecutor): self._current_tool = None self._tool_list = {} # 初始化MCP Host相关对象 - self._planner = MCPPlanner(self.task, self.llm) - self._host = MCPHost(self.task, self.llm) + self._planner = MCPPlanner(self.task) + self._host = MCPHost(self.task) user = await UserManager.get_user(self.task.metadata.userId) if not user: err = "[MCPAgentExecutor] 用户不存在: %s" @@ -60,6 +58,20 @@ class MCPAgentExecutor(BaseExecutor): # 获取历史 await self._load_history() + # 初始化任务状态(如果不存在) + if not self.task.state: + self.task.state = ExecutorCheckpoint( + taskId=self.task.metadata.id, + appId=self.agent_id, + executorId="", + executorName="", + executorStatus=ExecutorStatus.INIT, + stepId=uuid.uuid4(), + stepName="", + stepStatus=StepStatus.INIT, + stepType="", + ) + async def load_mcp(self) -> None: """加载MCP服务器列表""" _logger.info("[MCPAgentExecutor] 加载MCP服务器列表") @@ -85,7 +97,7 @@ class MCPAgentExecutor(BaseExecutor): self._mcp_list.append(mcp_service) for tool in await MCPServiceManager.get_mcp_tools(mcp_id): - self._tool_list[tool.id] = tool + self._tool_list[tool.toolName] = tool self._tool_list[AGENT_FINAL_STEP_NAME] = MCPTools( mcpId="", toolName=AGENT_FINAL_STEP_NAME, description="结束流程的工具", @@ -107,8 +119,10 @@ class MCPAgentExecutor(BaseExecutor): if is_first: # 获取第一个输入参数 self._current_tool = self._tool_list[state.stepName] + # 更新host的task引用以确保使用最新的context + self._host.task = self.task self._current_input = await self._host.get_first_input_params( - self._current_tool, self.task.runtime.userInput, self.task, + self._current_tool, self.task.runtime.userInput, ) else: # 获取后续输入参数 @@ -119,14 +133,13 @@ class MCPAgentExecutor(BaseExecutor): params = {} params_description = "" self._current_tool = self._tool_list[state.stepName] - state.currentInput = await self._host.fill_params( + self._current_input = await self._host.fill_params( self._current_tool, self.task.runtime.userInput, - state.currentInput, + self._current_input, state.errorMessage, params, params_description, - self.task.runtime.language, ) self.task.state = state @@ -379,17 +392,17 @@ class MCPAgentExecutor(BaseExecutor): for _ in range(max_retry): try: step = await self._planner.create_next_step(history, list(self._tool_list.values())) - if step.tool_id in self._tool_list: + if step.tool_name in self._tool_list: break except Exception: _logger.exception("[MCPAgentExecutor] 获取下一步失败,重试中...") - if step is None or step.tool_id not in self._tool_list: + if step is None or step.tool_name not in self._tool_list: step = Step( - tool_id=AGENT_FINAL_STEP_NAME, + tool_name=AGENT_FINAL_STEP_NAME, description=AGENT_FINAL_STEP_NAME, ) state.stepId = uuid.uuid4() - state.stepName = step.tool_id + state.stepName = step.tool_name state.stepStatus = StepStatus.INIT else: # 没有下一步了,结束流程 @@ -474,6 +487,7 @@ class MCPAgentExecutor(BaseExecutor): thinking_started = False async for chunk in self._planner.generate_answer( await self._host.assemble_memory(self.task.runtime, self.task.context), + self.llm, ): if chunk.reasoning_content: if not thinking_started: diff --git a/apps/scheduler/mcp/prompt.py b/apps/scheduler/mcp/prompt.py index 5bf1ed24bf8bacd511b6272fcac94156c7716e4b..afcd1f440f3baeb86472bfea5046ab7c1af15450 100644 --- a/apps/scheduler/mcp/prompt.py +++ b/apps/scheduler/mcp/prompt.py @@ -47,7 +47,7 @@ MCP_SELECT: dict[LanguageType, str] = { {% for tool in tools %} - - {{ tool.id }}{{tool.name}};{{ tool.description }} + - {{ tool.toolName }}{{tool.toolName}};{{ tool.description }} {% endfor %} - Final结束步骤,当执行到这一步时,\ 表示计划执行结束,所得到的结果将作为最终结果。 @@ -178,7 +178,7 @@ CREATE_PLAN: dict[str, str] = { ## 工具列表 {% for tool in tools %} - - **{{ tool.id }}**: {{tool.name}};{{ tool.description }} + - **{{ tool.toolName }}**: {{tool.toolName}};{{ tool.description }} {% endfor %} - **Final**: 结束步骤,标志计划执行完成 @@ -226,7 +226,7 @@ CREATE_PLAN: dict[str, str] = { ## Tool List {% for tool in tools %} - - **{{ tool.id }}**: {{tool.name}}; {{ tool.description }} + - **{{ tool.toolName }}**: {{tool.toolName}}; {{ tool.description }} {% endfor %} - **Final**: End step, marks plan execution complete diff --git a/apps/scheduler/mcp_agent/base.py b/apps/scheduler/mcp_agent/base.py index 432c6a63634d18f5a87445ae6de4dab761f199be..3c4c9da77c76de47ba87764a3ca4567d53edca8c 100644 --- a/apps/scheduler/mcp_agent/base.py +++ b/apps/scheduler/mcp_agent/base.py @@ -3,7 +3,6 @@ import logging -from apps.llm import LLM from apps.models import LanguageType from apps.schemas.task import TaskData @@ -14,29 +13,11 @@ class MCPBase: """MCP基类""" _user_id: str - _llm: LLM - _goal: str + task: TaskData _language: LanguageType - def __init__(self, task: TaskData, llm: LLM) -> None: + def __init__(self, task: TaskData) -> None: """初始化MCP基类""" self._user_id = task.metadata.userId - self._llm = llm - self._goal = task.runtime.userInput + self.task = task self._language = task.runtime.language - - async def get_reasoning_result(self, prompt: str) -> str: - """获取推理结果""" - # 调用推理大模型 - message = [ - {"role": "system", "content": prompt}, - {"role": "user", "content": "Please provide a JSON response based on the above information and schema."}, - ] - result = "" - async for chunk in self._llm.call( - message, - streaming=False, - ): - result += chunk.content or "" - - return result diff --git a/apps/scheduler/mcp_agent/host.py b/apps/scheduler/mcp_agent/host.py index 3fbe09ac1b70c8e1867a4790c7ed897652481c19..e7a33f277c71a61918c0a48ed89013533b739fe7 100644 --- a/apps/scheduler/mcp_agent/host.py +++ b/apps/scheduler/mcp_agent/host.py @@ -46,17 +46,17 @@ class MCPHost(MCPBase): ) async def get_first_input_params( - self, mcp_tool: MCPTools, current_goal: str, runtime: TaskRuntime, context: list[ExecutorHistory], + self, mcp_tool: MCPTools, current_goal: str, ) -> dict[str, Any]: """填充工具参数""" # 更清晰的输入指令,这样可以调用generate - prompt = _env.from_string(GEN_PARAMS[runtime.language]).render( + prompt = _env.from_string(GEN_PARAMS[self.task.runtime.language]).render( tool_name=mcp_tool.toolName, tool_description=mcp_tool.description, - goal=self._goal, + goal=self.task.runtime.userInput, current_goal=current_goal, input_schema=mcp_tool.inputSchema, - background_info=await self.assemble_memory(runtime, context), + background_info=await self.assemble_memory(self.task.runtime, self.task.context), ) _logger.info("[MCPHost] 填充工具参数: %s", prompt) # 使用json_generator解析结果 @@ -79,15 +79,14 @@ class MCPHost(MCPBase): mcp_tool: MCPTools, current_goal: str, current_input: dict[str, Any], - language: LanguageType, - error_message: str = "", + error_message: str | dict = "", params: dict[str, Any] | None = None, params_description: str = "", ) -> dict[str, Any]: - llm_query = _LLM_QUERY_FIX[language] - prompt = _env.from_string(REPAIR_PARAMS[language]).render( + llm_query = _LLM_QUERY_FIX[self._language] + prompt = _env.from_string(REPAIR_PARAMS[self._language]).render( tool_name=mcp_tool.toolName, - goal=self._goal, + goal=self.task.runtime.userInput, current_goal=current_goal, tool_description=mcp_tool.description, input_schema=mcp_tool.inputSchema, @@ -110,5 +109,5 @@ class MCPHost(MCPBase): {"role": "user", "content": prompt}, {"role": "user", "content": llm_query}, ], - language=language, + language=self._language, ) diff --git a/apps/scheduler/mcp_agent/plan.py b/apps/scheduler/mcp_agent/plan.py index bb6e798202b952212a882c685f90371bd8fab4fc..dd1ad17d1d5e46a378fff18858ec432c1a48759a 100644 --- a/apps/scheduler/mcp_agent/plan.py +++ b/apps/scheduler/mcp_agent/plan.py @@ -9,7 +9,7 @@ from typing import Any from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment -from apps.llm import json_generator +from apps.llm import LLM, json_generator from apps.models import MCPTools from apps.scheduler.mcp_agent.base import MCPBase from apps.scheduler.mcp_agent.prompt import ( @@ -49,7 +49,7 @@ class MCPPlanner(MCPBase): async def get_flow_name(self) -> FlowName: """获取当前流程的名称""" template = _env.from_string(GENERATE_FLOW_NAME[self._language]) - prompt = template.render(goal=self._goal) + prompt = template.render(goal=self.task.runtime.userInput) result = await json_generator.generate( function=GET_FLOW_NAME_FUNCTION[self._language], @@ -65,11 +65,11 @@ class MCPPlanner(MCPBase): """创建下一步的执行步骤""" # 构建提示词 template = _env.from_string(GEN_STEP[self._language]) - prompt = template.render(goal=self._goal, history=history, tools=tools) + prompt = template.render(goal=self.task.runtime.userInput, history=history, tools=tools) # 获取函数定义并动态设置tool_id的enum function = deepcopy(CREATE_NEXT_STEP_FUNCTION[self._language]) - function["parameters"]["properties"]["tool_id"]["enum"] = [tool.id for tool in tools] + function["parameters"]["properties"]["tool_name"]["enum"] = [tool.toolName for tool in tools] step = await json_generator.generate( function=function, @@ -122,9 +122,9 @@ class MCPPlanner(MCPBase): """判断错误信息是否是参数错误""" tmplate = _env.from_string(IS_PARAM_ERROR[self._language]) prompt = tmplate.render( - goal=self._goal, + goal=self.task.runtime.userInput, history=history, - step_id=tool.id, + step_id=tool.toolName, step_name=tool.toolName, step_description=step_description, input_params=input_params, @@ -171,14 +171,14 @@ class MCPPlanner(MCPBase): language=self._language, ) - async def generate_answer(self, memory: str) -> AsyncGenerator[LLMChunk, None]: + async def generate_answer(self, memory: str, llm: LLM) -> AsyncGenerator[LLMChunk, None]: """生成最终回答,返回LLMChunk""" template = _env.from_string(FINAL_ANSWER[self._language]) prompt = template.render( memory=memory, - goal=self._goal, + goal=self.task.runtime.userInput, ) - async for chunk in self._llm.call( + async for chunk in llm.call( [{"role": "user", "content": prompt}], streaming=True, ): diff --git a/apps/scheduler/mcp_agent/prompt.py b/apps/scheduler/mcp_agent/prompt.py index c4fb46b71f313b71236ee532927a262dba188cd6..390c4be4db1e9593baa55286818ed57f3c7af757 100644 --- a/apps/scheduler/mcp_agent/prompt.py +++ b/apps/scheduler/mcp_agent/prompt.py @@ -96,9 +96,9 @@ CREATE_NEXT_STEP_FUNCTION: dict[LanguageType, dict] = { "parameters": { "type": "object", "properties": { - "tool_id": { + "tool_name": { "type": "string", - "description": "工具ID", + "description": "工具名称", "enum": [], }, "description": { @@ -106,11 +106,11 @@ CREATE_NEXT_STEP_FUNCTION: dict[LanguageType, dict] = { "description": "步骤描述", }, }, - "required": ["tool_id", "description"], + "required": ["tool_name", "description"], }, "examples": [ { - "tool_id": "mcp_tool_1", + "tool_name": "mcp_tool_1", "description": "扫描ip为192.168.1.1的MySQL数据库,端口为3306,用户名为root,密码为password的数据库性能", }, ], @@ -124,9 +124,9 @@ CREATE_NEXT_STEP_FUNCTION: dict[LanguageType, dict] = { "parameters": { "type": "object", "properties": { - "tool_id": { + "tool_name": { "type": "string", - "description": "Tool ID", + "description": "Tool Name", "enum": [], }, "description": { @@ -134,11 +134,11 @@ CREATE_NEXT_STEP_FUNCTION: dict[LanguageType, dict] = { "description": "Step description", }, }, - "required": ["tool_id", "description"], + "required": ["tool_name", "description"], }, "examples": [ { - "tool_id": "mcp_tool_1", + "tool_name": "mcp_tool_1", "description": "Scan MySQL database performance at 192.168.1.1:3306 with user root", }, ], @@ -176,7 +176,7 @@ GEN_STEP: dict[LanguageType, str] = { **可用工具**: {% for tool in tools %} - - **{{tool.id}}**:{{tool.description}} + - **{{tool.toolName}}**:{{tool.description}} {% endfor %} """, ), @@ -215,7 +215,7 @@ GEN_STEP: dict[LanguageType, str] = { **Available Tools**: {% for tool in tools %} - - **{{tool.id}}**: {{tool.description}} + - **{{tool.toolName}}**: {{tool.description}} {% endfor %} """, ), diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 59b747d7fb08c1169808be22dbf89de249daf3e7..dbd4483c7127db1b77572099d45fea7e0085e1c1 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -53,11 +53,12 @@ class AppLoader: flow_ids = [app_flow.id for app_flow in metadata.flows] new_flows: list[AppFlow] = [] + flow_loader = FlowLoader() async for flow_file in flow_path.rglob("*.yaml"): if flow_file.stem not in flow_ids: logger.warning("[AppLoader] 工作流 %s 不在元数据中", flow_file) - flow_loader = FlowLoader() - flow = await flow_loader.load(app_id, flow_file.stem) + # 加载工作流,但不进行向量化(通过内部方法) + flow = await flow_loader._load_flow_without_vector(app_id, flow_file.stem) if not flow: err = f"[AppLoader] 工作流 {flow_file} 加载失败" raise ValueError(err) @@ -73,6 +74,8 @@ class AppLoader: ), ) metadata.flows = new_flows + # 所有工作流加载完成后,统一进行一次向量化 + await flow_loader._update_vector(app_id) try: metadata = FlowAppMetadata.model_validate(metadata) except Exception as e: diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 9fee833489f9f2d4b3e057b1a3f71a85c220f5d9..57ec5544d8590feb297413a45c7e31c732b4937e 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -53,7 +53,7 @@ class CallLoader: # 更新数据库 call_descriptions = [] for call_id, call in call_metadata.items(): - await session.merge(NodeInfo( + session.add(NodeInfo( id=call_id, name=call.name, description=call.description, diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index c8fa6b739933783e7bfd80533e98e3a3d149941b..ca7215a7c49d57048aa60626dc9c990a89c5235b 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -110,8 +110,8 @@ class FlowLoader: return flow_yaml - async def load(self, app_id: uuid.UUID, flow_id: str) -> Flow: - """从文件系统中加载【单个】工作流""" + async def _load_flow_without_vector(self, app_id: uuid.UUID, flow_id: str) -> Flow: + """从文件系统中加载【单个】工作流,但不进行向量化""" logger.info("[FlowLoader] 应用 %s:加载工作流 %s...", flow_id, app_id) # 构建工作流文件路径 @@ -153,9 +153,14 @@ class FlowLoader: debug=flow_config.checkStatus.debug, ), ) + return Flow.model_validate(flow_yaml) + + async def load(self, app_id: uuid.UUID, flow_id: str) -> Flow: + """从文件系统中加载【单个】工作流""" + flow = await self._load_flow_without_vector(app_id, flow_id) # 重新向量化该App的所有工作流 await self._update_vector(app_id) - return Flow.model_validate(flow_yaml) + return flow async def save(self, app_id: uuid.UUID, flow_id: str, flow: Flow) -> None: diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index d30c5005f83f0fe81b0e4741850cc549b568dbc3..0b1259e25a2b145578e6259b4c9fc70d94f4fa98 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -134,6 +134,36 @@ class MCPLoader: await f.aclose() return updated_config + @staticmethod + async def _install_sse_mcp( + mcp_id: str, + mcp_config: MCPServerSSEConfig, + item: MCPServerConfig, + ) -> MCPServerSSEConfig: + """ + 安装 SSE/StreamableHTTP 类型的 MCP + + :param str mcp_id: MCP模板ID + :param MCPServerSSEConfig mcp_config: MCP配置 + :param MCPServerConfig item: 完整的配置对象 + :return: 安装后的配置 + :rtype: MCPServerSSEConfig + """ + logger.info("[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: %s", mcp_id) + + # 修改 autoInstall 标志 + mcp_config.autoInstall = False + item.mcpServers[mcp_id] = mcp_config + + # 保存配置到文件 + template_config = MCP_PATH / "template" / mcp_id / "config.json" + f = await template_config.open("w+", encoding="utf-8") + config_data = item.model_dump(by_alias=True, exclude_none=True) + await f.write(json.dumps(config_data, indent=4, ensure_ascii=False)) + await f.aclose() + + return mcp_config + @staticmethod async def _handle_tool_vectorization( mcp_id: str, @@ -235,9 +265,10 @@ class MCPLoader: await MCPLoader.update_template_status(mcp_id, MCPInstallStatus.FAILED, postgres) return mcp_config = updated_config + elif isinstance(mcp_config, MCPServerSSEConfig): + mcp_config = await MCPLoader._install_sse_mcp(mcp_id, mcp_config, item) else: - logger.info("[Installer] SSE/StreamableHTTP方式的MCP模板,无需安装: %s", mcp_id) - item.mcpServers[mcp_id].autoInstall = False + logger.warning("[Installer] 未知的MCP类型,跳过安装: %s", mcp_id) tool_list = await MCPLoader._get_template_tool(mcp_id, item) @@ -530,7 +561,7 @@ class MCPLoader: ) await f.aclose() async with postgres.session() as session: - await session.merge(MCPActivated( + session.add(MCPActivated( mcpId=mcp_id, userId=user_id, )) diff --git a/apps/scheduler/scheduler/init.py b/apps/scheduler/scheduler/init.py index c7020f3648b754562686a0a1a568db3c7ce2b169..da35e5f33f387799abaa5d2074cd6c7820d632d7 100644 --- a/apps/scheduler/scheduler/init.py +++ b/apps/scheduler/scheduler/init.py @@ -3,6 +3,7 @@ import logging import uuid +from datetime import UTC, datetime from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment @@ -48,10 +49,13 @@ class InitMixin: id=task_id, userId=user_id, conversationId=conversation_id, + updatedAt=datetime.now(UTC), ), runtime=TaskRuntime( taskId=task_id, authHeader=auth_header, + userInput=self.post_body.question, + language=self.post_body.language, ), state=None, context=[], diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 474fd646bf95ab529755014962e5227f8cb3a73a..2c5be573bc910ede231570662dcf0595d5c9aa6a 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -121,7 +121,7 @@ class MCPPlan(BaseModel): class Step(BaseModel): """MCP步骤""" - tool_id: str = Field(description="工具ID") + tool_name: str = Field(description="工具名称") description: str = Field(description="步骤描述") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 31b87188e1b3cd2e331f2364640a941abcf3078a..2edc2ddc76c2f3e3eeface0ab2cdf4af4aa26342 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -94,7 +94,7 @@ class LLMAdminInfo(BaseModel): llm_description: str = Field(default="", alias="llmDescription", description="LLM描述") llm_type: list[LLMType] = Field(default=[], alias="llmType", description="LLM类型") base_url: str = Field(alias="baseUrl", description="API Base URL") - api_key: str = Field(alias="apiKey", description="API Key") + api_key: str | None = Field(default=None, alias="apiKey", description="API Key") model_name: str | None = Field(default=None, alias="modelName", description="模型名称") max_tokens: int = Field(alias="maxTokens", description="最大token数") ctx_length: int = Field(alias="ctxLength", description="上下文长度") diff --git a/apps/services/activity.py b/apps/services/activity.py index 687e9af75bf606bb780d6713ed2996bfb13dea37..5ef0b80c8ac9f2b1ba064bacff13d03724eae39a 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -48,7 +48,7 @@ class Activity: if current_active >= MAX_CONCURRENT_TASKS: err = "系统并发已达上限" raise ActivityError(err) - await session.merge(SessionActivity(userId=user_id, timestamp=time)) + session.add(SessionActivity(userId=user_id, timestamp=time)) await session.commit() diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index b54e6e55e94e2f9263ca9a31b6c27bea1cefb878..d7089a5c0547265e90b7ba114679b95d52c60814 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -402,28 +402,32 @@ class AppCenterManager: :param app_id: 应用唯一标识 :return: 更新是否成功 """ - if str(app_id) == "00000000-0000-0000-0000-000000000000": - return - async with postgres.session() as session: - app_usages = list((await session.scalars( - select(UserAppUsage).where(UserAppUsage.userId == user_id), - )).all()) - if not app_usages: - msg = f"[AppCenterManager] 用户不存在: {user_id}" - raise ValueError(msg) + app_usage = (await session.scalars( + select(UserAppUsage).where( + and_( + UserAppUsage.userId == user_id, + UserAppUsage.appId == app_id, + ), + ), + )).one_or_none() - for app_data in app_usages: - if app_data.appId == app_id: - app_data.lastUsed = datetime.now(UTC) - app_data.usageCount += 1 - await session.merge(app_data) - break + if app_usage: + # 存在则更新count和lastUsed + app_usage.lastUsed = datetime.now(UTC) + app_usage.usageCount += 1 + await session.merge(app_usage) else: - app_data = UserAppUsage(userId=user_id, appId=app_id, lastUsed=datetime.now(UTC), usageCount=1) - await session.merge(app_data) + # 不存在则创建新条目 + app_usage = UserAppUsage( + userId=user_id, + appId=app_id, + lastUsed=datetime.now(UTC), + usageCount=1, + ) + session.add(app_usage) + await session.commit() - return @staticmethod diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index a8024b01a32061189200082d1c4a50835bfda494..a0bd62a7604e5f79941607206d5bea4c8d243943 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -190,7 +190,7 @@ class AbuseManager: reason=reason, ) - await session.merge(new_blacklist) + session.add(new_blacklist) await session.commit() return True diff --git a/apps/services/comment.py b/apps/services/comment.py index 04b312ed5b8724d90257f054ee64443f14247e6c..17816083e3e3267505598435ae41bc0a27062c8c 100644 --- a/apps/services/comment.py +++ b/apps/services/comment.py @@ -68,5 +68,5 @@ class CommentManager: feedbackLink=data.feedback_link, feedbackContent=data.feedback_content, ) - await session.merge(comment_info) + session.add(comment_info) await session.commit() diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 8dfc50878f660ba777f81dbc74e42ec9f0d7fd5a..7a945798d2e2b69e681cba4279b605d5e1b65016 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -352,12 +352,7 @@ class MCPServiceManager: if mcp_info.status != MCPInstallStatus.READY: err = "[MCPServiceManager] MCP服务未准备就绪" raise RuntimeError(err) - await session.merge(MCPActivated( - mcpId=mcp_info.id, - userId=user_id, - )) - await session.commit() - await MCPLoader.user_active_template(user_id, mcp_id, mcp_env) + await MCPLoader.user_active_template(user_id, mcp_id, mcp_env) @staticmethod diff --git a/apps/services/record.py b/apps/services/record.py index 95814ff4df2d6c3cb83ef9c87968dc5d03aebf05..5bf11097370d88037f44252fd4498bb0dc1ee1a7 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -57,7 +57,7 @@ class RecordManager: logger.error("[RecordManager] 对话不存在: %s", conversation_id) return None - await session.merge(PgRecord( + session.add(PgRecord( id=record.id, conversationId=conversation_id, taskId=record.task_id, diff --git a/apps/services/tag.py b/apps/services/tag.py index 77f47f72c9346e5058f6583a04a78ea9c1914875..a9916c60b440eb604be99c8ed8fea5433e13c75c 100644 --- a/apps/services/tag.py +++ b/apps/services/tag.py @@ -66,11 +66,19 @@ class TagManager: :param domain_data: 领域信息 """ async with postgres.session() as session: - tag = Tag( - name=data.tag, - definition=data.description, - ) - await session.merge(tag) + existing_tag = (await session.scalars(select(Tag).where(Tag.name == data.tag).limit(1))).one_or_none() + + if existing_tag: + existing_tag.definition = data.description + existing_tag.updatedAt = datetime.now(tz=UTC) + await session.merge(existing_tag) + else: + tag = Tag( + name=data.tag, + definition=data.description, + ) + session.add(tag) + await session.commit() diff --git a/apps/services/user_tag.py b/apps/services/user_tag.py index 15d5145582d451b559d9900218ac43c345f3c5a8..8e21ae428ad8a361a195dd2a94bdde62cb81bf9c 100644 --- a/apps/services/user_tag.py +++ b/apps/services/user_tag.py @@ -60,8 +60,10 @@ class UserTagManager: ).one_or_none() if not user_domain: + # 创建新记录 user_domain = UserTag(userId=user_id, tag=tag.id, count=1) - await session.merge(user_domain) + session.add(user_domain) else: + # 更新已存在记录 user_domain.count += 1 await session.commit()