From bf6e4e3b42c39616e746024744bb90373260cc0b Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Fri, 22 Nov 2024 16:57:08 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxRAG/PocValidation/chat_with_ascend/app.py | 558 ++++++++++++++++---- 1 file changed, 441 insertions(+), 117 deletions(-) diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index 83d1f33d2..d5805bfb8 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -9,6 +9,9 @@ from loguru import logger from paddle.base import libpaddle from mx_rag.chain import SingleText2TextChain from mx_rag.embedding.local import TextEmbedding +from mx_rag.embedding.service import TEIEmbedding +from mx_rag.reranker.local import LocalReranker +from mx_rag.reranker.service import TEIReranker from mx_rag.llm import Text2TextLLM, LLMParameterConfig from mx_rag.retrievers import Retriever from mx_rag.storage.document_store import SQLiteDocstore @@ -17,39 +20,225 @@ from mx_rag.storage.vectorstore.vectorstore import SimilarityStrategy from mx_rag.utils import ClientParam from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import TextLoader -from mx_rag.document.loader import DocxLoader, PdfLoader -from mx_rag.knowledge import KnowledgeDB -from mx_rag.knowledge.knowledge import KnowledgeStore +from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader +from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr from mx_rag.knowledge.handler import upload_files, LoaderMng -global text2text_chain -global text_knowledge_db -global loader_mng -global text_emb +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + +parse = argparse.ArgumentParser(formatter_class=CustomFormatter) +parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", help="embedding模型本地路径") +parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") +parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed",help="使用TEI服务化的embedding模型url地址") +parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") +parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") +parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") +parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") +parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径") +parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址") +parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取") +parse.add_argument("--port", type=int, default=8080, help="web后端端口") + +args = parse.parse_args().__dict__ +tei_emb: bool = args.pop('tei_emb') +embedding_path: str = args.pop('embedding_path') +embedding_url: str = args.pop('embedding_url') +embedding_dim: int = args.pop('embedding_dim') +llm_url: str = args.pop('llm_url') +model_name: str = args.pop('model_name') +tei_reranker: bool = args.pop('tei_reranker') +reranker_path: str = args.pop('reranker_path') +reranker_url: str = args.pop('reranker_url') +dev: int = args.pop('dev') +port: int = args.pop('port') +# 初始化向量数据库 +vector_store = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index="./test_faiss.index", + auto_save=True + ) +# 初始化文档chunk关系数据库 +chunk_store = SQLiteDocstore(db_path="./test_sql.db") +# 初始化知识管理关系数据库 +knowledge_store = KnowledgeStore(db_path="./test_sql.db") +# 初始化知识库管理 +text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, + chunk_store=chunk_store, + vector_store=vector_store, + knowledge_name="test", + white_paths=["/tmp"] + ) +knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db")) + +# 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 +llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True)) + +# 配置embedding模型,请根据模型具体路径适配 +if tei_emb: + text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) +else: + text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + +# 配置rerank,请根据模型具体路径适配 +if tei_reranker: + reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) +elif reranker_path is not None: + reranker = LocalReranker(model_path=reranker_path, dev_id=dev) +else: + reranker = None + + +# 历史问题改写 +def generate_question(history, llm, history_n: int = 5): + prompt = """现在你有一个上下文依赖问题补全任务,任务要求:请根据对话历史和用户当前的问句,重写问句。\n + 历史问题依次是:\n + {}\n + 用户当前的问句:\n + {}\n + 注意如果当前问题不依赖历史问题直接返回none即可\n + 请根据上述对话历史重写用户当前的问句,仅输出重写后的问句,不需要附加任何分析。\n + 重写问句: \n + """ + if len(history) <= 2: + return history + cur_query = history[-1][0] + history_qa = history[0:-1] + history_list = [f"第{idx + 1}轮:{q_a[0]}" for idx, q_a in enumerate(history_qa) if q_a[0] is not None] + history_list = history_list[:history_n] + history_str = "\n\n".join(history_list) + re_query = prompt.format(history_str, cur_query) + new_query = llm.chat(query=re_query, llm_config=LLMParameterConfig(max_tokens=512, + temperature=0.5, + top_p=0.95)) + if new_query != "none": + history[-1][0] = "原始问题: " + cur_query + history += [[new_query, None]] + return history + + +# 聊天对话框 def bot_response(history, + history_r, max_tokens: int = 512, temperature: float = 0.5, - top_p: float = 0.95 + top_p: float = 0.95, + history_n: int = 5, + score_threshold: float = 0.5, + top_k: int = 1, + chat_type: str = "RAG检索增强对话", + show_type: str = "不显示", + is_rewrite:str = "否", + knowledge_name: str = 'test', ): - # 将最新的问题传给RAG + chat_type_mapping = {"RAG检索增强对话":1, + "仅大模型对话":0} + show_type_mapping = {"对话结束后显示":1, + "检索框单独显示":2, + "不显示":0} + is_rewrite_mapping = {"是":1, + "否":0} + # 初始化检索器 + retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold) + # 初始化chain + text2text_chain = SingleText2TextChain(llm=llm, retriever=retrieve_cls, reranker=reranker) + # 历史问题改写 + if is_rewrite_mapping.get(is_rewrite) == 1: + history = generate_question(history, llm, history_n) + history[-1][1] = '推理错误' try: - response = text2text_chain.query(history[-1][0], - LLMParameterConfig(max_tokens=max_tokens, temperature=temperature, top_p=top_p, - stream=True)) - # 返回迭代器 - history[-1][1] = '推理错误' - for res in response: - history[-1][1] = '推理错误' if res['result'] is None else res['result'] - yield history - yield history + # 仅使用大模型 + if chat_type_mapping.get(chat_type) == 0: + response = llm.chat_streamly(query=history[-1][0], + llm_config=LLMParameterConfig(max_tokens=max_tokens, + temperature=temperature, + top_p=top_p)) + # 返回迭代器 + for res in response: + history[-1][1] = res + yield history, history_r + # 使用RAG增强回答 + elif chat_type_mapping.get(chat_type) == 1: + response = text2text_chain.query(history[-1][0], + LLMParameterConfig(max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True)) + # 不展示检索内容 + if show_type_mapping.get(show_type) == 0: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'] + yield history, history_r + # 问答结尾展示 + elif show_type_mapping.get(show_type) == 1: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'] + q_docs = res['source_documents'] + yield history, history_r + # 有检索到信息 + if len(q_docs) > 0: + history_last = '' + for i ,source in enumerate(q_docs): + sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \ + "--->参考内容:" +source['page_content'] + "\n" + history_last += sources + history += [[None, history_last]] + yield history, history_r + # 检索窗口展示 + else: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'] + q_docs = res['source_documents'] + yield history, history_r + # 有检索到信息 + if len(q_docs) > 0: + history_r_last = '' + for i ,source in enumerate(q_docs): + sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \ + "--->参考内容:" +source['page_content'] + "\n" + history_r_last += sources + history_r += [[history[-1][0], history_r_last]] + yield history, history_r except Exception as err: logger.info(f"query failed, find exception: {err}") - history[-1][1] = "推理错误" - yield history + yield history, history_r + + +# 检索对话框 +def re_response(history_r, + score_threshold: float = 0.5, + top_k: int = 1, + knowledge_name: str = 'test' + ): + # 初始化检索器 + retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold) + q_docs = retrieve_cls.invoke(history_r[-1][0]) + if len(q_docs) > 0: + history_r_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源" + str(i + 1) + ":" + source.metadata['source'] + "====" + "\n" + \ + "--->参考内容:" + source.page_content + "\n" + history_r_last += sources + history_r[-1][1] = history_r_last + else: + history_r[-1][1] = "未检索到相关信息" + return history_r +# 检索信息 +def user_retriever(user_message, history_r): + return "",history_r + [[user_message, None]] + + +# 聊天信息 def user_query(user_message, history): return "", history + [[user_message, None]] @@ -58,23 +247,137 @@ def clear_history(history): return [] -SAVE_FILE_PATH = "/tmp/document_files" +# 创建新的知识库 +def creat_knowledge_db(knowledge_name: str, is_regist: bool=True): + knowledge_name = get_knowledge_rename(knowledge_name) + index_name, db_name = get_db_names(knowledge_name) + vector_store_cls = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index=index_name, + auto_save=True) + # 初始化文档chunk关系数据库 + chunk_store_cls = SQLiteDocstore(db_path=db_name) + # 初始化知识管理关系数据库 + knowledge_store_cls = KnowledgeStore(db_path=db_name) + # 初始化知识库管理 + knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls, + chunk_store=chunk_store_cls, + vector_store=vector_store_cls, + knowledge_name=knowledge_name, + white_paths=["/tmp"]) + if is_regist: + knowledge_mgr.register(knowledge_db_cls) + return knowledge_db_cls + + +def creat_retriever(knowledge_name: str, top_k, score_threshold): + knowledge_name = get_knowledge_rename(knowledge_name) + index_name, db_name = get_db_names(knowledge_name) + vector_store_ret = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index=index_name, + auto_save=True) + # 初始化文档chunk关系数据库 + chunk_store_ret = SQLiteDocstore(db_path=db_name) + retriever_cls = Retriever(vector_store=vector_store_ret, + document_store=chunk_store_ret, + embed_func=text_emb.embed_documents, + k=top_k, + score_threshold=score_threshold) + return retriever_cls + + +# 删除知识库 +def delete_knowledge_db(knowledge_name: str): + knowledge_name = get_knowledge_rename(knowledge_name) + knowledge_db_lst = knowledge_mgr.get_all() + if knowledge_name in knowledge_db_lst: + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + clear_document(knowledge_name) + knowledge_mgr.delete(knowledge_db_cls) + return get_knowledge_db() + + +# 获取知识库文档列表 +def get_document(knowledge_name: str): + knowledge_name = get_knowledge_rename(knowledge_name) + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + doc_names = knowledge_db_cls.get_all_documents() + return knowledge_name, doc_names, len(doc_names) + + +# 清空知识库文档列表 +def clear_document(knowledge_name: str): + knowledge_name, doc_names, doc_cnt = get_document(knowledge_name) + if doc_cnt > 0: + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + for doc_name in doc_names: + knowledge_db_cls.delete_file(doc_name) + return get_document(knowledge_name) + else: + return knowledge_name, doc_names, doc_cnt + + +# 获取知识库列表 +def get_knowledge_db(): + knowledge_db_lst = knowledge_mgr.get_all() + return knowledge_db_lst, len(knowledge_db_lst) + + +def get_db_names(knowledge_name: str): + index_name = "./" + knowledge_name + "_faiss.index" + dn_name = "./" + knowledge_name + "_sql.db" + return index_name, dn_name -def file_upload(files): +def get_knowledge_rename(knowledge_name: str): + if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name: + return 'test' + else: + return knowledge_name + + +# 上传知识库 +def file_upload(files, + knowledge_db_name: str = 'test', + chunk_size: int = 750, + chunk_overlap: int = 150 + ): + save_file_path = "/tmp/document_files" + knowledge_db_name = get_knowledge_rename(knowledge_db_name) # 指定保存文件的文件夹 - if not os.path.exists(SAVE_FILE_PATH): - os.makedirs(SAVE_FILE_PATH) + if not os.path.exists(save_file_path): + os.makedirs(save_file_path) if files is None or len(files) == 0: print('no file need save') + if knowledge_db_name in knowledge_mgr.get_all(): + knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False) + else: + knowledge_db_cls = creat_knowledge_db(knowledge_db_name) + # 注册文档加载器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) + loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"]) + + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], + splitter_params={"chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "keep_separator": False + }) for file in files: try: # 上传领域知识文档 - shutil.copy(file.name, SAVE_FILE_PATH) + shutil.copy(file.name, save_file_path) # 知识库:chunk\embedding\add - upload_files(text_knowledge_db, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, - force=True) - print(f"file {file.name} save to {SAVE_FILE_PATH}.") + upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True) + print(f"file {file.name} save to {save_file_path}.") except Exception as err: logger.error(f"save failed, find exception: {err}") @@ -83,6 +386,7 @@ def file_change(files, upload_btn): print("file changes") +# 构建gradio框架 def build_demo(): with gr.Blocks() as demo: gr.HTML("

