From d6f2c654a68bf1dfe33bac100037765d17c26009 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Sat, 26 Jul 2025 15:53:48 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=9B=B4=E6=8D=A2=E6=97=B6=E5=8C=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/models/app.py | 7 +++---- apps/models/blacklist.py | 5 ++--- apps/models/comment.py | 5 ++--- apps/models/conversation.py | 5 ++--- apps/models/document.py | 5 ++--- apps/models/flow.py | 7 +++---- apps/models/llm.py | 5 ++--- apps/models/mcp.py | 16 ++++++++++------ apps/models/node.py | 7 +++---- apps/models/record.py | 5 ++--- apps/models/service.py | 7 +++---- apps/models/session.py | 5 ++--- apps/models/tag.py | 7 +++---- apps/models/task.py | 0 tests/routers/test_blacklist.py | 4 ++-- 15 files changed, 41 insertions(+), 49 deletions(-) create mode 100644 apps/models/task.py diff --git a/apps/models/app.py b/apps/models/app.py index aa6fa69d7..b774a5ee0 100644 --- a/apps/models/app.py +++ b/apps/models/app.py @@ -1,9 +1,8 @@ """应用 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, Boolean, DateTime, Enum, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -29,8 +28,8 @@ class App(Base): """应用ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), index=True, ) """应用更新时间""" diff --git a/apps/models/blacklist.py b/apps/models/blacklist.py index 306f28bd0..c5358062e 100644 --- a/apps/models/blacklist.py +++ b/apps/models/blacklist.py @@ -1,9 +1,8 @@ """黑名单 数据库表结构""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -30,7 +29,7 @@ class Blacklist(Base): """举报原因""" updatedAt: Mapped[DateTime] = mapped_column( # noqa: N815 DateTime, - default_factory=lambda: datetime.now(pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), init=False, ) """更新时间""" diff --git a/apps/models/comment.py b/apps/models/comment.py index 92d6fcce2..85394ce80 100644 --- a/apps/models/comment.py +++ b/apps/models/comment.py @@ -1,9 +1,8 @@ """评论 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import ARRAY, BigInteger, DateTime, Enum, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -31,6 +30,6 @@ class Comment(Base): feedbackContent: Mapped[str] = mapped_column(String(1000)) # noqa: N815 """投诉内容""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 - DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), ) """评论创建时间""" diff --git a/apps/models/conversation.py b/apps/models/conversation.py index caa80b5bf..d3522c834 100644 --- a/apps/models/conversation.py +++ b/apps/models/conversation.py @@ -1,10 +1,9 @@ """对话 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime from enum import Enum as PyEnum -import pytz from sqlalchemy import BigInteger, Boolean, DateTime, Enum, ForeignKey, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -29,7 +28,7 @@ class Conversation(Base): ) """对话ID""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 - DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), ) """对话创建时间""" isTemporary: Mapped[bool] = mapped_column(Boolean, default=False) # noqa: N815 diff --git a/apps/models/document.py b/apps/models/document.py index 3ea18f394..5f1419569 100644 --- a/apps/models/document.py +++ b/apps/models/document.py @@ -1,9 +1,8 @@ """文件 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import DateTime, Float, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -31,7 +30,7 @@ class Document(Base): """文件的ID""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), ) """文件的创建时间""" __table_args__ = ( diff --git a/apps/models/flow.py b/apps/models/flow.py index c5c372138..188ce2c54 100644 --- a/apps/models/flow.py +++ b/apps/models/flow.py @@ -1,9 +1,8 @@ """Flow 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -33,8 +32,8 @@ class Flow(Base): """Flow的ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """Flow的更新时间""" __table_args__ = ( diff --git a/apps/models/llm.py b/apps/models/llm.py index 02abe04f8..dc0216c6b 100644 --- a/apps/models/llm.py +++ b/apps/models/llm.py @@ -1,8 +1,7 @@ """大模型信息 数据库表""" -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, mapped_column @@ -32,7 +31,7 @@ class LLMData(Base): """LLM最大Token数量""" createdAt: Mapped[DateTime] = mapped_column( # noqa: N815 DateTime, - default_factory=lambda: datetime.now(pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), init=False, ) """添加LLM的时间""" diff --git a/apps/models/mcp.py b/apps/models/mcp.py index b3454c879..98f2cb167 100644 --- a/apps/models/mcp.py +++ b/apps/models/mcp.py @@ -7,7 +7,7 @@ from hashlib import shake_128 from typing import Any from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -39,11 +39,11 @@ class MCPInfo(Base): """MCP 概述""" description: Mapped[str] = mapped_column(String(65535)) """MCP 描述""" - author: Mapped[str] = mapped_column(ForeignKey("framework_user.userSub")) + author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub")) """MCP 创建者""" config: Mapped[dict[str, Any]] = mapped_column(JSONB) """MCP 配置""" - id: Mapped[str] = mapped_column(String(255), primary_key=True) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) """MCP ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), @@ -61,9 +61,11 @@ class MCPActivated(Base): """MCP 激活用户""" __tablename__ = "framework_mcp_activated" - mcpId: Mapped[str] = mapped_column(ForeignKey("framework_mcp.id"), index=True, unique=True) # noqa: N815 + mcpId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 + UUID(as_uuid=True), ForeignKey("framework_mcp.id"), index=True, unique=True, + ) """MCP ID""" - userSub: Mapped[str] = mapped_column(ForeignKey("framework_user.userSub")) # noqa: N815 + userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub")) # noqa: N815 """用户ID""" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" @@ -73,7 +75,9 @@ class MCPTools(Base): """MCP 工具""" __tablename__ = "framework_mcp_tools" - mcpId: Mapped[str] = mapped_column(ForeignKey("framework_mcp.id"), index=True) # noqa: N815 + mcpId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 + UUID(as_uuid=True), ForeignKey("framework_mcp.id"), index=True, + ) """MCP ID""" toolName: Mapped[str] = mapped_column(String(255)) # noqa: N815 """MCP 工具名称""" diff --git a/apps/models/node.py b/apps/models/node.py index 9303fd868..fe4e607fd 100644 --- a/apps/models/node.py +++ b/apps/models/node.py @@ -1,10 +1,9 @@ """节点 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime from typing import Any -import pytz from sqlalchemy import DateTime, ForeignKey, String from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column @@ -34,8 +33,8 @@ class NodeInfo(Base): """节点ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """节点更新时间""" diff --git a/apps/models/record.py b/apps/models/record.py index 8c385b1e9..ee2a352a6 100644 --- a/apps/models/record.py +++ b/apps/models/record.py @@ -1,10 +1,9 @@ """问答对 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime from typing import Any -import pytz from sqlalchemy import VARCHAR, BigInteger, DateTime, Float, ForeignKey, Integer, String from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column @@ -31,7 +30,7 @@ class Record(Base): ) """问答对ID""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 - DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), ) """问答对创建时间""" diff --git a/apps/models/service.py b/apps/models/service.py index c23d15145..8f51f8665 100644 --- a/apps/models/service.py +++ b/apps/models/service.py @@ -1,9 +1,8 @@ """插件 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -25,8 +24,8 @@ class Service(Base): """插件作者""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """插件更新时间""" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) diff --git a/apps/models/session.py b/apps/models/session.py index 795d22953..87174a55c 100644 --- a/apps/models/session.py +++ b/apps/models/session.py @@ -1,10 +1,9 @@ """会话相关 数据库表""" import secrets -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from enum import Enum as PyEnum -import pytz from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -28,7 +27,7 @@ class Session(Base): """用户名""" validUntil: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")) + timedelta(days=30), + default_factory=lambda: datetime.now(tz=UTC) + timedelta(days=30), ) """有效期""" id: Mapped[str] = mapped_column(String(255), primary_key=True, default_factory=lambda: secrets.token_hex(16)) diff --git a/apps/models/tag.py b/apps/models/tag.py index 22238fe0a..19ca8f58b 100644 --- a/apps/models/tag.py +++ b/apps/models/tag.py @@ -1,8 +1,7 @@ """用户标签 数据库表""" -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, DateTime, String from sqlalchemy.orm import Mapped, mapped_column @@ -21,7 +20,7 @@ class Tag(Base): """标签定义""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """标签的更新时间""" diff --git a/apps/models/task.py b/apps/models/task.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/routers/test_blacklist.py b/tests/routers/test_blacklist.py index 3cc6b51b0..a5f1a4546 100644 --- a/tests/routers/test_blacklist.py +++ b/tests/routers/test_blacklist.py @@ -7,7 +7,7 @@ from unittest import TestCase from fastapi.testclient import TestClient from fastapi import status, FastAPI, Request -from apps.routers.blacklist import router +from apps.routers.blacklist import admin_router from apps.dependency import verify_csrf_token, get_current_user from apps.schemas.user import User @@ -24,7 +24,7 @@ class TestBlacklistRouter(TestCase): @classmethod def setUpClass(cls): app = FastAPI() - app.include_router(router) + app.include_router(admin_router) app.dependency_overrides[verify_csrf_token] = mock_csrf_token app.dependency_overrides[get_current_user] = mock_get_user cls.client = TestClient(app) -- Gitee From ce2967bc133fff8ca131189876f9ca548fd5d040 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Sat, 26 Jul 2025 15:56:26 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E6=97=B6=E5=8C=BA?= =?UTF-8?q?=E3=80=81=E4=BF=AE=E6=94=B9MCP=20Manager=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/mcp_service.py | 2 +- apps/scheduler/pool/loader/mcp.py | 27 ++-- apps/scheduler/pool/mcp/default.py | 4 +- apps/scheduler/pool/mcp/pool.py | 2 +- apps/schemas/mcp.py | 2 +- apps/services/activity.py | 5 +- apps/services/appcenter.py | 6 +- apps/services/blacklist.py | 2 +- apps/services/comment.py | 2 +- apps/services/conversation.py | 15 +- apps/services/mcp_service.py | 225 ++++++++++++++++++----------- 11 files changed, 176 insertions(+), 116 deletions(-) diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index 397d90444..0581b5b63 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -173,7 +173,7 @@ async def get_service_detail( indent=4, ensure_ascii=False, ), - mcpType=config.type, + mcpType=config.mcp_type, ) else: # 组装详情所需信息 diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 932d8272c..b74e7be4f 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -6,6 +6,7 @@ import base64 import json import logging import shutil +import uuid from hashlib import shake_128 import asyncer @@ -120,7 +121,7 @@ class MCPLoader(metaclass=SingletonMeta): @staticmethod - async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None: + async def init_one_template(mcp_id: uuid.UUID, config: MCPServerConfig) -> None: """ 初始化单个MCP模板 @@ -132,7 +133,7 @@ class MCPLoader(metaclass=SingletonMeta): await MCPLoader._insert_template_db(mcp_id, config) # 检查目录 - template_path = MCP_PATH / "template" / mcp_id + template_path = MCP_PATH / "template" / str(mcp_id) await Path.mkdir(template_path, parents=True, exist_ok=True) # 安装MCP模板 @@ -191,12 +192,12 @@ class MCPLoader(metaclass=SingletonMeta): """ # 创建客户端 if ( - (config.type == MCPType.STDIO and isinstance(config.config, MCPServerStdioConfig)) - or (config.type == MCPType.SSE and isinstance(config.config, MCPServerSSEConfig)) + (config.mcp_type == MCPType.STDIO and isinstance(config.config, MCPServerStdioConfig)) + or (config.mcp_type == MCPType.SSE and isinstance(config.config, MCPServerSSEConfig)) ): client = MCPClient() else: - err = f"MCP {mcp_id}:未知的MCP服务类型“{config.type}”" + err = f"MCP {mcp_id}:未知的MCP服务类型“{config.mcp_type}”" logger.error(err) raise ValueError(err) @@ -217,7 +218,7 @@ class MCPLoader(metaclass=SingletonMeta): return tool_list @staticmethod - async def _insert_template_db(mcp_id: str, config: MCPServerConfig) -> None: + async def _insert_template_db(mcp_id: uuid.UUID, config: MCPServerConfig) -> None: """插入单个MCP Server模板信息到数据库""" mcp_collection = MongoDB().get_collection("mcp") await mcp_collection.update_one( @@ -227,7 +228,7 @@ class MCPLoader(metaclass=SingletonMeta): _id=mcp_id, name=config.name, description=config.description, - type=config.type, + type=config.mcp_type, author=config.author, ).model_dump(by_alias=True, exclude_none=True), }, @@ -311,7 +312,7 @@ class MCPLoader(metaclass=SingletonMeta): await LanceDB().create_index("mcp_tool") @staticmethod - async def save_one(mcp_id: str, config: MCPServerConfig) -> None: + async def save_one(mcp_id: uuid.UUID, config: MCPServerConfig) -> None: """ 保存单个MCP模板的配置文件(``config.json``文件和``icon.png``文件) @@ -451,7 +452,15 @@ class MCPLoader(metaclass=SingletonMeta): :param list[str] deleted_mcp_list: 被删除的MCP列表 :return: 无 """ - # 从MongoDB中移除 + # 移除Info + async with postgres.session() as session: + for mcp_id in deleted_mcp_list: + mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() + if mcp_info: + await session.delete(mcp_info) + + + mcp_collection = MongoDB().get_collection("mcp") mcp_service_list = await mcp_collection.find( {"_id": {"$in": deleted_mcp_list}}, diff --git a/apps/scheduler/pool/mcp/default.py b/apps/scheduler/pool/mcp/default.py index 4f29ccaee..f16b964bf 100644 --- a/apps/scheduler/pool/mcp/default.py +++ b/apps/scheduler/pool/mcp/default.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) DEFAULT_STDIO = MCPServerConfig( name="MCP服务_" + uuid.uuid4().hex[:6], description="MCP服务描述", - type=MCPType.STDIO, + mcp_type=MCPType.STDIO, config=MCPServerStdioConfig( command="uvx", args=[ @@ -31,7 +31,7 @@ DEFAULT_STDIO = MCPServerConfig( DEFAULT_SSE = MCPServerConfig( name="MCP服务_" + uuid.uuid4().hex[:6], description="MCP服务描述", - type=MCPType.SSE, + mcp_type=MCPType.SSE, config=MCPServerSSEConfig( url="http://test.domain/sse", env={ diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 9119b901a..ec711ef96 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -33,7 +33,7 @@ class MCPPool(metaclass=SingletonMeta): config = MCPServerConfig.model_validate_json(await config_path.read_text()) - if config.type in (MCPType.SSE, MCPType.STDIO): + if config.mcp_type in (MCPType.SSE, MCPType.STDIO): client = MCPClient() else: logger.warning("[MCPPool] 用户 %s 的MCP %s 类型错误", user_sub, mcp_id) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index e6b98d25b..bdd9294b4 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -44,7 +44,7 @@ class MCPServerConfig(BaseModel): name: str = Field(description="MCP 服务器自然语言名称", default="") overview: str = Field(description="MCP 服务器概述", default="") description: str = Field(description="MCP 服务器自然语言描述", default="") - type: MCPType = Field(description="MCP 服务器类型", default=MCPType.STDIO) + mcp_type: MCPType = Field(description="MCP 服务器类型", default=MCPType.STDIO) author: str = Field(description="MCP 服务器上传者", default="") config: MCPServerStdioConfig | MCPServerSSEConfig = Field(description="MCP 服务器配置") diff --git a/apps/services/activity.py b/apps/services/activity.py index 1f7221f60..e378fcef2 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -2,9 +2,8 @@ """用户限流""" import uuid -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta -import pytz from sqlalchemy import delete from apps.common.postgres import postgres @@ -24,7 +23,7 @@ class Activity: :param user_sub: 用户实体ID :return: 判断结果,正在提问则返回True """ - time = datetime.now(pytz.timezone("Asia/Shanghai")) + time = datetime.now(tz=UTC) # 检查窗口内总请求数 count = await MongoDB().get_collection("activity").count_documents( diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index dda29b49b..da5ba28f0 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -247,7 +247,7 @@ class AppCenterManager: msg = f"[AppCenterManager] 应用不存在: {app_id}" raise ValueError(msg) app_obj.isPublished = published - session.add(app_obj) + await session.merge(app_obj) await session.commit() await AppCenterManager._process_app_and_save( @@ -376,11 +376,11 @@ class AppCenterManager: if app_data.appId == app_id: app_data.lastUsed = datetime.now(UTC) app_data.usageCount += 1 - session.add(app_data) + await session.merge(app_data) break else: app_data = UserAppUsage(userSub=user_sub, appId=app_id, lastUsed=datetime.now(UTC), usageCount=1) - session.add(app_data) + await session.merge(app_data) await session.commit() return diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index 871d03fda..7ae717715 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -182,7 +182,7 @@ class AbuseManager: reason=reason, ) - session.add(new_blacklist) + await session.merge(new_blacklist) await session.commit() return True diff --git a/apps/services/comment.py b/apps/services/comment.py index 27c4fcca8..25b084c3c 100644 --- a/apps/services/comment.py +++ b/apps/services/comment.py @@ -68,5 +68,5 @@ class CommentManager: feedbackLink=data.feedback_link, feedbackContent=data.feedback_content, ) - session.add(comment_info) + await session.merge(comment_info) await session.commit() diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 257e0a315..ade7fb6bf 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -3,10 +3,9 @@ import logging import uuid -from datetime import datetime +from datetime import UTC, datetime from typing import Any -import pytz from sqlalchemy import and_, func, select from apps.common.postgres import postgres @@ -78,7 +77,7 @@ class ConversationManager: # 使用PostgreSQL实现新建对话,并根据debug和usage进行更新 try: async with postgres.session() as session: - session.add(conv) + await session.merge(conv) await session.commit() await session.refresh(conv) @@ -96,15 +95,15 @@ class ConversationManager: # 假设App模型有last_used和usage_count字段(如没有请根据实际表结构调整) # 这里只做示例,实际字段名和类型请根据实际情况修改 app_obj.usageCount += 1 - app_obj.lastUsed = datetime.now(pytz.timezone("Asia/Shanghai")) - session.add(app_obj) + app_obj.lastUsed = datetime.now(tz=UTC) + await session.merge(app_obj) await session.commit() else: - session.add(UserAppUsage( + await session.merge(UserAppUsage( userSub=user_sub, appId=app_id, usageCount=1, - lastUsed=datetime.now(pytz.timezone("Asia/Shanghai")), + lastUsed=datetime.now(tz=UTC), )) await session.commit() return conv @@ -131,7 +130,7 @@ class ConversationManager: return False for key, value in data.items(): setattr(conv, key, value) - session.add(conv) + await session.merge(conv) await session.commit() return True diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 80abf5f48..07a79206e 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -1,37 +1,40 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -"""MCP服务管理器""" +"""MCP服务(插件中心)管理器""" import logging -import random import re +import uuid from typing import Any from fastapi import UploadFile from PIL import Image -from sqids.sqids import Sqids +from sqlalchemy import and_, or_, select +from apps.common.postgres import postgres from apps.constants import ( ALLOWED_ICON_MIME_TYPES, ICON_PATH, SERVICE_PAGE_SIZE, ) +from apps.models.mcp import ( + MCPActivated, + MCPInfo, + MCPInstallStatus, + MCPTools, + MCPType, +) from apps.scheduler.pool.loader.mcp import MCPLoader from apps.scheduler.pool.mcp.pool import MCPPool from apps.schemas.enum_var import SearchType from apps.schemas.mcp import ( - MCPCollection, - MCPInstallStatus, MCPServerConfig, MCPServerSSEConfig, MCPServerStdioConfig, - MCPTool, - MCPType, ) from apps.schemas.request_data import UpdateMCPServiceRequest from apps.schemas.response_data import MCPServiceCardItem logger = logging.getLogger(__name__) -sqids = Sqids(min_length=6) MCP_ICON_PATH = ICON_PATH / "mcp" @@ -47,10 +50,15 @@ class MCPServiceManager: :param str mcp_id: MCP服务ID :return: 是否激活 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - mcp_list = await mcp_collection.find({"_id": mcp_id}, {"activated": True}).to_list(None) - return any(user_sub in db_item.get("activated", []) for db_item in mcp_list) + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPActivated).where( + and_( + MCPActivated.mcpId == mcp_id, + MCPActivated.userSub == user_sub, + ), + ))).one_or_none() + return bool(mcp_info) + @staticmethod async def get_service_status(mcp_id: str) -> MCPInstallStatus: @@ -60,16 +68,12 @@ class MCPServiceManager: :param str mcp_id: MCP服务ID :return: MCP服务状态 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - mcp_list = await mcp_collection.find({"_id": mcp_id}, {"status": True}).to_list(None) - for db_item in mcp_list: - status = db_item.get("status") - if MCPInstallStatus.READY.value == status: - return MCPInstallStatus.READY - if MCPInstallStatus.INSTALLING.value == status: - return MCPInstallStatus.INSTALLING - return MCPInstallStatus.FAILED + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() + if mcp_info: + return mcp_info.status + return MCPInstallStatus.FAILED + @staticmethod async def fetch_mcp_services( @@ -87,8 +91,7 @@ class MCPServiceManager: :param page: int: 页码 :return: MCP服务列表 """ - filters = MCPServiceManager._build_filters(search_type, keyword) - mcpservice_pools = await MCPServiceManager._search_mcpservice(filters, page) + mcpservice_pools = await MCPServiceManager._search_mcpservice(search_type, keyword, page) return [ MCPServiceCardItem( mcpserviceId=item.id, @@ -102,21 +105,18 @@ class MCPServiceManager: for item in mcpservice_pools ] + @staticmethod - async def get_mcp_service(mcpservice_id: str) -> MCPCollection: + async def get_mcp_service(mcpservice_id: str) -> MCPInfo | None: """ 获取MCP服务详细信息 :param mcpservice_id: str: MCP服务ID :return: MCP服务详细信息 """ - # 验证用户权限 - mcpservice_collection = MongoDB().get_collection("mcp") - db_service = await mcpservice_collection.find_one({"_id": mcpservice_id}) - if not db_service: - msg = "[MCPServiceManager] MCP服务未找到" - raise RuntimeError(msg) - return MCPCollection.model_validate(db_service) + async with postgres.session() as session: + return (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcpservice_id))).one_or_none() + @staticmethod async def get_mcp_config(mcpservice_id: str) -> tuple[MCPServerConfig, str]: @@ -130,31 +130,27 @@ class MCPServiceManager: config = await MCPLoader.get_config(mcpservice_id) return config, icon + @staticmethod async def get_service_tools( service_id: str, - ) -> list[MCPTool]: + ) -> list[MCPTools]: """ 获取MCP可以用工具 :param service_id: str: MCP服务ID :return: MCP工具详细信息列表 """ - # 获取服务名称 - service_collection = MongoDB().get_collection("mcp") - data = await service_collection.find({"_id": service_id}, {"tools": True}).to_list(None) - result = [] - for item in data: - for tool in item.get("tools", {}): - tool_data = MCPTool.model_validate(tool) - result.append(tool_data) - return result + async with postgres.session() as session: + return list((await session.scalars(select(MCPTools).where(MCPTools.mcpId == service_id))).all()) + @staticmethod async def _search_mcpservice( - search_conditions: dict[str, Any], + search_type: SearchType, + keyword: str | None, page: int, - ) -> list[MCPCollection]: + ) -> list[MCPInfo]: """ 基于输入条件搜索MCP服务 @@ -162,44 +158,53 @@ class MCPServiceManager: :param page: int: 页码 :return: MCP列表 """ - mcpservice_collection = MongoDB().get_collection("mcp") # 分页查询 skip = (page - 1) * SERVICE_PAGE_SIZE - db_mcpservices = await mcpservice_collection.find(search_conditions).skip(skip).limit( - SERVICE_PAGE_SIZE, - ).to_list() + + async with postgres.session() as session: + if not keyword: + result = list( + (await session.scalars(select(MCPInfo).offset(skip).limit(SERVICE_PAGE_SIZE))).all(), + ) + elif search_type == SearchType.ALL: + result = list( + (await session.scalars(select(MCPInfo).where( + or_( + MCPInfo.name.like(f"%{keyword}%"), + MCPInfo.description.like(f"%{keyword}%"), + MCPInfo.author.like(f"%{keyword}%"), + ), + ).offset(skip).limit(SERVICE_PAGE_SIZE))).all(), + ) + elif search_type == SearchType.NAME: + result = list( + (await session.scalars( + select(MCPInfo).where(MCPInfo.name.like(f"%{keyword}%")).offset(skip).limit(SERVICE_PAGE_SIZE), + )).all(), + ) + elif search_type == SearchType.DESCRIPTION: + result = list( + (await session.scalars( + select(MCPInfo).where(MCPInfo.description.like(f"%{keyword}%")).offset(skip).limit(SERVICE_PAGE_SIZE), + )).all(), + ) + elif search_type == SearchType.AUTHOR: + result = list( + (await session.scalars( + select(MCPInfo).where(MCPInfo.author.like(f"%{keyword}%")).offset(skip).limit(SERVICE_PAGE_SIZE), + )).all(), + ) + # 如果未找到,返回空列表 - if not db_mcpservices: - logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_conditions) + if not result: + logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_type) return [] # 将数据库中的MCP服务转换为对象 - return [MCPCollection.model_validate(db_mcpservice) for db_mcpservice in db_mcpservices] - - @staticmethod - def _build_filters( - search_type: SearchType, - keyword: str | None, - ) -> dict[str, Any]: - if not keyword: - return {} - - if search_type == SearchType.ALL: - base_filters = {"$or": [ - {"name": {"$regex": keyword, "$options": "i"}}, - {"description": {"$regex": keyword, "$options": "i"}}, - {"author": {"$regex": keyword, "$options": "i"}}, - ]} - elif search_type == SearchType.NAME: - base_filters = {"name": {"$regex": keyword, "$options": "i"}} - elif search_type == SearchType.DESCRIPTION: - base_filters = {"description": {"$regex": keyword, "$options": "i"}} - elif search_type == SearchType.AUTHOR: - base_filters = {"author": {"$regex": keyword, "$options": "i"}} - return base_filters + return result @staticmethod - async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: + async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> uuid.UUID: """ 创建MCP服务 @@ -218,23 +223,23 @@ class MCPServiceManager: overview=data.overview, description=data.description, config=config, - type=data.mcp_type, + mcp_type=data.mcp_type, author=user_sub, ) # 检查是否存在相同服务 - mcp_collection = MongoDB().get_collection("mcp") - db_service = await mcp_collection.find_one({"name": mcp_server.name}) - mcp_id = sqids.encode([random.randint(0, 1000000) for _ in range(5)])[:6] # noqa: S311 - if db_service: - mcp_server.name = f"{mcp_server.name}-{mcp_id}" - logger.warning("[MCPServiceManager] 已存在相同名称和描述的MCP服务") + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.name == mcp_server.name))).one_or_none() + if mcp_info: + mcp_server.name = f"{mcp_server.name}-{uuid.uuid4().hex[:6]}" + logger.warning("[MCPServiceManager] 已存在相同ID或名称的MCP服务") # 保存并载入配置 logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name) - await MCPLoader.save_one(mcp_id, mcp_server) - await MCPLoader.init_one_template(mcp_id=mcp_id, config=mcp_server) - return mcp_id + await MCPLoader.save_one(mcp_server.id, mcp_server) + await MCPLoader.init_one_template(mcp_id=mcp_server.id, config=mcp_server) + return mcp_server.id + @staticmethod async def update_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: @@ -267,18 +272,18 @@ class MCPServiceManager: ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate_json( data.config, ), - type=data.mcp_type, + mcp_type=data.mcp_type, author=user_sub, )) # 返回服务ID return data.service_id + @staticmethod async def delete_mcpservice(service_id: str) -> None: """ 删除MCP服务 - :param user_sub: str: 用户ID :param service_id: str: MCP服务ID :return: 是否删除成功 """ @@ -292,6 +297,7 @@ class MCPServiceManager: {"$pull": {"mcp_service": service_id}}, ) + @staticmethod async def active_mcpservice( user_sub: str, @@ -370,3 +376,50 @@ class MCPServiceManager: image.save(MCP_ICON_PATH / f"{service_id}.png", format="PNG", optimize=True, compress_level=9) return f"/static/mcp/{service_id}.png" + + + @staticmethod + async def is_user_actived(user_sub: str, mcp_id: str) -> bool: + """ + 判断用户是否激活MCP + + :param user_sub: str: 用户ID + :param mcp_id: str: MCP服务ID + :return: 是否激活 + """ + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPActivated).where( + and_( + MCPActivated.mcpId == mcp_id, + MCPActivated.userSub == user_sub, + ), + ))).one_or_none() + return bool(mcp_info) + + + @staticmethod + async def query_mcp_tools(mcp_id: str) -> list[MCPTools]: + """ + 查询MCP工具 + + :param mcp_id: str: MCP服务ID + :return: MCP工具列表 + """ + async with postgres.session() as session: + mcp_tools = list((await session.scalars(select(MCPTools).where(MCPTools.mcpId == mcp_id))).all()) + + + @staticmethod + async def add_mcp_template(mcp_id: str, config: MCPServerConfig, tools: list[MCPTools]) -> None: + # 插入MCP表 + async with postgres.session() as session: + await session.merge(MCPInfo( + id=mcp_id, + name=config.name, + description=config.description, + config=config, + overview=config.overview, + mcpType=config.mcp_type, + author=config.author, + status=MCPInstallStatus.INSTALLING, + )) -- Gitee