From 002e574eef4cdbfccd2691531102227d0a3f2fb6 Mon Sep 17 00:00:00 2001 From: z30057876 Date: Wed, 30 Jul 2025 15:21:30 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=81=E7=A7=BB=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/exceptions.py | 3 - apps/models/user.py | 4 +- apps/scheduler/pool/loader/app.py | 13 +- apps/scheduler/pool/loader/openapi.py | 51 ++- apps/scheduler/pool/loader/service.py | 119 +++---- apps/scheduler/scheduler/scheduler.py | 3 +- apps/schemas/flow.py | 3 +- apps/schemas/node.py | 8 - apps/schemas/request_data.py | 2 +- apps/schemas/response_data.py | 10 +- apps/services/appcenter.py | 44 +-- apps/services/document.py | 2 +- apps/services/service.py | 465 +++++++++++++++++--------- 13 files changed, 422 insertions(+), 305 deletions(-) diff --git a/apps/exceptions.py b/apps/exceptions.py index 32ef76d1..1831246e 100644 --- a/apps/exceptions.py +++ b/apps/exceptions.py @@ -4,9 +4,6 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """ -class ServiceIDError(Exception): - """Service ID错误""" - class InstancePermissionError(Exception): """App/Service实例的权限错误""" diff --git a/apps/models/user.py b/apps/models/user.py index 3ec6c485..ed598c69 100644 --- a/apps/models/user.py +++ b/apps/models/user.py @@ -43,8 +43,8 @@ class User(Base): class UserFavoriteType(str, enum.Enum): """用户收藏类型""" - app = "app" - service = "service" + APP = "app" + SERVICE = "service" class UserFavorite(Base): diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index bfb539ce..8d5fab55 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -3,6 +3,7 @@ import logging import shutil +import uuid from anyio import Path from fastapi.encoders import jsonable_encoder @@ -24,13 +25,13 @@ BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "app" class AppLoader: """应用加载器""" - async def load(self, app_id: str, hashes: dict[str, str]) -> None: # noqa: C901 + async def load(self, app_id: uuid.UUID, hashes: dict[str, str]) -> None: # noqa: C901 """ 从文件系统中加载应用 :param app_id: 应用 ID """ - app_path = BASE_PATH / app_id + app_path = BASE_PATH / str(app_id) metadata_path = app_path / "metadata.yaml" metadata = await MetadataLoader().load_one(metadata_path) if not metadata: @@ -84,7 +85,7 @@ class AppLoader: raise RuntimeError(err) from e await self._update_db(metadata) - async def save(self, metadata: AppMetadata | AgentAppMetadata, app_id: str) -> None: + async def save(self, metadata: AppMetadata | AgentAppMetadata, app_id: uuid.UUID) -> None: """ 保存应用 @@ -92,7 +93,7 @@ class AppLoader: :param app_id: 应用 ID """ # 创建文件夹 - app_path = BASE_PATH / app_id + app_path = BASE_PATH / str(app_id) if not await app_path.exists(): await app_path.mkdir(parents=True, exist_ok=True) # 保存元数据 @@ -104,7 +105,7 @@ class AppLoader: @staticmethod - async def delete(app_id: str, *, is_reload: bool = False) -> None: + async def delete(app_id: uuid.UUID, *, is_reload: bool = False) -> None: """ 删除App,并更新数据库 @@ -129,7 +130,7 @@ class AppLoader: logger.exception("[AppLoader] MongoDB删除App失败") if not is_reload: - app_path = BASE_PATH / app_id + app_path = BASE_PATH / str(app_id) if await app_path.exists(): shutil.rmtree(str(app_path), ignore_errors=True) diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 53e1e05b..09c1b08f 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -2,6 +2,7 @@ """OpenAPI文档载入器""" import logging +import uuid from hashlib import shake_128 from typing import Any @@ -16,7 +17,7 @@ from apps.scheduler.openapi import ( ) from apps.scheduler.util import yaml_str_presenter from apps.schemas.enum_var import ContentType, HTTPMethod -from apps.schemas.node import APINode, APINodeInput, APINodeOutput +from apps.schemas.node import APINodeInput, APINodeOutput logger = logging.getLogger(__name__) @@ -121,31 +122,43 @@ class OpenAPILoader: async def _process_spec( self, - service_id: str, + service_id: uuid.UUID, yaml_filename: str, spec: ReducedOpenAPISpec, server: str, - ) -> list[APINode]: + ) -> list[NodeInfo]: """将OpenAPI文档拆解为Node""" nodes = [] for api_endpoint in spec.endpoints: # 通过算法生成唯一的标识符 identifier = shake_128(f"openapi::{yaml_filename}::{api_endpoint.uri}".encode()).hexdigest(16) # 组装新的NodePool item - node = APINode( - _id=identifier, + node = NodeInfo( + id=identifier, name=api_endpoint.name, # 此处固定Call的ID是"API" - call_id="API", + callId="API", description=api_endpoint.description, - service_id=service_id, + serviceId=service_id, + knownParams={}, + overrideInput={}, + overrideOutput={}, ) # 合并参数 - node.override_input, node.override_output, node.known_params = await self._get_api_data( + override_input, override_output, known_params = await self._get_api_data( api_endpoint, server, ) + node.overrideInput = override_input.model_dump( + exclude_none=True, + by_alias=True, + ) + node.overrideOutput = override_output.model_dump( + exclude_none=True, + by_alias=True, + ) + node.knownParams = known_params nodes.append(node) return nodes @@ -157,7 +170,7 @@ class OpenAPILoader: """加载字典形式的OpenAPI文档""" spec = reduce_openapi_spec(yaml_dict) try: - await self._process_spec("temp", "temp.yaml", spec, spec.servers) + await self._process_spec(uuid.UUID("00000000-0000-0000-0000-000000000000"), "temp.yaml", spec, spec.servers) except Exception: err = "[OpenAPILoader] 处理OpenAPI文档失败" logger.exception(err) @@ -166,7 +179,7 @@ class OpenAPILoader: return spec - async def load_one(self, service_id: str, yaml_path: Path, server: str) -> list[NodeInfo]: + async def load_one(self, service_id: uuid.UUID, yaml_path: Path, server: str) -> list[NodeInfo]: """加载单个OpenAPI文档,可以直接指定路径""" try: spec = await self._read_yaml(yaml_path) @@ -184,23 +197,7 @@ class OpenAPILoader: logger.exception(err) raise RuntimeError(err) from e - return [ - NodeInfo( - name=node.name, - description=node.description, - callId=node.call_id, - serviceId=service_id, - overrideInput=node.override_input.model_dump( - exclude_none=True, - by_alias=True, - ) - if node.override_input - else {}, - overrideOutput={}, - knownParams=node.known_params, - ) - for node in api_nodes - ] + return api_nodes async def save_one(self, yaml_path: Path, yaml_dict: dict[str, Any]) -> None: diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 2e06208b..ea0df332 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -3,10 +3,10 @@ import logging import shutil +import uuid from anyio import Path -from fastapi.encoders import jsonable_encoder -from sqlalchemy import delete, insert +from sqlalchemy import delete from apps.common.config import config from apps.common.postgres import postgres @@ -15,7 +15,7 @@ from apps.models.node import NodeInfo from apps.models.service import Service from apps.models.vectors import NodePoolVector, ServicePoolVector from apps.scheduler.pool.check import FileChecker -from apps.schemas.flow import Permission, ServiceMetadata +from apps.schemas.flow import PermissionType, ServiceMetadata from .metadata import MetadataLoader, MetadataType from .openapi import OpenAPILoader @@ -28,9 +28,9 @@ class ServiceLoader: """Service 加载器""" @staticmethod - async def load(service_id: str, hashes: dict[str, str]) -> None: + async def load(service_id: uuid.UUID, hashes: dict[str, str]) -> None: """加载单个Service""" - service_path = BASE_PATH / service_id + service_path = BASE_PATH / str(service_id) # 载入元数据 metadata = await MetadataLoader().load_one(service_path / "metadata.yaml") if not isinstance(metadata, ServiceMetadata): @@ -52,9 +52,9 @@ class ServiceLoader: @staticmethod - async def save(service_id: str, metadata: ServiceMetadata, data: dict) -> None: + async def save(service_id: uuid.UUID, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" - service_path = BASE_PATH / service_id + service_path = BASE_PATH / str(service_id) # 创建文件夹 if not await service_path.exists(): await service_path.mkdir(parents=True, exist_ok=True) @@ -72,28 +72,18 @@ class ServiceLoader: @staticmethod - async def delete(service_id: str, *, is_reload: bool = False) -> None: + async def delete(service_id: uuid.UUID, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" - mongo = MongoDB() - service_collection = mongo.get_collection("service") - node_collection = mongo.get_collection("node") - try: - await service_collection.delete_one({"_id": service_id}) - await node_collection.delete_many({"service_id": service_id}) - except Exception: - logger.exception("[ServiceLoader] 删除Service失败") + async with postgres.session() as session: + await session.execute(delete(Service).where(Service.id == service_id)) + await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == service_id)) - try: - # 删除postgres中的向量数据 - async with postgres.session() as session: - await session.execute(delete(ServicePoolVector).where(ServicePoolVector.service_id == service_id)) - await session.execute(delete(NodePoolVector).where(NodePoolVector.serviceId == service_id)) - await session.commit() - except Exception: - logger.exception("[ServiceLoader] 删除数据库失败") + await session.execute(delete(ServicePoolVector).where(ServicePoolVector.id == service_id)) + await session.execute(delete(NodePoolVector).where(NodePoolVector.serviceId == service_id)) + await session.commit() if not is_reload: - path = BASE_PATH / service_id + path = BASE_PATH / str(service_id) if await path.exists(): shutil.rmtree(path) @@ -105,71 +95,54 @@ class ServiceLoader: err = f"[ServiceLoader] 服务 {metadata.id} 的哈希值为空" logger.error(err) raise ValueError(err) - # 更新MongoDB - mongo = MongoDB() - service_collection = mongo.get_collection("service") - node_collection = mongo.get_collection("node") - try: - # 先删除旧的节点 - await node_collection.delete_many({"service_id": metadata.id}) - # 插入或更新 Service - await service_collection.update_one( - {"_id": metadata.id}, - { - "$set": jsonable_encoder( - ServicePool( - _id=metadata.id, - name=metadata.name, - description=metadata.description, - author=metadata.author, - permission=metadata.permission if metadata.permission else Permission(), - hashes=metadata.hashes, - ), - ), - }, - upsert=True, + + # 更新数据库 + async with postgres.session() as session: + # 删除旧的数据 + await session.execute(delete(Service).where(Service.id == metadata.id)) + await session.execute(delete(NodeInfo).where(NodeInfo.serviceId == metadata.id)) + + # 插入新的数据 + service_data = Service( + id=metadata.id, + name=metadata.name, + description=metadata.description, + author=metadata.author, + permission=PermissionType.PRIVATE, ) + session.add(service_data) + for node in nodes: - await node_collection.update_one({"_id": node.id}, {"$set": jsonable_encoder(node)}, upsert=True) - except Exception as e: - err = f"[ServiceLoader] 更新 MongoDB 失败:{e}" - logger.exception(err) - raise RuntimeError(err) from e + session.add(node) + await session.commit() # 删除旧的向量数据 async with postgres.session() as session: - await session.execute(delete(ServicePoolVector).where(ServicePoolVector.service_id == metadata.id)) + await session.execute(delete(ServicePoolVector).where(ServicePoolVector.id == metadata.id)) await session.execute(delete(NodePoolVector).where(NodePoolVector.serviceId == metadata.id)) await session.commit() # 进行向量化,更新postgres service_vecs = await Embedding.get_embedding([metadata.description]) async with postgres.session() as session: - await session.execute( - insert(ServicePoolVector), - [ - { - "service_id": metadata.id, - "embedding": service_vecs[0], - }, - ], + pool_data = ServicePoolVector( + id=metadata.id, + embedding=service_vecs[0], ) + session.add(pool_data) await session.commit() node_descriptions = [] for node in nodes: node_descriptions += [node.description] - node_vecs = await Embedding.get_embedding(node_descriptions) + async with postgres.session() as session: - await session.execute( - insert(NodePoolVector), - [ - { - "service_id": metadata.id, - "embedding": vec, - } - for vec in node_vecs - ], - ) + for vec in node_vecs: + node_data = NodePoolVector( + id=node.id, + serviceId=metadata.id, + embedding=vec, + ) + session.add(node_data) await session.commit() diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index dc9dc47a..8e105f2b 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -6,6 +6,7 @@ from datetime import UTC, datetime from apps.common.config import config from apps.common.queue import MessageQueue +from apps.models.app import App from apps.scheduler.executor.agent import MCPAgentExecutor from apps.scheduler.executor.flow import FlowExecutor from apps.scheduler.pool.pool import Pool @@ -15,9 +16,7 @@ from apps.scheduler.scheduler.message import ( push_init_message, push_rag_message, ) -from apps.schemas.collection import LLM from apps.schemas.enum_var import AppType, EventType -from apps.schemas.pool import AppPool from apps.schemas.rag_data import RAGQueryReq from apps.schemas.request_data import RequestData from apps.schemas.scheduler import ExecutorBackground diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index ed1d89de..840cff01 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """App、Flow和Service等外置配置数据结构""" +import uuid from typing import Any from pydantic import BaseModel, Field @@ -70,7 +71,7 @@ class MetadataBase(BaseModel): """ type: MetadataType = Field(description="元数据类型") - id: str = Field(description="元数据ID") + id: uuid.UUID = Field(description="元数据ID") icon: str = Field(description="图标", default="") name: str = Field(description="元数据名称") description: str = Field(description="元数据描述") diff --git a/apps/schemas/node.py b/apps/schemas/node.py index 8391109a..e6ebec9a 100644 --- a/apps/schemas/node.py +++ b/apps/schemas/node.py @@ -17,11 +17,3 @@ class APINodeOutput(BaseModel): """API节点覆盖输出""" result: dict[str, Any] | None = Field(description="API节点输出Schema", default=None) - - -class APINode(BaseModel): - """API节点""" - - call_id: str = "API" - override_input: APINodeInput | None = Field(description="API节点输入覆盖", default=None) - override_output: APINodeOutput | None = Field(description="API节点输出覆盖", default=None) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index d15ca384..b765ae13 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -96,7 +96,7 @@ class ActiveMCPServiceRequest(BaseModel): class UpdateServiceRequest(BaseModel): """POST /api/service 请求数据结构""" - service_id: str | None = Field(None, alias="serviceId", description="服务ID(更新时传递)") + service_id: uuid.UUID | None = Field(None, alias="serviceId", description="服务ID(更新时传递)") data: dict[str, Any] = Field(..., description="对应 YAML 内容的数据对象") diff --git a/apps/schemas/response_data.py b/apps/schemas/response_data.py index 4e79f0af..0567adab 100644 --- a/apps/schemas/response_data.py +++ b/apps/schemas/response_data.py @@ -307,7 +307,7 @@ class GetRecentAppListRsp(ResponseData): class ServiceCardItem(BaseModel): """语义接口中心:服务卡片数据结构""" - service_id: str = Field(..., alias="serviceId", description="服务ID") + service_id: uuid.UUID = Field(..., alias="serviceId", description="服务ID") name: str = Field(..., description="服务名称") description: str = Field(..., description="服务简介") icon: str = Field(..., description="服务图标") @@ -326,7 +326,7 @@ class ServiceApiData(BaseModel): class BaseServiceOperationMsg(BaseModel): """语义接口中心:基础服务操作Result数据结构""" - service_id: str = Field(..., alias="serviceId", description="服务ID") + service_id: uuid.UUID = Field(..., alias="serviceId", description="服务ID") class GetServiceListMsg(BaseModel): @@ -346,7 +346,7 @@ class GetServiceListRsp(ResponseData): class UpdateServiceMsg(BaseModel): """语义接口中心:服务属性数据结构""" - service_id: str = Field(..., alias="serviceId", description="服务ID") + service_id: uuid.UUID = Field(..., alias="serviceId", description="服务ID") name: str = Field(..., description="服务名称") apis: list[ServiceApiData] = Field(..., description="解析后的接口列表") @@ -360,7 +360,7 @@ class UpdateServiceRsp(ResponseData): class GetServiceDetailMsg(BaseModel): """GET /api/service/{serviceId} Result数据结构""" - service_id: str = Field(..., alias="serviceId", description="服务ID") + service_id: uuid.UUID = Field(..., alias="serviceId", description="服务ID") name: str = Field(..., description="服务名称") apis: list[ServiceApiData] | None = Field(default=None, description="解析后的接口列表") data: dict[str, Any] | None = Field(default=None, description="YAML 内容数据对象") @@ -381,7 +381,7 @@ class DeleteServiceRsp(ResponseData): class ChangeFavouriteServiceMsg(BaseModel): """PUT /api/service/{serviceId} Result数据结构""" - service_id: str = Field(..., alias="serviceId", description="服务ID") + service_id: uuid.UUID = Field(..., alias="serviceId", description="服务ID") favorited: bool = Field(..., description="是否已收藏") diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index da5ba28f..02f35674 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -6,7 +6,7 @@ import uuid from datetime import UTC, datetime from typing import Any -from sqlalchemy import select +from sqlalchemy import and_, select from apps.common.postgres import postgres from apps.constants import SERVICE_PAGE_SIZE @@ -57,6 +57,7 @@ class AppCenterManager: result = await app_collection.find_one(query) return result is not None + @staticmethod async def validate_app_belong_to_user(user_sub: str, app_id: str) -> bool: """ @@ -66,15 +67,20 @@ class AppCenterManager: :param app_id: 应用id :return: 如果应用属于用户则返回True,否则返回False """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") # 获取应用集合' - query = { - "_id": app_id, - "author": user_sub, - } + async with postgres.session() as session: + app_obj = (await session.scalars( + select(App).where( + and_( + App.id == app_id, + App.author == user_sub, + ), + ), + )).one_or_none() + if not app_obj: + msg = f"[AppCenterManager] 应用不存在或权限不足: {app_id}" + raise ValueError(msg) + return True - result = await app_collection.find_one(query) - return result is not None @staticmethod async def fetch_apps( @@ -183,7 +189,7 @@ class AppCenterManager: :param data: 应用数据 :return: 应用唯一标识 """ - app_id = str(uuid.uuid4()) + app_id = uuid.uuid4() await AppCenterManager._process_app_and_save( app_type=data.app_type, app_id=app_id, @@ -390,7 +396,7 @@ class AppCenterManager: search_conditions: dict[str, Any], page: int, page_size: int, - ) -> tuple[list[AppPool], int]: + ) -> tuple[list[App], int]: """根据过滤条件搜索应用并计算总页数""" mongo = MongoDB() app_collection = mongo.get_collection("app") @@ -402,12 +408,12 @@ class AppCenterManager: .limit(page_size) .to_list(length=page_size) ) - apps = [AppPool.model_validate(doc) for doc in db_data] + apps = [App.model_validate(doc) for doc in db_data] return apps, total_apps @staticmethod - async def _get_app_data(app_id: str, user_sub: str, *, check_permission: bool = True) -> AppPool: + async def _get_app_data(app_id: str, user_sub: str, *, check_permission: bool = True) -> App: """ 从数据库获取应用数据并验证权限 @@ -490,7 +496,7 @@ class AppCenterManager: common_params: dict, user_sub: str, data: AppData | None = None, - app_data: AppPool | None = None, + app_data: App | None = None, published: bool | None = None, ) -> AgentAppMetadata: """创建 Agent 应用的元数据""" @@ -522,7 +528,7 @@ class AppCenterManager: @staticmethod async def _create_metadata( app_type: AppType, - app_id: str, + app_id: uuid.UUID, user_sub: str, **kwargs: Any, ) -> AppMetadata | AgentAppMetadata: @@ -540,7 +546,7 @@ class AppCenterManager: :raises ValueError: 无效应用类型或缺少必要数据 """ data: AppData | None = kwargs.get("data") - app_data: AppPool | None = kwargs.get("app_data") + app_data: App | None = kwargs.get("app_data") published: bool | None = kwargs.get("published") # 验证必要数据 @@ -554,8 +560,8 @@ class AppCenterManager: msg = f"参数 data 类型应为 AppData,但获取到的是 {type(data).__name__}" raise ValueError(msg) - if app_data is not None and not isinstance(app_data, AppPool): - msg = f"参数 app_data 类型应为 AppPool,但获取到的是 {type(app_data).__name__}" + if app_data is not None and not isinstance(app_data, App): + msg = f"参数 app_data 类型应为 App,但获取到的是 {type(app_data).__name__}" raise ValueError(msg) if published is not None and not isinstance(published, bool): @@ -587,7 +593,7 @@ class AppCenterManager: @staticmethod async def _process_app_and_save( app_type: AppType, - app_id: str, + app_id: uuid.UUID, user_sub: str, **kwargs: Any, ) -> Any: diff --git a/apps/services/document.py b/apps/services/document.py index fde7465d..1ebf3b90 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -6,6 +6,7 @@ import logging import uuid import asyncer +import magic from fastapi import UploadFile from sqlalchemy import and_, func, select @@ -30,7 +31,6 @@ class DocumentManager: MinioClient.check_bucket("document") file = document.file # 获取文件MIME - import magic mime = magic.from_buffer(file.read(), mime=True) file.seek(0) diff --git a/apps/services/service.py b/apps/services/service.py index 652e7b20..d6648309 100644 --- a/apps/services/service.py +++ b/apps/services/service.py @@ -3,17 +3,18 @@ import logging import uuid +from datetime import UTC, datetime from typing import Any import yaml from anyio import Path -from sqlalchemy import and_, or_, select +from sqlalchemy import and_, delete, func, or_, select from apps.common.config import config from apps.common.postgres import postgres -from apps.exceptions import InstancePermissionError, ServiceIDError from apps.models.node import NodeInfo from apps.models.service import Service, ServiceACL +from apps.models.user import User, UserFavorite, UserFavoriteType from apps.scheduler.openapi import ReducedOpenAPISpec from apps.scheduler.pool.loader.openapi import OpenAPILoader from apps.scheduler.pool.loader.service import ServiceLoader @@ -41,7 +42,43 @@ class ServiceCenterManager: page_size: int, ) -> tuple[list[ServiceCardItem], int]: """获取所有服务列表""" - service_pools, total_count = await ServiceCenterManager._search_service(filters, page, page_size) + async with postgres.session() as session: + if not keyword: + service_pools = list( + (await session.scalars(select(Service).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.ALL: + service_pools = list( + (await session.scalars(select(Service).where( + or_( + Service.name.like(f"%{keyword}%"), + Service.description.like(f"%{keyword}%"), + Service.author.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.NAME: + service_pools = list( + (await session.scalars(select(Service).where( + Service.name.like(f"%{keyword}%"), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.DESCRIPTION: + service_pools = list( + (await session.scalars(select(Service).where( + Service.description.like(f"%{keyword}%"), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.AUTHOR: + service_pools = list( + (await session.scalars(select(Service).where( + Service.author.like(f"%{keyword}%"), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + else: + service_pools = [] + total_count = len(service_pools) + fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_sub) services = [ ServiceCardItem( @@ -56,6 +93,7 @@ class ServiceCenterManager: ] return services, total_count + @staticmethod async def fetch_user_services( user_sub: str, @@ -69,9 +107,58 @@ class ServiceCenterManager: if keyword is not None and keyword not in user_sub: return [], 0 keyword = user_sub - base_filter = {"author": user_sub} - filters = ServiceCenterManager._build_filters(base_filter, search_type, keyword) if keyword else base_filter - service_pools, total_count = await ServiceCenterManager._search_service(filters, page, page_size) + + async with postgres.session() as session: + if not keyword: + service_pools = list( + (await session.scalars(select(Service).where( + Service.author == user_sub, + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.ALL: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.author == user_sub, + or_( + Service.name.like(f"%{keyword}%"), + Service.description.like(f"%{keyword}%"), + Service.author.like(f"%{keyword}%"), + ), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.NAME: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.author == user_sub, + Service.name.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.DESCRIPTION: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.author == user_sub, + Service.description.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.AUTHOR: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.author == user_sub, + Service.author.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + else: + service_pools = [] + total_count = len(service_pools) + fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_sub) services = [ ServiceCardItem( @@ -86,6 +173,7 @@ class ServiceCenterManager: ] return services, total_count + @staticmethod async def fetch_favorite_services( user_sub: str, @@ -96,30 +184,79 @@ class ServiceCenterManager: ) -> tuple[list[ServiceCardItem], int]: """获取用户收藏的服务""" fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_sub) - base_filter = {"_id": {"$in": fav_service_ids}} - filters = ServiceCenterManager._build_filters(base_filter, search_type, keyword) if keyword else base_filter - service_pools, total_count = await ServiceCenterManager._search_service(filters, page, page_size) - services = [ - ServiceCardItem( - serviceId=service_pool.id, - icon="", - name=service_pool.name, - description=service_pool.description, - author=service_pool.author, - favorited=True, - ) - for service_pool in service_pools - ] - return services, total_count + + async with postgres.session() as session: + if not keyword: + service_pools = list( + (await session.scalars(select(Service).where( + Service.id.in_(fav_service_ids), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.ALL: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.id.in_(fav_service_ids), + or_( + Service.name.like(f"%{keyword}%"), + Service.description.like(f"%{keyword}%"), + Service.author.like(f"%{keyword}%"), + ), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.NAME: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.id.in_(fav_service_ids), + Service.name.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.DESCRIPTION: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.id.in_(fav_service_ids), + Service.description.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + elif search_type == SearchType.AUTHOR: + service_pools = list( + (await session.scalars(select(Service).where( + and_( + Service.id.in_(fav_service_ids), + Service.author.like(f"%{keyword}%"), + ), + ).order_by(Service.updatedAt.desc()).offset((page - 1) * page_size).limit(page_size))).all(), + ) + else: + service_pools = [] + total_count = len(service_pools) + + services = [ + ServiceCardItem( + serviceId=service_pool.id, + icon="", + name=service_pool.name, + description=service_pool.description, + author=service_pool.author, + favorited=True, + ) + for service_pool in service_pools + ] + return services, total_count @staticmethod async def create_service( user_sub: str, data: dict[str, Any], - ) -> str: + ) -> uuid.UUID: """创建服务""" - service_id = str(uuid.uuid4()) + service_id = uuid.uuid4() # 校验 OpenAPI 规范的 JSON Schema validated_data = await ServiceCenterManager._validate_service_data(data) @@ -136,7 +273,7 @@ class ServiceCenterManager: if db_service: msg = "[ServiceCenterManager] 已存在相同名称和描述的服务" - raise ServiceIDError(msg) + raise RuntimeError(msg) # 存入数据库 service_metadata = ServiceMetadata( @@ -157,9 +294,9 @@ class ServiceCenterManager: @staticmethod async def update_service( user_sub: str, - service_id: str, + service_id: uuid.UUID, data: dict[str, Any], - ) -> str: + ) -> uuid.UUID: """更新服务""" # 验证用户权限 async with postgres.session() as session: @@ -169,11 +306,12 @@ class ServiceCenterManager: ), )).one_or_none() if not db_service: - msg = "Service not found" - raise ServiceIDError(msg) + msg = "[ServiceCenterManager] Service not found" + raise RuntimeError(msg) if db_service.author != user_sub: - msg = "Permission denied" - raise InstancePermissionError(msg) + msg = "[ServiceCenterManager] Permission denied" + raise RuntimeError(msg) + db_service.updatedAt = datetime.now(tz=UTC) # 校验 OpenAPI 规范的 JSON Schema validated_data = await ServiceCenterManager._validate_service_data(data) # 存入数据库 @@ -190,53 +328,83 @@ class ServiceCenterManager: # 返回服务ID return service_id + @staticmethod async def get_service_apis( - service_id: str, + service_id: uuid.UUID, ) -> tuple[str, list[ServiceApiData]]: """获取服务API列表""" # 获取服务名称 async with postgres.session() as session: - db_service = (await session.scalars( - select(Service).where( + service_name = (await session.scalars( + select(Service.name).where( Service.id == service_id, ), )).one_or_none() - if not db_service: - msg = "Service not found" - raise ServiceIDError(msg) - # 根据 service_id 获取 API 列表 - node_collection = MongoDB().get_collection("node") - db_nodes = await node_collection.find({"service_id": service_id}).to_list() - api_list: list[ServiceApiData] = [] - for db_node in db_nodes: - node = NodePool.model_validate(db_node) - api_list.append( + if not service_name: + msg = "[ServiceCenterManager] Service not found" + raise RuntimeError(msg) + + # 根据 service_id 获取 API 列表 + node_data = (await session.scalars( + select(NodeInfo).where( + NodeInfo.serviceId == service_id, + ), + )).all() + api_list = [ ServiceApiData( name=node.name, - path=f"{node.known_params['method'].upper()} {node.known_params['url']}" - if node.known_params and "method" in node.known_params and "url" in node.known_params + path=f"{node.knownParams['method'].upper()} {node.knownParams['url']}" + if node.knownParams and "method" in node.knownParams and "url" in node.knownParams else "", description=node.description, - ), - ) - return service_pool_store.name, api_list + ) + for node in node_data + ] + return service_name, api_list @staticmethod async def get_service_data( user_sub: str, - service_id: str, + service_id: uuid.UUID, ) -> tuple[str, dict[str, Any]]: """获取服务数据""" # 验证用户权限 + async with postgres.session() as session: + db_service = (await session.scalars( + select(Service).where( + and_( + Service.id == service_id, + Service.author == user_sub, + ), + ), + )).one_or_none() + + if not db_service: + msg = "[ServiceCenterManager] Service not found or permission denied" + raise RuntimeError(msg) + + service_path = ( + Path(config.deploy.data_dir) / "semantics" / "service" / str(service_id) / "openapi" / "api.yaml" + ) + async with await service_path.open() as f: + service_data = yaml.safe_load(await f.read()) + return db_service.name, service_data + + + @staticmethod + async def get_service_metadata( + user_sub: str, + service_id: uuid.UUID, + ) -> ServiceMetadata: + """获取服务元数据""" async with postgres.session() as session: allowed_user = list((await session.scalars( select(ServiceACL.userSub).where( ServiceACL.serviceId == service_id, ), )).all()) - if user_sub in allowed_user: db_service = (await session.scalars( select(Service).where( @@ -257,46 +425,15 @@ class ServiceCenterManager: Service.permission == PermissionType.PRIVATE, ), Service.permission == PermissionType.PUBLIC, + Service.author == user_sub, ), ), ), )).one_or_none() - if not db_service: msg = "[ServiceCenterManager] Service not found or permission denied" raise RuntimeError(msg) - service_path = ( - Path(config.deploy.data_dir) / "semantics" / "service" / service_id / "openapi" / "api.yaml" - ) - async with await service_path.open() as f: - service_data = yaml.safe_load(await f.read()) - return db_service.name, service_data - - - @staticmethod - async def get_service_metadata( - user_sub: str, - service_id: uuid.UUID, - ) -> ServiceMetadata: - """获取服务元数据""" - service_collection = MongoDB().get_collection("service") - match_conditions = [ - {"author": user_sub}, - {"permission.type": PermissionType.PUBLIC.value}, - { - "$and": [ - {"permission.type": PermissionType.PROTECTED.value}, - {"permission.users": user_sub}, - ], - }, - ] - query = {"$and": [{"_id": service_id}, {"$or": match_conditions}]} - db_service = await service_collection.find_one(query) - if not db_service: - msg = "Service not found" - raise ServiceIDError(msg) - metadata_path = ( Path(config.deploy.data_dir) / "semantics" / "service" / str(service_id) / "metadata.yaml" ) @@ -304,94 +441,108 @@ class ServiceCenterManager: metadata_data = yaml.safe_load(await f.read()) return ServiceMetadata.model_validate(metadata_data) + @staticmethod - async def delete_service( - user_sub: str, - service_id: str, - ) -> bool: + async def delete_service(user_sub: str, service_id: uuid.UUID) -> None: """删除服务""" - service_collection = MongoDB().get_collection("service") - user_collection = MongoDB().get_collection("user") - db_service = await service_collection.find_one({"_id": service_id}) - if not db_service: - msg = "[ServiceCenterManager] Service未找到" - raise ServiceIDError(msg) - # 验证用户权限 - service_pool_store = ServicePool.model_validate(db_service) - if service_pool_store.author != user_sub: - msg = "Permission denied" - raise InstancePermissionError(msg) - # 删除服务 - service_loader = ServiceLoader() - await service_loader.delete(service_id) - # 删除用户收藏 - await user_collection.update_many( - {"fav_services": {"$in": [service_id]}}, - {"$pull": {"fav_services": service_id}}, - ) - return True + async with postgres.session() as session: + db_service = (await session.scalars( + select(Service).where( + and_( + Service.id == service_id, + Service.author == user_sub, + ), + ), + )).one_or_none() + if not db_service: + msg = "[ServiceCenterManager] Service not found or permission denied" + raise RuntimeError(msg) + + # 删除服务 + service_loader = ServiceLoader() + await service_loader.delete(service_id) + # 删除ACL + await session.execute( + delete(ServiceACL).where( + ServiceACL.serviceId == service_id, + ), + ) + # 删除收藏 + await session.execute( + delete(UserFavorite).where( + UserFavorite.itemId == service_id, + UserFavorite.type == UserFavoriteType.SERVICE, + ), + ) + await session.commit() + @staticmethod async def modify_favorite_service( user_sub: str, - service_id: str, + service_id: uuid.UUID, *, favorited: bool, - ) -> bool: + ) -> None: """修改收藏状态""" - service_collection = MongoDB().get_collection("service") - user_collection = MongoDB().get_collection("user") - db_service = await service_collection.find_one({"_id": service_id}) - if not db_service: - msg = f"[ServiceCenterManager] Service未找到: {service_id}" - logger.warning(msg) - raise ServiceIDError(msg) - db_user = await user_collection.find_one({"_id": user_sub}) - if not db_user: - msg = f"[ServiceCenterManager] 用户未找到: {user_sub}" - logger.warning(msg) - raise ServiceIDError(msg) - user_data = User.model_validate(db_user) - already_favorited = service_id in user_data.fav_services - if already_favorited == favorited: - return False - if favorited: - await user_collection.update_one( - {"_id": user_sub}, - {"$addToSet": {"fav_services": service_id}}, - ) - else: - await user_collection.update_one( - {"_id": user_sub}, - {"$pull": {"fav_services": service_id}}, - ) - return True + async with postgres.session() as session: + db_service = (await session.scalars( + select(func.count(Service.id)).where( + Service.id == service_id, + ), + )).one() + if not db_service: + msg = f"[ServiceCenterManager] Service未找到: {service_id}" + logger.warning(msg) + raise RuntimeError(msg) + + user = (await session.scalars( + select(func.count(User.userSub)).where( + User.userSub == user_sub, + ), + )).one() + if not user: + msg = f"[ServiceCenterManager] 用户未找到: {user_sub}" + logger.warning(msg) + raise RuntimeError(msg) + + # 检查是否已收藏 + user_favourite = (await session.scalars( + select(UserFavorite).where( + and_( + UserFavorite.itemId == service_id, + UserFavorite.userSub == user_sub, + UserFavorite.type == UserFavoriteType.SERVICE, + ), + ), + )).one_or_none() + if not user_favourite and favorited: + # 创建收藏条目 + user_favourite = UserFavorite( + itemId=service_id, + userSub=user_sub, + type=UserFavoriteType.SERVICE, + ) + session.add(user_favourite) + await session.commit() + elif user_favourite and not favorited: + # 删除收藏条目 + await session.delete(user_favourite) + await session.commit() - @staticmethod - async def _search_service( - search_conditions: dict, - page: int, - page_size: int, - ) -> tuple[list[ServicePool], int]: - """基于输入条件获取服务数据""" - service_collection = MongoDB().get_collection("service") - # 获取服务总数 - total = await service_collection.count_documents(search_conditions) - # 分页查询 - skip = (page - 1) * page_size - db_services = await service_collection.find(search_conditions).skip(skip).limit(page_size).to_list() - if not db_services and total > 0: - logger.warning("[ServiceCenterManager] 没有找到符合条件的服务: %s", search_conditions) - return [], -1 - service_pools = [ServicePool.model_validate(db_service) for db_service in db_services] - return service_pools, total @staticmethod - async def _get_favorite_service_ids_by_user(user_sub: str) -> list[str]: + async def _get_favorite_service_ids_by_user(user_sub: str) -> list[uuid.UUID]: """获取用户收藏的服务ID""" - user_collection = MongoDB().get_collection("user") - user_data = User.model_validate(await user_collection.find_one({"_id": user_sub})) - return user_data.fav_services + async with postgres.session() as session: + user_favourite = (await session.scalars( + select(UserFavorite).where( + UserFavorite.userSub == user_sub, + UserFavorite.type == UserFavoriteType.SERVICE, + ), + )).all() + return [user_favourite.itemId for user_favourite in user_favourite] + @staticmethod async def _validate_service_data(data: dict[str, Any]) -> ReducedOpenAPISpec: -- Gitee