检索增强生成(RAG)对话

powered by MindX RAG

") @@ -94,14 +398,51 @@ def build_demo(): files = gr.Files( height=300, file_count="multiple", - file_types=[".docx", ".txt", ".md", ".pdf"], + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], interactive=True, label="上传知识库文档" ) with gr.Row(): upload_btn = gr.Button("上传文件") with gr.Row(): - with gr.Accordion(label='知识库情况'): + with gr.TabItem(label='知识库情况'): + set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库') + with gr.Row(): + creat_knowledge_btn = gr.Button('创建知识库') + delete_knowledge_btn = gr.Button('删除知识库') + knowledge_name_output = gr.Textbox(label='知识库列表') + knowledge_number_output = gr.Textbox(label='知识库数量') + with gr.Row(): + show_knowledge_btn = gr.Button('显示知识库') + with gr.TabItem(label='文件情况'): + knowledge_name = gr.Textbox('知识库名称') + knowledge_file_output = gr.Textbox('知识库文件列表') + knowledge_file_num_output = gr.Textbox('知识库文件数量') + with gr.Row(): + knowledge_file_out_btn = gr.Button('显示文件情况') + knowledge_clear_btn = gr.Button('清空知识库') + with gr.Row(): + with gr.According(label='文档切分参数设置',open=False): + chunk_size = gr.Slider( + minimum=50, + maximum=5000, + value=750, + step=50, + interactive=True, + label="chunk_size", + info="文本切分长度" + ) + chunk_overlap = gr.Slider( + minimum=10, + maximum=500, + value=150, + step=10, + interactive=True, + label="chunk_overlap", + info="文本切分填充长度" + ) + with gr.Row(): + with gr.According(label='大模型参数设置', open=False): temperature = gr.Slider( minimum=0.01, maximum=2, @@ -129,101 +470,84 @@ def build_demo(): label="最大tokens", info="输入+输出最多的tokens数" ) - with gr.Column(scale=200): - chatbot = gr.Chatbot(height=500) - with gr.Row(): - msg = gr.Textbox(placeholder="在此输入问题...", container=False) + is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?") + history_n = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + interactive=True, + label="历史提问重新轮数", + info="问题重写时所参考的历史提问轮数" + ) with gr.Row(): - send_btn = gr.Button(value="发送", variant="primary") - clean_btn = gr.Button(value="清空历史") - send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False).then(bot_response, - [chatbot, max_tokens, - temperature, top_p], chatbot) + with gr.According(label='检索参数设置', open=False): + score_threshold = gr.Slider( + minimum=0, + maximum=1, + value=0.5, + step=0.01, + interactive=True, + label="score_threshold", + info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" + ) + top_k = gr.Slider( + minimum=1, + maximum=5, + value=1, + step=1, + interactive=True, + label="top_k", + info="相似性检索返回条数" + ) + show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择") + with gr.Column(scale=200): + with gr.Tabs(): + with gr.TabItem("对话窗口"): + chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?") + chatbot = gr.Chatbot(height=550) + with gr.Row(): + msg = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn = gr.Button(value="发送", variant="primary") + clean_btn = gr.Button(value="清空历史") + with gr.TabItem("检索窗口"): + chatbot_r = gr.Chatbot(height=550) + with gr.Row(): + msg_r = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn_r = gr.Button(value="文档检索", variant="primary") + clean_btn_r = gr.Button(value="清空历史") + # RAG发送 + send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False + ).then(bot_response, + [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name], + [chatbot, chatbot_r]) + # RAG清除历史 clean_btn.click(clear_history, chatbot, chatbot) + # 上传文件 files.change(file_change, [], []) - upload_btn.click(file_upload, files, []).then(clear_history, files, files) + upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files) + # 管理所有知识库 + creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output]) + show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) + delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output]) + # 管理知识库里文件 + knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + # 检索发送 + send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False + ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r) + # 检索清除历史 + clean_btn_r.click(clear_history, chatbot_r, chatbot_r) return demo -def create_gradio(port): +def create_gradio(ports): demo = build_demo() demo.queue() - demo.launch(share=True, server_name="0.0.0.0", server_port=port) - - -class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): - def _get_default_metavar_for_optional(self, action): - return action.type.__name__ - - def _get_default_metavar_for_positional(self, action): - return action.type.__name__ + demo.launch(share=True, server_name="0.0.0.0", server_port=ports) if __name__ == '__main__': - parse = argparse.ArgumentParser(formatter_class=CustomFormatter) - parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", - help="embedding模型本地路径") - parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") - parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") - parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") - parse.add_argument("--port", type=str, default=8080, help="web后端端口") - - args = parse.parse_args() - - try: - # 离线构建知识库,首先注册文档处理器 - loader_mng = LoaderMng() - # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 - loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) - loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) - loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) - - # 加载文档切分器,使用langchain的 - loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, - file_types=[".docx", ".txt", ".md", ".pdf"], - splitter_params={"chunk_size": 750, - "chunk_overlap": 150, - "keep_separator": False - } - ) - - # 设置向量检索使用的npu卡,具体可以用的卡可执行npu-smi info查询获取 - dev = 0 - - # 初始化向量数据库 - vector_store = MindFAISS(x_dim=args.embedding_dim, - similarity_strategy=SimilarityStrategy.FLAT_L2, - devs=[dev], - load_local_index="./faiss.index", - auto_save=True - ) - # 初始化文档chunk关系数据库 - chunk_store = SQLiteDocstore(db_path="./sql.db") - # 初始化知识管理关系数据库 - knowledge_store = KnowledgeStore(db_path="./sql.db") - # 初始化知识库管理 - text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, - chunk_store=chunk_store, - vector_store=vector_store, - knowledge_name="test", - white_paths=["/tmp"] - ) - - # 加载embedding模型,请根据模型具体路径适配 - text_emb = TextEmbedding(model_path=args.embedding_path, dev_id=dev) - - # Step2在线问题答复,初始化检索器 - text_retriever = Retriever(vector_store=vector_store, - document_store=chunk_store, - embed_func=text_emb.embed_documents, - k=1, - score_threshold=0.4 - ) - # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 - llm = Text2TextLLM(base_url=args.llm_url, model_name=args.model_name, client_param=ClientParam(use_http=True)) - text2text_chain = SingleText2TextChain(llm=llm, - retriever=text_retriever, - ) - create_gradio(int(args.port)) - except Exception as e: - logger.info(f"exception happed :{e}") + create_gradio(port) -- Gitee From 01ec67c2a0a8da930d0867ac111d47f4b2ed08f4 Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Mon, 25 Nov 2024 11:07:33 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxRAG/PocValidation/chat_with_ascend/app.py | 731 ++++++++++---------- 1 file changed, 359 insertions(+), 372 deletions(-) diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index d5805bfb8..5d6dea138 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -24,77 +24,148 @@ from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr from mx_rag.knowledge.handler import upload_files, LoaderMng - -class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): - def _get_default_metavar_for_optional(self, action): - return action.type.__name__ - - def _get_default_metavar_for_positional(self, action): - return action.type.__name__ - - -parse = argparse.ArgumentParser(formatter_class=CustomFormatter) -parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", help="embedding模型本地路径") -parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") -parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed",help="使用TEI服务化的embedding模型url地址") -parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") -parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") -parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") -parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") -parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径") -parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址") -parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取") -parse.add_argument("--port", type=int, default=8080, help="web后端端口") - -args = parse.parse_args().__dict__ -tei_emb: bool = args.pop('tei_emb') -embedding_path: str = args.pop('embedding_path') -embedding_url: str = args.pop('embedding_url') -embedding_dim: int = args.pop('embedding_dim') -llm_url: str = args.pop('llm_url') -model_name: str = args.pop('model_name') -tei_reranker: bool = args.pop('tei_reranker') -reranker_path: str = args.pop('reranker_path') -reranker_url: str = args.pop('reranker_url') -dev: int = args.pop('dev') -port: int = args.pop('port') - -# 初始化向量数据库 -vector_store = MindFAISS(x_dim=embedding_dim, - similarity_strategy=SimilarityStrategy.FLAT_L2, - devs=[dev], - load_local_index="./test_faiss.index", - auto_save=True - ) -# 初始化文档chunk关系数据库 -chunk_store = SQLiteDocstore(db_path="./test_sql.db") -# 初始化知识管理关系数据库 -knowledge_store = KnowledgeStore(db_path="./test_sql.db") -# 初始化知识库管理 -text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, - chunk_store=chunk_store, - vector_store=vector_store, - knowledge_name="test", - white_paths=["/tmp"] - ) +#初始化数据库管理 knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db")) -# 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 -llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True)) -# 配置embedding模型,请根据模型具体路径适配 -if tei_emb: - text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) -else: - text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) +# 创建新的知识库 +def creat_knowledge_db(knowledge_name: str, is_regist: bool = True): + knowledge_name = get_knowledge_rename(knowledge_name) + index_name, db_name = get_db_names(knowledge_name) + vector_store_cls = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index=index_name, + auto_save=True) + # 初始化文档chunk关系数据库 + chunk_store_cls = SQLiteDocstore(db_path=db_name) + # 初始化知识管理关系数据库 + knowledge_store_cls = KnowledgeStore(db_path=db_name) + # 初始化知识库管理 + knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls, + chunk_store=chunk_store_cls, + vector_store=vector_store_cls, + knowledge_name=knowledge_name, + white_paths=["/tmp"]) + if is_regist: + knowledge_mgr.register(knowledge_db_cls) + return knowledge_db_cls -# 配置rerank,请根据模型具体路径适配 -if tei_reranker: - reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) -elif reranker_path is not None: - reranker = LocalReranker(model_path=reranker_path, dev_id=dev) -else: - reranker = None + +# 创建检索器 +def creat_retriever(knowledge_name: str, top_k, score_threshold): + knowledge_name = get_knowledge_rename(knowledge_name) + index_name, db_name = get_db_names(knowledge_name) + vector_store_ret = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index=index_name, + auto_save=True) + # 初始化文档chunk关系数据库 + chunk_store_ret = SQLiteDocstore(db_path=db_name) + retriever_cls = Retriever(vector_store=vector_store_ret, + document_store=chunk_store_ret, + embed_func=text_emb.embed_documents, + k=top_k, + score_threshold=score_threshold) + return retriever_cls + + +# 删除知识库 +def delete_knowledge_db(knowledge_name: str): + knowledge_name = get_knowledge_rename(knowledge_name) + knowledge_db_lst = knowledge_mgr.get_all() + if knowledge_name in knowledge_db_lst: + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + clear_document(knowledge_name) + knowledge_mgr.delete(knowledge_db_cls) + return get_knowledge_db() + + +# 获取知识库中文档列表 +def get_document(knowledge_name: str): + knowledge_name = get_knowledge_rename(knowledge_name) + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + doc_names = knowledge_db_cls.get_all_documents() + return knowledge_name, doc_names, len(doc_names) + + +# 清空知识库中文档列表 +def clear_document(knowledge_name: str): + knowledge_name, doc_names, doc_cnt = get_document(knowledge_name) + if doc_cnt > 0: + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + for doc_name in doc_names: + knowledge_db_cls.delete_file(doc_name) + return get_document(knowledge_name) + else: + return knowledge_name, doc_names, doc_cnt + + +# 获取知识库列表 +def get_knowledge_db(): + knowledge_db_lst = knowledge_mgr.get_all() + return knowledge_db_lst, len(knowledge_db_lst) + + +# 上传知识库 +def file_upload(files, + knowledge_db_name: str = 'test', + chunk_size: int = 750, + chunk_overlap: int = 150 + ): + save_file_path = "/tmp/document_files" + knowledge_db_name = get_knowledge_rename(knowledge_db_name) + # 指定保存文件的文件夹 + if not os.path.exists(save_file_path): + os.makedirs(save_file_path) + if files is None or len(files) == 0: + print('no file need save') + if knowledge_db_name in knowledge_mgr.get_all(): + knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False) + else: + knowledge_db_cls = creat_knowledge_db(knowledge_db_name) + # 注册文档加载器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) + loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"]) + + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], + splitter_params={"chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "keep_separator": False + }) + for file in files: + try: + # 上传领域知识文档 + shutil.copy(file.name, save_file_path) + # 知识库:chunk\embedding\add + upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True) + print(f"file {file.name} save to {save_file_path}.") + except Exception as err: + logger.error(f"save failed, find exception: {err}") + + +def file_change(files, upload_btn): + print("file changes") + + +def get_db_names(knowledge_name: str): + index_name = "./" + knowledge_name + "_faiss.index" + db_name = "./" + knowledge_name + "_sql.db" + return index_name, db_name + + +def get_knowledge_rename(knowledge_name: str): + if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name: + return 'test' + else: + return knowledge_name # 历史问题改写 @@ -187,8 +258,9 @@ def bot_response(history, if len(q_docs) > 0: history_last = '' for i ,source in enumerate(q_docs): - sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \ - "--->参考内容:" +source['page_content'] + "\n" + sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ + "参考内容:" +source['page_content'] + "\n" history_last += sources history += [[None, history_last]] yield history, history_r @@ -202,8 +274,9 @@ def bot_response(history, if len(q_docs) > 0: history_r_last = '' for i ,source in enumerate(q_docs): - sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \ - "--->参考内容:" +source['page_content'] + "\n" + sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ + "参考内容:" +source['page_content'] + "\n" history_r_last += sources history_r += [[history[-1][0], history_r_last]] yield history, history_r @@ -224,8 +297,9 @@ def re_response(history_r, if len(q_docs) > 0: history_r_last = '' for i, source in enumerate(q_docs): - sources = "\n====检索信息来源" + str(i + 1) + ":" + source.metadata['source'] + "====" + "\n" + \ - "--->参考内容:" + source.page_content + "\n" + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) +\ + "[文件名]:" + source.metadata['source'] + "====" + "\n" + \ + "参考内容:" + source.page_content + "\n" history_r_last += sources history_r[-1][1] = history_r_last else: @@ -247,307 +321,220 @@ def clear_history(history): return [] -# 创建新的知识库 -def creat_knowledge_db(knowledge_name: str, is_regist: bool=True): - knowledge_name = get_knowledge_rename(knowledge_name) - index_name, db_name = get_db_names(knowledge_name) - vector_store_cls = MindFAISS(x_dim=embedding_dim, - similarity_strategy=SimilarityStrategy.FLAT_L2, - devs=[dev], - load_local_index=index_name, - auto_save=True) - # 初始化文档chunk关系数据库 - chunk_store_cls = SQLiteDocstore(db_path=db_name) - # 初始化知识管理关系数据库 - knowledge_store_cls = KnowledgeStore(db_path=db_name) - # 初始化知识库管理 - knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls, - chunk_store=chunk_store_cls, - vector_store=vector_store_cls, - knowledge_name=knowledge_name, - white_paths=["/tmp"]) - if is_regist: - knowledge_mgr.register(knowledge_db_cls) - return knowledge_db_cls - - -def creat_retriever(knowledge_name: str, top_k, score_threshold): - knowledge_name = get_knowledge_rename(knowledge_name) - index_name, db_name = get_db_names(knowledge_name) - vector_store_ret = MindFAISS(x_dim=embedding_dim, - similarity_strategy=SimilarityStrategy.FLAT_L2, - devs=[dev], - load_local_index=index_name, - auto_save=True) - # 初始化文档chunk关系数据库 - chunk_store_ret = SQLiteDocstore(db_path=db_name) - retriever_cls = Retriever(vector_store=vector_store_ret, - document_store=chunk_store_ret, - embed_func=text_emb.embed_documents, - k=top_k, - score_threshold=score_threshold) - return retriever_cls - - -# 删除知识库 -def delete_knowledge_db(knowledge_name: str): - knowledge_name = get_knowledge_rename(knowledge_name) - knowledge_db_lst = knowledge_mgr.get_all() - if knowledge_name in knowledge_db_lst: - knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) - clear_document(knowledge_name) - knowledge_mgr.delete(knowledge_db_cls) - return get_knowledge_db() - - -# 获取知识库文档列表 -def get_document(knowledge_name: str): - knowledge_name = get_knowledge_rename(knowledge_name) - knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) - doc_names = knowledge_db_cls.get_all_documents() - return knowledge_name, doc_names, len(doc_names) - - -# 清空知识库文档列表 -def clear_document(knowledge_name: str): - knowledge_name, doc_names, doc_cnt = get_document(knowledge_name) - if doc_cnt > 0: - knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) - for doc_name in doc_names: - knowledge_db_cls.delete_file(doc_name) - return get_document(knowledge_name) - else: - return knowledge_name, doc_names, doc_cnt - - -# 获取知识库列表 -def get_knowledge_db(): - knowledge_db_lst = knowledge_mgr.get_all() - return knowledge_db_lst, len(knowledge_db_lst) - - -def get_db_names(knowledge_name: str): - index_name = "./" + knowledge_name + "_faiss.index" - dn_name = "./" + knowledge_name + "_sql.db" - return index_name, dn_name - - -def get_knowledge_rename(knowledge_name: str): - if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name: - return 'test' +if __name__ == '__main__': + class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) + parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", + help="embedding模型本地路径") + parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") + parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed", + help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") + parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") + parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") + parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") + parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径") + parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取") + parse.add_argument("--port", type=int, default=8080, help="web后端端口") + + args = parse.parse_args().__dict__ + embedding_path: str = args.pop('embedding_path') + tei_emb: bool = args.pop('tei_emb') + embedding_url: str = args.pop('embedding_url') + embedding_dim: int = args.pop('embedding_dim') + llm_url: str = args.pop('llm_url') + model_name: str = args.pop('model_name') + tei_reranker: bool = args.pop('tei_reranker') + reranker_path: str = args.pop('reranker_path') + reranker_url: str = args.pop('reranker_url') + dev: int = args.pop('dev') + port: int = args.pop('port') + + # 初始化test数据库 + knowledge_db = creat_knowledge_db('test') + # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 + llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True)) + # 配置embedding模型,请根据模型具体路径适配 + if tei_emb: + text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) else: - return knowledge_name - - -# 上传知识库 -def file_upload(files, - knowledge_db_name: str = 'test', - chunk_size: int = 750, - chunk_overlap: int = 150 - ): - save_file_path = "/tmp/document_files" - knowledge_db_name = get_knowledge_rename(knowledge_db_name) - # 指定保存文件的文件夹 - if not os.path.exists(save_file_path): - os.makedirs(save_file_path) - if files is None or len(files) == 0: - print('no file need save') - if knowledge_db_name in knowledge_mgr.get_all(): - knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False) + text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + # 配置rerank,请根据模型具体路径适配 + if tei_reranker: + reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) + elif reranker_path is not None: + reranker = LocalReranker(model_path=reranker_path, dev_id=dev) else: - knowledge_db_cls = creat_knowledge_db(knowledge_db_name) - # 注册文档加载器 - loader_mng = LoaderMng() - # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 - loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) - loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) - loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) - loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"]) - - # 加载文档切分器,使用langchain的 - loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, - file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], - splitter_params={"chunk_size": chunk_size, - "chunk_overlap": chunk_overlap, - "keep_separator": False - }) - for file in files: - try: - # 上传领域知识文档 - shutil.copy(file.name, save_file_path) - # 知识库:chunk\embedding\add - upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True) - print(f"file {file.name} save to {save_file_path}.") - except Exception as err: - logger.error(f"save failed, find exception: {err}") - - -def file_change(files, upload_btn): - print("file changes") - - -# 构建gradio框架 -def build_demo(): - with gr.Blocks() as demo: - gr.HTML("

