From e0d473a28a90fd6cfcee4aac2fbbada34e5495a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E5=8D=9A?= <1016318004@qq.com> Date: Wed, 26 Feb 2025 02:56:09 +0800 Subject: [PATCH] for test --- apps/routers/chat.py | 121 ++++++++++++++++++++++--------------------- apps/routers/mock.py | 14 ++--- 2 files changed, 65 insertions(+), 70 deletions(-) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 2b9162e6..adf5a0f0 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -24,7 +24,7 @@ from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData from apps.manager.appcenter import AppCenterManager from apps.manager.blacklist import QuestionBlacklistManager, UserBlacklistManager -# from apps.scheduler.scheduler import Scheduler +from apps.scheduler.scheduler import Scheduler from apps.service.activity import Activity RECOMMEND_TRES = 5 @@ -35,80 +35,80 @@ router = APIRouter( ) -# async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: -# """进行实际问答,并从MQ中获取消息""" -# try: -# await Activity.set_active(user_sub) +async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) -> AsyncGenerator[str, None]: + """进行实际问答,并从MQ中获取消息""" + try: + await Activity.set_active(user_sub) -# # 敏感词检查 -# word_check = ray.get_actor("words_check") -# if await word_check.check.remote(post_body.question) != 1: -# yield "data: [SENSITIVE]\n\n" -# LOGGER.info(msg="问题包含敏感词!") -# await Activity.remove_active(user_sub) -# return + # 敏感词检查 + word_check = ray.get_actor("words_check") + if await word_check.check.remote(post_body.question) != 1: + yield "data: [SENSITIVE]\n\n" + LOGGER.info(msg="问题包含敏感词!") + await Activity.remove_active(user_sub) + return -# # 生成group_id -# group_id = str(uuid.uuid4()) if not post_body.group_id else post_body.group_id + # 生成group_id + group_id = str(uuid.uuid4()) if not post_body.group_id else post_body.group_id -# # 创建或还原Task(获取task_id) -# task_pool = ray.get_actor("task") -# task = await task_pool.get_task.remote(session_id=session_id, post_body=post_body) -# task_id = task.record.task_id + # 创建或还原Task(获取task_id) + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(session_id=session_id, post_body=post_body) + task_id = task.record.task_id -# task.record.group_id = group_id -# post_body.group_id = group_id -# await task_pool.set_task.remote(task_id, task) + task.record.group_id = group_id + post_body.group_id = group_id + await task_pool.set_task.remote(task_id, task) -# # 创建queue;由Scheduler进行关闭 -# queue = MessageQueue() -# await queue.init(task_id, enable_heartbeat=True) + # 创建queue;由Scheduler进行关闭 + queue = MessageQueue() + await queue.init(task_id, enable_heartbeat=True) -# # 在单独Task中运行Scheduler,拉齐queue.get的时机 -# scheduler = Scheduler(task_id, queue) -# scheduler_task = asyncio.create_task(scheduler.run(user_sub, session_id, post_body)) + # 在单独Task中运行Scheduler,拉齐queue.get的时机 + scheduler = Scheduler(task_id, queue) + scheduler_task = asyncio.create_task(scheduler.run(user_sub, session_id, post_body)) -# # 处理每一条消息 -# async for event in queue.get(): -# if event[:6] == "[DONE]": -# break + # 处理每一条消息 + async for event in queue.get(): + if event[:6] == "[DONE]": + break -# yield "data: " + event + "\n\n" + yield "data: " + event + "\n\n" -# # 等待Scheduler运行完毕 -# await asyncio.gather(scheduler_task) + # 等待Scheduler运行完毕 + await asyncio.gather(scheduler_task) -# # 获取最终答案 -# task = await task_pool.get_task.remote(task_id) -# answer_text = task.record.content.answer -# if not answer_text: -# LOGGER.error(msg="Answer is empty") -# yield "data: [ERROR]\n\n" -# await Activity.remove_active(user_sub) -# return + # 获取最终答案 + task = await task_pool.get_task.remote(task_id) + answer_text = task.record.content.answer + if not answer_text: + LOGGER.error(msg="Answer is empty") + yield "data: [ERROR]\n\n" + await Activity.remove_active(user_sub) + return -# # 对结果进行敏感词检查 -# if await word_check.check.remote(answer_text) != 1: -# yield "data: [SENSITIVE]\n\n" -# LOGGER.info(msg="答案包含敏感词!") -# await Activity.remove_active(user_sub) -# return + # 对结果进行敏感词检查 + if await word_check.check.remote(answer_text) != 1: + yield "data: [SENSITIVE]\n\n" + LOGGER.info(msg="答案包含敏感词!") + await Activity.remove_active(user_sub) + return -# # 创建新Record,存入数据库 -# await scheduler.save_state(user_sub, post_body) -# # 保存Task,从task_map中删除task -# await task_pool.save_task.remote(task_id) + # 创建新Record,存入数据库 + await scheduler.save_state(user_sub, post_body) + # 保存Task,从task_map中删除task + await task_pool.save_task.remote(task_id) -# yield "data: [DONE]\n\n" + yield "data: [DONE]\n\n" -# except Exception as e: -# LOGGER.error(msg=f"生成答案失败:{e!s}\n{traceback.format_exc()}") -# yield "data: [ERROR]\n\n" + except Exception as e: + LOGGER.error(msg=f"生成答案失败:{e!s}\n{traceback.format_exc()}") + yield "data: [ERROR]\n\n" -# finally: -# if scheduler_task: -# scheduler_task.cancel() -# await Activity.remove_active(user_sub) + finally: + if scheduler_task: + scheduler_task.cancel() + await Activity.remove_active(user_sub) @router.post("/chat", dependencies=[Depends(verify_csrf_token), Depends(verify_user)]) @@ -145,6 +145,7 @@ async def chat( ) + @router.post("/stop", response_model=ResponseData, dependencies=[Depends(verify_csrf_token)]) async def stop_generation(user_sub: Annotated[str, Depends(get_user)]): # noqa: ANN201 """停止生成""" diff --git a/apps/routers/mock.py b/apps/routers/mock.py index f2b7017c..0f00ab57 100644 --- a/apps/routers/mock.py +++ b/apps/routers/mock.py @@ -1,5 +1,3 @@ -import asyncio -import copy import json import random import time @@ -8,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, Optional import aiohttp from pydantic import BaseModel, Field import tiktoken -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, status from fastapi.responses import StreamingResponse from apps.common.config import config @@ -19,20 +17,16 @@ from apps.dependency import ( verify_user, ) from apps.entities.request_data import MockRequestData, RequestData -from apps.entities.scheduler import CallError, SysCallVars +from apps.entities.scheduler import CallError from apps.manager.flow import FlowManager from apps.scheduler.pool.loader.flow import FlowLoader from datetime import datetime from textwrap import dedent from typing import Any -import pytz -from jinja2 import BaseLoader, select_autoescape -from jinja2.sandbox import SandboxedEnvironment from pydantic import BaseModel, Field -from apps.entities.scheduler import CallError, SysCallVars -from apps.scheduler.call.core import CoreCall +from apps.entities.scheduler import CallError """问答大模型调用 @@ -323,7 +317,7 @@ async def mock_data(appId="68dd3d90-6a97-4da0-aa62-d38a81c7d2f5", flowId="966c79 time.sleep(t) yield "data: " + json.dumps(message,ensure_ascii=False) + "\n\n" mid_message = [] - flow = await FlowLoader.load(appId, flowId) + flow = await FlowLoader().load(appId, flowId) now_flow_item = "start" start_time = time.time() last_item = "" -- Gitee