diff --git a/apps/models/mcp.py b/apps/models/mcp.py index 012afb42743cc65714d47dee3524b0af177b34f5..828cf36d4e4517daecb650064ab7c3989ec1e808 100644 --- a/apps/models/mcp.py +++ b/apps/models/mcp.py @@ -1,14 +1,80 @@ """MCP 相关 数据库表""" +import uuid +from datetime import UTC, datetime +from enum import Enum as PyEnum +from typing import Any + from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String -from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.dialects.postgresql import JSONB, UUID from sqlalchemy.orm import Mapped, mapped_column from .base import Base -class MCP(Base): +class MCPInstallStatus(str, PyEnum): + """MCP 服务状态""" + + INSTALLING = "installing" + READY = "ready" + FAILED = "failed" + + +class MCPType(str, PyEnum): + """MCP 类型""" + + SSE = "sse" + STDIO = "stdio" + STREAMABLE = "stream" + + +class MCPInfo(Base): """MCP""" __tablename__ = "framework_mcp" - \ No newline at end of file + name: Mapped[str] = mapped_column(String(255)) + """MCP 名称""" + overview: Mapped[str] = mapped_column(String(2000)) + """MCP 概述""" + description: Mapped[str] = mapped_column(String(65535)) + """MCP 描述""" + author: Mapped[str] = mapped_column(ForeignKey("framework_user.userSub")) + """MCP 创建者""" + config: Mapped[dict[str, Any]] = mapped_column(JSONB) + """MCP 配置""" + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) + """MCP ID""" + updatedAt: Mapped[datetime] = mapped_column( # noqa: N815 + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=UTC), + onupdate=lambda: datetime.now(tz=UTC), + ) + """MCP 更新时间""" + status: Mapped[MCPInstallStatus] = mapped_column(Enum(MCPInstallStatus), default=MCPInstallStatus.INSTALLING) + """MCP 状态""" + mcpType: Mapped[MCPType] = mapped_column(Enum(MCPType), default=MCPType.STDIO) # noqa: N815 + """MCP 类型""" + + +class MCPActivated(Base): + """MCP 激活用户""" + + __tablename__ = "framework_mcp_activated" + mcp_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_mcp.id"), index=True, unique=True) + """MCP ID""" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.userSub")) + """用户ID""" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" + + +class MCPTools(Base): + """MCP 工具""" + + __tablename__ = "framework_mcp_tools" + mcp_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_mcp.id"), index=True) + """MCP ID""" + tools: Mapped[dict[str, Any]] = mapped_column(JSONB) + """MCP 工具""" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 1161e545448a0b479053abde69d68e2bb1ad7687..7edfeb132caaf2259fc7f9c91f3f7da8dbada913 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -14,7 +14,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( CallError, @@ -51,7 +51,7 @@ class CoreCall(BaseModel): name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True) description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True) - node: SkipJsonSchema[Node | None] = Field(description="节点信息", exclude=True) + node: SkipJsonSchema[NodeInfo | None] = Field(description="节点信息", exclude=True) enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False, exclude=True) tokens: SkipJsonSchema[CallTokens] = Field( description="Call的输入输出Tokens信息", @@ -158,7 +158,7 @@ class CoreCall(BaseModel): @classmethod - async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self: """实例化Call类""" obj = cls( name=executor.step.step.name, diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index 2f7f4aa5e399fadd072f64d15897dbd49e994639..cfe9276d8ea3625bc8fa30836b91b7b84dc96b8b 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -8,7 +8,7 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars @@ -39,7 +39,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): @classmethod - async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self: """初始化工具""" obj = cls( answer=executor.task.runtime.answer, diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 4ae07cbd5197a73d13a75462884519fc5efbaae6..deed87b9fb519a3369581c49e521522ca142a30d 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -11,7 +11,7 @@ from pydantic import Field from apps.llm.function import FunctionLLM, JsonGenerator from apps.llm.reasoning import ReasoningLLM -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot as SlotProcessor from apps.schemas.enum_var import CallOutputType @@ -94,7 +94,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): return await json_gen.generate() @classmethod - async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self: """实例化Call类""" obj = cls( name=executor.step.step.name, diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index 40e6934776a836a8421928442c011171d8634fe3..432c64687b22d61fd0b27042a78720b25e9ac2a6 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -13,7 +13,7 @@ from pydantic.json_schema import SkipJsonSchema from apps.common.security import Security from apps.llm.function import FunctionLLM -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType from apps.schemas.record import RecordContent @@ -57,7 +57,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO @classmethod - async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self: """初始化""" context = [ { diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index 754dc6e54ba01981605e9e91cba20df00b6b7c23..644b60a408da2ae0b33d82c77c15e977bba8f697 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any, Self from pydantic import Field from apps.llm.patterns.executor import ExecutorSummary -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.scheduler.call.core import CoreCall, DataBase from apps.schemas.enum_var import CallOutputType from apps.schemas.scheduler import ( @@ -35,7 +35,7 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): return CallInfo(name="理解上下文", description="使用大模型,理解对话上下文") @classmethod - async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: NodeInfo | None, **kwargs: Any) -> Self: """实例化工具""" obj = cls( context=executor.background, diff --git a/apps/scheduler/mcp/host.py b/apps/scheduler/mcp/host.py index 78aa7bc3ee869e8710e1fb02a2d9fb438d04be34..6299f0f5b46f041ccc93ddfdf1ea9e22c034fb1d 100644 --- a/apps/scheduler/mcp/host.py +++ b/apps/scheduler/mcp/host.py @@ -9,7 +9,6 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from mcp.types import TextContent -from apps.common.mongo import MongoDB from apps.llm.function import JsonGenerator from apps.scheduler.mcp.prompt import MEMORY_TEMPLATE from apps.scheduler.pool.mcp.client import MCPClient @@ -17,6 +16,7 @@ from apps.scheduler.pool.mcp.pool import MCPPool from apps.schemas.enum_var import StepStatus from apps.schemas.mcp import MCPPlanItem, MCPTool from apps.schemas.task import FlowStepHistory +from apps.services.mcp_service import MCPServiceManager from apps.services.task import TaskManager logger = logging.getLogger(__name__) @@ -43,12 +43,7 @@ class MCPHost: async def get_client(self, mcp_id: str) -> MCPClient | None: """获取MCP客户端""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - - # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) - if not mcp_db_result: + if not await MCPServiceManager.is_user_actived(self._user_sub, mcp_id): logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) return None @@ -173,21 +168,16 @@ class MCPHost: async def get_tool_list(self, mcp_id_list: list[str]) -> list[MCPTool]: """获取工具列表""" - mongo = MongoDB() - mcp_collection = mongo.get_collection("mcp") - # 获取工具列表 tool_list = [] for mcp_id in mcp_id_list: # 检查用户是否启用了这个mcp - mcp_db_result = await mcp_collection.find_one({"_id": mcp_id, "activated": self._user_sub}) - if not mcp_db_result: + if not await MCPServiceManager.is_user_actived(self._user_sub, mcp_id): logger.warning("用户 %s 未启用MCP %s", self._user_sub, mcp_id) continue # 获取MCP工具配置 try: - for tool in mcp_db_result["tools"]: - tool_list.extend([MCPTool.model_validate(tool)]) + tool_list.extend(await MCPServiceManager.get_service_tools(mcp_id)) except KeyError: logger.warning("用户 %s 的MCP Tool %s 配置错误", self._user_sub, mcp_id) continue diff --git a/apps/scheduler/mcp/select.py b/apps/scheduler/mcp/select.py index 2833f74ebbcfe3dade8e2cb837aa0ef7161f75e6..ffa45bc91708600bf38cf65847c2ea4578751444 100644 --- a/apps/scheduler/mcp/select.py +++ b/apps/scheduler/mcp/select.py @@ -3,8 +3,6 @@ import logging -from apps.common.lance import LanceDB -from apps.common.mongo import MongoDB from apps.llm.embedding import Embedding from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM @@ -12,6 +10,7 @@ from apps.schemas.mcp import ( MCPSelectResult, MCPTool, ) +from apps.services.mcp_service import MCPServiceManager logger = logging.getLogger(__name__) @@ -80,21 +79,11 @@ class MCPSelector: )).where(f"mcp_id IN {MCPSelector._assemble_sql(mcp_list)}").limit(top_n).to_list() # 拿到工具 - tool_collection = MongoDB().get_collection("mcp") llm_tool_list = [] for tool_vec in tool_vecs: - # 到MongoDB里找对应的工具 logger.info("[MCPHelper] 查询MCP Tool名称和描述: %s", tool_vec["mcp_id"]) - tool_data = await tool_collection.aggregate([ - {"$match": {"_id": tool_vec["mcp_id"]}}, - {"$unwind": "$tools"}, - {"$match": {"tools.id": tool_vec["id"]}}, - {"$project": {"_id": 0, "tools": 1}}, - {"$replaceRoot": {"newRoot": "$tools"}}, - ]) - async for tool in tool_data: - tool_obj = MCPTool.model_validate(tool) - llm_tool_list.append(tool_obj) + tool_data = await MCPServiceManager.get_service_tools(tool_vec["mcp_id"]) + llm_tool_list.extend(tool_data) return llm_tool_list diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index a485af28b669e6c07e4ba8888cde2fc68284d06f..4cc7d71defc3b86b294519fffd3a74bbbe5a83ef 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -14,7 +14,7 @@ from apps.common.lance import LanceDB from apps.common.mongo import MongoDB from apps.common.singleton import SingletonMeta from apps.llm.embedding import Embedding -from apps.models.node import Node +from apps.models.node import NodeInfo logger = logging.getLogger(__name__) BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "call" diff --git a/apps/scheduler/pool/loader/mcp.py b/apps/scheduler/pool/loader/mcp.py index d5f21a49e653bb93f0ccb7c7af9c2ce2a1de79c0..ef4a2a41a4f86862c5d3654d9a38b6ed0e77614b 100644 --- a/apps/scheduler/pool/loader/mcp.py +++ b/apps/scheduler/pool/loader/mcp.py @@ -12,23 +12,20 @@ import asyncer from anyio import Path from sqids.sqids import Sqids -from apps.common.lance import LanceDB -from apps.common.mongo import MongoDB from apps.common.process_handler import ProcessHandler from apps.common.singleton import SingletonMeta from apps.constants import MCP_PATH from apps.llm.embedding import Embedding +from apps.models.mcp import MCPInfo, MCPInstallStatus, MCPTools, MCPType from apps.scheduler.pool.mcp.client import MCPClient from apps.scheduler.pool.mcp.install import install_npx, install_uvx from apps.schemas.mcp import ( MCPCollection, - MCPInstallStatus, MCPServerConfig, MCPServerSSEConfig, MCPServerStdioConfig, MCPTool, MCPToolVector, - MCPType, MCPVector, ) diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index f455ac23250d9a25cc89918f414a6b713a1a4943..53e1e05b546df18c8089d6c07ff562e53aa5ecaf 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -8,7 +8,7 @@ from typing import Any import yaml from anyio import Path -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.scheduler.openapi import ( ReducedOpenAPIEndpoint, ReducedOpenAPISpec, @@ -166,7 +166,7 @@ class OpenAPILoader: return spec - async def load_one(self, service_id: str, yaml_path: Path, server: str) -> list[Node]: + async def load_one(self, service_id: str, yaml_path: Path, server: str) -> list[NodeInfo]: """加载单个OpenAPI文档,可以直接指定路径""" try: spec = await self._read_yaml(yaml_path) @@ -185,7 +185,7 @@ class OpenAPILoader: raise RuntimeError(err) from e return [ - Node( + NodeInfo( name=node.name, description=node.description, callId=node.call_id, diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 6b4de770dbb3491668db34a3fb4dbd04b8d2feba..c1dc502eaa554fbd9562070c1fdede6f36072f06 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -12,7 +12,7 @@ from apps.common.config import config from apps.common.mongo import MongoDB from apps.common.postgres import postgres from apps.llm.embedding import Embedding -from apps.models.node import Node +from apps.models.node import NodeInfo from apps.models.service import Service from apps.models.vectors import NodePoolVector, ServicePoolVector from apps.scheduler.pool.check import FileChecker diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 74fa4e4a59429cac9512de8fe6e9015984c55053..23b45d4b86e6eb6fc654c893cb3550724e46f1cb 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -3,6 +3,7 @@ import importlib import logging +import uuid from typing import Any from anyio import Path diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 3e9ffca11b4f8dd2c76d2a8395775739b8f3eb17..929d763922895dd7f8e00c11466a7dbe3a1ac103 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -65,7 +65,7 @@ async def push_rag_message( """推送RAG消息""" full_answer = "" - async for chunk in RAG.get_rag_result(user_sub, llm, history, doc_ids, rag_data): + async for chunk in RAG.chat_with_llm_base_on_rag(user_sub, llm, history, doc_ids, rag_data): task, content_obj = await _push_rag_chunk(task, queue, chunk) if content_obj.event_type == EventType.TEXT_ADD.value: # 如果是文本消息,直接拼接到答案中 diff --git a/apps/schemas/mcp.py b/apps/schemas/mcp.py index 5b6e25850e78705a819509b1cde1d47f654f0013..43be897cc794159c03f2b9b1c59ceb39b10a79df 100644 --- a/apps/schemas/mcp.py +++ b/apps/schemas/mcp.py @@ -8,14 +8,6 @@ from lancedb.pydantic import LanceModel, Vector from pydantic import BaseModel, Field -class MCPInstallStatus(str, Enum): - """MCP 服务状态""" - - INSTALLING = "installing" - READY = "ready" - FAILED = "failed" - - class MCPStatus(str, Enum): """MCP 状态""" @@ -24,14 +16,6 @@ class MCPStatus(str, Enum): STOPPED = "stopped" -class MCPType(str, Enum): - """MCP 类型""" - - SSE = "sse" - STDIO = "stdio" - STREAMABLE = "stream" - - class MCPBasicConfig(BaseModel): """MCP 基本配置""" diff --git a/apps/services/task.py b/apps/services/task.py index 98926f506614b205b20ab2f281eb0814a67e7508..a0fbb109c824f81d919824e9e4693ea26aae5a54 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -5,7 +5,7 @@ import logging import uuid from typing import Any -from apps.common.mongo import MongoDB +from apps.common.postgres import postgres from apps.schemas.record import RecordGroup from apps.schemas.request_data import RequestData from apps.schemas.task import ( @@ -151,7 +151,7 @@ class TaskManager: @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: str) -> None: + async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") diff --git a/apps/services/token.py b/apps/services/token.py index 359f06cc17df4874697658a47e55982cdfc95480..5cfe46363ef750c87d3816773b754464804c7ba2 100644 --- a/apps/services/token.py +++ b/apps/services/token.py @@ -7,7 +7,7 @@ from datetime import UTC, datetime, timedelta import httpx from fastapi import status -from sqlalchemy import select +from sqlalchemy import and_, select from apps.common.config import config from apps.common.oidc import oidc_provider @@ -42,9 +42,11 @@ class TokenManager: await session.scalars( select(Session) .where( - Session.userSub == user_sub, - Session.sessionType == SessionType.ACCESS_TOKEN, - Session.pluginId == str(plugin_id), + and_( + Session.userSub == user_sub, + Session.sessionType == SessionType.ACCESS_TOKEN, + Session.pluginId == str(plugin_id), + ), ), ) ).one_or_none() @@ -96,7 +98,12 @@ class TokenManager: token_data = ( await session.scalars( select(Session) - .where(Session.userSub == user_sub, Session.sessionType == SessionType.ACCESS_TOKEN), + .where( + and_( + Session.userSub == user_sub, + Session.sessionType == SessionType.ACCESS_TOKEN, + ), + ), ) ).one_or_none() @@ -108,7 +115,12 @@ class TokenManager: refresh_token = ( await session.scalars( select(Session) - .where(Session.userSub == user_sub, Session.sessionType == SessionType.REFRESH_TOKEN), + .where( + and_( + Session.userSub == user_sub, + Session.sessionType == SessionType.REFRESH_TOKEN, + ), + ), ) ).one_or_none() @@ -171,12 +183,14 @@ class TokenManager: token_data = (await session.scalars( select(Session) .where( - Session.userSub == user_sub, - Session.sessionType.in_([ - SessionType.ACCESS_TOKEN, - SessionType.REFRESH_TOKEN, - SessionType.PLUGIN_TOKEN, - ]), + and_( + Session.userSub == user_sub, + Session.sessionType.in_([ + SessionType.ACCESS_TOKEN, + SessionType.REFRESH_TOKEN, + SessionType.PLUGIN_TOKEN, + ]), + ), ), )).all() for token in token_data: diff --git a/apps/services/user.py b/apps/services/user.py index 675fa019e76c2dc377ca7ce62593144c430d6f17..72f144fcb16d7b893fa8d9a2d25184482966053e 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -71,7 +71,6 @@ class UserManager: user.lastLogin = datetime.now(tz=pytz.timezone("Asia/Shanghai")) await session.commit() - @staticmethod async def delete_user(user_sub: str) -> None: """ @@ -89,5 +88,4 @@ class UserManager: await session.delete(user) await session.commit() - for conv_id in result.conversations: - await ConversationManager.delete_conversation_by_conversation_id(user_sub, conv_id) + await ConversationManager.delete_conversation_by_user_sub(user_sub) diff --git a/apps/services/user_tag.py b/apps/services/user_tag.py index d6eed5aa40469bc74e2f16a7ce9b5ee6f43e01e6..7c84f4a2ca4528a32a76967ef5c97d298ee48133 100644 --- a/apps/services/user_tag.py +++ b/apps/services/user_tag.py @@ -3,7 +3,7 @@ import logging -from sqlalchemy import select +from sqlalchemy import and_, select from apps.common.postgres import postgres from apps.models.tag import Tag @@ -56,11 +56,14 @@ class UserTagManager: user_domain = ( await session.scalars( select(UserTag).where( - UserTag.userSub == user_sub, - UserTag.tag == tag.id, + and_( + UserTag.userSub == user_sub, + UserTag.tag == tag.id, + ), ), ) ).one_or_none() + if not user_domain: user_domain = UserTag(userSub=user_sub, tag=tag.id, count=1) session.add(user_domain)