检索增强生成(RAG)对话

powered by MindX RAG

") - with gr.Row(): - with gr.Column(scale=100): - with gr.Row(): - model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True) - with gr.Row(): - files = gr.Files( - height=300, - file_count="multiple", - file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], - interactive=True, - label="上传知识库文档" - ) - with gr.Row(): - upload_btn = gr.Button("上传文件") - with gr.Row(): - with gr.TabItem(label='知识库情况'): - set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库') - with gr.Row(): - creat_knowledge_btn = gr.Button('创建知识库') - delete_knowledge_btn = gr.Button('删除知识库') - knowledge_name_output = gr.Textbox(label='知识库列表') - knowledge_number_output = gr.Textbox(label='知识库数量') - with gr.Row(): - show_knowledge_btn = gr.Button('显示知识库') - with gr.TabItem(label='文件情况'): - knowledge_name = gr.Textbox('知识库名称') - knowledge_file_output = gr.Textbox('知识库文件列表') - knowledge_file_num_output = gr.Textbox('知识库文件数量') - with gr.Row(): - knowledge_file_out_btn = gr.Button('显示文件情况') - knowledge_clear_btn = gr.Button('清空知识库') - with gr.Row(): - with gr.According(label='文档切分参数设置',open=False): - chunk_size = gr.Slider( - minimum=50, - maximum=5000, - value=750, - step=50, + reranker = None + + # 构建gradio框架 + def build_demo(): + with (gr.Blocks() as demo): + gr.HTML("

