From 744d18d26147de3a7898d206aaecba150bafa993 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 22 Jul 2025 19:14:24 +0800 Subject: [PATCH 1/3] =?UTF-8?q?RecordGroupDocument=E8=B5=8B=E4=BB=A5?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E5=80=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/record.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/apps/schemas/record.py b/apps/schemas/record.py index 56ce08dd05..d7acd36820 100644 --- a/apps/schemas/record.py +++ b/apps/schemas/record.py @@ -103,10 +103,10 @@ class RecordGroupDocument(BaseModel): """RecordGroup关联的文件""" id: str = Field(default_factory=lambda: str(uuid.uuid4()), alias="_id") - name: str = Field(description="文档名称") - abstract: str = Field(description="文档摘要", default="") - extension: str = Field(description="文档扩展名", default="") - size: int = Field(description="文档大小,单位是KB", default=0) + name: str = Field(default="", description="文档名称") + abstract: str = Field(default="", description="文档摘要") + extension: str = Field(default="", description="文档扩展名") + size: int = Field(default=0, description="文档大小,单位是KB") associated: Literal["question", "answer"] -- Gitee From de59af82f917d2be15d171e95bbd0d4cbfb5ea92 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 22 Jul 2025 19:48:11 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E5=AE=8C=E5=96=84=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E6=94=B9=E5=86=99=E7=9A=84prompt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/llm/patterns/rewrite.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/apps/llm/patterns/rewrite.py b/apps/llm/patterns/rewrite.py index 3ad7ced379..f1707f8f4b 100644 --- a/apps/llm/patterns/rewrite.py +++ b/apps/llm/patterns/rewrite.py @@ -21,10 +21,7 @@ class QuestionRewriteResult(BaseModel): class QuestionRewrite(CorePattern): """问题补全与重写""" - system_prompt: str = "You are a helpful assistant." - """系统提示词""" - - user_prompt: str = r""" + system_prompt: str = r""" 根据历史对话,推断用户的实际意图并补全用户的提问内容,历史对话被包含在标签中,用户意图被包含在标签中。 @@ -77,6 +74,11 @@ class QuestionRewrite(CorePattern): """ """用户提示词""" + user_prompt: str = """ + + 请输出补全后的问题 + + """ async def generate(self, **kwargs) -> str: # noqa: ANN003 """问题补全与重写""" @@ -88,8 +90,8 @@ class QuestionRewrite(CorePattern): leave_tokens = llm._config.max_tokens leave_tokens -= TokenCalculator().calculate_token_length( messages=[ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format(history="", question=question)} + {"role": "system", "content": self.system_prompt.format(history="", question=question)}, + {"role": "user", "content": self.user_prompt} ] ) if leave_tokens <= 0: @@ -111,8 +113,8 @@ class QuestionRewrite(CorePattern): qa = sub_qa + qa index -= 2 messages = [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": self.user_prompt.format(history=qa, question=question)}, + {"role": "system", "content": self.system_prompt.format(history=qa, question=question)}, + {"role": "user", "content": self.user_prompt} ] result = "" async for chunk in llm.call(messages, streaming=False): -- Gitee From 12617796a03fc11c5f030dfa1b2330afc7019e64 Mon Sep 17 00:00:00 2001 From: zxstty Date: Tue, 22 Jul 2025 20:41:44 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=E7=9A=84rag=E6=9C=8D=E5=8A=A1?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/schemas/message.py | 2 +- apps/services/document.py | 1 + apps/services/rag.py | 53 +++++++++++++++++++++------------------ 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/apps/schemas/message.py b/apps/schemas/message.py index d887b345d1..d0661224e0 100644 --- a/apps/schemas/message.py +++ b/apps/schemas/message.py @@ -58,7 +58,7 @@ class TextAddContent(BaseModel): class DocumentAddContent(BaseModel): """document.add消息的content""" - document_id: str = Field(min_length=36, max_length=36, description="文档UUID", alias="documentId") + document_id: str = Field(description="文档UUID", alias="documentId") document_order: int = Field(description="文档在对话中的顺序,从1开始", alias="documentOrder") document_name: str = Field(description="文档名称", alias="documentName") document_abstract: str = Field(description="文档摘要", alias="documentAbstract", default="") diff --git a/apps/services/document.py b/apps/services/document.py index ec5979da83..203162da11 100644 --- a/apps/services/document.py +++ b/apps/services/document.py @@ -123,6 +123,7 @@ class DocumentManager: for doc in docs: if doc.associated == "question": doc_info = await document_collection.find_one({"_id": doc.id, "user_sub": user_sub}) + doc_info = Document.model_validate(doc_info) if doc_info else None if doc_info: doc.name = doc_info.name doc.extension = doc_info.type diff --git a/apps/services/rag.py b/apps/services/rag.py index 6905df51cd..6b6c843dc6 100644 --- a/apps/services/rag.py +++ b/apps/services/rag.py @@ -39,6 +39,7 @@ class RAG: 2.脚注的格式为[[1]],[[2]],[[3]]等,脚注的内容为提供的文档的id。 3.脚注只出现在回答的句子的末尾,例如句号、问号等标点符号后面。 4.不要对脚注本身进行解释或说明。 + 5.请不要使用中的文档的id作为脚注。 @@ -77,7 +78,7 @@ class RAG: """ @staticmethod - async def get_doc_info_from_rag(user_sub: str, llm: LLM, history: list[dict[str, str]], + async def get_doc_info_from_rag(user_sub: str, max_tokens: int, doc_ids: list[str], data: RAGQueryReq) -> list[dict[str, Any]]: """获取RAG服务的文档信息""" @@ -87,12 +88,6 @@ class RAG: "Content-Type": "application/json", "Authorization": f"Bearer {session_id}", } - if history: - try: - question_obj = QuestionRewrite() - data.query = await question_obj.generate(history=history, question=data.query, llm=llm) - except Exception: - logger.exception("[RAG] 问题重写失败") doc_chunk_list = [] if doc_ids: default_kb_id = "00000000-0000-0000-0000-000000000000" @@ -105,7 +100,7 @@ class RAG: isRelatedSurrounding=data.is_related_surrounding, isClassifyByDoc=data.is_classify_by_doc, isRerank=data.is_rerank, - tokensLimit=llm.max_tokens, + tokensLimit=max_tokens, ) try: async with httpx.AsyncClient(timeout=30) as client: @@ -134,14 +129,14 @@ class RAG: """组装文档信息""" bac_info = "" doc_info_list = [] - doc_cnt = 1 + doc_cnt = 0 doc_id_map = {} leave_tokens = max_tokens token_calculator = TokenCalculator() for doc_chunk in doc_chunk_list: if doc_chunk["docId"] not in doc_id_map: - doc_id_map[doc_chunk["docId"]] = doc_cnt doc_cnt += 1 + doc_id_map[doc_chunk["docId"]] = doc_cnt doc_index = doc_id_map[doc_chunk["docId"]] leave_tokens -= token_calculator.calculate_token_length(messages=[ {"role": "user", "content": f''''''}, @@ -152,10 +147,11 @@ class RAG: {"role": "user", "content": ""}, {"role": "user", "content": ""}, ], pure_text=True) - doc_cnt = 1 + doc_cnt = 0 doc_id_map = {} for doc_chunk in doc_chunk_list: if doc_chunk["docId"] not in doc_id_map: + doc_cnt += 1 doc_info_list.append({ "id": doc_chunk["docId"], "order": doc_cnt, @@ -165,7 +161,6 @@ class RAG: "size": doc_chunk.get("docSize", 0), }) doc_id_map[doc_chunk["docId"]] = doc_cnt - doc_cnt += 1 doc_index = doc_id_map[doc_chunk["docId"]] if bac_info: bac_info += "\n\n" @@ -189,17 +184,31 @@ class RAG: ''' bac_info += "" - return bac_info + return bac_info, doc_info_list @staticmethod async def chat_with_llm_base_on_rag( user_sub: str, llm: LLM, history: list[dict[str, str]], doc_ids: list[str], data: RAGQueryReq ) -> AsyncGenerator[str, None]: """获取RAG服务的结果""" - doc_info_list = await RAG.get_doc_info_from_rag( - user_sub=user_sub, llm=llm, history=history, doc_ids=doc_ids, data=data) - bac_info = await RAG.assemble_doc_info( - doc_chunk_list=doc_info_list, max_tokens=llm.max_tokens) + reasion_llm = ReasoningLLM( + LLMConfig( + endpoint=llm.openai_base_url, + key=llm.openai_api_key, + model=llm.model_name, + max_tokens=llm.max_tokens, + ) + ) + if history: + try: + question_obj = QuestionRewrite() + data.query = await question_obj.generate(history=history, question=data.query, llm=reasion_llm) + except Exception: + logger.exception("[RAG] 问题重写失败") + doc_chunk_list = await RAG.get_doc_info_from_rag( + user_sub=user_sub, max_tokens=llm.max_tokens, doc_ids=doc_ids, data=data) + bac_info, doc_info_list = await RAG.assemble_doc_info( + doc_chunk_list=doc_chunk_list, max_tokens=llm.max_tokens) messages = [ *history, { @@ -216,7 +225,9 @@ class RAG: ] input_tokens = TokenCalculator().calculate_token_length(messages=messages) output_tokens = 0 + doc_cnt = 0 for doc_info in doc_info_list: + doc_cnt = max(doc_cnt, doc_info["order"]) yield ( "data: " + json.dumps( @@ -235,14 +246,6 @@ class RAG: doc_cnt //= 10 max_footnote_length += 1 buffer = "" - reasion_llm = ReasoningLLM( - LLMConfig( - endpoint=llm.openai_base_url, - key=llm.openai_api_key, - model=llm.model_name, - max_tokens=llm.max_tokens, - ) - ) async for chunk in reasion_llm.call( messages, max_tokens=llm.max_tokens, -- Gitee