From bf6e4e3b42c39616e746024744bb90373260cc0b Mon Sep 17 00:00:00 2001
From: tanfeng <823018000@qq.com>
Date: Fri, 22 Nov 2024 16:57:08 +0800
Subject: [PATCH 1/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?=
=?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mxRAG/PocValidation/chat_with_ascend/app.py | 558 ++++++++++++++++----
1 file changed, 441 insertions(+), 117 deletions(-)
diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py
index 83d1f33d2..d5805bfb8 100644
--- a/mxRAG/PocValidation/chat_with_ascend/app.py
+++ b/mxRAG/PocValidation/chat_with_ascend/app.py
@@ -9,6 +9,9 @@ from loguru import logger
from paddle.base import libpaddle
from mx_rag.chain import SingleText2TextChain
from mx_rag.embedding.local import TextEmbedding
+from mx_rag.embedding.service import TEIEmbedding
+from mx_rag.reranker.local import LocalReranker
+from mx_rag.reranker.service import TEIReranker
from mx_rag.llm import Text2TextLLM, LLMParameterConfig
from mx_rag.retrievers import Retriever
from mx_rag.storage.document_store import SQLiteDocstore
@@ -17,39 +20,225 @@ from mx_rag.storage.vectorstore.vectorstore import SimilarityStrategy
from mx_rag.utils import ClientParam
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import TextLoader
-from mx_rag.document.loader import DocxLoader, PdfLoader
-from mx_rag.knowledge import KnowledgeDB
-from mx_rag.knowledge.knowledge import KnowledgeStore
+from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader
+from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr
from mx_rag.knowledge.handler import upload_files, LoaderMng
-global text2text_chain
-global text_knowledge_db
-global loader_mng
-global text_emb
+class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
+ def _get_default_metavar_for_optional(self, action):
+ return action.type.__name__
+
+ def _get_default_metavar_for_positional(self, action):
+ return action.type.__name__
+
+
+parse = argparse.ArgumentParser(formatter_class=CustomFormatter)
+parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", help="embedding模型本地路径")
+parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型")
+parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed",help="使用TEI服务化的embedding模型url地址")
+parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度")
+parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址")
+parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称")
+parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型")
+parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径")
+parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址")
+parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取")
+parse.add_argument("--port", type=int, default=8080, help="web后端端口")
+
+args = parse.parse_args().__dict__
+tei_emb: bool = args.pop('tei_emb')
+embedding_path: str = args.pop('embedding_path')
+embedding_url: str = args.pop('embedding_url')
+embedding_dim: int = args.pop('embedding_dim')
+llm_url: str = args.pop('llm_url')
+model_name: str = args.pop('model_name')
+tei_reranker: bool = args.pop('tei_reranker')
+reranker_path: str = args.pop('reranker_path')
+reranker_url: str = args.pop('reranker_url')
+dev: int = args.pop('dev')
+port: int = args.pop('port')
+# 初始化向量数据库
+vector_store = MindFAISS(x_dim=embedding_dim,
+ similarity_strategy=SimilarityStrategy.FLAT_L2,
+ devs=[dev],
+ load_local_index="./test_faiss.index",
+ auto_save=True
+ )
+# 初始化文档chunk关系数据库
+chunk_store = SQLiteDocstore(db_path="./test_sql.db")
+# 初始化知识管理关系数据库
+knowledge_store = KnowledgeStore(db_path="./test_sql.db")
+# 初始化知识库管理
+text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store,
+ chunk_store=chunk_store,
+ vector_store=vector_store,
+ knowledge_name="test",
+ white_paths=["/tmp"]
+ )
+knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db"))
+
+# 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改
+llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True))
+
+# 配置embedding模型,请根据模型具体路径适配
+if tei_emb:
+ text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True))
+else:
+ text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev)
+
+# 配置rerank,请根据模型具体路径适配
+if tei_reranker:
+ reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True))
+elif reranker_path is not None:
+ reranker = LocalReranker(model_path=reranker_path, dev_id=dev)
+else:
+ reranker = None
+
+
+# 历史问题改写
+def generate_question(history, llm, history_n: int = 5):
+ prompt = """现在你有一个上下文依赖问题补全任务,任务要求:请根据对话历史和用户当前的问句,重写问句。\n
+ 历史问题依次是:\n
+ {}\n
+ 用户当前的问句:\n
+ {}\n
+ 注意如果当前问题不依赖历史问题直接返回none即可\n
+ 请根据上述对话历史重写用户当前的问句,仅输出重写后的问句,不需要附加任何分析。\n
+ 重写问句: \n
+ """
+ if len(history) <= 2:
+ return history
+ cur_query = history[-1][0]
+ history_qa = history[0:-1]
+ history_list = [f"第{idx + 1}轮:{q_a[0]}" for idx, q_a in enumerate(history_qa) if q_a[0] is not None]
+ history_list = history_list[:history_n]
+ history_str = "\n\n".join(history_list)
+ re_query = prompt.format(history_str, cur_query)
+ new_query = llm.chat(query=re_query, llm_config=LLMParameterConfig(max_tokens=512,
+ temperature=0.5,
+ top_p=0.95))
+ if new_query != "none":
+ history[-1][0] = "原始问题: " + cur_query
+ history += [[new_query, None]]
+ return history
+
+
+# 聊天对话框
def bot_response(history,
+ history_r,
max_tokens: int = 512,
temperature: float = 0.5,
- top_p: float = 0.95
+ top_p: float = 0.95,
+ history_n: int = 5,
+ score_threshold: float = 0.5,
+ top_k: int = 1,
+ chat_type: str = "RAG检索增强对话",
+ show_type: str = "不显示",
+ is_rewrite:str = "否",
+ knowledge_name: str = 'test',
):
- # 将最新的问题传给RAG
+ chat_type_mapping = {"RAG检索增强对话":1,
+ "仅大模型对话":0}
+ show_type_mapping = {"对话结束后显示":1,
+ "检索框单独显示":2,
+ "不显示":0}
+ is_rewrite_mapping = {"是":1,
+ "否":0}
+ # 初始化检索器
+ retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold)
+ # 初始化chain
+ text2text_chain = SingleText2TextChain(llm=llm, retriever=retrieve_cls, reranker=reranker)
+ # 历史问题改写
+ if is_rewrite_mapping.get(is_rewrite) == 1:
+ history = generate_question(history, llm, history_n)
+ history[-1][1] = '推理错误'
try:
- response = text2text_chain.query(history[-1][0],
- LLMParameterConfig(max_tokens=max_tokens, temperature=temperature, top_p=top_p,
- stream=True))
- # 返回迭代器
- history[-1][1] = '推理错误'
- for res in response:
- history[-1][1] = '推理错误' if res['result'] is None else res['result']
- yield history
- yield history
+ # 仅使用大模型
+ if chat_type_mapping.get(chat_type) == 0:
+ response = llm.chat_streamly(query=history[-1][0],
+ llm_config=LLMParameterConfig(max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p))
+ # 返回迭代器
+ for res in response:
+ history[-1][1] = res
+ yield history, history_r
+ # 使用RAG增强回答
+ elif chat_type_mapping.get(chat_type) == 1:
+ response = text2text_chain.query(history[-1][0],
+ LLMParameterConfig(max_tokens=max_tokens,
+ temperature=temperature,
+ top_p=top_p,
+ stream=True))
+ # 不展示检索内容
+ if show_type_mapping.get(show_type) == 0:
+ for res in response:
+ history[-1][1] = '推理错误' if res['result'] is None else res['result']
+ yield history, history_r
+ # 问答结尾展示
+ elif show_type_mapping.get(show_type) == 1:
+ for res in response:
+ history[-1][1] = '推理错误' if res['result'] is None else res['result']
+ q_docs = res['source_documents']
+ yield history, history_r
+ # 有检索到信息
+ if len(q_docs) > 0:
+ history_last = ''
+ for i ,source in enumerate(q_docs):
+ sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \
+ "--->参考内容:" +source['page_content'] + "\n"
+ history_last += sources
+ history += [[None, history_last]]
+ yield history, history_r
+ # 检索窗口展示
+ else:
+ for res in response:
+ history[-1][1] = '推理错误' if res['result'] is None else res['result']
+ q_docs = res['source_documents']
+ yield history, history_r
+ # 有检索到信息
+ if len(q_docs) > 0:
+ history_r_last = ''
+ for i ,source in enumerate(q_docs):
+ sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \
+ "--->参考内容:" +source['page_content'] + "\n"
+ history_r_last += sources
+ history_r += [[history[-1][0], history_r_last]]
+ yield history, history_r
except Exception as err:
logger.info(f"query failed, find exception: {err}")
- history[-1][1] = "推理错误"
- yield history
+ yield history, history_r
+
+
+# 检索对话框
+def re_response(history_r,
+ score_threshold: float = 0.5,
+ top_k: int = 1,
+ knowledge_name: str = 'test'
+ ):
+ # 初始化检索器
+ retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold)
+ q_docs = retrieve_cls.invoke(history_r[-1][0])
+ if len(q_docs) > 0:
+ history_r_last = ''
+ for i, source in enumerate(q_docs):
+ sources = "\n====检索信息来源" + str(i + 1) + ":" + source.metadata['source'] + "====" + "\n" + \
+ "--->参考内容:" + source.page_content + "\n"
+ history_r_last += sources
+ history_r[-1][1] = history_r_last
+ else:
+ history_r[-1][1] = "未检索到相关信息"
+ return history_r
+# 检索信息
+def user_retriever(user_message, history_r):
+ return "",history_r + [[user_message, None]]
+
+
+# 聊天信息
def user_query(user_message, history):
return "", history + [[user_message, None]]
@@ -58,23 +247,137 @@ def clear_history(history):
return []
-SAVE_FILE_PATH = "/tmp/document_files"
+# 创建新的知识库
+def creat_knowledge_db(knowledge_name: str, is_regist: bool=True):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ index_name, db_name = get_db_names(knowledge_name)
+ vector_store_cls = MindFAISS(x_dim=embedding_dim,
+ similarity_strategy=SimilarityStrategy.FLAT_L2,
+ devs=[dev],
+ load_local_index=index_name,
+ auto_save=True)
+ # 初始化文档chunk关系数据库
+ chunk_store_cls = SQLiteDocstore(db_path=db_name)
+ # 初始化知识管理关系数据库
+ knowledge_store_cls = KnowledgeStore(db_path=db_name)
+ # 初始化知识库管理
+ knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls,
+ chunk_store=chunk_store_cls,
+ vector_store=vector_store_cls,
+ knowledge_name=knowledge_name,
+ white_paths=["/tmp"])
+ if is_regist:
+ knowledge_mgr.register(knowledge_db_cls)
+ return knowledge_db_cls
+
+
+def creat_retriever(knowledge_name: str, top_k, score_threshold):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ index_name, db_name = get_db_names(knowledge_name)
+ vector_store_ret = MindFAISS(x_dim=embedding_dim,
+ similarity_strategy=SimilarityStrategy.FLAT_L2,
+ devs=[dev],
+ load_local_index=index_name,
+ auto_save=True)
+ # 初始化文档chunk关系数据库
+ chunk_store_ret = SQLiteDocstore(db_path=db_name)
+ retriever_cls = Retriever(vector_store=vector_store_ret,
+ document_store=chunk_store_ret,
+ embed_func=text_emb.embed_documents,
+ k=top_k,
+ score_threshold=score_threshold)
+ return retriever_cls
+
+
+# 删除知识库
+def delete_knowledge_db(knowledge_name: str):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ knowledge_db_lst = knowledge_mgr.get_all()
+ if knowledge_name in knowledge_db_lst:
+ knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
+ clear_document(knowledge_name)
+ knowledge_mgr.delete(knowledge_db_cls)
+ return get_knowledge_db()
+
+
+# 获取知识库文档列表
+def get_document(knowledge_name: str):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
+ doc_names = knowledge_db_cls.get_all_documents()
+ return knowledge_name, doc_names, len(doc_names)
+
+
+# 清空知识库文档列表
+def clear_document(knowledge_name: str):
+ knowledge_name, doc_names, doc_cnt = get_document(knowledge_name)
+ if doc_cnt > 0:
+ knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
+ for doc_name in doc_names:
+ knowledge_db_cls.delete_file(doc_name)
+ return get_document(knowledge_name)
+ else:
+ return knowledge_name, doc_names, doc_cnt
+
+
+# 获取知识库列表
+def get_knowledge_db():
+ knowledge_db_lst = knowledge_mgr.get_all()
+ return knowledge_db_lst, len(knowledge_db_lst)
+
+
+def get_db_names(knowledge_name: str):
+ index_name = "./" + knowledge_name + "_faiss.index"
+ dn_name = "./" + knowledge_name + "_sql.db"
+ return index_name, dn_name
-def file_upload(files):
+def get_knowledge_rename(knowledge_name: str):
+ if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name:
+ return 'test'
+ else:
+ return knowledge_name
+
+
+# 上传知识库
+def file_upload(files,
+ knowledge_db_name: str = 'test',
+ chunk_size: int = 750,
+ chunk_overlap: int = 150
+ ):
+ save_file_path = "/tmp/document_files"
+ knowledge_db_name = get_knowledge_rename(knowledge_db_name)
# 指定保存文件的文件夹
- if not os.path.exists(SAVE_FILE_PATH):
- os.makedirs(SAVE_FILE_PATH)
+ if not os.path.exists(save_file_path):
+ os.makedirs(save_file_path)
if files is None or len(files) == 0:
print('no file need save')
+ if knowledge_db_name in knowledge_mgr.get_all():
+ knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False)
+ else:
+ knowledge_db_cls = creat_knowledge_db(knowledge_db_name)
+ # 注册文档加载器
+ loader_mng = LoaderMng()
+ # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的
+ loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
+ loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])
+ loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"])
+ loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"])
+
+ # 加载文档切分器,使用langchain的
+ loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter,
+ file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
+ splitter_params={"chunk_size": chunk_size,
+ "chunk_overlap": chunk_overlap,
+ "keep_separator": False
+ })
for file in files:
try:
# 上传领域知识文档
- shutil.copy(file.name, SAVE_FILE_PATH)
+ shutil.copy(file.name, save_file_path)
# 知识库:chunk\embedding\add
- upload_files(text_knowledge_db, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents,
- force=True)
- print(f"file {file.name} save to {SAVE_FILE_PATH}.")
+ upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True)
+ print(f"file {file.name} save to {save_file_path}.")
except Exception as err:
logger.error(f"save failed, find exception: {err}")
@@ -83,6 +386,7 @@ def file_change(files, upload_btn):
print("file changes")
+# 构建gradio框架
def build_demo():
with gr.Blocks() as demo:
gr.HTML("
检索增强生成(RAG)对话
powered by MindX RAG
")
@@ -94,14 +398,51 @@ def build_demo():
files = gr.Files(
height=300,
file_count="multiple",
- file_types=[".docx", ".txt", ".md", ".pdf"],
+ file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
interactive=True,
label="上传知识库文档"
)
with gr.Row():
upload_btn = gr.Button("上传文件")
with gr.Row():
- with gr.Accordion(label='知识库情况'):
+ with gr.TabItem(label='知识库情况'):
+ set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库')
+ with gr.Row():
+ creat_knowledge_btn = gr.Button('创建知识库')
+ delete_knowledge_btn = gr.Button('删除知识库')
+ knowledge_name_output = gr.Textbox(label='知识库列表')
+ knowledge_number_output = gr.Textbox(label='知识库数量')
+ with gr.Row():
+ show_knowledge_btn = gr.Button('显示知识库')
+ with gr.TabItem(label='文件情况'):
+ knowledge_name = gr.Textbox('知识库名称')
+ knowledge_file_output = gr.Textbox('知识库文件列表')
+ knowledge_file_num_output = gr.Textbox('知识库文件数量')
+ with gr.Row():
+ knowledge_file_out_btn = gr.Button('显示文件情况')
+ knowledge_clear_btn = gr.Button('清空知识库')
+ with gr.Row():
+ with gr.According(label='文档切分参数设置',open=False):
+ chunk_size = gr.Slider(
+ minimum=50,
+ maximum=5000,
+ value=750,
+ step=50,
+ interactive=True,
+ label="chunk_size",
+ info="文本切分长度"
+ )
+ chunk_overlap = gr.Slider(
+ minimum=10,
+ maximum=500,
+ value=150,
+ step=10,
+ interactive=True,
+ label="chunk_overlap",
+ info="文本切分填充长度"
+ )
+ with gr.Row():
+ with gr.According(label='大模型参数设置', open=False):
temperature = gr.Slider(
minimum=0.01,
maximum=2,
@@ -129,101 +470,84 @@ def build_demo():
label="最大tokens",
info="输入+输出最多的tokens数"
)
- with gr.Column(scale=200):
- chatbot = gr.Chatbot(height=500)
- with gr.Row():
- msg = gr.Textbox(placeholder="在此输入问题...", container=False)
+ is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?")
+ history_n = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=5,
+ step=1,
+ interactive=True,
+ label="历史提问重新轮数",
+ info="问题重写时所参考的历史提问轮数"
+ )
with gr.Row():
- send_btn = gr.Button(value="发送", variant="primary")
- clean_btn = gr.Button(value="清空历史")
- send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False).then(bot_response,
- [chatbot, max_tokens,
- temperature, top_p], chatbot)
+ with gr.According(label='检索参数设置', open=False):
+ score_threshold = gr.Slider(
+ minimum=0,
+ maximum=1,
+ value=0.5,
+ step=0.01,
+ interactive=True,
+ label="score_threshold",
+ info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。"
+ )
+ top_k = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=1,
+ step=1,
+ interactive=True,
+ label="top_k",
+ info="相似性检索返回条数"
+ )
+ show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择")
+ with gr.Column(scale=200):
+ with gr.Tabs():
+ with gr.TabItem("对话窗口"):
+ chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?")
+ chatbot = gr.Chatbot(height=550)
+ with gr.Row():
+ msg = gr.Textbox(placeholder="在此输入问题...", container=False)
+ with gr.Row():
+ send_btn = gr.Button(value="发送", variant="primary")
+ clean_btn = gr.Button(value="清空历史")
+ with gr.TabItem("检索窗口"):
+ chatbot_r = gr.Chatbot(height=550)
+ with gr.Row():
+ msg_r = gr.Textbox(placeholder="在此输入问题...", container=False)
+ with gr.Row():
+ send_btn_r = gr.Button(value="文档检索", variant="primary")
+ clean_btn_r = gr.Button(value="清空历史")
+ # RAG发送
+ send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False
+ ).then(bot_response,
+ [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name],
+ [chatbot, chatbot_r])
+ # RAG清除历史
clean_btn.click(clear_history, chatbot, chatbot)
+ # 上传文件
files.change(file_change, [], [])
- upload_btn.click(file_upload, files, []).then(clear_history, files, files)
+ upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files)
+ # 管理所有知识库
+ creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output])
+ show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output])
+ delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output])
+ # 管理知识库里文件
+ knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output])
+ knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output])
+ # 检索发送
+ send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False
+ ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r)
+ # 检索清除历史
+ clean_btn_r.click(clear_history, chatbot_r, chatbot_r)
return demo
-def create_gradio(port):
+def create_gradio(ports):
demo = build_demo()
demo.queue()
- demo.launch(share=True, server_name="0.0.0.0", server_port=port)
-
-
-class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
- def _get_default_metavar_for_optional(self, action):
- return action.type.__name__
-
- def _get_default_metavar_for_positional(self, action):
- return action.type.__name__
+ demo.launch(share=True, server_name="0.0.0.0", server_port=ports)
if __name__ == '__main__':
- parse = argparse.ArgumentParser(formatter_class=CustomFormatter)
- parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding",
- help="embedding模型本地路径")
- parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度")
- parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址")
- parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称")
- parse.add_argument("--port", type=str, default=8080, help="web后端端口")
-
- args = parse.parse_args()
-
- try:
- # 离线构建知识库,首先注册文档处理器
- loader_mng = LoaderMng()
- # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的
- loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
- loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])
- loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"])
-
- # 加载文档切分器,使用langchain的
- loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter,
- file_types=[".docx", ".txt", ".md", ".pdf"],
- splitter_params={"chunk_size": 750,
- "chunk_overlap": 150,
- "keep_separator": False
- }
- )
-
- # 设置向量检索使用的npu卡,具体可以用的卡可执行npu-smi info查询获取
- dev = 0
-
- # 初始化向量数据库
- vector_store = MindFAISS(x_dim=args.embedding_dim,
- similarity_strategy=SimilarityStrategy.FLAT_L2,
- devs=[dev],
- load_local_index="./faiss.index",
- auto_save=True
- )
- # 初始化文档chunk关系数据库
- chunk_store = SQLiteDocstore(db_path="./sql.db")
- # 初始化知识管理关系数据库
- knowledge_store = KnowledgeStore(db_path="./sql.db")
- # 初始化知识库管理
- text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store,
- chunk_store=chunk_store,
- vector_store=vector_store,
- knowledge_name="test",
- white_paths=["/tmp"]
- )
-
- # 加载embedding模型,请根据模型具体路径适配
- text_emb = TextEmbedding(model_path=args.embedding_path, dev_id=dev)
-
- # Step2在线问题答复,初始化检索器
- text_retriever = Retriever(vector_store=vector_store,
- document_store=chunk_store,
- embed_func=text_emb.embed_documents,
- k=1,
- score_threshold=0.4
- )
- # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改
- llm = Text2TextLLM(base_url=args.llm_url, model_name=args.model_name, client_param=ClientParam(use_http=True))
- text2text_chain = SingleText2TextChain(llm=llm,
- retriever=text_retriever,
- )
- create_gradio(int(args.port))
- except Exception as e:
- logger.info(f"exception happed :{e}")
+ create_gradio(port)
--
Gitee
From 01ec67c2a0a8da930d0867ac111d47f4b2ed08f4 Mon Sep 17 00:00:00 2001
From: tanfeng <823018000@qq.com>
Date: Mon, 25 Nov 2024 11:07:33 +0800
Subject: [PATCH 2/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?=
=?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mxRAG/PocValidation/chat_with_ascend/app.py | 731 ++++++++++----------
1 file changed, 359 insertions(+), 372 deletions(-)
diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py
index d5805bfb8..5d6dea138 100644
--- a/mxRAG/PocValidation/chat_with_ascend/app.py
+++ b/mxRAG/PocValidation/chat_with_ascend/app.py
@@ -24,77 +24,148 @@ from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader
from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr
from mx_rag.knowledge.handler import upload_files, LoaderMng
-
-class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
- def _get_default_metavar_for_optional(self, action):
- return action.type.__name__
-
- def _get_default_metavar_for_positional(self, action):
- return action.type.__name__
-
-
-parse = argparse.ArgumentParser(formatter_class=CustomFormatter)
-parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", help="embedding模型本地路径")
-parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型")
-parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed",help="使用TEI服务化的embedding模型url地址")
-parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度")
-parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址")
-parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称")
-parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型")
-parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径")
-parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址")
-parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取")
-parse.add_argument("--port", type=int, default=8080, help="web后端端口")
-
-args = parse.parse_args().__dict__
-tei_emb: bool = args.pop('tei_emb')
-embedding_path: str = args.pop('embedding_path')
-embedding_url: str = args.pop('embedding_url')
-embedding_dim: int = args.pop('embedding_dim')
-llm_url: str = args.pop('llm_url')
-model_name: str = args.pop('model_name')
-tei_reranker: bool = args.pop('tei_reranker')
-reranker_path: str = args.pop('reranker_path')
-reranker_url: str = args.pop('reranker_url')
-dev: int = args.pop('dev')
-port: int = args.pop('port')
-
-# 初始化向量数据库
-vector_store = MindFAISS(x_dim=embedding_dim,
- similarity_strategy=SimilarityStrategy.FLAT_L2,
- devs=[dev],
- load_local_index="./test_faiss.index",
- auto_save=True
- )
-# 初始化文档chunk关系数据库
-chunk_store = SQLiteDocstore(db_path="./test_sql.db")
-# 初始化知识管理关系数据库
-knowledge_store = KnowledgeStore(db_path="./test_sql.db")
-# 初始化知识库管理
-text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store,
- chunk_store=chunk_store,
- vector_store=vector_store,
- knowledge_name="test",
- white_paths=["/tmp"]
- )
+#初始化数据库管理
knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db"))
-# 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改
-llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True))
-# 配置embedding模型,请根据模型具体路径适配
-if tei_emb:
- text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True))
-else:
- text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev)
+# 创建新的知识库
+def creat_knowledge_db(knowledge_name: str, is_regist: bool = True):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ index_name, db_name = get_db_names(knowledge_name)
+ vector_store_cls = MindFAISS(x_dim=embedding_dim,
+ similarity_strategy=SimilarityStrategy.FLAT_L2,
+ devs=[dev],
+ load_local_index=index_name,
+ auto_save=True)
+ # 初始化文档chunk关系数据库
+ chunk_store_cls = SQLiteDocstore(db_path=db_name)
+ # 初始化知识管理关系数据库
+ knowledge_store_cls = KnowledgeStore(db_path=db_name)
+ # 初始化知识库管理
+ knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls,
+ chunk_store=chunk_store_cls,
+ vector_store=vector_store_cls,
+ knowledge_name=knowledge_name,
+ white_paths=["/tmp"])
+ if is_regist:
+ knowledge_mgr.register(knowledge_db_cls)
+ return knowledge_db_cls
-# 配置rerank,请根据模型具体路径适配
-if tei_reranker:
- reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True))
-elif reranker_path is not None:
- reranker = LocalReranker(model_path=reranker_path, dev_id=dev)
-else:
- reranker = None
+
+# 创建检索器
+def creat_retriever(knowledge_name: str, top_k, score_threshold):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ index_name, db_name = get_db_names(knowledge_name)
+ vector_store_ret = MindFAISS(x_dim=embedding_dim,
+ similarity_strategy=SimilarityStrategy.FLAT_L2,
+ devs=[dev],
+ load_local_index=index_name,
+ auto_save=True)
+ # 初始化文档chunk关系数据库
+ chunk_store_ret = SQLiteDocstore(db_path=db_name)
+ retriever_cls = Retriever(vector_store=vector_store_ret,
+ document_store=chunk_store_ret,
+ embed_func=text_emb.embed_documents,
+ k=top_k,
+ score_threshold=score_threshold)
+ return retriever_cls
+
+
+# 删除知识库
+def delete_knowledge_db(knowledge_name: str):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ knowledge_db_lst = knowledge_mgr.get_all()
+ if knowledge_name in knowledge_db_lst:
+ knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
+ clear_document(knowledge_name)
+ knowledge_mgr.delete(knowledge_db_cls)
+ return get_knowledge_db()
+
+
+# 获取知识库中文档列表
+def get_document(knowledge_name: str):
+ knowledge_name = get_knowledge_rename(knowledge_name)
+ knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
+ doc_names = knowledge_db_cls.get_all_documents()
+ return knowledge_name, doc_names, len(doc_names)
+
+
+# 清空知识库中文档列表
+def clear_document(knowledge_name: str):
+ knowledge_name, doc_names, doc_cnt = get_document(knowledge_name)
+ if doc_cnt > 0:
+ knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
+ for doc_name in doc_names:
+ knowledge_db_cls.delete_file(doc_name)
+ return get_document(knowledge_name)
+ else:
+ return knowledge_name, doc_names, doc_cnt
+
+
+# 获取知识库列表
+def get_knowledge_db():
+ knowledge_db_lst = knowledge_mgr.get_all()
+ return knowledge_db_lst, len(knowledge_db_lst)
+
+
+# 上传知识库
+def file_upload(files,
+ knowledge_db_name: str = 'test',
+ chunk_size: int = 750,
+ chunk_overlap: int = 150
+ ):
+ save_file_path = "/tmp/document_files"
+ knowledge_db_name = get_knowledge_rename(knowledge_db_name)
+ # 指定保存文件的文件夹
+ if not os.path.exists(save_file_path):
+ os.makedirs(save_file_path)
+ if files is None or len(files) == 0:
+ print('no file need save')
+ if knowledge_db_name in knowledge_mgr.get_all():
+ knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False)
+ else:
+ knowledge_db_cls = creat_knowledge_db(knowledge_db_name)
+ # 注册文档加载器
+ loader_mng = LoaderMng()
+ # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的
+ loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
+ loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])
+ loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"])
+ loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"])
+
+ # 加载文档切分器,使用langchain的
+ loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter,
+ file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
+ splitter_params={"chunk_size": chunk_size,
+ "chunk_overlap": chunk_overlap,
+ "keep_separator": False
+ })
+ for file in files:
+ try:
+ # 上传领域知识文档
+ shutil.copy(file.name, save_file_path)
+ # 知识库:chunk\embedding\add
+ upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True)
+ print(f"file {file.name} save to {save_file_path}.")
+ except Exception as err:
+ logger.error(f"save failed, find exception: {err}")
+
+
+def file_change(files, upload_btn):
+ print("file changes")
+
+
+def get_db_names(knowledge_name: str):
+ index_name = "./" + knowledge_name + "_faiss.index"
+ db_name = "./" + knowledge_name + "_sql.db"
+ return index_name, db_name
+
+
+def get_knowledge_rename(knowledge_name: str):
+ if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name:
+ return 'test'
+ else:
+ return knowledge_name
# 历史问题改写
@@ -187,8 +258,9 @@ def bot_response(history,
if len(q_docs) > 0:
history_last = ''
for i ,source in enumerate(q_docs):
- sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \
- "--->参考内容:" +source['page_content'] + "\n"
+ sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \
+ "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \
+ "参考内容:" +source['page_content'] + "\n"
history_last += sources
history += [[None, history_last]]
yield history, history_r
@@ -202,8 +274,9 @@ def bot_response(history,
if len(q_docs) > 0:
history_r_last = ''
for i ,source in enumerate(q_docs):
- sources = "\n====检索信息来源" + str(i + 1) + ":" + source['metadata']['source'] + "====" + "\n" + \
- "--->参考内容:" +source['page_content'] + "\n"
+ sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \
+ "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \
+ "参考内容:" +source['page_content'] + "\n"
history_r_last += sources
history_r += [[history[-1][0], history_r_last]]
yield history, history_r
@@ -224,8 +297,9 @@ def re_response(history_r,
if len(q_docs) > 0:
history_r_last = ''
for i, source in enumerate(q_docs):
- sources = "\n====检索信息来源" + str(i + 1) + ":" + source.metadata['source'] + "====" + "\n" + \
- "--->参考内容:" + source.page_content + "\n"
+ sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) +\
+ "[文件名]:" + source.metadata['source'] + "====" + "\n" + \
+ "参考内容:" + source.page_content + "\n"
history_r_last += sources
history_r[-1][1] = history_r_last
else:
@@ -247,307 +321,220 @@ def clear_history(history):
return []
-# 创建新的知识库
-def creat_knowledge_db(knowledge_name: str, is_regist: bool=True):
- knowledge_name = get_knowledge_rename(knowledge_name)
- index_name, db_name = get_db_names(knowledge_name)
- vector_store_cls = MindFAISS(x_dim=embedding_dim,
- similarity_strategy=SimilarityStrategy.FLAT_L2,
- devs=[dev],
- load_local_index=index_name,
- auto_save=True)
- # 初始化文档chunk关系数据库
- chunk_store_cls = SQLiteDocstore(db_path=db_name)
- # 初始化知识管理关系数据库
- knowledge_store_cls = KnowledgeStore(db_path=db_name)
- # 初始化知识库管理
- knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls,
- chunk_store=chunk_store_cls,
- vector_store=vector_store_cls,
- knowledge_name=knowledge_name,
- white_paths=["/tmp"])
- if is_regist:
- knowledge_mgr.register(knowledge_db_cls)
- return knowledge_db_cls
-
-
-def creat_retriever(knowledge_name: str, top_k, score_threshold):
- knowledge_name = get_knowledge_rename(knowledge_name)
- index_name, db_name = get_db_names(knowledge_name)
- vector_store_ret = MindFAISS(x_dim=embedding_dim,
- similarity_strategy=SimilarityStrategy.FLAT_L2,
- devs=[dev],
- load_local_index=index_name,
- auto_save=True)
- # 初始化文档chunk关系数据库
- chunk_store_ret = SQLiteDocstore(db_path=db_name)
- retriever_cls = Retriever(vector_store=vector_store_ret,
- document_store=chunk_store_ret,
- embed_func=text_emb.embed_documents,
- k=top_k,
- score_threshold=score_threshold)
- return retriever_cls
-
-
-# 删除知识库
-def delete_knowledge_db(knowledge_name: str):
- knowledge_name = get_knowledge_rename(knowledge_name)
- knowledge_db_lst = knowledge_mgr.get_all()
- if knowledge_name in knowledge_db_lst:
- knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
- clear_document(knowledge_name)
- knowledge_mgr.delete(knowledge_db_cls)
- return get_knowledge_db()
-
-
-# 获取知识库文档列表
-def get_document(knowledge_name: str):
- knowledge_name = get_knowledge_rename(knowledge_name)
- knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
- doc_names = knowledge_db_cls.get_all_documents()
- return knowledge_name, doc_names, len(doc_names)
-
-
-# 清空知识库文档列表
-def clear_document(knowledge_name: str):
- knowledge_name, doc_names, doc_cnt = get_document(knowledge_name)
- if doc_cnt > 0:
- knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False)
- for doc_name in doc_names:
- knowledge_db_cls.delete_file(doc_name)
- return get_document(knowledge_name)
- else:
- return knowledge_name, doc_names, doc_cnt
-
-
-# 获取知识库列表
-def get_knowledge_db():
- knowledge_db_lst = knowledge_mgr.get_all()
- return knowledge_db_lst, len(knowledge_db_lst)
-
-
-def get_db_names(knowledge_name: str):
- index_name = "./" + knowledge_name + "_faiss.index"
- dn_name = "./" + knowledge_name + "_sql.db"
- return index_name, dn_name
-
-
-def get_knowledge_rename(knowledge_name: str):
- if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name:
- return 'test'
+if __name__ == '__main__':
+ class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter):
+ def _get_default_metavar_for_optional(self, action):
+ return action.type.__name__
+
+ def _get_default_metavar_for_positional(self, action):
+ return action.type.__name__
+
+ parse = argparse.ArgumentParser(formatter_class=CustomFormatter)
+ parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding",
+ help="embedding模型本地路径")
+ parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型")
+ parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed",
+ help="使用TEI服务化的embedding模型url地址")
+ parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度")
+ parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址")
+ parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称")
+ parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型")
+ parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径")
+ parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址")
+ parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取")
+ parse.add_argument("--port", type=int, default=8080, help="web后端端口")
+
+ args = parse.parse_args().__dict__
+ embedding_path: str = args.pop('embedding_path')
+ tei_emb: bool = args.pop('tei_emb')
+ embedding_url: str = args.pop('embedding_url')
+ embedding_dim: int = args.pop('embedding_dim')
+ llm_url: str = args.pop('llm_url')
+ model_name: str = args.pop('model_name')
+ tei_reranker: bool = args.pop('tei_reranker')
+ reranker_path: str = args.pop('reranker_path')
+ reranker_url: str = args.pop('reranker_url')
+ dev: int = args.pop('dev')
+ port: int = args.pop('port')
+
+ # 初始化test数据库
+ knowledge_db = creat_knowledge_db('test')
+ # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改
+ llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True))
+ # 配置embedding模型,请根据模型具体路径适配
+ if tei_emb:
+ text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True))
else:
- return knowledge_name
-
-
-# 上传知识库
-def file_upload(files,
- knowledge_db_name: str = 'test',
- chunk_size: int = 750,
- chunk_overlap: int = 150
- ):
- save_file_path = "/tmp/document_files"
- knowledge_db_name = get_knowledge_rename(knowledge_db_name)
- # 指定保存文件的文件夹
- if not os.path.exists(save_file_path):
- os.makedirs(save_file_path)
- if files is None or len(files) == 0:
- print('no file need save')
- if knowledge_db_name in knowledge_mgr.get_all():
- knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False)
+ text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev)
+ # 配置rerank,请根据模型具体路径适配
+ if tei_reranker:
+ reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True))
+ elif reranker_path is not None:
+ reranker = LocalReranker(model_path=reranker_path, dev_id=dev)
else:
- knowledge_db_cls = creat_knowledge_db(knowledge_db_name)
- # 注册文档加载器
- loader_mng = LoaderMng()
- # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的
- loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
- loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"])
- loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"])
- loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"])
-
- # 加载文档切分器,使用langchain的
- loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter,
- file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
- splitter_params={"chunk_size": chunk_size,
- "chunk_overlap": chunk_overlap,
- "keep_separator": False
- })
- for file in files:
- try:
- # 上传领域知识文档
- shutil.copy(file.name, save_file_path)
- # 知识库:chunk\embedding\add
- upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True)
- print(f"file {file.name} save to {save_file_path}.")
- except Exception as err:
- logger.error(f"save failed, find exception: {err}")
-
-
-def file_change(files, upload_btn):
- print("file changes")
-
-
-# 构建gradio框架
-def build_demo():
- with gr.Blocks() as demo:
- gr.HTML("检索增强生成(RAG)对话
powered by MindX RAG
")
- with gr.Row():
- with gr.Column(scale=100):
- with gr.Row():
- model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True)
- with gr.Row():
- files = gr.Files(
- height=300,
- file_count="multiple",
- file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
- interactive=True,
- label="上传知识库文档"
- )
- with gr.Row():
- upload_btn = gr.Button("上传文件")
- with gr.Row():
- with gr.TabItem(label='知识库情况'):
- set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库')
- with gr.Row():
- creat_knowledge_btn = gr.Button('创建知识库')
- delete_knowledge_btn = gr.Button('删除知识库')
- knowledge_name_output = gr.Textbox(label='知识库列表')
- knowledge_number_output = gr.Textbox(label='知识库数量')
- with gr.Row():
- show_knowledge_btn = gr.Button('显示知识库')
- with gr.TabItem(label='文件情况'):
- knowledge_name = gr.Textbox('知识库名称')
- knowledge_file_output = gr.Textbox('知识库文件列表')
- knowledge_file_num_output = gr.Textbox('知识库文件数量')
- with gr.Row():
- knowledge_file_out_btn = gr.Button('显示文件情况')
- knowledge_clear_btn = gr.Button('清空知识库')
- with gr.Row():
- with gr.According(label='文档切分参数设置',open=False):
- chunk_size = gr.Slider(
- minimum=50,
- maximum=5000,
- value=750,
- step=50,
+ reranker = None
+
+ # 构建gradio框架
+ def build_demo():
+ with (gr.Blocks() as demo):
+ gr.HTML("检索增强生成(RAG)对话
powered by MindX RAG
")
+ with gr.Row():
+ with gr.Column(scale=100):
+ with gr.Row():
+ model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True)
+ with gr.Row():
+ files = gr.Files(
+ height=300,
+ file_count="multiple",
+ file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
interactive=True,
- label="chunk_size",
- info="文本切分长度"
+ label="上传知识库文档"
)
- chunk_overlap = gr.Slider(
- minimum=10,
- maximum=500,
- value=150,
- step=10,
- interactive=True,
- label="chunk_overlap",
- info="文本切分填充长度"
- )
- with gr.Row():
- with gr.According(label='大模型参数设置', open=False):
- temperature = gr.Slider(
- minimum=0.01,
- maximum=2,
- value=0.5,
- step=0.01,
- interactive=True,
- label="温度",
- info="Token生成的随机性"
- )
- top_p = gr.Slider(
- minimum=0.01,
- maximum=1,
- value=0.95,
- step=0.05,
- interactive=True,
- label="Top P",
- info="累计概率总和阈值"
- )
- max_tokens = gr.Slider(
- minimum=100,
- maximum=1024,
- value=512,
- step=1,
- interactive=True,
- label="最大tokens",
- info="输入+输出最多的tokens数"
- )
- is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?")
- history_n = gr.Slider(
- minimum=1,
- maximum=10,
- value=5,
- step=1,
- interactive=True,
- label="历史提问重新轮数",
- info="问题重写时所参考的历史提问轮数"
- )
- with gr.Row():
- with gr.According(label='检索参数设置', open=False):
- score_threshold = gr.Slider(
- minimum=0,
- maximum=1,
- value=0.5,
- step=0.01,
- interactive=True,
- label="score_threshold",
- info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。"
- )
- top_k = gr.Slider(
- minimum=1,
- maximum=5,
- value=1,
- step=1,
- interactive=True,
- label="top_k",
- info="相似性检索返回条数"
- )
- show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择")
- with gr.Column(scale=200):
- with gr.Tabs():
- with gr.TabItem("对话窗口"):
- chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?")
- chatbot = gr.Chatbot(height=550)
- with gr.Row():
- msg = gr.Textbox(placeholder="在此输入问题...", container=False)
- with gr.Row():
- send_btn = gr.Button(value="发送", variant="primary")
- clean_btn = gr.Button(value="清空历史")
- with gr.TabItem("检索窗口"):
- chatbot_r = gr.Chatbot(height=550)
- with gr.Row():
- msg_r = gr.Textbox(placeholder="在此输入问题...", container=False)
- with gr.Row():
- send_btn_r = gr.Button(value="文档检索", variant="primary")
- clean_btn_r = gr.Button(value="清空历史")
- # RAG发送
- send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False
- ).then(bot_response,
- [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name],
- [chatbot, chatbot_r])
- # RAG清除历史
- clean_btn.click(clear_history, chatbot, chatbot)
- # 上传文件
- files.change(file_change, [], [])
- upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files)
- # 管理所有知识库
- creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output])
- show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output])
- delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output])
- # 管理知识库里文件
- knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output])
- knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output])
- # 检索发送
- send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False
- ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r)
- # 检索清除历史
- clean_btn_r.click(clear_history, chatbot_r, chatbot_r)
- return demo
-
-
-def create_gradio(ports):
- demo = build_demo()
- demo.queue()
- demo.launch(share=True, server_name="0.0.0.0", server_port=ports)
-
-
-if __name__ == '__main__':
+ with gr.Row():
+ upload_btn = gr.Button("上传文件")
+ with gr.Row():
+ with gr.TabItem(label='知识库情况'):
+ set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库')
+ with gr.Row():
+ creat_knowledge_btn = gr.Button('创建知识库')
+ delete_knowledge_btn = gr.Button('删除知识库')
+ knowledge_name_output = gr.Textbox(label='知识库列表')
+ knowledge_number_output = gr.Textbox(label='知识库数量')
+ with gr.Row():
+ show_knowledge_btn = gr.Button('显示知识库')
+ with gr.TabItem(label='文件情况'):
+ knowledge_name = gr.Textbox('知识库名称')
+ knowledge_file_output = gr.Textbox('知识库文件列表')
+ knowledge_file_num_output = gr.Textbox('知识库文件数量')
+ with gr.Row():
+ knowledge_file_out_btn = gr.Button('显示文件情况')
+ knowledge_clear_btn = gr.Button('清空知识库')
+ with gr.Row():
+ with gr.According(label='文档切分参数设置',open=False):
+ chunk_size = gr.Slider(
+ minimum=50,
+ maximum=5000,
+ value=750,
+ step=50,
+ interactive=True,
+ label="chunk_size",
+ info="文本切分长度"
+ )
+ chunk_overlap = gr.Slider(
+ minimum=10,
+ maximum=500,
+ value=150,
+ step=10,
+ interactive=True,
+ label="chunk_overlap",
+ info="文本切分填充长度"
+ )
+ with gr.Row():
+ with gr.According(label='大模型参数设置', open=False):
+ temperature = gr.Slider(
+ minimum=0.01,
+ maximum=2,
+ value=0.5,
+ step=0.01,
+ interactive=True,
+ label="温度",
+ info="Token生成的随机性"
+ )
+ top_p = gr.Slider(
+ minimum=0.01,
+ maximum=1,
+ value=0.95,
+ step=0.05,
+ interactive=True,
+ label="Top P",
+ info="累计概率总和阈值"
+ )
+ max_tokens = gr.Slider(
+ minimum=100,
+ maximum=1024,
+ value=512,
+ step=1,
+ interactive=True,
+ label="最大tokens",
+ info="输入+输出最多的tokens数"
+ )
+ is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?")
+ history_n = gr.Slider(
+ minimum=1,
+ maximum=10,
+ value=5,
+ step=1,
+ interactive=True,
+ label="历史提问重新轮数",
+ info="问题重写时所参考的历史提问轮数"
+ )
+ with gr.Row():
+ with gr.According(label='检索参数设置', open=False):
+ score_threshold = gr.Slider(
+ minimum=0,
+ maximum=1,
+ value=0.5,
+ step=0.01,
+ interactive=True,
+ label="score_threshold",
+ info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。"
+ )
+ top_k = gr.Slider(
+ minimum=1,
+ maximum=5,
+ value=1,
+ step=1,
+ interactive=True,
+ label="top_k",
+ info="相似性检索返回条数"
+ )
+ show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择")
+ with gr.Column(scale=200):
+ with gr.Tabs():
+ with gr.TabItem("对话窗口"):
+ chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?")
+ chatbot = gr.Chatbot(height=550)
+ with gr.Row():
+ msg = gr.Textbox(placeholder="在此输入问题...", container=False)
+ with gr.Row():
+ send_btn = gr.Button(value="发送", variant="primary")
+ clean_btn = gr.Button(value="清空历史")
+ with gr.TabItem("检索窗口"):
+ chatbot_r = gr.Chatbot(height=550)
+ with gr.Row():
+ msg_r = gr.Textbox(placeholder="在此输入问题...", container=False)
+ with gr.Row():
+ send_btn_r = gr.Button(value="文档检索", variant="primary")
+ clean_btn_r = gr.Button(value="清空历史")
+ # RAG发送
+ send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False
+ ).then(bot_response,
+ [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name],
+ [chatbot, chatbot_r])
+ # RAG清除历史
+ clean_btn.click(clear_history, chatbot, chatbot)
+ # 上传文件
+ files.change(file_change, [], [])
+ upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files)
+ # 管理所有知识库
+ creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output])
+ show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output])
+ delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output])
+ # 管理知识库里文件
+ knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output])
+ knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output])
+ # 检索发送
+ send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False
+ ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r)
+ # 检索清除历史
+ clean_btn_r.click(clear_history, chatbot_r, chatbot_r)
+ return demo
+
+
+ def create_gradio(ports):
+ demo = build_demo()
+ demo.queue()
+ demo.launch(share=True, server_name="0.0.0.0", server_port=ports)
+
+ # 启动gradio
create_gradio(port)
--
Gitee
From 92a719ec317bf677315665fd781f6e717ab82dc7 Mon Sep 17 00:00:00 2001
From: tanfeng <823018000@qq.com>
Date: Mon, 25 Nov 2024 11:52:18 +0800
Subject: [PATCH 3/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?=
=?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mxRAG/PocValidation/chat_with_ascend/app.py | 68 ++++++++++-----------
1 file changed, 34 insertions(+), 34 deletions(-)
diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py
index 5d6dea138..4eb7c03cc 100644
--- a/mxRAG/PocValidation/chat_with_ascend/app.py
+++ b/mxRAG/PocValidation/chat_with_ascend/app.py
@@ -24,7 +24,7 @@ from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader
from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr
from mx_rag.knowledge.handler import upload_files, LoaderMng
-#初始化数据库管理
+# 初始化数据库管理
knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db"))
@@ -125,7 +125,7 @@ def file_upload(files,
knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False)
else:
knowledge_db_cls = creat_knowledge_db(knowledge_db_name)
- # 注册文档加载器
+ # 注册文档处理器
loader_mng = LoaderMng()
# 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的
loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"])
@@ -207,8 +207,8 @@ def bot_response(history,
top_k: int = 1,
chat_type: str = "RAG检索增强对话",
show_type: str = "不显示",
- is_rewrite:str = "否",
- knowledge_name: str = 'test',
+ is_rewrite: str = "否",
+ knowledge_name: str = 'test'
):
chat_type_mapping = {"RAG检索增强对话":1,
"仅大模型对话":0}
@@ -218,15 +218,15 @@ def bot_response(history,
is_rewrite_mapping = {"是":1,
"否":0}
# 初始化检索器
- retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold)
+ retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold)
# 初始化chain
- text2text_chain = SingleText2TextChain(llm=llm, retriever=retrieve_cls, reranker=reranker)
+ text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever_cls, reranker=reranker)
# 历史问题改写
if is_rewrite_mapping.get(is_rewrite) == 1:
history = generate_question(history, llm, history_n)
history[-1][1] = '推理错误'
try:
- # 仅使用大模型
+ # 仅使用大模型回答
if chat_type_mapping.get(chat_type) == 0:
response = llm.chat_streamly(query=history[-1][0],
llm_config=LLMParameterConfig(max_tokens=max_tokens,
@@ -257,10 +257,10 @@ def bot_response(history,
# 有检索到信息
if len(q_docs) > 0:
history_last = ''
- for i ,source in enumerate(q_docs):
- sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \
+ for i, source in enumerate(q_docs):
+ sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \
"[文件名]:" + source['metadata']['source'] + "====" + "\n" + \
- "参考内容:" +source['page_content'] + "\n"
+ "参考内容:" + source['page_content'] + "\n"
history_last += sources
history += [[None, history_last]]
yield history, history_r
@@ -274,9 +274,9 @@ def bot_response(history,
if len(q_docs) > 0:
history_r_last = ''
for i ,source in enumerate(q_docs):
- sources = "\n====检索信息来源:" + str(i + 1) + ":" + "[数据库名]:" + str(knowledge_name) + \
+ sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \
"[文件名]:" + source['metadata']['source'] + "====" + "\n" + \
- "参考内容:" +source['page_content'] + "\n"
+ "参考内容:" + source['page_content'] + "\n"
history_r_last += sources
history_r += [[history[-1][0], history_r_last]]
yield history, history_r
@@ -292,12 +292,12 @@ def re_response(history_r,
knowledge_name: str = 'test'
):
# 初始化检索器
- retrieve_cls = creat_retriever(knowledge_name, top_k, score_threshold)
- q_docs = retrieve_cls.invoke(history_r[-1][0])
+ retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold)
+ q_docs = retriever_cls.invoke(history_r[-1][0])
if len(q_docs) > 0:
history_r_last = ''
for i, source in enumerate(q_docs):
- sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) +\
+ sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \
"[文件名]:" + source.metadata['source'] + "====" + "\n" + \
"参考内容:" + source.page_content + "\n"
history_r_last += sources
@@ -309,7 +309,7 @@ def re_response(history_r,
# 检索信息
def user_retriever(user_message, history_r):
- return "",history_r + [[user_message, None]]
+ return "", history_r + [[user_message, None]]
# 聊天信息
@@ -366,7 +366,7 @@ if __name__ == '__main__':
text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True))
else:
text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev)
- # 配置rerank,请根据模型具体路径适配
+ # 配置reranker,请根据模型具体路径适配
if tei_reranker:
reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True))
elif reranker_path is not None:
@@ -381,10 +381,10 @@ if __name__ == '__main__':
with gr.Row():
with gr.Column(scale=100):
with gr.Row():
- model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True)
+ model_select = gr.Dropdown(choices=[model_name], value="Llama3-8B-Chinese-Chat", container=False, interactive=True)
with gr.Row():
- files = gr.Files(
- height=300,
+ files = gr.components.File(
+ height=100,
file_count="multiple",
file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"],
interactive=True,
@@ -393,8 +393,8 @@ if __name__ == '__main__':
with gr.Row():
upload_btn = gr.Button("上传文件")
with gr.Row():
- with gr.TabItem(label='知识库情况'):
- set_knowledge_name = gr.Textbox(label='设置当前知识库',placeholder='在此输入知识库名称,默认使用test知识库')
+ with gr.TabItem("知识库情况"):
+ set_knowledge_name = gr.Textbox(label='设置当前知识库', placeholder="在此输入知识库名称,默认使用test知识库")
with gr.Row():
creat_knowledge_btn = gr.Button('创建知识库')
delete_knowledge_btn = gr.Button('删除知识库')
@@ -402,15 +402,15 @@ if __name__ == '__main__':
knowledge_number_output = gr.Textbox(label='知识库数量')
with gr.Row():
show_knowledge_btn = gr.Button('显示知识库')
- with gr.TabItem(label='文件情况'):
- knowledge_name = gr.Textbox('知识库名称')
- knowledge_file_output = gr.Textbox('知识库文件列表')
- knowledge_file_num_output = gr.Textbox('知识库文件数量')
+ with gr.TabItem("文件情况"):
+ knowledge_name = gr.Textbox(label='知识库名称')
+ knowledge_file_output = gr.Textbox(label='知识库文件列表')
+ knowledge_file_num_output = gr.Textbox(label='知识库文件数量')
with gr.Row():
knowledge_file_out_btn = gr.Button('显示文件情况')
knowledge_clear_btn = gr.Button('清空知识库')
with gr.Row():
- with gr.According(label='文档切分参数设置',open=False):
+ with gr.Accordion(label='文档切分参数设置', open=False):
chunk_size = gr.Slider(
minimum=50,
maximum=5000,
@@ -430,7 +430,7 @@ if __name__ == '__main__':
info="文本切分填充长度"
)
with gr.Row():
- with gr.According(label='大模型参数设置', open=False):
+ with gr.Accordion(label='大模型参数设置', open=False):
temperature = gr.Slider(
minimum=0.01,
maximum=2,
@@ -458,14 +458,14 @@ if __name__ == '__main__':
label="最大tokens",
info="输入+输出最多的tokens数"
)
- is_rewrite = gr.Radio(['是','否'], value='否', label="是否根据历史提问重写问题?")
+ is_rewrite = gr.Radio(['是', '否'], value="否", label="是否根据历史提问重写问题?")
history_n = gr.Slider(
minimum=1,
maximum=10,
value=5,
step=1,
interactive=True,
- label="历史提问重新轮数",
+ label="历史提问重写轮数",
info="问题重写时所参考的历史提问轮数"
)
with gr.Row():
@@ -477,7 +477,7 @@ if __name__ == '__main__':
step=0.01,
interactive=True,
label="score_threshold",
- info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。"
+ info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。"
)
top_k = gr.Slider(
minimum=1,
@@ -488,11 +488,11 @@ if __name__ == '__main__':
label="top_k",
info="相似性检索返回条数"
)
- show_type = gr.Radio(['对话结束后显示','检索框单独显示','不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择")
+ show_type = gr.Radio(['对话结束后显示', '检索框单独显示', '不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择")
with gr.Column(scale=200):
with gr.Tabs():
with gr.TabItem("对话窗口"):
- chat_type = gr.Radio(['RAG检索增强对话','仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?")
+ chat_type = gr.Radio(['RAG检索增强对话', '仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?")
chatbot = gr.Chatbot(height=550)
with gr.Row():
msg = gr.Textbox(placeholder="在此输入问题...", container=False)
@@ -517,7 +517,7 @@ if __name__ == '__main__':
files.change(file_change, [], [])
upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files)
# 管理所有知识库
- creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name],[]).then(get_knowledge_db, [],[knowledge_name_output, knowledge_number_output])
+ creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name], []).then(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output])
show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output])
delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output])
# 管理知识库里文件
--
Gitee
From 3b002c8d0dd2a8b452425f7dc60d8b0d170e7b10 Mon Sep 17 00:00:00 2001
From: tanfeng <823018000@qq.com>
Date: Mon, 25 Nov 2024 12:02:25 +0800
Subject: [PATCH 4/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?=
=?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mxRAG/PocValidation/chat_with_ascend/app.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py
index 4eb7c03cc..53414ec6c 100644
--- a/mxRAG/PocValidation/chat_with_ascend/app.py
+++ b/mxRAG/PocValidation/chat_with_ascend/app.py
@@ -469,7 +469,7 @@ if __name__ == '__main__':
info="问题重写时所参考的历史提问轮数"
)
with gr.Row():
- with gr.According(label='检索参数设置', open=False):
+ with gr.Accordion(label='检索参数设置', open=False):
score_threshold = gr.Slider(
minimum=0,
maximum=1,
--
Gitee
From d63937f97b13c0bc7f9f88b02afe25b9dddafcad Mon Sep 17 00:00:00 2001
From: tanfeng <823018000@qq.com>
Date: Mon, 25 Nov 2024 12:09:27 +0800
Subject: [PATCH 5/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?=
=?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mxRAG/PocValidation/chat_with_ascend/app.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py
index 53414ec6c..cd4d6c720 100644
--- a/mxRAG/PocValidation/chat_with_ascend/app.py
+++ b/mxRAG/PocValidation/chat_with_ascend/app.py
@@ -218,6 +218,7 @@ def bot_response(history,
is_rewrite_mapping = {"是":1,
"否":0}
# 初始化检索器
+ knowledge_name = get_knowledge_rename(knowledge_name)
retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold)
# 初始化chain
text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever_cls, reranker=reranker)
--
Gitee
From 3f99338c0c819b0154429f8d2cdc22dacba7ddd6 Mon Sep 17 00:00:00 2001
From: tanfeng <823018000@qq.com>
Date: Mon, 25 Nov 2024 12:42:28 +0800
Subject: [PATCH 6/6] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=9C=80=E6=96=B0POC?=
=?UTF-8?q?=E7=89=88=E6=9C=AC=E5=B1=95=E7=A4=BA=E9=A1=B5=E9=9D=A2=EF=BC=8C?=
=?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=A2=9E=E5=8A=A0?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
mxRAG/PocValidation/chat_with_ascend/app.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py
index cd4d6c720..1d23eb9a3 100644
--- a/mxRAG/PocValidation/chat_with_ascend/app.py
+++ b/mxRAG/PocValidation/chat_with_ascend/app.py
@@ -274,7 +274,7 @@ def bot_response(history,
# 有检索到信息
if len(q_docs) > 0:
history_r_last = ''
- for i ,source in enumerate(q_docs):
+ for i, source in enumerate(q_docs):
sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \
"[文件名]:" + source['metadata']['source'] + "====" + "\n" + \
"参考内容:" + source['page_content'] + "\n"
@@ -293,6 +293,7 @@ def re_response(history_r,
knowledge_name: str = 'test'
):
# 初始化检索器
+ knowledge_name = get_knowledge_rename(knowledge_name)
retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold)
q_docs = retriever_cls.invoke(history_r[-1][0])
if len(q_docs) > 0:
--
Gitee