diff --git a/Dockerfile b/Dockerfile index 7180b6a9e4fc7316a27efb46a431676dbbb31780..51ea82487f3236e0a8d1d432cd49310a3884bd11 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,4 +6,4 @@ ENV TIKTOKEN_CACHE_DIR=/app/assets/tiktoken COPY --chmod=550 ./ /app/ RUN chmod 766 /root -CMD ["uv", "run", "--no-sync", "--no-dev", "apps/main.py"] +CMD ["uv", "run", "--no-sync", "python", "-m", "apps.main"] diff --git a/Dockerfile-base-arm b/Dockerfile-base-arm index 4d2f89d879475c9a8a7555a6572fc90324b415e2..48218cfd903a8bdee0fae1bc510638918c82dd3b 100644 --- a/Dockerfile-base-arm +++ b/Dockerfile-base-arm @@ -21,7 +21,7 @@ WORKDIR /app RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ # 使用正确的参数替代 --no-extra-index-url - uv sync --no-dev --no-install-project --no-cache-dir \ + uv sync --no-install-project --no-cache-dir \ --index-url https://mirrors.aliyun.com/pypi/simple/ \ --config-setting installer.pip.index-urls=https://mirrors.aliyun.com/pypi/simple/ \ --config-setting installer.pip.trusted-hosts=mirrors.aliyun.com diff --git a/Dockerfile-base-x86 b/Dockerfile-base-x86 index f88ebf32891ba3a286f00e01402df4900fd96621..5971a51fde7e05c05ffdf06e7fd4b97c8d5d390b 100644 --- a/Dockerfile-base-x86 +++ b/Dockerfile-base-x86 @@ -16,4 +16,4 @@ RUN sed -i 's|repo.openeuler.org|repo.huaweicloud.com/openeuler|g' /etc/yum.repo WORKDIR /app RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ - uv sync --no-dev --no-install-project --no-cache-dir \ No newline at end of file + uv sync --no-install-project --no-cache-dir \ No newline at end of file diff --git a/apps/dependency/__init__.py b/apps/dependency/__init__.py index 83c26ea2f11c0fdbaf704a9fdcd91b0eacb57825..1d35c5dce4254d47d96d4fd64f0d8efd5a7cc5d1 100644 --- a/apps/dependency/__init__.py +++ b/apps/dependency/__init__.py @@ -4,15 +4,15 @@ from apps.dependency.user import ( get_session, get_user, - get_user_by_api_key, - verify_api_key, + get_personal_token_user, + verify_personal_token, verify_user, ) __all__ = [ "get_session", "get_user", - "get_user_by_api_key", - "verify_api_key", + "get_personal_token_user", + "verify_personal_token", "verify_user", ] diff --git a/apps/dependency/user.py b/apps/dependency/user.py index fce67e51e00dd76b35bec2f7787dfe8039111589..547f32850394f047964c9f72806ee39d6dde8524 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -9,14 +9,14 @@ from starlette import status from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection -from apps.services.api_key import ApiKeyManager +from apps.services.personal_token import PersonalTokenManager from apps.services.session import SessionManager logger = logging.getLogger(__name__) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -async def _get_session_id_from_request(request: HTTPConnection) -> str | None: +async def _extract_session_id(request: HTTPConnection) -> str | None: """ 从请求中获取 session_id @@ -48,7 +48,7 @@ async def get_session(request: HTTPConnection) -> str: :param request: HTTP请求 :return: Session ID """ - session_id = await _get_session_id_from_request(request) + session_id = await _extract_session_id(request) if not session_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -69,7 +69,7 @@ async def get_user(request: HTTPConnection) -> str: :param request: HTTP请求体 :return: 用户sub """ - session_id = await _get_session_id_from_request(request) + session_id = await _extract_session_id(request) if not session_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -88,25 +88,25 @@ async def get_user(request: HTTPConnection) -> str: return user_sub -async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: +async def verify_personal_token(personal_token: str = Depends(oauth2_scheme)) -> None: """ - 验证API Key是否有效;无效则抛出HTTP 401;接口级dependence + 验证Personal Token是否有效;无效则抛出HTTP 401;接口级dependence - :param api_key: API Key + :param personal_token: Personal Token :return: """ - if not await ApiKeyManager.verify_api_key(api_key): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key!") + if not await PersonalTokenManager.verify_personal_token(personal_token): + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Personal Token!") -async def get_user_by_api_key(api_key: str = Depends(oauth2_scheme)) -> str: +async def get_personal_token_user(personal_token: str = Depends(oauth2_scheme)) -> str: """ - 验证API Key是否有效;若有效,返回对应的user_sub;若无效,抛出HTTP 401;参数级dependence + 验证Personal Token是否有效;若有效,返回对应的user_sub;若无效,抛出HTTP 401;参数级dependence - :param api_key: API Key + :param personal_token: Personal Token :return: 用户sub """ - user_sub = await ApiKeyManager.get_user_by_api_key(api_key) + user_sub = await PersonalTokenManager.get_user_by_personal_token(personal_token) if user_sub is None: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key!") + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Personal Token!") return user_sub diff --git a/apps/main.py b/apps/main.py index 43be981e94c37e4f605807f11280dd9149ea9e28..c9ce67303c569f042385c7098821d76182cace39 100644 --- a/apps/main.py +++ b/apps/main.py @@ -12,6 +12,7 @@ import logging import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from fastapi_profiler import Profiler from rich.console import Console from rich.logging import RichHandler @@ -21,7 +22,6 @@ from .common.postgres import postgres from .common.wordscheck import WordsCheck from .llm.token import TokenCalculator from .routers import ( - api_key, appcenter, auth, blacklist, @@ -34,14 +34,17 @@ from .routers import ( knowledge, llm, mcp_service, + personal_token, record, service, + tag, user, ) from .scheduler.pool.pool import Pool # 定义FastAPI app app = FastAPI(redoc_url=None) +Profiler(app) # 定义FastAPI全局中间件 app.add_middleware( CORSMiddleware, @@ -53,7 +56,7 @@ app.add_middleware( # 关联API路由 app.include_router(conversation.router) app.include_router(auth.router) -app.include_router(api_key.router) +app.include_router(personal_token.router) app.include_router(appcenter.router) app.include_router(service.router) app.include_router(comment.router) @@ -67,6 +70,7 @@ app.include_router(llm.router) app.include_router(mcp_service.router) app.include_router(flow.router) app.include_router(user.router) +app.include_router(tag.router) # logger配置 LOGGER_FORMAT = "%(funcName)s() - %(message)s" diff --git a/apps/models/app.py b/apps/models/app.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb0dd0c6048e3f7ce487b528df487ac4d615e67 --- /dev/null +++ b/apps/models/app.py @@ -0,0 +1,72 @@ +"""应用 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import BigInteger, Boolean, DateTime, Enum, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from apps.schemas.enum_var import AppType, PermissionType + +from .base import Base + + +class App(Base): + """应用""" + + __tablename__ = "framework_app" + name: Mapped[str] = mapped_column(String(255)) + """应用名称""" + description: Mapped[str] = mapped_column(String(2000)) + """应用描述""" + author: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """应用作者""" + type: Mapped[AppType] = mapped_column(Enum(AppType)) + """应用类型""" + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) + """应用ID""" + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + index=True, + ) + """应用更新时间""" + icon_path: Mapped[str] = mapped_column(String(255), default="") + """应用图标路径""" + is_published: Mapped[bool] = mapped_column(Boolean, default=False) + """是否发布""" + permission: Mapped[PermissionType] = mapped_column(Enum(PermissionType), default=PermissionType.PUBLIC) + """权限类型""" + __table_args__ = ( + Index("idx_published_updated_at", "is_published", "updated_at"), + Index("idx_author_id_name", "author", "id", "name"), + ) + + +class AppACL(Base): + """应用权限""" + + __tablename__ = "framework_app_acl" + app_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_app.id")) + """关联的应用ID""" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """用户名""" + action: Mapped[str] = mapped_column(String(255), default="") + """操作类型(读/写)""" + + +class AppHashes(Base): + """应用哈希""" + + __tablename__ = "framework_app_hashes" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" + app_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_app.id")) + """关联的应用ID""" + file_path: Mapped[str] = mapped_column(String(255)) + """文件路径""" + hash: Mapped[str] = mapped_column(String(255)) + """哈希值""" diff --git a/apps/models/base.py b/apps/models/base.py index fd90ec0eebc1fc80681621237b842039a5c3b7dd..5f4424f2f4256b18a62398f8a7d3fc3667e17f15 100644 --- a/apps/models/base.py +++ b/apps/models/base.py @@ -1,5 +1,8 @@ """SQLAlchemy模型基类""" -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass + + +class Base(MappedAsDataclass, DeclarativeBase): + """SQLAlchemy模型基类""" -Base = declarative_base() diff --git a/apps/models/comment.py b/apps/models/comment.py new file mode 100644 index 0000000000000000000000000000000000000000..44559e4d3a2e120ef9c7717e12d900371a779651 --- /dev/null +++ b/apps/models/comment.py @@ -0,0 +1,36 @@ +"""评论 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import ARRAY, BigInteger, DateTime, Enum, ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from apps.schemas.enum_var import CommentType + +from .base import Base + + +class Comment(Base): + """评论""" + + __tablename__ = "framework_comment" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" + record_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_record.id")) + """问答对ID""" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """用户名""" + comment_type: Mapped[CommentType] = mapped_column(Enum(CommentType)) + """点赞点踩""" + feedback_type: Mapped[list[str]] = mapped_column(ARRAY(String(100))) + """投诉类别""" + feedback_link: Mapped[str] = mapped_column(String(1000)) + """投诉链接""" + feedback_content: Mapped[str] = mapped_column(String(1000)) + """投诉内容""" + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """评论创建时间""" diff --git a/apps/models/conversation.py b/apps/models/conversation.py new file mode 100644 index 0000000000000000000000000000000000000000..a205935895f9402240b74fb00357177582a55e74 --- /dev/null +++ b/apps/models/conversation.py @@ -0,0 +1,49 @@ +"""对话 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from apps.constants import NEW_CHAT + +from .base import Base + + +class Conversation(Base): + """对话""" + + __tablename__ = "framework_conversation" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """用户名""" + app_id: Mapped[uuid.UUID | None] = mapped_column(ForeignKey("framework_app.id"), nullable=True) + """对话使用的App的ID""" + title: Mapped[str] = mapped_column(String(255), default=NEW_CHAT) + """对话标题""" + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default_factory=lambda: uuid.uuid4(), + ) + """对话ID""" + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """对话创建时间""" + is_temporary: Mapped[bool] = mapped_column(Boolean, default=False) + """是否为临时对话""" + + +class ConversationDocument(Base): + """对话所用的临时文件""" + + __tablename__ = "framework_conversation_document" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" + conversation_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_conversation.id")) + """对话ID""" + document_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_document.id")) + """文件ID""" + is_unused: Mapped[bool] = mapped_column(Boolean, default=True) + """是否未使用""" diff --git a/apps/models/document.py b/apps/models/document.py new file mode 100644 index 0000000000000000000000000000000000000000..b46e0e88f417f41fa127089c258ab52cf49371e6 --- /dev/null +++ b/apps/models/document.py @@ -0,0 +1,39 @@ +"""文件 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import DateTime, Float, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class Document(Base): + """文件""" + + __tablename__ = "framework_document" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """用户ID""" + name: Mapped[str] = mapped_column(String(255)) + """文件名称""" + type: Mapped[str] = mapped_column(String(100)) + """文件类型""" + size: Mapped[float] = mapped_column(Float) + """文件大小""" + conversation_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_conversation.id")) + """所属对话的ID""" + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4, + ) + """文件的ID""" + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """文件的创建时间""" + __table_args__ = ( + Index("idx_user_sub_conversation_id", "user_sub", "conversation_id"), + ) diff --git a/apps/models/flow.py b/apps/models/flow.py new file mode 100644 index 0000000000000000000000000000000000000000..1471ab685388601db49a0bf0552e6f6e45082a01 --- /dev/null +++ b/apps/models/flow.py @@ -0,0 +1,57 @@ +"""Flow 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import Boolean, DateTime, ForeignKey, Index, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class Flow(Base): + """Flow""" + + __tablename__ = "framework_flow" + app_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_app.id")) + """所属App的ID""" + name: Mapped[str] = mapped_column(String(255)) + """Flow的名称""" + description: Mapped[str] = mapped_column(String(2000)) + """Flow的描述""" + path: Mapped[str] = mapped_column(String(255)) + """Flow的路径""" + debug: Mapped[bool] = mapped_column(Boolean, default=False) + """是否经过调试""" + enabled: Mapped[bool] = mapped_column(Boolean, default=True) + """是否启用""" + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4, + ) + """Flow的ID""" + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """Flow的更新时间""" + __table_args__ = ( + Index("idx_app_id_id", "app_id", "id"), + Index("idx_app_id_name", "app_id", "name"), + ) + + +class FlowContext(Base): + """Flow上下文""" + + __tablename__ = "framework_flow_context" + flow_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_flow.id")) + """所属Flow的ID""" + context: Mapped[str] = mapped_column(String(2000)) + """上下文""" + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4, + ) + """Flow上下文的ID""" diff --git a/apps/models/node.py b/apps/models/node.py new file mode 100644 index 0000000000000000000000000000000000000000..2d25a410e26338afd0e3ef3def40bfb2fcd10cdf --- /dev/null +++ b/apps/models/node.py @@ -0,0 +1,40 @@ +"""节点 数据库表""" + +import uuid +from datetime import datetime +from typing import Any + +import pytz +from sqlalchemy import DateTime, Enum, ForeignKey, String +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class Node(Base): + """节点""" + + __tablename__ = "framework_node" + name: Mapped[str] = mapped_column(String(255)) + """节点名称""" + description: Mapped[str] = mapped_column(String(2000)) + """节点描述""" + service_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_service.id")) + """所属服务ID""" + call_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_call.id")) + """所属CallID""" + known_params: Mapped[dict[str, Any]] = mapped_column(JSONB) + """已知的用于Call部分的参数,独立于输入和输出之外""" + override_input: Mapped[dict[str, Any]] = mapped_column(JSONB) + """Node的输入Schema;用于描述Call的参数中特定的字段""" + override_output: Mapped[dict[str, Any]] = mapped_column(JSONB) + """Node的输出Schema;用于描述Call的输出中特定的字段""" + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) + """节点ID""" + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """节点更新时间""" diff --git a/apps/models/record.py b/apps/models/record.py new file mode 100644 index 0000000000000000000000000000000000000000..5c51a659ebed3838e851a804ea97e51bd7bd9598 --- /dev/null +++ b/apps/models/record.py @@ -0,0 +1,43 @@ +"""问答对 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import BigInteger, DateTime, ForeignKey +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class Record(Base): + """问答对""" + + __tablename__ = "framework_record" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """用户名""" + conversation_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_conversation.id")) + """对话ID""" + task_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_task.id")) + """任务ID""" + id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default_factory=lambda: uuid.uuid4(), + ) + """问答对ID""" + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """问答对创建时间""" + + +class RecordDocument(Base): + """问答对关联的文件""" + + __tablename__ = "framework_record_document" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" + record_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_record.id")) + """问答对ID""" + document_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_document.id")) + """文件ID""" diff --git a/apps/models/service.py b/apps/models/service.py new file mode 100644 index 0000000000000000000000000000000000000000..5b33f67abb01dcbfee5221eeabb9433534238511 --- /dev/null +++ b/apps/models/service.py @@ -0,0 +1,63 @@ +"""插件 数据库表""" + +import uuid +from datetime import datetime + +import pytz +from sqlalchemy import BigInteger, DateTime, Enum, ForeignKey, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from apps.schemas.enum_var import PermissionType + +from .base import Base + + +class Service(Base): + """插件""" + + __tablename__ = "framework_service" + name: Mapped[str] = mapped_column(String(255)) + """插件名称""" + description: Mapped[str] = mapped_column(String(2000)) + """插件描述""" + author: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """插件作者""" + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """插件更新时间""" + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default_factory=uuid.uuid4) + """插件ID""" + icon_path: Mapped[str] = mapped_column(String(255), default="") + """插件图标路径""" + permission: Mapped[PermissionType] = mapped_column(Enum(PermissionType), default=PermissionType.PUBLIC) + """权限类型""" + + +class ServiceACL(Base): + """插件权限""" + + __tablename__ = "framework_service_acl" + service_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_service.id")) + """关联的插件ID""" + user_sub: Mapped[str] = mapped_column(ForeignKey("framework_user.user_sub")) + """用户名""" + action: Mapped[str] = mapped_column(String(255), default="") + """操作类型(读/写)""" + + +class ServiceHashes(Base): + """插件哈希""" + + __tablename__ = "framework_service_hashes" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """主键ID""" + service_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_service.id")) + """关联的插件ID""" + file_path: Mapped[str] = mapped_column(String(255)) + """文件路径""" + hash: Mapped[str] = mapped_column(String(255)) + """哈希值""" diff --git a/apps/models/tag.py b/apps/models/tag.py new file mode 100644 index 0000000000000000000000000000000000000000..5b35df87fdc62455c41425f4d00a0bfef71220b9 --- /dev/null +++ b/apps/models/tag.py @@ -0,0 +1,27 @@ +"""用户标签 数据库表""" + +from datetime import datetime + +import pytz +from sqlalchemy import BigInteger, DateTime, String +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class Tag(Base): + """用户标签""" + + __tablename__ = "framework_tag" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """标签ID""" + name: Mapped[str] = mapped_column(String(255), index=True, unique=True) + """标签名称""" + definition: Mapped[str] = mapped_column(String(2000)) + """标签定义""" + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + onupdate=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """标签的更新时间""" diff --git a/apps/models/user.py b/apps/models/user.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd20204dc3ca36e422bf7da8bd31880ba8d28f3 --- /dev/null +++ b/apps/models/user.py @@ -0,0 +1,92 @@ +"""用户表""" + +import enum +import uuid +from datetime import datetime +from hashlib import sha256 + +import pytz +from sqlalchemy import BigInteger, Boolean, DateTime, Enum, Integer, String +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column + +from .base import Base + + +class User(Base): + """用户表""" + + __tablename__ = "framework_user" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """用户ID""" + user_sub: Mapped[str] = mapped_column(String(50), index=True, unique=True) + """用户名""" + last_login: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """用户最后一次登录时间""" + is_active: Mapped[bool] = mapped_column(Boolean, default=True) + """用户是否活跃""" + is_whitelisted: Mapped[bool] = mapped_column(Boolean, default=False) + """用户是否白名单""" + credit: Mapped[int] = mapped_column(Integer, default=100) + """用户风控分""" + personal_token: Mapped[str] = mapped_column( + String(100), default_factory=lambda: sha256(str(uuid.uuid4()).encode()).hexdigest()[:16], + ) + """用户个人令牌""" + is_admin: Mapped[bool] = mapped_column(Boolean, default=False) + """用户是否管理员""" + + +class UserFavoriteType(str, enum.Enum): + """用户收藏类型""" + + app = "app" + service = "service" + + +class UserFavorite(Base): + """用户收藏""" + + __tablename__ = "framework_user_favorite" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """用户收藏ID""" + user_sub: Mapped[str] = mapped_column(String(50), index=True, foreign_key="framework_user.user_sub") + """用户名""" + type: Mapped[UserFavoriteType] = mapped_column(Enum(UserFavoriteType)) + """收藏类型""" + item_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), index=True) + """收藏项目ID(App/Service ID)""" + + +class UserAppUsage(Base): + """用户应用使用情况""" + + __tablename__ = "framework_user_app_usage" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """用户应用使用情况ID""" + user_sub: Mapped[str] = mapped_column(String(50), index=True, foreign_key="framework_user.user_sub") + """用户名""" + app_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), index=True) + """应用ID""" + usage_count: Mapped[int] = mapped_column(Integer, default=0) + """应用使用次数""" + last_used: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default_factory=lambda: datetime.now(tz=pytz.timezone("Asia/Shanghai")), + ) + """用户最后一次使用时间""" + + +class UserTag(Base): + """用户标签""" + + __tablename__ = "framework_user_tag" + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) + """用户标签ID""" + user_sub: Mapped[str] = mapped_column(String(50), index=True, foreign_key="framework_user.user_sub") + """用户名""" + tag: Mapped[int] = mapped_column(BigInteger, index=True, foreign_key="framework_tag.id") + """标签ID""" + count: Mapped[int] = mapped_column(Integer, default=0) + """标签归类次数""" diff --git a/apps/models/vector.py b/apps/models/vector.py deleted file mode 100644 index 1cdc85c9b9fbe5aaa902e1b93238001e5bbe44c5..0000000000000000000000000000000000000000 --- a/apps/models/vector.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""向量数据库数据结构;数据将存储在LanceDB中""" - -from lancedb.pydantic import LanceModel, Vector - - -class FlowPoolVector(LanceModel): - """App向量信息""" - - id: str - app_id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] - - -class ServicePoolVector(LanceModel): - """Service向量信息""" - - id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] - - -class CallPoolVector(LanceModel): - """Call向量信息""" - - id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] - - -class NodePoolVector(LanceModel): - """Node向量信息""" - - id: str - service_id: str - embedding: Vector(dim=1024) # type: ignore[call-arg] diff --git a/apps/models/vectors.py b/apps/models/vectors.py index 6ce20f5515e7baa2c7a80c4eae3f60a05eecf8f4..af30a090a2d4a3cec563a7fae69f9ac4a024463f 100644 --- a/apps/models/vectors.py +++ b/apps/models/vectors.py @@ -1,7 +1,9 @@ """向量表""" +import uuid + from pgvector.sqlalchemy import Vector -from sqlalchemy import Integer, String +from sqlalchemy import ForeignKey, Index from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -10,35 +12,59 @@ from .base import Base class FlowPoolVector(Base): """Flow向量数据""" - __tablename__ = "flow_vector" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - flow_id: Mapped[str] = mapped_column(String(32), index=True, unique=True) - app_id: Mapped[str] = mapped_column(String(32)) + __tablename__ = "framework_flow_vector" + app_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_app.id")) + """所属App的ID""" embedding: Mapped[Vector] = mapped_column(Vector(1024)) + """向量数据""" + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_flow.id"), primary_key=True) + """Flow的ID""" + __table_args__ = ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 64}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ) class ServicePoolVector(Base): """Service向量数据""" - __tablename__ = "service_vector" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - service_id: Mapped[str] = mapped_column(String(32), index=True, unique=True) - embedding: Mapped[Vector] = mapped_column(Vector(1024)) - - -class CallPoolVector(Base): - """Call向量数据""" - - __tablename__ = "call_vector" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - call_id: Mapped[str] = mapped_column(String(32), index=True, unique=True) + __tablename__ = "framework_service_vector" embedding: Mapped[Vector] = mapped_column(Vector(1024)) + """向量数据""" + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_service.id"), primary_key=True) + """Service的ID""" + __table_args__ = ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 64}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ) class NodePoolVector(Base): """Node向量数据""" - __tablename__ = "node_vector" - id: Mapped[int] = mapped_column(Integer, primary_key=True) - service_id: Mapped[str] = mapped_column(String(32)) + __tablename__ = "framework_node_vector" embedding: Mapped[Vector] = mapped_column(Vector(1024)) + """向量数据""" + service_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_service.id")) + """Service的ID""" + id: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_node.id"), primary_key=True) + """Node的ID""" + __table_args__ = ( + Index( + "hnsw_index", + "embedding", + postgresql_using="hnsw", + postgresql_with={"m": 16, "ef_construction": 64}, + postgresql_ops={"embedding": "vector_cosine_ops"}, + ), + ) diff --git a/apps/routers/auth.py b/apps/routers/auth.py index 1cba5ed629d90776b59ae3bf652381318dcabf0e..87e8e8f0d7a5206d50123a0c3f6e8018ecb0284e 100644 --- a/apps/routers/auth.py +++ b/apps/routers/auth.py @@ -11,7 +11,6 @@ from fastapi.templating import Jinja2Templates from apps.common.oidc import oidc_provider from apps.dependency import get_session, get_user, verify_user -from apps.schemas.collection import Audit from apps.schemas.response_data import ( AuthUserMsg, AuthUserRsp, @@ -19,7 +18,6 @@ from apps.schemas.response_data import ( OidcRedirectRsp, ResponseData, ) -from apps.services.audit_log import AuditLogManager from apps.services.session import SessionManager from apps.services.token import TokenManager from apps.services.user import UserManager @@ -62,32 +60,15 @@ async def oidc_login(request: Request, code: str) -> HTMLResponse: if not user_sub: logger.error("OIDC no user_sub associated.") - data = Audit( - http_method="get", - module="auth", - client_ip=user_host, - message="/api/auth/login: OIDC no user_sub associated.", - ) - await AuditLogManager.add_audit_log(data) return templates.TemplateResponse( "login_failed.html.j2", {"request": request, "reason": "未能获取用户信息,请关闭本窗口并重试。"}, status_code=status.HTTP_403_FORBIDDEN, ) - await UserManager.update_userinfo_by_user_sub(user_sub) + await UserManager.update_user(user_sub) current_session = await SessionManager.create_session(user_host, user_sub) - - data = Audit( - user_sub=user_sub, - http_method="get", - module="auth", - client_ip=user_host, - message="/api/auth/login: User login.", - ) - await AuditLogManager.add_audit_log(data) - return templates.TemplateResponse( "login_success.html.j2", {"request": request, "current_session": current_session}, @@ -114,15 +95,6 @@ async def logout( await TokenManager.delete_plugin_token(user_sub) await SessionManager.delete_session(session_id) - data = Audit( - http_method="get", - module="auth", - client_ip=request.client.host, - user_sub=user_sub, - message="/api/auth/logout: User logout succeeded.", - ) - await AuditLogManager.add_audit_log(data) - return JSONResponse( status_code=status.HTTP_200_OK, content=ResponseData( @@ -151,60 +123,23 @@ async def oidc_redirect() -> JSONResponse: @router.post("/logout", dependencies=[Depends(verify_user)], response_model=ResponseData) async def oidc_logout(token: str) -> JSONResponse: """OIDC主动触发登出""" - - -@router.get("/user", response_model=AuthUserRsp) -async def userinfo( - user_sub: Annotated[str, Depends(get_user)], - _: Annotated[None, Depends(verify_user)], -) -> JSONResponse: - """获取用户信息""" - user = await UserManager.get_userinfo_by_user_sub(user_sub=user_sub) - if not user: - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="Get UserInfo failed.", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) return JSONResponse( status_code=status.HTTP_200_OK, - content=AuthUserRsp( + content=ResponseData( code=status.HTTP_200_OK, message="success", - result=AuthUserMsg( - user_sub=user_sub, - revision=user.is_active, - is_admin=user.is_admin, - ), + result={}, ).model_dump(exclude_none=True, by_alias=True), ) -@router.post( - "/update_revision_number", - dependencies=[Depends(verify_user)], - response_model=AuthUserRsp, - responses={ - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": ResponseData}, - }, -) -async def update_revision_number(request: Request, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: # noqa: ARG001 - """更新用户协议信息""" - ret: bool = await UserManager.update_userinfo_by_user_sub(user_sub, refresh_revision=True) - if not ret: - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=ResponseData( - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="update revision failed", - result={}, - ).model_dump(exclude_none=True, by_alias=True), - ) - - user = await UserManager.get_userinfo_by_user_sub(user_sub) +@router.get("/user", response_model=AuthUserRsp) +async def userinfo( + user_sub: Annotated[str, Depends(get_user)], + _: Annotated[None, Depends(verify_user)], +) -> JSONResponse: + """获取用户信息""" + user = await UserManager.get_user(user_sub=user_sub) if not user: return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -214,7 +149,6 @@ async def update_revision_number(request: Request, user_sub: Annotated[str, Depe result={}, ).model_dump(exclude_none=True, by_alias=True), ) - return JSONResponse( status_code=status.HTTP_200_OK, content=AuthUserRsp( diff --git a/apps/routers/comment.py b/apps/routers/comment.py index d096a2708ca2a991ea1dfd096c94c5d6aee43ddf..7a807692de8d85c16099f20c29aa16513340abce 100644 --- a/apps/routers/comment.py +++ b/apps/routers/comment.py @@ -28,7 +28,7 @@ router = APIRouter( @router.post("", response_model=ResponseData) async def add_comment(post_body: AddCommentData, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: """给Record添加评论""" - if not await RecordManager.verify_record_in_group(post_body.group_id, post_body.record_id, user_sub): + if not await RecordManager.verify_record_in_group(post_body.record_id, user_sub): logger.error("[Comment] record_id 不存在") return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( code=status.HTTP_400_BAD_REQUEST, @@ -43,7 +43,7 @@ async def add_comment(post_body: AddCommentData, user_sub: Annotated[str, Depend reason_description=post_body.reason_description, feedback_time=round(datetime.now(tz=UTC).timestamp(), 3), ) - await CommentManager.update_comment(post_body.group_id, post_body.record_id, comment_data) + await CommentManager.update_comment(post_body.record_id, comment_data) return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( code=status.HTTP_200_OK, message="success", diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 1620ab5421a4018fd316a6d55939d859befd5282..10832ce6a89ecdbbf2b8c5d813d5fce4985e3fd5 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -10,7 +10,7 @@ from fastapi import APIRouter, Body, Depends, Query, Request, status from fastapi.responses import JSONResponse from apps.dependency import get_user, verify_user -from apps.schemas.collection import Audit, Conversation +from apps.models.conversation import Conversation from apps.schemas.request_data import ( DeleteConversationData, ModifyConversationData, @@ -29,7 +29,6 @@ from apps.schemas.response_data import ( UpdateConversationRsp, ) from apps.services.application import AppManager -from apps.services.audit_log import AuditLogManager from apps.services.conversation import ConversationManager from apps.services.document import DocumentManager @@ -224,7 +223,6 @@ async def update_conversation( @router.delete("", response_model=ResponseData) async def delete_conversation( - request: Request, post_body: DeleteConversationData, user_sub: Annotated[str, Depends(get_user)], ) -> JSONResponse: @@ -234,21 +232,7 @@ async def delete_conversation( # 删除对话 await ConversationManager.delete_conversation_by_conversation_id(user_sub, conversation_id) # 删除对话对应的文件 - await DocumentManager.delete_document_by_conversation_id(user_sub, conversation_id) - - # 添加审计日志 - request_host = None - if request.client is not None: - request_host = request.client.host - data = Audit( - user_sub=user_sub, - http_method="delete", - module="/conversation", - client_ip=request_host, - message=f"deleted conversation with id: {conversation_id}", - ) - await AuditLogManager.add_audit_log(data) - + await DocumentManager.delete_document_by_conversation_id(conversation_id) deleted_conversation.append(conversation_id) return JSONResponse( diff --git a/apps/routers/mcp_service.py b/apps/routers/mcp_service.py index a845a3761ea2ab179a1b71d4674d0d5fd3b5a760..3db3eb663d727f331aca61cc4b421e71787fa1d3 100644 --- a/apps/routers/mcp_service.py +++ b/apps/routers/mcp_service.py @@ -37,7 +37,7 @@ router = APIRouter( ) async def _check_user_admin(user_sub: str) -> None: - user = await UserManager.get_userinfo_by_user_sub(user_sub) + user = await UserManager.get_user(user_sub) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户未登录") if not user.is_admin: diff --git a/apps/routers/api_key.py b/apps/routers/personal_token.py similarity index 37% rename from apps/routers/api_key.py rename to apps/routers/personal_token.py index 158cfc13a5be17f16e2d73ec5fa44c4a18e4f392..062e8677d6d5226381c8deff3838206828a7c6ea 100644 --- a/apps/routers/api_key.py +++ b/apps/routers/personal_token.py @@ -7,9 +7,9 @@ from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from apps.dependency.user import get_user, verify_user -from apps.schemas.api_key import GetAuthKeyRsp, PostAuthKeyMsg, PostAuthKeyRsp +from apps.schemas.personal_token import PostPersonalTokenMsg, PostPersonalTokenRsp from apps.schemas.response_data import ResponseData -from apps.services.api_key import ApiKeyManager +from apps.services.personal_token import PersonalTokenManager router = APIRouter( prefix="/api/auth/key", @@ -18,59 +18,26 @@ router = APIRouter( ) -@router.get("", response_model=GetAuthKeyRsp) -async def check_api_key_existence(user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: - """检查API密钥是否存在""" - exists: bool = await ApiKeyManager.api_key_exists(user_sub) - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result={ - "api_key_exists": exists, - }, - ).model_dump(exclude_none=True, by_alias=True)) - - @router.post("", responses={ 400: {"model": ResponseData}, -}, response_model=PostAuthKeyRsp) -async def manage_api_key(action: str, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: +}, response_model=PostPersonalTokenRsp) +async def manage_personal_token(action: str, user_sub: Annotated[str, Depends(get_user)]) -> JSONResponse: """管理用户的API密钥""" action = action.lower() if action == "create": - api_key: str | None = await ApiKeyManager.generate_api_key(user_sub) + api_key: str | None = await PersonalTokenManager.generate_personal_token(user_sub) elif action == "update": - api_key: str | None = await ApiKeyManager.update_api_key(user_sub) - elif action == "delete": - success: bool = await ApiKeyManager.delete_api_key(user_sub) - if success: - return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( - code=status.HTTP_200_OK, - message="success", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="failed to revoke api key", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) + api_key: str | None = await PersonalTokenManager.update_personal_token(user_sub) else: return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( code=status.HTTP_400_BAD_REQUEST, message="invalid request", result={}, ).model_dump(exclude_none=True, by_alias=True)) - - if api_key is None: - return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( - code=status.HTTP_400_BAD_REQUEST, - message="failed to generate api key", - result={}, - ).model_dump(exclude_none=True, by_alias=True)) - return JSONResponse(status_code=status.HTTP_200_OK, content=PostAuthKeyRsp( + return JSONResponse(status_code=status.HTTP_200_OK, content=PostPersonalTokenRsp( code=status.HTTP_200_OK, message="success", - result=PostAuthKeyMsg( + result=PostPersonalTokenMsg( api_key=api_key, ), ).model_dump(exclude_none=True, by_alias=True)) diff --git a/apps/routers/record.py b/apps/routers/record.py index 7384793b4b2abea5117462cb630c5b13c8f76071..72f277be2fdbc66f77cc568a319671f5d5be7282 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -80,7 +80,7 @@ async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_ ) # 获得Record关联的文档 - tmp_record.document = await DocumentManager.get_used_docs_by_record_group(user_sub, record_group.id) + tmp_record.document = await DocumentManager.get_used_docs_by_record(user_sub, record_group.id) # 获得Record关联的flow数据 flow_list = await TaskManager.get_context_by_record_id(record_group.id, record.id) diff --git a/apps/routers/domain.py b/apps/routers/tag.py similarity index 67% rename from apps/routers/domain.py rename to apps/routers/tag.py index 7eefafa6e4fea576d44f810bb6de7b35d617a085..31bec010d9ab41eb02d9ea13477d12a29b9642ee 100644 --- a/apps/routers/domain.py +++ b/apps/routers/tag.py @@ -1,17 +1,17 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""FastAPI 用户画像相关API""" +"""FastAPI 用户标签相关API""" from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from apps.dependency.user import verify_user -from apps.schemas.request_data import PostDomainData +from apps.schemas.request_data import PostTagData from apps.schemas.response_data import ResponseData -from apps.services.domain import DomainManager +from apps.services.tag import TagManager router = APIRouter( - prefix="/api/domain", - tags=["domain"], + prefix="/api/tag", + tags=["tag"], dependencies=[ Depends(verify_user), ], @@ -19,71 +19,71 @@ router = APIRouter( @router.post("", response_model=ResponseData) -async def add_domain(post_body: PostDomainData) -> JSONResponse: - """添加用户领域画像""" - if await DomainManager.get_domain_by_domain_name(post_body.domain_name): +async def add_tags(post_body: PostTagData) -> JSONResponse: + """添加用户标签""" + if await TagManager.get_tag_by_name(post_body.tag): return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="add domain name is exist.", + message="[Tag] Add tag name is exist.", result={}, ).model_dump(exclude_none=True, by_alias=True)) try: - await DomainManager.add_domain(post_body) + await TagManager.add_tag(post_body) except Exception as e: # noqa: BLE001 return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"add domain failed: {e!s}", + message=f"[Tag] Add tag failed: {e!s}", result={}, ).model_dump(exclude_none=True, by_alias=True)) return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( code=status.HTTP_200_OK, - message="add domain success.", + message="[Tag] Add tag success.", result={}, ).model_dump(exclude_none=True, by_alias=True)) @router.put("", response_model=ResponseData) -async def update_domain(post_body: PostDomainData) -> JSONResponse: +async def update_tag(post_body: PostTagData) -> JSONResponse: """更新用户领域画像""" - if not await DomainManager.get_domain_by_domain_name(post_body.domain_name): + if not await TagManager.get_tag_by_name(post_body.tag): return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="update domain name is not exist.", + message="[Tag] Update tag name is not exist.", result={}, ).model_dump(exclude_none=True, by_alias=True)) - if not await DomainManager.update_domain_by_domain_name(post_body): + if not await TagManager.update_tag_by_name(post_body): return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="update domain failed", + message="[Tag] Update tag failed", result={}, ).model_dump(exclude_none=True, by_alias=True)) return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( code=status.HTTP_200_OK, - message="update domain success.", + message="[Tag] Update tag success.", result={}, ).model_dump(exclude_none=True, by_alias=True)) @router.delete("", response_model=ResponseData) -async def delete_domain(post_body: PostDomainData) -> JSONResponse: +async def delete_tag(post_body: PostTagData) -> JSONResponse: """删除用户领域画像""" - if not await DomainManager.get_domain_by_domain_name(post_body.domain_name): + if not await TagManager.get_tag_by_name(post_body.tag): return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message="delete domain name is not exist.", + message="[Tag] Delete tag name is not exist.", result={}, ).model_dump(exclude_none=True, by_alias=True)) try: - await DomainManager.delete_domain_by_domain_name(post_body) + await TagManager.delete_tag(post_body) except Exception as e: # noqa: BLE001 return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"delete domain failed: {e!s}", + message=f"[Tag] Delete tag failed: {e!s}", result={}, ).model_dump(exclude_none=True, by_alias=True)) return JSONResponse(status_code=status.HTTP_200_OK, content=ResponseData( code=status.HTTP_200_OK, - message="delete domain success.", + message="[Tag] Delete tag success.", result={}, ).model_dump(exclude_none=True, by_alias=True)) diff --git a/apps/routers/user.py b/apps/routers/user.py index 54e12f444181b56e408f7face5c4ed37be76008b..b474031fa53f1d797b7fd63b38fd2c8d6dde0859 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -18,11 +18,11 @@ router = APIRouter( @router.get("") -async def chat( +async def list_user( user_sub: Annotated[str, Depends(get_user)], ) -> JSONResponse: """查询所有用户接口""" - user_list = await UserManager.get_all_user_sub() + user_list = await UserManager.list_user() user_info_list = [] for user in user_list: # user_info = await UserManager.get_userinfo_by_user_sub(user) 暂时不需要查询user_name diff --git a/apps/scheduler/call/api/api.py b/apps/scheduler/call/api/api.py index 54e06cb5b9b41f45059158c4fdba5bf2132edef8..d2a1e7235e1dd71c14d818acd0817fde7a9a8d66 100644 --- a/apps/scheduler/call/api/api.py +++ b/apps/scheduler/call/api/api.py @@ -67,27 +67,32 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): async def _init(self, call_vars: CallVars) -> APIInput: """初始化API调用工具""" - self._service_id = "" + self._service_id = None self._session_id = call_vars.ids.session_id self._auth = None - if self.node: - # 获取对应API的Service Metadata + if not self.node: + raise CallError( + message="API工具调用时,必须指定Node", + data={}, + ) + + # 获取对应API的Service Metadata + if self.node.service_id: try: service_metadata = await ServiceCenterManager.get_service_metadata( call_vars.ids.user_sub, - self.node.service_id or "", + self.node.service_id, ) + # 获取Service对应的Auth + self._auth = service_metadata.api.auth + self._service_id = self.node.service_id except Exception as e: raise CallError( message="API接口的Service Metadata获取失败", data={}, ) from e - # 获取Service对应的Auth - self._auth = service_metadata.api.auth - self._service_id = self.node.service_id or "" - return APIInput( url=self.url, method=self.method, @@ -166,23 +171,28 @@ class API(CoreCall, input_model=APIInput, output_model=APIOutput): req_cookie = {} req_params = {} - if self._auth.header: # type: ignore[attr-defined] # 如果header列表非空 - for item in self._auth.header: # type: ignore[attr-defined] - req_header[item.name] = item.value - elif self._auth.cookie: # type: ignore[attr-defined] # 如果cookie列表非空 - for item in self._auth.cookie: # type: ignore[attr-defined] - req_cookie[item.name] = item.value - elif self._auth.query: # type: ignore[attr-defined] # 如果query列表非空 - for item in self._auth.query: # type: ignore[attr-defined] - req_params[item.name] = item.value - elif self._auth.oidc: # type: ignore[attr-defined] # 如果oidc配置存在 - token = await TokenManager.get_plugin_token( - self._service_id, - self._session_id, - await oidc_provider.get_access_token_url(), - 30, - ) - req_header.update({"access-token": token}) + if self._auth: + # 如果header列表非空 + if self._auth.header: + for item in self._auth.header: + req_header[item.name] = item.value + # 如果cookie列表非空 + if self._auth.cookie: + for item in self._auth.cookie: + req_cookie[item.name] = item.value + # 如果query列表非空 + if self._auth.query: + for item in self._auth.query: + req_params[item.name] = item.value + # 如果oidc配置存在 + if self._auth.oidc: + token = await TokenManager.get_plugin_token( + self._service_id, + self._session_id, + await oidc_provider.get_access_token_url(), + 30, + ) + req_header.update({"access-token": token}) return req_header, req_cookie, req_params diff --git a/apps/scheduler/call/core.py b/apps/scheduler/call/core.py index 2b1cbba83b9345e8df84e8f23f893137cb46532b..1161e545448a0b479053abde69d68e2bb1ad7687 100644 --- a/apps/scheduler/call/core.py +++ b/apps/scheduler/call/core.py @@ -14,8 +14,8 @@ from pydantic.json_schema import SkipJsonSchema from apps.llm.function import FunctionLLM from apps.llm.reasoning import ReasoningLLM +from apps.models.node import Node from apps.schemas.enum_var import CallOutputType -from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( CallError, CallIds, @@ -51,7 +51,7 @@ class CoreCall(BaseModel): name: SkipJsonSchema[str] = Field(description="Step的名称", exclude=True) description: SkipJsonSchema[str] = Field(description="Step的描述", exclude=True) - node: SkipJsonSchema[NodePool | None] = Field(description="节点信息", exclude=True) + node: SkipJsonSchema[Node | None] = Field(description="节点信息", exclude=True) enable_filling: SkipJsonSchema[bool] = Field(description="是否需要进行自动参数填充", default=False, exclude=True) tokens: SkipJsonSchema[CallTokens] = Field( description="Call的输入输出Tokens信息", @@ -158,7 +158,7 @@ class CoreCall(BaseModel): @classmethod - async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: """实例化Call类""" obj = cls( name=executor.step.step.name, diff --git a/apps/scheduler/call/facts/facts.py b/apps/scheduler/call/facts/facts.py index 72490cba582255e1f0f05de9de45c1fad46ef6d6..2f7f4aa5e399fadd072f64d15897dbd49e994639 100644 --- a/apps/scheduler/call/facts/facts.py +++ b/apps/scheduler/call/facts/facts.py @@ -8,11 +8,11 @@ from jinja2 import BaseLoader from jinja2.sandbox import SandboxedEnvironment from pydantic import Field +from apps.models.node import Node from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType -from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars -from apps.services.user_domain import UserDomainManager +from apps.services.user_tag import UserTagManager from .prompt import DOMAIN_PROMPT, FACTS_PROMPT from .schema import ( @@ -39,7 +39,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): @classmethod - async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: """初始化工具""" obj = cls( answer=executor.task.runtime.answer, @@ -95,7 +95,7 @@ class FactsCall(CoreCall, input_model=FactsInput, output_model=FactsOutput): ], DomainGen) # type: ignore[arg-type] for domain in domain_list.keywords: - await UserDomainManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) + await UserTagManager.update_user_domain_by_user_sub_and_domain_name(data.user_sub, domain) yield CallOutputChunk( type=CallOutputType.DATA, diff --git a/apps/scheduler/call/slot/slot.py b/apps/scheduler/call/slot/slot.py index 5f13a672bd6a94e94ecfc4fe8663629fb05cc684..4ae07cbd5197a73d13a75462884519fc5efbaae6 100644 --- a/apps/scheduler/call/slot/slot.py +++ b/apps/scheduler/call/slot/slot.py @@ -11,10 +11,10 @@ from pydantic import Field from apps.llm.function import FunctionLLM, JsonGenerator from apps.llm.reasoning import ReasoningLLM +from apps.models.node import Node from apps.scheduler.call.core import CoreCall from apps.scheduler.slot.slot import Slot as SlotProcessor from apps.schemas.enum_var import CallOutputType -from apps.schemas.pool import NodePool from apps.schemas.scheduler import CallInfo, CallOutputChunk, CallVars from .prompt import SLOT_GEN_PROMPT @@ -94,7 +94,7 @@ class Slot(CoreCall, input_model=SlotInput, output_model=SlotOutput): return await json_gen.generate() @classmethod - async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: """实例化Call类""" obj = cls( name=executor.step.step.name, diff --git a/apps/scheduler/call/suggest/suggest.py b/apps/scheduler/call/suggest/suggest.py index b7d981df297d5adee29862f90a55fa0ed80d3125..23fddae4ec64b02d377d8f95b82d3d83d157214d 100644 --- a/apps/scheduler/call/suggest/suggest.py +++ b/apps/scheduler/call/suggest/suggest.py @@ -12,9 +12,9 @@ from pydantic.json_schema import SkipJsonSchema from apps.common.security import Security from apps.llm.function import FunctionLLM +from apps.models.node import Node from apps.scheduler.call.core import CoreCall from apps.schemas.enum_var import CallOutputType -from apps.schemas.pool import NodePool from apps.schemas.record import RecordContent from apps.schemas.scheduler import ( CallError, @@ -23,7 +23,7 @@ from apps.schemas.scheduler import ( CallVars, ) from apps.services.record import RecordManager -from apps.services.user_domain import UserDomainManager +from apps.services.user_tag import UserTagManager from .prompt import SUGGEST_PROMPT from .schema import ( @@ -48,6 +48,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO context: SkipJsonSchema[list[dict[str, str]]] = Field(description="Executor的上下文", exclude=True) conversation_id: SkipJsonSchema[str] = Field(description="对话ID", exclude=True) + @classmethod def info(cls) -> CallInfo: """返回Call的名称和描述""" @@ -55,7 +56,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO @classmethod - async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: """初始化""" context = [ { @@ -138,7 +139,7 @@ class Suggestion(CoreCall, input_model=SuggestionInput, output_model=SuggestionO ) # 获取当前用户的画像 - user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5) + user_domain = await UserTagManager.get_user_domain_by_user_sub_and_topk(data.user_sub, 5) # 已推送问题数量 pushed_questions = 0 # 初始化Prompt diff --git a/apps/scheduler/call/summary/summary.py b/apps/scheduler/call/summary/summary.py index 1dd99b819a89db6081edea327ce85e80756e4f72..754dc6e54ba01981605e9e91cba20df00b6b7c23 100644 --- a/apps/scheduler/call/summary/summary.py +++ b/apps/scheduler/call/summary/summary.py @@ -7,9 +7,9 @@ from typing import TYPE_CHECKING, Any, Self from pydantic import Field from apps.llm.patterns.executor import ExecutorSummary +from apps.models.node import Node from apps.scheduler.call.core import CoreCall, DataBase from apps.schemas.enum_var import CallOutputType -from apps.schemas.pool import NodePool from apps.schemas.scheduler import ( CallInfo, CallOutputChunk, @@ -35,7 +35,7 @@ class Summary(CoreCall, input_model=DataBase, output_model=SummaryOutput): return CallInfo(name="理解上下文", description="使用大模型,理解对话上下文") @classmethod - async def instance(cls, executor: "StepExecutor", node: NodePool | None, **kwargs: Any) -> Self: + async def instance(cls, executor: "StepExecutor", node: Node | None, **kwargs: Any) -> Self: """实例化工具""" obj = cls( context=executor.background, diff --git a/apps/scheduler/pool/loader/__init__.py b/apps/scheduler/pool/loader/__init__.py index c78ef6c4cd93ab9915d0d13cf8db8ae16a0bb3f3..5bb7c5804385a99ea30262f793787eb8911372b7 100644 --- a/apps/scheduler/pool/loader/__init__.py +++ b/apps/scheduler/pool/loader/__init__.py @@ -2,14 +2,12 @@ """配置加载器""" from .app import AppLoader -from .call import CallLoader from .flow import FlowLoader from .mcp import MCPLoader from .service import ServiceLoader __all__ = [ "AppLoader", - "CallLoader", "FlowLoader", "MCPLoader", "ServiceLoader", diff --git a/apps/scheduler/pool/loader/app.py b/apps/scheduler/pool/loader/app.py index 50d571e92c79671963470b4f9f5aa6d2ff841b4e..f7403b040bae8ce4831e8474fe0ce0ee4f3a6bcb 100644 --- a/apps/scheduler/pool/loader/app.py +++ b/apps/scheduler/pool/loader/app.py @@ -9,11 +9,11 @@ from fastapi.encoders import jsonable_encoder from apps.common.config import config from apps.common.mongo import MongoDB +from apps.models.app import App from apps.scheduler.pool.check import FileChecker from apps.schemas.agent import AgentAppMetadata from apps.schemas.enum_var import AppType from apps.schemas.flow import AppFlow, AppMetadata, MetadataType, Permission -from apps.schemas.pool import AppPool from .flow import FlowLoader from .metadata import MetadataLoader diff --git a/apps/scheduler/pool/loader/btdl.py b/apps/scheduler/pool/loader/btdl.py index 072303b7fa7e0ebd1ca4ed8ef3258ff64431aed8..c4fac922d3ff0ba0d4d05b37a49afd8fc1a24a0f 100644 --- a/apps/scheduler/pool/loader/btdl.py +++ b/apps/scheduler/pool/loader/btdl.py @@ -45,7 +45,11 @@ class BTDLLoader: for value in argument["properties"].values(): BTDLLoader._check_single_argument(value, strict=False) - def _load_single_subcmd(self, binary_name: str, subcmd_spec: dict[str, Any]) -> dict[str, tuple[str, str, dict[str, Any], dict[str, Any], str]]: + def _load_single_subcmd( + self, + binary_name: str, + subcmd_spec: dict[str, Any], + ) -> dict[str, tuple[str, str, dict[str, Any], dict[str, Any], str]]: """加载单个子命令""" if "name" not in subcmd_spec: err = "subcommand must have a name" diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index fb1b36af40798d2214e6c6a1ac3b06c45e16ad21..a485af28b669e6c07e4ba8888cde2fc68284d06f 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -14,9 +14,7 @@ from apps.common.lance import LanceDB from apps.common.mongo import MongoDB from apps.common.singleton import SingletonMeta from apps.llm.embedding import Embedding -from apps.models.vector import CallPoolVector -from apps.schemas.enum_var import CallType -from apps.schemas.pool import CallPool, NodePool +from apps.models.node import Node logger = logging.getLogger(__name__) BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "call" @@ -51,98 +49,6 @@ class CallLoader(metaclass=SingletonMeta): return call_metadata - async def _load_single_call_dir(self, call_dir_name: str) -> list[CallPool]: - """加载单个Call package""" - call_metadata = [] - - call_dir = BASE_PATH / call_dir_name - if not (call_dir / "__init__.py").exists(): - logger.info("[CallLoader] 模块 %s 不存在__init__.py文件,尝试自动创建。", call_dir) - try: - (Path(call_dir) / "__init__.py").touch() - except Exception as e: - err = f"自动创建模块文件{call_dir}/__init__.py失败:{e}。" - raise RuntimeError(err) from e - - # 载入子包 - try: - call_package = importlib.import_module("call." + call_dir_name) - except Exception as e: - err = f"载入模块call.{call_dir_name}失败:{e}。" - raise RuntimeError(err) from e - - sys.modules["call." + call_dir_name] = call_package - - # 已载入包,处理包中每个工具 - if not hasattr(call_package, "__all__"): - err = f"[CallLoader] 模块call.{call_dir_name}不符合模块要求,无法处理。" - logger.info(err) - raise ValueError(err) - - for call_id in call_package.__all__: - try: - call_cls = getattr(call_package, call_id) - call_info = call_cls.info() - except AttributeError as e: - err = f"[CallLoader] 载入工具call.{call_dir_name}.{call_id}失败:{e};跳过载入。" - logger.info(err) - continue - - cls_path = f"{call_package.service}::call.{call_dir_name}.{call_id}" - cls_hash = shake_128(cls_path.encode()).hexdigest(8) - call_metadata.append( - CallPool( - _id=cls_hash, - type=CallType.PYTHON, - name=call_info.name, - description=call_info.description, - path=f"python::call.{call_dir_name}::{call_id}", - ), - ) - - return call_metadata - - async def _load_all_user_call(self) -> list[CallPool]: - """加载Python Call""" - call_dir = BASE_PATH - call_metadata = [] - - # 载入父包 - try: - sys.path.insert(0, str(call_dir.parent)) - if not (call_dir / "__init__.py").exists(): - logger.info("[CallLoader] 父模块 %s 不存在__init__.py文件,尝试自动创建。", call_dir) - (Path(call_dir) / "__init__.py").touch() - importlib.import_module("call") - except Exception as e: - err = f"[CallLoader] 父模块'call'创建失败:{e};无法载入。" - raise RuntimeError(err) from e - - # 处理每一个子包 - for call_file in Path(call_dir).rglob("*"): - if not call_file.is_dir() or call_file.name[0] == "_": - continue - # 载入包 - try: - call_metadata.extend(await self._load_single_call_dir(call_file.name)) - except RuntimeError as e: - err = f"[CallLoader] 载入模块{call_file}失败:{e},跳过载入。" - logger.info(err) - continue - - return call_metadata - - async def _delete_one(self, call_name: str) -> None: - """删除单个Call""" - # 从数据库中删除 - await self._delete_from_db(call_name) - - # 从Python中卸载模块 - call_dir = BASE_PATH / call_name - if call_dir.exists(): - module_name = f"call.{call_name}" - if module_name in sys.modules: - del sys.modules[module_name] async def _delete_from_db(self, call_name: str) -> None: """从数据库中删除单个Call""" diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index 1f7e13ab51e65d07ed1953aabc17cc6b286806f2..beea79d73ae1fb0bdc3565262498809e717b4ce6 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -8,17 +8,18 @@ from typing import Any import aiofiles import yaml from anyio import Path +from sqlalchemy import delete, insert from apps.common.config import config from apps.common.mongo import MongoDB +from apps.common.postgres import postgres from apps.llm.embedding import Embedding +from apps.models.app import App +from apps.models.vectors import FlowPoolVector from apps.scheduler.util import yaml_enum_presenter, yaml_str_presenter from apps.schemas.enum_var import EdgeType from apps.schemas.flow import AppFlow, Flow -from apps.schemas.pool import AppPool -from apps.models.vector import FlowPoolVector from apps.services.node import NodeManager -from apps.common.lance import LanceDB logger = logging.getLogger(__name__) BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "app" @@ -181,11 +182,8 @@ class FlowLoader: logger.exception("[FlowLoader] 删除工作流文件失败:%s", flow_path) return False - table = await LanceDB().get_table("flow") - try: - await table.delete(f"id = '{flow_id}'") - except Exception: - logger.exception("[FlowLoader] LanceDB删除flow失败") + async with postgres.session() as session: + await session.execute(delete(FlowPoolVector).where(FlowPoolVector.flow_id == flow_id)) return True logger.warning("[FlowLoader] 工作流文件不存在或不是文件:%s", flow_path) return True @@ -236,36 +234,17 @@ class FlowLoader: logger.exception("[FlowLoader] 更新 MongoDB 失败") # 删除重复的ID - while True: - try: - table = await LanceDB().get_table("flow") - await table.delete(f"id = '{metadata.id}'") - break - except RuntimeError as e: - if "Commit conflict" in str(e): - logger.error("[FlowLoader] LanceDB删除flow冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + async with postgres.session() as session: + await session.execute(delete(FlowPoolVector).where(FlowPoolVector.flow_id == metadata.id)) + # 进行向量化 service_embedding = await Embedding.get_embedding([metadata.description]) vector_data = [ FlowPoolVector( - id=metadata.id, + flow_id=metadata.id, app_id=app_id, embedding=service_embedding[0], ), ] - while True: - try: - table = await LanceDB().get_table("flow") - await table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - vector_data, - ) - break - except RuntimeError as e: - if "Commit conflict" in str(e): - logger.error("[FlowLoader] LanceDB插入flow冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + async with postgres.session() as session: + await session.execute(insert(FlowPoolVector).values(vector_data)) diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 98801c96dfd291928729f152fec1f33f7027a78d..cc99255b558b6c0c3d8ea20ba48a1038971cbaab 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -8,16 +8,15 @@ from typing import Any import yaml from anyio import Path +from apps.models.node import Node from apps.scheduler.openapi import ( ReducedOpenAPIEndpoint, ReducedOpenAPISpec, reduce_openapi_spec, ) -from apps.scheduler.slot.slot import Slot from apps.scheduler.util import yaml_str_presenter from apps.schemas.enum_var import ContentType, HTTPMethod from apps.schemas.node import APINode, APINodeInput, APINodeOutput -from apps.schemas.pool import NodePool logger = logging.getLogger(__name__) @@ -100,10 +99,6 @@ class OpenAPILoader: logger.warning(err) raise ValueError(err) from e - # 将数据转为已知JSON - known_body = Slot(inp.body).create_empty_slot() if inp.body else {} - known_query = Slot(inp.query).create_empty_slot() if inp.query else {} - try: out = APINodeOutput( result=spec.spec["responses"]["content"]["application/json"]["schema"], @@ -117,12 +112,13 @@ class OpenAPILoader: "url": url, "method": method, "content_type": content_type, - "body": known_body, - "query": known_query, + "body": {}, + "query": {}, } return inp, out, known_params + async def _process_spec( self, service_id: str, @@ -170,7 +166,7 @@ class OpenAPILoader: return spec - async def load_one(self, service_id: str, yaml_path: Path, server: str) -> list[NodePool]: + async def load_one(self, service_id: str, yaml_path: Path, server: str) -> list[Node]: """加载单个OpenAPI文档,可以直接指定路径""" try: spec = await self._read_yaml(yaml_path) @@ -189,8 +185,7 @@ class OpenAPILoader: raise RuntimeError(err) from e return [ - NodePool( - _id=node.id, + Node( name=node.name, description=node.description, call_id=node.call_id, diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 2e640e50ce040d76e86fbcdabb1d32e3e04f2686..20b1ac1893a962545531dfa0e7fb9adb7a1d4361 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -1,21 +1,22 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """加载配置文件夹的Service部分""" -import asyncio import logging import shutil from anyio import Path from fastapi.encoders import jsonable_encoder +from sqlalchemy import delete, insert from apps.common.config import config -from apps.common.lance import LanceDB from apps.common.mongo import MongoDB +from apps.common.postgres import postgres from apps.llm.embedding import Embedding -from apps.models.vector import NodePoolVector, ServicePoolVector +from apps.models.node import Node +from apps.models.service import Service +from apps.models.vectors import NodePoolVector, ServicePoolVector from apps.scheduler.pool.check import FileChecker from apps.schemas.flow import Permission, ServiceMetadata -from apps.schemas.pool import NodePool, ServicePool from .metadata import MetadataLoader, MetadataType from .openapi import OpenAPILoader @@ -27,7 +28,8 @@ BASE_PATH = Path(config.deploy.data_dir) / "semantics" / "service" class ServiceLoader: """Service 加载器""" - async def load(self, service_id: str, hashes: dict[str, str]) -> None: + @staticmethod + async def load(service_id: str, hashes: dict[str, str]) -> None: """加载单个Service""" service_path = BASE_PATH / service_id # 载入元数据 @@ -47,10 +49,11 @@ class ServiceLoader: logger.exception("[ServiceLoader] 服务 %s 文件损坏", service_id) return # 更新数据库 - await self._update_db(nodes, metadata) + await ServiceLoader._update_db(nodes, metadata) - async def save(self, service_id: str, metadata: ServiceMetadata, data: dict) -> None: + @staticmethod + async def save(service_id: str, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" service_path = BASE_PATH / service_id # 创建文件夹 @@ -66,10 +69,11 @@ class ServiceLoader: # 重新载入 file_checker = FileChecker() await file_checker.diff_one(service_path) - await self.load(service_id, file_checker.hashes[f"service/{service_id}"]) + await ServiceLoader.load(service_id, file_checker.hashes[f"service/{service_id}"]) - async def delete(self, service_id: str, *, is_reload: bool = False) -> None: + @staticmethod + async def delete(service_id: str, *, is_reload: bool = False) -> None: """删除Service,并更新数据库""" mongo = MongoDB() service_collection = mongo.get_collection("service") @@ -81,13 +85,11 @@ class ServiceLoader: logger.exception("[ServiceLoader] 删除Service失败") try: - # 获取 LanceDB 表 - service_table = await LanceDB().get_table("service") - node_table = await LanceDB().get_table("node") - - # 删除数据 - await service_table.delete(f"id = '{service_id}'") - await node_table.delete(f"id = '{service_id}'") + # 删除postgres中的向量数据 + async with postgres.session() as session: + await session.execute(delete(ServicePoolVector).where(ServicePoolVector.service_id == service_id)) + await session.execute(delete(NodePoolVector).where(NodePoolVector.service_id == service_id)) + await session.commit() except Exception: logger.exception("[ServiceLoader] 删除数据库失败") @@ -97,7 +99,8 @@ class ServiceLoader: shutil.rmtree(path) - async def _update_db(self, nodes: list[NodePool], metadata: ServiceMetadata) -> None: # noqa: C901, PLR0912, PLR0915 + @staticmethod + async def _update_db(nodes: list[NodePool], metadata: ServiceMetadata) -> None: """更新数据库""" if not metadata.hashes: err = f"[ServiceLoader] 服务 {metadata.id} 的哈希值为空" @@ -134,68 +137,40 @@ class ServiceLoader: logger.exception(err) raise RuntimeError(err) from e - # 向量化所有数据并保存 - while True: - try: - service_table = await LanceDB().get_table("service") - node_table = await LanceDB().get_table("node") - await service_table.delete(f"id = '{metadata.id}'") - await node_table.delete(f"service_id = '{metadata.id}'") - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[ServiceLoader] LanceDB删除service冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise - - # 进行向量化,更新LanceDB + # 删除旧的向量数据 + async with postgres.session() as session: + await session.execute(delete(ServicePoolVector).where(ServicePoolVector.service_id == metadata.id)) + await session.execute(delete(NodePoolVector).where(NodePoolVector.service_id == metadata.id)) + await session.commit() + + # 进行向量化,更新postgres service_vecs = await Embedding.get_embedding([metadata.description]) - service_vector_data = [ - ServicePoolVector( - id=metadata.id, - embedding=service_vecs[0], - ), - ] - while True: - try: - service_table = await LanceDB().get_table("service") - await service_table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - service_vector_data, - ) - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[ServiceLoader] LanceDB插入service冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise + async with postgres.session() as session: + await session.execute( + insert(ServicePoolVector), + [ + { + "service_id": metadata.id, + "embedding": service_vecs[0], + }, + ], + ) + await session.commit() node_descriptions = [] for node in nodes: node_descriptions += [node.description] node_vecs = await Embedding.get_embedding(node_descriptions) - node_vector_data = [] - for i, vec in enumerate(node_vecs): - node_vector_data.append( - NodePoolVector( - id=nodes[i].id, - service_id=metadata.id, - embedding=vec, - ), + async with postgres.session() as session: + await session.execute( + insert(NodePoolVector), + [ + { + "service_id": metadata.id, + "embedding": vec, + } + for vec in node_vecs + ], ) - while True: - try: - node_table = await LanceDB().get_table("node") - await node_table.merge_insert("id").when_matched_update_all().when_not_matched_insert_all().execute( - node_vector_data, - ) - break - except Exception as e: - if "Commit conflict" in str(e): - logger.error("[ServiceLoader] LanceDB插入node冲突,重试中...") # noqa: TRY400 - await asyncio.sleep(0.01) - else: - raise - + await session.commit() diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 525f9f19a05c28769a9107a51afcf9662d7d2e2f..10b9296cbedebd17631bb5f9329d5c243aaf78b3 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -33,7 +33,7 @@ async def get_docs(user_sub: str, post_body: RequestData) -> tuple[list[RecordDo doc_ids = [] - docs = await DocumentManager.get_used_docs_by_record_group(user_sub, post_body.group_id) + docs = await DocumentManager.get_used_docs_by_record(user_sub, post_body.group_id) if not docs: # 是新提问 # 从Conversation中获取刚上传的文档 diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 9b4fccadf46e0a488537eace1c9442f5f1075564..68c9998001985338aaf952ef038c8493aa1d8aaa 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -7,7 +7,8 @@ from textwrap import dedent from apps.common.config import config from apps.common.queue import MessageQueue -from apps.schemas.collection import LLM, Document +from apps.models.document import Document +from apps.schemas.collection import LLM from apps.schemas.enum_var import EventType from apps.schemas.message import ( DocumentAddContent, @@ -119,7 +120,7 @@ async def _push_rag_chunk(task: Task, queue: MessageQueue, content: str) -> tupl async def push_document_message(task: Task, queue: MessageQueue, doc: RecordDocument | Document) -> Task: """推送文档消息""" content = DocumentAddContent( - documentId=doc.id, + documentId=str(doc.id), documentName=doc.name, documentType=doc.type, documentSize=round(doc.size, 2), diff --git a/apps/schemas/collection.py b/apps/schemas/collection.py index ae2b452fadcc662e99c0eadd2d9a26724fe906ff..29f76539fa90d3092f06518cd3589497083fcf16 100644 --- a/apps/schemas/collection.py +++ b/apps/schemas/collection.py @@ -7,7 +7,6 @@ from datetime import UTC, datetime from pydantic import BaseModel, Field from apps.common.config import config -from apps.constants import NEW_CHAT from apps.templates.generate_llm_operator_config import llm_provider_dict @@ -27,42 +26,6 @@ class Blacklist(BaseModel): updated_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) -class UserDomainData(BaseModel): - """用户领域数据""" - - name: str - count: int - - -class AppUsageData(BaseModel): - """User表子项:应用使用情况数据""" - - count: int = 0 - last_used: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - - -class User(BaseModel): - """ - 用户信息 - - Collection: user - 外键:user - conversation - """ - - id: str = Field(alias="_id") - last_login: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - is_active: bool = False - is_whitelisted: bool = False - credit: int = 100 - api_key: str | None = None - conversations: list[str] = [] - domains: list[UserDomainData] = [] - app_usage: dict[str, AppUsageData] = {} - fav_apps: list[str] = [] - fav_services: list[str] = [] - is_admin: bool = Field(default=False, description="是否为管理员") - - class LLM(BaseModel): """ 大模型信息 @@ -78,84 +41,3 @@ class LLM(BaseModel): model_name: str = Field(default=config.llm.model) max_tokens: int | None = Field(default=config.llm.max_tokens) created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - - -class LLMItem(BaseModel): - """大模型信息""" - - llm_id: str = Field(default="empty") - model_name: str = Field(default=config.llm.model) - icon: str = Field(default=llm_provider_dict["ollama"]["icon"]) - - -class KnowledgeBaseItem(BaseModel): - """知识库信息""" - - kb_id: str - kb_name: str - - -class Conversation(BaseModel): - """ - 对话信息 - - Collection: conversation - 外键:conversation - task, document, record_group - """ - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - user_sub: str - title: str = NEW_CHAT - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - app_id: str | None = Field(default="") - tasks: list[str] = [] - unused_docs: list[str] = [] - record_groups: list[str] = [] - debug: bool = Field(default=False) - llm: LLMItem | None = None - kb_list: list[KnowledgeBaseItem] = Field(default=[]) - - -class Document(BaseModel): - """ - 文件信息 - - Collection: document - """ - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - user_sub: str - name: str - type: str - size: float - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - conversation_id: str - - -class Audit(BaseModel): - """ - 审计日志 - - Collection: audit - """ - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - user_sub: str | None = None - http_method: str - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - module: str - client_ip: str | None = None - message: str - - -class Domain(BaseModel): - """ - 领域信息 - - Collection: domain - """ - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - name: str - definition: str - updated_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/enum_var.py b/apps/schemas/enum_var.py index 9f171f7bc96d1c8e66e21437bf420d388149de1f..fdd056f9ce11636c5fd9fc3409516fd930b1aa47 100644 --- a/apps/schemas/enum_var.py +++ b/apps/schemas/enum_var.py @@ -45,13 +45,6 @@ class EventType(str, Enum): DONE = "done" -class CallType(str, Enum): - """Call类型""" - - SYSTEM = "system" - PYTHON = "python" - - class MetadataType(str, Enum): """元数据类型""" diff --git a/apps/schemas/flow.py b/apps/schemas/flow.py index 542f7407417c67428208307344f85a7e70e208d9..1c071dee6016bdf2c721de32e4d5779153f7b7eb 100644 --- a/apps/schemas/flow.py +++ b/apps/schemas/flow.py @@ -149,10 +149,3 @@ class ServiceApiSpec(BaseModel): size: int = Field(description="OpenAPI文件大小(单位:KB)") path: str = Field(description="OpenAPI文件路径") hash: str = Field(description="OpenAPI文件的hash值") - - -class FlowConfig(BaseModel): - """Flow的配置信息 用于前期调试使用""" - - flow_id: str - flow_config: Flow diff --git a/apps/schemas/node.py b/apps/schemas/node.py index 2803bc909950c4b418a4f74a750aa782c361e8fd..8391109a82b3a96c949d938f19428ff16d15ce0e 100644 --- a/apps/schemas/node.py +++ b/apps/schemas/node.py @@ -5,8 +5,6 @@ from typing import Any from pydantic import BaseModel, Field -from .pool import NodePool - class APINodeInput(BaseModel): """API节点覆盖输入""" @@ -21,7 +19,7 @@ class APINodeOutput(BaseModel): result: dict[str, Any] | None = Field(description="API节点输出Schema", default=None) -class APINode(NodePool): +class APINode(BaseModel): """API节点""" call_id: str = "API" diff --git a/apps/schemas/api_key.py b/apps/schemas/personal_token.py similarity index 49% rename from apps/schemas/api_key.py rename to apps/schemas/personal_token.py index 9f0083c33267b372078dfbf7503ac3a9ee3aab31..815d20d68c6d32822d557b1eb7f167e6b2c7a127 100644 --- a/apps/schemas/api_key.py +++ b/apps/schemas/personal_token.py @@ -6,25 +6,13 @@ from pydantic import BaseModel from .response_data import ResponseData -class _GetAuthKeyMsg(BaseModel): - """GET /api/auth/key Result数据结构""" - - api_key_exists: bool - - -class GetAuthKeyRsp(ResponseData): - """GET /api/auth/key 返回数据结构""" - - result: _GetAuthKeyMsg - - -class PostAuthKeyMsg(BaseModel): +class PostPersonalTokenMsg(BaseModel): """POST /api/auth/key Result数据结构""" api_key: str -class PostAuthKeyRsp(ResponseData): +class PostPersonalTokenRsp(ResponseData): """POST /api/auth/key 返回数据结构""" - result: PostAuthKeyMsg + result: PostPersonalTokenMsg diff --git a/apps/schemas/pool.py b/apps/schemas/pool.py index 0e9cef22d165e7f0fd7c29cdfd2a401293dc73e8..c3b105c8c18963e78f634440ff928985642add45 100644 --- a/apps/schemas/pool.py +++ b/apps/schemas/pool.py @@ -6,19 +6,6 @@ from typing import Any from pydantic import BaseModel, Field -from .appcenter import AppLink -from .enum_var import AppType, CallType, PermissionType -from .flow import AppFlow, Permission - - -class BaseData(BaseModel): - """Pool的基础信息""" - - id: str = Field(alias="_id") - name: str - description: str - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) - class ServiceApiInfo(BaseModel): """外部服务API信息""" @@ -28,85 +15,14 @@ class ServiceApiInfo(BaseModel): path: str = Field(description="OpenAPI文件路径") -class ServicePool(BaseData): - """ - 外部服务信息 - - collection: service - """ - - author: str = Field(description="作者的用户ID") - permission: Permission = Field(description="服务可见性配置", default=Permission(type=PermissionType.PUBLIC)) - hashes: dict[str, str] = Field(description="服务关联的 OpenAPI YAML 和元数据文件哈希") - - -class CallPool(BaseData): - """ - Call信息 - - collection: call - - “path”的格式如下: - 1. Python代码会被导入成包,路径格式为`python::::`,用于查找Call的包路径和类路径 - """ - - type: CallType = Field(description="Call的类型") - path: str = Field(description="Call的路径") - - -class Node(BaseData): +class Node(BaseModel): """Node合并后的信息(不存库)""" + id: str = Field(alias="_id") + name: str + description: str + created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) service_id: str | None = Field(description="Node所属的Service ID", default=None) call_id: str = Field(description="所使用的Call的ID") params_schema: dict[str, Any] = Field(description="Node的参数schema", default={}) output_schema: dict[str, Any] = Field(description="Node输出的完整Schema", default={}) - - -class NodePool(BaseData): - """ - Node合并前的信息(作为附带信息的指针) - - collection: node - - annotation为Node的路径,指示Node的类型、来源等 - annotation的格式如下: - 1. 无路径(如对应的Call等):为None - 2. 从openapi中获取:`openapi::` - """ - - service_id: str | None = Field(description="Node所属的Service ID", default=None) - call_id: str = Field(description="所使用的Call的ID") - known_params: dict[str, Any] | None = Field( - description="已知的用于Call部分的参数,独立于输入和输出之外", - default=None, - ) - override_input: dict[str, Any] | None = Field( - description="Node的输入Schema;用于描述Call的参数中特定的字段", - default=None, - ) - override_output: dict[str, Any] | None = Field( - description="Node的输出Schema;用于描述Call的输出中特定的字段", - default=None, - ) - - -class AppPool(BaseData): - """ - 应用信息 - - collection: app - """ - - author: str = Field(description="作者的用户ID") - app_type: AppType = Field(description="应用类型", default=AppType.FLOW) - type: str = Field(description="应用类型", default="default") - icon: str = Field(description="应用图标", default="") - published: bool = Field(description="是否发布", default=False) - links: list[AppLink] = Field(description="相关链接", default=[]) - first_questions: list[str] = Field(description="推荐问题", default=[]) - history_len: int = Field(3, ge=1, le=10, description="对话轮次(1~10)") - permission: Permission = Field(description="应用权限配置", default=Permission()) - flows: list[AppFlow] = Field(description="Flow列表", default=[]) - hashes: dict[str, str] = Field(description="关联文件的hash值", default={}) - mcp_service: list[str] = Field(default=[], alias="mcpService", description="MCP服务id列表") diff --git a/apps/schemas/record.py b/apps/schemas/record.py index d6117d8c18997d056404e2f4958cc0e82bccfaf3..6e2050b94c92e692609c701f7f4f9ba43fad0556 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -62,6 +62,7 @@ class RecordMetadata(BaseModel): time_cost: float = Field(default=0, alias="timeCost") feature: dict[str, Any] = {} + class RecordComment(BaseModel): """Record表子项:Record的评论信息""" @@ -102,21 +103,3 @@ class Record(RecordData): content: str comment: RecordComment = Field(default=RecordComment()) flow: list[str] = Field(default=[]) - - -class RecordGroup(BaseModel): - """ - 问答组 - - 多次重新生成的问答都是一个问答组 - Collection: record_group - 外键:record_group - document - """ - - id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - user_sub: str - records: list[Record] = [] - docs: list[RecordGroupDocument] = [] # 问题不变,所用到的文档不变 - conversation_id: str - task_id: str - created_at: float = Field(default_factory=lambda: round(datetime.now(tz=UTC).timestamp(), 3)) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index e5bd09cfe27256357e370e9dd00aaa1ef4795df2..c5083385a75cd6b407e096f5c3a9c04d7c2ad8e2 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -1,6 +1,7 @@ # Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. """FastAPI 请求体""" +import uuid from typing import Any from pydantic import BaseModel, Field @@ -137,25 +138,24 @@ class ModifyConversationData(BaseModel): class DeleteConversationData(BaseModel): """删除会话""" - conversation_list: list[str] = Field(alias="conversationList") + conversation_list: list[uuid.UUID] = Field(alias="conversationList") class AddCommentData(BaseModel): """添加评论""" record_id: str - group_id: str comment: CommentType dislike_reason: str = Field(default="", max_length=200) reason_link: str = Field(default="", max_length=200) reason_description: str = Field(default="", max_length=500) -class PostDomainData(BaseModel): +class PostTagData(BaseModel): """添加领域""" - domain_name: str = Field(..., min_length=1, max_length=100) - domain_description: str = Field(..., max_length=2000) + tag: str = Field(..., min_length=1, max_length=100) + description: str = Field(..., max_length=2000) class PutFlowReq(BaseModel): diff --git a/apps/scripts/delete_user.py b/apps/scripts/delete_user.py index 7cc85a5406cae2b4c32906c9447b1a05c2a628f1..0baef4a0f3b0b2e4ebecbf8a64192a7c21ea75d3 100644 --- a/apps/scripts/delete_user.py +++ b/apps/scripts/delete_user.py @@ -7,8 +7,6 @@ from datetime import UTC, datetime, timedelta import asyncer from apps.common.mongo import MongoDB -from apps.schemas.collection import Audit -from apps.services.audit_log import AuditLogManager from apps.services.knowledge_base import KnowledgeBaseService from apps.services.session import SessionManager from apps.services.user import UserManager @@ -20,7 +18,7 @@ async def _delete_user(timestamp: float) -> None: """异步删除用户""" user_ids = await UserManager.query_userinfo_by_login_time(timestamp) for user_id in user_ids: - await UserManager.delete_userinfo_by_user_sub(user_id) + await UserManager.delete_user(user_id) # 查找用户关联的文件 doc_collection = MongoDB().get_collection("document") docs = [doc["_id"] async for doc in doc_collection.find({"user_sub": user_id})] @@ -31,13 +29,6 @@ async def _delete_user(timestamp: float) -> None: await KnowledgeBaseService.delete_doc_from_rag(session_id, docs) except Exception: logger.exception("[DeleteUserCron] 自动删除用户 %s 文档失败", user_id) - audit_log = Audit( - user_sub=user_id, - http_method="DELETE", - module="user", - message=f"Automatic deleted user: {user_id}, for inactive more than 30 days", - ) - await AuditLogManager.add_audit_log(audit_log) if __name__ == "__main__": diff --git a/apps/scripts/user_exporter.py b/apps/scripts/user_exporter.py deleted file mode 100644 index a1dbf30c80d29c720fed6e8aff6e60375fdd4be5..0000000000000000000000000000000000000000 --- a/apps/scripts/user_exporter.py +++ /dev/null @@ -1,276 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""用户导出工具""" - -from __future__ import annotations - -import argparse -import datetime -import re -import secrets -import shutil -import zipfile -from pathlib import Path -from typing import ClassVar - -from openpyxl import Workbook - -from apps.common.security import Security -from apps.schemas.collection import Audit -from apps.services.audit_log import AuditLogManager -from apps.services.conversation import ConversationManager -from apps.services.record import RecordManager -from apps.services.user import UserManager - - -class UserExporter: - """用户导出工具类""" - - start_row_id: int = 1 - chat_xlsx_column: ClassVar[list[str]] = ["question", "answer", "created_time"] - chat_column_map: ClassVar[dict[str, int]] = { - "question_column": 1, - "answer_column": 2, - "created_time_column": 3, - } - user_info_xlsx_column: ClassVar[list[str]] = [ - "user_sub", "organization", - "created_time", "login_time", "revision_number", - ] - user_info_column_map: ClassVar[dict[str, int]] = { - "user_sub_column": 1, - "organization_column": 2, - "created_time_column": 3, - "login_time_column": 4, - "revision_number_column": 5, - } - - @staticmethod - def get_datetime_from_str(date_str: str, date_format: str) -> float: - """将日期字符串转换为时间戳""" - date_time_obj = datetime.datetime.strptime(date_str, date_format).astimezone(datetime.timezone.utc) - date_time_obj = datetime.datetime(date_time_obj.year, date_time_obj.month, date_time_obj.day).astimezone(datetime.timezone.utc) - return date_time_obj.timestamp() - - @staticmethod - def zip_xlsx_folder(tmp_out_dir: Path) -> Path: - """将xlsx文件夹压缩为zip文件""" - dir_name = tmp_out_dir.parent - last_dir_name = tmp_out_dir.name - xlsx_file_name_list = list(tmp_out_dir.glob("*.xlsx")) - zip_file_dir = Path(dir_name) / (last_dir_name + ".zip") - with zipfile.ZipFile(zip_file_dir, "w") as zip_file: - for xlsx_file_name in xlsx_file_name_list: - xlsx_file_path = tmp_out_dir / xlsx_file_name - zip_file.write(xlsx_file_path) - return zip_file_dir - - @staticmethod - def save_chat_to_xlsx(xlsx_dir, chat_list): - """将聊天记录保存到xlsx文件中""" - workbook = Workbook() - sheet = workbook.active - if sheet is None: - err = "Workbook没有active的sheet" - raise ValueError(err) - for i, column in enumerate(UserExporter.chat_xlsx_column): - sheet.cell(row=UserExporter.start_row_id, column=i+1, value=column) - row_id = UserExporter.start_row_id + 1 - for chat in chat_list: - question = chat[0] - answer = chat[1] - created_time = chat[2] - sheet.cell(row=row_id, - column=UserExporter.chat_column_map["question_column"], - value=question) - sheet.cell(row=row_id, - column=UserExporter.chat_column_map["answer_column"], - value=answer) - sheet.cell(row=row_id, - column=UserExporter.chat_column_map["created_time_column"], - value=created_time) - row_id += 1 - workbook.save(xlsx_dir) - - @staticmethod - def save_user_info_to_xlsx(xlsx_dir, user_info): - workbook = Workbook() - sheet = workbook.active - if sheet is None: - err = "Workbook没有active的sheet" - raise ValueError(err) - for i, column in enumerate(UserExporter.user_info_xlsx_column): - sheet.cell(row=UserExporter.start_row_id, column=i+1, value=column) - row_id = UserExporter.start_row_id + 1 - user_sub = user_info.user_sub - organization = user_info.organization - created_time = user_info.created_time - login_time = user_info.login_time - revision_number = user_info.revision_number - sheet.cell(row=row_id, - column=UserExporter.user_info_column_map["user_sub_column"], - value=user_sub) - sheet.cell(row=row_id, - column=UserExporter.user_info_column_map["organization_column"], - value=organization) - sheet.cell(row=row_id, - column=UserExporter.user_info_column_map["created_time_column"], - value=created_time) - sheet.cell(row=row_id, - column=UserExporter.user_info_column_map["login_time_column"], - value=login_time) - sheet.cell(row=row_id, - column=UserExporter.user_info_column_map["revision_number_column"], - value=revision_number) - workbook.save(xlsx_dir) - - @staticmethod - def export_user_info_to_xlsx(tmp_out_dir: Path | str, user_sub: str) -> None: - """ - 将用户信息导出到xlsx文件中。 - - Args: - tmp_out_dir (Path | str): 临时输出目录的路径 - user_sub (str): 用户的唯一标识符 - - Returns: - None - - Raises: - ValueError: 当用户信息获取失败时抛出 - - """ - info = UserManager.get_userinfo_by_user_sub(user_sub) - fname = f"user_info_{user_sub}.xlsx" - fname = re.sub(r'[<>:"/\\|?*]', "_", fname) - fname = fname.replace(" ", "_") - fpath = Path(tmp_out_dir) / fname - UserExporter.save_user_info_to_xlsx(fpath, info) - - @staticmethod - def export_chats_to_xlsx(tmp_out_dir, user_sub, start_day, end_day): - user_qa_records = ConversationManager.get_conversation_by_user_sub( - user_sub) - for user_qa_record in user_qa_records: - chat_id = user_qa_record.conversation_id - chat_tile = re.sub(r'[<>:"/\\|?*]', "_", user_qa_record.title) - chat_tile = chat_tile.replace(" ", "_")[:20] - chat_created_time = str(user_qa_record.created_time) - encrypted_qa_records = RecordManager.query_encrypted_data_by_conversation_id( - chat_id) - chat = [] - for record in encrypted_qa_records: - question = Security.decrypt(record.encrypted_question, - record.question_encryption_config) - answer = Security.decrypt(record.encrypted_answer, - record.answer_encryption_config) - qa_record_created_time = record.created_time - if start_day is not None and UserExporter.get_datetime_from_str(record.created_time, "%Y-%m-%d %H:%M:%S") < start_day: - continue - if end_day is not None and UserExporter.get_datetime_from_str(record.created_time, "%Y-%m-%d %H:%M:%S") > end_day: - continue - chat.append([question, answer, qa_record_created_time]) - xlsx_file_name = "chat_"+chat_tile[:20] + "_"+chat_created_time+".xlsx" - xlsx_file_name = xlsx_file_name.replace(" ", "") - xlsx_dir = Path(tmp_out_dir) / xlsx_file_name - UserExporter.save_chat_to_xlsx(xlsx_dir, chat) - - @staticmethod - def export_user_data(users_dir, user_sub, export_preferences=None, start_day=None, end_day=None): - export_preferences = export_preferences or ["user_info", "chat"] - rand_num = secrets.randbits(128) - tmp_out_dir = Path("./") / users_dir / str(rand_num) - if tmp_out_dir.exists(): - shutil.rmtree(tmp_out_dir) - tmp_out_dir.mkdir(parents=True, exist_ok=True) - tmp_out_dir.chmod(0o750) - if "user_info" in export_preferences: - UserExporter.export_user_info_to_xlsx(tmp_out_dir, user_sub) - if "chat" in export_preferences: - UserExporter.export_chats_to_xlsx(tmp_out_dir, user_sub, start_day, end_day) - zip_file_path = UserExporter.zip_xlsx_folder(tmp_out_dir) - shutil.rmtree(tmp_out_dir) - return zip_file_path - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--user_sub", type=str, required=True, - help="""Please provide usr_sub identifier for the export \ - process. This ID ensures that the exported data is \ - accurately associated with your user profile. If this \ - field is \"all\", then all user information will be \ - exported""") - parser.add_argument("--export_preferences", type=str, required=True, - help="""Please enter your export preferences by specifying \ - 'chat' and/or 'user_info', separated by a space \ - if including both. Ensure that your input is limited to \ - these options for accurate data export processing.""") - parser.add_argument("--start_day", type=str, required=False, - help="""User record export start date, format reference is \ - as follows: 2024_03_23""") - parser.add_argument("--end_day", type=str, required=False, - help="""User record export end date, format reference is \ - as follows: 2024_03_23""") - args = vars(parser.parse_args()) - arg_user_sub = args["user_sub"] - arg_export_preferences = args["export_preferences"].split(" ") - start_day = args["start_day"] - end_day = args["end_day"] - try: - if start_day is not None: - start_day = UserExporter.get_datetime_from_str(start_day, "%Y_%m_%d") - except Exception as e: - data = Audit( - user_sub=arg_user_sub, - http_method="internal_user_exporter", - module="export_user_data", - client_ip="internal", - message=f"start_day_exchange failed due error: {e}", - ) - AuditLogManager.add_audit_log(data) - try: - if end_day is not None: - end_day = UserExporter.get_datetime_from_str(end_day, "%Y_%m_%d") - except Exception as e: - data = Audit( - user_sub=arg_user_sub, - http_method="internal_user_exporter", - module="export_user_data", - client_ip="internal", - message=f"end_day_exchange failed due error: {e}", - ) - AuditLogManager.add_audit_log(data) - if arg_user_sub == "all": - user_sub_list = UserManager.get_all_user_sub() - else: - user_sub_list = [arg_user_sub] - users_dir = str(secrets.randbits(128)) - if Path(users_dir).exists(): - shutil.rmtree(users_dir) - Path(users_dir).mkdir(parents=True, exist_ok=True) - Path(users_dir).chmod(0o750) - for arg_user_sub in user_sub_list: - arg_user_sub = arg_user_sub[0] - try: - export_path = UserExporter.export_user_data( - users_dir, arg_user_sub, arg_export_preferences, start_day, end_day) - audit_export_preference = f", preference: {arg_export_preferences}" if arg_export_preferences else "" - data = Audit( - user_sub=arg_user_sub, - http_method="internal_user_exporter", - module="export_user_data", - client_ip="internal", - message=f"exported user data of id: {arg_user_sub}{audit_export_preference}, path: {export_path}", - ) - AuditLogManager.add_audit_log(data) - except Exception as e: - data = Audit( - user_sub=arg_user_sub, - http_method="internal_user_exporter", - module="export_user_data", - client_ip="internal", - message=f"用户(id: {arg_user_sub})请求导出数据失败: {e!s}", - ) - AuditLogManager.add_audit_log(data) - zip_file_path = UserExporter.zip_xlsx_folder(Path(users_dir)) - shutil.rmtree(users_dir) diff --git a/apps/services/api_key.py b/apps/services/api_key.py deleted file mode 100644 index 8f9711f81606040019ecd525a17a2ff13d89367a..0000000000000000000000000000000000000000 --- a/apps/services/api_key.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""API Key管理""" - -import hashlib -import logging -import uuid - -from apps.common.mongo import MongoDB - -logger = logging.getLogger(__name__) - - -class ApiKeyManager: - """API Key管理""" - - @staticmethod - async def generate_api_key(user_sub: str) -> str | None: - """ - 生成新API Key - - :param user_sub: 用户名 - :return: API Key - """ - mongo = MongoDB() - api_key = str(uuid.uuid4().hex) - api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] - - try: - user_collection = mongo.get_collection("user") - await user_collection.update_one( - {"_id": user_sub}, - {"$set": {"api_key": api_key_hash}}, - ) - except Exception: - logger.exception("[ApiKeyManager] 生成API Key失败") - return None - else: - return api_key - - @staticmethod - async def delete_api_key(user_sub: str) -> bool: - """ - 删除API Key - - :param user_sub: 用户ID - :return: 删除API Key是否成功 - """ - mongo = MongoDB() - if not await ApiKeyManager.api_key_exists(user_sub): - return False - try: - user_collection = mongo.get_collection("user") - await user_collection.update_one( - {"_id": user_sub}, - {"$unset": {"api_key": ""}}, - ) - except Exception: - logger.exception("[ApiKeyManager] 删除API Key失败") - return False - return True - - @staticmethod - async def api_key_exists(user_sub: str) -> bool: - """ - 检查API Key是否存在 - - :param user_sub: 用户ID - :return: API Key是否存在 - """ - mongo = MongoDB() - try: - user_collection = mongo.get_collection("user") - user_data = await user_collection.find_one({"_id": user_sub}, {"_id": 0, "api_key": 1}) - return user_data is not None and ("api_key" in user_data and user_data["api_key"]) - except Exception: - logger.exception("[ApiKeyManager] 检查API Key是否存在失败") - return False - - @staticmethod - async def get_user_by_api_key(api_key: str) -> str | None: - """ - 根据API Key获取用户信息 - - :param api_key: API Key - :return: 用户ID - """ - mongo = MongoDB() - api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] - try: - user_collection = mongo.get_collection("user") - user_data = await user_collection.find_one({"api_key": api_key_hash}, {"_id": 1}) - return user_data["_id"] if user_data else None - except Exception: - logger.exception("[ApiKeyManager] 根据API Key获取用户信息失败") - return None - - @staticmethod - async def verify_api_key(api_key: str) -> bool: - """ - 验证API Key,用于FastAPI dependency - - :param api_key: API Key - :return: 验证API Key是否成功 - """ - mongo = MongoDB() - api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] - try: - user_collection = mongo.get_collection("user") - key_data = await user_collection.find_one({"api_key": api_key_hash}, {"_id": 1}) - except Exception: - logger.exception("[ApiKeyManager] 验证API Key失败") - return False - else: - return key_data is not None - - @staticmethod - async def update_api_key(user_sub: str) -> str | None: - """ - 更新API Key - - :param user_sub: 用户ID - :return: 更新后的API Key - """ - mongo = MongoDB() - if not await ApiKeyManager.api_key_exists(user_sub): - return None - api_key = str(uuid.uuid4().hex) - api_key_hash = hashlib.sha256(api_key.encode()).hexdigest()[:16] - try: - user_collection = mongo.get_collection("user") - await user_collection.update_one( - {"_id": user_sub}, - {"$set": {"api_key": api_key_hash}}, - ) - except Exception: - logger.exception("[ApiKeyManager] 更新API Key失败") - return None - return api_key diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index 85a5505ff31b1a0ede6501204d12ba3989dd568e..32aa241a5d8dce4d56be7adc07ba58d0e9c2bf01 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -9,10 +9,10 @@ from typing import Any from apps.common.mongo import MongoDB from apps.constants import SERVICE_PAGE_SIZE from apps.exceptions import InstancePermissionError +from apps.models.user import User from apps.scheduler.pool.loader.app import AppLoader from apps.schemas.agent import AgentAppMetadata from apps.schemas.appcenter import AppCenterCardItem, AppData, AppPermissionData -from apps.schemas.collection import User from apps.schemas.enum_var import AppFilterType, AppType, PermissionType from apps.schemas.flow import AppMetadata, MetadataType, Permission from apps.schemas.pool import AppPool @@ -20,6 +20,7 @@ from apps.schemas.response_data import RecentAppList, RecentAppListItem from .flow import FlowManager from .mcp_service import MCPServiceManager +from .user import UserManager logger = logging.getLogger(__name__) @@ -218,11 +219,10 @@ class AppCenterManager: if not db_data: msg = "应用不存在" raise ValueError(msg) - db_user = await user_collection.find_one({"_id": user_sub}) - if not db_user: + user_data = await UserManager.get_user(user_sub) + if not user_data: msg = "用户不存在" raise ValueError(msg) - user_data = User.model_validate(db_user) already_favorited = app_id in user_data.fav_apps if favorited == already_favorited: diff --git a/apps/services/audit_log.py b/apps/services/audit_log.py deleted file mode 100644 index f7b4843a4fd8615bc8e6942dade3f48eda88c56d..0000000000000000000000000000000000000000 --- a/apps/services/audit_log.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""审计日志Manager""" - -import logging - -from apps.common.mongo import MongoDB -from apps.schemas.collection import Audit - -logger = logging.getLogger(__name__) - -class AuditLogManager: - """审计日志相关操作""" - - @staticmethod - async def add_audit_log(data: Audit) -> None: - """ - EulerCopilot审计日志 - - :param data: 审计日志数据 - """ - mongo = MongoDB() - collection = mongo.get_collection("audit") - await collection.insert_one(data.model_dump(by_alias=True)) diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index 7dad46f28c2a1e22de0e7da211828fd3d6275079..06f3450ebc7206e5ccaecd7030fa7e50939aca5f 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -6,10 +6,8 @@ import re from apps.common.mongo import MongoDB from apps.common.security import Security -from apps.schemas.collection import ( - Blacklist, - User, -) +from apps.models.user import User +from apps.schemas.collection import Blacklist from apps.schemas.record import Record, RecordContent logger = logging.getLogger(__name__) diff --git a/apps/services/comment.py b/apps/services/comment.py index 60c13905ed8d991eec341b8368912a7e7298aa04..ce459bc923b22005b2bdb453faa372a2966f6546 100644 --- a/apps/services/comment.py +++ b/apps/services/comment.py @@ -2,8 +2,12 @@ """评论 Manager""" import logging +import uuid -from apps.common.mongo import MongoDB +from sqlalchemy import select + +from apps.common.postgres import postgres +from apps.models.comment import Comment from apps.schemas.record import RecordComment logger = logging.getLogger(__name__) @@ -13,38 +17,56 @@ class CommentManager: """评论相关操作""" @staticmethod - async def query_comment(group_id: str, record_id: str) -> RecordComment | None: + async def query_comment(record_id: str) -> RecordComment | None: """ 根据问答ID查询评论 :param record_id: 问答ID :return: 评论内容 """ - record_group_collection = MongoDB().get_collection("record_group") - result = await record_group_collection.aggregate( - [ - {"$match": {"_id": group_id, "records.id": record_id}}, - {"$unwind": "$records"}, - {"$match": {"records.id": record_id}}, - {"$limit": 1}, - ], - ) - result = await result.to_list(length=1) - if result: - return RecordComment.model_validate(result[0]["records"]["comment"]) - return None + async with postgres.session() as session: + result = ( + await session.scalars( + select(Comment).where(Comment.record_id == uuid.UUID(record_id)), + ) + ).one_or_none() + if result: + return RecordComment( + comment=result.comment_type, + dislike_reason=result.feedback_type, + reason_link=result.feedback_link, + reason_description=result.feedback_content, + feedback_time=round(result.created_at.timestamp(), 3), + ) + return None @staticmethod - async def update_comment(group_id: str, record_id: str, data: RecordComment) -> None: + async def update_comment(record_id: str, data: RecordComment, user_sub: str) -> None: """ 更新评论 :param record_id: 问答ID :param data: 评论内容 """ - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - await record_group_collection.update_one( - {"_id": group_id, "records.id": record_id}, - {"$set": {"records.$.comment": data.model_dump(by_alias=True)}}, - ) + async with postgres.session() as session: + result = ( + await session.scalars( + select(Comment).where(Comment.record_id == uuid.UUID(record_id)), + ) + ).one_or_none() + if result: + result.comment_type = data.comment + result.feedback_type = data.feedback_type + result.feedback_link = data.feedback_link + result.feedback_content = data.feedback_content + else: + comment_info = Comment( + record_id=uuid.UUID(record_id), + user_sub=user_sub, + comment_type=data.comment, + feedback_type=data.feedback_type, + feedback_link=data.feedback_link, + feedback_content=data.feedback_content, + ) + session.add(comment_info) + await session.commit() diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 9e4f97454213feca8be1f1bd125a79a9298c762e..1cc6f5ff84f672b8b122dc09e1049b6620ae0969 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -3,16 +3,16 @@ import logging import uuid -from datetime import UTC, datetime +from datetime import datetime from typing import Any -from apps.common.config import config -from apps.common.mongo import MongoDB -from apps.schemas.collection import Conversation, KnowledgeBaseItem, LLMItem -from apps.templates.generate_llm_operator_config import llm_provider_dict +import pytz +from sqlalchemy import func, select + +from apps.common.postgres import postgres +from apps.models.conversation import Conversation +from apps.models.user import UserAppUsage -from .knowledge import KnowledgeBaseManager -from .llm import LLMManager from .task import TaskManager logger = logging.getLogger(__name__) @@ -24,117 +24,119 @@ class ConversationManager: @staticmethod async def get_conversation_by_user_sub(user_sub: str) -> list[Conversation]: """根据用户ID获取对话列表,按时间由近到远排序""" - conv_collection = MongoDB().get_collection("conversation") - return [ - Conversation(**conv) - async for conv in conv_collection.find({"user_sub": user_sub, "debug": False}).sort({"created_at": 1}) - ] + async with postgres.session() as session: + result = (await session.scalars( + select(Conversation).where( + Conversation.user_sub == user_sub, + Conversation.is_temporary == False, # noqa: E712 + ).order_by( + Conversation.created_at.desc(), + ), + )).all() + return list(result) + @staticmethod - async def get_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> Conversation | None: + async def get_conversation_by_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> Conversation | None: """通过ConversationID查询对话信息""" - conv_collection = MongoDB().get_collection("conversation") - result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) - if not result: - return None - return Conversation.model_validate(result) + async with postgres.session() as session: + return (await session.scalars( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_sub == user_sub, + ), + )).one_or_none() + @staticmethod - async def add_conversation_by_user_sub( - user_sub: str, app_id: str, llm_id: str, kb_ids: list[str], - *, debug: bool) -> Conversation | None: + async def add_conversation_by_user_sub(user_sub: str, app_id: uuid.UUID, *, debug: bool) -> Conversation | None: """通过用户ID新建对话""" - if llm_id == "empty": - llm_item = LLMItem( - llm_id="empty", - model_name=config.llm.model, - icon=llm_provider_dict["ollama"]["icon"], - ) - else: - llm = await LLMManager.get_llm_by_id(user_sub, llm_id) - if llm is None: - logger.error("[ConversationManager] 获取大模型失败") - return None - llm_item = LLMItem( - llm_id=llm.id, - model_name=llm.model_name, - ) - kb_item_list = [] - team_kb_list = await KnowledgeBaseManager.get_team_kb_list_from_rag(user_sub, None, None) - for team_kb in team_kb_list: - for kb in team_kb["kbList"]: - if str(kb["kbId"]) in kb_ids: - kb_item = KnowledgeBaseItem( - kb_id=kb["kbId"], - kb_name=kb["kbName"], - ) - kb_item_list.append(kb_item) - conversation_id = str(uuid.uuid4()) conv = Conversation( - _id=conversation_id, user_sub=user_sub, app_id=app_id, - llm=llm_item, - kb_list=kb_item_list, - debug=debug if debug else False, + is_temporary=debug, ) - mongo = MongoDB() + # 使用PostgreSQL实现新建对话,并根据debug和usage进行更新 try: - async with mongo.get_session() as session, await session.start_transaction(): - conv_collection = mongo.get_collection("conversation") - await conv_collection.insert_one(conv.model_dump(by_alias=True), session=session) - user_collection = mongo.get_collection("user") - update_data: dict[str, dict[str, Any]] = { - "$push": {"conversations": conversation_id}, - } - if app_id: - # 非调试模式下更新应用使用情况 - if not debug: - update_data["$set"] = { - f"app_usage.{app_id}.last_used": round(datetime.now(UTC).timestamp(), 3), - } - update_data["$inc"] = {f"app_usage.{app_id}.count": 1} - await user_collection.update_one( - {"_id": user_sub}, - update_data, - session=session, - ) - await session.commit_transaction() + async with postgres.session() as session: + session.add(conv) + await session.commit() + await session.refresh(conv) + + # 如果是非调试模式且app_id存在,更新App的使用情况 + if app_id and not debug: + app_obj = (await session.scalars( + select(UserAppUsage).where(UserAppUsage.user_sub == user_sub, UserAppUsage.app_id == app_id), + )).one_or_none() + if app_obj: + # 假设App模型有last_used和usage_count字段(如没有请根据实际表结构调整) + # 这里只做示例,实际字段名和类型请根据实际情况修改 + app_obj.usage_count += 1 + app_obj.last_used = datetime.now(pytz.timezone("Asia/Shanghai")) + session.add(app_obj) + await session.commit() + else: + session.add(UserAppUsage( + user_sub=user_sub, + app_id=app_id, + usage_count=1, + last_used=datetime.now(pytz.timezone("Asia/Shanghai")), + )) + await session.commit() return conv except Exception: logger.exception("[ConversationManager] 新建对话失败") return None + @staticmethod - async def update_conversation_by_conversation_id(user_sub: str, conversation_id: str, data: dict[str, Any]) -> bool: + async def update_conversation_by_conversation_id( + user_sub: str, conversation_id: uuid.UUID, data: dict[str, Any], + ) -> bool: """通过ConversationID更新对话信息""" - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") - result = await conv_collection.update_one( - {"_id": conversation_id, "user_sub": user_sub}, - {"$set": data}, - ) - return result.modified_count > 0 + async with postgres.session() as session: + conv = (await session.scalars( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_sub == user_sub, + ), + )).one_or_none() + if not conv: + return False + for key, value in data.items(): + setattr(conv, key, value) + session.add(conv) + await session.commit() + return True + @staticmethod - async def delete_conversation_by_conversation_id(user_sub: str, conversation_id: str) -> None: + async def delete_conversation_by_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> None: """通过ConversationID删除对话""" - mongo = MongoDB() - user_collection = mongo.get_collection("user") - conv_collection = mongo.get_collection("conversation") - record_group_collection = mongo.get_collection("record_group") - - async with mongo.get_session() as session, await session.start_transaction(): - conversation_data = await conv_collection.find_one_and_delete( - {"_id": conversation_id, "user_sub": user_sub}, session=session, - ) - if not conversation_data: + async with postgres.session() as session: + conv = (await session.scalars( + select(Conversation).where( + Conversation.id == conversation_id, + Conversation.user_sub == user_sub, + ), + )).one_or_none() + if not conv: return - await user_collection.update_one( - {"_id": user_sub}, {"$pull": {"conversations": conversation_id}}, session=session, - ) - await record_group_collection.delete_many({"conversation_id": conversation_id}, session=session) - await session.commit_transaction() + await session.delete(conv) + await session.commit() await TaskManager.delete_tasks_by_conversation_id(conversation_id) + + + @staticmethod + async def verify_conversation_id(user_sub: str, conversation_id: uuid.UUID) -> bool: + """验证对话ID是否属于用户""" + async with postgres.session() as session: + result = (await session.scalars( + func.count(Conversation.id).where( + Conversation.id == conversation_id, + Conversation.user_sub == user_sub, + ), + )).one() + return bool(result) diff --git a/apps/services/document.py b/apps/services/document.py index f4b4733c70939cc8ea46d05d262a90491d37a7df..04299c498c8598a2618faa177359159c92bd8ca5 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -7,14 +7,14 @@ import uuid import asyncer from fastapi import UploadFile +from sqlalchemy import func, select from apps.common.minio import MinioClient -from apps.common.mongo import MongoDB -from apps.schemas.collection import ( - Conversation, - Document, -) -from apps.schemas.record import RecordDocument, RecordGroup, RecordGroupDocument +from apps.common.postgres import postgres +from apps.models.conversation import Conversation, ConversationDocument +from apps.models.document import Document +from apps.models.record import Record, RecordDocument +from apps.schemas.record import RecordGroupDocument from .knowledge_base import KnowledgeBaseService from .session import SessionManager @@ -25,8 +25,8 @@ logger = logging.getLogger(__name__) class DocumentManager: """文件相关操作""" - @classmethod - def _storage_single_doc_minio(cls, file_id: str, document: UploadFile) -> str: + @staticmethod + def _storage_single_doc_minio(file_id: uuid.UUID, document: UploadFile) -> str: """存储单个文件到MinIO""" MinioClient.check_bucket("document") file = document.file @@ -38,134 +38,130 @@ class DocumentManager: # 上传到MinIO MinioClient.upload_file( bucket_name="document", - object_name=file_id, + object_name=str(file_id), data=file, content_type=mime, length=-1, part_size=10 * 1024 * 1024, metadata={ - "file_name": base64.b64encode(document.filename.encode("utf-8")).decode("ascii") , # type: ignore[arg-type] + "file_name": base64.b64encode( + document.filename.encode("utf-8") + ).decode("ascii") if document.filename else "", }, ) return mime - @classmethod - async def storage_docs(cls, user_sub: str, conversation_id: str, documents: list[UploadFile]) -> list[Document]: + @staticmethod + async def storage_docs( + user_sub: str, conversation_id: uuid.UUID, documents: list[UploadFile], + ) -> list[Document]: """存储多个文件""" uploaded_files = [] - mongo = MongoDB() - doc_collection = mongo.get_collection("document") - conversation_collection = mongo.get_collection("conversation") for document in documents: if document.filename is None or document.size is None: + logger.error("[DocumentManager] 文件名或大小为空: %s, %s", document.filename, document.size) continue - file_id = str(uuid.uuid4()) + file_id = uuid.uuid4() try: - mime = await asyncer.asyncify(cls._storage_single_doc_minio)(file_id, document) + mime = await asyncer.asyncify(DocumentManager._storage_single_doc_minio)(file_id, document) except Exception: logger.exception("[DocumentManager] 上传文件失败") continue - # 保存到MongoDB + # 保存到数据库 doc_info = Document( - _id=file_id, user_sub=user_sub, name=document.filename, type=mime, size=document.size / 1024.0, conversation_id=conversation_id, ) - await doc_collection.insert_one(doc_info.model_dump(by_alias=True)) - await conversation_collection.update_one( - {"_id": conversation_id}, - { - "$push": {"unused_docs": file_id}, - }, - ) - - # 准备返回值 uploaded_files.append(doc_info) + async with postgres.session() as session: + session.add_all(uploaded_files) + await session.commit() return uploaded_files - @classmethod - async def get_unused_docs(cls, user_sub: str, conversation_id: str) -> list[Document]: - """获取Conversation中未使用的文件""" - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") - doc_collection = mongo.get_collection("document") - - conv = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) - if not conv: - logger.error("[DocumentManager] 对话不存在: %s", conversation_id) - return [] - - docs_ids = conv.get("unused_docs", []) - return [Document(**doc) async for doc in doc_collection.find({"_id": {"$in": docs_ids}})] - @classmethod - async def get_used_docs_by_record_group(cls, user_sub: str, record_group_id: str) -> list[RecordDocument]: + @staticmethod + async def get_unused_docs(conversation_id: uuid.UUID) -> list[Document]: + """获取Conversation中未使用的文件""" + async with postgres.session() as session: + conv = (await session.scalars( + select(ConversationDocument).where( + ConversationDocument.conversation_id == conversation_id, + ConversationDocument.is_unused.is_(True), + ), + )).all() + if not conv: + logger.error("[DocumentManager] 对话不存在: %s", conversation_id) + return [] + + docs_ids = [doc.document_id for doc in conv] + docs = (await session.scalars(select(Document).where(Document.id.in_(docs_ids)))).all() + return list(docs) + + + @staticmethod + async def get_used_docs_by_record(record_id: str) -> list[Document]: """获取RecordGroup关联的文件""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - docs_collection = mongo.get_collection("document") - - record_group = await record_group_collection.find_one({"_id": record_group_id, "user_sub": user_sub}) - if not record_group: - logger.info("[DocumentManager] 记录组不存在: %s", record_group_id) - return [] - - docs = RecordGroup.model_validate(record_group).docs - doc_ids = [doc.id for doc in docs] - doc_infos = [Document.model_validate(doc) async for doc in docs_collection.find({"_id": {"$in": doc_ids}})] - return [ - RecordDocument( - _id=item[0].id, - name=item[1].name, - type=item[1].type, - size=item[1].size, - conversation_id=item[1].conversation_id, - associated=item[0].associated, - ) - for item in zip(docs, doc_infos, strict=True) - ] - - @classmethod - async def get_used_docs(cls, user_sub: str, conversation_id: str, record_num: int | None = 10) -> list[Document]: + async with postgres.session() as session: + record_docs = (await session.scalars( + select(RecordDocument).where(RecordDocument.record_id == record_id), + )).all() + if not list(record_docs): + logger.info("[DocumentManager] 记录组不存在: %s", record_id) + return [] + + doc_infos: list[Document] = [] + for doc in record_docs: + doc_info = (await session.scalars(select(Document).where(Document.id == doc.document_id))).one_or_none() + if doc_info: + doc_infos.append(doc_info) + return doc_infos + + + @staticmethod + async def get_used_docs(conversation_id: str, record_num: int | None = 10) -> list[Document]: """获取最后n次问答所用到的文件""" - mongo = MongoDB() - docs_collection = mongo.get_collection("document") - record_group_collection = mongo.get_collection("record_group") - - if record_num: - record_groups = ( - record_group_collection.find({"conversation_id": conversation_id, "user_sub": user_sub}) - .sort("created_at", -1) - .limit(record_num) - ) - else: - record_groups = record_group_collection.find( - {"conversation_id": conversation_id, "user_sub": user_sub}, - ).sort("created_at", -1) - - docs = [] - async for current_record_group in record_groups: - for doc in RecordGroup.model_validate(current_record_group).docs: - docs += [doc.id] - # 文件ID去重 - docs = list(set(docs)) - # 返回文件详细信息 - return [Document.model_validate(doc) async for doc in docs_collection.find({"_id": {"$in": docs}})] - - @classmethod - def _remove_doc_from_minio(cls, doc_id: str) -> None: + async with postgres.session() as session: + records = (await session.scalars( + select(Record).where( + Record.conversation_id == conversation_id, + ).order_by(Record.created_at.desc()).limit(record_num), + )).all() + + docs = [] + for current_record in records: + record_docs = ( + await session.scalars( + select(RecordDocument).where(RecordDocument.record_id == current_record.id), + ) + ).all() + if list(record_docs): + docs += [doc.document_id for doc in record_docs] + + # 去重 + docs = list(set(docs)) + result = [] + for doc_id in docs: + doc = (await session.scalars(select(Document).where(Document.id == doc_id))).one_or_none() + if doc: + result.append(doc) + return result + + + @staticmethod + def _remove_doc_from_minio(doc_id: str) -> None: """从MinIO中删除文件""" MinioClient.delete_file("document", doc_id) - @classmethod - async def delete_document(cls, user_sub: str, document_list: list[str]) -> bool: + + @staticmethod + async def delete_document(user_sub: str, document_list: list[str]) -> bool: """从未使用文件列表中删除一个文件""" mongo = MongoDB() doc_collection = mongo.get_collection("document") @@ -200,8 +196,9 @@ class DocumentManager: logger.exception("[DocumentManager] 删除文件失败") return False - @classmethod - async def delete_document_by_conversation_id(cls, user_sub: str, conversation_id: str) -> list[str]: + + @staticmethod + async def delete_document_by_conversation_id(conversation_id: uuid.UUID) -> list[str]: """通过ConversationID删除文件""" mongo = MongoDB() doc_collection = mongo.get_collection("document") @@ -223,15 +220,20 @@ class DocumentManager: await KnowledgeBaseService.delete_doc_from_rag(session_id, doc_ids) return doc_ids - @classmethod - async def get_doc_count(cls, user_sub: str, conversation_id: str) -> int: + + @staticmethod + async def get_doc_count(conversation_id: str) -> int: """获取对话文件数量""" - mongo = MongoDB() - doc_collection = mongo.get_collection("document") - return await doc_collection.count_documents({"user_sub": user_sub, "conversation_id": conversation_id}) + async with postgres.session() as session: + return (await session.scalars( + select(func.count(ConversationDocument.id)).where( + ConversationDocument.conversation_id == conversation_id, + ), + )).one() + - @classmethod - async def change_doc_status(cls, user_sub: str, conversation_id: str, record_group_id: str) -> None: + @staticmethod + async def change_doc_status(user_sub: str, conversation_id: str, record_group_id: str) -> None: """文件状态由unused改为used""" mongo = MongoDB() record_group_collection = mongo.get_collection("record_group") @@ -255,8 +257,8 @@ class DocumentManager: # 把unused_docs从Conversation中删除 await conversation_collection.update_one({"_id": conversation_id}, {"$set": {"unused_docs": []}}) - @classmethod - async def save_answer_doc(cls, user_sub: str, record_group_id: str, doc_ids: list[str]) -> None: + @staticmethod + async def save_answer_doc(user_sub: str, record_group_id: str, doc_ids: list[str]) -> None: """保存与答案关联的文件""" mongo = MongoDB() record_group_collection = mongo.get_collection("record_group") diff --git a/apps/services/domain.py b/apps/services/domain.py deleted file mode 100644 index 7b512a8343cb06e298b18df3b8db750be41d7cf0..0000000000000000000000000000000000000000 --- a/apps/services/domain.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""画像领域管理""" - -import logging -from datetime import UTC, datetime - -from apps.common.mongo import MongoDB -from apps.schemas.collection import Domain -from apps.schemas.request_data import PostDomainData - -logger = logging.getLogger(__name__) - - -class DomainManager: - """用户画像相关操作""" - - @staticmethod - async def get_domain() -> list[Domain]: - """ - 获取所有领域信息 - - :return: 领域信息列表 - """ - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") - return [Domain(**domain) async for domain in domain_collection.find()] - - - @staticmethod - async def get_domain_by_domain_name(domain_name: str) -> Domain | None: - """ - 根据领域名称获取领域信息 - - :param domain_name: 领域名称 - :return: 领域信息 - """ - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") - domain_data = await domain_collection.find_one({"domain_name": domain_name}) - if domain_data: - return Domain(**domain_data) - return None - - - @staticmethod - async def add_domain(domain_data: PostDomainData) -> None: - """ - 添加领域 - - :param domain_data: 领域信息 - """ - mongo = MongoDB() - domain = Domain( - name=domain_data.domain_name, - definition=domain_data.domain_description, - ) - domain_collection = mongo.get_collection("domain") - await domain_collection.insert_one(domain.model_dump(by_alias=True)) - - - @staticmethod - async def update_domain_by_domain_name(domain_data: PostDomainData) -> Domain: - """ - 更新领域 - - :param domain_data: 领域信息 - :return: 更新后的领域信息 - """ - mongo = MongoDB() - update_dict = { - "definition": domain_data.domain_description, - "updated_at": round(datetime.now(tz=UTC).timestamp(), 3), - } - domain_collection = mongo.get_collection("domain") - await domain_collection.update_one( - {"name": domain_data.domain_name}, - {"$set": update_dict}, - ) - return Domain(name=domain_data.domain_name, **update_dict) - - @staticmethod - async def delete_domain_by_domain_name(domain_data: PostDomainData) -> None: - """ - 删除领域 - - :param domain_data: 领域信息 - """ - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") - await domain_collection.delete_one({"name": domain_data.domain_name}) diff --git a/apps/services/flow_validate.py b/apps/services/flow_validate.py index a4185c82938bb99bbf37ec4659f98fad0ca4f86d..855494d2677c386c4c31b7c0ecb9e3e5d56b3956 100644 --- a/apps/services/flow_validate.py +++ b/apps/services/flow_validate.py @@ -14,6 +14,9 @@ if TYPE_CHECKING: from pydantic import BaseModel logger = logging.getLogger(__name__) +BRANCH_ILLEGAL_CHARS = [ + ".", +] class FlowService: @@ -21,7 +24,7 @@ class FlowService: @staticmethod def _validate_branch_id( - node_name: str, branch_id: str, node_branches: set, branch_illegal_chars: str = ".", + node_name: str, branch_id: str, node_branches: set, ) -> None: """验证分支ID的合法性;当分支ID重复或包含非法字符时抛出异常""" if branch_id in node_branches: @@ -29,7 +32,7 @@ class FlowService: logger.error(err) raise FlowBranchValidationError(err) - for illegal_char in branch_illegal_chars: + for illegal_char in BRANCH_ILLEGAL_CHARS: if illegal_char in branch_id: err = f"[FlowService] 节点{node_name}的分支{branch_id}名称中含有非法字符" logger.error(err) @@ -39,7 +42,6 @@ class FlowService: async def remove_excess_structure_from_flow(flow_item: FlowItem) -> FlowItem: """移除流程图中的多余结构""" node_branch_map = {} - branch_illegal_chars = "." for node in flow_item.nodes: if node.node_id not in {"start", "end", "Empty"}: try: @@ -61,7 +63,7 @@ class FlowService: err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}重复" logger.error(err) raise Exception(err) - for illegal_char in branch_illegal_chars: + for illegal_char in BRANCH_ILLEGAL_CHARS: if illegal_char in choice["branchId"]: err = f"[FlowService] 节点{node.name}的分支{choice['branchId']}名称中含有非法字符" logger.error(err) diff --git a/apps/services/knowledge.py b/apps/services/knowledge.py index 876321dc70b460c708e641e081dc1db6d9690063..d5db3572440c78a5041ad79f2414df46053e04cf 100644 --- a/apps/services/knowledge.py +++ b/apps/services/knowledge.py @@ -9,7 +9,6 @@ from fastapi import status from apps.common.config import config from apps.common.mongo import MongoDB -from apps.schemas.collection import KnowledgeBaseItem from apps.schemas.response_data import KnowledgeBaseItem as KnowledgeBaseItemResponse from apps.schemas.response_data import TeamKnowledgeBaseItem @@ -21,26 +20,6 @@ logger = logging.getLogger(__name__) class KnowledgeBaseManager: """用户资产库管理""" - @staticmethod - async def get_kb_ids_by_conversation_id(user_sub: str, conversation_id: str) -> list[str]: - """ - 通过对话ID获取知识库ID - - :param user_sub: 用户ID - :param conversation_id: 对话ID - :return: 知识库ID列表 - """ - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") - result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) - if not result: - err_msg = "[KnowledgeBaseManager] 获取知识库ID失败,未找到对话" - logger.error(err_msg) - return [] - kb_config_list = result.get("kb_list", []) - return [kb_config["kb_id"] for kb_config in kb_config_list] - - @staticmethod async def get_team_kb_list_from_rag( user_sub: str, @@ -120,7 +99,7 @@ class KnowledgeBaseManager: kb_ids: list[str], ) -> list[str]: """ - 更新对话的知识库列表 + 更新用户当前选择的知识库 :param user_sub: 用户sub :param conversation_id: 对话ID diff --git a/apps/services/knowledge_base.py b/apps/services/knowledge_base.py index bd99c67fe51099d64ffe1bc65ef3add232cf1e8b..11aab144c26628bd6e42102720023f9b6159e813 100644 --- a/apps/services/knowledge_base.py +++ b/apps/services/knowledge_base.py @@ -7,7 +7,7 @@ import httpx from fastapi import status from apps.common.config import config -from apps.schemas.collection import Document +from apps.models.document import Document from apps.schemas.rag_data import ( RAGFileParseReq, RAGFileParseReqItem, @@ -32,7 +32,7 @@ class KnowledgeBaseService: "Authorization": f"Bearer {session_id}", } rag_docs = [RAGFileParseReqItem( - id=doc.id, + id=str(doc.id), name=doc.name, bucket_name="document", type=doc.type, diff --git a/apps/services/node.py b/apps/services/node.py index 6f0d492edacdd7611324a38953417e0fa769249b..78fd5b8ba91b40d9a8291421a7dc945c85bec007 100644 --- a/apps/services/node.py +++ b/apps/services/node.py @@ -4,53 +4,31 @@ import logging from typing import TYPE_CHECKING, Any -from apps.common.mongo import MongoDB -from apps.schemas.node import APINode -from apps.schemas.pool import NodePool +from sqlalchemy import select + +from apps.common.postgres import postgres +from apps.models.node import Node if TYPE_CHECKING: from pydantic import BaseModel logger = logging.getLogger(__name__) -NODE_TYPE_MAP = { - "API": APINode, -} + class NodeManager: """Node管理器""" @staticmethod - async def get_node_call_id(node_id: str) -> str: - """获取Node的call_id""" - node_collection = MongoDB().get_collection("node") - node = await node_collection.find_one({"_id": node_id}, {"call_id": 1}) - if not node: - err = f"[NodeManager] Node call_id {node_id} not found." - raise ValueError(err) - return node["call_id"] - - - @staticmethod - async def get_node(node_id: str) -> NodePool: - """获取Node的类型""" - node_collection = MongoDB().get_collection("node") - node = await node_collection.find_one({"_id": node_id}) - if not node: - err = f"[NodeManager] Node {node_id} not found." - raise ValueError(err) - return NodePool.model_validate(node) - - - @staticmethod - async def get_node_name(node_id: str) -> str: - """获取node的名称""" - node_collection = MongoDB().get_collection("node") - # 查询 Node 集合获取对应的 name - node_doc = await node_collection.find_one({"_id": node_id}, {"name": 1}) - if not node_doc: - logger.error("[NodeManager] Node %s not found", node_id) - return "" - return node_doc["name"] + async def get_node(node_id: str) -> Node: + """获取Node信息""" + async with postgres.session() as session: + node = (await session.scalars( + select(Node).where(Node.id == node_id), + )).one_or_none() + if not node: + err = f"[NodeManager] Node {node_id} not found." + raise ValueError(err) + return node @staticmethod @@ -83,19 +61,12 @@ class NodeManager: # 查找Node信息 logger.info("[NodeManager] 获取节点 %s", node_id) - node_collection = MongoDB().get_collection("node") - node = await node_collection.find_one({"_id": node_id}) - if not node: - err = f"[NodeManager] Node {node_id} not found." - logger.error(err) - raise ValueError(err) - - node_data = NodePool.model_validate(node) + node_data = await NodeManager.get_node(node_id) call_id = node_data.call_id # 查找Call信息 logger.info("[NodeManager] 获取Call %s", call_id) - call_class: type[BaseModel] = await Pool().get_call(call_id) + call_class: type[BaseModel] = await Pool().get_call(str(call_id)) if not call_class: err = f"[NodeManager] Call {call_id} 不存在" logger.error(err) diff --git a/apps/services/personal_token.py b/apps/services/personal_token.py new file mode 100644 index 0000000000000000000000000000000000000000..a9e747633fdbe31af5311afa8c1927e0b0a632e1 --- /dev/null +++ b/apps/services/personal_token.py @@ -0,0 +1,80 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""API Key管理""" + +import hashlib +import logging +import uuid + +from sqlalchemy import func, select, update + +from apps.common.postgres import postgres +from apps.models.user import User + +logger = logging.getLogger(__name__) + + +class PersonalTokenManager: + """Personal Token管理""" + + @staticmethod + async def get_user_by_personal_token(personal_token: str) -> str | None: + """ + 根据Personal Token获取用户信息 + + :param personal_token: Personal Token + :return: 用户ID + """ + async with postgres.session() as session: + try: + result = ( + await session.scalars( + select(User.user_sub).where(User.personal_token == personal_token), + ) + ).one_or_none() + except Exception: + logger.exception("[PersonalTokenManager] 根据Personal Token获取用户信息失败") + return None + else: + return result + + + @staticmethod + async def verify_personal_token(personal_token: str) -> bool: + """ + 验证Personal Token,用于FastAPI dependency + + :param personal_token: Personal Token + :return: 验证Personal Token是否成功 + """ + async with postgres.session() as session: + try: + result = ( + await session.scalars( + select(func.count(User.id)).where(User.personal_token == personal_token), + ) + ).one() + except Exception: + logger.exception("[PersonalTokenManager] 验证Personal Token失败") + return False + else: + return result > 0 + + + @staticmethod + async def update_personal_token(user_sub: str) -> str | None: + """ + 更新Personal Token + + :param user_sub: 用户ID + :return: 更新后的Personal Token + """ + personal_token = hashlib.sha256(str(uuid.uuid4().hex).encode()).hexdigest()[:16] + try: + async with postgres.session() as session: + await session.execute( + update(User).where(User.id == user_sub).values(personal_token=personal_token), + ) + except Exception: + logger.exception("[PersonalTokenManager] 更新Personal Token失败") + return None + return personal_token diff --git a/apps/services/service.py b/apps/services/service.py index deea90f0fd91728e407d519450c63f7af888b8e9..5f8a5a1ab1993baee29a36be768f3a6ac49e7aa0 100644 --- a/apps/services/service.py +++ b/apps/services/service.py @@ -7,14 +7,17 @@ from typing import Any import yaml from anyio import Path +from sqlalchemy import select from apps.common.config import config from apps.common.mongo import MongoDB +from apps.common.postgres import postgres from apps.exceptions import InstancePermissionError, ServiceIDError +from apps.models.node import Node +from apps.models.service import Service from apps.scheduler.openapi import ReducedOpenAPISpec from apps.scheduler.pool.loader.openapi import OpenAPILoader from apps.scheduler.pool.loader.service import ServiceLoader -from apps.schemas.collection import User from apps.schemas.enum_var import SearchType from apps.schemas.flow import ( Permission, @@ -22,7 +25,6 @@ from apps.schemas.flow import ( ServiceApiConfig, ServiceMetadata, ) -from apps.schemas.pool import NodePool, ServicePool from apps.schemas.response_data import ServiceApiData, ServiceCardItem logger = logging.getLogger(__name__) @@ -40,7 +42,6 @@ class ServiceCenterManager: page_size: int, ) -> tuple[list[ServiceCardItem], int]: """获取所有服务列表""" - filters = ServiceCenterManager._build_filters({}, search_type, keyword) if keyword else {} service_pools, total_count = await ServiceCenterManager._search_service(filters, page, page_size) fav_service_ids = await ServiceCenterManager._get_favorite_service_ids_by_user(user_sub) services = [ @@ -112,6 +113,7 @@ class ServiceCenterManager: ] return services, total_count + @staticmethod async def create_service( user_sub: str, @@ -121,17 +123,20 @@ class ServiceCenterManager: service_id = str(uuid.uuid4()) # 校验 OpenAPI 规范的 JSON Schema validated_data = await ServiceCenterManager._validate_service_data(data) + # 检查是否存在相同服务 - service_collection = MongoDB().get_collection("service") - db_service = await service_collection.find_one( - { - "name": validated_data.id, - "description": validated_data.description, - }, - ) - if db_service: - msg = "[ServiceCenterManager] 已存在相同名称和描述的服务" - raise ServiceIDError(msg) + async with postgres.session() as session: + db_service = (await session.scalars( + select(Service).where( + Service.name == validated_data.id, + Service.description == validated_data.description, + ), + )).one_or_none() + + if db_service: + msg = "[ServiceCenterManager] 已存在相同名称和描述的服务" + raise ServiceIDError(msg) + # 存入数据库 service_metadata = ServiceMetadata( id=service_id, @@ -147,6 +152,7 @@ class ServiceCenterManager: # 返回服务ID return service_id + @staticmethod async def update_service( user_sub: str, @@ -155,15 +161,18 @@ class ServiceCenterManager: ) -> str: """更新服务""" # 验证用户权限 - service_collection = MongoDB().get_collection("service") - db_service = await service_collection.find_one({"_id": service_id}) - if not db_service: - msg = "Service not found" - raise ServiceIDError(msg) - service_pool_store = ServicePool.model_validate(db_service) - if service_pool_store.author != user_sub: - msg = "Permission denied" - raise InstancePermissionError(msg) + async with postgres.session() as session: + db_service = (await session.scalars( + select(Service).where( + Service.id == service_id, + ), + )).one_or_none() + if not db_service: + msg = "Service not found" + raise ServiceIDError(msg) + if db_service.author != user_sub: + msg = "Permission denied" + raise InstancePermissionError(msg) # 校验 OpenAPI 规范的 JSON Schema validated_data = await ServiceCenterManager._validate_service_data(data) # 存入数据库 @@ -186,12 +195,15 @@ class ServiceCenterManager: ) -> tuple[str, list[ServiceApiData]]: """获取服务API列表""" # 获取服务名称 - service_collection = MongoDB().get_collection("service") - db_service = await service_collection.find_one({"_id": service_id}) - if not db_service: - msg = "Service not found" - raise ServiceIDError(msg) - service_pool_store = ServicePool.model_validate(db_service) + async with postgres.session() as session: + db_service = (await session.scalars( + select(Service).where( + Service.id == service_id, + ), + )).one_or_none() + if not db_service: + msg = "Service not found" + raise ServiceIDError(msg) # 根据 service_id 获取 API 列表 node_collection = MongoDB().get_collection("node") db_nodes = await node_collection.find({"service_id": service_id}).to_list() @@ -246,7 +258,7 @@ class ServiceCenterManager: @staticmethod async def get_service_metadata( user_sub: str, - service_id: str, + service_id: uuid.UUID, ) -> ServiceMetadata: """获取服务元数据""" service_collection = MongoDB().get_collection("service") @@ -267,7 +279,7 @@ class ServiceCenterManager: raise ServiceIDError(msg) metadata_path = ( - Path(config.deploy.data_dir) / "semantics" / "service" / service_id / "metadata.yaml" + Path(config.deploy.data_dir) / "semantics" / "service" / str(service_id) / "metadata.yaml" ) async with await metadata_path.open() as f: metadata_data = yaml.safe_load(await f.read()) @@ -371,24 +383,3 @@ class ServiceCenterManager: raise ValueError(msg) # 校验 OpenAPI 规范的 JSON Schema return await OpenAPILoader().load_dict(data) - - @staticmethod - def _build_filters( - base_filters: dict[str, Any], - search_type: SearchType, - keyword: str, - ) -> dict[str, Any]: - search_filters = [ - {"name": {"$regex": keyword, "$options": "i"}}, - {"description": {"$regex": keyword, "$options": "i"}}, - {"author": {"$regex": keyword, "$options": "i"}}, - ] - if search_type == SearchType.ALL: - base_filters["$or"] = search_filters - elif search_type == SearchType.NAME: - base_filters["name"] = {"$regex": keyword, "$options": "i"} - elif search_type == SearchType.DESCRIPTION: - base_filters["description"] = {"$regex": keyword, "$options": "i"} - elif search_type == SearchType.AUTHOR: - base_filters["author"] = {"$regex": keyword, "$options": "i"} - return base_filters diff --git a/apps/services/tag.py b/apps/services/tag.py new file mode 100644 index 0000000000000000000000000000000000000000..708bcaea23eb500a0b644efde2f2ce4e9635c950 --- /dev/null +++ b/apps/services/tag.py @@ -0,0 +1,97 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""用户标签管理""" + +import logging +from datetime import datetime + +import pytz +from sqlalchemy import select + +from apps.common.postgres import postgres +from apps.models.tag import Tag +from apps.schemas.request_data import PostTagData + +logger = logging.getLogger(__name__) + + +class TagManager: + """用户标签相关操作""" + + @staticmethod + async def get_tag() -> list[Tag]: + """ + 获取所有标签信息 + + :return: 标签信息列表 + """ + async with postgres.session() as session: + tags = (await session.scalars(select(Tag))).all() + return list(tags) + + + @staticmethod + async def get_tag_by_name(name: str) -> Tag | None: + """ + 根据领域名称获取领域信息 + + :param domain_name: 领域名称 + :return: 领域信息 + """ + async with postgres.session() as session: + return (await session.scalars(select(Tag).where(Tag.name == name).limit(1))).one_or_none() + + + @staticmethod + async def add_tag(data: PostTagData) -> None: + """ + 添加领域 + + :param domain_data: 领域信息 + """ + async with postgres.session() as session: + tag = Tag( + name=data.tag, + definition=data.description, + ) + session.add(tag) + await session.commit() + + + @staticmethod + async def update_tag_by_name(data: PostTagData) -> Tag: + """ + 更新领域 + + :param domain_data: 领域信息 + :return: 更新后的领域信息 + """ + async with postgres.session() as session: + tag = (await session.scalars(select(Tag).where(Tag.name == data.tag).limit(1))).one_or_none() + if not tag: + error_msg = f"[TagManager] Tag {data.tag} not found" + logger.error(error_msg) + raise ValueError(error_msg) + + tag.definition = data.description + tag.updated_at = datetime.now(tz=pytz.timezone("Asia/Shanghai")) + session.add(tag) + await session.commit() + return tag + + + @staticmethod + async def delete_tag(data: PostTagData) -> None: + """ + 删除领域 + + :param domain_data: 领域信息 + """ + async with postgres.session() as session: + tag = (await session.scalars(select(Tag).where(Tag.name == data.tag).limit(1))).one_or_none() + if not tag: + error_msg = f"[TagManager] Tag {data.tag} not found" + logger.error(error_msg) + raise ValueError(error_msg) + + await session.delete(tag) + await session.commit() diff --git a/apps/services/task.py b/apps/services/task.py index 98926f506614b205b20ab2f281eb0814a67e7508..5175cff8c94f13c30facb96768437b02147ad458 100644 --- a/apps/services/task.py +++ b/apps/services/task.py @@ -151,7 +151,7 @@ class TaskManager: @staticmethod - async def delete_tasks_by_conversation_id(conversation_id: str) -> None: + async def delete_tasks_by_conversation_id(conversation_id: uuid.UUID) -> None: """通过ConversationID删除Task信息""" mongo = MongoDB() task_collection = mongo.get_collection("task") diff --git a/apps/services/token.py b/apps/services/token.py index 8a41c0ad4861d4c5b359a28a3c3c80cbcb8203c9..6beb1e079af8765974332cd57142c44e33a5f27f 100644 --- a/apps/services/token.py +++ b/apps/services/token.py @@ -2,6 +2,7 @@ """Token Manager""" import logging +import uuid from datetime import UTC, datetime, timedelta import httpx @@ -22,7 +23,12 @@ class TokenManager: """管理用户Token和插件Token""" @staticmethod - async def get_plugin_token(plugin_name: str, session_id: str, access_token_url: str, expire_time: int) -> str: + async def get_plugin_token( + plugin_id: uuid.UUID | None, + session_id: str, + access_token_url: str, + expire_time: int, + ) -> str: """获取插件Token""" user_sub = await SessionManager.get_user(session_id=session_id) if not user_sub: @@ -32,14 +38,14 @@ class TokenManager: mongo = MongoDB() collection = mongo.get_collection("session") token_data = await collection.find_one({ - "_id": f"{plugin_name}_token_{user_sub}", + "_id": f"{plugin_id}_token_{user_sub}", }) if token_data: return token_data["token"] token = await TokenManager.generate_plugin_token( - plugin_name, + plugin_id, session_id, user_sub, access_token_url, @@ -61,7 +67,7 @@ class TokenManager: @staticmethod async def generate_plugin_token( - plugin_name: str, + plugin_name: uuid.UUID, session_id: str, user_sub: str, access_token_url: str, diff --git a/apps/services/user.py b/apps/services/user.py index 8a893d09ab693246397ac7ac11fada8742fbee89..e9fee1895a7c66161757b4495cbfbc945c13bcaa 100644 --- a/apps/services/user.py +++ b/apps/services/user.py @@ -2,10 +2,13 @@ """用户 Manager""" import logging -from datetime import UTC, datetime +from datetime import datetime -from apps.common.mongo import MongoDB -from apps.schemas.collection import User +import pytz +from sqlalchemy import select + +from apps.common.postgres import postgres +from apps.models.user import User from .conversation import ConversationManager @@ -16,92 +19,75 @@ class UserManager: """用户相关操作""" @staticmethod - async def add_userinfo(user_sub: str) -> None: - """ - 向数据库中添加用户信息 - - :param user_sub: 用户sub + async def list_user(n: int = 10, page: int = 1) -> list[User]: """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - await user_collection.insert_one(User( - _id=user_sub, - ).model_dump(by_alias=True)) + 获取所有用户 - @staticmethod - async def get_all_user_sub() -> list[str]: + :param n: 每页数量 + :param page: 页码 + :return: 所有用户列表 """ - 获取所有用户的sub + async with postgres.session() as session: + users = (await session.scalars(select(User).offset((page - 1) * n).limit(n))).all() + return list(users) - :return: 所有用户的sub列表 - """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - return [user["_id"] async for user in user_collection.find({}, {"_id": 1})] @staticmethod - async def get_userinfo_by_user_sub(user_sub: str) -> User | None: + async def get_user(user_sub: str) -> User | None: """ 根据用户sub获取用户信息 :param user_sub: 用户sub :return: 用户信息 """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - user_data = await user_collection.find_one({"_id": user_sub}) - return User(**user_data) if user_data else None + async with postgres.session() as session: + return ( + await session.scalars(select(User).where(User.user_sub == user_sub)) + ).one_or_none() + @staticmethod - async def update_userinfo_by_user_sub(user_sub: str, *, refresh_revision: bool = False) -> bool: + async def update_user(user_sub: str) -> None: """ 根据用户sub更新用户信息 :param user_sub: 用户sub - :param refresh_revision: 是否刷新revision - :return: 更新后的用户信息 """ - mongo = MongoDB() - user_data = await UserManager.get_userinfo_by_user_sub(user_sub) - if not user_data: - await UserManager.add_userinfo(user_sub) - return True - - update_dict = { - "$set": {"login_time": round(datetime.now(UTC).timestamp(), 3)}, - } - - if refresh_revision: - update_dict["$set"]["status"] = "init" # type: ignore[assignment] - user_collection = mongo.get_collection("user") - result = await user_collection.update_one({"_id": user_sub}, update_dict) - return result.modified_count > 0 + async with postgres.session() as session: + user = ( + await session.scalars(select(User).where(User.user_sub == user_sub)) + ).one_or_none() + if not user: + user = User( + user_sub=user_sub, + is_active=True, + is_whitelisted=False, + credit=0, + ) + session.add(user) + await session.commit() + return + + user.last_login = datetime.now(tz=pytz.timezone("Asia/Shanghai")) + await session.commit() - @staticmethod - async def query_userinfo_by_login_time(login_time: float) -> list[str]: - """ - 根据登录时间获取用户sub - - :param login_time: 登录时间 - :return: 用户sub列表 - """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - return [user["_id"] async for user in user_collection.find({"login_time": {"$lt": login_time}}, {"_id": 1})] @staticmethod - async def delete_userinfo_by_user_sub(user_sub: str) -> None: + async def delete_user(user_sub: str) -> None: """ 根据用户sub删除用户信息 :param user_sub: 用户sub """ - mongo = MongoDB() - user_collection = mongo.get_collection("user") - result = await user_collection.find_one_and_delete({"_id": user_sub}) - if not result: - return - result = User.model_validate(result) - - for conv_id in result.conversations: - await ConversationManager.delete_conversation_by_conversation_id(user_sub, conv_id) + async with postgres.session() as session: + user = ( + await session.scalars(select(User).where(User.user_sub == user_sub)) + ).one_or_none() + if not user: + return + + await session.delete(user) + await session.commit() + + for conv_id in result.conversations: + await ConversationManager.delete_conversation_by_conversation_id(user_sub, conv_id) diff --git a/apps/services/user_domain.py b/apps/services/user_domain.py deleted file mode 100644 index 8ce47645dc7e72bd3d78fca8bb83e96837483bbe..0000000000000000000000000000000000000000 --- a/apps/services/user_domain.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. -"""用户画像管理""" - -import logging - -from apps.common.mongo import MongoDB -from apps.schemas.collection import UserDomainData - -logger = logging.getLogger(__name__) - - -class UserDomainManager: - """用户画像管理""" - - @staticmethod - async def get_user_domain_by_user_sub_and_topk(user_sub: str, topk: int) -> list[str]: - """根据用户ID,查询用户最常涉及的n个领域""" - mongo = MongoDB() - user_collection = mongo.get_collection("user") - domains = await user_collection.aggregate( - [ - {"$project": {"_id": 1, "domains": 1}}, - {"$match": {"_id": user_sub}}, - {"$unwind": "$domains"}, - {"$sort": {"domain_count": -1}}, - {"$limit": topk}, - ], - ) - - return [UserDomainData.model_validate(domain).name async for domain in domains] - - @staticmethod - async def update_user_domain_by_user_sub_and_domain_name(user_sub: str, domain_name: str) -> None: - """增加特定用户特定领域的频次""" - mongo = MongoDB() - domain_collection = mongo.get_collection("domain") - user_collection = mongo.get_collection("user") - - # 检查领域是否存在 - domain = await domain_collection.find_one({"_id": domain_name}) - if not domain: - # 领域不存在,则创建领域 - await domain_collection.insert_one({"_id": domain_name, "domain_description": ""}) - await user_collection.update_one( - {"_id": user_sub, "domains.name": domain_name}, {"$inc": {"domains.$.count": 1}}, - ) diff --git a/apps/services/user_tag.py b/apps/services/user_tag.py new file mode 100644 index 0000000000000000000000000000000000000000..5b5caf57726fcd42744feb65f8909be4db5d9066 --- /dev/null +++ b/apps/services/user_tag.py @@ -0,0 +1,69 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2023-2025. All rights reserved. +"""用户画像管理""" + +import logging + +from sqlalchemy import select + +from apps.common.postgres import postgres +from apps.models.tag import Tag +from apps.models.user import UserTag + +logger = logging.getLogger(__name__) + + +class UserTagManager: + """用户画像管理""" + + @staticmethod + async def get_user_domain_by_user_sub_and_topk(user_sub: str, topk: int) -> list[str]: + """根据用户ID,查询用户最常涉及的n个领域""" + async with postgres.session() as session: + user_domains = ( + await session.scalars( + select(UserTag).where( + UserTag.user_sub == user_sub, + ).order_by(UserTag.count.desc()).limit(topk), + ) + ).all() + + tags = [] + for user_domain in user_domains: + result = ( + await session.scalars( + select(Tag).where(Tag.id == user_domain.tag), + ) + ).one_or_none() + if result: + tags.append(result) + return [tag.name for tag in tags] + + + @staticmethod + async def update_user_domain_by_user_sub_and_domain_name(user_sub: str, domain_name: str) -> None: + """增加特定用户特定领域的频次""" + async with postgres.session() as session: + tag = ( + await session.scalars( + select(Tag).where(Tag.name == domain_name), + ) + ).one_or_none() + if not tag: + err = f"[UserTagManager] Tag {domain_name} not found" + logger.error(err) + raise ValueError(err) + + user_domain = ( + await session.scalars( + select(UserTag).where( + UserTag.user_sub == user_sub, + UserTag.tag == tag.id, + ), + ) + ).one_or_none() + if not user_domain: + user_domain = UserTag(user_sub=user_sub, tag=tag.id, count=1) + session.add(user_domain) + else: + user_domain.count += 1 + await session.commit() diff --git a/deploy/chart/authhub/configs/backend/aops-config.yml b/deploy/chart/authhub/configs/backend/aops-config.yml index 8ea4d5972ea12674dce69f675cb44422577fa570..15ecd0948888cc2b18445507e35cb6e30f06e872 100644 --- a/deploy/chart/authhub/configs/backend/aops-config.yml +++ b/deploy/chart/authhub/configs/backend/aops-config.yml @@ -13,7 +13,7 @@ infrastructure: password: ${redis-password} include: "/etc/aops/conf.d" -domain: {{ .Values.domain.authhub }} +domain: {{ default "127.0.0.1" .Values.domain.authhub }} services: log: diff --git a/deploy/chart/authhub/values.yaml b/deploy/chart/authhub/values.yaml index 6421c168e9c493b4ea070cac360c0e06c716d497..59caf253d5725ae69a3d9d88f55c926ffca18063 100644 --- a/deploy/chart/authhub/values.yaml +++ b/deploy/chart/authhub/values.yaml @@ -15,8 +15,8 @@ storage: mysql: domain: - # AuthHub域名,默认为authhub.eulercopilot.local。单机部署时,服务基于Host进行区分,无法使用IP地址 - authhub: + # AuthHub域名,默认为127.0.0.1 + authhub: 127.0.0.1 # 部署AuthHub本地鉴权服务 authhub: @@ -37,13 +37,9 @@ authhub: # Service设置 service: # Service类型,例如NodePort - type: + type: NodePort # 当类型为NodePort时,填写主机的端口号 - nodePort: - # Ingress设置 - ingress: - # Ingress前缀,默认为/ - prefix: + nodePort: 30081 backend: # [必填] 是否部署AuthHub后端服务 diff --git a/deploy/chart/euler_copilot/configs/framework/config.toml b/deploy/chart/euler_copilot/configs/framework/config.toml index 929cc6a564d9b876e98ac2b6ca2a46234bd1c441..02f6d151c3280675fbb9c179d8a8e96d503d69e0 100644 --- a/deploy/chart/euler_copilot/configs/framework/config.toml +++ b/deploy/chart/euler_copilot/configs/framework/config.toml @@ -6,14 +6,14 @@ data_dir = '/app/data' [login] provider = 'authhub' [login.settings] -host = 'https://{{ default "authhub.eulercopilot.local" .Values.domain.authhub }}' +host = 'http://{{ default "127.0.0.1" .Values.login.host }}:{{ default "30081" .Values.login.authhub_web_nodePort }}' host_inner = 'http://authhub-backend-service.{{ .Release.Namespace }}.svc.cluster.local:11120' -login_api = 'https://{{ default "www.eulercopilot.local" .Values.domain.euler_copilot }}/api/auth/login' +login_api = 'http://{{ default "127.0.0.1" .Values.login.host }}:{{ default 30080 .Values.euler_copilot.web.service.nodePort }}/api/auth/login' app_id = '${clientId}' app_secret = '${clientSecret}' [fastapi] -domain = '{{ default "www.eulercopilot.local" .Values.domain.euler_copilot }}' +domain = '{{ default "127.0.0.1" .Values.login.host }}' [security] half_key1 = '${halfKey1}' diff --git a/deploy/chart/euler_copilot/templates/framework/framework-storage.yaml b/deploy/chart/euler_copilot/templates/framework/framework-storage.yaml index b6b97ee83ae623e93af7bdd3c22d46097996df56..25047e1b86d93d180bfc876aa6647ad15b64e675 100644 --- a/deploy/chart/euler_copilot/templates/framework/framework-storage.yaml +++ b/deploy/chart/euler_copilot/templates/framework/framework-storage.yaml @@ -28,4 +28,4 @@ spec: requests: storage: {{ default "5Gi" .Values.storage.frameworkSemantics.size }} volumeName: framework-semantics -{{- end -}} \ No newline at end of file +{{- end -}} diff --git a/deploy/chart/euler_copilot/templates/rag-web/rag-web-config.yaml b/deploy/chart/euler_copilot/templates/rag-web/rag-web-config.yaml index 75cf602658fb0deaebc615334a414ef974d8ddf1..cc7ce1ba799bf6ebcc41faeafe202e2e2f02ce47 100644 --- a/deploy/chart/euler_copilot/templates/rag-web/rag-web-config.yaml +++ b/deploy/chart/euler_copilot/templates/rag-web/rag-web-config.yaml @@ -6,6 +6,5 @@ metadata: namespace: {{ .Release.Namespace }} data: .env: |- - PROD=enabled DATA_CHAIN_BACEND_URL=http://rag-service.{{ .Release.Namespace }}.svc.cluster.local:9988 {{- end -}} diff --git a/deploy/chart/euler_copilot/values.yaml b/deploy/chart/euler_copilot/values.yaml index edf15b99d9ce58c0fc654f23de2a2507731072e5..086fd9136f0451112216f551a78983d1f6e13288 100644 --- a/deploy/chart/euler_copilot/values.yaml +++ b/deploy/chart/euler_copilot/values.yaml @@ -51,6 +51,10 @@ models: # 登录设置 login: + # [必填] 服务器的IP,不填则默认为127.0.0.1 + host: + # [必填] authhub的web端nodePort,不填则默认为30081 + authhub_web_nodePort: # 客户端ID设置,仅在type为authhub时有效 client: # [必填] 客户端ID @@ -58,13 +62,6 @@ login: # [必填] 客户端密钥 secret: -# 域名设置 -domain: - # 用于EulerCopilot的域名;默认为www.eulercopilot.local - euler_copilot: - # 部署authhub时使用的域名;默认为authhub.eulercopilot.local - authhub: - # 存储设置 storage: # 语义接口 diff --git a/deploy/scripts/7-install-authhub/install_authhub.sh b/deploy/scripts/7-install-authhub/install_authhub.sh index 9eadcd49a5e2806a11e09339ec12d72e23110e6f..38a87bc877952c07c28274b6caf5f7d96f1e16c5 100755 --- a/deploy/scripts/7-install-authhub/install_authhub.sh +++ b/deploy/scripts/7-install-authhub/install_authhub.sh @@ -91,20 +91,23 @@ uninstall_authhub() { echo -e "${GREEN}资源清理完成${NC}" } -get_user_input() { - echo -e "${BLUE}请输入 Authhub 的域名配置(直接回车使用默认值 authhub.eulercopilot.local):${NC}" - read -p "Authhub 的前端域名: " authhub_domain - - if [[ -z "$authhub_domain" ]]; then - authhub_domain="authhub.eulercopilot.local" - echo -e "${GREEN}使用默认域名:${authhub_domain}${NC}" +get_authhub_address() { + local default_address="127.0.0.1" + + echo -e "${BLUE}请输入 Authhub 的访问地址(IP或域名,直接回车使用默认值 ${default_address}):${NC}" + read -p "Authhub 地址: " authhub_address + + # 处理空输入情况 + if [[ -z "$authhub_address" ]]; then + authhub_address="$default_address" + echo -e "${GREEN}使用默认地址:${authhub_address}${NC}" else - if ! [[ "${authhub_domain}" =~ ^([a-zA-Z0-9-]+\.)*[a-zA-Z0-9-]+\.[a-zA-Z]{2,}$ ]]; then - echo -e "${RED}错误:输入的AuthHub域名格式不正确${NC}" - exit 1 - fi - echo -e "${GREEN}输入域名:${authhub_domain}${NC}" + echo -e "${GREEN}输入地址:${authhub_address}${NC}" fi + + # 返回用户输入的地址 + echo "$authhub_address" + return 0 } helm_install() { @@ -119,7 +122,7 @@ helm_install() { echo -e "${BLUE}正在安装 authhub...${NC}" helm upgrade --install authhub -n euler-copilot ./authhub \ --set globals.arch="$arch" \ - --set domain.authhub="${authhub_domain}" || { + --set domain.authhub="${authhub_address}" || { echo -e "${RED}Helm 安装 authhub 失败!${NC}" return 1 } @@ -169,7 +172,7 @@ main() { arch=$(get_architecture) || exit 1 create_namespace || exit 1 uninstall_authhub || exit 1 - get_user_input || exit 1 + get_authhub_address|| exit 1 helm_install "$arch" || exit 1 check_pods_status || { echo -e "${RED}部署失败:Pod状态检查未通过!${NC}" @@ -179,7 +182,7 @@ main() { echo -e "\n${GREEN}=========================" echo -e "Authhub 部署完成!" echo -e "查看pod状态:kubectl get pod -n euler-copilot" - echo -e "Authhub登录地址为: https://${authhub_domain}" + echo -e "Authhub登录地址为: https://${authhub_address}:30081" echo -e "默认账号密码: administrator/changeme" echo -e "=========================${NC}" } diff --git a/deploy/scripts/8-install-EulerCopilot/install_eulercopilot.sh b/deploy/scripts/8-install-EulerCopilot/install_eulercopilot.sh index 1ff60f727757ff517e1fb2f3a9fb183b977bcabe..a3d08766b953df532f13944067596eb19f13b4d2 100755 --- a/deploy/scripts/8-install-EulerCopilot/install_eulercopilot.sh +++ b/deploy/scripts/8-install-EulerCopilot/install_eulercopilot.sh @@ -10,7 +10,11 @@ BLUE='\e[34m' NC='\e[0m' # 恢复默认颜色 NAMESPACE="euler-copilot" -PLUGINS_DIR="/home/eulercopilot/semantics" +PLUGINS_DIR="/var/lib/eulercopilot" + +# 全局变量声明 +client_id="" +client_secret="" SCRIPT_PATH="$( cd "$(dirname "${BASH_SOURCE[0]}")" >/dev/null 2>&1 @@ -75,57 +79,65 @@ get_network_ip() { echo "$host" } - -# 处理域名 -get_domain_input() { +# 处理域名/主机输入 +get_host_input() { # 从环境变量读取或使用默认值 - eulercopilot_domain=${EULERCOPILOT_DOMAIN:-"www.eulercopilot.local"} - authhub_domain=${AUTHHUB_DOMAIN:-"authhub.eulercopilot.local"} - - # 非交互模式直接使用默认值 + eulercopilot_host=${EULERCOPILOT_HOST:-"127.0.0.1"} + + # 交互模式允许用户输入 if [ -t 0 ]; then # 仅在交互式终端显示提示 - echo -e "${BLUE}请输入 EulerCopilot 域名(默认:$eulercopilot_domain):${NC}" + echo -e "${BLUE}请输入 EulerCopilot 的访问IP或域名(默认:$eulercopilot_host):${NC}" read -p "> " input_euler - [ -n "$input_euler" ] && eulercopilot_domain=$input_euler - - echo -e "${BLUE}请输入 Authhub 域名(默认:$authhub_domain):${NC}" - read -p "> " input_auth - [ -n "$input_auth" ] && authhub_domain=$input_auth + [ -n "$input_euler" ] && eulercopilot_host=$input_euler fi - # 统一验证域名格式 - local domain_regex='^([a-zA-Z0-9-]{1,63}\.)+[a-zA-Z]{2,}$' - if ! [[ $eulercopilot_domain =~ $domain_regex ]]; then - echo -e "${RED}错误:EulerCopilot域名格式不正确${NC}" >&2 + # 验证输入格式(支持IP和域名) + local ip_regex='^([0-9]{1,3}\.){3}[0-9]{1,3}$' + local host_regex='^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9-]*[A-Za-z0-9])$' + + if ! [[ $eulercopilot_host =~ $ip_regex ]] && ! [[ $eulercopilot_host =~ $host_regex ]]; then + echo -e "${RED}错误:无效的IP地址或域名格式${NC}" >&2 exit 1 fi - if ! [[ $authhub_domain =~ $domain_regex ]]; then - echo -e "${RED}错误:AuthHub域名格式不正确${NC}" >&2 - exit 1 + + # 如果是IP地址,检查每部分是否在0-255范围内 + if [[ $eulercopilot_host =~ $ip_regex ]]; then + local IFS='.' + read -ra ip_parts <<< "$eulercopilot_host" + for part in "${ip_parts[@]}"; do + if (( part < 0 || part > 255 )); then + echo -e "${RED}错误:IP地址各部分必须在0-255之间${NC}" >&2 + exit 1 + fi + done fi echo -e "${GREEN}使用配置:" - echo "EulerCopilot域名: $eulercopilot_domain" - echo "Authhub域名: $authhub_domain" + echo "EulerCopilot的访问地址: $eulercopilot_host" + echo -e "${NC}" } get_client_info_auto() { - get_domain_input + get_host_input + # 创建临时文件 + local temp_file + temp_file=$(mktemp) + # 直接调用Python脚本并传递域名参数 - python3 "${DEPLOY_DIR}/scripts/9-other-script/get_client_id_and_secret.py" "${eulercopilot_domain}" > client_info.tmp 2>&1 + python3 "${DEPLOY_DIR}/scripts/9-other-script/get_client_id_and_secret.py" "${eulercopilot_host}" > "$temp_file" 2>&1 # 检查Python脚本执行结果 if [ $? -ne 0 ]; then echo -e "${RED}错误:Python脚本执行失败${NC}" - cat client_info.tmp - rm -f client_info.tmp + cat "$temp_file" + rm -f "$temp_file" return 1 fi # 提取凭证信息 - client_id=$(grep "client_id: " client_info.tmp | awk '{print $2}') - client_secret=$(grep "client_secret: " client_info.tmp | awk '{print $2}') - rm -f client_info.tmp + client_id=$(grep "client_id: " "$temp_file" | awk '{print $2}') + client_secret=$(grep "client_secret: " "$temp_file" | awk '{print $2}') + rm -f "$temp_file" # 验证结果 if [ -z "$client_id" ] || [ -z "$client_secret" ]; then @@ -141,7 +153,6 @@ get_client_info_auto() { } get_client_info_manual() { - # 非交互模式直接使用默认值 if [ -t 0 ]; then # 仅在交互式终端显示提示 echo -e "${BLUE}请输入 Client ID: 域名(端点信息:Client ID): ${NC}" @@ -163,8 +174,7 @@ check_directories() { echo -e "${BLUE}检查语义接口目录是否存在...${NC}" >&2 # 定义父目录和子目录列表 - local PLUGINS_DIR="/home/eulercopilot/semantics" - local REQUIRED_OWNER="1001:1001" + local REQUIRED_OWNER="root:root" # 检查并创建父目录 if [ -d "${PLUGINS_DIR}" ]; then @@ -173,7 +183,7 @@ check_directories() { local current_owner=$(stat -c "%u:%g" "${PLUGINS_DIR}" 2>/dev/null) if [ "$current_owner" != "$REQUIRED_OWNER" ]; then echo -e "${YELLOW}当前目录权限: ${current_owner},正在修改为 ${REQUIRED_OWNER}...${NC}" >&2 - if chown 1001:1001 "${PLUGINS_DIR}"; then + if chown root:root "${PLUGINS_DIR}"; then echo -e "${GREEN}目录权限已成功修改为 ${REQUIRED_OWNER}${NC}" >&2 else echo -e "${RED}错误:无法修改目录权限到 ${REQUIRED_OWNER}${NC}" >&2 @@ -185,7 +195,7 @@ check_directories() { else if mkdir -p "${PLUGINS_DIR}"; then echo -e "${GREEN}目录已创建:${PLUGINS_DIR}${NC}" >&2 - chown 1001:1001 "${PLUGINS_DIR}" # 设置父目录所有者 + chown root:root "${PLUGINS_DIR}" # 设置父目录所有者 else echo -e "${RED}错误:无法创建目录 ${PLUGINS_DIR}${NC}" >&2 exit 1 @@ -246,8 +256,8 @@ modify_yaml() { set_args+=( "--set" "login.client.id=${client_id}" "--set" "login.client.secret=${client_secret}" - "--set" "domain.authhub=${authhub_domain}" - "--set" "domain.euler_copilot=${eulercopilot_domain}" + "--set" "login.host=${eulercopilot_host}" + "--set" "login.authhub_web_nodePort=30081" ) # 如果不需要保留模型配置,则添加模型相关的参数 @@ -256,8 +266,8 @@ modify_yaml() { "--set" "models.answer.endpoint=http://$host:11434/v1" "--set" "models.answer.key=sk-123456" "--set" "models.answer.name=deepseek-llm-7b-chat:latest" - "--set" "models.functionCall.backend=ollama" - "--set" "models.embedding.type=openai" + "--set" "models.functionCall.backend=ollama" + "--set" "models.embedding.type=openai" "--set" "models.embedding.endpoint=http://$host:11434/v1" "--set" "models.embedding.key=sk-123456" "--set" "models.embedding.name=bge-m3:latest" @@ -284,6 +294,18 @@ enter_chart_directory() { } } +pre_install_checks() { + # 检查kubectl和helm是否可用 + command -v kubectl >/dev/null 2>&1 || error_exit "kubectl未安装" + command -v helm >/dev/null 2>&1 || error_exit "helm未安装" + + # 检查Kubernetes集群连接 + kubectl cluster-info >/dev/null 2>&1 || error_exit "无法连接到Kubernetes集群" + + # 检查必要的存储类 + kubectl get storageclasses >/dev/null 2>&1 || error_exit "无法获取存储类信息" +} + # 执行安装 execute_helm_install() { local arch=$1 @@ -337,56 +359,76 @@ check_pods_status() { done } +# 修改main函数 main() { + pre_install_checks + local arch host arch=$(get_architecture) || exit 1 host=$(get_network_ip) || exit 1 + uninstall_eulercopilot + if ! get_client_info_auto; then - echo -e "${YELLOW}需要手动登录Authhub域名并创建应用,获取client信息${NC}" get_client_info_manual fi + check_directories - - # 处理是否保留模型配置 - local preserve_models="N" # 非交互模式默认覆盖 - if [ -t 0 ]; then # 仅在交互式终端显示提示 - echo -e "${BLUE}是否保留现有的模型配置?(Y/n) ${NC}" - read -p "> " input_preserve - preserve_models=${input_preserve:-Y} + + # 交互式提示优化 + if [ -t 0 ]; then + echo -e "${YELLOW}是否保留现有的模型配置?${NC}" + echo -e " ${BLUE}Y) 保留现有配置${NC}" + echo -e " ${BLUE}n) 使用默认配置${NC}" + while true; do + read -p "请选择(Y/N): " input_preserve + case "${input_preserve:-Y}" in + [YyNn]) preserve_models=${input_preserve:-Y}; break ;; + *) echo -e "${RED}无效输入,请选择Y或n${NC}" ;; + esac + done + else + preserve_models="N" fi - - modify_yaml "$host" "$preserve_models" + + echo -e "${BLUE}开始修改YAML配置...${NC}" + modify_yaml "$eulercopilot_host" "$preserve_models" + + echo -e "${BLUE}开始Helm安装...${NC}" execute_helm_install "$arch" if check_pods_status; then - echo -e "${GREEN}所有组件已就绪!${NC}" + echo -e "${GREEN}所有组件已就绪!${NC}" + show_success_message "$host" "$arch" else - echo -e "${YELLOW}注意:部分组件尚未就绪,建议进行排查${NC}" >&2 + echo -e "${YELLOW}部分组件尚未就绪,建议进行排查!${NC}" fi +} + +# 添加安装成功信息显示函数 +show_success_message() { + local host=$1 + local arch=$2 + - # 最终部署信息输出 echo -e "\n${GREEN}==================================================${NC}" - echo -e "${GREEN} EulerCopilot 部署完成! ${NC}" + echo -e "${GREEN} EulerCopilot 部署完成! ${NC}" echo -e "${GREEN}==================================================${NC}" - echo -e "${YELLOW}访问地址:" - echo -e "EulerCopilot UI:\thttps://${eulercopilot_domain}" - echo -e "AuthHub 管理界面:\thttps://${authhub_domain}" - echo -e "\n${YELLOW}系统信息:" - echo -e "业务网络IP:\t${host}" - echo -e "系统架构:\t$(uname -m) (识别为: ${arch})" - echo -e "插件目录:\t${PLUGINS_DIR}" - echo -e "Chart目录:\t${DEPLOY_DIR}/chart/${NC}" - echo -e "" - echo -e "${BLUE}操作指南:" + + echo -e "${YELLOW}访问信息:${NC}" + echo -e "EulerCopilot UI: http://${eulercopilot_host}:30080" + echo -e "AuthHub 管理界面: http://${eulercopilot_host}:30081" + + echo -e "\n${YELLOW}系统信息:${NC}" + echo -e "内网IP: ${host}" + echo -e "系统架构: $(uname -m) (识别为: ${arch})" + echo -e "插件目录: ${PLUGINS_DIR}" + echo -e "Chart目录: ${DEPLOY_DIR}/chart/" + + echo -e "${BLUE}操作指南:${NC}" echo -e "1. 查看集群状态: kubectl get all -n $NAMESPACE" echo -e "2. 查看实时日志: kubectl logs -n $NAMESPACE -f deployment/$NAMESPACE" echo -e "3. 查看POD状态:kubectl get pods -n $NAMESPACE" - echo -e "4. 查看数据库并使用base64密码:kubectl edit secret $NAMESPACE-system -n $NAMESPACE" - echo -e "5. 添加域名解析(示例):" - echo -e " ${host} ${eulercopilot_domain}" - echo -e " ${host} ${authhub_domain}${NC}" } -main - +main "$@" diff --git a/deploy/scripts/9-other-script/get_client_id_and_secret.py b/deploy/scripts/9-other-script/get_client_id_and_secret.py index 9a31022c91cfd9e5e13ea0ca80140902dacdb79e..1ff7da3c35cc54887496b5d5ee15ba7edf996c6b 100755 --- a/deploy/scripts/9-other-script/get_client_id_and_secret.py +++ b/deploy/scripts/9-other-script/get_client_id_and_secret.py @@ -29,8 +29,8 @@ def get_service_cluster_ip(namespace, service_name): service_info = json.loads(result.stdout.decode()) return service_info['spec'].get('clusterIP', 'No Cluster IP found') -def get_user_token(auth_hub_url, username="administrator", password="changeme"): - url = auth_hub_url + "/oauth2/manager-login" +def get_user_token(authhub_web_url, username="administrator", password="changeme"): + url = authhub_web_url + "/oauth2/manager-login" response = requests.post( url, json={"password": password, "username": username}, @@ -41,29 +41,9 @@ def get_user_token(auth_hub_url, username="administrator", password="changeme"): response.raise_for_status() return response.json()["data"]["user_token"] -def register_app(auth_hub_url, user_token, client_name, client_url, redirect_urls): - response = requests.post( - auth_hub_url + "/oauth2/applications/register", - json={ - "client_name": client_name, - "client_uri": client_url, - "redirect_uris": redirect_urls, - "register_callback_uris": [], # 修复参数名中的空格 - "logout_callback_uris": [], # 修复参数名中的空格 - "skip_authorization": True, - "scope": ["email", "phone", "username", "openid", "offline_access"], - "grant_types": ["authorization_code"], - "response_types": ["code"], - "token_endpoint_auth_method": "none" - }, - headers={"Authorization": user_token, "Content-Type": "application/json"}, - verify=False - ) - response.raise_for_status() - -def get_client_secret(auth_hub_url, user_token, target_client_name): +def find_existing_app(authhub_web_url, user_token, client_name): response = requests.get( - auth_hub_url + "/oauth2/applications", + authhub_web_url + "/oauth2/applications", headers={"Authorization": user_token, "Content-Type": "application/json"}, timeout=10 ) @@ -84,18 +64,75 @@ def get_client_secret(auth_hub_url, user_token, target_client_name): app.get("client_info", {}).get("client_name") ] - if any(str(name).lower() == target_client_name.lower() for name in candidate_names if name): - return { - "client_id": app["client_info"]["client_id"], - "client_secret": app["client_info"]["client_secret"] - } - - raise ValueError(f"未找到匹配应用,请检查 client_name 是否准确(尝试使用全小写名称)") + if any(str(name).lower() == client_name.lower() for name in candidate_names if name): + return app["client_info"]["client_id"] + return None + +def register_or_update_app(authhub_web_url, user_token, client_name, client_url, redirect_urls): + client_id = find_existing_app(authhub_web_url, user_token, client_name) + + if client_id: + # 更新现有应用 + print(f"发现已存在应用 [名称: {client_name}], 正在更新...") + url = f"{authhub_web_url}/oauth2/applications/{client_id}" + response = requests.put( + url, + json={ + "client_uri": client_url, + "redirect_uris": redirect_urls, + "register_callback_uris": [], + "logout_callback_uris": [], + "skip_authorization": True, + "scope": ["email", "phone", "username", "openid", "offline_access"], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none" + }, + headers={"Authorization": user_token, "Content-Type": "application/json"}, + verify=False + ) + response.raise_for_status() + return response.json()["data"] + else: + # 注册新应用 + print(f"未找到已存在应用 [名称: {client_name}], 正在注册新应用...") + response = requests.post( + authhub_web_url + "/oauth2/applications/register", + json={ + "client_name": client_name, + "client_uri": client_url, + "redirect_uris": redirect_urls, + "register_callback_uris": [], + "logout_callback_uris": [], + "skip_authorization": True, + "scope": ["email", "phone", "username", "openid", "offline_access"], + "grant_types": ["authorization_code"], + "response_types": ["code"], + "token_endpoint_auth_method": "none" + }, + headers={"Authorization": user_token, "Content-Type": "application/json"}, + verify=False + ) + response.raise_for_status() + return response.json()["data"] + +def get_client_secret(authhub_web_url, user_token, client_id): + response = requests.get( + f"{authhub_web_url}/oauth2/applications/{client_id}", + headers={"Authorization": user_token, "Content-Type": "application/json"}, + timeout=10 + ) + response.raise_for_status() + app_data = response.json() + return { + "client_id": app_data["data"]["client_info"]["client_id"], + "client_secret": app_data["data"]["client_info"]["client_secret"] + } if __name__ == "__main__": # 解析命令行参数 parser = argparse.ArgumentParser() - parser.add_argument("eulercopilot_domain", help="EulerCopilot域名(例如:example.com)") + parser.add_argument("host_ip", help="主机IP地址(例如:192.168.1.100)") args = parser.parse_args() # 获取服务信息 @@ -103,25 +140,25 @@ if __name__ == "__main__": service_name = "authhub-web-service" print(f"正在查询服务信息: [命名空间: {namespace}] [服务名: {service_name}]") cluster_ip = get_service_cluster_ip(namespace, service_name) - auth_hub_url = f"http://{cluster_ip}:8000" + authhub_web_url = f"http://{cluster_ip}:8000" # 生成固定URL - client_url = f"https://{args.eulercopilot_domain}" - redirect_urls = [f"https://{args.eulercopilot_domain}/api/auth/login"] + client_url = f"http://{args.host_ip}:30080" + redirect_urls = [f"http://{args.host_ip}:30080/api/auth/login"] client_name = "EulerCopilot" # 设置固定默认值 # 认证流程 try: print("\n正在获取用户令牌...") - user_token = get_user_token(auth_hub_url) + user_token = get_user_token(authhub_web_url) print("✓ 用户令牌获取成功") - print(f"\n正在注册应用 [名称: {client_name}]...") - register_app(auth_hub_url, user_token, client_name, client_url, redirect_urls) - print("✓ 应用注册成功") + print(f"\n正在处理应用 [名称: {client_name}]...") + app_info = register_or_update_app(authhub_web_url, user_token, client_name, client_url, redirect_urls) + print("✓ 应用处理成功") - print(f"\n正在查询客户端凭证 [名称: {client_name}]...") - client_info = get_client_secret(auth_hub_url, user_token, client_name) + print(f"\n正在查询客户端凭证 [ID: {app_info['client_info']['client_id']}]...") + client_info = get_client_secret(authhub_web_url, user_token, app_info["client_info"]["client_id"]) print("\n✓ 认证信息获取成功:") print(f"client_id: {client_info['client_id']}") diff --git a/docs/zh/intelligent_foundation/sysHAX/deploy_guide/sysHax-deployment-guide.md b/docs/zh/intelligent_foundation/sysHAX/deploy_guide/sysHax-deployment-guide.md index e43afe39657293b238f6a69216c4c0067774e28c..a8573f7b4368a2d3a0a2e85928507b1d6425e516 100644 --- a/docs/zh/intelligent_foundation/sysHAX/deploy_guide/sysHax-deployment-guide.md +++ b/docs/zh/intelligent_foundation/sysHAX/deploy_guide/sysHax-deployment-guide.md @@ -179,12 +179,14 @@ dnf install sysHAX ```shell syshax init -syshax config gpu.port 8001 -syshax config cpu.port 8002 -syshax config conductor.port 8010 -syshax config model ds-32b +syshax config services.gpu.port 8001 +syshax config services.cpu.port 8002 +syshax config services.conductor.port 8010 +syshax config models.default ds-32b ``` +此外,也可以通过 `syshax config --help` 来查看全部配置命令。 + 配置完成后,通过如下命令启动sysHAX服务: ```shell diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/MCP\344\275\277\347\224\250\346\214\207\345\215\227.md" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/MCP\344\275\277\347\224\250\346\214\207\345\215\227.md" new file mode 100644 index 0000000000000000000000000000000000000000..4a2039f6a910a729fec02f7863fed4cdc169185f --- /dev/null +++ "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/MCP\344\275\277\347\224\250\346\214\207\345\215\227.md" @@ -0,0 +1,101 @@ +# 工作流中的MCP使用手册 + +# 前置操作 + +若用户发现在插件中心无法创建MCP服务,此时需要去后端手动给该用户添加管理员权限. + +``` +# 获取mongo-password的密码 +kubectl get secret -n euler-copilot euler-copilot-database -oyaml +#将获取到的密码使用base64进行解密 +echo {mongo_password} | base64 -d +# 此时获取到了进入mongodb的密码 +# 然后需要进入mongodb容器 +kubectl -n euler-copilot get pods +kubectl -n euler-copilot exec -it mongo-deploy-xxxxxx-xxxxxxx -- bash +# 然后替换为前面获取的mongodb的密码替换 +mongosh "mongodb://euler_copilot:{your_secret}@localhost:27017/?directConnection=true&replicaSet=rs0" +use euler_copilot; +# 这里的id为下图所展示的id +db["user"].update({_id: "{your_id}"}, { $set: {is_admin: true }}); +``` + +![id图片](./pictures/id图片.png) + + + +## 1 注册MCP + +### 1.1 在插件中心点击创建MCP服务 + +![创建MCP](./pictures/创建MCP.png) + +### 1.2 填写MCP服务的相关信息 + +![填写MCP信息](./pictures/填写MCP信息.png) + +上传图片,填写MCP名称和简介. + +### 1.3 选择MCP类型以及填入MCP.json + +以time-mcp-server为例 + +https://mcp.so/server/time/modelcontextprotocol,这是这个time-mcp的使用网址. + +需要将MCP-server给的server config改成上面图片所给的格式,所以不需要更改,但这里time MCPserver给的server config为 + +``` +{ + "mcp": { + "servers": { + "time": { + "command": "uvx", + "args": [ + "mcp-server-time", + "--local-timezone=America/New_York" + ] + } + } + } +} +``` + +所以需要改造一下变成 + +``` +{ + "command": "uvx", + "args": ["mcp-server-time", + "--local-timezone=America/New_York"], + "env": {}, + "autoApprove": [], + "autoInstall": true, + "disabled": false +} +``` + +## 2 使用MCP + +![复制MCPID信息](./pictures/复制MCPID信息.png) + +完成后会有一行显示ID,需要复制. + +当使用的时候还需要激活该MCP,点击上述激活按钮即可 + +### 2.1 创建工作流 + +![image-20250522173235997](./pictures/创建工作流步骤1.png) + +![创建工作流步骤2](./pictures/创建工作流步骤2.png) + +![创建工作流步骤3](./pictures/创建工作流步骤3.png) + +![创建工作流步骤4](./pictures/创建工作流步骤4.png) + +![填写MCP ID](./pictures/填写MCPID.png) + +![调试MCP](./pictures/调试MCP.png) + +![发布应用](./pictures/发布应用.png) + +到此应用MCP的一个工作流应用就创建完成了. \ No newline at end of file diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/id\345\233\276\347\211\207.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/id\345\233\276\347\211\207.png" new file mode 100644 index 0000000000000000000000000000000000000000..78700814edddfdbfb89b0eaa12d5ea7511cf5c4e Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/id\345\233\276\347\211\207.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272MCP.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272MCP.png" new file mode 100644 index 0000000000000000000000000000000000000000..7539a0954506ebb51e680c27b77f88c6ce327002 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272MCP.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2441.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2441.png" new file mode 100644 index 0000000000000000000000000000000000000000..2aa9901130769737d46c2c971b3e89a356c699eb Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2441.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2442.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2442.png" new file mode 100644 index 0000000000000000000000000000000000000000..4a60500e45b998833952a15cbff9b7b2e4b43339 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2442.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2443.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2443.png" new file mode 100644 index 0000000000000000000000000000000000000000..e3ec97310da9d57d22b1a169f3fe968c1411c0a7 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2443.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2444.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2444.png" new file mode 100644 index 0000000000000000000000000000000000000000..f267793f589ce038eb1bec84baeec692a909f3c7 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\210\233\345\273\272\345\267\245\344\275\234\346\265\201\346\255\245\351\252\2444.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\217\221\345\270\203\345\272\224\347\224\250.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\217\221\345\270\203\345\272\224\347\224\250.png" new file mode 100644 index 0000000000000000000000000000000000000000..e996336215c0389a9385cb45edc497fa1f24c562 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\217\221\345\270\203\345\272\224\347\224\250.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\241\253\345\206\231MCPID.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\241\253\345\206\231MCPID.png" new file mode 100644 index 0000000000000000000000000000000000000000..8c8e52a957b58c03dba10fa96238f63360b57c01 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\241\253\345\206\231MCPID.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\241\253\345\206\231MCP\344\277\241\346\201\257.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\241\253\345\206\231MCP\344\277\241\346\201\257.png" new file mode 100644 index 0000000000000000000000000000000000000000..3b747315fa977c4e388d85ff9fd5f113da4f3579 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\241\253\345\206\231MCP\344\277\241\346\201\257.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\244\215\345\210\266MCPID\344\277\241\346\201\257.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\244\215\345\210\266MCPID\344\277\241\346\201\257.png" new file mode 100644 index 0000000000000000000000000000000000000000..f94162a4122804abfa49bbaf669a1592cf08af1b Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\345\244\215\345\210\266MCPID\344\277\241\346\201\257.png" differ diff --git "a/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\350\260\203\350\257\225MCP.png" "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\350\260\203\350\257\225MCP.png" new file mode 100644 index 0000000000000000000000000000000000000000..40e9030f0fd268cd82a096c57a10c22bf2962145 Binary files /dev/null and "b/documents/user-guide/\344\275\277\347\224\250\346\214\207\345\215\227/\347\272\277\344\270\212\346\234\215\345\212\241/pictures/\350\260\203\350\257\225MCP.png" differ diff --git "a/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\346\227\240\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" "b/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\346\227\240\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" index 58fda3ef2b0042a11efd19b548dbc5be3fbc5d0b..aa009eaa0b4d471ce169cc2894406b2828e95c08 100644 --- "a/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\346\227\240\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" +++ "b/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\346\227\240\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" @@ -2,7 +2,7 @@ **版本信息** 当前版本:v0.9.6 -更新日期:2025年7月1日 +更新日期:2025年7月9日 ## 产品概述 @@ -54,21 +54,6 @@ openEuler Intelligence 是一款智能问答工具,使用 openEuler Intelligen ![部署图](./pictures/部署视图.png) -### 域名配置 - -需准备以下服务域名: - -- authhub 认证服务:`authhub.eulercopilot.local` -- openEuler Intelligence web 服务:`www.eulercopilot.local` - -```bash -# 本地 Windows 主机配置 -# 编辑 C:\Windows\System32\drivers\etc\hosts 添加记录 -# 替换 127.0.0.1 为目标服务器 IP -127.0.0.1 authhub.eulercopilot.local -127.0.0.1 www.eulercopilot.local -``` - ## 快速开始 ### 1. 资源获取 @@ -76,7 +61,7 @@ openEuler Intelligence 是一款智能问答工具,使用 openEuler Intelligen #### 获取部署脚本 [openEuler Intelligence 官方 Git 仓库](https://gitee.com/openeuler/euler-copilot-framework) -切换至 dev 分支下载 ZIP 并上传至目标服务器: +切换至 master 分支下载 ZIP 并上传至目标服务器: ```bash unzip euler-copilot-framework.tar -d /home @@ -177,7 +162,7 @@ bash deploy.sh **关键说明**: 1. 部署前需准备好所有资源 -2. 需输入 Authhub 和 openEuler Intelligence 域名(默认使用 `authhub.eulercopilot.local`, `www.eulercopilot.local`) +2. 需输入 openEuler Intelligence 的访问IP(默认使用 127.0.0.l`) #### 重启服务 @@ -253,7 +238,7 @@ k3s ctr image import $image_tar ## 验证安装 -恭喜您,**openEuler Intelligence** 已成功部署!为了开始您的体验,请在浏览器中输入 `https://www.eulercopilot.local` 访问 openEuler Intelligence 的网页: +恭喜您,**openEuler Intelligence** 已成功部署!为了开始您的体验,请在浏览器中输入 `https://$host:30080` 访问 openEuler Intelligence 的网页: 1. 首次访问点击 **立即注册** 创建账号 2. 完成登录流程 diff --git "a/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" "b/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" index 98c4cb586fea9fdf49f368f34e0b956ef53247b4..1441f2041e942446f6539fe44066dbd68e025428 100644 --- "a/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" +++ "b/documents/user-guide/\351\203\250\347\275\262\346\214\207\345\215\227/\347\275\221\347\273\234\347\216\257\345\242\203\344\270\213\351\203\250\347\275\262\346\214\207\345\215\227.md" @@ -2,7 +2,7 @@ **版本信息** 当前版本:v0.9.6 -更新日期:2025年7月1日 +更新日期:2025年7月9日 ## 产品概述 @@ -53,21 +53,6 @@ openEuler Intelligence 是一款智能问答工具,使用 openEuler Intelligen ![部署图](./pictures/部署视图.png) -## 域名配置 - -需准备以下服务域名: - -- authhub 认证服务:`authhub.eulercopilot.local` -- openEuler Intelligence web 服务:`www.eulercopilot.local` - -```bash -# 本地 Windows 主机配置 -# 编辑 C:\Windows\System32\drivers\etc\hosts 添加记录 -# 替换 127.0.0.1 为目标服务器 IP -127.0.0.1 authhub.eulercopilot.local -127.0.0.1 www.eulercopilot.local -``` - ## 快速开始 ### 1. 获取部署脚本 @@ -76,7 +61,7 @@ openEuler Intelligence 是一款智能问答工具,使用 openEuler Intelligen ```bash cd /home -git clone https://gitee.com/openeuler/euler-copilot-framework.git -b dev +git clone https://gitee.com/openeuler/euler-copilot-framework.git -b master cd euler-copilot-framework/deploy/scripts chmod -R +x ./* ``` @@ -212,7 +197,7 @@ k3s ctr image import $image_tar ## 验证安装 -恭喜您,**openEuler Intelligence** 已成功部署!为了开始您的体验,请在浏览器中输入 `https://www.eulercopilot.local` 访问 openEuler Intelligence 的网页: +恭喜您,**openEuler Intelligence** 已成功部署!为了开始您的体验,请在浏览器中输入 `https://$host:30080` 访问 openEuler Intelligence 的网页: 1. 首次访问点击 **立即注册** 创建账号 2. 完成登录流程 diff --git a/pyproject.toml b/pyproject.toml index 681ea9fb4552650da9148a22dd55d2063dca2410..ac114cec839d232bddd64f76fbee15552652153b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,8 @@ url = "https://mirrors.aliyun.com/pypi/simple/" dev = [ "autodoc-pydantic==2.2.0", "coverage==7.7.1", + "fastapi-cprofile>=0.0.2", + "fastapi-profiler-lite>=0.3.2", "pytest==8.3.5", "pytest-mock==3.14.0", "ruff==0.11.2",