diff --git a/mxRAG/PocValidation/chat_with_ascend/app.py b/mxRAG/PocValidation/chat_with_ascend/app.py index 83d1f33d2530cfb511657d9f004acc375060fa52..1d23eb9a3ee2ec8143d2752386633189a6474f95 100644 --- a/mxRAG/PocValidation/chat_with_ascend/app.py +++ b/mxRAG/PocValidation/chat_with_ascend/app.py @@ -9,6 +9,9 @@ from loguru import logger from paddle.base import libpaddle from mx_rag.chain import SingleText2TextChain from mx_rag.embedding.local import TextEmbedding +from mx_rag.embedding.service import TEIEmbedding +from mx_rag.reranker.local import LocalReranker +from mx_rag.reranker.service import TEIReranker from mx_rag.llm import Text2TextLLM, LLMParameterConfig from mx_rag.retrievers import Retriever from mx_rag.storage.document_store import SQLiteDocstore @@ -17,64 +20,133 @@ from mx_rag.storage.vectorstore.vectorstore import SimilarityStrategy from mx_rag.utils import ClientParam from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import TextLoader -from mx_rag.document.loader import DocxLoader, PdfLoader -from mx_rag.knowledge import KnowledgeDB -from mx_rag.knowledge.knowledge import KnowledgeStore +from mx_rag.document.loader import DocxLoader, PdfLoader, ExcelLoader +from mx_rag.knowledge import KnowledgeStore, KnowledgeDB, KnowledgeMgrStore, KnowledgeMgr from mx_rag.knowledge.handler import upload_files, LoaderMng -global text2text_chain -global text_knowledge_db -global loader_mng -global text_emb +# 初始化数据库管理 +knowledge_mgr = KnowledgeMgr(KnowledgeMgrStore("./test_sql.db")) -def bot_response(history, - max_tokens: int = 512, - temperature: float = 0.5, - top_p: float = 0.95 - ): - # 将最新的问题传给RAG - try: - response = text2text_chain.query(history[-1][0], - LLMParameterConfig(max_tokens=max_tokens, temperature=temperature, top_p=top_p, - stream=True)) - # 返回迭代器 - history[-1][1] = '推理错误' - for res in response: - history[-1][1] = '推理错误' if res['result'] is None else res['result'] - yield history - yield history - except Exception as err: - logger.info(f"query failed, find exception: {err}") - history[-1][1] = "推理错误" - yield history +# 创建新的知识库 +def creat_knowledge_db(knowledge_name: str, is_regist: bool = True): + knowledge_name = get_knowledge_rename(knowledge_name) + index_name, db_name = get_db_names(knowledge_name) + vector_store_cls = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index=index_name, + auto_save=True) + # 初始化文档chunk关系数据库 + chunk_store_cls = SQLiteDocstore(db_path=db_name) + # 初始化知识管理关系数据库 + knowledge_store_cls = KnowledgeStore(db_path=db_name) + # 初始化知识库管理 + knowledge_db_cls = KnowledgeDB(knowledge_store=knowledge_store_cls, + chunk_store=chunk_store_cls, + vector_store=vector_store_cls, + knowledge_name=knowledge_name, + white_paths=["/tmp"]) + if is_regist: + knowledge_mgr.register(knowledge_db_cls) + return knowledge_db_cls -def user_query(user_message, history): - return "", history + [[user_message, None]] +# 创建检索器 +def creat_retriever(knowledge_name: str, top_k, score_threshold): + knowledge_name = get_knowledge_rename(knowledge_name) + index_name, db_name = get_db_names(knowledge_name) + vector_store_ret = MindFAISS(x_dim=embedding_dim, + similarity_strategy=SimilarityStrategy.FLAT_L2, + devs=[dev], + load_local_index=index_name, + auto_save=True) + # 初始化文档chunk关系数据库 + chunk_store_ret = SQLiteDocstore(db_path=db_name) + retriever_cls = Retriever(vector_store=vector_store_ret, + document_store=chunk_store_ret, + embed_func=text_emb.embed_documents, + k=top_k, + score_threshold=score_threshold) + return retriever_cls -def clear_history(history): - return [] +# 删除知识库 +def delete_knowledge_db(knowledge_name: str): + knowledge_name = get_knowledge_rename(knowledge_name) + knowledge_db_lst = knowledge_mgr.get_all() + if knowledge_name in knowledge_db_lst: + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + clear_document(knowledge_name) + knowledge_mgr.delete(knowledge_db_cls) + return get_knowledge_db() -SAVE_FILE_PATH = "/tmp/document_files" +# 获取知识库中文档列表 +def get_document(knowledge_name: str): + knowledge_name = get_knowledge_rename(knowledge_name) + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + doc_names = knowledge_db_cls.get_all_documents() + return knowledge_name, doc_names, len(doc_names) -def file_upload(files): +# 清空知识库中文档列表 +def clear_document(knowledge_name: str): + knowledge_name, doc_names, doc_cnt = get_document(knowledge_name) + if doc_cnt > 0: + knowledge_db_cls = creat_knowledge_db(knowledge_name, is_regist=False) + for doc_name in doc_names: + knowledge_db_cls.delete_file(doc_name) + return get_document(knowledge_name) + else: + return knowledge_name, doc_names, doc_cnt + + +# 获取知识库列表 +def get_knowledge_db(): + knowledge_db_lst = knowledge_mgr.get_all() + return knowledge_db_lst, len(knowledge_db_lst) + + +# 上传知识库 +def file_upload(files, + knowledge_db_name: str = 'test', + chunk_size: int = 750, + chunk_overlap: int = 150 + ): + save_file_path = "/tmp/document_files" + knowledge_db_name = get_knowledge_rename(knowledge_db_name) # 指定保存文件的文件夹 - if not os.path.exists(SAVE_FILE_PATH): - os.makedirs(SAVE_FILE_PATH) + if not os.path.exists(save_file_path): + os.makedirs(save_file_path) if files is None or len(files) == 0: print('no file need save') + if knowledge_db_name in knowledge_mgr.get_all(): + knowledge_db_cls = creat_knowledge_db(knowledge_db_name, is_regist=False) + else: + knowledge_db_cls = creat_knowledge_db(knowledge_db_name) + # 注册文档处理器 + loader_mng = LoaderMng() + # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 + loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) + loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) + loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) + loader_mng.register_loader(loader_class=ExcelLoader, file_types=[".xlsx", ".xls"]) + + # 加载文档切分器,使用langchain的 + loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], + splitter_params={"chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "keep_separator": False + }) for file in files: try: # 上传领域知识文档 - shutil.copy(file.name, SAVE_FILE_PATH) + shutil.copy(file.name, save_file_path) # 知识库:chunk\embedding\add - upload_files(text_knowledge_db, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, - force=True) - print(f"file {file.name} save to {SAVE_FILE_PATH}.") + upload_files(knowledge_db_cls, [file.name], loader_mng=loader_mng, embed_func=text_emb.embed_documents, force=True) + print(f"file {file.name} save to {save_file_path}.") except Exception as err: logger.error(f"save failed, find exception: {err}") @@ -83,147 +155,388 @@ def file_change(files, upload_btn): print("file changes") -def build_demo(): - with gr.Blocks() as demo: - gr.HTML("

