diff --git a/apps/common/queue.py b/apps/common/queue.py index 50b02ba3a8dcd6c7600546f01562ae015cf30e89..83f655a6efcb0b8d00bf70491d74453fbbfb1780 100644 --- a/apps/common/queue.py +++ b/apps/common/queue.py @@ -5,6 +5,7 @@ from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import Any +import ray from redis.exceptions import ResponseError from apps.constants import LOGGER @@ -16,7 +17,6 @@ from apps.entities.message import ( MessageMetadata, ) from apps.entities.task import TaskBlock -from apps.manager.task import TaskManager from apps.models.redis import RedisConnectionPool @@ -50,7 +50,8 @@ class MessageQueue: await client.publish(self._stream_name, "[DONE]") return - tcb: TaskBlock = await TaskManager.get_task(self._task_id) + task = ray.get_actor("task") + tcb: TaskBlock = await task.get_task.remote(self._task_id) # 计算创建Task到现在的时间 used_time = round((datetime.now(timezone.utc).timestamp() - tcb.record.metadata.time), 2) diff --git a/apps/common/singleton.py b/apps/common/singleton.py deleted file mode 100644 index 4db37893c354f5c7eadc8786f9007efc83e89915..0000000000000000000000000000000000000000 --- a/apps/common/singleton.py +++ /dev/null @@ -1,20 +0,0 @@ -"""给类开启全局单例模式 - -Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. -""" -from multiprocessing import Lock -from typing import Any, ClassVar - - -class Singleton(type): - """用于实现全局单例的MetaClass""" - - _instances: ClassVar[dict[type, Any]] = {} - _lock = Lock() - - def __call__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204 - """实现单例模式""" - if cls not in cls._instances: - with cls._lock: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] diff --git a/apps/common/task.py b/apps/common/task.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c9f5181455fd393ae13e3ca57dfb516062b3b0 --- /dev/null +++ b/apps/common/task.py @@ -0,0 +1,143 @@ +"""任务模块 + +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" +import uuid +from copy import deepcopy +from datetime import datetime, timezone +from typing import Optional + +import ray + +from apps.entities.enum_var import StepStatus +from apps.entities.record import ( + RecordContent, + RecordData, + RecordMetadata, +) +from apps.entities.request_data import RequestData +from apps.entities.task import TaskBlock, TaskData +from apps.manager.task import TaskManager +from apps.models.mongo import MongoDB + + +@ray.remote +class Task: + """保存任务信息;每一个任务关联一组队列(现阶段只有输入和输出)""" + + def __init__(self) -> None: + """初始化TaskManager""" + self._task_map: dict[str, TaskBlock] = {} + + async def update_token_summary(self, task_id: str, input_num: int, output_num: int) -> None: + """更新对应task_id的Token统计数据""" + self._task_map[task_id].record.metadata.input_tokens += input_num + self._task_map[task_id].record.metadata.output_tokens += output_num + + async def get_task(self, task_id: Optional[str] = None, session_id: Optional[str] = None, post_body: Optional[RequestData] = None) -> TaskBlock: + """获取任务块""" + # 如果task_map里面已经有了,则直接返回副本 + if task_id in self._task_map: + return deepcopy(self._task_map[task_id]) + + # 如果task_map里面没有,则尝试从数据库中读取 + if not session_id or not post_body: + err = "session_id and conversation_id or group_id and conversation_id are required to recover/create a task." + raise ValueError(err) + + if post_body.group_id: + task = await TaskManager.get_task_by_group_id(post_body.group_id, post_body.conversation_id) + else: + task = await TaskManager.get_task_by_conversation_id(post_body.conversation_id) + + # 创建新的Record,缺失的数据延迟关联 + new_record = RecordData( + id=str(uuid.uuid4()), + conversationId=post_body.conversation_id, + groupId=str(uuid.uuid4()) if not post_body.group_id else post_body.group_id, + taskId="", + content=RecordContent( + question=post_body.question, + answer="", + ), + metadata=RecordMetadata( + input_tokens=0, + output_tokens=0, + time=0, + feature=post_body.features.model_dump(by_alias=True), + ), + createdAt=round(datetime.now(timezone.utc).timestamp(), 3), + ) + + if not task: + # 任务不存在,新建Task,并放入task_map + task_id = str(uuid.uuid4()) + new_record.task_id = task_id + + self._task_map[task_id] = TaskBlock( + session_id=session_id, + record=new_record, + ) + return deepcopy(self._task_map[task_id]) + + # 任务存在,整理Task,放入task_map + task_id = task.id + new_record.task_id = task_id + self._task_map[task_id] = TaskBlock( + session_id=session_id, + record=new_record, + flow_state=task.state, + ) + + return deepcopy(self._task_map[task_id]) + + async def set_task(self, task_id: str, value: TaskBlock) -> None: + """设置任务块""" + # 检查task_id合法性 + if task_id not in self._task_map: + err = f"Task {task_id} not found" + raise KeyError(err) + + # 替换task_map中的数据 + self._task_map[task_id] = value + + async def save_task(self, task_id: str) -> None: + """保存任务块""" + # 整理任务信息 + origin_task = await TaskManager.get_task_by_conversation_id(self._task_map[task_id].record.conversation_id) + if not origin_task: + # 创建新的Task记录 + task = TaskData( + _id=task_id, + conversation_id=self._task_map[task_id].record.conversation_id, + record_groups=[self._task_map[task_id].record.group_id], + state=self._task_map[task_id].flow_state, + ended=False, + updated_at=round(datetime.now(timezone.utc).timestamp(), 3), + ) + else: + # 更新已有的Task记录 + task = origin_task + task.record_groups.append(self._task_map[task_id].record.group_id) + task.state = self._task_map[task_id].flow_state + task.updated_at = round(datetime.now(timezone.utc).timestamp(), 3) + + # 判断Task是否结束 + if ( + not self._task_map[task_id].flow_state or + self._task_map[task_id].flow_state.status == StepStatus.ERROR or # type: ignore[] + self._task_map[task_id].flow_state.status == StepStatus.SUCCESS # type: ignore[] + ): + task.ended = True + + # 使用MongoDB保存任务块 + task_collection = MongoDB.get_collection("task") + + if task_id not in self._task_map: + err = f"Task {task_id} not found" + raise ValueError(err) + + await task_collection.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, upsert=True) + + # 从task_map中删除任务块,释放内存 + del self._task_map[task_id] diff --git a/apps/common/wordscheck.py b/apps/common/wordscheck.py index cc389e938523a5585fb9caafc77aa46676e1d424..7093533695c0d5951ba7bf81d82225017707c87f 100644 --- a/apps/common/wordscheck.py +++ b/apps/common/wordscheck.py @@ -4,75 +4,61 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import http import re -from pathlib import Path -from typing import Union -import requests +import aiofiles +import aiohttp +import ray from apps.common.config import config from apps.constants import LOGGER -class APICheck: - """使用API接口检查敏感词""" +@ray.remote +class WordsCheck: + """敏感词检查工具""" + + async def _init_file(self) -> None: + """初始化关键词列表""" + async with aiofiles.open(config["WORDS_LIST"], encoding="utf-8") as f: + self.words_list = (await f.read()).splitlines() + + async def _check_wordlist(self, message: str) -> int: + """使用关键词列表检查敏感词""" + if message in self.words_list: + return 1 + return 0 - @classmethod - def check(cls, content: str) -> int: - """检查敏感词""" + async def _check_api(self, message: str) -> int: + """使用API接口检查敏感词""" url = config["WORDS_CHECK"] if url is None: err = "配置文件中未设置WORDS_CHECK" raise ValueError(err) headers = {"Content-Type": "application/json"} - data = {"content": content} + data = {"content": message} try: - response = requests.post(url=url, json=data, headers=headers, timeout=10) - if response.status_code == http.HTTPStatus.OK and re.search("ok", str(response.content)): - return 1 - return 0 + async with aiohttp.ClientSession() as session, session.post(url=url, json=data, headers=headers, timeout=10) as response: + if response.status == http.HTTPStatus.OK and re.search("ok", str(response.content)): + return 1 + return 0 except Exception as e: LOGGER.info("过滤敏感词错误:" + str(e)) return -1 -class KeywordCheck: - """使用关键词列表检查敏感词""" - - words_list: list - - def __init__(self) -> None: - """初始化关键词列表""" - with Path(config["WORDS_LIST"]).open("r", encoding="utf-8") as f: - self.words_list = f.read().splitlines() - - def check(self, message: str) -> int: - """使用关键词列表检查关键词""" - if message in self.words_list: - return 1 - return 0 - - -class WordsCheck: - """敏感词检查工具""" - - tool: Union[APICheck, KeywordCheck, None] = None - - @classmethod - def init(cls) -> None: + async def init(self) -> None: """初始化敏感词检查器""" if config["DETECT_TYPE"] == "keyword": - cls.tool = KeywordCheck() - elif config["DETECT_TYPE"] == "wordscheck": - cls.tool = APICheck() - else: - cls.tool = None + await self._init_file() - @classmethod - async def check(cls, message: str) -> int: + async def check(self, message: str) -> int: """检查消息是否包含关键词 异常-1,拦截0,正常1 """ - if not cls.tool: - return 1 - return cls.tool.check(message) + if config["DETECT_TYPE"] == "keyword": + return await self._check_wordlist(message) + if config["DETECT_TYPE"] == "wordscheck": + return await self._check_api(message) + # 不设置检查类型,默认不拦截 + return 1 diff --git a/apps/cron/delete_user.py b/apps/cron/delete_user.py index 52bc1508767713b2e0e87c107a898584a70ff82a..8cccf8022b7a37a5dfb36c22014c4b22b1161b6e 100644 --- a/apps/cron/delete_user.py +++ b/apps/cron/delete_user.py @@ -8,10 +8,8 @@ import asyncer from apps.constants import LOGGER from apps.entities.collection import Audit -from apps.manager import ( - AuditLogManager, - UserManager, -) +from apps.manager.audit_log import AuditLogManager +from apps.manager.user import UserManager from apps.models.mongo import MongoDB from apps.service.knowledge_base import KnowledgeBaseService diff --git a/apps/entities/flow.py b/apps/entities/flow.py index 2fa3a7537c629df3e3cabe087c12825fa5dbecaf..e40a37e0f06ad1585ef21d467537f94b8b087fa0 100644 --- a/apps/entities/flow.py +++ b/apps/entities/flow.py @@ -67,7 +67,10 @@ class Permission(BaseModel): class MetadataBase(BaseModel): - """Service或App的元数据""" + """Service或App的元数据 + + 注意:hash字段在save和load的时候exclude + """ type: MetadataType = Field(description="元数据类型") id: str = Field(description="元数据ID") diff --git a/apps/entities/flow_topology.py b/apps/entities/flow_topology.py index 619f6950ad24f3c3322f3fd392d165e88b84405f..04fe0ecd11acc0d391e1b1518a086543827c626f 100644 --- a/apps/entities/flow_topology.py +++ b/apps/entities/flow_topology.py @@ -28,6 +28,8 @@ class NodeServiceItem(BaseModel): type: str = Field(..., description="服务类型") node_meta_datas: list[NodeMetaDataItem] = Field(alias="nodeMetaDatas", default=[]) created_at: str = Field(..., alias="createdAt", description="创建时间") + + class PositionItem(BaseModel): """请求/响应中的前端相对位置变量类""" @@ -57,6 +59,7 @@ class NodeItem(BaseModel): position: PositionItem=Field(default=PositionItem()) editable: bool = Field(default=True) + class EdgeItem(BaseModel): """请求/响应中的边变量类""" diff --git a/apps/entities/task.py b/apps/entities/task.py index e51f543ddaea288341f8055bbc32eb499e03900f..26a788ede968eccb78b3f1e796bfb9a377ba6255 100644 --- a/apps/entities/task.py +++ b/apps/entities/task.py @@ -62,7 +62,7 @@ class RequestDataApp(BaseModel): params: dict[str, Any] = Field(description="插件参数") -class Task(BaseModel): +class TaskData(BaseModel): """任务信息 Collection: task diff --git a/apps/entities/user.py b/apps/entities/user.py index c42e57ff9a45828ada971516363526c5e40bc43e..1335ffb403f808e0dac297427dbc9780bc408323 100644 --- a/apps/entities/user.py +++ b/apps/entities/user.py @@ -10,4 +10,4 @@ class UserInfo(BaseModel): """用户信息数据结构""" user_sub: str = Field(alias="userSub", default="") - user_name: str = Field(alias="userName", default="") \ No newline at end of file + user_name: str = Field(alias="userName", default="") diff --git a/apps/llm/patterns/__init__.py b/apps/llm/patterns/__init__.py index 0e468c9801ab1988e8ad502ea0521397d95a4b81..914ad23c0a148cbd3bebb5724fd965d9c62ae0d9 100644 --- a/apps/llm/patterns/__init__.py +++ b/apps/llm/patterns/__init__.py @@ -6,7 +6,7 @@ from apps.llm.patterns.core import CorePattern from apps.llm.patterns.domain import Domain from apps.llm.patterns.executor import ( ExecutorBackground, - ExecutorResult, + FinalThought, ExecutorThought, ) from apps.llm.patterns.json import Json @@ -17,7 +17,7 @@ __all__ = [ "CorePattern", "Domain", "ExecutorBackground", - "ExecutorResult", + "FinalThought", "ExecutorThought", "Json", "Recommend", diff --git a/apps/llm/patterns/domain.py b/apps/llm/patterns/domain.py index 3036bda8a28905f4ce1cd3e4dcb18d68c9c30667..f7c4e88655c7f03e6d5858f8efcb052d77bda104 100644 --- a/apps/llm/patterns/domain.py +++ b/apps/llm/patterns/domain.py @@ -13,21 +13,33 @@ class Domain(CorePattern): """从问答中提取领域信息""" user_prompt: str = r""" - 根据对话上文,提取推荐系统所需的关键词标签,要求: - 1. 实体名词、技术术语、时间范围、地点、产品等关键信息均可作为关键词标签 - 2. 至少一个关键词与对话的话题有关 - 3. 标签需精简,不得重复,不得超过10个字 + + + 根据对话上文,提取推荐系统所需的关键词标签,要求: + 1. 实体名词、技术术语、时间范围、地点、产品等关键信息均可作为关键词标签 + 2. 至少一个关键词与对话的话题有关 + 3. 标签需精简,不得重复,不得超过10个字 + 4. 使用JSON格式输出,不要包含XML标签,不要包含任何解释说明 + - ==示例== - 样例对话: - 用户:北京天气如何? - 助手:北京今天晴。 + + + 北京天气如何? + 北京今天晴。 + - 样例输出: - ["北京", "天气"] - ==结束示例== + + { + "keywords": ["北京", "天气"] + } + + + - 输出结果: + + {conversation} + + """ """用户提示词""" @@ -50,9 +62,10 @@ class Domain(CorePattern): async def generate(self, task_id: str, **kwargs) -> list[str]: # noqa: ANN003 """从问答中提取领域信息""" - messages = [{"role": "system", "content": ""}] - messages += kwargs["conversation"] - messages += [{"role": "user", "content": self.user_prompt}] + messages = [ + {"role": "system", "content": ""}, + {"role": "user", "content": self.user_prompt.format(conversation=kwargs["conversation"])}, + ] result = "" async for chunk in ReasoningLLM().call(task_id, messages, streaming=False): diff --git a/apps/llm/patterns/executor.py b/apps/llm/patterns/executor.py index 2bfed3ceada470e5eb380d855aaa26a1b3f47615..f09fab6d054f61265fd7dc586baea01530ac9c19 100644 --- a/apps/llm/patterns/executor.py +++ b/apps/llm/patterns/executor.py @@ -14,9 +14,6 @@ from apps.llm.reasoning import ReasoningLLM class ExecutorThought(CorePattern): """通过大模型生成Executor的思考内容""" - system_prompt: str = "" - """系统提示词""" - user_prompt: str = r""" 你是一个可以使用工具的智能助手。请简明扼要地总结工具的使用过程,提供你的见解,并给出下一步的行动。 @@ -126,12 +123,9 @@ class ExecutorBackground(CorePattern): return result -class ExecutorResult(CorePattern): +class FinalThought(CorePattern): """使用大模型生成Executor的最终结果""" - system_prompt: str = "" - """系统提示词""" - user_prompt: str = r""" 你是AI智能助手,请回答用户的问题并满足以下要求: 1. 使用中文回答问题,不要使用其他语言。 diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..be2066e7b64f4b110c533ead26384927378bffef --- /dev/null +++ b/apps/llm/patterns/rewrite.py @@ -0,0 +1,28 @@ +"""问题改写""" +from apps.llm.patterns.core import CorePattern + + +class QuestionRewrite(CorePattern): + """问题补全与重写""" + + user_prompt: str = r""" + + + 根据上面的对话,推断用户的实际意图并补全用户的提问内容。 + 要求: + 1. 请使用JSON格式输出,参考下面给出的样例;输出不要包含XML标签,不要包含任何解释说明; + 2. 若用户当前提问内容与对话上文不相关,或你认为用户的提问内容已足够完整,请直接输出用户的提问内容。 + 3. 补全内容必须精准、恰当,不要编造任何内容。 + + + + openEuler的特点? + openEuler相较于其他操作系统,其特点是什么? + + + + + {question} + + """ + diff --git a/apps/llm/reasoning.py b/apps/llm/reasoning.py index 39dfbb9b9d3802cb4060ab60375dc28bbb0e49ff..6cb2ca0a5473c2cccec43a25b7abd62fcc9c0124 100644 --- a/apps/llm/reasoning.py +++ b/apps/llm/reasoning.py @@ -5,16 +5,15 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. from collections.abc import AsyncGenerator from typing import Optional +import ray import tiktoken from openai import AsyncOpenAI from apps.common.config import config -from apps.common.singleton import Singleton from apps.constants import LOGGER, REASONING_BEGIN_TOKEN, REASONING_END_TOKEN -from apps.manager.task import TaskManager -class ReasoningLLM(metaclass=Singleton): +class ReasoningLLM: """调用用于问答的大模型""" _encoder = tiktoken.get_encoding("cl100k_base") @@ -158,4 +157,5 @@ class ReasoningLLM(metaclass=Singleton): LOGGER.info(f"推理LLM:{reasoning_content}\n\n{result}") output_tokens = self._calculate_token_length([{"role": "assistant", "content": result}], pure_text=True) - await TaskManager.update_token_summary(task_id, input_tokens, output_tokens) + task = ray.get_actor("task") + await task.update_token_summary.remote(task_id, input_tokens, output_tokens) diff --git a/apps/main.py b/apps/main.py index 5278d1fe42d5828cecdd2d8448f8cc4c3984fa95..6712e530028f29ec36ad82f04a0b23c23ab0081a 100644 --- a/apps/main.py +++ b/apps/main.py @@ -13,6 +13,7 @@ from ray import serve from ray.serve.config import HTTPOptions from apps.common.config import config +from apps.common.task import Task from apps.common.wordscheck import WordsCheck from apps.cron.delete_user import DeleteUserCron from apps.dependency.session import VerifySessionMiddleware @@ -80,11 +81,11 @@ if __name__ == "__main__": ray.init(dashboard_host="0.0.0.0", num_cpus=4) # noqa: S104 # 初始化必要资源 - WordsCheck.init() - - pool_actor = Pool.remote() - ray.get(pool_actor.init.remote()) # type: ignore[] - ray.kill(pool_actor) + words_check = WordsCheck.options(name="words_check").remote() + ray.get(words_check.init.remote()) # type: ignore[attr-type] + task = Task.options(name="task").remote() + pool_actor = Pool.options(name="pool").remote() + ray.get(pool_actor.init.remote()) # type: ignore[attr-type] # 启动FastAPI serve.start(http_options=HTTPOptions(host="0.0.0.0", port=8002)) # noqa: S104 diff --git a/apps/manager/__init__.py b/apps/manager/__init__.py index 1fc68a4db81bcb888a69f170aed97468abea8c13..a54f580667bdfbdd1028b8783b6d48814ca53baf 100644 --- a/apps/manager/__init__.py +++ b/apps/manager/__init__.py @@ -1,32 +1,4 @@ -"""Manager模块, 包含所有与数据库操作相关的逻辑。""" -from apps.manager.api_key import ApiKeyManager -from apps.manager.audit_log import AuditLogManager -from apps.manager.blacklist import ( - AbuseManager, - QuestionBlacklistManager, - UserBlacklistManager, -) -from apps.manager.comment import CommentManager -from apps.manager.conversation import ConversationManager -from apps.manager.document import DocumentManager -from apps.manager.record import RecordManager -from apps.manager.session import SessionManager -from apps.manager.task import TaskManager -from apps.manager.user import UserManager -from apps.manager.user_domain import UserDomainManager +"""Manager模块 -__all__ = [ - "AbuseManager", - "ApiKeyManager", - "AuditLogManager", - "CommentManager", - "ConversationManager", - "DocumentManager", - "QuestionBlacklistManager", - "RecordManager", - "SessionManager", - "TaskManager", - "UserBlacklistManager", - "UserDomainManager", - "UserManager", -] +Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. +""" diff --git a/apps/manager/service.py b/apps/manager/service.py index 59abbe68467616fb62bc041fa4338eda45f3d6e0..142f1c245759bcf56ef49fa8b10f1098a07e2a2d 100644 --- a/apps/manager/service.py +++ b/apps/manager/service.py @@ -2,10 +2,10 @@ Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. """ - import uuid from typing import Any, Optional +import ray import yaml from anyio import Path from jsonschema import ValidationError, validate diff --git a/apps/manager/task.py b/apps/manager/task.py index 51a0d00553fea10e27a2f1aacd492fe5c3eb68d0..6aa4fb4c9e43fb315d84fb098ae30a590f273591 100644 --- a/apps/manager/task.py +++ b/apps/manager/task.py @@ -1,47 +1,21 @@ -"""任务模块 +"""获取和保存Task信息到数据库 Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import uuid -from asyncio import Lock -from copy import deepcopy -from datetime import datetime, timezone -from typing import ClassVar, Optional +from typing import Optional from apps.constants import LOGGER -from apps.entities.collection import ( - RecordGroup, -) -from apps.entities.enum_var import StepStatus -from apps.entities.record import ( - RecordContent, - RecordData, - RecordMetadata, -) -from apps.entities.request_data import RequestData -from apps.entities.task import FlowHistory, Task, TaskBlock +from apps.entities.collection import RecordGroup +from apps.entities.task import FlowHistory, TaskData from apps.manager.record import RecordManager from apps.models.mongo import MongoDB -from apps.models.redis import RedisConnectionPool class TaskManager: - """保存任务信息;每一个任务关联一组队列(现阶段只有输入和输出)""" - - _connection = RedisConnectionPool.get_redis_connection() - _task_map: ClassVar[dict[str, TaskBlock]] = {} - _write_lock: Lock = Lock() - - @classmethod - async def update_token_summary(cls, task_id: str, input_num: int, output_num: int) -> None: - """更新对应task_id的Token统计数据""" - async with cls._write_lock: - task = cls._task_map[task_id] - task.record.metadata.input_tokens += input_num - task.record.metadata.output_tokens += output_num + """从数据库中获取任务信息""" @staticmethod - async def get_task_by_conversation_id(conversation_id: str) -> Optional[Task]: + async def get_task_by_conversation_id(conversation_id: str) -> Optional[TaskData]: """获取对话ID的最后一条问答组关联的任务""" # 查询对话ID的最后一条问答组 last_group = await RecordManager.query_record_group_by_conversation_id(conversation_id, 1) @@ -61,16 +35,15 @@ class TaskManager: LOGGER.error(f"Task {task_id} not found.") return None - task = Task.model_validate(task) + task = TaskData.model_validate(task) if task.ended: # Task已结束,新建Task return None return task - @staticmethod - async def get_task_by_group_id(group_id: str, conversation_id: str) -> Optional[Task]: + async def get_task_by_group_id(group_id: str, conversation_id: str) -> Optional[TaskData]: """获取组ID的最后一条问答组关联的任务""" task_collection = MongoDB.get_collection("task") record_group_collection = MongoDB.get_collection("record_group") @@ -80,128 +53,11 @@ class TaskManager: return None record_group_obj = RecordGroup.model_validate(record_group) task = await task_collection.find_one({"_id": record_group_obj.task_id}) - return Task.model_validate(task) + return TaskData.model_validate(task) except Exception as e: LOGGER.error(f"[TaskManager] Get task by group_id failed: {e}") return None - - @classmethod - async def get_task(cls, task_id: Optional[str] = None, session_id: Optional[str] = None, post_body: Optional[RequestData] = None) -> TaskBlock: - """获取任务块""" - # 如果task_map里面已经有了,则直接返回副本 - if task_id in cls._task_map: - return deepcopy(cls._task_map[task_id]) - - # 如果task_map里面没有,则尝试从数据库中读取 - if not session_id or not post_body: - err = "session_id and conversation_id or group_id and conversation_id are required to recover/create a task." - raise ValueError(err) - - if post_body.group_id: - task = await TaskManager.get_task_by_group_id(post_body.group_id, post_body.conversation_id) - else: - task = await TaskManager.get_task_by_conversation_id(post_body.conversation_id) - - # 创建新的Record,缺失的数据延迟关联 - new_record = RecordData( - id=str(uuid.uuid4()), - conversationId=post_body.conversation_id, - groupId=str(uuid.uuid4()) if not post_body.group_id else post_body.group_id, - taskId="", - content=RecordContent( - question=post_body.question, - answer="", - ), - metadata=RecordMetadata( - input_tokens=0, - output_tokens=0, - time=0, - feature=post_body.features.model_dump(by_alias=True), - ), - createdAt=round(datetime.now(timezone.utc).timestamp(), 3), - ) - - if not task: - # 任务不存在,新建Task,并放入task_map - task_id = str(uuid.uuid4()) - new_record.task_id = task_id - - async with cls._write_lock: - cls._task_map[task_id] = TaskBlock( - session_id=session_id, - record=new_record, - ) - return deepcopy(cls._task_map[task_id]) - - # 任务存在,整理Task,放入task_map - task_id = task.id - new_record.task_id = task_id - async with cls._write_lock: - cls._task_map[task_id] = TaskBlock( - session_id=session_id, - record=new_record, - flow_state=task.state, - ) - - return deepcopy(cls._task_map[task_id]) - - @classmethod - async def set_task(cls, task_id: str, value: TaskBlock) -> None: - """设置任务块""" - # 检查task_id合法性 - if task_id not in cls._task_map: - err = f"Task {task_id} not found" - raise KeyError(err) - - # 替换task_map中的数据 - async with cls._write_lock: - cls._task_map[task_id] = value - - @classmethod - async def save_task(cls, task_id: str) -> None: - """保存任务块""" - # 整理任务信息 - origin_task = await cls.get_task_by_conversation_id(cls._task_map[task_id].record.conversation_id) - if not origin_task: - # 创建新的Task记录 - task = Task( - _id=task_id, - conversation_id=cls._task_map[task_id].record.conversation_id, - record_groups=[cls._task_map[task_id].record.group_id], - state=cls._task_map[task_id].flow_state, - ended=False, - updated_at=round(datetime.now(timezone.utc).timestamp(), 3), - ) - else: - # 更新已有的Task记录 - task = origin_task - task.record_groups.append(cls._task_map[task_id].record.group_id) - task.state = cls._task_map[task_id].flow_state - task.updated_at = round(datetime.now(timezone.utc).timestamp(), 3) - - # 判断Task是否结束 - if ( - not cls._task_map[task_id].flow_state or - cls._task_map[task_id].flow_state.status == StepStatus.ERROR or # type: ignore[attr-defined] - cls._task_map[task_id].flow_state.status == StepStatus.SUCCESS # type: ignore[attr-defined] - ): - task.ended = True - - # 使用MongoDB保存任务块 - task_collection = MongoDB.get_collection("task") - - if task_id not in cls._task_map: - err = f"Task {task_id} not found" - raise ValueError(err) - - await task_collection.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, upsert=True) - - # 从task_map中删除任务块,释放内存 - async with cls._write_lock: - del cls._task_map[task_id] - - @staticmethod async def get_flow_history_by_record_id(record_group_id: str, record_id: str) -> list[FlowHistory]: """根据record_group_id获取flow信息""" @@ -272,3 +128,4 @@ class TaskManager: await session.commit_transaction() except Exception as e: LOGGER.error(f"[TaskManager] Delete tasks by conversation_id failed: {e}") + diff --git a/apps/models/minio.py b/apps/models/minio.py index 39994f8cc891081a07e3fe1933d0c0e5450233b4..bda2c8b54252f241c6071c318b64752ead99fb5d 100644 --- a/apps/models/minio.py +++ b/apps/models/minio.py @@ -24,7 +24,7 @@ class MinioClient: cls.client.make_bucket(bucket_name) @classmethod - def upload_file(cls, **kwargs: Any) -> None: # noqa: ANN401 + def upload_file(cls, **kwargs: Any) -> None: """上传文件""" cls.client.put_object(**kwargs) diff --git a/apps/routers/chat.py b/apps/routers/chat.py index 1216f639a1303890a813b1c1ffe3cf1cf1866037..58ab7fe8b019e46df14e85f719626915dc4eadf4 100644 --- a/apps/routers/chat.py +++ b/apps/routers/chat.py @@ -8,11 +8,11 @@ import uuid from collections.abc import AsyncGenerator from typing import Annotated +import ray from fastapi import APIRouter, Depends, HTTPException, status from fastapi.responses import JSONResponse, StreamingResponse from apps.common.queue import MessageQueue -from apps.common.wordscheck import WordsCheck from apps.constants import LOGGER from apps.dependency import ( get_session, @@ -22,12 +22,8 @@ from apps.dependency import ( ) from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData -from apps.manager import ( - QuestionBlacklistManager, - TaskManager, - UserBlacklistManager, -) from apps.manager.appcenter import AppCenterManager +from apps.manager.blacklist import QuestionBlacklistManager, UserBlacklistManager from apps.scheduler.scheduler import Scheduler from apps.service.activity import Activity @@ -45,7 +41,8 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await Activity.set_active(user_sub) # 敏感词检查 - if await WordsCheck.check(post_body.question) != 1: + word_check = ray.get_actor("words_check") + if await word_check.check.remote(post_body.question) != 1: yield "data: [SENSITIVE]\n\n" LOGGER.info(msg="问题包含敏感词!") await Activity.remove_active(user_sub) @@ -55,12 +52,13 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) group_id = str(uuid.uuid4()) if not post_body.group_id else post_body.group_id # 创建或还原Task - task = await TaskManager.get_task(session_id=session_id, post_body=post_body) + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(session_id=session_id, post_body=post_body) task_id = task.record.task_id task.record.group_id = group_id post_body.group_id = group_id - await TaskManager.set_task(task_id, task) + await task_pool.set_task.remote(task_id, task) # 创建queue;由Scheduler进行关闭 queue = MessageQueue() @@ -81,7 +79,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) await asyncio.gather(scheduler_task) # 获取最终答案 - task = await TaskManager.get_task(task_id) + task = await task_pool.get_task.remote(task_id) answer_text = task.record.content.answer if not answer_text: LOGGER.error(msg="Answer is empty") @@ -90,7 +88,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) return # 对结果进行敏感词检查 - if await WordsCheck.check(answer_text) != 1: + if await word_check.check.remote(answer_text) != 1: yield "data: [SENSITIVE]\n\n" LOGGER.info(msg="答案包含敏感词!") await Activity.remove_active(user_sub) @@ -99,7 +97,7 @@ async def chat_generator(post_body: RequestData, user_sub: str, session_id: str) # 创建新Record,存入数据库 await scheduler.save_state(user_sub, post_body) # 保存Task,从task_map中删除task - await TaskManager.save_task(task_id) + await task_pool.save_task.remote(task_id) yield "data: [DONE]\n\n" diff --git a/apps/routers/conversation.py b/apps/routers/conversation.py index 46066efb814927b6444b19169048717b754760a2..c8c1a0ae9ce69432a746af86cddf1356077c375c 100644 --- a/apps/routers/conversation.py +++ b/apps/routers/conversation.py @@ -27,13 +27,11 @@ from apps.entities.response_data import ( ResponseData, UpdateConversationRsp, ) -from apps.manager import ( - AuditLogManager, - ConversationManager, - DocumentManager, - RecordManager, -) from apps.manager.application import AppManager +from apps.manager.audit_log import AuditLogManager +from apps.manager.conversation import ConversationManager +from apps.manager.document import DocumentManager +from apps.manager.record import RecordManager router = APIRouter( prefix="/api/conversation", diff --git a/apps/routers/flow.py b/apps/routers/flow.py index 7ed4d5b8dd4f26f9b01813f18f46c6b7fd254028..a9786ccd2f8b1070d3bef6779617fc9fbe8ddc40 100644 --- a/apps/routers/flow.py +++ b/apps/routers/flow.py @@ -2,20 +2,32 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from fastapi import APIRouter, Depends, status, Query, Body -from fastapi.responses import JSONResponse from typing import Annotated, Optional +from fastapi import APIRouter, Body, Depends, Query, status +from fastapi.responses import JSONResponse + +from apps.dependency import get_user from apps.dependency.csrf import verify_csrf_token from apps.dependency.user import verify_user -from apps.dependency import get_user from apps.entities.flow_topology import NodeItem from apps.entities.request_data import PutFlowReq -from apps.entities.response_data import NodeServiceListRsp, NodeServiceListMsg, NodeMetaDataRsp, FlowStructureGetRsp, \ - FlowStructureGetMsg, FlowStructurePutRsp, FlowStructurePutMsg, FlowStructureDeleteRsp, FlowStructureDeleteMsg, ResponseData -from apps.manager.flow import FlowManager +from apps.entities.response_data import ( + FlowStructureDeleteMsg, + FlowStructureDeleteRsp, + FlowStructureGetMsg, + FlowStructureGetRsp, + FlowStructurePutMsg, + FlowStructurePutRsp, + NodeMetaDataRsp, + NodeServiceListMsg, + NodeServiceListRsp, + ResponseData, +) from apps.manager.application import AppManager +from apps.manager.flow import FlowManager from apps.utils.flow import FlowService + router = APIRouter( prefix="/api/flow", tags=["flow"], @@ -30,7 +42,7 @@ router = APIRouter( status.HTTP_404_NOT_FOUND: {"model": ResponseData}, }) async def get_services( - user_sub: Annotated[str, Depends(get_user)] + user_sub: Annotated[str, Depends(get_user)], ): """获取用户可访问的节点元数据所在服务的信息""" services = await FlowManager.get_service_by_user_id(user_sub) @@ -54,7 +66,7 @@ async def get_services( }) async def get_node_metadatas( user_sub: Annotated[str, Depends(get_user)], - node_metadata_id: int = Query(..., alias="NodeMetadataId") + node_metadata_id: Annotated[str, Query(alias="NodeMetadataId")], ): """获取节点元数据的详细信息""" if not await FlowManager.validate_user_node_meta_data_access(user_sub, node_metadata_id): @@ -139,7 +151,7 @@ async def put_flow( message="应用下流更新失败", result=FlowStructurePutMsg(), ).model_dump(exclude_none=True, by_alias=True)) - flow,_=await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) + flow, _ = await FlowManager.get_flow_by_app_and_flow_id(app_id, flow_id) if flow is None: return JSONResponse(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=FlowStructurePutRsp( code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/apps/routers/mock.py b/apps/routers/mock.py index d1271622138fced91c9ccbf78cce3c05b73a1e85..fa7e709e2c518239eb9b95e439c2c98d5a06b2b3 100644 --- a/apps/routers/mock.py +++ b/apps/routers/mock.py @@ -1,7 +1,9 @@ -from fastapi import APIRouter, Depends, HTTPException, status -from fastapi.responses import StreamingResponse import json + import tiktoken +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import StreamingResponse + from apps.common.config import config from apps.dependency import ( get_session, @@ -108,7 +110,7 @@ def mock_data(question): "inputTokens": 200, "outputTokens": 50, "time": 0.5, - } + }, }, { # 【API】获取任务简介 "event": "step.input", @@ -146,9 +148,9 @@ def mock_data(question): "inputTokens": 200, "outputTokens": 50, "time": 0.5, - } + }, }, - { + { "event": "step.output", "id": "0f9d3e6b-7845-44ab-b247-35c522d38f13", "groupId":"8b9d3e6b-a892-4602-b247-35c522d38f13", diff --git a/apps/routers/user.py b/apps/routers/user.py index 7d1da95a41437a6f79f62e4d83fd31b530f2b1b9..8fe0431628cdebe8d433d7bbcc560034100d57bb 100644 --- a/apps/routers/user.py +++ b/apps/routers/user.py @@ -17,9 +17,9 @@ from apps.entities.request_data import RequestData from apps.entities.response_data import ResponseData, UserGetMsp, UserGetRsp from apps.entities.user import UserInfo from apps.manager.appcenter import AppCenterManager +from apps.manager.user import UserManager from apps.scheduler.scheduler import Scheduler from apps.service.activity import Activity -from apps.manager.user import UserManager router = APIRouter( prefix="/api/user", diff --git a/apps/scheduler/call/__init__.py b/apps/scheduler/call/__init__.py index ca09d472293af22aaf90df4c3dd69670134be887..c962e24c1c9fe75f4b0e2a8abf1ae2bb54f2d088 100644 --- a/apps/scheduler/call/__init__.py +++ b/apps/scheduler/call/__init__.py @@ -3,8 +3,9 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from apps.scheduler.call.api import API -from apps.scheduler.call.convert import Reformat +from apps.scheduler.call.convert import Convert from apps.scheduler.call.llm import LLM +from apps.scheduler.call.rag import RAG from apps.scheduler.call.render.render import Render from apps.scheduler.call.sql import SQL from apps.scheduler.call.suggest import Suggestion @@ -12,8 +13,9 @@ from apps.scheduler.call.suggest import Suggestion __all__ = [ "API", "LLM", + "RAG", "SQL", - "Reformat", + "Convert", "Render", "Suggestion", ] diff --git a/apps/scheduler/call/api.py b/apps/scheduler/call/api.py index ccaac5e90de5ba70ab109e9722a8979e9dce41bf..8405e7bebbe623c4b7886490773b266b5e811cf6 100644 --- a/apps/scheduler/call/api.py +++ b/apps/scheduler/call/api.py @@ -27,7 +27,7 @@ class APIParams(BaseModel): "application/json", "application/x-www-form-urlencoded", "multipart/form-data", ] = Field(description="API接口的Content-Type") timeout: int = Field(description="工具超时时间", default=300) - known_body: dict[str, Any] = Field(description="已知的部分请求体", default={}) + body: dict[str, Any] = Field(description="已知的部分请求体", default={}) input_schema: dict[str, Any] = Field(description="API请求体的JSON Schema", default={}) auth: dict[str, Any] = Field(description="API鉴权信息", default={}) diff --git a/apps/scheduler/call/convert.py b/apps/scheduler/call/convert.py index d26f0aa5ba4e9506d7308989cbcd3c8c08c5fadb..3a1ecea9b42bbe88a23f3e97cdac2b29edd341de 100644 --- a/apps/scheduler/call/convert.py +++ b/apps/scheduler/call/convert.py @@ -17,35 +17,35 @@ from apps.entities.scheduler import SysCallVars from apps.scheduler.call.core import CoreCall -class _ReformatParam(BaseModel): - """校验Reformat Call需要的额外参数""" +class _ConvertParam(BaseModel): + """校验Convert Call需要的额外参数""" text: Optional[str] = Field(description="对生成的文字信息进行格式化,没有则不改动;jinja2语法", default=None) data: Optional[str] = Field(description="对生成的原始数据(JSON)进行格式化,没有则不改动;jsonnet语法", default=None) -class _ReformatOutput(BaseModel): - """定义Reformat工具的输出""" +class _ConvertOutput(BaseModel): + """定义Convert工具的输出""" message: str = Field(description="格式化后的文字信息") output: dict = Field(description="格式化后的结果") -class Reformat(metaclass=CoreCall, param_cls=_ReformatParam, output_cls=_ReformatOutput): - """Reformat 工具,用于对生成的文字信息和原始数据进行格式化""" +class Convert(metaclass=CoreCall, param_cls=_ConvertParam, output_cls=_ConvertOutput): + """Convert 工具,用于对生成的文字信息和原始数据进行格式化""" - name: str = "reformat" + name: str = "convert" description: str = "从上一步的工具的原始JSON返回结果中,提取特定字段的信息。" - async def __call__(self, _slot_data: dict[str, Any]) -> _ReformatOutput: - """调用Reformat工具 + async def __call__(self, _slot_data: dict[str, Any]) -> _ConvertOutput: + """调用Convert工具 :param _slot_data: 经用户确认后的参数(目前未使用) :return: 提取出的字段 """ # 获取必要参数 - params: _ReformatParam = getattr(self, "_params") + params: _ConvertParam = getattr(self, "_params") syscall_vars: SysCallVars = getattr(self, "_syscall_vars") last_output = syscall_vars.history[-1].output_data # 判断用户是否给了值 @@ -76,7 +76,7 @@ class Reformat(metaclass=CoreCall, param_cls=_ReformatParam, output_cls=_Reforma """) result_data = json.loads(_jsonnet.evaluate_snippet(data_template, params.data), ensure_ascii=False) - return _ReformatOutput( + return _ConvertOutput( message=result_message, output=result_data, ) diff --git a/apps/scheduler/call/rag.py b/apps/scheduler/call/rag.py index fd850a19be703749f395fb76d47b2c15034061ca..3cace0fd256f5bae7ff6afb6fb5e6b93f02bcec4 100644 --- a/apps/scheduler/call/rag.py +++ b/apps/scheduler/call/rag.py @@ -4,29 +4,33 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Any, Optional -from pydantic import BaseModel, Field - import aiohttp - -from apps.scheduler.call.core import CoreCall - -from apps.entities.scheduler import CallError, SysCallVars +from fastapi import status +from pydantic import BaseModel, Field from apps.common.config import config +from apps.entities.scheduler import CallError, SysCallVars +from apps.scheduler.call.core import CoreCall class _RAGParams(BaseModel): """RAG工具的参数""" - knowledge_base: str = Field(description="知识库的id") + knowledge_base: str = Field(description="知识库的id", alias="kb_sn") top_k: int = Field(description="返回的答案数量(经过整合以及上下文关联)", default=5) methods: Optional[list[str]] = Field(description="rag检索方法") +class _RAGOutputList(BaseModel): + """RAG工具的输出""" + + corpus: list[str] = Field(description="知识库的语料列表") + + class _RAGOutput(BaseModel): """RAG工具的输出""" - message: list[str] = Field(description="RAG工具的输出") + output: _RAGOutputList = Field(description="RAG工具的输出") class RAG(metaclass=CoreCall, param_cls=_RAGParams, output_cls=_RAGOutput): @@ -35,41 +39,35 @@ class RAG(metaclass=CoreCall, param_cls=_RAGParams, output_cls=_RAGOutput): name: str = "rag" description: str = "RAG工具,用于查询知识库" - async def __call__(self, _slot_data: dict[str, Any]): + async def __call__(self, _slot_data: dict[str, Any]) -> _RAGOutput: """调用RAG工具""" syscall_vars: SysCallVars = getattr(self, "_syscall_vars") params: _RAGParams = getattr(self, "_params") - question = syscall_vars.question - - # 资产库ID - kb_sn = params.knowledge_base - # 获取chunk的数量 - top_k = params.top_k + params_dict = params.model_dump(exclude_none=True, by_alias=True) + params_dict["question"] = syscall_vars.question url = config["RAG_HOST"].rstrip("/") + "/chunk/get" - headers = { "Content-Type": "application/json", } - params = { - "question": question, - "kb_sn": kb_sn, - "top_k": top_k - } - - try: - # 发送 GET 请求 - session = aiohttp.ClientSession() - async with session.get(url, headers=headers, params=params) as response: - # 检查响应状态码 - if response.status_code == 200: - result = await response.json() - chunk_list = result['data'] - else: - text = await response.text() - raise CallError(message=f"rag调用失败:{text}", data={}) - except Exception as e: - raise CallError(message=f"rag调用失败:{e!s}", data={}) from e - return _RAGOutput(message=chunk_list) + # 发送 GET 请求 + session = aiohttp.ClientSession() + async with session.post(url, headers=headers, json=params_dict) as response: + # 检查响应状态码 + if response.status == status.HTTP_200_OK: + result = await response.json() + chunk_list = result["data"] + return _RAGOutput( + output=_RAGOutputList(corpus=chunk_list), + ) + text = await response.text() + raise CallError( + message=f"rag调用失败:{text}", + data={ + "question": syscall_vars.question, + "status": response.status, + "text": text, + }, + ) diff --git a/apps/scheduler/call/suggest.py b/apps/scheduler/call/suggest.py index 03f9ae7bb3c8ebbb1dcfc3eddbb2bb911155b257..61d8537948a1684643373e0a7caa4afdeab7ff1b 100644 --- a/apps/scheduler/call/suggest.py +++ b/apps/scheduler/call/suggest.py @@ -4,13 +4,11 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ from typing import Any, Optional +import ray from pydantic import BaseModel, Field from apps.entities.scheduler import CallError, SysCallVars -from apps.manager import ( - TaskManager, - UserDomainManager, -) +from apps.manager.user_domain import UserDomainManager from apps.scheduler.call.core import CoreCall @@ -56,7 +54,7 @@ class Suggestion(metaclass=CoreCall, param_cls=_SuggestInput, output_cls=_Sugges params: _SuggestInput = getattr(self, "_params") # 获取当前任务 - task = await TaskManager.get_task(sys_vars.task_id) + task = await Task.get_task(sys_vars.task_id) # 获取当前用户的画像 user_domain = await UserDomainManager.get_user_domain_by_user_sub_and_topk(sys_vars.user_sub, 5) diff --git a/apps/scheduler/executor/flow.py b/apps/scheduler/executor/flow.py index 6448010f09eb95239ef34d81213f2955cacd38a0..97922ca8f56ac0cc9a253840742ddba044580075 100644 --- a/apps/scheduler/executor/flow.py +++ b/apps/scheduler/executor/flow.py @@ -3,7 +3,9 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ import traceback -from typing import Optional +from typing import Any, Optional + +import ray from apps.constants import LOGGER, MAX_SCHEDULER_HISTORY_SIZE from apps.entities.enum_var import StepStatus @@ -15,7 +17,7 @@ from apps.entities.scheduler import ( from apps.entities.task import ExecutorState, TaskBlock from apps.llm.patterns import ExecutorThought from apps.llm.patterns.executor import ExecutorBackground -from apps.manager import TaskManager +from apps.manager.task import TaskManager from apps.scheduler.executor.message import ( push_flow_start, push_flow_stop, @@ -27,6 +29,7 @@ from apps.scheduler.slot.slot import Slot # 单个流的执行工具 +@ray.remote class Executor: """用于执行工作流的Executor""" @@ -39,16 +42,19 @@ class Executor: async def load_state(self, sysexec_vars: SysExecVars) -> None: """从JSON中加载FlowExecutor的状态""" # 获取Task - task = await TaskManager.get_task(sysexec_vars.task_id) + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(sysexec_vars.task_id) if not task: err = "[Executor] Task error." raise ValueError(err) # 加载Flow信息 - flow, flow_data = Pool().get_flow(sysexec_vars.app_data.flow_id, sysexec_vars.app_data.app_id) + pool = ray.get_actor("pool") + flow, flow_data = await pool.get_flow.remote(sysexec_vars.app_data.flow_id, sysexec_vars.app_data.app_id) + # Flow不合法,拒绝执行 if flow is None or flow_data is None: - err = "Flow不合法!" + err = "Flow不存在!请检查Flow ID是否正确!" raise ValueError(err) # 设置名称和描述 @@ -80,20 +86,21 @@ class Executor: ) # 是否结束运行 self._stop = False - await TaskManager.set_task(self._vars.task_id, task) + await task.set_task(self._vars.task_id, task) - async def _get_last_output(self, task: TaskBlock) -> Optional[CallResult]: + async def _get_last_output(self, task: TaskBlock) -> dict[str, Any]: """获取上一步的输出""" if not task.flow_context: - return None + return {} return CallResult(**task.flow_context[self.flow_state.step_id].output_data) - async def _run_step(self, step_data: Step) -> CallResult: # noqa: PLR0915 + async def _run_step(self, step_data: Step) -> dict[str, Any]: # noqa: PLR0915 """运行单个步骤""" # 获取Task - task = await TaskManager.get_task(self._vars.task_id) + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(self._vars.task_id) if not task: err = "[Executor] Task error." raise ValueError(err) @@ -103,28 +110,18 @@ class Executor: self.flow_state.status = StepStatus.RUNNING # Call类型为none,直接错误 - call_type = step_data.call_type - if call_type == "none": + node_id = step_data.node + if node_id == "none": self.flow_state.status = StepStatus.ERROR - return CallResult( - message="", - output={}, - output_schema={}, - extra=None, - ) + return {} # 从Pool中获取对应的Call - call_data, call_cls = Pool().get_call(call_type, self.flow_state.app_id) + call_data, call_cls = Pool().get_call(node_id, self.flow_state.app_id) if call_data is None or call_cls is None: - err = f"[FlowExecutor] 尝试执行工具{call_type}时发生错误:找不到该工具。\n{traceback.format_exc()}" + err = f"[FlowExecutor] 尝试执行工具{node_id}时发生错误:找不到该工具。\n{traceback.format_exc()}" LOGGER.error(err) self.flow_state.status = StepStatus.ERROR - return CallResult( - message=err, - output={}, - output_schema={}, - extra=None, - ) + return {} # 准备history history = list(task.flow_context.values()) @@ -151,15 +148,10 @@ class Executor: # 初始化Call call_obj = call_cls(sys_vars, **params) except Exception as e: - err = f"[FlowExecutor] 初始化工具{call_type}时发生错误:{e!s}\n{traceback.format_exc()}" + err = f"[FlowExecutor] 初始化工具{node_id}时发生错误:{e!s}\n{traceback.format_exc()}" LOGGER.error(err) self.flow_state.status = StepStatus.ERROR - return CallResult( - message=err, - output={}, - output_schema={}, - extra=None, - ) + return {} # 如果call_obj里面有slot_schema,初始化Slot处理器 if hasattr(call_obj, "slot_schema") and call_obj.slot_schema: @@ -191,12 +183,7 @@ class Executor: await push_step_input(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data) # 推送空输出 self.flow_state.status = StepStatus.PARAM - result = CallResult( - message="当前工具参数不完整!", - output={}, - output_schema={}, - extra=None, - ) + result = {} await push_step_output(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data, result) return result @@ -205,18 +192,13 @@ class Executor: # 执行Call try: - result: CallResult = await call_obj.call(self.flow_state.slot_data) + result: dict[str, Any] = await call_obj.call(self.flow_state.slot_data) except Exception as e: - err = f"[FlowExecutor] 执行工具{call_type}时发生错误:{e!s}\n{traceback.format_exc()}" + err = f"[FlowExecutor] 执行工具{node_id}时发生错误:{e!s}\n{traceback.format_exc()}" LOGGER.error(err) self.flow_state.status = StepStatus.ERROR # 推送空输出 - result = CallResult( - message=err, - output={}, - output_schema={}, - extra=None, - ) + result = {} await push_step_output(self._vars.task_id, self._vars.queue, self.flow_state, self._flow_data, result) return result @@ -228,7 +210,7 @@ class Executor: return result - async def _handle_next_step(self, result: CallResult) -> None: + async def _handle_next_step(self, result: dict[str, Any]) -> None: """处理下一步""" if self._next_step is None: return @@ -242,13 +224,13 @@ class Executor: self._next_step = self._flow_data.steps[self._next_step].next - async def _update_thought(self, call_name: str, call_description: str, call_result: CallResult) -> None: + async def _update_thought(self, call_name: str, call_description: str, call_result: dict[str, Any]) -> None: """执行步骤后,更新FlowExecutor的思考内容""" # 组装工具信息 tool_info = { "name": call_name, "description": call_description, - "output": call_result.output, + "output": call_result, } # 更新背景 self.flow_state.thought = await ExecutorThought().generate( diff --git a/apps/scheduler/executor/message.py b/apps/scheduler/executor/message.py index 1d63144efb1f4c196fbb15b4047f6dd559ffe7d4..b5c0e6efa9eb79450d716e48ba8f545519a0736f 100644 --- a/apps/scheduler/executor/message.py +++ b/apps/scheduler/executor/message.py @@ -2,6 +2,10 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +from typing import Any + +import ray + from apps.common.queue import MessageQueue from apps.entities.enum_var import EventType, FlowOutputType, StepStatus from apps.entities.flow import Flow @@ -12,15 +16,15 @@ from apps.entities.message import ( StepOutputContent, TextAddContent, ) -from apps.entities.task import ExecutorState, FlowHistory -from apps.llm.patterns.executor import ExecutorResult -from apps.manager.task import TaskManager +from apps.entities.task import ExecutorState, FlowHistory, TaskBlock +from apps.llm.patterns.executor import FinalThought async def push_step_input(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow) -> None: """推送步骤输入""" # 获取Task - task = await TaskManager.get_task(task_id) + task_actor = ray.get_actor("task") + task = await task_actor.get_task.remote(task_id) if not task.flow_state: err = "当前Record不存在Flow信息!" @@ -40,7 +44,7 @@ async def push_step_input(task_id: str, queue: MessageQueue, state: ExecutorStat task.new_context.append(flow_history.id) task.flow_context[state.step_id] = flow_history # 保存Task到TaskMap - await TaskManager.set_task(task_id, task) + await task.set_task.remote(task_id, task) # 组装消息 if state.status == StepStatus.ERROR: @@ -61,10 +65,11 @@ async def push_step_input(task_id: str, queue: MessageQueue, state: ExecutorStat await queue.push_output(event_type=EventType.STEP_INPUT, data=content.model_dump(exclude_none=True, by_alias=True)) -async def push_step_output(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow, output: CallResult) -> None: +async def push_step_output(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow, output: dict[str, Any]) -> None: """推送步骤输出""" # 获取Task - task = await TaskManager.get_task(task_id) + task_actor = ray.get_actor("task") + task = await task_actor.get_task.remote(task_id) if not task.flow_state: err = "当前Record不存在Flow信息!" @@ -74,16 +79,16 @@ async def push_step_output(task_id: str, queue: MessageQueue, state: ExecutorSta task.flow_state = state # 更新FlowContext - task.flow_context[state.step_id].output_data = output.model_dump(exclude_none=True, by_alias=True) if output else {} + task.flow_context[state.step_id].output_data = output task.flow_context[state.step_id].status = state.status # 保存Task到TaskMap - await TaskManager.set_task(task_id, task) + await task.set_task.remote(task_id, task) # 组装消息;只保留message和output content = StepOutputContent( callType=flow.steps[state.step_id].call_type, - message=output.message if output else "", - output=output.output if output else {}, + message=output["message"] if output and "message" in output else "", + output=output["output"] if output and "output" in output else {}, ) await queue.push_output(event_type=EventType.STEP_OUTPUT, data=content.model_dump(exclude_none=True, by_alias=True)) @@ -91,11 +96,13 @@ async def push_step_output(task_id: str, queue: MessageQueue, state: ExecutorSta async def push_flow_start(task_id: str, queue: MessageQueue, state: ExecutorState, question: str) -> None: """推送Flow开始""" # 获取Task - task = await TaskManager.get_task(task_id) + task_actor = ray.get_actor("task") + task = await task_actor.get_task.remote(task_id) + # 设置state task.flow_state = state # 保存Task到TaskMap - await TaskManager.set_task(task_id, task) + await task.set_task.remote(task_id, task) # 组装消息 content = FlowStartContent( @@ -109,9 +116,11 @@ async def push_flow_start(task_id: str, queue: MessageQueue, state: ExecutorStat async def push_flow_stop(task_id: str, queue: MessageQueue, state: ExecutorState, flow: Flow, question: str) -> None: """推送Flow结束""" # 获取Task - task = await TaskManager.get_task(task_id) + task_actor = ray.get_actor("task") + task = await task_actor.get_task.remote(task_id) + task.flow_state = state - await TaskManager.set_task(task_id, task) + await task.set_task.remote(task_id, task) # 准备必要数据 call_type = flow.steps[state.step_id].call_type @@ -140,7 +149,7 @@ async def push_flow_stop(task_id: str, queue: MessageQueue, state: ExecutorState "final_output": content, } full_text = "" - async for chunk in ExecutorResult().generate(task_id, **params): + async for chunk in FinalThought().generate(task_id, **params): if not chunk: continue await queue.push_output( @@ -155,4 +164,5 @@ async def push_flow_stop(task_id: str, queue: MessageQueue, state: ExecutorState # 更新Thought task.record.content.answer = full_text task.flow_state = state - await TaskManager.set_task(task_id, task) + + await task.set_task.remote(task_id, task) diff --git a/apps/scheduler/pool/check.py b/apps/scheduler/pool/check.py index 5175b2049a9b8036914f9ed4e4f8081ffa5f4d1f..f8553aa4d05e4ed7af6c580488077706872a7cce 100644 --- a/apps/scheduler/pool/check.py +++ b/apps/scheduler/pool/check.py @@ -72,19 +72,22 @@ class FileChecker: # 遍历列表 for list_item in items: # 判断是否存在? - if not (self._dir_path / list_item["_id"]).exists(): + if not Path(self._dir_path / list_item["_id"]).exists(): deleted_list.append(list_item["_id"]) continue # 判断是否发生变化 - if self.diff_one(self._dir_path / list_item["_id"], list_item.get("hashes", None)): + if self.diff_one(Path(self._dir_path / list_item["_id"]), list_item.get("hashes", None)): changed_list.append(list_item["_id"]) # 遍历目录 + item_names = [item["_id"] for item in items] for service_folder in self._dir_path.iterdir(): # 判断是否新增? - if (service_folder.name not in items and + if (service_folder.name not in item_names and service_folder.name not in deleted_list and service_folder.name not in changed_list): changed_list += [service_folder.name] + # 触发一次hash计算 + self.diff_one(service_folder) return changed_list, deleted_list diff --git a/apps/scheduler/pool/loader/call.py b/apps/scheduler/pool/loader/call.py index 520cd3006e8ddee654868369e9acf446a5f7cbfd..c466dbbd7255c137997ce0db9217d490b2bf4846 100644 --- a/apps/scheduler/pool/loader/call.py +++ b/apps/scheduler/pool/loader/call.py @@ -197,6 +197,7 @@ class CallLoader: ) await session.execute(insert_stmt) await session.commit() + await session.aclose() async def load(self) -> None: diff --git a/apps/scheduler/pool/loader/flow.py b/apps/scheduler/pool/loader/flow.py index c0d39df348b22d8da60e249619863eadb358b697..c871d6557a132e51624b387a826d8c68f86c6e74 100644 --- a/apps/scheduler/pool/loader/flow.py +++ b/apps/scheduler/pool/loader/flow.py @@ -2,8 +2,7 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -import asyncio -from pathlib import Path +from typing import Optional import aiofiles import yaml @@ -17,6 +16,7 @@ from apps.models.mongo import MongoDB async def search_step_type(node_id: str) -> str: + """获取node的基础类型""" node_collection = MongoDB.get_collection("node") # 查询 Node 集合获取对应的 call_id node_doc = await node_collection.find_one({"_id": node_id}) @@ -30,6 +30,7 @@ async def search_step_type(node_id: str) -> str: return call_id async def search_step_name(node_id: str) -> str: + """获取node的名称""" node_collection = MongoDB.get_collection("node") # 查询 Node 集合获取对应的 call_id node_doc = await node_collection.find_one({"_id": node_id}) @@ -42,14 +43,15 @@ async def search_step_name(node_id: str) -> str: return "" return call_id + class FlowLoader: """工作流加载器""" @classmethod - async def load(cls, app_id: str, flow_id: str) -> Flow: + async def load(cls, app_id: str, flow_id: str) -> Optional[Flow]: """从文件系统中加载【单个】工作流""" flow_path = Path(config["SERVICE_DIR"]) / "app" / app_id / "flow" / f"{flow_id}.yaml" - async with aiofiles.open(flow_path, mode='r', encoding="utf-8") as f: + async with aiofiles.open(flow_path, encoding="utf-8") as f: flow_yaml = yaml.safe_load(await f.read()) if "name" not in flow_yaml: @@ -118,13 +120,13 @@ class FlowLoader: "id": edge.id, "from": edge.edge_from, "to": edge.edge_to, - "type": edge.edge_type.value, + "type": edge.edge_type.value if edge.edge_type else None, } for edge in flow.edges - ] + ], } - async with aiofiles.open(flow_path, mode='w', encoding="utf-8") as f: + async with aiofiles.open(flow_path, mode="w", encoding="utf-8") as f: await f.write(yaml.dump(flow_dict, allow_unicode=True, sort_keys=False)) @classmethod diff --git a/apps/scheduler/pool/loader/metadata.py b/apps/scheduler/pool/loader/metadata.py index 180e99c653da7f1065df5efbb5bf25c3f1d0ea21..edf0f02cfefef736070acc3ce71df3c46460e6dd 100644 --- a/apps/scheduler/pool/loader/metadata.py +++ b/apps/scheduler/pool/loader/metadata.py @@ -25,6 +25,10 @@ class MetadataLoader: # 检查yaml格式 try: metadata_dict = yaml.safe_load(await file_path.read_text()) + # 忽略hashes字段,手动指定无效 + if "hashes" in metadata_dict: + metadata_dict.pop("hashes") + # 提取metadata的类型 metadata_type = metadata_dict["type"] except Exception as e: err = f"metadata.yaml读取失败: {e}" diff --git a/apps/scheduler/pool/loader/openapi.py b/apps/scheduler/pool/loader/openapi.py index 7b4c262eca3ac174d08f10fa0aed8529b58842a8..487c6b20cb6a0e2374ce38f48363c5e154214456 100644 --- a/apps/scheduler/pool/loader/openapi.py +++ b/apps/scheduler/pool/loader/openapi.py @@ -37,7 +37,7 @@ class OpenAPILoader: return reduce_openapi_spec(spec) - def parameters_to_spec(self, raw_schema: list[dict[str, Any]]) -> dict[str, Any]: + async def parameters_to_spec(self, raw_schema: list[dict[str, Any]]) -> dict[str, Any]: """将OpenAPI中GET接口List形式的请求体Spec转换为JSON Schema""" schema = { "type": "object", @@ -97,7 +97,7 @@ class OpenAPILoader: } try: - input_schema["properties"]["url_parameters"] = self.parameters_to_spec(spec.spec["parameters"]) + input_schema["properties"]["url_parameters"] = await self.parameters_to_spec(spec.spec["parameters"]) except KeyError: err = f"接口{spec.name}不存在URL参数定义" LOGGER.error(msg=err) @@ -162,6 +162,6 @@ class OpenAPILoader: return await self._process_spec(service_id, yaml_filename, spec, service_metadata) - async def save_one(self, nodes: list[NodePool]) -> None: + async def save_one(self, yaml_dict: dict[str, Any]) -> None: """保存单个OpenAPI文档""" pass diff --git a/apps/scheduler/pool/loader/service.py b/apps/scheduler/pool/loader/service.py index 73946ca629b3e1c8248e942ed5a12ac5f14dc6c1..a4b8d6d913296a8fd7d044f650240efc6e85ffdc 100644 --- a/apps/scheduler/pool/loader/service.py +++ b/apps/scheduler/pool/loader/service.py @@ -5,7 +5,9 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. import asyncio import uuid +import ray from anyio import Path +from sqlalchemy import delete from sqlalchemy.dialects.postgresql import insert from apps.common.config import config @@ -20,12 +22,10 @@ from apps.scheduler.pool.loader.metadata import MetadataLoader from apps.scheduler.pool.loader.openapi import OpenAPILoader +@ray.remote class ServiceLoader: """Service 加载器""" - _collection = MongoDB.get_collection("service") - - async def load(self, service_id: str, hashes: dict[str, str]) -> None: """加载单个Service""" service_path = Path(config["SEMANTICS_DIR"]) / "service" / service_id @@ -46,8 +46,7 @@ class ServiceLoader: await self._update_db(nodes, metadata) - @classmethod - async def _update_db(cls, nodes: list[NodePool], metadata: ServiceMetadata) -> None: + async def _update_db(self, nodes: list[NodePool], metadata: ServiceMetadata) -> None: """更新数据库""" # 更新MongoDB service_collection = MongoDB.get_collection("service") @@ -96,7 +95,7 @@ class ServiceLoader: await session.execute(insert_stmt) await session.commit() - + await session.aclose() async def save(self, service_id: str, metadata: ServiceMetadata, data: dict) -> None: """在文件系统上保存Service,并更新数据库""" @@ -116,4 +115,22 @@ class ServiceLoader: async def delete(self, service_id: str) -> None: """删除Service,并更新数据库""" - pass + service_collection = MongoDB.get_collection("service") + node_collection = MongoDB.get_collection("node") + try: + await service_collection.delete_one({"_id": service_id}) + await node_collection.delete_many({"service_id": service_id}) + except Exception as e: + err = f"删除Service失败:{e}" + LOGGER.error(err) + + session = await PostgreSQL.get_session() + try: + await session.execute(delete(ServicePoolVector).where(ServicePoolVector.id == service_id)) + await session.execute(delete(NodePoolVector).where(NodePoolVector.id == service_id)) + await session.commit() + except Exception as e: + err = f"删除数据库失败:{e}" + LOGGER.error(err) + + await session.aclose() diff --git a/apps/scheduler/pool/pool.py b/apps/scheduler/pool/pool.py index 14f2dc54c35226d0a3df4f1e1e23061f4899a151..44b7a2606388c924b1c98b03d04b20e330a71393 100644 --- a/apps/scheduler/pool/pool.py +++ b/apps/scheduler/pool/pool.py @@ -2,15 +2,17 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ -from pathlib import Path +import asyncio from typing import Optional import ray +from anyio import Path from apps.common.config import config from apps.constants import LOGGER, SERVICE_DIR from apps.entities.enum_var import MetadataType from apps.entities.flow_topology import FlowItem +from apps.models.mongo import MongoDB from apps.scheduler.pool.check import FileChecker from apps.scheduler.pool.loader import CallLoader, ServiceLoader @@ -27,31 +29,45 @@ class Pool: # 检查文件变动 checker = FileChecker() changed_service, deleted_service = await checker.diff(MetadataType.SERVICE) - LOGGER.info(f"11111, checker.hashes: {checker.hashes}, changed_service: {changed_service}, deleted_service: {deleted_service}") # 处理Service - service_loader = ServiceLoader() - for service in changed_service: - # 重载变化的Service - await service_loader.load(service, checker.hashes[Path(SERVICE_DIR + "/" + service).as_posix()]) - for service in deleted_service: - # 删除消失的Service - await service_loader.delete(service) + service_loader = ServiceLoader.remote() + + # 批量删除 + delete_task = [service_loader.delete.remote(service) for service in changed_service] # type: ignore[attr-type] + delete_task += [service_loader.delete.remote(service) for service in deleted_service] # type: ignore[attr-type] + await asyncio.gather(*delete_task) + + # 批量加载 + load_task = [service_loader.load.remote(service, checker.hashes[Path(SERVICE_DIR + "/" + service).as_posix()]) for service in changed_service] # type: ignore[attr-type] + await asyncio.gather(*load_task) + + # 完成Service load + ray.kill(service_loader) # 加载App changed_app, deleted_app = await checker.diff(MetadataType.APP) - def save(self, *, is_deletion: bool = False) -> None: + async def save(self, *, is_deletion: bool = False) -> None: """保存【单个】资源""" pass - def get_flow_metadata(self, app_id: str) -> Optional[FlowItem]: - """从数据库中获取全部Flow的元数据""" - pass + async def get_flow_metadata(self, app_id: str) -> Optional[FlowItem]: + """从数据库中获取特定App的全部Flow的元数据""" + app_collection = MongoDB.get_collection("app") + try: + flow_list = await app_collection.find_one({"_id": app_id}, {"flows": 1}) + except Exception as e: + err = f"获取App{app_id}的Flow列表失败:{e}" + LOGGER.error(err) + raise RuntimeError(err) from e + + if not flow_list: + return None - def get_flow(self, app_id: str, flow_id: str) -> Optional[FlowItem]: + async def get_flow(self, app_id: str, flow_id: str) -> Optional[FlowItem]: """从数据库中获取单个Flow的全部数据""" pass diff --git a/apps/scheduler/scheduler/context.py b/apps/scheduler/scheduler/context.py index 89a0c237c9fe0b832e8a6f680275f4517900a40a..106687ce1da1cc45072f9ca5538b2e9ec02067b7 100644 --- a/apps/scheduler/scheduler/context.py +++ b/apps/scheduler/scheduler/context.py @@ -2,14 +2,16 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import ray + from apps.common.security import Security from apps.entities.collection import RecordContent from apps.entities.request_data import RequestData from apps.llm.patterns.facts import Facts -from apps.manager import RecordManager, TaskManager +from apps.manager.record import RecordManager -async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[list[dict[str, str]], list[str]]: +async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[str, list[str]]: """获取当前问答的上下文信息 注意:这里的n要比用户选择的多,因为要考虑事实信息和历史问题 @@ -23,23 +25,29 @@ async def get_context(user_sub: str, post_body: RequestData, n: int) -> tuple[li facts = [] for record in records: facts.extend(record.facts) + # 组装问答 - messages = [] + messages = "" for record in records: record_data = RecordContent.model_validate_json(Security.decrypt(record.data, record.key)) - messages = [ - {"role": "user", "content": record_data.question}, - {"role": "assistant", "content": record_data.answer}, - *messages, - ] + messages += f""" + + {record_data.question} + + + {record_data.answer} + + """ + messages += "" return messages, facts async def generate_facts(task_id: str, question: str) -> list[str]: """生成Facts""" - task = await TaskManager.get_task(task_id) + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(task_id) if not task: err = "Task not found" raise ValueError(err) diff --git a/apps/scheduler/scheduler/flow.py b/apps/scheduler/scheduler/flow.py index 5e59626b877e5f35e258aa199a37f31092844d3f..64a20ef75594f81bb2fe44d1ff939e56328ed5d3 100644 --- a/apps/scheduler/scheduler/flow.py +++ b/apps/scheduler/scheduler/flow.py @@ -2,12 +2,14 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. """ +import asyncio from typing import Optional +import ray + from apps.entities.flow import Flow from apps.entities.task import RequestDataApp from apps.llm.patterns import Select -from apps.scheduler.pool.pool import Pool class PredifinedRAGFlow(Flow): @@ -20,19 +22,30 @@ class PredifinedRAGFlow(Flow): class FlowChooser: """Flow选择器""" - def __init__(self, task_id: str, question: str, user_selected: Optional[RequestDataApp] = None): + def __init__(self, task_id: str, question: str, user_selected: Optional[RequestDataApp] = None) -> None: """初始化Flow选择器""" self._task_id = task_id self._question = question self._user_selected = user_selected - def get_top_flow(self) -> str: + async def get_top_flow(self) -> str: """获取Top1 Flow""" - pass + pool = ray.get_actor("pool") + # 获取所选应用的所有Flow + if not self._user_selected or not self._user_selected.app_id: + return "KnowledgeBase" + + flow_list = await asyncio.gather(pool.get_flow_list.remote(self._user_selected.app_id))[0] + if not flow_list: + return "KnowledgeBase" + + top_flow = await Select.generate() + return top_flow + - def choose_flow(self) -> Optional[RequestDataApp]: + async def choose_flow(self) -> Optional[RequestDataApp]: """依据用户的输入和选择,构造对应的Flow。 - 当用户没有选择任何app时,直接进行智能问答 @@ -44,7 +57,7 @@ class FlowChooser: if self._user_selected.flow_id: return self._user_selected - top_flow = self.get_top_flow() + top_flow = await self.get_top_flow() if top_flow == "KnowledgeBase": return None diff --git a/apps/scheduler/scheduler/message.py b/apps/scheduler/scheduler/message.py index 9771b1081acb384872a46e87d31281a2612cdd85..a38617827ba9cb1cbda594a660e707691159b5f2 100644 --- a/apps/scheduler/scheduler/message.py +++ b/apps/scheduler/scheduler/message.py @@ -6,6 +6,8 @@ from datetime import datetime, timezone from textwrap import dedent from typing import Union +import ray + from apps.common.queue import MessageQueue from apps.constants import LOGGER from apps.entities.collection import Document @@ -19,14 +21,15 @@ from apps.entities.message import ( from apps.entities.rag_data import RAGEventData, RAGQueryReq from apps.entities.record import RecordDocument from apps.entities.request_data import RequestData -from apps.manager import TaskManager +from apps.entities.task import TaskBlock from apps.service import RAG async def push_init_message(task_id: str, queue: MessageQueue, post_body: RequestData, *, is_flow: bool = False) -> None: """推送初始化消息""" # 拿到Task - task = await TaskManager.get_task(task_id) + task_actor = ray.get_actor("task") + task: TaskBlock = await task_actor.get_task.remote(task_id) if not task: err = "[Scheduler] Task not found" raise ValueError(err) @@ -51,7 +54,7 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques created_at = round(datetime.now(timezone.utc).timestamp(), 2) task.record.metadata.time = created_at task.record.metadata.feature = feature.model_dump(exclude_none=True, by_alias=True) - await TaskManager.set_task(task_id, task) + await task_actor.set_task.remote(task_id, task) # 推送初始化消息 await queue.push_output(event_type=EventType.INIT, data=InitContent(feature=feature, createdAt=created_at).model_dump(exclude_none=True, by_alias=True)) @@ -59,7 +62,8 @@ async def push_init_message(task_id: str, queue: MessageQueue, post_body: Reques async def push_rag_message(task_id: str, queue: MessageQueue, user_sub: str, rag_data: RAGQueryReq) -> None: """推送RAG消息""" - task = await TaskManager.get_task(task_id) + task_actor = ray.get_actor("task") + task: TaskBlock = await task_actor.get_task.remote(task_id) if not task: err = "Task not found" raise ValueError(err) @@ -74,7 +78,7 @@ async def push_rag_message(task_id: str, queue: MessageQueue, user_sub: str, rag # 保存答案 task.record.content.answer = full_answer - await TaskManager.set_task(task_id, task) + await task_actor.set_task.remote(task_id, task) async def _push_rag_chunk(task_id: str, queue: MessageQueue, content: str, rag_input_tokens: int, rag_output_tokens: int) -> tuple[str, int, int]: @@ -92,7 +96,8 @@ async def _push_rag_chunk(task_id: str, queue: MessageQueue, content: str, rag_i # 计算Token数量 delta_input_tokens = content_obj.input_tokens - rag_input_tokens delta_output_tokens = content_obj.output_tokens - rag_output_tokens - await TaskManager.update_token_summary(task_id, delta_input_tokens, delta_output_tokens) + task_actor = ray.get_actor("task") + await task_actor.update_token_summary.remote(task_id, delta_input_tokens, delta_output_tokens) # 更新Token的值 rag_input_tokens = content_obj.input_tokens rag_output_tokens = content_obj.output_tokens diff --git a/apps/scheduler/scheduler/scheduler.py b/apps/scheduler/scheduler/scheduler.py index 4a4a9999f1ef640713a4f7ae679cc26440d35fe2..82b5e6dc558deac121728b464b261187b7ea8cba 100644 --- a/apps/scheduler/scheduler/scheduler.py +++ b/apps/scheduler/scheduler/scheduler.py @@ -7,6 +7,8 @@ import traceback from datetime import datetime, timezone from typing import Union +import ray + from apps.common.queue import MessageQueue from apps.common.security import Security from apps.constants import LOGGER @@ -20,14 +22,10 @@ from apps.entities.record import RecordDocument from apps.entities.request_data import RequestData from apps.entities.scheduler import ExecutorBackground, SysExecVars from apps.entities.task import RequestDataApp -from apps.manager import ( - DocumentManager, - RecordManager, - TaskManager, - UserManager, -) - -# from apps.scheduler.executor import Executor +from apps.manager.document import DocumentManager +from apps.manager.record import RecordManager +from apps.manager.user import UserManager +from apps.scheduler.executor import Executor from apps.scheduler.scheduler.context import generate_facts, get_context # from apps.scheduler.scheduler.flow import choose_flow @@ -89,7 +87,6 @@ class Scheduler: language=post_body.language, document_ids=doc_ids, kb_sn=None if not user_info.kb_id else user_info.kb_id, - history=context, top_k=5, ) @@ -97,7 +94,7 @@ class Scheduler: need_recommend = True # 如果是智能问答,直接执行 if not post_body.app or post_body.app.app_id == "": - # await push_init_message(self._task_id, self._queue, post_body, is_flow=False) + await push_init_message(self._task_id, self._queue, post_body, is_flow=False) await asyncio.sleep(0.1) for doc in docs: # 保存使用的文件ID @@ -108,13 +105,12 @@ class Scheduler: await push_rag_message(self._task_id, self._queue, user_sub, rag_data) else: # 需要执行Flow - # await push_init_message(self._task_id, self._queue, post_body, is_flow=True) + await push_init_message(self._task_id, self._queue, post_body, is_flow=True) # 组装上下文 background = ExecutorBackground( conversation=context, facts=facts, ) - need_recommend = await self.run_executor(session_id, post_body, background, post_body.app) # 生成推荐问题和事实提取 # 如果需要生成推荐问题,则生成 @@ -136,7 +132,7 @@ class Scheduler: async def run_executor(self, session_id: str, post_body: RequestData, background: ExecutorBackground, user_selected_flow: RequestDataApp) -> bool: """构造FlowExecutor,并执行所选择的流""" # 获取当前Task - task = await TaskManager.get_task(self._task_id) + task = await Task.get_task(self._task_id) if not task: err = "[Scheduler] Task error." raise ValueError(err) @@ -162,7 +158,8 @@ class Scheduler: async def save_state(self, user_sub: str, post_body: RequestData) -> None: """保存当前Executor、Task、Record等的数据""" # 获取当前Task - task = await TaskManager.get_task(self._task_id) + task_pool = ray.get_actor("task") + task = await task_pool.get_task.remote(self._task_id) if not task: err = "Task not found" raise ValueError(err) @@ -184,7 +181,7 @@ class Scheduler: if history.id == history_id: history_data.append(history) break - await TaskManager.create_flows(history_data) + await Task.create_flows(history_data) # 修改metadata里面时间为实际运行时间 task.record.metadata.time = round(datetime.now(timezone.utc).timestamp() - task.record.metadata.time, 2) diff --git a/requirements.txt b/requirements.txt index d01d2084df8499257982fb2bfef9755dd7943c1d..674e325c03b97f3b160f84c2b471922d9bd5155b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ jsonnet-binary==0.17.0 jsonschema==4.23.0 jieba==0.42.1 httpx==0.27.2 +memray==1.15.0 minio==7.2.11 ollama==0.4.4 openai==1.57.0