检索增强生成(RAG)对话

powered by MindX RAG

") + with gr.Row(): + with gr.Column(scale=100): + with gr.Row(): + model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True) + with gr.Row(): + files = gr.Files( + height=300, + file_count="multiple", + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], interactive=True, - label="chunk_size", - info="文本切分长度" + label="上传知识库文档" ) - chunk_overlap = gr.Slider( - minimum=10, - maximum=500, - value=150, - step=10, - interactive=True, - label="chunk_overlap", - info="文本切分填充长度" - ) - with gr.Row(): - with gr.According(label='大模型参数设置', open=False): - temperature = gr.Slider( - minimum=0.01, - maximum=2, - value=0.5, - step=0.01, - interactive=True, - label="温度", - info="Token生成的随机性" - ) - top_p = gr.Slider( - minimum=0.01, - maximum=1, - value=0.95, - step=0.05, - interactive=True, - label="Top P", - info="累计概率总和阈值" - ) - max_tokens = gr.Slider( - minimum=100, - maximum=1024, - value=512, - step=1, - interactive=True, - label="最大tokens", - info="输入+输出最多的tokens数" - ) - is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?") - history_n = gr.Slider( - minimum=1, - maximum=10, - value=5, - step=1, - interactive=True, - label="历史提问重新轮数", - info="问题重写时所参考的历史提问轮数" - ) - with gr.Row(): - with gr.According(label='检索参数设置', open=False): - score_threshold = gr.Slider( - minimum=0, - maximum=1, - value=0.5, - step=0.01, - interactive=True, - label="score_threshold", - info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" - ) - top_k = gr.Slider( - minimum=1, - maximum=5, - value=1, - step=1, - interactive=True, - label="top_k", - info="相似性检索返回条数" - ) - show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择") - with gr.Column(scale=200): - with gr.Tabs(): - with gr.TabItem("对话窗口"): - chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?") - chatbot = gr.Chatbot(height=550) - with gr.Row(): - msg = gr.Textbox(placeholder="在此输入问题...", container=False) - with gr.Row(): - send_btn = gr.Button(value="发送", variant="primary") - clean_btn = gr.Button(value="清空历史") - with gr.TabItem("检索窗口"): - chatbot_r = gr.Chatbot(height=550) - with gr.Row(): - msg_r = gr.Textbox(placeholder="在此输入问题...", container=False) - with gr.Row(): - send_btn_r = gr.Button(value="文档检索", variant="primary") - clean_btn_r = gr.Button(value="清空历史") - # RAG发送 - send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False - ).then(bot_response, - [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name], - [chatbot, chatbot_r]) - # RAG清除历史 - clean_btn.click(clear_history, chatbot, chatbot) - # 上传文件 - files.change(file_change, [], []) - upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files) - # 管理所有知识库 - creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output]) - show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) - delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output]) - # 管理知识库里文件 - knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) - knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) - # 检索发送 - send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False - ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r) - # 检索清除历史 - clean_btn_r.click(clear_history, chatbot_r, chatbot_r) - return demo - - -def create_gradio(ports): - demo = build_demo() - demo.queue() - demo.launch(share=True, server_name="0.0.0.0", server_port=ports) - - -if __name__ == '__main__': + with gr.Row(): + upload_btn = gr.Button("上传文件") + with gr.Row(): + with gr.TabItem(label='知识库情况'): + set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库') + with gr.Row(): + creat_knowledge_btn = gr.Button('创建知识库') + delete_knowledge_btn = gr.Button('删除知识库') + knowledge_name_output = gr.Textbox(label='知识库列表') + knowledge_number_output = gr.Textbox(label='知识库数量') + with gr.Row(): + show_knowledge_btn = gr.Button('显示知识库') + with gr.TabItem(label='文件情况'): + knowledge_name = gr.Textbox('知识库名称') + knowledge_file_output = gr.Textbox('知识库文件列表') + knowledge_file_num_output = gr.Textbox('知识库文件数量') + with gr.Row(): + knowledge_file_out_btn = gr.Button('显示文件情况') + knowledge_clear_btn = gr.Button('清空知识库') + with gr.Row(): + with gr.According(label='文档切分参数设置',open=False): + chunk_size = gr.Slider( + minimum=50, + maximum=5000, + value=750, + step=50, + interactive=True, + label="chunk_size", + info="文本切分长度" + ) + chunk_overlap = gr.Slider( + minimum=10, + maximum=500, + value=150, + step=10, + interactive=True, + label="chunk_overlap", + info="文本切分填充长度" + ) + with gr.Row(): + with gr.According(label='大模型参数设置', open=False): + temperature = gr.Slider( + minimum=0.01, + maximum=2, + value=0.5, + step=0.01, + interactive=True, + label="温度", + info="Token生成的随机性" + ) + top_p = gr.Slider( + minimum=0.01, + maximum=1, + value=0.95, + step=0.05, + interactive=True, + label="Top P", + info="累计概率总和阈值" + ) + max_tokens = gr.Slider( + minimum=100, + maximum=1024, + value=512, + step=1, + interactive=True, + label="最大tokens", + info="输入+输出最多的tokens数" + ) + is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?") + history_n = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + interactive=True, + label="历史提问重新轮数", + info="问题重写时所参考的历史提问轮数" + ) + with gr.Row(): + with gr.According(label='检索参数设置', open=False): + score_threshold = gr.Slider( + minimum=0, + maximum=1, + value=0.5, + step=0.01, + interactive=True, + label="score_threshold", + info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" + ) + top_k = gr.Slider( + minimum=1, + maximum=5, + value=1, + step=1, + interactive=True, + label="top_k", + info="相似性检索返回条数" + ) + show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择") + with gr.Column(scale=200): + with gr.Tabs(): + with gr.TabItem("对话窗口"): + chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?") + chatbot = gr.Chatbot(height=550) + with gr.Row(): + msg = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn = gr.Button(value="发送", variant="primary") + clean_btn = gr.Button(value="清空历史") + with gr.TabItem("检索窗口"): + chatbot_r = gr.Chatbot(height=550) + with gr.Row(): + msg_r = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn_r = gr.Button(value="文档检索", variant="primary") + clean_btn_r = gr.Button(value="清空历史") + # RAG发送 + send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False + ).then(bot_response, + [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name], + [chatbot, chatbot_r]) + # RAG清除历史 + clean_btn.click(clear_history, chatbot, chatbot) + # 上传文件 + files.change(file_change, [], []) + upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files) + # 管理所有知识库 + creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output]) + show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) + delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output]) + # 管理知识库里文件 + knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + # 检索发送 + send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False + ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r) + # 检索清除历史 + clean_btn_r.click(clear_history, chatbot_r, chatbot_r) + return demo + + + def create_gradio(ports): + demo = build_demo() + demo.queue() + demo.launch(share=True, server_name="0.0.0.0", server_port=ports) + + # 启动gradio create_gradio(port) -- Gitee From 92a719ec317bf677315665fd781f6e717ab82dc7 Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Mon, 25 Nov 2024 11:52:18 +0800 Subject: [PATCH 3/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxRAG/PocValidation/chat_with_ascend/app.py | 68 ++++++++++----------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index 5d6dea138..4eb7c03cc 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -24,7 +24,7 @@ from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr from mx_rag.knowledge.handler import upload_files, LoaderMng -#初始化数据库管理 +# 初始化数据库管理 knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db")) @@ -125,7 +125,7 @@ def file_upload(files, knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False) else: knowledge_db_cls = creat_knowledge_db(knowledge_db_name) - # 注册文档加载器 + # 注册文档处理器 loader_mng = LoaderMng() # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) @@ -207,8 +207,8 @@ def bot_response(history, top_k: int = 1, chat_type: str = "RAG检索增强对话", show_type: str = "不显示", - is_rewrite:str = "否", - knowledge_name: str = 'test', + is_rewrite: str = "否", + knowledge_name: str = 'test' ): chat_type_mapping = {"RAG检索增强对话":1, "仅大模型对话":0} @@ -218,15 +218,15 @@ def bot_response(history, is_rewrite_mapping = {"是":1, "否":0} # 初始化检索器 - retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold) + retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold) # 初始化chain - text2text_chain = SingleText2TextChain(llm=llm, retriever=retrieve_cls, reranker=reranker) + text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever_cls, reranker=reranker) # 历史问题改写 if is_rewrite_mapping.get(is_rewrite) == 1: history = generate_question(history, llm, history_n) history[-1][1] = '推理错误' try: - # 仅使用大模型 + # 仅使用大模型回答 if chat_type_mapping.get(chat_type) == 0: response = llm.chat_streamly(query=history[-1][0], llm_config=LLMParameterConfig(max_tokens=max_tokens, @@ -257,10 +257,10 @@ def bot_response(history, # 有检索到信息 if len(q_docs) > 0: history_last = '' - for i ,source in enumerate(q_docs): - sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \ + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ - "参考内容:" +source['page_content'] + "\n" + "参考内容:" + source['page_content'] + "\n" history_last += sources history += [[None, history_last]] yield history, history_r @@ -274,9 +274,9 @@ def bot_response(history, if len(q_docs) > 0: history_r_last = '' for i ,source in enumerate(q_docs): - sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \ + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ - "参考内容:" +source['page_content'] + "\n" + "参考内容:" + source['page_content'] + "\n" history_r_last += sources history_r += [[history[-1][0], history_r_last]] yield history, history_r @@ -292,12 +292,12 @@ def re_response(history_r, knowledge_name: str = 'test' ): # 初始化检索器 - retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold) - q_docs = retrieve_cls.invoke(history_r[-1][0]) + retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold) + q_docs = retriever_cls.invoke(history_r[-1][0]) if len(q_docs) > 0: history_r_last = '' for i, source in enumerate(q_docs): - sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) +\ + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ "[文件名]:" + source.metadata['source'] + "====" + "\n" + \ "参考内容:" + source.page_content + "\n" history_r_last += sources @@ -309,7 +309,7 @@ def re_response(history_r, # 检索信息 def user_retriever(user_message, history_r): - return "",history_r + [[user_message, None]] + return "", history_r + [[user_message, None]] # 聊天信息 @@ -366,7 +366,7 @@ if __name__ == '__main__': text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) else: text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) - # 配置rerank,请根据模型具体路径适配 + # 配置reranker,请根据模型具体路径适配 if tei_reranker: reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) elif reranker_path is not None: @@ -381,10 +381,10 @@ if __name__ == '__main__': with gr.Row(): with gr.Column(scale=100): with gr.Row(): - model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True) + model_select = gr.Dropdown(choices=[model_name], value="Llama3-8B-Chinese-Chat", container=False, interactive=True) with gr.Row(): - files = gr.Files( - height=300, + files = gr.components.File( + height=100, file_count="multiple", file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], interactive=True, @@ -393,8 +393,8 @@ if __name__ == '__main__': with gr.Row(): upload_btn = gr.Button("上传文件") with gr.Row(): - with gr.TabItem(label='知识库情况'): - set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库') + with gr.TabItem("知识库情况"): + set_knowledge_name = gr.Textbox(label='设置当前知识库', placeholder="在此输入知识库名称,默认使用test知识库") with gr.Row(): creat_knowledge_btn = gr.Button('创建知识库') delete_knowledge_btn = gr.Button('删除知识库') @@ -402,15 +402,15 @@ if __name__ == '__main__': knowledge_number_output = gr.Textbox(label='知识库数量') with gr.Row(): show_knowledge_btn = gr.Button('显示知识库') - with gr.TabItem(label='文件情况'): - knowledge_name = gr.Textbox('知识库名称') - knowledge_file_output = gr.Textbox('知识库文件列表') - knowledge_file_num_output = gr.Textbox('知识库文件数量') + with gr.TabItem("文件情况"): + knowledge_name = gr.Textbox(label='知识库名称') + knowledge_file_output = gr.Textbox(label='知识库文件列表') + knowledge_file_num_output = gr.Textbox(label='知识库文件数量') with gr.Row(): knowledge_file_out_btn = gr.Button('显示文件情况') knowledge_clear_btn = gr.Button('清空知识库') with gr.Row(): - with gr.According(label='文档切分参数设置',open=False): + with gr.Accordion(label='文档切分参数设置', open=False): chunk_size = gr.Slider( minimum=50, maximum=5000, @@ -430,7 +430,7 @@ if __name__ == '__main__': info="文本切分填充长度" ) with gr.Row(): - with gr.According(label='大模型参数设置', open=False): + with gr.Accordion(label='大模型参数设置', open=False): temperature = gr.Slider( minimum=0.01, maximum=2, @@ -458,14 +458,14 @@ if __name__ == '__main__': label="最大tokens", info="输入+输出最多的tokens数" ) - is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?") + is_rewrite = gr.Radio(['是', '否'], value="否", label="是否根据历史提问重写问题?") history_n = gr.Slider( minimum=1, maximum=10, value=5, step=1, interactive=True, - label="历史提问重新轮数", + label="历史提问重写轮数", info="问题重写时所参考的历史提问轮数" ) with gr.Row(): @@ -477,7 +477,7 @@ if __name__ == '__main__': step=0.01, interactive=True, label="score_threshold", - info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" + info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" ) top_k = gr.Slider( minimum=1, @@ -488,11 +488,11 @@ if __name__ == '__main__': label="top_k", info="相似性检索返回条数" ) - show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择") + show_type = gr.Radio(['对话结束后显示', '检索框单独显示', '不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择") with gr.Column(scale=200): with gr.Tabs(): with gr.TabItem("对话窗口"): - chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?") + chat_type = gr.Radio(['RAG检索增强对话', '仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?") chatbot = gr.Chatbot(height=550) with gr.Row(): msg = gr.Textbox(placeholder="在此输入问题...", container=False) @@ -517,7 +517,7 @@ if __name__ == '__main__': files.change(file_change, [], []) upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files) # 管理所有知识库 - creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output]) + creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name], []).then(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output]) # 管理知识库里文件 -- Gitee From 3b002c8d0dd2a8b452425f7dc60d8b0d170e7b10 Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Mon, 25 Nov 2024 12:02:25 +0800 Subject: [PATCH 4/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxRAG/PocValidation/chat_with_ascend/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index 4eb7c03cc..53414ec6c 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -469,7 +469,7 @@ if __name__ == '__main__': info="问题重写时所参考的历史提问轮数" ) with gr.Row(): - with gr.According(label='检索参数设置', open=False): + with gr.Accordion(label='检索参数设置', open=False): score_threshold = gr.Slider( minimum=0, maximum=1, -- Gitee From d63937f97b13c0bc7f9f88b02afe25b9dddafcad Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Mon, 25 Nov 2024 12:09:27 +0800 Subject: [PATCH 5/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxRAG/PocValidation/chat_with_ascend/app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index 53414ec6c..cd4d6c720 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -218,6 +218,7 @@ def bot_response(history, is_rewrite_mapping = {"是":1, "否":0} # 初始化检索器 + knowledge_name = get_knowledge_rename(knowledge_name) retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold) # 初始化chain text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever_cls, reranker=reranker) -- Gitee From 3f99338c0c819b0154429f8d2cdc22dacba7ddd6 Mon Sep 17 00:00:00 2001 From: tanfeng <823018000@qq.com> Date: Mon, 25 Nov 2024 12:42:28 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?= =?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mxRAG/PocValidation/chat_with_ascend/app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index cd4d6c720..1d23eb9a3 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -274,7 +274,7 @@ def bot_response(history, # 有检索到信息 if len(q_docs) > 0: history_r_last = '' - for i ,source in enumerate(q_docs): + for i, source in enumerate(q_docs): sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ "参考内容:" + source['page_content'] + "\n" @@ -293,6 +293,7 @@ def re_response(history_r, knowledge_name: str = 'test' ): # 初始化检索器 + knowledge_name = get_knowledge_rename(knowledge_name) retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold) q_docs = retriever_cls.invoke(history_r[-1][0]) if len(q_docs) > 0: -- Gitee