检索增强生成(RAG)对话

powered by MindX RAG

") - with gr.Row(): - with gr.Column(scale=100): - with gr.Row(): - model_select = gr.Dropdown(choices=["Default"], value="Default", container=False, interactive=True) - with gr.Row(): - files = gr.Files( - height=300, - file_count="multiple", - file_types=[".docx", ".txt", ".md", ".pdf"], - interactive=True, - label="上传知识库文档" - ) - with gr.Row(): - upload_btn = gr.Button("上传文件") - with gr.Row(): - with gr.Accordion(label='知识库情况'): - temperature = gr.Slider( - minimum=0.01, - maximum=2, - value=0.5, - step=0.01, - interactive=True, - label="温度", - info="Token生成的随机性" - ) - top_p = gr.Slider( - minimum=0.01, - maximum=1, - value=0.95, - step=0.05, - interactive=True, - label="Top P", - info="累计概率总和阈值" - ) - max_tokens = gr.Slider( - minimum=100, - maximum=1024, - value=512, - step=1, - interactive=True, - label="最大tokens", - info="输入+输出最多的tokens数" - ) - with gr.Column(scale=200): - chatbot = gr.Chatbot(height=500) - with gr.Row(): - msg = gr.Textbox(placeholder="在此输入问题...", container=False) - with gr.Row(): - send_btn = gr.Button(value="发送", variant="primary") - clean_btn = gr.Button(value="清空历史") - send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False).then(bot_response, - [chatbot, max_tokens, - temperature, top_p], chatbot) - clean_btn.click(clear_history, chatbot, chatbot) - files.change(file_change, [], []) - upload_btn.click(file_upload, files, []).then(clear_history, files, files) - return demo +def get_db_names(knowledge_name: str): + index_name = "./" + knowledge_name + "_faiss.index" + db_name = "./" + knowledge_name + "_sql.db" + return index_name, db_name -def create_gradio(port): - demo = build_demo() - demo.queue() - demo.launch(share=True, server_name="0.0.0.0", server_port=port) +def get_knowledge_rename(knowledge_name: str): + if knowledge_name is None or len(knowledge_name) == 0 or ' ' in knowledge_name: + return 'test' + else: + return knowledge_name -class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): - def _get_default_metavar_for_optional(self, action): - return action.type.__name__ +# 历史问题改写 +def generate_question(history, llm, history_n: int = 5): + prompt = """现在你有一个上下文依赖问题补全任务,任务要求:请根据对话历史和用户当前的问句,重写问句。\n + 历史问题依次是:\n + {}\n + 用户当前的问句:\n + {}\n + 注意如果当前问题不依赖历史问题直接返回none即可\n + 请根据上述对话历史重写用户当前的问句,仅输出重写后的问句,不需要附加任何分析。\n + 重写问句: \n + """ + if len(history) <= 2: + return history + cur_query = history[-1][0] + history_qa = history[0:-1] + history_list = [f"第{idx + 1}轮:{q_a[0]}" for idx, q_a in enumerate(history_qa) if q_a[0] is not None] + history_list = history_list[:history_n] + history_str = "\n\n".join(history_list) + re_query = prompt.format(history_str, cur_query) + new_query = llm.chat(query=re_query, llm_config=LLMParameterConfig(max_tokens=512, + temperature=0.5, + top_p=0.95)) + if new_query != "none": + history[-1][0] = "原始问题: " + cur_query + history += [[new_query, None]] + return history - def _get_default_metavar_for_positional(self, action): - return action.type.__name__ + +# 聊天对话框 +def bot_response(history, + history_r, + max_tokens: int = 512, + temperature: float = 0.5, + top_p: float = 0.95, + history_n: int = 5, + score_threshold: float = 0.5, + top_k: int = 1, + chat_type: str = "RAG检索增强对话", + show_type: str = "不显示", + is_rewrite: str = "否", + knowledge_name: str = 'test' + ): + chat_type_mapping = {"RAG检索增强对话":1, + "仅大模型对话":0} + show_type_mapping = {"对话结束后显示":1, + "检索框单独显示":2, + "不显示":0} + is_rewrite_mapping = {"是":1, + "否":0} + # 初始化检索器 + knowledge_name = get_knowledge_rename(knowledge_name) + retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold) + # 初始化chain + text2text_chain = SingleText2TextChain(llm=llm, retriever=retriever_cls, reranker=reranker) + # 历史问题改写 + if is_rewrite_mapping.get(is_rewrite) == 1: + history = generate_question(history, llm, history_n) + history[-1][1] = '推理错误' + try: + # 仅使用大模型回答 + if chat_type_mapping.get(chat_type) == 0: + response = llm.chat_streamly(query=history[-1][0], + llm_config=LLMParameterConfig(max_tokens=max_tokens, + temperature=temperature, + top_p=top_p)) + # 返回迭代器 + for res in response: + history[-1][1] = res + yield history, history_r + # 使用RAG增强回答 + elif chat_type_mapping.get(chat_type) == 1: + response = text2text_chain.query(history[-1][0], + LLMParameterConfig(max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + stream=True)) + # 不展示检索内容 + if show_type_mapping.get(show_type) == 0: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'] + yield history, history_r + # 问答结尾展示 + elif show_type_mapping.get(show_type) == 1: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'] + q_docs = res['source_documents'] + yield history, history_r + # 有检索到信息 + if len(q_docs) > 0: + history_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ + "参考内容:" + source['page_content'] + "\n" + history_last += sources + history += [[None, history_last]] + yield history, history_r + # 检索窗口展示 + else: + for res in response: + history[-1][1] = '推理错误' if res['result'] is None else res['result'] + q_docs = res['source_documents'] + yield history, history_r + # 有检索到信息 + if len(q_docs) > 0: + history_r_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source['metadata']['source'] + "====" + "\n" + \ + "参考内容:" + source['page_content'] + "\n" + history_r_last += sources + history_r += [[history[-1][0], history_r_last]] + yield history, history_r + except Exception as err: + logger.info(f"query failed, find exception: {err}") + yield history, history_r + + +# 检索对话框 +def re_response(history_r, + score_threshold: float = 0.5, + top_k: int = 1, + knowledge_name: str = 'test' + ): + # 初始化检索器 + knowledge_name = get_knowledge_rename(knowledge_name) + retriever_cls = creat_retriever(knowledge_name, top_k, score_threshold) + q_docs = retriever_cls.invoke(history_r[-1][0]) + if len(q_docs) > 0: + history_r_last = '' + for i, source in enumerate(q_docs): + sources = "\n====检索信息来源:" + str(i + 1) + "[数据库名]:" + str(knowledge_name) + \ + "[文件名]:" + source.metadata['source'] + "====" + "\n" + \ + "参考内容:" + source.page_content + "\n" + history_r_last += sources + history_r[-1][1] = history_r_last + else: + history_r[-1][1] = "未检索到相关信息" + return history_r + + +# 检索信息 +def user_retriever(user_message, history_r): + return "", history_r + [[user_message, None]] + + +# 聊天信息 +def user_query(user_message, history): + return "", history + [[user_message, None]] + + +def clear_history(history): + return [] if __name__ == '__main__': + class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter): + def _get_default_metavar_for_optional(self, action): + return action.type.__name__ + + def _get_default_metavar_for_positional(self, action): + return action.type.__name__ + parse = argparse.ArgumentParser(formatter_class=CustomFormatter) parse.add_argument("--embedding_path", type=str, default="/home/data/acge_text_embedding", help="embedding模型本地路径") + parse.add_argument("--tei_emb", type=bool, default=False, help="是否使用TEI服务化的embedding模型") + parse.add_argument("--embedding_url", type=str, default="http://127.0.0.1:8080/embed", + help="使用TEI服务化的embedding模型url地址") parse.add_argument("--embedding_dim", type=int, default=1024, help="embedding模型向量维度") parse.add_argument("--llm_url", type=str, default="http://127.0.0.1:1025/v1/chat/completions", help="大模型url地址") parse.add_argument("--model_name", type=str, default="Llama3-8B-Chinese-Chat", help="大模型名称") - parse.add_argument("--port", type=str, default=8080, help="web后端端口") + parse.add_argument("--tei_reranker", type=bool, default=False, help="是否使用TEI服务化的reranker模型") + parse.add_argument("--reranker_path", type=str, default=None, help="reranker模型本地路径") + parse.add_argument("--reranker_url", type=str, default=None, help="使用TEI服务化的embedding模型url地址") + parse.add_argument("--dev", type=int, default=0, help="使用的npu卡,可通过执行npu-smi info获取") + parse.add_argument("--port", type=int, default=8080, help="web后端端口") - args = parse.parse_args() + args = parse.parse_args().__dict__ + embedding_path: str = args.pop('embedding_path') + tei_emb: bool = args.pop('tei_emb') + embedding_url: str = args.pop('embedding_url') + embedding_dim: int = args.pop('embedding_dim') + llm_url: str = args.pop('llm_url') + model_name: str = args.pop('model_name') + tei_reranker: bool = args.pop('tei_reranker') + reranker_path: str = args.pop('reranker_path') + reranker_url: str = args.pop('reranker_url') + dev: int = args.pop('dev') + port: int = args.pop('port') - try: - # 离线构建知识库,首先注册文档处理器 - loader_mng = LoaderMng() - # 加载文档加载器,可以使用mxrag自有的,也可以使用langchain的 - loader_mng.register_loader(loader_class=TextLoader, file_types=[".txt", ".md"]) - loader_mng.register_loader(loader_class=DocxLoader, file_types=[".docx"]) - loader_mng.register_loader(loader_class=PdfLoader, file_types=[".pdf"]) - - # 加载文档切分器,使用langchain的 - loader_mng.register_splitter(splitter_class=RecursiveCharacterTextSplitter, - file_types=[".docx", ".txt", ".md", ".pdf"], - splitter_params={"chunk_size": 750, - "chunk_overlap": 150, - "keep_separator": False - } - ) - - # 设置向量检索使用的npu卡,具体可以用的卡可执行npu-smi info查询获取 - dev = 0 - - # 初始化向量数据库 - vector_store = MindFAISS(x_dim=args.embedding_dim, - similarity_strategy=SimilarityStrategy.FLAT_L2, - devs=[dev], - load_local_index="./faiss.index", - auto_save=True - ) - # 初始化文档chunk关系数据库 - chunk_store = SQLiteDocstore(db_path="./sql.db") - # 初始化知识管理关系数据库 - knowledge_store = KnowledgeStore(db_path="./sql.db") - # 初始化知识库管理 - text_knowledge_db = KnowledgeDB(knowledge_store=knowledge_store, - chunk_store=chunk_store, - vector_store=vector_store, - knowledge_name="test", - white_paths=["/tmp"] - ) - - # 加载embedding模型,请根据模型具体路径适配 - text_emb = TextEmbedding(model_path=args.embedding_path, dev_id=dev) - - # Step2在线问题答复,初始化检索器 - text_retriever = Retriever(vector_store=vector_store, - document_store=chunk_store, - embed_func=text_emb.embed_documents, - k=1, - score_threshold=0.4 - ) - # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 - llm = Text2TextLLM(base_url=args.llm_url, model_name=args.model_name, client_param=ClientParam(use_http=True)) - text2text_chain = SingleText2TextChain(llm=llm, - retriever=text_retriever, - ) - create_gradio(int(args.port)) - except Exception as e: - logger.info(f"exception happed :{e}") + # 初始化test数据库 + knowledge_db = creat_knowledge_db('test') + # 配置text生成text大模型chain,具体ip端口请根据实际情况适配修改 + llm = Text2TextLLM(base_url=llm_url, model_name=model_name, client_param=ClientParam(use_http=True)) + # 配置embedding模型,请根据模型具体路径适配 + if tei_emb: + text_emb = TEIEmbedding(url=embedding_url, client_param=ClientParam(use_http=True)) + else: + text_emb = TextEmbedding(model_path=embedding_path, dev_id=dev) + # 配置reranker,请根据模型具体路径适配 + if tei_reranker: + reranker = TEIReranker(url=reranker_url, client_param=ClientParam(use_http=True)) + elif reranker_path is not None: + reranker = LocalReranker(model_path=reranker_path, dev_id=dev) + else: + reranker = None + + # 构建gradio框架 + def build_demo(): + with (gr.Blocks() as demo): + gr.HTML("

检索增强生成(RAG)对话

powered by MindX RAG

") + with gr.Row(): + with gr.Column(scale=100): + with gr.Row(): + model_select = gr.Dropdown(choices=[model_name], value="Llama3-8B-Chinese-Chat", container=False, interactive=True) + with gr.Row(): + files = gr.components.File( + height=100, + file_count="multiple", + file_types=[".docx", ".txt", ".md", ".pdf", ".xlsx", ".xls"], + interactive=True, + label="上传知识库文档" + ) + with gr.Row(): + upload_btn = gr.Button("上传文件") + with gr.Row(): + with gr.TabItem("知识库情况"): + set_knowledge_name = gr.Textbox(label='设置当前知识库', placeholder="在此输入知识库名称,默认使用test知识库") + with gr.Row(): + creat_knowledge_btn = gr.Button('创建知识库') + delete_knowledge_btn = gr.Button('删除知识库') + knowledge_name_output = gr.Textbox(label='知识库列表') + knowledge_number_output = gr.Textbox(label='知识库数量') + with gr.Row(): + show_knowledge_btn = gr.Button('显示知识库') + with gr.TabItem("文件情况"): + knowledge_name = gr.Textbox(label='知识库名称') + knowledge_file_output = gr.Textbox(label='知识库文件列表') + knowledge_file_num_output = gr.Textbox(label='知识库文件数量') + with gr.Row(): + knowledge_file_out_btn = gr.Button('显示文件情况') + knowledge_clear_btn = gr.Button('清空知识库') + with gr.Row(): + with gr.Accordion(label='文档切分参数设置', open=False): + chunk_size = gr.Slider( + minimum=50, + maximum=5000, + value=750, + step=50, + interactive=True, + label="chunk_size", + info="文本切分长度" + ) + chunk_overlap = gr.Slider( + minimum=10, + maximum=500, + value=150, + step=10, + interactive=True, + label="chunk_overlap", + info="文本切分填充长度" + ) + with gr.Row(): + with gr.Accordion(label='大模型参数设置', open=False): + temperature = gr.Slider( + minimum=0.01, + maximum=2, + value=0.5, + step=0.01, + interactive=True, + label="温度", + info="Token生成的随机性" + ) + top_p = gr.Slider( + minimum=0.01, + maximum=1, + value=0.95, + step=0.05, + interactive=True, + label="Top P", + info="累计概率总和阈值" + ) + max_tokens = gr.Slider( + minimum=100, + maximum=1024, + value=512, + step=1, + interactive=True, + label="最大tokens", + info="输入+输出最多的tokens数" + ) + is_rewrite = gr.Radio(['是', '否'], value="否", label="是否根据历史提问重写问题?") + history_n = gr.Slider( + minimum=1, + maximum=10, + value=5, + step=1, + interactive=True, + label="历史提问重写轮数", + info="问题重写时所参考的历史提问轮数" + ) + with gr.Row(): + with gr.Accordion(label='检索参数设置', open=False): + score_threshold = gr.Slider( + minimum=0, + maximum=1, + value=0.5, + step=0.01, + interactive=True, + label="score_threshold", + info="相似性检索阈值,值越大表示越相关,低于阈值不会被返回。" + ) + top_k = gr.Slider( + minimum=1, + maximum=5, + value=1, + step=1, + interactive=True, + label="top_k", + info="相似性检索返回条数" + ) + show_type = gr.Radio(['对话结束后显示', '检索框单独显示', '不显示'], value="对话结束后显示", label="知识库文档匹配结果展示方式选择") + with gr.Column(scale=200): + with gr.Tabs(): + with gr.TabItem("对话窗口"): + chat_type = gr.Radio(['RAG检索增强对话', '仅大模型对话'], value="RAG检索增强对话", label="请选择对话模式?") + chatbot = gr.Chatbot(height=550) + with gr.Row(): + msg = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn = gr.Button(value="发送", variant="primary") + clean_btn = gr.Button(value="清空历史") + with gr.TabItem("检索窗口"): + chatbot_r = gr.Chatbot(height=550) + with gr.Row(): + msg_r = gr.Textbox(placeholder="在此输入问题...", container=False) + with gr.Row(): + send_btn_r = gr.Button(value="文档检索", variant="primary") + clean_btn_r = gr.Button(value="清空历史") + # RAG发送 + send_btn.click(user_query, [msg, chatbot], [msg, chatbot], queue=False + ).then(bot_response, + [chatbot, chatbot_r, max_tokens, temperature, top_p, history_n, score_threshold, top_k, chat_type, show_type, is_rewrite, set_knowledge_name], + [chatbot, chatbot_r]) + # RAG清除历史 + clean_btn.click(clear_history, chatbot, chatbot) + # 上传文件 + files.change(file_change, [], []) + upload_btn.click(file_upload, [files, set_knowledge_name, chunk_size, chunk_overlap], files) + # 管理所有知识库 + creat_knowledge_btn.click(creat_knowledge_db, [set_knowledge_name], []).then(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) + show_knowledge_btn.click(get_knowledge_db, [], [knowledge_name_output, knowledge_number_output]) + delete_knowledge_btn.click(delete_knowledge_db, [set_knowledge_name], [knowledge_name_output, knowledge_number_output]) + # 管理知识库里文件 + knowledge_file_out_btn.click(get_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + knowledge_clear_btn.click(clear_document, [set_knowledge_name], [knowledge_name, knowledge_file_output, knowledge_file_num_output]) + # 检索发送 + send_btn_r.click(user_retriever, [msg_r, chatbot_r], [msg_r, chatbot_r], queue=False + ).then(re_response, [chatbot_r, score_threshold, top_k, set_knowledge_name], chatbot_r) + # 检索清除历史 + clean_btn_r.click(clear_history, chatbot_r, chatbot_r) + return demo + + + def create_gradio(ports): + demo = build_demo() + demo.queue() + demo.launch(share=True, server_name="0.0.0.0", server_port=ports) + + # 启动gradio + create_gradio(port)