diff --git a/apps/models/app.py b/apps/models/app.py index aa6fa69d72945af6f8f480142d0e099610a35e7b..b774a5ee0610317d404c8c530dbe2c01a9e57824 100644 --- a/apps/models/app.py +++ b/apps/models/app.py @@ -1,9 +1,8 @@ """应用 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, Boolean, DateTime, Enum, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -29,8 +28,8 @@ class App(Base): """应用ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), index=True, ) """应用更新时间""" diff --git a/apps/models/blacklist.py b/apps/models/blacklist.py index 306f28bd0055438bb0a5966ae9b55cffda0cd6fa..c5358062ebec4bdd0b813e74d2e87d71cabbe507 100644 --- a/apps/models/blacklist.py +++ b/apps/models/blacklist.py @@ -1,9 +1,8 @@ """黑名单 数据库表结构""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -30,7 +29,7 @@ class Blacklist(Base): """举报原因""" updatedAt: Mapped[DateTime] = mapped_column( # noqa: N815 DateTime, - default_factory=lambda: datetime.now(pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), init=False, ) """更新时间""" diff --git a/apps/models/comment.py b/apps/models/comment.py index 92d6fcce29fbbad4c8205ab2950d73beba29050a..85394ce80778b4e7c1ff687adfaf54369fbea84f 100644 --- a/apps/models/comment.py +++ b/apps/models/comment.py @@ -1,9 +1,8 @@ """评论 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import ARRAY, BigInteger, DateTime, Enum, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -31,6 +30,6 @@ class Comment(Base): feedbackContent: Mapped[str] = mapped_column(String(1000)) # noqa: N815 """投诉内容""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 - DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), ) """评论创建时间""" diff --git a/apps/models/conversation.py b/apps/models/conversation.py index caa80b5bf0a1b71d01000e562166e95602806592..d3522c8341374c564c38a6c83593eec4104f2119 100644 --- a/apps/models/conversation.py +++ b/apps/models/conversation.py @@ -1,10 +1,9 @@ """对话 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime from enum import Enum as PyEnum -import pytz from sqlalchemy import BigInteger, Boolean, DateTime, Enum, ForeignKey, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -29,7 +28,7 @@ class Conversation(Base): ) """对话ID""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 - DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), ) """对话创建时间""" isTemporary: Mapped[bool] = mapped_column(Boolean, default=False) # noqa: N815 diff --git a/apps/models/document.py b/apps/models/document.py index 3ea18f394f53a04794a8cd4254afa93d32cf203c..5f1419569ca794fe8563f1fa8e2902c6e66bbc06 100644 --- a/apps/models/document.py +++ b/apps/models/document.py @@ -1,9 +1,8 @@ """文件 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import DateTime, Float, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -31,7 +30,7 @@ class Document(Base): """文件的ID""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), ) """文件的创建时间""" __table_args__ = ( diff --git a/apps/models/flow.py b/apps/models/flow.py index c5c372138224a65fcccce2b9ee493eab62394f4b..188ce2c546da35b7f555170bb371a1387f76dff5 100644 --- a/apps/models/flow.py +++ b/apps/models/flow.py @@ -1,9 +1,8 @@ """Flow 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -33,8 +32,8 @@ class Flow(Base): """Flow的ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """Flow的更新时间""" __table_args__ = ( diff --git a/apps/models/llm.py b/apps/models/llm.py index 02abe04f8746dffab5398530b06a8f180f6652e6..dc0216c6b952074083f5e31fc8f324e008191374 100644 --- a/apps/models/llm.py +++ b/apps/models/llm.py @@ -1,8 +1,7 @@ """大模型信息 数据库表""" -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, DateTime, ForeignKey, Integer, String from sqlalchemy.orm import Mapped, mapped_column @@ -32,7 +31,7 @@ class LLMData(Base): """LLM最大Token数量""" createdAt: Mapped[DateTime] = mapped_column( # noqa: N815 DateTime, - default_factory=lambda: datetime.now(pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), init=False, ) """添加LLM的时间""" diff --git a/apps/models/mcp.py b/apps/models/mcp.py index b3454c879f8a9c30db4c8fa5fa55c1b754ffd244..98f2cb16728c2cd23b6100cbba06477ac298e45a 100644 --- a/apps/models/mcp.py +++ b/apps/models/mcp.py @@ -7,7 +7,7 @@ from hashlib import shake_128 from typing import Any from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -39,11 +39,11 @@ class MCPInfo(Base): """MCP 概述""" description: Mapped[str] = mapped_column(String(65535)) """MCP 描述""" - author: Mapped[str] = mapped_column(ForeignKey("framework_user.userSub")) + author: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub")) """MCP 创建者""" config: Mapped[dict[str, Any]] = mapped_column(JSONB) """MCP 配置""" - id: Mapped[str] = mapped_column(String(255), primary_key=True) + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) """MCP ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), @@ -61,9 +61,11 @@ class MCPActivated(Base): """MCP 激活用户""" __tablename__ = "framework_mcp_activated" - mcpId: Mapped[str] = mapped_column(ForeignKey("framework_mcp.id"), index=True, unique=True) # noqa: N815 + mcpId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 + UUID(as_uuid=True), ForeignKey("framework_mcp.id"), index=True, unique=True, + ) """MCP ID""" - userSub: Mapped[str] = mapped_column(ForeignKey("framework_user.userSub")) # noqa: N815 + userSub: Mapped[str] = mapped_column(String(50), ForeignKey("framework_user.userSub")) # noqa: N815 """用户ID""" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" @@ -73,7 +75,9 @@ class MCPTools(Base): """MCP 工具""" __tablename__ = "framework_mcp_tools" - mcpId: Mapped[str] = mapped_column(ForeignKey("framework_mcp.id"), index=True) # noqa: N815 + mcpId: Mapped[uuid.UUID] = mapped_column( # noqa: N815 + UUID(as_uuid=True), ForeignKey("framework_mcp.id"), index=True, + ) """MCP ID""" toolName: Mapped[str] = mapped_column(String(255)) # noqa: N815 """MCP 工具名称""" diff --git a/apps/models/node.py b/apps/models/node.py index 9303fd868f6976502e8406b9b4e54f15b32043f7..fe4e607fdbcc7ee4546ee3058867e9cf11354217 100644 --- a/apps/models/node.py +++ b/apps/models/node.py @@ -1,10 +1,9 @@ """节点 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime from typing import Any -import pytz from sqlalchemy import DateTime, ForeignKey, String from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column @@ -34,8 +33,8 @@ class NodeInfo(Base): """节点ID""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """节点更新时间""" diff --git a/apps/models/record.py b/apps/models/record.py index 8c385b1e97d0fa423277bb55b56a471c81b4661d..ee2a352a6d5dabec08be7b28d86d9fc5392d30c9 100644 --- a/apps/models/record.py +++ b/apps/models/record.py @@ -1,10 +1,9 @@ """问答对 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime from typing import Any -import pytz from sqlalchemy import VARCHAR, BigInteger, DateTime, Float, ForeignKey, Integer, String from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column @@ -31,7 +30,7 @@ class Record(Base): ) """问答对ID""" createdAt: Mapped[datetime] = mapped_column( # noqa: N815 - DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=UTC), ) """问答对创建时间""" diff --git a/apps/models/service.py b/apps/models/service.py index c23d15145578eec5730d9ec54a81182f0d7340df..8f51f86659e63b009147a2bd658c162ffd617a3a 100644 --- a/apps/models/service.py +++ b/apps/models/service.py @@ -1,9 +1,8 @@ """插件 数据库表""" import uuid -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Mapped, mapped_column @@ -25,8 +24,8 @@ class Service(Base): """插件作者""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """插件更新时间""" id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) diff --git a/apps/models/session.py b/apps/models/session.py index 795d22953c2bb2666f02f261df36ce6ea773947d..87174a55c035625a642c93381699f44cd728511c 100644 --- a/apps/models/session.py +++ b/apps/models/session.py @@ -1,10 +1,9 @@ """会话相关 数据库表""" import secrets -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta from enum import Enum as PyEnum -import pytz from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -28,7 +27,7 @@ class Session(Base): """用户名""" validUntil: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")) + timedelta(days=30), + default_factory=lambda: datetime.now(tz=UTC) + timedelta(days=30), ) """有效期""" id: Mapped[str] = mapped_column(String(255), primary_key=True, default_factory=lambda: secrets.token_hex(16)) diff --git a/apps/models/tag.py b/apps/models/tag.py index 22238fe0aaace0ea2597875bf46dfedea699f170..19ca8f58bcffb8fea4947105cf2e3b7534520de7 100644 --- a/apps/models/tag.py +++ b/apps/models/tag.py @@ -1,8 +1,7 @@ """用户标签 数据库表""" -from datetime import datetime +from datetime import UTC, datetime -import pytz from sqlalchemy import BigInteger, DateTime, String from sqlalchemy.orm import Mapped, mapped_column @@ -21,7 +20,7 @@ class Tag(Base): """标签定义""" updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 DateTime(timezone=True), - default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), - onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), ) """标签的更新时间""" diff --git a/apps/models/task.py b/apps/models/task.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index 397d90444b426be788293a6061b3f527d8cee05a..0581b5b633191e06a00810fbdeee9269acbc3d0e 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -173,7 +173,7 @@ async def get_service_detail( indent=4, ensure_ascii=False, ), - mcpType=config.type, + mcpType=config.mcp_type, ) else: # 组装详情所需信息 diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index 932d8272c1c63398cb4fb12601b7e65f2c21dc35..b74e7be4f96d1b04868309f73ea9a08da79dde4d 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -6,6 +6,7 @@ import base64 import json import logging import shutil +import uuid from hashlib import shake_128 import asyncer @@ -120,7 +121,7 @@ class MCPLoader(metaclass=SingletonMeta): @staticmethod - async def init_one_template(mcp_id: str, config: MCPServerConfig) -> None: + async def init_one_template(mcp_id: uuid.UUID, config: MCPServerConfig) -> None: """ 初始化单个MCP模板 @@ -132,7 +133,7 @@ class MCPLoader(metaclass=SingletonMeta): await MCPLoader._insert_template_db(mcp_id, config) # 检查目录 - template_path = MCP_PATH / "template" / mcp_id + template_path = MCP_PATH / "template" / str(mcp_id) await Path.mkdir(template_path, parents=True, exist_ok=True) # 安装MCP模板 @@ -191,12 +192,12 @@ class MCPLoader(metaclass=SingletonMeta): """ # 创建客户端 if ( - (config.type == MCPType.STDIO and isinstance(config.config, MCPServerStdioConfig)) - or (config.type == MCPType.SSE and isinstance(config.config, MCPServerSSEConfig)) + (config.mcp_type == MCPType.STDIO and isinstance(config.config, MCPServerStdioConfig)) + or (config.mcp_type == MCPType.SSE and isinstance(config.config, MCPServerSSEConfig)) ): client = MCPClient() else: - err = f"MCP {mcp_id}:未知的MCP服务类型“{config.type}”" + err = f"MCP {mcp_id}:未知的MCP服务类型“{config.mcp_type}”" logger.error(err) raise ValueError(err) @@ -217,7 +218,7 @@ class MCPLoader(metaclass=SingletonMeta): return tool_list @staticmethod - async def _insert_template_db(mcp_id: str, config: MCPServerConfig) -> None: + async def _insert_template_db(mcp_id: uuid.UUID, config: MCPServerConfig) -> None: """插入单个MCP Server模板信息到数据库""" mcp_collection = MongoDB().get_collection("mcp") await mcp_collection.update_one( @@ -227,7 +228,7 @@ class MCPLoader(metaclass=SingletonMeta): _id=mcp_id, name=config.name, description=config.description, - type=config.type, + type=config.mcp_type, author=config.author, ).model_dump(by_alias=True, exclude_none=True), }, @@ -311,7 +312,7 @@ class MCPLoader(metaclass=SingletonMeta): await LanceDB().create_index("mcp_tool") @staticmethod - async def save_one(mcp_id: str, config: MCPServerConfig) -> None: + async def save_one(mcp_id: uuid.UUID, config: MCPServerConfig) -> None: """ 保存单个MCP模板的配置文件(``config.json``文件和``icon.png``文件) @@ -451,7 +452,15 @@ class MCPLoader(metaclass=SingletonMeta): :param list[str] deleted_mcp_list: 被删除的MCP列表 :return: 无 """ - # 从MongoDB中移除 + # 移除Info + async with postgres.session() as session: + for mcp_id in deleted_mcp_list: + mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() + if mcp_info: + await session.delete(mcp_info) + + + mcp_collection = MongoDB().get_collection("mcp") mcp_service_list = await mcp_collection.find( {"_id": {"$in": deleted_mcp_list}}, diff --git a/apps/scheduler/pool/mcp/default.py b/apps/scheduler/pool/mcp/default.py index 4f29ccaee3bf5c85eeb06e85a1be5abd0ab13463..f16b964bff3722aff73bee08a1f2c067231c670f 100644 --- a/apps/scheduler/pool/mcp/default.py +++ b/apps/scheduler/pool/mcp/default.py @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) DEFAULT_STDIO = MCPServerConfig( name="MCP服务_" + uuid.uuid4().hex[:6], description="MCP服务描述", - type=MCPType.STDIO, + mcp_type=MCPType.STDIO, config=MCPServerStdioConfig( command="uvx", args=[ @@ -31,7 +31,7 @@ DEFAULT_STDIO = MCPServerConfig( DEFAULT_SSE = MCPServerConfig( name="MCP服务_" + uuid.uuid4().hex[:6], description="MCP服务描述", - type=MCPType.SSE, + mcp_type=MCPType.SSE, config=MCPServerSSEConfig( url="http://test.domain/sse", env={ diff --git a/apps/scheduler/pool/mcp/pool.py b/apps/scheduler/pool/mcp/pool.py index 9119b901aa8a7b00dcb786be0da35e11e1223155..ec711ef961de208d5a28faa4972010e550a2da9c 100644 --- a/apps/scheduler/pool/mcp/pool.py +++ b/apps/scheduler/pool/mcp/pool.py @@ -33,7 +33,7 @@ class MCPPool(metaclass=SingletonMeta): config = MCPServerConfig.model_validate_json(await config_path.read_text()) - if config.type in (MCPType.SSE, MCPType.STDIO): + if config.mcp_type in (MCPType.SSE, MCPType.STDIO): client = MCPClient() else: logger.warning("[MCPPool] 用户 %s 的MCP %s 类型错误", user_sub, mcp_id) diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index e6b98d25bf946e7fde511441039d166c42f06fe5..bdd9294b42f4ef5662083b81f14f1ad41cf4ec8e 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -44,7 +44,7 @@ class MCPServerConfig(BaseModel): name: str = Field(description="MCP 服务器自然语言名称", default="") overview: str = Field(description="MCP 服务器概述", default="") description: str = Field(description="MCP 服务器自然语言描述", default="") - type: MCPType = Field(description="MCP 服务器类型", default=MCPType.STDIO) + mcp_type: MCPType = Field(description="MCP 服务器类型", default=MCPType.STDIO) author: str = Field(description="MCP 服务器上传者", default="") config: MCPServerStdioConfig | MCPServerSSEConfig = Field(description="MCP 服务器配置") diff --git a/apps/services/activity.py b/apps/services/activity.py index 1f7221f60fa9b77c8d7158dc9837c36ad90e561a..e378fcef222a8a235e0a77ef1f3ce0933dc7200c 100644 --- a/apps/services/activity.py +++ b/apps/services/activity.py @@ -2,9 +2,8 @@ """用户限流""" import uuid -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta -import pytz from sqlalchemy import delete from apps.common.postgres import postgres @@ -24,7 +23,7 @@ class Activity: :param user_sub: 用户实体ID :return: 判断结果,正在提问则返回True """ - time = datetime.now(pytz.timezone("Asia/Shanghai")) + time = datetime.now(tz=UTC) # 检查窗口内总请求数 count = await MongoDB().get_collection("activity").count_documents( diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index dda29b49b88544428c3f93d637d790c685fbad38..da5ba28f0ba88898bf1fcf75b25ecf719484383b 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -247,7 +247,7 @@ class AppCenterManager: msg = f"[AppCenterManager] 应用不存在: {app_id}" raise ValueError(msg) app_obj.isPublished = published - session.add(app_obj) + await session.merge(app_obj) await session.commit() await AppCenterManager._process_app_and_save( @@ -376,11 +376,11 @@ class AppCenterManager: if app_data.appId == app_id: app_data.lastUsed = datetime.now(UTC) app_data.usageCount += 1 - session.add(app_data) + await session.merge(app_data) break else: app_data = UserAppUsage(userSub=user_sub, appId=app_id, lastUsed=datetime.now(UTC), usageCount=1) - session.add(app_data) + await session.merge(app_data) await session.commit() return diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index 871d03fda4ded919abc5d6fdd08e31bdd6f5496e..7ae7177157632e80e65a20abddcd27ebbf512490 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -182,7 +182,7 @@ class AbuseManager: reason=reason, ) - session.add(new_blacklist) + await session.merge(new_blacklist) await session.commit() return True diff --git a/apps/services/comment.py b/apps/services/comment.py index 27c4fcca8e27c160aaafd1a0d83cc07695dedd4c..25b084c3c768a3e46dba05581e47180b3d210e45 100644 --- a/apps/services/comment.py +++ b/apps/services/comment.py @@ -68,5 +68,5 @@ class CommentManager: feedbackLink=data.feedback_link, feedbackContent=data.feedback_content, ) - session.add(comment_info) + await session.merge(comment_info) await session.commit() diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 257e0a3156eb50d11db11aeee6c8ccf71743b940..ade7fb6bf2cc227e148947064153e68f7c3b289b 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -3,10 +3,9 @@ import logging import uuid -from datetime import datetime +from datetime import UTC, datetime from typing import Any -import pytz from sqlalchemy import and_, func, select from apps.common.postgres import postgres @@ -78,7 +77,7 @@ class ConversationManager: # 使用PostgreSQL实现新建对话,并根据debug和usage进行更新 try: async with postgres.session() as session: - session.add(conv) + await session.merge(conv) await session.commit() await session.refresh(conv) @@ -96,15 +95,15 @@ class ConversationManager: # 假设App模型有last_used和usage_count字段(如没有请根据实际表结构调整) # 这里只做示例,实际字段名和类型请根据实际情况修改 app_obj.usageCount += 1 - app_obj.lastUsed = datetime.now(pytz.timezone("Asia/Shanghai")) - session.add(app_obj) + app_obj.lastUsed = datetime.now(tz=UTC) + await session.merge(app_obj) await session.commit() else: - session.add(UserAppUsage( + await session.merge(UserAppUsage( userSub=user_sub, appId=app_id, usageCount=1, - lastUsed=datetime.now(pytz.timezone("Asia/Shanghai")), + lastUsed=datetime.now(tz=UTC), )) await session.commit() return conv @@ -131,7 +130,7 @@ class ConversationManager: return False for key, value in data.items(): setattr(conv, key, value) - session.add(conv) + await session.merge(conv) await session.commit() return True diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 80abf5f48876bef780ac11318b11734123ab1fc0..07a79206ec64a5cbf9cb76318acb54773e6a6db0 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -1,37 +1,40 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. -"""MCP服务管理器""" +"""MCP服务(插件中心)管理器""" import logging -import random import re +import uuid from typing import Any from fastapi import UploadFile from PIL import Image -from sqids.sqids import Sqids +from sqlalchemy import and_, or_, select +from apps.common.postgres import postgres from apps.constants import ( ALLOWED_ICON_MIME_TYPES, ICON_PATH, SERVICE_PAGE_SIZE, ) +from apps.models.mcp import ( + MCPActivated, + MCPInfo, + MCPInstallStatus, + MCPTools, + MCPType, +) from apps.scheduler.pool.loader.mcp import MCPLoader from apps.scheduler.pool.mcp.pool import MCPPool from apps.schemas.enum_var import SearchType from apps.schemas.mcp import ( - MCPCollection, - MCPInstallStatus, MCPServerConfig, MCPServerSSEConfig, MCPServerStdioConfig, - MCPTool, - MCPType, ) from apps.schemas.request_data import UpdateMCPServiceRequest from apps.schemas.response_data import MCPServiceCardItem logger = logging.getLogger(__name__) -sqids = Sqids(min_length=6) MCP_ICON_PATH = ICON_PATH / "mcp" @@ -47,10 +50,15 @@ class MCPServiceManager: :param str mcp_id: MCP服务ID :return: 是否激活 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - mcp_list = await mcp_collection.find({"_id": mcp_id}, {"activated": True}).to_list(None) - return any(user_sub in db_item.get("activated", []) for db_item in mcp_list) + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPActivated).where( + and_( + MCPActivated.mcpId == mcp_id, + MCPActivated.userSub == user_sub, + ), + ))).one_or_none() + return bool(mcp_info) + @staticmethod async def get_service_status(mcp_id: str) -> MCPInstallStatus: @@ -60,16 +68,12 @@ class MCPServiceManager: :param str mcp_id: MCP服务ID :return: MCP服务状态 """ - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - mcp_list = await mcp_collection.find({"_id": mcp_id}, {"status": True}).to_list(None) - for db_item in mcp_list: - status = db_item.get("status") - if MCPInstallStatus.READY.value == status: - return MCPInstallStatus.READY - if MCPInstallStatus.INSTALLING.value == status: - return MCPInstallStatus.INSTALLING - return MCPInstallStatus.FAILED + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcp_id))).one_or_none() + if mcp_info: + return mcp_info.status + return MCPInstallStatus.FAILED + @staticmethod async def fetch_mcp_services( @@ -87,8 +91,7 @@ class MCPServiceManager: :param page: int: 页码 :return: MCP服务列表 """ - filters = MCPServiceManager._build_filters(search_type, keyword) - mcpservice_pools = await MCPServiceManager._search_mcpservice(filters, page) + mcpservice_pools = await MCPServiceManager._search_mcpservice(search_type, keyword, page) return [ MCPServiceCardItem( mcpserviceId=item.id, @@ -102,21 +105,18 @@ class MCPServiceManager: for item in mcpservice_pools ] + @staticmethod - async def get_mcp_service(mcpservice_id: str) -> MCPCollection: + async def get_mcp_service(mcpservice_id: str) -> MCPInfo | None: """ 获取MCP服务详细信息 :param mcpservice_id: str: MCP服务ID :return: MCP服务详细信息 """ - # 验证用户权限 - mcpservice_collection = MongoDB().get_collection("mcp") - db_service = await mcpservice_collection.find_one({"_id": mcpservice_id}) - if not db_service: - msg = "[MCPServiceManager] MCP服务未找到" - raise RuntimeError(msg) - return MCPCollection.model_validate(db_service) + async with postgres.session() as session: + return (await session.scalars(select(MCPInfo).where(MCPInfo.id == mcpservice_id))).one_or_none() + @staticmethod async def get_mcp_config(mcpservice_id: str) -> tuple[MCPServerConfig, str]: @@ -130,31 +130,27 @@ class MCPServiceManager: config = await MCPLoader.get_config(mcpservice_id) return config, icon + @staticmethod async def get_service_tools( service_id: str, - ) -> list[MCPTool]: + ) -> list[MCPTools]: """ 获取MCP可以用工具 :param service_id: str: MCP服务ID :return: MCP工具详细信息列表 """ - # 获取服务名称 - service_collection = MongoDB().get_collection("mcp") - data = await service_collection.find({"_id": service_id}, {"tools": True}).to_list(None) - result = [] - for item in data: - for tool in item.get("tools", {}): - tool_data = MCPTool.model_validate(tool) - result.append(tool_data) - return result + async with postgres.session() as session: + return list((await session.scalars(select(MCPTools).where(MCPTools.mcpId == service_id))).all()) + @staticmethod async def _search_mcpservice( - search_conditions: dict[str, Any], + search_type: SearchType, + keyword: str | None, page: int, - ) -> list[MCPCollection]: + ) -> list[MCPInfo]: """ 基于输入条件搜索MCP服务 @@ -162,44 +158,53 @@ class MCPServiceManager: :param page: int: 页码 :return: MCP列表 """ - mcpservice_collection = MongoDB().get_collection("mcp") # 分页查询 skip = (page - 1) * SERVICE_PAGE_SIZE - db_mcpservices = await mcpservice_collection.find(search_conditions).skip(skip).limit( - SERVICE_PAGE_SIZE, - ).to_list() + + async with postgres.session() as session: + if not keyword: + result = list( + (await session.scalars(select(MCPInfo).offset(skip).limit(SERVICE_PAGE_SIZE))).all(), + ) + elif search_type == SearchType.ALL: + result = list( + (await session.scalars(select(MCPInfo).where( + or_( + MCPInfo.name.like(f"%{keyword}%"), + MCPInfo.description.like(f"%{keyword}%"), + MCPInfo.author.like(f"%{keyword}%"), + ), + ).offset(skip).limit(SERVICE_PAGE_SIZE))).all(), + ) + elif search_type == SearchType.NAME: + result = list( + (await session.scalars( + select(MCPInfo).where(MCPInfo.name.like(f"%{keyword}%")).offset(skip).limit(SERVICE_PAGE_SIZE), + )).all(), + ) + elif search_type == SearchType.DESCRIPTION: + result = list( + (await session.scalars( + select(MCPInfo).where(MCPInfo.description.like(f"%{keyword}%")).offset(skip).limit(SERVICE_PAGE_SIZE), + )).all(), + ) + elif search_type == SearchType.AUTHOR: + result = list( + (await session.scalars( + select(MCPInfo).where(MCPInfo.author.like(f"%{keyword}%")).offset(skip).limit(SERVICE_PAGE_SIZE), + )).all(), + ) + # 如果未找到,返回空列表 - if not db_mcpservices: - logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_conditions) + if not result: + logger.warning("[MCPServiceManager] 没有找到符合条件的MCP服务: %s", search_type) return [] # 将数据库中的MCP服务转换为对象 - return [MCPCollection.model_validate(db_mcpservice) for db_mcpservice in db_mcpservices] - - @staticmethod - def _build_filters( - search_type: SearchType, - keyword: str | None, - ) -> dict[str, Any]: - if not keyword: - return {} - - if search_type == SearchType.ALL: - base_filters = {"$or": [ - {"name": {"$regex": keyword, "$options": "i"}}, - {"description": {"$regex": keyword, "$options": "i"}}, - {"author": {"$regex": keyword, "$options": "i"}}, - ]} - elif search_type == SearchType.NAME: - base_filters = {"name": {"$regex": keyword, "$options": "i"}} - elif search_type == SearchType.DESCRIPTION: - base_filters = {"description": {"$regex": keyword, "$options": "i"}} - elif search_type == SearchType.AUTHOR: - base_filters = {"author": {"$regex": keyword, "$options": "i"}} - return base_filters + return result @staticmethod - async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: + async def create_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> uuid.UUID: """ 创建MCP服务 @@ -218,23 +223,23 @@ class MCPServiceManager: overview=data.overview, description=data.description, config=config, - type=data.mcp_type, + mcp_type=data.mcp_type, author=user_sub, ) # 检查是否存在相同服务 - mcp_collection = MongoDB().get_collection("mcp") - db_service = await mcp_collection.find_one({"name": mcp_server.name}) - mcp_id = sqids.encode([random.randint(0, 1000000) for _ in range(5)])[:6] # noqa: S311 - if db_service: - mcp_server.name = f"{mcp_server.name}-{mcp_id}" - logger.warning("[MCPServiceManager] 已存在相同名称和描述的MCP服务") + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPInfo).where(MCPInfo.name == mcp_server.name))).one_or_none() + if mcp_info: + mcp_server.name = f"{mcp_server.name}-{uuid.uuid4().hex[:6]}" + logger.warning("[MCPServiceManager] 已存在相同ID或名称的MCP服务") # 保存并载入配置 logger.info("[MCPServiceManager] 创建mcp:%s", mcp_server.name) - await MCPLoader.save_one(mcp_id, mcp_server) - await MCPLoader.init_one_template(mcp_id=mcp_id, config=mcp_server) - return mcp_id + await MCPLoader.save_one(mcp_server.id, mcp_server) + await MCPLoader.init_one_template(mcp_id=mcp_server.id, config=mcp_server) + return mcp_server.id + @staticmethod async def update_mcpservice(data: UpdateMCPServiceRequest, user_sub: str) -> str: @@ -267,18 +272,18 @@ class MCPServiceManager: ) if data.mcp_type == MCPType.STDIO else MCPServerSSEConfig.model_validate_json( data.config, ), - type=data.mcp_type, + mcp_type=data.mcp_type, author=user_sub, )) # 返回服务ID return data.service_id + @staticmethod async def delete_mcpservice(service_id: str) -> None: """ 删除MCP服务 - :param user_sub: str: 用户ID :param service_id: str: MCP服务ID :return: 是否删除成功 """ @@ -292,6 +297,7 @@ class MCPServiceManager: {"$pull": {"mcp_service": service_id}}, ) + @staticmethod async def active_mcpservice( user_sub: str, @@ -370,3 +376,50 @@ class MCPServiceManager: image.save(MCP_ICON_PATH / f"{service_id}.png", format="PNG", optimize=True, compress_level=9) return f"/static/mcp/{service_id}.png" + + + @staticmethod + async def is_user_actived(user_sub: str, mcp_id: str) -> bool: + """ + 判断用户是否激活MCP + + :param user_sub: str: 用户ID + :param mcp_id: str: MCP服务ID + :return: 是否激活 + """ + async with postgres.session() as session: + mcp_info = (await session.scalars(select(MCPActivated).where( + and_( + MCPActivated.mcpId == mcp_id, + MCPActivated.userSub == user_sub, + ), + ))).one_or_none() + return bool(mcp_info) + + + @staticmethod + async def query_mcp_tools(mcp_id: str) -> list[MCPTools]: + """ + 查询MCP工具 + + :param mcp_id: str: MCP服务ID + :return: MCP工具列表 + """ + async with postgres.session() as session: + mcp_tools = list((await session.scalars(select(MCPTools).where(MCPTools.mcpId == mcp_id))).all()) + + + @staticmethod + async def add_mcp_template(mcp_id: str, config: MCPServerConfig, tools: list[MCPTools]) -> None: + # 插入MCP表 + async with postgres.session() as session: + await session.merge(MCPInfo( + id=mcp_id, + name=config.name, + description=config.description, + config=config, + overview=config.overview, + mcpType=config.mcp_type, + author=config.author, + status=MCPInstallStatus.INSTALLING, + )) diff --git a/tests/routers/test_blacklist.py b/tests/routers/test_blacklist.py index 3cc6b51b01b1b88b4816f34d56b71b01088ca87b..a5f1a4546c1265cf4cb520187679f615fd7b6a90 100644 --- a/tests/routers/test_blacklist.py +++ b/tests/routers/test_blacklist.py @@ -7,7 +7,7 @@ from unittest import TestCase from fastapi.testclient import TestClient from fastapi import status, FastAPI, Request -from apps.routers.blacklist import router +from apps.routers.blacklist import admin_router from apps.dependency import verify_csrf_token, get_current_user from apps.schemas.user import User @@ -24,7 +24,7 @@ class TestBlacklistRouter(TestCase): @classmethod def setUpClass(cls): app = FastAPI() - app.include_router(router) + app.include_router(admin_router) app.dependency_overrides[verify_csrf_token] = mock_csrf_token app.dependency_overrides[get_current_user] = mock_get_user cls.client = TestClient(app)