diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py
index 3ad7ced379d54e6b5d474efaae81df731b5f5e0c..f1707f8f4b2ca5527679f6f4ad7836a1981df95e 100644
--- a/apps/llm/patterns/rewrite.py
+++ b/apps/llm/patterns/rewrite.py
@@ -21,10 +21,7 @@ class QuestionRewriteResult(BaseModel):
class QuestionRewrite(CorePattern):
"""问题补全与重写"""
- system_prompt: str = "You are a helpful assistant."
- """系统提示词"""
-
- user_prompt: str = r"""
+ system_prompt: str = r"""
根据历史对话,推断用户的实际意图并补全用户的提问内容,历史对话被包含在标签中,用户意图被包含在标签中。
@@ -77,6 +74,11 @@ class QuestionRewrite(CorePattern):
"""
"""用户提示词"""
+ user_prompt: str = """
+
+ 请输出补全后的问题
+
+ """
async def generate(self, **kwargs) -> str: # noqa: ANN003
"""问题补全与重写"""
@@ -88,8 +90,8 @@ class QuestionRewrite(CorePattern):
leave_tokens = llm._config.max_tokens
leave_tokens -= TokenCalculator().calculate_token_length(
messages=[
- {"role": "system", "content": self.system_prompt},
- {"role": "user", "content": self.user_prompt.format(history="", question=question)}
+ {"role": "system", "content": self.system_prompt.format(history="", question=question)},
+ {"role": "user", "content": self.user_prompt}
]
)
if leave_tokens <= 0:
@@ -111,8 +113,8 @@ class QuestionRewrite(CorePattern):
qa = sub_qa + qa
index -= 2
messages = [
- {"role": "system", "content": self.system_prompt},
- {"role": "user", "content": self.user_prompt.format(history=qa, question=question)},
+ {"role": "system", "content": self.system_prompt.format(history=qa, question=question)},
+ {"role": "user", "content": self.user_prompt}
]
result = ""
async for chunk in llm.call(messages, streaming=False):
diff --git a/apps/schemas/message.py b/apps/schemas/message.py
index d887b345d17c899c5da702781185dff1a382d1c8..d0661224e0817ff6f25e341d65c88903020d18e8 100644
--- a/apps/schemas/message.py
+++ b/apps/schemas/message.py
@@ -58,7 +58,7 @@ class TextAddContent(BaseModel):
class DocumentAddContent(BaseModel):
"""document.add消息的content"""
- document_id: str = Field(min_length=36, max_length=36, description="文档UUID", alias="documentId")
+ document_id: str = Field(description="文档UUID", alias="documentId")
document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder")
document_name: str = Field(description="文档名称", alias="documentName")
document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="")
diff --git a/apps/schemas/record.py b/apps/schemas/record.py
index 56ce08dd059de10f01d78ff9edff6257dd26be85..d7acd368205d0aa5a2b368edf4af3a31ed4d1ff6 100644
--- a/apps/schemas/record.py
+++ b/apps/schemas/record.py
@@ -103,10 +103,10 @@ class RecordGroupDocument(BaseModel):
"""RecordGroup关联的文件"""
id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id")
- name: str = Field(description="文档名称")
- abstract: str = Field(description="文档摘要", default="")
- extension: str = Field(description="文档扩展名", default="")
- size: int = Field(description="文档大小,单位是KB", default=0)
+ name: str = Field(default="", description="文档名称")
+ abstract: str = Field(default="", description="文档摘要")
+ extension: str = Field(default="", description="文档扩展名")
+ size: int = Field(default=0, description="文档大小,单位是KB")
associated: Literal["question", "answer"]
diff --git a/apps/services/document.py b/apps/services/document.py
index ec5979da83989a554eefd4202dd284cb34a7f1ec..203162da1137cdd69d900ceea65a3029d2f4fd8e 100644
--- a/apps/services/document.py
+++ b/apps/services/document.py
@@ -123,6 +123,7 @@ class DocumentManager:
for doc in docs:
if doc.associated == "question":
doc_info = await document_collection.find_one({"_id": doc.id, "user_sub": user_sub})
+ doc_info = Document.model_validate(doc_info) if doc_info else None
if doc_info:
doc.name = doc_info.name
doc.extension = doc_info.type
diff --git a/apps/services/rag.py b/apps/services/rag.py
index 6905df51cdd807d7efaea04d0b127ef3df3d9de0..6b6c843dc6fc09a14809ef0a11bc44b2da434f29 100644
--- a/apps/services/rag.py
+++ b/apps/services/rag.py
@@ -39,6 +39,7 @@ class RAG:
2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。
3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。
4.不要对脚注本身进行解释或说明。
+ 5.请不要使用中的文档的id作为脚注。
@@ -77,7 +78,7 @@ class RAG:
"""
@staticmethod
- async def get_doc_info_from_rag(user_sub: str, llm: LLM, history: list[dict[str, str]],
+ async def get_doc_info_from_rag(user_sub: str, max_tokens: int,
doc_ids: list[str],
data: RAGQueryReq) -> list[dict[str, Any]]:
"""获取RAG服务的文档信息"""
@@ -87,12 +88,6 @@ class RAG:
"Content-Type": "application/json",
"Authorization": f"Bearer {session_id}",
}
- if history:
- try:
- question_obj = QuestionRewrite()
- data.query = await question_obj.generate(history=history, question=data.query, llm=llm)
- except Exception:
- logger.exception("[RAG] 问题重写失败")
doc_chunk_list = []
if doc_ids:
default_kb_id = "00000000-0000-0000-0000-000000000000"
@@ -105,7 +100,7 @@ class RAG:
isRelatedSurrounding=data.is_related_surrounding,
isClassifyByDoc=data.is_classify_by_doc,
isRerank=data.is_rerank,
- tokensLimit=llm.max_tokens,
+ tokensLimit=max_tokens,
)
try:
async with httpx.AsyncClient(timeout=30) as client:
@@ -134,14 +129,14 @@ class RAG:
"""组装文档信息"""
bac_info = ""
doc_info_list = []
- doc_cnt = 1
+ doc_cnt = 0
doc_id_map = {}
leave_tokens = max_tokens
token_calculator = TokenCalculator()
for doc_chunk in doc_chunk_list:
if doc_chunk["docId"] not in doc_id_map:
- doc_id_map[doc_chunk["docId"]] = doc_cnt
doc_cnt += 1
+ doc_id_map[doc_chunk["docId"]] = doc_cnt
doc_index = doc_id_map[doc_chunk["docId"]]
leave_tokens -= token_calculator.calculate_token_length(messages=[
{"role": "user", "content": f''''''},
@@ -152,10 +147,11 @@ class RAG:
{"role": "user", "content": ""},
{"role": "user", "content": ""},
], pure_text=True)
- doc_cnt = 1
+ doc_cnt = 0
doc_id_map = {}
for doc_chunk in doc_chunk_list:
if doc_chunk["docId"] not in doc_id_map:
+ doc_cnt += 1
doc_info_list.append({
"id": doc_chunk["docId"],
"order": doc_cnt,
@@ -165,7 +161,6 @@ class RAG:
"size": doc_chunk.get("docSize", 0),
})
doc_id_map[doc_chunk["docId"]] = doc_cnt
- doc_cnt += 1
doc_index = doc_id_map[doc_chunk["docId"]]
if bac_info:
bac_info += "\n\n"
@@ -189,17 +184,31 @@ class RAG:
'''
bac_info += ""
- return bac_info
+ return bac_info, doc_info_list
@staticmethod
async def chat_with_llm_base_on_rag(
user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], data: RAGQueryReq
) -> AsyncGenerator[str, None]:
"""获取RAG服务的结果"""
- doc_info_list = await RAG.get_doc_info_from_rag(
- user_sub=user_sub, llm=llm, history=history, doc_ids=doc_ids, data=data)
- bac_info = await RAG.assemble_doc_info(
- doc_chunk_list=doc_info_list, max_tokens=llm.max_tokens)
+ reasion_llm = ReasoningLLM(
+ LLMConfig(
+ endpoint=llm.openai_base_url,
+ key=llm.openai_api_key,
+ model=llm.model_name,
+ max_tokens=llm.max_tokens,
+ )
+ )
+ if history:
+ try:
+ question_obj = QuestionRewrite()
+ data.query = await question_obj.generate(history=history, question=data.query, llm=reasion_llm)
+ except Exception:
+ logger.exception("[RAG] 问题重写失败")
+ doc_chunk_list = await RAG.get_doc_info_from_rag(
+ user_sub=user_sub, max_tokens=llm.max_tokens, doc_ids=doc_ids, data=data)
+ bac_info, doc_info_list = await RAG.assemble_doc_info(
+ doc_chunk_list=doc_chunk_list, max_tokens=llm.max_tokens)
messages = [
*history,
{
@@ -216,7 +225,9 @@ class RAG:
]
input_tokens = TokenCalculator().calculate_token_length(messages=messages)
output_tokens = 0
+ doc_cnt = 0
for doc_info in doc_info_list:
+ doc_cnt = max(doc_cnt, doc_info["order"])
yield (
"data: "
+ json.dumps(
@@ -235,14 +246,6 @@ class RAG:
doc_cnt //= 10
max_footnote_length += 1
buffer = ""
- reasion_llm = ReasoningLLM(
- LLMConfig(
- endpoint=llm.openai_base_url,
- key=llm.openai_api_key,
- model=llm.model_name,
- max_tokens=llm.max_tokens,
- )
- )
async for chunk in reasion_llm.call(
messages,
max_tokens=llm.max_tokens,