diff --git a/apps/models/blacklist.py b/apps/models/blacklist.py index baf42282dcbe15399473be8f24f666fcd6ac62d6..306f28bd0055438bb0a5966ae9b55cffda0cd6fa 100644 --- a/apps/models/blacklist.py +++ b/apps/models/blacklist.py @@ -1,9 +1,10 @@ """黑名单 数据库表结构""" +import uuid from datetime import datetime import pytz -from sqlalchemy import BigInteger, Boolean, DateTime, String +from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column from .base import Base @@ -15,6 +16,8 @@ class Blacklist(Base): __tablename__ = "framework_blacklist" id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True, init=False) """主键ID""" + recordId: Mapped[uuid.UUID] = mapped_column(ForeignKey("framework_record.id")) # noqa: N815 + """关联的问答对ID""" question: Mapped[str] = mapped_column(String(65535)) """黑名单问题""" answer: Mapped[str | None] = mapped_column(String(65535), default=None) diff --git a/apps/schemas/request_data.py b/apps/schemas/request_data.py index 7bdebf4fd9a2cf2f6224fe106b0e7a63ca025f45..491f2e3be84d5e97be4a0b6312142fac481ffff6 100644 --- a/apps/schemas/request_data.py +++ b/apps/schemas/request_data.py @@ -6,8 +6,6 @@ from typing import Any from pydantic import BaseModel, Field -from apps.common.config import config - from .appcenter import AppData from .enum_var import CommentType from .flow_topology import FlowItem @@ -22,22 +20,6 @@ class RequestDataApp(BaseModel): params: dict[str, Any] = Field(description="插件参数") -class MockRequestData(BaseModel): - """POST /api/mock/chat的请求体""" - - app_id: str = Field(default="", description="应用ID", alias="appId") - flow_id: str = Field(default="", description="流程ID", alias="flowId") - conversation_id: str = Field(..., description="会话ID", alias="conversationId") - question: str = Field(..., description="问题", alias="question") - - -class RequestDataFeatures(BaseModel): - """POST /api/chat的features字段数据""" - - max_tokens: int | None = Field(default=config.llm.max_tokens, description="最大生成token数") - context_num: int = Field(default=5, description="上下文消息数量", le=10, ge=0) - - class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" @@ -124,12 +106,6 @@ class ModFavServiceRequest(BaseModel): favorited: bool = Field(..., description="是否收藏") -class ClientSessionData(BaseModel): - """客户端Session信息""" - - session_id: str | None = Field(default=None) - - class ModifyConversationData(BaseModel): """修改会话信息""" diff --git a/apps/services/appcenter.py b/apps/services/appcenter.py index 9ad2d107f2c5494be5ec71154529e56bb78f8e68..9a0daa552518cd93905d57fc1dbff67dbe24dd42 100644 --- a/apps/services/appcenter.py +++ b/apps/services/appcenter.py @@ -6,16 +6,18 @@ import uuid from datetime import UTC, datetime from typing import Any -from apps.common.mongo import MongoDB +from sqlalchemy import select + +from apps.common.postgres import postgres from apps.constants import SERVICE_PAGE_SIZE from apps.exceptions import InstancePermissionError +from apps.models.app import App from apps.models.user import User from apps.scheduler.pool.loader.app import AppLoader from apps.schemas.agent import AgentAppMetadata from apps.schemas.appcenter import AppCenterCardItem, AppData, AppPermissionData from apps.schemas.enum_var import AppFilterType, AppType, PermissionType from apps.schemas.flow import AppMetadata, MetadataType, Permission -from apps.schemas.pool import AppPool from apps.schemas.response_data import RecentAppList, RecentAppListItem from .flow import FlowManager @@ -154,21 +156,24 @@ class AppCenterManager: return app_cards, total_apps + @staticmethod - async def fetch_app_data_by_id(app_id: str) -> AppPool: + async def fetch_app_data_by_id(app_id: str) -> App: """ - 根据应用ID获取应用元数据 + 根据应用ID获取应用元数据(使用PostgreSQL) :param app_id: 应用唯一标识 - :return: 应用元数据 + :return: 应用数据 """ - mongo = MongoDB() - app_collection = mongo.get_collection("app") - db_data = await app_collection.find_one({"_id": app_id}) - if not db_data: - msg = "应用不存在" - raise ValueError(msg) - return AppPool.model_validate(db_data) + async with postgres.session() as session: + app_obj = (await session.scalars( + select(App).where(App.id == app_id), + )).one_or_none() + if not app_obj: + msg = f"[AppCenterManager] 应用不存在: {app_id}" + raise ValueError(msg) + return app_obj + @staticmethod async def create_app(user_sub: str, data: AppData) -> str: @@ -188,6 +193,7 @@ class AppCenterManager: ) return app_id + @staticmethod async def update_app(user_sub: str, app_id: str, data: AppData) -> None: """ @@ -213,6 +219,7 @@ class AppCenterManager: app_data=app_data, ) + @staticmethod async def update_app_publish_status(app_id: str, user_sub: str) -> bool: """ @@ -233,12 +240,16 @@ class AppCenterManager: break # 更新数据库 - mongo = MongoDB() - app_collection = mongo.get_collection("app") - await app_collection.update_one( - {"_id": app_id}, - {"$set": {"published": published}}, - ) + async with postgres.session() as session: + app_obj = (await session.scalars( + select(App).where(App.id == app_id), + )).one_or_none() + if not app_obj: + msg = f"[AppCenterManager] 应用不存在: {app_id}" + raise ValueError(msg) + app_obj.isPublished = published + session.add(app_obj) + await session.commit() await AppCenterManager._process_app_and_save( app_type=app_data.app_type, @@ -250,6 +261,7 @@ class AppCenterManager: return published + @staticmethod async def modify_favorite_app(app_id: str, user_sub: str, *, favorited: bool) -> None: """ diff --git a/apps/services/blacklist.py b/apps/services/blacklist.py index 06f3450ebc7206e5ccaecd7030fa7e50939aca5f..871d03fda4ded919abc5d6fdd08e31bdd6f5496e 100644 --- a/apps/services/blacklist.py +++ b/apps/services/blacklist.py @@ -2,13 +2,16 @@ """黑名单相关操作""" import logging -import re +import uuid -from apps.common.mongo import MongoDB +from sqlalchemy import and_, select + +from apps.common.postgres import postgres from apps.common.security import Security +from apps.models.blacklist import Blacklist +from apps.models.record import Record from apps.models.user import User -from apps.schemas.collection import Blacklist -from apps.schemas.record import Record, RecordContent +from apps.schemas.record import RecordContent logger = logging.getLogger(__name__) @@ -19,21 +22,20 @@ class QuestionBlacklistManager: @staticmethod async def check_blacklisted_questions(input_question: str) -> bool: """给定问题,查找问题是否在黑名单里""" - try: - blacklist_collection = MongoDB().get_collection("blacklist") - result = await blacklist_collection.find_one( - {"question": {"$regex": f"/{re.escape(input_question)}/i"}, "is_audited": True}, {"_id": 1}, - ) - if result: + async with postgres.session() as session: + # 查找是否有审核通过的黑名单问题包含输入的问题 + blacklist_item = (await session.scalars(select(Blacklist).where( + and_( + Blacklist.question.ilike(f"%{input_question}%"), + Blacklist.isAudited == True, # noqa: E712 + ), + ).limit(1))).one_or_none() + + if blacklist_item: # 用户输入的问题中包含黑名单问题的一部分,故拉黑 logger.info("[QuestionBlacklistManager] 问题在黑名单中") return False - except Exception: - # 访问数据库异常 - logger.exception("[QuestionBlacklistManager] 检查问题黑名单失败") - return False - else: - return True + return True @staticmethod async def change_blacklisted_questions( @@ -44,40 +46,33 @@ class QuestionBlacklistManager: is_deletion标识是否为删除操作 """ - try: - blacklist_collection = MongoDB().get_collection("blacklist") + async with postgres.session() as session: + # 根据ID查找黑名单记录 + blacklist_item = await session.get(Blacklist, int(blacklist_id)) + + if not blacklist_item: + logger.info("[QuestionBlacklistManager] 黑名单记录不存在") + return False if is_deletion: - await blacklist_collection.find_one_and_delete({"_id": blacklist_id}) + await session.delete(blacklist_item) logger.info("[QuestionBlacklistManager] 问题从黑名单中删除") - return True + else: + # 修改 + blacklist_item.question = question + blacklist_item.answer = answer + logger.info("[QuestionBlacklistManager] 问题在黑名单中修改") - # 修改 - await blacklist_collection.find_one_and_update( - {"_id": blacklist_id}, {"$set": {"question": question, "answer": answer}}, - ) - logger.info("[QuestionBlacklistManager] 问题在黑名单中修改") - except Exception: - # 数据库操作异常 - logger.exception("[QuestionBlacklistManager] 修改问题黑名单失败") - # 放弃执行后续操作 - return False - else: + await session.commit() return True @staticmethod async def get_blacklisted_questions(limit: int, offset: int, *, is_audited: bool) -> list[Blacklist]: """分页式获取目前所有的问题(待审核或已拉黑)黑名单""" - try: - blacklist_collection = MongoDB().get_collection("blacklist") - return [ - Blacklist.model_validate(item) - async for item in blacklist_collection.find({"is_audited": is_audited}).skip(offset).limit(limit) - ] - except Exception: - logger.exception("[QuestionBlacklistManager] 查询问题黑名单失败") - # 异常 - return [] + async with postgres.session() as session: + return list((await session.scalars(select(Blacklist).where( + Blacklist.isAudited == is_audited, + ).offset(offset).limit(limit))).all()) class UserBlacklistManager: @@ -86,32 +81,24 @@ class UserBlacklistManager: @staticmethod async def get_blacklisted_users(limit: int, offset: int) -> list[str]: """获取当前所有黑名单用户""" - try: - user_collection = MongoDB().get_collection("user") - return [ - user["_id"] - async for user in user_collection.find({"credit": {"$lte": 0}}, {"_id": 1}) - .sort({"_id": 1}) - .skip(offset) - .limit(limit) - ] - except Exception: - logger.exception("[UserBlacklistManager] 查询用户黑名单失败") - return [] + async with postgres.session() as session: + return list((await session.scalars(select(User.userSub).where( + User.credit <= 0, + ).order_by(User.userSub).offset(offset).limit(limit))).all()) @staticmethod async def check_blacklisted_users(user_sub: str) -> bool: """检测某用户是否已被拉黑""" - try: - user_collection = MongoDB().get_collection("user") - result = await user_collection.find_one( - {"user_sub": user_sub, "credit": {"$lte": 0}, "is_whitelisted": False}, {"_id": 1}, - ) - except Exception: - logger.exception("[UserBlacklistManager] 检查用户黑名单失败") - return False - else: - if result is not None: + async with postgres.session() as session: + user = (await session.scalars(select(User).where( + and_( + User.userSub == user_sub, + User.credit <= 0, + User.isWhitelisted == False, # noqa: E712 + ), + ).limit(1))).one_or_none() + + if user: logger.info("[UserBlacklistManager] 用户在黑名单中") return True return False @@ -119,29 +106,28 @@ class UserBlacklistManager: @staticmethod async def change_blacklisted_users(user_sub: str, credit_diff: int, credit_limit: int = 100) -> bool: """修改用户的信用分""" - try: + async with postgres.session() as session: # 获取用户当前信用分 - user_collection = MongoDB().get_collection("user") - result = await user_collection.find_one({"user_sub": user_sub}, {"_id": 0, "credit": 1}) + user = (await session.scalars(select(User).where(User.userSub == user_sub))).one_or_none() + # 用户不存在 - if result is None: + if user is None: logger.info("[UserBlacklistManager] 用户不存在") return False - result = User.model_validate(result) # 用户已被加白,什么都不做 - if result.is_whitelisted: + if user.isWhitelisted: return False - if result.credit > 0 and credit_diff > 0: + if user.credit > 0 and credit_diff > 0: logger.info("[UserBlacklistManager] 用户已解禁") return True - if result.credit <= 0 and credit_diff < 0: + if user.credit <= 0 and credit_diff < 0: logger.info("[UserBlacklistManager] 用户已封禁") return True # 给当前用户的信用分加上偏移量 - new_credit = result.credit + credit_diff + new_credit = user.credit + credit_diff # 不得超过积分上限 if new_credit > credit_limit: new_credit = credit_limit @@ -150,12 +136,8 @@ class UserBlacklistManager: new_credit = 0 # 更新用户信用分 - await user_collection.update_one({"user_sub": user_sub}, {"$set": {"credit": new_credit}}) - except Exception: - # 数据库错误 - logger.exception("[UserBlacklistManager] 修改用户黑名单失败") - return False - else: + user.credit = new_credit + await session.commit() return True @@ -163,68 +145,66 @@ class AbuseManager: """用户举报相关操作""" @staticmethod - async def change_abuse_report(user_sub: str, record_id: str, reason_type: str, reason: str) -> bool: + async def change_abuse_report(user_sub: str, record_id: uuid.UUID, reason_type: str, reason: str) -> bool: """存储用户举报详情""" - try: - # 判断record_id是否合法 - record_group_collection = MongoDB().get_collection("record_group") - record = await record_group_collection.aggregate( - [ - {"$match": {"user_sub": user_sub}}, - {"$unwind": "$records"}, - {"$match": {"records.id": record_id}}, - {"$limit": 1}, - ], - ) + async with postgres.session() as session: + record = (await session.scalars(select(Record).where( + and_( + Record.userSub == user_sub, + Record.id == record_id, + ), + ))).one_or_none() - record = await record.to_list(length=1) if not record: logger.info("[AbuseManager] 举报记录不合法") return False # 获得Record明文内容 - record = Record.model_validate(record[0]["records"]) record_data = Security.decrypt(record.content, record.key) - record_data = RecordContent.model_validate_json(record_data) + record_content = RecordContent.model_validate_json(record_data) # 检查该条目类似内容是否已被举报过 - blacklist_collection = MongoDB().get_collection("question_blacklist") - query = await blacklist_collection.find_one({"_id": record_id}) - if query is not None: + existing_blacklist = (await session.scalars(select(Blacklist).where( + Blacklist.recordId == record_id, + ).limit(1))).one_or_none() + + if existing_blacklist is not None: logger.info("[AbuseManager] 问题已被举报过") return True # 增加新条目 new_blacklist = Blacklist( - _id=record_id, - is_audited=False, - question=record_data.question, - answer=record_data.answer, - reason_type=reason_type, + recordId=record_id, + question=record_content.question, + answer=record_content.answer, + isAudited=False, + reasonType=reason_type, reason=reason, ) - await blacklist_collection.insert_one(new_blacklist.model_dump(by_alias=True)) - except Exception: - logger.exception("[AbuseManager] 修改用户举报失败") - return False - else: + session.add(new_blacklist) + await session.commit() return True @staticmethod - async def audit_abuse_report(question_id: str, *, is_deletion: bool = False) -> bool: + async def audit_abuse_report(record_id: uuid.UUID, *, is_deletion: bool = False) -> bool: """对某一特定的待审问题进行操作,包括批准审核与删除未审问题""" - try: - blacklist_collection = MongoDB().get_collection("blacklist") + async with postgres.session() as session: + blacklist_item = (await session.scalars(select(Blacklist).where( + and_( + Blacklist.recordId == record_id, + Blacklist.isAudited == False, # noqa: E712 + ), + ))).one_or_none() + + if not blacklist_item: + logger.info("[AbuseManager] 待审核问题不存在") + return False + if is_deletion: - await blacklist_collection.delete_one({"_id": question_id, "is_audited": False}) - return True - await blacklist_collection.update_one( - {"_id": question_id, "is_audited": False}, - {"$set": {"is_audited": True}}, - ) - except Exception: - logger.exception("[AbuseManager] 审核用户举报失败") - return False - else: + await session.delete(blacklist_item) + else: + blacklist_item.isAudited = True + + await session.commit() return True diff --git a/apps/services/conversation.py b/apps/services/conversation.py index 9fd1bc29ef77c1895572c5ac3b63e8148d2e41e2..49d2d78db4c66ef00b2363f4a02c722dcc2a4b8e 100644 --- a/apps/services/conversation.py +++ b/apps/services/conversation.py @@ -7,7 +7,7 @@ from datetime import datetime from typing import Any import pytz -from sqlalchemy import func, select +from sqlalchemy import and_, func, select from apps.common.postgres import postgres from apps.models.conversation import Conversation @@ -27,8 +27,10 @@ class ConversationManager: async with postgres.session() as session: result = (await session.scalars( select(Conversation).where( - Conversation.userSub == user_sub, - Conversation.isTemporary == False, # noqa: E712 + and_( + Conversation.userSub == user_sub, + Conversation.isTemporary == False, # noqa: E712 + ), ).order_by( Conversation.createdAt.desc(), ), @@ -42,8 +44,10 @@ class ConversationManager: async with postgres.session() as session: return (await session.scalars( select(Conversation).where( - Conversation.id == conversation_id, - Conversation.userSub == user_sub, + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), ), )).one_or_none() @@ -54,8 +58,10 @@ class ConversationManager: async with postgres.session() as session: result = (await session.scalars( func.count(Conversation.id).where( - Conversation.id == conversation_id, - Conversation.userSub == user_sub, + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), ), )).one() return bool(result) @@ -79,7 +85,12 @@ class ConversationManager: # 如果是非调试模式且app_id存在,更新App的使用情况 if app_id and not debug: app_obj = (await session.scalars( - select(UserAppUsage).where(UserAppUsage.userSub == user_sub, UserAppUsage.appId == app_id), + select(UserAppUsage).where( + and_( + UserAppUsage.userSub == user_sub, + UserAppUsage.appId == app_id, + ), + ), )).one_or_none() if app_obj: # 假设App模型有last_used和usage_count字段(如没有请根据实际表结构调整) @@ -110,8 +121,10 @@ class ConversationManager: async with postgres.session() as session: conv = (await session.scalars( select(Conversation).where( - Conversation.id == conversation_id, - Conversation.userSub == user_sub, + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), ), )).one_or_none() if not conv: @@ -129,8 +142,10 @@ class ConversationManager: async with postgres.session() as session: conv = (await session.scalars( select(Conversation).where( - Conversation.id == conversation_id, - Conversation.userSub == user_sub, + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), ), )).one_or_none() if not conv: @@ -148,8 +163,10 @@ class ConversationManager: async with postgres.session() as session: result = (await session.scalars( func.count(Conversation.id).where( - Conversation.id == conversation_id, - Conversation.userSub == user_sub, + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), ), )).one() return bool(result) diff --git a/apps/services/document.py b/apps/services/document.py index bd6c6f935c9b1609b5116623ed2390dd81ea0068..fde7465d0ef01c9a18399460d6288242acfb8f5c 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -7,7 +7,7 @@ import uuid import asyncer from fastapi import UploadFile -from sqlalchemy import func, select +from sqlalchemy import and_, func, select from apps.common.minio import MinioClient from apps.common.postgres import postgres @@ -91,8 +91,10 @@ class DocumentManager: async with postgres.session() as session: conv = (await session.scalars( select(ConversationDocument).where( - ConversationDocument.conversationId == conversation_id, - ConversationDocument.isUnused.is_(True), + and_( + ConversationDocument.conversationId == conversation_id, + ConversationDocument.isUnused.is_(True), + ), ), )).all() if not conv: @@ -175,7 +177,12 @@ class DocumentManager: async with postgres.session() as session: for doc in document_list: doc_info = await session.scalars( - select(Document).where(Document.id == doc, Document.userSub == user_sub), + select(Document).where( + and_( + Document.id == doc, + Document.userSub == user_sub, + ), + ), ) if not doc_info: logger.error("[DocumentManager] 文件不存在: %s", doc) @@ -183,8 +190,10 @@ class DocumentManager: conv_doc = await session.scalars( select(ConversationDocument).where( - ConversationDocument.documentId == doc, - ConversationDocument.isUnused.is_(True), + and_( + ConversationDocument.documentId == doc, + ConversationDocument.isUnused.is_(True), + ), ), ) if not conv_doc: @@ -238,7 +247,12 @@ class DocumentManager: """文件状态由unused改为used""" async with postgres.session() as session: conversation = (await session.scalars( - select(Conversation).where(Conversation.id == conversation_id, Conversation.userSub == user_sub), + select(Conversation).where( + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), + ), )).one_or_none() if not conversation: logger.error("[DocumentManager] 对话不存在: %s", conversation_id) @@ -247,8 +261,10 @@ class DocumentManager: # 把unused_docs加入RecordGroup中,并与问题关联 docs = (await session.scalars( select(ConversationDocument).where( - ConversationDocument.conversationId == conversation_id, - ConversationDocument.isUnused.is_(True), + and_( + ConversationDocument.conversationId == conversation_id, + ConversationDocument.isUnused.is_(True), + ), ), )).all() if not docs: @@ -265,7 +281,12 @@ class DocumentManager: async with postgres.session() as session: # 查询对应的Record record = (await session.scalars( - select(Record).where(Record.id == record_id, Record.userSub == user_sub), + select(Record).where( + and_( + Record.id == record_id, + Record.userSub == user_sub, + ), + ), )).one_or_none() if not record: logger.error("[DocumentManager] 记录不存在或非当前用户: %s", record_id) @@ -279,6 +300,6 @@ class DocumentManager: if doc: doc.isUnused = False - doc.associated = ConvDocAssociated.answer + doc.associated = ConvDocAssociated.ANSWER doc.recordId = record_id await session.commit() diff --git a/apps/services/llm.py b/apps/services/llm.py index aa50d22a7dddf9dc13e9101bb407bb69afec0a23..7143e5dfff10e5d0ad853a55ffb095305a66c811 100644 --- a/apps/services/llm.py +++ b/apps/services/llm.py @@ -3,8 +3,12 @@ import logging +from sqlalchemy import and_, select + from apps.common.config import config -from apps.schemas.collection import LLM +from apps.common.postgres import postgres +from apps.models.llm import LLMData +from apps.models.user import User from apps.schemas.request_data import ( UpdateLLMReq, ) @@ -35,24 +39,28 @@ class LLMManager: provider_list.append(item) return provider_list + @staticmethod - async def get_llm_id_by_conversation_id(user_sub: str, conversation_id: str) -> str: + async def get_llm_id_by_user_id(user_sub: str) -> int | None: """ - 通过对话ID获取大模型ID + 通过用户ID获取大模型ID :param user_sub: 用户ID - :param conversation_id: 对话ID :return: 大模型ID """ - mongo = MongoDB() - conv_collection = mongo.get_collection("conversation") - result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) - if not result: - return "" - return result.get("llm", {}).get("llm_id", "") + async with postgres.session() as session: + user = (await session.scalars( + select(User).where(User.userSub == user_sub), + )).one_or_none() + if not user: + logger.error("[LLMManager] 用户 %s 不存在", user_sub) + return None + + return user.selectedLLM + @staticmethod - async def get_llm_by_id(user_sub: str, llm_id: str) -> LLM: + async def get_llm_by_id(user_sub: str, llm_id: int) -> LLMData | None: """ 通过ID获取大模型 @@ -60,16 +68,23 @@ class LLMManager: :param llm_id: 大模型ID :return: 大模型对象 """ - llm_collection = MongoDB().get_collection("llm") - result = await llm_collection.find_one({"_id": llm_id, "user_sub": user_sub}) - if not result: - err = f"[LLMManager] LLM {llm_id} 不存在" - logger.error(err) - raise ValueError(err) - return LLM.model_validate(result) + async with postgres.session() as session: + llm = (await session.scalars( + select(LLMData).where( + and_( + LLMData.id == llm_id, + LLMData.userSub == user_sub, + ), + ), + )).one_or_none() + if not llm: + logger.error("[LLMManager] LLM %s 不存在", llm_id) + return None + return llm + @staticmethod - async def list_llm(user_sub: str, llm_id: str | None) -> list[LLMProviderInfo]: + async def list_llm(user_sub: str, llm_id: int | None) -> list[LLMProviderInfo]: """ 获取大模型列表 @@ -77,23 +92,26 @@ class LLMManager: :param llm_id: 大模型ID :return: 大模型列表 """ - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") - - query = {"user_sub": user_sub} - if llm_id: - query["llm_id"] = llm_id - result = await llm_collection.find(query).sort({"created_at": 1}).to_list(length=None) - - llm_item = LLMProviderInfo( - llmId="empty", - icon=llm_provider_dict["ollama"]["icon"], - openaiBaseUrl=config.llm.endpoint, - openaiApiKey=config.llm.key, - modelName=config.llm.model, - maxTokens=config.llm.max_tokens, - isEditable=False, - ) + async with postgres.session() as session: + if llm_id: + llm_list = (await session.scalars( + select(LLMData).where( + and_( + LLMData.id == llm_id, + LLMData.userSub == user_sub, + ), + ), + )).all() + else: + llm_list = (await session.scalars( + select(LLMData).where(LLMData.userSub == user_sub), + )).all() + if not llm_list: + logger.error("[LLMManager] 无法找到用户 %s 的大模型", user_sub) + return [] + + # 默认大模型 + llm_item = LLMProviderInfo(llmId="empty") llm_list = [llm_item] for llm in result: llm_item = LLMProviderInfo( @@ -147,45 +165,47 @@ class LLMManager: await llm_collection.insert_one(llm.model_dump(by_alias=True)) return llm.id + @staticmethod - async def delete_llm(user_sub: str, llm_id: str) -> str: + async def delete_llm(user_sub: str, llm_id: int | None) -> None: """ 删除大模型 :param user_sub: 用户ID :param llm_id: 大模型ID - :return: 大模型ID """ - if llm_id == "empty": + if llm_id is None: err = "[LLMManager] 不能删除默认大模型" - logger.error(err) raise ValueError(err) - mongo = MongoDB() - llm_collection = mongo.get_collection("llm") - conv_collection = mongo.get_collection("conversation") - - conv_dict = await conv_collection.find_one({"llm.llm_id": llm_id, "user_sub": user_sub}) - if conv_dict: - await conv_collection.update_many( - {"_id": conv_dict["_id"], "user_sub": user_sub}, - {"$set": {"llm": { - "llm_id": "empty", - "icon": llm_provider_dict["ollama"]["icon"], - "model_name": config.llm.model, - }}}, - ) + async with postgres.session() as session: + llm = (await session.scalars( + select(LLMData).where( + and_( + LLMData.id == llm_id, + LLMData.userSub == user_sub, + ), + ), + )).one_or_none() + if not llm: + err = f"[LLMManager] LLM {llm_id} 不存在" + else: + await session.delete(llm) + await session.commit() + + async with postgres.session() as session: + user = (await session.scalars( + select(User).where(User.userSub == user_sub), + )).one_or_none() + if not user: + err = f"[LLMManager] 用户 {user_sub} 不存在" + raise ValueError(err) + user.selectedLLM = None + await session.commit() - llm_config = await llm_collection.find_one({"_id": llm_id, "user_sub": user_sub}) - if not llm_config: - err = f"[LLMManager] LLM {llm_id} 不存在" - logger.error(err) - raise ValueError(err) - await llm_collection.delete_one({"_id": llm_id, "user_sub": user_sub}) - return llm_id @staticmethod - async def update_conversation_llm( + async def update_user_llm( user_sub: str, conversation_id: str, llm_id: str, @@ -218,11 +238,15 @@ class LLMManager: logger.error(err_msg) raise ValueError(err_msg) - llm_item = LLMItem( - llm_id=llm_id, - model_name=llm_dict["model_name"], + llm_item = LLMData( + userSub=user_sub, icon=llm_dict["icon"], + openaiBaseUrl=llm_dict["openai_base_url"], + openaiAPIKey=llm_dict["openai_api_key"], + modelName=llm_dict["model_name"], + maxToken=llm_dict["max_tokens"], ) + await conv_collection.update_one( {"_id": conversation_id, "user_sub": user_sub}, {"$set": {"llm": llm_item.model_dump(by_alias=True)}}, diff --git a/apps/services/mcp_service.py b/apps/services/mcp_service.py index 7cb880c0f08b1cc7727bde528be39ac748f07ae2..80abf5f48876bef780ac11318b11734123ab1fc0 100644 --- a/apps/services/mcp_service.py +++ b/apps/services/mcp_service.py @@ -10,7 +10,6 @@ from fastapi import UploadFile from PIL import Image from sqids.sqids import Sqids -from apps.common.mongo import MongoDB from apps.constants import ( ALLOWED_ICON_MIME_TYPES, ICON_PATH, diff --git a/apps/services/personal_token.py b/apps/services/personal_token.py index aaf463c01e9a1ca21314df79f57e7ff1d910a9d5..4a7205ae5abbfafd51fcda908c36191e47f6e550 100644 --- a/apps/services/personal_token.py +++ b/apps/services/personal_token.py @@ -5,7 +5,7 @@ import hashlib import logging import uuid -from sqlalchemy import func, select, update +from sqlalchemy import select, update from apps.common.postgres import postgres from apps.models.user import User @@ -38,28 +38,6 @@ class PersonalTokenManager: return result - @staticmethod - async def verify_personal_token(personal_token: str) -> bool: - """ - 验证Personal Token,用于FastAPI dependency - - :param personal_token: Personal Token - :return: 验证Personal Token是否成功 - """ - async with postgres.session() as session: - try: - result = ( - await session.scalars( - select(func.count(User.id)).where(User.personalToken == personal_token), - ) - ).one() - except Exception: - logger.exception("[PersonalTokenManager] 验证Personal Token失败") - return False - else: - return result > 0 - - @staticmethod async def update_personal_token(user_sub: str) -> str | None: """ diff --git a/apps/services/rag.py b/apps/services/rag.py index 875ce42d33becfdb7330100f41da3182ef47c4d1..8e12514564e3cbbd3bab34970f5c3ffa275a9974 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -12,7 +12,6 @@ from apps.common.config import config from apps.llm.patterns.rewrite import QuestionRewrite from apps.llm.reasoning import ReasoningLLM from apps.llm.token import TokenCalculator -from apps.schemas.collection import LLM from apps.schemas.config import LLMConfig from apps.schemas.enum_var import EventType from apps.schemas.rag_data import RAGQueryReq diff --git a/apps/services/record.py b/apps/services/record.py index eb6c557f14217424df19c439b23faba20e9dee06..798ba8d821064a2e3b37375986b6b1f2f8d79bc2 100644 --- a/apps/services/record.py +++ b/apps/services/record.py @@ -2,8 +2,15 @@ """问答对Manager""" import logging +import uuid +from datetime import UTC, datetime from typing import Literal +from sqlalchemy import and_, select + +from apps.common.postgres import postgres +from apps.models.conversation import Conversation +from apps.models.record import Record as PgRecord from apps.schemas.record import Record logger = logging.getLogger(__name__) @@ -13,60 +20,65 @@ class RecordManager: """问答对相关操作""" @staticmethod - async def create_record_group(group_id: str, user_sub: str, conversation_id: str, task_id: str) -> str | None: - """创建问答组""" - mongo = MongoDB() - record_group_collection = mongo.get_collection("record_group") - conversation_collection = mongo.get_collection("conversation") - record_group = RecordGroup( - _id=group_id, - user_sub=user_sub, - conversation_id=conversation_id, - task_id=task_id, - ) + async def verify_record_in_conversation(record_id: uuid.UUID, user_sub: str, conversation_id: uuid.UUID) -> bool: + """ + 校验指定record_id是否属于指定用户和会话(PostgreSQL实现) - try: - async with mongo.get_session() as session, await session.start_transaction(): - # RecordGroup里面加一条记录 - await record_group_collection.insert_one(record_group.model_dump(by_alias=True), session=session) - # Conversation里面加一个ID - await conversation_collection.update_one( - {"_id": conversation_id}, {"$push": {"record_groups": group_id}}, session=session, - ) - except Exception: - logger.exception("[RecordManager] 创建问答组失败") - return None + :param record_id: 记录ID + :param user_sub: 用户sub + :param conversation_id: 会话ID + :return: 是否存在 + """ + async with postgres.session() as session: + result = (await session.scalars( + select(PgRecord).where( + and_( + PgRecord.id == record_id, + PgRecord.userSub == user_sub, + PgRecord.conversationId == conversation_id, + ), + ))).one_or_none() + return result is not None - return group_id @staticmethod - async def insert_record_data_into_record_group(user_sub: str, group_id: str, record: Record) -> str | None: - """加密问答对,并插入MongoDB中的特定问答组""" - mongo = MongoDB() - group_collection = mongo.get_collection("record_group") - try: - await group_collection.update_one( - {"_id": group_id, "user_sub": user_sub}, - {"$push": {"records": record.model_dump(by_alias=True)}}, - ) - except Exception: - logger.exception("[RecordManager] 插入加密问答对失败") - return None - else: - return record.id + async def insert_record_data(user_sub: str, conversation_id: uuid.UUID, record: Record) -> uuid.UUID | None: + """Record插入PostgreSQL""" + async with postgres.session() as session: + conv = (await session.scalars( + select(Conversation).where( + and_( + Conversation.id == conversation_id, + Conversation.userSub == user_sub, + ), + ), + )).one_or_none() + if not conv: + logger.error("[RecordManager] 对话不存在: %s", conversation_id) + return None + + session.add(PgRecord( + id=record.id, + conversationId=conversation_id, + taskId=record.task_id, + userSub=user_sub, + content=record.content, + key=record.key, + createdAt=datetime.fromtimestamp(record.created_at, tz=UTC), + )) + await session.commit() + + return record.id + @staticmethod async def query_record_by_conversation_id( user_sub: str, - conversation_id: str, + conversation_id: uuid.UUID, total_pairs: int | None = None, order: Literal["desc", "asc"] = "desc", ) -> list[Record]: - """ - 查询ConversationID的最后n条问答对 - - 每个record_group只取最后一条record - """ + """查询ConversationID的最后n条问答对""" sort_order = -1 if order == "desc" else 1 mongo = MongoDB() @@ -104,27 +116,3 @@ class RecordManager: return [] else: return records - - @staticmethod - async def query_record_group_by_conversation_id( - conversation_id: str, total_pairs: int | None = None, - ) -> list[RecordGroup]: - """ - 查询对话ID的最后n条问答组 - - 包含全部record_group及其关联的record - """ - record_group_collection = MongoDB().get_collection("record_group") - try: - pipeline = [ - {"$match": {"conversation_id": conversation_id}}, - {"$sort": {"created_at": -1}}, - ] - if total_pairs is not None: - pipeline.append({"$limit": total_pairs}) - - records = await record_group_collection.aggregate(pipeline) - return [RecordGroup.model_validate(record) async for record in records] - except Exception: - logger.exception("[RecordManager] 查询问答组失败") - return [] diff --git a/apps/services/service.py b/apps/services/service.py index ab27ce68f8d18d6d2ccc46099235cce21a147bba..ee78963f55a241134e0d0ab8c7710ab213af07c0 100644 --- a/apps/services/service.py +++ b/apps/services/service.py @@ -7,7 +7,7 @@ from typing import Any import yaml from anyio import Path -from sqlalchemy import select +from sqlalchemy import and_, select from apps.common.config import config from apps.common.postgres import postgres @@ -127,8 +127,10 @@ class ServiceCenterManager: async with postgres.session() as session: db_service = (await session.scalars( select(Service).where( - Service.name == validated_data.id, - Service.description == validated_data.description, + and_( + Service.name == validated_data.id, + Service.description == validated_data.description, + ), ), )).one_or_none()