From c822055cc5e18ef484c880f6a567a7da078535b4 Mon Sep 17 00:00:00 2001 From: HeYudong <1357644645@qq.com> Date: Wed, 29 May 2024 02:28:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E5=AE=8C=E6=88=90openai?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- application/controllers/openai_controller.py | 95 ++++++++++++++++++- application/utils/__init__.py | 7 -- application/utils/general_tool/__init__.py | 33 ++++++- application/utils/moonshot/__init__.py | 2 +- application/utils/openai_api/__init__.py | 2 +- application/utils/openai_api/chat.py | 23 +++++ application/utils/qianfan/__init__.py | 2 +- application/utils/vector/__init__.py | 3 +- .../utils/vector/hang_vector_database.py | 11 +++ 9 files changed, 163 insertions(+), 15 deletions(-) create mode 100644 application/utils/vector/hang_vector_database.py diff --git a/application/controllers/openai_controller.py b/application/controllers/openai_controller.py index 0f89340..55c3e0a 100644 --- a/application/controllers/openai_controller.py +++ b/application/controllers/openai_controller.py @@ -1,4 +1,7 @@ from flask import Blueprint, request, Response, current_app +from application.utils.general_tool import gcfnpdap, filter_docs, test_chat_history_length +from application.utils.vector import initialize_faiss_database +from application.utils.openai_api import gpt_three_stream openai_bp = Blueprint('openai', __name__, url_prefix='/openai') @@ -10,10 +13,12 @@ def openai_chat(): { "query": str, "history": List, - "model_param": { + "model_config": { "temperature": int, # openai 0-2. 其于模型未知 "max_token": int, #最大值1024 - "system_role": str, #最大值500 + "system_prompt_code": int | None + "system_prompt_content": str | None 最大值1500 + "knowledge": str | None # 允许选定知识库,如果选择知识库为None } } :return: @@ -29,8 +34,92 @@ def openai_chat(): pass else: + current_app.vkb = None + current_app.train_prompt = None + current_app.no_train_prompt = None + current_app.max_chat_length = 8000 + + # 相似性阀值,选用余弦相似度,1表示完全相关,0表示完全不相关 + current_app.similarity = 0.8 + new_question, chat_history = request.json - return Response("hello") + # 相关的语料 + related_docs_content = None + + # 相关的doc + related_docs = None + + if data["model_config"]["knowledge"]: + vf_apath = gcfnpdap(__name__, 2) + "/static/store_vector_knowledge_dictory/" + "1716811846_2" + current_app.vkb = initialize_faiss_database(vf_apath) + + if len(chat_history) > 2: + # 历史相似片段 + history_related_docs = current_app.vkb.similarity_search_with_relevance_scores(chat_history[-3], 1) + + # 最新问题相似片段 + new_question_related_docs = current_app.vkb.similarity_search_with_relevance_scores(new_question, 2) + + related_docs = new_question_related_docs + history_related_docs + + else: + # 最新问题相似片段 + new_question_related_docs = current_app.vkb.similarity_search_with_relevance_scores(new_question, 2) + + related_docs = new_question_related_docs + + # 相关的doc + related_docs = filter_docs(related_docs, current_app.similarity) + + # related_docs 经过过滤后不为空,则查找最相似的doc的scores + most_related_doc_scores = max([i[1] for i in related_docs]) + + related_docs_content = [doc[0].page_content for doc in related_docs] + related_docs_content = list(set(related_docs_content)) + related_docs_content = "\n\n".join(related_docs_content) + + # 最相关预料 + if related_docs: + + # 数据库操作 + # ------ + # ------ + # 这里data中传了代码过来,则用他的,使用数据库 + + trained_prompt_parent_dir = gcfnpdap(__name__, 2) + "/static/trained/" + "file_name.txt" + + with open(trained_prompt_parent_dir, "r", encoding="UTF-8") as f: + current_app.train_prompt = f.read() + + trained_prompt = current_app.train_prompt.format(doc=related_docs_content) + + chat_history.insert(0, trained_prompt) + + chat_history = test_chat_history_length(chat_history, current_app.max_chat_length) + + data["chat_history"] = chat_history + + return Response(gpt_three_stream(data), mimetype="text/event-stream") + + else: + + # 数据库操作 + # ------ + # ------ + + no_trained_prompt_parent_dir = gcfnpdap(__name__, 2) + "/static/no_trained/" + "file_name.txt" + + with open(no_trained_prompt_parent_dir, "r", encoding="UTF-8") as f: + current_app.no_train_prompt = f.read() + + chat_history.insert(0, current_app.no_train_prompt) + + chat_history = test_chat_history_length(chat_history, current_app.max_chat_length) + + data["chat_history"] = chat_history + + return Response(gpt_three_stream(data), mimetype="text/event-stream") + diff --git a/application/utils/__init__.py b/application/utils/__init__.py index 880f412..e69de29 100644 --- a/application/utils/__init__.py +++ b/application/utils/__init__.py @@ -1,7 +0,0 @@ -from application.utils.moonshot import all_func as moonshot -from application.utils.openai_api import all_func as openai_api -from application.utils.qianfan import all_func as qianfan -from application.utils.general_tool import all_func as tool -from application.utils.vector import all_func as vector - -__all__ = moonshot + openai_api + qianfan + tool + vector diff --git a/application/utils/general_tool/__init__.py b/application/utils/general_tool/__init__.py index c384539..18a6391 100644 --- a/application/utils/general_tool/__init__.py +++ b/application/utils/general_tool/__init__.py @@ -19,7 +19,38 @@ def gcfnpdap(file_path, n): return dir_absolute_path -all_func = [gcfnpdap] +def filter_docs(list_, value_): + """ + desc: 过滤相似性较低的语料 + :param list_: 向量数据库查询的相似的docs + :param value_: 自定义相似性阀值 + :return: + """ + return [i for i in list_ if i[1] >= value_] + + +def measure_chat_history_length(chat_history): + """判断字符列表中所有元素总长度""" + total_length = 0 + for i in chat_history: + total_length = total_length + len(i) + + return total_length + + +def test_chat_history_length(chat_history, max_chat_length): + """检验字符列表总长度,如果长度过长,删去历史记录""" + if measure_chat_history_length(chat_history) <= max_chat_length: + return chat_history + + else: + chat_history.pop(1) + chat_history.pop(1) + + return test_chat_history_length() + + +__all__ = ["gcfnpdap", "filter_docs", "test_chat_history_length"] diff --git a/application/utils/moonshot/__init__.py b/application/utils/moonshot/__init__.py index 1fa0b45..9fcffee 100644 --- a/application/utils/moonshot/__init__.py +++ b/application/utils/moonshot/__init__.py @@ -1,3 +1,3 @@ from application.utils.moonshot.chat import moonshot_one, moonshot_one_stream -all_func = [moonshot_one, moonshot_one_stream] \ No newline at end of file +__all__ = ["moonshot_one", "moonshot_one_stream"] diff --git a/application/utils/openai_api/__init__.py b/application/utils/openai_api/__init__.py index c2affe0..18c6a98 100644 --- a/application/utils/openai_api/__init__.py +++ b/application/utils/openai_api/__init__.py @@ -1,3 +1,3 @@ from application.utils.openai_api.chat import gpt_three, gpt_three_stream -all_func = [gpt_three, gpt_three_stream] +__all__ = ["gpt_three", "gpt_three_stream"] diff --git a/application/utils/openai_api/chat.py b/application/utils/openai_api/chat.py index 04f4b2f..8e966af 100644 --- a/application/utils/openai_api/chat.py +++ b/application/utils/openai_api/chat.py @@ -13,6 +13,17 @@ def gpt_three(data): } """ + # 处理聊天数据 + for i in range(len(data["message"])): + if i == 0: + data["message"][i] = {"role": "system", "content": data["message"][i]} + + elif i % 2 != 0: + data["message"][i] = {"role": "user", "content": data["message"][i]} + + else: + data["message"][i] = {"role": "assistant", "content": data["message"][i]} + client = OpenAI() completion = client.chat.completions.create( model="gpt-4o", @@ -33,6 +44,18 @@ def gpt_three(data): def gpt_three_stream(data): + + # 处理聊天数据 + for i in range(len(data["message"])): + if i == 0: + data["message"][i] = {"role": "system", "content": data["message"][i]} + + elif i % 2 != 0: + data["message"][i] = {"role": "user", "content": data["message"][i]} + + else: + data["message"][i] = {"role": "assistant", "content": data["message"][i]} + client = OpenAI() completion = client.chat.completions.create( model='gpt-3.5-turbo', diff --git a/application/utils/qianfan/__init__.py b/application/utils/qianfan/__init__.py index d8a67ff..ab24b7a 100644 --- a/application/utils/qianfan/__init__.py +++ b/application/utils/qianfan/__init__.py @@ -3,5 +3,5 @@ import json from application.utils.qianfan.chat import ernie_three, ernie_three_stream -all_func = [ernie_three, ernie_three_stream] +__all__ = ["ernie_three", "ernie_three_stream"] diff --git a/application/utils/vector/__init__.py b/application/utils/vector/__init__.py index 815c7e3..6764e7f 100644 --- a/application/utils/vector/__init__.py +++ b/application/utils/vector/__init__.py @@ -1,4 +1,5 @@ from application.utils.vector.excel_to_vetor import xlsx_excel_to_vector +from application.utils.vector.hang_vector_database import initialize_faiss_database import os # store openai-api-key to internal storage @@ -8,4 +9,4 @@ os.environ["OPENAI_API_KEY"] = "sk-proj-wOXa1EvcrzPxGvVwjbp3T3BlbkFJSZNHhSdxIjtB os.environ["http_proxy"] = "http://localhost:7890" os.environ["https_proxy"] = "http://localhost:7890" -all_func = [xlsx_excel_to_vector] +__all__ = ["xlsx_excel_to_vector", "initialize_faiss_database"] diff --git a/application/utils/vector/hang_vector_database.py b/application/utils/vector/hang_vector_database.py new file mode 100644 index 0000000..16da623 --- /dev/null +++ b/application/utils/vector/hang_vector_database.py @@ -0,0 +1,11 @@ +from langchain_openai import OpenAIEmbeddings +from langchain_community.vectorstores import FAISS + + +def initialize_faiss_database(vector_data_dir_absolute_path): + # openai embedding + embeddings = OpenAIEmbeddings() + + vector_kb_object = FAISS.load_local(vector_data_dir_absolute_path, embeddings) + + return vector_kb_object -- Gitee