diff --git a/apps/common/queue.py b/apps/common/queue.py
index 50b02ba3a8dcd6c7600546f01562ae015cf30e89..83f655a6efcb0b8d00bf70491d74453fbbfb1780 100644
--- a/apps/common/queue.py
+++ b/apps/common/queue.py
@@ -5,6 +5,7 @@ from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import Any
+import ray
from redis.exceptions import ResponseError
from apps.constants import LOGGER
@@ -16,7 +17,6 @@ from apps.entities.message import (
MessageMetadata,
)
from apps.entities.task import TaskBlock
-from apps.manager.task import TaskManager
from apps.models.redis import RedisConnectionPool
@@ -50,7 +50,8 @@ class MessageQueue:
await client.publish(self._stream_name, "[DONE]")
return
- tcb: TaskBlock = await TaskManager.get_task(self._task_id)
+ task = ray.get_actor("task")
+ tcb: TaskBlock = await task.get_task.remote(self._task_id)
# 计算创建Task到现在的时间
used_time = round((datetime.now(timezone.utc).timestamp() - tcb.record.metadata.time), 2)
diff --git a/apps/common/singleton.py b/apps/common/singleton.py
deleted file mode 100644
index 4db37893c354f5c7eadc8786f9007efc83e89915..0000000000000000000000000000000000000000
--- a/apps/common/singleton.py
+++ /dev/null
@@ -1,20 +0,0 @@
-"""给类开启全局单例模式
-
-Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
-"""
-from multiprocessing import Lock
-from typing import Any, ClassVar
-
-
-class Singleton(type):
- """用于实现全局单例的MetaClass"""
-
- _instances: ClassVar[dict[type, Any]] = {}
- _lock = Lock()
-
- def __call__(cls, *args, **kwargs): # noqa: ANN002, ANN003, ANN204
- """实现单例模式"""
- if cls not in cls._instances:
- with cls._lock:
- cls._instances[cls] = super().__call__(*args, **kwargs)
- return cls._instances[cls]
diff --git a/apps/common/task.py b/apps/common/task.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7c9f5181455fd393ae13e3ca57dfb516062b3b0
--- /dev/null
+++ b/apps/common/task.py
@@ -0,0 +1,143 @@
+"""任务模块
+
+Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
+"""
+import uuid
+from copy import deepcopy
+from datetime import datetime, timezone
+from typing import Optional
+
+import ray
+
+from apps.entities.enum_var import StepStatus
+from apps.entities.record import (
+ RecordContent,
+ RecordData,
+ RecordMetadata,
+)
+from apps.entities.request_data import RequestData
+from apps.entities.task import TaskBlock, TaskData
+from apps.manager.task import TaskManager
+from apps.models.mongo import MongoDB
+
+
+@ray.remote
+class Task:
+ """保存任务信息;每一个任务关联一组队列(现阶段只有输入和输出)"""
+
+ def __init__(self) -> None:
+ """初始化TaskManager"""
+ self._task_map: dict[str, TaskBlock] = {}
+
+ async def update_token_summary(self, task_id: str, input_num: int, output_num: int) -> None:
+ """更新对应task_id的Token统计数据"""
+ self._task_map[task_id].record.metadata.input_tokens += input_num
+ self._task_map[task_id].record.metadata.output_tokens += output_num
+
+ async def get_task(self, task_id: Optional[str] = None, session_id: Optional[str] = None, post_body: Optional[RequestData] = None) -> TaskBlock:
+ """获取任务块"""
+ # 如果task_map里面已经有了,则直接返回副本
+ if task_id in self._task_map:
+ return deepcopy(self._task_map[task_id])
+
+ # 如果task_map里面没有,则尝试从数据库中读取
+ if not session_id or not post_body:
+ err = "session_id and conversation_id or group_id and conversation_id are required to recover/create a task."
+ raise ValueError(err)
+
+ if post_body.group_id:
+ task = await TaskManager.get_task_by_group_id(post_body.group_id, post_body.conversation_id)
+ else:
+ task = await TaskManager.get_task_by_conversation_id(post_body.conversation_id)
+
+ # 创建新的Record,缺失的数据延迟关联
+ new_record = RecordData(
+ id=str(uuid.uuid4()),
+ conversationId=post_body.conversation_id,
+ groupId=str(uuid.uuid4()) if not post_body.group_id else post_body.group_id,
+ taskId="",
+ content=RecordContent(
+ question=post_body.question,
+ answer="",
+ ),
+ metadata=RecordMetadata(
+ input_tokens=0,
+ output_tokens=0,
+ time=0,
+ feature=post_body.features.model_dump(by_alias=True),
+ ),
+ createdAt=round(datetime.now(timezone.utc).timestamp(), 3),
+ )
+
+ if not task:
+ # 任务不存在,新建Task,并放入task_map
+ task_id = str(uuid.uuid4())
+ new_record.task_id = task_id
+
+ self._task_map[task_id] = TaskBlock(
+ session_id=session_id,
+ record=new_record,
+ )
+ return deepcopy(self._task_map[task_id])
+
+ # 任务存在,整理Task,放入task_map
+ task_id = task.id
+ new_record.task_id = task_id
+ self._task_map[task_id] = TaskBlock(
+ session_id=session_id,
+ record=new_record,
+ flow_state=task.state,
+ )
+
+ return deepcopy(self._task_map[task_id])
+
+ async def set_task(self, task_id: str, value: TaskBlock) -> None:
+ """设置任务块"""
+ # 检查task_id合法性
+ if task_id not in self._task_map:
+ err = f"Task {task_id} not found"
+ raise KeyError(err)
+
+ # 替换task_map中的数据
+ self._task_map[task_id] = value
+
+ async def save_task(self, task_id: str) -> None:
+ """保存任务块"""
+ # 整理任务信息
+ origin_task = await TaskManager.get_task_by_conversation_id(self._task_map[task_id].record.conversation_id)
+ if not origin_task:
+ # 创建新的Task记录
+ task = TaskData(
+ _id=task_id,
+ conversation_id=self._task_map[task_id].record.conversation_id,
+ record_groups=[self._task_map[task_id].record.group_id],
+ state=self._task_map[task_id].flow_state,
+ ended=False,
+ updated_at=round(datetime.now(timezone.utc).timestamp(), 3),
+ )
+ else:
+ # 更新已有的Task记录
+ task = origin_task
+ task.record_groups.append(self._task_map[task_id].record.group_id)
+ task.state = self._task_map[task_id].flow_state
+ task.updated_at = round(datetime.now(timezone.utc).timestamp(), 3)
+
+ # 判断Task是否结束
+ if (
+ not self._task_map[task_id].flow_state or
+ self._task_map[task_id].flow_state.status == StepStatus.ERROR or # type: ignore[]
+ self._task_map[task_id].flow_state.status == StepStatus.SUCCESS # type: ignore[]
+ ):
+ task.ended = True
+
+ # 使用MongoDB保存任务块
+ task_collection = MongoDB.get_collection("task")
+
+ if task_id not in self._task_map:
+ err = f"Task {task_id} not found"
+ raise ValueError(err)
+
+ await task_collection.update_one({"_id": task_id}, {"$set": task.model_dump(by_alias=True)}, upsert=True)
+
+ # 从task_map中删除任务块,释放内存
+ del self._task_map[task_id]
diff --git a/apps/common/wordscheck.py b/apps/common/wordscheck.py
index cc389e938523a5585fb9caafc77aa46676e1d424..7093533695c0d5951ba7bf81d82225017707c87f 100644
--- a/apps/common/wordscheck.py
+++ b/apps/common/wordscheck.py
@@ -4,75 +4,61 @@ Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved.
"""
import http
import re
-from pathlib import Path
-from typing import Union
-import requests
+import aiofiles
+import aiohttp
+import ray
from apps.common.config import config
from apps.constants import LOGGER
-class APICheck:
- """使用API接口检查敏感词"""
+@ray.remote
+class WordsCheck:
+ """敏感词检查工具"""
+
+ async def _init_file(self) -> None:
+ """初始化关键词列表"""
+ async with aiofiles.open(config["WORDS_LIST"], encoding="utf-8") as f:
+ self.words_list = (await f.read()).splitlines()
+
+ async def _check_wordlist(self, message: str) -> int:
+ """使用关键词列表检查敏感词"""
+ if message in self.words_list:
+ return 1
+ return 0
- @classmethod
- def check(cls, content: str) -> int:
- """检查敏感词"""
+ async def _check_api(self, message: str) -> int:
+ """使用API接口检查敏感词"""
url = config["WORDS_CHECK"]
if url is None:
err = "配置文件中未设置WORDS_CHECK"
raise ValueError(err)
headers = {"Content-Type": "application/json"}
- data = {"content": content}
+ data = {"content": message}
try:
- response = requests.post(url=url, json=data, headers=headers, timeout=10)
- if response.status_code == http.HTTPStatus.OK and re.search("ok", str(response.content)):
- return 1
- return 0
+ async with aiohttp.ClientSession() as session, session.post(url=url, json=data, headers=headers, timeout=10) as response:
+ if response.status == http.HTTPStatus.OK and re.search("ok", str(response.content)):
+ return 1
+ return 0
except Exception as e:
LOGGER.info("过滤敏感词错误:" + str(e))
return -1
-class KeywordCheck:
- """使用关键词列表检查敏感词"""
-
- words_list: list
-
- def __init__(self) -> None:
- """初始化关键词列表"""
- with Path(config["WORDS_LIST"]).open("r", encoding="utf-8") as f:
- self.words_list = f.read().splitlines()
-
- def check(self, message: str) -> int:
- """使用关键词列表检查关键词"""
- if message in self.words_list:
- return 1
- return 0
-
-
-class WordsCheck:
- """敏感词检查工具"""
-
- tool: Union[APICheck, KeywordCheck, None] = None
-
- @classmethod
- def init(cls) -> None:
+ async def init(self) -> None:
"""初始化敏感词检查器"""
if config["DETECT_TYPE"] == "keyword":
- cls.tool = KeywordCheck()
- elif config["DETECT_TYPE"] == "wordscheck":
- cls.tool = APICheck()
- else:
- cls.tool = None
+ await self._init_file()
- @classmethod
- async def check(cls, message: str) -> int:
+ async def check(self, message: str) -> int:
"""检查消息是否包含关键词
异常-1,拦截0,正常1
"""
- if not cls.tool:
- return 1
- return cls.tool.check(message)
+ if config["DETECT_TYPE"] == "keyword":
+ return await self._check_wordlist(message)
+ if config["DETECT_TYPE"] == "wordscheck":
+ return await self._check_api(message)
+ # 不设置检查类型,默认不拦截
+ return 1
diff --git a/apps/cron/delete_user.py b/apps/cron/delete_user.py
index 52bc1508767713b2e0e87c107a898584a70ff82a..8cccf8022b7a37a5dfb36c22014c4b22b1161b6e 100644
--- a/apps/cron/delete_user.py
+++ b/apps/cron/delete_user.py
@@ -8,10 +8,8 @@ import asyncer
from apps.constants import LOGGER
from apps.entities.collection import Audit
-from apps.manager import (
- AuditLogManager,
- UserManager,
-)
+from apps.manager.audit_log import AuditLogManager
+from apps.manager.user import UserManager
from apps.models.mongo import MongoDB
from apps.service.knowledge_base import KnowledgeBaseService
diff --git a/apps/entities/flow.py b/apps/entities/flow.py
index 2fa3a7537c629df3e3cabe087c12825fa5dbecaf..e40a37e0f06ad1585ef21d467537f94b8b087fa0 100644
--- a/apps/entities/flow.py
+++ b/apps/entities/flow.py
@@ -67,7 +67,10 @@ class Permission(BaseModel):
class MetadataBase(BaseModel):
- """Service或App的元数据"""
+ """Service或App的元数据
+
+ 注意:hash字段在save和load的时候exclude
+ """
type: MetadataType = Field(description="元数据类型")
id: str = Field(description="元数据ID")
diff --git a/apps/entities/flow_topology.py b/apps/entities/flow_topology.py
index 619f6950ad24f3c3322f3fd392d165e88b84405f..04fe0ecd11acc0d391e1b1518a086543827c626f 100644
--- a/apps/entities/flow_topology.py
+++ b/apps/entities/flow_topology.py
@@ -28,6 +28,8 @@ class NodeServiceItem(BaseModel):
type: str = Field(..., description="服务类型")
node_meta_datas: list[NodeMetaDataItem] = Field(alias="nodeMetaDatas", default=[])
created_at: str = Field(..., alias="createdAt", description="创建时间")
+
+
class PositionItem(BaseModel):
"""请求/响应中的前端相对位置变量类"""
@@ -57,6 +59,7 @@ class NodeItem(BaseModel):
position: PositionItem=Field(default=PositionItem())
editable: bool = Field(default=True)
+
class EdgeItem(BaseModel):
"""请求/响应中的边变量类"""
diff --git a/apps/entities/task.py b/apps/entities/task.py
index e51f543ddaea288341f8055bbc32eb499e03900f..26a788ede968eccb78b3f1e796bfb9a377ba6255 100644
--- a/apps/entities/task.py
+++ b/apps/entities/task.py
@@ -62,7 +62,7 @@ class RequestDataApp(BaseModel):
params: dict[str, Any] = Field(description="插件参数")
-class Task(BaseModel):
+class TaskData(BaseModel):
"""任务信息
Collection: task
diff --git a/apps/entities/user.py b/apps/entities/user.py
index c42e57ff9a45828ada971516363526c5e40bc43e..1335ffb403f808e0dac297427dbc9780bc408323 100644
--- a/apps/entities/user.py
+++ b/apps/entities/user.py
@@ -10,4 +10,4 @@ class UserInfo(BaseModel):
"""用户信息数据结构"""
user_sub: str = Field(alias="userSub", default="")
- user_name: str = Field(alias="userName", default="")
\ No newline at end of file
+ user_name: str = Field(alias="userName", default="")
diff --git a/apps/llm/patterns/__init__.py b/apps/llm/patterns/__init__.py
index 0e468c9801ab1988e8ad502ea0521397d95a4b81..914ad23c0a148cbd3bebb5724fd965d9c62ae0d9 100644
--- a/apps/llm/patterns/__init__.py
+++ b/apps/llm/patterns/__init__.py
@@ -6,7 +6,7 @@ from apps.llm.patterns.core import CorePattern
from apps.llm.patterns.domain import Domain
from apps.llm.patterns.executor import (
ExecutorBackground,
- ExecutorResult,
+ FinalThought,
ExecutorThought,
)
from apps.llm.patterns.json import Json
@@ -17,7 +17,7 @@ __all__ = [
"CorePattern",
"Domain",
"ExecutorBackground",
- "ExecutorResult",
+ "FinalThought",
"ExecutorThought",
"Json",
"Recommend",
diff --git a/apps/llm/patterns/domain.py b/apps/llm/patterns/domain.py
index 3036bda8a28905f4ce1cd3e4dcb18d68c9c30667..f7c4e88655c7f03e6d5858f8efcb052d77bda104 100644
--- a/apps/llm/patterns/domain.py
+++ b/apps/llm/patterns/domain.py
@@ -13,21 +13,33 @@ class Domain(CorePattern):
"""从问答中提取领域信息"""
user_prompt: str = r"""
- 根据对话上文,提取推荐系统所需的关键词标签,要求:
- 1. 实体名词、技术术语、时间范围、地点、产品等关键信息均可作为关键词标签
- 2. 至少一个关键词与对话的话题有关
- 3. 标签需精简,不得重复,不得超过10个字
+
+
+ 根据对话上文,提取推荐系统所需的关键词标签,要求:
+ 1. 实体名词、技术术语、时间范围、地点、产品等关键信息均可作为关键词标签
+ 2. 至少一个关键词与对话的话题有关
+ 3. 标签需精简,不得重复,不得超过10个字
+ 4. 使用JSON格式输出,不要包含XML标签,不要包含任何解释说明
+
- ==示例==
- 样例对话:
- 用户:北京天气如何?
- 助手:北京今天晴。
+
+
+ 北京天气如何?
+ 北京今天晴。
+
- 样例输出:
- ["北京", "天气"]
- ==结束示例==
+
+
+
- 输出结果:
+
+ {conversation}
+
+