From 61015025a17d017e04fa174978ed2cee412086a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Thu, 23 Jan 2025 20:48:30 +0800 Subject: [PATCH 01/13] =?UTF-8?q?=E4=BF=AE=E6=AD=A3user=5Fsub=E8=8E=B7?= =?UTF-8?q?=E5=BE=97=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/routers/conversation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 80f36d7e..6caca27a 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -45,7 +45,7 @@ router = APIRouter( async def create_new_conversation( - user_sub: str, + user_sub: Annotated[str,Depends(get_user)], conv_list: list[Conversation], app_id: str = "", is_debug: bool = False, @@ -124,11 +124,11 @@ async def get_conversation_list(user_sub: Annotated[str, Depends(get_user)]): # @router.post("", dependencies=[Depends(verify_csrf_token)], response_model=AddConversationRsp) -async def add_conversation( +async def add_conversation( # noqa: ANN201 user_sub: Annotated[str, Depends(get_user)], appId: Optional[str] = None, # noqa: N803 isDebug: Optional[bool] = None, # noqa: N803 -): +): """手动创建新对话""" conversations = await ConversationManager.get_conversation_by_user_sub(user_sub) # 尝试创建新对话 -- Gitee From f2c78ad58fa79a2363b08966d63befe79d8c32d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 00:31:10 +0800 Subject: [PATCH 02/13] =?UTF-8?q?=E5=AE=8C=E5=96=84chat=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=E5=BA=94=E7=94=A8=E4=B8=AD=E5=BF=83?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/manager/appcenter.py | 351 ++++++++++++++++++++++++++ apps/routers/appcenter.py | 286 +++++++++++++++++++++ apps/scheduler/scheduler/flow.py | 65 +---- apps/scheduler/scheduler/scheduler.py | 8 +- apps/service/suggestion.py | 4 +- 5 files changed, 649 insertions(+), 65 deletions(-) create mode 100644 apps/manager/appcenter.py create mode 100644 apps/routers/appcenter.py diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py new file mode 100644 index 00000000..44e3d707 --- /dev/null +++ b/apps/manager/appcenter.py @@ -0,0 +1,351 @@ +"""应用中心 Manager + +Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +""" +import uuid +from enum import Enum +from typing import Any, Optional + +from fastapi.encoders import jsonable_encoder + +from apps.constants import LOGGER +from apps.entities.appcenter import AppCenterCardItem, AppData +from apps.entities.enum_var import SearchType +from apps.entities.flow import Permission +from apps.entities.pool import AppPool +from apps.entities.response_data import RecentAppList, RecentAppListItem +from apps.models.mongo import MongoDB + + +class AppCenterManager: + """应用中心管理器""" + + class ModFavAppFlag(Enum): + """收藏应用标志""" + + SUCCESS = 0 + NOT_FOUND = 1 + BAD_REQUEST = 2 + INTERNAL_ERROR = 3 + + @staticmethod + async def fetch_all_apps( + user_sub: str, + search_type: SearchType, + keyword: Optional[str], + page: int, + page_size: int, + ) -> tuple[list[AppCenterCardItem], int]: + """获取所有应用列表""" + try: + # 搜索条件 + filters: dict[str, Any] = AppCenterManager._build_filters( + {"published": True}, + search_type, + keyword, + ) if keyword and search_type != SearchType.AUTHOR else {} + apps, total_pages = await AppCenterManager._search_apps_by_filter(filters, page, page_size) + return [ + AppCenterCardItem( + appId=app.id, + icon=app.icon, + name=app.name, + description=app.description, + author=app.author, + favorited=(user_sub in app.favorites), + published=app.published, + ) + for app in apps + ], total_pages + except Exception as e: + LOGGER.error(f"[AppCenterManager] Get app list failed: {e}") + return [], -1 + + @staticmethod + async def fetch_user_apps( + user_sub: str, + search_type: SearchType, + keyword: Optional[str], + page: int, + page_size: int, + ) -> tuple[list[AppCenterCardItem], int]: + """获取用户应用列表""" + try: + # 搜索条件 + base_filter = {"author": user_sub} + filters: dict[str, Any] = AppCenterManager._build_filters( + base_filter, + search_type, + keyword, + ) if keyword and search_type != SearchType.AUTHOR else base_filter + apps, total_pages = await AppCenterManager._search_apps_by_filter(filters, page, page_size) + return [ + AppCenterCardItem( + appId=app.id, + icon=app.icon, + name=app.name, + description=app.description, + author=app.author, + favorited=(user_sub in app.favorites), + published=app.published, + ) + for app in apps + ], total_pages + except Exception as e: + LOGGER.info(f"[AppCenterManager] Get app list by user failed: {e}") + return [], -1 + + @staticmethod + async def fetch_favorite_apps( + user_sub: str, + search_type: SearchType, + keyword: Optional[str], + page: int, + page_size: int, + ) -> tuple[list[AppCenterCardItem], int]: + """获取用户收藏的应用列表""" + try: + fav_app = await AppCenterManager._get_favorite_app_ids_by_user(user_sub) + # 搜索条件 + base_filter = { + "_id": {"$in": fav_app}, + "published": True, + } + filters: dict[str, Any] = AppCenterManager._build_filters( + base_filter, + search_type, + keyword, + ) if keyword else base_filter + apps, total_pages = await AppCenterManager._search_apps_by_filter(filters, page, page_size) + return [ + AppCenterCardItem( + appId=app.id, + icon=app.icon, + name=app.name, + description=app.description, + author=app.author, + favorited=True, + published=app.published, + ) + for app in apps + ], total_pages + except Exception as e: + LOGGER.info(f"[AppCenterManager] Get favorite app list failed: {e}") + return [], -1 + + @staticmethod + async def fetch_app_data_by_id(app_id: str) -> Optional[AppPool]: + """根据应用ID获取应用元数据""" + try: + app_collection = MongoDB.get_collection("app") + db_data = await app_collection.find_one({"_id": app_id}) + if not db_data: + return None + return AppPool.model_validate(db_data) + except Exception as e: + LOGGER.info(f"[AppCenterManager] Get app metadata by app_id failed: {e}") + return None + + @staticmethod + async def create_app(user_sub: str, data: AppData) -> Optional[str]: + """创建应用""" + app_id = str(uuid.uuid4()) + app = AppPool( + _id=app_id, + name=data.name, + description=data.description, + author=user_sub, + icon=data.icon, + links=data.links, + first_questions=data.first_questions, + history_len=data.history_len, + permission=Permission( + type=data.permission.type, + users=data.permission.users or [], + ), + ) + try: + app_collection = MongoDB.get_collection("app") + app_dict = jsonable_encoder(app) + await app_collection.insert_one(app_dict) + return app_id + except Exception as e: + LOGGER.error(f"[AppCenterManager] Create app failed: {e}") + return None + + @staticmethod + async def update_app(app_id: str, data: AppData) -> bool: + """更新应用""" + try: + app_collection = MongoDB.get_collection("app") + app_data = AppPool.model_validate(await app_collection.find_one({"_id": app_id})) + if not app_data: + return False + # 如果工作流ID列表不一致,则需要取消发布状态 + published_false_needed = {flow.id for flow in app_data.flows} != set(data.workflows) + update_data = { + "name": data.name, + "description": data.description, + "icon": data.icon, + "links": data.links, + "first_questions": data.first_questions, + "history_len": data.history_len, + "permission": Permission( + type=data.permission.type, + users=data.permission.users or [], + ), + } + if published_false_needed: + update_data["published"] = False + await app_collection.update_one({"_id": app_id}, {"$set": update_data}) + return True + except Exception as e: + LOGGER.error(f"[AppCenterManager] Update app failed: {e}") + return False + + @staticmethod + async def publish_app(app_id: str) -> bool: + """发布应用""" + try: + app_collection = MongoDB.get_collection("app") + await app_collection.update_one( + {"_id": app_id}, + {"$set": {"published": True}}, + ) + return True + except Exception as e: + LOGGER.error(f"[AppCenterManager] Publish app failed: {e}") + return False + + @staticmethod + async def modify_favorite_app(app_id: str, user_sub: str, *, favorited: bool) -> ModFavAppFlag: + """修改收藏状态""" + try: + app_collection = MongoDB.get_collection("app") + db_data = await app_collection.find_one({"_id": app_id}) + if not db_data: + return AppCenterManager.ModFavAppFlag.NOT_FOUND + + app_data = AppPool.model_validate(db_data) + already_favorited = user_sub in app_data.favorites + + # 只能收藏未收藏的 + if favorited and already_favorited: + return AppCenterManager.ModFavAppFlag.BAD_REQUEST + # 只能取消已收藏的 + if not favorited and not already_favorited: + return AppCenterManager.ModFavAppFlag.BAD_REQUEST + + if favorited: + await app_collection.update_one( + {"_id": app_id}, + {"$addToSet": {"favorites": user_sub}}, + upsert=True, + ) + else: + await app_collection.update_one( + {"_id": app_id}, + {"$pull": {"favorites": user_sub}}, + ) + return AppCenterManager.ModFavAppFlag.SUCCESS + except Exception as e: + LOGGER.error(f"[AppCenterManager] Modify favorite app failed: {e}") + return AppCenterManager.ModFavAppFlag.INTERNAL_ERROR + + @staticmethod + async def get_recently_used_apps(count: int, user_sub: str) -> Optional[RecentAppList]: + """获取用户最近使用的应用列表""" + try: + user_collection = MongoDB.get_collection("user") + app_collection = MongoDB.get_collection("app") + user_data = await user_collection.find_one({"_id": user_sub}, {"_id": 0, "app_usage": 1}) + if user_data and "app_usage" in user_data and user_data["app_usage"]: + usage_list = sorted( + user_data["app_usage"].items(), + key=lambda x: x[1]["last_used"], + reverse=True, + )[:count] + app_ids = [t[0] for t in usage_list] + apps = await app_collection.find( + {"_id": {"$in": app_ids}}, {"name": 1}).to_list(len(app_ids)) + app_map = {str(a["_id"]): a.get("name", "") for a in apps} + return RecentAppList(applications=[ + RecentAppListItem(appId=app_id, name=app_map.get(app_id, "")) + for app_id in app_ids + ]) + except Exception as e: + LOGGER.info(f"[AppCenterManager] Get recently used apps failed: {e}") + return None + + @staticmethod + async def delete_app(app_id: str, user_sub: str) -> bool: + """删除应用""" + try: + async with MongoDB.get_session() as session, await session.start_transaction(): + app_collection = MongoDB.get_collection("app") + await app_collection.delete_one({"_id": app_id}, session=session) + user_collection = MongoDB.get_collection("user") + await user_collection.update_one( + {"_id": user_sub}, + {"$unset": {f"app_usage.{app_id}": ""}}, + session=session, + ) + await session.commit_transaction() + return True + except Exception as e: + LOGGER.error(f"[AppCenterManager] Delete app failed: {e}") + return False + + @staticmethod + def _build_filters( + base_filters: dict[str, Any], + search_type: SearchType, + keyword: str, + ) -> dict[str, Any]: + search_filters = [ + {"name": {"$regex": keyword, "$options": "i"}}, + {"description": {"$regex": keyword, "$options": "i"}}, + {"author": {"$regex": keyword, "$options": "i"}}, + ] + if search_type == SearchType.ALL: + base_filters["$or"] = search_filters + elif search_type == SearchType.NAME: + base_filters["name"] = {"$regex": keyword, "$options": "i"} + elif search_type == SearchType.DESCRIPTION: + base_filters["description"] = {"$regex": keyword, "$options": "i"} + elif search_type == SearchType.AUTHOR: + base_filters["author"] = {"$regex": keyword, "$options": "i"} + return base_filters + + @staticmethod + async def _search_apps_by_filter( + search_conditions: dict[str, Any], + page: int, + page_size: int, + ) -> tuple[list[AppPool], int]: + """根据过滤条件搜索应用并计算总页数""" + try: + app_collection = MongoDB.get_collection("app") + total_apps = await app_collection.count_documents(search_conditions) + total_pages = (total_apps + page_size - 1) // page_size + db_data = await app_collection.find(search_conditions) \ + .sort("created_at", -1) \ + .skip((page - 1) * page_size) \ + .limit(page_size) \ + .to_list(length=page_size) + apps = [AppPool.model_validate(doc) for doc in db_data] + return apps, total_pages + except Exception as e: + LOGGER.info(f"[AppCenterManager] Search apps by filter failed: {e}") + return [], -1 + + @staticmethod + async def _get_favorite_app_ids_by_user(user_sub: str) -> list[str]: + """获取用户收藏的应用ID""" + try: + app_collection = MongoDB.get_collection("app") + cursor = app_collection.find({"favorites": user_sub}) + return [AppPool.model_validate(doc).id async for doc in cursor] + except Exception as e: + LOGGER.info(f"[AppCenterManager] Get favorite app ids by user_sub failed: {e}") + return [] diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py new file mode 100644 index 00000000..6140dfed --- /dev/null +++ b/apps/routers/appcenter.py @@ -0,0 +1,286 @@ +"""FastAPI 应用中心相关路由 + +Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. +""" +from typing import Annotated, Optional, Union + +from fastapi import APIRouter, Body, Depends, Path, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency.csrf import verify_csrf_token +from apps.dependency.user import get_user, verify_user +from apps.entities.appcenter import AppPermissionData +from apps.entities.enum_var import SearchType +from apps.entities.request_data import CreateAppRequest, ModFavAppRequest +from apps.entities.response_data import ( + BaseAppOperationMsg, + BaseAppOperationRsp, + GetAppListMsg, + GetAppListRsp, + GetAppPropertyMsg, + GetAppPropertyRsp, + GetRecentAppListRsp, + ModFavAppMsg, + ModFavAppRsp, + ResponseData, +) +from apps.manager.appcenter import AppCenterManager +from apps.manager.flow import FlowManager + +router = APIRouter( + prefix="/api/app", + tags=["appcenter"], + dependencies=[Depends(verify_user)], +) + + +@router.get("", response_model=Union[GetAppListRsp, ResponseData]) +async def get_applications( # noqa: ANN201, PLR0913 + user_sub: Annotated[str, Depends(get_user)], + my_app: Annotated[Optional[bool], Query(alias="createdByMe", description="筛选我创建的")] = None, + my_fav: Annotated[Optional[bool], Query(alias="favorited", description="筛选我收藏的")] = None, + search_type: Annotated[SearchType, Query(alias="searchType", description="搜索类型")] = SearchType.ALL, + keyword: Annotated[Optional[str], Query(alias="keyword", description="搜索关键字")] = None, + page: Annotated[int, Query(alias="page", ge=1, description="页码")] = 1, + page_size: Annotated[int, Query(alias="pageSize", ge=1, le=100, description="每页条数")] = 20, +): + """获取应用列表""" + if my_app and my_fav: # 只能同时使用一个过滤条件 + return ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="createdByMe 和 favorited 不能同时生效", + result={}, + ) + + app_cards, total_pages = [], -1 + if my_app: # 筛选我创建的 + app_cards, total_pages = await AppCenterManager.fetch_user_apps( + user_sub, search_type, keyword, page, page_size) + elif my_fav: # 筛选已收藏的 + app_cards, total_pages = await AppCenterManager.fetch_favorite_apps( + user_sub, search_type, keyword, page, page_size) + else: # 获取所有应用 + app_cards, total_pages = await AppCenterManager.fetch_all_apps( + user_sub, search_type, keyword, page, page_size) + if total_pages == -1: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="查询失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=GetAppListRsp( + code=status.HTTP_200_OK, + message="查询成功", + result=GetAppListMsg( + currentPage=page, + totalPages=total_pages, + applications=app_cards, + ), + ).model_dump(exclude_none=True, by_alias=True)) + + +@router.post("", dependencies=[Depends(verify_csrf_token)], response_model=Union[BaseAppOperationRsp, ResponseData]) +async def create_or_update_application( # noqa: ANN201 + request: Annotated[CreateAppRequest, Body(...)], + user_sub: Annotated[str, Depends(get_user)], +): + """创建或更新应用""" + app_id = request.app_id + if app_id: + # 更新应用 + confirm = await AppCenterManager.update_app(app_id, request) + if not confirm: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="更新失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=BaseAppOperationRsp( + code=status.HTTP_200_OK, + message="更新成功", + result=BaseAppOperationMsg(appId=app_id), + ).model_dump(exclude_none=True, by_alias=True)) + # 创建应用 + app_id = await AppCenterManager.create_app(user_sub, request) + if not app_id: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="创建失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=BaseAppOperationRsp( + code=status.HTTP_200_OK, + message="创建成功", + result=BaseAppOperationMsg(appId=app_id), + ).model_dump(exclude_none=True, by_alias=True)) + + +@router.get("/recent", response_model=Union[GetRecentAppListRsp, ResponseData]) +async def get_recently_used_applications( # noqa: ANN201 + user_sub: Annotated[str, Depends(get_user)], + count: Annotated[int, Query(ge=1, le=10)] = 5, +): + """获取最近使用的应用""" + recent_apps = await AppCenterManager.get_recently_used_apps( + count, user_sub) + if not recent_apps: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="查询失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=GetRecentAppListRsp( + code=status.HTTP_200_OK, + message="查询成功", + result=recent_apps, + ).model_dump(exclude_none=True, by_alias=True)) + + +@router.get("/{appId}", response_model=Union[GetAppPropertyRsp, ResponseData]) +async def get_application( # noqa: ANN201 + app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], +): + """获取应用详情""" + app_data = await AppCenterManager.fetch_app_data_by_id(app_id) + if not app_data: + return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="找不到应用", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + workflows = [flow.id for flow in app_data.flows] + return JSONResponse(status_code=status.HTTP_200_OK, content=GetAppPropertyRsp( + code=status.HTTP_200_OK, + message="查询成功", + result=GetAppPropertyMsg( + appId=app_data.id, + published=app_data.published, + name=app_data.name, + description=app_data.description, + icon=app_data.icon, + links=app_data.links, + recommendedQuestions=app_data.first_questions, + dialogRounds=app_data.history_len, + permission=AppPermissionData( + visibility=app_data.permission.type, + authorizedUsers=app_data.permission.users, + ), + workflows=workflows, + ), + ).model_dump(exclude_none=True, by_alias=True)) + + +@router.delete( + "/{appId}", + dependencies=[Depends(verify_csrf_token)], + response_model=Union[BaseAppOperationRsp, ResponseData], +) +async def delete_application( # noqa: ANN201 + app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], + user_sub: Annotated[str, Depends(get_user)], +): + """删除应用""" + app_data = await AppCenterManager.fetch_app_data_by_id(app_id) + if not app_data: + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="查询失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + # 校验应用作者是否为当前用户 + if app_data.author != user_sub: + return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content=ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="无权删除他人创建的应用", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + # 删除应用相关的工作流 + for flow in app_data.flows: + if not await FlowManager.delete_flow_by_app_and_flow_id(app_id, flow.id): + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message=f"删除应用下属工作流 {flow.id} 失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + # 删除应用 + if not await AppCenterManager.delete_app(app_id, user_sub): + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="删除失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=BaseAppOperationRsp( + code=status.HTTP_200_OK, + message="删除成功", + result=BaseAppOperationMsg(appId=app_id), + ).model_dump(exclude_none=True, by_alias=True)) + + +@router.post("/{appId}", response_model=BaseAppOperationRsp) +async def publish_application( # noqa: ANN201 + app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], + user_sub: Annotated[str, Depends(get_user)], +): + """发布应用""" + app_data = await AppCenterManager.fetch_app_data_by_id(app_id) + if not app_data: + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="查询失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + # 验证用户权限 + if app_data.author != user_sub: + return JSONResponse(status_code=status.HTTP_403_FORBIDDEN, content=ResponseData( + code=status.HTTP_403_FORBIDDEN, + message="无权发布他人创建的应用", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + # 发布应用 + if not await AppCenterManager.publish_app(app_id): + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="发布失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=BaseAppOperationRsp( + code=status.HTTP_200_OK, + message="发布成功", + result=BaseAppOperationMsg(appId=app_id), + ).model_dump(exclude_none=True, by_alias=True)) + + +@router.put("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=Union[ModFavAppRsp, ResponseData]) +async def modify_favorite_application( # noqa: ANN201 + app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], + request: Annotated[ModFavAppRequest, Body(...)], + user_sub: Annotated[str, Depends(get_user)], +): + """更改应用收藏状态""" + flag = await AppCenterManager.modify_favorite_app(app_id, user_sub, favorited=request.favorited) + if flag == AppCenterManager.ModFavAppFlag.NOT_FOUND: + return JSONResponse(status_code=status.HTTP_404_NOT_FOUND, content=ResponseData( + code=status.HTTP_404_NOT_FOUND, + message="找不到应用", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + if flag == AppCenterManager.ModFavAppFlag.BAD_REQUEST: + return JSONResponse(status_code=status.HTTP_400_BAD_REQUEST, content=ResponseData( + code=status.HTTP_400_BAD_REQUEST, + message="不可重复操作", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + if flag == AppCenterManager.ModFavAppFlag.INTERNAL_ERROR: + return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + message="操作失败", + result={}, + ).model_dump(exclude_none=True, by_alias=True)) + return JSONResponse(status_code=status.HTTP_200_OK, content=ModFavAppRsp( + code=status.HTTP_200_OK, + message="操作成功", + result=ModFavAppMsg( + appId=app_id, + favorited=request.favorited, + ), + ).model_dump(exclude_none=True, by_alias=True)) diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 4804cce6..15520326 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -4,12 +4,12 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Optional -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.llm.patterns import Select from apps.scheduler.pool.pool import Pool -async def choose_flow(task_id: str, question: str, origin_plugin_list: list[RequestDataPlugin]) -> tuple[str, Optional[RequestDataPlugin]]: +async def choose_flow(task_id: str, question: str, origin_app: RequestDataApp) -> tuple[str, Optional[RequestDataApp]]: """依据用户的输入和选择,构造对应的Flow。 - 当用户没有选择任何Plugin时,直接进行智能问答 @@ -17,61 +17,8 @@ async def choose_flow(task_id: str, question: str, origin_plugin_list: list[Requ - 当用户选择Plugin时,在plugin内挑选最适合的flow :param question: 用户输入(用户问题) - :param origin_plugin_list: 用户选择的插件,可以一次选择多个 - :result: 经LLM选择的Plugin ID和Flow ID + :param origin_app: 用户选择的app信息 + :result: 经LLM选择的App ID和Flow ID """ - # 去掉无效的插件选项:plugin_id为空 - plugin_ids = [] - flow_ids = [] - for item in origin_plugin_list: - if not item.plugin_id: - continue - plugin_ids.append(item.plugin_id) - if item.flow_id: - flow_ids.append(item) - - # 用户什么都不选,直接智能问答 - if len(plugin_ids) == 0: - return "", None - - # 用户只选了auto - if len(plugin_ids) == 1 and plugin_ids[0] == "auto": - # 用户要求自动识别 - plugin_top = Pool().get_k_plugins(question) - # 聚合插件的Flow - plugin_ids = [str(plugin.name) for plugin in plugin_top] - - # 用户固定了Flow的ID - if len(flow_ids) > 0: - # 直接使用对应的Flow,不选择 - return plugin_ids[0], flow_ids[0] - - # 用户选了插件 - flows = Pool().get_k_flows(question, plugin_ids) - - # 使用大模型选择Top1 Flow - flow_list = [{ - "name": str(item.plugin) + "/" + str(item.name), - "description": str(item.description), - } for item in flows] - - if len(plugin_ids) == 1 and plugin_ids[0] == "auto": - # 用户选择自动识别时,包含智能问答 - flow_list += [{ - "name": "KnowledgeBase", - "description": "当上述工具无法直接解决用户问题时,使用知识库进行回答。", - }] - - # 返回top1 Flow的ID - selected_id = await Select().generate(task_id=task_id, choices=flow_list, question=question) - if selected_id == "KnowledgeBase": - return "", None - - plugin_id = selected_id.split("/")[0] - flow_id = selected_id.split("/")[1] - return plugin_id, RequestDataPlugin( - plugin_id=plugin_id, - flow_id=flow_id, - params={}, - auth={}, - ) + # TODO: 根据用户选择的App,选一次top_k flow + return "", None \ No newline at end of file diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 4c078d07..888967fc 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -19,7 +19,7 @@ from apps.entities.plugin import ExecutorBackground, SysExecVars from apps.entities.rag_data import RAGQueryReq from apps.entities.record import RecordDocument from apps.entities.request_data import RequestData -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.manager import ( DocumentManager, RecordManager, @@ -73,7 +73,7 @@ class Scheduler: # 捕获所有异常:出现问题就输出日志,并停止queue try: # 根据用户的请求,返回插件ID列表,选择Flow - self._plugin_id, user_selected_flow = await choose_flow(self._task_id, post_body.question, post_body.plugins) + self._plugin_id, user_selected_flow = await choose_flow(self._task_id, post_body.question, post_body.apps) # 获取当前问答可供关联的文档 docs, doc_ids = await self._get_docs(user_sub, post_body) # 获取上下文;最多20轮 @@ -122,7 +122,7 @@ class Scheduler: if need_recommend: routine_results = await asyncio.gather( generate_facts(self._task_id, post_body.question), - plan_next_flow(user_sub, self._task_id, self._queue, post_body.plugins), + plan_next_flow(user_sub, self._task_id, self._queue, post_body.apps), ) else: routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) @@ -139,7 +139,7 @@ class Scheduler: await self._queue.close() - async def run_executor(self, session_id: str, post_body: RequestData, background: ExecutorBackground, user_selected_flow: RequestDataPlugin) -> bool: + async def run_executor(self, session_id: str, post_body: RequestData, background: ExecutorBackground, user_selected_flow: RequestDataApp) -> bool: """构造FlowExecutor,并执行所选择的流""" # 获取当前Task task = await TaskManager.get_task(self._task_id) diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py index b23e95e0..f6ff8da7 100644 --- a/apps/service/suggestion.py +++ b/apps/service/suggestion.py @@ -11,7 +11,7 @@ from apps.constants import LOGGER from apps.entities.collection import RecordContent from apps.entities.enum_var import EventType from apps.entities.message import SuggestContent -from apps.entities.task import RequestDataPlugin +from apps.entities.task import RequestDataApp from apps.llm.patterns.recommend import Recommend from apps.manager import ( RecordManager, @@ -28,7 +28,7 @@ USER_TOP_DOMAINS_NUM = 5 HISTORY_QUESTIONS_NUM = 4 -async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: list[RequestDataPlugin]) -> None: # noqa: C901, PLR0912 +async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: list[RequestDataApp]) -> None: # noqa: C901, PLR0912 """生成用户“下一步”Flow的推荐。 - 若Flow的配置文件中已定义`next_flow[]`字段,则直接使用该字段给定的值 -- Gitee From eff32cef05830c8b4c3aa985584385e830f48661 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 00:37:00 +0800 Subject: [PATCH 03/13] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=BA=94=E7=94=A8?= =?UTF-8?q?=E4=B8=AD=E5=BF=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/manager/appcenter.py | 7 ------- apps/routers/appcenter.py | 18 ------------------ 2 files changed, 25 deletions(-) diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py index 0e086087..44e3d707 100644 --- a/apps/manager/appcenter.py +++ b/apps/manager/appcenter.py @@ -6,11 +6,8 @@ import uuid from enum import Enum from typing import Any, Optional -<<<<<<< HEAD from fastapi.encoders import jsonable_encoder -======= ->>>>>>> bd0b2f5b96d780f9db7be62d24d5b065edd0871f from apps.constants import LOGGER from apps.entities.appcenter import AppCenterCardItem, AppData from apps.entities.enum_var import SearchType @@ -169,12 +166,8 @@ class AppCenterManager: ) try: app_collection = MongoDB.get_collection("app") -<<<<<<< HEAD app_dict = jsonable_encoder(app) await app_collection.insert_one(app_dict) -======= - await app_collection.insert_one(app.model_dump(by_alias=True)) ->>>>>>> bd0b2f5b96d780f9db7be62d24d5b065edd0871f return app_id except Exception as e: LOGGER.error(f"[AppCenterManager] Create app failed: {e}") diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 70eb7aba..6140dfed 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -37,21 +37,12 @@ router = APIRouter( @router.get("", response_model=Union[GetAppListRsp, ResponseData]) async def get_applications( # noqa: ANN201, PLR0913 user_sub: Annotated[str, Depends(get_user)], -<<<<<<< HEAD my_app: Annotated[Optional[bool], Query(alias="createdByMe", description="筛选我创建的")] = None, my_fav: Annotated[Optional[bool], Query(alias="favorited", description="筛选我收藏的")] = None, search_type: Annotated[SearchType, Query(alias="searchType", description="搜索类型")] = SearchType.ALL, keyword: Annotated[Optional[str], Query(alias="keyword", description="搜索关键字")] = None, page: Annotated[int, Query(alias="page", ge=1, description="页码")] = 1, page_size: Annotated[int, Query(alias="pageSize", ge=1, le=100, description="每页条数")] = 20, -======= - my_app: Annotated[Optional[bool], Query(None, alias="createdByMe", description="筛选我创建的")], - my_fav: Annotated[Optional[bool], Query(None, alias="favorited", description="筛选我收藏的")], - search_type: Annotated[SearchType, Query(SearchType.ALL, alias="searchType", description="搜索类型")], - keyword: Annotated[Optional[str], Query(None, alias="keyword", description="搜索关键字")], - page: Annotated[int, Query(1, alias="page", ge=1, description="页码")], - page_size: Annotated[int, Query(20, alias="pageSize", ge=1, le=100, description="每页条数")], ->>>>>>> bd0b2f5b96d780f9db7be62d24d5b065edd0871f ): """获取应用列表""" if my_app and my_fav: # 只能同时使用一个过滤条件 @@ -126,13 +117,8 @@ async def create_or_update_application( # noqa: ANN201 @router.get("/recent", response_model=Union[GetRecentAppListRsp, ResponseData]) async def get_recently_used_applications( # noqa: ANN201 -<<<<<<< HEAD user_sub: Annotated[str, Depends(get_user)], count: Annotated[int, Query(ge=1, le=10)] = 5, -======= - count: Annotated[int, Query(5, ge=1, le=10)], - user_sub: Annotated[str, Depends(get_user)], ->>>>>>> bd0b2f5b96d780f9db7be62d24d5b065edd0871f ): """获取最近使用的应用""" recent_apps = await AppCenterManager.get_recently_used_apps( @@ -230,11 +216,7 @@ async def delete_application( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -<<<<<<< HEAD @router.post("/{appId}", response_model=BaseAppOperationRsp) -======= -@router.post("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=BaseAppOperationRsp) ->>>>>>> bd0b2f5b96d780f9db7be62d24d5b065edd0871f async def publish_application( # noqa: ANN201 app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], user_sub: Annotated[str, Depends(get_user)], -- Gitee From b5f0f23e193b98b7bd73b34a5996cf0b4274293d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 14:23:38 +0800 Subject: [PATCH 04/13] =?UTF-8?q?=E5=AE=8C=E5=96=84chat=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/queue.py | 11 +- apps/entities/request_data.py | 4 +- apps/scheduler/scheduler/message.py | 26 ++-- apps/scheduler/scheduler/scheduler.py | 7 +- apps/service/suggestion.py | 181 +++++++++++++------------- 5 files changed, 117 insertions(+), 112 deletions(-) diff --git a/apps/common/queue.py b/apps/common/queue.py index 0aac111d..eab5b7aa 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -65,7 +65,7 @@ class MessageQueue: if not history_ids: # 如果new_history为空,则说明是第一次执行,创建一个空值 flow = MessageFlow( - appId=tcb.flow_state.plugin_id, + appId=tcb.flow_state.app_id, flowId=tcb.flow_state.name, stepId="start", stepStatus=StepStatus.RUNNING, @@ -75,7 +75,8 @@ class MessageQueue: history = tcb.flow_context[tcb.flow_state.step_id] flow = MessageFlow( - appId=history.plugin_id, + # TODO:appId 和 flowId 暂时使用flow_id + appId=history.flow_id, flowId=history.flow_id, stepId=history.step_id, stepStatus=history.status, @@ -86,9 +87,9 @@ class MessageQueue: message = MessageBase( event=event_type, id=tcb.record.id, - group_id=tcb.record.group_id, - conversation_id=tcb.record.conversation_id, - task_id=tcb.record.task_id, + groupId=tcb.record.group_id, + conversationId=tcb.record.conversation_id, + taskId=tcb.record.task_id, metadata=metadata, flow=flow, content=data, diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index 9061c4b6..68c28196 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -22,8 +22,8 @@ class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" question: str = Field(max_length=2000, description="用户输入") - conversation_id: str - group_id: str + conversation_id: str = Field(alias= "conversationId", description="会话ID") + group_id: str = Field(alias= "groupId", description="会话分组ID") language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[]) apps: list[RequestDataApp] = Field(default=[]) diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index fb7bd87b..9771b108 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -34,17 +34,17 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques # 组装feature if is_flow: feature = InitContentFeature( - max_tokens=post_body.features.max_tokens, - context_num=post_body.features.context_num, - enable_feedback=False, - enable_regenerate=False, + maxTokens=post_body.features.max_tokens, + contextNum=post_body.features.context_num, + enableFeedback=False, + enableRegenerate=False, ) else: feature = InitContentFeature( - max_tokens=post_body.features.max_tokens, - context_num=post_body.features.context_num, - enable_feedback=True, - enable_regenerate=True, + maxTokens=post_body.features.max_tokens, + contextNum=post_body.features.context_num, + enableFeedback=True, + enableRegenerate=True, ) # 保存必要信息到Task @@ -54,7 +54,7 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques await TaskManager.set_task(task_id, task) # 推送初始化消息 - await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, created_at=created_at).model_dump(exclude_none=True, by_alias=True)) + await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True)) async def push_rag_message(task_id: str, queue: MessageQueue, user_sub: str, rag_data: RAGQueryReq) -> None: @@ -108,9 +108,9 @@ async def _push_rag_chunk(task_id: str, queue: MessageQueue, content: str, rag_i async def push_document_message(queue: MessageQueue, doc: Union[RecordDocument, Document]) -> None: """推送文档消息""" content = DocumentAddContent( - document_id=doc.id, - document_name=doc.name, - document_type=doc.type, - document_size=round(doc.size, 2), + documentId=doc.id, + documentName=doc.name, + documentType=doc.type, + documentSize=round(doc.size, 2), ) await queue.push_output(event_type=EventType.DOCUMENT_ADD, data=content.model_dump(exclude_none=True, by_alias=True)) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 888967fc..0e7048bc 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -73,7 +73,8 @@ class Scheduler: # 捕获所有异常:出现问题就输出日志,并停止queue try: # 根据用户的请求,返回插件ID列表,选择Flow - self._plugin_id, user_selected_flow = await choose_flow(self._task_id, post_body.question, post_body.apps) + # self._plugin_id, user_selected_flow = await choose_flow(self._task_id, post_body.question, post_body.apps) + user_selected_flow = None # 获取当前问答可供关联的文档 docs, doc_ids = await self._get_docs(user_sub, post_body) # 获取上下文;最多20轮 @@ -89,7 +90,7 @@ class Scheduler: question=post_body.question, language=post_body.language, document_ids=doc_ids, - kb_sn=None if not user_info.kb_id else user_info.kb_id, + kb_sn=None if user_info is None or not user_info.kb_id else user_info.kb_id, history=context, top_k=5, ) @@ -121,7 +122,7 @@ class Scheduler: # 如果需要生成推荐问题,则生成 if need_recommend: routine_results = await asyncio.gather( - generate_facts(self._task_id, post_body.question), + # generate_facts(self._task_id, post_body.question), plan_next_flow(user_sub, self._task_id, self._queue, post_body.apps), ) else: diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py index f6ff8da7..fbe48b9e 100644 --- a/apps/service/suggestion.py +++ b/apps/service/suggestion.py @@ -69,99 +69,102 @@ async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_ generated_questions += f"{question}\n" content = SuggestContent( question=question, - plugin_id="", - flow_id="", - flow_description="", + appId="", + flowId="", + flowDescription="", ) await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) return # 当前使用了Flow flow_id = task.flow_state.name - plugin_id = task.flow_state.plugin_id - _, flow_data = Pool().get_flow(flow_id, plugin_id) - if flow_data is None: - err = "Flow数据不存在" - raise ValueError(err) - - if flow_data.next_flow is None: - # 根据用户选择的插件,选一次top_k flow - plugin_ids = [] - for plugin in user_selected_plugins: - if plugin.plugin_id and plugin.plugin_id not in plugin_ids: - plugin_ids.append(plugin.plugin_id) - result = Pool().get_k_flows(task.record.content.question, plugin_ids) - for i, flow in enumerate(result): - if i >= MAX_RECOMMEND: - break - # 改写问题 - rewrite_question = await Recommend().generate( - task_id=task_id, - action_description=flow.description, - history_questions=last_n_questions, - recent_question=current_record, - user_preference=str(user_domain), - shown_questions=generated_questions, - ) - generated_questions += f"{rewrite_question}\n" - - content = SuggestContent( - plugin_id=plugin_id, - flow_id=flow_id, - flow_description=str(flow.description), - question=rewrite_question, - ) - await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) - return - - # 当前有next_flow - for i, next_flow in enumerate(flow_data.next_flow): - # 取前MAX_RECOMMEND个Flow,保持顺序 - if i >= MAX_RECOMMEND: - break - - if next_flow.plugin is not None: - next_flow_plugin_id = next_flow.plugin - else: - next_flow_plugin_id = plugin_id - - flow_metadata, _ = Pool().get_flow( - next_flow.id, - next_flow_plugin_id, - ) - - # flow不合法 - if flow_metadata is None: - LOGGER.error(f"Flow {next_flow.id} in {next_flow_plugin_id} not found") - continue - - # 如果设置了question,直接使用这个question - if next_flow.question is not None: - content = SuggestContent( - plugin_id=next_flow_plugin_id, - flow_id=next_flow.id, - flow_description=str(flow_metadata.description), - question=next_flow.question, - ) - await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) - continue - - # 没有设置question,则需要生成问题 - rewrite_question = await Recommend().generate( - task_id=task_id, - action_description=flow_metadata.description, - history_questions=last_n_questions, - recent_question=current_record, - user_preference=str(user_domain), - shown_questions=generated_questions, - ) - generated_questions += f"{rewrite_question}\n" - content = SuggestContent( - plugin_id=next_flow_plugin_id, - flow_id=next_flow.id, - flow_description=str(flow_metadata.description), - question=rewrite_question, - ) - await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) - continue + app_id = task.flow_state.app_id return + # TODO: 推荐flow待完善 + # _, flow_data = Pool().get_flow(flow_id, app_id) + # if flow_data is None: + # err = "Flow数据不存在" + # raise ValueError(err) + + # if flow_data.next_flow is None: + # # 根据用户选择的插件,选一次top_k flow + # app_ids = [] + # for plugin in user_selected_plugins: + # if plugin.app_id and plugin.app_id not in app_ids: + # app_ids.append(plugin.app_id) + # result = Pool().get_k_flows(task.record.content.question, app_ids) + # for i, flow in enumerate(result): + # if i >= MAX_RECOMMEND: + # break + # # 改写问题 + # rewrite_question = await Recommend().generate( + # task_id=task_id, + # action_description=flow.description, + # history_questions=last_n_questions, + # recent_question=current_record, + # user_preference=str(user_domain), + # shown_questions=generated_questions, + # ) + # generated_questions += f"{rewrite_question}\n" + + # content = SuggestContent( + # app_id=app_id, + # flow_id=flow_id, + # flow_description=str(flow.description), + # question=rewrite_question, + # ) + # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) + # return + + # # 当前有next_flow + # for i, next_flow in enumerate(flow_data.next_flow): + # # 取前MAX_RECOMMEND个Flow,保持顺序 + # if i >= MAX_RECOMMEND: + # break + + # if next_flow.plugin is not None: + # next_flow_app_id = next_flow.plugin + # else: + # next_flow_app_id = app_id + + # flow_metadata, _ = next_flow.id, next_flow_app_id, + # # flow_metadata, _ = Pool().get_flow( + # # next_flow.id, + # # next_flow_app_id, + # # ) + + # # flow不合法 + # if flow_metadata is None: + # LOGGER.error(f"Flow {next_flow.id} in {next_flow_app_id} not found") + # continue + + # # 如果设置了question,直接使用这个question + # if next_flow.question is not None: + # content = SuggestContent( + # appId=next_flow_app_id, + # flowId=next_flow.id, + # flowDescription=str(flow_metadata.description), + # question=next_flow.question, + # ) + # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) + # continue + + # # 没有设置question,则需要生成问题 + # rewrite_question = await Recommend().generate( + # task_id=task_id, + # action_description=flow_metadata.description, + # history_questions=last_n_questions, + # recent_question=current_record, + # user_preference=str(user_domain), + # shown_questions=generated_questions, + # ) + # generated_questions += f"{rewrite_question}\n" + # content = SuggestContent( + # appId=next_flow_app_id, + # flowId=next_flow.id, + # flowDescription=str(flow_metadata.description), + # question=rewrite_question, + # ) + # await queue.push_output(event_type=EventType.SUGGEST, data=content.model_dump(exclude_none=True, by_alias=True)) + # continue + # return -- Gitee From ef84ac7bde341fdf86c0f31eb34c3e186fe9fc1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 16:10:59 +0800 Subject: [PATCH 05/13] =?UTF-8?q?=E4=BF=AE=E6=AD=A3request?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/request_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index 68c28196..a436ee83 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -22,11 +22,11 @@ class RequestData(BaseModel): """POST /api/chat 请求的总的数据结构""" question: str = Field(max_length=2000, description="用户输入") - conversation_id: str = Field(alias= "conversationId", description="会话ID") - group_id: str = Field(alias= "groupId", description="会话分组ID") + conversation_id: str = Field(default=None, alias="conversationId", description="会话ID") + group_id: Optional[str] = Field(default=None, alias="groupId", description="群组ID") language: str = Field(default="zh", description="语言") - files: list[str] = Field(default=[]) - apps: list[RequestDataApp] = Field(default=[]) + files: list[str] = Field(default=[], description="文件列表") + apps: list[RequestDataApp] = Field(default=[], description="应用列表") features: RequestDataFeatures = Field(description="消息功能设置") -- Gitee From d3defb45567859b1ab0c768967ca47d55e9b53ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 16:13:36 +0800 Subject: [PATCH 06/13] =?UTF-8?q?=E4=BF=AE=E6=AD=A3chat=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/scheduler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 0e7048bc..79d92dec 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -125,11 +125,11 @@ class Scheduler: # generate_facts(self._task_id, post_body.question), plan_next_flow(user_sub, self._task_id, self._queue, post_body.apps), ) - else: - routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) + # else: + # routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) - # 保存事实信息 - self._facts = routine_results[0] + # # 保存事实信息 + # self._facts = routine_results[0] # 发送结束消息 await self._queue.push_output(event_type=EventType.DONE, data={}) -- Gitee From 2dec92c9092ba00009e0556aff8a8fa8267c3a35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 16:13:52 +0800 Subject: [PATCH 07/13] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/run_api.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/run_api.py diff --git a/tests/run_api.py b/tests/run_api.py new file mode 100644 index 00000000..cc13c8fd --- /dev/null +++ b/tests/run_api.py @@ -0,0 +1,19 @@ +from fastapi import FastAPI +from apps.routers.appcenter import router as appcenter_router +from apps.routers.chat import router as chat_router +from apps.routers.conversation import router as conversation_router +from apps.routers.auth import router as auth_router +from apps.routers.record import router as record_router + +app = FastAPI() + +# 注册路由 +app.include_router(chat_router) +app.include_router(conversation_router) +app.include_router(appcenter_router) +app.include_router(record_router) +# app.include_router(auth_router) + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file -- Gitee From 2ca78d15f7d7e52ce48aa79c1707b6f1f0062530 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 16:36:15 +0800 Subject: [PATCH 08/13] =?UTF-8?q?scheduler=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/scheduler/scheduler/scheduler.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 79d92dec..4f32951f 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -202,7 +202,8 @@ class Scheduler: user_sub=user_sub, data=encrypt_data, key=encrypt_config, - facts=self._facts, + # TODO:暂时不保存facts,因为facts是动态的,需要根据答案生成 + facts=[""], metadata=task.record.metadata, created_at=task.record.created_at, flow=task.new_context, -- Gitee From 3420c55caf257baeb7b1d328e4a5c03cf4d91ed9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Fri, 24 Jan 2025 16:46:40 +0800 Subject: [PATCH 09/13] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 ++- apps/routers/appcenter.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 11b4659f..72be2dc7 100644 --- a/.gitignore +++ b/.gitignore @@ -10,4 +10,5 @@ apps/utils/init *.bak apps/embedding -logs \ No newline at end of file +logs +test/run_api.py \ No newline at end of file diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index c0acc3ce..8ba2f16c 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -217,7 +217,7 @@ async def delete_application( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -@router.post("/{appId}", response_model=BaseAppOperationRsp) +@router.post("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=BaseAppOperationRsp) async def publish_application( # noqa: ANN201 app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], user_sub: Annotated[str, Depends(get_user)], -- Gitee From 9100fcf22906b532ee20d93e0de6f8296e10bfc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= Date: Fri, 24 Jan 2025 08:47:27 +0000 Subject: [PATCH 10/13] =?UTF-8?q?=E5=88=A0=E9=99=A4=E6=96=87=E4=BB=B6=20te?= =?UTF-8?q?sts/run=5Fapi.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/run_api.py | 19 ------------------- 1 file changed, 19 deletions(-) delete mode 100644 tests/run_api.py diff --git a/tests/run_api.py b/tests/run_api.py deleted file mode 100644 index cc13c8fd..00000000 --- a/tests/run_api.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import FastAPI -from apps.routers.appcenter import router as appcenter_router -from apps.routers.chat import router as chat_router -from apps.routers.conversation import router as conversation_router -from apps.routers.auth import router as auth_router -from apps.routers.record import router as record_router - -app = FastAPI() - -# 注册路由 -app.include_router(chat_router) -app.include_router(conversation_router) -app.include_router(appcenter_router) -app.include_router(record_router) -# app.include_router(auth_router) - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file -- Gitee From ca89fe920dab3ec1b1dad4f597aa3fa9cb71c01a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Sun, 26 Jan 2025 15:53:50 +0800 Subject: [PATCH 11/13] showcase demo --- apps/dependency/user.py | 29 +++++++------- apps/entities/flow.py | 2 +- apps/entities/request_data.py | 2 +- apps/entities/response_data.py | 2 +- apps/manager/appcenter.py | 56 ++++++++++++++++++++++++++- apps/manager/conversation.py | 22 ++++++----- apps/manager/record.py | 1 + apps/models/mongo.py | 3 +- apps/routers/appcenter.py | 34 +++++++++------- apps/routers/chat.py | 37 ++++++++++++------ apps/routers/conversation.py | 24 +++++++++--- apps/routers/record.py | 9 +++-- apps/scheduler/scheduler/scheduler.py | 9 +++-- apps/service/suggestion.py | 2 +- 14 files changed, 163 insertions(+), 69 deletions(-) diff --git a/apps/dependency/user.py b/apps/dependency/user.py index 6c45ce8c..b12b5ba7 100644 --- a/apps/dependency/user.py +++ b/apps/dependency/user.py @@ -20,9 +20,10 @@ async def verify_user(request: HTTPConnection) -> None: :param request: HTTP请求 :return: """ - session_id = request.cookies["ECSESSION"] - if not await SessionManager.verify_user(session_id): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") + pass + # session_id = request.cookies["ECSESSION"] + # if not await SessionManager.verify_user(session_id): + # raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") async def get_session(request: HTTPConnection) -> str: """验证Session是否已鉴权,并返回Session ID;未鉴权则抛出HTTP 401;参数级dependence @@ -35,18 +36,20 @@ async def get_session(request: HTTPConnection) -> str: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") return session_id -async def get_user(request: HTTPConnection) -> str: - """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence +# async def get_user(request: HTTPConnection) -> str: +# """验证Session是否已鉴权;若已鉴权,查询对应的user_sub;若未鉴权,抛出HTTP 401;参数级dependence - :param request: HTTP请求体 - :return: 用户sub - """ - session_id = request.cookies["ECSESSION"] - user = await SessionManager.get_user(session_id) - if not user: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") - return user +# :param request: HTTP请求体 +# :return: 用户sub +# """ +# session_id = request.cookies["ECSESSION"] +# user = await SessionManager.get_user(session_id) +# if not user: +# raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication Error.") +# return user +async def get_user(request: HTTPConnection) -> str: + return "test" async def verify_api_key(api_key: str = Depends(oauth2_scheme)) -> None: """验证API Key是否有效;无效则抛出HTTP 401;接口级dependence diff --git a/apps/entities/flow.py b/apps/entities/flow.py index ad6caa1e..c674afb1 100644 --- a/apps/entities/flow.py +++ b/apps/entities/flow.py @@ -122,7 +122,7 @@ class AppLink(BaseModel): """App的相关链接""" title: str = Field(description="链接标题") - url: HttpUrl = Field(..., description="链接地址") + url: str = Field(..., description="链接地址") class Permission(BaseModel): diff --git a/apps/entities/request_data.py b/apps/entities/request_data.py index a436ee83..50ab6a67 100644 --- a/apps/entities/request_data.py +++ b/apps/entities/request_data.py @@ -26,7 +26,7 @@ class RequestData(BaseModel): group_id: Optional[str] = Field(default=None, alias="groupId", description="群组ID") language: str = Field(default="zh", description="语言") files: list[str] = Field(default=[], description="文件列表") - apps: list[RequestDataApp] = Field(default=[], description="应用列表") + app: list[RequestDataApp] = Field(default=[], description="应用列表") features: RequestDataFeatures = Field(description="消息功能设置") diff --git a/apps/entities/response_data.py b/apps/entities/response_data.py index e2c55800..534ba0ef 100644 --- a/apps/entities/response_data.py +++ b/apps/entities/response_data.py @@ -284,7 +284,7 @@ class GetAppListMsg(BaseModel): """GET /api/app Result数据结构""" page_number: int = Field(..., alias="currentPage", description="当前页码") - page_count: int = Field(..., alias="totalPages", description="总页数") + app_count: int = Field(..., alias="total", description="总页数") applications: list[AppCenterCardItem] = Field(..., description="应用列表") diff --git a/apps/manager/appcenter.py b/apps/manager/appcenter.py index f94866bb..ea8c9ced 100644 --- a/apps/manager/appcenter.py +++ b/apps/manager/appcenter.py @@ -2,6 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. """ +from datetime import datetime, timezone import uuid from enum import Enum from typing import Any, Optional @@ -112,12 +113,16 @@ class AppCenterManager: "_id": {"$in": fav_app}, "published": True, } + print(base_filter) filters: dict[str, Any] = AppCenterManager._build_filters( base_filter, search_type, keyword, ) if keyword else base_filter + print(filters) + print(page, page_size) apps, total_pages = await AppCenterManager._search_apps_by_filter(filters, page, page_size) + print(apps) return [ AppCenterCardItem( appId=app.id, @@ -197,7 +202,8 @@ class AppCenterManager: } if published_false_needed: update_data["published"] = False - await app_collection.update_one({"_id": app_id}, {"$set": update_data}) + #TODO: 格式有問題 + await app_collection.update_one({"_id": app_id}, {"$set": jsonable_encoder(update_data)}) return True except Exception as e: LOGGER.error(f"[AppCenterManager] Update app failed: {e}") @@ -329,7 +335,9 @@ class AppCenterManager: try: app_collection = MongoDB.get_collection("app") total_apps = await app_collection.count_documents(search_conditions) - total_pages = (total_apps + page_size - 1) // page_size + # TODO: 暂时修改为 total_apps + total_pages = total_apps + # total_pages = (total_apps + page_size - 1) // page_size db_data = await app_collection.find(search_conditions) \ .sort("created_at", -1) \ .skip((page - 1) * page_size) \ @@ -351,3 +359,47 @@ class AppCenterManager: except Exception as e: LOGGER.info(f"[AppCenterManager] Get favorite app ids by user_sub failed: {e}") return [] + + @staticmethod + async def update_recent_app(user_sub: str, app_id: str) -> bool: + """更新用户的最近使用应用列表 + + :param user_sub: 用户唯一标识 + :param app_id: 应用唯一标识 + :return: 更新是否成功 + """ + try: + # 获取 user 集合 + user_collection = MongoDB.get_collection("user") + + # 获取当前时间戳 + current_time = round(datetime.now(tz=timezone.utc).timestamp(), 3) + + # 更新用户的 app_usage 字段 + result = await user_collection.update_one( + {"_id": user_sub}, # 查询条件 + { + "$set": { + f"app_usage.{app_id}.last_used": current_time # 更新最后使用时间 + }, + "$inc": { + f"app_usage.{app_id}.count": 1 # 增加使用次数 + } + }, + upsert=True # 如果 app_usage 字段或 app_id 不存在,则创建 + ) + + # 检查更新是否成功 + if result.modified_count > 0 or result.upserted_id is not None: + print("YES") + LOGGER.info(f"[AppCenterManager] Updated recent app for user {user_sub}: {app_id}") + return True + else: + print("NO") + LOGGER.warning(f"[AppCenterManager] No changes made for user {user_sub}") + return False + + except Exception as e: + print(e) + LOGGER.error(f"[AppCenterManager] Failed to update recent app: {e}") + return False \ No newline at end of file diff --git a/apps/manager/conversation.py b/apps/manager/conversation.py index f595d0c5..f816ef6b 100644 --- a/apps/manager/conversation.py +++ b/apps/manager/conversation.py @@ -31,6 +31,7 @@ class ConversationManager: try: conv_collection = MongoDB.get_collection("conversation") result = await conv_collection.find_one({"_id": conversation_id, "user_sub": user_sub}) + print(result, conversation_id, user_sub) if not result: return None return Conversation.model_validate(result) @@ -48,6 +49,7 @@ class ConversationManager: app_id=app_id, is_debug=is_debug, ) + print(conv) try: async with MongoDB.get_session() as session, await session.start_transaction(): conv_collection = MongoDB.get_collection("conversation") @@ -56,18 +58,20 @@ class ConversationManager: update_data: dict[str, dict[str, Any]] = { "$push": {"conversations": conversation_id}, } + if app_id: # 非调试模式下更新应用使用情况 - if not is_debug: - update_data["$set"] = {f"app_usage.{app_id}.last_used": round(datetime.now(timezone.utc).timestamp(), 3)} - update_data["$inc"] = {f"app_usage.{app_id}.count": 1} - await user_collection.update_one( - {"_id": user_sub}, - update_data, - session=session, - ) - await session.commit_transaction() + if not is_debug: + update_data["$set"] = {f"app_usage.{app_id}.last_used": round(datetime.now(timezone.utc).timestamp(), 3)} + update_data["$inc"] = {f"app_usage.{app_id}.count": 1} + await user_collection.update_one( + {"_id": user_sub}, + update_data, + session=session, + ) + await session.commit_transaction() return conv except Exception as e: + print(e) LOGGER.info(f"[ConversationManager] Add conversation by user_sub failed: {e}") return None diff --git a/apps/manager/record.py b/apps/manager/record.py index 54731ad9..a6079db3 100644 --- a/apps/manager/record.py +++ b/apps/manager/record.py @@ -37,6 +37,7 @@ class RecordManager: # Conversation里面加一个ID await conversation_collection.update_one({"_id": conversation_id}, {"$push": {"record_groups": group_id}}, session=session) except Exception as e: + print(e) LOGGER.info(f"Create record group failed: {e}") return None diff --git a/apps/models/mongo.py b/apps/models/mongo.py index 7cb25419..27d1ac65 100644 --- a/apps/models/mongo.py +++ b/apps/models/mongo.py @@ -21,8 +21,7 @@ class MongoDB: """MongoDB连接""" _client: AsyncMongoClient = AsyncMongoClient( - f"mongodb://{urllib.parse.quote_plus(config['MONGODB_USER'])}:{urllib.parse.quote_plus(config['MONGODB_PWD'])}@{config['MONGODB_HOST']}:{config['MONGODB_PORT']}/?directConnection=true&replicaSet=rs0", - ) + f"mongodb://{urllib.parse.quote_plus(config['MONGODB_USER'])}:{urllib.parse.quote_plus(config['MONGODB_PWD'])}@{config['MONGODB_HOST']}:{config['MONGODB_PORT']}/?directConnection=true&replicaSet=rs0", ) @classmethod def get_collection(cls, collection_name: str) -> AsyncCollection: diff --git a/apps/routers/appcenter.py b/apps/routers/appcenter.py index 8ba2f16c..45a1ee0f 100644 --- a/apps/routers/appcenter.py +++ b/apps/routers/appcenter.py @@ -5,10 +5,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. from typing import Annotated, Optional, Union from fastapi import APIRouter, Body, Depends, Path, Query, status +from fastapi.requests import HTTPConnection from fastapi.responses import JSONResponse -from apps.dependency.csrf import verify_csrf_token -from apps.dependency.user import get_user, verify_user +# from apps.dependency.csrf import verify_csrf_token +# from apps.dependency.user import get_user, verify_user from apps.entities.appcenter import AppPermissionData from apps.entities.enum_var import SearchType from apps.entities.request_data import CreateAppRequest, ModFavAppRequest @@ -30,9 +31,10 @@ from apps.manager.flow import FlowManager router = APIRouter( prefix="/api/app", tags=["appcenter"], - dependencies=[Depends(verify_user)], + # dependencies=[Depends(verify_user)], ) - +async def get_user(request: HTTPConnection) -> str: + return "test" @router.get("", response_model=Union[GetAppListRsp, ResponseData]) async def get_applications( # noqa: ANN201, PLR0913 @@ -53,34 +55,35 @@ async def get_applications( # noqa: ANN201, PLR0913 result={}, ) - app_cards, total_pages = [], -1 + app_cards, total_apps = [], -1 if my_app: # 筛选我创建的 - app_cards, total_pages = await AppCenterManager.fetch_user_apps( + app_cards, total_apps = await AppCenterManager.fetch_user_apps( user_sub, search_type, keyword, page, page_size) elif my_fav: # 筛选已收藏的 - app_cards, total_pages = await AppCenterManager.fetch_favorite_apps( + app_cards, total_apps = await AppCenterManager.fetch_favorite_apps( user_sub, search_type, keyword, page, page_size) else: # 获取所有应用 - app_cards, total_pages = await AppCenterManager.fetch_all_apps( + app_cards, total_apps = await AppCenterManager.fetch_all_apps( user_sub, search_type, keyword, page, page_size) - if total_pages == -1: + if total_apps == -1: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, message="查询失败", result={}, ).model_dump(exclude_none=True, by_alias=True)) + #TODO: 返回总量数 return JSONResponse(status_code=status.HTTP_200_OK, content=GetAppListRsp( code=status.HTTP_200_OK, message="查询成功", result=GetAppListMsg( currentPage=page, - totalPages=total_pages, + total=total_apps, applications=app_cards, ), ).model_dump(exclude_none=True, by_alias=True)) -@router.post("", dependencies=[Depends(verify_csrf_token)], response_model=Union[BaseAppOperationRsp, ResponseData]) +@router.post("", response_model=Union[BaseAppOperationRsp, ResponseData]) async def create_or_update_application( # noqa: ANN201 request: Annotated[CreateAppRequest, Body(...)], user_sub: Annotated[str, Depends(get_user)], @@ -124,6 +127,7 @@ async def get_recently_used_applications( # noqa: ANN201 """获取最近使用的应用""" recent_apps = await AppCenterManager.get_recently_used_apps( count, user_sub) + print(recent_apps) if recent_apps is None: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=ResponseData( code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -173,7 +177,7 @@ async def get_application( # noqa: ANN201 @router.delete( "/{appId}", - dependencies=[Depends(verify_csrf_token)], + # dependencies=[Depends(verify_csrf_token)], response_model=Union[BaseAppOperationRsp, ResponseData], ) async def delete_application( # noqa: ANN201 @@ -217,7 +221,9 @@ async def delete_application( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -@router.post("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=BaseAppOperationRsp) +@router.post("/{appId}", + # dependencies=[Depends(verify_csrf_token)], + response_model=BaseAppOperationRsp) async def publish_application( # noqa: ANN201 app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], user_sub: Annotated[str, Depends(get_user)], @@ -251,7 +257,7 @@ async def publish_application( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -@router.put("/{appId}", dependencies=[Depends(verify_csrf_token)], response_model=Union[ModFavAppRsp, ResponseData]) +@router.put("/{appId}", response_model=Union[ModFavAppRsp, ResponseData]) async def modify_favorite_application( # noqa: ANN201 app_id: Annotated[str, Path(..., alias="appId", description="应用ID")], request: Annotated[ModFavAppRequest, Body(...)], diff --git a/apps/routers/chat.py b/apps/routers/chat.py index ded2007c..3dced9e0 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -9,17 +9,19 @@ from collections.abc import AsyncGenerator from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.requests import HTTPConnection from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue from apps.common.wordscheck import WordsCheck from apps.constants import LOGGER -from apps.dependency import ( - get_session, - get_user, - verify_csrf_token, - verify_user, -) + +# from apps.dependency import ( +# # get_session, +# # get_user, +# # verify_csrf_token, +# # verify_user, +# ) from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData from apps.manager import ( @@ -27,6 +29,7 @@ from apps.manager import ( TaskManager, UserBlacklistManager, ) +from apps.manager.appcenter import AppCenterManager from apps.scheduler.scheduler import Scheduler from apps.service.activity import Activity @@ -36,7 +39,8 @@ router = APIRouter( prefix="/api", tags=["chat"], ) - +async def get_user(request: HTTPConnection) -> str: + return "test" async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: """进行实际问答,并从MQ中获取消息""" @@ -112,11 +116,14 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await Activity.remove_active(user_sub) -@router.post("/chat", dependencies=[Depends(verify_csrf_token), Depends(verify_user)]) +@router.post("/chat", + # dependencies=[Depends(verify_csrf_token), Depends(verify_user)] + ) async def chat( post_body: RequestData, user_sub: Annotated[str, Depends(get_user)], - session_id: Annotated[str, Depends(get_session)], + # session_id: Annotated[str, Depends(get_session)], + session_id : str = "1234567890" ) -> StreamingResponse: """LLM流式对话接口""" # 问题黑名单检测 @@ -126,9 +133,12 @@ async def chat( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="question is blacklisted") # 限流检查 - if await Activity.is_active(user_sub): - raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") + # if await Activity.is_active(user_sub): + # raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Too many requests") + # print(post_body.app) + if post_body.app and post_body.app[0].app_id: + await AppCenterManager.update_recent_app(user_sub, post_body.app[0].app_id) res = chat_generator(post_body, user_sub, session_id) return StreamingResponse( content=res, @@ -139,7 +149,10 @@ async def chat( ) -@router.post("/stop", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +@router.post("/stop", response_model=ResponseData, + # dependencies=[Depends(verify_csrf_token)] + ) + async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """停止生成""" await Activity.remove_active(user_sub) diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 6caca27a..4b8d3806 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -5,12 +5,13 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from datetime import datetime from typing import Annotated, Optional +from fastapi.requests import HTTPConnection import pytz from fastapi import APIRouter, Depends, Query, Request, status from fastapi.responses import JSONResponse from apps.constants import LOGGER -from apps.dependency import get_user, verify_csrf_token, verify_user +# from apps.dependency import get_user, verify_csrf_token, verify_user from apps.entities.collection import Audit, Conversation from apps.entities.request_data import ( DeleteConversationData, @@ -39,9 +40,11 @@ router = APIRouter( prefix="/api/conversation", tags=["conversation"], dependencies=[ - Depends(verify_user), + # Depends(verify_user), ], ) +async def get_user(request: HTTPConnection) -> str: + return "test" async def create_new_conversation( @@ -59,10 +62,11 @@ async def create_new_conversation( conv_records = await RecordManager.query_record_by_conversation_id(user_sub, last_conv.id, 1, "desc") if len(conv_records) > 0: create_new = True + # return last_conv # 新建对话 if create_new: - if not AppManager.validate_user_app_access(user_sub, app_id): + if app_id and not await AppManager.validate_user_app_access(user_sub, app_id): err = "Invalid app_id." raise RuntimeError(err) new_conv = await ConversationManager.add_conversation_by_user_sub(user_sub, @@ -123,7 +127,9 @@ async def get_conversation_list(user_sub: Annotated[str, Depends(get_user)]): # -@router.post("", dependencies=[Depends(verify_csrf_token)], response_model=AddConversationRsp) +@router.post("", + # dependencies=[Depends(verify_csrf_token)], + response_model=AddConversationRsp) async def add_conversation( # noqa: ANN201 user_sub: Annotated[str, Depends(get_user)], appId: Optional[str] = None, # noqa: N803 @@ -134,6 +140,8 @@ async def add_conversation( # noqa: ANN201 # 尝试创建新对话 try: app_id = appId if appId else "" + if appId: + conversations = [] is_debug = isDebug if isDebug is not None else False new_conv = await create_new_conversation(user_sub, conversations, app_id=app_id, is_debug=is_debug) @@ -157,7 +165,9 @@ async def add_conversation( # noqa: ANN201 ).model_dump(exclude_none=True, by_alias=True)) -@router.put("", response_model=UpdateConversationRsp, dependencies=[Depends(verify_csrf_token)]) +@router.put("", response_model=UpdateConversationRsp, + # dependencies=[Depends(verify_csrf_token)] + ) async def update_conversation( # noqa: ANN201 post_body: ModifyConversationData, conversationId: Annotated[str, Query()], # noqa: N803 @@ -206,7 +216,9 @@ async def update_conversation( # noqa: ANN201 ) -@router.delete("", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) +@router.delete("", response_model=ResponseData, + # dependencies=[Depends(verify_csrf_token)] + ) async def delete_conversation(request: Request, post_body: DeleteConversationData, user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """删除特定对话""" deleted_conversation = [] diff --git a/apps/routers/record.py b/apps/routers/record.py index 361adc74..7c66e4c2 100644 --- a/apps/routers/record.py +++ b/apps/routers/record.py @@ -6,10 +6,11 @@ import json from typing import Annotated from fastapi import APIRouter, Depends, status +from fastapi.requests import HTTPConnection from fastapi.responses import JSONResponse from apps.common.security import Security -from apps.dependency import get_user, verify_user +# from apps.dependency import get_user, verify_user from apps.entities.collection import ( RecordContent, ) @@ -28,14 +29,16 @@ router = APIRouter( prefix="/api/record", tags=["record"], dependencies=[ - Depends(verify_user), + # Depends(verify_user), ], ) - +async def get_user(request: HTTPConnection) -> str: + return "test" @router.get("/{conversation_id}", response_model=RecordListRsp, responses={status.HTTP_403_FORBIDDEN: {"model": ResponseData}}) async def get_record(conversation_id: str, user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """获取某个对话的所有问答对""" + print(user_sub, conversation_id) cur_conv = await ConversationManager.get_conversation_by_conversation_id(user_sub, conversation_id) # 判断conversation是否合法 if not cur_conv: diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 4f32951f..5c45f35d 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -123,7 +123,7 @@ class Scheduler: if need_recommend: routine_results = await asyncio.gather( # generate_facts(self._task_id, post_body.question), - plan_next_flow(user_sub, self._task_id, self._queue, post_body.apps), + plan_next_flow(user_sub, self._task_id, self._queue, post_body.app[0]), ) # else: # routine_results = await asyncio.gather(generate_facts(self._task_id, post_body.question)) @@ -202,8 +202,9 @@ class Scheduler: user_sub=user_sub, data=encrypt_data, key=encrypt_config, - # TODO:暂时不保存facts,因为facts是动态的,需要根据答案生成 - facts=[""], + # facts=self._facts, + #TODO:暂停 + facts=[], metadata=task.record.metadata, created_at=task.record.created_at, flow=task.new_context, @@ -216,7 +217,7 @@ class Scheduler: if not record_group: LOGGER.error("[Scheduler] Create record group failed.") return - + print(record_group) # 修改文件状态 await DocumentManager.change_doc_status(user_sub, post_body.conversation_id, record_group) # 保存Record diff --git a/apps/service/suggestion.py b/apps/service/suggestion.py index fbe48b9e..4e4deb9d 100644 --- a/apps/service/suggestion.py +++ b/apps/service/suggestion.py @@ -28,7 +28,7 @@ USER_TOP_DOMAINS_NUM = 5 HISTORY_QUESTIONS_NUM = 4 -async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: list[RequestDataApp]) -> None: # noqa: C901, PLR0912 +async def plan_next_flow(user_sub: str, task_id: str, queue: MessageQueue, user_selected_plugins: RequestDataApp) -> None: # noqa: C901, PLR0912 """生成用户“下一步”Flow的推荐。 - 若Flow的配置文件中已定义`next_flow[]`字段,则直接使用该字段给定的值 -- Gitee From 020a5d017994b9ff674c18fac9c2f5e375b72318 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Mon, 27 Jan 2025 10:00:07 +0800 Subject: [PATCH 12/13] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/entities/plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/entities/plugin.py b/apps/entities/plugin.py index af308208..52074931 100644 --- a/apps/entities/plugin.py +++ b/apps/entities/plugin.py @@ -7,7 +7,7 @@ from typing import Any from pydantic import BaseModel, Field from apps.common.queue import MessageQueue -from apps.entities.task import FlowHistory, RequestDataPlugin +from apps.entities.task import FlowHistory, RequestDataApp class SysCallVars(BaseModel): @@ -42,7 +42,7 @@ class SysExecVars(BaseModel): question: str = Field(description="当前Agent的目标") task_id: str = Field(description="当前Executor关联的TaskID") session_id: str = Field(description="当前用户的Session ID") - plugin_data: RequestDataPlugin = Field(description="传递给Executor中Call的参数") + App_data: RequestDataApp = Field(description="传递给Executor中Call的参数") background: ExecutorBackground = Field(description="当前Executor的背景信息") class Config: -- Gitee From e7cc2fccfb151a7e3afe4a56fe344e55b56e28b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 5 Feb 2025 10:05:42 +0800 Subject: [PATCH 13/13] fix bug --- apps/scheduler/scheduler/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 5c45f35d..bacd40ee 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -26,7 +26,7 @@ from apps.manager import ( TaskManager, UserManager, ) -from apps.scheduler.executor import Executor +# from apps.scheduler.executor import Executor from apps.scheduler.scheduler.context import generate_facts, get_context from apps.scheduler.scheduler.flow import choose_flow from apps.scheduler.scheduler.message import ( -- Gitee