代码拉取完成,页面将自动刷新
同步操作将从 衣沾不足惜/gitee-ai-docs-test 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain_text_splitters import RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from langchain_text_splitters import Language
from langchain_community.document_loaders.parsers import LanguageParser
from langchain_community.document_loaders.generic import GenericLoader
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain_community.document_loaders import JSONLoader
from langchain.schema import (
HumanMessage, AIMessage, SystemMessage
)
import chainlit as cl
import utils
import torch
from langchain_huggingface import HuggingFaceEmbeddings
import os
import re
import gradio as gr
import requests
import json
import subprocess
import threading
import time
from langchain_community.document_loaders import DirectoryLoader, TextLoader
DEBUG = os.environ.get("DEBUG", "false") in ("true", "1", "yes")
model_name = "hf-models/glm-4-9b-chat" # gitee-docs-lora
api_server_ready = False
GITEE_ACCESS_TOKEN = os.environ.get("GITEE_ACCESS_TOKEN", "")
if (not GITEE_ACCESS_TOKEN):
print("GITEE_ACCESS_TOKEN 环境变量不存在")
base_url = "http://127.0.0.1:8000/v1/"
repo_path = "/repos"
git_clone_ai_doc_command = ["git", "clone",
f"https://oauth2:{GITEE_ACCESS_TOKEN}@gitee.com/gitee-ai/docs.git", "--depth=1", "--single-branch", repo_path+"/ai.gitee.com"]
git_clone_gitee_doc_command = ["git", "clone",
f"https://oauth2:{GITEE_ACCESS_TOKEN}@gitee.com/oschina/gitee-help-center.git", "--depth=1", "--single-branch", repo_path+"/help.gitee.com"]
def clone_doc_repo():
try:
subprocess.run(
git_clone_ai_doc_command, text=True)
subprocess.run(
git_clone_gitee_doc_command, text=True)
# print("stdout:", git_process.stdout)
# print("stderr:", git_process.stderr)
except Exception as e:
print("克隆仓库发生错误:", e)
db = ""
retriever = ""
bge_model_name = "hf-models/bge-m3"
bge_model_kwargs = {'device': 'cuda:1'} # 'torch_dtype': torch.float16
bge_encode_kwargs = {'normalize_embeddings': False}
hfEmbeddings = HuggingFaceEmbeddings(
model_name=bge_model_name,
model_kwargs=bge_model_kwargs,
encode_kwargs=bge_encode_kwargs,
)
# 多卡跑多 模型时, vllm GPU blocks: 0 https://github.com/vllm-project/vllm/issues/2248
torch.cuda.empty_cache()
print('向量模型准备完成')
# 读取 choice-gitee-docs.json 文件
# with open('./data/choice-gitee-docs.json', 'r', encoding='utf-8') as file:
# choice_gitee_docs = json.load(file)
# # 提取所有 output
# choice_gitee_docs_output = [doc['output']
# for doc in choice_gitee_docs if 'output' in doc]
# # 修改文档中的链接。包含 slug 则由 ai 处理.
# choice_gitee_docs_output_docs = RecursiveJsonSplitter(
# max_chunk_size=2000).create_documents(texts=[choice_gitee_docs_output])
# json_loader = JSONLoader(
# file_path='./data/choice-gitee-docs.json',
# jq_schema='.',
# text_content=False)
# choice_gitee_docs_output_docs = json_loader.load()
def update_sources(documents):
utils.load_network_data()
dir_loader = DirectoryLoader(
'./data', glob="**/[!.]*", loader_cls=TextLoader, show_progress=True)
dir_loader_document = dir_loader.load()
print("./data 数据量", len(dir_loader_document))
for doc in documents:
source = doc.metadata.get('source')
# /repos/ai.gitee.com/docs/xxx
# /repos/help.gitee.com/docs/xxx
if source and source.startswith(repo_path+'/'):
# 将 /repos/ 改为 https://
source = source.replace(repo_path+'/', 'https://', 1)
# 去掉末尾的 .md
if source.endswith('.md'):
source = source[:-3]
# 针对 /help.gitee.com 去掉 /docs
if 'help.gitee.com/docs' in source:
source = source.replace('/docs', '', 1)
# 检查 page_content 中是否有 slug
slug_match = re.search(r'slug: (/.+)', doc.page_content)
if slug_match:
slug = slug_match.group(1)
doc.page_content = f'参考文档链接: https://help.gitee.com{slug}\n\n{doc.page_content}'
elif source:
# 拼接 source 到 page_content 最开始
doc.page_content = f'参考文档链接: {source}\n\n{doc.page_content}'
# 从 metadata 中移除 source
if 'source' in doc.metadata:
del doc.metadata['source']
return dir_loader_document + documents
update_doc_db_ing = False
def update_doc_db():
global update_doc_db_ing
global db
global retriever
print('开始生成向量库')
if (update_doc_db_ing):
return
update_doc_db_ing = True
subprocess.run(["rm", "-rf", repo_path])
clone_doc_repo()
loader = GenericLoader.from_filesystem(
repo_path,
glob="**/*",
suffixes=[".md", ".txt"],
exclude=["**/non-utf8-encoding.py", ".git/**"],
parser=LanguageParser(parser_threshold=0),
)
documents = loader.load()
documents = update_sources(documents)
python_splitter = RecursiveCharacterTextSplitter.from_language(
# ValueError: Batch size 51030 exceeds maximum batch size 41666
language=Language.MARKDOWN, chunk_size=1800, chunk_overlap=400
)
docs = python_splitter.split_documents(documents)
print("召回内容", docs[5:10])
print("docs 分块长度", len(docs))
# db = Chroma.from_documents(docs, OpenAIEmbeddings(disallowed_special=(), api_key="EMPTY", base_url=base_url, model="hf-models/bge-m3", timeout=300,
# tiktoken_enabled=False, show_progress_bar=True, chunk_size=8192), persist_directory="./chroma_db") # 非 OpenAI 实现,tiktoken_enabled 必须设置为 False
db = Chroma.from_documents(docs, hfEmbeddings)
retriever = db.as_retriever(
# search_type="mmr", # 值为 'similarity', 'similarity_score_threshold', 'mmr')
search_type="similarity",
search_kwargs={"k": 13}
)
update_doc_db_ing = False
print('向量库已生成')
torch.cuda.empty_cache()
# print(documents[100:500])
update_doc_db()
# chunk_size 块字符长度, 召回的每一项中字符串的长度, 代码文件需要比较大才能召回文件内容完整。300 行代码可能有六千字符
# chunk_overlap (块重叠) 指相邻块之间的重叠部分,确保在块分块后, 块之间的内容不会被忽略
llm = ChatOpenAI(model=model_name, api_key="EMPTY", base_url=base_url,
# glm4 [151329, 151336, 151338] qwen2[151643, 151644, 151645]
callbacks=[StreamingStdOutCallbackHandler()],
streaming=True, temperature=0.2, presence_penalty=1.2, top_p=0.95, extra_body={"stop_token_ids": [151329, 151336, 151338]})
@cl.on_chat_start
async def start_chat():
# cl.user_session.set(
# "message_history",
# [{"role": "system", "content": "xxx"}],
# )
await cl.Message(content="哈喽!我是马建仓,你在使用 Gitee AI 过程中有任何问题都可以找我!\nAI 输出,仅用于文档参考,不代表官方观点。").send()
@cl.on_message
async def main(message: cl.Message):
gen_time_start = time.time()
global api_server_ready
global db
global retriever
if (not api_server_ready):
api_server_ready = utils.is_port_open(base_url)
return await cl.Message(content="API 服务正在启动中,请重试...").send()
if (not message.content):
return await cl.Message(content="请输入你的问题").send()
if len(message.content) > 1000:
return await cl.Message(content="输入长度不能超过 1000 哦!").send()
print("\ninput:", message.content)
if (not db):
update_doc_db()
# raise gr.Info("正在更新向量数据库,请稍后,过程大约持续三分钟")
search_res = str(retriever.invoke(message.content))
if (DEBUG):
print("搜索结果:", search_res)
system_message = {"role": "system", "content": f"""
- 【你十分幽默风趣活泼,名叫马建仓,是 Gitee 吉祥物,现在在给 Gitee 打工,是一名金牌客服】。
- 你的回答参考 markdown 文档内容,紧扣用户问题,使用 markdown 格式回答。
- 注意区分用户问题, 如有必要, Gitee 和 Gitee AI 分开回答, 上下文提供了 Gitee 和 Gitee AI 两种文档。
- 用户提问模糊不清时,如有必要,你可以继续询问用户,细化问题。
- 回答时,确保:
1. 文档中有 https 链接(图片、链接等),如有必要,请提供完整,可以使用 markdown 渲染,相对路径和不相关图片请勿提供。
2. 如有必要,回答末尾可提供前少量参考文档 https:// 链接, 方便用户了解更多, 【注意: 请勿编造链接,请勿总结混合链接,没有就不要提供】。
3. 遵守中国法律,不回复任何敏感、违法、违反道德的问题。
4. 你的知识和上下文文档冲突时,以文档为准。
5. 如果你不知道答案,请勿捏造假内容。
- markdown 文档内容为: {search_res}
"""}
message_history = cl.user_session.get(
"message_history") or []
message_history.append({"role": "user", "content": message.content})
ai_msg = cl.Message(content="")
await ai_msg.send()
# 每次都在最开始插入 system_message, 但不保存到用户会话历史消息中
# message_history.insert(0, system_message) # insert 改变原数组, 返回 None
stream = llm.stream([system_message] + message_history)
if (DEBUG):
print("history", str(message_history))
full_answer = ""
for part in stream:
if content := part.content or "":
full_answer += content
if (len(full_answer) == 1):
gen_time_end = time.time()
print("首字消耗时间:", gen_time_end - gen_time_start)
await ai_msg.stream_token(content)
message_history.append({"role": "assistant", "content": ai_msg.content})
cl.user_session.set("message_history", message_history)
await ai_msg.update()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。